├── .gitignore
├── LICENSE
├── README.md
├── assets
├── images
│ └── ref.png
└── videos
│ └── dance.mp4
├── configs
├── inference_v2.yaml
├── test_stage_1.yaml
├── test_stage_2.yaml
├── train_stage_1.yaml
└── train_stage_2.yaml
├── downloading_weights.py
├── draw_dwpose.py
├── extract_dwpose_keypoints.py
├── extract_meta_info_multiple_dataset.py
├── musepose
├── __init__.py
├── dataset
│ ├── dance_image.py
│ └── dance_video.py
├── models
│ ├── attention.py
│ ├── motion_module.py
│ ├── mutual_self_attention.py
│ ├── pose_guider.py
│ ├── resnet.py
│ ├── transformer_2d.py
│ ├── transformer_3d.py
│ ├── unet_2d_blocks.py
│ ├── unet_2d_condition.py
│ ├── unet_3d.py
│ └── unet_3d_blocks.py
├── pipelines
│ ├── __init__.py
│ ├── context.py
│ ├── pipeline_pose2img.py
│ ├── pipeline_pose2vid.py
│ ├── pipeline_pose2vid_long.py
│ └── utils.py
└── utils
│ └── util.py
├── pose
├── config
│ ├── dwpose-l_384x288.py
│ └── yolox_l_8xb8-300e_coco.py
└── script
│ ├── dwpose.py
│ ├── tool.py
│ ├── util.py
│ └── wholebody.py
├── pose_align.py
├── pretrained_weights
└── put_models_here.txt
├── requirements.txt
├── test_stage_1.py
├── test_stage_2.py
├── train_stage_1_multiGPU.py
└── train_stage_2_multiGPU.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | .DS_Store
7 | pretrained_weights
8 | output
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | MIT License
3 |
4 | Copyright (c) 2024 Tencent Music Entertainment Group
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
24 |
25 | Other dependencies and licenses:
26 |
27 |
28 | Open Source Software Licensed under the MIT License:
29 | --------------------------------------------------------------------
30 | 1. sd-vae-ft-mse
31 | Files:https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main
32 | License:MIT license
33 | For details:https://choosealicense.com/licenses/mit/
34 |
35 |
36 |
37 |
38 | Open Source Software Licensed under the Apache License Version 2.0:
39 | --------------------------------------------------------------------
40 | 1. DWpose
41 | Files:https://huggingface.co/yzd-v/DWPose/tree/main
42 | License:Apache-2.0
43 | For details:https://choosealicense.com/licenses/apache-2.0/
44 |
45 | 2. Moore-AnimateAnyone
46 | Files:https://github.com/MooreThreads/Moore-AnimateAnyone
47 | License:Apache-2.0
48 | For details:https://github.com/MooreThreads/Moore-AnimateAnyone/blob/master/LICENSE
49 |
50 | Terms of the Apache License Version 2.0:
51 | --------------------------------------------------------------------
52 | Apache License
53 |
54 | Version 2.0, January 2004
55 |
56 | http://www.apache.org/licenses/
57 |
58 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
59 | 1. Definitions.
60 |
61 | "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
62 |
63 | "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
64 |
65 | "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
66 |
67 | "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
68 |
69 | "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
70 |
71 | "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
72 |
73 | "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
74 |
75 | "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
76 |
77 | "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
78 |
79 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
80 |
81 | 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
82 |
83 | 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
84 |
85 | 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
86 |
87 | You must give any other recipients of the Work or Derivative Works a copy of this License; and
88 |
89 | You must cause any modified files to carry prominent notices stating that You changed the files; and
90 |
91 | You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
92 |
93 | If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
94 |
95 | You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
96 |
97 | 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
98 |
99 | 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
100 |
101 | 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
102 |
103 | 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
104 |
105 | 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
106 |
107 | END OF TERMS AND CONDITIONS
108 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MusePose
2 |
3 | MusePose: a Pose-Driven Image-to-Video Framework for Virtual Human Generation.
4 |
5 | Zhengyan Tong,
6 | Chao Li,
7 | Zhaokang Chen,
8 | Bin Wu†,
9 | Wenjiang Zhou
10 | (†Corresponding Author, benbinwu@tencent.com)
11 |
12 | Lyra Lab, Tencent Music Entertainment
13 |
14 |
15 | **[github](https://github.com/TMElyralab/MusePose)** **[huggingface](https://huggingface.co/TMElyralab/MusePose)** **space (comming soon)** **Project (comming soon)** **Technical report (comming soon)**
16 |
17 | [MusePose](https://github.com/TMElyralab/MusePose) is an image-to-video generation framework for virtual human under control signal such as pose. The current released model was an implementation of [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone) by optimizing [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone).
18 |
19 | `MusePose` is the last building block of **the Muse opensource serie**. Together with [MuseV](https://github.com/TMElyralab/MuseV) and [MuseTalk](https://github.com/TMElyralab/MuseTalk), we hope the community can join us and march towards the vision where a virtual human can be generated end2end with native ability of full body movement and interaction. Please stay tuned for our next milestone!
20 |
21 | We really appreciate [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone) for their academic paper and [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone) for their code base, which have significantly expedited the development of the AIGC community and [MusePose](https://github.com/TMElyralab/MusePose).
22 |
23 | Update:
24 | 1. We release train codes of MusePose now!
25 |
26 | ## Overview
27 | [MusePose](https://github.com/TMElyralab/MusePose) is a diffusion-based and pose-guided virtual human video generation framework.
28 | Our main contributions could be summarized as follows:
29 | 1. The released model can generate dance videos of the human character in a reference image under the given pose sequence. The result quality exceeds almost all current open source models within the same topic.
30 | 2. We release the `pose align` algorithm so that users could align arbitrary dance videos to arbitrary reference images, which **SIGNIFICANTLY** improved inference performance and enhanced model usability.
31 | 3. We have fixed several important bugs and made some improvement based on the code of [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone).
32 |
33 | ## Demos
34 |
35 |
36 |
37 |
38 |
39 | |
40 |
41 |
42 | |
43 |
44 |
45 |
46 |
47 |
48 | |
49 |
50 |
51 | |
52 |
53 |
54 |
55 |
56 |
57 |
58 | |
59 |
60 |
61 | |
62 |
63 |
64 |
65 |
66 |
67 | |
68 |
69 |
70 | |
71 |
72 |
73 |
74 |
75 |
76 | ## News
77 | - [05/27/2024] Release `MusePose` and pretrained models.
78 | - [05/31/2024] Support [Comfyui-MusePose](https://github.com/TMElyralab/Comfyui-MusePose)
79 | - [06/14/2024] Bug Fixed in `inference_v2.yaml`.
80 | - [03/04/2025] Release train codes.
81 |
82 | ## Todo:
83 | - [x] release our trained models and inference codes of MusePose.
84 | - [x] release pose align algorithm.
85 | - [x] Comfyui-MusePose
86 | - [x] training guidelines.
87 | - [ ] Huggingface Gradio demo.
88 | - [ ] a improved architecture and model (may take longer).
89 |
90 |
91 | # Getting Started
92 | We provide a detailed tutorial about the installation and the basic usage of MusePose for new users:
93 |
94 | ## Installation
95 | To prepare the Python environment and install additional packages such as opencv, diffusers, mmcv, etc., please follow the steps below:
96 |
97 | ### Build environment
98 |
99 | We recommend a python version >=3.10 and cuda version =11.7. Then build environment as follows:
100 |
101 | ```shell
102 | pip install -r requirements.txt
103 | ```
104 |
105 | ### mmlab packages
106 | ```bash
107 | pip install --no-cache-dir -U openmim
108 | mim install mmengine
109 | mim install "mmcv>=2.0.1"
110 | mim install "mmdet>=3.1.0"
111 | mim install "mmpose>=1.1.0"
112 | ```
113 |
114 |
115 | ### Download weights
116 | You can download weights manually as follows:
117 |
118 | 1. Download our trained [weights](https://huggingface.co/TMElyralab/MusePose).
119 |
120 | 2. Download the weights of other components:
121 | - [sd-image-variations-diffusers](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/tree/main/unet)
122 | - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
123 | - [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
124 | - [yolox](https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth) - Make sure to rename to `yolox_l_8x8_300e_coco.pth`
125 | - [image_encoder](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/tree/main/image_encoder)
126 | - [control_v11p_sd15_openpose](https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/blob/main/diffusion_pytorch_model.bin) (for training only)
127 | - [animatediff](https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15_v2.ckpt) (for training only)
128 |
129 | Finally, these weights should be organized in `pretrained_weights` as follows:
130 | ```
131 | ./pretrained_weights/
132 | |-- MusePose
133 | | |-- denoising_unet.pth
134 | | |-- motion_module.pth
135 | | |-- pose_guider.pth
136 | | └── reference_unet.pth
137 | |-- dwpose
138 | | |-- dw-ll_ucoco_384.pth
139 | | └── yolox_l_8x8_300e_coco.pth
140 | |-- sd-image-variations-diffusers
141 | | └── unet
142 | | |-- config.json
143 | | └── diffusion_pytorch_model.bin
144 | |-- image_encoder
145 | | |-- config.json
146 | | └── pytorch_model.bin
147 | |-- sd-vae-ft-mse
148 | | |-- config.json
149 | | └── diffusion_pytorch_model.bin
150 | |-- control_v11p_sd15_openpose
151 | | └── diffusion_pytorch_model.bin
152 | └── animatediff
153 | └── mm_sd_v15_v2.ckpt
154 | ```
155 |
156 | ## Quickstart
157 | ### Inference
158 | #### Preparation
159 | Prepare your referemce images and dance videos in the folder ```./assets``` and organnized as the example:
160 | ```
161 | ./assets/
162 | |-- images
163 | | └── ref.png
164 | └── videos
165 | └── dance.mp4
166 | ```
167 |
168 | #### Pose Alignment
169 | Get the aligned dwpose of the reference image:
170 | ```
171 | python pose_align.py --imgfn_refer ./assets/images/ref.png --vidfn ./assets/videos/dance.mp4
172 | ```
173 | After this, you can see the pose align results in ```./assets/poses```, where ```./assets/poses/align/img_ref_video_dance.mp4``` is the aligned dwpose and the ```./assets/poses/align_demo/img_ref_video_dance.mp4``` is for debug.
174 |
175 | #### Inferring MusePose
176 | Add the path of the reference image and the aligned dwpose to the test config file ```./configs/test_stage_2.yaml``` as the example:
177 | ```
178 | test_cases:
179 | "./assets/images/ref.png":
180 | - "./assets/poses/align/img_ref_video_dance.mp4"
181 | ```
182 |
183 | Then, simply run
184 | ```
185 | python test_stage_2.py --config ./configs/test_stage_2.yaml
186 | ```
187 | ```./configs/test_stage_2.yaml``` is the path to the inference configuration file.
188 |
189 | Finally, you can see the output results in ```./output/```
190 |
191 | ##### Reducing VRAM cost
192 | If you want to reduce the VRAM cost, you could set the width and height for inference. For example,
193 | ```
194 | python test_stage_2.py --config ./configs/test_stage_2.yaml -W 512 -H 512
195 | ```
196 | It will generate the video at 512 x 512 first, and then resize it back to the original size of the pose video.
197 |
198 | Currently, it takes 16GB VRAM to run on 512 x 512 x 48 and takes 28GB VRAM to run on 768 x 768 x 48. However, it should be noticed that the inference resolution would affect the final results (especially face region).
199 |
200 | #### Face Enhancement
201 |
202 | If you want to enhance the face region to have a better consistency of the face, you could use [FaceFusion](https://github.com/facefusion/facefusion). You could use the `face-swap` function to swap the face in the reference image to the generated video.
203 |
204 | ### Training
205 | 1. Prepare
206 | First, put all your dance videos in a folder such as `./xxx`
207 | Next, `python extract_dwpose_keypoints.py --video_dir ./xxx`. The extracted dwpose_keypoints will be saved in `./xxx_dwpose_keypoints`.
208 | Then, `python draw_dwpose.py --video_dir ./xxx`. The rendered dwpose videos will be saved in `./xxx_dwpose_without_face` if `draw_face=False`. The rendered dwpose videos will be saved in `./xxx_dwpose` if `draw_face=True`.
209 | Finally, `python extract_meta_info_multiple_dataset.py --video_dirs ./xxx --dataset_name xxx`
210 | You will get a json file to record the path of all data. `./meta/xxx.json`
211 |
212 | 2. Config your accelerate and deepspeed
213 | `pip install accelerate`
214 | use cmd `accelerate config` to config your deepspeed according to your machine.
215 | We use zero 2 without any offload and our machine has 8x80GB GPU.
216 |
217 | 3. Config the yaml file for training
218 | stage 1
219 | `./configs/train_stage_1.yaml`
220 | stage 2
221 | `./configs/train_stage_2.yaml`
222 |
223 | 4. Launch Training
224 | stage 1
225 | `accelerate launch train_stage_1_multiGPU.py --config configs/train_stage_1.yaml`
226 | stage 2
227 | `accelerate launch train_stage_2_multiGPU.py --config configs/train_stage_2.yaml`
228 |
229 |
230 |
231 |
232 | # Acknowledgement
233 | 1. We thank [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone) for their technical report, and have refer much to [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone) and [diffusers](https://github.com/huggingface/diffusers).
234 | 1. We thank open-source components like [AnimateDiff](https://animatediff.github.io/), [dwpose](https://github.com/IDEA-Research/DWPose), [Stable Diffusion](https://github.com/CompVis/stable-diffusion), etc..
235 |
236 | Thanks for open-sourcing!
237 |
238 | # Limitations
239 | - Detail consitency: some details of the original character are not well preserved (e.g. face region and complex clothing).
240 | - Noise and flickering: we observe noise and flicking in complex background.
241 |
242 | # Citation
243 | ```bib
244 | @article{musepose,
245 | title={MusePose: a Pose-Driven Image-to-Video Framework for Virtual Human Generation},
246 | author={Tong, Zhengyan and Li, Chao and Chen, Zhaokang and Wu, Bin and Zhou, Wenjiang},
247 | journal={arxiv},
248 | year={2024}
249 | }
250 | ```
251 | # Disclaimer/License
252 | 1. `code`: The code of MusePose is released under the MIT License. There is no limitation for both academic and commercial usage.
253 | 1. `model`: The trained model are available for non-commercial research purposes only.
254 | 1. `other opensource model`: Other open-source models used must comply with their license, such as `ft-mse-vae`, `dwpose`, etc..
255 | 1. The testdata are collected from internet, which are available for non-commercial research purposes only.
256 | 1. `AIGC`: This project strives to impact the domain of AI-driven video generation positively. Users are granted the freedom to create videos using this tool, but they are expected to comply with local laws and utilize it responsibly. The developers do not assume any responsibility for potential misuse by users.
257 |
--------------------------------------------------------------------------------
/assets/images/ref.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TMElyralab/MusePose/61c52bd937224a614b3951419b735b639397cb62/assets/images/ref.png
--------------------------------------------------------------------------------
/assets/videos/dance.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TMElyralab/MusePose/61c52bd937224a614b3951419b735b639397cb62/assets/videos/dance.mp4
--------------------------------------------------------------------------------
/configs/inference_v2.yaml:
--------------------------------------------------------------------------------
1 | unet_additional_kwargs:
2 | use_inflated_groupnorm: true
3 | unet_use_cross_frame_attention: false
4 | unet_use_temporal_attention: false
5 | use_motion_module: true
6 | motion_module_resolutions:
7 | - 1
8 | - 2
9 | - 4
10 | - 8
11 | motion_module_mid_block: true
12 | motion_module_decoder_only: false
13 | motion_module_type: Vanilla
14 | motion_module_kwargs:
15 | num_attention_heads: 8
16 | num_transformer_block: 1
17 | attention_block_types:
18 | - Temporal_Self
19 | - Temporal_Self
20 | temporal_position_encoding: true
21 | temporal_position_encoding_max_len: 128
22 | temporal_attention_dim_div: 1
23 |
24 | noise_scheduler_kwargs:
25 | beta_start: 0.00085
26 | beta_end: 0.012
27 | beta_schedule: "scaled_linear"
28 | clip_sample: false
29 | steps_offset: 1
30 | ### Zero-SNR params
31 | prediction_type: "v_prediction"
32 | rescale_betas_zero_snr: True
33 | timestep_spacing: "trailing"
34 |
35 | sampler: DDIM
--------------------------------------------------------------------------------
/configs/test_stage_1.yaml:
--------------------------------------------------------------------------------
1 | pretrained_base_model_path: './pretrained_weights/sd-image-variations-diffusers'
2 | pretrained_vae_path: './pretrained_weights/sd-vae-ft-mse'
3 | image_encoder_path: './pretrained_weights/image_encoder'
4 |
5 |
6 |
7 | denoising_unet_path: "./pretrained_weights/MusePose/denoising_unet.pth"
8 | reference_unet_path: "./pretrained_weights/MusePose/reference_unet.pth"
9 | pose_guider_path: "./pretrained_weights/MusePose/pose_guider.pth"
10 |
11 |
12 |
13 |
14 | inference_config: "./configs/inference_v2.yaml"
15 | weight_dtype: 'fp16'
16 |
17 |
18 |
19 | test_cases:
20 | "./assets/images/ref.png":
21 | - "./assets/poses/align/img_ref_video_dance.mp4"
22 |
23 |
24 |
25 |
26 |
27 |
--------------------------------------------------------------------------------
/configs/test_stage_2.yaml:
--------------------------------------------------------------------------------
1 | pretrained_base_model_path: './pretrained_weights/sd-image-variations-diffusers'
2 | pretrained_vae_path: './pretrained_weights/sd-vae-ft-mse'
3 | image_encoder_path: './pretrained_weights/image_encoder'
4 |
5 |
6 |
7 | denoising_unet_path: "./pretrained_weights/MusePose/denoising_unet.pth"
8 | reference_unet_path: "./pretrained_weights/MusePose/reference_unet.pth"
9 | pose_guider_path: "./pretrained_weights/MusePose/pose_guider.pth"
10 | motion_module_path: "./pretrained_weights/MusePose/motion_module.pth"
11 |
12 |
13 |
14 | inference_config: "./configs/inference_v2.yaml"
15 | weight_dtype: 'fp16'
16 |
17 |
18 |
19 | test_cases:
20 | "./assets/images/ref.png":
21 | - "./assets/poses/align/img_ref_video_dance.mp4"
22 |
--------------------------------------------------------------------------------
/configs/train_stage_1.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train_bs: 8
3 | train_width: 768
4 | train_height: 768
5 | meta_paths:
6 | - "./meta/xxx.json"
7 | # Margin of frame indexes between ref and tgt images
8 | sample_margin: 128
9 |
10 | solver:
11 | gradient_accumulation_steps: 1
12 | mixed_precision: 'fp16'
13 | enable_xformers_memory_efficient_attention: True
14 | gradient_checkpointing: False
15 | max_train_steps: 400000
16 | max_grad_norm: 1.0
17 | # lr
18 | learning_rate: 1.0e-5
19 | scale_lr: False
20 | lr_warmup_steps: 1
21 | lr_scheduler: 'constant'
22 |
23 | # optimizer
24 | use_8bit_adam: False
25 | adam_beta1: 0.9
26 | adam_beta2: 0.999
27 | adam_weight_decay: 1.0e-2
28 | adam_epsilon: 1.0e-8
29 |
30 | val:
31 | validation_steps: 20000000000000000
32 |
33 |
34 | noise_scheduler_kwargs:
35 | num_train_timesteps: 1000
36 | beta_start: 0.00085
37 | beta_end: 0.012
38 | beta_schedule: "scaled_linear"
39 | steps_offset: 1
40 | clip_sample: false
41 |
42 | base_model_path: './pretrained_weights/sd-image-variations-diffusers'
43 | vae_model_path: './pretrained_weights/sd-vae-ft-mse'
44 | image_encoder_path: './pretrained_weights/sd-image-variations-diffusers/image_encoder'
45 | controlnet_openpose_path: './pretrained_weights/control_v11p_sd15_openpose/diffusion_pytorch_model.bin'
46 |
47 | weight_dtype: 'fp16' # [fp16, fp32]
48 | uncond_ratio: 0.1
49 | noise_offset: 0.05
50 | snr_gamma: 5.0
51 | # snr_gamma: 0
52 | enable_zero_snr: True
53 | pose_guider_pretrain: True
54 |
55 | seed: 12580
56 | resume_from_checkpoint: ''
57 | checkpointing_steps: 2250000000000000000000
58 | save_model_epoch_interval: 25
59 | exp_name: 'stage_1'
60 | output_dir: './exp_output'
61 |
62 | # load pretrained weights
63 | load_pth: True
64 | denoising_unet_path: "denoising_unet.pth"
65 | reference_unet_path: "reference_unet.pth"
66 | pose_guider_path: "pose_guider.pth"
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/configs/train_stage_2.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train_bs: 1
3 | train_width: 768
4 | train_height: 768
5 | meta_paths:
6 | - "./meta/xxx.json"
7 | sample_rate: 2
8 | n_sample_frames: 48
9 |
10 | solver:
11 | gradient_accumulation_steps: 1
12 | mixed_precision: 'fp16'
13 | enable_xformers_memory_efficient_attention: True
14 | gradient_checkpointing: True
15 | max_train_steps: 2000000
16 | max_grad_norm: 1.0
17 | # lr
18 | learning_rate: 1e-5
19 | scale_lr: False
20 | lr_warmup_steps: 1
21 | lr_scheduler: 'constant'
22 |
23 | # optimizer
24 | use_8bit_adam: False
25 | adam_beta1: 0.9
26 | adam_beta2: 0.999
27 | adam_weight_decay: 1.0e-2
28 | adam_epsilon: 1.0e-8
29 |
30 | val:
31 | validation_steps: 20000000000000
32 |
33 |
34 | noise_scheduler_kwargs:
35 | num_train_timesteps: 1000
36 | beta_start: 0.00085
37 | beta_end: 0.012
38 | beta_schedule: "scaled_linear"
39 | steps_offset: 1
40 | clip_sample: false
41 |
42 | base_model_path: './pretrained_weights/stable-diffusion-v1-5'
43 | vae_model_path: './pretrained_weights/sd-vae-ft-mse'
44 | image_encoder_path: './pretrained_weights/sd-image-variations-diffusers/image_encoder'
45 | mm_path: './pretrained_weights/mm_sd_v15_v2.ckpt'
46 |
47 | weight_dtype: 'fp16' # [fp16, fp32]
48 | uncond_ratio: 0.1
49 | noise_offset: 0.05
50 | snr_gamma: 5.0
51 | # snr_gamma: 0
52 | enable_zero_snr: True
53 | stage1_ckpt_dir: './exp_output/stage_1'
54 | stage1_ckpt_step: 30000
55 |
56 |
57 |
58 | seed: 12580
59 | resume_from_checkpoint: ''
60 | checkpointing_steps: 2000000000000000
61 | exp_name: 'stage_2'
62 | output_dir: './exp_output'
63 |
64 |
65 |
--------------------------------------------------------------------------------
/downloading_weights.py:
--------------------------------------------------------------------------------
1 | import os
2 | import wget
3 | from tqdm import tqdm
4 |
5 | os.makedirs('pretrained_weights', exist_ok=True)
6 |
7 | urls = ['https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth',
8 | 'https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.pth',
9 | 'https://huggingface.co/TMElyralab/MusePose/resolve/main/MusePose/denoising_unet.pth',
10 | 'https://huggingface.co/TMElyralab/MusePose/resolve/main/MusePose/motion_module.pth',
11 | 'https://huggingface.co/TMElyralab/MusePose/resolve/main/MusePose/pose_guider.pth',
12 | 'https://huggingface.co/TMElyralab/MusePose/resolve/main/MusePose/reference_unet.pth',
13 | 'https://huggingface.co/lambdalabs/sd-image-variations-diffusers/resolve/main/unet/diffusion_pytorch_model.bin',
14 | 'https://huggingface.co/lambdalabs/sd-image-variations-diffusers/resolve/main/image_encoder/pytorch_model.bin',
15 | 'https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/diffusion_pytorch_model.bin'
16 | ]
17 |
18 | paths = ['dwpose', 'dwpose', 'MusePose', 'MusePose', 'MusePose', 'MusePose', 'sd-image-variations-diffusers/unet', 'image_encoder', 'sd-vae-ft-mse']
19 |
20 | for path in paths:
21 | os.makedirs(f'pretrained_weights/{path}', exist_ok=True)
22 |
23 | # saving weights
24 | for url, path in tqdm(zip(urls, paths)):
25 | filename = wget.download(url, f'pretrained_weights/{path}')
26 |
27 | config_urls = ['https://huggingface.co/lambdalabs/sd-image-variations-diffusers/resolve/main/unet/config.json',
28 | 'https://huggingface.co/lambdalabs/sd-image-variations-diffusers/resolve/main/image_encoder/config.json',
29 | 'https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/config.json']
30 |
31 | config_paths = ['sd-image-variations-diffusers/unet', 'image_encoder', 'sd-vae-ft-mse']
32 |
33 | # saving config files
34 | for url, path in tqdm(zip(config_urls, config_paths)):
35 | filename = wget.download(url, f'pretrained_weights/{path}')
36 |
37 | # renaming model name as given in readme
38 | os.rename('pretrained_weights/dwpose/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth', 'pretrained_weights/dwpose/yolox_l_8x8_300e_coco.pth')
39 |
--------------------------------------------------------------------------------
/draw_dwpose.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import argparse
4 | import numpy as np
5 | from tqdm import tqdm
6 | from PIL import Image
7 |
8 | from pose.script.tool import save_videos_from_pil
9 | from pose.script.dwpose import draw_pose
10 |
11 |
12 |
13 | def draw_dwpose(video_path, pose_path, out_path, draw_face):
14 |
15 | # capture video info
16 | cap = cv2.VideoCapture(video_path)
17 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
18 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
19 | fps = cap.get(cv2.CAP_PROP_FPS)
20 | fps = int(np.around(fps))
21 | # fps = get_fps(video_path)
22 | cap.release()
23 |
24 | # render resolution, short edge = 1024
25 | k = float(1024) / min(width, height)
26 | h_render = int(k*height//2 * 2)
27 | w_render = int(k*width//2 * 2)
28 |
29 | # save resolution, short edge = 768
30 | k = float(768) / min(width, height)
31 | h_save = int(k*height//2 * 2)
32 | w_save = int(k*width//2 * 2)
33 |
34 | poses = np.load(pose_path, allow_pickle=True)
35 | poses = poses.tolist()
36 |
37 | frames = []
38 | for pose in tqdm(poses):
39 | detected_map = draw_pose(pose, h_render, w_render, draw_face)
40 | detected_map = cv2.resize(detected_map, (w_save, h_save), interpolation=cv2.INTER_AREA)
41 | # cv2.imshow('', detected_map)
42 | # cv2.waitKey(0)
43 | detected_map = cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB)
44 | detected_map = Image.fromarray(detected_map)
45 | frames.append(detected_map)
46 |
47 | save_videos_from_pil(frames, out_path, fps)
48 |
49 |
50 |
51 | if __name__ == "__main__":
52 |
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument("--video_dir", type=str, default="./UBC_fashion/test", help='dance video dir')
55 | parser.add_argument("--pose_dir", type=str, default=None, help='auto makedir')
56 | parser.add_argument("--save_dir", type=str, default=None, help='auto makedir')
57 | parser.add_argument("--draw_face", type=bool, default=False, help='whether draw face or not')
58 | args = parser.parse_args()
59 |
60 |
61 | # video dir
62 | video_dir = args.video_dir
63 |
64 | # pose dir
65 | if args.pose_dir is None:
66 | pose_dir = args.video_dir + "_dwpose_keypoints"
67 | else:
68 | pose_dir = args.pose_dir
69 |
70 | # save dir
71 | if args.save_dir is None:
72 | if args.draw_face == True:
73 | save_dir = args.video_dir + "_dwpose"
74 | else:
75 | save_dir = args.video_dir + "_dwpose_without_face"
76 | else:
77 | save_dir = args.save_dir
78 | if not os.path.exists(save_dir):
79 | os.makedirs(save_dir)
80 |
81 |
82 | # collect all video_folder paths
83 | video_mp4_paths = set()
84 | for root, dirs, files in os.walk(args.video_dir):
85 | for name in files:
86 | if name.endswith(".mp4"):
87 | video_mp4_paths.add(os.path.join(root, name))
88 | video_mp4_paths = list(video_mp4_paths)
89 | # random.shuffle(video_mp4_paths)
90 | video_mp4_paths.sort()
91 | print("Num of videos:", len(video_mp4_paths))
92 |
93 |
94 | # draw dwpose
95 | for i in range(len(video_mp4_paths)):
96 | video_path = video_mp4_paths[i]
97 | video_name = os.path.relpath(video_path, video_dir)
98 | base_name = os.path.splitext(video_name)[0]
99 |
100 | pose_path = os.path.join(pose_dir, base_name + '.npy')
101 | if not os.path.exists(pose_path):
102 | print('no keypoint file:', pose_path)
103 |
104 | out_path = os.path.join(save_dir, base_name + '.mp4')
105 | if os.path.exists(out_path):
106 | print('already have rendered pose:', out_path)
107 | continue
108 |
109 | draw_dwpose(video_path, pose_path, out_path, args.draw_face)
110 | print(f"Process {i+1}/{len(video_mp4_paths)} video")
111 |
112 | print('all done!')
--------------------------------------------------------------------------------
/extract_dwpose_keypoints.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
4 | import argparse
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 | from pose.script.dwpose import DWposeDetector
9 | from pose.script.tool import read_frames
10 |
11 |
12 |
13 |
14 | def process_single_video(video_path, detector, root_dir, save_dir):
15 | # print(video_path)
16 | video_name = os.path.relpath(video_path, root_dir)
17 | base_name=os.path.splitext(video_name)[0]
18 | out_path = os.path.join(save_dir, base_name + '.npy')
19 | if os.path.exists(out_path):
20 | return
21 |
22 | frames = read_frames(video_path)
23 | keypoints = []
24 | for frame in tqdm(frames):
25 | keypoint = detector(frame)
26 | keypoints.append(keypoint)
27 |
28 | result = np.array(keypoints)
29 | np.save(out_path, result)
30 |
31 |
32 |
33 | def process_batch_videos(video_list, detector, root_dir, save_dir):
34 | for i, video_path in enumerate(video_list):
35 | process_single_video(video_path, detector, root_dir, save_dir)
36 | print(f"Process {i+1}/{len(video_list)} video")
37 |
38 |
39 |
40 | if __name__ == "__main__":
41 |
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument("--video_dir", type=str, default="./UBC_fashion/test")
44 | parser.add_argument("--save_dir", type=str, default=None)
45 | parser.add_argument("--yolox_config", type=str, default="./pose/config/yolox_l_8xb8-300e_coco.py")
46 | parser.add_argument("--dwpose_config", type=str, default="./pose/config/dwpose-l_384x288.py")
47 | parser.add_argument("--yolox_ckpt", type=str, default="./pretrained_weights/dwpose/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth")
48 | parser.add_argument("--dwpose_ckpt", type=str, default="./pretrained_weights/dwpose/dw-ll_ucoco_384.pth")
49 | args = parser.parse_args()
50 |
51 | # make save dir
52 | if args.save_dir is None:
53 | save_dir = args.video_dir + "_dwpose_keypoints"
54 | else:
55 | save_dir = args.save_dir
56 | if not os.path.exists(save_dir):
57 | os.makedirs(save_dir)
58 |
59 | # collect all video_folder paths
60 | video_mp4_paths = set()
61 | for root, dirs, files in os.walk(args.video_dir):
62 | for name in files:
63 | if name.endswith(".mp4"):
64 | video_mp4_paths.add(os.path.join(root, name))
65 | video_mp4_paths = list(video_mp4_paths)
66 | video_mp4_paths.sort()
67 | print("Num of videos:", len(video_mp4_paths))
68 |
69 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
70 | detector = DWposeDetector(
71 | det_config = args.yolox_config,
72 | det_ckpt = args.yolox_ckpt,
73 | pose_config = args.dwpose_config,
74 | pose_ckpt = args.dwpose_ckpt,
75 | keypoints_only=True
76 | )
77 | detector = detector.to(device)
78 |
79 | process_batch_videos(video_mp4_paths, detector, args.video_dir, save_dir)
80 | print('all done!')
81 |
--------------------------------------------------------------------------------
/extract_meta_info_multiple_dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | # python tools/extract_meta_info.py --video_dirs /path/to/video_dir --dataset_name fashion
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument("--video_dirs", type=str, default=[
8 | "./UBC_fashion/test",
9 | # "path_of_dataset_1",
10 | # "path_of_dataset_2",
11 | ])
12 | parser.add_argument("--save_dir", type=str, default="./meta")
13 | parser.add_argument("--dataset_name", type=str, default="my_dataset")
14 | parser.add_argument("--meta_info_name", type=str, default=None)
15 | parser.add_argument("--draw_face", type=bool, default=False)
16 | args = parser.parse_args()
17 |
18 | if args.meta_info_name is None:
19 | args.meta_info_name = args.dataset_name
20 |
21 | # collect all video_folder paths
22 | meta_infos = []
23 |
24 | for dataset_path in args.video_dirs:
25 | video_mp4_paths = set()
26 | if args.draw_face == True:
27 | pose_dir = dataset_path + "_dwpose"
28 | else:
29 | pose_dir = dataset_path + "_dwpose_without_face"
30 |
31 | for root, dirs, files in os.walk(dataset_path):
32 | for name in files:
33 | if name.endswith(".mp4"):
34 | video_mp4_paths.add(os.path.join(root, name))
35 |
36 | video_mp4_paths = list(video_mp4_paths)
37 | print(dataset_path)
38 | print("video num:", len(video_mp4_paths))
39 |
40 |
41 | for video_mp4_path in video_mp4_paths:
42 | relative_video_name = os.path.relpath(video_mp4_path, dataset_path)
43 | kps_path = os.path.join(pose_dir, relative_video_name)
44 | meta_infos.append({"video_path": video_mp4_path, "kps_path": kps_path})
45 |
46 | save_path = os.path.join(args.save_dir, f"{args.meta_info_name}_meta.json")
47 | json.dump(meta_infos, open(save_path, "w"))
48 | print('data dumped')
49 | print('total pieces of data', len(meta_infos))
50 |
51 |
52 | import cv2
53 | # check data (cannot read or damaged)
54 | for index, video_meta in enumerate(meta_infos):
55 |
56 | video_path = video_meta[ "video_path"]
57 | kps_path = video_meta[ "kps_path"]
58 |
59 | video = cv2.VideoCapture(video_path)
60 | kps = cv2.VideoCapture(kps_path)
61 | frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
62 | frame_count_2 = int(kps.get(cv2.CAP_PROP_FRAME_COUNT))
63 | assert(frame_count) == (frame_count_2), f"{frame_count} != {frame_count_2} in {video_path}"
64 |
65 | if (index+1) % 100 == 0: print(index+1)
66 |
67 | print('data checked, no problem')
68 |
--------------------------------------------------------------------------------
/musepose/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TMElyralab/MusePose/61c52bd937224a614b3951419b735b639397cb62/musepose/__init__.py
--------------------------------------------------------------------------------
/musepose/dataset/dance_image.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 |
4 | import torch
5 | import torchvision.transforms as transforms
6 | from decord import VideoReader
7 | from PIL import Image
8 | from torch.utils.data import Dataset
9 | from transformers import CLIPImageProcessor
10 |
11 |
12 | class HumanDanceDataset(Dataset):
13 | def __init__(
14 | self,
15 | img_size,
16 | img_scale=(1.0, 1.0),
17 | img_ratio=(0.9, 1.0),
18 | drop_ratio=0.1,
19 | data_meta_paths=["./data/fahsion_meta.json"],
20 | sample_margin=30,
21 | ):
22 | super().__init__()
23 |
24 | self.img_size = img_size
25 | self.img_scale = img_scale
26 | self.img_ratio = img_ratio
27 | self.sample_margin = sample_margin
28 |
29 | # -----
30 | # vid_meta format:
31 | # [{'video_path': , 'kps_path': , 'other':},
32 | # {'video_path': , 'kps_path': , 'other':}]
33 | # -----
34 | vid_meta = []
35 | for data_meta_path in data_meta_paths:
36 | vid_meta.extend(json.load(open(data_meta_path, "r")))
37 | self.vid_meta = vid_meta
38 |
39 | self.clip_image_processor = CLIPImageProcessor()
40 |
41 | self.transform = transforms.Compose(
42 | [
43 | # transforms.RandomResizedCrop(
44 | # self.img_size,
45 | # scale=self.img_scale,
46 | # ratio=self.img_ratio,
47 | # interpolation=transforms.InterpolationMode.BILINEAR,
48 | # ),
49 | transforms.Resize(
50 | self.img_size,
51 | ),
52 | transforms.ToTensor(),
53 | transforms.Normalize([0.5], [0.5]),
54 | ]
55 | )
56 |
57 | self.cond_transform = transforms.Compose(
58 | [
59 | # transforms.RandomResizedCrop(
60 | # self.img_size,
61 | # scale=self.img_scale,
62 | # ratio=self.img_ratio,
63 | # interpolation=transforms.InterpolationMode.BILINEAR,
64 | # ),
65 | transforms.Resize(
66 | self.img_size,
67 | ),
68 | transforms.ToTensor(),
69 | ]
70 | )
71 |
72 | self.drop_ratio = drop_ratio
73 |
74 | def augmentation(self, image, transform, state=None):
75 | if state is not None:
76 | torch.set_rng_state(state)
77 | return transform(image)
78 |
79 | def __getitem__(self, index):
80 | video_meta = self.vid_meta[index]
81 | video_path = video_meta["video_path"]
82 | kps_path = video_meta["kps_path"]
83 |
84 | video_reader = VideoReader(video_path)
85 | kps_reader = VideoReader(kps_path)
86 |
87 | assert len(video_reader) == len(
88 | kps_reader
89 | ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"
90 |
91 | video_length = len(video_reader)
92 |
93 | margin = min(self.sample_margin, video_length)
94 |
95 | ref_img_idx = random.randint(0, video_length - 1)
96 | if ref_img_idx + margin < video_length:
97 | tgt_img_idx = random.randint(ref_img_idx + margin, video_length - 1)
98 | elif ref_img_idx - margin > 0:
99 | tgt_img_idx = random.randint(0, ref_img_idx - margin)
100 | else:
101 | tgt_img_idx = random.randint(0, video_length - 1)
102 |
103 | ref_img = video_reader[ref_img_idx]
104 | ref_img_pil = Image.fromarray(ref_img.asnumpy())
105 | tgt_img = video_reader[tgt_img_idx]
106 | tgt_img_pil = Image.fromarray(tgt_img.asnumpy())
107 |
108 | tgt_pose = kps_reader[tgt_img_idx]
109 | tgt_pose_pil = Image.fromarray(tgt_pose.asnumpy())
110 |
111 | state = torch.get_rng_state()
112 | tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
113 | tgt_pose_img = self.augmentation(tgt_pose_pil, self.cond_transform, state)
114 | ref_img_vae = self.augmentation(ref_img_pil, self.transform, state)
115 | clip_image = self.clip_image_processor(
116 | images=ref_img_pil, return_tensors="pt"
117 | ).pixel_values[0]
118 |
119 | sample = dict(
120 | video_dir=video_path,
121 | img=tgt_img,
122 | tgt_pose=tgt_pose_img,
123 | ref_img=ref_img_vae,
124 | clip_images=clip_image,
125 | )
126 |
127 | return sample
128 |
129 | def __len__(self):
130 | return len(self.vid_meta)
131 |
--------------------------------------------------------------------------------
/musepose/dataset/dance_video.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | from typing import List
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | import torchvision.transforms as transforms
9 | from decord import VideoReader
10 | from PIL import Image
11 | from torch.utils.data import Dataset
12 | from transformers import CLIPImageProcessor
13 |
14 |
15 | class HumanDanceVideoDataset(Dataset):
16 | def __init__(
17 | self,
18 | sample_rate,
19 | n_sample_frames,
20 | width,
21 | height,
22 | img_scale=(1.0, 1.0),
23 | img_ratio=(0.9, 1.0),
24 | drop_ratio=0.1,
25 | data_meta_paths=["./data/fashion_meta.json"],
26 | ):
27 | super().__init__()
28 | self.sample_rate = sample_rate
29 | self.n_sample_frames = n_sample_frames
30 | self.width = width
31 | self.height = height
32 | self.img_scale = img_scale
33 | self.img_ratio = img_ratio
34 |
35 | vid_meta = []
36 | for data_meta_path in data_meta_paths:
37 | vid_meta.extend(json.load(open(data_meta_path, "r")))
38 | self.vid_meta = vid_meta
39 |
40 | self.clip_image_processor = CLIPImageProcessor()
41 |
42 | self.pixel_transform = transforms.Compose(
43 | [
44 | # transforms.RandomResizedCrop(
45 | # (height, width),
46 | # scale=self.img_scale,
47 | # ratio=self.img_ratio,
48 | # interpolation=transforms.InterpolationMode.BILINEAR,
49 | # ),
50 | transforms.Resize(
51 | (height, width),
52 | ),
53 | transforms.ToTensor(),
54 | transforms.Normalize([0.5], [0.5]),
55 | ]
56 | )
57 |
58 | self.cond_transform = transforms.Compose(
59 | [
60 | # transforms.RandomResizedCrop(
61 | # (height, width),
62 | # scale=self.img_scale,
63 | # ratio=self.img_ratio,
64 | # interpolation=transforms.InterpolationMode.BILINEAR,
65 | # ),
66 | transforms.Resize(
67 | (height, width),
68 | ),
69 | transforms.ToTensor(),
70 | ]
71 | )
72 |
73 | self.drop_ratio = drop_ratio
74 |
75 | def augmentation(self, images, transform, state=None):
76 | if state is not None:
77 | torch.set_rng_state(state)
78 | if isinstance(images, List):
79 | transformed_images = [transform(img) for img in images]
80 | ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
81 | else:
82 | ret_tensor = transform(images) # (c, h, w)
83 | return ret_tensor
84 |
85 | def __getitem__(self, index):
86 | video_meta = self.vid_meta[index]
87 | video_path = video_meta["video_path"]
88 | kps_path = video_meta["kps_path"]
89 |
90 | video_reader = VideoReader(video_path)
91 | kps_reader = VideoReader(kps_path)
92 |
93 | assert len(video_reader) == len(
94 | kps_reader
95 | ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"
96 |
97 | video_length = len(video_reader)
98 | video_fps = video_reader.get_avg_fps()
99 | # print("fps", video_fps)
100 | if video_fps > 30: # 30-60
101 | sample_rate = self.sample_rate*2
102 | else:
103 | sample_rate = self.sample_rate
104 |
105 |
106 | clip_length = min(
107 | video_length, (self.n_sample_frames - 1) * sample_rate + 1
108 | )
109 | start_idx = random.randint(0, video_length - clip_length)
110 | batch_index = np.linspace(
111 | start_idx, start_idx + clip_length - 1, self.n_sample_frames, dtype=int
112 | ).tolist()
113 |
114 | # read frames and kps
115 | vid_pil_image_list = []
116 | pose_pil_image_list = []
117 | for index in batch_index:
118 | img = video_reader[index]
119 | vid_pil_image_list.append(Image.fromarray(img.asnumpy()))
120 | img = kps_reader[index]
121 | pose_pil_image_list.append(Image.fromarray(img.asnumpy()))
122 |
123 | ref_img_idx = random.randint(0, video_length - 1)
124 | ref_img = Image.fromarray(video_reader[ref_img_idx].asnumpy())
125 |
126 | # transform
127 | state = torch.get_rng_state()
128 | pixel_values_vid = self.augmentation(
129 | vid_pil_image_list, self.pixel_transform, state
130 | )
131 | pixel_values_pose = self.augmentation(
132 | pose_pil_image_list, self.cond_transform, state
133 | )
134 | pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state)
135 | clip_ref_img = self.clip_image_processor(
136 | images=ref_img, return_tensors="pt"
137 | ).pixel_values[0]
138 |
139 | sample = dict(
140 | video_dir=video_path,
141 | pixel_values_vid=pixel_values_vid,
142 | pixel_values_pose=pixel_values_pose,
143 | pixel_values_ref_img=pixel_values_ref_img,
144 | clip_ref_img=clip_ref_img,
145 | )
146 |
147 | return sample
148 |
149 | def __len__(self):
150 | return len(self.vid_meta)
151 |
--------------------------------------------------------------------------------
/musepose/models/attention.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2 |
3 | from typing import Any, Dict, Optional
4 |
5 | import torch
6 | from diffusers.models.attention import AdaLayerNorm, Attention, FeedForward
7 | from diffusers.models.embeddings import SinusoidalPositionalEmbedding
8 | from einops import rearrange
9 | from torch import nn
10 |
11 |
12 | class BasicTransformerBlock(nn.Module):
13 | r"""
14 | A basic Transformer block.
15 |
16 | Parameters:
17 | dim (`int`): The number of channels in the input and output.
18 | num_attention_heads (`int`): The number of heads to use for multi-head attention.
19 | attention_head_dim (`int`): The number of channels in each head.
20 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
21 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
22 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
23 | num_embeds_ada_norm (:
24 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
25 | attention_bias (:
26 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
27 | only_cross_attention (`bool`, *optional*):
28 | Whether to use only cross-attention layers. In this case two cross attention layers are used.
29 | double_self_attention (`bool`, *optional*):
30 | Whether to use two self-attention layers. In this case no cross attention layers are used.
31 | upcast_attention (`bool`, *optional*):
32 | Whether to upcast the attention computation to float32. This is useful for mixed precision training.
33 | norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
34 | Whether to use learnable elementwise affine parameters for normalization.
35 | norm_type (`str`, *optional*, defaults to `"layer_norm"`):
36 | The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
37 | final_dropout (`bool` *optional*, defaults to False):
38 | Whether to apply a final dropout after the last feed-forward layer.
39 | attention_type (`str`, *optional*, defaults to `"default"`):
40 | The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
41 | positional_embeddings (`str`, *optional*, defaults to `None`):
42 | The type of positional embeddings to apply to.
43 | num_positional_embeddings (`int`, *optional*, defaults to `None`):
44 | The maximum number of positional embeddings to apply.
45 | """
46 |
47 | def __init__(
48 | self,
49 | dim: int,
50 | num_attention_heads: int,
51 | attention_head_dim: int,
52 | dropout=0.0,
53 | cross_attention_dim: Optional[int] = None,
54 | activation_fn: str = "geglu",
55 | num_embeds_ada_norm: Optional[int] = None,
56 | attention_bias: bool = False,
57 | only_cross_attention: bool = False,
58 | double_self_attention: bool = False,
59 | upcast_attention: bool = False,
60 | norm_elementwise_affine: bool = True,
61 | norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
62 | norm_eps: float = 1e-5,
63 | final_dropout: bool = False,
64 | attention_type: str = "default",
65 | positional_embeddings: Optional[str] = None,
66 | num_positional_embeddings: Optional[int] = None,
67 | ):
68 | super().__init__()
69 | self.only_cross_attention = only_cross_attention
70 |
71 | self.use_ada_layer_norm_zero = (
72 | num_embeds_ada_norm is not None
73 | ) and norm_type == "ada_norm_zero"
74 | self.use_ada_layer_norm = (
75 | num_embeds_ada_norm is not None
76 | ) and norm_type == "ada_norm"
77 | self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
78 | self.use_layer_norm = norm_type == "layer_norm"
79 |
80 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
81 | raise ValueError(
82 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
83 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
84 | )
85 |
86 | if positional_embeddings and (num_positional_embeddings is None):
87 | raise ValueError(
88 | "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
89 | )
90 |
91 | if positional_embeddings == "sinusoidal":
92 | self.pos_embed = SinusoidalPositionalEmbedding(
93 | dim, max_seq_length=num_positional_embeddings
94 | )
95 | else:
96 | self.pos_embed = None
97 |
98 | # Define 3 blocks. Each block has its own normalization layer.
99 | # 1. Self-Attn
100 | if self.use_ada_layer_norm:
101 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
102 | elif self.use_ada_layer_norm_zero:
103 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
104 | else:
105 | self.norm1 = nn.LayerNorm(
106 | dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
107 | )
108 |
109 | self.attn1 = Attention(
110 | query_dim=dim,
111 | heads=num_attention_heads,
112 | dim_head=attention_head_dim,
113 | dropout=dropout,
114 | bias=attention_bias,
115 | cross_attention_dim=cross_attention_dim if only_cross_attention else None,
116 | upcast_attention=upcast_attention,
117 | )
118 |
119 | # 2. Cross-Attn
120 | if cross_attention_dim is not None or double_self_attention:
121 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
122 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
123 | # the second cross attention block.
124 | self.norm2 = (
125 | AdaLayerNorm(dim, num_embeds_ada_norm)
126 | if self.use_ada_layer_norm
127 | else nn.LayerNorm(
128 | dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
129 | )
130 | )
131 | self.attn2 = Attention(
132 | query_dim=dim,
133 | cross_attention_dim=cross_attention_dim
134 | if not double_self_attention
135 | else None,
136 | heads=num_attention_heads,
137 | dim_head=attention_head_dim,
138 | dropout=dropout,
139 | bias=attention_bias,
140 | upcast_attention=upcast_attention,
141 | ) # is self-attn if encoder_hidden_states is none
142 | else:
143 | self.norm2 = None
144 | self.attn2 = None
145 |
146 | # 3. Feed-forward
147 | if not self.use_ada_layer_norm_single:
148 | self.norm3 = nn.LayerNorm(
149 | dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
150 | )
151 |
152 | self.ff = FeedForward(
153 | dim,
154 | dropout=dropout,
155 | activation_fn=activation_fn,
156 | final_dropout=final_dropout,
157 | )
158 |
159 | # 4. Fuser
160 | if attention_type == "gated" or attention_type == "gated-text-image":
161 | self.fuser = GatedSelfAttentionDense(
162 | dim, cross_attention_dim, num_attention_heads, attention_head_dim
163 | )
164 |
165 | # 5. Scale-shift for PixArt-Alpha.
166 | if self.use_ada_layer_norm_single:
167 | self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
168 |
169 | # let chunk size default to None
170 | self._chunk_size = None
171 | self._chunk_dim = 0
172 |
173 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
174 | # Sets chunk feed-forward
175 | self._chunk_size = chunk_size
176 | self._chunk_dim = dim
177 |
178 | def forward(
179 | self,
180 | hidden_states: torch.FloatTensor,
181 | attention_mask: Optional[torch.FloatTensor] = None,
182 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
183 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
184 | timestep: Optional[torch.LongTensor] = None,
185 | cross_attention_kwargs: Dict[str, Any] = None,
186 | class_labels: Optional[torch.LongTensor] = None,
187 | ) -> torch.FloatTensor:
188 | # Notice that normalization is always applied before the real computation in the following blocks.
189 | # 0. Self-Attention
190 | batch_size = hidden_states.shape[0]
191 |
192 | if self.use_ada_layer_norm:
193 | norm_hidden_states = self.norm1(hidden_states, timestep)
194 | elif self.use_ada_layer_norm_zero:
195 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
196 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
197 | )
198 | elif self.use_layer_norm:
199 | norm_hidden_states = self.norm1(hidden_states)
200 | elif self.use_ada_layer_norm_single:
201 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
202 | self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
203 | ).chunk(6, dim=1)
204 | norm_hidden_states = self.norm1(hidden_states)
205 | norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
206 | norm_hidden_states = norm_hidden_states.squeeze(1)
207 | else:
208 | raise ValueError("Incorrect norm used")
209 |
210 | if self.pos_embed is not None:
211 | norm_hidden_states = self.pos_embed(norm_hidden_states)
212 |
213 | # 1. Retrieve lora scale.
214 | lora_scale = (
215 | cross_attention_kwargs.get("scale", 1.0)
216 | if cross_attention_kwargs is not None
217 | else 1.0
218 | )
219 |
220 | # 2. Prepare GLIGEN inputs
221 | cross_attention_kwargs = (
222 | cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
223 | )
224 | gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
225 |
226 | attn_output = self.attn1(
227 | norm_hidden_states,
228 | encoder_hidden_states=encoder_hidden_states
229 | if self.only_cross_attention
230 | else None,
231 | attention_mask=attention_mask,
232 | **cross_attention_kwargs,
233 | )
234 | if self.use_ada_layer_norm_zero:
235 | attn_output = gate_msa.unsqueeze(1) * attn_output
236 | elif self.use_ada_layer_norm_single:
237 | attn_output = gate_msa * attn_output
238 |
239 | hidden_states = attn_output + hidden_states
240 | if hidden_states.ndim == 4:
241 | hidden_states = hidden_states.squeeze(1)
242 |
243 | # 2.5 GLIGEN Control
244 | if gligen_kwargs is not None:
245 | hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
246 |
247 | # 3. Cross-Attention
248 | if self.attn2 is not None:
249 | if self.use_ada_layer_norm:
250 | norm_hidden_states = self.norm2(hidden_states, timestep)
251 | elif self.use_ada_layer_norm_zero or self.use_layer_norm:
252 | norm_hidden_states = self.norm2(hidden_states)
253 | elif self.use_ada_layer_norm_single:
254 | # For PixArt norm2 isn't applied here:
255 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
256 | norm_hidden_states = hidden_states
257 | else:
258 | raise ValueError("Incorrect norm")
259 |
260 | if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
261 | norm_hidden_states = self.pos_embed(norm_hidden_states)
262 |
263 | attn_output = self.attn2(
264 | norm_hidden_states,
265 | encoder_hidden_states=encoder_hidden_states,
266 | attention_mask=encoder_attention_mask,
267 | **cross_attention_kwargs,
268 | )
269 | hidden_states = attn_output + hidden_states
270 |
271 | # 4. Feed-forward
272 | if not self.use_ada_layer_norm_single:
273 | norm_hidden_states = self.norm3(hidden_states)
274 |
275 | if self.use_ada_layer_norm_zero:
276 | norm_hidden_states = (
277 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
278 | )
279 |
280 | if self.use_ada_layer_norm_single:
281 | norm_hidden_states = self.norm2(hidden_states)
282 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
283 |
284 | ff_output = self.ff(norm_hidden_states, scale=lora_scale)
285 |
286 | if self.use_ada_layer_norm_zero:
287 | ff_output = gate_mlp.unsqueeze(1) * ff_output
288 | elif self.use_ada_layer_norm_single:
289 | ff_output = gate_mlp * ff_output
290 |
291 | hidden_states = ff_output + hidden_states
292 | if hidden_states.ndim == 4:
293 | hidden_states = hidden_states.squeeze(1)
294 |
295 | return hidden_states
296 |
297 |
298 | class TemporalBasicTransformerBlock(nn.Module):
299 | def __init__(
300 | self,
301 | dim: int,
302 | num_attention_heads: int,
303 | attention_head_dim: int,
304 | dropout=0.0,
305 | cross_attention_dim: Optional[int] = None,
306 | activation_fn: str = "geglu",
307 | num_embeds_ada_norm: Optional[int] = None,
308 | attention_bias: bool = False,
309 | only_cross_attention: bool = False,
310 | upcast_attention: bool = False,
311 | unet_use_cross_frame_attention=None,
312 | unet_use_temporal_attention=None,
313 | ):
314 | super().__init__()
315 | self.only_cross_attention = only_cross_attention
316 | self.use_ada_layer_norm = num_embeds_ada_norm is not None
317 | self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
318 | self.unet_use_temporal_attention = unet_use_temporal_attention
319 |
320 | # SC-Attn
321 | self.attn1 = Attention(
322 | query_dim=dim,
323 | heads=num_attention_heads,
324 | dim_head=attention_head_dim,
325 | dropout=dropout,
326 | bias=attention_bias,
327 | upcast_attention=upcast_attention,
328 | )
329 | self.norm1 = (
330 | AdaLayerNorm(dim, num_embeds_ada_norm)
331 | if self.use_ada_layer_norm
332 | else nn.LayerNorm(dim)
333 | )
334 |
335 | # Cross-Attn
336 | if cross_attention_dim is not None:
337 | self.attn2 = Attention(
338 | query_dim=dim,
339 | cross_attention_dim=cross_attention_dim,
340 | heads=num_attention_heads,
341 | dim_head=attention_head_dim,
342 | dropout=dropout,
343 | bias=attention_bias,
344 | upcast_attention=upcast_attention,
345 | )
346 | else:
347 | self.attn2 = None
348 |
349 | if cross_attention_dim is not None:
350 | self.norm2 = (
351 | AdaLayerNorm(dim, num_embeds_ada_norm)
352 | if self.use_ada_layer_norm
353 | else nn.LayerNorm(dim)
354 | )
355 | else:
356 | self.norm2 = None
357 |
358 | # Feed-forward
359 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
360 | self.norm3 = nn.LayerNorm(dim)
361 | self.use_ada_layer_norm_zero = False
362 |
363 | # Temp-Attn
364 | assert unet_use_temporal_attention is not None
365 | if unet_use_temporal_attention:
366 | self.attn_temp = Attention(
367 | query_dim=dim,
368 | heads=num_attention_heads,
369 | dim_head=attention_head_dim,
370 | dropout=dropout,
371 | bias=attention_bias,
372 | upcast_attention=upcast_attention,
373 | )
374 | nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
375 | self.norm_temp = (
376 | AdaLayerNorm(dim, num_embeds_ada_norm)
377 | if self.use_ada_layer_norm
378 | else nn.LayerNorm(dim)
379 | )
380 |
381 | def forward(
382 | self,
383 | hidden_states,
384 | encoder_hidden_states=None,
385 | timestep=None,
386 | attention_mask=None,
387 | video_length=None,
388 | ):
389 | norm_hidden_states = (
390 | self.norm1(hidden_states, timestep)
391 | if self.use_ada_layer_norm
392 | else self.norm1(hidden_states)
393 | )
394 |
395 | if self.unet_use_cross_frame_attention:
396 | hidden_states = (
397 | self.attn1(
398 | norm_hidden_states,
399 | attention_mask=attention_mask,
400 | video_length=video_length,
401 | )
402 | + hidden_states
403 | )
404 | else:
405 | hidden_states = (
406 | self.attn1(norm_hidden_states, attention_mask=attention_mask)
407 | + hidden_states
408 | )
409 |
410 | if self.attn2 is not None:
411 | # Cross-Attention
412 | norm_hidden_states = (
413 | self.norm2(hidden_states, timestep)
414 | if self.use_ada_layer_norm
415 | else self.norm2(hidden_states)
416 | )
417 | hidden_states = (
418 | self.attn2(
419 | norm_hidden_states,
420 | encoder_hidden_states=encoder_hidden_states,
421 | attention_mask=attention_mask,
422 | )
423 | + hidden_states
424 | )
425 |
426 | # Feed-forward
427 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
428 |
429 | # Temporal-Attention
430 | if self.unet_use_temporal_attention:
431 | d = hidden_states.shape[1]
432 | hidden_states = rearrange(
433 | hidden_states, "(b f) d c -> (b d) f c", f=video_length
434 | )
435 | norm_hidden_states = (
436 | self.norm_temp(hidden_states, timestep)
437 | if self.use_ada_layer_norm
438 | else self.norm_temp(hidden_states)
439 | )
440 | hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
441 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
442 |
443 | return hidden_states
444 |
--------------------------------------------------------------------------------
/musepose/models/motion_module.py:
--------------------------------------------------------------------------------
1 | # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
2 | import math
3 | from dataclasses import dataclass
4 | from typing import Callable, Optional
5 |
6 | import torch
7 | from diffusers.models.attention import FeedForward
8 | from diffusers.models.attention_processor import Attention, AttnProcessor
9 | from diffusers.utils import BaseOutput
10 | from diffusers.utils.import_utils import is_xformers_available
11 | from einops import rearrange, repeat
12 | from torch import nn
13 |
14 |
15 | def zero_module(module):
16 | # Zero out the parameters of a module and return it.
17 | for p in module.parameters():
18 | p.detach().zero_()
19 | return module
20 |
21 |
22 | @dataclass
23 | class TemporalTransformer3DModelOutput(BaseOutput):
24 | sample: torch.FloatTensor
25 |
26 |
27 | if is_xformers_available():
28 | import xformers
29 | import xformers.ops
30 | else:
31 | xformers = None
32 |
33 |
34 | def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
35 | if motion_module_type == "Vanilla":
36 | return VanillaTemporalModule(
37 | in_channels=in_channels,
38 | **motion_module_kwargs,
39 | )
40 | else:
41 | raise ValueError
42 |
43 |
44 | class VanillaTemporalModule(nn.Module):
45 | def __init__(
46 | self,
47 | in_channels,
48 | num_attention_heads=8,
49 | num_transformer_block=2,
50 | attention_block_types=("Temporal_Self", "Temporal_Self"),
51 | cross_frame_attention_mode=None,
52 | temporal_position_encoding=False,
53 | temporal_position_encoding_max_len=24,
54 | temporal_attention_dim_div=1,
55 | zero_initialize=True,
56 | ):
57 | super().__init__()
58 |
59 | self.temporal_transformer = TemporalTransformer3DModel(
60 | in_channels=in_channels,
61 | num_attention_heads=num_attention_heads,
62 | attention_head_dim=in_channels
63 | // num_attention_heads
64 | // temporal_attention_dim_div,
65 | num_layers=num_transformer_block,
66 | attention_block_types=attention_block_types,
67 | cross_frame_attention_mode=cross_frame_attention_mode,
68 | temporal_position_encoding=temporal_position_encoding,
69 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
70 | )
71 |
72 | if zero_initialize:
73 | self.temporal_transformer.proj_out = zero_module(
74 | self.temporal_transformer.proj_out
75 | )
76 |
77 | def forward(
78 | self,
79 | input_tensor,
80 | temb,
81 | encoder_hidden_states,
82 | attention_mask=None,
83 | anchor_frame_idx=None,
84 | ):
85 | hidden_states = input_tensor
86 | hidden_states = self.temporal_transformer(
87 | hidden_states, encoder_hidden_states, attention_mask
88 | )
89 |
90 | output = hidden_states
91 | return output
92 |
93 |
94 | class TemporalTransformer3DModel(nn.Module):
95 | def __init__(
96 | self,
97 | in_channels,
98 | num_attention_heads,
99 | attention_head_dim,
100 | num_layers,
101 | attention_block_types=(
102 | "Temporal_Self",
103 | "Temporal_Self",
104 | ),
105 | dropout=0.0,
106 | norm_num_groups=32,
107 | cross_attention_dim=768,
108 | activation_fn="geglu",
109 | attention_bias=False,
110 | upcast_attention=False,
111 | cross_frame_attention_mode=None,
112 | temporal_position_encoding=False,
113 | temporal_position_encoding_max_len=24,
114 | ):
115 | super().__init__()
116 |
117 | inner_dim = num_attention_heads * attention_head_dim
118 |
119 | self.norm = torch.nn.GroupNorm(
120 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
121 | )
122 | self.proj_in = nn.Linear(in_channels, inner_dim)
123 |
124 | self.transformer_blocks = nn.ModuleList(
125 | [
126 | TemporalTransformerBlock(
127 | dim=inner_dim,
128 | num_attention_heads=num_attention_heads,
129 | attention_head_dim=attention_head_dim,
130 | attention_block_types=attention_block_types,
131 | dropout=dropout,
132 | norm_num_groups=norm_num_groups,
133 | cross_attention_dim=cross_attention_dim,
134 | activation_fn=activation_fn,
135 | attention_bias=attention_bias,
136 | upcast_attention=upcast_attention,
137 | cross_frame_attention_mode=cross_frame_attention_mode,
138 | temporal_position_encoding=temporal_position_encoding,
139 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
140 | )
141 | for d in range(num_layers)
142 | ]
143 | )
144 | self.proj_out = nn.Linear(inner_dim, in_channels)
145 |
146 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
147 | assert (
148 | hidden_states.dim() == 5
149 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
150 | video_length = hidden_states.shape[2]
151 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
152 |
153 | batch, channel, height, weight = hidden_states.shape
154 | residual = hidden_states
155 |
156 | hidden_states = self.norm(hidden_states)
157 | inner_dim = hidden_states.shape[1]
158 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
159 | batch, height * weight, inner_dim
160 | )
161 | hidden_states = self.proj_in(hidden_states)
162 |
163 | # Transformer Blocks
164 | for block in self.transformer_blocks:
165 | hidden_states = block(
166 | hidden_states,
167 | encoder_hidden_states=encoder_hidden_states,
168 | video_length=video_length,
169 | )
170 |
171 | # output
172 | hidden_states = self.proj_out(hidden_states)
173 | hidden_states = (
174 | hidden_states.reshape(batch, height, weight, inner_dim)
175 | .permute(0, 3, 1, 2)
176 | .contiguous()
177 | )
178 |
179 | output = hidden_states + residual
180 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
181 |
182 | return output
183 |
184 |
185 | class TemporalTransformerBlock(nn.Module):
186 | def __init__(
187 | self,
188 | dim,
189 | num_attention_heads,
190 | attention_head_dim,
191 | attention_block_types=(
192 | "Temporal_Self",
193 | "Temporal_Self",
194 | ),
195 | dropout=0.0,
196 | norm_num_groups=32,
197 | cross_attention_dim=768,
198 | activation_fn="geglu",
199 | attention_bias=False,
200 | upcast_attention=False,
201 | cross_frame_attention_mode=None,
202 | temporal_position_encoding=False,
203 | temporal_position_encoding_max_len=24,
204 | ):
205 | super().__init__()
206 |
207 | attention_blocks = []
208 | norms = []
209 |
210 | for block_name in attention_block_types:
211 | attention_blocks.append(
212 | VersatileAttention(
213 | attention_mode=block_name.split("_")[0],
214 | cross_attention_dim=cross_attention_dim
215 | if block_name.endswith("_Cross")
216 | else None,
217 | query_dim=dim,
218 | heads=num_attention_heads,
219 | dim_head=attention_head_dim,
220 | dropout=dropout,
221 | bias=attention_bias,
222 | upcast_attention=upcast_attention,
223 | cross_frame_attention_mode=cross_frame_attention_mode,
224 | temporal_position_encoding=temporal_position_encoding,
225 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
226 | )
227 | )
228 | norms.append(nn.LayerNorm(dim))
229 |
230 | self.attention_blocks = nn.ModuleList(attention_blocks)
231 | self.norms = nn.ModuleList(norms)
232 |
233 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
234 | self.ff_norm = nn.LayerNorm(dim)
235 |
236 | def forward(
237 | self,
238 | hidden_states,
239 | encoder_hidden_states=None,
240 | attention_mask=None,
241 | video_length=None,
242 | ):
243 | for attention_block, norm in zip(self.attention_blocks, self.norms):
244 | norm_hidden_states = norm(hidden_states)
245 | hidden_states = (
246 | attention_block(
247 | norm_hidden_states,
248 | encoder_hidden_states=encoder_hidden_states
249 | if attention_block.is_cross_attention
250 | else None,
251 | video_length=video_length,
252 | )
253 | + hidden_states
254 | )
255 |
256 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
257 |
258 | output = hidden_states
259 | return output
260 |
261 |
262 | class PositionalEncoding(nn.Module):
263 | def __init__(self, d_model, dropout=0.0, max_len=24):
264 | super().__init__()
265 | self.dropout = nn.Dropout(p=dropout)
266 | position = torch.arange(max_len).unsqueeze(1)
267 | div_term = torch.exp(
268 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
269 | )
270 | pe = torch.zeros(1, max_len, d_model)
271 | pe[0, :, 0::2] = torch.sin(position * div_term)
272 | pe[0, :, 1::2] = torch.cos(position * div_term)
273 | self.register_buffer("pe", pe)
274 |
275 | def forward(self, x):
276 | x = x + self.pe[:, : x.size(1)]
277 | return self.dropout(x)
278 |
279 |
280 | class VersatileAttention(Attention):
281 | def __init__(
282 | self,
283 | attention_mode=None,
284 | cross_frame_attention_mode=None,
285 | temporal_position_encoding=False,
286 | temporal_position_encoding_max_len=24,
287 | *args,
288 | **kwargs,
289 | ):
290 | super().__init__(*args, **kwargs)
291 | assert attention_mode == "Temporal"
292 |
293 | self.attention_mode = attention_mode
294 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None
295 |
296 | self.pos_encoder = (
297 | PositionalEncoding(
298 | kwargs["query_dim"],
299 | dropout=0.0,
300 | max_len=temporal_position_encoding_max_len,
301 | )
302 | if (temporal_position_encoding and attention_mode == "Temporal")
303 | else None
304 | )
305 |
306 | def extra_repr(self):
307 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
308 |
309 | def set_use_memory_efficient_attention_xformers(
310 | self,
311 | use_memory_efficient_attention_xformers: bool,
312 | attention_op: Optional[Callable] = None,
313 | ):
314 | if use_memory_efficient_attention_xformers:
315 | if not is_xformers_available():
316 | raise ModuleNotFoundError(
317 | (
318 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
319 | " xformers"
320 | ),
321 | name="xformers",
322 | )
323 | elif not torch.cuda.is_available():
324 | raise ValueError(
325 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
326 | " only available for GPU "
327 | )
328 | else:
329 | try:
330 | # Make sure we can run the memory efficient attention
331 | _ = xformers.ops.memory_efficient_attention(
332 | torch.randn((1, 2, 40), device="cuda"),
333 | torch.randn((1, 2, 40), device="cuda"),
334 | torch.randn((1, 2, 40), device="cuda"),
335 | )
336 | except Exception as e:
337 | raise e
338 |
339 | # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13.
340 | # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13.
341 | # You don't need XFormersAttnProcessor here.
342 | # processor = XFormersAttnProcessor(
343 | # attention_op=attention_op,
344 | # )
345 | processor = AttnProcessor()
346 | else:
347 | processor = AttnProcessor()
348 |
349 | self.set_processor(processor)
350 |
351 | def forward(
352 | self,
353 | hidden_states,
354 | encoder_hidden_states=None,
355 | attention_mask=None,
356 | video_length=None,
357 | **cross_attention_kwargs,
358 | ):
359 | if self.attention_mode == "Temporal":
360 | d = hidden_states.shape[1] # d means HxW
361 | hidden_states = rearrange(
362 | hidden_states, "(b f) d c -> (b d) f c", f=video_length
363 | )
364 |
365 | if self.pos_encoder is not None:
366 | hidden_states = self.pos_encoder(hidden_states)
367 |
368 | encoder_hidden_states = (
369 | repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
370 | if encoder_hidden_states is not None
371 | else encoder_hidden_states
372 | )
373 |
374 | else:
375 | raise NotImplementedError
376 |
377 | hidden_states = self.processor(
378 | self,
379 | hidden_states,
380 | encoder_hidden_states=encoder_hidden_states,
381 | attention_mask=attention_mask,
382 | **cross_attention_kwargs,
383 | )
384 |
385 | if self.attention_mode == "Temporal":
386 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
387 |
388 | return hidden_states
389 |
--------------------------------------------------------------------------------
/musepose/models/mutual_self_attention.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py
2 | from typing import Any, Dict, Optional
3 |
4 | import torch
5 | from einops import rearrange
6 |
7 | from musepose.models.attention import TemporalBasicTransformerBlock
8 |
9 | from .attention import BasicTransformerBlock
10 |
11 |
12 | def torch_dfs(model: torch.nn.Module):
13 | result = [model]
14 | for child in model.children():
15 | result += torch_dfs(child)
16 | return result
17 |
18 |
19 | class ReferenceAttentionControl:
20 | def __init__(
21 | self,
22 | unet,
23 | mode="write",
24 | do_classifier_free_guidance=False,
25 | attention_auto_machine_weight=float("inf"),
26 | gn_auto_machine_weight=1.0,
27 | style_fidelity=1.0,
28 | reference_attn=True,
29 | reference_adain=False,
30 | fusion_blocks="midup",
31 | batch_size=1,
32 | ) -> None:
33 | # 10. Modify self attention and group norm
34 | self.unet = unet
35 | assert mode in ["read", "write"]
36 | assert fusion_blocks in ["midup", "full"]
37 | self.reference_attn = reference_attn
38 | self.reference_adain = reference_adain
39 | self.fusion_blocks = fusion_blocks
40 | self.register_reference_hooks(
41 | mode,
42 | do_classifier_free_guidance,
43 | attention_auto_machine_weight,
44 | gn_auto_machine_weight,
45 | style_fidelity,
46 | reference_attn,
47 | reference_adain,
48 | fusion_blocks,
49 | batch_size=batch_size,
50 | )
51 |
52 | def register_reference_hooks(
53 | self,
54 | mode,
55 | do_classifier_free_guidance,
56 | attention_auto_machine_weight,
57 | gn_auto_machine_weight,
58 | style_fidelity,
59 | reference_attn,
60 | reference_adain,
61 | dtype=torch.float16,
62 | batch_size=1,
63 | num_images_per_prompt=1,
64 | device=torch.device("cpu"),
65 | fusion_blocks="midup",
66 | ):
67 | MODE = mode
68 | do_classifier_free_guidance = do_classifier_free_guidance
69 | attention_auto_machine_weight = attention_auto_machine_weight
70 | gn_auto_machine_weight = gn_auto_machine_weight
71 | style_fidelity = style_fidelity
72 | reference_attn = reference_attn
73 | reference_adain = reference_adain
74 | fusion_blocks = fusion_blocks
75 | num_images_per_prompt = num_images_per_prompt
76 | dtype = dtype
77 | if do_classifier_free_guidance:
78 | uc_mask = (
79 | torch.Tensor(
80 | [1] * batch_size * num_images_per_prompt * 16
81 | + [0] * batch_size * num_images_per_prompt * 16
82 | )
83 | .to(device)
84 | .bool()
85 | )
86 | else:
87 | uc_mask = (
88 | torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
89 | .to(device)
90 | .bool()
91 | )
92 |
93 | def hacked_basic_transformer_inner_forward(
94 | self,
95 | hidden_states: torch.FloatTensor,
96 | attention_mask: Optional[torch.FloatTensor] = None,
97 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
98 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
99 | timestep: Optional[torch.LongTensor] = None,
100 | cross_attention_kwargs: Dict[str, Any] = None,
101 | class_labels: Optional[torch.LongTensor] = None,
102 | video_length=None,
103 | ):
104 | if self.use_ada_layer_norm: # False
105 | norm_hidden_states = self.norm1(hidden_states, timestep)
106 | elif self.use_ada_layer_norm_zero:
107 | (
108 | norm_hidden_states,
109 | gate_msa,
110 | shift_mlp,
111 | scale_mlp,
112 | gate_mlp,
113 | ) = self.norm1(
114 | hidden_states,
115 | timestep,
116 | class_labels,
117 | hidden_dtype=hidden_states.dtype,
118 | )
119 | else:
120 | norm_hidden_states = self.norm1(hidden_states)
121 |
122 | # 1. Self-Attention
123 | # self.only_cross_attention = False
124 | cross_attention_kwargs = (
125 | cross_attention_kwargs if cross_attention_kwargs is not None else {}
126 | )
127 | if self.only_cross_attention:
128 | attn_output = self.attn1(
129 | norm_hidden_states,
130 | encoder_hidden_states=encoder_hidden_states
131 | if self.only_cross_attention
132 | else None,
133 | attention_mask=attention_mask,
134 | **cross_attention_kwargs,
135 | )
136 | else:
137 | if MODE == "write":
138 | self.bank.append(norm_hidden_states.clone())
139 | attn_output = self.attn1(
140 | norm_hidden_states,
141 | encoder_hidden_states=encoder_hidden_states
142 | if self.only_cross_attention
143 | else None,
144 | attention_mask=attention_mask,
145 | **cross_attention_kwargs,
146 | )
147 | if MODE == "read":
148 | bank_fea = [
149 | rearrange(
150 | d.unsqueeze(1).repeat(1, video_length, 1, 1),
151 | "b t l c -> (b t) l c",
152 | )
153 | for d in self.bank
154 | ]
155 | modify_norm_hidden_states = torch.cat(
156 | [norm_hidden_states] + bank_fea, dim=1
157 | )
158 | hidden_states_uc = (
159 | self.attn1(
160 | norm_hidden_states,
161 | encoder_hidden_states=modify_norm_hidden_states,
162 | attention_mask=attention_mask,
163 | )
164 | + hidden_states
165 | )
166 | if do_classifier_free_guidance:
167 | hidden_states_c = hidden_states_uc.clone()
168 | _uc_mask = uc_mask.clone()
169 | if hidden_states.shape[0] != _uc_mask.shape[0]:
170 | _uc_mask = (
171 | torch.Tensor(
172 | [1] * (hidden_states.shape[0] // 2)
173 | + [0] * (hidden_states.shape[0] // 2)
174 | )
175 | .to(device)
176 | .bool()
177 | )
178 | hidden_states_c[_uc_mask] = (
179 | self.attn1(
180 | norm_hidden_states[_uc_mask],
181 | encoder_hidden_states=norm_hidden_states[_uc_mask],
182 | attention_mask=attention_mask,
183 | )
184 | + hidden_states[_uc_mask]
185 | )
186 | hidden_states = hidden_states_c.clone()
187 | else:
188 | hidden_states = hidden_states_uc
189 |
190 | # self.bank.clear()
191 | if self.attn2 is not None:
192 | # Cross-Attention
193 | norm_hidden_states = (
194 | self.norm2(hidden_states, timestep)
195 | if self.use_ada_layer_norm
196 | else self.norm2(hidden_states)
197 | )
198 | hidden_states = (
199 | self.attn2(
200 | norm_hidden_states,
201 | encoder_hidden_states=encoder_hidden_states,
202 | attention_mask=attention_mask,
203 | )
204 | + hidden_states
205 | )
206 |
207 | # Feed-forward
208 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
209 |
210 | # Temporal-Attention
211 | if self.unet_use_temporal_attention:
212 | d = hidden_states.shape[1]
213 | hidden_states = rearrange(
214 | hidden_states, "(b f) d c -> (b d) f c", f=video_length
215 | )
216 | norm_hidden_states = (
217 | self.norm_temp(hidden_states, timestep)
218 | if self.use_ada_layer_norm
219 | else self.norm_temp(hidden_states)
220 | )
221 | hidden_states = (
222 | self.attn_temp(norm_hidden_states) + hidden_states
223 | )
224 | hidden_states = rearrange(
225 | hidden_states, "(b d) f c -> (b f) d c", d=d
226 | )
227 |
228 | return hidden_states
229 |
230 | if self.use_ada_layer_norm_zero:
231 | attn_output = gate_msa.unsqueeze(1) * attn_output
232 | hidden_states = attn_output + hidden_states
233 |
234 | if self.attn2 is not None:
235 | norm_hidden_states = (
236 | self.norm2(hidden_states, timestep)
237 | if self.use_ada_layer_norm
238 | else self.norm2(hidden_states)
239 | )
240 |
241 | # 2. Cross-Attention
242 | attn_output = self.attn2(
243 | norm_hidden_states,
244 | encoder_hidden_states=encoder_hidden_states,
245 | attention_mask=encoder_attention_mask,
246 | **cross_attention_kwargs,
247 | )
248 | hidden_states = attn_output + hidden_states
249 |
250 | # 3. Feed-forward
251 | norm_hidden_states = self.norm3(hidden_states)
252 |
253 | if self.use_ada_layer_norm_zero:
254 | norm_hidden_states = (
255 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
256 | )
257 |
258 | ff_output = self.ff(norm_hidden_states)
259 |
260 | if self.use_ada_layer_norm_zero:
261 | ff_output = gate_mlp.unsqueeze(1) * ff_output
262 |
263 | hidden_states = ff_output + hidden_states
264 |
265 | return hidden_states
266 |
267 | if self.reference_attn:
268 | if self.fusion_blocks == "midup":
269 | attn_modules = [
270 | module
271 | for module in (
272 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
273 | )
274 | if isinstance(module, BasicTransformerBlock)
275 | or isinstance(module, TemporalBasicTransformerBlock)
276 | ]
277 | elif self.fusion_blocks == "full":
278 | attn_modules = [
279 | module
280 | for module in torch_dfs(self.unet)
281 | if isinstance(module, BasicTransformerBlock)
282 | or isinstance(module, TemporalBasicTransformerBlock)
283 | ]
284 | attn_modules = sorted(
285 | attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
286 | )
287 |
288 | for i, module in enumerate(attn_modules):
289 | module._original_inner_forward = module.forward
290 | if isinstance(module, BasicTransformerBlock):
291 | module.forward = hacked_basic_transformer_inner_forward.__get__(
292 | module, BasicTransformerBlock
293 | )
294 | if isinstance(module, TemporalBasicTransformerBlock):
295 | module.forward = hacked_basic_transformer_inner_forward.__get__(
296 | module, TemporalBasicTransformerBlock
297 | )
298 |
299 | module.bank = []
300 | module.attn_weight = float(i) / float(len(attn_modules))
301 |
302 | def update(self, writer, dtype=torch.float16):
303 | if self.reference_attn:
304 | if self.fusion_blocks == "midup":
305 | reader_attn_modules = [
306 | module
307 | for module in (
308 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
309 | )
310 | if isinstance(module, TemporalBasicTransformerBlock)
311 | ]
312 | writer_attn_modules = [
313 | module
314 | for module in (
315 | torch_dfs(writer.unet.mid_block)
316 | + torch_dfs(writer.unet.up_blocks)
317 | )
318 | if isinstance(module, BasicTransformerBlock)
319 | ]
320 | elif self.fusion_blocks == "full":
321 | reader_attn_modules = [
322 | module
323 | for module in torch_dfs(self.unet)
324 | if isinstance(module, TemporalBasicTransformerBlock)
325 | ]
326 | writer_attn_modules = [
327 | module
328 | for module in torch_dfs(writer.unet)
329 | if isinstance(module, BasicTransformerBlock)
330 | ]
331 | reader_attn_modules = sorted(
332 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
333 | )
334 | writer_attn_modules = sorted(
335 | writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
336 | )
337 | for r, w in zip(reader_attn_modules, writer_attn_modules):
338 | r.bank = [v.clone().to(dtype) for v in w.bank]
339 | # w.bank.clear()
340 |
341 | def clear(self):
342 | if self.reference_attn:
343 | if self.fusion_blocks == "midup":
344 | reader_attn_modules = [
345 | module
346 | for module in (
347 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
348 | )
349 | if isinstance(module, BasicTransformerBlock)
350 | or isinstance(module, TemporalBasicTransformerBlock)
351 | ]
352 | elif self.fusion_blocks == "full":
353 | reader_attn_modules = [
354 | module
355 | for module in torch_dfs(self.unet)
356 | if isinstance(module, BasicTransformerBlock)
357 | or isinstance(module, TemporalBasicTransformerBlock)
358 | ]
359 | reader_attn_modules = sorted(
360 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
361 | )
362 | for r in reader_attn_modules:
363 | r.bank.clear()
364 |
--------------------------------------------------------------------------------
/musepose/models/pose_guider.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.nn.init as init
6 | from diffusers.models.modeling_utils import ModelMixin
7 |
8 | from musepose.models.motion_module import zero_module
9 | from musepose.models.resnet import InflatedConv3d
10 |
11 |
12 | class PoseGuider(ModelMixin):
13 | def __init__(
14 | self,
15 | conditioning_embedding_channels: int,
16 | conditioning_channels: int = 3,
17 | block_out_channels: Tuple[int] = (16, 32, 64, 128),
18 | ):
19 | super().__init__()
20 | self.conv_in = InflatedConv3d(
21 | conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
22 | )
23 |
24 | self.blocks = nn.ModuleList([])
25 |
26 | for i in range(len(block_out_channels) - 1):
27 | channel_in = block_out_channels[i]
28 | channel_out = block_out_channels[i + 1]
29 | self.blocks.append(
30 | InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
31 | )
32 | self.blocks.append(
33 | InflatedConv3d(
34 | channel_in, channel_out, kernel_size=3, padding=1, stride=2
35 | )
36 | )
37 |
38 | self.conv_out = zero_module(
39 | InflatedConv3d(
40 | block_out_channels[-1],
41 | conditioning_embedding_channels,
42 | kernel_size=3,
43 | padding=1,
44 | )
45 | )
46 |
47 | def forward(self, conditioning):
48 | embedding = self.conv_in(conditioning)
49 | embedding = F.silu(embedding)
50 |
51 | for block in self.blocks:
52 | embedding = block(embedding)
53 | embedding = F.silu(embedding)
54 |
55 | embedding = self.conv_out(embedding)
56 |
57 | return embedding
58 |
--------------------------------------------------------------------------------
/musepose/models/resnet.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from einops import rearrange
7 |
8 |
9 | class InflatedConv3d(nn.Conv2d):
10 | def forward(self, x):
11 | video_length = x.shape[2]
12 |
13 | x = rearrange(x, "b c f h w -> (b f) c h w")
14 | x = super().forward(x)
15 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
16 |
17 | return x
18 |
19 |
20 | class InflatedGroupNorm(nn.GroupNorm):
21 | def forward(self, x):
22 | video_length = x.shape[2]
23 |
24 | x = rearrange(x, "b c f h w -> (b f) c h w")
25 | x = super().forward(x)
26 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
27 |
28 | return x
29 |
30 |
31 | class Upsample3D(nn.Module):
32 | def __init__(
33 | self,
34 | channels,
35 | use_conv=False,
36 | use_conv_transpose=False,
37 | out_channels=None,
38 | name="conv",
39 | ):
40 | super().__init__()
41 | self.channels = channels
42 | self.out_channels = out_channels or channels
43 | self.use_conv = use_conv
44 | self.use_conv_transpose = use_conv_transpose
45 | self.name = name
46 |
47 | conv = None
48 | if use_conv_transpose:
49 | raise NotImplementedError
50 | elif use_conv:
51 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
52 |
53 | def forward(self, hidden_states, output_size=None):
54 | assert hidden_states.shape[1] == self.channels
55 |
56 | if self.use_conv_transpose:
57 | raise NotImplementedError
58 |
59 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
60 | dtype = hidden_states.dtype
61 | if dtype == torch.bfloat16:
62 | hidden_states = hidden_states.to(torch.float32)
63 |
64 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
65 | if hidden_states.shape[0] >= 64:
66 | hidden_states = hidden_states.contiguous()
67 |
68 | # if `output_size` is passed we force the interpolation output
69 | # size and do not make use of `scale_factor=2`
70 | if output_size is None:
71 | hidden_states = F.interpolate(
72 | hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
73 | )
74 | else:
75 | hidden_states = F.interpolate(
76 | hidden_states, size=output_size, mode="nearest"
77 | )
78 |
79 | # If the input is bfloat16, we cast back to bfloat16
80 | if dtype == torch.bfloat16:
81 | hidden_states = hidden_states.to(dtype)
82 |
83 | # if self.use_conv:
84 | # if self.name == "conv":
85 | # hidden_states = self.conv(hidden_states)
86 | # else:
87 | # hidden_states = self.Conv2d_0(hidden_states)
88 | hidden_states = self.conv(hidden_states)
89 |
90 | return hidden_states
91 |
92 |
93 | class Downsample3D(nn.Module):
94 | def __init__(
95 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
96 | ):
97 | super().__init__()
98 | self.channels = channels
99 | self.out_channels = out_channels or channels
100 | self.use_conv = use_conv
101 | self.padding = padding
102 | stride = 2
103 | self.name = name
104 |
105 | if use_conv:
106 | self.conv = InflatedConv3d(
107 | self.channels, self.out_channels, 3, stride=stride, padding=padding
108 | )
109 | else:
110 | raise NotImplementedError
111 |
112 | def forward(self, hidden_states):
113 | assert hidden_states.shape[1] == self.channels
114 | if self.use_conv and self.padding == 0:
115 | raise NotImplementedError
116 |
117 | assert hidden_states.shape[1] == self.channels
118 | hidden_states = self.conv(hidden_states)
119 |
120 | return hidden_states
121 |
122 |
123 | class ResnetBlock3D(nn.Module):
124 | def __init__(
125 | self,
126 | *,
127 | in_channels,
128 | out_channels=None,
129 | conv_shortcut=False,
130 | dropout=0.0,
131 | temb_channels=512,
132 | groups=32,
133 | groups_out=None,
134 | pre_norm=True,
135 | eps=1e-6,
136 | non_linearity="swish",
137 | time_embedding_norm="default",
138 | output_scale_factor=1.0,
139 | use_in_shortcut=None,
140 | use_inflated_groupnorm=None,
141 | ):
142 | super().__init__()
143 | self.pre_norm = pre_norm
144 | self.pre_norm = True
145 | self.in_channels = in_channels
146 | out_channels = in_channels if out_channels is None else out_channels
147 | self.out_channels = out_channels
148 | self.use_conv_shortcut = conv_shortcut
149 | self.time_embedding_norm = time_embedding_norm
150 | self.output_scale_factor = output_scale_factor
151 |
152 | if groups_out is None:
153 | groups_out = groups
154 |
155 | assert use_inflated_groupnorm != None
156 | if use_inflated_groupnorm:
157 | self.norm1 = InflatedGroupNorm(
158 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True
159 | )
160 | else:
161 | self.norm1 = torch.nn.GroupNorm(
162 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True
163 | )
164 |
165 | self.conv1 = InflatedConv3d(
166 | in_channels, out_channels, kernel_size=3, stride=1, padding=1
167 | )
168 |
169 | if temb_channels is not None:
170 | if self.time_embedding_norm == "default":
171 | time_emb_proj_out_channels = out_channels
172 | elif self.time_embedding_norm == "scale_shift":
173 | time_emb_proj_out_channels = out_channels * 2
174 | else:
175 | raise ValueError(
176 | f"unknown time_embedding_norm : {self.time_embedding_norm} "
177 | )
178 |
179 | self.time_emb_proj = torch.nn.Linear(
180 | temb_channels, time_emb_proj_out_channels
181 | )
182 | else:
183 | self.time_emb_proj = None
184 |
185 | if use_inflated_groupnorm:
186 | self.norm2 = InflatedGroupNorm(
187 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
188 | )
189 | else:
190 | self.norm2 = torch.nn.GroupNorm(
191 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
192 | )
193 | self.dropout = torch.nn.Dropout(dropout)
194 | self.conv2 = InflatedConv3d(
195 | out_channels, out_channels, kernel_size=3, stride=1, padding=1
196 | )
197 |
198 | if non_linearity == "swish":
199 | self.nonlinearity = lambda x: F.silu(x)
200 | elif non_linearity == "mish":
201 | self.nonlinearity = Mish()
202 | elif non_linearity == "silu":
203 | self.nonlinearity = nn.SiLU()
204 |
205 | self.use_in_shortcut = (
206 | self.in_channels != self.out_channels
207 | if use_in_shortcut is None
208 | else use_in_shortcut
209 | )
210 |
211 | self.conv_shortcut = None
212 | if self.use_in_shortcut:
213 | self.conv_shortcut = InflatedConv3d(
214 | in_channels, out_channels, kernel_size=1, stride=1, padding=0
215 | )
216 |
217 | def forward(self, input_tensor, temb):
218 | hidden_states = input_tensor
219 |
220 | hidden_states = self.norm1(hidden_states)
221 | hidden_states = self.nonlinearity(hidden_states)
222 |
223 | hidden_states = self.conv1(hidden_states)
224 |
225 | if temb is not None:
226 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
227 |
228 | if temb is not None and self.time_embedding_norm == "default":
229 | hidden_states = hidden_states + temb
230 |
231 | hidden_states = self.norm2(hidden_states)
232 |
233 | if temb is not None and self.time_embedding_norm == "scale_shift":
234 | scale, shift = torch.chunk(temb, 2, dim=1)
235 | hidden_states = hidden_states * (1 + scale) + shift
236 |
237 | hidden_states = self.nonlinearity(hidden_states)
238 |
239 | hidden_states = self.dropout(hidden_states)
240 | hidden_states = self.conv2(hidden_states)
241 |
242 | if self.conv_shortcut is not None:
243 | input_tensor = self.conv_shortcut(input_tensor)
244 |
245 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
246 |
247 | return output_tensor
248 |
249 |
250 | class Mish(torch.nn.Module):
251 | def forward(self, hidden_states):
252 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
253 |
--------------------------------------------------------------------------------
/musepose/models/transformer_3d.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional
3 |
4 | import torch
5 | from diffusers.configuration_utils import ConfigMixin, register_to_config
6 | from diffusers.models import ModelMixin
7 | from diffusers.utils import BaseOutput
8 | from diffusers.utils.import_utils import is_xformers_available
9 | from einops import rearrange, repeat
10 | from torch import nn
11 |
12 | from .attention import TemporalBasicTransformerBlock
13 |
14 |
15 | @dataclass
16 | class Transformer3DModelOutput(BaseOutput):
17 | sample: torch.FloatTensor
18 |
19 |
20 | if is_xformers_available():
21 | import xformers
22 | import xformers.ops
23 | else:
24 | xformers = None
25 |
26 |
27 | class Transformer3DModel(ModelMixin, ConfigMixin):
28 | _supports_gradient_checkpointing = True
29 |
30 | @register_to_config
31 | def __init__(
32 | self,
33 | num_attention_heads: int = 16,
34 | attention_head_dim: int = 88,
35 | in_channels: Optional[int] = None,
36 | num_layers: int = 1,
37 | dropout: float = 0.0,
38 | norm_num_groups: int = 32,
39 | cross_attention_dim: Optional[int] = None,
40 | attention_bias: bool = False,
41 | activation_fn: str = "geglu",
42 | num_embeds_ada_norm: Optional[int] = None,
43 | use_linear_projection: bool = False,
44 | only_cross_attention: bool = False,
45 | upcast_attention: bool = False,
46 | unet_use_cross_frame_attention=None,
47 | unet_use_temporal_attention=None,
48 | ):
49 | super().__init__()
50 | self.use_linear_projection = use_linear_projection
51 | self.num_attention_heads = num_attention_heads
52 | self.attention_head_dim = attention_head_dim
53 | inner_dim = num_attention_heads * attention_head_dim
54 |
55 | # Define input layers
56 | self.in_channels = in_channels
57 |
58 | self.norm = torch.nn.GroupNorm(
59 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
60 | )
61 | if use_linear_projection:
62 | self.proj_in = nn.Linear(in_channels, inner_dim)
63 | else:
64 | self.proj_in = nn.Conv2d(
65 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0
66 | )
67 |
68 | # Define transformers blocks
69 | self.transformer_blocks = nn.ModuleList(
70 | [
71 | TemporalBasicTransformerBlock(
72 | inner_dim,
73 | num_attention_heads,
74 | attention_head_dim,
75 | dropout=dropout,
76 | cross_attention_dim=cross_attention_dim,
77 | activation_fn=activation_fn,
78 | num_embeds_ada_norm=num_embeds_ada_norm,
79 | attention_bias=attention_bias,
80 | only_cross_attention=only_cross_attention,
81 | upcast_attention=upcast_attention,
82 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
83 | unet_use_temporal_attention=unet_use_temporal_attention,
84 | )
85 | for d in range(num_layers)
86 | ]
87 | )
88 |
89 | # 4. Define output layers
90 | if use_linear_projection:
91 | self.proj_out = nn.Linear(in_channels, inner_dim)
92 | else:
93 | self.proj_out = nn.Conv2d(
94 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0
95 | )
96 |
97 | self.gradient_checkpointing = False
98 |
99 | def _set_gradient_checkpointing(self, module, value=False):
100 | if hasattr(module, "gradient_checkpointing"):
101 | module.gradient_checkpointing = value
102 |
103 | def forward(
104 | self,
105 | hidden_states,
106 | encoder_hidden_states=None,
107 | timestep=None,
108 | return_dict: bool = True,
109 | ):
110 | # Input
111 | assert (
112 | hidden_states.dim() == 5
113 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
114 | video_length = hidden_states.shape[2]
115 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
116 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
117 | encoder_hidden_states = repeat(
118 | encoder_hidden_states, "b n c -> (b f) n c", f=video_length
119 | )
120 |
121 | batch, channel, height, weight = hidden_states.shape
122 | residual = hidden_states
123 |
124 | hidden_states = self.norm(hidden_states)
125 | if not self.use_linear_projection:
126 | hidden_states = self.proj_in(hidden_states)
127 | inner_dim = hidden_states.shape[1]
128 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
129 | batch, height * weight, inner_dim
130 | )
131 | else:
132 | inner_dim = hidden_states.shape[1]
133 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
134 | batch, height * weight, inner_dim
135 | )
136 | hidden_states = self.proj_in(hidden_states)
137 |
138 | # Blocks
139 | for i, block in enumerate(self.transformer_blocks):
140 | hidden_states = block(
141 | hidden_states,
142 | encoder_hidden_states=encoder_hidden_states,
143 | timestep=timestep,
144 | video_length=video_length,
145 | )
146 |
147 | # Output
148 | if not self.use_linear_projection:
149 | hidden_states = (
150 | hidden_states.reshape(batch, height, weight, inner_dim)
151 | .permute(0, 3, 1, 2)
152 | .contiguous()
153 | )
154 | hidden_states = self.proj_out(hidden_states)
155 | else:
156 | hidden_states = self.proj_out(hidden_states)
157 | hidden_states = (
158 | hidden_states.reshape(batch, height, weight, inner_dim)
159 | .permute(0, 3, 1, 2)
160 | .contiguous()
161 | )
162 |
163 | output = hidden_states + residual
164 |
165 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
166 | if not return_dict:
167 | return (output,)
168 |
169 | return Transformer3DModelOutput(sample=output)
170 |
--------------------------------------------------------------------------------
/musepose/pipelines/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TMElyralab/MusePose/61c52bd937224a614b3951419b735b639397cb62/musepose/pipelines/__init__.py
--------------------------------------------------------------------------------
/musepose/pipelines/context.py:
--------------------------------------------------------------------------------
1 | # TODO: Adapted from cli
2 | from typing import Callable, List, Optional
3 |
4 | import numpy as np
5 |
6 |
7 | def ordered_halving(val):
8 | bin_str = f"{val:064b}"
9 | bin_flip = bin_str[::-1]
10 | as_int = int(bin_flip, 2)
11 |
12 | return as_int / (1 << 64)
13 |
14 |
15 | def uniform(
16 | step: int = ...,
17 | num_steps: Optional[int] = None,
18 | num_frames: int = ...,
19 | context_size: Optional[int] = None,
20 | context_stride: int = 3,
21 | context_overlap: int = 4,
22 | closed_loop: bool = False,
23 | ):
24 | if num_frames <= context_size:
25 | yield list(range(num_frames))
26 | return
27 |
28 | context_stride = min(
29 | context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
30 | )
31 |
32 | for context_step in 1 << np.arange(context_stride):
33 | pad = int(round(num_frames * ordered_halving(step)))
34 | for j in range(
35 | int(ordered_halving(step) * context_step) + pad,
36 | num_frames + pad + (0 if closed_loop else -context_overlap),
37 | (context_size * context_step - context_overlap),
38 | ):
39 | yield [
40 | e % num_frames
41 | for e in range(j, j + context_size * context_step, context_step)
42 | ]
43 |
44 |
45 | def get_context_scheduler(name: str) -> Callable:
46 | if name == "uniform":
47 | return uniform
48 | else:
49 | raise ValueError(f"Unknown context_overlap policy {name}")
50 |
51 |
52 | def get_total_steps(
53 | scheduler,
54 | timesteps: List[int],
55 | num_steps: Optional[int] = None,
56 | num_frames: int = ...,
57 | context_size: Optional[int] = None,
58 | context_stride: int = 3,
59 | context_overlap: int = 4,
60 | closed_loop: bool = True,
61 | ):
62 | return sum(
63 | len(
64 | list(
65 | scheduler(
66 | i,
67 | num_steps,
68 | num_frames,
69 | context_size,
70 | context_stride,
71 | context_overlap,
72 | )
73 | )
74 | )
75 | for i in range(len(timesteps))
76 | )
77 |
--------------------------------------------------------------------------------
/musepose/pipelines/pipeline_pose2img.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from dataclasses import dataclass
3 | from typing import Callable, List, Optional, Union
4 |
5 | import numpy as np
6 | import torch
7 | from diffusers import DiffusionPipeline
8 | from diffusers.image_processor import VaeImageProcessor
9 | from diffusers.schedulers import (
10 | DDIMScheduler,
11 | DPMSolverMultistepScheduler,
12 | EulerAncestralDiscreteScheduler,
13 | EulerDiscreteScheduler,
14 | LMSDiscreteScheduler,
15 | PNDMScheduler,
16 | )
17 | from diffusers.utils import BaseOutput, is_accelerate_available
18 | from diffusers.utils.torch_utils import randn_tensor
19 | from einops import rearrange
20 | from tqdm import tqdm
21 | from transformers import CLIPImageProcessor
22 |
23 | from musepose.models.mutual_self_attention import ReferenceAttentionControl
24 |
25 |
26 | @dataclass
27 | class Pose2ImagePipelineOutput(BaseOutput):
28 | images: Union[torch.Tensor, np.ndarray]
29 |
30 |
31 | class Pose2ImagePipeline(DiffusionPipeline):
32 | _optional_components = []
33 |
34 | def __init__(
35 | self,
36 | vae,
37 | image_encoder,
38 | reference_unet,
39 | denoising_unet,
40 | pose_guider,
41 | scheduler: Union[
42 | DDIMScheduler,
43 | PNDMScheduler,
44 | LMSDiscreteScheduler,
45 | EulerDiscreteScheduler,
46 | EulerAncestralDiscreteScheduler,
47 | DPMSolverMultistepScheduler,
48 | ],
49 | ):
50 | super().__init__()
51 |
52 | self.register_modules(
53 | vae=vae,
54 | image_encoder=image_encoder,
55 | reference_unet=reference_unet,
56 | denoising_unet=denoising_unet,
57 | pose_guider=pose_guider,
58 | scheduler=scheduler,
59 | )
60 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
61 | self.clip_image_processor = CLIPImageProcessor()
62 | self.ref_image_processor = VaeImageProcessor(
63 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
64 | )
65 | self.cond_image_processor = VaeImageProcessor(
66 | vae_scale_factor=self.vae_scale_factor,
67 | do_convert_rgb=True,
68 | do_normalize=False,
69 | )
70 |
71 | def enable_vae_slicing(self):
72 | self.vae.enable_slicing()
73 |
74 | def disable_vae_slicing(self):
75 | self.vae.disable_slicing()
76 |
77 | def enable_sequential_cpu_offload(self, gpu_id=0):
78 | if is_accelerate_available():
79 | from accelerate import cpu_offload
80 | else:
81 | raise ImportError("Please install accelerate via `pip install accelerate`")
82 |
83 | device = torch.device(f"cuda:{gpu_id}")
84 |
85 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
86 | if cpu_offloaded_model is not None:
87 | cpu_offload(cpu_offloaded_model, device)
88 |
89 | @property
90 | def _execution_device(self):
91 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
92 | return self.device
93 | for module in self.unet.modules():
94 | if (
95 | hasattr(module, "_hf_hook")
96 | and hasattr(module._hf_hook, "execution_device")
97 | and module._hf_hook.execution_device is not None
98 | ):
99 | return torch.device(module._hf_hook.execution_device)
100 | return self.device
101 |
102 | def decode_latents(self, latents):
103 | video_length = latents.shape[2]
104 | latents = 1 / 0.18215 * latents
105 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
106 | # video = self.vae.decode(latents).sample
107 | video = []
108 | for frame_idx in tqdm(range(latents.shape[0])):
109 | video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
110 | video = torch.cat(video)
111 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
112 | video = (video / 2 + 0.5).clamp(0, 1)
113 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
114 | video = video.cpu().float().numpy()
115 | return video
116 |
117 | def prepare_extra_step_kwargs(self, generator, eta):
118 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
119 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
120 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
121 | # and should be between [0, 1]
122 |
123 | accepts_eta = "eta" in set(
124 | inspect.signature(self.scheduler.step).parameters.keys()
125 | )
126 | extra_step_kwargs = {}
127 | if accepts_eta:
128 | extra_step_kwargs["eta"] = eta
129 |
130 | # check if the scheduler accepts generator
131 | accepts_generator = "generator" in set(
132 | inspect.signature(self.scheduler.step).parameters.keys()
133 | )
134 | if accepts_generator:
135 | extra_step_kwargs["generator"] = generator
136 | return extra_step_kwargs
137 |
138 | def prepare_latents(
139 | self,
140 | batch_size,
141 | num_channels_latents,
142 | width,
143 | height,
144 | dtype,
145 | device,
146 | generator,
147 | latents=None,
148 | ):
149 | shape = (
150 | batch_size,
151 | num_channels_latents,
152 | height // self.vae_scale_factor,
153 | width // self.vae_scale_factor,
154 | )
155 | if isinstance(generator, list) and len(generator) != batch_size:
156 | raise ValueError(
157 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
158 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
159 | )
160 |
161 | if latents is None:
162 | latents = randn_tensor(
163 | shape, generator=generator, device=device, dtype=dtype
164 | )
165 | else:
166 | latents = latents.to(device)
167 |
168 | # scale the initial noise by the standard deviation required by the scheduler
169 | latents = latents * self.scheduler.init_noise_sigma
170 | return latents
171 |
172 | def prepare_condition(
173 | self,
174 | cond_image,
175 | width,
176 | height,
177 | device,
178 | dtype,
179 | do_classififer_free_guidance=False,
180 | ):
181 | image = self.cond_image_processor.preprocess(
182 | cond_image, height=height, width=width
183 | ).to(dtype=torch.float32)
184 |
185 | image = image.to(device=device, dtype=dtype)
186 |
187 | if do_classififer_free_guidance:
188 | image = torch.cat([image] * 2)
189 |
190 | return image
191 |
192 | @torch.no_grad()
193 | def __call__(
194 | self,
195 | ref_image,
196 | pose_image,
197 | width,
198 | height,
199 | num_inference_steps,
200 | guidance_scale,
201 | num_images_per_prompt=1,
202 | eta: float = 0.0,
203 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
204 | output_type: Optional[str] = "tensor",
205 | return_dict: bool = True,
206 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
207 | callback_steps: Optional[int] = 1,
208 | **kwargs,
209 | ):
210 | # Default height and width to unet
211 | height = height or self.unet.config.sample_size * self.vae_scale_factor
212 | width = width or self.unet.config.sample_size * self.vae_scale_factor
213 |
214 | device = self._execution_device
215 |
216 | do_classifier_free_guidance = guidance_scale > 1.0
217 |
218 | # Prepare timesteps
219 | self.scheduler.set_timesteps(num_inference_steps, device=device)
220 | timesteps = self.scheduler.timesteps
221 |
222 | batch_size = 1
223 |
224 | # Prepare clip image embeds
225 | clip_image = self.clip_image_processor.preprocess(
226 | ref_image.resize((224, 224)), return_tensors="pt"
227 | ).pixel_values
228 | clip_image_embeds = self.image_encoder(
229 | clip_image.to(device, dtype=self.image_encoder.dtype)
230 | ).image_embeds
231 | image_prompt_embeds = clip_image_embeds.unsqueeze(1)
232 | uncond_image_prompt_embeds = torch.zeros_like(image_prompt_embeds)
233 |
234 | if do_classifier_free_guidance:
235 | image_prompt_embeds = torch.cat(
236 | [uncond_image_prompt_embeds, image_prompt_embeds], dim=0
237 | )
238 |
239 | reference_control_writer = ReferenceAttentionControl(
240 | self.reference_unet,
241 | do_classifier_free_guidance=do_classifier_free_guidance,
242 | mode="write",
243 | batch_size=batch_size,
244 | fusion_blocks="full",
245 | )
246 | reference_control_reader = ReferenceAttentionControl(
247 | self.denoising_unet,
248 | do_classifier_free_guidance=do_classifier_free_guidance,
249 | mode="read",
250 | batch_size=batch_size,
251 | fusion_blocks="full",
252 | )
253 |
254 | num_channels_latents = self.denoising_unet.in_channels
255 | latents = self.prepare_latents(
256 | batch_size * num_images_per_prompt,
257 | num_channels_latents,
258 | width,
259 | height,
260 | clip_image_embeds.dtype,
261 | device,
262 | generator,
263 | )
264 | latents = latents.unsqueeze(2) # (bs, c, 1, h', w')
265 | latents_dtype = latents.dtype
266 |
267 | # Prepare extra step kwargs.
268 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
269 |
270 | # Prepare ref image latents
271 | ref_image_tensor = self.ref_image_processor.preprocess(
272 | ref_image, height=height, width=width
273 | ) # (bs, c, width, height)
274 | ref_image_tensor = ref_image_tensor.to(
275 | dtype=self.vae.dtype, device=self.vae.device
276 | )
277 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
278 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
279 |
280 | # Prepare pose condition image
281 | pose_cond_tensor = self.cond_image_processor.preprocess(
282 | pose_image, height=height, width=width
283 | )
284 | pose_cond_tensor = pose_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w)
285 | pose_cond_tensor = pose_cond_tensor.to(
286 | device=device, dtype=self.pose_guider.dtype
287 | )
288 | pose_fea = self.pose_guider(pose_cond_tensor)
289 | pose_fea = (
290 | torch.cat([pose_fea] * 2) if do_classifier_free_guidance else pose_fea
291 | )
292 |
293 | # denoising loop
294 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
295 | with self.progress_bar(total=num_inference_steps) as progress_bar:
296 | for i, t in enumerate(timesteps):
297 | # 1. Forward reference image
298 | if i == 0:
299 | self.reference_unet(
300 | ref_image_latents.repeat(
301 | (2 if do_classifier_free_guidance else 1), 1, 1, 1
302 | ),
303 | torch.zeros_like(t),
304 | encoder_hidden_states=image_prompt_embeds,
305 | return_dict=False,
306 | )
307 |
308 | # 2. Update reference unet feature into denosing net
309 | reference_control_reader.update(reference_control_writer)
310 |
311 | # 3.1 expand the latents if we are doing classifier free guidance
312 | latent_model_input = (
313 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents
314 | )
315 | latent_model_input = self.scheduler.scale_model_input(
316 | latent_model_input, t
317 | )
318 |
319 | noise_pred = self.denoising_unet(
320 | latent_model_input,
321 | t,
322 | encoder_hidden_states=image_prompt_embeds,
323 | pose_cond_fea=pose_fea,
324 | return_dict=False,
325 | )[0]
326 |
327 | # perform guidance
328 | if do_classifier_free_guidance:
329 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
330 | noise_pred = noise_pred_uncond + guidance_scale * (
331 | noise_pred_text - noise_pred_uncond
332 | )
333 |
334 | # compute the previous noisy sample x_t -> x_t-1
335 | latents = self.scheduler.step(
336 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False
337 | )[0]
338 |
339 | # call the callback, if provided
340 | if i == len(timesteps) - 1 or (
341 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
342 | ):
343 | progress_bar.update()
344 | if callback is not None and i % callback_steps == 0:
345 | step_idx = i // getattr(self.scheduler, "order", 1)
346 | callback(step_idx, t, latents)
347 | reference_control_reader.clear()
348 | reference_control_writer.clear()
349 |
350 | # Post-processing
351 | image = self.decode_latents(latents) # (b, c, 1, h, w)
352 |
353 | # Convert to tensor
354 | if output_type == "tensor":
355 | image = torch.from_numpy(image)
356 |
357 | if not return_dict:
358 | return image
359 |
360 | return Pose2ImagePipelineOutput(images=image)
361 |
--------------------------------------------------------------------------------
/musepose/pipelines/pipeline_pose2vid.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from dataclasses import dataclass
3 | from typing import Callable, List, Optional, Union
4 |
5 | import numpy as np
6 | import torch
7 | from diffusers import DiffusionPipeline
8 | from diffusers.image_processor import VaeImageProcessor
9 | from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
10 | EulerAncestralDiscreteScheduler,
11 | EulerDiscreteScheduler, LMSDiscreteScheduler,
12 | PNDMScheduler)
13 | from diffusers.utils import BaseOutput, is_accelerate_available
14 | from diffusers.utils.torch_utils import randn_tensor
15 | from einops import rearrange
16 | from tqdm import tqdm
17 | from transformers import CLIPImageProcessor
18 |
19 | from musepose.models.mutual_self_attention import ReferenceAttentionControl
20 |
21 |
22 | @dataclass
23 | class Pose2VideoPipelineOutput(BaseOutput):
24 | videos: Union[torch.Tensor, np.ndarray]
25 |
26 |
27 | class Pose2VideoPipeline(DiffusionPipeline):
28 | _optional_components = []
29 |
30 | def __init__(
31 | self,
32 | vae,
33 | image_encoder,
34 | reference_unet,
35 | denoising_unet,
36 | pose_guider,
37 | scheduler: Union[
38 | DDIMScheduler,
39 | PNDMScheduler,
40 | LMSDiscreteScheduler,
41 | EulerDiscreteScheduler,
42 | EulerAncestralDiscreteScheduler,
43 | DPMSolverMultistepScheduler,
44 | ],
45 | image_proj_model=None,
46 | tokenizer=None,
47 | text_encoder=None,
48 | ):
49 | super().__init__()
50 |
51 | self.register_modules(
52 | vae=vae,
53 | image_encoder=image_encoder,
54 | reference_unet=reference_unet,
55 | denoising_unet=denoising_unet,
56 | pose_guider=pose_guider,
57 | scheduler=scheduler,
58 | image_proj_model=image_proj_model,
59 | tokenizer=tokenizer,
60 | text_encoder=text_encoder,
61 | )
62 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
63 | self.clip_image_processor = CLIPImageProcessor()
64 | self.ref_image_processor = VaeImageProcessor(
65 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
66 | )
67 | self.cond_image_processor = VaeImageProcessor(
68 | vae_scale_factor=self.vae_scale_factor,
69 | do_convert_rgb=True,
70 | do_normalize=False,
71 | )
72 |
73 | def enable_vae_slicing(self):
74 | self.vae.enable_slicing()
75 |
76 | def disable_vae_slicing(self):
77 | self.vae.disable_slicing()
78 |
79 | def enable_sequential_cpu_offload(self, gpu_id=0):
80 | if is_accelerate_available():
81 | from accelerate import cpu_offload
82 | else:
83 | raise ImportError("Please install accelerate via `pip install accelerate`")
84 |
85 | device = torch.device(f"cuda:{gpu_id}")
86 |
87 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
88 | if cpu_offloaded_model is not None:
89 | cpu_offload(cpu_offloaded_model, device)
90 |
91 | @property
92 | def _execution_device(self):
93 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
94 | return self.device
95 | for module in self.unet.modules():
96 | if (
97 | hasattr(module, "_hf_hook")
98 | and hasattr(module._hf_hook, "execution_device")
99 | and module._hf_hook.execution_device is not None
100 | ):
101 | return torch.device(module._hf_hook.execution_device)
102 | return self.device
103 |
104 | def decode_latents(self, latents):
105 | video_length = latents.shape[2]
106 | latents = 1 / 0.18215 * latents
107 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
108 | # video = self.vae.decode(latents).sample
109 | video = []
110 | for frame_idx in tqdm(range(latents.shape[0])):
111 | video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
112 | video = torch.cat(video)
113 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
114 | video = (video / 2 + 0.5).clamp(0, 1)
115 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
116 | video = video.cpu().float().numpy()
117 | return video
118 |
119 | def prepare_extra_step_kwargs(self, generator, eta):
120 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
121 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
122 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
123 | # and should be between [0, 1]
124 |
125 | accepts_eta = "eta" in set(
126 | inspect.signature(self.scheduler.step).parameters.keys()
127 | )
128 | extra_step_kwargs = {}
129 | if accepts_eta:
130 | extra_step_kwargs["eta"] = eta
131 |
132 | # check if the scheduler accepts generator
133 | accepts_generator = "generator" in set(
134 | inspect.signature(self.scheduler.step).parameters.keys()
135 | )
136 | if accepts_generator:
137 | extra_step_kwargs["generator"] = generator
138 | return extra_step_kwargs
139 |
140 | def prepare_latents(
141 | self,
142 | batch_size,
143 | num_channels_latents,
144 | width,
145 | height,
146 | video_length,
147 | dtype,
148 | device,
149 | generator,
150 | latents=None,
151 | ):
152 | shape = (
153 | batch_size,
154 | num_channels_latents,
155 | video_length,
156 | height // self.vae_scale_factor,
157 | width // self.vae_scale_factor,
158 | )
159 | if isinstance(generator, list) and len(generator) != batch_size:
160 | raise ValueError(
161 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
162 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
163 | )
164 |
165 | if latents is None:
166 | latents = randn_tensor(
167 | shape, generator=generator, device=device, dtype=dtype
168 | )
169 | else:
170 | latents = latents.to(device)
171 |
172 | # scale the initial noise by the standard deviation required by the scheduler
173 | latents = latents * self.scheduler.init_noise_sigma
174 | return latents
175 |
176 | def _encode_prompt(
177 | self,
178 | prompt,
179 | device,
180 | num_videos_per_prompt,
181 | do_classifier_free_guidance,
182 | negative_prompt,
183 | ):
184 | batch_size = len(prompt) if isinstance(prompt, list) else 1
185 |
186 | text_inputs = self.tokenizer(
187 | prompt,
188 | padding="max_length",
189 | max_length=self.tokenizer.model_max_length,
190 | truncation=True,
191 | return_tensors="pt",
192 | )
193 | text_input_ids = text_inputs.input_ids
194 | untruncated_ids = self.tokenizer(
195 | prompt, padding="longest", return_tensors="pt"
196 | ).input_ids
197 |
198 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
199 | text_input_ids, untruncated_ids
200 | ):
201 | removed_text = self.tokenizer.batch_decode(
202 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
203 | )
204 |
205 | if (
206 | hasattr(self.text_encoder.config, "use_attention_mask")
207 | and self.text_encoder.config.use_attention_mask
208 | ):
209 | attention_mask = text_inputs.attention_mask.to(device)
210 | else:
211 | attention_mask = None
212 |
213 | text_embeddings = self.text_encoder(
214 | text_input_ids.to(device),
215 | attention_mask=attention_mask,
216 | )
217 | text_embeddings = text_embeddings[0]
218 |
219 | # duplicate text embeddings for each generation per prompt, using mps friendly method
220 | bs_embed, seq_len, _ = text_embeddings.shape
221 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
222 | text_embeddings = text_embeddings.view(
223 | bs_embed * num_videos_per_prompt, seq_len, -1
224 | )
225 |
226 | # get unconditional embeddings for classifier free guidance
227 | if do_classifier_free_guidance:
228 | uncond_tokens: List[str]
229 | if negative_prompt is None:
230 | uncond_tokens = [""] * batch_size
231 | elif type(prompt) is not type(negative_prompt):
232 | raise TypeError(
233 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
234 | f" {type(prompt)}."
235 | )
236 | elif isinstance(negative_prompt, str):
237 | uncond_tokens = [negative_prompt]
238 | elif batch_size != len(negative_prompt):
239 | raise ValueError(
240 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
241 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
242 | " the batch size of `prompt`."
243 | )
244 | else:
245 | uncond_tokens = negative_prompt
246 |
247 | max_length = text_input_ids.shape[-1]
248 | uncond_input = self.tokenizer(
249 | uncond_tokens,
250 | padding="max_length",
251 | max_length=max_length,
252 | truncation=True,
253 | return_tensors="pt",
254 | )
255 |
256 | if (
257 | hasattr(self.text_encoder.config, "use_attention_mask")
258 | and self.text_encoder.config.use_attention_mask
259 | ):
260 | attention_mask = uncond_input.attention_mask.to(device)
261 | else:
262 | attention_mask = None
263 |
264 | uncond_embeddings = self.text_encoder(
265 | uncond_input.input_ids.to(device),
266 | attention_mask=attention_mask,
267 | )
268 | uncond_embeddings = uncond_embeddings[0]
269 |
270 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
271 | seq_len = uncond_embeddings.shape[1]
272 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
273 | uncond_embeddings = uncond_embeddings.view(
274 | batch_size * num_videos_per_prompt, seq_len, -1
275 | )
276 |
277 | # For classifier free guidance, we need to do two forward passes.
278 | # Here we concatenate the unconditional and text embeddings into a single batch
279 | # to avoid doing two forward passes
280 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
281 |
282 | return text_embeddings
283 |
284 | @torch.no_grad()
285 | def __call__(
286 | self,
287 | ref_image,
288 | pose_images,
289 | width,
290 | height,
291 | video_length,
292 | num_inference_steps,
293 | guidance_scale,
294 | num_images_per_prompt=1,
295 | eta: float = 0.0,
296 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
297 | output_type: Optional[str] = "tensor",
298 | return_dict: bool = True,
299 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
300 | callback_steps: Optional[int] = 1,
301 | **kwargs,
302 | ):
303 | # Default height and width to unet
304 | height = height or self.unet.config.sample_size * self.vae_scale_factor
305 | width = width or self.unet.config.sample_size * self.vae_scale_factor
306 |
307 | device = self._execution_device
308 |
309 | do_classifier_free_guidance = guidance_scale > 1.0
310 |
311 | # Prepare timesteps
312 | self.scheduler.set_timesteps(num_inference_steps, device=device)
313 | timesteps = self.scheduler.timesteps
314 |
315 | batch_size = 1
316 |
317 | # Prepare clip image embeds
318 | clip_image = self.clip_image_processor.preprocess(
319 | ref_image, return_tensors="pt"
320 | ).pixel_values
321 | clip_image_embeds = self.image_encoder(
322 | clip_image.to(device, dtype=self.image_encoder.dtype)
323 | ).image_embeds
324 | encoder_hidden_states = clip_image_embeds.unsqueeze(1)
325 | uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states)
326 |
327 | if do_classifier_free_guidance:
328 | encoder_hidden_states = torch.cat(
329 | [uncond_encoder_hidden_states, encoder_hidden_states], dim=0
330 | )
331 | reference_control_writer = ReferenceAttentionControl(
332 | self.reference_unet,
333 | do_classifier_free_guidance=do_classifier_free_guidance,
334 | mode="write",
335 | batch_size=batch_size,
336 | fusion_blocks="full",
337 | )
338 | reference_control_reader = ReferenceAttentionControl(
339 | self.denoising_unet,
340 | do_classifier_free_guidance=do_classifier_free_guidance,
341 | mode="read",
342 | batch_size=batch_size,
343 | fusion_blocks="full",
344 | )
345 |
346 | num_channels_latents = self.denoising_unet.in_channels
347 | latents = self.prepare_latents(
348 | batch_size * num_images_per_prompt,
349 | num_channels_latents,
350 | width,
351 | height,
352 | video_length,
353 | clip_image_embeds.dtype,
354 | device,
355 | generator,
356 | )
357 |
358 | # Prepare extra step kwargs.
359 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
360 |
361 | # Prepare ref image latents
362 | ref_image_tensor = self.ref_image_processor.preprocess(
363 | ref_image, height=height, width=width
364 | ) # (bs, c, width, height)
365 | ref_image_tensor = ref_image_tensor.to(
366 | dtype=self.vae.dtype, device=self.vae.device
367 | )
368 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
369 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
370 |
371 | # Prepare a list of pose condition images
372 | pose_cond_tensor_list = []
373 | for pose_image in pose_images:
374 | pose_cond_tensor = (
375 | torch.from_numpy(np.array(pose_image.resize((width, height)))) / 255.0
376 | )
377 | pose_cond_tensor = pose_cond_tensor.permute(2, 0, 1).unsqueeze(
378 | 1
379 | ) # (c, 1, h, w)
380 | pose_cond_tensor_list.append(pose_cond_tensor)
381 | pose_cond_tensor = torch.cat(pose_cond_tensor_list, dim=1) # (c, t, h, w)
382 | pose_cond_tensor = pose_cond_tensor.unsqueeze(0)
383 | pose_cond_tensor = pose_cond_tensor.to(
384 | device=device, dtype=self.pose_guider.dtype
385 | )
386 | pose_fea = self.pose_guider(pose_cond_tensor)
387 | pose_fea = (
388 | torch.cat([pose_fea] * 2) if do_classifier_free_guidance else pose_fea
389 | )
390 |
391 | # denoising loop
392 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
393 | with self.progress_bar(total=num_inference_steps) as progress_bar:
394 | for i, t in enumerate(timesteps):
395 | # 1. Forward reference image
396 | if i == 0:
397 | self.reference_unet(
398 | ref_image_latents.repeat(
399 | (2 if do_classifier_free_guidance else 1), 1, 1, 1
400 | ),
401 | torch.zeros_like(t),
402 | # t,
403 | encoder_hidden_states=encoder_hidden_states,
404 | return_dict=False,
405 | )
406 | reference_control_reader.update(reference_control_writer)
407 |
408 | # 3.1 expand the latents if we are doing classifier free guidance
409 | latent_model_input = (
410 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents
411 | )
412 | latent_model_input = self.scheduler.scale_model_input(
413 | latent_model_input, t
414 | )
415 |
416 | noise_pred = self.denoising_unet(
417 | latent_model_input,
418 | t,
419 | encoder_hidden_states=encoder_hidden_states,
420 | pose_cond_fea=pose_fea,
421 | return_dict=False,
422 | )[0]
423 |
424 | # perform guidance
425 | if do_classifier_free_guidance:
426 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
427 | noise_pred = noise_pred_uncond + guidance_scale * (
428 | noise_pred_text - noise_pred_uncond
429 | )
430 |
431 | # compute the previous noisy sample x_t -> x_t-1
432 | latents = self.scheduler.step(
433 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False
434 | )[0]
435 |
436 | # call the callback, if provided
437 | if i == len(timesteps) - 1 or (
438 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
439 | ):
440 | progress_bar.update()
441 | if callback is not None and i % callback_steps == 0:
442 | step_idx = i // getattr(self.scheduler, "order", 1)
443 | callback(step_idx, t, latents)
444 |
445 | reference_control_reader.clear()
446 | reference_control_writer.clear()
447 |
448 | # Post-processing
449 | images = self.decode_latents(latents) # (b, c, f, h, w)
450 |
451 | # Convert to tensor
452 | if output_type == "tensor":
453 | images = torch.from_numpy(images)
454 |
455 | if not return_dict:
456 | return images
457 |
458 | return Pose2VideoPipelineOutput(videos=images)
459 |
--------------------------------------------------------------------------------
/musepose/pipelines/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | tensor_interpolation = None
4 |
5 |
6 | def get_tensor_interpolation_method():
7 | return tensor_interpolation
8 |
9 |
10 | def set_tensor_interpolation_method(is_slerp):
11 | global tensor_interpolation
12 | tensor_interpolation = slerp if is_slerp else linear
13 |
14 |
15 | def linear(v1, v2, t):
16 | return (1.0 - t) * v1 + t * v2
17 |
18 |
19 | def slerp(
20 | v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
21 | ) -> torch.Tensor:
22 | u0 = v0 / v0.norm()
23 | u1 = v1 / v1.norm()
24 | dot = (u0 * u1).sum()
25 | if dot.abs() > DOT_THRESHOLD:
26 | # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')
27 | return (1.0 - t) * v0 + t * v1
28 | omega = dot.acos()
29 | return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
30 |
--------------------------------------------------------------------------------
/musepose/utils/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | import os.path as osp
4 | import shutil
5 | import sys
6 | from pathlib import Path
7 |
8 | import av
9 | import numpy as np
10 | import torch
11 | import torchvision
12 | from einops import rearrange
13 | from PIL import Image
14 |
15 |
16 | def seed_everything(seed):
17 | import random
18 |
19 | import numpy as np
20 |
21 | torch.manual_seed(seed)
22 | torch.cuda.manual_seed_all(seed)
23 | np.random.seed(seed % (2**32))
24 | random.seed(seed)
25 |
26 |
27 | def import_filename(filename):
28 | spec = importlib.util.spec_from_file_location("mymodule", filename)
29 | module = importlib.util.module_from_spec(spec)
30 | sys.modules[spec.name] = module
31 | spec.loader.exec_module(module)
32 | return module
33 |
34 |
35 | def delete_additional_ckpt(base_path, num_keep):
36 | dirs = []
37 | for d in os.listdir(base_path):
38 | if d.startswith("checkpoint-"):
39 | dirs.append(d)
40 | num_tot = len(dirs)
41 | if num_tot <= num_keep:
42 | return
43 | # ensure ckpt is sorted and delete the ealier!
44 | del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
45 | for d in del_dirs:
46 | path_to_dir = osp.join(base_path, d)
47 | if osp.exists(path_to_dir):
48 | shutil.rmtree(path_to_dir)
49 |
50 |
51 | def save_videos_from_pil(pil_images, path, fps=8):
52 | import av
53 |
54 | save_fmt = Path(path).suffix
55 | os.makedirs(os.path.dirname(path), exist_ok=True)
56 | width, height = pil_images[0].size
57 |
58 | if save_fmt == ".mp4":
59 | codec = "libx264"
60 | container = av.open(path, "w")
61 | stream = container.add_stream(codec, rate=fps)
62 |
63 | stream.width = width
64 | stream.height = height
65 | stream.pix_fmt = 'yuv420p'
66 | stream.bit_rate = 10000000
67 | stream.options["crf"] = "18"
68 |
69 |
70 |
71 | for pil_image in pil_images:
72 | # pil_image = Image.fromarray(image_arr).convert("RGB")
73 | av_frame = av.VideoFrame.from_image(pil_image)
74 | container.mux(stream.encode(av_frame))
75 | container.mux(stream.encode())
76 | container.close()
77 |
78 | elif save_fmt == ".gif":
79 | pil_images[0].save(
80 | fp=path,
81 | format="GIF",
82 | append_images=pil_images[1:],
83 | save_all=True,
84 | duration=(1 / fps * 1000),
85 | loop=0,
86 | )
87 | else:
88 | raise ValueError("Unsupported file type. Use .mp4 or .gif.")
89 |
90 |
91 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
92 | videos = rearrange(videos, "b c t h w -> t b c h w")
93 | height, width = videos.shape[-2:]
94 | outputs = []
95 |
96 | for x in videos:
97 | x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
98 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
99 | if rescale:
100 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
101 | x = (x * 255).numpy().astype(np.uint8)
102 | x = Image.fromarray(x)
103 |
104 | outputs.append(x)
105 |
106 | os.makedirs(os.path.dirname(path), exist_ok=True)
107 |
108 | save_videos_from_pil(outputs, path, fps)
109 |
110 |
111 | def read_frames(video_path):
112 | container = av.open(video_path)
113 |
114 | video_stream = next(s for s in container.streams if s.type == "video")
115 | frames = []
116 | for packet in container.demux(video_stream):
117 | for frame in packet.decode():
118 | image = Image.frombytes(
119 | "RGB",
120 | (frame.width, frame.height),
121 | frame.to_rgb().to_ndarray(),
122 | )
123 | frames.append(image)
124 |
125 | return frames
126 |
127 |
128 | def get_fps(video_path):
129 | container = av.open(video_path)
130 | video_stream = next(s for s in container.streams if s.type == "video")
131 | fps = video_stream.average_rate
132 | container.close()
133 | return fps
134 |
--------------------------------------------------------------------------------
/pose/config/dwpose-l_384x288.py:
--------------------------------------------------------------------------------
1 | # runtime
2 | max_epochs = 270
3 | stage2_num_epochs = 30
4 | base_lr = 4e-3
5 |
6 | train_cfg = dict(max_epochs=max_epochs, val_interval=10)
7 | randomness = dict(seed=21)
8 |
9 | # optimizer
10 | optim_wrapper = dict(
11 | type='OptimWrapper',
12 | optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
13 | paramwise_cfg=dict(
14 | norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
15 |
16 | # learning rate
17 | param_scheduler = [
18 | dict(
19 | type='LinearLR',
20 | start_factor=1.0e-5,
21 | by_epoch=False,
22 | begin=0,
23 | end=1000),
24 | dict(
25 | # use cosine lr from 150 to 300 epoch
26 | type='CosineAnnealingLR',
27 | eta_min=base_lr * 0.05,
28 | begin=max_epochs // 2,
29 | end=max_epochs,
30 | T_max=max_epochs // 2,
31 | by_epoch=True,
32 | convert_to_iter_based=True),
33 | ]
34 |
35 | # automatically scaling LR based on the actual training batch size
36 | auto_scale_lr = dict(base_batch_size=512)
37 |
38 | # codec settings
39 | codec = dict(
40 | type='SimCCLabel',
41 | input_size=(288, 384),
42 | sigma=(6., 6.93),
43 | simcc_split_ratio=2.0,
44 | normalize=False,
45 | use_dark=False)
46 |
47 | # model settings
48 | model = dict(
49 | type='TopdownPoseEstimator',
50 | data_preprocessor=dict(
51 | type='PoseDataPreprocessor',
52 | mean=[123.675, 116.28, 103.53],
53 | std=[58.395, 57.12, 57.375],
54 | bgr_to_rgb=True),
55 | backbone=dict(
56 | _scope_='mmdet',
57 | type='CSPNeXt',
58 | arch='P5',
59 | expand_ratio=0.5,
60 | deepen_factor=1.,
61 | widen_factor=1.,
62 | out_indices=(4, ),
63 | channel_attention=True,
64 | norm_cfg=dict(type='SyncBN'),
65 | act_cfg=dict(type='SiLU'),
66 | init_cfg=dict(
67 | type='Pretrained',
68 | prefix='backbone.',
69 | checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
70 | 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa
71 | )),
72 | head=dict(
73 | type='RTMCCHead',
74 | in_channels=1024,
75 | out_channels=133,
76 | input_size=codec['input_size'],
77 | in_featuremap_size=(9, 12),
78 | simcc_split_ratio=codec['simcc_split_ratio'],
79 | final_layer_kernel_size=7,
80 | gau_cfg=dict(
81 | hidden_dims=256,
82 | s=128,
83 | expansion_factor=2,
84 | dropout_rate=0.,
85 | drop_path=0.,
86 | act_fn='SiLU',
87 | use_rel_bias=False,
88 | pos_enc=False),
89 | loss=dict(
90 | type='KLDiscretLoss',
91 | use_target_weight=True,
92 | beta=10.,
93 | label_softmax=True),
94 | decoder=codec),
95 | test_cfg=dict(flip_test=True, ))
96 |
97 | # base dataset settings
98 | dataset_type = 'CocoWholeBodyDataset'
99 | data_mode = 'topdown'
100 | data_root = '/data/'
101 |
102 | backend_args = dict(backend='local')
103 | # backend_args = dict(
104 | # backend='petrel',
105 | # path_mapping=dict({
106 | # f'{data_root}': 's3://openmmlab/datasets/detection/coco/',
107 | # f'{data_root}': 's3://openmmlab/datasets/detection/coco/'
108 | # }))
109 |
110 | # pipelines
111 | train_pipeline = [
112 | dict(type='LoadImage', backend_args=backend_args),
113 | dict(type='GetBBoxCenterScale'),
114 | dict(type='RandomFlip', direction='horizontal'),
115 | dict(type='RandomHalfBody'),
116 | dict(
117 | type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80),
118 | dict(type='TopdownAffine', input_size=codec['input_size']),
119 | dict(type='mmdet.YOLOXHSVRandomAug'),
120 | dict(
121 | type='Albumentation',
122 | transforms=[
123 | dict(type='Blur', p=0.1),
124 | dict(type='MedianBlur', p=0.1),
125 | dict(
126 | type='CoarseDropout',
127 | max_holes=1,
128 | max_height=0.4,
129 | max_width=0.4,
130 | min_holes=1,
131 | min_height=0.2,
132 | min_width=0.2,
133 | p=1.0),
134 | ]),
135 | dict(type='GenerateTarget', encoder=codec),
136 | dict(type='PackPoseInputs')
137 | ]
138 | val_pipeline = [
139 | dict(type='LoadImage', backend_args=backend_args),
140 | dict(type='GetBBoxCenterScale'),
141 | dict(type='TopdownAffine', input_size=codec['input_size']),
142 | dict(type='PackPoseInputs')
143 | ]
144 |
145 | train_pipeline_stage2 = [
146 | dict(type='LoadImage', backend_args=backend_args),
147 | dict(type='GetBBoxCenterScale'),
148 | dict(type='RandomFlip', direction='horizontal'),
149 | dict(type='RandomHalfBody'),
150 | dict(
151 | type='RandomBBoxTransform',
152 | shift_factor=0.,
153 | scale_factor=[0.75, 1.25],
154 | rotate_factor=60),
155 | dict(type='TopdownAffine', input_size=codec['input_size']),
156 | dict(type='mmdet.YOLOXHSVRandomAug'),
157 | dict(
158 | type='Albumentation',
159 | transforms=[
160 | dict(type='Blur', p=0.1),
161 | dict(type='MedianBlur', p=0.1),
162 | dict(
163 | type='CoarseDropout',
164 | max_holes=1,
165 | max_height=0.4,
166 | max_width=0.4,
167 | min_holes=1,
168 | min_height=0.2,
169 | min_width=0.2,
170 | p=0.5),
171 | ]),
172 | dict(type='GenerateTarget', encoder=codec),
173 | dict(type='PackPoseInputs')
174 | ]
175 |
176 | datasets = []
177 | dataset_coco=dict(
178 | type=dataset_type,
179 | data_root=data_root,
180 | data_mode=data_mode,
181 | ann_file='coco/annotations/coco_wholebody_train_v1.0.json',
182 | data_prefix=dict(img='coco/train2017/'),
183 | pipeline=[],
184 | )
185 | datasets.append(dataset_coco)
186 |
187 | scene = ['Magic_show', 'Entertainment', 'ConductMusic', 'Online_class',
188 | 'TalkShow', 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow',
189 | 'Singing', 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference']
190 |
191 | for i in range(len(scene)):
192 | datasets.append(
193 | dict(
194 | type=dataset_type,
195 | data_root=data_root,
196 | data_mode=data_mode,
197 | ann_file='UBody/annotations/'+scene[i]+'/keypoint_annotation.json',
198 | data_prefix=dict(img='UBody/images/'+scene[i]+'/'),
199 | pipeline=[],
200 | )
201 | )
202 |
203 | # data loaders
204 | train_dataloader = dict(
205 | batch_size=32,
206 | num_workers=10,
207 | persistent_workers=True,
208 | sampler=dict(type='DefaultSampler', shuffle=True),
209 | dataset=dict(
210 | type='CombinedDataset',
211 | metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
212 | datasets=datasets,
213 | pipeline=train_pipeline,
214 | test_mode=False,
215 | ))
216 | val_dataloader = dict(
217 | batch_size=32,
218 | num_workers=10,
219 | persistent_workers=True,
220 | drop_last=False,
221 | sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
222 | dataset=dict(
223 | type=dataset_type,
224 | data_root=data_root,
225 | data_mode=data_mode,
226 | ann_file='coco/annotations/coco_wholebody_val_v1.0.json',
227 | bbox_file=f'{data_root}coco/person_detection_results/'
228 | 'COCO_val2017_detections_AP_H_56_person.json',
229 | data_prefix=dict(img='coco/val2017/'),
230 | test_mode=True,
231 | pipeline=val_pipeline,
232 | ))
233 | test_dataloader = val_dataloader
234 |
235 | # hooks
236 | default_hooks = dict(
237 | checkpoint=dict(
238 | save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
239 |
240 | custom_hooks = [
241 | dict(
242 | type='EMAHook',
243 | ema_type='ExpMomentumEMA',
244 | momentum=0.0002,
245 | update_buffers=True,
246 | priority=49),
247 | dict(
248 | type='mmdet.PipelineSwitchHook',
249 | switch_epoch=max_epochs - stage2_num_epochs,
250 | switch_pipeline=train_pipeline_stage2)
251 | ]
252 |
253 | # evaluators
254 | val_evaluator = dict(
255 | type='CocoWholeBodyMetric',
256 | ann_file=data_root + 'coco/annotations/coco_wholebody_val_v1.0.json')
257 | test_evaluator = val_evaluator
--------------------------------------------------------------------------------
/pose/config/yolox_l_8xb8-300e_coco.py:
--------------------------------------------------------------------------------
1 | img_scale = (640, 640) # width, height
2 |
3 | # model settings
4 | model = dict(
5 | type='YOLOX',
6 | data_preprocessor=dict(
7 | type='DetDataPreprocessor',
8 | pad_size_divisor=32,
9 | batch_augments=[
10 | dict(
11 | type='BatchSyncRandomResize',
12 | random_size_range=(480, 800),
13 | size_divisor=32,
14 | interval=10)
15 | ]),
16 | backbone=dict(
17 | type='CSPDarknet',
18 | deepen_factor=1.0,
19 | widen_factor=1.0,
20 | out_indices=(2, 3, 4),
21 | use_depthwise=False,
22 | spp_kernal_sizes=(5, 9, 13),
23 | norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
24 | act_cfg=dict(type='Swish'),
25 | ),
26 | neck=dict(
27 | type='YOLOXPAFPN',
28 | in_channels=[256, 512, 1024],
29 | out_channels=256,
30 | num_csp_blocks=3,
31 | use_depthwise=False,
32 | upsample_cfg=dict(scale_factor=2, mode='nearest'),
33 | norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
34 | act_cfg=dict(type='Swish')),
35 | bbox_head=dict(
36 | type='YOLOXHead',
37 | num_classes=80,
38 | in_channels=256,
39 | feat_channels=256,
40 | stacked_convs=2,
41 | strides=(8, 16, 32),
42 | use_depthwise=False,
43 | norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
44 | act_cfg=dict(type='Swish'),
45 | loss_cls=dict(
46 | type='CrossEntropyLoss',
47 | use_sigmoid=True,
48 | reduction='sum',
49 | loss_weight=1.0),
50 | loss_bbox=dict(
51 | type='IoULoss',
52 | mode='square',
53 | eps=1e-16,
54 | reduction='sum',
55 | loss_weight=5.0),
56 | loss_obj=dict(
57 | type='CrossEntropyLoss',
58 | use_sigmoid=True,
59 | reduction='sum',
60 | loss_weight=1.0),
61 | loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0)),
62 | train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
63 | # In order to align the source code, the threshold of the val phase is
64 | # 0.01, and the threshold of the test phase is 0.001.
65 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))
66 |
67 | # dataset settings
68 | data_root = 'data/coco/'
69 | dataset_type = 'CocoDataset'
70 |
71 | # Example to use different file client
72 | # Method 1: simply set the data root and let the file I/O module
73 | # automatically infer from prefix (not support LMDB and Memcache yet)
74 |
75 | # data_root = 's3://openmmlab/datasets/detection/coco/'
76 |
77 | # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6
78 | # backend_args = dict(
79 | # backend='petrel',
80 | # path_mapping=dict({
81 | # './data/': 's3://openmmlab/datasets/detection/',
82 | # 'data/': 's3://openmmlab/datasets/detection/'
83 | # }))
84 | backend_args = None
85 |
86 | train_pipeline = [
87 | dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
88 | dict(
89 | type='RandomAffine',
90 | scaling_ratio_range=(0.1, 2),
91 | # img_scale is (width, height)
92 | border=(-img_scale[0] // 2, -img_scale[1] // 2)),
93 | dict(
94 | type='MixUp',
95 | img_scale=img_scale,
96 | ratio_range=(0.8, 1.6),
97 | pad_val=114.0),
98 | dict(type='YOLOXHSVRandomAug'),
99 | dict(type='RandomFlip', prob=0.5),
100 | # According to the official implementation, multi-scale
101 | # training is not considered here but in the
102 | # 'mmdet/models/detectors/yolox.py'.
103 | # Resize and Pad are for the last 15 epochs when Mosaic,
104 | # RandomAffine, and MixUp are closed by YOLOXModeSwitchHook.
105 | dict(type='Resize', scale=img_scale, keep_ratio=True),
106 | dict(
107 | type='Pad',
108 | pad_to_square=True,
109 | # If the image is three-channel, the pad value needs
110 | # to be set separately for each channel.
111 | pad_val=dict(img=(114.0, 114.0, 114.0))),
112 | dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
113 | dict(type='PackDetInputs')
114 | ]
115 |
116 | train_dataset = dict(
117 | # use MultiImageMixDataset wrapper to support mosaic and mixup
118 | type='MultiImageMixDataset',
119 | dataset=dict(
120 | type=dataset_type,
121 | data_root=data_root,
122 | ann_file='annotations/instances_train2017.json',
123 | data_prefix=dict(img='train2017/'),
124 | pipeline=[
125 | dict(type='LoadImageFromFile', backend_args=backend_args),
126 | dict(type='LoadAnnotations', with_bbox=True)
127 | ],
128 | filter_cfg=dict(filter_empty_gt=False, min_size=32),
129 | backend_args=backend_args),
130 | pipeline=train_pipeline)
131 |
132 | test_pipeline = [
133 | dict(type='LoadImageFromFile', backend_args=backend_args),
134 | dict(type='Resize', scale=img_scale, keep_ratio=True),
135 | dict(
136 | type='Pad',
137 | pad_to_square=True,
138 | pad_val=dict(img=(114.0, 114.0, 114.0))),
139 | dict(type='LoadAnnotations', with_bbox=True),
140 | dict(
141 | type='PackDetInputs',
142 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
143 | 'scale_factor'))
144 | ]
145 |
146 | train_dataloader = dict(
147 | batch_size=8,
148 | num_workers=4,
149 | persistent_workers=True,
150 | sampler=dict(type='DefaultSampler', shuffle=True),
151 | dataset=train_dataset)
152 | val_dataloader = dict(
153 | batch_size=8,
154 | num_workers=4,
155 | persistent_workers=True,
156 | drop_last=False,
157 | sampler=dict(type='DefaultSampler', shuffle=False),
158 | dataset=dict(
159 | type=dataset_type,
160 | data_root=data_root,
161 | ann_file='annotations/instances_val2017.json',
162 | data_prefix=dict(img='val2017/'),
163 | test_mode=True,
164 | pipeline=test_pipeline,
165 | backend_args=backend_args))
166 | test_dataloader = val_dataloader
167 |
168 | val_evaluator = dict(
169 | type='CocoMetric',
170 | ann_file=data_root + 'annotations/instances_val2017.json',
171 | metric='bbox',
172 | backend_args=backend_args)
173 | test_evaluator = val_evaluator
174 |
175 | # training settings
176 | max_epochs = 300
177 | num_last_epochs = 15
178 | interval = 10
179 |
180 | train_cfg = dict(max_epochs=max_epochs, val_interval=interval)
181 |
182 | # optimizer
183 | # default 8 gpu
184 | base_lr = 0.01
185 | optim_wrapper = dict(
186 | type='OptimWrapper',
187 | optimizer=dict(
188 | type='SGD', lr=base_lr, momentum=0.9, weight_decay=5e-4,
189 | nesterov=True),
190 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
191 |
192 | # learning rate
193 | param_scheduler = [
194 | dict(
195 | # use quadratic formula to warm up 5 epochs
196 | # and lr is updated by iteration
197 | # TODO: fix default scope in get function
198 | type='mmdet.QuadraticWarmupLR',
199 | by_epoch=True,
200 | begin=0,
201 | end=5,
202 | convert_to_iter_based=True),
203 | dict(
204 | # use cosine lr from 5 to 285 epoch
205 | type='CosineAnnealingLR',
206 | eta_min=base_lr * 0.05,
207 | begin=5,
208 | T_max=max_epochs - num_last_epochs,
209 | end=max_epochs - num_last_epochs,
210 | by_epoch=True,
211 | convert_to_iter_based=True),
212 | dict(
213 | # use fixed lr during last 15 epochs
214 | type='ConstantLR',
215 | by_epoch=True,
216 | factor=1,
217 | begin=max_epochs - num_last_epochs,
218 | end=max_epochs,
219 | )
220 | ]
221 |
222 | default_hooks = dict(
223 | checkpoint=dict(
224 | interval=interval,
225 | max_keep_ckpts=3 # only keep latest 3 checkpoints
226 | ))
227 |
228 | custom_hooks = [
229 | dict(
230 | type='YOLOXModeSwitchHook',
231 | num_last_epochs=num_last_epochs,
232 | priority=48),
233 | dict(type='SyncNormHook', priority=48),
234 | dict(
235 | type='EMAHook',
236 | ema_type='ExpMomentumEMA',
237 | momentum=0.0001,
238 | update_buffers=True,
239 | priority=49)
240 | ]
241 |
242 | # NOTE: `auto_scale_lr` is for automatically scaling LR,
243 | # USER SHOULD NOT CHANGE ITS VALUES.
244 | # base_batch_size = (8 GPUs) x (8 samples per GPU)
245 | auto_scale_lr = dict(base_batch_size=64)
246 |
--------------------------------------------------------------------------------
/pose/script/dwpose.py:
--------------------------------------------------------------------------------
1 | # Openpose
2 | # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
3 | # 2nd Edited by https://github.com/Hzzone/pytorch-openpose
4 | # 3rd Edited by ControlNet
5 | # 4th Edited by ControlNet (added face and correct hands)
6 |
7 | import os
8 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
9 |
10 | import cv2
11 | import torch
12 | import numpy as np
13 | from PIL import Image
14 |
15 |
16 | import pose.script.util as util
17 |
18 | def resize_image(input_image, resolution):
19 | H, W, C = input_image.shape
20 | H = float(H)
21 | W = float(W)
22 | k = float(resolution) / min(H, W)
23 | H *= k
24 | W *= k
25 | H = int(np.round(H / 64.0)) * 64
26 | W = int(np.round(W / 64.0)) * 64
27 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
28 | return img
29 |
30 | def HWC3(x):
31 | assert x.dtype == np.uint8
32 | if x.ndim == 2:
33 | x = x[:, :, None]
34 | assert x.ndim == 3
35 | H, W, C = x.shape
36 | assert C == 1 or C == 3 or C == 4
37 | if C == 3:
38 | return x
39 | if C == 1:
40 | return np.concatenate([x, x, x], axis=2)
41 | if C == 4:
42 | color = x[:, :, 0:3].astype(np.float32)
43 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0
44 | y = color * alpha + 255.0 * (1.0 - alpha)
45 | y = y.clip(0, 255).astype(np.uint8)
46 | return y
47 |
48 | def draw_pose(pose, H, W, draw_face):
49 | bodies = pose['bodies']
50 | faces = pose['faces']
51 | hands = pose['hands']
52 | candidate = bodies['candidate']
53 | subset = bodies['subset']
54 |
55 | # only the most significant person
56 | faces = pose['faces'][:1]
57 | hands = pose['hands'][:2]
58 | candidate = bodies['candidate'][:18]
59 | subset = bodies['subset'][:1]
60 |
61 | # draw
62 | canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
63 | canvas = util.draw_bodypose(canvas, candidate, subset)
64 | canvas = util.draw_handpose(canvas, hands)
65 | if draw_face == True:
66 | canvas = util.draw_facepose(canvas, faces)
67 |
68 | return canvas
69 |
70 | class DWposeDetector:
71 | def __init__(self, det_config=None, det_ckpt=None, pose_config=None, pose_ckpt=None, device="cpu", keypoints_only=False):
72 | from pose.script.wholebody import Wholebody
73 |
74 | self.pose_estimation = Wholebody(det_config, det_ckpt, pose_config, pose_ckpt, device)
75 | self.keypoints_only = keypoints_only
76 | def to(self, device):
77 | self.pose_estimation.to(device)
78 | return self
79 | '''
80 | detect_resolution: 短边resize到多少 这是 draw pose 时的原始渲染分辨率。建议1024
81 | image_resolution: 短边resize到多少 这是 save pose 时的文件分辨率。建议768
82 |
83 | 实际检测分辨率:
84 | yolox: (640, 640)
85 | dwpose:(288, 384)
86 | '''
87 |
88 | def __call__(self, input_image, detect_resolution=1024, image_resolution=768, output_type="pil", **kwargs):
89 |
90 | input_image = cv2.cvtColor(np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR)
91 | # cv2.imshow('', input_image)
92 | # cv2.waitKey(0)
93 |
94 | input_image = HWC3(input_image)
95 | input_image = resize_image(input_image, detect_resolution)
96 | H, W, C = input_image.shape
97 |
98 | with torch.no_grad():
99 | candidate, subset = self.pose_estimation(input_image)
100 | nums, keys, locs = candidate.shape
101 | candidate[..., 0] /= float(W)
102 | candidate[..., 1] /= float(H)
103 | body = candidate[:,:18].copy()
104 | body = body.reshape(nums*18, locs)
105 | score = subset[:,:18]
106 |
107 | for i in range(len(score)):
108 | for j in range(len(score[i])):
109 | if score[i][j] > 0.3:
110 | score[i][j] = int(18*i+j)
111 | else:
112 | score[i][j] = -1
113 |
114 | un_visible = subset<0.3
115 | candidate[un_visible] = -1
116 |
117 | foot = candidate[:,18:24]
118 |
119 | faces = candidate[:,24:92]
120 |
121 | hands = candidate[:,92:113]
122 | hands = np.vstack([hands, candidate[:,113:]])
123 |
124 | bodies = dict(candidate=body, subset=score)
125 | pose = dict(bodies=bodies, hands=hands, faces=faces)
126 |
127 | if self.keypoints_only==True:
128 | return pose
129 | else:
130 | detected_map = draw_pose(pose, H, W, draw_face=False)
131 | detected_map = HWC3(detected_map)
132 | img = resize_image(input_image, image_resolution)
133 | H, W, C = img.shape
134 | detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
135 | # cv2.imshow('detected_map',detected_map)
136 | # cv2.waitKey(0)
137 |
138 | if output_type == "pil":
139 | detected_map = cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB)
140 | detected_map = Image.fromarray(detected_map)
141 |
142 | return detected_map, pose
143 |
144 |
--------------------------------------------------------------------------------
/pose/script/tool.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | import os.path as osp
4 | import shutil
5 | import sys
6 | from pathlib import Path
7 |
8 | import av
9 | import numpy as np
10 | import torch
11 | import torchvision
12 | from einops import rearrange
13 | from PIL import Image
14 |
15 |
16 | def seed_everything(seed):
17 | import random
18 |
19 | import numpy as np
20 |
21 | torch.manual_seed(seed)
22 | torch.cuda.manual_seed_all(seed)
23 | np.random.seed(seed % (2**32))
24 | random.seed(seed)
25 |
26 |
27 | def import_filename(filename):
28 | spec = importlib.util.spec_from_file_location("mymodule", filename)
29 | module = importlib.util.module_from_spec(spec)
30 | sys.modules[spec.name] = module
31 | spec.loader.exec_module(module)
32 | return module
33 |
34 |
35 | def delete_additional_ckpt(base_path, num_keep):
36 | dirs = []
37 | for d in os.listdir(base_path):
38 | if d.startswith("checkpoint-"):
39 | dirs.append(d)
40 | num_tot = len(dirs)
41 | if num_tot <= num_keep:
42 | return
43 | # ensure ckpt is sorted and delete the ealier!
44 | del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
45 | for d in del_dirs:
46 | path_to_dir = osp.join(base_path, d)
47 | if osp.exists(path_to_dir):
48 | shutil.rmtree(path_to_dir)
49 |
50 |
51 | def save_videos_from_pil(pil_images, path, fps):
52 |
53 | save_fmt = Path(path).suffix
54 | os.makedirs(os.path.dirname(path), exist_ok=True)
55 | width, height = pil_images[0].size
56 |
57 | if save_fmt == ".mp4":
58 | codec = "libx264"
59 | container = av.open(path, "w")
60 | stream = container.add_stream(codec, rate=fps)
61 |
62 | stream.width = width
63 | stream.height = height
64 | stream.pix_fmt = 'yuv420p'
65 | stream.bit_rate = 10000000
66 | stream.options["crf"] = "18"
67 |
68 | for pil_image in pil_images:
69 | # pil_image = Image.fromarray(image_arr).convert("RGB")
70 | av_frame = av.VideoFrame.from_image(pil_image)
71 | container.mux(stream.encode(av_frame))
72 | container.mux(stream.encode())
73 | container.close()
74 |
75 | elif save_fmt == ".gif":
76 | pil_images[0].save(
77 | fp=path,
78 | format="GIF",
79 | append_images=pil_images[1:],
80 | save_all=True,
81 | duration=(1 / fps * 1000),
82 | loop=0,
83 | )
84 | else:
85 | raise ValueError("Unsupported file type. Use .mp4 or .gif.")
86 |
87 |
88 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
89 | videos = rearrange(videos, "b c t h w -> t b c h w")
90 | height, width = videos.shape[-2:]
91 | outputs = []
92 |
93 | for x in videos:
94 | x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
95 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
96 | if rescale:
97 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
98 | x = (x * 255).numpy().astype(np.uint8)
99 | x = Image.fromarray(x)
100 |
101 | outputs.append(x)
102 |
103 | os.makedirs(os.path.dirname(path), exist_ok=True)
104 |
105 | save_videos_from_pil(outputs, path, fps)
106 |
107 |
108 | def read_frames(video_path):
109 | container = av.open(video_path)
110 |
111 | video_stream = next(s for s in container.streams if s.type == "video")
112 | frames = []
113 | for packet in container.demux(video_stream):
114 | for frame in packet.decode():
115 | image = Image.frombytes(
116 | "RGB",
117 | (frame.width, frame.height),
118 | frame.to_rgb().to_ndarray(),
119 | )
120 | frames.append(image)
121 |
122 | return frames
123 |
124 |
125 | def get_fps(video_path):
126 | container = av.open(video_path)
127 | video_stream = next(s for s in container.streams if s.type == "video")
128 | fps = video_stream.average_rate
129 | container.close()
130 | return fps
131 |
--------------------------------------------------------------------------------
/pose/script/util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import cv2
4 |
5 |
6 | eps = 0.01
7 |
8 | def smart_width(d):
9 | if d<5:
10 | return 1
11 | elif d<10:
12 | return 2
13 | elif d<20:
14 | return 3
15 | elif d<40:
16 | return 4
17 | elif d<80:
18 | return 5
19 | elif d<160:
20 | return 6
21 | elif d<320:
22 | return 7
23 | else:
24 | return 8
25 |
26 |
27 |
28 | def draw_bodypose(canvas, candidate, subset):
29 | H, W, C = canvas.shape
30 | candidate = np.array(candidate)
31 | subset = np.array(subset)
32 |
33 | limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
34 | [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
35 | [1, 16], [16, 18], [3, 17], [6, 18]]
36 |
37 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
38 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
39 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
40 |
41 | for i in range(17):
42 | for n in range(len(subset)):
43 | index = subset[n][np.array(limbSeq[i]) - 1]
44 | if -1 in index:
45 | continue
46 | Y = candidate[index.astype(int), 0] * float(W)
47 | X = candidate[index.astype(int), 1] * float(H)
48 | mX = np.mean(X)
49 | mY = np.mean(Y)
50 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
51 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
52 |
53 | width = smart_width(length)
54 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), width), int(angle), 0, 360, 1)
55 | cv2.fillConvexPoly(canvas, polygon, colors[i])
56 |
57 | canvas = (canvas * 0.6).astype(np.uint8)
58 |
59 | for i in range(18):
60 | for n in range(len(subset)):
61 | index = int(subset[n][i])
62 | if index == -1:
63 | continue
64 | x, y = candidate[index][0:2]
65 | x = int(x * W)
66 | y = int(y * H)
67 | radius = 4
68 | cv2.circle(canvas, (int(x), int(y)), radius, colors[i], thickness=-1)
69 |
70 | return canvas
71 |
72 |
73 | def draw_handpose(canvas, all_hand_peaks):
74 | import matplotlib
75 |
76 | H, W, C = canvas.shape
77 |
78 | edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
79 | [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
80 |
81 | # (person_number*2, 21, 2)
82 | for i in range(len(all_hand_peaks)):
83 | peaks = all_hand_peaks[i]
84 | peaks = np.array(peaks)
85 |
86 | for ie, e in enumerate(edges):
87 |
88 | x1, y1 = peaks[e[0]]
89 | x2, y2 = peaks[e[1]]
90 |
91 | x1 = int(x1 * W)
92 | y1 = int(y1 * H)
93 | x2 = int(x2 * W)
94 | y2 = int(y2 * H)
95 | if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
96 | length = ((x1 - x2) ** 2 + (y1 - y2) ** 2) ** 0.5
97 | width = smart_width(length)
98 | cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=width)
99 |
100 | for _, keyponit in enumerate(peaks):
101 | x, y = keyponit
102 |
103 | x = int(x * W)
104 | y = int(y * H)
105 | if x > eps and y > eps:
106 | radius = 3
107 | cv2.circle(canvas, (x, y), radius, (0, 0, 255), thickness=-1)
108 | return canvas
109 |
110 |
111 | def draw_facepose(canvas, all_lmks):
112 | H, W, C = canvas.shape
113 | for lmks in all_lmks:
114 | lmks = np.array(lmks)
115 | for lmk in lmks:
116 | x, y = lmk
117 | x = int(x * W)
118 | y = int(y * H)
119 | if x > eps and y > eps:
120 | radius = 3
121 | cv2.circle(canvas, (x, y), radius, (255, 255, 255), thickness=-1)
122 | return canvas
123 |
124 |
125 |
126 |
127 | # Calculate the resolution
128 | def size_calculate(h, w, resolution):
129 |
130 | H = float(h)
131 | W = float(w)
132 |
133 | # resize the short edge to the resolution
134 | k = float(resolution) / min(H, W) # short edge
135 | H *= k
136 | W *= k
137 |
138 | # resize to the nearest integer multiple of 64
139 | H = int(np.round(H / 64.0)) * 64
140 | W = int(np.round(W / 64.0)) * 64
141 | return H, W
142 |
143 |
144 |
145 | def warpAffine_kps(kps, M):
146 | a = M[:,:2]
147 | t = M[:,2]
148 | kps = np.dot(kps, a.T) + t
149 | return kps
150 |
151 |
152 |
153 |
154 |
--------------------------------------------------------------------------------
/pose/script/wholebody.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os
3 | import numpy as np
4 | import warnings
5 |
6 | try:
7 | import mmcv
8 | except ImportError:
9 | warnings.warn(
10 | "The module 'mmcv' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmcv>=2.0.1'"
11 | )
12 |
13 | try:
14 | from mmpose.apis import inference_topdown
15 | from mmpose.apis import init_model as init_pose_estimator
16 | from mmpose.evaluation.functional import nms
17 | from mmpose.utils import adapt_mmdet_pipeline
18 | from mmpose.structures import merge_data_samples
19 | except ImportError:
20 | warnings.warn(
21 | "The module 'mmpose' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmpose>=1.1.0'"
22 | )
23 |
24 | try:
25 | from mmdet.apis import inference_detector, init_detector
26 | except ImportError:
27 | warnings.warn(
28 | "The module 'mmdet' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmdet>=3.1.0'"
29 | )
30 |
31 |
32 | class Wholebody:
33 | def __init__(self,
34 | det_config=None, det_ckpt=None,
35 | pose_config=None, pose_ckpt=None,
36 | device="cpu"):
37 |
38 | if det_config is None:
39 | det_config = os.path.join(os.path.dirname(__file__), "yolox_config/yolox_l_8xb8-300e_coco.py")
40 |
41 | if pose_config is None:
42 | pose_config = os.path.join(os.path.dirname(__file__), "dwpose_config/dwpose-l_384x288.py")
43 |
44 | if det_ckpt is None:
45 | det_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth'
46 |
47 | if pose_ckpt is None:
48 | pose_ckpt = "https://huggingface.co/wanghaofan/dw-ll_ucoco_384/resolve/main/dw-ll_ucoco_384.pth"
49 |
50 | # build detector
51 | self.detector = init_detector(det_config, det_ckpt, device=device)
52 | self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg)
53 |
54 | # build pose estimator
55 | self.pose_estimator = init_pose_estimator(
56 | pose_config,
57 | pose_ckpt,
58 | device=device)
59 |
60 | def to(self, device):
61 | self.detector.to(device)
62 | self.pose_estimator.to(device)
63 | return self
64 |
65 | def __call__(self, oriImg):
66 | # predict bbox
67 | det_result = inference_detector(self.detector, oriImg)
68 | pred_instance = det_result.pred_instances.cpu().numpy()
69 | bboxes = np.concatenate(
70 | (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
71 | bboxes = bboxes[np.logical_and(pred_instance.labels == 0,
72 | pred_instance.scores > 0.5)]
73 |
74 | # set NMS threshold
75 | bboxes = bboxes[nms(bboxes, 0.7), :4]
76 |
77 | # predict keypoints
78 | if len(bboxes) == 0:
79 | pose_results = inference_topdown(self.pose_estimator, oriImg)
80 | else:
81 | pose_results = inference_topdown(self.pose_estimator, oriImg, bboxes)
82 | preds = merge_data_samples(pose_results)
83 | preds = preds.pred_instances
84 |
85 | # preds = pose_results[0].pred_instances
86 | keypoints = preds.get('transformed_keypoints',
87 | preds.keypoints)
88 | if 'keypoint_scores' in preds:
89 | scores = preds.keypoint_scores
90 | else:
91 | scores = np.ones(keypoints.shape[:-1])
92 |
93 | if 'keypoints_visible' in preds:
94 | visible = preds.keypoints_visible
95 | else:
96 | visible = np.ones(keypoints.shape[:-1])
97 | keypoints_info = np.concatenate(
98 | (keypoints, scores[..., None], visible[..., None]),
99 | axis=-1)
100 | # compute neck joint
101 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
102 | # neck score when visualizing pred
103 | neck[:, 2:4] = np.logical_and(
104 | keypoints_info[:, 5, 2:4] > 0.3,
105 | keypoints_info[:, 6, 2:4] > 0.3).astype(int)
106 | new_keypoints_info = np.insert(
107 | keypoints_info, 17, neck, axis=1)
108 | mmpose_idx = [
109 | 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
110 | ]
111 | openpose_idx = [
112 | 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
113 | ]
114 | new_keypoints_info[:, openpose_idx] = \
115 | new_keypoints_info[:, mmpose_idx]
116 | keypoints_info = new_keypoints_info
117 |
118 | keypoints, scores, visible = keypoints_info[
119 | ..., :2], keypoints_info[..., 2], keypoints_info[..., 3]
120 |
121 | return keypoints, scores
122 |
--------------------------------------------------------------------------------
/pretrained_weights/put_models_here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TMElyralab/MusePose/61c52bd937224a614b3951419b735b639397cb62/pretrained_weights/put_models_here.txt
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.0.1
2 | torchdiffeq==0.2.3
3 | torchmetrics==1.2.1
4 | torchsde==0.2.5
5 | torchvision==0.15.2
6 | accelerate==0.29.3
7 | av==11.0.0
8 | clip @ https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip#sha256=b5842c25da441d6c581b53a5c60e0c2127ebafe0f746f8e15561a006c6c3be6a
9 | decord==0.6.0
10 | diffusers>=0.24.0,<=0.27.2
11 | einops==0.4.1
12 | imageio==2.33.0
13 | imageio-ffmpeg==0.4.9
14 | ffmpeg-python==0.2.0
15 | omegaconf==2.2.3
16 | open-clip-torch==2.20.0
17 | opencv-contrib-python==4.8.1.78
18 | opencv-python==4.8.1.78
19 | scikit-image==0.21.0
20 | scikit-learn==1.3.2
21 | transformers==4.33.1
22 | xformers==0.0.22
23 | moviepy==1.0.3
24 | wget==3.2
25 |
--------------------------------------------------------------------------------
/test_stage_1.py:
--------------------------------------------------------------------------------
1 | import os,sys
2 | import argparse
3 | import os
4 | import sys
5 | from datetime import datetime
6 | from pathlib import Path
7 | from typing import List
8 | import glob
9 |
10 | import numpy as np
11 | import torch
12 | import torchvision
13 | from diffusers import AutoencoderKL, DDIMScheduler
14 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
15 | from einops import repeat
16 | from omegaconf import OmegaConf
17 | from PIL import Image
18 | from torchvision import transforms
19 | from transformers import CLIPVisionModelWithProjection
20 |
21 |
22 | from musepose.models.pose_guider import PoseGuider
23 | from musepose.models.unet_2d_condition import UNet2DConditionModel
24 | from musepose.models.unet_3d import UNet3DConditionModel
25 | from musepose.pipelines.pipeline_pose2img import Pose2ImagePipeline
26 | from musepose.utils.util import get_fps, read_frames, save_videos_grid
27 |
28 |
29 | def parse_args():
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument("--config",default="./configs/test_stage_1.yaml")
32 | parser.add_argument("-W", type=int, default=768)
33 | parser.add_argument("-H", type=int, default=768)
34 | parser.add_argument("--seed", type=int, default=42)
35 | parser.add_argument("--cnt", type=int, default=1)
36 | parser.add_argument("--cfg", type=float, default=7)
37 | parser.add_argument("--steps", type=int, default=20)
38 | parser.add_argument("--fps", type=int)
39 | args = parser.parse_args()
40 |
41 | return args
42 |
43 |
44 |
45 | def main():
46 | args = parse_args()
47 |
48 | config = OmegaConf.load(args.config)
49 |
50 | if config.weight_dtype == "fp16":
51 | weight_dtype = torch.float16
52 | else:
53 | weight_dtype = torch.float32
54 |
55 | vae = AutoencoderKL.from_pretrained(
56 | config.pretrained_vae_path,
57 | ).to("cuda", dtype=weight_dtype)
58 |
59 | reference_unet = UNet2DConditionModel.from_pretrained(
60 | config.pretrained_base_model_path,
61 | subfolder="unet",
62 | ).to(dtype=weight_dtype, device="cuda")
63 |
64 | inference_config_path = config.inference_config
65 | infer_config = OmegaConf.load(inference_config_path)
66 | denoising_unet = UNet3DConditionModel.from_pretrained_2d(
67 | config.pretrained_base_model_path,
68 | # config.motion_module_path,
69 | "",
70 | subfolder="unet",
71 | unet_additional_kwargs={
72 | "use_motion_module": False,
73 | "unet_use_temporal_attention": False,
74 | },
75 | ).to(dtype=weight_dtype, device="cuda")
76 |
77 | pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
78 | dtype=weight_dtype, device="cuda"
79 | )
80 |
81 | image_enc = CLIPVisionModelWithProjection.from_pretrained(
82 | config.image_encoder_path
83 | ).to(dtype=weight_dtype, device="cuda")
84 |
85 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
86 | scheduler = DDIMScheduler(**sched_kwargs)
87 |
88 |
89 | width, height = args.W, args.H
90 |
91 | # load pretrained weights
92 | denoising_unet.load_state_dict(
93 | torch.load(config.denoising_unet_path, map_location="cpu"),
94 | strict=False,
95 | )
96 | reference_unet.load_state_dict(
97 | torch.load(config.reference_unet_path, map_location="cpu"),
98 | )
99 | pose_guider.load_state_dict(
100 | torch.load(config.pose_guider_path, map_location="cpu"),
101 | )
102 |
103 | pipe = Pose2ImagePipeline(
104 | vae=vae,
105 | image_encoder=image_enc,
106 | reference_unet=reference_unet,
107 | denoising_unet=denoising_unet,
108 | pose_guider=pose_guider,
109 | scheduler=scheduler,
110 | )
111 |
112 | pipe = pipe.to("cuda", dtype=weight_dtype)
113 |
114 | date_str = datetime.now().strftime("%Y%m%d")
115 | time_str = datetime.now().strftime("%H%M")
116 |
117 | m1 = config.pose_guider_path.split('.')[0].split('/')[-1]
118 | save_dir_name = f"{time_str}-{m1}"
119 |
120 | save_dir = Path(f"./output/image-{date_str}/{save_dir_name}")
121 | save_dir.mkdir(exist_ok=True, parents=True)
122 |
123 | def handle_single(ref_image_path, pose_path,seed):
124 | generator = torch.manual_seed(seed)
125 | ref_name = Path(ref_image_path).stem
126 | # pose_name = Path(pose_image_path).stem.replace("_kps", "")
127 | pose_name = Path(pose_path).stem
128 |
129 | ref_image_pil = Image.open(ref_image_path).convert("RGB")
130 | pose_image = Image.open(pose_path).convert("RGB")
131 |
132 | original_width, original_height = pose_image.size
133 |
134 | pose_transform = transforms.Compose(
135 | [transforms.Resize((height, width)), transforms.ToTensor()]
136 | )
137 |
138 | pose_image_tensor = pose_transform(pose_image)
139 | pose_image_tensor = pose_image_tensor.unsqueeze(0) # (1, c, h, w)
140 |
141 | ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w)
142 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w)
143 |
144 | image = pipe(
145 | ref_image_pil,
146 | pose_image,
147 | width,
148 | height,
149 | args.steps,
150 | args.cfg,
151 | generator=generator,
152 | ).images
153 |
154 | image = image.squeeze(2).squeeze(0) # (c, h, w)
155 | image = image.transpose(0, 1).transpose(1, 2) # (h w c)
156 | #image = (image + 1.0) / 2.0 # -1,1 -> 0,1
157 |
158 | image = (image * 255).numpy().astype(np.uint8)
159 | image = Image.fromarray(image, 'RGB')
160 | # image.save(os.path.join(save_dir, f"{ref_name}_{pose_name}_{args.H}x{args.W}_{int(args.cfg)}_{time_str}.png"))
161 |
162 | image_grid = Image.new('RGB',(original_width*3,original_height))
163 | imgs = [ref_image_pil,pose_image,image]
164 | x_offset = 0
165 | for img in imgs:
166 | img = img.resize((original_width*2, original_height*2))
167 | img.save(os.path.join(save_dir, f"res_{ref_name}_{pose_name}_{args.cfg}_{seed}.jpg"))
168 | img = img.resize((original_width,original_height))
169 | image_grid.paste(img, (x_offset,0))
170 | x_offset += img.size[0]
171 | image_grid.save(os.path.join(save_dir, f"grid_{ref_name}_{pose_name}_{args.cfg}_{seed}.jpg"))
172 |
173 |
174 | for ref_image_path_dir in config["test_cases"].keys():
175 | if os.path.isdir(ref_image_path_dir):
176 | ref_image_paths = glob.glob(os.path.join(ref_image_path_dir, '*.jpg'))
177 | else:
178 | ref_image_paths = [ref_image_path_dir]
179 | for ref_image_path in ref_image_paths:
180 | for pose_image_path_dir in config["test_cases"][ref_image_path_dir]:
181 | if os.path.isdir(pose_image_path_dir):
182 | pose_image_paths = glob.glob(os.path.join(pose_image_path_dir, '*.jpg'))
183 | else:
184 | pose_image_paths = [pose_image_path_dir]
185 | for pose_image_path in pose_image_paths:
186 | for i in range(args.cnt):
187 | handle_single(ref_image_path, pose_image_path, args.seed + i)
188 |
189 |
190 | if __name__ == "__main__":
191 | main()
192 |
--------------------------------------------------------------------------------
/test_stage_2.py:
--------------------------------------------------------------------------------
1 | import os,sys
2 | import argparse
3 | from datetime import datetime
4 | from pathlib import Path
5 | from typing import List
6 |
7 | import av
8 | import numpy as np
9 | import torch
10 | import torchvision
11 | from diffusers import AutoencoderKL, DDIMScheduler
12 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
13 | from einops import repeat
14 | from omegaconf import OmegaConf
15 | from PIL import Image
16 | from torchvision import transforms
17 | from transformers import CLIPVisionModelWithProjection
18 | import glob
19 | import torch.nn.functional as F
20 |
21 | from musepose.models.pose_guider import PoseGuider
22 | from musepose.models.unet_2d_condition import UNet2DConditionModel
23 | from musepose.models.unet_3d import UNet3DConditionModel
24 | from musepose.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
25 | from musepose.utils.util import get_fps, read_frames, save_videos_grid
26 |
27 |
28 |
29 | def parse_args():
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument("--config", type=str, default="./configs/test_stage_2.yaml")
32 | parser.add_argument("-W", type=int, default=768, help="Width")
33 | parser.add_argument("-H", type=int, default=768, help="Height")
34 | parser.add_argument("-L", type=int, default=300, help="video frame length")
35 | parser.add_argument("-S", type=int, default=48, help="video slice frame number")
36 | parser.add_argument("-O", type=int, default=4, help="video slice overlap frame number")
37 |
38 | parser.add_argument("--cfg", type=float, default=3.5, help="Classifier free guidance")
39 | parser.add_argument("--seed", type=int, default=99)
40 | parser.add_argument("--steps", type=int, default=20, help="DDIM sampling steps")
41 | parser.add_argument("--fps", type=int)
42 |
43 | parser.add_argument("--skip", type=int, default=1, help="frame sample rate = (skip+1)")
44 | args = parser.parse_args()
45 |
46 | print('Width:', args.W)
47 | print('Height:', args.H)
48 | print('Length:', args.L)
49 | print('Slice:', args.S)
50 | print('Overlap:', args.O)
51 | print('Classifier free guidance:', args.cfg)
52 | print('DDIM sampling steps :', args.steps)
53 | print("skip", args.skip)
54 |
55 | return args
56 |
57 |
58 | def scale_video(video,width,height):
59 | video_reshaped = video.view(-1, *video.shape[2:]) # [batch*frames, channels, height, width]
60 | scaled_video = F.interpolate(video_reshaped, size=(height, width), mode='bilinear', align_corners=False)
61 | scaled_video = scaled_video.view(*video.shape[:2], scaled_video.shape[1], height, width) # [batch, frames, channels, height, width]
62 |
63 | return scaled_video
64 |
65 |
66 | def main():
67 | args = parse_args()
68 |
69 | config = OmegaConf.load(args.config)
70 |
71 | if config.weight_dtype == "fp16":
72 | weight_dtype = torch.float16
73 | else:
74 | weight_dtype = torch.float32
75 |
76 | vae = AutoencoderKL.from_pretrained(
77 | config.pretrained_vae_path,
78 | ).to("cuda", dtype=weight_dtype)
79 |
80 | reference_unet = UNet2DConditionModel.from_pretrained(
81 | config.pretrained_base_model_path,
82 | subfolder="unet",
83 | ).to(dtype=weight_dtype, device="cuda")
84 |
85 | inference_config_path = config.inference_config
86 | infer_config = OmegaConf.load(inference_config_path)
87 | denoising_unet = UNet3DConditionModel.from_pretrained_2d(
88 | config.pretrained_base_model_path,
89 | config.motion_module_path,
90 | subfolder="unet",
91 | unet_additional_kwargs=infer_config.unet_additional_kwargs,
92 | ).to(dtype=weight_dtype, device="cuda")
93 |
94 | pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
95 | dtype=weight_dtype, device="cuda"
96 | )
97 |
98 | image_enc = CLIPVisionModelWithProjection.from_pretrained(
99 | config.image_encoder_path
100 | ).to(dtype=weight_dtype, device="cuda")
101 |
102 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
103 | scheduler = DDIMScheduler(**sched_kwargs)
104 |
105 | generator = torch.manual_seed(args.seed)
106 |
107 | width, height = args.W, args.H
108 |
109 | # load pretrained weights
110 | denoising_unet.load_state_dict(
111 | torch.load(config.denoising_unet_path, map_location="cpu"),
112 | strict=False,
113 | )
114 | reference_unet.load_state_dict(
115 | torch.load(config.reference_unet_path, map_location="cpu"),
116 | )
117 | pose_guider.load_state_dict(
118 | torch.load(config.pose_guider_path, map_location="cpu"),
119 | )
120 |
121 | pipe = Pose2VideoPipeline(
122 | vae=vae,
123 | image_encoder=image_enc,
124 | reference_unet=reference_unet,
125 | denoising_unet=denoising_unet,
126 | pose_guider=pose_guider,
127 | scheduler=scheduler,
128 | )
129 | pipe = pipe.to("cuda", dtype=weight_dtype)
130 |
131 | date_str = datetime.now().strftime("%Y%m%d")
132 | time_str = datetime.now().strftime("%H%M")
133 |
134 | def handle_single(ref_image_path,pose_video_path):
135 | print ('handle===',ref_image_path, pose_video_path)
136 | ref_name = Path(ref_image_path).stem
137 | pose_name = Path(pose_video_path).stem.replace("_kps", "")
138 |
139 | ref_image_pil = Image.open(ref_image_path).convert("RGB")
140 |
141 | pose_list = []
142 | pose_tensor_list = []
143 | pose_images = read_frames(pose_video_path)
144 | src_fps = get_fps(pose_video_path)
145 | print(f"pose video has {len(pose_images)} frames, with {src_fps} fps")
146 | L = min(args.L, len(pose_images))
147 | pose_transform = transforms.Compose(
148 | [transforms.Resize((height, width)), transforms.ToTensor()]
149 | )
150 | original_width,original_height = 0,0
151 |
152 | pose_images = pose_images[::args.skip+1]
153 | print("processing length:", len(pose_images))
154 | src_fps = src_fps // (args.skip + 1)
155 | print("fps", src_fps)
156 | L = L // ((args.skip + 1))
157 |
158 | for pose_image_pil in pose_images[: L]:
159 | pose_tensor_list.append(pose_transform(pose_image_pil))
160 | pose_list.append(pose_image_pil)
161 | original_width, original_height = pose_image_pil.size
162 | pose_image_pil = pose_image_pil.resize((width,height))
163 |
164 | # repeart the last segment
165 | last_segment_frame_num = (L - args.S) % (args.S - args.O)
166 | repeart_frame_num = (args.S - args.O - last_segment_frame_num) % (args.S - args.O)
167 | for i in range(repeart_frame_num):
168 | pose_list.append(pose_list[-1])
169 | pose_tensor_list.append(pose_tensor_list[-1])
170 |
171 |
172 | ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w)
173 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w)
174 | ref_image_tensor = repeat(ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=L)
175 |
176 | pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)
177 | pose_tensor = pose_tensor.transpose(0, 1)
178 | pose_tensor = pose_tensor.unsqueeze(0)
179 |
180 | video = pipe(
181 | ref_image_pil,
182 | pose_list,
183 | width,
184 | height,
185 | len(pose_list),
186 | args.steps,
187 | args.cfg,
188 | generator=generator,
189 | context_frames=args.S,
190 | context_stride=1,
191 | context_overlap=args.O,
192 | ).videos
193 |
194 |
195 | m1 = config.pose_guider_path.split('.')[0].split('/')[-1]
196 | m2 = config.motion_module_path.split('.')[0].split('/')[-1]
197 |
198 | save_dir_name = f"{time_str}-{args.cfg}-{m1}-{m2}"
199 | save_dir = Path(f"./output/video-{date_str}/{save_dir_name}")
200 | save_dir.mkdir(exist_ok=True, parents=True)
201 |
202 | result = scale_video(video[:,:,:L], original_width, original_height)
203 | save_videos_grid(
204 | result,
205 | f"{save_dir}/{ref_name}_{pose_name}_{args.cfg}_{args.steps}_{args.skip}.mp4",
206 | n_rows=1,
207 | fps=src_fps if args.fps is None else args.fps,
208 | )
209 |
210 | video = torch.cat([ref_image_tensor, pose_tensor[:,:,:L], video[:,:,:L]], dim=0)
211 | video = scale_video(video, original_width, original_height)
212 | save_videos_grid(
213 | video,
214 | f"{save_dir}/{ref_name}_{pose_name}_{args.cfg}_{args.steps}_{args.skip}_{m1}_{m2}.mp4",
215 | n_rows=3,
216 | fps=src_fps if args.fps is None else args.fps,
217 | )
218 |
219 | for ref_image_path_dir in config["test_cases"].keys():
220 | if os.path.isdir(ref_image_path_dir):
221 | ref_image_paths = glob.glob(os.path.join(ref_image_path_dir, '*.jpg'))
222 | else:
223 | ref_image_paths = [ref_image_path_dir]
224 | for ref_image_path in ref_image_paths:
225 | for pose_video_path_dir in config["test_cases"][ref_image_path_dir]:
226 | if os.path.isdir(pose_video_path_dir):
227 | pose_video_paths = glob.glob(os.path.join(pose_video_path_dir, '*.mp4'))
228 | else:
229 | pose_video_paths = [pose_video_path_dir]
230 | for pose_video_path in pose_video_paths:
231 | handle_single(ref_image_path, pose_video_path)
232 |
233 |
234 |
235 |
236 | if __name__ == "__main__":
237 | main()
238 |
--------------------------------------------------------------------------------