├── .gitignore ├── LICENSE ├── README.md ├── assets ├── images │ └── ref.png └── videos │ └── dance.mp4 ├── configs ├── inference_v2.yaml ├── test_stage_1.yaml ├── test_stage_2.yaml ├── train_stage_1.yaml └── train_stage_2.yaml ├── downloading_weights.py ├── draw_dwpose.py ├── extract_dwpose_keypoints.py ├── extract_meta_info_multiple_dataset.py ├── musepose ├── __init__.py ├── dataset │ ├── dance_image.py │ └── dance_video.py ├── models │ ├── attention.py │ ├── motion_module.py │ ├── mutual_self_attention.py │ ├── pose_guider.py │ ├── resnet.py │ ├── transformer_2d.py │ ├── transformer_3d.py │ ├── unet_2d_blocks.py │ ├── unet_2d_condition.py │ ├── unet_3d.py │ └── unet_3d_blocks.py ├── pipelines │ ├── __init__.py │ ├── context.py │ ├── pipeline_pose2img.py │ ├── pipeline_pose2vid.py │ ├── pipeline_pose2vid_long.py │ └── utils.py └── utils │ └── util.py ├── pose ├── config │ ├── dwpose-l_384x288.py │ └── yolox_l_8xb8-300e_coco.py └── script │ ├── dwpose.py │ ├── tool.py │ ├── util.py │ └── wholebody.py ├── pose_align.py ├── pretrained_weights └── put_models_here.txt ├── requirements.txt ├── test_stage_1.py ├── test_stage_2.py ├── train_stage_1_multiGPU.py └── train_stage_2_multiGPU.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | .DS_Store 7 | pretrained_weights 8 | output 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | MIT License 3 | 4 | Copyright (c) 2024 Tencent Music Entertainment Group 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | 24 | 25 | Other dependencies and licenses: 26 | 27 | 28 | Open Source Software Licensed under the MIT License: 29 | -------------------------------------------------------------------- 30 | 1. sd-vae-ft-mse 31 | Files:https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main 32 | License:MIT license 33 | For details:https://choosealicense.com/licenses/mit/ 34 | 35 | 36 | 37 | 38 | Open Source Software Licensed under the Apache License Version 2.0: 39 | -------------------------------------------------------------------- 40 | 1. DWpose 41 | Files:https://huggingface.co/yzd-v/DWPose/tree/main 42 | License:Apache-2.0 43 | For details:https://choosealicense.com/licenses/apache-2.0/ 44 | 45 | 2. Moore-AnimateAnyone 46 | Files:https://github.com/MooreThreads/Moore-AnimateAnyone 47 | License:Apache-2.0 48 | For details:https://github.com/MooreThreads/Moore-AnimateAnyone/blob/master/LICENSE 49 | 50 | Terms of the Apache License Version 2.0: 51 | -------------------------------------------------------------------- 52 | Apache License 53 | 54 | Version 2.0, January 2004 55 | 56 | http://www.apache.org/licenses/ 57 | 58 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 59 | 1. Definitions. 60 | 61 | "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. 62 | 63 | "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. 64 | 65 | "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 66 | 67 | "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. 68 | 69 | "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. 70 | 71 | "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. 72 | 73 | "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). 74 | 75 | "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. 76 | 77 | "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 78 | 79 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 80 | 81 | 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 82 | 83 | 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 84 | 85 | 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: 86 | 87 | You must give any other recipients of the Work or Derivative Works a copy of this License; and 88 | 89 | You must cause any modified files to carry prominent notices stating that You changed the files; and 90 | 91 | You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and 92 | 93 | If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. 94 | 95 | You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 96 | 97 | 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 98 | 99 | 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 100 | 101 | 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 102 | 103 | 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 104 | 105 | 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 106 | 107 | END OF TERMS AND CONDITIONS 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MusePose 2 | 3 | MusePose: a Pose-Driven Image-to-Video Framework for Virtual Human Generation. 4 | 5 | Zhengyan Tong, 6 | Chao Li, 7 | Zhaokang Chen, 8 | Bin Wu, 9 | Wenjiang Zhou 10 | (Corresponding Author, benbinwu@tencent.com) 11 | 12 | Lyra Lab, Tencent Music Entertainment 13 | 14 | 15 | **[github](https://github.com/TMElyralab/MusePose)** **[huggingface](https://huggingface.co/TMElyralab/MusePose)** **space (comming soon)** **Project (comming soon)** **Technical report (comming soon)** 16 | 17 | [MusePose](https://github.com/TMElyralab/MusePose) is an image-to-video generation framework for virtual human under control signal such as pose. The current released model was an implementation of [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone) by optimizing [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone). 18 | 19 | `MusePose` is the last building block of **the Muse opensource serie**. Together with [MuseV](https://github.com/TMElyralab/MuseV) and [MuseTalk](https://github.com/TMElyralab/MuseTalk), we hope the community can join us and march towards the vision where a virtual human can be generated end2end with native ability of full body movement and interaction. Please stay tuned for our next milestone! 20 | 21 | We really appreciate [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone) for their academic paper and [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone) for their code base, which have significantly expedited the development of the AIGC community and [MusePose](https://github.com/TMElyralab/MusePose). 22 | 23 | Update: 24 | 1. We release train codes of MusePose now! 25 | 26 | ## Overview 27 | [MusePose](https://github.com/TMElyralab/MusePose) is a diffusion-based and pose-guided virtual human video generation framework. 28 | Our main contributions could be summarized as follows: 29 | 1. The released model can generate dance videos of the human character in a reference image under the given pose sequence. The result quality exceeds almost all current open source models within the same topic. 30 | 2. We release the `pose align` algorithm so that users could align arbitrary dance videos to arbitrary reference images, which **SIGNIFICANTLY** improved inference performance and enhanced model usability. 31 | 3. We have fixed several important bugs and made some improvement based on the code of [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone). 32 | 33 | ## Demos 34 | 35 | 36 | 37 | 40 | 43 | 44 | 45 | 46 | 49 | 52 | 53 | 54 | 55 | 56 | 59 | 62 | 63 | 64 | 65 | 68 | 71 | 72 | 73 |
38 | 39 | 41 | 42 |
47 | 48 | 50 | 51 |
57 | 58 | 60 | 61 |
66 | 67 | 69 | 70 |
74 | 75 | 76 | ## News 77 | - [05/27/2024] Release `MusePose` and pretrained models. 78 | - [05/31/2024] Support [Comfyui-MusePose](https://github.com/TMElyralab/Comfyui-MusePose) 79 | - [06/14/2024] Bug Fixed in `inference_v2.yaml`. 80 | - [03/04/2025] Release train codes. 81 | 82 | ## Todo: 83 | - [x] release our trained models and inference codes of MusePose. 84 | - [x] release pose align algorithm. 85 | - [x] Comfyui-MusePose 86 | - [x] training guidelines. 87 | - [ ] Huggingface Gradio demo. 88 | - [ ] a improved architecture and model (may take longer). 89 | 90 | 91 | # Getting Started 92 | We provide a detailed tutorial about the installation and the basic usage of MusePose for new users: 93 | 94 | ## Installation 95 | To prepare the Python environment and install additional packages such as opencv, diffusers, mmcv, etc., please follow the steps below: 96 | 97 | ### Build environment 98 | 99 | We recommend a python version >=3.10 and cuda version =11.7. Then build environment as follows: 100 | 101 | ```shell 102 | pip install -r requirements.txt 103 | ``` 104 | 105 | ### mmlab packages 106 | ```bash 107 | pip install --no-cache-dir -U openmim 108 | mim install mmengine 109 | mim install "mmcv>=2.0.1" 110 | mim install "mmdet>=3.1.0" 111 | mim install "mmpose>=1.1.0" 112 | ``` 113 | 114 | 115 | ### Download weights 116 | You can download weights manually as follows: 117 | 118 | 1. Download our trained [weights](https://huggingface.co/TMElyralab/MusePose). 119 | 120 | 2. Download the weights of other components: 121 | - [sd-image-variations-diffusers](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/tree/main/unet) 122 | - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse) 123 | - [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main) 124 | - [yolox](https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth) - Make sure to rename to `yolox_l_8x8_300e_coco.pth` 125 | - [image_encoder](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/tree/main/image_encoder) 126 | - [control_v11p_sd15_openpose](https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/blob/main/diffusion_pytorch_model.bin) (for training only) 127 | - [animatediff](https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15_v2.ckpt) (for training only) 128 | 129 | Finally, these weights should be organized in `pretrained_weights` as follows: 130 | ``` 131 | ./pretrained_weights/ 132 | |-- MusePose 133 | | |-- denoising_unet.pth 134 | | |-- motion_module.pth 135 | | |-- pose_guider.pth 136 | | └── reference_unet.pth 137 | |-- dwpose 138 | | |-- dw-ll_ucoco_384.pth 139 | | └── yolox_l_8x8_300e_coco.pth 140 | |-- sd-image-variations-diffusers 141 | | └── unet 142 | | |-- config.json 143 | | └── diffusion_pytorch_model.bin 144 | |-- image_encoder 145 | | |-- config.json 146 | | └── pytorch_model.bin 147 | |-- sd-vae-ft-mse 148 | | |-- config.json 149 | | └── diffusion_pytorch_model.bin 150 | |-- control_v11p_sd15_openpose 151 | | └── diffusion_pytorch_model.bin 152 | └── animatediff 153 | └── mm_sd_v15_v2.ckpt 154 | ``` 155 | 156 | ## Quickstart 157 | ### Inference 158 | #### Preparation 159 | Prepare your referemce images and dance videos in the folder ```./assets``` and organnized as the example: 160 | ``` 161 | ./assets/ 162 | |-- images 163 | | └── ref.png 164 | └── videos 165 | └── dance.mp4 166 | ``` 167 | 168 | #### Pose Alignment 169 | Get the aligned dwpose of the reference image: 170 | ``` 171 | python pose_align.py --imgfn_refer ./assets/images/ref.png --vidfn ./assets/videos/dance.mp4 172 | ``` 173 | After this, you can see the pose align results in ```./assets/poses```, where ```./assets/poses/align/img_ref_video_dance.mp4``` is the aligned dwpose and the ```./assets/poses/align_demo/img_ref_video_dance.mp4``` is for debug. 174 | 175 | #### Inferring MusePose 176 | Add the path of the reference image and the aligned dwpose to the test config file ```./configs/test_stage_2.yaml``` as the example: 177 | ``` 178 | test_cases: 179 | "./assets/images/ref.png": 180 | - "./assets/poses/align/img_ref_video_dance.mp4" 181 | ``` 182 | 183 | Then, simply run 184 | ``` 185 | python test_stage_2.py --config ./configs/test_stage_2.yaml 186 | ``` 187 | ```./configs/test_stage_2.yaml``` is the path to the inference configuration file. 188 | 189 | Finally, you can see the output results in ```./output/``` 190 | 191 | ##### Reducing VRAM cost 192 | If you want to reduce the VRAM cost, you could set the width and height for inference. For example, 193 | ``` 194 | python test_stage_2.py --config ./configs/test_stage_2.yaml -W 512 -H 512 195 | ``` 196 | It will generate the video at 512 x 512 first, and then resize it back to the original size of the pose video. 197 | 198 | Currently, it takes 16GB VRAM to run on 512 x 512 x 48 and takes 28GB VRAM to run on 768 x 768 x 48. However, it should be noticed that the inference resolution would affect the final results (especially face region). 199 | 200 | #### Face Enhancement 201 | 202 | If you want to enhance the face region to have a better consistency of the face, you could use [FaceFusion](https://github.com/facefusion/facefusion). You could use the `face-swap` function to swap the face in the reference image to the generated video. 203 | 204 | ### Training 205 | 1. Prepare 206 | First, put all your dance videos in a folder such as `./xxx` 207 | Next, `python extract_dwpose_keypoints.py --video_dir ./xxx`. The extracted dwpose_keypoints will be saved in `./xxx_dwpose_keypoints`. 208 | Then, `python draw_dwpose.py --video_dir ./xxx`. The rendered dwpose videos will be saved in `./xxx_dwpose_without_face` if `draw_face=False`. The rendered dwpose videos will be saved in `./xxx_dwpose` if `draw_face=True`. 209 | Finally, `python extract_meta_info_multiple_dataset.py --video_dirs ./xxx --dataset_name xxx` 210 | You will get a json file to record the path of all data. `./meta/xxx.json` 211 | 212 | 2. Config your accelerate and deepspeed 213 | `pip install accelerate` 214 | use cmd `accelerate config` to config your deepspeed according to your machine. 215 | We use zero 2 without any offload and our machine has 8x80GB GPU. 216 | 217 | 3. Config the yaml file for training 218 | stage 1 219 | `./configs/train_stage_1.yaml` 220 | stage 2 221 | `./configs/train_stage_2.yaml` 222 | 223 | 4. Launch Training 224 | stage 1 225 | `accelerate launch train_stage_1_multiGPU.py --config configs/train_stage_1.yaml` 226 | stage 2 227 | `accelerate launch train_stage_2_multiGPU.py --config configs/train_stage_2.yaml` 228 | 229 | 230 | 231 | 232 | # Acknowledgement 233 | 1. We thank [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone) for their technical report, and have refer much to [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone) and [diffusers](https://github.com/huggingface/diffusers). 234 | 1. We thank open-source components like [AnimateDiff](https://animatediff.github.io/), [dwpose](https://github.com/IDEA-Research/DWPose), [Stable Diffusion](https://github.com/CompVis/stable-diffusion), etc.. 235 | 236 | Thanks for open-sourcing! 237 | 238 | # Limitations 239 | - Detail consitency: some details of the original character are not well preserved (e.g. face region and complex clothing). 240 | - Noise and flickering: we observe noise and flicking in complex background. 241 | 242 | # Citation 243 | ```bib 244 | @article{musepose, 245 | title={MusePose: a Pose-Driven Image-to-Video Framework for Virtual Human Generation}, 246 | author={Tong, Zhengyan and Li, Chao and Chen, Zhaokang and Wu, Bin and Zhou, Wenjiang}, 247 | journal={arxiv}, 248 | year={2024} 249 | } 250 | ``` 251 | # Disclaimer/License 252 | 1. `code`: The code of MusePose is released under the MIT License. There is no limitation for both academic and commercial usage. 253 | 1. `model`: The trained model are available for non-commercial research purposes only. 254 | 1. `other opensource model`: Other open-source models used must comply with their license, such as `ft-mse-vae`, `dwpose`, etc.. 255 | 1. The testdata are collected from internet, which are available for non-commercial research purposes only. 256 | 1. `AIGC`: This project strives to impact the domain of AI-driven video generation positively. Users are granted the freedom to create videos using this tool, but they are expected to comply with local laws and utilize it responsibly. The developers do not assume any responsibility for potential misuse by users. 257 | -------------------------------------------------------------------------------- /assets/images/ref.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MusePose/61c52bd937224a614b3951419b735b639397cb62/assets/images/ref.png -------------------------------------------------------------------------------- /assets/videos/dance.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MusePose/61c52bd937224a614b3951419b735b639397cb62/assets/videos/dance.mp4 -------------------------------------------------------------------------------- /configs/inference_v2.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | use_inflated_groupnorm: true 3 | unet_use_cross_frame_attention: false 4 | unet_use_temporal_attention: false 5 | use_motion_module: true 6 | motion_module_resolutions: 7 | - 1 8 | - 2 9 | - 4 10 | - 8 11 | motion_module_mid_block: true 12 | motion_module_decoder_only: false 13 | motion_module_type: Vanilla 14 | motion_module_kwargs: 15 | num_attention_heads: 8 16 | num_transformer_block: 1 17 | attention_block_types: 18 | - Temporal_Self 19 | - Temporal_Self 20 | temporal_position_encoding: true 21 | temporal_position_encoding_max_len: 128 22 | temporal_attention_dim_div: 1 23 | 24 | noise_scheduler_kwargs: 25 | beta_start: 0.00085 26 | beta_end: 0.012 27 | beta_schedule: "scaled_linear" 28 | clip_sample: false 29 | steps_offset: 1 30 | ### Zero-SNR params 31 | prediction_type: "v_prediction" 32 | rescale_betas_zero_snr: True 33 | timestep_spacing: "trailing" 34 | 35 | sampler: DDIM -------------------------------------------------------------------------------- /configs/test_stage_1.yaml: -------------------------------------------------------------------------------- 1 | pretrained_base_model_path: './pretrained_weights/sd-image-variations-diffusers' 2 | pretrained_vae_path: './pretrained_weights/sd-vae-ft-mse' 3 | image_encoder_path: './pretrained_weights/image_encoder' 4 | 5 | 6 | 7 | denoising_unet_path: "./pretrained_weights/MusePose/denoising_unet.pth" 8 | reference_unet_path: "./pretrained_weights/MusePose/reference_unet.pth" 9 | pose_guider_path: "./pretrained_weights/MusePose/pose_guider.pth" 10 | 11 | 12 | 13 | 14 | inference_config: "./configs/inference_v2.yaml" 15 | weight_dtype: 'fp16' 16 | 17 | 18 | 19 | test_cases: 20 | "./assets/images/ref.png": 21 | - "./assets/poses/align/img_ref_video_dance.mp4" 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /configs/test_stage_2.yaml: -------------------------------------------------------------------------------- 1 | pretrained_base_model_path: './pretrained_weights/sd-image-variations-diffusers' 2 | pretrained_vae_path: './pretrained_weights/sd-vae-ft-mse' 3 | image_encoder_path: './pretrained_weights/image_encoder' 4 | 5 | 6 | 7 | denoising_unet_path: "./pretrained_weights/MusePose/denoising_unet.pth" 8 | reference_unet_path: "./pretrained_weights/MusePose/reference_unet.pth" 9 | pose_guider_path: "./pretrained_weights/MusePose/pose_guider.pth" 10 | motion_module_path: "./pretrained_weights/MusePose/motion_module.pth" 11 | 12 | 13 | 14 | inference_config: "./configs/inference_v2.yaml" 15 | weight_dtype: 'fp16' 16 | 17 | 18 | 19 | test_cases: 20 | "./assets/images/ref.png": 21 | - "./assets/poses/align/img_ref_video_dance.mp4" 22 | -------------------------------------------------------------------------------- /configs/train_stage_1.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_bs: 8 3 | train_width: 768 4 | train_height: 768 5 | meta_paths: 6 | - "./meta/xxx.json" 7 | # Margin of frame indexes between ref and tgt images 8 | sample_margin: 128 9 | 10 | solver: 11 | gradient_accumulation_steps: 1 12 | mixed_precision: 'fp16' 13 | enable_xformers_memory_efficient_attention: True 14 | gradient_checkpointing: False 15 | max_train_steps: 400000 16 | max_grad_norm: 1.0 17 | # lr 18 | learning_rate: 1.0e-5 19 | scale_lr: False 20 | lr_warmup_steps: 1 21 | lr_scheduler: 'constant' 22 | 23 | # optimizer 24 | use_8bit_adam: False 25 | adam_beta1: 0.9 26 | adam_beta2: 0.999 27 | adam_weight_decay: 1.0e-2 28 | adam_epsilon: 1.0e-8 29 | 30 | val: 31 | validation_steps: 20000000000000000 32 | 33 | 34 | noise_scheduler_kwargs: 35 | num_train_timesteps: 1000 36 | beta_start: 0.00085 37 | beta_end: 0.012 38 | beta_schedule: "scaled_linear" 39 | steps_offset: 1 40 | clip_sample: false 41 | 42 | base_model_path: './pretrained_weights/sd-image-variations-diffusers' 43 | vae_model_path: './pretrained_weights/sd-vae-ft-mse' 44 | image_encoder_path: './pretrained_weights/sd-image-variations-diffusers/image_encoder' 45 | controlnet_openpose_path: './pretrained_weights/control_v11p_sd15_openpose/diffusion_pytorch_model.bin' 46 | 47 | weight_dtype: 'fp16' # [fp16, fp32] 48 | uncond_ratio: 0.1 49 | noise_offset: 0.05 50 | snr_gamma: 5.0 51 | # snr_gamma: 0 52 | enable_zero_snr: True 53 | pose_guider_pretrain: True 54 | 55 | seed: 12580 56 | resume_from_checkpoint: '' 57 | checkpointing_steps: 2250000000000000000000 58 | save_model_epoch_interval: 25 59 | exp_name: 'stage_1' 60 | output_dir: './exp_output' 61 | 62 | # load pretrained weights 63 | load_pth: True 64 | denoising_unet_path: "denoising_unet.pth" 65 | reference_unet_path: "reference_unet.pth" 66 | pose_guider_path: "pose_guider.pth" 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /configs/train_stage_2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_bs: 1 3 | train_width: 768 4 | train_height: 768 5 | meta_paths: 6 | - "./meta/xxx.json" 7 | sample_rate: 2 8 | n_sample_frames: 48 9 | 10 | solver: 11 | gradient_accumulation_steps: 1 12 | mixed_precision: 'fp16' 13 | enable_xformers_memory_efficient_attention: True 14 | gradient_checkpointing: True 15 | max_train_steps: 2000000 16 | max_grad_norm: 1.0 17 | # lr 18 | learning_rate: 1e-5 19 | scale_lr: False 20 | lr_warmup_steps: 1 21 | lr_scheduler: 'constant' 22 | 23 | # optimizer 24 | use_8bit_adam: False 25 | adam_beta1: 0.9 26 | adam_beta2: 0.999 27 | adam_weight_decay: 1.0e-2 28 | adam_epsilon: 1.0e-8 29 | 30 | val: 31 | validation_steps: 20000000000000 32 | 33 | 34 | noise_scheduler_kwargs: 35 | num_train_timesteps: 1000 36 | beta_start: 0.00085 37 | beta_end: 0.012 38 | beta_schedule: "scaled_linear" 39 | steps_offset: 1 40 | clip_sample: false 41 | 42 | base_model_path: './pretrained_weights/stable-diffusion-v1-5' 43 | vae_model_path: './pretrained_weights/sd-vae-ft-mse' 44 | image_encoder_path: './pretrained_weights/sd-image-variations-diffusers/image_encoder' 45 | mm_path: './pretrained_weights/mm_sd_v15_v2.ckpt' 46 | 47 | weight_dtype: 'fp16' # [fp16, fp32] 48 | uncond_ratio: 0.1 49 | noise_offset: 0.05 50 | snr_gamma: 5.0 51 | # snr_gamma: 0 52 | enable_zero_snr: True 53 | stage1_ckpt_dir: './exp_output/stage_1' 54 | stage1_ckpt_step: 30000 55 | 56 | 57 | 58 | seed: 12580 59 | resume_from_checkpoint: '' 60 | checkpointing_steps: 2000000000000000 61 | exp_name: 'stage_2' 62 | output_dir: './exp_output' 63 | 64 | 65 | -------------------------------------------------------------------------------- /downloading_weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wget 3 | from tqdm import tqdm 4 | 5 | os.makedirs('pretrained_weights', exist_ok=True) 6 | 7 | urls = ['https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth', 8 | 'https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.pth', 9 | 'https://huggingface.co/TMElyralab/MusePose/resolve/main/MusePose/denoising_unet.pth', 10 | 'https://huggingface.co/TMElyralab/MusePose/resolve/main/MusePose/motion_module.pth', 11 | 'https://huggingface.co/TMElyralab/MusePose/resolve/main/MusePose/pose_guider.pth', 12 | 'https://huggingface.co/TMElyralab/MusePose/resolve/main/MusePose/reference_unet.pth', 13 | 'https://huggingface.co/lambdalabs/sd-image-variations-diffusers/resolve/main/unet/diffusion_pytorch_model.bin', 14 | 'https://huggingface.co/lambdalabs/sd-image-variations-diffusers/resolve/main/image_encoder/pytorch_model.bin', 15 | 'https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/diffusion_pytorch_model.bin' 16 | ] 17 | 18 | paths = ['dwpose', 'dwpose', 'MusePose', 'MusePose', 'MusePose', 'MusePose', 'sd-image-variations-diffusers/unet', 'image_encoder', 'sd-vae-ft-mse'] 19 | 20 | for path in paths: 21 | os.makedirs(f'pretrained_weights/{path}', exist_ok=True) 22 | 23 | # saving weights 24 | for url, path in tqdm(zip(urls, paths)): 25 | filename = wget.download(url, f'pretrained_weights/{path}') 26 | 27 | config_urls = ['https://huggingface.co/lambdalabs/sd-image-variations-diffusers/resolve/main/unet/config.json', 28 | 'https://huggingface.co/lambdalabs/sd-image-variations-diffusers/resolve/main/image_encoder/config.json', 29 | 'https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/config.json'] 30 | 31 | config_paths = ['sd-image-variations-diffusers/unet', 'image_encoder', 'sd-vae-ft-mse'] 32 | 33 | # saving config files 34 | for url, path in tqdm(zip(config_urls, config_paths)): 35 | filename = wget.download(url, f'pretrained_weights/{path}') 36 | 37 | # renaming model name as given in readme 38 | os.rename('pretrained_weights/dwpose/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth', 'pretrained_weights/dwpose/yolox_l_8x8_300e_coco.pth') 39 | -------------------------------------------------------------------------------- /draw_dwpose.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import argparse 4 | import numpy as np 5 | from tqdm import tqdm 6 | from PIL import Image 7 | 8 | from pose.script.tool import save_videos_from_pil 9 | from pose.script.dwpose import draw_pose 10 | 11 | 12 | 13 | def draw_dwpose(video_path, pose_path, out_path, draw_face): 14 | 15 | # capture video info 16 | cap = cv2.VideoCapture(video_path) 17 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 18 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 19 | fps = cap.get(cv2.CAP_PROP_FPS) 20 | fps = int(np.around(fps)) 21 | # fps = get_fps(video_path) 22 | cap.release() 23 | 24 | # render resolution, short edge = 1024 25 | k = float(1024) / min(width, height) 26 | h_render = int(k*height//2 * 2) 27 | w_render = int(k*width//2 * 2) 28 | 29 | # save resolution, short edge = 768 30 | k = float(768) / min(width, height) 31 | h_save = int(k*height//2 * 2) 32 | w_save = int(k*width//2 * 2) 33 | 34 | poses = np.load(pose_path, allow_pickle=True) 35 | poses = poses.tolist() 36 | 37 | frames = [] 38 | for pose in tqdm(poses): 39 | detected_map = draw_pose(pose, h_render, w_render, draw_face) 40 | detected_map = cv2.resize(detected_map, (w_save, h_save), interpolation=cv2.INTER_AREA) 41 | # cv2.imshow('', detected_map) 42 | # cv2.waitKey(0) 43 | detected_map = cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB) 44 | detected_map = Image.fromarray(detected_map) 45 | frames.append(detected_map) 46 | 47 | save_videos_from_pil(frames, out_path, fps) 48 | 49 | 50 | 51 | if __name__ == "__main__": 52 | 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--video_dir", type=str, default="./UBC_fashion/test", help='dance video dir') 55 | parser.add_argument("--pose_dir", type=str, default=None, help='auto makedir') 56 | parser.add_argument("--save_dir", type=str, default=None, help='auto makedir') 57 | parser.add_argument("--draw_face", type=bool, default=False, help='whether draw face or not') 58 | args = parser.parse_args() 59 | 60 | 61 | # video dir 62 | video_dir = args.video_dir 63 | 64 | # pose dir 65 | if args.pose_dir is None: 66 | pose_dir = args.video_dir + "_dwpose_keypoints" 67 | else: 68 | pose_dir = args.pose_dir 69 | 70 | # save dir 71 | if args.save_dir is None: 72 | if args.draw_face == True: 73 | save_dir = args.video_dir + "_dwpose" 74 | else: 75 | save_dir = args.video_dir + "_dwpose_without_face" 76 | else: 77 | save_dir = args.save_dir 78 | if not os.path.exists(save_dir): 79 | os.makedirs(save_dir) 80 | 81 | 82 | # collect all video_folder paths 83 | video_mp4_paths = set() 84 | for root, dirs, files in os.walk(args.video_dir): 85 | for name in files: 86 | if name.endswith(".mp4"): 87 | video_mp4_paths.add(os.path.join(root, name)) 88 | video_mp4_paths = list(video_mp4_paths) 89 | # random.shuffle(video_mp4_paths) 90 | video_mp4_paths.sort() 91 | print("Num of videos:", len(video_mp4_paths)) 92 | 93 | 94 | # draw dwpose 95 | for i in range(len(video_mp4_paths)): 96 | video_path = video_mp4_paths[i] 97 | video_name = os.path.relpath(video_path, video_dir) 98 | base_name = os.path.splitext(video_name)[0] 99 | 100 | pose_path = os.path.join(pose_dir, base_name + '.npy') 101 | if not os.path.exists(pose_path): 102 | print('no keypoint file:', pose_path) 103 | 104 | out_path = os.path.join(save_dir, base_name + '.mp4') 105 | if os.path.exists(out_path): 106 | print('already have rendered pose:', out_path) 107 | continue 108 | 109 | draw_dwpose(video_path, pose_path, out_path, args.draw_face) 110 | print(f"Process {i+1}/{len(video_mp4_paths)} video") 111 | 112 | print('all done!') -------------------------------------------------------------------------------- /extract_dwpose_keypoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 4 | import argparse 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from pose.script.dwpose import DWposeDetector 9 | from pose.script.tool import read_frames 10 | 11 | 12 | 13 | 14 | def process_single_video(video_path, detector, root_dir, save_dir): 15 | # print(video_path) 16 | video_name = os.path.relpath(video_path, root_dir) 17 | base_name=os.path.splitext(video_name)[0] 18 | out_path = os.path.join(save_dir, base_name + '.npy') 19 | if os.path.exists(out_path): 20 | return 21 | 22 | frames = read_frames(video_path) 23 | keypoints = [] 24 | for frame in tqdm(frames): 25 | keypoint = detector(frame) 26 | keypoints.append(keypoint) 27 | 28 | result = np.array(keypoints) 29 | np.save(out_path, result) 30 | 31 | 32 | 33 | def process_batch_videos(video_list, detector, root_dir, save_dir): 34 | for i, video_path in enumerate(video_list): 35 | process_single_video(video_path, detector, root_dir, save_dir) 36 | print(f"Process {i+1}/{len(video_list)} video") 37 | 38 | 39 | 40 | if __name__ == "__main__": 41 | 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--video_dir", type=str, default="./UBC_fashion/test") 44 | parser.add_argument("--save_dir", type=str, default=None) 45 | parser.add_argument("--yolox_config", type=str, default="./pose/config/yolox_l_8xb8-300e_coco.py") 46 | parser.add_argument("--dwpose_config", type=str, default="./pose/config/dwpose-l_384x288.py") 47 | parser.add_argument("--yolox_ckpt", type=str, default="./pretrained_weights/dwpose/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth") 48 | parser.add_argument("--dwpose_ckpt", type=str, default="./pretrained_weights/dwpose/dw-ll_ucoco_384.pth") 49 | args = parser.parse_args() 50 | 51 | # make save dir 52 | if args.save_dir is None: 53 | save_dir = args.video_dir + "_dwpose_keypoints" 54 | else: 55 | save_dir = args.save_dir 56 | if not os.path.exists(save_dir): 57 | os.makedirs(save_dir) 58 | 59 | # collect all video_folder paths 60 | video_mp4_paths = set() 61 | for root, dirs, files in os.walk(args.video_dir): 62 | for name in files: 63 | if name.endswith(".mp4"): 64 | video_mp4_paths.add(os.path.join(root, name)) 65 | video_mp4_paths = list(video_mp4_paths) 66 | video_mp4_paths.sort() 67 | print("Num of videos:", len(video_mp4_paths)) 68 | 69 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 70 | detector = DWposeDetector( 71 | det_config = args.yolox_config, 72 | det_ckpt = args.yolox_ckpt, 73 | pose_config = args.dwpose_config, 74 | pose_ckpt = args.dwpose_ckpt, 75 | keypoints_only=True 76 | ) 77 | detector = detector.to(device) 78 | 79 | process_batch_videos(video_mp4_paths, detector, args.video_dir, save_dir) 80 | print('all done!') 81 | -------------------------------------------------------------------------------- /extract_meta_info_multiple_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | # python tools/extract_meta_info.py --video_dirs /path/to/video_dir --dataset_name fashion 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--video_dirs", type=str, default=[ 8 | "./UBC_fashion/test", 9 | # "path_of_dataset_1", 10 | # "path_of_dataset_2", 11 | ]) 12 | parser.add_argument("--save_dir", type=str, default="./meta") 13 | parser.add_argument("--dataset_name", type=str, default="my_dataset") 14 | parser.add_argument("--meta_info_name", type=str, default=None) 15 | parser.add_argument("--draw_face", type=bool, default=False) 16 | args = parser.parse_args() 17 | 18 | if args.meta_info_name is None: 19 | args.meta_info_name = args.dataset_name 20 | 21 | # collect all video_folder paths 22 | meta_infos = [] 23 | 24 | for dataset_path in args.video_dirs: 25 | video_mp4_paths = set() 26 | if args.draw_face == True: 27 | pose_dir = dataset_path + "_dwpose" 28 | else: 29 | pose_dir = dataset_path + "_dwpose_without_face" 30 | 31 | for root, dirs, files in os.walk(dataset_path): 32 | for name in files: 33 | if name.endswith(".mp4"): 34 | video_mp4_paths.add(os.path.join(root, name)) 35 | 36 | video_mp4_paths = list(video_mp4_paths) 37 | print(dataset_path) 38 | print("video num:", len(video_mp4_paths)) 39 | 40 | 41 | for video_mp4_path in video_mp4_paths: 42 | relative_video_name = os.path.relpath(video_mp4_path, dataset_path) 43 | kps_path = os.path.join(pose_dir, relative_video_name) 44 | meta_infos.append({"video_path": video_mp4_path, "kps_path": kps_path}) 45 | 46 | save_path = os.path.join(args.save_dir, f"{args.meta_info_name}_meta.json") 47 | json.dump(meta_infos, open(save_path, "w")) 48 | print('data dumped') 49 | print('total pieces of data', len(meta_infos)) 50 | 51 | 52 | import cv2 53 | # check data (cannot read or damaged) 54 | for index, video_meta in enumerate(meta_infos): 55 | 56 | video_path = video_meta[ "video_path"] 57 | kps_path = video_meta[ "kps_path"] 58 | 59 | video = cv2.VideoCapture(video_path) 60 | kps = cv2.VideoCapture(kps_path) 61 | frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 62 | frame_count_2 = int(kps.get(cv2.CAP_PROP_FRAME_COUNT)) 63 | assert(frame_count) == (frame_count_2), f"{frame_count} != {frame_count_2} in {video_path}" 64 | 65 | if (index+1) % 100 == 0: print(index+1) 66 | 67 | print('data checked, no problem') 68 | -------------------------------------------------------------------------------- /musepose/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MusePose/61c52bd937224a614b3951419b735b639397cb62/musepose/__init__.py -------------------------------------------------------------------------------- /musepose/dataset/dance_image.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | import torch 5 | import torchvision.transforms as transforms 6 | from decord import VideoReader 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | from transformers import CLIPImageProcessor 10 | 11 | 12 | class HumanDanceDataset(Dataset): 13 | def __init__( 14 | self, 15 | img_size, 16 | img_scale=(1.0, 1.0), 17 | img_ratio=(0.9, 1.0), 18 | drop_ratio=0.1, 19 | data_meta_paths=["./data/fahsion_meta.json"], 20 | sample_margin=30, 21 | ): 22 | super().__init__() 23 | 24 | self.img_size = img_size 25 | self.img_scale = img_scale 26 | self.img_ratio = img_ratio 27 | self.sample_margin = sample_margin 28 | 29 | # ----- 30 | # vid_meta format: 31 | # [{'video_path': , 'kps_path': , 'other':}, 32 | # {'video_path': , 'kps_path': , 'other':}] 33 | # ----- 34 | vid_meta = [] 35 | for data_meta_path in data_meta_paths: 36 | vid_meta.extend(json.load(open(data_meta_path, "r"))) 37 | self.vid_meta = vid_meta 38 | 39 | self.clip_image_processor = CLIPImageProcessor() 40 | 41 | self.transform = transforms.Compose( 42 | [ 43 | # transforms.RandomResizedCrop( 44 | # self.img_size, 45 | # scale=self.img_scale, 46 | # ratio=self.img_ratio, 47 | # interpolation=transforms.InterpolationMode.BILINEAR, 48 | # ), 49 | transforms.Resize( 50 | self.img_size, 51 | ), 52 | transforms.ToTensor(), 53 | transforms.Normalize([0.5], [0.5]), 54 | ] 55 | ) 56 | 57 | self.cond_transform = transforms.Compose( 58 | [ 59 | # transforms.RandomResizedCrop( 60 | # self.img_size, 61 | # scale=self.img_scale, 62 | # ratio=self.img_ratio, 63 | # interpolation=transforms.InterpolationMode.BILINEAR, 64 | # ), 65 | transforms.Resize( 66 | self.img_size, 67 | ), 68 | transforms.ToTensor(), 69 | ] 70 | ) 71 | 72 | self.drop_ratio = drop_ratio 73 | 74 | def augmentation(self, image, transform, state=None): 75 | if state is not None: 76 | torch.set_rng_state(state) 77 | return transform(image) 78 | 79 | def __getitem__(self, index): 80 | video_meta = self.vid_meta[index] 81 | video_path = video_meta["video_path"] 82 | kps_path = video_meta["kps_path"] 83 | 84 | video_reader = VideoReader(video_path) 85 | kps_reader = VideoReader(kps_path) 86 | 87 | assert len(video_reader) == len( 88 | kps_reader 89 | ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}" 90 | 91 | video_length = len(video_reader) 92 | 93 | margin = min(self.sample_margin, video_length) 94 | 95 | ref_img_idx = random.randint(0, video_length - 1) 96 | if ref_img_idx + margin < video_length: 97 | tgt_img_idx = random.randint(ref_img_idx + margin, video_length - 1) 98 | elif ref_img_idx - margin > 0: 99 | tgt_img_idx = random.randint(0, ref_img_idx - margin) 100 | else: 101 | tgt_img_idx = random.randint(0, video_length - 1) 102 | 103 | ref_img = video_reader[ref_img_idx] 104 | ref_img_pil = Image.fromarray(ref_img.asnumpy()) 105 | tgt_img = video_reader[tgt_img_idx] 106 | tgt_img_pil = Image.fromarray(tgt_img.asnumpy()) 107 | 108 | tgt_pose = kps_reader[tgt_img_idx] 109 | tgt_pose_pil = Image.fromarray(tgt_pose.asnumpy()) 110 | 111 | state = torch.get_rng_state() 112 | tgt_img = self.augmentation(tgt_img_pil, self.transform, state) 113 | tgt_pose_img = self.augmentation(tgt_pose_pil, self.cond_transform, state) 114 | ref_img_vae = self.augmentation(ref_img_pil, self.transform, state) 115 | clip_image = self.clip_image_processor( 116 | images=ref_img_pil, return_tensors="pt" 117 | ).pixel_values[0] 118 | 119 | sample = dict( 120 | video_dir=video_path, 121 | img=tgt_img, 122 | tgt_pose=tgt_pose_img, 123 | ref_img=ref_img_vae, 124 | clip_images=clip_image, 125 | ) 126 | 127 | return sample 128 | 129 | def __len__(self): 130 | return len(self.vid_meta) 131 | -------------------------------------------------------------------------------- /musepose/dataset/dance_video.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from typing import List 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | import torchvision.transforms as transforms 9 | from decord import VideoReader 10 | from PIL import Image 11 | from torch.utils.data import Dataset 12 | from transformers import CLIPImageProcessor 13 | 14 | 15 | class HumanDanceVideoDataset(Dataset): 16 | def __init__( 17 | self, 18 | sample_rate, 19 | n_sample_frames, 20 | width, 21 | height, 22 | img_scale=(1.0, 1.0), 23 | img_ratio=(0.9, 1.0), 24 | drop_ratio=0.1, 25 | data_meta_paths=["./data/fashion_meta.json"], 26 | ): 27 | super().__init__() 28 | self.sample_rate = sample_rate 29 | self.n_sample_frames = n_sample_frames 30 | self.width = width 31 | self.height = height 32 | self.img_scale = img_scale 33 | self.img_ratio = img_ratio 34 | 35 | vid_meta = [] 36 | for data_meta_path in data_meta_paths: 37 | vid_meta.extend(json.load(open(data_meta_path, "r"))) 38 | self.vid_meta = vid_meta 39 | 40 | self.clip_image_processor = CLIPImageProcessor() 41 | 42 | self.pixel_transform = transforms.Compose( 43 | [ 44 | # transforms.RandomResizedCrop( 45 | # (height, width), 46 | # scale=self.img_scale, 47 | # ratio=self.img_ratio, 48 | # interpolation=transforms.InterpolationMode.BILINEAR, 49 | # ), 50 | transforms.Resize( 51 | (height, width), 52 | ), 53 | transforms.ToTensor(), 54 | transforms.Normalize([0.5], [0.5]), 55 | ] 56 | ) 57 | 58 | self.cond_transform = transforms.Compose( 59 | [ 60 | # transforms.RandomResizedCrop( 61 | # (height, width), 62 | # scale=self.img_scale, 63 | # ratio=self.img_ratio, 64 | # interpolation=transforms.InterpolationMode.BILINEAR, 65 | # ), 66 | transforms.Resize( 67 | (height, width), 68 | ), 69 | transforms.ToTensor(), 70 | ] 71 | ) 72 | 73 | self.drop_ratio = drop_ratio 74 | 75 | def augmentation(self, images, transform, state=None): 76 | if state is not None: 77 | torch.set_rng_state(state) 78 | if isinstance(images, List): 79 | transformed_images = [transform(img) for img in images] 80 | ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w) 81 | else: 82 | ret_tensor = transform(images) # (c, h, w) 83 | return ret_tensor 84 | 85 | def __getitem__(self, index): 86 | video_meta = self.vid_meta[index] 87 | video_path = video_meta["video_path"] 88 | kps_path = video_meta["kps_path"] 89 | 90 | video_reader = VideoReader(video_path) 91 | kps_reader = VideoReader(kps_path) 92 | 93 | assert len(video_reader) == len( 94 | kps_reader 95 | ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}" 96 | 97 | video_length = len(video_reader) 98 | video_fps = video_reader.get_avg_fps() 99 | # print("fps", video_fps) 100 | if video_fps > 30: # 30-60 101 | sample_rate = self.sample_rate*2 102 | else: 103 | sample_rate = self.sample_rate 104 | 105 | 106 | clip_length = min( 107 | video_length, (self.n_sample_frames - 1) * sample_rate + 1 108 | ) 109 | start_idx = random.randint(0, video_length - clip_length) 110 | batch_index = np.linspace( 111 | start_idx, start_idx + clip_length - 1, self.n_sample_frames, dtype=int 112 | ).tolist() 113 | 114 | # read frames and kps 115 | vid_pil_image_list = [] 116 | pose_pil_image_list = [] 117 | for index in batch_index: 118 | img = video_reader[index] 119 | vid_pil_image_list.append(Image.fromarray(img.asnumpy())) 120 | img = kps_reader[index] 121 | pose_pil_image_list.append(Image.fromarray(img.asnumpy())) 122 | 123 | ref_img_idx = random.randint(0, video_length - 1) 124 | ref_img = Image.fromarray(video_reader[ref_img_idx].asnumpy()) 125 | 126 | # transform 127 | state = torch.get_rng_state() 128 | pixel_values_vid = self.augmentation( 129 | vid_pil_image_list, self.pixel_transform, state 130 | ) 131 | pixel_values_pose = self.augmentation( 132 | pose_pil_image_list, self.cond_transform, state 133 | ) 134 | pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state) 135 | clip_ref_img = self.clip_image_processor( 136 | images=ref_img, return_tensors="pt" 137 | ).pixel_values[0] 138 | 139 | sample = dict( 140 | video_dir=video_path, 141 | pixel_values_vid=pixel_values_vid, 142 | pixel_values_pose=pixel_values_pose, 143 | pixel_values_ref_img=pixel_values_ref_img, 144 | clip_ref_img=clip_ref_img, 145 | ) 146 | 147 | return sample 148 | 149 | def __len__(self): 150 | return len(self.vid_meta) 151 | -------------------------------------------------------------------------------- /musepose/models/attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py 2 | 3 | from typing import Any, Dict, Optional 4 | 5 | import torch 6 | from diffusers.models.attention import AdaLayerNorm, Attention, FeedForward 7 | from diffusers.models.embeddings import SinusoidalPositionalEmbedding 8 | from einops import rearrange 9 | from torch import nn 10 | 11 | 12 | class BasicTransformerBlock(nn.Module): 13 | r""" 14 | A basic Transformer block. 15 | 16 | Parameters: 17 | dim (`int`): The number of channels in the input and output. 18 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 19 | attention_head_dim (`int`): The number of channels in each head. 20 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 21 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 22 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 23 | num_embeds_ada_norm (: 24 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. 25 | attention_bias (: 26 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. 27 | only_cross_attention (`bool`, *optional*): 28 | Whether to use only cross-attention layers. In this case two cross attention layers are used. 29 | double_self_attention (`bool`, *optional*): 30 | Whether to use two self-attention layers. In this case no cross attention layers are used. 31 | upcast_attention (`bool`, *optional*): 32 | Whether to upcast the attention computation to float32. This is useful for mixed precision training. 33 | norm_elementwise_affine (`bool`, *optional*, defaults to `True`): 34 | Whether to use learnable elementwise affine parameters for normalization. 35 | norm_type (`str`, *optional*, defaults to `"layer_norm"`): 36 | The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. 37 | final_dropout (`bool` *optional*, defaults to False): 38 | Whether to apply a final dropout after the last feed-forward layer. 39 | attention_type (`str`, *optional*, defaults to `"default"`): 40 | The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. 41 | positional_embeddings (`str`, *optional*, defaults to `None`): 42 | The type of positional embeddings to apply to. 43 | num_positional_embeddings (`int`, *optional*, defaults to `None`): 44 | The maximum number of positional embeddings to apply. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | dim: int, 50 | num_attention_heads: int, 51 | attention_head_dim: int, 52 | dropout=0.0, 53 | cross_attention_dim: Optional[int] = None, 54 | activation_fn: str = "geglu", 55 | num_embeds_ada_norm: Optional[int] = None, 56 | attention_bias: bool = False, 57 | only_cross_attention: bool = False, 58 | double_self_attention: bool = False, 59 | upcast_attention: bool = False, 60 | norm_elementwise_affine: bool = True, 61 | norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' 62 | norm_eps: float = 1e-5, 63 | final_dropout: bool = False, 64 | attention_type: str = "default", 65 | positional_embeddings: Optional[str] = None, 66 | num_positional_embeddings: Optional[int] = None, 67 | ): 68 | super().__init__() 69 | self.only_cross_attention = only_cross_attention 70 | 71 | self.use_ada_layer_norm_zero = ( 72 | num_embeds_ada_norm is not None 73 | ) and norm_type == "ada_norm_zero" 74 | self.use_ada_layer_norm = ( 75 | num_embeds_ada_norm is not None 76 | ) and norm_type == "ada_norm" 77 | self.use_ada_layer_norm_single = norm_type == "ada_norm_single" 78 | self.use_layer_norm = norm_type == "layer_norm" 79 | 80 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: 81 | raise ValueError( 82 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" 83 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." 84 | ) 85 | 86 | if positional_embeddings and (num_positional_embeddings is None): 87 | raise ValueError( 88 | "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." 89 | ) 90 | 91 | if positional_embeddings == "sinusoidal": 92 | self.pos_embed = SinusoidalPositionalEmbedding( 93 | dim, max_seq_length=num_positional_embeddings 94 | ) 95 | else: 96 | self.pos_embed = None 97 | 98 | # Define 3 blocks. Each block has its own normalization layer. 99 | # 1. Self-Attn 100 | if self.use_ada_layer_norm: 101 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) 102 | elif self.use_ada_layer_norm_zero: 103 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) 104 | else: 105 | self.norm1 = nn.LayerNorm( 106 | dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps 107 | ) 108 | 109 | self.attn1 = Attention( 110 | query_dim=dim, 111 | heads=num_attention_heads, 112 | dim_head=attention_head_dim, 113 | dropout=dropout, 114 | bias=attention_bias, 115 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 116 | upcast_attention=upcast_attention, 117 | ) 118 | 119 | # 2. Cross-Attn 120 | if cross_attention_dim is not None or double_self_attention: 121 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 122 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 123 | # the second cross attention block. 124 | self.norm2 = ( 125 | AdaLayerNorm(dim, num_embeds_ada_norm) 126 | if self.use_ada_layer_norm 127 | else nn.LayerNorm( 128 | dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps 129 | ) 130 | ) 131 | self.attn2 = Attention( 132 | query_dim=dim, 133 | cross_attention_dim=cross_attention_dim 134 | if not double_self_attention 135 | else None, 136 | heads=num_attention_heads, 137 | dim_head=attention_head_dim, 138 | dropout=dropout, 139 | bias=attention_bias, 140 | upcast_attention=upcast_attention, 141 | ) # is self-attn if encoder_hidden_states is none 142 | else: 143 | self.norm2 = None 144 | self.attn2 = None 145 | 146 | # 3. Feed-forward 147 | if not self.use_ada_layer_norm_single: 148 | self.norm3 = nn.LayerNorm( 149 | dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps 150 | ) 151 | 152 | self.ff = FeedForward( 153 | dim, 154 | dropout=dropout, 155 | activation_fn=activation_fn, 156 | final_dropout=final_dropout, 157 | ) 158 | 159 | # 4. Fuser 160 | if attention_type == "gated" or attention_type == "gated-text-image": 161 | self.fuser = GatedSelfAttentionDense( 162 | dim, cross_attention_dim, num_attention_heads, attention_head_dim 163 | ) 164 | 165 | # 5. Scale-shift for PixArt-Alpha. 166 | if self.use_ada_layer_norm_single: 167 | self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) 168 | 169 | # let chunk size default to None 170 | self._chunk_size = None 171 | self._chunk_dim = 0 172 | 173 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): 174 | # Sets chunk feed-forward 175 | self._chunk_size = chunk_size 176 | self._chunk_dim = dim 177 | 178 | def forward( 179 | self, 180 | hidden_states: torch.FloatTensor, 181 | attention_mask: Optional[torch.FloatTensor] = None, 182 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 183 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 184 | timestep: Optional[torch.LongTensor] = None, 185 | cross_attention_kwargs: Dict[str, Any] = None, 186 | class_labels: Optional[torch.LongTensor] = None, 187 | ) -> torch.FloatTensor: 188 | # Notice that normalization is always applied before the real computation in the following blocks. 189 | # 0. Self-Attention 190 | batch_size = hidden_states.shape[0] 191 | 192 | if self.use_ada_layer_norm: 193 | norm_hidden_states = self.norm1(hidden_states, timestep) 194 | elif self.use_ada_layer_norm_zero: 195 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 196 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 197 | ) 198 | elif self.use_layer_norm: 199 | norm_hidden_states = self.norm1(hidden_states) 200 | elif self.use_ada_layer_norm_single: 201 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 202 | self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) 203 | ).chunk(6, dim=1) 204 | norm_hidden_states = self.norm1(hidden_states) 205 | norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa 206 | norm_hidden_states = norm_hidden_states.squeeze(1) 207 | else: 208 | raise ValueError("Incorrect norm used") 209 | 210 | if self.pos_embed is not None: 211 | norm_hidden_states = self.pos_embed(norm_hidden_states) 212 | 213 | # 1. Retrieve lora scale. 214 | lora_scale = ( 215 | cross_attention_kwargs.get("scale", 1.0) 216 | if cross_attention_kwargs is not None 217 | else 1.0 218 | ) 219 | 220 | # 2. Prepare GLIGEN inputs 221 | cross_attention_kwargs = ( 222 | cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 223 | ) 224 | gligen_kwargs = cross_attention_kwargs.pop("gligen", None) 225 | 226 | attn_output = self.attn1( 227 | norm_hidden_states, 228 | encoder_hidden_states=encoder_hidden_states 229 | if self.only_cross_attention 230 | else None, 231 | attention_mask=attention_mask, 232 | **cross_attention_kwargs, 233 | ) 234 | if self.use_ada_layer_norm_zero: 235 | attn_output = gate_msa.unsqueeze(1) * attn_output 236 | elif self.use_ada_layer_norm_single: 237 | attn_output = gate_msa * attn_output 238 | 239 | hidden_states = attn_output + hidden_states 240 | if hidden_states.ndim == 4: 241 | hidden_states = hidden_states.squeeze(1) 242 | 243 | # 2.5 GLIGEN Control 244 | if gligen_kwargs is not None: 245 | hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) 246 | 247 | # 3. Cross-Attention 248 | if self.attn2 is not None: 249 | if self.use_ada_layer_norm: 250 | norm_hidden_states = self.norm2(hidden_states, timestep) 251 | elif self.use_ada_layer_norm_zero or self.use_layer_norm: 252 | norm_hidden_states = self.norm2(hidden_states) 253 | elif self.use_ada_layer_norm_single: 254 | # For PixArt norm2 isn't applied here: 255 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 256 | norm_hidden_states = hidden_states 257 | else: 258 | raise ValueError("Incorrect norm") 259 | 260 | if self.pos_embed is not None and self.use_ada_layer_norm_single is False: 261 | norm_hidden_states = self.pos_embed(norm_hidden_states) 262 | 263 | attn_output = self.attn2( 264 | norm_hidden_states, 265 | encoder_hidden_states=encoder_hidden_states, 266 | attention_mask=encoder_attention_mask, 267 | **cross_attention_kwargs, 268 | ) 269 | hidden_states = attn_output + hidden_states 270 | 271 | # 4. Feed-forward 272 | if not self.use_ada_layer_norm_single: 273 | norm_hidden_states = self.norm3(hidden_states) 274 | 275 | if self.use_ada_layer_norm_zero: 276 | norm_hidden_states = ( 277 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 278 | ) 279 | 280 | if self.use_ada_layer_norm_single: 281 | norm_hidden_states = self.norm2(hidden_states) 282 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp 283 | 284 | ff_output = self.ff(norm_hidden_states, scale=lora_scale) 285 | 286 | if self.use_ada_layer_norm_zero: 287 | ff_output = gate_mlp.unsqueeze(1) * ff_output 288 | elif self.use_ada_layer_norm_single: 289 | ff_output = gate_mlp * ff_output 290 | 291 | hidden_states = ff_output + hidden_states 292 | if hidden_states.ndim == 4: 293 | hidden_states = hidden_states.squeeze(1) 294 | 295 | return hidden_states 296 | 297 | 298 | class TemporalBasicTransformerBlock(nn.Module): 299 | def __init__( 300 | self, 301 | dim: int, 302 | num_attention_heads: int, 303 | attention_head_dim: int, 304 | dropout=0.0, 305 | cross_attention_dim: Optional[int] = None, 306 | activation_fn: str = "geglu", 307 | num_embeds_ada_norm: Optional[int] = None, 308 | attention_bias: bool = False, 309 | only_cross_attention: bool = False, 310 | upcast_attention: bool = False, 311 | unet_use_cross_frame_attention=None, 312 | unet_use_temporal_attention=None, 313 | ): 314 | super().__init__() 315 | self.only_cross_attention = only_cross_attention 316 | self.use_ada_layer_norm = num_embeds_ada_norm is not None 317 | self.unet_use_cross_frame_attention = unet_use_cross_frame_attention 318 | self.unet_use_temporal_attention = unet_use_temporal_attention 319 | 320 | # SC-Attn 321 | self.attn1 = Attention( 322 | query_dim=dim, 323 | heads=num_attention_heads, 324 | dim_head=attention_head_dim, 325 | dropout=dropout, 326 | bias=attention_bias, 327 | upcast_attention=upcast_attention, 328 | ) 329 | self.norm1 = ( 330 | AdaLayerNorm(dim, num_embeds_ada_norm) 331 | if self.use_ada_layer_norm 332 | else nn.LayerNorm(dim) 333 | ) 334 | 335 | # Cross-Attn 336 | if cross_attention_dim is not None: 337 | self.attn2 = Attention( 338 | query_dim=dim, 339 | cross_attention_dim=cross_attention_dim, 340 | heads=num_attention_heads, 341 | dim_head=attention_head_dim, 342 | dropout=dropout, 343 | bias=attention_bias, 344 | upcast_attention=upcast_attention, 345 | ) 346 | else: 347 | self.attn2 = None 348 | 349 | if cross_attention_dim is not None: 350 | self.norm2 = ( 351 | AdaLayerNorm(dim, num_embeds_ada_norm) 352 | if self.use_ada_layer_norm 353 | else nn.LayerNorm(dim) 354 | ) 355 | else: 356 | self.norm2 = None 357 | 358 | # Feed-forward 359 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 360 | self.norm3 = nn.LayerNorm(dim) 361 | self.use_ada_layer_norm_zero = False 362 | 363 | # Temp-Attn 364 | assert unet_use_temporal_attention is not None 365 | if unet_use_temporal_attention: 366 | self.attn_temp = Attention( 367 | query_dim=dim, 368 | heads=num_attention_heads, 369 | dim_head=attention_head_dim, 370 | dropout=dropout, 371 | bias=attention_bias, 372 | upcast_attention=upcast_attention, 373 | ) 374 | nn.init.zeros_(self.attn_temp.to_out[0].weight.data) 375 | self.norm_temp = ( 376 | AdaLayerNorm(dim, num_embeds_ada_norm) 377 | if self.use_ada_layer_norm 378 | else nn.LayerNorm(dim) 379 | ) 380 | 381 | def forward( 382 | self, 383 | hidden_states, 384 | encoder_hidden_states=None, 385 | timestep=None, 386 | attention_mask=None, 387 | video_length=None, 388 | ): 389 | norm_hidden_states = ( 390 | self.norm1(hidden_states, timestep) 391 | if self.use_ada_layer_norm 392 | else self.norm1(hidden_states) 393 | ) 394 | 395 | if self.unet_use_cross_frame_attention: 396 | hidden_states = ( 397 | self.attn1( 398 | norm_hidden_states, 399 | attention_mask=attention_mask, 400 | video_length=video_length, 401 | ) 402 | + hidden_states 403 | ) 404 | else: 405 | hidden_states = ( 406 | self.attn1(norm_hidden_states, attention_mask=attention_mask) 407 | + hidden_states 408 | ) 409 | 410 | if self.attn2 is not None: 411 | # Cross-Attention 412 | norm_hidden_states = ( 413 | self.norm2(hidden_states, timestep) 414 | if self.use_ada_layer_norm 415 | else self.norm2(hidden_states) 416 | ) 417 | hidden_states = ( 418 | self.attn2( 419 | norm_hidden_states, 420 | encoder_hidden_states=encoder_hidden_states, 421 | attention_mask=attention_mask, 422 | ) 423 | + hidden_states 424 | ) 425 | 426 | # Feed-forward 427 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 428 | 429 | # Temporal-Attention 430 | if self.unet_use_temporal_attention: 431 | d = hidden_states.shape[1] 432 | hidden_states = rearrange( 433 | hidden_states, "(b f) d c -> (b d) f c", f=video_length 434 | ) 435 | norm_hidden_states = ( 436 | self.norm_temp(hidden_states, timestep) 437 | if self.use_ada_layer_norm 438 | else self.norm_temp(hidden_states) 439 | ) 440 | hidden_states = self.attn_temp(norm_hidden_states) + hidden_states 441 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 442 | 443 | return hidden_states 444 | -------------------------------------------------------------------------------- /musepose/models/motion_module.py: -------------------------------------------------------------------------------- 1 | # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py 2 | import math 3 | from dataclasses import dataclass 4 | from typing import Callable, Optional 5 | 6 | import torch 7 | from diffusers.models.attention import FeedForward 8 | from diffusers.models.attention_processor import Attention, AttnProcessor 9 | from diffusers.utils import BaseOutput 10 | from diffusers.utils.import_utils import is_xformers_available 11 | from einops import rearrange, repeat 12 | from torch import nn 13 | 14 | 15 | def zero_module(module): 16 | # Zero out the parameters of a module and return it. 17 | for p in module.parameters(): 18 | p.detach().zero_() 19 | return module 20 | 21 | 22 | @dataclass 23 | class TemporalTransformer3DModelOutput(BaseOutput): 24 | sample: torch.FloatTensor 25 | 26 | 27 | if is_xformers_available(): 28 | import xformers 29 | import xformers.ops 30 | else: 31 | xformers = None 32 | 33 | 34 | def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict): 35 | if motion_module_type == "Vanilla": 36 | return VanillaTemporalModule( 37 | in_channels=in_channels, 38 | **motion_module_kwargs, 39 | ) 40 | else: 41 | raise ValueError 42 | 43 | 44 | class VanillaTemporalModule(nn.Module): 45 | def __init__( 46 | self, 47 | in_channels, 48 | num_attention_heads=8, 49 | num_transformer_block=2, 50 | attention_block_types=("Temporal_Self", "Temporal_Self"), 51 | cross_frame_attention_mode=None, 52 | temporal_position_encoding=False, 53 | temporal_position_encoding_max_len=24, 54 | temporal_attention_dim_div=1, 55 | zero_initialize=True, 56 | ): 57 | super().__init__() 58 | 59 | self.temporal_transformer = TemporalTransformer3DModel( 60 | in_channels=in_channels, 61 | num_attention_heads=num_attention_heads, 62 | attention_head_dim=in_channels 63 | // num_attention_heads 64 | // temporal_attention_dim_div, 65 | num_layers=num_transformer_block, 66 | attention_block_types=attention_block_types, 67 | cross_frame_attention_mode=cross_frame_attention_mode, 68 | temporal_position_encoding=temporal_position_encoding, 69 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 70 | ) 71 | 72 | if zero_initialize: 73 | self.temporal_transformer.proj_out = zero_module( 74 | self.temporal_transformer.proj_out 75 | ) 76 | 77 | def forward( 78 | self, 79 | input_tensor, 80 | temb, 81 | encoder_hidden_states, 82 | attention_mask=None, 83 | anchor_frame_idx=None, 84 | ): 85 | hidden_states = input_tensor 86 | hidden_states = self.temporal_transformer( 87 | hidden_states, encoder_hidden_states, attention_mask 88 | ) 89 | 90 | output = hidden_states 91 | return output 92 | 93 | 94 | class TemporalTransformer3DModel(nn.Module): 95 | def __init__( 96 | self, 97 | in_channels, 98 | num_attention_heads, 99 | attention_head_dim, 100 | num_layers, 101 | attention_block_types=( 102 | "Temporal_Self", 103 | "Temporal_Self", 104 | ), 105 | dropout=0.0, 106 | norm_num_groups=32, 107 | cross_attention_dim=768, 108 | activation_fn="geglu", 109 | attention_bias=False, 110 | upcast_attention=False, 111 | cross_frame_attention_mode=None, 112 | temporal_position_encoding=False, 113 | temporal_position_encoding_max_len=24, 114 | ): 115 | super().__init__() 116 | 117 | inner_dim = num_attention_heads * attention_head_dim 118 | 119 | self.norm = torch.nn.GroupNorm( 120 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True 121 | ) 122 | self.proj_in = nn.Linear(in_channels, inner_dim) 123 | 124 | self.transformer_blocks = nn.ModuleList( 125 | [ 126 | TemporalTransformerBlock( 127 | dim=inner_dim, 128 | num_attention_heads=num_attention_heads, 129 | attention_head_dim=attention_head_dim, 130 | attention_block_types=attention_block_types, 131 | dropout=dropout, 132 | norm_num_groups=norm_num_groups, 133 | cross_attention_dim=cross_attention_dim, 134 | activation_fn=activation_fn, 135 | attention_bias=attention_bias, 136 | upcast_attention=upcast_attention, 137 | cross_frame_attention_mode=cross_frame_attention_mode, 138 | temporal_position_encoding=temporal_position_encoding, 139 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 140 | ) 141 | for d in range(num_layers) 142 | ] 143 | ) 144 | self.proj_out = nn.Linear(inner_dim, in_channels) 145 | 146 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): 147 | assert ( 148 | hidden_states.dim() == 5 149 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 150 | video_length = hidden_states.shape[2] 151 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 152 | 153 | batch, channel, height, weight = hidden_states.shape 154 | residual = hidden_states 155 | 156 | hidden_states = self.norm(hidden_states) 157 | inner_dim = hidden_states.shape[1] 158 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 159 | batch, height * weight, inner_dim 160 | ) 161 | hidden_states = self.proj_in(hidden_states) 162 | 163 | # Transformer Blocks 164 | for block in self.transformer_blocks: 165 | hidden_states = block( 166 | hidden_states, 167 | encoder_hidden_states=encoder_hidden_states, 168 | video_length=video_length, 169 | ) 170 | 171 | # output 172 | hidden_states = self.proj_out(hidden_states) 173 | hidden_states = ( 174 | hidden_states.reshape(batch, height, weight, inner_dim) 175 | .permute(0, 3, 1, 2) 176 | .contiguous() 177 | ) 178 | 179 | output = hidden_states + residual 180 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 181 | 182 | return output 183 | 184 | 185 | class TemporalTransformerBlock(nn.Module): 186 | def __init__( 187 | self, 188 | dim, 189 | num_attention_heads, 190 | attention_head_dim, 191 | attention_block_types=( 192 | "Temporal_Self", 193 | "Temporal_Self", 194 | ), 195 | dropout=0.0, 196 | norm_num_groups=32, 197 | cross_attention_dim=768, 198 | activation_fn="geglu", 199 | attention_bias=False, 200 | upcast_attention=False, 201 | cross_frame_attention_mode=None, 202 | temporal_position_encoding=False, 203 | temporal_position_encoding_max_len=24, 204 | ): 205 | super().__init__() 206 | 207 | attention_blocks = [] 208 | norms = [] 209 | 210 | for block_name in attention_block_types: 211 | attention_blocks.append( 212 | VersatileAttention( 213 | attention_mode=block_name.split("_")[0], 214 | cross_attention_dim=cross_attention_dim 215 | if block_name.endswith("_Cross") 216 | else None, 217 | query_dim=dim, 218 | heads=num_attention_heads, 219 | dim_head=attention_head_dim, 220 | dropout=dropout, 221 | bias=attention_bias, 222 | upcast_attention=upcast_attention, 223 | cross_frame_attention_mode=cross_frame_attention_mode, 224 | temporal_position_encoding=temporal_position_encoding, 225 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 226 | ) 227 | ) 228 | norms.append(nn.LayerNorm(dim)) 229 | 230 | self.attention_blocks = nn.ModuleList(attention_blocks) 231 | self.norms = nn.ModuleList(norms) 232 | 233 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 234 | self.ff_norm = nn.LayerNorm(dim) 235 | 236 | def forward( 237 | self, 238 | hidden_states, 239 | encoder_hidden_states=None, 240 | attention_mask=None, 241 | video_length=None, 242 | ): 243 | for attention_block, norm in zip(self.attention_blocks, self.norms): 244 | norm_hidden_states = norm(hidden_states) 245 | hidden_states = ( 246 | attention_block( 247 | norm_hidden_states, 248 | encoder_hidden_states=encoder_hidden_states 249 | if attention_block.is_cross_attention 250 | else None, 251 | video_length=video_length, 252 | ) 253 | + hidden_states 254 | ) 255 | 256 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states 257 | 258 | output = hidden_states 259 | return output 260 | 261 | 262 | class PositionalEncoding(nn.Module): 263 | def __init__(self, d_model, dropout=0.0, max_len=24): 264 | super().__init__() 265 | self.dropout = nn.Dropout(p=dropout) 266 | position = torch.arange(max_len).unsqueeze(1) 267 | div_term = torch.exp( 268 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) 269 | ) 270 | pe = torch.zeros(1, max_len, d_model) 271 | pe[0, :, 0::2] = torch.sin(position * div_term) 272 | pe[0, :, 1::2] = torch.cos(position * div_term) 273 | self.register_buffer("pe", pe) 274 | 275 | def forward(self, x): 276 | x = x + self.pe[:, : x.size(1)] 277 | return self.dropout(x) 278 | 279 | 280 | class VersatileAttention(Attention): 281 | def __init__( 282 | self, 283 | attention_mode=None, 284 | cross_frame_attention_mode=None, 285 | temporal_position_encoding=False, 286 | temporal_position_encoding_max_len=24, 287 | *args, 288 | **kwargs, 289 | ): 290 | super().__init__(*args, **kwargs) 291 | assert attention_mode == "Temporal" 292 | 293 | self.attention_mode = attention_mode 294 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None 295 | 296 | self.pos_encoder = ( 297 | PositionalEncoding( 298 | kwargs["query_dim"], 299 | dropout=0.0, 300 | max_len=temporal_position_encoding_max_len, 301 | ) 302 | if (temporal_position_encoding and attention_mode == "Temporal") 303 | else None 304 | ) 305 | 306 | def extra_repr(self): 307 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" 308 | 309 | def set_use_memory_efficient_attention_xformers( 310 | self, 311 | use_memory_efficient_attention_xformers: bool, 312 | attention_op: Optional[Callable] = None, 313 | ): 314 | if use_memory_efficient_attention_xformers: 315 | if not is_xformers_available(): 316 | raise ModuleNotFoundError( 317 | ( 318 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 319 | " xformers" 320 | ), 321 | name="xformers", 322 | ) 323 | elif not torch.cuda.is_available(): 324 | raise ValueError( 325 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" 326 | " only available for GPU " 327 | ) 328 | else: 329 | try: 330 | # Make sure we can run the memory efficient attention 331 | _ = xformers.ops.memory_efficient_attention( 332 | torch.randn((1, 2, 40), device="cuda"), 333 | torch.randn((1, 2, 40), device="cuda"), 334 | torch.randn((1, 2, 40), device="cuda"), 335 | ) 336 | except Exception as e: 337 | raise e 338 | 339 | # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13. 340 | # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13. 341 | # You don't need XFormersAttnProcessor here. 342 | # processor = XFormersAttnProcessor( 343 | # attention_op=attention_op, 344 | # ) 345 | processor = AttnProcessor() 346 | else: 347 | processor = AttnProcessor() 348 | 349 | self.set_processor(processor) 350 | 351 | def forward( 352 | self, 353 | hidden_states, 354 | encoder_hidden_states=None, 355 | attention_mask=None, 356 | video_length=None, 357 | **cross_attention_kwargs, 358 | ): 359 | if self.attention_mode == "Temporal": 360 | d = hidden_states.shape[1] # d means HxW 361 | hidden_states = rearrange( 362 | hidden_states, "(b f) d c -> (b d) f c", f=video_length 363 | ) 364 | 365 | if self.pos_encoder is not None: 366 | hidden_states = self.pos_encoder(hidden_states) 367 | 368 | encoder_hidden_states = ( 369 | repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) 370 | if encoder_hidden_states is not None 371 | else encoder_hidden_states 372 | ) 373 | 374 | else: 375 | raise NotImplementedError 376 | 377 | hidden_states = self.processor( 378 | self, 379 | hidden_states, 380 | encoder_hidden_states=encoder_hidden_states, 381 | attention_mask=attention_mask, 382 | **cross_attention_kwargs, 383 | ) 384 | 385 | if self.attention_mode == "Temporal": 386 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 387 | 388 | return hidden_states 389 | -------------------------------------------------------------------------------- /musepose/models/mutual_self_attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py 2 | from typing import Any, Dict, Optional 3 | 4 | import torch 5 | from einops import rearrange 6 | 7 | from musepose.models.attention import TemporalBasicTransformerBlock 8 | 9 | from .attention import BasicTransformerBlock 10 | 11 | 12 | def torch_dfs(model: torch.nn.Module): 13 | result = [model] 14 | for child in model.children(): 15 | result += torch_dfs(child) 16 | return result 17 | 18 | 19 | class ReferenceAttentionControl: 20 | def __init__( 21 | self, 22 | unet, 23 | mode="write", 24 | do_classifier_free_guidance=False, 25 | attention_auto_machine_weight=float("inf"), 26 | gn_auto_machine_weight=1.0, 27 | style_fidelity=1.0, 28 | reference_attn=True, 29 | reference_adain=False, 30 | fusion_blocks="midup", 31 | batch_size=1, 32 | ) -> None: 33 | # 10. Modify self attention and group norm 34 | self.unet = unet 35 | assert mode in ["read", "write"] 36 | assert fusion_blocks in ["midup", "full"] 37 | self.reference_attn = reference_attn 38 | self.reference_adain = reference_adain 39 | self.fusion_blocks = fusion_blocks 40 | self.register_reference_hooks( 41 | mode, 42 | do_classifier_free_guidance, 43 | attention_auto_machine_weight, 44 | gn_auto_machine_weight, 45 | style_fidelity, 46 | reference_attn, 47 | reference_adain, 48 | fusion_blocks, 49 | batch_size=batch_size, 50 | ) 51 | 52 | def register_reference_hooks( 53 | self, 54 | mode, 55 | do_classifier_free_guidance, 56 | attention_auto_machine_weight, 57 | gn_auto_machine_weight, 58 | style_fidelity, 59 | reference_attn, 60 | reference_adain, 61 | dtype=torch.float16, 62 | batch_size=1, 63 | num_images_per_prompt=1, 64 | device=torch.device("cpu"), 65 | fusion_blocks="midup", 66 | ): 67 | MODE = mode 68 | do_classifier_free_guidance = do_classifier_free_guidance 69 | attention_auto_machine_weight = attention_auto_machine_weight 70 | gn_auto_machine_weight = gn_auto_machine_weight 71 | style_fidelity = style_fidelity 72 | reference_attn = reference_attn 73 | reference_adain = reference_adain 74 | fusion_blocks = fusion_blocks 75 | num_images_per_prompt = num_images_per_prompt 76 | dtype = dtype 77 | if do_classifier_free_guidance: 78 | uc_mask = ( 79 | torch.Tensor( 80 | [1] * batch_size * num_images_per_prompt * 16 81 | + [0] * batch_size * num_images_per_prompt * 16 82 | ) 83 | .to(device) 84 | .bool() 85 | ) 86 | else: 87 | uc_mask = ( 88 | torch.Tensor([0] * batch_size * num_images_per_prompt * 2) 89 | .to(device) 90 | .bool() 91 | ) 92 | 93 | def hacked_basic_transformer_inner_forward( 94 | self, 95 | hidden_states: torch.FloatTensor, 96 | attention_mask: Optional[torch.FloatTensor] = None, 97 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 98 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 99 | timestep: Optional[torch.LongTensor] = None, 100 | cross_attention_kwargs: Dict[str, Any] = None, 101 | class_labels: Optional[torch.LongTensor] = None, 102 | video_length=None, 103 | ): 104 | if self.use_ada_layer_norm: # False 105 | norm_hidden_states = self.norm1(hidden_states, timestep) 106 | elif self.use_ada_layer_norm_zero: 107 | ( 108 | norm_hidden_states, 109 | gate_msa, 110 | shift_mlp, 111 | scale_mlp, 112 | gate_mlp, 113 | ) = self.norm1( 114 | hidden_states, 115 | timestep, 116 | class_labels, 117 | hidden_dtype=hidden_states.dtype, 118 | ) 119 | else: 120 | norm_hidden_states = self.norm1(hidden_states) 121 | 122 | # 1. Self-Attention 123 | # self.only_cross_attention = False 124 | cross_attention_kwargs = ( 125 | cross_attention_kwargs if cross_attention_kwargs is not None else {} 126 | ) 127 | if self.only_cross_attention: 128 | attn_output = self.attn1( 129 | norm_hidden_states, 130 | encoder_hidden_states=encoder_hidden_states 131 | if self.only_cross_attention 132 | else None, 133 | attention_mask=attention_mask, 134 | **cross_attention_kwargs, 135 | ) 136 | else: 137 | if MODE == "write": 138 | self.bank.append(norm_hidden_states.clone()) 139 | attn_output = self.attn1( 140 | norm_hidden_states, 141 | encoder_hidden_states=encoder_hidden_states 142 | if self.only_cross_attention 143 | else None, 144 | attention_mask=attention_mask, 145 | **cross_attention_kwargs, 146 | ) 147 | if MODE == "read": 148 | bank_fea = [ 149 | rearrange( 150 | d.unsqueeze(1).repeat(1, video_length, 1, 1), 151 | "b t l c -> (b t) l c", 152 | ) 153 | for d in self.bank 154 | ] 155 | modify_norm_hidden_states = torch.cat( 156 | [norm_hidden_states] + bank_fea, dim=1 157 | ) 158 | hidden_states_uc = ( 159 | self.attn1( 160 | norm_hidden_states, 161 | encoder_hidden_states=modify_norm_hidden_states, 162 | attention_mask=attention_mask, 163 | ) 164 | + hidden_states 165 | ) 166 | if do_classifier_free_guidance: 167 | hidden_states_c = hidden_states_uc.clone() 168 | _uc_mask = uc_mask.clone() 169 | if hidden_states.shape[0] != _uc_mask.shape[0]: 170 | _uc_mask = ( 171 | torch.Tensor( 172 | [1] * (hidden_states.shape[0] // 2) 173 | + [0] * (hidden_states.shape[0] // 2) 174 | ) 175 | .to(device) 176 | .bool() 177 | ) 178 | hidden_states_c[_uc_mask] = ( 179 | self.attn1( 180 | norm_hidden_states[_uc_mask], 181 | encoder_hidden_states=norm_hidden_states[_uc_mask], 182 | attention_mask=attention_mask, 183 | ) 184 | + hidden_states[_uc_mask] 185 | ) 186 | hidden_states = hidden_states_c.clone() 187 | else: 188 | hidden_states = hidden_states_uc 189 | 190 | # self.bank.clear() 191 | if self.attn2 is not None: 192 | # Cross-Attention 193 | norm_hidden_states = ( 194 | self.norm2(hidden_states, timestep) 195 | if self.use_ada_layer_norm 196 | else self.norm2(hidden_states) 197 | ) 198 | hidden_states = ( 199 | self.attn2( 200 | norm_hidden_states, 201 | encoder_hidden_states=encoder_hidden_states, 202 | attention_mask=attention_mask, 203 | ) 204 | + hidden_states 205 | ) 206 | 207 | # Feed-forward 208 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 209 | 210 | # Temporal-Attention 211 | if self.unet_use_temporal_attention: 212 | d = hidden_states.shape[1] 213 | hidden_states = rearrange( 214 | hidden_states, "(b f) d c -> (b d) f c", f=video_length 215 | ) 216 | norm_hidden_states = ( 217 | self.norm_temp(hidden_states, timestep) 218 | if self.use_ada_layer_norm 219 | else self.norm_temp(hidden_states) 220 | ) 221 | hidden_states = ( 222 | self.attn_temp(norm_hidden_states) + hidden_states 223 | ) 224 | hidden_states = rearrange( 225 | hidden_states, "(b d) f c -> (b f) d c", d=d 226 | ) 227 | 228 | return hidden_states 229 | 230 | if self.use_ada_layer_norm_zero: 231 | attn_output = gate_msa.unsqueeze(1) * attn_output 232 | hidden_states = attn_output + hidden_states 233 | 234 | if self.attn2 is not None: 235 | norm_hidden_states = ( 236 | self.norm2(hidden_states, timestep) 237 | if self.use_ada_layer_norm 238 | else self.norm2(hidden_states) 239 | ) 240 | 241 | # 2. Cross-Attention 242 | attn_output = self.attn2( 243 | norm_hidden_states, 244 | encoder_hidden_states=encoder_hidden_states, 245 | attention_mask=encoder_attention_mask, 246 | **cross_attention_kwargs, 247 | ) 248 | hidden_states = attn_output + hidden_states 249 | 250 | # 3. Feed-forward 251 | norm_hidden_states = self.norm3(hidden_states) 252 | 253 | if self.use_ada_layer_norm_zero: 254 | norm_hidden_states = ( 255 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 256 | ) 257 | 258 | ff_output = self.ff(norm_hidden_states) 259 | 260 | if self.use_ada_layer_norm_zero: 261 | ff_output = gate_mlp.unsqueeze(1) * ff_output 262 | 263 | hidden_states = ff_output + hidden_states 264 | 265 | return hidden_states 266 | 267 | if self.reference_attn: 268 | if self.fusion_blocks == "midup": 269 | attn_modules = [ 270 | module 271 | for module in ( 272 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) 273 | ) 274 | if isinstance(module, BasicTransformerBlock) 275 | or isinstance(module, TemporalBasicTransformerBlock) 276 | ] 277 | elif self.fusion_blocks == "full": 278 | attn_modules = [ 279 | module 280 | for module in torch_dfs(self.unet) 281 | if isinstance(module, BasicTransformerBlock) 282 | or isinstance(module, TemporalBasicTransformerBlock) 283 | ] 284 | attn_modules = sorted( 285 | attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 286 | ) 287 | 288 | for i, module in enumerate(attn_modules): 289 | module._original_inner_forward = module.forward 290 | if isinstance(module, BasicTransformerBlock): 291 | module.forward = hacked_basic_transformer_inner_forward.__get__( 292 | module, BasicTransformerBlock 293 | ) 294 | if isinstance(module, TemporalBasicTransformerBlock): 295 | module.forward = hacked_basic_transformer_inner_forward.__get__( 296 | module, TemporalBasicTransformerBlock 297 | ) 298 | 299 | module.bank = [] 300 | module.attn_weight = float(i) / float(len(attn_modules)) 301 | 302 | def update(self, writer, dtype=torch.float16): 303 | if self.reference_attn: 304 | if self.fusion_blocks == "midup": 305 | reader_attn_modules = [ 306 | module 307 | for module in ( 308 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) 309 | ) 310 | if isinstance(module, TemporalBasicTransformerBlock) 311 | ] 312 | writer_attn_modules = [ 313 | module 314 | for module in ( 315 | torch_dfs(writer.unet.mid_block) 316 | + torch_dfs(writer.unet.up_blocks) 317 | ) 318 | if isinstance(module, BasicTransformerBlock) 319 | ] 320 | elif self.fusion_blocks == "full": 321 | reader_attn_modules = [ 322 | module 323 | for module in torch_dfs(self.unet) 324 | if isinstance(module, TemporalBasicTransformerBlock) 325 | ] 326 | writer_attn_modules = [ 327 | module 328 | for module in torch_dfs(writer.unet) 329 | if isinstance(module, BasicTransformerBlock) 330 | ] 331 | reader_attn_modules = sorted( 332 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 333 | ) 334 | writer_attn_modules = sorted( 335 | writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 336 | ) 337 | for r, w in zip(reader_attn_modules, writer_attn_modules): 338 | r.bank = [v.clone().to(dtype) for v in w.bank] 339 | # w.bank.clear() 340 | 341 | def clear(self): 342 | if self.reference_attn: 343 | if self.fusion_blocks == "midup": 344 | reader_attn_modules = [ 345 | module 346 | for module in ( 347 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) 348 | ) 349 | if isinstance(module, BasicTransformerBlock) 350 | or isinstance(module, TemporalBasicTransformerBlock) 351 | ] 352 | elif self.fusion_blocks == "full": 353 | reader_attn_modules = [ 354 | module 355 | for module in torch_dfs(self.unet) 356 | if isinstance(module, BasicTransformerBlock) 357 | or isinstance(module, TemporalBasicTransformerBlock) 358 | ] 359 | reader_attn_modules = sorted( 360 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 361 | ) 362 | for r in reader_attn_modules: 363 | r.bank.clear() 364 | -------------------------------------------------------------------------------- /musepose/models/pose_guider.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | from diffusers.models.modeling_utils import ModelMixin 7 | 8 | from musepose.models.motion_module import zero_module 9 | from musepose.models.resnet import InflatedConv3d 10 | 11 | 12 | class PoseGuider(ModelMixin): 13 | def __init__( 14 | self, 15 | conditioning_embedding_channels: int, 16 | conditioning_channels: int = 3, 17 | block_out_channels: Tuple[int] = (16, 32, 64, 128), 18 | ): 19 | super().__init__() 20 | self.conv_in = InflatedConv3d( 21 | conditioning_channels, block_out_channels[0], kernel_size=3, padding=1 22 | ) 23 | 24 | self.blocks = nn.ModuleList([]) 25 | 26 | for i in range(len(block_out_channels) - 1): 27 | channel_in = block_out_channels[i] 28 | channel_out = block_out_channels[i + 1] 29 | self.blocks.append( 30 | InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1) 31 | ) 32 | self.blocks.append( 33 | InflatedConv3d( 34 | channel_in, channel_out, kernel_size=3, padding=1, stride=2 35 | ) 36 | ) 37 | 38 | self.conv_out = zero_module( 39 | InflatedConv3d( 40 | block_out_channels[-1], 41 | conditioning_embedding_channels, 42 | kernel_size=3, 43 | padding=1, 44 | ) 45 | ) 46 | 47 | def forward(self, conditioning): 48 | embedding = self.conv_in(conditioning) 49 | embedding = F.silu(embedding) 50 | 51 | for block in self.blocks: 52 | embedding = block(embedding) 53 | embedding = F.silu(embedding) 54 | 55 | embedding = self.conv_out(embedding) 56 | 57 | return embedding 58 | -------------------------------------------------------------------------------- /musepose/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | 8 | 9 | class InflatedConv3d(nn.Conv2d): 10 | def forward(self, x): 11 | video_length = x.shape[2] 12 | 13 | x = rearrange(x, "b c f h w -> (b f) c h w") 14 | x = super().forward(x) 15 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 16 | 17 | return x 18 | 19 | 20 | class InflatedGroupNorm(nn.GroupNorm): 21 | def forward(self, x): 22 | video_length = x.shape[2] 23 | 24 | x = rearrange(x, "b c f h w -> (b f) c h w") 25 | x = super().forward(x) 26 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 27 | 28 | return x 29 | 30 | 31 | class Upsample3D(nn.Module): 32 | def __init__( 33 | self, 34 | channels, 35 | use_conv=False, 36 | use_conv_transpose=False, 37 | out_channels=None, 38 | name="conv", 39 | ): 40 | super().__init__() 41 | self.channels = channels 42 | self.out_channels = out_channels or channels 43 | self.use_conv = use_conv 44 | self.use_conv_transpose = use_conv_transpose 45 | self.name = name 46 | 47 | conv = None 48 | if use_conv_transpose: 49 | raise NotImplementedError 50 | elif use_conv: 51 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 52 | 53 | def forward(self, hidden_states, output_size=None): 54 | assert hidden_states.shape[1] == self.channels 55 | 56 | if self.use_conv_transpose: 57 | raise NotImplementedError 58 | 59 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 60 | dtype = hidden_states.dtype 61 | if dtype == torch.bfloat16: 62 | hidden_states = hidden_states.to(torch.float32) 63 | 64 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 65 | if hidden_states.shape[0] >= 64: 66 | hidden_states = hidden_states.contiguous() 67 | 68 | # if `output_size` is passed we force the interpolation output 69 | # size and do not make use of `scale_factor=2` 70 | if output_size is None: 71 | hidden_states = F.interpolate( 72 | hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest" 73 | ) 74 | else: 75 | hidden_states = F.interpolate( 76 | hidden_states, size=output_size, mode="nearest" 77 | ) 78 | 79 | # If the input is bfloat16, we cast back to bfloat16 80 | if dtype == torch.bfloat16: 81 | hidden_states = hidden_states.to(dtype) 82 | 83 | # if self.use_conv: 84 | # if self.name == "conv": 85 | # hidden_states = self.conv(hidden_states) 86 | # else: 87 | # hidden_states = self.Conv2d_0(hidden_states) 88 | hidden_states = self.conv(hidden_states) 89 | 90 | return hidden_states 91 | 92 | 93 | class Downsample3D(nn.Module): 94 | def __init__( 95 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv" 96 | ): 97 | super().__init__() 98 | self.channels = channels 99 | self.out_channels = out_channels or channels 100 | self.use_conv = use_conv 101 | self.padding = padding 102 | stride = 2 103 | self.name = name 104 | 105 | if use_conv: 106 | self.conv = InflatedConv3d( 107 | self.channels, self.out_channels, 3, stride=stride, padding=padding 108 | ) 109 | else: 110 | raise NotImplementedError 111 | 112 | def forward(self, hidden_states): 113 | assert hidden_states.shape[1] == self.channels 114 | if self.use_conv and self.padding == 0: 115 | raise NotImplementedError 116 | 117 | assert hidden_states.shape[1] == self.channels 118 | hidden_states = self.conv(hidden_states) 119 | 120 | return hidden_states 121 | 122 | 123 | class ResnetBlock3D(nn.Module): 124 | def __init__( 125 | self, 126 | *, 127 | in_channels, 128 | out_channels=None, 129 | conv_shortcut=False, 130 | dropout=0.0, 131 | temb_channels=512, 132 | groups=32, 133 | groups_out=None, 134 | pre_norm=True, 135 | eps=1e-6, 136 | non_linearity="swish", 137 | time_embedding_norm="default", 138 | output_scale_factor=1.0, 139 | use_in_shortcut=None, 140 | use_inflated_groupnorm=None, 141 | ): 142 | super().__init__() 143 | self.pre_norm = pre_norm 144 | self.pre_norm = True 145 | self.in_channels = in_channels 146 | out_channels = in_channels if out_channels is None else out_channels 147 | self.out_channels = out_channels 148 | self.use_conv_shortcut = conv_shortcut 149 | self.time_embedding_norm = time_embedding_norm 150 | self.output_scale_factor = output_scale_factor 151 | 152 | if groups_out is None: 153 | groups_out = groups 154 | 155 | assert use_inflated_groupnorm != None 156 | if use_inflated_groupnorm: 157 | self.norm1 = InflatedGroupNorm( 158 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True 159 | ) 160 | else: 161 | self.norm1 = torch.nn.GroupNorm( 162 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True 163 | ) 164 | 165 | self.conv1 = InflatedConv3d( 166 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 167 | ) 168 | 169 | if temb_channels is not None: 170 | if self.time_embedding_norm == "default": 171 | time_emb_proj_out_channels = out_channels 172 | elif self.time_embedding_norm == "scale_shift": 173 | time_emb_proj_out_channels = out_channels * 2 174 | else: 175 | raise ValueError( 176 | f"unknown time_embedding_norm : {self.time_embedding_norm} " 177 | ) 178 | 179 | self.time_emb_proj = torch.nn.Linear( 180 | temb_channels, time_emb_proj_out_channels 181 | ) 182 | else: 183 | self.time_emb_proj = None 184 | 185 | if use_inflated_groupnorm: 186 | self.norm2 = InflatedGroupNorm( 187 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True 188 | ) 189 | else: 190 | self.norm2 = torch.nn.GroupNorm( 191 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True 192 | ) 193 | self.dropout = torch.nn.Dropout(dropout) 194 | self.conv2 = InflatedConv3d( 195 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 196 | ) 197 | 198 | if non_linearity == "swish": 199 | self.nonlinearity = lambda x: F.silu(x) 200 | elif non_linearity == "mish": 201 | self.nonlinearity = Mish() 202 | elif non_linearity == "silu": 203 | self.nonlinearity = nn.SiLU() 204 | 205 | self.use_in_shortcut = ( 206 | self.in_channels != self.out_channels 207 | if use_in_shortcut is None 208 | else use_in_shortcut 209 | ) 210 | 211 | self.conv_shortcut = None 212 | if self.use_in_shortcut: 213 | self.conv_shortcut = InflatedConv3d( 214 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 215 | ) 216 | 217 | def forward(self, input_tensor, temb): 218 | hidden_states = input_tensor 219 | 220 | hidden_states = self.norm1(hidden_states) 221 | hidden_states = self.nonlinearity(hidden_states) 222 | 223 | hidden_states = self.conv1(hidden_states) 224 | 225 | if temb is not None: 226 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 227 | 228 | if temb is not None and self.time_embedding_norm == "default": 229 | hidden_states = hidden_states + temb 230 | 231 | hidden_states = self.norm2(hidden_states) 232 | 233 | if temb is not None and self.time_embedding_norm == "scale_shift": 234 | scale, shift = torch.chunk(temb, 2, dim=1) 235 | hidden_states = hidden_states * (1 + scale) + shift 236 | 237 | hidden_states = self.nonlinearity(hidden_states) 238 | 239 | hidden_states = self.dropout(hidden_states) 240 | hidden_states = self.conv2(hidden_states) 241 | 242 | if self.conv_shortcut is not None: 243 | input_tensor = self.conv_shortcut(input_tensor) 244 | 245 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 246 | 247 | return output_tensor 248 | 249 | 250 | class Mish(torch.nn.Module): 251 | def forward(self, hidden_states): 252 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 253 | -------------------------------------------------------------------------------- /musepose/models/transformer_3d.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | import torch 5 | from diffusers.configuration_utils import ConfigMixin, register_to_config 6 | from diffusers.models import ModelMixin 7 | from diffusers.utils import BaseOutput 8 | from diffusers.utils.import_utils import is_xformers_available 9 | from einops import rearrange, repeat 10 | from torch import nn 11 | 12 | from .attention import TemporalBasicTransformerBlock 13 | 14 | 15 | @dataclass 16 | class Transformer3DModelOutput(BaseOutput): 17 | sample: torch.FloatTensor 18 | 19 | 20 | if is_xformers_available(): 21 | import xformers 22 | import xformers.ops 23 | else: 24 | xformers = None 25 | 26 | 27 | class Transformer3DModel(ModelMixin, ConfigMixin): 28 | _supports_gradient_checkpointing = True 29 | 30 | @register_to_config 31 | def __init__( 32 | self, 33 | num_attention_heads: int = 16, 34 | attention_head_dim: int = 88, 35 | in_channels: Optional[int] = None, 36 | num_layers: int = 1, 37 | dropout: float = 0.0, 38 | norm_num_groups: int = 32, 39 | cross_attention_dim: Optional[int] = None, 40 | attention_bias: bool = False, 41 | activation_fn: str = "geglu", 42 | num_embeds_ada_norm: Optional[int] = None, 43 | use_linear_projection: bool = False, 44 | only_cross_attention: bool = False, 45 | upcast_attention: bool = False, 46 | unet_use_cross_frame_attention=None, 47 | unet_use_temporal_attention=None, 48 | ): 49 | super().__init__() 50 | self.use_linear_projection = use_linear_projection 51 | self.num_attention_heads = num_attention_heads 52 | self.attention_head_dim = attention_head_dim 53 | inner_dim = num_attention_heads * attention_head_dim 54 | 55 | # Define input layers 56 | self.in_channels = in_channels 57 | 58 | self.norm = torch.nn.GroupNorm( 59 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True 60 | ) 61 | if use_linear_projection: 62 | self.proj_in = nn.Linear(in_channels, inner_dim) 63 | else: 64 | self.proj_in = nn.Conv2d( 65 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 66 | ) 67 | 68 | # Define transformers blocks 69 | self.transformer_blocks = nn.ModuleList( 70 | [ 71 | TemporalBasicTransformerBlock( 72 | inner_dim, 73 | num_attention_heads, 74 | attention_head_dim, 75 | dropout=dropout, 76 | cross_attention_dim=cross_attention_dim, 77 | activation_fn=activation_fn, 78 | num_embeds_ada_norm=num_embeds_ada_norm, 79 | attention_bias=attention_bias, 80 | only_cross_attention=only_cross_attention, 81 | upcast_attention=upcast_attention, 82 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 83 | unet_use_temporal_attention=unet_use_temporal_attention, 84 | ) 85 | for d in range(num_layers) 86 | ] 87 | ) 88 | 89 | # 4. Define output layers 90 | if use_linear_projection: 91 | self.proj_out = nn.Linear(in_channels, inner_dim) 92 | else: 93 | self.proj_out = nn.Conv2d( 94 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0 95 | ) 96 | 97 | self.gradient_checkpointing = False 98 | 99 | def _set_gradient_checkpointing(self, module, value=False): 100 | if hasattr(module, "gradient_checkpointing"): 101 | module.gradient_checkpointing = value 102 | 103 | def forward( 104 | self, 105 | hidden_states, 106 | encoder_hidden_states=None, 107 | timestep=None, 108 | return_dict: bool = True, 109 | ): 110 | # Input 111 | assert ( 112 | hidden_states.dim() == 5 113 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 114 | video_length = hidden_states.shape[2] 115 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 116 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]: 117 | encoder_hidden_states = repeat( 118 | encoder_hidden_states, "b n c -> (b f) n c", f=video_length 119 | ) 120 | 121 | batch, channel, height, weight = hidden_states.shape 122 | residual = hidden_states 123 | 124 | hidden_states = self.norm(hidden_states) 125 | if not self.use_linear_projection: 126 | hidden_states = self.proj_in(hidden_states) 127 | inner_dim = hidden_states.shape[1] 128 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 129 | batch, height * weight, inner_dim 130 | ) 131 | else: 132 | inner_dim = hidden_states.shape[1] 133 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 134 | batch, height * weight, inner_dim 135 | ) 136 | hidden_states = self.proj_in(hidden_states) 137 | 138 | # Blocks 139 | for i, block in enumerate(self.transformer_blocks): 140 | hidden_states = block( 141 | hidden_states, 142 | encoder_hidden_states=encoder_hidden_states, 143 | timestep=timestep, 144 | video_length=video_length, 145 | ) 146 | 147 | # Output 148 | if not self.use_linear_projection: 149 | hidden_states = ( 150 | hidden_states.reshape(batch, height, weight, inner_dim) 151 | .permute(0, 3, 1, 2) 152 | .contiguous() 153 | ) 154 | hidden_states = self.proj_out(hidden_states) 155 | else: 156 | hidden_states = self.proj_out(hidden_states) 157 | hidden_states = ( 158 | hidden_states.reshape(batch, height, weight, inner_dim) 159 | .permute(0, 3, 1, 2) 160 | .contiguous() 161 | ) 162 | 163 | output = hidden_states + residual 164 | 165 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 166 | if not return_dict: 167 | return (output,) 168 | 169 | return Transformer3DModelOutput(sample=output) 170 | -------------------------------------------------------------------------------- /musepose/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MusePose/61c52bd937224a614b3951419b735b639397cb62/musepose/pipelines/__init__.py -------------------------------------------------------------------------------- /musepose/pipelines/context.py: -------------------------------------------------------------------------------- 1 | # TODO: Adapted from cli 2 | from typing import Callable, List, Optional 3 | 4 | import numpy as np 5 | 6 | 7 | def ordered_halving(val): 8 | bin_str = f"{val:064b}" 9 | bin_flip = bin_str[::-1] 10 | as_int = int(bin_flip, 2) 11 | 12 | return as_int / (1 << 64) 13 | 14 | 15 | def uniform( 16 | step: int = ..., 17 | num_steps: Optional[int] = None, 18 | num_frames: int = ..., 19 | context_size: Optional[int] = None, 20 | context_stride: int = 3, 21 | context_overlap: int = 4, 22 | closed_loop: bool = False, 23 | ): 24 | if num_frames <= context_size: 25 | yield list(range(num_frames)) 26 | return 27 | 28 | context_stride = min( 29 | context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1 30 | ) 31 | 32 | for context_step in 1 << np.arange(context_stride): 33 | pad = int(round(num_frames * ordered_halving(step))) 34 | for j in range( 35 | int(ordered_halving(step) * context_step) + pad, 36 | num_frames + pad + (0 if closed_loop else -context_overlap), 37 | (context_size * context_step - context_overlap), 38 | ): 39 | yield [ 40 | e % num_frames 41 | for e in range(j, j + context_size * context_step, context_step) 42 | ] 43 | 44 | 45 | def get_context_scheduler(name: str) -> Callable: 46 | if name == "uniform": 47 | return uniform 48 | else: 49 | raise ValueError(f"Unknown context_overlap policy {name}") 50 | 51 | 52 | def get_total_steps( 53 | scheduler, 54 | timesteps: List[int], 55 | num_steps: Optional[int] = None, 56 | num_frames: int = ..., 57 | context_size: Optional[int] = None, 58 | context_stride: int = 3, 59 | context_overlap: int = 4, 60 | closed_loop: bool = True, 61 | ): 62 | return sum( 63 | len( 64 | list( 65 | scheduler( 66 | i, 67 | num_steps, 68 | num_frames, 69 | context_size, 70 | context_stride, 71 | context_overlap, 72 | ) 73 | ) 74 | ) 75 | for i in range(len(timesteps)) 76 | ) 77 | -------------------------------------------------------------------------------- /musepose/pipelines/pipeline_pose2img.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Callable, List, Optional, Union 4 | 5 | import numpy as np 6 | import torch 7 | from diffusers import DiffusionPipeline 8 | from diffusers.image_processor import VaeImageProcessor 9 | from diffusers.schedulers import ( 10 | DDIMScheduler, 11 | DPMSolverMultistepScheduler, 12 | EulerAncestralDiscreteScheduler, 13 | EulerDiscreteScheduler, 14 | LMSDiscreteScheduler, 15 | PNDMScheduler, 16 | ) 17 | from diffusers.utils import BaseOutput, is_accelerate_available 18 | from diffusers.utils.torch_utils import randn_tensor 19 | from einops import rearrange 20 | from tqdm import tqdm 21 | from transformers import CLIPImageProcessor 22 | 23 | from musepose.models.mutual_self_attention import ReferenceAttentionControl 24 | 25 | 26 | @dataclass 27 | class Pose2ImagePipelineOutput(BaseOutput): 28 | images: Union[torch.Tensor, np.ndarray] 29 | 30 | 31 | class Pose2ImagePipeline(DiffusionPipeline): 32 | _optional_components = [] 33 | 34 | def __init__( 35 | self, 36 | vae, 37 | image_encoder, 38 | reference_unet, 39 | denoising_unet, 40 | pose_guider, 41 | scheduler: Union[ 42 | DDIMScheduler, 43 | PNDMScheduler, 44 | LMSDiscreteScheduler, 45 | EulerDiscreteScheduler, 46 | EulerAncestralDiscreteScheduler, 47 | DPMSolverMultistepScheduler, 48 | ], 49 | ): 50 | super().__init__() 51 | 52 | self.register_modules( 53 | vae=vae, 54 | image_encoder=image_encoder, 55 | reference_unet=reference_unet, 56 | denoising_unet=denoising_unet, 57 | pose_guider=pose_guider, 58 | scheduler=scheduler, 59 | ) 60 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 61 | self.clip_image_processor = CLIPImageProcessor() 62 | self.ref_image_processor = VaeImageProcessor( 63 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True 64 | ) 65 | self.cond_image_processor = VaeImageProcessor( 66 | vae_scale_factor=self.vae_scale_factor, 67 | do_convert_rgb=True, 68 | do_normalize=False, 69 | ) 70 | 71 | def enable_vae_slicing(self): 72 | self.vae.enable_slicing() 73 | 74 | def disable_vae_slicing(self): 75 | self.vae.disable_slicing() 76 | 77 | def enable_sequential_cpu_offload(self, gpu_id=0): 78 | if is_accelerate_available(): 79 | from accelerate import cpu_offload 80 | else: 81 | raise ImportError("Please install accelerate via `pip install accelerate`") 82 | 83 | device = torch.device(f"cuda:{gpu_id}") 84 | 85 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 86 | if cpu_offloaded_model is not None: 87 | cpu_offload(cpu_offloaded_model, device) 88 | 89 | @property 90 | def _execution_device(self): 91 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 92 | return self.device 93 | for module in self.unet.modules(): 94 | if ( 95 | hasattr(module, "_hf_hook") 96 | and hasattr(module._hf_hook, "execution_device") 97 | and module._hf_hook.execution_device is not None 98 | ): 99 | return torch.device(module._hf_hook.execution_device) 100 | return self.device 101 | 102 | def decode_latents(self, latents): 103 | video_length = latents.shape[2] 104 | latents = 1 / 0.18215 * latents 105 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 106 | # video = self.vae.decode(latents).sample 107 | video = [] 108 | for frame_idx in tqdm(range(latents.shape[0])): 109 | video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample) 110 | video = torch.cat(video) 111 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 112 | video = (video / 2 + 0.5).clamp(0, 1) 113 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 114 | video = video.cpu().float().numpy() 115 | return video 116 | 117 | def prepare_extra_step_kwargs(self, generator, eta): 118 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 119 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 120 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 121 | # and should be between [0, 1] 122 | 123 | accepts_eta = "eta" in set( 124 | inspect.signature(self.scheduler.step).parameters.keys() 125 | ) 126 | extra_step_kwargs = {} 127 | if accepts_eta: 128 | extra_step_kwargs["eta"] = eta 129 | 130 | # check if the scheduler accepts generator 131 | accepts_generator = "generator" in set( 132 | inspect.signature(self.scheduler.step).parameters.keys() 133 | ) 134 | if accepts_generator: 135 | extra_step_kwargs["generator"] = generator 136 | return extra_step_kwargs 137 | 138 | def prepare_latents( 139 | self, 140 | batch_size, 141 | num_channels_latents, 142 | width, 143 | height, 144 | dtype, 145 | device, 146 | generator, 147 | latents=None, 148 | ): 149 | shape = ( 150 | batch_size, 151 | num_channels_latents, 152 | height // self.vae_scale_factor, 153 | width // self.vae_scale_factor, 154 | ) 155 | if isinstance(generator, list) and len(generator) != batch_size: 156 | raise ValueError( 157 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 158 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 159 | ) 160 | 161 | if latents is None: 162 | latents = randn_tensor( 163 | shape, generator=generator, device=device, dtype=dtype 164 | ) 165 | else: 166 | latents = latents.to(device) 167 | 168 | # scale the initial noise by the standard deviation required by the scheduler 169 | latents = latents * self.scheduler.init_noise_sigma 170 | return latents 171 | 172 | def prepare_condition( 173 | self, 174 | cond_image, 175 | width, 176 | height, 177 | device, 178 | dtype, 179 | do_classififer_free_guidance=False, 180 | ): 181 | image = self.cond_image_processor.preprocess( 182 | cond_image, height=height, width=width 183 | ).to(dtype=torch.float32) 184 | 185 | image = image.to(device=device, dtype=dtype) 186 | 187 | if do_classififer_free_guidance: 188 | image = torch.cat([image] * 2) 189 | 190 | return image 191 | 192 | @torch.no_grad() 193 | def __call__( 194 | self, 195 | ref_image, 196 | pose_image, 197 | width, 198 | height, 199 | num_inference_steps, 200 | guidance_scale, 201 | num_images_per_prompt=1, 202 | eta: float = 0.0, 203 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 204 | output_type: Optional[str] = "tensor", 205 | return_dict: bool = True, 206 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 207 | callback_steps: Optional[int] = 1, 208 | **kwargs, 209 | ): 210 | # Default height and width to unet 211 | height = height or self.unet.config.sample_size * self.vae_scale_factor 212 | width = width or self.unet.config.sample_size * self.vae_scale_factor 213 | 214 | device = self._execution_device 215 | 216 | do_classifier_free_guidance = guidance_scale > 1.0 217 | 218 | # Prepare timesteps 219 | self.scheduler.set_timesteps(num_inference_steps, device=device) 220 | timesteps = self.scheduler.timesteps 221 | 222 | batch_size = 1 223 | 224 | # Prepare clip image embeds 225 | clip_image = self.clip_image_processor.preprocess( 226 | ref_image.resize((224, 224)), return_tensors="pt" 227 | ).pixel_values 228 | clip_image_embeds = self.image_encoder( 229 | clip_image.to(device, dtype=self.image_encoder.dtype) 230 | ).image_embeds 231 | image_prompt_embeds = clip_image_embeds.unsqueeze(1) 232 | uncond_image_prompt_embeds = torch.zeros_like(image_prompt_embeds) 233 | 234 | if do_classifier_free_guidance: 235 | image_prompt_embeds = torch.cat( 236 | [uncond_image_prompt_embeds, image_prompt_embeds], dim=0 237 | ) 238 | 239 | reference_control_writer = ReferenceAttentionControl( 240 | self.reference_unet, 241 | do_classifier_free_guidance=do_classifier_free_guidance, 242 | mode="write", 243 | batch_size=batch_size, 244 | fusion_blocks="full", 245 | ) 246 | reference_control_reader = ReferenceAttentionControl( 247 | self.denoising_unet, 248 | do_classifier_free_guidance=do_classifier_free_guidance, 249 | mode="read", 250 | batch_size=batch_size, 251 | fusion_blocks="full", 252 | ) 253 | 254 | num_channels_latents = self.denoising_unet.in_channels 255 | latents = self.prepare_latents( 256 | batch_size * num_images_per_prompt, 257 | num_channels_latents, 258 | width, 259 | height, 260 | clip_image_embeds.dtype, 261 | device, 262 | generator, 263 | ) 264 | latents = latents.unsqueeze(2) # (bs, c, 1, h', w') 265 | latents_dtype = latents.dtype 266 | 267 | # Prepare extra step kwargs. 268 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 269 | 270 | # Prepare ref image latents 271 | ref_image_tensor = self.ref_image_processor.preprocess( 272 | ref_image, height=height, width=width 273 | ) # (bs, c, width, height) 274 | ref_image_tensor = ref_image_tensor.to( 275 | dtype=self.vae.dtype, device=self.vae.device 276 | ) 277 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean 278 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w) 279 | 280 | # Prepare pose condition image 281 | pose_cond_tensor = self.cond_image_processor.preprocess( 282 | pose_image, height=height, width=width 283 | ) 284 | pose_cond_tensor = pose_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w) 285 | pose_cond_tensor = pose_cond_tensor.to( 286 | device=device, dtype=self.pose_guider.dtype 287 | ) 288 | pose_fea = self.pose_guider(pose_cond_tensor) 289 | pose_fea = ( 290 | torch.cat([pose_fea] * 2) if do_classifier_free_guidance else pose_fea 291 | ) 292 | 293 | # denoising loop 294 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 295 | with self.progress_bar(total=num_inference_steps) as progress_bar: 296 | for i, t in enumerate(timesteps): 297 | # 1. Forward reference image 298 | if i == 0: 299 | self.reference_unet( 300 | ref_image_latents.repeat( 301 | (2 if do_classifier_free_guidance else 1), 1, 1, 1 302 | ), 303 | torch.zeros_like(t), 304 | encoder_hidden_states=image_prompt_embeds, 305 | return_dict=False, 306 | ) 307 | 308 | # 2. Update reference unet feature into denosing net 309 | reference_control_reader.update(reference_control_writer) 310 | 311 | # 3.1 expand the latents if we are doing classifier free guidance 312 | latent_model_input = ( 313 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents 314 | ) 315 | latent_model_input = self.scheduler.scale_model_input( 316 | latent_model_input, t 317 | ) 318 | 319 | noise_pred = self.denoising_unet( 320 | latent_model_input, 321 | t, 322 | encoder_hidden_states=image_prompt_embeds, 323 | pose_cond_fea=pose_fea, 324 | return_dict=False, 325 | )[0] 326 | 327 | # perform guidance 328 | if do_classifier_free_guidance: 329 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 330 | noise_pred = noise_pred_uncond + guidance_scale * ( 331 | noise_pred_text - noise_pred_uncond 332 | ) 333 | 334 | # compute the previous noisy sample x_t -> x_t-1 335 | latents = self.scheduler.step( 336 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False 337 | )[0] 338 | 339 | # call the callback, if provided 340 | if i == len(timesteps) - 1 or ( 341 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 342 | ): 343 | progress_bar.update() 344 | if callback is not None and i % callback_steps == 0: 345 | step_idx = i // getattr(self.scheduler, "order", 1) 346 | callback(step_idx, t, latents) 347 | reference_control_reader.clear() 348 | reference_control_writer.clear() 349 | 350 | # Post-processing 351 | image = self.decode_latents(latents) # (b, c, 1, h, w) 352 | 353 | # Convert to tensor 354 | if output_type == "tensor": 355 | image = torch.from_numpy(image) 356 | 357 | if not return_dict: 358 | return image 359 | 360 | return Pose2ImagePipelineOutput(images=image) 361 | -------------------------------------------------------------------------------- /musepose/pipelines/pipeline_pose2vid.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Callable, List, Optional, Union 4 | 5 | import numpy as np 6 | import torch 7 | from diffusers import DiffusionPipeline 8 | from diffusers.image_processor import VaeImageProcessor 9 | from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler, 10 | EulerAncestralDiscreteScheduler, 11 | EulerDiscreteScheduler, LMSDiscreteScheduler, 12 | PNDMScheduler) 13 | from diffusers.utils import BaseOutput, is_accelerate_available 14 | from diffusers.utils.torch_utils import randn_tensor 15 | from einops import rearrange 16 | from tqdm import tqdm 17 | from transformers import CLIPImageProcessor 18 | 19 | from musepose.models.mutual_self_attention import ReferenceAttentionControl 20 | 21 | 22 | @dataclass 23 | class Pose2VideoPipelineOutput(BaseOutput): 24 | videos: Union[torch.Tensor, np.ndarray] 25 | 26 | 27 | class Pose2VideoPipeline(DiffusionPipeline): 28 | _optional_components = [] 29 | 30 | def __init__( 31 | self, 32 | vae, 33 | image_encoder, 34 | reference_unet, 35 | denoising_unet, 36 | pose_guider, 37 | scheduler: Union[ 38 | DDIMScheduler, 39 | PNDMScheduler, 40 | LMSDiscreteScheduler, 41 | EulerDiscreteScheduler, 42 | EulerAncestralDiscreteScheduler, 43 | DPMSolverMultistepScheduler, 44 | ], 45 | image_proj_model=None, 46 | tokenizer=None, 47 | text_encoder=None, 48 | ): 49 | super().__init__() 50 | 51 | self.register_modules( 52 | vae=vae, 53 | image_encoder=image_encoder, 54 | reference_unet=reference_unet, 55 | denoising_unet=denoising_unet, 56 | pose_guider=pose_guider, 57 | scheduler=scheduler, 58 | image_proj_model=image_proj_model, 59 | tokenizer=tokenizer, 60 | text_encoder=text_encoder, 61 | ) 62 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 63 | self.clip_image_processor = CLIPImageProcessor() 64 | self.ref_image_processor = VaeImageProcessor( 65 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True 66 | ) 67 | self.cond_image_processor = VaeImageProcessor( 68 | vae_scale_factor=self.vae_scale_factor, 69 | do_convert_rgb=True, 70 | do_normalize=False, 71 | ) 72 | 73 | def enable_vae_slicing(self): 74 | self.vae.enable_slicing() 75 | 76 | def disable_vae_slicing(self): 77 | self.vae.disable_slicing() 78 | 79 | def enable_sequential_cpu_offload(self, gpu_id=0): 80 | if is_accelerate_available(): 81 | from accelerate import cpu_offload 82 | else: 83 | raise ImportError("Please install accelerate via `pip install accelerate`") 84 | 85 | device = torch.device(f"cuda:{gpu_id}") 86 | 87 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 88 | if cpu_offloaded_model is not None: 89 | cpu_offload(cpu_offloaded_model, device) 90 | 91 | @property 92 | def _execution_device(self): 93 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 94 | return self.device 95 | for module in self.unet.modules(): 96 | if ( 97 | hasattr(module, "_hf_hook") 98 | and hasattr(module._hf_hook, "execution_device") 99 | and module._hf_hook.execution_device is not None 100 | ): 101 | return torch.device(module._hf_hook.execution_device) 102 | return self.device 103 | 104 | def decode_latents(self, latents): 105 | video_length = latents.shape[2] 106 | latents = 1 / 0.18215 * latents 107 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 108 | # video = self.vae.decode(latents).sample 109 | video = [] 110 | for frame_idx in tqdm(range(latents.shape[0])): 111 | video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample) 112 | video = torch.cat(video) 113 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 114 | video = (video / 2 + 0.5).clamp(0, 1) 115 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 116 | video = video.cpu().float().numpy() 117 | return video 118 | 119 | def prepare_extra_step_kwargs(self, generator, eta): 120 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 121 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 122 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 123 | # and should be between [0, 1] 124 | 125 | accepts_eta = "eta" in set( 126 | inspect.signature(self.scheduler.step).parameters.keys() 127 | ) 128 | extra_step_kwargs = {} 129 | if accepts_eta: 130 | extra_step_kwargs["eta"] = eta 131 | 132 | # check if the scheduler accepts generator 133 | accepts_generator = "generator" in set( 134 | inspect.signature(self.scheduler.step).parameters.keys() 135 | ) 136 | if accepts_generator: 137 | extra_step_kwargs["generator"] = generator 138 | return extra_step_kwargs 139 | 140 | def prepare_latents( 141 | self, 142 | batch_size, 143 | num_channels_latents, 144 | width, 145 | height, 146 | video_length, 147 | dtype, 148 | device, 149 | generator, 150 | latents=None, 151 | ): 152 | shape = ( 153 | batch_size, 154 | num_channels_latents, 155 | video_length, 156 | height // self.vae_scale_factor, 157 | width // self.vae_scale_factor, 158 | ) 159 | if isinstance(generator, list) and len(generator) != batch_size: 160 | raise ValueError( 161 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 162 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 163 | ) 164 | 165 | if latents is None: 166 | latents = randn_tensor( 167 | shape, generator=generator, device=device, dtype=dtype 168 | ) 169 | else: 170 | latents = latents.to(device) 171 | 172 | # scale the initial noise by the standard deviation required by the scheduler 173 | latents = latents * self.scheduler.init_noise_sigma 174 | return latents 175 | 176 | def _encode_prompt( 177 | self, 178 | prompt, 179 | device, 180 | num_videos_per_prompt, 181 | do_classifier_free_guidance, 182 | negative_prompt, 183 | ): 184 | batch_size = len(prompt) if isinstance(prompt, list) else 1 185 | 186 | text_inputs = self.tokenizer( 187 | prompt, 188 | padding="max_length", 189 | max_length=self.tokenizer.model_max_length, 190 | truncation=True, 191 | return_tensors="pt", 192 | ) 193 | text_input_ids = text_inputs.input_ids 194 | untruncated_ids = self.tokenizer( 195 | prompt, padding="longest", return_tensors="pt" 196 | ).input_ids 197 | 198 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 199 | text_input_ids, untruncated_ids 200 | ): 201 | removed_text = self.tokenizer.batch_decode( 202 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 203 | ) 204 | 205 | if ( 206 | hasattr(self.text_encoder.config, "use_attention_mask") 207 | and self.text_encoder.config.use_attention_mask 208 | ): 209 | attention_mask = text_inputs.attention_mask.to(device) 210 | else: 211 | attention_mask = None 212 | 213 | text_embeddings = self.text_encoder( 214 | text_input_ids.to(device), 215 | attention_mask=attention_mask, 216 | ) 217 | text_embeddings = text_embeddings[0] 218 | 219 | # duplicate text embeddings for each generation per prompt, using mps friendly method 220 | bs_embed, seq_len, _ = text_embeddings.shape 221 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) 222 | text_embeddings = text_embeddings.view( 223 | bs_embed * num_videos_per_prompt, seq_len, -1 224 | ) 225 | 226 | # get unconditional embeddings for classifier free guidance 227 | if do_classifier_free_guidance: 228 | uncond_tokens: List[str] 229 | if negative_prompt is None: 230 | uncond_tokens = [""] * batch_size 231 | elif type(prompt) is not type(negative_prompt): 232 | raise TypeError( 233 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 234 | f" {type(prompt)}." 235 | ) 236 | elif isinstance(negative_prompt, str): 237 | uncond_tokens = [negative_prompt] 238 | elif batch_size != len(negative_prompt): 239 | raise ValueError( 240 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 241 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 242 | " the batch size of `prompt`." 243 | ) 244 | else: 245 | uncond_tokens = negative_prompt 246 | 247 | max_length = text_input_ids.shape[-1] 248 | uncond_input = self.tokenizer( 249 | uncond_tokens, 250 | padding="max_length", 251 | max_length=max_length, 252 | truncation=True, 253 | return_tensors="pt", 254 | ) 255 | 256 | if ( 257 | hasattr(self.text_encoder.config, "use_attention_mask") 258 | and self.text_encoder.config.use_attention_mask 259 | ): 260 | attention_mask = uncond_input.attention_mask.to(device) 261 | else: 262 | attention_mask = None 263 | 264 | uncond_embeddings = self.text_encoder( 265 | uncond_input.input_ids.to(device), 266 | attention_mask=attention_mask, 267 | ) 268 | uncond_embeddings = uncond_embeddings[0] 269 | 270 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 271 | seq_len = uncond_embeddings.shape[1] 272 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) 273 | uncond_embeddings = uncond_embeddings.view( 274 | batch_size * num_videos_per_prompt, seq_len, -1 275 | ) 276 | 277 | # For classifier free guidance, we need to do two forward passes. 278 | # Here we concatenate the unconditional and text embeddings into a single batch 279 | # to avoid doing two forward passes 280 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 281 | 282 | return text_embeddings 283 | 284 | @torch.no_grad() 285 | def __call__( 286 | self, 287 | ref_image, 288 | pose_images, 289 | width, 290 | height, 291 | video_length, 292 | num_inference_steps, 293 | guidance_scale, 294 | num_images_per_prompt=1, 295 | eta: float = 0.0, 296 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 297 | output_type: Optional[str] = "tensor", 298 | return_dict: bool = True, 299 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 300 | callback_steps: Optional[int] = 1, 301 | **kwargs, 302 | ): 303 | # Default height and width to unet 304 | height = height or self.unet.config.sample_size * self.vae_scale_factor 305 | width = width or self.unet.config.sample_size * self.vae_scale_factor 306 | 307 | device = self._execution_device 308 | 309 | do_classifier_free_guidance = guidance_scale > 1.0 310 | 311 | # Prepare timesteps 312 | self.scheduler.set_timesteps(num_inference_steps, device=device) 313 | timesteps = self.scheduler.timesteps 314 | 315 | batch_size = 1 316 | 317 | # Prepare clip image embeds 318 | clip_image = self.clip_image_processor.preprocess( 319 | ref_image, return_tensors="pt" 320 | ).pixel_values 321 | clip_image_embeds = self.image_encoder( 322 | clip_image.to(device, dtype=self.image_encoder.dtype) 323 | ).image_embeds 324 | encoder_hidden_states = clip_image_embeds.unsqueeze(1) 325 | uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states) 326 | 327 | if do_classifier_free_guidance: 328 | encoder_hidden_states = torch.cat( 329 | [uncond_encoder_hidden_states, encoder_hidden_states], dim=0 330 | ) 331 | reference_control_writer = ReferenceAttentionControl( 332 | self.reference_unet, 333 | do_classifier_free_guidance=do_classifier_free_guidance, 334 | mode="write", 335 | batch_size=batch_size, 336 | fusion_blocks="full", 337 | ) 338 | reference_control_reader = ReferenceAttentionControl( 339 | self.denoising_unet, 340 | do_classifier_free_guidance=do_classifier_free_guidance, 341 | mode="read", 342 | batch_size=batch_size, 343 | fusion_blocks="full", 344 | ) 345 | 346 | num_channels_latents = self.denoising_unet.in_channels 347 | latents = self.prepare_latents( 348 | batch_size * num_images_per_prompt, 349 | num_channels_latents, 350 | width, 351 | height, 352 | video_length, 353 | clip_image_embeds.dtype, 354 | device, 355 | generator, 356 | ) 357 | 358 | # Prepare extra step kwargs. 359 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 360 | 361 | # Prepare ref image latents 362 | ref_image_tensor = self.ref_image_processor.preprocess( 363 | ref_image, height=height, width=width 364 | ) # (bs, c, width, height) 365 | ref_image_tensor = ref_image_tensor.to( 366 | dtype=self.vae.dtype, device=self.vae.device 367 | ) 368 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean 369 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w) 370 | 371 | # Prepare a list of pose condition images 372 | pose_cond_tensor_list = [] 373 | for pose_image in pose_images: 374 | pose_cond_tensor = ( 375 | torch.from_numpy(np.array(pose_image.resize((width, height)))) / 255.0 376 | ) 377 | pose_cond_tensor = pose_cond_tensor.permute(2, 0, 1).unsqueeze( 378 | 1 379 | ) # (c, 1, h, w) 380 | pose_cond_tensor_list.append(pose_cond_tensor) 381 | pose_cond_tensor = torch.cat(pose_cond_tensor_list, dim=1) # (c, t, h, w) 382 | pose_cond_tensor = pose_cond_tensor.unsqueeze(0) 383 | pose_cond_tensor = pose_cond_tensor.to( 384 | device=device, dtype=self.pose_guider.dtype 385 | ) 386 | pose_fea = self.pose_guider(pose_cond_tensor) 387 | pose_fea = ( 388 | torch.cat([pose_fea] * 2) if do_classifier_free_guidance else pose_fea 389 | ) 390 | 391 | # denoising loop 392 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 393 | with self.progress_bar(total=num_inference_steps) as progress_bar: 394 | for i, t in enumerate(timesteps): 395 | # 1. Forward reference image 396 | if i == 0: 397 | self.reference_unet( 398 | ref_image_latents.repeat( 399 | (2 if do_classifier_free_guidance else 1), 1, 1, 1 400 | ), 401 | torch.zeros_like(t), 402 | # t, 403 | encoder_hidden_states=encoder_hidden_states, 404 | return_dict=False, 405 | ) 406 | reference_control_reader.update(reference_control_writer) 407 | 408 | # 3.1 expand the latents if we are doing classifier free guidance 409 | latent_model_input = ( 410 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents 411 | ) 412 | latent_model_input = self.scheduler.scale_model_input( 413 | latent_model_input, t 414 | ) 415 | 416 | noise_pred = self.denoising_unet( 417 | latent_model_input, 418 | t, 419 | encoder_hidden_states=encoder_hidden_states, 420 | pose_cond_fea=pose_fea, 421 | return_dict=False, 422 | )[0] 423 | 424 | # perform guidance 425 | if do_classifier_free_guidance: 426 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 427 | noise_pred = noise_pred_uncond + guidance_scale * ( 428 | noise_pred_text - noise_pred_uncond 429 | ) 430 | 431 | # compute the previous noisy sample x_t -> x_t-1 432 | latents = self.scheduler.step( 433 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False 434 | )[0] 435 | 436 | # call the callback, if provided 437 | if i == len(timesteps) - 1 or ( 438 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 439 | ): 440 | progress_bar.update() 441 | if callback is not None and i % callback_steps == 0: 442 | step_idx = i // getattr(self.scheduler, "order", 1) 443 | callback(step_idx, t, latents) 444 | 445 | reference_control_reader.clear() 446 | reference_control_writer.clear() 447 | 448 | # Post-processing 449 | images = self.decode_latents(latents) # (b, c, f, h, w) 450 | 451 | # Convert to tensor 452 | if output_type == "tensor": 453 | images = torch.from_numpy(images) 454 | 455 | if not return_dict: 456 | return images 457 | 458 | return Pose2VideoPipelineOutput(videos=images) 459 | -------------------------------------------------------------------------------- /musepose/pipelines/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | tensor_interpolation = None 4 | 5 | 6 | def get_tensor_interpolation_method(): 7 | return tensor_interpolation 8 | 9 | 10 | def set_tensor_interpolation_method(is_slerp): 11 | global tensor_interpolation 12 | tensor_interpolation = slerp if is_slerp else linear 13 | 14 | 15 | def linear(v1, v2, t): 16 | return (1.0 - t) * v1 + t * v2 17 | 18 | 19 | def slerp( 20 | v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995 21 | ) -> torch.Tensor: 22 | u0 = v0 / v0.norm() 23 | u1 = v1 / v1.norm() 24 | dot = (u0 * u1).sum() 25 | if dot.abs() > DOT_THRESHOLD: 26 | # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.') 27 | return (1.0 - t) * v0 + t * v1 28 | omega = dot.acos() 29 | return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin() 30 | -------------------------------------------------------------------------------- /musepose/utils/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import os.path as osp 4 | import shutil 5 | import sys 6 | from pathlib import Path 7 | 8 | import av 9 | import numpy as np 10 | import torch 11 | import torchvision 12 | from einops import rearrange 13 | from PIL import Image 14 | 15 | 16 | def seed_everything(seed): 17 | import random 18 | 19 | import numpy as np 20 | 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | np.random.seed(seed % (2**32)) 24 | random.seed(seed) 25 | 26 | 27 | def import_filename(filename): 28 | spec = importlib.util.spec_from_file_location("mymodule", filename) 29 | module = importlib.util.module_from_spec(spec) 30 | sys.modules[spec.name] = module 31 | spec.loader.exec_module(module) 32 | return module 33 | 34 | 35 | def delete_additional_ckpt(base_path, num_keep): 36 | dirs = [] 37 | for d in os.listdir(base_path): 38 | if d.startswith("checkpoint-"): 39 | dirs.append(d) 40 | num_tot = len(dirs) 41 | if num_tot <= num_keep: 42 | return 43 | # ensure ckpt is sorted and delete the ealier! 44 | del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep] 45 | for d in del_dirs: 46 | path_to_dir = osp.join(base_path, d) 47 | if osp.exists(path_to_dir): 48 | shutil.rmtree(path_to_dir) 49 | 50 | 51 | def save_videos_from_pil(pil_images, path, fps=8): 52 | import av 53 | 54 | save_fmt = Path(path).suffix 55 | os.makedirs(os.path.dirname(path), exist_ok=True) 56 | width, height = pil_images[0].size 57 | 58 | if save_fmt == ".mp4": 59 | codec = "libx264" 60 | container = av.open(path, "w") 61 | stream = container.add_stream(codec, rate=fps) 62 | 63 | stream.width = width 64 | stream.height = height 65 | stream.pix_fmt = 'yuv420p' 66 | stream.bit_rate = 10000000 67 | stream.options["crf"] = "18" 68 | 69 | 70 | 71 | for pil_image in pil_images: 72 | # pil_image = Image.fromarray(image_arr).convert("RGB") 73 | av_frame = av.VideoFrame.from_image(pil_image) 74 | container.mux(stream.encode(av_frame)) 75 | container.mux(stream.encode()) 76 | container.close() 77 | 78 | elif save_fmt == ".gif": 79 | pil_images[0].save( 80 | fp=path, 81 | format="GIF", 82 | append_images=pil_images[1:], 83 | save_all=True, 84 | duration=(1 / fps * 1000), 85 | loop=0, 86 | ) 87 | else: 88 | raise ValueError("Unsupported file type. Use .mp4 or .gif.") 89 | 90 | 91 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): 92 | videos = rearrange(videos, "b c t h w -> t b c h w") 93 | height, width = videos.shape[-2:] 94 | outputs = [] 95 | 96 | for x in videos: 97 | x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w) 98 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c) 99 | if rescale: 100 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 101 | x = (x * 255).numpy().astype(np.uint8) 102 | x = Image.fromarray(x) 103 | 104 | outputs.append(x) 105 | 106 | os.makedirs(os.path.dirname(path), exist_ok=True) 107 | 108 | save_videos_from_pil(outputs, path, fps) 109 | 110 | 111 | def read_frames(video_path): 112 | container = av.open(video_path) 113 | 114 | video_stream = next(s for s in container.streams if s.type == "video") 115 | frames = [] 116 | for packet in container.demux(video_stream): 117 | for frame in packet.decode(): 118 | image = Image.frombytes( 119 | "RGB", 120 | (frame.width, frame.height), 121 | frame.to_rgb().to_ndarray(), 122 | ) 123 | frames.append(image) 124 | 125 | return frames 126 | 127 | 128 | def get_fps(video_path): 129 | container = av.open(video_path) 130 | video_stream = next(s for s in container.streams if s.type == "video") 131 | fps = video_stream.average_rate 132 | container.close() 133 | return fps 134 | -------------------------------------------------------------------------------- /pose/config/dwpose-l_384x288.py: -------------------------------------------------------------------------------- 1 | # runtime 2 | max_epochs = 270 3 | stage2_num_epochs = 30 4 | base_lr = 4e-3 5 | 6 | train_cfg = dict(max_epochs=max_epochs, val_interval=10) 7 | randomness = dict(seed=21) 8 | 9 | # optimizer 10 | optim_wrapper = dict( 11 | type='OptimWrapper', 12 | optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05), 13 | paramwise_cfg=dict( 14 | norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) 15 | 16 | # learning rate 17 | param_scheduler = [ 18 | dict( 19 | type='LinearLR', 20 | start_factor=1.0e-5, 21 | by_epoch=False, 22 | begin=0, 23 | end=1000), 24 | dict( 25 | # use cosine lr from 150 to 300 epoch 26 | type='CosineAnnealingLR', 27 | eta_min=base_lr * 0.05, 28 | begin=max_epochs // 2, 29 | end=max_epochs, 30 | T_max=max_epochs // 2, 31 | by_epoch=True, 32 | convert_to_iter_based=True), 33 | ] 34 | 35 | # automatically scaling LR based on the actual training batch size 36 | auto_scale_lr = dict(base_batch_size=512) 37 | 38 | # codec settings 39 | codec = dict( 40 | type='SimCCLabel', 41 | input_size=(288, 384), 42 | sigma=(6., 6.93), 43 | simcc_split_ratio=2.0, 44 | normalize=False, 45 | use_dark=False) 46 | 47 | # model settings 48 | model = dict( 49 | type='TopdownPoseEstimator', 50 | data_preprocessor=dict( 51 | type='PoseDataPreprocessor', 52 | mean=[123.675, 116.28, 103.53], 53 | std=[58.395, 57.12, 57.375], 54 | bgr_to_rgb=True), 55 | backbone=dict( 56 | _scope_='mmdet', 57 | type='CSPNeXt', 58 | arch='P5', 59 | expand_ratio=0.5, 60 | deepen_factor=1., 61 | widen_factor=1., 62 | out_indices=(4, ), 63 | channel_attention=True, 64 | norm_cfg=dict(type='SyncBN'), 65 | act_cfg=dict(type='SiLU'), 66 | init_cfg=dict( 67 | type='Pretrained', 68 | prefix='backbone.', 69 | checkpoint='https://download.openmmlab.com/mmpose/v1/projects/' 70 | 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa 71 | )), 72 | head=dict( 73 | type='RTMCCHead', 74 | in_channels=1024, 75 | out_channels=133, 76 | input_size=codec['input_size'], 77 | in_featuremap_size=(9, 12), 78 | simcc_split_ratio=codec['simcc_split_ratio'], 79 | final_layer_kernel_size=7, 80 | gau_cfg=dict( 81 | hidden_dims=256, 82 | s=128, 83 | expansion_factor=2, 84 | dropout_rate=0., 85 | drop_path=0., 86 | act_fn='SiLU', 87 | use_rel_bias=False, 88 | pos_enc=False), 89 | loss=dict( 90 | type='KLDiscretLoss', 91 | use_target_weight=True, 92 | beta=10., 93 | label_softmax=True), 94 | decoder=codec), 95 | test_cfg=dict(flip_test=True, )) 96 | 97 | # base dataset settings 98 | dataset_type = 'CocoWholeBodyDataset' 99 | data_mode = 'topdown' 100 | data_root = '/data/' 101 | 102 | backend_args = dict(backend='local') 103 | # backend_args = dict( 104 | # backend='petrel', 105 | # path_mapping=dict({ 106 | # f'{data_root}': 's3://openmmlab/datasets/detection/coco/', 107 | # f'{data_root}': 's3://openmmlab/datasets/detection/coco/' 108 | # })) 109 | 110 | # pipelines 111 | train_pipeline = [ 112 | dict(type='LoadImage', backend_args=backend_args), 113 | dict(type='GetBBoxCenterScale'), 114 | dict(type='RandomFlip', direction='horizontal'), 115 | dict(type='RandomHalfBody'), 116 | dict( 117 | type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80), 118 | dict(type='TopdownAffine', input_size=codec['input_size']), 119 | dict(type='mmdet.YOLOXHSVRandomAug'), 120 | dict( 121 | type='Albumentation', 122 | transforms=[ 123 | dict(type='Blur', p=0.1), 124 | dict(type='MedianBlur', p=0.1), 125 | dict( 126 | type='CoarseDropout', 127 | max_holes=1, 128 | max_height=0.4, 129 | max_width=0.4, 130 | min_holes=1, 131 | min_height=0.2, 132 | min_width=0.2, 133 | p=1.0), 134 | ]), 135 | dict(type='GenerateTarget', encoder=codec), 136 | dict(type='PackPoseInputs') 137 | ] 138 | val_pipeline = [ 139 | dict(type='LoadImage', backend_args=backend_args), 140 | dict(type='GetBBoxCenterScale'), 141 | dict(type='TopdownAffine', input_size=codec['input_size']), 142 | dict(type='PackPoseInputs') 143 | ] 144 | 145 | train_pipeline_stage2 = [ 146 | dict(type='LoadImage', backend_args=backend_args), 147 | dict(type='GetBBoxCenterScale'), 148 | dict(type='RandomFlip', direction='horizontal'), 149 | dict(type='RandomHalfBody'), 150 | dict( 151 | type='RandomBBoxTransform', 152 | shift_factor=0., 153 | scale_factor=[0.75, 1.25], 154 | rotate_factor=60), 155 | dict(type='TopdownAffine', input_size=codec['input_size']), 156 | dict(type='mmdet.YOLOXHSVRandomAug'), 157 | dict( 158 | type='Albumentation', 159 | transforms=[ 160 | dict(type='Blur', p=0.1), 161 | dict(type='MedianBlur', p=0.1), 162 | dict( 163 | type='CoarseDropout', 164 | max_holes=1, 165 | max_height=0.4, 166 | max_width=0.4, 167 | min_holes=1, 168 | min_height=0.2, 169 | min_width=0.2, 170 | p=0.5), 171 | ]), 172 | dict(type='GenerateTarget', encoder=codec), 173 | dict(type='PackPoseInputs') 174 | ] 175 | 176 | datasets = [] 177 | dataset_coco=dict( 178 | type=dataset_type, 179 | data_root=data_root, 180 | data_mode=data_mode, 181 | ann_file='coco/annotations/coco_wholebody_train_v1.0.json', 182 | data_prefix=dict(img='coco/train2017/'), 183 | pipeline=[], 184 | ) 185 | datasets.append(dataset_coco) 186 | 187 | scene = ['Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 188 | 'TalkShow', 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 189 | 'Singing', 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'] 190 | 191 | for i in range(len(scene)): 192 | datasets.append( 193 | dict( 194 | type=dataset_type, 195 | data_root=data_root, 196 | data_mode=data_mode, 197 | ann_file='UBody/annotations/'+scene[i]+'/keypoint_annotation.json', 198 | data_prefix=dict(img='UBody/images/'+scene[i]+'/'), 199 | pipeline=[], 200 | ) 201 | ) 202 | 203 | # data loaders 204 | train_dataloader = dict( 205 | batch_size=32, 206 | num_workers=10, 207 | persistent_workers=True, 208 | sampler=dict(type='DefaultSampler', shuffle=True), 209 | dataset=dict( 210 | type='CombinedDataset', 211 | metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), 212 | datasets=datasets, 213 | pipeline=train_pipeline, 214 | test_mode=False, 215 | )) 216 | val_dataloader = dict( 217 | batch_size=32, 218 | num_workers=10, 219 | persistent_workers=True, 220 | drop_last=False, 221 | sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), 222 | dataset=dict( 223 | type=dataset_type, 224 | data_root=data_root, 225 | data_mode=data_mode, 226 | ann_file='coco/annotations/coco_wholebody_val_v1.0.json', 227 | bbox_file=f'{data_root}coco/person_detection_results/' 228 | 'COCO_val2017_detections_AP_H_56_person.json', 229 | data_prefix=dict(img='coco/val2017/'), 230 | test_mode=True, 231 | pipeline=val_pipeline, 232 | )) 233 | test_dataloader = val_dataloader 234 | 235 | # hooks 236 | default_hooks = dict( 237 | checkpoint=dict( 238 | save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1)) 239 | 240 | custom_hooks = [ 241 | dict( 242 | type='EMAHook', 243 | ema_type='ExpMomentumEMA', 244 | momentum=0.0002, 245 | update_buffers=True, 246 | priority=49), 247 | dict( 248 | type='mmdet.PipelineSwitchHook', 249 | switch_epoch=max_epochs - stage2_num_epochs, 250 | switch_pipeline=train_pipeline_stage2) 251 | ] 252 | 253 | # evaluators 254 | val_evaluator = dict( 255 | type='CocoWholeBodyMetric', 256 | ann_file=data_root + 'coco/annotations/coco_wholebody_val_v1.0.json') 257 | test_evaluator = val_evaluator -------------------------------------------------------------------------------- /pose/config/yolox_l_8xb8-300e_coco.py: -------------------------------------------------------------------------------- 1 | img_scale = (640, 640) # width, height 2 | 3 | # model settings 4 | model = dict( 5 | type='YOLOX', 6 | data_preprocessor=dict( 7 | type='DetDataPreprocessor', 8 | pad_size_divisor=32, 9 | batch_augments=[ 10 | dict( 11 | type='BatchSyncRandomResize', 12 | random_size_range=(480, 800), 13 | size_divisor=32, 14 | interval=10) 15 | ]), 16 | backbone=dict( 17 | type='CSPDarknet', 18 | deepen_factor=1.0, 19 | widen_factor=1.0, 20 | out_indices=(2, 3, 4), 21 | use_depthwise=False, 22 | spp_kernal_sizes=(5, 9, 13), 23 | norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), 24 | act_cfg=dict(type='Swish'), 25 | ), 26 | neck=dict( 27 | type='YOLOXPAFPN', 28 | in_channels=[256, 512, 1024], 29 | out_channels=256, 30 | num_csp_blocks=3, 31 | use_depthwise=False, 32 | upsample_cfg=dict(scale_factor=2, mode='nearest'), 33 | norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), 34 | act_cfg=dict(type='Swish')), 35 | bbox_head=dict( 36 | type='YOLOXHead', 37 | num_classes=80, 38 | in_channels=256, 39 | feat_channels=256, 40 | stacked_convs=2, 41 | strides=(8, 16, 32), 42 | use_depthwise=False, 43 | norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), 44 | act_cfg=dict(type='Swish'), 45 | loss_cls=dict( 46 | type='CrossEntropyLoss', 47 | use_sigmoid=True, 48 | reduction='sum', 49 | loss_weight=1.0), 50 | loss_bbox=dict( 51 | type='IoULoss', 52 | mode='square', 53 | eps=1e-16, 54 | reduction='sum', 55 | loss_weight=5.0), 56 | loss_obj=dict( 57 | type='CrossEntropyLoss', 58 | use_sigmoid=True, 59 | reduction='sum', 60 | loss_weight=1.0), 61 | loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0)), 62 | train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)), 63 | # In order to align the source code, the threshold of the val phase is 64 | # 0.01, and the threshold of the test phase is 0.001. 65 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))) 66 | 67 | # dataset settings 68 | data_root = 'data/coco/' 69 | dataset_type = 'CocoDataset' 70 | 71 | # Example to use different file client 72 | # Method 1: simply set the data root and let the file I/O module 73 | # automatically infer from prefix (not support LMDB and Memcache yet) 74 | 75 | # data_root = 's3://openmmlab/datasets/detection/coco/' 76 | 77 | # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 78 | # backend_args = dict( 79 | # backend='petrel', 80 | # path_mapping=dict({ 81 | # './data/': 's3://openmmlab/datasets/detection/', 82 | # 'data/': 's3://openmmlab/datasets/detection/' 83 | # })) 84 | backend_args = None 85 | 86 | train_pipeline = [ 87 | dict(type='Mosaic', img_scale=img_scale, pad_val=114.0), 88 | dict( 89 | type='RandomAffine', 90 | scaling_ratio_range=(0.1, 2), 91 | # img_scale is (width, height) 92 | border=(-img_scale[0] // 2, -img_scale[1] // 2)), 93 | dict( 94 | type='MixUp', 95 | img_scale=img_scale, 96 | ratio_range=(0.8, 1.6), 97 | pad_val=114.0), 98 | dict(type='YOLOXHSVRandomAug'), 99 | dict(type='RandomFlip', prob=0.5), 100 | # According to the official implementation, multi-scale 101 | # training is not considered here but in the 102 | # 'mmdet/models/detectors/yolox.py'. 103 | # Resize and Pad are for the last 15 epochs when Mosaic, 104 | # RandomAffine, and MixUp are closed by YOLOXModeSwitchHook. 105 | dict(type='Resize', scale=img_scale, keep_ratio=True), 106 | dict( 107 | type='Pad', 108 | pad_to_square=True, 109 | # If the image is three-channel, the pad value needs 110 | # to be set separately for each channel. 111 | pad_val=dict(img=(114.0, 114.0, 114.0))), 112 | dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False), 113 | dict(type='PackDetInputs') 114 | ] 115 | 116 | train_dataset = dict( 117 | # use MultiImageMixDataset wrapper to support mosaic and mixup 118 | type='MultiImageMixDataset', 119 | dataset=dict( 120 | type=dataset_type, 121 | data_root=data_root, 122 | ann_file='annotations/instances_train2017.json', 123 | data_prefix=dict(img='train2017/'), 124 | pipeline=[ 125 | dict(type='LoadImageFromFile', backend_args=backend_args), 126 | dict(type='LoadAnnotations', with_bbox=True) 127 | ], 128 | filter_cfg=dict(filter_empty_gt=False, min_size=32), 129 | backend_args=backend_args), 130 | pipeline=train_pipeline) 131 | 132 | test_pipeline = [ 133 | dict(type='LoadImageFromFile', backend_args=backend_args), 134 | dict(type='Resize', scale=img_scale, keep_ratio=True), 135 | dict( 136 | type='Pad', 137 | pad_to_square=True, 138 | pad_val=dict(img=(114.0, 114.0, 114.0))), 139 | dict(type='LoadAnnotations', with_bbox=True), 140 | dict( 141 | type='PackDetInputs', 142 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 143 | 'scale_factor')) 144 | ] 145 | 146 | train_dataloader = dict( 147 | batch_size=8, 148 | num_workers=4, 149 | persistent_workers=True, 150 | sampler=dict(type='DefaultSampler', shuffle=True), 151 | dataset=train_dataset) 152 | val_dataloader = dict( 153 | batch_size=8, 154 | num_workers=4, 155 | persistent_workers=True, 156 | drop_last=False, 157 | sampler=dict(type='DefaultSampler', shuffle=False), 158 | dataset=dict( 159 | type=dataset_type, 160 | data_root=data_root, 161 | ann_file='annotations/instances_val2017.json', 162 | data_prefix=dict(img='val2017/'), 163 | test_mode=True, 164 | pipeline=test_pipeline, 165 | backend_args=backend_args)) 166 | test_dataloader = val_dataloader 167 | 168 | val_evaluator = dict( 169 | type='CocoMetric', 170 | ann_file=data_root + 'annotations/instances_val2017.json', 171 | metric='bbox', 172 | backend_args=backend_args) 173 | test_evaluator = val_evaluator 174 | 175 | # training settings 176 | max_epochs = 300 177 | num_last_epochs = 15 178 | interval = 10 179 | 180 | train_cfg = dict(max_epochs=max_epochs, val_interval=interval) 181 | 182 | # optimizer 183 | # default 8 gpu 184 | base_lr = 0.01 185 | optim_wrapper = dict( 186 | type='OptimWrapper', 187 | optimizer=dict( 188 | type='SGD', lr=base_lr, momentum=0.9, weight_decay=5e-4, 189 | nesterov=True), 190 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)) 191 | 192 | # learning rate 193 | param_scheduler = [ 194 | dict( 195 | # use quadratic formula to warm up 5 epochs 196 | # and lr is updated by iteration 197 | # TODO: fix default scope in get function 198 | type='mmdet.QuadraticWarmupLR', 199 | by_epoch=True, 200 | begin=0, 201 | end=5, 202 | convert_to_iter_based=True), 203 | dict( 204 | # use cosine lr from 5 to 285 epoch 205 | type='CosineAnnealingLR', 206 | eta_min=base_lr * 0.05, 207 | begin=5, 208 | T_max=max_epochs - num_last_epochs, 209 | end=max_epochs - num_last_epochs, 210 | by_epoch=True, 211 | convert_to_iter_based=True), 212 | dict( 213 | # use fixed lr during last 15 epochs 214 | type='ConstantLR', 215 | by_epoch=True, 216 | factor=1, 217 | begin=max_epochs - num_last_epochs, 218 | end=max_epochs, 219 | ) 220 | ] 221 | 222 | default_hooks = dict( 223 | checkpoint=dict( 224 | interval=interval, 225 | max_keep_ckpts=3 # only keep latest 3 checkpoints 226 | )) 227 | 228 | custom_hooks = [ 229 | dict( 230 | type='YOLOXModeSwitchHook', 231 | num_last_epochs=num_last_epochs, 232 | priority=48), 233 | dict(type='SyncNormHook', priority=48), 234 | dict( 235 | type='EMAHook', 236 | ema_type='ExpMomentumEMA', 237 | momentum=0.0001, 238 | update_buffers=True, 239 | priority=49) 240 | ] 241 | 242 | # NOTE: `auto_scale_lr` is for automatically scaling LR, 243 | # USER SHOULD NOT CHANGE ITS VALUES. 244 | # base_batch_size = (8 GPUs) x (8 samples per GPU) 245 | auto_scale_lr = dict(base_batch_size=64) 246 | -------------------------------------------------------------------------------- /pose/script/dwpose.py: -------------------------------------------------------------------------------- 1 | # Openpose 2 | # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose 3 | # 2nd Edited by https://github.com/Hzzone/pytorch-openpose 4 | # 3rd Edited by ControlNet 5 | # 4th Edited by ControlNet (added face and correct hands) 6 | 7 | import os 8 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 9 | 10 | import cv2 11 | import torch 12 | import numpy as np 13 | from PIL import Image 14 | 15 | 16 | import pose.script.util as util 17 | 18 | def resize_image(input_image, resolution): 19 | H, W, C = input_image.shape 20 | H = float(H) 21 | W = float(W) 22 | k = float(resolution) / min(H, W) 23 | H *= k 24 | W *= k 25 | H = int(np.round(H / 64.0)) * 64 26 | W = int(np.round(W / 64.0)) * 64 27 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 28 | return img 29 | 30 | def HWC3(x): 31 | assert x.dtype == np.uint8 32 | if x.ndim == 2: 33 | x = x[:, :, None] 34 | assert x.ndim == 3 35 | H, W, C = x.shape 36 | assert C == 1 or C == 3 or C == 4 37 | if C == 3: 38 | return x 39 | if C == 1: 40 | return np.concatenate([x, x, x], axis=2) 41 | if C == 4: 42 | color = x[:, :, 0:3].astype(np.float32) 43 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 44 | y = color * alpha + 255.0 * (1.0 - alpha) 45 | y = y.clip(0, 255).astype(np.uint8) 46 | return y 47 | 48 | def draw_pose(pose, H, W, draw_face): 49 | bodies = pose['bodies'] 50 | faces = pose['faces'] 51 | hands = pose['hands'] 52 | candidate = bodies['candidate'] 53 | subset = bodies['subset'] 54 | 55 | # only the most significant person 56 | faces = pose['faces'][:1] 57 | hands = pose['hands'][:2] 58 | candidate = bodies['candidate'][:18] 59 | subset = bodies['subset'][:1] 60 | 61 | # draw 62 | canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) 63 | canvas = util.draw_bodypose(canvas, candidate, subset) 64 | canvas = util.draw_handpose(canvas, hands) 65 | if draw_face == True: 66 | canvas = util.draw_facepose(canvas, faces) 67 | 68 | return canvas 69 | 70 | class DWposeDetector: 71 | def __init__(self, det_config=None, det_ckpt=None, pose_config=None, pose_ckpt=None, device="cpu", keypoints_only=False): 72 | from pose.script.wholebody import Wholebody 73 | 74 | self.pose_estimation = Wholebody(det_config, det_ckpt, pose_config, pose_ckpt, device) 75 | self.keypoints_only = keypoints_only 76 | def to(self, device): 77 | self.pose_estimation.to(device) 78 | return self 79 | ''' 80 | detect_resolution: 短边resize到多少 这是 draw pose 时的原始渲染分辨率。建议1024 81 | image_resolution: 短边resize到多少 这是 save pose 时的文件分辨率。建议768 82 | 83 | 实际检测分辨率: 84 | yolox: (640, 640) 85 | dwpose:(288, 384) 86 | ''' 87 | 88 | def __call__(self, input_image, detect_resolution=1024, image_resolution=768, output_type="pil", **kwargs): 89 | 90 | input_image = cv2.cvtColor(np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR) 91 | # cv2.imshow('', input_image) 92 | # cv2.waitKey(0) 93 | 94 | input_image = HWC3(input_image) 95 | input_image = resize_image(input_image, detect_resolution) 96 | H, W, C = input_image.shape 97 | 98 | with torch.no_grad(): 99 | candidate, subset = self.pose_estimation(input_image) 100 | nums, keys, locs = candidate.shape 101 | candidate[..., 0] /= float(W) 102 | candidate[..., 1] /= float(H) 103 | body = candidate[:,:18].copy() 104 | body = body.reshape(nums*18, locs) 105 | score = subset[:,:18] 106 | 107 | for i in range(len(score)): 108 | for j in range(len(score[i])): 109 | if score[i][j] > 0.3: 110 | score[i][j] = int(18*i+j) 111 | else: 112 | score[i][j] = -1 113 | 114 | un_visible = subset<0.3 115 | candidate[un_visible] = -1 116 | 117 | foot = candidate[:,18:24] 118 | 119 | faces = candidate[:,24:92] 120 | 121 | hands = candidate[:,92:113] 122 | hands = np.vstack([hands, candidate[:,113:]]) 123 | 124 | bodies = dict(candidate=body, subset=score) 125 | pose = dict(bodies=bodies, hands=hands, faces=faces) 126 | 127 | if self.keypoints_only==True: 128 | return pose 129 | else: 130 | detected_map = draw_pose(pose, H, W, draw_face=False) 131 | detected_map = HWC3(detected_map) 132 | img = resize_image(input_image, image_resolution) 133 | H, W, C = img.shape 134 | detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) 135 | # cv2.imshow('detected_map',detected_map) 136 | # cv2.waitKey(0) 137 | 138 | if output_type == "pil": 139 | detected_map = cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB) 140 | detected_map = Image.fromarray(detected_map) 141 | 142 | return detected_map, pose 143 | 144 | -------------------------------------------------------------------------------- /pose/script/tool.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import os.path as osp 4 | import shutil 5 | import sys 6 | from pathlib import Path 7 | 8 | import av 9 | import numpy as np 10 | import torch 11 | import torchvision 12 | from einops import rearrange 13 | from PIL import Image 14 | 15 | 16 | def seed_everything(seed): 17 | import random 18 | 19 | import numpy as np 20 | 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | np.random.seed(seed % (2**32)) 24 | random.seed(seed) 25 | 26 | 27 | def import_filename(filename): 28 | spec = importlib.util.spec_from_file_location("mymodule", filename) 29 | module = importlib.util.module_from_spec(spec) 30 | sys.modules[spec.name] = module 31 | spec.loader.exec_module(module) 32 | return module 33 | 34 | 35 | def delete_additional_ckpt(base_path, num_keep): 36 | dirs = [] 37 | for d in os.listdir(base_path): 38 | if d.startswith("checkpoint-"): 39 | dirs.append(d) 40 | num_tot = len(dirs) 41 | if num_tot <= num_keep: 42 | return 43 | # ensure ckpt is sorted and delete the ealier! 44 | del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep] 45 | for d in del_dirs: 46 | path_to_dir = osp.join(base_path, d) 47 | if osp.exists(path_to_dir): 48 | shutil.rmtree(path_to_dir) 49 | 50 | 51 | def save_videos_from_pil(pil_images, path, fps): 52 | 53 | save_fmt = Path(path).suffix 54 | os.makedirs(os.path.dirname(path), exist_ok=True) 55 | width, height = pil_images[0].size 56 | 57 | if save_fmt == ".mp4": 58 | codec = "libx264" 59 | container = av.open(path, "w") 60 | stream = container.add_stream(codec, rate=fps) 61 | 62 | stream.width = width 63 | stream.height = height 64 | stream.pix_fmt = 'yuv420p' 65 | stream.bit_rate = 10000000 66 | stream.options["crf"] = "18" 67 | 68 | for pil_image in pil_images: 69 | # pil_image = Image.fromarray(image_arr).convert("RGB") 70 | av_frame = av.VideoFrame.from_image(pil_image) 71 | container.mux(stream.encode(av_frame)) 72 | container.mux(stream.encode()) 73 | container.close() 74 | 75 | elif save_fmt == ".gif": 76 | pil_images[0].save( 77 | fp=path, 78 | format="GIF", 79 | append_images=pil_images[1:], 80 | save_all=True, 81 | duration=(1 / fps * 1000), 82 | loop=0, 83 | ) 84 | else: 85 | raise ValueError("Unsupported file type. Use .mp4 or .gif.") 86 | 87 | 88 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): 89 | videos = rearrange(videos, "b c t h w -> t b c h w") 90 | height, width = videos.shape[-2:] 91 | outputs = [] 92 | 93 | for x in videos: 94 | x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w) 95 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c) 96 | if rescale: 97 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 98 | x = (x * 255).numpy().astype(np.uint8) 99 | x = Image.fromarray(x) 100 | 101 | outputs.append(x) 102 | 103 | os.makedirs(os.path.dirname(path), exist_ok=True) 104 | 105 | save_videos_from_pil(outputs, path, fps) 106 | 107 | 108 | def read_frames(video_path): 109 | container = av.open(video_path) 110 | 111 | video_stream = next(s for s in container.streams if s.type == "video") 112 | frames = [] 113 | for packet in container.demux(video_stream): 114 | for frame in packet.decode(): 115 | image = Image.frombytes( 116 | "RGB", 117 | (frame.width, frame.height), 118 | frame.to_rgb().to_ndarray(), 119 | ) 120 | frames.append(image) 121 | 122 | return frames 123 | 124 | 125 | def get_fps(video_path): 126 | container = av.open(video_path) 127 | video_stream = next(s for s in container.streams if s.type == "video") 128 | fps = video_stream.average_rate 129 | container.close() 130 | return fps 131 | -------------------------------------------------------------------------------- /pose/script/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import cv2 4 | 5 | 6 | eps = 0.01 7 | 8 | def smart_width(d): 9 | if d<5: 10 | return 1 11 | elif d<10: 12 | return 2 13 | elif d<20: 14 | return 3 15 | elif d<40: 16 | return 4 17 | elif d<80: 18 | return 5 19 | elif d<160: 20 | return 6 21 | elif d<320: 22 | return 7 23 | else: 24 | return 8 25 | 26 | 27 | 28 | def draw_bodypose(canvas, candidate, subset): 29 | H, W, C = canvas.shape 30 | candidate = np.array(candidate) 31 | subset = np.array(subset) 32 | 33 | limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ 34 | [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ 35 | [1, 16], [16, 18], [3, 17], [6, 18]] 36 | 37 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ 38 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ 39 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 40 | 41 | for i in range(17): 42 | for n in range(len(subset)): 43 | index = subset[n][np.array(limbSeq[i]) - 1] 44 | if -1 in index: 45 | continue 46 | Y = candidate[index.astype(int), 0] * float(W) 47 | X = candidate[index.astype(int), 1] * float(H) 48 | mX = np.mean(X) 49 | mY = np.mean(Y) 50 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 51 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 52 | 53 | width = smart_width(length) 54 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), width), int(angle), 0, 360, 1) 55 | cv2.fillConvexPoly(canvas, polygon, colors[i]) 56 | 57 | canvas = (canvas * 0.6).astype(np.uint8) 58 | 59 | for i in range(18): 60 | for n in range(len(subset)): 61 | index = int(subset[n][i]) 62 | if index == -1: 63 | continue 64 | x, y = candidate[index][0:2] 65 | x = int(x * W) 66 | y = int(y * H) 67 | radius = 4 68 | cv2.circle(canvas, (int(x), int(y)), radius, colors[i], thickness=-1) 69 | 70 | return canvas 71 | 72 | 73 | def draw_handpose(canvas, all_hand_peaks): 74 | import matplotlib 75 | 76 | H, W, C = canvas.shape 77 | 78 | edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ 79 | [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] 80 | 81 | # (person_number*2, 21, 2) 82 | for i in range(len(all_hand_peaks)): 83 | peaks = all_hand_peaks[i] 84 | peaks = np.array(peaks) 85 | 86 | for ie, e in enumerate(edges): 87 | 88 | x1, y1 = peaks[e[0]] 89 | x2, y2 = peaks[e[1]] 90 | 91 | x1 = int(x1 * W) 92 | y1 = int(y1 * H) 93 | x2 = int(x2 * W) 94 | y2 = int(y2 * H) 95 | if x1 > eps and y1 > eps and x2 > eps and y2 > eps: 96 | length = ((x1 - x2) ** 2 + (y1 - y2) ** 2) ** 0.5 97 | width = smart_width(length) 98 | cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=width) 99 | 100 | for _, keyponit in enumerate(peaks): 101 | x, y = keyponit 102 | 103 | x = int(x * W) 104 | y = int(y * H) 105 | if x > eps and y > eps: 106 | radius = 3 107 | cv2.circle(canvas, (x, y), radius, (0, 0, 255), thickness=-1) 108 | return canvas 109 | 110 | 111 | def draw_facepose(canvas, all_lmks): 112 | H, W, C = canvas.shape 113 | for lmks in all_lmks: 114 | lmks = np.array(lmks) 115 | for lmk in lmks: 116 | x, y = lmk 117 | x = int(x * W) 118 | y = int(y * H) 119 | if x > eps and y > eps: 120 | radius = 3 121 | cv2.circle(canvas, (x, y), radius, (255, 255, 255), thickness=-1) 122 | return canvas 123 | 124 | 125 | 126 | 127 | # Calculate the resolution 128 | def size_calculate(h, w, resolution): 129 | 130 | H = float(h) 131 | W = float(w) 132 | 133 | # resize the short edge to the resolution 134 | k = float(resolution) / min(H, W) # short edge 135 | H *= k 136 | W *= k 137 | 138 | # resize to the nearest integer multiple of 64 139 | H = int(np.round(H / 64.0)) * 64 140 | W = int(np.round(W / 64.0)) * 64 141 | return H, W 142 | 143 | 144 | 145 | def warpAffine_kps(kps, M): 146 | a = M[:,:2] 147 | t = M[:,2] 148 | kps = np.dot(kps, a.T) + t 149 | return kps 150 | 151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /pose/script/wholebody.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import numpy as np 4 | import warnings 5 | 6 | try: 7 | import mmcv 8 | except ImportError: 9 | warnings.warn( 10 | "The module 'mmcv' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmcv>=2.0.1'" 11 | ) 12 | 13 | try: 14 | from mmpose.apis import inference_topdown 15 | from mmpose.apis import init_model as init_pose_estimator 16 | from mmpose.evaluation.functional import nms 17 | from mmpose.utils import adapt_mmdet_pipeline 18 | from mmpose.structures import merge_data_samples 19 | except ImportError: 20 | warnings.warn( 21 | "The module 'mmpose' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmpose>=1.1.0'" 22 | ) 23 | 24 | try: 25 | from mmdet.apis import inference_detector, init_detector 26 | except ImportError: 27 | warnings.warn( 28 | "The module 'mmdet' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmdet>=3.1.0'" 29 | ) 30 | 31 | 32 | class Wholebody: 33 | def __init__(self, 34 | det_config=None, det_ckpt=None, 35 | pose_config=None, pose_ckpt=None, 36 | device="cpu"): 37 | 38 | if det_config is None: 39 | det_config = os.path.join(os.path.dirname(__file__), "yolox_config/yolox_l_8xb8-300e_coco.py") 40 | 41 | if pose_config is None: 42 | pose_config = os.path.join(os.path.dirname(__file__), "dwpose_config/dwpose-l_384x288.py") 43 | 44 | if det_ckpt is None: 45 | det_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth' 46 | 47 | if pose_ckpt is None: 48 | pose_ckpt = "https://huggingface.co/wanghaofan/dw-ll_ucoco_384/resolve/main/dw-ll_ucoco_384.pth" 49 | 50 | # build detector 51 | self.detector = init_detector(det_config, det_ckpt, device=device) 52 | self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg) 53 | 54 | # build pose estimator 55 | self.pose_estimator = init_pose_estimator( 56 | pose_config, 57 | pose_ckpt, 58 | device=device) 59 | 60 | def to(self, device): 61 | self.detector.to(device) 62 | self.pose_estimator.to(device) 63 | return self 64 | 65 | def __call__(self, oriImg): 66 | # predict bbox 67 | det_result = inference_detector(self.detector, oriImg) 68 | pred_instance = det_result.pred_instances.cpu().numpy() 69 | bboxes = np.concatenate( 70 | (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1) 71 | bboxes = bboxes[np.logical_and(pred_instance.labels == 0, 72 | pred_instance.scores > 0.5)] 73 | 74 | # set NMS threshold 75 | bboxes = bboxes[nms(bboxes, 0.7), :4] 76 | 77 | # predict keypoints 78 | if len(bboxes) == 0: 79 | pose_results = inference_topdown(self.pose_estimator, oriImg) 80 | else: 81 | pose_results = inference_topdown(self.pose_estimator, oriImg, bboxes) 82 | preds = merge_data_samples(pose_results) 83 | preds = preds.pred_instances 84 | 85 | # preds = pose_results[0].pred_instances 86 | keypoints = preds.get('transformed_keypoints', 87 | preds.keypoints) 88 | if 'keypoint_scores' in preds: 89 | scores = preds.keypoint_scores 90 | else: 91 | scores = np.ones(keypoints.shape[:-1]) 92 | 93 | if 'keypoints_visible' in preds: 94 | visible = preds.keypoints_visible 95 | else: 96 | visible = np.ones(keypoints.shape[:-1]) 97 | keypoints_info = np.concatenate( 98 | (keypoints, scores[..., None], visible[..., None]), 99 | axis=-1) 100 | # compute neck joint 101 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1) 102 | # neck score when visualizing pred 103 | neck[:, 2:4] = np.logical_and( 104 | keypoints_info[:, 5, 2:4] > 0.3, 105 | keypoints_info[:, 6, 2:4] > 0.3).astype(int) 106 | new_keypoints_info = np.insert( 107 | keypoints_info, 17, neck, axis=1) 108 | mmpose_idx = [ 109 | 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 110 | ] 111 | openpose_idx = [ 112 | 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 113 | ] 114 | new_keypoints_info[:, openpose_idx] = \ 115 | new_keypoints_info[:, mmpose_idx] 116 | keypoints_info = new_keypoints_info 117 | 118 | keypoints, scores, visible = keypoints_info[ 119 | ..., :2], keypoints_info[..., 2], keypoints_info[..., 3] 120 | 121 | return keypoints, scores 122 | -------------------------------------------------------------------------------- /pretrained_weights/put_models_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MusePose/61c52bd937224a614b3951419b735b639397cb62/pretrained_weights/put_models_here.txt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | torchdiffeq==0.2.3 3 | torchmetrics==1.2.1 4 | torchsde==0.2.5 5 | torchvision==0.15.2 6 | accelerate==0.29.3 7 | av==11.0.0 8 | clip @ https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip#sha256=b5842c25da441d6c581b53a5c60e0c2127ebafe0f746f8e15561a006c6c3be6a 9 | decord==0.6.0 10 | diffusers>=0.24.0,<=0.27.2 11 | einops==0.4.1 12 | imageio==2.33.0 13 | imageio-ffmpeg==0.4.9 14 | ffmpeg-python==0.2.0 15 | omegaconf==2.2.3 16 | open-clip-torch==2.20.0 17 | opencv-contrib-python==4.8.1.78 18 | opencv-python==4.8.1.78 19 | scikit-image==0.21.0 20 | scikit-learn==1.3.2 21 | transformers==4.33.1 22 | xformers==0.0.22 23 | moviepy==1.0.3 24 | wget==3.2 25 | -------------------------------------------------------------------------------- /test_stage_1.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import argparse 3 | import os 4 | import sys 5 | from datetime import datetime 6 | from pathlib import Path 7 | from typing import List 8 | import glob 9 | 10 | import numpy as np 11 | import torch 12 | import torchvision 13 | from diffusers import AutoencoderKL, DDIMScheduler 14 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline 15 | from einops import repeat 16 | from omegaconf import OmegaConf 17 | from PIL import Image 18 | from torchvision import transforms 19 | from transformers import CLIPVisionModelWithProjection 20 | 21 | 22 | from musepose.models.pose_guider import PoseGuider 23 | from musepose.models.unet_2d_condition import UNet2DConditionModel 24 | from musepose.models.unet_3d import UNet3DConditionModel 25 | from musepose.pipelines.pipeline_pose2img import Pose2ImagePipeline 26 | from musepose.utils.util import get_fps, read_frames, save_videos_grid 27 | 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--config",default="./configs/test_stage_1.yaml") 32 | parser.add_argument("-W", type=int, default=768) 33 | parser.add_argument("-H", type=int, default=768) 34 | parser.add_argument("--seed", type=int, default=42) 35 | parser.add_argument("--cnt", type=int, default=1) 36 | parser.add_argument("--cfg", type=float, default=7) 37 | parser.add_argument("--steps", type=int, default=20) 38 | parser.add_argument("--fps", type=int) 39 | args = parser.parse_args() 40 | 41 | return args 42 | 43 | 44 | 45 | def main(): 46 | args = parse_args() 47 | 48 | config = OmegaConf.load(args.config) 49 | 50 | if config.weight_dtype == "fp16": 51 | weight_dtype = torch.float16 52 | else: 53 | weight_dtype = torch.float32 54 | 55 | vae = AutoencoderKL.from_pretrained( 56 | config.pretrained_vae_path, 57 | ).to("cuda", dtype=weight_dtype) 58 | 59 | reference_unet = UNet2DConditionModel.from_pretrained( 60 | config.pretrained_base_model_path, 61 | subfolder="unet", 62 | ).to(dtype=weight_dtype, device="cuda") 63 | 64 | inference_config_path = config.inference_config 65 | infer_config = OmegaConf.load(inference_config_path) 66 | denoising_unet = UNet3DConditionModel.from_pretrained_2d( 67 | config.pretrained_base_model_path, 68 | # config.motion_module_path, 69 | "", 70 | subfolder="unet", 71 | unet_additional_kwargs={ 72 | "use_motion_module": False, 73 | "unet_use_temporal_attention": False, 74 | }, 75 | ).to(dtype=weight_dtype, device="cuda") 76 | 77 | pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to( 78 | dtype=weight_dtype, device="cuda" 79 | ) 80 | 81 | image_enc = CLIPVisionModelWithProjection.from_pretrained( 82 | config.image_encoder_path 83 | ).to(dtype=weight_dtype, device="cuda") 84 | 85 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) 86 | scheduler = DDIMScheduler(**sched_kwargs) 87 | 88 | 89 | width, height = args.W, args.H 90 | 91 | # load pretrained weights 92 | denoising_unet.load_state_dict( 93 | torch.load(config.denoising_unet_path, map_location="cpu"), 94 | strict=False, 95 | ) 96 | reference_unet.load_state_dict( 97 | torch.load(config.reference_unet_path, map_location="cpu"), 98 | ) 99 | pose_guider.load_state_dict( 100 | torch.load(config.pose_guider_path, map_location="cpu"), 101 | ) 102 | 103 | pipe = Pose2ImagePipeline( 104 | vae=vae, 105 | image_encoder=image_enc, 106 | reference_unet=reference_unet, 107 | denoising_unet=denoising_unet, 108 | pose_guider=pose_guider, 109 | scheduler=scheduler, 110 | ) 111 | 112 | pipe = pipe.to("cuda", dtype=weight_dtype) 113 | 114 | date_str = datetime.now().strftime("%Y%m%d") 115 | time_str = datetime.now().strftime("%H%M") 116 | 117 | m1 = config.pose_guider_path.split('.')[0].split('/')[-1] 118 | save_dir_name = f"{time_str}-{m1}" 119 | 120 | save_dir = Path(f"./output/image-{date_str}/{save_dir_name}") 121 | save_dir.mkdir(exist_ok=True, parents=True) 122 | 123 | def handle_single(ref_image_path, pose_path,seed): 124 | generator = torch.manual_seed(seed) 125 | ref_name = Path(ref_image_path).stem 126 | # pose_name = Path(pose_image_path).stem.replace("_kps", "") 127 | pose_name = Path(pose_path).stem 128 | 129 | ref_image_pil = Image.open(ref_image_path).convert("RGB") 130 | pose_image = Image.open(pose_path).convert("RGB") 131 | 132 | original_width, original_height = pose_image.size 133 | 134 | pose_transform = transforms.Compose( 135 | [transforms.Resize((height, width)), transforms.ToTensor()] 136 | ) 137 | 138 | pose_image_tensor = pose_transform(pose_image) 139 | pose_image_tensor = pose_image_tensor.unsqueeze(0) # (1, c, h, w) 140 | 141 | ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w) 142 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w) 143 | 144 | image = pipe( 145 | ref_image_pil, 146 | pose_image, 147 | width, 148 | height, 149 | args.steps, 150 | args.cfg, 151 | generator=generator, 152 | ).images 153 | 154 | image = image.squeeze(2).squeeze(0) # (c, h, w) 155 | image = image.transpose(0, 1).transpose(1, 2) # (h w c) 156 | #image = (image + 1.0) / 2.0 # -1,1 -> 0,1 157 | 158 | image = (image * 255).numpy().astype(np.uint8) 159 | image = Image.fromarray(image, 'RGB') 160 | # image.save(os.path.join(save_dir, f"{ref_name}_{pose_name}_{args.H}x{args.W}_{int(args.cfg)}_{time_str}.png")) 161 | 162 | image_grid = Image.new('RGB',(original_width*3,original_height)) 163 | imgs = [ref_image_pil,pose_image,image] 164 | x_offset = 0 165 | for img in imgs: 166 | img = img.resize((original_width*2, original_height*2)) 167 | img.save(os.path.join(save_dir, f"res_{ref_name}_{pose_name}_{args.cfg}_{seed}.jpg")) 168 | img = img.resize((original_width,original_height)) 169 | image_grid.paste(img, (x_offset,0)) 170 | x_offset += img.size[0] 171 | image_grid.save(os.path.join(save_dir, f"grid_{ref_name}_{pose_name}_{args.cfg}_{seed}.jpg")) 172 | 173 | 174 | for ref_image_path_dir in config["test_cases"].keys(): 175 | if os.path.isdir(ref_image_path_dir): 176 | ref_image_paths = glob.glob(os.path.join(ref_image_path_dir, '*.jpg')) 177 | else: 178 | ref_image_paths = [ref_image_path_dir] 179 | for ref_image_path in ref_image_paths: 180 | for pose_image_path_dir in config["test_cases"][ref_image_path_dir]: 181 | if os.path.isdir(pose_image_path_dir): 182 | pose_image_paths = glob.glob(os.path.join(pose_image_path_dir, '*.jpg')) 183 | else: 184 | pose_image_paths = [pose_image_path_dir] 185 | for pose_image_path in pose_image_paths: 186 | for i in range(args.cnt): 187 | handle_single(ref_image_path, pose_image_path, args.seed + i) 188 | 189 | 190 | if __name__ == "__main__": 191 | main() 192 | -------------------------------------------------------------------------------- /test_stage_2.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import argparse 3 | from datetime import datetime 4 | from pathlib import Path 5 | from typing import List 6 | 7 | import av 8 | import numpy as np 9 | import torch 10 | import torchvision 11 | from diffusers import AutoencoderKL, DDIMScheduler 12 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline 13 | from einops import repeat 14 | from omegaconf import OmegaConf 15 | from PIL import Image 16 | from torchvision import transforms 17 | from transformers import CLIPVisionModelWithProjection 18 | import glob 19 | import torch.nn.functional as F 20 | 21 | from musepose.models.pose_guider import PoseGuider 22 | from musepose.models.unet_2d_condition import UNet2DConditionModel 23 | from musepose.models.unet_3d import UNet3DConditionModel 24 | from musepose.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline 25 | from musepose.utils.util import get_fps, read_frames, save_videos_grid 26 | 27 | 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--config", type=str, default="./configs/test_stage_2.yaml") 32 | parser.add_argument("-W", type=int, default=768, help="Width") 33 | parser.add_argument("-H", type=int, default=768, help="Height") 34 | parser.add_argument("-L", type=int, default=300, help="video frame length") 35 | parser.add_argument("-S", type=int, default=48, help="video slice frame number") 36 | parser.add_argument("-O", type=int, default=4, help="video slice overlap frame number") 37 | 38 | parser.add_argument("--cfg", type=float, default=3.5, help="Classifier free guidance") 39 | parser.add_argument("--seed", type=int, default=99) 40 | parser.add_argument("--steps", type=int, default=20, help="DDIM sampling steps") 41 | parser.add_argument("--fps", type=int) 42 | 43 | parser.add_argument("--skip", type=int, default=1, help="frame sample rate = (skip+1)") 44 | args = parser.parse_args() 45 | 46 | print('Width:', args.W) 47 | print('Height:', args.H) 48 | print('Length:', args.L) 49 | print('Slice:', args.S) 50 | print('Overlap:', args.O) 51 | print('Classifier free guidance:', args.cfg) 52 | print('DDIM sampling steps :', args.steps) 53 | print("skip", args.skip) 54 | 55 | return args 56 | 57 | 58 | def scale_video(video,width,height): 59 | video_reshaped = video.view(-1, *video.shape[2:]) # [batch*frames, channels, height, width] 60 | scaled_video = F.interpolate(video_reshaped, size=(height, width), mode='bilinear', align_corners=False) 61 | scaled_video = scaled_video.view(*video.shape[:2], scaled_video.shape[1], height, width) # [batch, frames, channels, height, width] 62 | 63 | return scaled_video 64 | 65 | 66 | def main(): 67 | args = parse_args() 68 | 69 | config = OmegaConf.load(args.config) 70 | 71 | if config.weight_dtype == "fp16": 72 | weight_dtype = torch.float16 73 | else: 74 | weight_dtype = torch.float32 75 | 76 | vae = AutoencoderKL.from_pretrained( 77 | config.pretrained_vae_path, 78 | ).to("cuda", dtype=weight_dtype) 79 | 80 | reference_unet = UNet2DConditionModel.from_pretrained( 81 | config.pretrained_base_model_path, 82 | subfolder="unet", 83 | ).to(dtype=weight_dtype, device="cuda") 84 | 85 | inference_config_path = config.inference_config 86 | infer_config = OmegaConf.load(inference_config_path) 87 | denoising_unet = UNet3DConditionModel.from_pretrained_2d( 88 | config.pretrained_base_model_path, 89 | config.motion_module_path, 90 | subfolder="unet", 91 | unet_additional_kwargs=infer_config.unet_additional_kwargs, 92 | ).to(dtype=weight_dtype, device="cuda") 93 | 94 | pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to( 95 | dtype=weight_dtype, device="cuda" 96 | ) 97 | 98 | image_enc = CLIPVisionModelWithProjection.from_pretrained( 99 | config.image_encoder_path 100 | ).to(dtype=weight_dtype, device="cuda") 101 | 102 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) 103 | scheduler = DDIMScheduler(**sched_kwargs) 104 | 105 | generator = torch.manual_seed(args.seed) 106 | 107 | width, height = args.W, args.H 108 | 109 | # load pretrained weights 110 | denoising_unet.load_state_dict( 111 | torch.load(config.denoising_unet_path, map_location="cpu"), 112 | strict=False, 113 | ) 114 | reference_unet.load_state_dict( 115 | torch.load(config.reference_unet_path, map_location="cpu"), 116 | ) 117 | pose_guider.load_state_dict( 118 | torch.load(config.pose_guider_path, map_location="cpu"), 119 | ) 120 | 121 | pipe = Pose2VideoPipeline( 122 | vae=vae, 123 | image_encoder=image_enc, 124 | reference_unet=reference_unet, 125 | denoising_unet=denoising_unet, 126 | pose_guider=pose_guider, 127 | scheduler=scheduler, 128 | ) 129 | pipe = pipe.to("cuda", dtype=weight_dtype) 130 | 131 | date_str = datetime.now().strftime("%Y%m%d") 132 | time_str = datetime.now().strftime("%H%M") 133 | 134 | def handle_single(ref_image_path,pose_video_path): 135 | print ('handle===',ref_image_path, pose_video_path) 136 | ref_name = Path(ref_image_path).stem 137 | pose_name = Path(pose_video_path).stem.replace("_kps", "") 138 | 139 | ref_image_pil = Image.open(ref_image_path).convert("RGB") 140 | 141 | pose_list = [] 142 | pose_tensor_list = [] 143 | pose_images = read_frames(pose_video_path) 144 | src_fps = get_fps(pose_video_path) 145 | print(f"pose video has {len(pose_images)} frames, with {src_fps} fps") 146 | L = min(args.L, len(pose_images)) 147 | pose_transform = transforms.Compose( 148 | [transforms.Resize((height, width)), transforms.ToTensor()] 149 | ) 150 | original_width,original_height = 0,0 151 | 152 | pose_images = pose_images[::args.skip+1] 153 | print("processing length:", len(pose_images)) 154 | src_fps = src_fps // (args.skip + 1) 155 | print("fps", src_fps) 156 | L = L // ((args.skip + 1)) 157 | 158 | for pose_image_pil in pose_images[: L]: 159 | pose_tensor_list.append(pose_transform(pose_image_pil)) 160 | pose_list.append(pose_image_pil) 161 | original_width, original_height = pose_image_pil.size 162 | pose_image_pil = pose_image_pil.resize((width,height)) 163 | 164 | # repeart the last segment 165 | last_segment_frame_num = (L - args.S) % (args.S - args.O) 166 | repeart_frame_num = (args.S - args.O - last_segment_frame_num) % (args.S - args.O) 167 | for i in range(repeart_frame_num): 168 | pose_list.append(pose_list[-1]) 169 | pose_tensor_list.append(pose_tensor_list[-1]) 170 | 171 | 172 | ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w) 173 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w) 174 | ref_image_tensor = repeat(ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=L) 175 | 176 | pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w) 177 | pose_tensor = pose_tensor.transpose(0, 1) 178 | pose_tensor = pose_tensor.unsqueeze(0) 179 | 180 | video = pipe( 181 | ref_image_pil, 182 | pose_list, 183 | width, 184 | height, 185 | len(pose_list), 186 | args.steps, 187 | args.cfg, 188 | generator=generator, 189 | context_frames=args.S, 190 | context_stride=1, 191 | context_overlap=args.O, 192 | ).videos 193 | 194 | 195 | m1 = config.pose_guider_path.split('.')[0].split('/')[-1] 196 | m2 = config.motion_module_path.split('.')[0].split('/')[-1] 197 | 198 | save_dir_name = f"{time_str}-{args.cfg}-{m1}-{m2}" 199 | save_dir = Path(f"./output/video-{date_str}/{save_dir_name}") 200 | save_dir.mkdir(exist_ok=True, parents=True) 201 | 202 | result = scale_video(video[:,:,:L], original_width, original_height) 203 | save_videos_grid( 204 | result, 205 | f"{save_dir}/{ref_name}_{pose_name}_{args.cfg}_{args.steps}_{args.skip}.mp4", 206 | n_rows=1, 207 | fps=src_fps if args.fps is None else args.fps, 208 | ) 209 | 210 | video = torch.cat([ref_image_tensor, pose_tensor[:,:,:L], video[:,:,:L]], dim=0) 211 | video = scale_video(video, original_width, original_height) 212 | save_videos_grid( 213 | video, 214 | f"{save_dir}/{ref_name}_{pose_name}_{args.cfg}_{args.steps}_{args.skip}_{m1}_{m2}.mp4", 215 | n_rows=3, 216 | fps=src_fps if args.fps is None else args.fps, 217 | ) 218 | 219 | for ref_image_path_dir in config["test_cases"].keys(): 220 | if os.path.isdir(ref_image_path_dir): 221 | ref_image_paths = glob.glob(os.path.join(ref_image_path_dir, '*.jpg')) 222 | else: 223 | ref_image_paths = [ref_image_path_dir] 224 | for ref_image_path in ref_image_paths: 225 | for pose_video_path_dir in config["test_cases"][ref_image_path_dir]: 226 | if os.path.isdir(pose_video_path_dir): 227 | pose_video_paths = glob.glob(os.path.join(pose_video_path_dir, '*.mp4')) 228 | else: 229 | pose_video_paths = [pose_video_path_dir] 230 | for pose_video_path in pose_video_paths: 231 | handle_single(ref_image_path, pose_video_path) 232 | 233 | 234 | 235 | 236 | if __name__ == "__main__": 237 | main() 238 | --------------------------------------------------------------------------------