├── .gitignore
├── .gitmodules
├── LICENSE
├── NOTICE
├── README.md
├── app.py
├── assets
├── cai-xukun.mp4
└── mini_program_maliang.png
├── configs
├── inference
│ ├── inference_v1.yaml
│ ├── inference_v2.yaml
│ ├── pose_images
│ │ └── pose-1.png
│ ├── pose_videos
│ │ ├── anyone-video-1_kps.mp4
│ │ ├── anyone-video-2_kps.mp4
│ │ ├── anyone-video-4_kps.mp4
│ │ └── anyone-video-5_kps.mp4
│ └── ref_images
│ │ ├── anyone-1.png
│ │ ├── anyone-10.png
│ │ ├── anyone-11.png
│ │ ├── anyone-2.png
│ │ ├── anyone-3.png
│ │ └── anyone-5.png
├── prompts
│ ├── animation.yaml
│ └── test_cases.py
└── train
│ ├── stage1.yaml
│ └── stage2.yaml
├── install.ps1
├── install_cn.ps1
├── requirements-windows.txt
├── requirements.txt
├── run_VidControlnetAux_gui.ps1
├── run_gui.ps1
├── scripts
└── pose2vid.py
├── src
├── __init__.py
├── dataset
│ ├── dance_image.py
│ └── dance_video.py
├── dwpose
│ ├── __init__.py
│ ├── onnxdet.py
│ ├── onnxpose.py
│ ├── util.py
│ └── wholebody.py
├── models
│ ├── attention.py
│ ├── model_util.py
│ ├── motion_module.py
│ ├── mutual_self_attention.py
│ ├── pose_guider.py
│ ├── resnet.py
│ ├── transformer_2d.py
│ ├── transformer_3d.py
│ ├── unet_2d_blocks.py
│ ├── unet_2d_condition.py
│ ├── unet_3d.py
│ └── unet_3d_blocks.py
├── pipelines
│ ├── __init__.py
│ ├── context.py
│ ├── pipeline_pose2img.py
│ ├── pipeline_pose2vid.py
│ ├── pipeline_pose2vid_long.py
│ └── utils.py
└── utils
│ └── util.py
├── tools
├── download_weights.py
├── extract_dwpose_from_vid.py
├── extract_meta_info.py
└── vid2pose.py
├── train_stage_1.py
└── train_stage_2.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | pretrained_weights/
3 | output/
4 | outputs/
5 | .venv/
6 | venv/
7 | huggingface/
8 | .venv/
9 | mlruns/
10 | data/
11 |
12 | *.pth
13 | *.pt
14 | *.pkl
15 | *.bin
16 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "video_controlnet_aux"]
2 | path = video_controlnet_aux
3 | url = https://github.com/sdbds/video_controlnet_aux
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright @2023-2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
2 |
3 | Apache License
4 | Version 2.0, January 2004
5 | http://www.apache.org/licenses/
6 |
7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8 |
9 | 1. Definitions.
10 |
11 | "License" shall mean the terms and conditions for use, reproduction,
12 | and distribution as defined by Sections 1 through 9 of this document.
13 |
14 | "Licensor" shall mean the copyright owner or entity authorized by
15 | the copyright owner that is granting the License.
16 |
17 | "Legal Entity" shall mean the union of the acting entity and all
18 | other entities that control, are controlled by, or are under common
19 | control with that entity. For the purposes of this definition,
20 | "control" means (i) the power, direct or indirect, to cause the
21 | direction or management of such entity, whether by contract or
22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
23 | outstanding shares, or (iii) beneficial ownership of such entity.
24 |
25 | "You" (or "Your") shall mean an individual or Legal Entity
26 | exercising permissions granted by this License.
27 |
28 | "Source" form shall mean the preferred form for making modifications,
29 | including but not limited to software source code, documentation
30 | source, and configuration files.
31 |
32 | "Object" form shall mean any form resulting from mechanical
33 | transformation or translation of a Source form, including but
34 | not limited to compiled object code, generated documentation,
35 | and conversions to other media types.
36 |
37 | "Work" shall mean the work of authorship, whether in Source or
38 | Object form, made available under the License, as indicated by a
39 | copyright notice that is included in or attached to the work
40 | (an example is provided in the Appendix below).
41 |
42 | "Derivative Works" shall mean any work, whether in Source or Object
43 | form, that is based on (or derived from) the Work and for which the
44 | editorial revisions, annotations, elaborations, or other modifications
45 | represent, as a whole, an original work of authorship. For the purposes
46 | of this License, Derivative Works shall not include works that remain
47 | separable from, or merely link (or bind by name) to the interfaces of,
48 | the Work and Derivative Works thereof.
49 |
50 | "Contribution" shall mean any work of authorship, including
51 | the original version of the Work and any modifications or additions
52 | to that Work or Derivative Works thereof, that is intentionally
53 | submitted to Licensor for inclusion in the Work by the copyright owner
54 | or by an individual or Legal Entity authorized to submit on behalf of
55 | the copyright owner. For the purposes of this definition, "submitted"
56 | means any form of electronic, verbal, or written communication sent
57 | to the Licensor or its representatives, including but not limited to
58 | communication on electronic mailing lists, source code control systems,
59 | and issue tracking systems that are managed by, or on behalf of, the
60 | Licensor for the purpose of discussing and improving the Work, but
61 | excluding communication that is conspicuously marked or otherwise
62 | designated in writing by the copyright owner as "Not a Contribution."
63 |
64 | "Contributor" shall mean Licensor and any individual or Legal Entity
65 | on behalf of whom a Contribution has been received by Licensor and
66 | subsequently incorporated within the Work.
67 |
68 | 2. Grant of Copyright License. Subject to the terms and conditions of
69 | this License, each Contributor hereby grants to You a perpetual,
70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71 | copyright license to reproduce, prepare Derivative Works of,
72 | publicly display, publicly perform, sublicense, and distribute the
73 | Work and such Derivative Works in Source or Object form.
74 |
75 | 3. Grant of Patent License. Subject to the terms and conditions of
76 | this License, each Contributor hereby grants to You a perpetual,
77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78 | (except as stated in this section) patent license to make, have made,
79 | use, offer to sell, sell, import, and otherwise transfer the Work,
80 | where such license applies only to those patent claims licensable
81 | by such Contributor that are necessarily infringed by their
82 | Contribution(s) alone or by combination of their Contribution(s)
83 | with the Work to which such Contribution(s) was submitted. If You
84 | institute patent litigation against any entity (including a
85 | cross-claim or counterclaim in a lawsuit) alleging that the Work
86 | or a Contribution incorporated within the Work constitutes direct
87 | or contributory patent infringement, then any patent licenses
88 | granted to You under this License for that Work shall terminate
89 | as of the date such litigation is filed.
90 |
91 | 4. Redistribution. You may reproduce and distribute copies of the
92 | Work or Derivative Works thereof in any medium, with or without
93 | modifications, and in Source or Object form, provided that You
94 | meet the following conditions:
95 |
96 | (a) You must give any other recipients of the Work or
97 | Derivative Works a copy of this License; and
98 |
99 | (b) You must cause any modified files to carry prominent notices
100 | stating that You changed the files; and
101 |
102 | (c) You must retain, in the Source form of any Derivative Works
103 | that You distribute, all copyright, patent, trademark, and
104 | attribution notices from the Source form of the Work,
105 | excluding those notices that do not pertain to any part of
106 | the Derivative Works; and
107 |
108 | (d) If the Work includes a "NOTICE" text file as part of its
109 | distribution, then any Derivative Works that You distribute must
110 | include a readable copy of the attribution notices contained
111 | within such NOTICE file, excluding those notices that do not
112 | pertain to any part of the Derivative Works, in at least one
113 | of the following places: within a NOTICE text file distributed
114 | as part of the Derivative Works; within the Source form or
115 | documentation, if provided along with the Derivative Works; or,
116 | within a display generated by the Derivative Works, if and
117 | wherever such third-party notices normally appear. The contents
118 | of the NOTICE file are for informational purposes only and
119 | do not modify the License. You may add Your own attribution
120 | notices within Derivative Works that You distribute, alongside
121 | or as an addendum to the NOTICE text from the Work, provided
122 | that such additional attribution notices cannot be construed
123 | as modifying the License.
124 |
125 | You may add Your own copyright statement to Your modifications and
126 | may provide additional or different license terms and conditions
127 | for use, reproduction, or distribution of Your modifications, or
128 | for any such Derivative Works as a whole, provided Your use,
129 | reproduction, and distribution of the Work otherwise complies with
130 | the conditions stated in this License.
131 |
132 | 5. Submission of Contributions. Unless You explicitly state otherwise,
133 | any Contribution intentionally submitted for inclusion in the Work
134 | by You to the Licensor shall be under the terms and conditions of
135 | this License, without any additional terms or conditions.
136 | Notwithstanding the above, nothing herein shall supersede or modify
137 | the terms of any separate license agreement you may have executed
138 | with Licensor regarding such Contributions.
139 |
140 | 6. Trademarks. This License does not grant permission to use the trade
141 | names, trademarks, service marks, or product names of the Licensor,
142 | except as required for reasonable and customary use in describing the
143 | origin of the Work and reproducing the content of the NOTICE file.
144 |
145 | 7. Disclaimer of Warranty. Unless required by applicable law or
146 | agreed to in writing, Licensor provides the Work (and each
147 | Contributor provides its Contributions) on an "AS IS" BASIS,
148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149 | implied, including, without limitation, any warranties or conditions
150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151 | PARTICULAR PURPOSE. You are solely responsible for determining the
152 | appropriateness of using or redistributing the Work and assume any
153 | risks associated with Your exercise of permissions under this License.
154 |
155 | 8. Limitation of Liability. In no event and under no legal theory,
156 | whether in tort (including negligence), contract, or otherwise,
157 | unless required by applicable law (such as deliberate and grossly
158 | negligent acts) or agreed to in writing, shall any Contributor be
159 | liable to You for damages, including any direct, indirect, special,
160 | incidental, or consequential damages of any character arising as a
161 | result of this License or out of the use or inability to use the
162 | Work (including but not limited to damages for loss of goodwill,
163 | work stoppage, computer failure or malfunction, or any and all
164 | other commercial damages or losses), even if such Contributor
165 | has been advised of the possibility of such damages.
166 |
167 | 9. Accepting Warranty or Additional Liability. While redistributing
168 | the Work or Derivative Works thereof, You may choose to offer,
169 | and charge a fee for, acceptance of support, warranty, indemnity,
170 | or other liability obligations and/or rights consistent with this
171 | License. However, in accepting such obligations, You may act only
172 | on Your own behalf and on Your sole responsibility, not on behalf
173 | of any other Contributor, and only if You agree to indemnify,
174 | defend, and hold each Contributor harmless for any liability
175 | incurred by, or claims asserted against, such Contributor by reason
176 | of your accepting any such warranty or additional liability.
177 |
178 | END OF TERMS AND CONDITIONS
179 |
180 | APPENDIX: How to apply the Apache License to your work.
181 |
182 | To apply the Apache License to your work, attach the following
183 | boilerplate notice, with the fields enclosed by brackets "[]"
184 | replaced with your own identifying information. (Don't include
185 | the brackets!) The text should be enclosed in the appropriate
186 | comment syntax for the file format. We also recommend that a
187 | file or class name and description of purpose be included on the
188 | same "printed page" as the copyright notice for easier
189 | identification within third-party archives.
190 |
191 | Copyright [yyyy] [name of copyright owner]
192 |
193 | Licensed under the Apache License, Version 2.0 (the "License");
194 | you may not use this file except in compliance with the License.
195 | You may obtain a copy of the License at
196 |
197 | http://www.apache.org/licenses/LICENSE-2.0
198 |
199 | Unless required by applicable law or agreed to in writing, software
200 | distributed under the License is distributed on an "AS IS" BASIS,
201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202 | See the License for the specific language governing permissions and
203 | limitations under the License.
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | ==============================================================
2 | This repo also contains various third-party components and some code modified from other repos under other open source licenses. The following sections contain licensing infromation for such third-party libraries.
3 |
4 | -----------------------------
5 | majic-animate
6 | BSD 3-Clause License
7 | Copyright (c) Bytedance Inc.
8 |
9 | -----------------------------
10 | animatediff
11 | Apache License, Version 2.0
12 |
13 | -----------------------------
14 | Dwpose
15 | Apache License, Version 2.0
16 |
17 | -----------------------------
18 | inference pipeline for animatediff-cli-prompt-travel
19 | animatediff-cli-prompt-travel
20 | Apache License, Version 2.0
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🤗 Introduction
2 |
3 | **update** 🏋️🏋️🏋️ We release our training codes!! Now you can train your own AnimateAnyone models. See [here](#train) for more details. Have fun!
4 |
5 | **update**:🔥🔥🔥 We launch a HuggingFace Spaces demo of Moore-AnimateAnyone at [here](https://huggingface.co/spaces/xunsong/Moore-AnimateAnyone)!!
6 |
7 | This repository reproduces [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone). To align the results demonstrated by the original paper, we adopt various approaches and tricks, which may differ somewhat from the paper and another [implementation](https://github.com/guoqincode/Open-AnimateAnyone).
8 |
9 | It's worth noting that this is a very preliminary version, aiming for approximating the performance (roughly 80% under our test) showed in [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone).
10 |
11 | We will continue to develop it, and also welcome feedbacks and ideas from the community. The enhanced version will also be launched on our [MoBi MaLiang](https://maliang.mthreads.com/) AIGC platform, running on our own full-featured GPU S4000 cloud computing platform.
12 |
13 | # 📝 Release Plans
14 |
15 | - [x] Inference codes and pretrained weights
16 | - [x] Training scripts
17 |
18 | # 🎞️ Examples
19 |
20 | Here are some results we generated, with the resolution of 512x768.
21 |
22 | https://github.com/MooreThreads/Moore-AnimateAnyone/assets/138439222/f0454f30-6726-4ad4-80a7-5b7a15619057
23 |
24 | https://github.com/MooreThreads/Moore-AnimateAnyone/assets/138439222/337ff231-68a3-4760-a9f9-5113654acf48
25 |
26 |
27 |
28 |
29 |
30 |
31 | |
32 |
33 |
34 | |
35 |
36 |
37 |
38 |
39 |
40 | |
41 |
42 |
43 | |
44 |
45 |
46 |
47 | **Limitation**: We observe following shortcomings in current version:
48 | 1. The background may occur some artifacts, when the reference image has a clean background
49 | 2. Suboptimal results may arise when there is a scale mismatch between the reference image and keypoints. We have yet to implement preprocessing techniques as mentioned in the [paper](https://arxiv.org/pdf/2311.17117.pdf).
50 | 3. Some flickering and jittering may occur when the motion sequence is subtle or the scene is static.
51 |
52 | These issues will be addressed and improved in the near future. We appreciate your anticipation!
53 |
54 | # ⚒️ Installation
55 |
56 | prerequisites: `3.11>=python>=3.8`, `CUDA>=11.3`, `ffmpeg` and `git`.
57 |
58 | Python and Git:
59 |
60 | - Python 3.10.11: https://www.python.org/ftp/python/3.10.11/python-3.10.11-amd64.exe
61 | - git: https://git-scm.com/download/win
62 |
63 | - Install [ffmpeg](https://ffmpeg.org/) for your operating system
64 | (https://www.geeksforgeeks.org/how-to-install-ffmpeg-on-windows/)
65 |
66 | notice:step 4 use windows system Set Enviroment Path.
67 |
68 | Give unrestricted script access to powershell so venv can work:
69 |
70 | - Open an administrator powershell window
71 | - Type `Set-ExecutionPolicy Unrestricted` and answer A
72 | - Close admin powershell window
73 |
74 | ```
75 | git clone --recurse-submodules https://github.com/sdbds/Moore-AnimateAnyone-for-windows/
76 | ```
77 |
78 | Install with Powershell run `install.ps1` or `install-cn.ps1`(for Chinese)
79 |
80 | ### Use local model
81 |
82 | Add loading local safetensors or ckpt,you can change `config/prompts/animation.yaml` about `pretrained_weights` for your local SD1.5 model.
83 | such as `"D:\\stablediffusion-webui\\models\\Stable-diffusion\\v1-5-pruned.ckpt"`
84 |
85 | ## No need Download models manually
86 | ~~Download weights~~
87 |
88 | ~~Download our trained [weights](https://huggingface.co/patrolli/AnimateAnyone/tree/main), which include four parts: `denoising_unet.pth`, `reference_unet.pth`, `pose_guider.pth` and `motion_module.pth`.~~
89 |
90 | ~~Download pretrained weight of based models and other components:~~
91 | ~~- [StableDiffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5)~~
92 | ~~- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)~~
93 | ~~- [image_encoder](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/tree/main/image_encoder)~~
94 |
95 | ~~Download dwpose weights (`dw-ll_ucoco_384.onnx`, `yolox_l.onnx`) following [this](https://github.com/IDEA-Research/DWPose?tab=readme-ov-file#-dwpose-for-controlnet).~~
96 |
97 | ~~Put these weights under a directory, like `./pretrained_weights`, and orgnize them as follows:~~
98 |
99 | ```text
100 | ./pretrained_weights/
101 | |-- DWPose
102 | | |-- dw-ll_ucoco_384.onnx
103 | | `-- yolox_l.onnx
104 | |-- image_encoder
105 | | |-- config.json
106 | | `-- pytorch_model.bin
107 | |-- denoising_unet.pth
108 | |-- motion_module.pth
109 | |-- pose_guider.pth
110 | |-- reference_unet.pth
111 | `-- stable-diffusion-v1-5
112 | |-- feature_extractor
113 | | `-- preprocessor_config.json
114 | |-- model_index.json
115 | |-- unet
116 | | |-- config.json
117 | | `-- diffusion_pytorch_model.bin
118 | `-- v1-inference.yaml
119 | ```
120 |
121 | ~~Note: If you have installed some of the pretrained models, such as `StableDiffusion V1.5`, you can specify their paths in the config file (e.g. `./config/prompts/animation.yaml`).~~
122 |
123 | # 🚀 Training and Inference
124 |
125 | ## Inference
126 |
127 | Here is the cli command for running inference scripts:
128 |
129 | ```shell
130 | python -m scripts.pose2vid --config ./configs/prompts/animation.yaml -W 512 -H 784 -L 64
131 | ```
132 |
133 | You can refer the format of `animation.yaml` to add your own reference images or pose videos. To convert the raw video into a pose video (keypoint sequence), you can run with the following command:
134 |
135 | ```shell
136 | python tools/vid2pose.py --video_path /path/to/your/video.mp4
137 | ```
138 |
139 | # 🎨 Gradio Demo
140 |
141 | ### Local Gradio Demo:
142 |
143 | Launch local gradio demo on GPU:
144 |
145 | Powershell run with `run_gui.ps1`
146 |
147 | Then open gradio demo in local browser.
148 |
149 | ### Online Gradio Demo:
150 | ## Training
151 |
152 | Note: package dependencies have been updated, you may upgrade your environment via `pip install -r requirements.txt` before training.
153 |
154 | ### Data Preparation
155 |
156 | Extract keypoints from raw videos:
157 |
158 | ```shell
159 | python tools/extract_dwpose_from_vid.py --video_root /path/to/your/video_dir
160 | ```
161 |
162 | Extract the meta info of dataset:
163 |
164 | ```shell
165 | python tools/extract_meta_info.py --root_path /path/to/your/video_dir --dataset_name anyone
166 | ```
167 |
168 | Update lines in the training config file:
169 |
170 | ```yaml
171 | data:
172 | meta_paths:
173 | - "./data/anyone_meta.json"
174 | ```
175 |
176 | ### Stage1
177 |
178 | Put [openpose controlnet weights](https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/tree/main) under `./pretrained_weights`, which is used to initialize the pose_guider.
179 |
180 | Put [sd-image-variation](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/tree/main) under `./pretrained_weights`, which is used to initialize unet weights.
181 |
182 | Run command:
183 |
184 | ```shell
185 | accelerate launch train_stage_1.py --config configs/train/stage1.yaml
186 | ```
187 |
188 | ### Stage2
189 |
190 | Put the pretrained motion module weights `mm_sd_v15_v2.ckpt` ([download link](https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15_v2.ckpt)) under `./pretrained_weights`.
191 |
192 | Specify the stage1 training weights in the config file `stage2.yaml`, for example:
193 |
194 | ```yaml
195 | stage1_ckpt_dir: './exp_output/stage1'
196 | stage1_ckpt_step: 30000
197 | ```
198 |
199 | Run command:
200 |
201 | ```shell
202 | accelerate launch train_stage_2.py --config configs/train/stage2.yaml
203 | ```
204 |
205 | **HuggingFace Demo**: We launch a quick preview demo of Moore-AnimateAnyone at [HuggingFace Spaces](https://huggingface.co/spaces/xunsong/Moore-AnimateAnyone)!!
206 |
207 | We appreciate the assistance provided by the HuggingFace team in setting up this demo.
208 |
209 | To reduce waiting time, we limit the size (width, height, and length) and inference steps when generating videos.
210 |
211 | If you have your own GPU resource (>= 16GB vram), you can run a local gradio app via following commands:
212 |
213 | `python app.py`
214 |
215 | # Community Contributions
216 |
217 | - Installation for Windows users: [Moore-AnimateAnyone-for-windows](https://github.com/sdbds/Moore-AnimateAnyone-for-windows)
218 |
219 | # 🖌️ Try on Mobi MaLiang
220 |
221 | We will launched this model on our [MoBi MaLiang](https://maliang.mthreads.com/) AIGC platform, running on our own full-featured GPU S4000 cloud computing platform. Mobi MaLiang has now integrated various AIGC applications and functionalities (e.g. text-to-image, controllable generation...). You can experience it by [clicking this link](https://maliang.mthreads.com/) or scanning the QR code bellow via WeChat!
222 |
223 |
224 |
226 |
227 |
228 | # ⚖️ Disclaimer
229 |
230 | This project is intended for academic research, and we explicitly disclaim any responsibility for user-generated content. Users are solely liable for their actions while using the generative model. The project contributors have no legal affiliation with, nor accountability for, users' behaviors. It is imperative to use the generative model responsibly, adhering to both ethical and legal standards.
231 |
232 | # 🙏🏻 Acknowledgements
233 |
234 | We first thank the authors of [AnimateAnyone](). Additionally, we would like to thank the contributors to the [majic-animate](https://github.com/magic-research/magic-animate), [animatediff](https://github.com/guoyww/AnimateDiff) and [Open-AnimateAnyone](https://github.com/guoqincode/Open-AnimateAnyone) repositories, for their open research and exploration. Furthermore, our repo incorporates some codes from [dwpose](https://github.com/IDEA-Research/DWPose) and [animatediff-cli-prompt-travel](https://github.com/s9roll7/animatediff-cli-prompt-travel/), and we extend our thanks to them as well.
235 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from datetime import datetime
4 |
5 | import gradio as gr
6 | import numpy as np
7 | import torch
8 | from diffusers import AutoencoderKL, DDIMScheduler
9 | from diffusers.models.attention_processor import AttnProcessor2_0
10 | from einops import repeat
11 | from omegaconf import OmegaConf
12 | from PIL import Image
13 | from torchvision import transforms
14 | from transformers import CLIPVisionModelWithProjection
15 |
16 | from src.models.model_util import load_models, torch_gc, get_torch_device
17 | from src.models.pose_guider import PoseGuider
18 | from src.models.unet_2d_condition import UNet2DConditionModel
19 | from src.models.unet_3d import UNet3DConditionModel
20 | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
21 | from src.utils.util import get_fps, read_frames, save_videos_grid
22 |
23 |
24 | class AnimateController:
25 | def __init__(
26 | self,
27 | config_path="./configs/prompts/animation.yaml",
28 | weight_dtype=torch.float16,
29 | ):
30 | # Read pretrained weights path from config
31 | self.config = OmegaConf.load(config_path)
32 | self.pipeline = None
33 | self.weight_dtype = weight_dtype
34 |
35 | def animate(
36 | self,
37 | ref_image,
38 | pose_video_path,
39 | width=512,
40 | height=768,
41 | length=24,
42 | num_inference_steps=25,
43 | cfg=3.5,
44 | seed=123,
45 | ):
46 | generator = torch.manual_seed(seed)
47 | self.device = get_torch_device()
48 | if isinstance(ref_image, np.ndarray):
49 | ref_image = Image.fromarray(ref_image)
50 | if self.pipeline is None:
51 | (
52 | _,
53 | _,
54 | self.unet,
55 | _,
56 | self.vae,
57 | ) = load_models(
58 | self.config.pretrained_base_model_path,
59 | scheduler_name="",
60 | v2=False,
61 | v_pred=False,
62 | weight_dtype=self.weight_dtype,
63 | )
64 | self.vae = self.vae.to(self.device, dtype=self.weight_dtype)
65 | self.reference_unet = self.unet.to(dtype=self.weight_dtype, device=self.device)
66 |
67 | inference_config_path = self.config.inference_config
68 | infer_config = OmegaConf.load(inference_config_path)
69 | self.denoising_unet = UNet3DConditionModel.from_pretrained_2d(
70 | self.config.pretrained_base_model_path,
71 | self.config.motion_module_path,
72 | subfolder="unet",
73 | unet_additional_kwargs=infer_config.unet_additional_kwargs,
74 | ).to(dtype=self.weight_dtype, device=self.device)
75 |
76 | self.pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
77 | dtype=self.weight_dtype, device=self.device
78 | )
79 |
80 | self.image_enc = CLIPVisionModelWithProjection.from_pretrained(
81 | self.config.image_encoder_path
82 | ).to(dtype=self.weight_dtype, device=self.device)
83 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
84 | self.scheduler = DDIMScheduler(**sched_kwargs)
85 |
86 | # load pretrained weights
87 | self.denoising_unet.load_state_dict(
88 | torch.load(self.config.denoising_unet_path, map_location="cpu"),
89 | strict=False,
90 | )
91 | self.reference_unet.load_state_dict(
92 | torch.load(self.config.reference_unet_path, map_location="cpu"),
93 | )
94 |
95 | self.pose_guider.load_state_dict(
96 | torch.load(self.config.pose_guider_path, map_location="cpu"),
97 | )
98 |
99 | self.denoising_unet.set_attn_processor(AttnProcessor2_0())
100 | self.reference_unet.set_attn_processor(AttnProcessor2_0())
101 |
102 | pipe = Pose2VideoPipeline(
103 | vae=self.vae,
104 | image_encoder=self.image_enc,
105 | reference_unet=self.reference_unet,
106 | denoising_unet=self.denoising_unet,
107 | pose_guider=self.pose_guider,
108 | scheduler=self.scheduler,
109 | )
110 | pipe = pipe.to(self.device, dtype=self.weight_dtype)
111 | self.pipeline = pipe
112 |
113 | pose_images = read_frames(pose_video_path)
114 | src_fps = get_fps(pose_video_path)
115 |
116 | pose_list = []
117 | pose_tensor_list = []
118 | pose_transform = transforms.Compose(
119 | [transforms.Resize((height, width)), transforms.ToTensor()]
120 | )
121 | for pose_image_pil in pose_images[:length]:
122 | pose_list.append(pose_image_pil)
123 | pose_tensor_list.append(pose_transform(pose_image_pil))
124 |
125 | video = self.pipeline(
126 | ref_image,
127 | pose_list,
128 | width=width,
129 | height=height,
130 | video_length=length,
131 | num_inference_steps=num_inference_steps,
132 | guidance_scale=cfg,
133 | generator=generator,
134 | ).videos
135 |
136 | ref_image_tensor = pose_transform(ref_image) # (c, h, w)
137 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w)
138 | ref_image_tensor = repeat(
139 | ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=length
140 | )
141 | pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)
142 | pose_tensor = pose_tensor.transpose(0, 1)
143 | pose_tensor = pose_tensor.unsqueeze(0)
144 | video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0)
145 |
146 | save_dir = f"./output/gradio"
147 | if not os.path.exists(save_dir):
148 | os.makedirs(save_dir, exist_ok=True)
149 | date_str = datetime.now().strftime("%Y%m%d")
150 | time_str = datetime.now().strftime("%H%M")
151 | out_path = os.path.join(save_dir, f"{date_str}T{time_str}.mp4")
152 | save_videos_grid(
153 | video,
154 | out_path,
155 | n_rows=3,
156 | fps=src_fps,
157 | )
158 | self.vae.to("cpu")
159 | self.image_enc.to("cpu")
160 | self.reference_unet.to("cpu")
161 | self.denoising_unet.to("cpu")
162 | self.pose_guider.to("cpu")
163 | torch_gc()
164 |
165 | return out_path
166 |
167 |
168 | controller = AnimateController()
169 |
170 |
171 | def ui():
172 | with gr.Blocks() as demo:
173 | gr.Markdown(
174 | """
175 | # Moore-AnimateAnyone Demo
176 | """
177 | )
178 | animation = gr.Video(
179 | format="mp4",
180 | label="Animation Results",
181 | height=448,
182 | autoplay=True,
183 | )
184 |
185 | with gr.Row():
186 | reference_image = gr.Image(label="Reference Image")
187 | motion_sequence = gr.Video(
188 | format="mp4", label="Motion Sequence", height=512
189 | )
190 |
191 | with gr.Column():
192 | width_slider = gr.Slider(
193 | label="Width", minimum=448, maximum=768, value=512, step=64
194 | )
195 | height_slider = gr.Slider(
196 | label="Height", minimum=512, maximum=1024, value=768, step=64
197 | )
198 | length_slider = gr.Slider(
199 | label="Video Length", minimum=24, maximum=128, value=24, step=24
200 | )
201 | with gr.Row():
202 | seed_textbox = gr.Textbox(label="Seed", value=-1)
203 | seed_button = gr.Button(
204 | value="\U0001F3B2", elem_classes="toolbutton"
205 | )
206 | seed_button.click(
207 | fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)),
208 | inputs=[],
209 | outputs=[seed_textbox],
210 | )
211 | with gr.Row():
212 | sampling_steps = gr.Slider(
213 | label="Sampling steps",
214 | value=25,
215 | info="default: 25",
216 | step=5,
217 | maximum=30,
218 | minimum=10,
219 | )
220 | guidance_scale = gr.Slider(
221 | label="Guidance scale",
222 | value=3.5,
223 | info="default: 3.5",
224 | step=0.5,
225 | maximum=10,
226 | minimum=2.0,
227 | )
228 | submit = gr.Button("Animate")
229 |
230 | def read_video(video):
231 | return video
232 |
233 | def read_image(image):
234 | return Image.fromarray(image)
235 |
236 | # when user uploads a new video
237 | motion_sequence.upload(read_video, motion_sequence, motion_sequence)
238 | # when `first_frame` is updated
239 | reference_image.upload(read_image, reference_image, reference_image)
240 | # when the `submit` button is clicked
241 | submit.click(
242 | controller.animate,
243 | [
244 | reference_image,
245 | motion_sequence,
246 | width_slider,
247 | height_slider,
248 | length_slider,
249 | sampling_steps,
250 | guidance_scale,
251 | seed_textbox,
252 | ],
253 | animation,
254 | )
255 |
256 | # Examples
257 | gr.Markdown("## Examples")
258 | gr.Examples(
259 | examples=[
260 | [
261 | "./configs/inference/ref_images/anyone-5.png",
262 | "./configs/inference/pose_videos/anyone-video-2_kps.mp4",
263 | ],
264 | [
265 | "./configs/inference/ref_images/anyone-10.png",
266 | "./configs/inference/pose_videos/anyone-video-1_kps.mp4",
267 | ],
268 | [
269 | "./configs/inference/ref_images/anyone-2.png",
270 | "./configs/inference/pose_videos/anyone-video-5_kps.mp4",
271 | ],
272 | ],
273 | inputs=[reference_image, motion_sequence],
274 | outputs=animation,
275 | )
276 |
277 | return demo
278 |
279 |
280 | demo = ui()
281 | demo.launch(share=True)
282 |
--------------------------------------------------------------------------------
/assets/cai-xukun.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/Moore-AnimateAnyone-for-windows/7cc83a44681509cd6e9138a21ec87cb8a6108c2d/assets/cai-xukun.mp4
--------------------------------------------------------------------------------
/assets/mini_program_maliang.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/Moore-AnimateAnyone-for-windows/7cc83a44681509cd6e9138a21ec87cb8a6108c2d/assets/mini_program_maliang.png
--------------------------------------------------------------------------------
/configs/inference/inference_v1.yaml:
--------------------------------------------------------------------------------
1 | unet_additional_kwargs:
2 | unet_use_cross_frame_attention: false
3 | unet_use_temporal_attention: false
4 | use_motion_module: true
5 | motion_module_resolutions: [1,2,4,8]
6 | motion_module_mid_block: false
7 | motion_module_decoder_only: false
8 | motion_module_type: "Vanilla"
9 |
10 | motion_module_kwargs:
11 | num_attention_heads: 8
12 | num_transformer_block: 1
13 | attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
14 | temporal_position_encoding: true
15 | temporal_position_encoding_max_len: 24
16 | temporal_attention_dim_div: 1
17 |
18 | noise_scheduler_kwargs:
19 | beta_start: 0.00085
20 | beta_end: 0.012
21 | beta_schedule: "linear"
22 | steps_offset: 1
23 | clip_sample: False
--------------------------------------------------------------------------------
/configs/inference/inference_v2.yaml:
--------------------------------------------------------------------------------
1 | unet_additional_kwargs:
2 | use_inflated_groupnorm: true
3 | unet_use_cross_frame_attention: false
4 | unet_use_temporal_attention: false
5 | use_motion_module: true
6 | motion_module_resolutions:
7 | - 1
8 | - 2
9 | - 4
10 | - 8
11 | motion_module_mid_block: true
12 | motion_module_decoder_only: false
13 | motion_module_type: Vanilla
14 | motion_module_kwargs:
15 | num_attention_heads: 8
16 | num_transformer_block: 1
17 | attention_block_types:
18 | - Temporal_Self
19 | - Temporal_Self
20 | temporal_position_encoding: true
21 | temporal_position_encoding_max_len: 32
22 | temporal_attention_dim_div: 1
23 |
24 | noise_scheduler_kwargs:
25 | beta_start: 0.00085
26 | beta_end: 0.012
27 | beta_schedule: "linear"
28 | clip_sample: false
29 | steps_offset: 1
30 | ### Zero-SNR params
31 | prediction_type: "v_prediction"
32 | rescale_betas_zero_snr: True
33 | timestep_spacing: "trailing"
34 |
35 | sampler: DDIM
--------------------------------------------------------------------------------
/configs/inference/pose_images/pose-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/Moore-AnimateAnyone-for-windows/7cc83a44681509cd6e9138a21ec87cb8a6108c2d/configs/inference/pose_images/pose-1.png
--------------------------------------------------------------------------------
/configs/inference/pose_videos/anyone-video-1_kps.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/Moore-AnimateAnyone-for-windows/7cc83a44681509cd6e9138a21ec87cb8a6108c2d/configs/inference/pose_videos/anyone-video-1_kps.mp4
--------------------------------------------------------------------------------
/configs/inference/pose_videos/anyone-video-2_kps.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/Moore-AnimateAnyone-for-windows/7cc83a44681509cd6e9138a21ec87cb8a6108c2d/configs/inference/pose_videos/anyone-video-2_kps.mp4
--------------------------------------------------------------------------------
/configs/inference/pose_videos/anyone-video-4_kps.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/Moore-AnimateAnyone-for-windows/7cc83a44681509cd6e9138a21ec87cb8a6108c2d/configs/inference/pose_videos/anyone-video-4_kps.mp4
--------------------------------------------------------------------------------
/configs/inference/pose_videos/anyone-video-5_kps.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/Moore-AnimateAnyone-for-windows/7cc83a44681509cd6e9138a21ec87cb8a6108c2d/configs/inference/pose_videos/anyone-video-5_kps.mp4
--------------------------------------------------------------------------------
/configs/inference/ref_images/anyone-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/Moore-AnimateAnyone-for-windows/7cc83a44681509cd6e9138a21ec87cb8a6108c2d/configs/inference/ref_images/anyone-1.png
--------------------------------------------------------------------------------
/configs/inference/ref_images/anyone-10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/Moore-AnimateAnyone-for-windows/7cc83a44681509cd6e9138a21ec87cb8a6108c2d/configs/inference/ref_images/anyone-10.png
--------------------------------------------------------------------------------
/configs/inference/ref_images/anyone-11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/Moore-AnimateAnyone-for-windows/7cc83a44681509cd6e9138a21ec87cb8a6108c2d/configs/inference/ref_images/anyone-11.png
--------------------------------------------------------------------------------
/configs/inference/ref_images/anyone-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/Moore-AnimateAnyone-for-windows/7cc83a44681509cd6e9138a21ec87cb8a6108c2d/configs/inference/ref_images/anyone-2.png
--------------------------------------------------------------------------------
/configs/inference/ref_images/anyone-3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/Moore-AnimateAnyone-for-windows/7cc83a44681509cd6e9138a21ec87cb8a6108c2d/configs/inference/ref_images/anyone-3.png
--------------------------------------------------------------------------------
/configs/inference/ref_images/anyone-5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/Moore-AnimateAnyone-for-windows/7cc83a44681509cd6e9138a21ec87cb8a6108c2d/configs/inference/ref_images/anyone-5.png
--------------------------------------------------------------------------------
/configs/prompts/animation.yaml:
--------------------------------------------------------------------------------
1 | pretrained_base_model_path: "./pretrained_weights/stable-diffusion-v1-5/"
2 | pretrained_vae_path: "./pretrained_weights/sd-vae-ft-mse"
3 | image_encoder_path: "./pretrained_weights/image_encoder"
4 | denoising_unet_path: "./pretrained_weights/denoising_unet.pth"
5 | reference_unet_path: "./pretrained_weights/reference_unet.pth"
6 | pose_guider_path: "./pretrained_weights/pose_guider.pth"
7 | motion_module_path: "./pretrained_weights/motion_module.pth"
8 |
9 | inference_config: "./configs/inference/inference_v2.yaml"
10 | weight_dtype: 'fp16'
11 |
12 | test_cases:
13 | "./configs/inference/ref_images/anyone-2.png":
14 | - "./configs/inference/pose_videos/anyone-video-2_kps.mp4"
15 | - "./configs/inference/pose_videos/anyone-video-5_kps.mp4"
16 | "./configs/inference/ref_images/anyone-10.png":
17 | - "./configs/inference/pose_videos/anyone-video-1_kps.mp4"
18 | - "./configs/inference/pose_videos/anyone-video-2_kps.mp4"
19 | "./configs/inference/ref_images/anyone-11.png":
20 | - "./configs/inference/pose_videos/anyone-video-1_kps.mp4"
21 | - "./configs/inference/pose_videos/anyone-video-2_kps.mp4"
22 | "./configs/inference/ref_images/anyone-3.png":
23 | - "./configs/inference/pose_videos/anyone-video-2_kps.mp4"
24 | - "./configs/inference/pose_videos/anyone-video-5_kps.mp4"
25 | "./configs/inference/ref_images/anyone-5.png":
26 | - "./configs/inference/pose_videos/anyone-video-2_kps.mp4"
27 |
--------------------------------------------------------------------------------
/configs/prompts/test_cases.py:
--------------------------------------------------------------------------------
1 | TestCasesDict = {
2 | 0: [
3 | {
4 | "./configs/inference/ref_images/anyone-2.png": [
5 | "./configs/inference/pose_videos/anyone-video-2_kps.mp4",
6 | "./configs/inference/pose_videos/anyone-video-5_kps.mp4",
7 | ]
8 | },
9 | {
10 | "./configs/inference/ref_images/anyone-10.png": [
11 | "./configs/inference/pose_videos/anyone-video-1_kps.mp4",
12 | "./configs/inference/pose_videos/anyone-video-2_kps.mp4",
13 | ]
14 | },
15 | {
16 | "./configs/inference/ref_images/anyone-11.png": [
17 | "./configs/inference/pose_videos/anyone-video-1_kps.mp4",
18 | "./configs/inference/pose_videos/anyone-video-2_kps.mp4",
19 | ]
20 | },
21 | {
22 | "./configs/inference/anyone-ref-3.png": [
23 | "./configs/inference/pose_videos/anyone-video-2_kps.mp4",
24 | "./configs/inference/pose_videos/anyone-video-5_kps.mp4",
25 | ]
26 | },
27 | {
28 | "./configs/inference/ref_images/anyone-5.png": [
29 | "./configs/inference/pose_videos/anyone-video-2_kps.mp4"
30 | ]
31 | },
32 | ],
33 | }
34 |
--------------------------------------------------------------------------------
/configs/train/stage1.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train_bs: 4
3 | train_width: 768
4 | train_height: 768
5 | meta_paths:
6 | - "./data/fashion_meta.json"
7 | # Margin of frame indexes between ref and tgt images
8 | sample_margin: 30
9 |
10 | solver:
11 | gradient_accumulation_steps: 1
12 | mixed_precision: 'fp16'
13 | enable_xformers_memory_efficient_attention: True
14 | gradient_checkpointing: False
15 | max_train_steps: 30000
16 | max_grad_norm: 1.0
17 | # lr
18 | learning_rate: 1.0e-5
19 | scale_lr: False
20 | lr_warmup_steps: 1
21 | lr_scheduler: 'constant'
22 |
23 | # optimizer
24 | use_8bit_adam: False
25 | adam_beta1: 0.9
26 | adam_beta2: 0.999
27 | adam_weight_decay: 1.0e-2
28 | adam_epsilon: 1.0e-8
29 |
30 | val:
31 | validation_steps: 200
32 |
33 |
34 | noise_scheduler_kwargs:
35 | num_train_timesteps: 1000
36 | beta_start: 0.00085
37 | beta_end: 0.012
38 | beta_schedule: "scaled_linear"
39 | steps_offset: 1
40 | clip_sample: false
41 |
42 | base_model_path: './pretrained_weights/sd-image-variations-diffusers'
43 | vae_model_path: './pretrained_weights/sd-vae-ft-mse'
44 | image_encoder_path: './pretrained_weights/sd-image-variations-diffusers/image_encoder'
45 | controlnet_openpose_path: './pretrained_weights/control_v11p_sd15_openpose/diffusion_pytorch_model.bin'
46 |
47 | weight_dtype: 'fp16' # [fp16, fp32]
48 | uncond_ratio: 0.1
49 | noise_offset: 0.05
50 | snr_gamma: 5.0
51 | enable_zero_snr: True
52 | pose_guider_pretrain: True
53 |
54 | seed: 12580
55 | resume_from_checkpoint: ''
56 | checkpointing_steps: 2000
57 | save_model_epoch_interval: 5
58 | exp_name: 'stage1'
59 | output_dir: './exp_output'
--------------------------------------------------------------------------------
/configs/train/stage2.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train_bs: 1
3 | train_width: 512
4 | train_height: 512
5 | meta_paths:
6 | - "./data/fashion_meta.json"
7 | sample_rate: 4
8 | n_sample_frames: 24
9 |
10 | solver:
11 | gradient_accumulation_steps: 1
12 | mixed_precision: 'fp16'
13 | enable_xformers_memory_efficient_attention: True
14 | gradient_checkpointing: True
15 | max_train_steps: 10000
16 | max_grad_norm: 1.0
17 | # lr
18 | learning_rate: 1e-5
19 | scale_lr: False
20 | lr_warmup_steps: 1
21 | lr_scheduler: 'constant'
22 |
23 | # optimizer
24 | use_8bit_adam: True
25 | adam_beta1: 0.9
26 | adam_beta2: 0.999
27 | adam_weight_decay: 1.0e-2
28 | adam_epsilon: 1.0e-8
29 |
30 | val:
31 | validation_steps: 20
32 |
33 |
34 | noise_scheduler_kwargs:
35 | num_train_timesteps: 1000
36 | beta_start: 0.00085
37 | beta_end: 0.012
38 | beta_schedule: "linear"
39 | steps_offset: 1
40 | clip_sample: false
41 |
42 | base_model_path: './pretrained_weights/stable-diffusion-v1-5'
43 | vae_model_path: './pretrained_weights/sd-vae-ft-mse'
44 | image_encoder_path: './pretrained_weights/sd-image-variations-diffusers/image_encoder'
45 | mm_path: './pretrained_weights/mm_sd_v15_v2.ckpt'
46 |
47 | weight_dtype: 'fp16' # [fp16, fp32]
48 | uncond_ratio: 0.1
49 | noise_offset: 0.05
50 | snr_gamma: 5.0
51 | enable_zero_snr: True
52 | stage1_ckpt_dir: './exp_output/stage1'
53 | stage1_ckpt_step: 980
54 |
55 | seed: 12580
56 | resume_from_checkpoint: ''
57 | checkpointing_steps: 2000
58 | exp_name: 'stage2'
59 | output_dir: './exp_output'
--------------------------------------------------------------------------------
/install.ps1:
--------------------------------------------------------------------------------
1 | Set-Location $PSScriptRoot
2 |
3 | $Env:PIP_DISABLE_PIP_VERSION_CHECK = 1
4 |
5 | if (!(Test-Path -Path "venv")) {
6 | Write-Output "Creating venv for python..."
7 | python -m venv venv
8 | }
9 | .\venv\Scripts\activate
10 |
11 | Write-Output "install deps..."
12 | pip install -U -r requirements-windows.txt
13 |
14 | Write-Output "check models..."
15 |
16 | if (!(Test-Path -Path "pretrained_weights")) {
17 | Write-Output "Downloading pretrained_weights..."
18 | git lfs install
19 | git lfs clone https://huggingface.co/patrolli/AnimateAnyone pretrained_weights
20 | if (Test-Path -Path "pretrained_weights/.git/lfs") {
21 | Remove-Item -Path pretrained_weights/.git/* -Recurse -Force
22 | }
23 | }
24 |
25 | Set-Location .\pretrained_weights
26 |
27 | if (!(Test-Path -Path "image_encoder")) {
28 | Write-Output "Downloading image_encoder models..."
29 | git lfs install
30 | git lfs clone https://huggingface.co/bdsqlsz/image_encoder
31 | if (Test-Path -Path "image_encoder/.git/lfs") {
32 | Remove-Item -Path image_encoder/.git/* -Recurse -Force
33 | }
34 | }
35 |
36 | $install_SD15 = Read-Host "Do you need to download SD15? If you don't have any SD15 model locally select y, if you want to change to another SD1.5 model select n. [y/n] (Default is y)"
37 | if ($install_SD15 -eq "y" -or $install_SD15 -eq "Y" -or $install_SD15 -eq "") {
38 | if (!(Test-Path -Path "stable-diffusion-v1-5")) {
39 | Write-Output "Downloading stable-diffusion-v1-5 models..."
40 | git lfs clone https://huggingface.co/bdsqlsz/stable-diffusion-v1-5
41 |
42 | }
43 | if (Test-Path -Path "stable-diffusion-v1-5/.git/lfs") {
44 | Remove-Item -Path stable-diffusion-v1-5/.git/lfs/* -Recurse -Force
45 | }
46 | }
47 |
48 |
49 | if (!(Test-Path -Path "DWPose")) {
50 | Write-Output "Downloading dwpose models..."
51 | mkdir "DWPose"
52 | wget https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx -o DWPose/dw-ll_ucoco_384.onnx
53 | wget https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx -o DWPose/yolox_l.onnx
54 | }
55 |
56 | Write-Output "Installing Video_controlnet_aux..."
57 |
58 | git submodule update --recursive --init
59 |
60 | Set-Location $PSScriptRoot/video_controlnet_aux
61 | pip install -r requirements.txt
62 | pip install -r requirements-video.txt
63 |
64 | Write-Output "Install completed"
65 | Read-Host | Out-Null ;
66 |
--------------------------------------------------------------------------------
/install_cn.ps1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/Moore-AnimateAnyone-for-windows/7cc83a44681509cd6e9138a21ec87cb8a6108c2d/install_cn.ps1
--------------------------------------------------------------------------------
/requirements-windows.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu118
2 | accelerate==0.21.0
3 | av==11.0.0
4 | clip @ https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip#sha256=b5842c25da441d6c581b53a5c60e0c2127ebafe0f746f8e15561a006c6c3be6a
5 | decord==0.6.0
6 | diffusers==0.24.0
7 | einops==0.4.1
8 | gradio==3.41.2
9 | gradio_client==0.5.0
10 | imageio==2.33.0
11 | imageio-ffmpeg==0.4.9
12 | numpy==1.24.4
13 | omegaconf==2.2.3
14 | onnxruntime==1.16.3
15 | onnxruntime-gpu==1.16.3
16 | open-clip-torch==2.20.0
17 | opencv-contrib-python==4.8.1.78
18 | opencv-python==4.8.1.78
19 | Pillow==9.5.0
20 | scikit-image==0.21.0
21 | scikit-learn==1.3.2
22 | scipy==1.11.4
23 | torch==2.0.1
24 | torchdiffeq==0.2.3
25 | torchmetrics==1.2.1
26 | torchsde==0.2.5
27 | torchvision==0.15.2
28 | tqdm==4.66.1
29 | transformers==4.30.2
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.21.0
2 | av==11.0.0
3 | clip @ https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip#sha256=b5842c25da441d6c581b53a5c60e0c2127ebafe0f746f8e15561a006c6c3be6a
4 | decord==0.6.0
5 | diffusers==0.24.0
6 | einops==0.4.1
7 | gradio==3.41.2
8 | gradio_client==0.5.0
9 | imageio==2.33.0
10 | imageio-ffmpeg==0.4.9
11 | numpy==1.23.5
12 | omegaconf==2.2.3
13 | onnxruntime-gpu==1.16.3
14 | open-clip-torch==2.20.0
15 | opencv-contrib-python==4.8.1.78
16 | opencv-python==4.8.1.78
17 | Pillow==9.5.0
18 | scikit-image==0.21.0
19 | scikit-learn==1.3.2
20 | scipy==1.11.4
21 | torch==2.0.1
22 | torchdiffeq==0.2.3
23 | torchmetrics==1.2.1
24 | torchsde==0.2.5
25 | torchvision==0.15.2
26 | tqdm==4.66.1
27 | transformers==4.30.2
28 | mlflow==2.9.2
29 | xformers==0.0.22
30 | controlnet-aux==0.0.7
--------------------------------------------------------------------------------
/run_VidControlnetAux_gui.ps1:
--------------------------------------------------------------------------------
1 | $input_path="./assets/cai-xukun.mp4"
2 | $output_path="./outputs/"
3 |
4 |
5 | Set-Location $PSScriptRoot
6 | .\venv\Scripts\activate
7 |
8 | $Env:HF_HOME = "./huggingface"
9 | $Env:XFORMERS_FORCE_DISABLE_TRITON = "1"
10 | #$Env:PYTHONPATH = $PSScriptRoot
11 | $ext_args = [System.Collections.ArrayList]::new()
12 |
13 | if ($input_path) {
14 | [void]$ext_args.Add("-i=$input_path")
15 | }
16 |
17 | if ($output_path) {
18 | [void]$ext_args.Add("-o=$output_path")
19 | }
20 |
21 |
22 | python.exe "video_controlnet_aux/src/video_controlnet_aux.py" $ext_args
23 |
--------------------------------------------------------------------------------
/run_gui.ps1:
--------------------------------------------------------------------------------
1 | Set-Location $PSScriptRoot
2 | .\venv\Scripts\activate
3 |
4 | $Env:HF_HOME = "./huggingface"
5 | $Env:XFORMERS_FORCE_DISABLE_TRITON = "1"
6 | $Env:PYTHONPATH = $PSScriptRoot
7 |
8 | python.exe "app.py"
--------------------------------------------------------------------------------
/scripts/pose2vid.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from datetime import datetime
4 | from pathlib import Path
5 | from typing import List
6 |
7 | import av
8 | import numpy as np
9 | import torch
10 | import torchvision
11 | from diffusers import AutoencoderKL, DDIMScheduler
12 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
13 | from einops import repeat
14 | from omegaconf import OmegaConf
15 | from PIL import Image
16 | from torchvision import transforms
17 | from transformers import CLIPVisionModelWithProjection
18 |
19 | from configs.prompts.test_cases import TestCasesDict
20 | from src.models.pose_guider import PoseGuider
21 | from src.models.unet_2d_condition import UNet2DConditionModel
22 | from src.models.unet_3d import UNet3DConditionModel
23 | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
24 | from src.utils.util import get_fps, read_frames, save_videos_grid
25 | from src.models.model_util import get_torch_device
26 |
27 | def parse_args():
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument("--config")
30 | parser.add_argument("-W", type=int, default=512)
31 | parser.add_argument("-H", type=int, default=784)
32 | parser.add_argument("-L", type=int, default=24)
33 | parser.add_argument("--seed", type=int, default=42)
34 | parser.add_argument("--cfg", type=float, default=3.5)
35 | parser.add_argument("--steps", type=int, default=30)
36 | parser.add_argument("--fps", type=int)
37 | args = parser.parse_args()
38 |
39 | return args
40 |
41 |
42 | def main():
43 | device = get_torch_device()
44 |
45 | args = parse_args()
46 |
47 | config = OmegaConf.load(args.config)
48 |
49 | if config.weight_dtype == "fp16":
50 | weight_dtype = torch.float16
51 | else:
52 | weight_dtype = torch.float32
53 |
54 | vae = AutoencoderKL.from_pretrained(
55 | config.pretrained_vae_path,
56 | ).to(device, dtype=weight_dtype)
57 |
58 | reference_unet = UNet2DConditionModel.from_pretrained(
59 | config.pretrained_base_model_path,
60 | subfolder="unet",
61 | ).to(dtype=weight_dtype, device=device)
62 |
63 | inference_config_path = config.inference_config
64 | infer_config = OmegaConf.load(inference_config_path)
65 | denoising_unet = UNet3DConditionModel.from_pretrained_2d(
66 | config.pretrained_base_model_path,
67 | config.motion_module_path,
68 | subfolder="unet",
69 | unet_additional_kwargs=infer_config.unet_additional_kwargs,
70 | ).to(dtype=weight_dtype, device=device)
71 |
72 | pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
73 | dtype=weight_dtype, device=device
74 | )
75 |
76 | image_enc = CLIPVisionModelWithProjection.from_pretrained(
77 | config.image_encoder_path
78 | ).to(dtype=weight_dtype, device=device)
79 |
80 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
81 | scheduler = DDIMScheduler(**sched_kwargs)
82 |
83 | generator = torch.manual_seed(args.seed)
84 |
85 | width, height = args.W, args.H
86 |
87 | # load pretrained weights
88 | denoising_unet.load_state_dict(
89 | torch.load(config.denoising_unet_path, map_location="cpu"),
90 | strict=False,
91 | )
92 | reference_unet.load_state_dict(
93 | torch.load(config.reference_unet_path, map_location="cpu"),
94 | )
95 | pose_guider.load_state_dict(
96 | torch.load(config.pose_guider_path, map_location="cpu"),
97 | )
98 |
99 | pipe = Pose2VideoPipeline(
100 | vae=vae,
101 | image_encoder=image_enc,
102 | reference_unet=reference_unet,
103 | denoising_unet=denoising_unet,
104 | pose_guider=pose_guider,
105 | scheduler=scheduler,
106 | )
107 | pipe = pipe.to(device, dtype=weight_dtype)
108 |
109 | date_str = datetime.now().strftime("%Y%m%d")
110 | time_str = datetime.now().strftime("%H%M")
111 | save_dir_name = f"{time_str}--seed_{args.seed}-{args.W}x{args.H}"
112 |
113 | save_dir = Path(f"output/{date_str}/{save_dir_name}")
114 | save_dir.mkdir(exist_ok=True, parents=True)
115 |
116 | for ref_image_path in config["test_cases"].keys():
117 | # Each ref_image may correspond to multiple actions
118 | for pose_video_path in config["test_cases"][ref_image_path]:
119 | ref_name = Path(ref_image_path).stem
120 | pose_name = Path(pose_video_path).stem.replace("_kps", "")
121 |
122 | ref_image_pil = Image.open(ref_image_path).convert("RGB")
123 |
124 | pose_list = []
125 | pose_tensor_list = []
126 | pose_images = read_frames(pose_video_path)
127 | src_fps = get_fps(pose_video_path)
128 | print(f"pose video has {len(pose_images)} frames, with {src_fps} fps")
129 | pose_transform = transforms.Compose(
130 | [transforms.Resize((height, width)), transforms.ToTensor()]
131 | )
132 | for pose_image_pil in pose_images[: args.L]:
133 | pose_tensor_list.append(pose_transform(pose_image_pil))
134 | pose_list.append(pose_image_pil)
135 |
136 | ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w)
137 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(
138 | 0
139 | ) # (1, c, 1, h, w)
140 | ref_image_tensor = repeat(
141 | ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=args.L
142 | )
143 |
144 | pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)
145 | pose_tensor = pose_tensor.transpose(0, 1)
146 | pose_tensor = pose_tensor.unsqueeze(0)
147 |
148 | video = pipe(
149 | ref_image_pil,
150 | pose_list,
151 | width,
152 | height,
153 | args.L,
154 | args.steps,
155 | args.cfg,
156 | generator=generator,
157 | ).videos
158 |
159 | video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0)
160 | save_videos_grid(
161 | video,
162 | f"{save_dir}/{ref_name}_{pose_name}_{args.H}x{args.W}_{int(args.cfg)}_{time_str}.mp4",
163 | n_rows=3,
164 | fps=src_fps if args.fps is None else args.fps,
165 | )
166 |
167 |
168 | if __name__ == "__main__":
169 | main()
170 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/Moore-AnimateAnyone-for-windows/7cc83a44681509cd6e9138a21ec87cb8a6108c2d/src/__init__.py
--------------------------------------------------------------------------------
/src/dataset/dance_image.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 |
4 | import torch
5 | import torchvision.transforms as transforms
6 | from decord import VideoReader
7 | from PIL import Image
8 | from torch.utils.data import Dataset
9 | from transformers import CLIPImageProcessor
10 |
11 |
12 | class HumanDanceDataset(Dataset):
13 | def __init__(
14 | self,
15 | img_size,
16 | img_scale=(1.0, 1.0),
17 | img_ratio=(0.9, 1.0),
18 | drop_ratio=0.1,
19 | data_meta_paths=["./data/fahsion_meta.json"],
20 | sample_margin=30,
21 | ):
22 | super().__init__()
23 |
24 | self.img_size = img_size
25 | self.img_scale = img_scale
26 | self.img_ratio = img_ratio
27 | self.sample_margin = sample_margin
28 |
29 | # -----
30 | # vid_meta format:
31 | # [{'video_path': , 'kps_path': , 'other':},
32 | # {'video_path': , 'kps_path': , 'other':}]
33 | # -----
34 | vid_meta = []
35 | for data_meta_path in data_meta_paths:
36 | vid_meta.extend(json.load(open(data_meta_path, "r")))
37 | self.vid_meta = vid_meta
38 |
39 | self.clip_image_processor = CLIPImageProcessor()
40 |
41 | self.transform = transforms.Compose(
42 | [
43 | transforms.RandomResizedCrop(
44 | self.img_size,
45 | scale=self.img_scale,
46 | ratio=self.img_ratio,
47 | interpolation=transforms.InterpolationMode.BILINEAR,
48 | ),
49 | transforms.ToTensor(),
50 | transforms.Normalize([0.5], [0.5]),
51 | ]
52 | )
53 |
54 | self.cond_transform = transforms.Compose(
55 | [
56 | transforms.RandomResizedCrop(
57 | self.img_size,
58 | scale=self.img_scale,
59 | ratio=self.img_ratio,
60 | interpolation=transforms.InterpolationMode.BILINEAR,
61 | ),
62 | transforms.ToTensor(),
63 | ]
64 | )
65 |
66 | self.drop_ratio = drop_ratio
67 |
68 | def augmentation(self, image, transform, state=None):
69 | if state is not None:
70 | torch.set_rng_state(state)
71 | return transform(image)
72 |
73 | def __getitem__(self, index):
74 | video_meta = self.vid_meta[index]
75 | video_path = video_meta["video_path"]
76 | kps_path = video_meta["kps_path"]
77 |
78 | video_reader = VideoReader(video_path)
79 | kps_reader = VideoReader(kps_path)
80 |
81 | assert len(video_reader) == len(
82 | kps_reader
83 | ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"
84 |
85 | video_length = len(video_reader)
86 |
87 | margin = min(self.sample_margin, video_length)
88 |
89 | ref_img_idx = random.randint(0, video_length - 1)
90 | if ref_img_idx + margin < video_length:
91 | tgt_img_idx = random.randint(ref_img_idx + margin, video_length - 1)
92 | elif ref_img_idx - margin > 0:
93 | tgt_img_idx = random.randint(0, ref_img_idx - margin)
94 | else:
95 | tgt_img_idx = random.randint(0, video_length - 1)
96 |
97 | ref_img = video_reader[ref_img_idx]
98 | ref_img_pil = Image.fromarray(ref_img.asnumpy())
99 | tgt_img = video_reader[tgt_img_idx]
100 | tgt_img_pil = Image.fromarray(tgt_img.asnumpy())
101 |
102 | tgt_pose = kps_reader[tgt_img_idx]
103 | tgt_pose_pil = Image.fromarray(tgt_pose.asnumpy())
104 |
105 | state = torch.get_rng_state()
106 | tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
107 | tgt_pose_img = self.augmentation(tgt_pose_pil, self.cond_transform, state)
108 | ref_img_vae = self.augmentation(ref_img_pil, self.transform, state)
109 | clip_image = self.clip_image_processor(
110 | images=ref_img_pil, return_tensors="pt"
111 | ).pixel_values[0]
112 |
113 | sample = dict(
114 | video_dir=video_path,
115 | img=tgt_img,
116 | tgt_pose=tgt_pose_img,
117 | ref_img=ref_img_vae,
118 | clip_images=clip_image,
119 | )
120 |
121 | return sample
122 |
123 | def __len__(self):
124 | return len(self.vid_meta)
125 |
--------------------------------------------------------------------------------
/src/dataset/dance_video.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | from typing import List
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | import torchvision.transforms as transforms
9 | from decord import VideoReader
10 | from PIL import Image
11 | from torch.utils.data import Dataset
12 | from transformers import CLIPImageProcessor
13 |
14 |
15 | class HumanDanceVideoDataset(Dataset):
16 | def __init__(
17 | self,
18 | sample_rate,
19 | n_sample_frames,
20 | width,
21 | height,
22 | img_scale=(1.0, 1.0),
23 | img_ratio=(0.9, 1.0),
24 | drop_ratio=0.1,
25 | data_meta_paths=["./data/fashion_meta.json"],
26 | ):
27 | super().__init__()
28 | self.sample_rate = sample_rate
29 | self.n_sample_frames = n_sample_frames
30 | self.width = width
31 | self.height = height
32 | self.img_scale = img_scale
33 | self.img_ratio = img_ratio
34 |
35 | vid_meta = []
36 | for data_meta_path in data_meta_paths:
37 | vid_meta.extend(json.load(open(data_meta_path, "r")))
38 | self.vid_meta = vid_meta
39 |
40 | self.clip_image_processor = CLIPImageProcessor()
41 |
42 | self.pixel_transform = transforms.Compose(
43 | [
44 | transforms.RandomResizedCrop(
45 | (height, width),
46 | scale=self.img_scale,
47 | ratio=self.img_ratio,
48 | interpolation=transforms.InterpolationMode.BILINEAR,
49 | ),
50 | transforms.ToTensor(),
51 | transforms.Normalize([0.5], [0.5]),
52 | ]
53 | )
54 |
55 | self.cond_transform = transforms.Compose(
56 | [
57 | transforms.RandomResizedCrop(
58 | (height, width),
59 | scale=self.img_scale,
60 | ratio=self.img_ratio,
61 | interpolation=transforms.InterpolationMode.BILINEAR,
62 | ),
63 | transforms.ToTensor(),
64 | ]
65 | )
66 |
67 | self.drop_ratio = drop_ratio
68 |
69 | def augmentation(self, images, transform, state=None):
70 | if state is not None:
71 | torch.set_rng_state(state)
72 | if isinstance(images, List):
73 | transformed_images = [transform(img) for img in images]
74 | ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
75 | else:
76 | ret_tensor = transform(images) # (c, h, w)
77 | return ret_tensor
78 |
79 | def __getitem__(self, index):
80 | video_meta = self.vid_meta[index]
81 | video_path = video_meta["video_path"]
82 | kps_path = video_meta["kps_path"]
83 |
84 | video_reader = VideoReader(video_path)
85 | kps_reader = VideoReader(kps_path)
86 |
87 | assert len(video_reader) == len(
88 | kps_reader
89 | ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"
90 |
91 | video_length = len(video_reader)
92 |
93 | clip_length = min(
94 | video_length, (self.n_sample_frames - 1) * self.sample_rate + 1
95 | )
96 | start_idx = random.randint(0, video_length - clip_length)
97 | batch_index = np.linspace(
98 | start_idx, start_idx + clip_length - 1, self.n_sample_frames, dtype=int
99 | ).tolist()
100 |
101 | # read frames and kps
102 | vid_pil_image_list = []
103 | pose_pil_image_list = []
104 | for index in batch_index:
105 | img = video_reader[index]
106 | vid_pil_image_list.append(Image.fromarray(img.asnumpy()))
107 | img = kps_reader[index]
108 | pose_pil_image_list.append(Image.fromarray(img.asnumpy()))
109 |
110 | ref_img_idx = random.randint(0, video_length - 1)
111 | ref_img = Image.fromarray(video_reader[ref_img_idx].asnumpy())
112 |
113 | # transform
114 | state = torch.get_rng_state()
115 | pixel_values_vid = self.augmentation(
116 | vid_pil_image_list, self.pixel_transform, state
117 | )
118 | pixel_values_pose = self.augmentation(
119 | pose_pil_image_list, self.cond_transform, state
120 | )
121 | pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state)
122 | clip_ref_img = self.clip_image_processor(
123 | images=ref_img, return_tensors="pt"
124 | ).pixel_values[0]
125 |
126 | sample = dict(
127 | video_dir=video_path,
128 | pixel_values_vid=pixel_values_vid,
129 | pixel_values_pose=pixel_values_pose,
130 | pixel_values_ref_img=pixel_values_ref_img,
131 | clip_ref_img=clip_ref_img,
132 | )
133 |
134 | return sample
135 |
136 | def __len__(self):
137 | return len(self.vid_meta)
138 |
--------------------------------------------------------------------------------
/src/dwpose/__init__.py:
--------------------------------------------------------------------------------
1 | # https://github.com/IDEA-Research/DWPose
2 | # Openpose
3 | # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
4 | # 2nd Edited by https://github.com/Hzzone/pytorch-openpose
5 | # 3rd Edited by ControlNet
6 | # 4th Edited by ControlNet (added face and correct hands)
7 |
8 | import copy
9 | import os
10 |
11 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
12 | import cv2
13 | import numpy as np
14 | import torch
15 | from controlnet_aux.util import HWC3, resize_image
16 | from PIL import Image
17 |
18 | from . import util
19 | from .wholebody import Wholebody
20 |
21 |
22 | def draw_pose(pose, H, W):
23 | bodies = pose["bodies"]
24 | faces = pose["faces"]
25 | hands = pose["hands"]
26 | candidate = bodies["candidate"]
27 | subset = bodies["subset"]
28 | canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
29 |
30 | canvas = util.draw_bodypose(canvas, candidate, subset)
31 |
32 | canvas = util.draw_handpose(canvas, hands)
33 |
34 | canvas = util.draw_facepose(canvas, faces)
35 |
36 | return canvas
37 |
38 |
39 | class DWposeDetector:
40 | def __init__(self):
41 | pass
42 |
43 | def to(self, device):
44 | self.pose_estimation = Wholebody(device)
45 | return self
46 |
47 | def cal_height(self, input_image):
48 | input_image = cv2.cvtColor(
49 | np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR
50 | )
51 |
52 | input_image = HWC3(input_image)
53 | H, W, C = input_image.shape
54 | with torch.no_grad():
55 | candidate, subset = self.pose_estimation(input_image)
56 | nums, keys, locs = candidate.shape
57 | # candidate[..., 0] /= float(W)
58 | # candidate[..., 1] /= float(H)
59 | body = candidate
60 | return body[0, ..., 1].min(), body[..., 1].max() - body[..., 1].min()
61 |
62 | def __call__(
63 | self,
64 | input_image,
65 | detect_resolution=512,
66 | image_resolution=512,
67 | output_type="pil",
68 | **kwargs,
69 | ):
70 | input_image = cv2.cvtColor(
71 | np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR
72 | )
73 |
74 | input_image = HWC3(input_image)
75 | input_image = resize_image(input_image, detect_resolution)
76 | H, W, C = input_image.shape
77 | with torch.no_grad():
78 | candidate, subset = self.pose_estimation(input_image)
79 | nums, keys, locs = candidate.shape
80 | candidate[..., 0] /= float(W)
81 | candidate[..., 1] /= float(H)
82 | score = subset[:, :18]
83 | max_ind = np.mean(score, axis=-1).argmax(axis=0)
84 | score = score[[max_ind]]
85 | body = candidate[:, :18].copy()
86 | body = body[[max_ind]]
87 | nums = 1
88 | body = body.reshape(nums * 18, locs)
89 | body_score = copy.deepcopy(score)
90 | for i in range(len(score)):
91 | for j in range(len(score[i])):
92 | if score[i][j] > 0.3:
93 | score[i][j] = int(18 * i + j)
94 | else:
95 | score[i][j] = -1
96 |
97 | un_visible = subset < 0.3
98 | candidate[un_visible] = -1
99 |
100 | foot = candidate[:, 18:24]
101 |
102 | faces = candidate[[max_ind], 24:92]
103 |
104 | hands = candidate[[max_ind], 92:113]
105 | hands = np.vstack([hands, candidate[[max_ind], 113:]])
106 |
107 | bodies = dict(candidate=body, subset=score)
108 | pose = dict(bodies=bodies, hands=hands, faces=faces)
109 |
110 | detected_map = draw_pose(pose, H, W)
111 | detected_map = HWC3(detected_map)
112 |
113 | img = resize_image(input_image, image_resolution)
114 | H, W, C = img.shape
115 |
116 | detected_map = cv2.resize(
117 | detected_map, (W, H), interpolation=cv2.INTER_LINEAR
118 | )
119 |
120 | if output_type == "pil":
121 | detected_map = Image.fromarray(detected_map)
122 |
123 | return detected_map, body_score
124 |
--------------------------------------------------------------------------------
/src/dwpose/onnxdet.py:
--------------------------------------------------------------------------------
1 | # https://github.com/IDEA-Research/DWPose
2 | import cv2
3 | import numpy as np
4 | import onnxruntime
5 |
6 |
7 | def nms(boxes, scores, nms_thr):
8 | """Single class NMS implemented in Numpy."""
9 | x1 = boxes[:, 0]
10 | y1 = boxes[:, 1]
11 | x2 = boxes[:, 2]
12 | y2 = boxes[:, 3]
13 |
14 | areas = (x2 - x1 + 1) * (y2 - y1 + 1)
15 | order = scores.argsort()[::-1]
16 |
17 | keep = []
18 | while order.size > 0:
19 | i = order[0]
20 | keep.append(i)
21 | xx1 = np.maximum(x1[i], x1[order[1:]])
22 | yy1 = np.maximum(y1[i], y1[order[1:]])
23 | xx2 = np.minimum(x2[i], x2[order[1:]])
24 | yy2 = np.minimum(y2[i], y2[order[1:]])
25 |
26 | w = np.maximum(0.0, xx2 - xx1 + 1)
27 | h = np.maximum(0.0, yy2 - yy1 + 1)
28 | inter = w * h
29 | ovr = inter / (areas[i] + areas[order[1:]] - inter)
30 |
31 | inds = np.where(ovr <= nms_thr)[0]
32 | order = order[inds + 1]
33 |
34 | return keep
35 |
36 |
37 | def multiclass_nms(boxes, scores, nms_thr, score_thr):
38 | """Multiclass NMS implemented in Numpy. Class-aware version."""
39 | final_dets = []
40 | num_classes = scores.shape[1]
41 | for cls_ind in range(num_classes):
42 | cls_scores = scores[:, cls_ind]
43 | valid_score_mask = cls_scores > score_thr
44 | if valid_score_mask.sum() == 0:
45 | continue
46 | else:
47 | valid_scores = cls_scores[valid_score_mask]
48 | valid_boxes = boxes[valid_score_mask]
49 | keep = nms(valid_boxes, valid_scores, nms_thr)
50 | if len(keep) > 0:
51 | cls_inds = np.ones((len(keep), 1)) * cls_ind
52 | dets = np.concatenate(
53 | [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
54 | )
55 | final_dets.append(dets)
56 | if len(final_dets) == 0:
57 | return None
58 | return np.concatenate(final_dets, 0)
59 |
60 |
61 | def demo_postprocess(outputs, img_size, p6=False):
62 | grids = []
63 | expanded_strides = []
64 | strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
65 |
66 | hsizes = [img_size[0] // stride for stride in strides]
67 | wsizes = [img_size[1] // stride for stride in strides]
68 |
69 | for hsize, wsize, stride in zip(hsizes, wsizes, strides):
70 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
71 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
72 | grids.append(grid)
73 | shape = grid.shape[:2]
74 | expanded_strides.append(np.full((*shape, 1), stride))
75 |
76 | grids = np.concatenate(grids, 1)
77 | expanded_strides = np.concatenate(expanded_strides, 1)
78 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
79 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
80 |
81 | return outputs
82 |
83 |
84 | def preprocess(img, input_size, swap=(2, 0, 1)):
85 | if len(img.shape) == 3:
86 | padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
87 | else:
88 | padded_img = np.ones(input_size, dtype=np.uint8) * 114
89 |
90 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
91 | resized_img = cv2.resize(
92 | img,
93 | (int(img.shape[1] * r), int(img.shape[0] * r)),
94 | interpolation=cv2.INTER_LINEAR,
95 | ).astype(np.uint8)
96 | padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
97 |
98 | padded_img = padded_img.transpose(swap)
99 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
100 | return padded_img, r
101 |
102 |
103 | def inference_detector(session, oriImg):
104 | input_shape = (640, 640)
105 | img, ratio = preprocess(oriImg, input_shape)
106 |
107 | ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
108 | output = session.run(None, ort_inputs)
109 | predictions = demo_postprocess(output[0], input_shape)[0]
110 |
111 | boxes = predictions[:, :4]
112 | scores = predictions[:, 4:5] * predictions[:, 5:]
113 |
114 | boxes_xyxy = np.ones_like(boxes)
115 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.0
116 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.0
117 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.0
118 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.0
119 | boxes_xyxy /= ratio
120 | dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
121 | if dets is not None:
122 | final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
123 | isscore = final_scores > 0.3
124 | iscat = final_cls_inds == 0
125 | isbbox = [i and j for (i, j) in zip(isscore, iscat)]
126 | final_boxes = final_boxes[isbbox]
127 | else:
128 | return []
129 |
130 | return final_boxes
131 |
--------------------------------------------------------------------------------
/src/dwpose/onnxpose.py:
--------------------------------------------------------------------------------
1 | # https://github.com/IDEA-Research/DWPose
2 | from typing import List, Tuple
3 |
4 | import cv2
5 | import numpy as np
6 | import onnxruntime as ort
7 |
8 |
9 | def preprocess(
10 | img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
11 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
12 | """Do preprocessing for RTMPose model inference.
13 |
14 | Args:
15 | img (np.ndarray): Input image in shape.
16 | input_size (tuple): Input image size in shape (w, h).
17 |
18 | Returns:
19 | tuple:
20 | - resized_img (np.ndarray): Preprocessed image.
21 | - center (np.ndarray): Center of image.
22 | - scale (np.ndarray): Scale of image.
23 | """
24 | # get shape of image
25 | img_shape = img.shape[:2]
26 | out_img, out_center, out_scale = [], [], []
27 | if len(out_bbox) == 0:
28 | out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
29 | for i in range(len(out_bbox)):
30 | x0 = out_bbox[i][0]
31 | y0 = out_bbox[i][1]
32 | x1 = out_bbox[i][2]
33 | y1 = out_bbox[i][3]
34 | bbox = np.array([x0, y0, x1, y1])
35 |
36 | # get center and scale
37 | center, scale = bbox_xyxy2cs(bbox, padding=1.25)
38 |
39 | # do affine transformation
40 | resized_img, scale = top_down_affine(input_size, scale, center, img)
41 |
42 | # normalize image
43 | mean = np.array([123.675, 116.28, 103.53])
44 | std = np.array([58.395, 57.12, 57.375])
45 | resized_img = (resized_img - mean) / std
46 |
47 | out_img.append(resized_img)
48 | out_center.append(center)
49 | out_scale.append(scale)
50 |
51 | return out_img, out_center, out_scale
52 |
53 |
54 | def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
55 | """Inference RTMPose model.
56 |
57 | Args:
58 | sess (ort.InferenceSession): ONNXRuntime session.
59 | img (np.ndarray): Input image in shape.
60 |
61 | Returns:
62 | outputs (np.ndarray): Output of RTMPose model.
63 | """
64 | all_out = []
65 | # build input
66 | for i in range(len(img)):
67 | input = [img[i].transpose(2, 0, 1)]
68 |
69 | # build output
70 | sess_input = {sess.get_inputs()[0].name: input}
71 | sess_output = []
72 | for out in sess.get_outputs():
73 | sess_output.append(out.name)
74 |
75 | # run model
76 | outputs = sess.run(sess_output, sess_input)
77 | all_out.append(outputs)
78 |
79 | return all_out
80 |
81 |
82 | def postprocess(
83 | outputs: List[np.ndarray],
84 | model_input_size: Tuple[int, int],
85 | center: Tuple[int, int],
86 | scale: Tuple[int, int],
87 | simcc_split_ratio: float = 2.0,
88 | ) -> Tuple[np.ndarray, np.ndarray]:
89 | """Postprocess for RTMPose model output.
90 |
91 | Args:
92 | outputs (np.ndarray): Output of RTMPose model.
93 | model_input_size (tuple): RTMPose model Input image size.
94 | center (tuple): Center of bbox in shape (x, y).
95 | scale (tuple): Scale of bbox in shape (w, h).
96 | simcc_split_ratio (float): Split ratio of simcc.
97 |
98 | Returns:
99 | tuple:
100 | - keypoints (np.ndarray): Rescaled keypoints.
101 | - scores (np.ndarray): Model predict scores.
102 | """
103 | all_key = []
104 | all_score = []
105 | for i in range(len(outputs)):
106 | # use simcc to decode
107 | simcc_x, simcc_y = outputs[i]
108 | keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
109 |
110 | # rescale keypoints
111 | keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
112 | all_key.append(keypoints[0])
113 | all_score.append(scores[0])
114 |
115 | return np.array(all_key), np.array(all_score)
116 |
117 |
118 | def bbox_xyxy2cs(
119 | bbox: np.ndarray, padding: float = 1.0
120 | ) -> Tuple[np.ndarray, np.ndarray]:
121 | """Transform the bbox format from (x,y,w,h) into (center, scale)
122 |
123 | Args:
124 | bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
125 | as (left, top, right, bottom)
126 | padding (float): BBox padding factor that will be multilied to scale.
127 | Default: 1.0
128 |
129 | Returns:
130 | tuple: A tuple containing center and scale.
131 | - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
132 | (n, 2)
133 | - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
134 | (n, 2)
135 | """
136 | # convert single bbox from (4, ) to (1, 4)
137 | dim = bbox.ndim
138 | if dim == 1:
139 | bbox = bbox[None, :]
140 |
141 | # get bbox center and scale
142 | x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
143 | center = np.hstack([x1 + x2, y1 + y2]) * 0.5
144 | scale = np.hstack([x2 - x1, y2 - y1]) * padding
145 |
146 | if dim == 1:
147 | center = center[0]
148 | scale = scale[0]
149 |
150 | return center, scale
151 |
152 |
153 | def _fix_aspect_ratio(bbox_scale: np.ndarray, aspect_ratio: float) -> np.ndarray:
154 | """Extend the scale to match the given aspect ratio.
155 |
156 | Args:
157 | scale (np.ndarray): The image scale (w, h) in shape (2, )
158 | aspect_ratio (float): The ratio of ``w/h``
159 |
160 | Returns:
161 | np.ndarray: The reshaped image scale in (2, )
162 | """
163 | w, h = np.hsplit(bbox_scale, [1])
164 | bbox_scale = np.where(
165 | w > h * aspect_ratio,
166 | np.hstack([w, w / aspect_ratio]),
167 | np.hstack([h * aspect_ratio, h]),
168 | )
169 | return bbox_scale
170 |
171 |
172 | def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
173 | """Rotate a point by an angle.
174 |
175 | Args:
176 | pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
177 | angle_rad (float): rotation angle in radian
178 |
179 | Returns:
180 | np.ndarray: Rotated point in shape (2, )
181 | """
182 | sn, cs = np.sin(angle_rad), np.cos(angle_rad)
183 | rot_mat = np.array([[cs, -sn], [sn, cs]])
184 | return rot_mat @ pt
185 |
186 |
187 | def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
188 | """To calculate the affine matrix, three pairs of points are required. This
189 | function is used to get the 3rd point, given 2D points a & b.
190 |
191 | The 3rd point is defined by rotating vector `a - b` by 90 degrees
192 | anticlockwise, using b as the rotation center.
193 |
194 | Args:
195 | a (np.ndarray): The 1st point (x,y) in shape (2, )
196 | b (np.ndarray): The 2nd point (x,y) in shape (2, )
197 |
198 | Returns:
199 | np.ndarray: The 3rd point.
200 | """
201 | direction = a - b
202 | c = b + np.r_[-direction[1], direction[0]]
203 | return c
204 |
205 |
206 | def get_warp_matrix(
207 | center: np.ndarray,
208 | scale: np.ndarray,
209 | rot: float,
210 | output_size: Tuple[int, int],
211 | shift: Tuple[float, float] = (0.0, 0.0),
212 | inv: bool = False,
213 | ) -> np.ndarray:
214 | """Calculate the affine transformation matrix that can warp the bbox area
215 | in the input image to the output size.
216 |
217 | Args:
218 | center (np.ndarray[2, ]): Center of the bounding box (x, y).
219 | scale (np.ndarray[2, ]): Scale of the bounding box
220 | wrt [width, height].
221 | rot (float): Rotation angle (degree).
222 | output_size (np.ndarray[2, ] | list(2,)): Size of the
223 | destination heatmaps.
224 | shift (0-100%): Shift translation ratio wrt the width/height.
225 | Default (0., 0.).
226 | inv (bool): Option to inverse the affine transform direction.
227 | (inv=False: src->dst or inv=True: dst->src)
228 |
229 | Returns:
230 | np.ndarray: A 2x3 transformation matrix
231 | """
232 | shift = np.array(shift)
233 | src_w = scale[0]
234 | dst_w = output_size[0]
235 | dst_h = output_size[1]
236 |
237 | # compute transformation matrix
238 | rot_rad = np.deg2rad(rot)
239 | src_dir = _rotate_point(np.array([0.0, src_w * -0.5]), rot_rad)
240 | dst_dir = np.array([0.0, dst_w * -0.5])
241 |
242 | # get four corners of the src rectangle in the original image
243 | src = np.zeros((3, 2), dtype=np.float32)
244 | src[0, :] = center + scale * shift
245 | src[1, :] = center + src_dir + scale * shift
246 | src[2, :] = _get_3rd_point(src[0, :], src[1, :])
247 |
248 | # get four corners of the dst rectangle in the input image
249 | dst = np.zeros((3, 2), dtype=np.float32)
250 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
251 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
252 | dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
253 |
254 | if inv:
255 | warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
256 | else:
257 | warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
258 |
259 | return warp_mat
260 |
261 |
262 | def top_down_affine(
263 | input_size: dict, bbox_scale: dict, bbox_center: dict, img: np.ndarray
264 | ) -> Tuple[np.ndarray, np.ndarray]:
265 | """Get the bbox image as the model input by affine transform.
266 |
267 | Args:
268 | input_size (dict): The input size of the model.
269 | bbox_scale (dict): The bbox scale of the img.
270 | bbox_center (dict): The bbox center of the img.
271 | img (np.ndarray): The original image.
272 |
273 | Returns:
274 | tuple: A tuple containing center and scale.
275 | - np.ndarray[float32]: img after affine transform.
276 | - np.ndarray[float32]: bbox scale after affine transform.
277 | """
278 | w, h = input_size
279 | warp_size = (int(w), int(h))
280 |
281 | # reshape bbox to fixed aspect ratio
282 | bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
283 |
284 | # get the affine matrix
285 | center = bbox_center
286 | scale = bbox_scale
287 | rot = 0
288 | warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
289 |
290 | # do affine transform
291 | img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
292 |
293 | return img, bbox_scale
294 |
295 |
296 | def get_simcc_maximum(
297 | simcc_x: np.ndarray, simcc_y: np.ndarray
298 | ) -> Tuple[np.ndarray, np.ndarray]:
299 | """Get maximum response location and value from simcc representations.
300 |
301 | Note:
302 | instance number: N
303 | num_keypoints: K
304 | heatmap height: H
305 | heatmap width: W
306 |
307 | Args:
308 | simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
309 | simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
310 |
311 | Returns:
312 | tuple:
313 | - locs (np.ndarray): locations of maximum heatmap responses in shape
314 | (K, 2) or (N, K, 2)
315 | - vals (np.ndarray): values of maximum heatmap responses in shape
316 | (K,) or (N, K)
317 | """
318 | N, K, Wx = simcc_x.shape
319 | simcc_x = simcc_x.reshape(N * K, -1)
320 | simcc_y = simcc_y.reshape(N * K, -1)
321 |
322 | # get maximum value locations
323 | x_locs = np.argmax(simcc_x, axis=1)
324 | y_locs = np.argmax(simcc_y, axis=1)
325 | locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
326 | max_val_x = np.amax(simcc_x, axis=1)
327 | max_val_y = np.amax(simcc_y, axis=1)
328 |
329 | # get maximum value across x and y axis
330 | mask = max_val_x > max_val_y
331 | max_val_x[mask] = max_val_y[mask]
332 | vals = max_val_x
333 | locs[vals <= 0.0] = -1
334 |
335 | # reshape
336 | locs = locs.reshape(N, K, 2)
337 | vals = vals.reshape(N, K)
338 |
339 | return locs, vals
340 |
341 |
342 | def decode(
343 | simcc_x: np.ndarray, simcc_y: np.ndarray, simcc_split_ratio
344 | ) -> Tuple[np.ndarray, np.ndarray]:
345 | """Modulate simcc distribution with Gaussian.
346 |
347 | Args:
348 | simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
349 | simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
350 | simcc_split_ratio (int): The split ratio of simcc.
351 |
352 | Returns:
353 | tuple: A tuple containing center and scale.
354 | - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
355 | - np.ndarray[float32]: scores in shape (K,) or (n, K)
356 | """
357 | keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
358 | keypoints /= simcc_split_ratio
359 |
360 | return keypoints, scores
361 |
362 |
363 | def inference_pose(session, out_bbox, oriImg):
364 | h, w = session.get_inputs()[0].shape[2:]
365 | model_input_size = (w, h)
366 | resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
367 | outputs = inference(session, resized_img)
368 | keypoints, scores = postprocess(outputs, model_input_size, center, scale)
369 |
370 | return keypoints, scores
371 |
--------------------------------------------------------------------------------
/src/dwpose/util.py:
--------------------------------------------------------------------------------
1 | # https://github.com/IDEA-Research/DWPose
2 | import math
3 | import numpy as np
4 | import matplotlib
5 | import cv2
6 |
7 |
8 | eps = 0.01
9 |
10 |
11 | def smart_resize(x, s):
12 | Ht, Wt = s
13 | if x.ndim == 2:
14 | Ho, Wo = x.shape
15 | Co = 1
16 | else:
17 | Ho, Wo, Co = x.shape
18 | if Co == 3 or Co == 1:
19 | k = float(Ht + Wt) / float(Ho + Wo)
20 | return cv2.resize(
21 | x,
22 | (int(Wt), int(Ht)),
23 | interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4,
24 | )
25 | else:
26 | return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
27 |
28 |
29 | def smart_resize_k(x, fx, fy):
30 | if x.ndim == 2:
31 | Ho, Wo = x.shape
32 | Co = 1
33 | else:
34 | Ho, Wo, Co = x.shape
35 | Ht, Wt = Ho * fy, Wo * fx
36 | if Co == 3 or Co == 1:
37 | k = float(Ht + Wt) / float(Ho + Wo)
38 | return cv2.resize(
39 | x,
40 | (int(Wt), int(Ht)),
41 | interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4,
42 | )
43 | else:
44 | return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
45 |
46 |
47 | def padRightDownCorner(img, stride, padValue):
48 | h = img.shape[0]
49 | w = img.shape[1]
50 |
51 | pad = 4 * [None]
52 | pad[0] = 0 # up
53 | pad[1] = 0 # left
54 | pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
55 | pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
56 |
57 | img_padded = img
58 | pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
59 | img_padded = np.concatenate((pad_up, img_padded), axis=0)
60 | pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
61 | img_padded = np.concatenate((pad_left, img_padded), axis=1)
62 | pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
63 | img_padded = np.concatenate((img_padded, pad_down), axis=0)
64 | pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
65 | img_padded = np.concatenate((img_padded, pad_right), axis=1)
66 |
67 | return img_padded, pad
68 |
69 |
70 | def transfer(model, model_weights):
71 | transfered_model_weights = {}
72 | for weights_name in model.state_dict().keys():
73 | transfered_model_weights[weights_name] = model_weights[
74 | ".".join(weights_name.split(".")[1:])
75 | ]
76 | return transfered_model_weights
77 |
78 |
79 | def draw_bodypose(canvas, candidate, subset):
80 | H, W, C = canvas.shape
81 | candidate = np.array(candidate)
82 | subset = np.array(subset)
83 |
84 | stickwidth = 4
85 |
86 | limbSeq = [
87 | [2, 3],
88 | [2, 6],
89 | [3, 4],
90 | [4, 5],
91 | [6, 7],
92 | [7, 8],
93 | [2, 9],
94 | [9, 10],
95 | [10, 11],
96 | [2, 12],
97 | [12, 13],
98 | [13, 14],
99 | [2, 1],
100 | [1, 15],
101 | [15, 17],
102 | [1, 16],
103 | [16, 18],
104 | [3, 17],
105 | [6, 18],
106 | ]
107 |
108 | colors = [
109 | [255, 0, 0],
110 | [255, 85, 0],
111 | [255, 170, 0],
112 | [255, 255, 0],
113 | [170, 255, 0],
114 | [85, 255, 0],
115 | [0, 255, 0],
116 | [0, 255, 85],
117 | [0, 255, 170],
118 | [0, 255, 255],
119 | [0, 170, 255],
120 | [0, 85, 255],
121 | [0, 0, 255],
122 | [85, 0, 255],
123 | [170, 0, 255],
124 | [255, 0, 255],
125 | [255, 0, 170],
126 | [255, 0, 85],
127 | ]
128 |
129 | for i in range(17):
130 | for n in range(len(subset)):
131 | index = subset[n][np.array(limbSeq[i]) - 1]
132 | if -1 in index:
133 | continue
134 | Y = candidate[index.astype(int), 0] * float(W)
135 | X = candidate[index.astype(int), 1] * float(H)
136 | mX = np.mean(X)
137 | mY = np.mean(Y)
138 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
139 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
140 | polygon = cv2.ellipse2Poly(
141 | (int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1
142 | )
143 | cv2.fillConvexPoly(canvas, polygon, colors[i])
144 |
145 | canvas = (canvas * 0.6).astype(np.uint8)
146 |
147 | for i in range(18):
148 | for n in range(len(subset)):
149 | index = int(subset[n][i])
150 | if index == -1:
151 | continue
152 | x, y = candidate[index][0:2]
153 | x = int(x * W)
154 | y = int(y * H)
155 | cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
156 |
157 | return canvas
158 |
159 |
160 | def draw_handpose(canvas, all_hand_peaks):
161 | H, W, C = canvas.shape
162 |
163 | edges = [
164 | [0, 1],
165 | [1, 2],
166 | [2, 3],
167 | [3, 4],
168 | [0, 5],
169 | [5, 6],
170 | [6, 7],
171 | [7, 8],
172 | [0, 9],
173 | [9, 10],
174 | [10, 11],
175 | [11, 12],
176 | [0, 13],
177 | [13, 14],
178 | [14, 15],
179 | [15, 16],
180 | [0, 17],
181 | [17, 18],
182 | [18, 19],
183 | [19, 20],
184 | ]
185 |
186 | for peaks in all_hand_peaks:
187 | peaks = np.array(peaks)
188 |
189 | for ie, e in enumerate(edges):
190 | x1, y1 = peaks[e[0]]
191 | x2, y2 = peaks[e[1]]
192 | x1 = int(x1 * W)
193 | y1 = int(y1 * H)
194 | x2 = int(x2 * W)
195 | y2 = int(y2 * H)
196 | if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
197 | cv2.line(
198 | canvas,
199 | (x1, y1),
200 | (x2, y2),
201 | matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0])
202 | * 255,
203 | thickness=2,
204 | )
205 |
206 | for i, keyponit in enumerate(peaks):
207 | x, y = keyponit
208 | x = int(x * W)
209 | y = int(y * H)
210 | if x > eps and y > eps:
211 | cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
212 | return canvas
213 |
214 |
215 | def draw_facepose(canvas, all_lmks):
216 | H, W, C = canvas.shape
217 | for lmks in all_lmks:
218 | lmks = np.array(lmks)
219 | for lmk in lmks:
220 | x, y = lmk
221 | x = int(x * W)
222 | y = int(y * H)
223 | if x > eps and y > eps:
224 | cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
225 | return canvas
226 |
227 |
228 | # detect hand according to body pose keypoints
229 | # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
230 | def handDetect(candidate, subset, oriImg):
231 | # right hand: wrist 4, elbow 3, shoulder 2
232 | # left hand: wrist 7, elbow 6, shoulder 5
233 | ratioWristElbow = 0.33
234 | detect_result = []
235 | image_height, image_width = oriImg.shape[0:2]
236 | for person in subset.astype(int):
237 | # if any of three not detected
238 | has_left = np.sum(person[[5, 6, 7]] == -1) == 0
239 | has_right = np.sum(person[[2, 3, 4]] == -1) == 0
240 | if not (has_left or has_right):
241 | continue
242 | hands = []
243 | # left hand
244 | if has_left:
245 | left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
246 | x1, y1 = candidate[left_shoulder_index][:2]
247 | x2, y2 = candidate[left_elbow_index][:2]
248 | x3, y3 = candidate[left_wrist_index][:2]
249 | hands.append([x1, y1, x2, y2, x3, y3, True])
250 | # right hand
251 | if has_right:
252 | right_shoulder_index, right_elbow_index, right_wrist_index = person[
253 | [2, 3, 4]
254 | ]
255 | x1, y1 = candidate[right_shoulder_index][:2]
256 | x2, y2 = candidate[right_elbow_index][:2]
257 | x3, y3 = candidate[right_wrist_index][:2]
258 | hands.append([x1, y1, x2, y2, x3, y3, False])
259 |
260 | for x1, y1, x2, y2, x3, y3, is_left in hands:
261 | # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
262 | # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
263 | # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
264 | # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
265 | # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
266 | # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
267 | x = x3 + ratioWristElbow * (x3 - x2)
268 | y = y3 + ratioWristElbow * (y3 - y2)
269 | distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
270 | distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
271 | width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
272 | # x-y refers to the center --> offset to topLeft point
273 | # handRectangle.x -= handRectangle.width / 2.f;
274 | # handRectangle.y -= handRectangle.height / 2.f;
275 | x -= width / 2
276 | y -= width / 2 # width = height
277 | # overflow the image
278 | if x < 0:
279 | x = 0
280 | if y < 0:
281 | y = 0
282 | width1 = width
283 | width2 = width
284 | if x + width > image_width:
285 | width1 = image_width - x
286 | if y + width > image_height:
287 | width2 = image_height - y
288 | width = min(width1, width2)
289 | # the max hand box value is 20 pixels
290 | if width >= 20:
291 | detect_result.append([int(x), int(y), int(width), is_left])
292 |
293 | """
294 | return value: [[x, y, w, True if left hand else False]].
295 | width=height since the network require squared input.
296 | x, y is the coordinate of top left
297 | """
298 | return detect_result
299 |
300 |
301 | # Written by Lvmin
302 | def faceDetect(candidate, subset, oriImg):
303 | # left right eye ear 14 15 16 17
304 | detect_result = []
305 | image_height, image_width = oriImg.shape[0:2]
306 | for person in subset.astype(int):
307 | has_head = person[0] > -1
308 | if not has_head:
309 | continue
310 |
311 | has_left_eye = person[14] > -1
312 | has_right_eye = person[15] > -1
313 | has_left_ear = person[16] > -1
314 | has_right_ear = person[17] > -1
315 |
316 | if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
317 | continue
318 |
319 | head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
320 |
321 | width = 0.0
322 | x0, y0 = candidate[head][:2]
323 |
324 | if has_left_eye:
325 | x1, y1 = candidate[left_eye][:2]
326 | d = max(abs(x0 - x1), abs(y0 - y1))
327 | width = max(width, d * 3.0)
328 |
329 | if has_right_eye:
330 | x1, y1 = candidate[right_eye][:2]
331 | d = max(abs(x0 - x1), abs(y0 - y1))
332 | width = max(width, d * 3.0)
333 |
334 | if has_left_ear:
335 | x1, y1 = candidate[left_ear][:2]
336 | d = max(abs(x0 - x1), abs(y0 - y1))
337 | width = max(width, d * 1.5)
338 |
339 | if has_right_ear:
340 | x1, y1 = candidate[right_ear][:2]
341 | d = max(abs(x0 - x1), abs(y0 - y1))
342 | width = max(width, d * 1.5)
343 |
344 | x, y = x0, y0
345 |
346 | x -= width
347 | y -= width
348 |
349 | if x < 0:
350 | x = 0
351 |
352 | if y < 0:
353 | y = 0
354 |
355 | width1 = width * 2
356 | width2 = width * 2
357 |
358 | if x + width > image_width:
359 | width1 = image_width - x
360 |
361 | if y + width > image_height:
362 | width2 = image_height - y
363 |
364 | width = min(width1, width2)
365 |
366 | if width >= 20:
367 | detect_result.append([int(x), int(y), int(width)])
368 |
369 | return detect_result
370 |
371 |
372 | # get max index of 2d array
373 | def npmax(array):
374 | arrayindex = array.argmax(1)
375 | arrayvalue = array.max(1)
376 | i = arrayvalue.argmax()
377 | j = arrayindex[i]
378 | return i, j
379 |
--------------------------------------------------------------------------------
/src/dwpose/wholebody.py:
--------------------------------------------------------------------------------
1 | # https://github.com/IDEA-Research/DWPose
2 | from pathlib import Path
3 |
4 | import cv2
5 | import numpy as np
6 | import onnxruntime as ort
7 |
8 | from .onnxdet import inference_detector
9 | from .onnxpose import inference_pose
10 |
11 | ModelDataPathPrefix = Path("./pretrained_weights")
12 |
13 |
14 | class Wholebody:
15 | def __init__(self, device="cuda:0"):
16 | providers = (
17 | ["CPUExecutionProvider"] if device == "cpu" else ["CUDAExecutionProvider"]
18 | )
19 | onnx_det = ModelDataPathPrefix.joinpath("DWPose/yolox_l.onnx")
20 | onnx_pose = ModelDataPathPrefix.joinpath("DWPose/dw-ll_ucoco_384.onnx")
21 |
22 | self.session_det = ort.InferenceSession(
23 | path_or_bytes=onnx_det, providers=providers
24 | )
25 | self.session_pose = ort.InferenceSession(
26 | path_or_bytes=onnx_pose, providers=providers
27 | )
28 |
29 | def __call__(self, oriImg):
30 | det_result = inference_detector(self.session_det, oriImg)
31 | keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
32 |
33 | keypoints_info = np.concatenate((keypoints, scores[..., None]), axis=-1)
34 | # compute neck joint
35 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
36 | # neck score when visualizing pred
37 | neck[:, 2:4] = np.logical_and(
38 | keypoints_info[:, 5, 2:4] > 0.3, keypoints_info[:, 6, 2:4] > 0.3
39 | ).astype(int)
40 | new_keypoints_info = np.insert(keypoints_info, 17, neck, axis=1)
41 | mmpose_idx = [17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3]
42 | openpose_idx = [1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17]
43 | new_keypoints_info[:, openpose_idx] = new_keypoints_info[:, mmpose_idx]
44 | keypoints_info = new_keypoints_info
45 |
46 | keypoints, scores = keypoints_info[..., :2], keypoints_info[..., 2]
47 |
48 | return keypoints, scores
49 |
--------------------------------------------------------------------------------
/src/models/model_util.py:
--------------------------------------------------------------------------------
1 | from typing import Literal, Union, Optional
2 |
3 | import torch
4 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
5 | from diffusers import (
6 | SchedulerMixin,
7 | StableDiffusionPipeline,
8 | StableDiffusionXLPipeline,
9 | AutoencoderKL,
10 | )
11 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
12 | convert_ldm_unet_checkpoint,
13 | )
14 | from safetensors.torch import load_file
15 | from src.models.unet_2d_condition import UNet2DConditionModel
16 | from diffusers.schedulers import (
17 | DDIMScheduler,
18 | DDPMScheduler,
19 | LMSDiscreteScheduler,
20 | EulerAncestralDiscreteScheduler,
21 | UniPCMultistepScheduler,
22 | )
23 |
24 | from omegaconf import OmegaConf
25 |
26 | # DiffUsers版StableDiffusionのモデルパラメータ
27 | NUM_TRAIN_TIMESTEPS = 1000
28 | BETA_START = 0.00085
29 | BETA_END = 0.0120
30 |
31 | UNET_PARAMS_MODEL_CHANNELS = 320
32 | UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
33 | UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
34 | UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
35 | UNET_PARAMS_IN_CHANNELS = 4
36 | UNET_PARAMS_OUT_CHANNELS = 4
37 | UNET_PARAMS_NUM_RES_BLOCKS = 2
38 | UNET_PARAMS_CONTEXT_DIM = 768
39 | UNET_PARAMS_NUM_HEADS = 8
40 | # UNET_PARAMS_USE_LINEAR_PROJECTION = False
41 |
42 | VAE_PARAMS_Z_CHANNELS = 4
43 | VAE_PARAMS_RESOLUTION = 256
44 | VAE_PARAMS_IN_CHANNELS = 3
45 | VAE_PARAMS_OUT_CH = 3
46 | VAE_PARAMS_CH = 128
47 | VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
48 | VAE_PARAMS_NUM_RES_BLOCKS = 2
49 |
50 | # V2
51 | V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
52 | V2_UNET_PARAMS_CONTEXT_DIM = 1024
53 | # V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
54 |
55 | TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4"
56 | TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1"
57 |
58 | AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a", "uniPC"]
59 |
60 | SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
61 |
62 | DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this
63 |
64 |
65 | def load_checkpoint_with_text_encoder_conversion(ckpt_path: str, device="cpu"):
66 | # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
67 | TEXT_ENCODER_KEY_REPLACEMENTS = [
68 | (
69 | "cond_stage_model.transformer.embeddings.",
70 | "cond_stage_model.transformer.text_model.embeddings.",
71 | ),
72 | (
73 | "cond_stage_model.transformer.encoder.",
74 | "cond_stage_model.transformer.text_model.encoder.",
75 | ),
76 | (
77 | "cond_stage_model.transformer.final_layer_norm.",
78 | "cond_stage_model.transformer.text_model.final_layer_norm.",
79 | ),
80 | ]
81 |
82 | if ckpt_path.endswith(".safetensors"):
83 | checkpoint = None
84 | state_dict = load_file(ckpt_path) # , device) # may causes error
85 | else:
86 | checkpoint = torch.load(ckpt_path, map_location=device)
87 | if "state_dict" in checkpoint:
88 | state_dict = checkpoint["state_dict"]
89 | else:
90 | state_dict = checkpoint
91 | checkpoint = None
92 |
93 | key_reps = []
94 | for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
95 | for key in state_dict.keys():
96 | if key.startswith(rep_from):
97 | new_key = rep_to + key[len(rep_from) :]
98 | key_reps.append((key, new_key))
99 |
100 | for key, new_key in key_reps:
101 | state_dict[new_key] = state_dict[key]
102 | del state_dict[key]
103 |
104 | return checkpoint, state_dict
105 |
106 |
107 | def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
108 | """
109 | Creates a config for the diffusers based on the config of the LDM model.
110 | """
111 | # unet_params = original_config.model.params.unet_config.params
112 |
113 | block_out_channels = [
114 | UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT
115 | ]
116 |
117 | down_block_types = []
118 | resolution = 1
119 | for i in range(len(block_out_channels)):
120 | block_type = (
121 | "CrossAttnDownBlock2D"
122 | if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS
123 | else "DownBlock2D"
124 | )
125 | down_block_types.append(block_type)
126 | if i != len(block_out_channels) - 1:
127 | resolution *= 2
128 |
129 | up_block_types = []
130 | for i in range(len(block_out_channels)):
131 | block_type = (
132 | "CrossAttnUpBlock2D"
133 | if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS
134 | else "UpBlock2D"
135 | )
136 | up_block_types.append(block_type)
137 | resolution //= 2
138 |
139 | config = dict(
140 | sample_size=UNET_PARAMS_IMAGE_SIZE,
141 | in_channels=UNET_PARAMS_IN_CHANNELS,
142 | out_channels=UNET_PARAMS_OUT_CHANNELS,
143 | down_block_types=tuple(down_block_types),
144 | up_block_types=tuple(up_block_types),
145 | block_out_channels=tuple(block_out_channels),
146 | layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
147 | cross_attention_dim=UNET_PARAMS_CONTEXT_DIM
148 | if not v2
149 | else V2_UNET_PARAMS_CONTEXT_DIM,
150 | attention_head_dim=UNET_PARAMS_NUM_HEADS
151 | if not v2
152 | else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
153 | # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
154 | )
155 | if v2 and use_linear_projection_in_v2:
156 | config["use_linear_projection"] = True
157 |
158 | return config
159 |
160 |
161 | def load_diffusers_model(
162 | pretrained_model_name_or_path: str,
163 | v2: bool = False,
164 | clip_skip: Optional[int] = None,
165 | weight_dtype: torch.dtype = torch.float32,
166 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
167 | if v2:
168 | tokenizer = CLIPTokenizer.from_pretrained(
169 | TOKENIZER_V2_MODEL_NAME,
170 | subfolder="tokenizer",
171 | torch_dtype=weight_dtype,
172 | cache_dir=DIFFUSERS_CACHE_DIR,
173 | )
174 | text_encoder = CLIPTextModel.from_pretrained(
175 | pretrained_model_name_or_path,
176 | subfolder="text_encoder",
177 | # default is clip skip 2
178 | num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23,
179 | torch_dtype=weight_dtype,
180 | cache_dir=DIFFUSERS_CACHE_DIR,
181 | )
182 | else:
183 | tokenizer = CLIPTokenizer.from_pretrained(
184 | TOKENIZER_V1_MODEL_NAME,
185 | subfolder="tokenizer",
186 | torch_dtype=weight_dtype,
187 | cache_dir=DIFFUSERS_CACHE_DIR,
188 | )
189 | text_encoder = CLIPTextModel.from_pretrained(
190 | pretrained_model_name_or_path,
191 | subfolder="text_encoder",
192 | num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12,
193 | torch_dtype=weight_dtype,
194 | cache_dir=DIFFUSERS_CACHE_DIR,
195 | )
196 |
197 | unet = UNet2DConditionModel.from_pretrained(
198 | pretrained_model_name_or_path,
199 | subfolder="unet",
200 | torch_dtype=weight_dtype,
201 | cache_dir=DIFFUSERS_CACHE_DIR,
202 | )
203 |
204 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
205 |
206 | return tokenizer, text_encoder, unet, vae
207 |
208 |
209 | def load_checkpoint_model(
210 | checkpoint_path: str,
211 | v2: bool = False,
212 | clip_skip: Optional[int] = None,
213 | weight_dtype: torch.dtype = torch.float32,
214 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
215 | pipe = StableDiffusionPipeline.from_single_file(
216 | checkpoint_path,
217 | upcast_attention=True if v2 else False,
218 | torch_dtype=weight_dtype,
219 | cache_dir=DIFFUSERS_CACHE_DIR,
220 | )
221 |
222 | _, state_dict = load_checkpoint_with_text_encoder_conversion(checkpoint_path)
223 | unet_config = create_unet_diffusers_config(v2, use_linear_projection_in_v2=v2)
224 | unet_config["class_embed_type"] = None
225 | unet_config["addition_embed_type"] = None
226 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config)
227 | unet = UNet2DConditionModel(**unet_config)
228 | converted_unet_checkpoint.pop("conv_out.weight")
229 | converted_unet_checkpoint.pop("conv_out.bias")
230 | unet.load_state_dict(converted_unet_checkpoint)
231 |
232 | tokenizer = pipe.tokenizer
233 | text_encoder = pipe.text_encoder
234 | vae = pipe.vae
235 | if clip_skip is not None:
236 | if v2:
237 | text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1)
238 | else:
239 | text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1)
240 |
241 | del pipe
242 |
243 | return tokenizer, text_encoder, unet, vae
244 |
245 |
246 | def load_models(
247 | pretrained_model_name_or_path: str,
248 | scheduler_name: str,
249 | v2: bool = False,
250 | v_pred: bool = False,
251 | weight_dtype: torch.dtype = torch.float32,
252 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]:
253 | if pretrained_model_name_or_path.endswith(
254 | ".ckpt"
255 | ) or pretrained_model_name_or_path.endswith(".safetensors"):
256 | tokenizer, text_encoder, unet, vae = load_checkpoint_model(
257 | pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
258 | )
259 | else: # diffusers
260 | tokenizer, text_encoder, unet, vae = load_diffusers_model(
261 | pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
262 | )
263 |
264 | if scheduler_name:
265 | scheduler = create_noise_scheduler(
266 | scheduler_name,
267 | prediction_type="v_prediction" if v_pred else "epsilon",
268 | )
269 | else:
270 | scheduler = None
271 |
272 | return tokenizer, text_encoder, unet, scheduler, vae
273 |
274 |
275 | def load_diffusers_model_xl(
276 | pretrained_model_name_or_path: str,
277 | weight_dtype: torch.dtype = torch.float32,
278 | ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
279 | # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet
280 |
281 | tokenizers = [
282 | CLIPTokenizer.from_pretrained(
283 | pretrained_model_name_or_path,
284 | subfolder="tokenizer",
285 | torch_dtype=weight_dtype,
286 | cache_dir=DIFFUSERS_CACHE_DIR,
287 | ),
288 | CLIPTokenizer.from_pretrained(
289 | pretrained_model_name_or_path,
290 | subfolder="tokenizer_2",
291 | torch_dtype=weight_dtype,
292 | cache_dir=DIFFUSERS_CACHE_DIR,
293 | pad_token_id=0, # same as open clip
294 | ),
295 | ]
296 |
297 | text_encoders = [
298 | CLIPTextModel.from_pretrained(
299 | pretrained_model_name_or_path,
300 | subfolder="text_encoder",
301 | torch_dtype=weight_dtype,
302 | cache_dir=DIFFUSERS_CACHE_DIR,
303 | ),
304 | CLIPTextModelWithProjection.from_pretrained(
305 | pretrained_model_name_or_path,
306 | subfolder="text_encoder_2",
307 | torch_dtype=weight_dtype,
308 | cache_dir=DIFFUSERS_CACHE_DIR,
309 | ),
310 | ]
311 |
312 | unet = UNet2DConditionModel.from_pretrained(
313 | pretrained_model_name_or_path,
314 | subfolder="unet",
315 | torch_dtype=weight_dtype,
316 | cache_dir=DIFFUSERS_CACHE_DIR,
317 | )
318 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
319 | return tokenizers, text_encoders, unet, vae
320 |
321 |
322 | def load_checkpoint_model_xl(
323 | checkpoint_path: str,
324 | weight_dtype: torch.dtype = torch.float32,
325 | ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
326 | pipe = StableDiffusionXLPipeline.from_single_file(
327 | checkpoint_path,
328 | torch_dtype=weight_dtype,
329 | cache_dir=DIFFUSERS_CACHE_DIR,
330 | )
331 |
332 | unet = pipe.unet
333 | vae = pipe.vae
334 | tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
335 | text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
336 | if len(text_encoders) == 2:
337 | text_encoders[1].pad_token_id = 0
338 |
339 | del pipe
340 |
341 | return tokenizers, text_encoders, unet, vae
342 |
343 |
344 | def load_models_xl(
345 | pretrained_model_name_or_path: str,
346 | scheduler_name: str,
347 | weight_dtype: torch.dtype = torch.float32,
348 | noise_scheduler_kwargs=None,
349 | ) -> tuple[
350 | list[CLIPTokenizer],
351 | list[SDXL_TEXT_ENCODER_TYPE],
352 | UNet2DConditionModel,
353 | SchedulerMixin,
354 | ]:
355 | if pretrained_model_name_or_path.endswith(
356 | ".ckpt"
357 | ) or pretrained_model_name_or_path.endswith(".safetensors"):
358 | (tokenizers, text_encoders, unet, vae) = load_checkpoint_model_xl(
359 | pretrained_model_name_or_path, weight_dtype
360 | )
361 | else: # diffusers
362 | (tokenizers, text_encoders, unet, vae) = load_diffusers_model_xl(
363 | pretrained_model_name_or_path, weight_dtype
364 | )
365 | if scheduler_name:
366 | scheduler = create_noise_scheduler(scheduler_name, noise_scheduler_kwargs)
367 | else:
368 | scheduler = None
369 |
370 | return tokenizers, text_encoders, unet, scheduler, vae
371 |
372 |
373 | def create_noise_scheduler(
374 | scheduler_name: AVAILABLE_SCHEDULERS = "ddpm",
375 | noise_scheduler_kwargs=None,
376 | prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
377 | ) -> SchedulerMixin:
378 | name = scheduler_name.lower().replace(" ", "_")
379 | if name.lower() == "ddim":
380 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim
381 | scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
382 | elif name.lower() == "ddpm":
383 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm
384 | scheduler = DDPMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
385 | elif name.lower() == "lms":
386 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete
387 | scheduler = LMSDiscreteScheduler(
388 | **OmegaConf.to_container(noise_scheduler_kwargs)
389 | )
390 | elif name.lower() == "euler_a":
391 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
392 | scheduler = EulerAncestralDiscreteScheduler(
393 | **OmegaConf.to_container(noise_scheduler_kwargs)
394 | )
395 | elif name.lower() == "unipc":
396 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/unipc
397 | scheduler = UniPCMultistepScheduler(
398 | **OmegaConf.to_container(noise_scheduler_kwargs)
399 | )
400 | else:
401 | raise ValueError(f"Unknown scheduler name: {name}")
402 |
403 | return scheduler
404 |
405 |
406 | def torch_gc():
407 | import gc
408 |
409 | gc.collect()
410 | if torch.cuda.is_available():
411 | with torch.cuda.device("cuda"):
412 | torch.cuda.empty_cache()
413 | torch.cuda.ipc_collect()
414 |
415 |
416 | from enum import Enum
417 |
418 |
419 | class CPUState(Enum):
420 | GPU = 0
421 | CPU = 1
422 | MPS = 2
423 |
424 |
425 | cpu_state = CPUState.GPU
426 | xpu_available = False
427 | directml_enabled = False
428 |
429 |
430 | def is_intel_xpu():
431 | global cpu_state
432 | global xpu_available
433 | if cpu_state == CPUState.GPU:
434 | if xpu_available:
435 | return True
436 | return False
437 |
438 |
439 | try:
440 | import intel_extension_for_pytorch as ipex
441 |
442 | if torch.xpu.is_available():
443 | xpu_available = True
444 | except:
445 | pass
446 |
447 | try:
448 | if torch.backends.mps.is_available():
449 | cpu_state = CPUState.MPS
450 | import torch.mps
451 | except:
452 | pass
453 |
454 |
455 | def get_torch_device():
456 | global directml_enabled
457 | global cpu_state
458 | if directml_enabled:
459 | global directml_device
460 | return directml_device
461 | if cpu_state == CPUState.MPS:
462 | return torch.device("mps")
463 | if cpu_state == CPUState.CPU:
464 | return torch.device("cpu")
465 | else:
466 | if is_intel_xpu():
467 | return torch.device("xpu")
468 | else:
469 | return torch.device(torch.cuda.current_device())
470 |
--------------------------------------------------------------------------------
/src/models/motion_module.py:
--------------------------------------------------------------------------------
1 | # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
2 | import math
3 | from dataclasses import dataclass
4 | from typing import Callable, Optional
5 |
6 | import torch
7 | from diffusers.models.attention import FeedForward
8 | from diffusers.models.attention_processor import Attention, AttnProcessor
9 | from diffusers.utils import BaseOutput
10 | from diffusers.utils.import_utils import is_xformers_available
11 | from einops import rearrange, repeat
12 | from torch import nn
13 | from src.models.model_util import get_torch_device
14 |
15 |
16 | def zero_module(module):
17 | # Zero out the parameters of a module and return it.
18 | for p in module.parameters():
19 | p.detach().zero_()
20 | return module
21 |
22 |
23 | @dataclass
24 | class TemporalTransformer3DModelOutput(BaseOutput):
25 | sample: torch.FloatTensor
26 |
27 |
28 | if is_xformers_available():
29 | import xformers
30 | import xformers.ops
31 | else:
32 | xformers = None
33 |
34 |
35 | def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
36 | if motion_module_type == "Vanilla":
37 | return VanillaTemporalModule(
38 | in_channels=in_channels,
39 | **motion_module_kwargs,
40 | )
41 | else:
42 | raise ValueError
43 |
44 |
45 | class VanillaTemporalModule(nn.Module):
46 | def __init__(
47 | self,
48 | in_channels,
49 | num_attention_heads=8,
50 | num_transformer_block=2,
51 | attention_block_types=("Temporal_Self", "Temporal_Self"),
52 | cross_frame_attention_mode=None,
53 | temporal_position_encoding=False,
54 | temporal_position_encoding_max_len=24,
55 | temporal_attention_dim_div=1,
56 | zero_initialize=True,
57 | ):
58 | super().__init__()
59 |
60 | self.temporal_transformer = TemporalTransformer3DModel(
61 | in_channels=in_channels,
62 | num_attention_heads=num_attention_heads,
63 | attention_head_dim=in_channels
64 | // num_attention_heads
65 | // temporal_attention_dim_div,
66 | num_layers=num_transformer_block,
67 | attention_block_types=attention_block_types,
68 | cross_frame_attention_mode=cross_frame_attention_mode,
69 | temporal_position_encoding=temporal_position_encoding,
70 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
71 | )
72 |
73 | if zero_initialize:
74 | self.temporal_transformer.proj_out = zero_module(
75 | self.temporal_transformer.proj_out
76 | )
77 |
78 | def forward(
79 | self,
80 | input_tensor,
81 | temb,
82 | encoder_hidden_states,
83 | attention_mask=None,
84 | anchor_frame_idx=None,
85 | ):
86 | hidden_states = input_tensor
87 | hidden_states = self.temporal_transformer(
88 | hidden_states, encoder_hidden_states, attention_mask
89 | )
90 |
91 | output = hidden_states
92 | return output
93 |
94 |
95 | class TemporalTransformer3DModel(nn.Module):
96 | def __init__(
97 | self,
98 | in_channels,
99 | num_attention_heads,
100 | attention_head_dim,
101 | num_layers,
102 | attention_block_types=(
103 | "Temporal_Self",
104 | "Temporal_Self",
105 | ),
106 | dropout=0.0,
107 | norm_num_groups=32,
108 | cross_attention_dim=768,
109 | activation_fn="geglu",
110 | attention_bias=False,
111 | upcast_attention=False,
112 | cross_frame_attention_mode=None,
113 | temporal_position_encoding=False,
114 | temporal_position_encoding_max_len=24,
115 | ):
116 | super().__init__()
117 |
118 | inner_dim = num_attention_heads * attention_head_dim
119 |
120 | self.norm = torch.nn.GroupNorm(
121 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
122 | )
123 | self.proj_in = nn.Linear(in_channels, inner_dim)
124 |
125 | self.transformer_blocks = nn.ModuleList(
126 | [
127 | TemporalTransformerBlock(
128 | dim=inner_dim,
129 | num_attention_heads=num_attention_heads,
130 | attention_head_dim=attention_head_dim,
131 | attention_block_types=attention_block_types,
132 | dropout=dropout,
133 | norm_num_groups=norm_num_groups,
134 | cross_attention_dim=cross_attention_dim,
135 | activation_fn=activation_fn,
136 | attention_bias=attention_bias,
137 | upcast_attention=upcast_attention,
138 | cross_frame_attention_mode=cross_frame_attention_mode,
139 | temporal_position_encoding=temporal_position_encoding,
140 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
141 | )
142 | for d in range(num_layers)
143 | ]
144 | )
145 | self.proj_out = nn.Linear(inner_dim, in_channels)
146 |
147 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
148 | assert (
149 | hidden_states.dim() == 5
150 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
151 | video_length = hidden_states.shape[2]
152 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
153 |
154 | batch, channel, height, weight = hidden_states.shape
155 | residual = hidden_states
156 |
157 | hidden_states = self.norm(hidden_states)
158 | inner_dim = hidden_states.shape[1]
159 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
160 | batch, height * weight, inner_dim
161 | )
162 | hidden_states = self.proj_in(hidden_states)
163 |
164 | # Transformer Blocks
165 | for block in self.transformer_blocks:
166 | hidden_states = block(
167 | hidden_states,
168 | encoder_hidden_states=encoder_hidden_states,
169 | video_length=video_length,
170 | )
171 |
172 | # output
173 | hidden_states = self.proj_out(hidden_states)
174 | hidden_states = (
175 | hidden_states.reshape(batch, height, weight, inner_dim)
176 | .permute(0, 3, 1, 2)
177 | .contiguous()
178 | )
179 |
180 | output = hidden_states + residual
181 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
182 |
183 | return output
184 |
185 |
186 | class TemporalTransformerBlock(nn.Module):
187 | def __init__(
188 | self,
189 | dim,
190 | num_attention_heads,
191 | attention_head_dim,
192 | attention_block_types=(
193 | "Temporal_Self",
194 | "Temporal_Self",
195 | ),
196 | dropout=0.0,
197 | norm_num_groups=32,
198 | cross_attention_dim=768,
199 | activation_fn="geglu",
200 | attention_bias=False,
201 | upcast_attention=False,
202 | cross_frame_attention_mode=None,
203 | temporal_position_encoding=False,
204 | temporal_position_encoding_max_len=24,
205 | ):
206 | super().__init__()
207 |
208 | attention_blocks = []
209 | norms = []
210 |
211 | for block_name in attention_block_types:
212 | attention_blocks.append(
213 | VersatileAttention(
214 | attention_mode=block_name.split("_")[0],
215 | cross_attention_dim=cross_attention_dim
216 | if block_name.endswith("_Cross")
217 | else None,
218 | query_dim=dim,
219 | heads=num_attention_heads,
220 | dim_head=attention_head_dim,
221 | dropout=dropout,
222 | bias=attention_bias,
223 | upcast_attention=upcast_attention,
224 | cross_frame_attention_mode=cross_frame_attention_mode,
225 | temporal_position_encoding=temporal_position_encoding,
226 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
227 | )
228 | )
229 | norms.append(nn.LayerNorm(dim))
230 |
231 | self.attention_blocks = nn.ModuleList(attention_blocks)
232 | self.norms = nn.ModuleList(norms)
233 |
234 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
235 | self.ff_norm = nn.LayerNorm(dim)
236 | self.device = get_torch_device()
237 |
238 | def forward(
239 | self,
240 | hidden_states,
241 | encoder_hidden_states=None,
242 | attention_mask=None,
243 | video_length=None,
244 | ):
245 | for attention_block, norm in zip(self.attention_blocks, self.norms):
246 | norm_hidden_states = norm(hidden_states)
247 | hidden_states = (
248 | attention_block(
249 | norm_hidden_states,
250 | encoder_hidden_states=encoder_hidden_states
251 | if attention_block.is_cross_attention
252 | else None,
253 | video_length=video_length,
254 | )
255 | + hidden_states
256 | )
257 |
258 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
259 |
260 | output = hidden_states
261 | return output
262 |
263 |
264 | class PositionalEncoding(nn.Module):
265 | def __init__(self, d_model, dropout=0.0, max_len=24):
266 | super().__init__()
267 | self.dropout = nn.Dropout(p=dropout)
268 | position = torch.arange(max_len).unsqueeze(1)
269 | div_term = torch.exp(
270 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
271 | )
272 | pe = torch.zeros(1, max_len, d_model)
273 | pe[0, :, 0::2] = torch.sin(position * div_term)
274 | pe[0, :, 1::2] = torch.cos(position * div_term)
275 | self.register_buffer("pe", pe)
276 |
277 | def forward(self, x):
278 | x = x + self.pe[:, : x.size(1)]
279 | return self.dropout(x)
280 |
281 |
282 | class VersatileAttention(Attention):
283 | def __init__(
284 | self,
285 | attention_mode=None,
286 | cross_frame_attention_mode=None,
287 | temporal_position_encoding=False,
288 | temporal_position_encoding_max_len=24,
289 | *args,
290 | **kwargs,
291 | ):
292 | super().__init__(*args, **kwargs)
293 | assert attention_mode == "Temporal"
294 |
295 | self.attention_mode = attention_mode
296 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None
297 |
298 | self.pos_encoder = (
299 | PositionalEncoding(
300 | kwargs["query_dim"],
301 | dropout=0.0,
302 | max_len=temporal_position_encoding_max_len,
303 | )
304 | if (temporal_position_encoding and attention_mode == "Temporal")
305 | else None
306 | )
307 |
308 | def extra_repr(self):
309 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
310 |
311 | def set_use_memory_efficient_attention_xformers(
312 | self,
313 | use_memory_efficient_attention_xformers: bool,
314 | attention_op: Optional[Callable] = None,
315 | ):
316 | if use_memory_efficient_attention_xformers:
317 | if not is_xformers_available():
318 | raise ModuleNotFoundError(
319 | (
320 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
321 | " xformers"
322 | ),
323 | name="xformers",
324 | )
325 | elif not torch.cuda.is_available():
326 | raise ValueError(
327 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
328 | " only available for GPU "
329 | )
330 | else:
331 | try:
332 | # Make sure we can run the memory efficient attention
333 | _ = xformers.ops.memory_efficient_attention(
334 | torch.randn((1, 2, 40), device=self.device),
335 | torch.randn((1, 2, 40), device=self.device),
336 | torch.randn((1, 2, 40), device=self.device),
337 | )
338 | except Exception as e:
339 | raise e
340 |
341 | # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13.
342 | # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13.
343 | # You don't need XFormersAttnProcessor here.
344 | # processor = XFormersAttnProcessor(
345 | # attention_op=attention_op,
346 | # )
347 | processor = AttnProcessor()
348 | else:
349 | processor = AttnProcessor()
350 |
351 | self.set_processor(processor)
352 |
353 | def forward(
354 | self,
355 | hidden_states,
356 | encoder_hidden_states=None,
357 | attention_mask=None,
358 | video_length=None,
359 | **cross_attention_kwargs,
360 | ):
361 | if self.attention_mode == "Temporal":
362 | d = hidden_states.shape[1] # d means HxW
363 | hidden_states = rearrange(
364 | hidden_states, "(b f) d c -> (b d) f c", f=video_length
365 | )
366 |
367 | if self.pos_encoder is not None:
368 | hidden_states = self.pos_encoder(hidden_states)
369 |
370 | encoder_hidden_states = (
371 | repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
372 | if encoder_hidden_states is not None
373 | else encoder_hidden_states
374 | )
375 |
376 | else:
377 | raise NotImplementedError
378 |
379 | hidden_states = self.processor(
380 | self,
381 | hidden_states,
382 | encoder_hidden_states=encoder_hidden_states,
383 | attention_mask=attention_mask,
384 | **cross_attention_kwargs,
385 | )
386 |
387 | if self.attention_mode == "Temporal":
388 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
389 |
390 | return hidden_states
391 |
--------------------------------------------------------------------------------
/src/models/mutual_self_attention.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py
2 | from typing import Any, Dict, Optional
3 |
4 | import torch
5 | from einops import rearrange
6 |
7 | from src.models.attention import TemporalBasicTransformerBlock
8 |
9 | from .attention import BasicTransformerBlock
10 |
11 |
12 | def torch_dfs(model: torch.nn.Module):
13 | result = [model]
14 | for child in model.children():
15 | result += torch_dfs(child)
16 | return result
17 |
18 |
19 | class ReferenceAttentionControl:
20 | def __init__(
21 | self,
22 | unet,
23 | mode="write",
24 | do_classifier_free_guidance=False,
25 | attention_auto_machine_weight=float("inf"),
26 | gn_auto_machine_weight=1.0,
27 | style_fidelity=1.0,
28 | reference_attn=True,
29 | reference_adain=False,
30 | fusion_blocks="midup",
31 | batch_size=1,
32 | ) -> None:
33 | # 10. Modify self attention and group norm
34 | self.unet = unet
35 | assert mode in ["read", "write"]
36 | assert fusion_blocks in ["midup", "full"]
37 | self.reference_attn = reference_attn
38 | self.reference_adain = reference_adain
39 | self.fusion_blocks = fusion_blocks
40 | self.register_reference_hooks(
41 | mode,
42 | do_classifier_free_guidance,
43 | attention_auto_machine_weight,
44 | gn_auto_machine_weight,
45 | style_fidelity,
46 | reference_attn,
47 | reference_adain,
48 | fusion_blocks,
49 | batch_size=batch_size,
50 | )
51 |
52 | def register_reference_hooks(
53 | self,
54 | mode,
55 | do_classifier_free_guidance,
56 | attention_auto_machine_weight,
57 | gn_auto_machine_weight,
58 | style_fidelity,
59 | reference_attn,
60 | reference_adain,
61 | dtype=torch.float16,
62 | batch_size=1,
63 | num_images_per_prompt=1,
64 | device=torch.device("cpu"),
65 | fusion_blocks="midup",
66 | ):
67 | MODE = mode
68 | do_classifier_free_guidance = do_classifier_free_guidance
69 | attention_auto_machine_weight = attention_auto_machine_weight
70 | gn_auto_machine_weight = gn_auto_machine_weight
71 | style_fidelity = style_fidelity
72 | reference_attn = reference_attn
73 | reference_adain = reference_adain
74 | fusion_blocks = fusion_blocks
75 | num_images_per_prompt = num_images_per_prompt
76 | dtype = dtype
77 | if do_classifier_free_guidance:
78 | uc_mask = (
79 | torch.Tensor(
80 | [1] * batch_size * num_images_per_prompt * 16
81 | + [0] * batch_size * num_images_per_prompt * 16
82 | )
83 | .to(device)
84 | .bool()
85 | )
86 | else:
87 | uc_mask = (
88 | torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
89 | .to(device)
90 | .bool()
91 | )
92 |
93 | def hacked_basic_transformer_inner_forward(
94 | self,
95 | hidden_states: torch.FloatTensor,
96 | attention_mask: Optional[torch.FloatTensor] = None,
97 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
98 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
99 | timestep: Optional[torch.LongTensor] = None,
100 | cross_attention_kwargs: Dict[str, Any] = None,
101 | class_labels: Optional[torch.LongTensor] = None,
102 | video_length=None,
103 | ):
104 | if self.use_ada_layer_norm: # False
105 | norm_hidden_states = self.norm1(hidden_states, timestep)
106 | elif self.use_ada_layer_norm_zero:
107 | (
108 | norm_hidden_states,
109 | gate_msa,
110 | shift_mlp,
111 | scale_mlp,
112 | gate_mlp,
113 | ) = self.norm1(
114 | hidden_states,
115 | timestep,
116 | class_labels,
117 | hidden_dtype=hidden_states.dtype,
118 | )
119 | else:
120 | norm_hidden_states = self.norm1(hidden_states)
121 |
122 | # 1. Self-Attention
123 | # self.only_cross_attention = False
124 | cross_attention_kwargs = (
125 | cross_attention_kwargs if cross_attention_kwargs is not None else {}
126 | )
127 | if self.only_cross_attention:
128 | attn_output = self.attn1(
129 | norm_hidden_states,
130 | encoder_hidden_states=encoder_hidden_states
131 | if self.only_cross_attention
132 | else None,
133 | attention_mask=attention_mask,
134 | **cross_attention_kwargs,
135 | )
136 | else:
137 | if MODE == "write":
138 | self.bank.append(norm_hidden_states.clone())
139 | attn_output = self.attn1(
140 | norm_hidden_states,
141 | encoder_hidden_states=encoder_hidden_states
142 | if self.only_cross_attention
143 | else None,
144 | attention_mask=attention_mask,
145 | **cross_attention_kwargs,
146 | )
147 | if MODE == "read":
148 | bank_fea = [
149 | rearrange(
150 | d.unsqueeze(1).repeat(1, video_length, 1, 1),
151 | "b t l c -> (b t) l c",
152 | )
153 | for d in self.bank
154 | ]
155 | modify_norm_hidden_states = torch.cat(
156 | [norm_hidden_states] + bank_fea, dim=1
157 | )
158 | hidden_states_uc = (
159 | self.attn1(
160 | norm_hidden_states,
161 | encoder_hidden_states=modify_norm_hidden_states,
162 | attention_mask=attention_mask,
163 | )
164 | + hidden_states
165 | )
166 | if do_classifier_free_guidance:
167 | hidden_states_c = hidden_states_uc.clone()
168 | _uc_mask = uc_mask.clone()
169 | if hidden_states.shape[0] != _uc_mask.shape[0]:
170 | _uc_mask = (
171 | torch.Tensor(
172 | [1] * (hidden_states.shape[0] // 2)
173 | + [0] * (hidden_states.shape[0] // 2)
174 | )
175 | .to(device)
176 | .bool()
177 | )
178 | hidden_states_c[_uc_mask] = (
179 | self.attn1(
180 | norm_hidden_states[_uc_mask],
181 | encoder_hidden_states=norm_hidden_states[_uc_mask],
182 | attention_mask=attention_mask,
183 | )
184 | + hidden_states[_uc_mask]
185 | )
186 | hidden_states = hidden_states_c.clone()
187 | else:
188 | hidden_states = hidden_states_uc
189 |
190 | # self.bank.clear()
191 | if self.attn2 is not None:
192 | # Cross-Attention
193 | norm_hidden_states = (
194 | self.norm2(hidden_states, timestep)
195 | if self.use_ada_layer_norm
196 | else self.norm2(hidden_states)
197 | )
198 | hidden_states = (
199 | self.attn2(
200 | norm_hidden_states,
201 | encoder_hidden_states=encoder_hidden_states,
202 | attention_mask=attention_mask,
203 | )
204 | + hidden_states
205 | )
206 |
207 | # Feed-forward
208 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
209 |
210 | # Temporal-Attention
211 | if self.unet_use_temporal_attention:
212 | d = hidden_states.shape[1]
213 | hidden_states = rearrange(
214 | hidden_states, "(b f) d c -> (b d) f c", f=video_length
215 | )
216 | norm_hidden_states = (
217 | self.norm_temp(hidden_states, timestep)
218 | if self.use_ada_layer_norm
219 | else self.norm_temp(hidden_states)
220 | )
221 | hidden_states = (
222 | self.attn_temp(norm_hidden_states) + hidden_states
223 | )
224 | hidden_states = rearrange(
225 | hidden_states, "(b d) f c -> (b f) d c", d=d
226 | )
227 |
228 | return hidden_states
229 |
230 | if self.use_ada_layer_norm_zero:
231 | attn_output = gate_msa.unsqueeze(1) * attn_output
232 | hidden_states = attn_output + hidden_states
233 |
234 | if self.attn2 is not None:
235 | norm_hidden_states = (
236 | self.norm2(hidden_states, timestep)
237 | if self.use_ada_layer_norm
238 | else self.norm2(hidden_states)
239 | )
240 |
241 | # 2. Cross-Attention
242 | attn_output = self.attn2(
243 | norm_hidden_states,
244 | encoder_hidden_states=encoder_hidden_states,
245 | attention_mask=encoder_attention_mask,
246 | **cross_attention_kwargs,
247 | )
248 | hidden_states = attn_output + hidden_states
249 |
250 | # 3. Feed-forward
251 | norm_hidden_states = self.norm3(hidden_states)
252 |
253 | if self.use_ada_layer_norm_zero:
254 | norm_hidden_states = (
255 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
256 | )
257 |
258 | ff_output = self.ff(norm_hidden_states)
259 |
260 | if self.use_ada_layer_norm_zero:
261 | ff_output = gate_mlp.unsqueeze(1) * ff_output
262 |
263 | hidden_states = ff_output + hidden_states
264 |
265 | return hidden_states
266 |
267 | if self.reference_attn:
268 | if self.fusion_blocks == "midup":
269 | attn_modules = [
270 | module
271 | for module in (
272 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
273 | )
274 | if isinstance(module, BasicTransformerBlock)
275 | or isinstance(module, TemporalBasicTransformerBlock)
276 | ]
277 | elif self.fusion_blocks == "full":
278 | attn_modules = [
279 | module
280 | for module in torch_dfs(self.unet)
281 | if isinstance(module, BasicTransformerBlock)
282 | or isinstance(module, TemporalBasicTransformerBlock)
283 | ]
284 | attn_modules = sorted(
285 | attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
286 | )
287 |
288 | for i, module in enumerate(attn_modules):
289 | module._original_inner_forward = module.forward
290 | if isinstance(module, BasicTransformerBlock):
291 | module.forward = hacked_basic_transformer_inner_forward.__get__(
292 | module, BasicTransformerBlock
293 | )
294 | if isinstance(module, TemporalBasicTransformerBlock):
295 | module.forward = hacked_basic_transformer_inner_forward.__get__(
296 | module, TemporalBasicTransformerBlock
297 | )
298 |
299 | module.bank = []
300 | module.attn_weight = float(i) / float(len(attn_modules))
301 |
302 | def update(self, writer, dtype=torch.float16):
303 | if self.reference_attn:
304 | if self.fusion_blocks == "midup":
305 | reader_attn_modules = [
306 | module
307 | for module in (
308 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
309 | )
310 | if isinstance(module, TemporalBasicTransformerBlock)
311 | ]
312 | writer_attn_modules = [
313 | module
314 | for module in (
315 | torch_dfs(writer.unet.mid_block)
316 | + torch_dfs(writer.unet.up_blocks)
317 | )
318 | if isinstance(module, BasicTransformerBlock)
319 | ]
320 | elif self.fusion_blocks == "full":
321 | reader_attn_modules = [
322 | module
323 | for module in torch_dfs(self.unet)
324 | if isinstance(module, TemporalBasicTransformerBlock)
325 | ]
326 | writer_attn_modules = [
327 | module
328 | for module in torch_dfs(writer.unet)
329 | if isinstance(module, BasicTransformerBlock)
330 | ]
331 | reader_attn_modules = sorted(
332 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
333 | )
334 | writer_attn_modules = sorted(
335 | writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
336 | )
337 | for r, w in zip(reader_attn_modules, writer_attn_modules):
338 | r.bank = [v.clone().to(dtype) for v in w.bank]
339 | # w.bank.clear()
340 |
341 | def clear(self):
342 | if self.reference_attn:
343 | if self.fusion_blocks == "midup":
344 | reader_attn_modules = [
345 | module
346 | for module in (
347 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
348 | )
349 | if isinstance(module, BasicTransformerBlock)
350 | or isinstance(module, TemporalBasicTransformerBlock)
351 | ]
352 | elif self.fusion_blocks == "full":
353 | reader_attn_modules = [
354 | module
355 | for module in torch_dfs(self.unet)
356 | if isinstance(module, BasicTransformerBlock)
357 | or isinstance(module, TemporalBasicTransformerBlock)
358 | ]
359 | reader_attn_modules = sorted(
360 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
361 | )
362 | for r in reader_attn_modules:
363 | r.bank.clear()
364 |
--------------------------------------------------------------------------------
/src/models/pose_guider.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.nn.init as init
6 | from diffusers.models.modeling_utils import ModelMixin
7 |
8 | from src.models.motion_module import zero_module
9 | from src.models.resnet import InflatedConv3d
10 |
11 |
12 | class PoseGuider(ModelMixin):
13 | def __init__(
14 | self,
15 | conditioning_embedding_channels: int,
16 | conditioning_channels: int = 3,
17 | block_out_channels: Tuple[int] = (16, 32, 64, 128),
18 | ):
19 | super().__init__()
20 | self.conv_in = InflatedConv3d(
21 | conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
22 | )
23 |
24 | self.blocks = nn.ModuleList([])
25 |
26 | for i in range(len(block_out_channels) - 1):
27 | channel_in = block_out_channels[i]
28 | channel_out = block_out_channels[i + 1]
29 | self.blocks.append(
30 | InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
31 | )
32 | self.blocks.append(
33 | InflatedConv3d(
34 | channel_in, channel_out, kernel_size=3, padding=1, stride=2
35 | )
36 | )
37 |
38 | self.conv_out = zero_module(
39 | InflatedConv3d(
40 | block_out_channels[-1],
41 | conditioning_embedding_channels,
42 | kernel_size=3,
43 | padding=1,
44 | )
45 | )
46 |
47 | def forward(self, conditioning):
48 | embedding = self.conv_in(conditioning)
49 | embedding = F.silu(embedding)
50 |
51 | for block in self.blocks:
52 | embedding = block(embedding)
53 | embedding = F.silu(embedding)
54 |
55 | embedding = self.conv_out(embedding)
56 |
57 | return embedding
58 |
--------------------------------------------------------------------------------
/src/models/resnet.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from einops import rearrange
7 |
8 |
9 | class InflatedConv3d(nn.Conv2d):
10 | def forward(self, x):
11 | video_length = x.shape[2]
12 |
13 | x = rearrange(x, "b c f h w -> (b f) c h w")
14 | x = super().forward(x)
15 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
16 |
17 | return x
18 |
19 |
20 | class InflatedGroupNorm(nn.GroupNorm):
21 | def forward(self, x):
22 | video_length = x.shape[2]
23 |
24 | x = rearrange(x, "b c f h w -> (b f) c h w")
25 | x = super().forward(x)
26 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
27 |
28 | return x
29 |
30 |
31 | class Upsample3D(nn.Module):
32 | def __init__(
33 | self,
34 | channels,
35 | use_conv=False,
36 | use_conv_transpose=False,
37 | out_channels=None,
38 | name="conv",
39 | ):
40 | super().__init__()
41 | self.channels = channels
42 | self.out_channels = out_channels or channels
43 | self.use_conv = use_conv
44 | self.use_conv_transpose = use_conv_transpose
45 | self.name = name
46 |
47 | conv = None
48 | if use_conv_transpose:
49 | raise NotImplementedError
50 | elif use_conv:
51 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
52 |
53 | def forward(self, hidden_states, output_size=None):
54 | assert hidden_states.shape[1] == self.channels
55 |
56 | if self.use_conv_transpose:
57 | raise NotImplementedError
58 |
59 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
60 | dtype = hidden_states.dtype
61 | if dtype == torch.bfloat16:
62 | hidden_states = hidden_states.to(torch.float32)
63 |
64 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
65 | if hidden_states.shape[0] >= 64:
66 | hidden_states = hidden_states.contiguous()
67 |
68 | # if `output_size` is passed we force the interpolation output
69 | # size and do not make use of `scale_factor=2`
70 | if output_size is None:
71 | hidden_states = F.interpolate(
72 | hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
73 | )
74 | else:
75 | hidden_states = F.interpolate(
76 | hidden_states, size=output_size, mode="nearest"
77 | )
78 |
79 | # If the input is bfloat16, we cast back to bfloat16
80 | if dtype == torch.bfloat16:
81 | hidden_states = hidden_states.to(dtype)
82 |
83 | # if self.use_conv:
84 | # if self.name == "conv":
85 | # hidden_states = self.conv(hidden_states)
86 | # else:
87 | # hidden_states = self.Conv2d_0(hidden_states)
88 | hidden_states = self.conv(hidden_states)
89 |
90 | return hidden_states
91 |
92 |
93 | class Downsample3D(nn.Module):
94 | def __init__(
95 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
96 | ):
97 | super().__init__()
98 | self.channels = channels
99 | self.out_channels = out_channels or channels
100 | self.use_conv = use_conv
101 | self.padding = padding
102 | stride = 2
103 | self.name = name
104 |
105 | if use_conv:
106 | self.conv = InflatedConv3d(
107 | self.channels, self.out_channels, 3, stride=stride, padding=padding
108 | )
109 | else:
110 | raise NotImplementedError
111 |
112 | def forward(self, hidden_states):
113 | assert hidden_states.shape[1] == self.channels
114 | if self.use_conv and self.padding == 0:
115 | raise NotImplementedError
116 |
117 | assert hidden_states.shape[1] == self.channels
118 | hidden_states = self.conv(hidden_states)
119 |
120 | return hidden_states
121 |
122 |
123 | class ResnetBlock3D(nn.Module):
124 | def __init__(
125 | self,
126 | *,
127 | in_channels,
128 | out_channels=None,
129 | conv_shortcut=False,
130 | dropout=0.0,
131 | temb_channels=512,
132 | groups=32,
133 | groups_out=None,
134 | pre_norm=True,
135 | eps=1e-6,
136 | non_linearity="swish",
137 | time_embedding_norm="default",
138 | output_scale_factor=1.0,
139 | use_in_shortcut=None,
140 | use_inflated_groupnorm=None,
141 | ):
142 | super().__init__()
143 | self.pre_norm = pre_norm
144 | self.pre_norm = True
145 | self.in_channels = in_channels
146 | out_channels = in_channels if out_channels is None else out_channels
147 | self.out_channels = out_channels
148 | self.use_conv_shortcut = conv_shortcut
149 | self.time_embedding_norm = time_embedding_norm
150 | self.output_scale_factor = output_scale_factor
151 |
152 | if groups_out is None:
153 | groups_out = groups
154 |
155 | assert use_inflated_groupnorm != None
156 | if use_inflated_groupnorm:
157 | self.norm1 = InflatedGroupNorm(
158 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True
159 | )
160 | else:
161 | self.norm1 = torch.nn.GroupNorm(
162 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True
163 | )
164 |
165 | self.conv1 = InflatedConv3d(
166 | in_channels, out_channels, kernel_size=3, stride=1, padding=1
167 | )
168 |
169 | if temb_channels is not None:
170 | if self.time_embedding_norm == "default":
171 | time_emb_proj_out_channels = out_channels
172 | elif self.time_embedding_norm == "scale_shift":
173 | time_emb_proj_out_channels = out_channels * 2
174 | else:
175 | raise ValueError(
176 | f"unknown time_embedding_norm : {self.time_embedding_norm} "
177 | )
178 |
179 | self.time_emb_proj = torch.nn.Linear(
180 | temb_channels, time_emb_proj_out_channels
181 | )
182 | else:
183 | self.time_emb_proj = None
184 |
185 | if use_inflated_groupnorm:
186 | self.norm2 = InflatedGroupNorm(
187 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
188 | )
189 | else:
190 | self.norm2 = torch.nn.GroupNorm(
191 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
192 | )
193 | self.dropout = torch.nn.Dropout(dropout)
194 | self.conv2 = InflatedConv3d(
195 | out_channels, out_channels, kernel_size=3, stride=1, padding=1
196 | )
197 |
198 | if non_linearity == "swish":
199 | self.nonlinearity = lambda x: F.silu(x)
200 | elif non_linearity == "mish":
201 | self.nonlinearity = Mish()
202 | elif non_linearity == "silu":
203 | self.nonlinearity = nn.SiLU()
204 |
205 | self.use_in_shortcut = (
206 | self.in_channels != self.out_channels
207 | if use_in_shortcut is None
208 | else use_in_shortcut
209 | )
210 |
211 | self.conv_shortcut = None
212 | if self.use_in_shortcut:
213 | self.conv_shortcut = InflatedConv3d(
214 | in_channels, out_channels, kernel_size=1, stride=1, padding=0
215 | )
216 |
217 | def forward(self, input_tensor, temb):
218 | hidden_states = input_tensor
219 |
220 | hidden_states = self.norm1(hidden_states)
221 | hidden_states = self.nonlinearity(hidden_states)
222 |
223 | hidden_states = self.conv1(hidden_states)
224 |
225 | if temb is not None:
226 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
227 |
228 | if temb is not None and self.time_embedding_norm == "default":
229 | hidden_states = hidden_states + temb
230 |
231 | hidden_states = self.norm2(hidden_states)
232 |
233 | if temb is not None and self.time_embedding_norm == "scale_shift":
234 | scale, shift = torch.chunk(temb, 2, dim=1)
235 | hidden_states = hidden_states * (1 + scale) + shift
236 |
237 | hidden_states = self.nonlinearity(hidden_states)
238 |
239 | hidden_states = self.dropout(hidden_states)
240 | hidden_states = self.conv2(hidden_states)
241 |
242 | if self.conv_shortcut is not None:
243 | input_tensor = self.conv_shortcut(input_tensor)
244 |
245 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
246 |
247 | return output_tensor
248 |
249 |
250 | class Mish(torch.nn.Module):
251 | def forward(self, hidden_states):
252 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
253 |
--------------------------------------------------------------------------------
/src/models/transformer_3d.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional
3 |
4 | import torch
5 | from diffusers.configuration_utils import ConfigMixin, register_to_config
6 | from diffusers.models import ModelMixin
7 | from diffusers.utils import BaseOutput
8 | from diffusers.utils.import_utils import is_xformers_available
9 | from einops import rearrange, repeat
10 | from torch import nn
11 |
12 | from .attention import TemporalBasicTransformerBlock
13 |
14 |
15 | @dataclass
16 | class Transformer3DModelOutput(BaseOutput):
17 | sample: torch.FloatTensor
18 |
19 |
20 | if is_xformers_available():
21 | import xformers
22 | import xformers.ops
23 | else:
24 | xformers = None
25 |
26 |
27 | class Transformer3DModel(ModelMixin, ConfigMixin):
28 | _supports_gradient_checkpointing = True
29 |
30 | @register_to_config
31 | def __init__(
32 | self,
33 | num_attention_heads: int = 16,
34 | attention_head_dim: int = 88,
35 | in_channels: Optional[int] = None,
36 | num_layers: int = 1,
37 | dropout: float = 0.0,
38 | norm_num_groups: int = 32,
39 | cross_attention_dim: Optional[int] = None,
40 | attention_bias: bool = False,
41 | activation_fn: str = "geglu",
42 | num_embeds_ada_norm: Optional[int] = None,
43 | use_linear_projection: bool = False,
44 | only_cross_attention: bool = False,
45 | upcast_attention: bool = False,
46 | unet_use_cross_frame_attention=None,
47 | unet_use_temporal_attention=None,
48 | ):
49 | super().__init__()
50 | self.use_linear_projection = use_linear_projection
51 | self.num_attention_heads = num_attention_heads
52 | self.attention_head_dim = attention_head_dim
53 | inner_dim = num_attention_heads * attention_head_dim
54 |
55 | # Define input layers
56 | self.in_channels = in_channels
57 |
58 | self.norm = torch.nn.GroupNorm(
59 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
60 | )
61 | if use_linear_projection:
62 | self.proj_in = nn.Linear(in_channels, inner_dim)
63 | else:
64 | self.proj_in = nn.Conv2d(
65 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0
66 | )
67 |
68 | # Define transformers blocks
69 | self.transformer_blocks = nn.ModuleList(
70 | [
71 | TemporalBasicTransformerBlock(
72 | inner_dim,
73 | num_attention_heads,
74 | attention_head_dim,
75 | dropout=dropout,
76 | cross_attention_dim=cross_attention_dim,
77 | activation_fn=activation_fn,
78 | num_embeds_ada_norm=num_embeds_ada_norm,
79 | attention_bias=attention_bias,
80 | only_cross_attention=only_cross_attention,
81 | upcast_attention=upcast_attention,
82 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
83 | unet_use_temporal_attention=unet_use_temporal_attention,
84 | )
85 | for d in range(num_layers)
86 | ]
87 | )
88 |
89 | # 4. Define output layers
90 | if use_linear_projection:
91 | self.proj_out = nn.Linear(in_channels, inner_dim)
92 | else:
93 | self.proj_out = nn.Conv2d(
94 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0
95 | )
96 |
97 | self.gradient_checkpointing = False
98 |
99 | def _set_gradient_checkpointing(self, module, value=False):
100 | if hasattr(module, "gradient_checkpointing"):
101 | module.gradient_checkpointing = value
102 |
103 | def forward(
104 | self,
105 | hidden_states,
106 | encoder_hidden_states=None,
107 | timestep=None,
108 | return_dict: bool = True,
109 | ):
110 | # Input
111 | assert (
112 | hidden_states.dim() == 5
113 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
114 | video_length = hidden_states.shape[2]
115 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
116 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
117 | encoder_hidden_states = repeat(
118 | encoder_hidden_states, "b n c -> (b f) n c", f=video_length
119 | )
120 |
121 | batch, channel, height, weight = hidden_states.shape
122 | residual = hidden_states
123 |
124 | hidden_states = self.norm(hidden_states)
125 | if not self.use_linear_projection:
126 | hidden_states = self.proj_in(hidden_states)
127 | inner_dim = hidden_states.shape[1]
128 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
129 | batch, height * weight, inner_dim
130 | )
131 | else:
132 | inner_dim = hidden_states.shape[1]
133 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
134 | batch, height * weight, inner_dim
135 | )
136 | hidden_states = self.proj_in(hidden_states)
137 |
138 | # Blocks
139 | for i, block in enumerate(self.transformer_blocks):
140 | hidden_states = block(
141 | hidden_states,
142 | encoder_hidden_states=encoder_hidden_states,
143 | timestep=timestep,
144 | video_length=video_length,
145 | )
146 |
147 | # Output
148 | if not self.use_linear_projection:
149 | hidden_states = (
150 | hidden_states.reshape(batch, height, weight, inner_dim)
151 | .permute(0, 3, 1, 2)
152 | .contiguous()
153 | )
154 | hidden_states = self.proj_out(hidden_states)
155 | else:
156 | hidden_states = self.proj_out(hidden_states)
157 | hidden_states = (
158 | hidden_states.reshape(batch, height, weight, inner_dim)
159 | .permute(0, 3, 1, 2)
160 | .contiguous()
161 | )
162 |
163 | output = hidden_states + residual
164 |
165 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
166 | if not return_dict:
167 | return (output,)
168 |
169 | return Transformer3DModelOutput(sample=output)
170 |
--------------------------------------------------------------------------------
/src/pipelines/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/Moore-AnimateAnyone-for-windows/7cc83a44681509cd6e9138a21ec87cb8a6108c2d/src/pipelines/__init__.py
--------------------------------------------------------------------------------
/src/pipelines/context.py:
--------------------------------------------------------------------------------
1 | # TODO: Adapted from cli
2 | from typing import Callable, List, Optional
3 |
4 | import numpy as np
5 |
6 |
7 | def ordered_halving(val):
8 | bin_str = f"{val:064b}"
9 | bin_flip = bin_str[::-1]
10 | as_int = int(bin_flip, 2)
11 |
12 | return as_int / (1 << 64)
13 |
14 |
15 | def uniform(
16 | step: int = ...,
17 | num_steps: Optional[int] = None,
18 | num_frames: int = ...,
19 | context_size: Optional[int] = None,
20 | context_stride: int = 3,
21 | context_overlap: int = 4,
22 | closed_loop: bool = True,
23 | ):
24 | if num_frames <= context_size:
25 | yield list(range(num_frames))
26 | return
27 |
28 | context_stride = min(
29 | context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
30 | )
31 |
32 | for context_step in 1 << np.arange(context_stride):
33 | pad = int(round(num_frames * ordered_halving(step)))
34 | for j in range(
35 | int(ordered_halving(step) * context_step) + pad,
36 | num_frames + pad + (0 if closed_loop else -context_overlap),
37 | (context_size * context_step - context_overlap),
38 | ):
39 | yield [
40 | e % num_frames
41 | for e in range(j, j + context_size * context_step, context_step)
42 | ]
43 |
44 |
45 | def get_context_scheduler(name: str) -> Callable:
46 | if name == "uniform":
47 | return uniform
48 | else:
49 | raise ValueError(f"Unknown context_overlap policy {name}")
50 |
51 |
52 | def get_total_steps(
53 | scheduler,
54 | timesteps: List[int],
55 | num_steps: Optional[int] = None,
56 | num_frames: int = ...,
57 | context_size: Optional[int] = None,
58 | context_stride: int = 3,
59 | context_overlap: int = 4,
60 | closed_loop: bool = True,
61 | ):
62 | return sum(
63 | len(
64 | list(
65 | scheduler(
66 | i,
67 | num_steps,
68 | num_frames,
69 | context_size,
70 | context_stride,
71 | context_overlap,
72 | )
73 | )
74 | )
75 | for i in range(len(timesteps))
76 | )
77 |
--------------------------------------------------------------------------------
/src/pipelines/pipeline_pose2img.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from dataclasses import dataclass
3 | from typing import Callable, List, Optional, Union
4 |
5 | import numpy as np
6 | import torch
7 | from diffusers import DiffusionPipeline
8 | from diffusers.image_processor import VaeImageProcessor
9 | from diffusers.schedulers import (
10 | DDIMScheduler,
11 | DPMSolverMultistepScheduler,
12 | EulerAncestralDiscreteScheduler,
13 | EulerDiscreteScheduler,
14 | LMSDiscreteScheduler,
15 | PNDMScheduler,
16 | )
17 | from diffusers.utils import BaseOutput, is_accelerate_available
18 | from diffusers.utils.torch_utils import randn_tensor
19 | from einops import rearrange
20 | from tqdm import tqdm
21 | from transformers import CLIPImageProcessor
22 |
23 | from src.models.mutual_self_attention import ReferenceAttentionControl
24 |
25 |
26 | @dataclass
27 | class Pose2ImagePipelineOutput(BaseOutput):
28 | images: Union[torch.Tensor, np.ndarray]
29 |
30 |
31 | class Pose2ImagePipeline(DiffusionPipeline):
32 | _optional_components = []
33 |
34 | def __init__(
35 | self,
36 | vae,
37 | image_encoder,
38 | reference_unet,
39 | denoising_unet,
40 | pose_guider,
41 | scheduler: Union[
42 | DDIMScheduler,
43 | PNDMScheduler,
44 | LMSDiscreteScheduler,
45 | EulerDiscreteScheduler,
46 | EulerAncestralDiscreteScheduler,
47 | DPMSolverMultistepScheduler,
48 | ],
49 | ):
50 | super().__init__()
51 |
52 | self.register_modules(
53 | vae=vae,
54 | image_encoder=image_encoder,
55 | reference_unet=reference_unet,
56 | denoising_unet=denoising_unet,
57 | pose_guider=pose_guider,
58 | scheduler=scheduler,
59 | )
60 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
61 | self.clip_image_processor = CLIPImageProcessor()
62 | self.ref_image_processor = VaeImageProcessor(
63 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
64 | )
65 | self.cond_image_processor = VaeImageProcessor(
66 | vae_scale_factor=self.vae_scale_factor,
67 | do_convert_rgb=True,
68 | do_normalize=False,
69 | )
70 |
71 | def enable_vae_slicing(self):
72 | self.vae.enable_slicing()
73 |
74 | def disable_vae_slicing(self):
75 | self.vae.disable_slicing()
76 |
77 | def enable_sequential_cpu_offload(self, gpu_id=0):
78 | if is_accelerate_available():
79 | from accelerate import cpu_offload
80 | else:
81 | raise ImportError("Please install accelerate via `pip install accelerate`")
82 |
83 | device = torch.device(f"cuda:{gpu_id}")
84 |
85 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
86 | if cpu_offloaded_model is not None:
87 | cpu_offload(cpu_offloaded_model, device)
88 |
89 | @property
90 | def _execution_device(self):
91 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
92 | return self.device
93 | for module in self.unet.modules():
94 | if (
95 | hasattr(module, "_hf_hook")
96 | and hasattr(module._hf_hook, "execution_device")
97 | and module._hf_hook.execution_device is not None
98 | ):
99 | return torch.device(module._hf_hook.execution_device)
100 | return self.device
101 |
102 | def decode_latents(self, latents):
103 | video_length = latents.shape[2]
104 | latents = 1 / 0.18215 * latents
105 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
106 | # video = self.vae.decode(latents).sample
107 | video = []
108 | for frame_idx in tqdm(range(latents.shape[0])):
109 | video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
110 | video = torch.cat(video)
111 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
112 | video = (video / 2 + 0.5).clamp(0, 1)
113 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
114 | video = video.cpu().float().numpy()
115 | return video
116 |
117 | def prepare_extra_step_kwargs(self, generator, eta):
118 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
119 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
120 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
121 | # and should be between [0, 1]
122 |
123 | accepts_eta = "eta" in set(
124 | inspect.signature(self.scheduler.step).parameters.keys()
125 | )
126 | extra_step_kwargs = {}
127 | if accepts_eta:
128 | extra_step_kwargs["eta"] = eta
129 |
130 | # check if the scheduler accepts generator
131 | accepts_generator = "generator" in set(
132 | inspect.signature(self.scheduler.step).parameters.keys()
133 | )
134 | if accepts_generator:
135 | extra_step_kwargs["generator"] = generator
136 | return extra_step_kwargs
137 |
138 | def prepare_latents(
139 | self,
140 | batch_size,
141 | num_channels_latents,
142 | width,
143 | height,
144 | dtype,
145 | device,
146 | generator,
147 | latents=None,
148 | ):
149 | shape = (
150 | batch_size,
151 | num_channels_latents,
152 | height // self.vae_scale_factor,
153 | width // self.vae_scale_factor,
154 | )
155 | if isinstance(generator, list) and len(generator) != batch_size:
156 | raise ValueError(
157 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
158 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
159 | )
160 |
161 | if latents is None:
162 | latents = randn_tensor(
163 | shape, generator=generator, device=device, dtype=dtype
164 | )
165 | else:
166 | latents = latents.to(device)
167 |
168 | # scale the initial noise by the standard deviation required by the scheduler
169 | latents = latents * self.scheduler.init_noise_sigma
170 | return latents
171 |
172 | def prepare_condition(
173 | self,
174 | cond_image,
175 | width,
176 | height,
177 | device,
178 | dtype,
179 | do_classififer_free_guidance=False,
180 | ):
181 | image = self.cond_image_processor.preprocess(
182 | cond_image, height=height, width=width
183 | ).to(dtype=torch.float32)
184 |
185 | image = image.to(device=device, dtype=dtype)
186 |
187 | if do_classififer_free_guidance:
188 | image = torch.cat([image] * 2)
189 |
190 | return image
191 |
192 | @torch.no_grad()
193 | def __call__(
194 | self,
195 | ref_image,
196 | pose_image,
197 | width,
198 | height,
199 | num_inference_steps,
200 | guidance_scale,
201 | num_images_per_prompt=1,
202 | eta: float = 0.0,
203 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
204 | output_type: Optional[str] = "tensor",
205 | return_dict: bool = True,
206 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
207 | callback_steps: Optional[int] = 1,
208 | **kwargs,
209 | ):
210 | # Default height and width to unet
211 | height = height or self.unet.config.sample_size * self.vae_scale_factor
212 | width = width or self.unet.config.sample_size * self.vae_scale_factor
213 |
214 | device = self._execution_device
215 |
216 | do_classifier_free_guidance = guidance_scale > 1.0
217 |
218 | # Prepare timesteps
219 | self.scheduler.set_timesteps(num_inference_steps, device=device)
220 | timesteps = self.scheduler.timesteps
221 |
222 | batch_size = 1
223 |
224 | # Prepare clip image embeds
225 | clip_image = self.clip_image_processor.preprocess(
226 | ref_image.resize((224, 224)), return_tensors="pt"
227 | ).pixel_values
228 | clip_image_embeds = self.image_encoder(
229 | clip_image.to(device, dtype=self.image_encoder.dtype)
230 | ).image_embeds
231 | image_prompt_embeds = clip_image_embeds.unsqueeze(1)
232 | uncond_image_prompt_embeds = torch.zeros_like(image_prompt_embeds)
233 |
234 | if do_classifier_free_guidance:
235 | image_prompt_embeds = torch.cat(
236 | [uncond_image_prompt_embeds, image_prompt_embeds], dim=0
237 | )
238 |
239 | reference_control_writer = ReferenceAttentionControl(
240 | self.reference_unet,
241 | do_classifier_free_guidance=do_classifier_free_guidance,
242 | mode="write",
243 | batch_size=batch_size,
244 | fusion_blocks="full",
245 | )
246 | reference_control_reader = ReferenceAttentionControl(
247 | self.denoising_unet,
248 | do_classifier_free_guidance=do_classifier_free_guidance,
249 | mode="read",
250 | batch_size=batch_size,
251 | fusion_blocks="full",
252 | )
253 |
254 | num_channels_latents = self.denoising_unet.in_channels
255 | latents = self.prepare_latents(
256 | batch_size * num_images_per_prompt,
257 | num_channels_latents,
258 | width,
259 | height,
260 | clip_image_embeds.dtype,
261 | device,
262 | generator,
263 | )
264 | latents = latents.unsqueeze(2) # (bs, c, 1, h', w')
265 | latents_dtype = latents.dtype
266 |
267 | # Prepare extra step kwargs.
268 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
269 |
270 | # Prepare ref image latents
271 | ref_image_tensor = self.ref_image_processor.preprocess(
272 | ref_image, height=height, width=width
273 | ) # (bs, c, width, height)
274 | ref_image_tensor = ref_image_tensor.to(
275 | dtype=self.vae.dtype, device=self.vae.device
276 | )
277 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
278 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
279 |
280 | # Prepare pose condition image
281 | pose_cond_tensor = self.cond_image_processor.preprocess(
282 | pose_image, height=height, width=width
283 | )
284 | pose_cond_tensor = pose_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w)
285 | pose_cond_tensor = pose_cond_tensor.to(
286 | device=device, dtype=self.pose_guider.dtype
287 | )
288 | pose_fea = self.pose_guider(pose_cond_tensor)
289 | pose_fea = (
290 | torch.cat([pose_fea] * 2) if do_classifier_free_guidance else pose_fea
291 | )
292 |
293 | # denoising loop
294 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
295 | with self.progress_bar(total=num_inference_steps) as progress_bar:
296 | for i, t in enumerate(timesteps):
297 | # 1. Forward reference image
298 | if i == 0:
299 | self.reference_unet(
300 | ref_image_latents.repeat(
301 | (2 if do_classifier_free_guidance else 1), 1, 1, 1
302 | ),
303 | torch.zeros_like(t),
304 | encoder_hidden_states=image_prompt_embeds,
305 | return_dict=False,
306 | )
307 |
308 | # 2. Update reference unet feature into denosing net
309 | reference_control_reader.update(reference_control_writer)
310 |
311 | # 3.1 expand the latents if we are doing classifier free guidance
312 | latent_model_input = (
313 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents
314 | )
315 | latent_model_input = self.scheduler.scale_model_input(
316 | latent_model_input, t
317 | )
318 |
319 | noise_pred = self.denoising_unet(
320 | latent_model_input,
321 | t,
322 | encoder_hidden_states=image_prompt_embeds,
323 | pose_cond_fea=pose_fea,
324 | return_dict=False,
325 | )[0]
326 |
327 | # perform guidance
328 | if do_classifier_free_guidance:
329 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
330 | noise_pred = noise_pred_uncond + guidance_scale * (
331 | noise_pred_text - noise_pred_uncond
332 | )
333 |
334 | # compute the previous noisy sample x_t -> x_t-1
335 | latents = self.scheduler.step(
336 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False
337 | )[0]
338 |
339 | # call the callback, if provided
340 | if i == len(timesteps) - 1 or (
341 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
342 | ):
343 | progress_bar.update()
344 | if callback is not None and i % callback_steps == 0:
345 | step_idx = i // getattr(self.scheduler, "order", 1)
346 | callback(step_idx, t, latents)
347 | reference_control_reader.clear()
348 | reference_control_writer.clear()
349 |
350 | # Post-processing
351 | image = self.decode_latents(latents) # (b, c, 1, h, w)
352 |
353 | # Convert to tensor
354 | if output_type == "tensor":
355 | image = torch.from_numpy(image)
356 |
357 | if not return_dict:
358 | return image
359 |
360 | return Pose2ImagePipelineOutput(images=image)
361 |
--------------------------------------------------------------------------------
/src/pipelines/pipeline_pose2vid.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from dataclasses import dataclass
3 | from typing import Callable, List, Optional, Union
4 |
5 | import numpy as np
6 | import torch
7 | from diffusers import DiffusionPipeline
8 | from diffusers.image_processor import VaeImageProcessor
9 | from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
10 | EulerAncestralDiscreteScheduler,
11 | EulerDiscreteScheduler, LMSDiscreteScheduler,
12 | PNDMScheduler)
13 | from diffusers.utils import BaseOutput, is_accelerate_available
14 | from diffusers.utils.torch_utils import randn_tensor
15 | from einops import rearrange
16 | from tqdm import tqdm
17 | from transformers import CLIPImageProcessor
18 |
19 | from src.models.mutual_self_attention import ReferenceAttentionControl
20 |
21 |
22 | @dataclass
23 | class Pose2VideoPipelineOutput(BaseOutput):
24 | videos: Union[torch.Tensor, np.ndarray]
25 |
26 |
27 | class Pose2VideoPipeline(DiffusionPipeline):
28 | _optional_components = []
29 |
30 | def __init__(
31 | self,
32 | vae,
33 | image_encoder,
34 | reference_unet,
35 | denoising_unet,
36 | pose_guider,
37 | scheduler: Union[
38 | DDIMScheduler,
39 | PNDMScheduler,
40 | LMSDiscreteScheduler,
41 | EulerDiscreteScheduler,
42 | EulerAncestralDiscreteScheduler,
43 | DPMSolverMultistepScheduler,
44 | ],
45 | image_proj_model=None,
46 | tokenizer=None,
47 | text_encoder=None,
48 | ):
49 | super().__init__()
50 |
51 | self.register_modules(
52 | vae=vae,
53 | image_encoder=image_encoder,
54 | reference_unet=reference_unet,
55 | denoising_unet=denoising_unet,
56 | pose_guider=pose_guider,
57 | scheduler=scheduler,
58 | image_proj_model=image_proj_model,
59 | tokenizer=tokenizer,
60 | text_encoder=text_encoder,
61 | )
62 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
63 | self.clip_image_processor = CLIPImageProcessor()
64 | self.ref_image_processor = VaeImageProcessor(
65 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
66 | )
67 | self.cond_image_processor = VaeImageProcessor(
68 | vae_scale_factor=self.vae_scale_factor,
69 | do_convert_rgb=True,
70 | do_normalize=False,
71 | )
72 |
73 | def enable_vae_slicing(self):
74 | self.vae.enable_slicing()
75 |
76 | def disable_vae_slicing(self):
77 | self.vae.disable_slicing()
78 |
79 | def enable_sequential_cpu_offload(self, gpu_id=0):
80 | if is_accelerate_available():
81 | from accelerate import cpu_offload
82 | else:
83 | raise ImportError("Please install accelerate via `pip install accelerate`")
84 |
85 | device = torch.device(f"cuda:{gpu_id}")
86 |
87 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
88 | if cpu_offloaded_model is not None:
89 | cpu_offload(cpu_offloaded_model, device)
90 |
91 | @property
92 | def _execution_device(self):
93 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
94 | return self.device
95 | for module in self.unet.modules():
96 | if (
97 | hasattr(module, "_hf_hook")
98 | and hasattr(module._hf_hook, "execution_device")
99 | and module._hf_hook.execution_device is not None
100 | ):
101 | return torch.device(module._hf_hook.execution_device)
102 | return self.device
103 |
104 | def decode_latents(self, latents):
105 | video_length = latents.shape[2]
106 | latents = 1 / 0.18215 * latents
107 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
108 | # video = self.vae.decode(latents).sample
109 | video = []
110 | for frame_idx in tqdm(range(latents.shape[0])):
111 | video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
112 | video = torch.cat(video)
113 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
114 | video = (video / 2 + 0.5).clamp(0, 1)
115 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
116 | video = video.cpu().float().numpy()
117 | return video
118 |
119 | def prepare_extra_step_kwargs(self, generator, eta):
120 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
121 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
122 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
123 | # and should be between [0, 1]
124 |
125 | accepts_eta = "eta" in set(
126 | inspect.signature(self.scheduler.step).parameters.keys()
127 | )
128 | extra_step_kwargs = {}
129 | if accepts_eta:
130 | extra_step_kwargs["eta"] = eta
131 |
132 | # check if the scheduler accepts generator
133 | accepts_generator = "generator" in set(
134 | inspect.signature(self.scheduler.step).parameters.keys()
135 | )
136 | if accepts_generator:
137 | extra_step_kwargs["generator"] = generator
138 | return extra_step_kwargs
139 |
140 | def prepare_latents(
141 | self,
142 | batch_size,
143 | num_channels_latents,
144 | width,
145 | height,
146 | video_length,
147 | dtype,
148 | device,
149 | generator,
150 | latents=None,
151 | ):
152 | shape = (
153 | batch_size,
154 | num_channels_latents,
155 | video_length,
156 | height // self.vae_scale_factor,
157 | width // self.vae_scale_factor,
158 | )
159 | if isinstance(generator, list) and len(generator) != batch_size:
160 | raise ValueError(
161 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
162 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
163 | )
164 |
165 | if latents is None:
166 | latents = randn_tensor(
167 | shape, generator=generator, device=device, dtype=dtype
168 | )
169 | else:
170 | latents = latents.to(device)
171 |
172 | # scale the initial noise by the standard deviation required by the scheduler
173 | latents = latents * self.scheduler.init_noise_sigma
174 | return latents
175 |
176 | def _encode_prompt(
177 | self,
178 | prompt,
179 | device,
180 | num_videos_per_prompt,
181 | do_classifier_free_guidance,
182 | negative_prompt,
183 | ):
184 | batch_size = len(prompt) if isinstance(prompt, list) else 1
185 |
186 | text_inputs = self.tokenizer(
187 | prompt,
188 | padding="max_length",
189 | max_length=self.tokenizer.model_max_length,
190 | truncation=True,
191 | return_tensors="pt",
192 | )
193 | text_input_ids = text_inputs.input_ids
194 | untruncated_ids = self.tokenizer(
195 | prompt, padding="longest", return_tensors="pt"
196 | ).input_ids
197 |
198 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
199 | text_input_ids, untruncated_ids
200 | ):
201 | removed_text = self.tokenizer.batch_decode(
202 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
203 | )
204 |
205 | if (
206 | hasattr(self.text_encoder.config, "use_attention_mask")
207 | and self.text_encoder.config.use_attention_mask
208 | ):
209 | attention_mask = text_inputs.attention_mask.to(device)
210 | else:
211 | attention_mask = None
212 |
213 | text_embeddings = self.text_encoder(
214 | text_input_ids.to(device),
215 | attention_mask=attention_mask,
216 | )
217 | text_embeddings = text_embeddings[0]
218 |
219 | # duplicate text embeddings for each generation per prompt, using mps friendly method
220 | bs_embed, seq_len, _ = text_embeddings.shape
221 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
222 | text_embeddings = text_embeddings.view(
223 | bs_embed * num_videos_per_prompt, seq_len, -1
224 | )
225 |
226 | # get unconditional embeddings for classifier free guidance
227 | if do_classifier_free_guidance:
228 | uncond_tokens: List[str]
229 | if negative_prompt is None:
230 | uncond_tokens = [""] * batch_size
231 | elif type(prompt) is not type(negative_prompt):
232 | raise TypeError(
233 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
234 | f" {type(prompt)}."
235 | )
236 | elif isinstance(negative_prompt, str):
237 | uncond_tokens = [negative_prompt]
238 | elif batch_size != len(negative_prompt):
239 | raise ValueError(
240 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
241 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
242 | " the batch size of `prompt`."
243 | )
244 | else:
245 | uncond_tokens = negative_prompt
246 |
247 | max_length = text_input_ids.shape[-1]
248 | uncond_input = self.tokenizer(
249 | uncond_tokens,
250 | padding="max_length",
251 | max_length=max_length,
252 | truncation=True,
253 | return_tensors="pt",
254 | )
255 |
256 | if (
257 | hasattr(self.text_encoder.config, "use_attention_mask")
258 | and self.text_encoder.config.use_attention_mask
259 | ):
260 | attention_mask = uncond_input.attention_mask.to(device)
261 | else:
262 | attention_mask = None
263 |
264 | uncond_embeddings = self.text_encoder(
265 | uncond_input.input_ids.to(device),
266 | attention_mask=attention_mask,
267 | )
268 | uncond_embeddings = uncond_embeddings[0]
269 |
270 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
271 | seq_len = uncond_embeddings.shape[1]
272 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
273 | uncond_embeddings = uncond_embeddings.view(
274 | batch_size * num_videos_per_prompt, seq_len, -1
275 | )
276 |
277 | # For classifier free guidance, we need to do two forward passes.
278 | # Here we concatenate the unconditional and text embeddings into a single batch
279 | # to avoid doing two forward passes
280 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
281 |
282 | return text_embeddings
283 |
284 | @torch.no_grad()
285 | def __call__(
286 | self,
287 | ref_image,
288 | pose_images,
289 | width,
290 | height,
291 | video_length,
292 | num_inference_steps,
293 | guidance_scale,
294 | num_images_per_prompt=1,
295 | eta: float = 0.0,
296 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
297 | output_type: Optional[str] = "tensor",
298 | return_dict: bool = True,
299 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
300 | callback_steps: Optional[int] = 1,
301 | **kwargs,
302 | ):
303 | # Default height and width to unet
304 | height = height or self.unet.config.sample_size * self.vae_scale_factor
305 | width = width or self.unet.config.sample_size * self.vae_scale_factor
306 |
307 | device = self._execution_device
308 |
309 | do_classifier_free_guidance = guidance_scale > 1.0
310 |
311 | # Prepare timesteps
312 | self.scheduler.set_timesteps(num_inference_steps, device=device)
313 | timesteps = self.scheduler.timesteps
314 |
315 | batch_size = 1
316 |
317 | # Prepare clip image embeds
318 | clip_image = self.clip_image_processor.preprocess(
319 | ref_image, return_tensors="pt"
320 | ).pixel_values
321 | clip_image_embeds = self.image_encoder(
322 | clip_image.to(device, dtype=self.image_encoder.dtype)
323 | ).image_embeds
324 | encoder_hidden_states = clip_image_embeds.unsqueeze(1)
325 | uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states)
326 |
327 | if do_classifier_free_guidance:
328 | encoder_hidden_states = torch.cat(
329 | [uncond_encoder_hidden_states, encoder_hidden_states], dim=0
330 | )
331 | reference_control_writer = ReferenceAttentionControl(
332 | self.reference_unet,
333 | do_classifier_free_guidance=do_classifier_free_guidance,
334 | mode="write",
335 | batch_size=batch_size,
336 | fusion_blocks="full",
337 | )
338 | reference_control_reader = ReferenceAttentionControl(
339 | self.denoising_unet,
340 | do_classifier_free_guidance=do_classifier_free_guidance,
341 | mode="read",
342 | batch_size=batch_size,
343 | fusion_blocks="full",
344 | )
345 |
346 | num_channels_latents = self.denoising_unet.in_channels
347 | latents = self.prepare_latents(
348 | batch_size * num_images_per_prompt,
349 | num_channels_latents,
350 | width,
351 | height,
352 | video_length,
353 | clip_image_embeds.dtype,
354 | device,
355 | generator,
356 | )
357 |
358 | # Prepare extra step kwargs.
359 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
360 |
361 | # Prepare ref image latents
362 | ref_image_tensor = self.ref_image_processor.preprocess(
363 | ref_image, height=height, width=width
364 | ) # (bs, c, width, height)
365 | ref_image_tensor = ref_image_tensor.to(
366 | dtype=self.vae.dtype, device=self.vae.device
367 | )
368 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
369 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
370 |
371 | # Prepare a list of pose condition images
372 | pose_cond_tensor_list = []
373 | for pose_image in pose_images:
374 | pose_cond_tensor = (
375 | torch.from_numpy(np.array(pose_image.resize((width, height)))) / 255.0
376 | )
377 | pose_cond_tensor = pose_cond_tensor.permute(2, 0, 1).unsqueeze(
378 | 1
379 | ) # (c, 1, h, w)
380 | pose_cond_tensor_list.append(pose_cond_tensor)
381 | pose_cond_tensor = torch.cat(pose_cond_tensor_list, dim=1) # (c, t, h, w)
382 | pose_cond_tensor = pose_cond_tensor.unsqueeze(0)
383 | pose_cond_tensor = pose_cond_tensor.to(
384 | device=device, dtype=self.pose_guider.dtype
385 | )
386 | pose_fea = self.pose_guider(pose_cond_tensor)
387 | pose_fea = (
388 | torch.cat([pose_fea] * 2) if do_classifier_free_guidance else pose_fea
389 | )
390 |
391 | # denoising loop
392 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
393 | with self.progress_bar(total=num_inference_steps) as progress_bar:
394 | for i, t in enumerate(timesteps):
395 | # 1. Forward reference image
396 | if i == 0:
397 | self.reference_unet(
398 | ref_image_latents.repeat(
399 | (2 if do_classifier_free_guidance else 1), 1, 1, 1
400 | ),
401 | torch.zeros_like(t),
402 | # t,
403 | encoder_hidden_states=encoder_hidden_states,
404 | return_dict=False,
405 | )
406 | reference_control_reader.update(reference_control_writer)
407 |
408 | # 3.1 expand the latents if we are doing classifier free guidance
409 | latent_model_input = (
410 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents
411 | )
412 | latent_model_input = self.scheduler.scale_model_input(
413 | latent_model_input, t
414 | )
415 |
416 | noise_pred = self.denoising_unet(
417 | latent_model_input,
418 | t,
419 | encoder_hidden_states=encoder_hidden_states,
420 | pose_cond_fea=pose_fea,
421 | return_dict=False,
422 | )[0]
423 |
424 | # perform guidance
425 | if do_classifier_free_guidance:
426 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
427 | noise_pred = noise_pred_uncond + guidance_scale * (
428 | noise_pred_text - noise_pred_uncond
429 | )
430 |
431 | # compute the previous noisy sample x_t -> x_t-1
432 | latents = self.scheduler.step(
433 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False
434 | )[0]
435 |
436 | # call the callback, if provided
437 | if i == len(timesteps) - 1 or (
438 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
439 | ):
440 | progress_bar.update()
441 | if callback is not None and i % callback_steps == 0:
442 | step_idx = i // getattr(self.scheduler, "order", 1)
443 | callback(step_idx, t, latents)
444 |
445 | reference_control_reader.clear()
446 | reference_control_writer.clear()
447 |
448 | # Post-processing
449 | images = self.decode_latents(latents) # (b, c, f, h, w)
450 |
451 | # Convert to tensor
452 | if output_type == "tensor":
453 | images = torch.from_numpy(images)
454 |
455 | if not return_dict:
456 | return images
457 |
458 | return Pose2VideoPipelineOutput(videos=images)
459 |
--------------------------------------------------------------------------------
/src/pipelines/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | tensor_interpolation = None
4 |
5 |
6 | def get_tensor_interpolation_method():
7 | return tensor_interpolation
8 |
9 |
10 | def set_tensor_interpolation_method(is_slerp):
11 | global tensor_interpolation
12 | tensor_interpolation = slerp if is_slerp else linear
13 |
14 |
15 | def linear(v1, v2, t):
16 | return (1.0 - t) * v1 + t * v2
17 |
18 |
19 | def slerp(
20 | v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
21 | ) -> torch.Tensor:
22 | u0 = v0 / v0.norm()
23 | u1 = v1 / v1.norm()
24 | dot = (u0 * u1).sum()
25 | if dot.abs() > DOT_THRESHOLD:
26 | # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')
27 | return (1.0 - t) * v0 + t * v1
28 | omega = dot.acos()
29 | return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
30 |
--------------------------------------------------------------------------------
/src/utils/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | import os.path as osp
4 | import shutil
5 | import sys
6 | from pathlib import Path
7 |
8 | import av
9 | import numpy as np
10 | import torch
11 | import torchvision
12 | from einops import rearrange
13 | from PIL import Image
14 |
15 |
16 | def seed_everything(seed):
17 | import random
18 |
19 | import numpy as np
20 |
21 | torch.manual_seed(seed)
22 | torch.cuda.manual_seed_all(seed)
23 | np.random.seed(seed % (2**32))
24 | random.seed(seed)
25 |
26 |
27 | def import_filename(filename):
28 | spec = importlib.util.spec_from_file_location("mymodule", filename)
29 | module = importlib.util.module_from_spec(spec)
30 | sys.modules[spec.name] = module
31 | spec.loader.exec_module(module)
32 | return module
33 |
34 |
35 | def delete_additional_ckpt(base_path, num_keep):
36 | dirs = []
37 | for d in os.listdir(base_path):
38 | if d.startswith("checkpoint-"):
39 | dirs.append(d)
40 | num_tot = len(dirs)
41 | if num_tot <= num_keep:
42 | return
43 | # ensure ckpt is sorted and delete the ealier!
44 | del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
45 | for d in del_dirs:
46 | path_to_dir = osp.join(base_path, d)
47 | if osp.exists(path_to_dir):
48 | shutil.rmtree(path_to_dir)
49 |
50 |
51 | def save_videos_from_pil(pil_images, path, fps=8):
52 | import av
53 |
54 | save_fmt = Path(path).suffix
55 | os.makedirs(os.path.dirname(path), exist_ok=True)
56 | width, height = pil_images[0].size
57 |
58 | if save_fmt == ".mp4":
59 | codec = "libx264"
60 | container = av.open(path, "w")
61 | stream = container.add_stream(codec, rate=fps)
62 |
63 | stream.width = width
64 | stream.height = height
65 |
66 | for pil_image in pil_images:
67 | # pil_image = Image.fromarray(image_arr).convert("RGB")
68 | av_frame = av.VideoFrame.from_image(pil_image)
69 | container.mux(stream.encode(av_frame))
70 | container.mux(stream.encode())
71 | container.close()
72 |
73 | elif save_fmt == ".gif":
74 | pil_images[0].save(
75 | fp=path,
76 | format="GIF",
77 | append_images=pil_images[1:],
78 | save_all=True,
79 | duration=(1 / fps * 1000),
80 | loop=0,
81 | )
82 | else:
83 | raise ValueError("Unsupported file type. Use .mp4 or .gif.")
84 |
85 |
86 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
87 | videos = rearrange(videos, "b c t h w -> t b c h w")
88 | height, width = videos.shape[-2:]
89 | outputs = []
90 |
91 | for x in videos:
92 | x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
93 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
94 | if rescale:
95 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
96 | x = (x * 255).numpy().astype(np.uint8)
97 | x = Image.fromarray(x)
98 |
99 | outputs.append(x)
100 |
101 | os.makedirs(os.path.dirname(path), exist_ok=True)
102 |
103 | save_videos_from_pil(outputs, path, fps)
104 |
105 |
106 | def read_frames(video_path):
107 | container = av.open(video_path)
108 |
109 | video_stream = next(s for s in container.streams if s.type == "video")
110 | frames = []
111 | for packet in container.demux(video_stream):
112 | for frame in packet.decode():
113 | image = Image.frombytes(
114 | "RGB",
115 | (frame.width, frame.height),
116 | frame.to_rgb().to_ndarray(),
117 | )
118 | frames.append(image)
119 |
120 | return frames
121 |
122 |
123 | def get_fps(video_path):
124 | container = av.open(video_path)
125 | video_stream = next(s for s in container.streams if s.type == "video")
126 | fps = video_stream.average_rate
127 | container.close()
128 | return fps
129 |
--------------------------------------------------------------------------------
/tools/download_weights.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path, PurePosixPath
3 |
4 | from huggingface_hub import hf_hub_download
5 |
6 |
7 | def prepare_base_model():
8 | print(f'Preparing base stable-diffusion-v1-5 weights...')
9 | local_dir = "./pretrained_weights/stable-diffusion-v1-5"
10 | os.makedirs(local_dir, exist_ok=True)
11 | for hub_file in ["unet/config.json", "unet/diffusion_pytorch_model.bin"]:
12 | path = Path(hub_file)
13 | saved_path = local_dir / path
14 | if os.path.exists(saved_path):
15 | continue
16 | hf_hub_download(
17 | repo_id="runwayml/stable-diffusion-v1-5",
18 | subfolder=PurePosixPath(path.parent),
19 | filename=PurePosixPath(path.name),
20 | local_dir=local_dir,
21 | )
22 |
23 |
24 | def prepare_image_encoder():
25 | print(f"Preparing image encoder weights...")
26 | local_dir = "./pretrained_weights"
27 | os.makedirs(local_dir, exist_ok=True)
28 | for hub_file in ["image_encoder/config.json", "image_encoder/pytorch_model.bin"]:
29 | path = Path(hub_file)
30 | saved_path = local_dir / path
31 | if os.path.exists(saved_path):
32 | continue
33 | hf_hub_download(
34 | repo_id="lambdalabs/sd-image-variations-diffusers",
35 | subfolder=PurePosixPath(path.parent),
36 | filename=PurePosixPath(path.name),
37 | local_dir=local_dir,
38 | )
39 |
40 |
41 | def prepare_dwpose():
42 | print(f"Preparing DWPose weights...")
43 | local_dir = "./pretrained_weights/DWPose"
44 | os.makedirs(local_dir, exist_ok=True)
45 | for hub_file in [
46 | "dw-ll_ucoco_384.onnx",
47 | "yolox_l.onnx",
48 | ]:
49 | path = Path(hub_file)
50 | saved_path = local_dir / path
51 | if os.path.exists(saved_path):
52 | continue
53 |
54 | hf_hub_download(
55 | repo_id="yzd-v/DWPose",
56 | subfolder=PurePosixPath(path.parent),
57 | filename=PurePosixPath(path.name),
58 | local_dir=local_dir,
59 | )
60 |
61 |
62 | def prepare_vae():
63 | print(f"Preparing vae weights...")
64 | local_dir = "./pretrained_weights/sd-vae-ft-mse"
65 | os.makedirs(local_dir, exist_ok=True)
66 | for hub_file in [
67 | "config.json",
68 | "diffusion_pytorch_model.bin",
69 | ]:
70 | path = Path(hub_file)
71 | saved_path = local_dir / path
72 | if os.path.exists(saved_path):
73 | continue
74 |
75 | hf_hub_download(
76 | repo_id="stabilityai/sd-vae-ft-mse",
77 | subfolder=PurePosixPath(path.parent),
78 | filename=PurePosixPath(path.name),
79 | local_dir=local_dir,
80 | )
81 |
82 |
83 | def prepare_anyone():
84 | print(f"Preparing AnimateAnyone weights...")
85 | local_dir = "./pretrained_weights"
86 | os.makedirs(local_dir, exist_ok=True)
87 | for hub_file in [
88 | "denoising_unet.pth",
89 | "motion_module.pth",
90 | "pose_guider.pth",
91 | "reference_unet.pth",
92 | ]:
93 | path = Path(hub_file)
94 | saved_path = local_dir / path
95 | if os.path.exists(saved_path):
96 | continue
97 |
98 | hf_hub_download(
99 | repo_id="patrolli/AnimateAnyone",
100 | subfolder=PurePosixPath(path.parent),
101 | filename=PurePosixPath(path.name),
102 | local_dir=local_dir,
103 | )
104 |
105 | if __name__ == '__main__':
106 | prepare_base_model()
107 | prepare_image_encoder()
108 | prepare_dwpose()
109 | prepare_vae()
110 | prepare_anyone()
111 |
--------------------------------------------------------------------------------
/tools/extract_dwpose_from_vid.py:
--------------------------------------------------------------------------------
1 | import concurrent.futures
2 | import os
3 | import random
4 | from pathlib import Path
5 |
6 | import numpy as np
7 |
8 | from src.dwpose import DWposeDetector
9 | from src.utils.util import get_fps, read_frames, save_videos_from_pil
10 |
11 | # Extract dwpose mp4 videos from raw videos
12 | # /path/to/video_dataset/*/*.mp4 -> /path/to/video_dataset_dwpose/*/*.mp4
13 |
14 |
15 | def process_single_video(video_path, detector, root_dir, save_dir):
16 | relative_path = os.path.relpath(video_path, root_dir)
17 | print(relative_path, video_path, root_dir)
18 | out_path = os.path.join(save_dir, relative_path)
19 | if os.path.exists(out_path):
20 | return
21 |
22 | output_dir = Path(os.path.dirname(os.path.join(save_dir, relative_path)))
23 | if not output_dir.exists():
24 | output_dir.mkdir(parents=True, exist_ok=True)
25 |
26 | fps = get_fps(video_path)
27 | frames = read_frames(video_path)
28 | kps_results = []
29 | for i, frame_pil in enumerate(frames):
30 | result, score = detector(frame_pil)
31 | score = np.mean(score, axis=-1)
32 |
33 | kps_results.append(result)
34 |
35 | save_videos_from_pil(kps_results, out_path, fps=fps)
36 |
37 |
38 | def process_batch_videos(video_list, detector, root_dir, save_dir):
39 | for i, video_path in enumerate(video_list):
40 | print(f"Process {i}/{len(video_list)} video")
41 | process_single_video(video_path, detector, root_dir, save_dir)
42 |
43 |
44 | if __name__ == "__main__":
45 | # -----
46 | # NOTE:
47 | # python tools/extract_dwpose_from_vid.py --video_root /path/to/video_dir
48 | # -----
49 | import argparse
50 |
51 | parser = argparse.ArgumentParser()
52 | parser.add_argument("--video_root", type=str)
53 | parser.add_argument(
54 | "--save_dir", type=str, help="Path to save extracted pose videos"
55 | )
56 | parser.add_argument("-j", type=int, default=4, help="Num workers")
57 | args = parser.parse_args()
58 | num_workers = args.j
59 | if args.save_dir is None:
60 | save_dir = args.video_root + "_dwpose"
61 | else:
62 | save_dir = args.save_dir
63 | if not os.path.exists(save_dir):
64 | os.makedirs(save_dir)
65 | cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
66 | gpu_ids = [int(id) for id in range(len(cuda_visible_devices.split(",")))]
67 | print(f"avaliable gpu ids: {gpu_ids}")
68 |
69 | # collect all video_folder paths
70 | video_mp4_paths = set()
71 | for root, dirs, files in os.walk(args.video_root):
72 | for name in files:
73 | if name.endswith(".mp4"):
74 | video_mp4_paths.add(os.path.join(root, name))
75 | video_mp4_paths = list(video_mp4_paths)
76 | random.shuffle(video_mp4_paths)
77 |
78 | # split into chunks,
79 | batch_size = (len(video_mp4_paths) + num_workers - 1) // num_workers
80 | print(f"Num videos: {len(video_mp4_paths)} {batch_size = }")
81 | video_chunks = [
82 | video_mp4_paths[i : i + batch_size]
83 | for i in range(0, len(video_mp4_paths), batch_size)
84 | ]
85 |
86 | with concurrent.futures.ThreadPoolExecutor() as executor:
87 | futures = []
88 | for i, chunk in enumerate(video_chunks):
89 | # init detector
90 | gpu_id = gpu_ids[i % len(gpu_ids)]
91 | detector = DWposeDetector()
92 | # torch.cuda.set_device(gpu_id)
93 | detector = detector.to(f"cuda:{gpu_id}")
94 |
95 | futures.append(
96 | executor.submit(
97 | process_batch_videos, chunk, detector, args.video_root, save_dir
98 | )
99 | )
100 | for future in concurrent.futures.as_completed(futures):
101 | future.result()
102 |
--------------------------------------------------------------------------------
/tools/extract_meta_info.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | # -----
6 | # [{'vid': , 'kps': , 'other':},
7 | # {'vid': , 'kps': , 'other':}]
8 | # -----
9 | # python tools/extract_meta_info.py --root_path /path/to/video_dir --dataset_name fashion
10 | # -----
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("--root_path", type=str)
13 | parser.add_argument("--dataset_name", type=str)
14 | parser.add_argument("--meta_info_name", type=str)
15 |
16 | args = parser.parse_args()
17 |
18 | if args.meta_info_name is None:
19 | args.meta_info_name = args.dataset_name
20 |
21 | pose_dir = args.root_path + "_dwpose"
22 |
23 | # collect all video_folder paths
24 | video_mp4_paths = set()
25 | for root, dirs, files in os.walk(args.root_path):
26 | for name in files:
27 | if name.endswith(".mp4"):
28 | video_mp4_paths.add(os.path.join(root, name))
29 | video_mp4_paths = list(video_mp4_paths)
30 |
31 | meta_infos = []
32 | for video_mp4_path in video_mp4_paths:
33 | relative_video_name = os.path.relpath(video_mp4_path, args.root_path)
34 | kps_path = os.path.join(pose_dir, relative_video_name)
35 | meta_infos.append({"video_path": video_mp4_path, "kps_path": kps_path})
36 |
37 | json.dump(meta_infos, open(f"./data/{args.meta_info_name}_meta.json", "w"))
38 |
--------------------------------------------------------------------------------
/tools/vid2pose.py:
--------------------------------------------------------------------------------
1 | from src.dwpose import DWposeDetector
2 | import os
3 | from pathlib import Path
4 |
5 | from src.utils.util import get_fps, read_frames, save_videos_from_pil
6 | import numpy as np
7 | from src.models.model_util import get_torch_device
8 |
9 |
10 | if __name__ == "__main__":
11 | import argparse
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--video_path", type=str)
15 | args = parser.parse_args()
16 |
17 | if not os.path.exists(args.video_path):
18 | raise ValueError(f"Path: {args.video_path} not exists")
19 |
20 | dir_path, video_name = (
21 | os.path.dirname(args.video_path),
22 | os.path.splitext(os.path.basename(args.video_path))[0],
23 | )
24 | out_path = os.path.join(dir_path, video_name + "_kps.mp4")
25 |
26 | detector = DWposeDetector()
27 | detector = detector.to(get_torch_device())
28 |
29 | fps = get_fps(args.video_path)
30 | frames = read_frames(args.video_path)
31 | kps_results = []
32 | for i, frame_pil in enumerate(frames):
33 | result, score = detector(frame_pil)
34 | score = np.mean(score, axis=-1)
35 |
36 | kps_results.append(result)
37 |
38 | print(out_path)
39 | save_videos_from_pil(kps_results, out_path, fps=fps)
40 |
--------------------------------------------------------------------------------