├── LICENSE
├── README.md
├── asset
├── Aragaki.mp4
├── cxk.mp4
├── jijin.mp4
├── kara.mp4
├── lyl.mp4
├── num18.mp4
├── pipeline.png
├── solo.mp4
├── zhiji_logo.png
└── zl.mp4
├── configs
├── inference
│ ├── audio
│ │ └── lyl.wav
│ ├── head_pose_temp
│ │ ├── pose_ref_video.mp4
│ │ └── pose_temp.npy
│ ├── inference_audio.yaml
│ ├── inference_v1.yaml
│ ├── inference_v2.yaml
│ ├── pose_videos
│ │ └── solo_pose.mp4
│ ├── ref_images
│ │ ├── Aragaki.png
│ │ ├── lyl.png
│ │ └── solo.png
│ └── video
│ │ └── Aragaki_song.mp4
├── prompts
│ ├── animation.yaml
│ ├── animation_audio.yaml
│ ├── animation_facereenac.yaml
│ └── test_cases.py
└── train
│ ├── stage1.yaml
│ └── stage2.yaml
├── pretrained_model
└── Put pre-trained weights here.txt
├── requirements.txt
├── scripts
├── app.py
├── audio2vid.py
├── generate_ref_pose.py
├── pose2vid.py
├── prepare_video.py
├── preprocess_dataset.py
├── vid2pose.py
└── vid2vid.py
├── src
├── audio_models
│ ├── mish.py
│ ├── model.py
│ ├── pose_model.py
│ ├── torch_utils.py
│ └── wav2vec2.py
├── dataset
│ └── dataset_face.py
├── models
│ ├── attention.py
│ ├── motion_module.py
│ ├── mutual_self_attention.py
│ ├── pose_guider.py
│ ├── resnet.py
│ ├── transformer_2d.py
│ ├── transformer_3d.py
│ ├── unet_2d_blocks.py
│ ├── unet_2d_condition.py
│ ├── unet_3d.py
│ └── unet_3d_blocks.py
├── pipelines
│ ├── context.py
│ ├── pipeline_pose2img.py
│ ├── pipeline_pose2vid.py
│ ├── pipeline_pose2vid_long.py
│ └── utils.py
└── utils
│ ├── audio_util.py
│ ├── draw_util.py
│ ├── face_landmark.py
│ ├── frame_interpolation.py
│ ├── mp_models
│ ├── blaze_face_short_range.tflite
│ ├── face_landmarker_v2_with_blendshapes.task
│ └── pose_landmarker_heavy.task
│ ├── mp_utils.py
│ ├── pose_util.py
│ └── util.py
├── train_stage_1.py
└── train_stage_2.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AniPortrait
2 |
3 | **AniPortrait: Audio-Driven Synthesis of Photorealistic Portrait Animations**
4 |
5 | Author: Huawei Wei, Zejun Yang, Zhisheng Wang
6 |
7 | Organization: Tencent Games Zhiji, Tencent
8 |
9 | 
10 |
11 | Here we propose AniPortrait, a novel framework for generating high-quality animation driven by
12 | audio and a reference portrait image. You can also provide a video to achieve face reenacment.
13 |
14 |
15 |
16 |
17 |
18 | ## Pipeline
19 |
20 | 
21 |
22 | ## Updates / TODO List
23 |
24 | - ✅ [2024/03/27] Now our paper is available on arXiv.
25 |
26 | - ✅ [2024/03/27] Update the code to generate pose_temp.npy for head pose control.
27 |
28 | - ✅ [2024/04/02] Update a new pose retarget strategy for vid2vid. Now we support substantial pose difference between ref_image and source video.
29 |
30 | - ✅ [2024/04/03] We release our Gradio [demo](https://huggingface.co/spaces/ZJYang/AniPortrait_official) on HuggingFace Spaces (thanks to the HF team for their free GPU support)!
31 |
32 | - ✅ [2024/04/07] Update a frame interpolation module to accelerate the inference process. Now you can add -acc in inference commands to get a faster video generation.
33 |
34 | - ✅ [2024/04/21] We have released the audio2pose model and [pre-trained weight](https://huggingface.co/ZJYang/AniPortrait/tree/main) for audio2video. Please update the code and download the weight file to experience.
35 |
36 | ## Various Generated Videos
37 |
38 | ### Self driven
39 |
40 |
41 |
42 |
43 |
44 | |
45 |
46 |
47 | |
48 |
49 |
50 |
51 | ### Face reenacment
52 |
53 |
54 |
55 |
56 |
57 | |
58 |
59 |
60 | |
61 |
62 |
63 |
64 | Video Source: [鹿火CAVY from bilibili](https://www.bilibili.com/video/BV1H4421F7dE/?spm_id_from=333.337.search-card.all.click)
65 |
66 | ### Audio driven
67 |
68 |
69 |
70 |
71 |
72 | |
73 |
74 |
75 | |
76 |
77 |
78 |
79 |
80 |
81 | |
82 |
83 |
84 | |
85 |
86 |
87 |
88 | ## Installation
89 |
90 | ### Build environment
91 |
92 | We recommend a python version >=3.10 and cuda version =11.7. Then build environment as follows:
93 |
94 | ```shell
95 | pip install -r requirements.txt
96 | ```
97 |
98 | ### Download weights
99 |
100 | All the weights should be placed under the `./pretrained_weights` direcotry. You can download weights manually as follows:
101 |
102 | 1. Download our trained [weights](https://huggingface.co/ZJYang/AniPortrait/tree/main), which include the following parts: `denoising_unet.pth`, `reference_unet.pth`, `pose_guider.pth`, `motion_module.pth`, `audio2mesh.pt`, `audio2pose.pt` and `film_net_fp16.pt`. You can also download from [wisemodel](https://wisemodel.cn/models/zjyang8510/AniPortrait).
103 |
104 | 2. Download pretrained weight of based models and other components:
105 | - [StableDiffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5)
106 | - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
107 | - [image_encoder](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/tree/main/image_encoder)
108 | - [wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h)
109 |
110 | Finally, these weights should be orgnized as follows:
111 |
112 | ```text
113 | ./pretrained_weights/
114 | |-- image_encoder
115 | | |-- config.json
116 | | `-- pytorch_model.bin
117 | |-- sd-vae-ft-mse
118 | | |-- config.json
119 | | |-- diffusion_pytorch_model.bin
120 | | `-- diffusion_pytorch_model.safetensors
121 | |-- stable-diffusion-v1-5
122 | | |-- feature_extractor
123 | | | `-- preprocessor_config.json
124 | | |-- model_index.json
125 | | |-- unet
126 | | | |-- config.json
127 | | | `-- diffusion_pytorch_model.bin
128 | | `-- v1-inference.yaml
129 | |-- wav2vec2-base-960h
130 | | |-- config.json
131 | | |-- feature_extractor_config.json
132 | | |-- preprocessor_config.json
133 | | |-- pytorch_model.bin
134 | | |-- README.md
135 | | |-- special_tokens_map.json
136 | | |-- tokenizer_config.json
137 | | `-- vocab.json
138 | |-- audio2mesh.pt
139 | |-- audio2pose.pt
140 | |-- denoising_unet.pth
141 | |-- film_net_fp16.pt
142 | |-- motion_module.pth
143 | |-- pose_guider.pth
144 | `-- reference_unet.pth
145 | ```
146 |
147 | Note: If you have installed some of the pretrained models, such as `StableDiffusion V1.5`, you can specify their paths in the config file (e.g. `./config/prompts/animation.yaml`).
148 |
149 |
150 | ## Gradio Web UI
151 |
152 | You can try out our web demo by the following command. We alse provide online demo
in Huggingface Spaces.
153 |
154 |
155 | ```shell
156 | python -m scripts.app
157 | ```
158 |
159 | ## Inference
160 |
161 | Kindly note that you can set -L to the desired number of generating frames in the command, for example, `-L 300`.
162 |
163 | **Acceleration method**: If it takes long time to generate a video, you can download [film_net_fp16.pt](https://huggingface.co/ZJYang/AniPortrait/tree/main) and put it under the `./pretrained_weights` direcotry. Then add `-acc` in the command.
164 |
165 | Here are the cli commands for running inference scripts:
166 |
167 | ### Self driven
168 |
169 | ```shell
170 | python -m scripts.pose2vid --config ./configs/prompts/animation.yaml -W 512 -H 512 -acc
171 | ```
172 |
173 | You can refer the format of animation.yaml to add your own reference images or pose videos. To convert the raw video into a pose video (keypoint sequence), you can run with the following command:
174 |
175 | ```shell
176 | python -m scripts.vid2pose --video_path pose_video_path.mp4
177 | ```
178 |
179 | ### Face reenacment
180 |
181 | ```shell
182 | python -m scripts.vid2vid --config ./configs/prompts/animation_facereenac.yaml -W 512 -H 512 -acc
183 | ```
184 |
185 | Add source face videos and reference images in the animation_facereenac.yaml.
186 |
187 | ### Audio driven
188 |
189 | ```shell
190 | python -m scripts.audio2vid --config ./configs/prompts/animation_audio.yaml -W 512 -H 512 -acc
191 | ```
192 |
193 | Add audios and reference images in the animation_audio.yaml.
194 |
195 | Delete `pose_temp` in `./configs/prompts/animation_audio.yaml` can enable the audio2pose model.
196 |
197 | You can also use this command to generate a pose_temp.npy for head pose control:
198 |
199 | ```shell
200 | python -m scripts.generate_ref_pose --ref_video ./configs/inference/head_pose_temp/pose_ref_video.mp4 --save_path ./configs/inference/head_pose_temp/pose.npy
201 | ```
202 |
203 | ## Training
204 |
205 | ### Data preparation
206 | Download [VFHQ](https://liangbinxie.github.io/projects/vfhq/) and [CelebV-HQ](https://github.com/CelebV-HQ/CelebV-HQ)
207 |
208 | Extract keypoints from raw videos and write training json file (here is an example of processing VFHQ):
209 |
210 | ```shell
211 | python -m scripts.preprocess_dataset --input_dir VFHQ_PATH --output_dir SAVE_PATH --training_json JSON_PATH
212 | ```
213 |
214 | Update lines in the training config file:
215 |
216 | ```yaml
217 | data:
218 | json_path: JSON_PATH
219 | ```
220 |
221 | ### Stage1
222 |
223 | Run command:
224 |
225 | ```shell
226 | accelerate launch train_stage_1.py --config ./configs/train/stage1.yaml
227 | ```
228 |
229 | ### Stage2
230 |
231 | Put the pretrained motion module weights `mm_sd_v15_v2.ckpt` ([download link](https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15_v2.ckpt)) under `./pretrained_weights`.
232 |
233 | Specify the stage1 training weights in the config file `stage2.yaml`, for example:
234 |
235 | ```yaml
236 | stage1_ckpt_dir: './exp_output/stage1'
237 | stage1_ckpt_step: 30000
238 | ```
239 |
240 | Run command:
241 |
242 | ```shell
243 | accelerate launch train_stage_2.py --config ./configs/train/stage2.yaml
244 | ```
245 |
246 | ## Acknowledgements
247 |
248 | We first thank the authors of [EMO](https://github.com/HumanAIGC/EMO), and part of the images and audios in our demos are from EMO. Additionally, we would like to thank the contributors to the [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone), [majic-animate](https://github.com/magic-research/magic-animate), [animatediff](https://github.com/guoyww/AnimateDiff) and [Open-AnimateAnyone](https://github.com/guoqincode/Open-AnimateAnyone) repositories, for their open research and exploration.
249 |
250 | ## Citation
251 |
252 | ```
253 | @misc{wei2024aniportrait,
254 | title={AniPortrait: Audio-Driven Synthesis of Photorealistic Portrait Animations},
255 | author={Huawei Wei and Zejun Yang and Zhisheng Wang},
256 | year={2024},
257 | eprint={2403.17694},
258 | archivePrefix={arXiv},
259 | primaryClass={cs.CV}
260 | }
261 | ```
--------------------------------------------------------------------------------
/asset/Aragaki.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/asset/Aragaki.mp4
--------------------------------------------------------------------------------
/asset/cxk.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/asset/cxk.mp4
--------------------------------------------------------------------------------
/asset/jijin.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/asset/jijin.mp4
--------------------------------------------------------------------------------
/asset/kara.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/asset/kara.mp4
--------------------------------------------------------------------------------
/asset/lyl.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/asset/lyl.mp4
--------------------------------------------------------------------------------
/asset/num18.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/asset/num18.mp4
--------------------------------------------------------------------------------
/asset/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/asset/pipeline.png
--------------------------------------------------------------------------------
/asset/solo.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/asset/solo.mp4
--------------------------------------------------------------------------------
/asset/zhiji_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/asset/zhiji_logo.png
--------------------------------------------------------------------------------
/asset/zl.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/asset/zl.mp4
--------------------------------------------------------------------------------
/configs/inference/audio/lyl.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/configs/inference/audio/lyl.wav
--------------------------------------------------------------------------------
/configs/inference/head_pose_temp/pose_ref_video.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/configs/inference/head_pose_temp/pose_ref_video.mp4
--------------------------------------------------------------------------------
/configs/inference/head_pose_temp/pose_temp.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/configs/inference/head_pose_temp/pose_temp.npy
--------------------------------------------------------------------------------
/configs/inference/inference_audio.yaml:
--------------------------------------------------------------------------------
1 | a2m_model:
2 | out_dim: 1404
3 | latent_dim: 512
4 | model_path: ./pretrained_model/wav2vec2-base-960h
5 | only_last_fetures: True
6 | from_pretrained: True
7 |
8 | a2p_model:
9 | out_dim: 6
10 | latent_dim: 512
11 | model_path: ./pretrained_model/wav2vec2-base-960h
12 | only_last_fetures: True
13 | from_pretrained: True
14 |
15 | pretrained_model:
16 | a2m_ckpt: ./pretrained_model/audio2mesh.pt
17 | a2p_ckpt: ./pretrained_model/audio2pose.pt
18 |
--------------------------------------------------------------------------------
/configs/inference/inference_v1.yaml:
--------------------------------------------------------------------------------
1 | unet_additional_kwargs:
2 | unet_use_cross_frame_attention: false
3 | unet_use_temporal_attention: false
4 | use_motion_module: true
5 | motion_module_resolutions: [1,2,4,8]
6 | motion_module_mid_block: false
7 | motion_module_decoder_only: false
8 | motion_module_type: "Vanilla"
9 |
10 | motion_module_kwargs:
11 | num_attention_heads: 8
12 | num_transformer_block: 1
13 | attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
14 | temporal_position_encoding: true
15 | temporal_position_encoding_max_len: 24
16 | temporal_attention_dim_div: 1
17 |
18 | noise_scheduler_kwargs:
19 | beta_start: 0.00085
20 | beta_end: 0.012
21 | beta_schedule: "linear"
22 | steps_offset: 1
23 | clip_sample: False
--------------------------------------------------------------------------------
/configs/inference/inference_v2.yaml:
--------------------------------------------------------------------------------
1 | unet_additional_kwargs:
2 | use_inflated_groupnorm: true
3 | unet_use_cross_frame_attention: false
4 | unet_use_temporal_attention: false
5 | use_motion_module: true
6 | motion_module_resolutions:
7 | - 1
8 | - 2
9 | - 4
10 | - 8
11 | motion_module_mid_block: true
12 | motion_module_decoder_only: false
13 | motion_module_type: Vanilla
14 | motion_module_kwargs:
15 | num_attention_heads: 8
16 | num_transformer_block: 1
17 | attention_block_types:
18 | - Temporal_Self
19 | - Temporal_Self
20 | temporal_position_encoding: true
21 | temporal_position_encoding_max_len: 32
22 | temporal_attention_dim_div: 1
23 |
24 | noise_scheduler_kwargs:
25 | beta_start: 0.00085
26 | beta_end: 0.012
27 | beta_schedule: "linear"
28 | clip_sample: false
29 | steps_offset: 1
30 | ### Zero-SNR params
31 | prediction_type: "v_prediction"
32 | rescale_betas_zero_snr: True
33 | timestep_spacing: "trailing"
34 |
35 | sampler: DDIM
--------------------------------------------------------------------------------
/configs/inference/pose_videos/solo_pose.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/configs/inference/pose_videos/solo_pose.mp4
--------------------------------------------------------------------------------
/configs/inference/ref_images/Aragaki.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/configs/inference/ref_images/Aragaki.png
--------------------------------------------------------------------------------
/configs/inference/ref_images/lyl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/configs/inference/ref_images/lyl.png
--------------------------------------------------------------------------------
/configs/inference/ref_images/solo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/configs/inference/ref_images/solo.png
--------------------------------------------------------------------------------
/configs/inference/video/Aragaki_song.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/configs/inference/video/Aragaki_song.mp4
--------------------------------------------------------------------------------
/configs/prompts/animation.yaml:
--------------------------------------------------------------------------------
1 | pretrained_base_model_path: './pretrained_model/stable-diffusion-v1-5'
2 | pretrained_vae_path: './pretrained_model/sd-vae-ft-mse'
3 | image_encoder_path: './pretrained_model/image_encoder'
4 |
5 | denoising_unet_path: "./pretrained_model/denoising_unet.pth"
6 | reference_unet_path: "./pretrained_model/reference_unet.pth"
7 | pose_guider_path: "./pretrained_model/pose_guider.pth"
8 | motion_module_path: "./pretrained_model/motion_module.pth"
9 |
10 | inference_config: "./configs/inference/inference_v2.yaml"
11 | weight_dtype: 'fp16'
12 |
13 | test_cases:
14 | "./configs/inference/ref_images/solo.png":
15 | - "./configs/inference/pose_videos/solo_pose.mp4"
16 |
--------------------------------------------------------------------------------
/configs/prompts/animation_audio.yaml:
--------------------------------------------------------------------------------
1 | pretrained_base_model_path: './pretrained_model/stable-diffusion-v1-5'
2 | pretrained_vae_path: './pretrained_model/sd-vae-ft-mse'
3 | image_encoder_path: './pretrained_model/image_encoder'
4 |
5 | denoising_unet_path: "./pretrained_model/denoising_unet.pth"
6 | reference_unet_path: "./pretrained_model/reference_unet.pth"
7 | pose_guider_path: "./pretrained_model/pose_guider.pth"
8 | motion_module_path: "./pretrained_model/motion_module.pth"
9 |
10 | audio_inference_config: "./configs/inference/inference_audio.yaml"
11 | inference_config: "./configs/inference/inference_v2.yaml"
12 | weight_dtype: 'fp16'
13 |
14 | # path of your custom head pose template
15 | # pose_temp: "./configs/inference/head_pose_temp/pose_temp.npy"
16 |
17 | test_cases:
18 | "./configs/inference/ref_images/lyl.png":
19 | - "./configs/inference/audio/lyl.wav"
20 |
--------------------------------------------------------------------------------
/configs/prompts/animation_facereenac.yaml:
--------------------------------------------------------------------------------
1 | pretrained_base_model_path: './pretrained_model/stable-diffusion-v1-5'
2 | pretrained_vae_path: './pretrained_model/sd-vae-ft-mse'
3 | image_encoder_path: './pretrained_model/image_encoder'
4 |
5 | denoising_unet_path: "./pretrained_model/denoising_unet.pth"
6 | reference_unet_path: "./pretrained_model/reference_unet.pth"
7 | pose_guider_path: "./pretrained_model/pose_guider.pth"
8 | motion_module_path: "./pretrained_model/motion_module.pth"
9 |
10 | inference_config: "./configs/inference/inference_v2.yaml"
11 | weight_dtype: 'fp16'
12 |
13 | test_cases:
14 | "./configs/inference/ref_images/Aragaki.png":
15 | - "./configs/inference/video/Aragaki_song.mp4"
16 |
--------------------------------------------------------------------------------
/configs/prompts/test_cases.py:
--------------------------------------------------------------------------------
1 | TestCasesDict = {
2 | 0: [
3 | {
4 | "./configs/inference/ref_images/Aragaki.png": [
5 | "./configs/inference/pose_videos/Aragaki_pose.mp4",
6 | "./configs/inference/pose_videos/solo_pose.mp4",
7 | ]
8 | },
9 | {
10 | "./configs/inference/ref_images/solo.png": [
11 | "./configs/inference/pose_videos/solo_pose.mp4",
12 | "./configs/inference/pose_videos/Aragaki_pose.mp4",
13 | ]
14 | },
15 | ],
16 | }
17 |
--------------------------------------------------------------------------------
/configs/train/stage1.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | json_path: "/data/VFHQ/training_data.json"
3 | sample_size: [512, 512]
4 | sample_stride: 4
5 | sample_n_frames: 16
6 |
7 | solver:
8 | gradient_accumulation_steps: 1
9 | mixed_precision: 'fp16'
10 | enable_xformers_memory_efficient_attention: True
11 | gradient_checkpointing: False
12 | max_train_steps: 300000
13 | max_grad_norm: 1.0
14 | # lr
15 | learning_rate: 1.0e-5
16 | scale_lr: False
17 | lr_warmup_steps: 1
18 | lr_scheduler: 'constant'
19 |
20 | # optimizer
21 | use_8bit_adam: True
22 | adam_beta1: 0.9
23 | adam_beta2: 0.999
24 | adam_weight_decay: 1.0e-2
25 | adam_epsilon: 1.0e-8
26 |
27 | val:
28 | validation_steps: 500
29 | validation_steps_tuple: [1, 10, 50, 100, 200, 500]
30 |
31 |
32 | noise_scheduler_kwargs:
33 | num_train_timesteps: 1000
34 | beta_start: 0.00085
35 | beta_end: 0.012
36 | beta_schedule: "scaled_linear"
37 | steps_offset: 1
38 | clip_sample: false
39 |
40 | base_model_path: './pretrained_model/stable-diffusion-v1-5'
41 | vae_model_path: './pretrained_model/sd-vae-ft-mse'
42 | image_encoder_path: './pretrained_model/image_encoder'
43 | controlnet_openpose_path: ''
44 |
45 | train_bs: 2
46 |
47 | weight_dtype: 'fp16' # [fp16, fp32]
48 | uncond_ratio: 0.1
49 | noise_offset: 0.05
50 | snr_gamma: 5.0
51 | enable_zero_snr: True
52 | pose_guider_pretrain: True
53 |
54 | seed: 12580
55 | resume_from_checkpoint: ''
56 | checkpointing_steps: 2000
57 | save_model_epoch_interval: 5
58 | exp_name: 'stage1'
59 | output_dir: './exp_output'
--------------------------------------------------------------------------------
/configs/train/stage2.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | json_path: "/data/VFHQ/training_data.json"
3 | sample_size: [512, 512]
4 | sample_stride: 1
5 | sample_n_frames: 16
6 | sample_stride_aug: True
7 |
8 |
9 | solver:
10 | gradient_accumulation_steps: 1
11 | mixed_precision: 'fp16'
12 | enable_xformers_memory_efficient_attention: True
13 | gradient_checkpointing: True
14 | max_train_steps: 40000
15 | max_grad_norm: 1.0
16 | # lr
17 | learning_rate: 1e-5
18 | scale_lr: False
19 | lr_warmup_steps: 1
20 | lr_scheduler: 'constant'
21 |
22 | # optimizer
23 | use_8bit_adam: True
24 | adam_beta1: 0.9
25 | adam_beta2: 0.999
26 | adam_weight_decay: 1.0e-2
27 | adam_epsilon: 1.0e-8
28 |
29 | val:
30 | validation_steps: 500
31 | validation_steps_tuple: [1, 10, 20, 50, 80]
32 |
33 |
34 | noise_scheduler_kwargs:
35 | num_train_timesteps: 1000
36 | beta_start: 0.00085
37 | beta_end: 0.012
38 | beta_schedule: "linear"
39 | steps_offset: 1
40 | clip_sample: false
41 |
42 | base_model_path: './pretrained_model/stable-diffusion-v1-5'
43 | vae_model_path: './pretrained_model/sd-vae-ft-mse'
44 | image_encoder_path: './pretrained_model/image_encoder'
45 | mm_path: './pretrained_model/mm_sd_v15_v2.ckpt'
46 |
47 | train_bs: 1
48 |
49 | weight_dtype: 'fp16' # [fp16, fp32]
50 | uncond_ratio: 0.1
51 | noise_offset: 0.05
52 | snr_gamma: 5.0
53 | enable_zero_snr: True
54 | stage1_ckpt_dir: './exp_output/stage1'
55 | stage1_ckpt_step: 300000
56 |
57 | seed: 12580
58 | resume_from_checkpoint: ''
59 | checkpointing_steps: 2000
60 | exp_name: 'stage2'
61 | output_dir: './exp_output'
--------------------------------------------------------------------------------
/pretrained_model/Put pre-trained weights here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/pretrained_model/Put pre-trained weights here.txt
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.21.0
2 | av==11.0.0
3 | clip @ https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip#sha256=b5842c25da441d6c581b53a5c60e0c2127ebafe0f746f8e15561a006c6c3be6a
4 | decord==0.6.0
5 | diffusers==0.24.0
6 | einops==0.4.1
7 | gradio==4.24.0
8 | gradio_client==0.14.0
9 | imageio==2.33.0
10 | imageio-ffmpeg==0.4.9
11 | numpy==1.24.4
12 | omegaconf==2.2.3
13 | onnxruntime-gpu==1.16.3
14 | open-clip-torch==2.20.0
15 | opencv-contrib-python==4.8.1.78
16 | opencv-python==4.8.1.78
17 | Pillow==9.5.0
18 | scikit-image==0.21.0
19 | scikit-learn==1.3.2
20 | scipy==1.11.4
21 | torch==2.0.1
22 | torchdiffeq==0.2.3
23 | torchmetrics==1.2.1
24 | torchsde==0.2.5
25 | torchvision==0.15.2
26 | tqdm==4.66.1
27 | transformers==4.30.2
28 | xformers==0.0.22
29 | controlnet-aux==0.0.7
30 | mediapipe==0.10.11
31 | librosa==0.9.2
32 | ffmpeg-python==0.2.0
--------------------------------------------------------------------------------
/scripts/audio2vid.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import ffmpeg
4 | import random
5 | from datetime import datetime
6 | from pathlib import Path
7 | from typing import List
8 | import subprocess
9 | import av
10 | import numpy as np
11 | import cv2
12 | import torch
13 | import torchvision
14 | from diffusers import AutoencoderKL, DDIMScheduler
15 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
16 | from einops import repeat
17 | from omegaconf import OmegaConf
18 | from PIL import Image
19 | from torchvision import transforms
20 | from transformers import CLIPVisionModelWithProjection
21 |
22 | from configs.prompts.test_cases import TestCasesDict
23 | from src.models.pose_guider import PoseGuider
24 | from src.models.unet_2d_condition import UNet2DConditionModel
25 | from src.models.unet_3d import UNet3DConditionModel
26 | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
27 | from src.utils.util import get_fps, read_frames, save_videos_grid
28 |
29 | from src.audio_models.model import Audio2MeshModel
30 | from src.audio_models.pose_model import Audio2PoseModel
31 | from src.utils.audio_util import prepare_audio_feature
32 | from src.utils.mp_utils import LMKExtractor
33 | from src.utils.draw_util import FaceMeshVisualizer
34 | from src.utils.pose_util import project_points, smooth_pose_seq
35 | from src.utils.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool
36 |
37 |
38 | def parse_args():
39 | parser = argparse.ArgumentParser()
40 | parser.add_argument("--config", type=str, default='./configs/prompts/animation_audio.yaml')
41 | parser.add_argument("-W", type=int, default=512)
42 | parser.add_argument("-H", type=int, default=512)
43 | parser.add_argument("-L", type=int)
44 | parser.add_argument("--seed", type=int, default=42)
45 | parser.add_argument("--cfg", type=float, default=3.5)
46 | parser.add_argument("--steps", type=int, default=25)
47 | parser.add_argument("--fps", type=int, default=30)
48 | parser.add_argument("-acc", "--accelerate", action='store_true')
49 | parser.add_argument("--fi_step", type=int, default=3)
50 | args = parser.parse_args()
51 |
52 | return args
53 |
54 | def main():
55 | args = parse_args()
56 |
57 | config = OmegaConf.load(args.config)
58 |
59 | if config.weight_dtype == "fp16":
60 | weight_dtype = torch.float16
61 | else:
62 | weight_dtype = torch.float32
63 |
64 | audio_infer_config = OmegaConf.load(config.audio_inference_config)
65 | # prepare model
66 | a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
67 | a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt']), strict=False)
68 | a2m_model.cuda().eval()
69 |
70 | a2p_model = Audio2PoseModel(audio_infer_config['a2p_model'])
71 | a2p_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2p_ckpt']), strict=False)
72 | a2p_model.cuda().eval()
73 |
74 | vae = AutoencoderKL.from_pretrained(
75 | config.pretrained_vae_path,
76 | ).to("cuda", dtype=weight_dtype)
77 |
78 | reference_unet = UNet2DConditionModel.from_pretrained(
79 | config.pretrained_base_model_path,
80 | subfolder="unet",
81 | ).to(dtype=weight_dtype, device="cuda")
82 |
83 | inference_config_path = config.inference_config
84 | infer_config = OmegaConf.load(inference_config_path)
85 | denoising_unet = UNet3DConditionModel.from_pretrained_2d(
86 | config.pretrained_base_model_path,
87 | config.motion_module_path,
88 | subfolder="unet",
89 | unet_additional_kwargs=infer_config.unet_additional_kwargs,
90 | ).to(dtype=weight_dtype, device="cuda")
91 |
92 |
93 | pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
94 |
95 | image_enc = CLIPVisionModelWithProjection.from_pretrained(
96 | config.image_encoder_path
97 | ).to(dtype=weight_dtype, device="cuda")
98 |
99 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
100 | scheduler = DDIMScheduler(**sched_kwargs)
101 |
102 | generator = torch.manual_seed(args.seed)
103 |
104 | width, height = args.W, args.H
105 |
106 | # load pretrained weights
107 | denoising_unet.load_state_dict(
108 | torch.load(config.denoising_unet_path, map_location="cpu"),
109 | strict=False,
110 | )
111 | reference_unet.load_state_dict(
112 | torch.load(config.reference_unet_path, map_location="cpu"),
113 | )
114 | pose_guider.load_state_dict(
115 | torch.load(config.pose_guider_path, map_location="cpu"),
116 | )
117 |
118 | pipe = Pose2VideoPipeline(
119 | vae=vae,
120 | image_encoder=image_enc,
121 | reference_unet=reference_unet,
122 | denoising_unet=denoising_unet,
123 | pose_guider=pose_guider,
124 | scheduler=scheduler,
125 | )
126 | pipe = pipe.to("cuda", dtype=weight_dtype)
127 |
128 | date_str = datetime.now().strftime("%Y%m%d")
129 | time_str = datetime.now().strftime("%H%M")
130 | save_dir_name = f"{time_str}--seed_{args.seed}-{args.W}x{args.H}"
131 |
132 | save_dir = Path(f"output/{date_str}/{save_dir_name}")
133 | save_dir.mkdir(exist_ok=True, parents=True)
134 |
135 |
136 | lmk_extractor = LMKExtractor()
137 | vis = FaceMeshVisualizer(forehead_edge=False)
138 |
139 | if args.accelerate:
140 | frame_inter_model = init_frame_interpolation_model()
141 |
142 | for ref_image_path in config["test_cases"].keys():
143 | # Each ref_image may correspond to multiple actions
144 | for audio_path in config["test_cases"][ref_image_path]:
145 | ref_name = Path(ref_image_path).stem
146 | audio_name = Path(audio_path).stem
147 |
148 | ref_image_pil = Image.open(ref_image_path).convert("RGB")
149 | ref_image_np = cv2.cvtColor(np.array(ref_image_pil), cv2.COLOR_RGB2BGR)
150 | ref_image_np = cv2.resize(ref_image_np, (args.H, args.W))
151 |
152 | face_result = lmk_extractor(ref_image_np)
153 | assert face_result is not None, "No face detected."
154 | lmks = face_result['lmks'].astype(np.float32)
155 | ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
156 |
157 | sample = prepare_audio_feature(audio_path, wav2vec_model_path=audio_infer_config['a2m_model']['model_path'])
158 | sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
159 | sample['audio_feature'] = sample['audio_feature'].unsqueeze(0)
160 |
161 | # inference
162 | pred = a2m_model.infer(sample['audio_feature'], sample['seq_len'])
163 | pred = pred.squeeze().detach().cpu().numpy()
164 | pred = pred.reshape(pred.shape[0], -1, 3)
165 | pred = pred + face_result['lmks3d']
166 |
167 | if 'pose_temp' in config and config['pose_temp'] is not None:
168 | pose_seq = np.load(config['pose_temp'])
169 | mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0)
170 | pose_seq = np.tile(mirrored_pose_seq, (sample['seq_len'] // len(mirrored_pose_seq) + 1, 1))[:sample['seq_len']]
171 | else:
172 | id_seed = random.randint(0, 99)
173 | id_seed = torch.LongTensor([id_seed]).cuda()
174 |
175 | # Currently, only inference up to a maximum length of 10 seconds is supported.
176 | chunk_duration = 5 # 5 seconds
177 | sr = 16000
178 | fps = 30
179 | chunk_size = sr * chunk_duration
180 |
181 | audio_chunks = list(sample['audio_feature'].split(chunk_size, dim=1))
182 | seq_len_list = [chunk_duration*fps] * (len(audio_chunks) - 1) + [sample['seq_len'] % (chunk_duration*fps)] # 30 fps
183 |
184 | audio_chunks[-2] = torch.cat((audio_chunks[-2], audio_chunks[-1]), dim=1)
185 | seq_len_list[-2] = seq_len_list[-2] + seq_len_list[-1]
186 | del audio_chunks[-1]
187 | del seq_len_list[-1]
188 |
189 | pose_seq = []
190 | for audio, seq_len in zip(audio_chunks, seq_len_list):
191 | pose_seq_chunk = a2p_model.infer(audio, seq_len, id_seed)
192 | pose_seq_chunk = pose_seq_chunk.squeeze().detach().cpu().numpy()
193 | pose_seq_chunk[:, :3] *= 0.5
194 | pose_seq.append(pose_seq_chunk)
195 |
196 | pose_seq = np.concatenate(pose_seq, 0)
197 | pose_seq = smooth_pose_seq(pose_seq, 7)
198 |
199 | # project 3D mesh to 2D landmark
200 | projected_vertices = project_points(pred, face_result['trans_mat'], pose_seq, [height, width])
201 |
202 | pose_images = []
203 | for i, verts in enumerate(projected_vertices):
204 | lmk_img = vis.draw_landmarks((width, height), verts, normed=False)
205 | pose_images.append(lmk_img)
206 |
207 | pose_list = []
208 | pose_tensor_list = []
209 | print(f"pose video has {len(pose_images)} frames, with {args.fps} fps")
210 | pose_transform = transforms.Compose(
211 | [transforms.Resize((height, width)), transforms.ToTensor()]
212 | )
213 | args_L = len(pose_images) if args.L is None else args.L
214 | for pose_image_np in pose_images[: args_L]:
215 | pose_image_pil = Image.fromarray(cv2.cvtColor(pose_image_np, cv2.COLOR_BGR2RGB))
216 | pose_tensor_list.append(pose_transform(pose_image_pil))
217 | sub_step = args.fi_step if args.accelerate else 1
218 | for pose_image_np in pose_images[: args_L: sub_step]:
219 | pose_image_np = cv2.resize(pose_image_np, (width, height))
220 | pose_list.append(pose_image_np)
221 |
222 | pose_list = np.array(pose_list)
223 |
224 | video_length = len(pose_list)
225 |
226 | pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)
227 | pose_tensor = pose_tensor.transpose(0, 1)
228 | pose_tensor = pose_tensor.unsqueeze(0)
229 |
230 | video = pipe(
231 | ref_image_pil,
232 | pose_list,
233 | ref_pose,
234 | width,
235 | height,
236 | video_length,
237 | args.steps,
238 | args.cfg,
239 | generator=generator,
240 | ).videos
241 |
242 | if args.accelerate:
243 | video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=args.fi_step-1)
244 |
245 | ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w)
246 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(
247 | 0
248 | ) # (1, c, 1, h, w)
249 | ref_image_tensor = repeat(
250 | ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=video.shape[2]
251 | )
252 |
253 | video = torch.cat([ref_image_tensor, pose_tensor[:,:,:video.shape[2]], video], dim=0)
254 | save_path = f"{save_dir}/{ref_name}_{audio_name}_{args.H}x{args.W}_{int(args.cfg)}_{time_str}_noaudio.mp4"
255 | save_videos_grid(
256 | video,
257 | save_path,
258 | n_rows=3,
259 | fps=args.fps,
260 | )
261 |
262 | stream = ffmpeg.input(save_path)
263 | audio = ffmpeg.input(audio_path)
264 | ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
265 | os.remove(save_path)
266 |
267 | if __name__ == "__main__":
268 | main()
269 |
--------------------------------------------------------------------------------
/scripts/generate_ref_pose.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import argparse
4 |
5 | import numpy as np
6 | import cv2
7 | from tqdm import tqdm
8 | from scipy.spatial.transform import Rotation as R
9 | from scipy.interpolate import interp1d
10 |
11 | from src.utils.mp_utils import LMKExtractor
12 | from src.utils.pose_util import smooth_pose_seq, matrix_to_euler_and_translation
13 |
14 |
15 | def parse_args():
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument("--ref_video", type=str, default='', help='path of input video')
18 | parser.add_argument("--save_path", type=str, default='', help='path to save pose')
19 | args = parser.parse_args()
20 | return args
21 |
22 |
23 | def main():
24 | args = parse_args()
25 |
26 | lmk_extractor = LMKExtractor()
27 |
28 | cap = cv2.VideoCapture(args.ref_video)
29 |
30 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
31 | fps = cap.get(cv2.CAP_PROP_FPS)
32 |
33 | pbar = tqdm(range(total_frames), desc="processing ...")
34 |
35 | trans_mat_list = []
36 | while cap.isOpened():
37 | ret, frame = cap.read()
38 | if not ret:
39 | break
40 |
41 | pbar.update(1)
42 | result = lmk_extractor(frame)
43 | if result is None:
44 | break
45 | trans_mat_list.append(result['trans_mat'].astype(np.float32))
46 | cap.release()
47 |
48 | total_frames = len(trans_mat_list)
49 |
50 | trans_mat_arr = np.array(trans_mat_list)
51 |
52 | # compute delta pose
53 | trans_mat_inv_frame_0 = np.linalg.inv(trans_mat_arr[0])
54 | pose_arr = np.zeros([trans_mat_arr.shape[0], 6])
55 |
56 | for i in range(pose_arr.shape[0]):
57 | pose_mat = trans_mat_inv_frame_0 @ trans_mat_arr[i]
58 | euler_angles, translation_vector = matrix_to_euler_and_translation(pose_mat)
59 | pose_arr[i, :3] = euler_angles
60 | pose_arr[i, 3:6] = translation_vector
61 |
62 | # interpolate to 30 fps
63 | new_fps = 30
64 | old_time = np.linspace(0, total_frames / fps, total_frames)
65 | new_time = np.linspace(0, total_frames / fps, int(total_frames * new_fps / fps))
66 |
67 | pose_arr_interp = np.zeros((len(new_time), 6))
68 | for i in range(6):
69 | interp_func = interp1d(old_time, pose_arr[:, i])
70 | pose_arr_interp[:, i] = interp_func(new_time)
71 |
72 | pose_arr_smooth = smooth_pose_seq(pose_arr_interp)
73 | np.save(args.save_path, pose_arr_smooth)
74 |
75 |
76 | if __name__ == "__main__":
77 | main()
78 |
79 |
--------------------------------------------------------------------------------
/scripts/pose2vid.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import ffmpeg
4 | from datetime import datetime
5 | from pathlib import Path
6 | from typing import List
7 | import subprocess
8 | import av
9 | import numpy as np
10 | import cv2
11 | import torch
12 | import torchvision
13 | from diffusers import AutoencoderKL, DDIMScheduler
14 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
15 | from einops import repeat
16 | from omegaconf import OmegaConf
17 | from PIL import Image
18 | from torchvision import transforms
19 | from transformers import CLIPVisionModelWithProjection
20 |
21 | from configs.prompts.test_cases import TestCasesDict
22 | from src.models.pose_guider import PoseGuider
23 | from src.models.unet_2d_condition import UNet2DConditionModel
24 | from src.models.unet_3d import UNet3DConditionModel
25 | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
26 | from src.utils.util import get_fps, read_frames, save_videos_grid
27 | from src.utils.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool
28 |
29 | from src.utils.mp_utils import LMKExtractor
30 | from src.utils.draw_util import FaceMeshVisualizer
31 |
32 |
33 | def parse_args():
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument("--config", type=str, default='./configs/prompts/animation.yaml')
36 | parser.add_argument("-W", type=int, default=512)
37 | parser.add_argument("-H", type=int, default=512)
38 | parser.add_argument("-L", type=int)
39 | parser.add_argument("--seed", type=int, default=42)
40 | parser.add_argument("--cfg", type=float, default=3.5)
41 | parser.add_argument("--steps", type=int, default=25)
42 | parser.add_argument("--fps", type=int)
43 | parser.add_argument("-acc", "--accelerate", action='store_true')
44 | parser.add_argument("--fi_step", type=int, default=3)
45 | args = parser.parse_args()
46 |
47 | return args
48 |
49 | def main():
50 | args = parse_args()
51 |
52 | config = OmegaConf.load(args.config)
53 |
54 | if config.weight_dtype == "fp16":
55 | weight_dtype = torch.float16
56 | else:
57 | weight_dtype = torch.float32
58 |
59 | vae = AutoencoderKL.from_pretrained(
60 | config.pretrained_vae_path,
61 | ).to("cuda", dtype=weight_dtype)
62 |
63 | reference_unet = UNet2DConditionModel.from_pretrained(
64 | config.pretrained_base_model_path,
65 | subfolder="unet",
66 | ).to(dtype=weight_dtype, device="cuda")
67 |
68 | inference_config_path = config.inference_config
69 | infer_config = OmegaConf.load(inference_config_path)
70 | denoising_unet = UNet3DConditionModel.from_pretrained_2d(
71 | config.pretrained_base_model_path,
72 | config.motion_module_path,
73 | subfolder="unet",
74 | unet_additional_kwargs=infer_config.unet_additional_kwargs,
75 | ).to(dtype=weight_dtype, device="cuda")
76 |
77 | pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
78 |
79 | image_enc = CLIPVisionModelWithProjection.from_pretrained(
80 | config.image_encoder_path
81 | ).to(dtype=weight_dtype, device="cuda")
82 |
83 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
84 | scheduler = DDIMScheduler(**sched_kwargs)
85 |
86 | generator = torch.manual_seed(args.seed)
87 |
88 | width, height = args.W, args.H
89 |
90 | # load pretrained weights
91 | denoising_unet.load_state_dict(
92 | torch.load(config.denoising_unet_path, map_location="cpu"),
93 | strict=False,
94 | )
95 | reference_unet.load_state_dict(
96 | torch.load(config.reference_unet_path, map_location="cpu"),
97 | )
98 | pose_guider.load_state_dict(
99 | torch.load(config.pose_guider_path, map_location="cpu"),
100 | )
101 |
102 | pipe = Pose2VideoPipeline(
103 | vae=vae,
104 | image_encoder=image_enc,
105 | reference_unet=reference_unet,
106 | denoising_unet=denoising_unet,
107 | pose_guider=pose_guider,
108 | scheduler=scheduler,
109 | )
110 | pipe = pipe.to("cuda", dtype=weight_dtype)
111 |
112 | date_str = datetime.now().strftime("%Y%m%d")
113 | time_str = datetime.now().strftime("%H%M")
114 | save_dir_name = f"{time_str}--seed_{args.seed}-{args.W}x{args.H}"
115 |
116 | save_dir = Path(f"output/{date_str}/{save_dir_name}")
117 | save_dir.mkdir(exist_ok=True, parents=True)
118 |
119 |
120 | lmk_extractor = LMKExtractor()
121 | vis = FaceMeshVisualizer(forehead_edge=False)
122 |
123 | if args.accelerate:
124 | frame_inter_model = init_frame_interpolation_model()
125 |
126 | for ref_image_path in config["test_cases"].keys():
127 | # Each ref_image may correspond to multiple actions
128 | for pose_video_path in config["test_cases"][ref_image_path]:
129 | ref_name = Path(ref_image_path).stem
130 | pose_name = Path(pose_video_path).stem
131 |
132 | ref_image_pil = Image.open(ref_image_path).convert("RGB")
133 | ref_image_np = cv2.cvtColor(np.array(ref_image_pil), cv2.COLOR_RGB2BGR)
134 | ref_image_np = cv2.resize(ref_image_np, (args.H, args.W))
135 |
136 | face_result = lmk_extractor(ref_image_np)
137 | assert face_result is not None, "Can not detect a face in the reference image."
138 | lmks = face_result['lmks'].astype(np.float32)
139 | ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
140 |
141 | pose_list = []
142 | pose_tensor_list = []
143 | pose_images = read_frames(pose_video_path)
144 | src_fps = get_fps(pose_video_path)
145 | print(f"pose video has {len(pose_images)} frames, with {src_fps} fps")
146 | pose_transform = transforms.Compose(
147 | [transforms.Resize((height, width)), transforms.ToTensor()]
148 | )
149 | args_L = len(pose_images) if args.L is None else args.L
150 | for pose_image_pil in pose_images[: args_L]:
151 | pose_tensor_list.append(pose_transform(pose_image_pil))
152 | sub_step = args.fi_step if args.accelerate else 1
153 | for pose_image_pil in pose_images[: args.L: sub_step]:
154 | pose_image_np = cv2.cvtColor(np.array(pose_image_pil), cv2.COLOR_RGB2BGR)
155 | pose_image_np = cv2.resize(pose_image_np, (width, height))
156 | pose_list.append(pose_image_np)
157 |
158 | pose_list = np.array(pose_list)
159 |
160 | video_length = len(pose_list)
161 |
162 | pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)
163 | pose_tensor = pose_tensor.transpose(0, 1)
164 | pose_tensor = pose_tensor.unsqueeze(0)
165 |
166 | video = pipe(
167 | ref_image_pil,
168 | pose_list,
169 | ref_pose,
170 | width,
171 | height,
172 | video_length,
173 | args.steps,
174 | args.cfg,
175 | generator=generator,
176 | ).videos
177 |
178 | if args.accelerate:
179 | video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=args.fi_step-1)
180 |
181 | ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w)
182 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(
183 | 0
184 | ) # (1, c, 1, h, w)
185 | ref_image_tensor = repeat(
186 | ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=video.shape[2]
187 | )
188 |
189 | video = torch.cat([ref_image_tensor, pose_tensor[:,:,:video.shape[2]], video], dim=0)
190 | save_path = f"{save_dir}/{ref_name}_{pose_name}_{args.H}x{args.W}_{int(args.cfg)}_{time_str}_noaudio.mp4"
191 | save_videos_grid(
192 | video,
193 | save_path,
194 | n_rows=3,
195 | fps=src_fps if args.fps is None else args.fps,
196 | )
197 |
198 | audio_output = 'audio_from_video.aac'
199 | # extract audio
200 | ffmpeg.input(pose_video_path).output(audio_output, acodec='copy').run()
201 | # merge audio and video
202 | stream = ffmpeg.input(save_path)
203 | audio = ffmpeg.input(audio_output)
204 | ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
205 |
206 | os.remove(save_path)
207 | os.remove(audio_output)
208 |
209 | if __name__ == "__main__":
210 | main()
211 |
--------------------------------------------------------------------------------
/scripts/prepare_video.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from src.utils.mp_utils import LMKExtractor
3 | from src.utils.draw_util import FaceMeshVisualizer
4 |
5 | import os
6 | import numpy as np
7 | import cv2
8 |
9 |
10 | def crop_video(input_file, output_file):
11 | width = 800
12 | height = 800
13 | x = 550
14 | y = 50
15 |
16 | ffmpeg_cmd = f'ffmpeg -i {input_file} -filter:v "crop={width}:{height}:{x}:{y}" -c:a copy {output_file}'
17 | subprocess.call(ffmpeg_cmd, shell=True)
18 |
19 |
20 |
21 | def extract_and_draw_lmks(input_file, output_file):
22 | lmk_extractor = LMKExtractor()
23 | vis = FaceMeshVisualizer()
24 |
25 | cap = cv2.VideoCapture(input_file)
26 | frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
27 | fps = cap.get(cv2.CAP_PROP_FPS)
28 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
29 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
30 |
31 | # Define the codec and create VideoWriter object
32 | fourcc = cv2.VideoWriter_fourcc(*'mp4v')
33 | out = cv2.VideoWriter(output_file, fourcc, fps, (width, height))
34 |
35 | write_ref_img = False
36 | for i in range(200):
37 | ret, frame = cap.read()
38 |
39 | if not ret:
40 | break
41 |
42 | if not write_ref_img:
43 | write_ref_img = True
44 | cv2.imwrite(os.path.join(os.path.dirname(output_file), "ref_img.jpg"), frame)
45 |
46 | result = lmk_extractor(frame)
47 |
48 | if result is not None:
49 | lmks = result['lmks'].astype(np.float32)
50 | lmk_img = vis.draw_landmarks((frame.shape[1], frame.shape[0]), lmks, normed=True)
51 | out.write(lmk_img)
52 | else:
53 | print('multiple faces in the frame')
54 |
55 |
56 | if __name__ == "__main__":
57 |
58 | input_file = "./Moore-AnimateAnyone/examples/video.mp4"
59 | lmk_video_path = "./Moore-AnimateAnyone/examples/pose.mp4"
60 |
61 | # crop video
62 | # crop_video(input_file, output_file)
63 |
64 | # extract and draw lmks
65 | extract_and_draw_lmks(input_file, lmk_video_path)
--------------------------------------------------------------------------------
/scripts/preprocess_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import numpy as np
5 | import cv2
6 | from tqdm import tqdm
7 | import glob
8 | import json
9 |
10 | from src.utils.mp_utils import LMKExtractor
11 |
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("--input_dir", type=str, default='', help='path of dataset')
16 | parser.add_argument("--output_dir", type=str, default='', help='path to save extracted annotations')
17 | parser.add_argument("--training_json", type=str, default='', help='path to save training json')
18 | args = parser.parse_args()
19 | return args
20 |
21 |
22 | def generate_training_json_mesh(video_dir, face_info_dir, res_json_path, min_clip_length=30):
23 | video_name_list = sorted(os.listdir(face_info_dir))
24 | res_data_dic = {}
25 |
26 | pbar = tqdm(range(len(video_name_list)))
27 |
28 |
29 | for video_index, video_name in enumerate(video_name_list):
30 | pbar.update(1)
31 |
32 | tem_dic = {}
33 | tem_tem_dic = {}
34 | video_clip_dir = os.path.join(video_dir, video_name)
35 | lmks_clip_dir = os.path.join(face_info_dir, video_name)
36 |
37 | video_clip_num = 1
38 | video_data_list = []
39 |
40 | frame_path_list = sorted(glob.glob(os.path.join(video_clip_dir, '*.png')))
41 | lmks_path_list = sorted(glob.glob(os.path.join(lmks_clip_dir, '*lmks.npy')))
42 |
43 | min_len = min(len(frame_path_list), len(lmks_path_list))
44 | frame_path_list = frame_path_list[:min_len]
45 | lmks_path_list = lmks_path_list[:min_len]
46 |
47 |
48 | if min_len < min_clip_length:
49 | info = 'min length: {} {}'.format(video_name, min_len)
50 | video_clip_num -= 1
51 | continue
52 |
53 | first_frame_basename = os.path.basename(frame_path_list[0]).split('.')[0]
54 | first_lmks_basename = os.path.basename(lmks_path_list[0]).split('_')[0]
55 | last_frame_basename = os.path.basename(frame_path_list[-1]).split('.')[0]
56 | last_lmks_basename = os.path.basename(lmks_path_list[-1]).split('_')[0]
57 |
58 | if (first_frame_basename != first_lmks_basename) or (last_frame_basename != last_lmks_basename):
59 | info = 'different length skip: {} , length {}/{}, frame/lmks'.format(video_name, len(frame_path_list), len(lmks_path_list))
60 | video_clip_num -= 1
61 | continue
62 |
63 | frame_name_list = [os.path.join(video_name, os.path.basename(item)) for item in frame_path_list]
64 |
65 | tem_tem_dic['frame_name_list'] = frame_name_list
66 | tem_tem_dic['frame_path_list'] = frame_path_list
67 | tem_tem_dic['lmks_list'] = lmks_path_list
68 | video_data_list.append(tem_tem_dic)
69 |
70 | tem_dic['video_clip_num'] = video_clip_num
71 | tem_dic['clip_data_list'] = video_data_list
72 | res_data_dic[video_name] = tem_dic
73 |
74 | with open(res_json_path, 'w') as f:
75 | json.dump(res_data_dic, f)
76 |
77 |
78 | def main():
79 | args = parse_args()
80 |
81 | os.makedirs(args.output_dir, exist_ok=True)
82 | folders = [f.path for f in os.scandir(args.input_dir) if f.is_dir()]
83 | folders.sort()
84 |
85 | lmk_extractor = LMKExtractor()
86 |
87 | pbar = tqdm(range(len(folders)), desc="processing ...")
88 | for folder in folders:
89 | pbar.update(1)
90 | output_subdir = os.path.join(args.output_dir, os.path.basename(folder))
91 | os.makedirs(output_subdir, exist_ok=True)
92 | for img_file in sorted(glob.glob(os.path.join(folder, "*.png"))):
93 | base = os.path.basename(img_file)
94 | lmks_output_file = os.path.join(output_subdir, os.path.splitext(base)[0] + "_lmks.npy")
95 | lmks3d_output_file = os.path.join(output_subdir, os.path.splitext(base)[0] + "_lmks3d.npy")
96 | trans_mat_output_file = os.path.join(output_subdir, os.path.splitext(base)[0] + "_trans_mat.npy")
97 | bs_output_file = os.path.join(output_subdir, os.path.splitext(base)[0] + "_bs.npy")
98 |
99 | img = cv2.imread(img_file)
100 | result = lmk_extractor(img)
101 |
102 | if result is not None:
103 | np.save(lmks_output_file, result['lmks'].astype(np.float32))
104 | np.save(lmks3d_output_file, result['lmks3d'].astype(np.float32))
105 | np.save(trans_mat_output_file, result['trans_mat'].astype(np.float32))
106 | np.save(bs_output_file, np.array(result['bs']).astype(np.float32))
107 |
108 | # write json
109 | generate_training_json_mesh(args.input_dir, args.output_dir, args.training_json, min_clip_length=30)
110 |
111 |
112 | if __name__ == "__main__":
113 | main()
114 |
115 |
116 |
--------------------------------------------------------------------------------
/scripts/vid2pose.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ffmpeg
3 | from PIL import Image
4 | import cv2
5 | from tqdm import tqdm
6 |
7 | from src.utils.util import get_fps, read_frames, save_videos_from_pil
8 | import numpy as np
9 | from src.utils.draw_util import FaceMeshVisualizer
10 | from src.utils.mp_utils import LMKExtractor
11 |
12 | if __name__ == "__main__":
13 | import argparse
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--video_path", type=str)
17 | args = parser.parse_args()
18 |
19 | if not os.path.exists(args.video_path):
20 | raise ValueError(f"Path: {args.video_path} not exists")
21 |
22 | dir_path, video_name = (
23 | os.path.dirname(args.video_path),
24 | os.path.splitext(os.path.basename(args.video_path))[0],
25 | )
26 | out_path = os.path.join(dir_path, video_name + "_kps_noaudio.mp4")
27 |
28 | lmk_extractor = LMKExtractor()
29 | vis = FaceMeshVisualizer(forehead_edge=False)
30 |
31 | width = 512
32 | height = 512
33 |
34 | fps = get_fps(args.video_path)
35 | frames = read_frames(args.video_path)
36 | kps_results = []
37 | for i, frame_pil in enumerate(tqdm(frames)):
38 | image_np = cv2.cvtColor(np.array(frame_pil), cv2.COLOR_RGB2BGR)
39 | image_np = cv2.resize(image_np, (height, width))
40 | face_result = lmk_extractor(image_np)
41 | try:
42 | lmks = face_result['lmks'].astype(np.float32)
43 | pose_img = vis.draw_landmarks((image_np.shape[1], image_np.shape[0]), lmks, normed=True)
44 | pose_img = Image.fromarray(cv2.cvtColor(pose_img, cv2.COLOR_BGR2RGB))
45 | except:
46 | pose_img = kps_results[-1]
47 |
48 | kps_results.append(pose_img)
49 |
50 | print(out_path.replace('_noaudio.mp4', '.mp4'))
51 | save_videos_from_pil(kps_results, out_path, fps=fps)
52 |
53 | audio_output = 'audio_from_video.aac'
54 | ffmpeg.input(args.video_path).output(audio_output, acodec='copy').run()
55 | stream = ffmpeg.input(out_path)
56 | audio = ffmpeg.input(audio_output)
57 | ffmpeg.output(stream.video, audio.audio, out_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac').run()
58 | os.remove(out_path)
59 | os.remove(audio_output)
60 |
--------------------------------------------------------------------------------
/scripts/vid2vid.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import ffmpeg
4 | from datetime import datetime
5 | from pathlib import Path
6 | from typing import List
7 | import subprocess
8 | import av
9 | import numpy as np
10 | import cv2
11 | import torch
12 | import torchvision
13 | from diffusers import AutoencoderKL, DDIMScheduler
14 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
15 | from einops import repeat
16 | from omegaconf import OmegaConf
17 | from PIL import Image
18 | from torchvision import transforms
19 | from transformers import CLIPVisionModelWithProjection
20 |
21 | from configs.prompts.test_cases import TestCasesDict
22 | from src.models.pose_guider import PoseGuider
23 | from src.models.unet_2d_condition import UNet2DConditionModel
24 | from src.models.unet_3d import UNet3DConditionModel
25 | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
26 | from src.utils.util import get_fps, read_frames, save_videos_grid
27 |
28 | from src.utils.mp_utils import LMKExtractor
29 | from src.utils.draw_util import FaceMeshVisualizer
30 | from src.utils.pose_util import project_points_with_trans, matrix_to_euler_and_translation, euler_and_translation_to_matrix, smooth_pose_seq
31 | from src.utils.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool
32 |
33 |
34 | def parse_args():
35 | parser = argparse.ArgumentParser()
36 | parser.add_argument("--config", type=str, default='./configs/prompts/animation_facereenac.yaml')
37 | parser.add_argument("-W", type=int, default=512)
38 | parser.add_argument("-H", type=int, default=512)
39 | parser.add_argument("-L", type=int)
40 | parser.add_argument("--seed", type=int, default=42)
41 | parser.add_argument("--cfg", type=float, default=3.5)
42 | parser.add_argument("--steps", type=int, default=25)
43 | parser.add_argument("--fps", type=int)
44 | parser.add_argument("-acc", "--accelerate", action='store_true')
45 | parser.add_argument("--fi_step", type=int, default=3)
46 | args = parser.parse_args()
47 |
48 | return args
49 |
50 | def main():
51 | args = parse_args()
52 |
53 | config = OmegaConf.load(args.config)
54 |
55 | if config.weight_dtype == "fp16":
56 | weight_dtype = torch.float16
57 | else:
58 | weight_dtype = torch.float32
59 |
60 | vae = AutoencoderKL.from_pretrained(
61 | config.pretrained_vae_path,
62 | ).to("cuda", dtype=weight_dtype)
63 |
64 | reference_unet = UNet2DConditionModel.from_pretrained(
65 | config.pretrained_base_model_path,
66 | subfolder="unet",
67 | ).to(dtype=weight_dtype, device="cuda")
68 |
69 | inference_config_path = config.inference_config
70 | infer_config = OmegaConf.load(inference_config_path)
71 | denoising_unet = UNet3DConditionModel.from_pretrained_2d(
72 | config.pretrained_base_model_path,
73 | config.motion_module_path,
74 | subfolder="unet",
75 | unet_additional_kwargs=infer_config.unet_additional_kwargs,
76 | ).to(dtype=weight_dtype, device="cuda")
77 |
78 | pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
79 |
80 | image_enc = CLIPVisionModelWithProjection.from_pretrained(
81 | config.image_encoder_path
82 | ).to(dtype=weight_dtype, device="cuda")
83 |
84 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
85 | scheduler = DDIMScheduler(**sched_kwargs)
86 |
87 | generator = torch.manual_seed(args.seed)
88 |
89 | width, height = args.W, args.H
90 |
91 | # load pretrained weights
92 | denoising_unet.load_state_dict(
93 | torch.load(config.denoising_unet_path, map_location="cpu"),
94 | strict=False,
95 | )
96 | reference_unet.load_state_dict(
97 | torch.load(config.reference_unet_path, map_location="cpu"),
98 | )
99 | pose_guider.load_state_dict(
100 | torch.load(config.pose_guider_path, map_location="cpu"),
101 | )
102 |
103 | pipe = Pose2VideoPipeline(
104 | vae=vae,
105 | image_encoder=image_enc,
106 | reference_unet=reference_unet,
107 | denoising_unet=denoising_unet,
108 | pose_guider=pose_guider,
109 | scheduler=scheduler,
110 | )
111 | pipe = pipe.to("cuda", dtype=weight_dtype)
112 |
113 | date_str = datetime.now().strftime("%Y%m%d")
114 | time_str = datetime.now().strftime("%H%M")
115 | save_dir_name = f"{time_str}--seed_{args.seed}-{args.W}x{args.H}"
116 |
117 | save_dir = Path(f"output/{date_str}/{save_dir_name}")
118 | save_dir.mkdir(exist_ok=True, parents=True)
119 |
120 |
121 | lmk_extractor = LMKExtractor()
122 | vis = FaceMeshVisualizer(forehead_edge=False)
123 |
124 | if args.accelerate:
125 | frame_inter_model = init_frame_interpolation_model()
126 |
127 | for ref_image_path in config["test_cases"].keys():
128 | # Each ref_image may correspond to multiple actions
129 | for source_video_path in config["test_cases"][ref_image_path]:
130 | ref_name = Path(ref_image_path).stem
131 | pose_name = Path(source_video_path).stem
132 |
133 | ref_image_pil = Image.open(ref_image_path).convert("RGB")
134 | ref_image_np = cv2.cvtColor(np.array(ref_image_pil), cv2.COLOR_RGB2BGR)
135 | ref_image_np = cv2.resize(ref_image_np, (args.H, args.W))
136 |
137 | face_result = lmk_extractor(ref_image_np)
138 | assert face_result is not None, "Can not detect a face in the reference image."
139 | lmks = face_result['lmks'].astype(np.float32)
140 | ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
141 |
142 |
143 |
144 | source_images = read_frames(source_video_path)
145 | src_fps = get_fps(source_video_path)
146 | print(f"source video has {len(source_images)} frames, with {src_fps} fps")
147 | pose_transform = transforms.Compose(
148 | [transforms.Resize((height, width)), transforms.ToTensor()]
149 | )
150 |
151 | step = 1
152 | if src_fps == 60:
153 | src_fps = 30
154 | step = 2
155 |
156 | pose_trans_list = []
157 | verts_list = []
158 | bs_list = []
159 | src_tensor_list = []
160 | args_L = len(source_images) if args.L is None else args.L*step
161 | for src_image_pil in source_images[: args_L: step]:
162 | src_tensor_list.append(pose_transform(src_image_pil))
163 | sub_step = step*args.fi_step if args.accelerate else step
164 | for src_image_pil in source_images[: args_L: sub_step]:
165 | src_img_np = cv2.cvtColor(np.array(src_image_pil), cv2.COLOR_RGB2BGR)
166 | frame_height, frame_width, _ = src_img_np.shape
167 | src_img_result = lmk_extractor(src_img_np)
168 | if src_img_result is None:
169 | break
170 | pose_trans_list.append(src_img_result['trans_mat'])
171 | verts_list.append(src_img_result['lmks3d'])
172 | bs_list.append(src_img_result['bs'])
173 |
174 | trans_mat_arr = np.array(pose_trans_list)
175 | verts_arr = np.array(verts_list)
176 | bs_arr = np.array(bs_list)
177 | min_bs_idx = np.argmin(bs_arr.sum(1))
178 |
179 | # compute delta pose
180 | pose_arr = np.zeros([trans_mat_arr.shape[0], 6])
181 |
182 | for i in range(pose_arr.shape[0]):
183 | euler_angles, translation_vector = matrix_to_euler_and_translation(trans_mat_arr[i]) # real pose of source
184 | pose_arr[i, :3] = euler_angles
185 | pose_arr[i, 3:6] = translation_vector
186 |
187 | init_tran_vec = face_result['trans_mat'][:3, 3] # init translation of tgt
188 | pose_arr[:, 3:6] = pose_arr[:, 3:6] - pose_arr[0, 3:6] + init_tran_vec # (relative translation of source) + (init translation of tgt)
189 |
190 | pose_arr_smooth = smooth_pose_seq(pose_arr, window_size=3)
191 | pose_mat_smooth = [euler_and_translation_to_matrix(pose_arr_smooth[i][:3], pose_arr_smooth[i][3:6]) for i in range(pose_arr_smooth.shape[0])]
192 | pose_mat_smooth = np.array(pose_mat_smooth)
193 |
194 | # face retarget
195 | verts_arr = verts_arr - verts_arr[min_bs_idx] + face_result['lmks3d']
196 | # project 3D mesh to 2D landmark
197 | projected_vertices = project_points_with_trans(verts_arr, pose_mat_smooth, [frame_height, frame_width])
198 |
199 | pose_list = []
200 | for i, verts in enumerate(projected_vertices):
201 | lmk_img = vis.draw_landmarks((frame_width, frame_height), verts, normed=False)
202 | pose_image_np = cv2.resize(lmk_img, (width, height))
203 | pose_list.append(pose_image_np)
204 |
205 | pose_list = np.array(pose_list)
206 |
207 | video_length = len(pose_list)
208 |
209 | src_tensor = torch.stack(src_tensor_list, dim=0) # (f, c, h, w)
210 | src_tensor = src_tensor.transpose(0, 1)
211 | src_tensor = src_tensor.unsqueeze(0)
212 |
213 | video = pipe(
214 | ref_image_pil,
215 | pose_list,
216 | ref_pose,
217 | width,
218 | height,
219 | video_length,
220 | args.steps,
221 | args.cfg,
222 | generator=generator,
223 | ).videos
224 |
225 | if args.accelerate:
226 | video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=args.fi_step-1)
227 |
228 | ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w)
229 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(
230 | 0
231 | ) # (1, c, 1, h, w)
232 | ref_image_tensor = repeat(
233 | ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=video.shape[2]
234 | )
235 |
236 | video = torch.cat([ref_image_tensor, video, src_tensor[:,:,:video.shape[2]]], dim=0)
237 | save_path = f"{save_dir}/{ref_name}_{pose_name}_{args.H}x{args.W}_{int(args.cfg)}_{time_str}_noaudio.mp4"
238 | save_videos_grid(
239 | video,
240 | save_path,
241 | n_rows=3,
242 | fps=src_fps if args.fps is None else args.fps,
243 | )
244 |
245 | audio_output = 'audio_from_video.aac'
246 | # extract audio
247 | ffmpeg.input(source_video_path).output(audio_output, acodec='copy').run()
248 | # merge audio and video
249 | stream = ffmpeg.input(save_path)
250 | audio = ffmpeg.input(audio_output)
251 | ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
252 |
253 | os.remove(save_path)
254 | os.remove(audio_output)
255 |
256 | if __name__ == "__main__":
257 | main()
258 |
--------------------------------------------------------------------------------
/src/audio_models/mish.py:
--------------------------------------------------------------------------------
1 | """
2 | Applies the mish function element-wise:
3 | mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
4 | """
5 |
6 | # import pytorch
7 | import torch
8 | import torch.nn.functional as F
9 | from torch import nn
10 |
11 | @torch.jit.script
12 | def mish(input):
13 | """
14 | Applies the mish function element-wise:
15 | mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
16 | See additional documentation for mish class.
17 | """
18 | return input * torch.tanh(F.softplus(input))
19 |
20 | class Mish(nn.Module):
21 | """
22 | Applies the mish function element-wise:
23 | mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
24 |
25 | Shape:
26 | - Input: (N, *) where * means, any number of additional
27 | dimensions
28 | - Output: (N, *), same shape as the input
29 |
30 | Examples:
31 | >>> m = Mish()
32 | >>> input = torch.randn(2)
33 | >>> output = m(input)
34 |
35 | Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
36 | """
37 |
38 | def __init__(self):
39 | """
40 | Init method.
41 | """
42 | super().__init__()
43 |
44 | def forward(self, input):
45 | """
46 | Forward pass of the function.
47 | """
48 | if torch.__version__ >= "1.9":
49 | return F.mish(input)
50 | else:
51 | return mish(input)
--------------------------------------------------------------------------------
/src/audio_models/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from transformers import Wav2Vec2Config
6 |
7 | from .torch_utils import get_mask_from_lengths
8 | from .wav2vec2 import Wav2Vec2Model
9 |
10 |
11 | class Audio2MeshModel(nn.Module):
12 | def __init__(
13 | self,
14 | config
15 | ):
16 | super().__init__()
17 | out_dim = config['out_dim']
18 | latent_dim = config['latent_dim']
19 | model_path = config['model_path']
20 | only_last_fetures = config['only_last_fetures']
21 | from_pretrained = config['from_pretrained']
22 |
23 | self._only_last_features = only_last_fetures
24 |
25 | self.audio_encoder_config = Wav2Vec2Config.from_pretrained(model_path, local_files_only=True)
26 | if from_pretrained:
27 | self.audio_encoder = Wav2Vec2Model.from_pretrained(model_path, local_files_only=True)
28 | else:
29 | self.audio_encoder = Wav2Vec2Model(self.audio_encoder_config)
30 | self.audio_encoder.feature_extractor._freeze_parameters()
31 |
32 | hidden_size = self.audio_encoder_config.hidden_size
33 |
34 | self.in_fn = nn.Linear(hidden_size, latent_dim)
35 |
36 | self.out_fn = nn.Linear(latent_dim, out_dim)
37 | nn.init.constant_(self.out_fn.weight, 0)
38 | nn.init.constant_(self.out_fn.bias, 0)
39 |
40 | def forward(self, audio, label, audio_len=None):
41 | attention_mask = ~get_mask_from_lengths(audio_len) if audio_len else None
42 |
43 | seq_len = label.shape[1]
44 |
45 | embeddings = self.audio_encoder(audio, seq_len=seq_len, output_hidden_states=True,
46 | attention_mask=attention_mask)
47 |
48 | if self._only_last_features:
49 | hidden_states = embeddings.last_hidden_state
50 | else:
51 | hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states)
52 |
53 | layer_in = self.in_fn(hidden_states)
54 | out = self.out_fn(layer_in)
55 |
56 | return out, None
57 |
58 | def infer(self, input_value, seq_len):
59 | embeddings = self.audio_encoder(input_value, seq_len=seq_len, output_hidden_states=True)
60 |
61 | if self._only_last_features:
62 | hidden_states = embeddings.last_hidden_state
63 | else:
64 | hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states)
65 |
66 | layer_in = self.in_fn(hidden_states)
67 | out = self.out_fn(layer_in)
68 |
69 | return out
70 |
71 |
72 |
--------------------------------------------------------------------------------
/src/audio_models/pose_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | from transformers import Wav2Vec2Config
6 |
7 | from .torch_utils import get_mask_from_lengths
8 | from .wav2vec2 import Wav2Vec2Model
9 |
10 |
11 | def init_biased_mask(n_head, max_seq_len, period):
12 | def get_slopes(n):
13 | def get_slopes_power_of_2(n):
14 | start = (2**(-2**-(math.log2(n)-3)))
15 | ratio = start
16 | return [start*ratio**i for i in range(n)]
17 | if math.log2(n).is_integer():
18 | return get_slopes_power_of_2(n)
19 | else:
20 | closest_power_of_2 = 2**math.floor(math.log2(n))
21 | return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]
22 | slopes = torch.Tensor(get_slopes(n_head))
23 | bias = torch.arange(start=0, end=max_seq_len, step=period).unsqueeze(1).repeat(1,period).view(-1)//(period)
24 | bias = - torch.flip(bias,dims=[0])
25 | alibi = torch.zeros(max_seq_len, max_seq_len)
26 | for i in range(max_seq_len):
27 | alibi[i, :i+1] = bias[-(i+1):]
28 | alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)
29 | mask = (torch.triu(torch.ones(max_seq_len, max_seq_len)) == 1).transpose(0, 1)
30 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
31 | mask = mask.unsqueeze(0) + alibi
32 | return mask
33 |
34 |
35 | def enc_dec_mask(device, T, S):
36 | mask = torch.ones(T, S)
37 | for i in range(T):
38 | mask[i, i] = 0
39 | return (mask==1).to(device=device)
40 |
41 |
42 | class PositionalEncoding(nn.Module):
43 | def __init__(self, d_model, max_len=600):
44 | super(PositionalEncoding, self).__init__()
45 | pe = torch.zeros(max_len, d_model)
46 | position = torch.arange(0, max_len).unsqueeze(1).float()
47 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
48 | pe[:, 0::2] = torch.sin(position * div_term)
49 | pe[:, 1::2] = torch.cos(position * div_term)
50 | pe = pe.unsqueeze(0)
51 | self.register_buffer('pe', pe)
52 |
53 | def forward(self, x):
54 | x = x + self.pe[:, :x.size(1)]
55 | return x
56 |
57 |
58 | class Audio2PoseModel(nn.Module):
59 | def __init__(
60 | self,
61 | config
62 | ):
63 |
64 | super().__init__()
65 |
66 | latent_dim = config['latent_dim']
67 | model_path = config['model_path']
68 | only_last_fetures = config['only_last_fetures']
69 | from_pretrained = config['from_pretrained']
70 | out_dim = config['out_dim']
71 |
72 | self.out_dim = out_dim
73 |
74 | self._only_last_features = only_last_fetures
75 |
76 | self.audio_encoder_config = Wav2Vec2Config.from_pretrained(model_path, local_files_only=True)
77 | if from_pretrained:
78 | self.audio_encoder = Wav2Vec2Model.from_pretrained(model_path, local_files_only=True)
79 | else:
80 | self.audio_encoder = Wav2Vec2Model(self.audio_encoder_config)
81 | self.audio_encoder.feature_extractor._freeze_parameters()
82 |
83 | hidden_size = self.audio_encoder_config.hidden_size
84 |
85 | self.pose_map = nn.Linear(out_dim, latent_dim)
86 | self.in_fn = nn.Linear(hidden_size, latent_dim)
87 |
88 | self.PPE = PositionalEncoding(latent_dim)
89 | self.biased_mask = init_biased_mask(n_head = 8, max_seq_len = 600, period=1)
90 | decoder_layer = nn.TransformerDecoderLayer(d_model=latent_dim, nhead=8, dim_feedforward=2*latent_dim, batch_first=True)
91 | self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=8)
92 | self.pose_map_r = nn.Linear(latent_dim, out_dim)
93 |
94 | self.id_embed = nn.Embedding(100, latent_dim) # 100 ids
95 |
96 |
97 | def infer(self, input_value, seq_len, id_seed=None):
98 | embeddings = self.audio_encoder(input_value, seq_len=seq_len, output_hidden_states=True)
99 |
100 | if self._only_last_features:
101 | hidden_states = embeddings.last_hidden_state
102 | else:
103 | hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states)
104 |
105 | hidden_states = self.in_fn(hidden_states)
106 |
107 | id_embedding = self.id_embed(id_seed).unsqueeze(1)
108 |
109 | init_pose = torch.zeros([hidden_states.shape[0], 1, self.out_dim]).to(hidden_states.device)
110 | for i in range(seq_len):
111 | if i==0:
112 | pose_emb = self.pose_map(init_pose)
113 | pose_input = self.PPE(pose_emb)
114 | else:
115 | pose_input = self.PPE(pose_emb)
116 |
117 | pose_input = pose_input + id_embedding
118 | tgt_mask = self.biased_mask[:, :pose_input.shape[1], :pose_input.shape[1]].clone().detach().to(hidden_states.device)
119 | memory_mask = enc_dec_mask(hidden_states.device, pose_input.shape[1], hidden_states.shape[1])
120 | pose_out = self.transformer_decoder(pose_input, hidden_states, tgt_mask=tgt_mask, memory_mask=memory_mask)
121 | pose_out = self.pose_map_r(pose_out)
122 | new_output = self.pose_map(pose_out[:,-1,:]).unsqueeze(1)
123 | pose_emb = torch.cat((pose_emb, new_output), 1)
124 | return pose_out
125 |
126 |
--------------------------------------------------------------------------------
/src/audio_models/torch_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def get_mask_from_lengths(lengths, max_len=None):
6 | lengths = lengths.to(torch.long)
7 | if max_len is None:
8 | max_len = torch.max(lengths).item()
9 |
10 | ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device)
11 | mask = ids < lengths.unsqueeze(1).expand(-1, max_len)
12 |
13 | return mask
14 |
15 |
16 | def linear_interpolation(features, seq_len):
17 | features = features.transpose(1, 2)
18 | output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
19 | return output_features.transpose(1, 2)
20 |
21 |
22 | if __name__ == "__main__":
23 | import numpy as np
24 | mask = ~get_mask_from_lengths(torch.from_numpy(np.array([4,6])))
25 | import pdb; pdb.set_trace()
--------------------------------------------------------------------------------
/src/audio_models/wav2vec2.py:
--------------------------------------------------------------------------------
1 | from transformers import Wav2Vec2Config, Wav2Vec2Model
2 | from transformers.modeling_outputs import BaseModelOutput
3 |
4 | from .torch_utils import linear_interpolation
5 |
6 | # the implementation of Wav2Vec2Model is borrowed from
7 | # https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py
8 | # initialize our encoder with the pre-trained wav2vec 2.0 weights.
9 | class Wav2Vec2Model(Wav2Vec2Model):
10 | def __init__(self, config: Wav2Vec2Config):
11 | super().__init__(config)
12 |
13 | def forward(
14 | self,
15 | input_values,
16 | seq_len,
17 | attention_mask=None,
18 | mask_time_indices=None,
19 | output_attentions=None,
20 | output_hidden_states=None,
21 | return_dict=None,
22 | ):
23 | self.config.output_attentions = True
24 |
25 | output_hidden_states = (
26 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
27 | )
28 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
29 |
30 | extract_features = self.feature_extractor(input_values)
31 | extract_features = extract_features.transpose(1, 2)
32 | extract_features = linear_interpolation(extract_features, seq_len=seq_len)
33 |
34 | if attention_mask is not None:
35 | # compute reduced attention_mask corresponding to feature vectors
36 | attention_mask = self._get_feature_vector_attention_mask(
37 | extract_features.shape[1], attention_mask, add_adapter=False
38 | )
39 |
40 | hidden_states, extract_features = self.feature_projection(extract_features)
41 | hidden_states = self._mask_hidden_states(
42 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
43 | )
44 |
45 | encoder_outputs = self.encoder(
46 | hidden_states,
47 | attention_mask=attention_mask,
48 | output_attentions=output_attentions,
49 | output_hidden_states=output_hidden_states,
50 | return_dict=return_dict,
51 | )
52 |
53 | hidden_states = encoder_outputs[0]
54 |
55 | if self.adapter is not None:
56 | hidden_states = self.adapter(hidden_states)
57 |
58 | if not return_dict:
59 | return (hidden_states, ) + encoder_outputs[1:]
60 | return BaseModelOutput(
61 | last_hidden_state=hidden_states,
62 | hidden_states=encoder_outputs.hidden_states,
63 | attentions=encoder_outputs.attentions,
64 | )
65 |
66 |
67 | def feature_extract(
68 | self,
69 | input_values,
70 | seq_len,
71 | ):
72 | extract_features = self.feature_extractor(input_values)
73 | extract_features = extract_features.transpose(1, 2)
74 | extract_features = linear_interpolation(extract_features, seq_len=seq_len)
75 |
76 | return extract_features
77 |
78 | def encode(
79 | self,
80 | extract_features,
81 | attention_mask=None,
82 | mask_time_indices=None,
83 | output_attentions=None,
84 | output_hidden_states=None,
85 | return_dict=None,
86 | ):
87 | self.config.output_attentions = True
88 |
89 | output_hidden_states = (
90 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
91 | )
92 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
93 |
94 | if attention_mask is not None:
95 | # compute reduced attention_mask corresponding to feature vectors
96 | attention_mask = self._get_feature_vector_attention_mask(
97 | extract_features.shape[1], attention_mask, add_adapter=False
98 | )
99 |
100 |
101 | hidden_states, extract_features = self.feature_projection(extract_features)
102 | hidden_states = self._mask_hidden_states(
103 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
104 | )
105 |
106 | encoder_outputs = self.encoder(
107 | hidden_states,
108 | attention_mask=attention_mask,
109 | output_attentions=output_attentions,
110 | output_hidden_states=output_hidden_states,
111 | return_dict=return_dict,
112 | )
113 |
114 | hidden_states = encoder_outputs[0]
115 |
116 | if self.adapter is not None:
117 | hidden_states = self.adapter(hidden_states)
118 |
119 | if not return_dict:
120 | return (hidden_states, ) + encoder_outputs[1:]
121 | return BaseModelOutput(
122 | last_hidden_state=hidden_states,
123 | hidden_states=encoder_outputs.hidden_states,
124 | attentions=encoder_outputs.attentions,
125 | )
126 |
--------------------------------------------------------------------------------
/src/dataset/dataset_face.py:
--------------------------------------------------------------------------------
1 | import os, io, csv, math, random, pdb
2 | import cv2
3 | import numpy as np
4 | import json
5 | from PIL import Image
6 | from einops import rearrange
7 |
8 | import torch
9 | import torchvision.transforms as transforms
10 | from torch.utils.data.dataset import Dataset
11 | from transformers import CLIPImageProcessor
12 | import torch.distributed as dist
13 |
14 |
15 | from src.utils.draw_util import FaceMeshVisualizer
16 |
17 | def zero_rank_print(s):
18 | if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
19 |
20 |
21 |
22 | class FaceDatasetValid(Dataset):
23 | def __init__(
24 | self,
25 | json_path,
26 | extra_json_path=None,
27 | sample_size=[512, 512], sample_stride=4, sample_n_frames=16,
28 | is_image=False,
29 | sample_stride_aug=False
30 | ):
31 | zero_rank_print(f"loading annotations from {json_path} ...")
32 | self.data_dic_name_list, self.data_dic = self.get_data(json_path, extra_json_path)
33 |
34 | self.length = len(self.data_dic_name_list)
35 | zero_rank_print(f"data scale: {self.length}")
36 |
37 | self.sample_stride = sample_stride
38 | self.sample_n_frames = sample_n_frames
39 |
40 | self.sample_stride_aug = sample_stride_aug
41 |
42 | self.sample_size = sample_size
43 | self.resize = transforms.Resize((sample_size[0], sample_size[1]))
44 |
45 |
46 | self.pixel_transforms = transforms.Compose([
47 | transforms.Resize([sample_size[1], sample_size[0]]),
48 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
49 | ])
50 |
51 | self.visualizer = FaceMeshVisualizer(forehead_edge=False)
52 | self.clip_image_processor = CLIPImageProcessor()
53 | self.is_image = is_image
54 |
55 | def get_data(self, json_name, extra_json_name, augment_num=1):
56 | zero_rank_print(f"start loading data: {json_name}")
57 | with open(json_name,'r') as f:
58 | data_dic = json.load(f)
59 |
60 | data_dic_name_list = []
61 | for augment_index in range(augment_num):
62 | for video_name in data_dic.keys():
63 | data_dic_name_list.append(video_name)
64 |
65 | invalid_video_name_list = []
66 | for video_name in data_dic_name_list:
67 | video_clip_num = len(data_dic[video_name]['clip_data_list'])
68 | if video_clip_num < 1:
69 | invalid_video_name_list.append(video_name)
70 | for name in invalid_video_name_list:
71 | data_dic_name_list.remove(name)
72 |
73 |
74 | if extra_json_name is not None:
75 | zero_rank_print(f"start loading data: {extra_json_name}")
76 | with open(extra_json_name,'r') as f:
77 | extra_data_dic = json.load(f)
78 | data_dic.update(extra_data_dic)
79 | for augment_index in range(3*augment_num):
80 | for video_name in extra_data_dic.keys():
81 | data_dic_name_list.append(video_name)
82 | random.shuffle(data_dic_name_list)
83 | zero_rank_print("finish loading")
84 | return data_dic_name_list, data_dic
85 |
86 | def __len__(self):
87 | return len(self.data_dic_name_list)
88 |
89 | def get_batch_wo_pose(self, index):
90 | video_name = self.data_dic_name_list[index]
91 | video_clip_num = len(self.data_dic[video_name]['clip_data_list'])
92 |
93 | source_anchor = random.sample(range(video_clip_num), 1)[0]
94 | source_image_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['frame_path_list']
95 | source_mesh2d_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['lmks_list']
96 |
97 | video_length = len(source_image_path_list)
98 |
99 | if self.sample_stride_aug:
100 | tmp_sample_stride = self.sample_stride if random.random() > 0.5 else 4
101 | else:
102 | tmp_sample_stride = self.sample_stride
103 |
104 | if not self.is_image:
105 | clip_length = min(video_length, (self.sample_n_frames - 1) * tmp_sample_stride + 1)
106 | start_idx = random.randint(0, video_length - clip_length)
107 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
108 | else:
109 | batch_index = [random.randint(0, video_length - 1)]
110 |
111 | ref_img_idx = random.randint(0, video_length - 1)
112 |
113 | ref_img = cv2.imread(source_image_path_list[ref_img_idx])
114 | ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)
115 | ref_img = self.contrast_normalization(ref_img)
116 |
117 | ref_mesh2d_clip = np.load(source_mesh2d_path_list[ref_img_idx]).astype(float)
118 | ref_pose_image = self.visualizer.draw_landmarks(self.sample_size, ref_mesh2d_clip, normed=True)
119 |
120 | images = [cv2.imread(source_image_path_list[idx]) for idx in batch_index]
121 | images = [cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB) for bgr_image in images]
122 | image_np = np.array([self.contrast_normalization(img) for img in images])
123 |
124 | pixel_values = torch.from_numpy(image_np).permute(0, 3, 1, 2).contiguous()
125 | pixel_values = pixel_values / 255.
126 |
127 | mesh2d_clip = np.array([np.load(source_mesh2d_path_list[idx]).astype(float) for idx in batch_index])
128 |
129 | pixel_values_pose = []
130 | for frame_id in range(mesh2d_clip.shape[0]):
131 | normed_mesh2d = mesh2d_clip[frame_id]
132 |
133 | pose_image = self.visualizer.draw_landmarks(self.sample_size, normed_mesh2d, normed=True)
134 | pixel_values_pose.append(pose_image)
135 | pixel_values_pose = np.array(pixel_values_pose)
136 |
137 | if self.is_image:
138 | pixel_values = pixel_values[0]
139 | pixel_values_pose = pixel_values_pose[0]
140 | image_np = image_np[0]
141 |
142 | return ref_img, pixel_values_pose, image_np, ref_pose_image
143 |
144 | def contrast_normalization(self, image, lower_bound=0, upper_bound=255):
145 | # convert input image to float32
146 | image = image.astype(np.float32)
147 |
148 | # normalize the image
149 | normalized_image = image * (upper_bound - lower_bound) / 255 + lower_bound
150 |
151 | # convert to uint8
152 | normalized_image = normalized_image.astype(np.uint8)
153 |
154 | return normalized_image
155 |
156 | def __getitem__(self, idx):
157 | ref_img, pixel_values_pose, tar_gt, pixel_values_ref_pose = self.get_batch_wo_pose(idx)
158 |
159 | sample = dict(
160 | pixel_values_pose=pixel_values_pose,
161 | ref_img=ref_img,
162 | tar_gt=tar_gt,
163 | pixel_values_ref_pose=pixel_values_ref_pose,
164 | )
165 |
166 | return sample
167 |
168 |
169 |
170 | class FaceDataset(Dataset):
171 | def __init__(
172 | self,
173 | json_path,
174 | extra_json_path=None,
175 | sample_size=[512, 512], sample_stride=4, sample_n_frames=16,
176 | is_image=False,
177 | sample_stride_aug=False
178 | ):
179 | zero_rank_print(f"loading annotations from {json_path} ...")
180 | self.data_dic_name_list, self.data_dic = self.get_data(json_path, extra_json_path)
181 |
182 | self.length = len(self.data_dic_name_list)
183 | zero_rank_print(f"data scale: {self.length}")
184 |
185 | self.sample_stride = sample_stride
186 | self.sample_n_frames = sample_n_frames
187 |
188 | self.sample_stride_aug = sample_stride_aug
189 |
190 | self.sample_size = sample_size
191 | self.resize = transforms.Resize((sample_size[0], sample_size[1]))
192 |
193 |
194 | self.pixel_transforms = transforms.Compose([
195 | transforms.Resize([sample_size[1], sample_size[0]]),
196 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
197 | ])
198 |
199 | self.visualizer = FaceMeshVisualizer(forehead_edge=False)
200 | self.clip_image_processor = CLIPImageProcessor()
201 | self.is_image = is_image
202 |
203 | def get_data(self, json_name, extra_json_name, augment_num=1):
204 | zero_rank_print(f"start loading data: {json_name}")
205 | with open(json_name,'r') as f:
206 | data_dic = json.load(f)
207 |
208 | data_dic_name_list = []
209 | for augment_index in range(augment_num):
210 | for video_name in data_dic.keys():
211 | data_dic_name_list.append(video_name)
212 |
213 | invalid_video_name_list = []
214 | for video_name in data_dic_name_list:
215 | video_clip_num = len(data_dic[video_name]['clip_data_list'])
216 | if video_clip_num < 1:
217 | invalid_video_name_list.append(video_name)
218 | for name in invalid_video_name_list:
219 | data_dic_name_list.remove(name)
220 |
221 |
222 | if extra_json_name is not None:
223 | zero_rank_print(f"start loading data: {extra_json_name}")
224 | with open(extra_json_name,'r') as f:
225 | extra_data_dic = json.load(f)
226 | data_dic.update(extra_data_dic)
227 | for augment_index in range(3*augment_num):
228 | for video_name in extra_data_dic.keys():
229 | data_dic_name_list.append(video_name)
230 | random.shuffle(data_dic_name_list)
231 | zero_rank_print("finish loading")
232 | return data_dic_name_list, data_dic
233 |
234 | def __len__(self):
235 | return len(self.data_dic_name_list)
236 |
237 |
238 | def get_batch_wo_pose(self, index):
239 | video_name = self.data_dic_name_list[index]
240 | video_clip_num = len(self.data_dic[video_name]['clip_data_list'])
241 |
242 | source_anchor = random.sample(range(video_clip_num), 1)[0]
243 | source_image_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['frame_path_list']
244 | source_mesh2d_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['lmks_list']
245 |
246 | video_length = len(source_image_path_list)
247 |
248 | if self.sample_stride_aug:
249 | tmp_sample_stride = self.sample_stride if random.random() > 0.5 else 4
250 | else:
251 | tmp_sample_stride = self.sample_stride
252 |
253 | if not self.is_image:
254 | clip_length = min(video_length, (self.sample_n_frames - 1) * tmp_sample_stride + 1)
255 | start_idx = random.randint(0, video_length - clip_length)
256 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
257 | else:
258 | batch_index = [random.randint(0, video_length - 1)]
259 |
260 | ref_img_idx = random.randint(0, video_length - 1)
261 | ref_img = cv2.imread(source_image_path_list[ref_img_idx])
262 | ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)
263 | ref_img = self.contrast_normalization(ref_img)
264 | ref_img_pil = Image.fromarray(ref_img)
265 |
266 | clip_ref_image = self.clip_image_processor(images=ref_img_pil, return_tensors="pt").pixel_values
267 |
268 | pixel_values_ref_img = torch.from_numpy(ref_img).permute(2, 0, 1).contiguous()
269 | pixel_values_ref_img = pixel_values_ref_img / 255.
270 |
271 | ref_mesh2d_clip = np.load(source_mesh2d_path_list[ref_img_idx]).astype(float)
272 | ref_pose_image = self.visualizer.draw_landmarks(self.sample_size, ref_mesh2d_clip, normed=True)
273 | pixel_values_ref_pose = torch.from_numpy(ref_pose_image).permute(2, 0, 1).contiguous()
274 | pixel_values_ref_pose = pixel_values_ref_pose / 255.
275 |
276 | images = [cv2.imread(source_image_path_list[idx]) for idx in batch_index]
277 | images = [cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB) for bgr_image in images]
278 | image_np = np.array([self.contrast_normalization(img) for img in images])
279 |
280 | pixel_values = torch.from_numpy(image_np).permute(0, 3, 1, 2).contiguous()
281 | pixel_values = pixel_values / 255.
282 |
283 | mesh2d_clip = np.array([np.load(source_mesh2d_path_list[idx]).astype(float) for idx in batch_index])
284 |
285 | pixel_values_pose = []
286 | for frame_id in range(mesh2d_clip.shape[0]):
287 | normed_mesh2d = mesh2d_clip[frame_id]
288 |
289 | pose_image = self.visualizer.draw_landmarks(self.sample_size, normed_mesh2d, normed=True)
290 |
291 | pixel_values_pose.append(pose_image)
292 |
293 | pixel_values_pose = np.array(pixel_values_pose)
294 | pixel_values_pose = torch.from_numpy(pixel_values_pose).permute(0, 3, 1, 2).contiguous()
295 | pixel_values_pose = pixel_values_pose / 255.
296 |
297 | if self.is_image:
298 | pixel_values = pixel_values[0]
299 | pixel_values_pose = pixel_values_pose[0]
300 |
301 | return pixel_values, pixel_values_pose, clip_ref_image, pixel_values_ref_img, pixel_values_ref_pose
302 |
303 | def contrast_normalization(self, image, lower_bound=0, upper_bound=255):
304 | image = image.astype(np.float32)
305 | normalized_image = image * (upper_bound - lower_bound) / 255 + lower_bound
306 | normalized_image = normalized_image.astype(np.uint8)
307 |
308 | return normalized_image
309 |
310 | def __getitem__(self, idx):
311 | pixel_values, pixel_values_pose, clip_ref_image, pixel_values_ref_img, pixel_values_ref_pose = self.get_batch_wo_pose(idx)
312 |
313 | pixel_values = self.pixel_transforms(pixel_values)
314 | pixel_values_pose = self.pixel_transforms(pixel_values_pose)
315 |
316 | pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
317 | pixel_values_ref_img = self.pixel_transforms(pixel_values_ref_img)
318 | pixel_values_ref_img = pixel_values_ref_img.squeeze(0)
319 |
320 | pixel_values_ref_pose = pixel_values_ref_pose.unsqueeze(0)
321 | pixel_values_ref_pose = self.pixel_transforms(pixel_values_ref_pose)
322 | pixel_values_ref_pose = pixel_values_ref_pose.squeeze(0)
323 |
324 | drop_image_embeds = 1 if random.random() < 0.1 else 0
325 |
326 | sample = dict(
327 | pixel_values=pixel_values,
328 | pixel_values_pose=pixel_values_pose,
329 | clip_ref_image=clip_ref_image,
330 | pixel_values_ref_img=pixel_values_ref_img,
331 | drop_image_embeds=drop_image_embeds,
332 | pixel_values_ref_pose=pixel_values_ref_pose,
333 | )
334 |
335 | return sample
336 |
337 | def collate_fn(data):
338 | pixel_values = torch.stack([example["pixel_values"] for example in data])
339 | pixel_values_pose = torch.stack([example["pixel_values_pose"] for example in data])
340 | clip_ref_image = torch.cat([example["clip_ref_image"] for example in data])
341 | pixel_values_ref_img = torch.stack([example["pixel_values_ref_img"] for example in data])
342 | drop_image_embeds = [example["drop_image_embeds"] for example in data]
343 | drop_image_embeds = torch.Tensor(drop_image_embeds)
344 | pixel_values_ref_pose = torch.stack([example["pixel_values_ref_pose"] for example in data])
345 |
346 | return {
347 | "pixel_values": pixel_values,
348 | "pixel_values_pose": pixel_values_pose,
349 | "clip_ref_image": clip_ref_image,
350 | "pixel_values_ref_img": pixel_values_ref_img,
351 | "drop_image_embeds": drop_image_embeds,
352 | "pixel_values_ref_pose": pixel_values_ref_pose,
353 | }
354 |
355 |
--------------------------------------------------------------------------------
/src/models/motion_module.py:
--------------------------------------------------------------------------------
1 | # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
2 | import math
3 | from dataclasses import dataclass
4 | from typing import Callable, Optional
5 |
6 | import torch
7 | from diffusers.models.attention import FeedForward
8 | from diffusers.models.attention_processor import Attention, AttnProcessor
9 | from diffusers.utils import BaseOutput
10 | from diffusers.utils.import_utils import is_xformers_available
11 | from einops import rearrange, repeat
12 | from torch import nn
13 |
14 |
15 | def zero_module(module):
16 | # Zero out the parameters of a module and return it.
17 | for p in module.parameters():
18 | p.detach().zero_()
19 | return module
20 |
21 |
22 | @dataclass
23 | class TemporalTransformer3DModelOutput(BaseOutput):
24 | sample: torch.FloatTensor
25 |
26 |
27 | if is_xformers_available():
28 | import xformers
29 | import xformers.ops
30 | else:
31 | xformers = None
32 |
33 |
34 | def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
35 | if motion_module_type == "Vanilla":
36 | return VanillaTemporalModule(
37 | in_channels=in_channels,
38 | **motion_module_kwargs,
39 | )
40 | else:
41 | raise ValueError
42 |
43 |
44 | class VanillaTemporalModule(nn.Module):
45 | def __init__(
46 | self,
47 | in_channels,
48 | num_attention_heads=8,
49 | num_transformer_block=2,
50 | attention_block_types=("Temporal_Self", "Temporal_Self"),
51 | cross_frame_attention_mode=None,
52 | temporal_position_encoding=False,
53 | temporal_position_encoding_max_len=24,
54 | temporal_attention_dim_div=1,
55 | zero_initialize=True,
56 | ):
57 | super().__init__()
58 |
59 | self.temporal_transformer = TemporalTransformer3DModel(
60 | in_channels=in_channels,
61 | num_attention_heads=num_attention_heads,
62 | attention_head_dim=in_channels
63 | // num_attention_heads
64 | // temporal_attention_dim_div,
65 | num_layers=num_transformer_block,
66 | attention_block_types=attention_block_types,
67 | cross_frame_attention_mode=cross_frame_attention_mode,
68 | temporal_position_encoding=temporal_position_encoding,
69 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
70 | )
71 |
72 | if zero_initialize:
73 | self.temporal_transformer.proj_out = zero_module(
74 | self.temporal_transformer.proj_out
75 | )
76 |
77 | def forward(
78 | self,
79 | input_tensor,
80 | temb,
81 | encoder_hidden_states,
82 | attention_mask=None,
83 | anchor_frame_idx=None,
84 | ):
85 | hidden_states = input_tensor
86 | hidden_states = self.temporal_transformer(
87 | hidden_states, encoder_hidden_states, attention_mask
88 | )
89 |
90 | output = hidden_states
91 | return output
92 |
93 |
94 | class TemporalTransformer3DModel(nn.Module):
95 | def __init__(
96 | self,
97 | in_channels,
98 | num_attention_heads,
99 | attention_head_dim,
100 | num_layers,
101 | attention_block_types=(
102 | "Temporal_Self",
103 | "Temporal_Self",
104 | ),
105 | dropout=0.0,
106 | norm_num_groups=32,
107 | cross_attention_dim=768,
108 | activation_fn="geglu",
109 | attention_bias=False,
110 | upcast_attention=False,
111 | cross_frame_attention_mode=None,
112 | temporal_position_encoding=False,
113 | temporal_position_encoding_max_len=24,
114 | ):
115 | super().__init__()
116 |
117 | inner_dim = num_attention_heads * attention_head_dim
118 |
119 | self.norm = torch.nn.GroupNorm(
120 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
121 | )
122 | self.proj_in = nn.Linear(in_channels, inner_dim)
123 |
124 | self.transformer_blocks = nn.ModuleList(
125 | [
126 | TemporalTransformerBlock(
127 | dim=inner_dim,
128 | num_attention_heads=num_attention_heads,
129 | attention_head_dim=attention_head_dim,
130 | attention_block_types=attention_block_types,
131 | dropout=dropout,
132 | norm_num_groups=norm_num_groups,
133 | cross_attention_dim=cross_attention_dim,
134 | activation_fn=activation_fn,
135 | attention_bias=attention_bias,
136 | upcast_attention=upcast_attention,
137 | cross_frame_attention_mode=cross_frame_attention_mode,
138 | temporal_position_encoding=temporal_position_encoding,
139 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
140 | )
141 | for d in range(num_layers)
142 | ]
143 | )
144 | self.proj_out = nn.Linear(inner_dim, in_channels)
145 |
146 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
147 | assert (
148 | hidden_states.dim() == 5
149 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
150 | video_length = hidden_states.shape[2]
151 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
152 |
153 | batch, channel, height, weight = hidden_states.shape
154 | residual = hidden_states
155 |
156 | hidden_states = self.norm(hidden_states)
157 | inner_dim = hidden_states.shape[1]
158 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
159 | batch, height * weight, inner_dim
160 | )
161 | hidden_states = self.proj_in(hidden_states)
162 |
163 | # Transformer Blocks
164 | for block in self.transformer_blocks:
165 | hidden_states = block(
166 | hidden_states,
167 | encoder_hidden_states=encoder_hidden_states,
168 | video_length=video_length,
169 | )
170 |
171 | # output
172 | hidden_states = self.proj_out(hidden_states)
173 | hidden_states = (
174 | hidden_states.reshape(batch, height, weight, inner_dim)
175 | .permute(0, 3, 1, 2)
176 | .contiguous()
177 | )
178 |
179 | output = hidden_states + residual
180 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
181 |
182 | return output
183 |
184 |
185 | class TemporalTransformerBlock(nn.Module):
186 | def __init__(
187 | self,
188 | dim,
189 | num_attention_heads,
190 | attention_head_dim,
191 | attention_block_types=(
192 | "Temporal_Self",
193 | "Temporal_Self",
194 | ),
195 | dropout=0.0,
196 | norm_num_groups=32,
197 | cross_attention_dim=768,
198 | activation_fn="geglu",
199 | attention_bias=False,
200 | upcast_attention=False,
201 | cross_frame_attention_mode=None,
202 | temporal_position_encoding=False,
203 | temporal_position_encoding_max_len=24,
204 | ):
205 | super().__init__()
206 |
207 | attention_blocks = []
208 | norms = []
209 |
210 | for block_name in attention_block_types:
211 | attention_blocks.append(
212 | VersatileAttention(
213 | attention_mode=block_name.split("_")[0],
214 | cross_attention_dim=cross_attention_dim
215 | if block_name.endswith("_Cross")
216 | else None,
217 | query_dim=dim,
218 | heads=num_attention_heads,
219 | dim_head=attention_head_dim,
220 | dropout=dropout,
221 | bias=attention_bias,
222 | upcast_attention=upcast_attention,
223 | cross_frame_attention_mode=cross_frame_attention_mode,
224 | temporal_position_encoding=temporal_position_encoding,
225 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
226 | )
227 | )
228 | norms.append(nn.LayerNorm(dim))
229 |
230 | self.attention_blocks = nn.ModuleList(attention_blocks)
231 | self.norms = nn.ModuleList(norms)
232 |
233 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
234 | self.ff_norm = nn.LayerNorm(dim)
235 |
236 | def forward(
237 | self,
238 | hidden_states,
239 | encoder_hidden_states=None,
240 | attention_mask=None,
241 | video_length=None,
242 | ):
243 | for attention_block, norm in zip(self.attention_blocks, self.norms):
244 | norm_hidden_states = norm(hidden_states)
245 | hidden_states = (
246 | attention_block(
247 | norm_hidden_states,
248 | encoder_hidden_states=encoder_hidden_states
249 | if attention_block.is_cross_attention
250 | else None,
251 | video_length=video_length,
252 | )
253 | + hidden_states
254 | )
255 |
256 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
257 |
258 | output = hidden_states
259 | return output
260 |
261 |
262 | class PositionalEncoding(nn.Module):
263 | def __init__(self, d_model, dropout=0.0, max_len=24):
264 | super().__init__()
265 | self.dropout = nn.Dropout(p=dropout)
266 | position = torch.arange(max_len).unsqueeze(1)
267 | div_term = torch.exp(
268 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
269 | )
270 | pe = torch.zeros(1, max_len, d_model)
271 | pe[0, :, 0::2] = torch.sin(position * div_term)
272 | pe[0, :, 1::2] = torch.cos(position * div_term)
273 | self.register_buffer("pe", pe)
274 |
275 | def forward(self, x):
276 | x = x + self.pe[:, : x.size(1)]
277 | return self.dropout(x)
278 |
279 |
280 | class VersatileAttention(Attention):
281 | def __init__(
282 | self,
283 | attention_mode=None,
284 | cross_frame_attention_mode=None,
285 | temporal_position_encoding=False,
286 | temporal_position_encoding_max_len=24,
287 | *args,
288 | **kwargs,
289 | ):
290 | super().__init__(*args, **kwargs)
291 | assert attention_mode == "Temporal"
292 |
293 | self.attention_mode = attention_mode
294 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None
295 |
296 | self.pos_encoder = (
297 | PositionalEncoding(
298 | kwargs["query_dim"],
299 | dropout=0.0,
300 | max_len=temporal_position_encoding_max_len,
301 | )
302 | if (temporal_position_encoding and attention_mode == "Temporal")
303 | else None
304 | )
305 |
306 | def extra_repr(self):
307 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
308 |
309 | def set_use_memory_efficient_attention_xformers(
310 | self,
311 | use_memory_efficient_attention_xformers: bool,
312 | attention_op: Optional[Callable] = None,
313 | ):
314 | if use_memory_efficient_attention_xformers:
315 | if not is_xformers_available():
316 | raise ModuleNotFoundError(
317 | (
318 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
319 | " xformers"
320 | ),
321 | name="xformers",
322 | )
323 | elif not torch.cuda.is_available():
324 | raise ValueError(
325 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
326 | " only available for GPU "
327 | )
328 | else:
329 | try:
330 | # Make sure we can run the memory efficient attention
331 | _ = xformers.ops.memory_efficient_attention(
332 | torch.randn((1, 2, 40), device="cuda"),
333 | torch.randn((1, 2, 40), device="cuda"),
334 | torch.randn((1, 2, 40), device="cuda"),
335 | )
336 | except Exception as e:
337 | raise e
338 |
339 | # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13.
340 | # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13.
341 | # You don't need XFormersAttnProcessor here.
342 | # processor = XFormersAttnProcessor(
343 | # attention_op=attention_op,
344 | # )
345 | processor = AttnProcessor()
346 | else:
347 | processor = AttnProcessor()
348 |
349 | self.set_processor(processor)
350 |
351 | def forward(
352 | self,
353 | hidden_states,
354 | encoder_hidden_states=None,
355 | attention_mask=None,
356 | video_length=None,
357 | **cross_attention_kwargs,
358 | ):
359 | if self.attention_mode == "Temporal":
360 | d = hidden_states.shape[1] # d means HxW
361 | hidden_states = rearrange(
362 | hidden_states, "(b f) d c -> (b d) f c", f=video_length
363 | )
364 |
365 | if self.pos_encoder is not None:
366 | hidden_states = self.pos_encoder(hidden_states)
367 |
368 | encoder_hidden_states = (
369 | repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
370 | if encoder_hidden_states is not None
371 | else encoder_hidden_states
372 | )
373 |
374 | else:
375 | raise NotImplementedError
376 |
377 | hidden_states = self.processor(
378 | self,
379 | hidden_states,
380 | encoder_hidden_states=encoder_hidden_states,
381 | attention_mask=attention_mask,
382 | **cross_attention_kwargs,
383 | )
384 |
385 | if self.attention_mode == "Temporal":
386 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
387 |
388 | return hidden_states
389 |
--------------------------------------------------------------------------------
/src/models/mutual_self_attention.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py
2 | from typing import Any, Dict, Optional
3 |
4 | import torch
5 | from einops import rearrange
6 |
7 | from src.models.attention import TemporalBasicTransformerBlock
8 |
9 | from .attention import BasicTransformerBlock
10 |
11 |
12 | def torch_dfs(model: torch.nn.Module):
13 | result = [model]
14 | for child in model.children():
15 | result += torch_dfs(child)
16 | return result
17 |
18 |
19 | class ReferenceAttentionControl:
20 | def __init__(
21 | self,
22 | unet,
23 | mode="write",
24 | do_classifier_free_guidance=False,
25 | attention_auto_machine_weight=float("inf"),
26 | gn_auto_machine_weight=1.0,
27 | style_fidelity=1.0,
28 | reference_attn=True,
29 | reference_adain=False,
30 | fusion_blocks="midup",
31 | batch_size=1,
32 | ) -> None:
33 | # 10. Modify self attention and group norm
34 | self.unet = unet
35 | assert mode in ["read", "write"]
36 | assert fusion_blocks in ["midup", "full"]
37 | self.reference_attn = reference_attn
38 | self.reference_adain = reference_adain
39 | self.fusion_blocks = fusion_blocks
40 | self.register_reference_hooks(
41 | mode,
42 | do_classifier_free_guidance,
43 | attention_auto_machine_weight,
44 | gn_auto_machine_weight,
45 | style_fidelity,
46 | reference_attn,
47 | reference_adain,
48 | fusion_blocks,
49 | batch_size=batch_size,
50 | )
51 |
52 | def register_reference_hooks(
53 | self,
54 | mode,
55 | do_classifier_free_guidance,
56 | attention_auto_machine_weight,
57 | gn_auto_machine_weight,
58 | style_fidelity,
59 | reference_attn,
60 | reference_adain,
61 | dtype=torch.float16,
62 | batch_size=1,
63 | num_images_per_prompt=1,
64 | device=torch.device("cpu"),
65 | fusion_blocks="midup",
66 | ):
67 | MODE = mode
68 | do_classifier_free_guidance = do_classifier_free_guidance
69 | attention_auto_machine_weight = attention_auto_machine_weight
70 | gn_auto_machine_weight = gn_auto_machine_weight
71 | style_fidelity = style_fidelity
72 | reference_attn = reference_attn
73 | reference_adain = reference_adain
74 | fusion_blocks = fusion_blocks
75 | num_images_per_prompt = num_images_per_prompt
76 | dtype = dtype
77 | if do_classifier_free_guidance:
78 | uc_mask = (
79 | torch.Tensor(
80 | [1] * batch_size * num_images_per_prompt * 16
81 | + [0] * batch_size * num_images_per_prompt * 16
82 | )
83 | .to(device)
84 | .bool()
85 | )
86 | else:
87 | uc_mask = (
88 | torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
89 | .to(device)
90 | .bool()
91 | )
92 |
93 | def hacked_basic_transformer_inner_forward(
94 | self,
95 | hidden_states: torch.FloatTensor,
96 | attention_mask: Optional[torch.FloatTensor] = None,
97 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
98 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
99 | timestep: Optional[torch.LongTensor] = None,
100 | cross_attention_kwargs: Dict[str, Any] = None,
101 | class_labels: Optional[torch.LongTensor] = None,
102 | video_length=None,
103 | ):
104 | if self.use_ada_layer_norm: # False
105 | norm_hidden_states = self.norm1(hidden_states, timestep)
106 | elif self.use_ada_layer_norm_zero:
107 | (
108 | norm_hidden_states,
109 | gate_msa,
110 | shift_mlp,
111 | scale_mlp,
112 | gate_mlp,
113 | ) = self.norm1(
114 | hidden_states,
115 | timestep,
116 | class_labels,
117 | hidden_dtype=hidden_states.dtype,
118 | )
119 | else:
120 | norm_hidden_states = self.norm1(hidden_states)
121 |
122 | # 1. Self-Attention
123 | # self.only_cross_attention = False
124 | cross_attention_kwargs = (
125 | cross_attention_kwargs if cross_attention_kwargs is not None else {}
126 | )
127 | if self.only_cross_attention:
128 | attn_output = self.attn1(
129 | norm_hidden_states,
130 | encoder_hidden_states=encoder_hidden_states
131 | if self.only_cross_attention
132 | else None,
133 | attention_mask=attention_mask,
134 | **cross_attention_kwargs,
135 | )
136 | else:
137 | if MODE == "write":
138 | self.bank.append(norm_hidden_states.clone())
139 | attn_output = self.attn1(
140 | norm_hidden_states,
141 | encoder_hidden_states=encoder_hidden_states
142 | if self.only_cross_attention
143 | else None,
144 | attention_mask=attention_mask,
145 | **cross_attention_kwargs,
146 | )
147 | if MODE == "read":
148 | bank_fea = [
149 | rearrange(
150 | d.unsqueeze(1).repeat(1, video_length, 1, 1),
151 | "b t l c -> (b t) l c",
152 | )
153 | for d in self.bank
154 | ]
155 | modify_norm_hidden_states = torch.cat(
156 | [norm_hidden_states] + bank_fea, dim=1
157 | )
158 | hidden_states_uc = (
159 | self.attn1(
160 | norm_hidden_states,
161 | encoder_hidden_states=modify_norm_hidden_states,
162 | attention_mask=attention_mask,
163 | )
164 | + hidden_states
165 | )
166 | if do_classifier_free_guidance:
167 | hidden_states_c = hidden_states_uc.clone()
168 | _uc_mask = uc_mask.clone()
169 | if hidden_states.shape[0] != _uc_mask.shape[0]:
170 | _uc_mask = (
171 | torch.Tensor(
172 | [1] * (hidden_states.shape[0] // 2)
173 | + [0] * (hidden_states.shape[0] // 2)
174 | )
175 | .to(device)
176 | .bool()
177 | )
178 | hidden_states_c[_uc_mask] = (
179 | self.attn1(
180 | norm_hidden_states[_uc_mask],
181 | encoder_hidden_states=norm_hidden_states[_uc_mask],
182 | attention_mask=attention_mask,
183 | )
184 | + hidden_states[_uc_mask]
185 | )
186 | hidden_states = hidden_states_c.clone()
187 | else:
188 | hidden_states = hidden_states_uc
189 |
190 | # self.bank.clear()
191 | if self.attn2 is not None:
192 | # Cross-Attention
193 | norm_hidden_states = (
194 | self.norm2(hidden_states, timestep)
195 | if self.use_ada_layer_norm
196 | else self.norm2(hidden_states)
197 | )
198 | hidden_states = (
199 | self.attn2(
200 | norm_hidden_states,
201 | encoder_hidden_states=encoder_hidden_states,
202 | attention_mask=attention_mask,
203 | )
204 | + hidden_states
205 | )
206 |
207 | # Feed-forward
208 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
209 |
210 | # Temporal-Attention
211 | if self.unet_use_temporal_attention:
212 | d = hidden_states.shape[1]
213 | hidden_states = rearrange(
214 | hidden_states, "(b f) d c -> (b d) f c", f=video_length
215 | )
216 | norm_hidden_states = (
217 | self.norm_temp(hidden_states, timestep)
218 | if self.use_ada_layer_norm
219 | else self.norm_temp(hidden_states)
220 | )
221 | hidden_states = (
222 | self.attn_temp(norm_hidden_states) + hidden_states
223 | )
224 | hidden_states = rearrange(
225 | hidden_states, "(b d) f c -> (b f) d c", d=d
226 | )
227 |
228 | return hidden_states
229 |
230 | if self.use_ada_layer_norm_zero:
231 | attn_output = gate_msa.unsqueeze(1) * attn_output
232 | hidden_states = attn_output + hidden_states
233 |
234 | if self.attn2 is not None:
235 | norm_hidden_states = (
236 | self.norm2(hidden_states, timestep)
237 | if self.use_ada_layer_norm
238 | else self.norm2(hidden_states)
239 | )
240 |
241 | # 2. Cross-Attention
242 | attn_output = self.attn2(
243 | norm_hidden_states,
244 | encoder_hidden_states=encoder_hidden_states,
245 | attention_mask=encoder_attention_mask,
246 | **cross_attention_kwargs,
247 | )
248 | hidden_states = attn_output + hidden_states
249 |
250 | # 3. Feed-forward
251 | norm_hidden_states = self.norm3(hidden_states)
252 |
253 | if self.use_ada_layer_norm_zero:
254 | norm_hidden_states = (
255 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
256 | )
257 |
258 | ff_output = self.ff(norm_hidden_states)
259 |
260 | if self.use_ada_layer_norm_zero:
261 | ff_output = gate_mlp.unsqueeze(1) * ff_output
262 |
263 | hidden_states = ff_output + hidden_states
264 |
265 | return hidden_states
266 |
267 | if self.reference_attn:
268 | if self.fusion_blocks == "midup":
269 | attn_modules = [
270 | module
271 | for module in (
272 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
273 | )
274 | if isinstance(module, BasicTransformerBlock)
275 | or isinstance(module, TemporalBasicTransformerBlock)
276 | ]
277 | elif self.fusion_blocks == "full":
278 | attn_modules = [
279 | module
280 | for module in torch_dfs(self.unet)
281 | if isinstance(module, BasicTransformerBlock)
282 | or isinstance(module, TemporalBasicTransformerBlock)
283 | ]
284 | attn_modules = sorted(
285 | attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
286 | )
287 |
288 | for i, module in enumerate(attn_modules):
289 | module._original_inner_forward = module.forward
290 | if isinstance(module, BasicTransformerBlock):
291 | module.forward = hacked_basic_transformer_inner_forward.__get__(
292 | module, BasicTransformerBlock
293 | )
294 | if isinstance(module, TemporalBasicTransformerBlock):
295 | module.forward = hacked_basic_transformer_inner_forward.__get__(
296 | module, TemporalBasicTransformerBlock
297 | )
298 |
299 | module.bank = []
300 | module.attn_weight = float(i) / float(len(attn_modules))
301 |
302 | def update(self, writer, dtype=torch.float16):
303 | if self.reference_attn:
304 | if self.fusion_blocks == "midup":
305 | reader_attn_modules = [
306 | module
307 | for module in (
308 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
309 | )
310 | if isinstance(module, TemporalBasicTransformerBlock)
311 | ]
312 | writer_attn_modules = [
313 | module
314 | for module in (
315 | torch_dfs(writer.unet.mid_block)
316 | + torch_dfs(writer.unet.up_blocks)
317 | )
318 | if isinstance(module, BasicTransformerBlock)
319 | ]
320 | elif self.fusion_blocks == "full":
321 | reader_attn_modules = [
322 | module
323 | for module in torch_dfs(self.unet)
324 | if isinstance(module, TemporalBasicTransformerBlock)
325 | ]
326 | writer_attn_modules = [
327 | module
328 | for module in torch_dfs(writer.unet)
329 | if isinstance(module, BasicTransformerBlock)
330 | ]
331 | reader_attn_modules = sorted(
332 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
333 | )
334 | writer_attn_modules = sorted(
335 | writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
336 | )
337 | for r, w in zip(reader_attn_modules, writer_attn_modules):
338 | r.bank = [v.clone().to(dtype) for v in w.bank]
339 | # w.bank.clear()
340 |
341 | def clear(self):
342 | if self.reference_attn:
343 | if self.fusion_blocks == "midup":
344 | reader_attn_modules = [
345 | module
346 | for module in (
347 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
348 | )
349 | if isinstance(module, BasicTransformerBlock)
350 | or isinstance(module, TemporalBasicTransformerBlock)
351 | ]
352 | elif self.fusion_blocks == "full":
353 | reader_attn_modules = [
354 | module
355 | for module in torch_dfs(self.unet)
356 | if isinstance(module, BasicTransformerBlock)
357 | or isinstance(module, TemporalBasicTransformerBlock)
358 | ]
359 | reader_attn_modules = sorted(
360 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
361 | )
362 | for r in reader_attn_modules:
363 | r.bank.clear()
364 |
--------------------------------------------------------------------------------
/src/models/pose_guider.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.init as init
5 | from einops import rearrange
6 | import numpy as np
7 | from diffusers.models.modeling_utils import ModelMixin
8 |
9 | from typing import Any, Dict, Optional
10 | from src.models.attention import BasicTransformerBlock
11 |
12 |
13 | class PoseGuider(ModelMixin):
14 | def __init__(self, noise_latent_channels=320, use_ca=True):
15 | super(PoseGuider, self).__init__()
16 |
17 | self.use_ca = use_ca
18 |
19 | self.conv_layers = nn.Sequential(
20 | nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1),
21 | nn.BatchNorm2d(3),
22 | nn.ReLU(),
23 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1),
24 | nn.BatchNorm2d(16),
25 | nn.ReLU(),
26 |
27 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1),
28 | nn.BatchNorm2d(16),
29 | nn.ReLU(),
30 | nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1),
31 | nn.BatchNorm2d(32),
32 | nn.ReLU(),
33 |
34 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
35 | nn.BatchNorm2d(32),
36 | nn.ReLU(),
37 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
38 | nn.BatchNorm2d(64),
39 | nn.ReLU(),
40 |
41 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
42 | nn.BatchNorm2d(64),
43 | nn.ReLU(),
44 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
45 | nn.BatchNorm2d(128),
46 | nn.ReLU()
47 | )
48 |
49 | # Final projection layer
50 | self.final_proj = nn.Conv2d(in_channels=128, out_channels=noise_latent_channels, kernel_size=1)
51 |
52 | self.conv_layers_1 = nn.Sequential(
53 | nn.Conv2d(in_channels=noise_latent_channels, out_channels=noise_latent_channels, kernel_size=3, padding=1),
54 | nn.BatchNorm2d(noise_latent_channels),
55 | nn.ReLU(),
56 | nn.Conv2d(in_channels=noise_latent_channels, out_channels=noise_latent_channels, kernel_size=3, stride=2, padding=1),
57 | nn.BatchNorm2d(noise_latent_channels),
58 | nn.ReLU(),
59 | )
60 |
61 | self.conv_layers_2 = nn.Sequential(
62 | nn.Conv2d(in_channels=noise_latent_channels, out_channels=noise_latent_channels, kernel_size=3, padding=1),
63 | nn.BatchNorm2d(noise_latent_channels),
64 | nn.ReLU(),
65 | nn.Conv2d(in_channels=noise_latent_channels, out_channels=noise_latent_channels*2, kernel_size=3, stride=2, padding=1),
66 | nn.BatchNorm2d(noise_latent_channels*2),
67 | nn.ReLU(),
68 | )
69 |
70 | self.conv_layers_3 = nn.Sequential(
71 | nn.Conv2d(in_channels=noise_latent_channels*2, out_channels=noise_latent_channels*2, kernel_size=3, padding=1),
72 | nn.BatchNorm2d(noise_latent_channels*2),
73 | nn.ReLU(),
74 | nn.Conv2d(in_channels=noise_latent_channels*2, out_channels=noise_latent_channels*4, kernel_size=3, stride=2, padding=1),
75 | nn.BatchNorm2d(noise_latent_channels*4),
76 | nn.ReLU(),
77 | )
78 |
79 | self.conv_layers_4 = nn.Sequential(
80 | nn.Conv2d(in_channels=noise_latent_channels*4, out_channels=noise_latent_channels*4, kernel_size=3, padding=1),
81 | nn.BatchNorm2d(noise_latent_channels*4),
82 | nn.ReLU(),
83 | )
84 |
85 | if self.use_ca:
86 | self.cross_attn1 = Transformer2DModel(in_channels=noise_latent_channels)
87 | self.cross_attn2 = Transformer2DModel(in_channels=noise_latent_channels*2)
88 | self.cross_attn3 = Transformer2DModel(in_channels=noise_latent_channels*4)
89 | self.cross_attn4 = Transformer2DModel(in_channels=noise_latent_channels*4)
90 |
91 | # Initialize layers
92 | self._initialize_weights()
93 |
94 | self.scale = nn.Parameter(torch.ones(1) * 2)
95 |
96 | # def _initialize_weights(self):
97 | # # Initialize weights with Gaussian distribution and zero out the final layer
98 | # for m in self.conv_layers:
99 | # if isinstance(m, nn.Conv2d):
100 | # init.normal_(m.weight, mean=0.0, std=0.02)
101 | # if m.bias is not None:
102 | # init.zeros_(m.bias)
103 |
104 | # init.zeros_(self.final_proj.weight)
105 | # if self.final_proj.bias is not None:
106 | # init.zeros_(self.final_proj.bias)
107 |
108 | def _initialize_weights(self):
109 | # Initialize weights with He initialization and zero out the biases
110 | conv_blocks = [self.conv_layers, self.conv_layers_1, self.conv_layers_2, self.conv_layers_3, self.conv_layers_4]
111 | for block_item in conv_blocks:
112 | for m in block_item:
113 | if isinstance(m, nn.Conv2d):
114 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
115 | init.normal_(m.weight, mean=0.0, std=np.sqrt(2. / n))
116 | if m.bias is not None:
117 | init.zeros_(m.bias)
118 |
119 | # For the final projection layer, initialize weights to zero (or you may choose to use He initialization here as well)
120 | init.zeros_(self.final_proj.weight)
121 | if self.final_proj.bias is not None:
122 | init.zeros_(self.final_proj.bias)
123 |
124 | def forward(self, x, ref_x):
125 | fea = []
126 | b = x.shape[0]
127 |
128 | x = rearrange(x, "b c f h w -> (b f) c h w")
129 | x = self.conv_layers(x)
130 | x = self.final_proj(x)
131 | x = x * self.scale
132 | # x = rearrange(x, "(b f) c h w -> b c f h w", b=b)
133 | fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b))
134 |
135 | x = self.conv_layers_1(x)
136 | if self.use_ca:
137 | ref_x = self.conv_layers(ref_x)
138 | ref_x = self.final_proj(ref_x)
139 | ref_x = ref_x * self.scale
140 | ref_x = self.conv_layers_1(ref_x)
141 | x = self.cross_attn1(x, ref_x)
142 | fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b))
143 |
144 | x = self.conv_layers_2(x)
145 | if self.use_ca:
146 | ref_x = self.conv_layers_2(ref_x)
147 | x = self.cross_attn2(x, ref_x)
148 | fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b))
149 |
150 | x = self.conv_layers_3(x)
151 | if self.use_ca:
152 | ref_x = self.conv_layers_3(ref_x)
153 | x = self.cross_attn3(x, ref_x)
154 | fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b))
155 |
156 | x = self.conv_layers_4(x)
157 | if self.use_ca:
158 | ref_x = self.conv_layers_4(ref_x)
159 | x = self.cross_attn4(x, ref_x)
160 | fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b))
161 |
162 | return fea
163 |
164 | @classmethod
165 | def from_pretrained(cls,pretrained_model_path):
166 | if not os.path.exists(pretrained_model_path):
167 | print(f"There is no model file in {pretrained_model_path}")
168 | print(f"loaded PoseGuider's pretrained weights from {pretrained_model_path} ...")
169 |
170 | state_dict = torch.load(pretrained_model_path, map_location="cpu")
171 | model = Hack_PoseGuider(noise_latent_channels=320)
172 |
173 | m, u = model.load_state_dict(state_dict, strict=True)
174 | # print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
175 | params = [p.numel() for n, p in model.named_parameters()]
176 | print(f"### PoseGuider's Parameters: {sum(params) / 1e6} M")
177 |
178 | return model
179 |
180 |
181 | class Transformer2DModel(ModelMixin):
182 | _supports_gradient_checkpointing = True
183 | def __init__(
184 | self,
185 | num_attention_heads: int = 16,
186 | attention_head_dim: int = 88,
187 | in_channels: Optional[int] = None,
188 | num_layers: int = 1,
189 | dropout: float = 0.0,
190 | norm_num_groups: int = 32,
191 | cross_attention_dim: Optional[int] = None,
192 | attention_bias: bool = False,
193 | activation_fn: str = "geglu",
194 | num_embeds_ada_norm: Optional[int] = None,
195 | use_linear_projection: bool = False,
196 | only_cross_attention: bool = False,
197 | double_self_attention: bool = False,
198 | upcast_attention: bool = False,
199 | norm_type: str = "layer_norm",
200 | norm_elementwise_affine: bool = True,
201 | norm_eps: float = 1e-5,
202 | attention_type: str = "default",
203 | ):
204 | super().__init__()
205 | self.use_linear_projection = use_linear_projection
206 | self.num_attention_heads = num_attention_heads
207 | self.attention_head_dim = attention_head_dim
208 | inner_dim = num_attention_heads * attention_head_dim
209 |
210 | self.in_channels = in_channels
211 |
212 | self.norm = torch.nn.GroupNorm(
213 | num_groups=norm_num_groups,
214 | num_channels=in_channels,
215 | eps=1e-6,
216 | affine=True,
217 | )
218 | if use_linear_projection:
219 | self.proj_in = nn.Linear(in_channels, inner_dim)
220 | else:
221 | self.proj_in = nn.Conv2d(
222 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0
223 | )
224 |
225 | # 3. Define transformers blocks
226 | self.transformer_blocks = nn.ModuleList(
227 | [
228 | BasicTransformerBlock(
229 | inner_dim,
230 | num_attention_heads,
231 | attention_head_dim,
232 | dropout=dropout,
233 | cross_attention_dim=cross_attention_dim,
234 | activation_fn=activation_fn,
235 | num_embeds_ada_norm=num_embeds_ada_norm,
236 | attention_bias=attention_bias,
237 | only_cross_attention=only_cross_attention,
238 | double_self_attention=double_self_attention,
239 | upcast_attention=upcast_attention,
240 | norm_type=norm_type,
241 | norm_elementwise_affine=norm_elementwise_affine,
242 | norm_eps=norm_eps,
243 | attention_type=attention_type,
244 | )
245 | for d in range(num_layers)
246 | ]
247 | )
248 |
249 | if use_linear_projection:
250 | self.proj_out = nn.Linear(inner_dim, in_channels)
251 | else:
252 | self.proj_out = nn.Conv2d(
253 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0
254 | )
255 |
256 | self.gradient_checkpointing = False
257 |
258 | def _set_gradient_checkpointing(self, module, value=False):
259 | if hasattr(module, "gradient_checkpointing"):
260 | module.gradient_checkpointing = value
261 |
262 | def forward(
263 | self,
264 | hidden_states: torch.Tensor,
265 | encoder_hidden_states: Optional[torch.Tensor] = None,
266 | timestep: Optional[torch.LongTensor] = None,
267 | ):
268 | batch, _, height, width = hidden_states.shape
269 | residual = hidden_states
270 |
271 | hidden_states = self.norm(hidden_states)
272 | if not self.use_linear_projection:
273 | hidden_states = self.proj_in(hidden_states)
274 | inner_dim = hidden_states.shape[1]
275 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
276 | batch, height * width, inner_dim
277 | )
278 | else:
279 | inner_dim = hidden_states.shape[1]
280 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
281 | batch, height * width, inner_dim
282 | )
283 | hidden_states = self.proj_in(hidden_states)
284 |
285 | for block in self.transformer_blocks:
286 | hidden_states = block(
287 | hidden_states,
288 | encoder_hidden_states=encoder_hidden_states,
289 | timestep=timestep,
290 | )
291 |
292 | if not self.use_linear_projection:
293 | hidden_states = (
294 | hidden_states.reshape(batch, height, width, inner_dim)
295 | .permute(0, 3, 1, 2)
296 | .contiguous()
297 | )
298 | hidden_states = self.proj_out(hidden_states)
299 | else:
300 | hidden_states = self.proj_out(hidden_states)
301 | hidden_states = (
302 | hidden_states.reshape(batch, height, width, inner_dim)
303 | .permute(0, 3, 1, 2)
304 | .contiguous()
305 | )
306 |
307 | output = hidden_states + residual
308 | return output
309 |
310 |
311 | if __name__ == '__main__':
312 | model = PoseGuider(noise_latent_channels=320).to(device="cuda")
313 |
314 | input_data = torch.randn(1,3,1,512,512).to(device="cuda")
315 | input_data1 = torch.randn(1,3,512,512).to(device="cuda")
316 |
317 | output = model(input_data, input_data1)
318 | for item in output:
319 | print(item.shape)
320 |
321 | # tf_model = Transformer2DModel(
322 | # in_channels=320
323 | # ).to('cuda')
324 |
325 | # input_data = torch.randn(4,320,32,32).to(device="cuda")
326 | # # input_emb = torch.randn(4,1,768).to(device="cuda")
327 | # input_emb = torch.randn(4,320,32,32).to(device="cuda")
328 | # o1 = tf_model(input_data, input_emb)
329 | # print(o1.shape)
330 |
--------------------------------------------------------------------------------
/src/models/resnet.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from einops import rearrange
7 | from typing import Dict, Optional
8 |
9 |
10 | class InflatedConv3d(nn.Conv2d):
11 | def forward(self, x):
12 | video_length = x.shape[2]
13 |
14 | x = rearrange(x, "b c f h w -> (b f) c h w")
15 | x = super().forward(x)
16 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
17 |
18 | return x
19 |
20 |
21 | class InflatedGroupNorm(nn.GroupNorm):
22 | def forward(self, x):
23 | video_length = x.shape[2]
24 |
25 | x = rearrange(x, "b c f h w -> (b f) c h w")
26 | x = super().forward(x)
27 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
28 |
29 | return x
30 |
31 |
32 | class Upsample3D(nn.Module):
33 | def __init__(
34 | self,
35 | channels,
36 | use_conv=False,
37 | use_conv_transpose=False,
38 | out_channels=None,
39 | name="conv",
40 | ):
41 | super().__init__()
42 | self.channels = channels
43 | self.out_channels = out_channels or channels
44 | self.use_conv = use_conv
45 | self.use_conv_transpose = use_conv_transpose
46 | self.name = name
47 |
48 | conv = None
49 | if use_conv_transpose:
50 | raise NotImplementedError
51 | elif use_conv:
52 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
53 |
54 | def forward(self, hidden_states, output_size=None):
55 | assert hidden_states.shape[1] == self.channels
56 |
57 | if self.use_conv_transpose:
58 | raise NotImplementedError
59 |
60 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
61 | dtype = hidden_states.dtype
62 | if dtype == torch.bfloat16:
63 | hidden_states = hidden_states.to(torch.float32)
64 |
65 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
66 | if hidden_states.shape[0] >= 64:
67 | hidden_states = hidden_states.contiguous()
68 |
69 | # if `output_size` is passed we force the interpolation output
70 | # size and do not make use of `scale_factor=2`
71 | if output_size is None:
72 | hidden_states = F.interpolate(
73 | hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
74 | )
75 | else:
76 | hidden_states = F.interpolate(
77 | hidden_states, size=output_size, mode="nearest"
78 | )
79 |
80 | # If the input is bfloat16, we cast back to bfloat16
81 | if dtype == torch.bfloat16:
82 | hidden_states = hidden_states.to(dtype)
83 |
84 | # if self.use_conv:
85 | # if self.name == "conv":
86 | # hidden_states = self.conv(hidden_states)
87 | # else:
88 | # hidden_states = self.Conv2d_0(hidden_states)
89 | hidden_states = self.conv(hidden_states)
90 |
91 | return hidden_states
92 |
93 |
94 | class Downsample3D(nn.Module):
95 | def __init__(
96 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
97 | ):
98 | super().__init__()
99 | self.channels = channels
100 | self.out_channels = out_channels or channels
101 | self.use_conv = use_conv
102 | self.padding = padding
103 | stride = 2
104 | self.name = name
105 |
106 | if use_conv:
107 | self.conv = InflatedConv3d(
108 | self.channels, self.out_channels, 3, stride=stride, padding=padding
109 | )
110 | else:
111 | raise NotImplementedError
112 |
113 | def forward(self, hidden_states):
114 | assert hidden_states.shape[1] == self.channels
115 | if self.use_conv and self.padding == 0:
116 | raise NotImplementedError
117 |
118 | assert hidden_states.shape[1] == self.channels
119 | hidden_states = self.conv(hidden_states)
120 |
121 | return hidden_states
122 |
123 |
124 | class ResnetBlock3D(nn.Module):
125 | def __init__(
126 | self,
127 | *,
128 | in_channels,
129 | out_channels=None,
130 | conv_shortcut=False,
131 | dropout=0.0,
132 | temb_channels=512,
133 | groups=32,
134 | groups_out=None,
135 | pre_norm=True,
136 | eps=1e-6,
137 | non_linearity="swish",
138 | time_embedding_norm="default",
139 | output_scale_factor=1.0,
140 | use_in_shortcut=None,
141 | use_inflated_groupnorm=None,
142 | ):
143 | super().__init__()
144 | self.pre_norm = pre_norm
145 | self.pre_norm = True
146 | self.in_channels = in_channels
147 | out_channels = in_channels if out_channels is None else out_channels
148 | self.out_channels = out_channels
149 | self.use_conv_shortcut = conv_shortcut
150 | self.time_embedding_norm = time_embedding_norm
151 | self.output_scale_factor = output_scale_factor
152 |
153 | if groups_out is None:
154 | groups_out = groups
155 |
156 | assert use_inflated_groupnorm != None
157 | if use_inflated_groupnorm:
158 | self.norm1 = InflatedGroupNorm(
159 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True
160 | )
161 | else:
162 | self.norm1 = torch.nn.GroupNorm(
163 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True
164 | )
165 |
166 | self.conv1 = InflatedConv3d(
167 | in_channels, out_channels, kernel_size=3, stride=1, padding=1
168 | )
169 |
170 | if temb_channels is not None:
171 | if self.time_embedding_norm == "default":
172 | time_emb_proj_out_channels = out_channels
173 | elif self.time_embedding_norm == "scale_shift":
174 | time_emb_proj_out_channels = out_channels * 2
175 | else:
176 | raise ValueError(
177 | f"unknown time_embedding_norm : {self.time_embedding_norm} "
178 | )
179 |
180 | self.time_emb_proj = torch.nn.Linear(
181 | temb_channels, time_emb_proj_out_channels
182 | )
183 | else:
184 | self.time_emb_proj = None
185 |
186 | if use_inflated_groupnorm:
187 | self.norm2 = InflatedGroupNorm(
188 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
189 | )
190 | else:
191 | self.norm2 = torch.nn.GroupNorm(
192 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
193 | )
194 | self.dropout = torch.nn.Dropout(dropout)
195 | self.conv2 = InflatedConv3d(
196 | out_channels, out_channels, kernel_size=3, stride=1, padding=1
197 | )
198 |
199 | if non_linearity == "swish":
200 | self.nonlinearity = lambda x: F.silu(x)
201 | elif non_linearity == "mish":
202 | self.nonlinearity = Mish()
203 | elif non_linearity == "silu":
204 | self.nonlinearity = nn.SiLU()
205 |
206 | self.use_in_shortcut = (
207 | self.in_channels != self.out_channels
208 | if use_in_shortcut is None
209 | else use_in_shortcut
210 | )
211 |
212 | self.conv_shortcut = None
213 | if self.use_in_shortcut:
214 | self.conv_shortcut = InflatedConv3d(
215 | in_channels, out_channels, kernel_size=1, stride=1, padding=0
216 | )
217 |
218 | def forward(self, input_tensor, temb):
219 | hidden_states = input_tensor
220 |
221 | hidden_states = self.norm1(hidden_states)
222 | hidden_states = self.nonlinearity(hidden_states)
223 |
224 | hidden_states = self.conv1(hidden_states)
225 |
226 | if temb is not None:
227 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
228 |
229 | if temb is not None and self.time_embedding_norm == "default":
230 | hidden_states = hidden_states + temb
231 |
232 | hidden_states = self.norm2(hidden_states)
233 |
234 | if temb is not None and self.time_embedding_norm == "scale_shift":
235 | scale, shift = torch.chunk(temb, 2, dim=1)
236 | hidden_states = hidden_states * (1 + scale) + shift
237 |
238 | hidden_states = self.nonlinearity(hidden_states)
239 |
240 | hidden_states = self.dropout(hidden_states)
241 | hidden_states = self.conv2(hidden_states)
242 |
243 | if self.conv_shortcut is not None:
244 | input_tensor = self.conv_shortcut(input_tensor)
245 |
246 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
247 |
248 | return output_tensor
249 |
250 | class Mish(torch.nn.Module):
251 | def forward(self, hidden_states):
252 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
253 |
--------------------------------------------------------------------------------
/src/models/transformer_3d.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional, Dict
3 |
4 | import torch
5 | from diffusers.configuration_utils import ConfigMixin, register_to_config
6 | from diffusers.models import ModelMixin
7 | from diffusers.utils import BaseOutput
8 | from diffusers.utils.import_utils import is_xformers_available
9 | from einops import rearrange, repeat
10 | from torch import nn
11 |
12 | from .attention import TemporalBasicTransformerBlock, ResidualTemporalBasicTransformerBlock
13 |
14 |
15 | @dataclass
16 | class Transformer3DModelOutput(BaseOutput):
17 | sample: torch.FloatTensor
18 |
19 |
20 | if is_xformers_available():
21 | import xformers
22 | import xformers.ops
23 | else:
24 | xformers = None
25 |
26 |
27 | class Transformer3DModel(ModelMixin, ConfigMixin):
28 | _supports_gradient_checkpointing = True
29 |
30 | @register_to_config
31 | def __init__(
32 | self,
33 | num_attention_heads: int = 16,
34 | attention_head_dim: int = 88,
35 | in_channels: Optional[int] = None,
36 | num_layers: int = 1,
37 | dropout: float = 0.0,
38 | norm_num_groups: int = 32,
39 | cross_attention_dim: Optional[int] = None,
40 | attention_bias: bool = False,
41 | activation_fn: str = "geglu",
42 | num_embeds_ada_norm: Optional[int] = None,
43 | use_linear_projection: bool = False,
44 | only_cross_attention: bool = False,
45 | upcast_attention: bool = False,
46 | unet_use_cross_frame_attention=None,
47 | unet_use_temporal_attention=None,
48 | ):
49 | super().__init__()
50 | self.use_linear_projection = use_linear_projection
51 | self.num_attention_heads = num_attention_heads
52 | self.attention_head_dim = attention_head_dim
53 | inner_dim = num_attention_heads * attention_head_dim
54 |
55 | # Define input layers
56 | self.in_channels = in_channels
57 |
58 | self.norm = torch.nn.GroupNorm(
59 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
60 | )
61 | if use_linear_projection:
62 | self.proj_in = nn.Linear(in_channels, inner_dim)
63 | else:
64 | self.proj_in = nn.Conv2d(
65 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0
66 | )
67 |
68 | # Define transformers blocks
69 | self.transformer_blocks = nn.ModuleList(
70 | [
71 | TemporalBasicTransformerBlock(
72 | inner_dim,
73 | num_attention_heads,
74 | attention_head_dim,
75 | dropout=dropout,
76 | cross_attention_dim=cross_attention_dim,
77 | activation_fn=activation_fn,
78 | num_embeds_ada_norm=num_embeds_ada_norm,
79 | attention_bias=attention_bias,
80 | only_cross_attention=only_cross_attention,
81 | upcast_attention=upcast_attention,
82 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
83 | unet_use_temporal_attention=unet_use_temporal_attention,
84 | )
85 | for d in range(num_layers)
86 | ]
87 | )
88 |
89 | # 4. Define output layers
90 | if use_linear_projection:
91 | self.proj_out = nn.Linear(in_channels, inner_dim)
92 | else:
93 | self.proj_out = nn.Conv2d(
94 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0
95 | )
96 |
97 | self.gradient_checkpointing = False
98 |
99 | def _set_gradient_checkpointing(self, module, value=False):
100 | if hasattr(module, "gradient_checkpointing"):
101 | module.gradient_checkpointing = value
102 |
103 | def forward(
104 | self,
105 | hidden_states,
106 | encoder_hidden_states=None,
107 | timestep=None,
108 | return_dict: bool = True,
109 | ):
110 | # Input
111 | assert (
112 | hidden_states.dim() == 5
113 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
114 | video_length = hidden_states.shape[2]
115 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
116 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
117 | encoder_hidden_states = repeat(
118 | encoder_hidden_states, "b n c -> (b f) n c", f=video_length
119 | )
120 |
121 | batch, channel, height, weight = hidden_states.shape
122 | residual = hidden_states
123 |
124 | hidden_states = self.norm(hidden_states)
125 | if not self.use_linear_projection:
126 | hidden_states = self.proj_in(hidden_states)
127 | inner_dim = hidden_states.shape[1]
128 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
129 | batch, height * weight, inner_dim
130 | )
131 | else:
132 | inner_dim = hidden_states.shape[1]
133 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
134 | batch, height * weight, inner_dim
135 | )
136 | hidden_states = self.proj_in(hidden_states)
137 |
138 | # Blocks
139 | for i, block in enumerate(self.transformer_blocks):
140 | hidden_states = block(
141 | hidden_states,
142 | encoder_hidden_states=encoder_hidden_states,
143 | timestep=timestep,
144 | video_length=video_length,
145 | )
146 |
147 | # Output
148 | if not self.use_linear_projection:
149 | hidden_states = (
150 | hidden_states.reshape(batch, height, weight, inner_dim)
151 | .permute(0, 3, 1, 2)
152 | .contiguous()
153 | )
154 | hidden_states = self.proj_out(hidden_states)
155 | else:
156 | hidden_states = self.proj_out(hidden_states)
157 | hidden_states = (
158 | hidden_states.reshape(batch, height, weight, inner_dim)
159 | .permute(0, 3, 1, 2)
160 | .contiguous()
161 | )
162 |
163 | output = hidden_states + residual
164 |
165 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
166 | if not return_dict:
167 | return (output,)
168 |
169 | return Transformer3DModelOutput(sample=output)
170 |
--------------------------------------------------------------------------------
/src/pipelines/context.py:
--------------------------------------------------------------------------------
1 | # TODO: Adapted from cli
2 | from typing import Callable, List, Optional
3 |
4 | import numpy as np
5 |
6 |
7 | def ordered_halving(val):
8 | bin_str = f"{val:064b}"
9 | bin_flip = bin_str[::-1]
10 | as_int = int(bin_flip, 2)
11 |
12 | return as_int / (1 << 64)
13 |
14 |
15 | def uniform(
16 | step: int = ...,
17 | num_steps: Optional[int] = None,
18 | num_frames: int = ...,
19 | context_size: Optional[int] = None,
20 | context_stride: int = 3,
21 | context_overlap: int = 4,
22 | closed_loop: bool = True,
23 | ):
24 | if num_frames <= context_size:
25 | yield list(range(num_frames))
26 | return
27 |
28 | context_stride = min(
29 | context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
30 | )
31 |
32 | for context_step in 1 << np.arange(context_stride):
33 | pad = int(round(num_frames * ordered_halving(step)))
34 | for j in range(
35 | int(ordered_halving(step) * context_step) + pad,
36 | num_frames + pad + (0 if closed_loop else -context_overlap),
37 | (context_size * context_step - context_overlap),
38 | ):
39 | yield [
40 | e % num_frames
41 | for e in range(j, j + context_size * context_step, context_step)
42 | ]
43 |
44 |
45 | def get_context_scheduler(name: str) -> Callable:
46 | if name == "uniform":
47 | return uniform
48 | else:
49 | raise ValueError(f"Unknown context_overlap policy {name}")
50 |
51 |
52 | def get_total_steps(
53 | scheduler,
54 | timesteps: List[int],
55 | num_steps: Optional[int] = None,
56 | num_frames: int = ...,
57 | context_size: Optional[int] = None,
58 | context_stride: int = 3,
59 | context_overlap: int = 4,
60 | closed_loop: bool = True,
61 | ):
62 | return sum(
63 | len(
64 | list(
65 | scheduler(
66 | i,
67 | num_steps,
68 | num_frames,
69 | context_size,
70 | context_stride,
71 | context_overlap,
72 | )
73 | )
74 | )
75 | for i in range(len(timesteps))
76 | )
77 |
--------------------------------------------------------------------------------
/src/pipelines/pipeline_pose2img.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from dataclasses import dataclass
3 | from typing import Callable, List, Optional, Union
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | import torchvision.transforms as transforms
9 | from diffusers import DiffusionPipeline
10 | from diffusers.image_processor import VaeImageProcessor
11 | from diffusers.schedulers import (
12 | DDIMScheduler,
13 | DPMSolverMultistepScheduler,
14 | EulerAncestralDiscreteScheduler,
15 | EulerDiscreteScheduler,
16 | LMSDiscreteScheduler,
17 | PNDMScheduler,
18 | )
19 | from diffusers.utils import BaseOutput, is_accelerate_available
20 | from diffusers.utils.torch_utils import randn_tensor
21 | from einops import rearrange
22 | from tqdm import tqdm
23 | from transformers import CLIPImageProcessor
24 |
25 | from src.models.mutual_self_attention import ReferenceAttentionControl
26 |
27 |
28 | @dataclass
29 | class Pose2ImagePipelineOutput(BaseOutput):
30 | images: Union[torch.Tensor, np.ndarray]
31 |
32 |
33 | class Pose2ImagePipeline(DiffusionPipeline):
34 | _optional_components = []
35 |
36 | def __init__(
37 | self,
38 | vae,
39 | image_encoder,
40 | reference_unet,
41 | denoising_unet,
42 | pose_guider,
43 | scheduler: Union[
44 | DDIMScheduler,
45 | PNDMScheduler,
46 | LMSDiscreteScheduler,
47 | EulerDiscreteScheduler,
48 | EulerAncestralDiscreteScheduler,
49 | DPMSolverMultistepScheduler,
50 | ],
51 | ):
52 | super().__init__()
53 |
54 | self.register_modules(
55 | vae=vae,
56 | image_encoder=image_encoder,
57 | reference_unet=reference_unet,
58 | denoising_unet=denoising_unet,
59 | pose_guider=pose_guider,
60 | scheduler=scheduler,
61 | )
62 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
63 | self.clip_image_processor = CLIPImageProcessor()
64 | self.ref_image_processor = VaeImageProcessor(
65 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
66 | )
67 | self.cond_image_processor = VaeImageProcessor(
68 | vae_scale_factor=self.vae_scale_factor,
69 | do_convert_rgb=True,
70 | do_normalize=True,
71 | )
72 |
73 | def enable_vae_slicing(self):
74 | self.vae.enable_slicing()
75 |
76 | def disable_vae_slicing(self):
77 | self.vae.disable_slicing()
78 |
79 | def enable_sequential_cpu_offload(self, gpu_id=0):
80 | if is_accelerate_available():
81 | from accelerate import cpu_offload
82 | else:
83 | raise ImportError("Please install accelerate via `pip install accelerate`")
84 |
85 | device = torch.device(f"cuda:{gpu_id}")
86 |
87 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
88 | if cpu_offloaded_model is not None:
89 | cpu_offload(cpu_offloaded_model, device)
90 |
91 | @property
92 | def _execution_device(self):
93 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
94 | return self.device
95 | for module in self.unet.modules():
96 | if (
97 | hasattr(module, "_hf_hook")
98 | and hasattr(module._hf_hook, "execution_device")
99 | and module._hf_hook.execution_device is not None
100 | ):
101 | return torch.device(module._hf_hook.execution_device)
102 | return self.device
103 |
104 | def decode_latents(self, latents):
105 | video_length = latents.shape[2]
106 | latents = 1 / 0.18215 * latents
107 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
108 | # video = self.vae.decode(latents).sample
109 | video = []
110 | for frame_idx in tqdm(range(latents.shape[0])):
111 | video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
112 | video = torch.cat(video)
113 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
114 | video = (video / 2 + 0.5).clamp(0, 1)
115 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
116 | video = video.cpu().float().numpy()
117 | return video
118 |
119 | def prepare_extra_step_kwargs(self, generator, eta):
120 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
121 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
122 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
123 | # and should be between [0, 1]
124 |
125 | accepts_eta = "eta" in set(
126 | inspect.signature(self.scheduler.step).parameters.keys()
127 | )
128 | extra_step_kwargs = {}
129 | if accepts_eta:
130 | extra_step_kwargs["eta"] = eta
131 |
132 | # check if the scheduler accepts generator
133 | accepts_generator = "generator" in set(
134 | inspect.signature(self.scheduler.step).parameters.keys()
135 | )
136 | if accepts_generator:
137 | extra_step_kwargs["generator"] = generator
138 | return extra_step_kwargs
139 |
140 | def prepare_latents(
141 | self,
142 | batch_size,
143 | num_channels_latents,
144 | width,
145 | height,
146 | dtype,
147 | device,
148 | generator,
149 | latents=None,
150 | ):
151 | shape = (
152 | batch_size,
153 | num_channels_latents,
154 | height // self.vae_scale_factor,
155 | width // self.vae_scale_factor,
156 | )
157 | if isinstance(generator, list) and len(generator) != batch_size:
158 | raise ValueError(
159 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
160 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
161 | )
162 |
163 | if latents is None:
164 | latents = randn_tensor(
165 | shape, generator=generator, device=device, dtype=dtype
166 | )
167 | else:
168 | latents = latents.to(device)
169 |
170 | # scale the initial noise by the standard deviation required by the scheduler
171 | latents = latents * self.scheduler.init_noise_sigma
172 | return latents
173 |
174 | def prepare_condition(
175 | self,
176 | cond_image,
177 | width,
178 | height,
179 | device,
180 | dtype,
181 | do_classififer_free_guidance=False,
182 | ):
183 | image = self.cond_image_processor.preprocess(
184 | cond_image, height=height, width=width
185 | ).to(dtype=torch.float32)
186 |
187 | image = image.to(device=device, dtype=dtype)
188 |
189 |
190 | if do_classififer_free_guidance:
191 | image = torch.cat([image] * 2)
192 |
193 | return image
194 |
195 | @torch.no_grad()
196 | def __call__(
197 | self,
198 | ref_image,
199 | pose_image,
200 | ref_pose_image,
201 | width,
202 | height,
203 | num_inference_steps,
204 | guidance_scale,
205 | num_images_per_prompt=1,
206 | eta: float = 0.0,
207 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
208 | output_type: Optional[str] = "tensor",
209 | return_dict: bool = True,
210 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
211 | callback_steps: Optional[int] = 1,
212 | **kwargs,
213 | ):
214 | # Default height and width to unet
215 | height = height or self.unet.config.sample_size * self.vae_scale_factor
216 | width = width or self.unet.config.sample_size * self.vae_scale_factor
217 |
218 | device = self._execution_device
219 |
220 | do_classifier_free_guidance = guidance_scale > 1.0
221 |
222 | # Prepare timesteps
223 | self.scheduler.set_timesteps(num_inference_steps, device=device)
224 | timesteps = self.scheduler.timesteps
225 |
226 | batch_size = 1
227 |
228 | # Prepare clip image embeds
229 | clip_image = self.clip_image_processor.preprocess(
230 | ref_image.resize((224, 224)), return_tensors="pt"
231 | ).pixel_values
232 | clip_image_embeds = self.image_encoder(
233 | clip_image.to(device, dtype=self.image_encoder.dtype)
234 | ).image_embeds
235 | image_prompt_embeds = clip_image_embeds.unsqueeze(1)
236 |
237 | uncond_image_prompt_embeds = torch.zeros_like(image_prompt_embeds)
238 |
239 | if do_classifier_free_guidance:
240 | image_prompt_embeds = torch.cat(
241 | [uncond_image_prompt_embeds, image_prompt_embeds], dim=0
242 | )
243 |
244 | reference_control_writer = ReferenceAttentionControl(
245 | self.reference_unet,
246 | do_classifier_free_guidance=do_classifier_free_guidance,
247 | mode="write",
248 | batch_size=batch_size,
249 | fusion_blocks="full",
250 | )
251 | reference_control_reader = ReferenceAttentionControl(
252 | self.denoising_unet,
253 | do_classifier_free_guidance=do_classifier_free_guidance,
254 | mode="read",
255 | batch_size=batch_size,
256 | fusion_blocks="full",
257 | )
258 |
259 | num_channels_latents = self.denoising_unet.in_channels
260 | latents = self.prepare_latents(
261 | batch_size * num_images_per_prompt,
262 | num_channels_latents,
263 | width,
264 | height,
265 | clip_image_embeds.dtype,
266 | device,
267 | generator,
268 | )
269 | latents = latents.unsqueeze(2) # (bs, c, 1, h', w')
270 | latents_dtype = latents.dtype
271 |
272 | # Prepare extra step kwargs.
273 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
274 |
275 | # Prepare ref image latents
276 | ref_image_tensor = self.ref_image_processor.preprocess(
277 | ref_image, height=height, width=width
278 | ) # (bs, c, width, height)
279 | ref_image_tensor = ref_image_tensor.to(
280 | dtype=self.vae.dtype, device=self.vae.device
281 | )
282 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
283 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
284 |
285 | # Prepare pose condition image
286 | pose_cond_tensor = self.cond_image_processor.preprocess(
287 | pose_image, height=height, width=width
288 | )
289 |
290 | pose_cond_tensor = pose_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w)
291 | pose_cond_tensor = pose_cond_tensor.to(
292 | device=device, dtype=self.pose_guider.dtype
293 | )
294 |
295 | ref_pose_tensor = self.cond_image_processor.preprocess(
296 | ref_pose_image, height=height, width=width
297 | )
298 | ref_pose_tensor = ref_pose_tensor.to(
299 | device=device, dtype=self.pose_guider.dtype
300 | )
301 |
302 | pose_fea = self.pose_guider(pose_cond_tensor, ref_pose_tensor)
303 | if do_classifier_free_guidance:
304 | for idxx in range(len(pose_fea)):
305 | pose_fea[idxx] = torch.cat([pose_fea[idxx]] * 2)
306 |
307 | # denoising loop
308 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
309 | with self.progress_bar(total=num_inference_steps) as progress_bar:
310 | for i, t in enumerate(timesteps):
311 | # 1. Forward reference image
312 | if i == 0:
313 | self.reference_unet(
314 | ref_image_latents.repeat(
315 | (2 if do_classifier_free_guidance else 1), 1, 1, 1
316 | ),
317 | torch.zeros_like(t),
318 | encoder_hidden_states=image_prompt_embeds,
319 | return_dict=False,
320 | )
321 |
322 | # 2. Update reference unet feature into denosing net
323 | reference_control_reader.update(reference_control_writer)
324 |
325 | # 3.1 expand the latents if we are doing classifier free guidance
326 | latent_model_input = (
327 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents
328 | )
329 | latent_model_input = self.scheduler.scale_model_input(
330 | latent_model_input, t
331 | )
332 |
333 | noise_pred = self.denoising_unet(
334 | latent_model_input,
335 | t,
336 | encoder_hidden_states=image_prompt_embeds,
337 | pose_cond_fea=pose_fea,
338 | return_dict=False,
339 | )[0]
340 |
341 | # perform guidance
342 | if do_classifier_free_guidance:
343 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
344 | noise_pred = noise_pred_uncond + guidance_scale * (
345 | noise_pred_text - noise_pred_uncond
346 | )
347 |
348 | # compute the previous noisy sample x_t -> x_t-1
349 | latents = self.scheduler.step(
350 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False
351 | )[0]
352 |
353 | # call the callback, if provided
354 | if i == len(timesteps) - 1 or (
355 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
356 | ):
357 | progress_bar.update()
358 | if callback is not None and i % callback_steps == 0:
359 | step_idx = i // getattr(self.scheduler, "order", 1)
360 | callback(step_idx, t, latents)
361 | reference_control_reader.clear()
362 | reference_control_writer.clear()
363 |
364 | # Post-processing
365 | image = self.decode_latents(latents) # (b, c, 1, h, w)
366 |
367 | # Convert to tensor
368 | if output_type == "tensor":
369 | image = torch.from_numpy(image)
370 |
371 | if not return_dict:
372 | return image
373 |
374 | return Pose2ImagePipelineOutput(images=image)
375 |
--------------------------------------------------------------------------------
/src/pipelines/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | tensor_interpolation = None
4 |
5 |
6 | def get_tensor_interpolation_method():
7 | return tensor_interpolation
8 |
9 |
10 | def set_tensor_interpolation_method(is_slerp):
11 | global tensor_interpolation
12 | tensor_interpolation = slerp if is_slerp else linear
13 |
14 |
15 | def linear(v1, v2, t):
16 | return (1.0 - t) * v1 + t * v2
17 |
18 |
19 | def slerp(
20 | v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
21 | ) -> torch.Tensor:
22 | u0 = v0 / v0.norm()
23 | u1 = v1 / v1.norm()
24 | dot = (u0 * u1).sum()
25 | if dot.abs() > DOT_THRESHOLD:
26 | # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')
27 | return (1.0 - t) * v0 + t * v1
28 | omega = dot.acos()
29 | return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
30 |
--------------------------------------------------------------------------------
/src/utils/audio_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 |
4 | import librosa
5 | import numpy as np
6 | from transformers import Wav2Vec2FeatureExtractor
7 |
8 |
9 | class DataProcessor:
10 | def __init__(self, sampling_rate, wav2vec_model_path):
11 | self._processor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True)
12 | self._sampling_rate = sampling_rate
13 |
14 | def extract_feature(self, audio_path):
15 | speech_array, sampling_rate = librosa.load(audio_path, sr=self._sampling_rate)
16 | input_value = np.squeeze(self._processor(speech_array, sampling_rate=sampling_rate).input_values)
17 | return input_value
18 |
19 |
20 | def prepare_audio_feature(wav_file, fps=30, sampling_rate=16000, wav2vec_model_path=None):
21 | data_preprocessor = DataProcessor(sampling_rate, wav2vec_model_path)
22 |
23 | input_value = data_preprocessor.extract_feature(wav_file)
24 | seq_len = math.ceil(len(input_value)/sampling_rate*fps)
25 | return {
26 | "audio_feature": input_value,
27 | "seq_len": seq_len
28 | }
29 |
30 |
31 |
--------------------------------------------------------------------------------
/src/utils/draw_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import mediapipe as mp
3 | import numpy as np
4 | from mediapipe.framework.formats import landmark_pb2
5 |
6 | class FaceMeshVisualizer:
7 | def __init__(self, forehead_edge=False):
8 | self.mp_drawing = mp.solutions.drawing_utils
9 | mp_face_mesh = mp.solutions.face_mesh
10 | self.mp_face_mesh = mp_face_mesh
11 | self.forehead_edge = forehead_edge
12 |
13 | DrawingSpec = mp.solutions.drawing_styles.DrawingSpec
14 | f_thick = 2
15 | f_rad = 1
16 | right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad)
17 | right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad)
18 | right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad)
19 | left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad)
20 | left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad)
21 | left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad)
22 | head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad)
23 |
24 | mouth_draw_obl = DrawingSpec(color=(10, 180, 20), thickness=f_thick, circle_radius=f_rad)
25 | mouth_draw_obr = DrawingSpec(color=(20, 10, 180), thickness=f_thick, circle_radius=f_rad)
26 |
27 | mouth_draw_ibl = DrawingSpec(color=(100, 100, 30), thickness=f_thick, circle_radius=f_rad)
28 | mouth_draw_ibr = DrawingSpec(color=(100, 150, 50), thickness=f_thick, circle_radius=f_rad)
29 |
30 | mouth_draw_otl = DrawingSpec(color=(20, 80, 100), thickness=f_thick, circle_radius=f_rad)
31 | mouth_draw_otr = DrawingSpec(color=(80, 100, 20), thickness=f_thick, circle_radius=f_rad)
32 |
33 | mouth_draw_itl = DrawingSpec(color=(120, 100, 200), thickness=f_thick, circle_radius=f_rad)
34 | mouth_draw_itr = DrawingSpec(color=(150 ,120, 100), thickness=f_thick, circle_radius=f_rad)
35 |
36 | FACEMESH_LIPS_OUTER_BOTTOM_LEFT = [(61,146),(146,91),(91,181),(181,84),(84,17)]
37 | FACEMESH_LIPS_OUTER_BOTTOM_RIGHT = [(17,314),(314,405),(405,321),(321,375),(375,291)]
38 |
39 | FACEMESH_LIPS_INNER_BOTTOM_LEFT = [(78,95),(95,88),(88,178),(178,87),(87,14)]
40 | FACEMESH_LIPS_INNER_BOTTOM_RIGHT = [(14,317),(317,402),(402,318),(318,324),(324,308)]
41 |
42 | FACEMESH_LIPS_OUTER_TOP_LEFT = [(61,185),(185,40),(40,39),(39,37),(37,0)]
43 | FACEMESH_LIPS_OUTER_TOP_RIGHT = [(0,267),(267,269),(269,270),(270,409),(409,291)]
44 |
45 | FACEMESH_LIPS_INNER_TOP_LEFT = [(78,191),(191,80),(80,81),(81,82),(82,13)]
46 | FACEMESH_LIPS_INNER_TOP_RIGHT = [(13,312),(312,311),(311,310),(310,415),(415,308)]
47 |
48 | FACEMESH_CUSTOM_FACE_OVAL = [(176, 149), (150, 136), (356, 454), (58, 132), (152, 148), (361, 288), (251, 389), (132, 93), (389, 356), (400, 377), (136, 172), (377, 152), (323, 361), (172, 58), (454, 323), (365, 379), (379, 378), (148, 176), (93, 234), (397, 365), (149, 150), (288, 397), (234, 127), (378, 400), (127, 162), (162, 21)]
49 |
50 | # mp_face_mesh.FACEMESH_CONTOURS has all the items we care about.
51 | face_connection_spec = {}
52 | if self.forehead_edge:
53 | for edge in mp_face_mesh.FACEMESH_FACE_OVAL:
54 | face_connection_spec[edge] = head_draw
55 | else:
56 | for edge in FACEMESH_CUSTOM_FACE_OVAL:
57 | face_connection_spec[edge] = head_draw
58 | for edge in mp_face_mesh.FACEMESH_LEFT_EYE:
59 | face_connection_spec[edge] = left_eye_draw
60 | for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW:
61 | face_connection_spec[edge] = left_eyebrow_draw
62 | # for edge in mp_face_mesh.FACEMESH_LEFT_IRIS:
63 | # face_connection_spec[edge] = left_iris_draw
64 | for edge in mp_face_mesh.FACEMESH_RIGHT_EYE:
65 | face_connection_spec[edge] = right_eye_draw
66 | for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:
67 | face_connection_spec[edge] = right_eyebrow_draw
68 | # for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:
69 | # face_connection_spec[edge] = right_iris_draw
70 | # for edge in mp_face_mesh.FACEMESH_LIPS:
71 | # face_connection_spec[edge] = mouth_draw
72 |
73 | for edge in FACEMESH_LIPS_OUTER_BOTTOM_LEFT:
74 | face_connection_spec[edge] = mouth_draw_obl
75 | for edge in FACEMESH_LIPS_OUTER_BOTTOM_RIGHT:
76 | face_connection_spec[edge] = mouth_draw_obr
77 | for edge in FACEMESH_LIPS_INNER_BOTTOM_LEFT:
78 | face_connection_spec[edge] = mouth_draw_ibl
79 | for edge in FACEMESH_LIPS_INNER_BOTTOM_RIGHT:
80 | face_connection_spec[edge] = mouth_draw_ibr
81 | for edge in FACEMESH_LIPS_OUTER_TOP_LEFT:
82 | face_connection_spec[edge] = mouth_draw_otl
83 | for edge in FACEMESH_LIPS_OUTER_TOP_RIGHT:
84 | face_connection_spec[edge] = mouth_draw_otr
85 | for edge in FACEMESH_LIPS_INNER_TOP_LEFT:
86 | face_connection_spec[edge] = mouth_draw_itl
87 | for edge in FACEMESH_LIPS_INNER_TOP_RIGHT:
88 | face_connection_spec[edge] = mouth_draw_itr
89 |
90 |
91 | iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw}
92 |
93 | self.face_connection_spec = face_connection_spec
94 | def draw_pupils(self, image, landmark_list, drawing_spec, halfwidth: int = 2):
95 | """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all
96 | landmarks. Until our PR is merged into mediapipe, we need this separate method."""
97 | if len(image.shape) != 3:
98 | raise ValueError("Input image must be H,W,C.")
99 | image_rows, image_cols, image_channels = image.shape
100 | if image_channels != 3: # BGR channels
101 | raise ValueError('Input image must contain three channel bgr data.')
102 | for idx, landmark in enumerate(landmark_list.landmark):
103 | if (
104 | (landmark.HasField('visibility') and landmark.visibility < 0.9) or
105 | (landmark.HasField('presence') and landmark.presence < 0.5)
106 | ):
107 | continue
108 | if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0:
109 | continue
110 | image_x = int(image_cols*landmark.x)
111 | image_y = int(image_rows*landmark.y)
112 | draw_color = None
113 | if isinstance(drawing_spec, Mapping):
114 | if drawing_spec.get(idx) is None:
115 | continue
116 | else:
117 | draw_color = drawing_spec[idx].color
118 | elif isinstance(drawing_spec, DrawingSpec):
119 | draw_color = drawing_spec.color
120 | image[image_y-halfwidth:image_y+halfwidth, image_x-halfwidth:image_x+halfwidth, :] = draw_color
121 |
122 |
123 |
124 | def draw_landmarks(self, image_size, keypoints, normed=False):
125 | ini_size = [512, 512]
126 | image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.uint8)
127 | new_landmarks = landmark_pb2.NormalizedLandmarkList()
128 | for i in range(keypoints.shape[0]):
129 | landmark = new_landmarks.landmark.add()
130 | if normed:
131 | landmark.x = keypoints[i, 0]
132 | landmark.y = keypoints[i, 1]
133 | else:
134 | landmark.x = keypoints[i, 0] / image_size[0]
135 | landmark.y = keypoints[i, 1] / image_size[1]
136 | landmark.z = 1.0
137 |
138 | self.mp_drawing.draw_landmarks(
139 | image=image,
140 | landmark_list=new_landmarks,
141 | connections=self.face_connection_spec.keys(),
142 | landmark_drawing_spec=None,
143 | connection_drawing_spec=self.face_connection_spec
144 | )
145 | # draw_pupils(image, face_landmarks, iris_landmark_spec, 2)
146 | image = cv2.resize(image, (image_size[0], image_size[1]))
147 |
148 | return image
149 |
150 |
--------------------------------------------------------------------------------
/src/utils/frame_interpolation.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/dajes/frame-interpolation-pytorch
2 | import os
3 | import cv2
4 | import numpy as np
5 | import torch
6 | import bisect
7 | import shutil
8 | import pdb
9 | from tqdm import tqdm
10 |
11 | def init_frame_interpolation_model():
12 | print("Initializing frame interpolation model")
13 | checkpoint_name = os.path.join("./pretrained_model/film_net_fp16.pt")
14 |
15 | model = torch.jit.load(checkpoint_name, map_location='cpu')
16 | model.eval()
17 | model = model.half()
18 | model = model.to(device="cuda")
19 | return model
20 |
21 |
22 | def batch_images_interpolation_tool(input_tensor, model, inter_frames=1):
23 |
24 | video_tensor = []
25 | frame_num = input_tensor.shape[2] # bs, channel, frame, height, width
26 |
27 | for idx in tqdm(range(frame_num-1)):
28 | image1 = input_tensor[:,:,idx]
29 | image2 = input_tensor[:,:,idx+1]
30 |
31 | results = [image1, image2]
32 |
33 | inter_frames = int(inter_frames)
34 | idxes = [0, inter_frames + 1]
35 | remains = list(range(1, inter_frames + 1))
36 |
37 | splits = torch.linspace(0, 1, inter_frames + 2)
38 |
39 | for _ in range(len(remains)):
40 | starts = splits[idxes[:-1]]
41 | ends = splits[idxes[1:]]
42 | distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs()
43 | matrix = torch.argmin(distances).item()
44 | start_i, step = np.unravel_index(matrix, distances.shape)
45 | end_i = start_i + 1
46 |
47 | x0 = results[start_i]
48 | x1 = results[end_i]
49 |
50 | x0 = x0.half()
51 | x1 = x1.half()
52 | x0 = x0.cuda()
53 | x1 = x1.cuda()
54 |
55 | dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])
56 |
57 | with torch.no_grad():
58 | prediction = model(x0, x1, dt)
59 | insert_position = bisect.bisect_left(idxes, remains[step])
60 | idxes.insert(insert_position, remains[step])
61 | results.insert(insert_position, prediction.clamp(0, 1).cpu().float())
62 | del remains[step]
63 |
64 | for sub_idx in range(len(results)-1):
65 | video_tensor.append(results[sub_idx].unsqueeze(2))
66 |
67 | video_tensor.append(input_tensor[:,:,-1].unsqueeze(2))
68 | video_tensor = torch.cat(video_tensor, dim=2)
69 | return video_tensor
--------------------------------------------------------------------------------
/src/utils/mp_models/blaze_face_short_range.tflite:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/src/utils/mp_models/blaze_face_short_range.tflite
--------------------------------------------------------------------------------
/src/utils/mp_models/face_landmarker_v2_with_blendshapes.task:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/src/utils/mp_models/face_landmarker_v2_with_blendshapes.task
--------------------------------------------------------------------------------
/src/utils/mp_models/pose_landmarker_heavy.task:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/src/utils/mp_models/pose_landmarker_heavy.task
--------------------------------------------------------------------------------
/src/utils/mp_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import time
5 | from tqdm import tqdm
6 | import multiprocessing
7 | import glob
8 |
9 | import mediapipe as mp
10 | from mediapipe import solutions
11 | from mediapipe.framework.formats import landmark_pb2
12 | from mediapipe.tasks import python
13 | from mediapipe.tasks.python import vision
14 | from . import face_landmark
15 |
16 | CUR_DIR = os.path.dirname(__file__)
17 |
18 |
19 | class LMKExtractor():
20 | def __init__(self, FPS=25):
21 | # Create an FaceLandmarker object.
22 | self.mode = mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE
23 | base_options = python.BaseOptions(model_asset_path=os.path.join(CUR_DIR, 'mp_models/face_landmarker_v2_with_blendshapes.task'))
24 | base_options.delegate = mp.tasks.BaseOptions.Delegate.CPU
25 | options = vision.FaceLandmarkerOptions(base_options=base_options,
26 | running_mode=self.mode,
27 | output_face_blendshapes=True,
28 | output_facial_transformation_matrixes=True,
29 | num_faces=1)
30 | self.detector = face_landmark.FaceLandmarker.create_from_options(options)
31 | self.last_ts = 0
32 | self.frame_ms = int(1000 / FPS)
33 |
34 | det_base_options = python.BaseOptions(model_asset_path=os.path.join(CUR_DIR, 'mp_models/blaze_face_short_range.tflite'))
35 | det_options = vision.FaceDetectorOptions(base_options=det_base_options)
36 | self.det_detector = vision.FaceDetector.create_from_options(det_options)
37 |
38 |
39 | def __call__(self, img):
40 | frame = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
41 | image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame)
42 | t0 = time.time()
43 | if self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.VIDEO:
44 | det_result = self.det_detector.detect(image)
45 | if len(det_result.detections) != 1:
46 | return None
47 | self.last_ts += self.frame_ms
48 | try:
49 | detection_result, mesh3d = self.detector.detect_for_video(image, timestamp_ms=self.last_ts)
50 | except:
51 | return None
52 | elif self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE:
53 | # det_result = self.det_detector.detect(image)
54 |
55 | # if len(det_result.detections) != 1:
56 | # return None
57 | try:
58 | detection_result, mesh3d = self.detector.detect(image)
59 | except:
60 | return None
61 |
62 |
63 | bs_list = detection_result.face_blendshapes
64 | if len(bs_list) == 1:
65 | bs = bs_list[0]
66 | bs_values = []
67 | for index in range(len(bs)):
68 | bs_values.append(bs[index].score)
69 | bs_values = bs_values[1:] # remove neutral
70 | trans_mat = detection_result.facial_transformation_matrixes[0]
71 | face_landmarks_list = detection_result.face_landmarks
72 | face_landmarks = face_landmarks_list[0]
73 | lmks = []
74 | for index in range(len(face_landmarks)):
75 | x = face_landmarks[index].x
76 | y = face_landmarks[index].y
77 | z = face_landmarks[index].z
78 | lmks.append([x, y, z])
79 | lmks = np.array(lmks)
80 |
81 | lmks3d = np.array(mesh3d.vertex_buffer)
82 | lmks3d = lmks3d.reshape(-1, 5)[:, :3]
83 | mp_tris = np.array(mesh3d.index_buffer).reshape(-1, 3) + 1
84 |
85 | return {
86 | "lmks": lmks,
87 | 'lmks3d': lmks3d,
88 | "trans_mat": trans_mat,
89 | 'faces': mp_tris,
90 | "bs": bs_values
91 | }
92 | else:
93 | # print('multiple faces in the image: {}'.format(img_path))
94 | return None
95 |
--------------------------------------------------------------------------------
/src/utils/pose_util.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | from scipy.spatial.transform import Rotation as R
5 |
6 |
7 | def create_perspective_matrix(aspect_ratio):
8 | kDegreesToRadians = np.pi / 180.
9 | near = 1
10 | far = 10000
11 | perspective_matrix = np.zeros(16, dtype=np.float32)
12 |
13 | # Standard perspective projection matrix calculations.
14 | f = 1.0 / np.tan(kDegreesToRadians * 63 / 2.)
15 |
16 | denom = 1.0 / (near - far)
17 | perspective_matrix[0] = f / aspect_ratio
18 | perspective_matrix[5] = f
19 | perspective_matrix[10] = (near + far) * denom
20 | perspective_matrix[11] = -1.
21 | perspective_matrix[14] = 1. * far * near * denom
22 |
23 | # If the environment's origin point location is in the top left corner,
24 | # then skip additional flip along Y-axis is required to render correctly.
25 |
26 | perspective_matrix[5] *= -1.
27 | return perspective_matrix
28 |
29 |
30 | def project_points(points_3d, transformation_matrix, pose_vectors, image_shape):
31 | P = create_perspective_matrix(image_shape[1] / image_shape[0]).reshape(4, 4).T
32 | L, N, _ = points_3d.shape
33 | projected_points = np.zeros((L, N, 2))
34 | for i in range(L):
35 | points_3d_frame = points_3d[i]
36 | ones = np.ones((points_3d_frame.shape[0], 1))
37 | points_3d_homogeneous = np.hstack([points_3d_frame, ones])
38 | transformed_points = points_3d_homogeneous @ (transformation_matrix @ euler_and_translation_to_matrix(pose_vectors[i][:3], pose_vectors[i][3:])).T @ P
39 | projected_points_frame = transformed_points[:, :2] / transformed_points[:, 3, np.newaxis] # -1 ~ 1
40 | projected_points_frame[:, 0] = (projected_points_frame[:, 0] + 1) * 0.5 * image_shape[1]
41 | projected_points_frame[:, 1] = (projected_points_frame[:, 1] + 1) * 0.5 * image_shape[0]
42 | projected_points[i] = projected_points_frame
43 | return projected_points
44 |
45 |
46 | def project_points_with_trans(points_3d, transformation_matrix, image_shape):
47 | P = create_perspective_matrix(image_shape[1] / image_shape[0]).reshape(4, 4).T
48 | L, N, _ = points_3d.shape
49 | projected_points = np.zeros((L, N, 2))
50 | for i in range(L):
51 | points_3d_frame = points_3d[i]
52 | ones = np.ones((points_3d_frame.shape[0], 1))
53 | points_3d_homogeneous = np.hstack([points_3d_frame, ones])
54 | transformed_points = points_3d_homogeneous @ transformation_matrix[i].T @ P
55 | projected_points_frame = transformed_points[:, :2] / transformed_points[:, 3, np.newaxis] # -1 ~ 1
56 | projected_points_frame[:, 0] = (projected_points_frame[:, 0] + 1) * 0.5 * image_shape[1]
57 | projected_points_frame[:, 1] = (projected_points_frame[:, 1] + 1) * 0.5 * image_shape[0]
58 | projected_points[i] = projected_points_frame
59 | return projected_points
60 |
61 |
62 | def euler_and_translation_to_matrix(euler_angles, translation_vector):
63 | rotation = R.from_euler('xyz', euler_angles, degrees=True)
64 | rotation_matrix = rotation.as_matrix()
65 |
66 | matrix = np.eye(4)
67 | matrix[:3, :3] = rotation_matrix
68 | matrix[:3, 3] = translation_vector
69 |
70 | return matrix
71 |
72 |
73 | def matrix_to_euler_and_translation(matrix):
74 | rotation_matrix = matrix[:3, :3]
75 | translation_vector = matrix[:3, 3]
76 | rotation = R.from_matrix(rotation_matrix)
77 | euler_angles = rotation.as_euler('xyz', degrees=True)
78 | return euler_angles, translation_vector
79 |
80 |
81 | def smooth_pose_seq(pose_seq, window_size=5):
82 | smoothed_pose_seq = np.zeros_like(pose_seq)
83 |
84 | for i in range(len(pose_seq)):
85 | start = max(0, i - window_size // 2)
86 | end = min(len(pose_seq), i + window_size // 2 + 1)
87 | smoothed_pose_seq[i] = np.mean(pose_seq[start:end], axis=0)
88 |
89 | return smoothed_pose_seq
--------------------------------------------------------------------------------
/src/utils/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | import os.path as osp
4 | import shutil
5 | import sys
6 | import cv2
7 | from pathlib import Path
8 |
9 | import av
10 | import numpy as np
11 | import torch
12 | import torchvision
13 | from einops import rearrange
14 | from PIL import Image
15 |
16 |
17 | def seed_everything(seed):
18 | import random
19 |
20 | import numpy as np
21 |
22 | torch.manual_seed(seed)
23 | torch.cuda.manual_seed_all(seed)
24 | np.random.seed(seed % (2**32))
25 | random.seed(seed)
26 |
27 |
28 | def import_filename(filename):
29 | spec = importlib.util.spec_from_file_location("mymodule", filename)
30 | module = importlib.util.module_from_spec(spec)
31 | sys.modules[spec.name] = module
32 | spec.loader.exec_module(module)
33 | return module
34 |
35 |
36 | def delete_additional_ckpt(base_path, num_keep):
37 | dirs = []
38 | for d in os.listdir(base_path):
39 | if d.startswith("checkpoint-"):
40 | dirs.append(d)
41 | num_tot = len(dirs)
42 | if num_tot <= num_keep:
43 | return
44 | # ensure ckpt is sorted and delete the ealier!
45 | del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
46 | for d in del_dirs:
47 | path_to_dir = osp.join(base_path, d)
48 | if osp.exists(path_to_dir):
49 | shutil.rmtree(path_to_dir)
50 |
51 |
52 | def save_videos_from_pil(pil_images, path, fps=8):
53 | import av
54 |
55 | save_fmt = Path(path).suffix
56 | os.makedirs(os.path.dirname(path), exist_ok=True)
57 | width, height = pil_images[0].size
58 |
59 | if save_fmt == ".mp4":
60 | codec = "libx264"
61 | container = av.open(path, "w")
62 | stream = container.add_stream(codec, rate=fps)
63 |
64 | stream.width = width
65 | stream.height = height
66 |
67 | for pil_image in pil_images:
68 | # pil_image = Image.fromarray(image_arr).convert("RGB")
69 | av_frame = av.VideoFrame.from_image(pil_image)
70 | container.mux(stream.encode(av_frame))
71 | container.mux(stream.encode())
72 | container.close()
73 |
74 | elif save_fmt == ".gif":
75 | pil_images[0].save(
76 | fp=path,
77 | format="GIF",
78 | append_images=pil_images[1:],
79 | save_all=True,
80 | duration=(1 / fps * 1000),
81 | loop=0,
82 | )
83 | else:
84 | raise ValueError("Unsupported file type. Use .mp4 or .gif.")
85 |
86 |
87 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
88 | videos = rearrange(videos, "b c t h w -> t b c h w")
89 | height, width = videos.shape[-2:]
90 | outputs = []
91 |
92 | for x in videos:
93 | x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
94 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
95 | if rescale:
96 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
97 | x = (x * 255).numpy().astype(np.uint8)
98 | x = Image.fromarray(x)
99 |
100 | outputs.append(x)
101 |
102 | os.makedirs(os.path.dirname(path), exist_ok=True)
103 |
104 | save_videos_from_pil(outputs, path, fps)
105 |
106 |
107 | def read_frames(video_path):
108 | container = av.open(video_path)
109 |
110 | video_stream = next(s for s in container.streams if s.type == "video")
111 | frames = []
112 | for packet in container.demux(video_stream):
113 | for frame in packet.decode():
114 | image = Image.frombytes(
115 | "RGB",
116 | (frame.width, frame.height),
117 | frame.to_rgb().to_ndarray(),
118 | )
119 | frames.append(image)
120 |
121 | return frames
122 |
123 |
124 | def get_fps(video_path):
125 | container = av.open(video_path)
126 | video_stream = next(s for s in container.streams if s.type == "video")
127 | fps = video_stream.average_rate
128 | container.close()
129 | return fps
130 |
131 | def crop_face(img, lmk_extractor, expand=1.5):
132 | result = lmk_extractor(img) # cv2 BGR
133 |
134 | if result is None:
135 | return None
136 |
137 | H, W, _ = img.shape
138 | lmks = result['lmks']
139 | lmks[:, 0] *= W
140 | lmks[:, 1] *= H
141 |
142 | x_min = np.min(lmks[:, 0])
143 | x_max = np.max(lmks[:, 0])
144 | y_min = np.min(lmks[:, 1])
145 | y_max = np.max(lmks[:, 1])
146 |
147 | width = x_max - x_min
148 | height = y_max - y_min
149 |
150 | if width*height >= W*H*0.15:
151 | if W == H:
152 | return img
153 | size = min(H, W)
154 | offset = int((max(H, W) - size)/2)
155 | if size == H:
156 | return img[:, offset:-offset]
157 | else:
158 | return img[offset:-offset, :]
159 | else:
160 | center_x = x_min + width / 2
161 | center_y = y_min + height / 2
162 |
163 | width *= expand
164 | height *= expand
165 |
166 | size = max(width, height)
167 |
168 | x_min = int(center_x - size / 2)
169 | x_max = int(center_x + size / 2)
170 | y_min = int(center_y - size / 2)
171 | y_max = int(center_y + size / 2)
172 |
173 | top = max(0, -y_min)
174 | bottom = max(0, y_max - img.shape[0])
175 | left = max(0, -x_min)
176 | right = max(0, x_max - img.shape[1])
177 | img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
178 |
179 | cropped_img = img[y_min + top:y_max + top, x_min + left:x_max + left]
180 |
181 | return cropped_img
--------------------------------------------------------------------------------