├── .dockerignore ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── example_data │ ├── images │ │ └── demo1.jpg │ └── videos │ │ └── pose1.mp4 └── figures │ ├── latent_fusion.png │ ├── model_structure.png │ ├── preview_1.gif │ ├── preview_2.gif │ ├── preview_3.gif │ ├── preview_4.gif │ ├── preview_5.gif │ └── preview_6.gif ├── cog.yaml ├── configs └── test.yaml ├── constants.py ├── environment.yaml ├── inference.py ├── mimicmotion ├── __init__.py ├── dwpose │ ├── .gitignore │ ├── __init__.py │ ├── dwpose_detector.py │ ├── onnxdet.py │ ├── onnxpose.py │ ├── preprocess.py │ ├── util.py │ └── wholebody.py ├── modules │ ├── __init__.py │ ├── attention.py │ ├── pose_net.py │ └── unet.py ├── pipelines │ └── pipeline_mimicmotion.py └── utils │ ├── __init__.py │ ├── geglu_patch.py │ ├── loader.py │ └── utils.py └── predict.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Replicate 6 | /models/ 7 | /outputs/ 8 | *.gif 9 | *.mp4 10 | *.jpg 11 | *.jpeg 12 | *.png 13 | *.webp 14 | MimicMotion.pth 15 | 16 | # Exclude Git files 17 | .git 18 | .github 19 | .gitignore 20 | 21 | # Exclude Python cache files 22 | __pycache__ 23 | .mypy_cache 24 | .pytest_cache 25 | .ruff_cache 26 | 27 | # Exclude Python virtual environment 28 | /venv 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | #/site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Cython debug symbols 145 | cython_debug/ 146 | 147 | # PyCharm 148 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 149 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 150 | # and can be added to the global gitignore or merged into this file. For a more nuclear 151 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 152 | .idea/ 153 | 154 | # custom ignores 155 | .DS_Store 156 | _.* 157 | 158 | # models and outputs 159 | models/ 160 | outputs/ 161 | 162 | # Replicate 163 | .cog -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Tencent is pleased to support the open source community by making MimicMotion available. 2 | 3 | Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) THL A29 Limited. 4 | 5 | MimicMotion is licensed under the Apache License Version 2.0 except for the third-party components listed below. 6 | 7 | 8 | Terms of the Apache License Version 2.0: 9 | -------------------------------------------------------------------- 10 | Apache License 11 | 12 | Version 2.0, January 2004 13 | 14 | http://www.apache.org/licenses/ 15 | 16 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 17 | 1. Definitions. 18 | 19 | "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. 20 | 21 | "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. 22 | 23 | "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. 28 | 29 | "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. 30 | 31 | "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). 32 | 33 | "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. 34 | 35 | "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 36 | 37 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 38 | 39 | 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 40 | 41 | 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 42 | 43 | 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: 44 | 45 | You must give any other recipients of the Work or Derivative Works a copy of this License; and 46 | 47 | You must cause any modified files to carry prominent notices stating that You changed the files; and 48 | 49 | You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and 50 | 51 | If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. 52 | 53 | You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 54 | 55 | 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 56 | 57 | 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 58 | 59 | 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 60 | 61 | 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 62 | 63 | 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 64 | 65 | END OF TERMS AND CONDITIONS 66 | 67 | 68 | 69 | Other dependencies and licenses: 70 | 71 | 72 | Open Source Software Licensed under the Apache License Version 2.0: 73 | The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2023 THL A29 Limited. 74 | -------------------------------------------------------------------- 75 | 1. diffusers 76 | Copyright (c) diffusers original author and authors 77 | 78 | 2. DWPose 79 | Copyright 2018-2020 Open-MMLab. 80 | Please note this software has been modified by Tencent in this distribution. 81 | 82 | 3. transformers 83 | Copyright (c) transformers original author and authors 84 | 85 | 4. decord 86 | Copyright (c) DWPoseoriginal author and authors 87 | 88 | 89 | A copy of Apache 2.0 has been included in this file. 90 | 91 | 92 | 93 | Open Source Software Licensed under the BSD 3-Clause License: 94 | -------------------------------------------------------------------- 95 | 1. torch 96 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 97 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 98 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 99 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 100 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 101 | Copyright (c) 2011-2013 NYU (Clement Farabet) 102 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 103 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 104 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 105 | 106 | 2. omegaconf 107 | Copyright (c) 2018, Omry Yadan 108 | All rights reserved. 109 | 110 | 3. torchvision 111 | Copyright (c) Soumith Chintala 2016, 112 | All rights reserved. 113 | 114 | 115 | Terms of the BSD 3-Clause: 116 | -------------------------------------------------------------------- 117 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 118 | 119 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 120 | 121 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 122 | 123 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 124 | 125 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 126 | 127 | 128 | 129 | Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein: 130 | -------------------------------------------------------------------- 131 | 1. numpy 132 | Copyright (c) 2005-2023, NumPy Developers. 133 | All rights reserved. 134 | 135 | A copy of the BSD 3-Clause is included in this file. 136 | 137 | For the license of other third party components, please refer to the following URL: 138 | https://github.com/numpy/numpy/blob/v1.26.3/LICENSES_bundled.txt 139 | 140 | 141 | 142 | Open Source Software Licensed under the HPND License: 143 | -------------------------------------------------------------------- 144 | 1. Pillow 145 | Copyright © 2010-2023 by Jeffrey A. Clark (Alex) and contributors. 146 | 147 | 148 | Terms of the HPND License: 149 | -------------------------------------------------------------------- 150 | The Python Imaging Library (PIL) is 151 | 152 | Copyright © 1997-2011 by Secret Labs AB 153 | Copyright © 1995-2011 by Fredrik Lundh 154 | 155 | Pillow is the friendly PIL fork. It is 156 | 157 | Copyright © 2010-2023 by Jeffrey A. Clark (Alex) and contributors. 158 | 159 | Like PIL, Pillow is licensed under the open source HPND License: 160 | 161 | By obtaining, using, and/or copying this software and/or its associated 162 | documentation, you agree that you have read, understood, and will comply 163 | with the following terms and conditions: 164 | 165 | Permission to use, copy, modify and distribute this software and its 166 | documentation for any purpose and without fee is hereby granted, 167 | provided that the above copyright notice appears in all copies, and that 168 | both that copyright notice and this permission notice appear in supporting 169 | documentation, and that the name of Secret Labs AB or the author not be 170 | used in advertising or publicity pertaining to distribution of the software 171 | without specific, written prior permission. 172 | 173 | SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS 174 | SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. 175 | IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR BE LIABLE FOR ANY SPECIAL, 176 | INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM 177 | LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE 178 | OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR 179 | PERFORMANCE OF THIS SOFTWARE. 180 | 181 | 182 | Open Source Software Licensed under the Matplotlib License and Other Licenses of the Third-Party Components therein: 183 | -------------------------------------------------------------------- 184 | 1. matplotlib 185 | Copyright (c) 186 | 2012- Matplotlib Development Team; All Rights Reserved 187 | 188 | 189 | Terms of the Matplotlib License: 190 | -------------------------------------------------------------------- 191 | License agreement for matplotlib versions 1.3.0 and later 192 | ========================================================= 193 | 194 | 1. This LICENSE AGREEMENT is between the Matplotlib Development Team 195 | ("MDT"), and the Individual or Organization ("Licensee") accessing and 196 | otherwise using matplotlib software in source or binary form and its 197 | associated documentation. 198 | 199 | 2. Subject to the terms and conditions of this License Agreement, MDT 200 | hereby grants Licensee a nonexclusive, royalty-free, world-wide license 201 | to reproduce, analyze, test, perform and/or display publicly, prepare 202 | derivative works, distribute, and otherwise use matplotlib 203 | alone or in any derivative version, provided, however, that MDT's 204 | License Agreement and MDT's notice of copyright, i.e., "Copyright (c) 205 | 2012- Matplotlib Development Team; All Rights Reserved" are retained in 206 | matplotlib alone or in any derivative version prepared by 207 | Licensee. 208 | 209 | 3. In the event Licensee prepares a derivative work that is based on or 210 | incorporates matplotlib or any part thereof, and wants to 211 | make the derivative work available to others as provided herein, then 212 | Licensee hereby agrees to include in any such work a brief summary of 213 | the changes made to matplotlib . 214 | 215 | 4. MDT is making matplotlib available to Licensee on an "AS 216 | IS" basis. MDT MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR 217 | IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, MDT MAKES NO AND 218 | DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS 219 | FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF MATPLOTLIB 220 | WILL NOT INFRINGE ANY THIRD PARTY RIGHTS. 221 | 222 | 5. MDT SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF MATPLOTLIB 223 | FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR 224 | LOSS AS A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING 225 | MATPLOTLIB , OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF 226 | THE POSSIBILITY THEREOF. 227 | 228 | 6. This License Agreement will automatically terminate upon a material 229 | breach of its terms and conditions. 230 | 231 | 7. Nothing in this License Agreement shall be deemed to create any 232 | relationship of agency, partnership, or joint venture between MDT and 233 | Licensee. This License Agreement does not grant permission to use MDT 234 | trademarks or trade name in a trademark sense to endorse or promote 235 | products or services of Licensee, or any third party. 236 | 237 | 8. By copying, installing or otherwise using matplotlib , 238 | Licensee agrees to be bound by the terms and conditions of this License 239 | Agreement. 240 | 241 | License agreement for matplotlib versions prior to 1.3.0 242 | ======================================================== 243 | 244 | 1. This LICENSE AGREEMENT is between John D. Hunter ("JDH"), and the 245 | Individual or Organization ("Licensee") accessing and otherwise using 246 | matplotlib software in source or binary form and its associated 247 | documentation. 248 | 249 | 2. Subject to the terms and conditions of this License Agreement, JDH 250 | hereby grants Licensee a nonexclusive, royalty-free, world-wide license 251 | to reproduce, analyze, test, perform and/or display publicly, prepare 252 | derivative works, distribute, and otherwise use matplotlib 253 | alone or in any derivative version, provided, however, that JDH's 254 | License Agreement and JDH's notice of copyright, i.e., "Copyright (c) 255 | 2002-2011 John D. Hunter; All Rights Reserved" are retained in 256 | matplotlib alone or in any derivative version prepared by 257 | Licensee. 258 | 259 | 3. In the event Licensee prepares a derivative work that is based on or 260 | incorporates matplotlib or any part thereof, and wants to 261 | make the derivative work available to others as provided herein, then 262 | Licensee hereby agrees to include in any such work a brief summary of 263 | the changes made to matplotlib. 264 | 265 | 4. JDH is making matplotlib available to Licensee on an "AS 266 | IS" basis. JDH MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR 267 | IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, JDH MAKES NO AND 268 | DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS 269 | FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF MATPLOTLIB 270 | WILL NOT INFRINGE ANY THIRD PARTY RIGHTS. 271 | 272 | 5. JDH SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF MATPLOTLIB 273 | FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR 274 | LOSS AS A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING 275 | MATPLOTLIB , OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF 276 | THE POSSIBILITY THEREOF. 277 | 278 | 6. This License Agreement will automatically terminate upon a material 279 | breach of its terms and conditions. 280 | 281 | 7. Nothing in this License Agreement shall be deemed to create any 282 | relationship of agency, partnership, or joint venture between JDH and 283 | Licensee. This License Agreement does not grant permission to use JDH 284 | trademarks or trade name in a trademark sense to endorse or promote 285 | products or services of Licensee, or any third party. 286 | 287 | 8. By copying, installing or otherwise using matplotlib, 288 | Licensee agrees to be bound by the terms and conditions of this License 289 | Agreement. 290 | 291 | For the license of other third party components, please refer to the following URL: 292 | https://github.com/matplotlib/matplotlib/tree/v3.8.0/LICENSE 293 | 294 | 295 | Open Source Software Licensed under the MIT License: 296 | -------------------------------------------------------------------- 297 | 1. einops 298 | Copyright (c) 2018 Alex Rogozhnikov 299 | 300 | 2. onnxruntime 301 | Copyright (c) Microsoft Corporation 302 | 303 | 3. OpenCV 304 | Copyright (c) Olli-Pekka Heinisuo 305 | 306 | 307 | Terms of the MIT License: 308 | -------------------------------------------------------------------- 309 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 310 | 311 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 312 | 313 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MimicMotion 2 | 3 | [![Replicate](https://replicate.com/zsxkib/mimic-motion/badge)](https://replicate.com/zsxkib/mimic-motion) 4 | 5 | MimicMotion: High-Quality Human Motion Video Generation with Confidence-aware Pose Guidance 6 |
7 | *Yuang Zhang1,2, Jiaxi Gu1, Li-Wen Wang1, Han Wang1,2, Junqi Cheng1, Yuefeng Zhu1, Fangyuan Zou1* 8 |
9 | [1Tencent; 2Shanghai Jiao Tong University] 10 | 11 |

12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | Highlights: rich details, good temporal smoothness, and long video length. 20 |

21 | 22 | ## Overview 23 | 24 |

25 | model architecture 26 |
27 | An overview of the framework of MimicMotion. 28 |

29 | 30 | In recent years, generative artificial intelligence has achieved significant advancements in the field of image generation, spawning a variety of applications. However, video generation still faces considerable challenges in various aspects such as controllability, video length, and richness of details, which hinder the application and popularization of this technology. In this work, we propose a controllable video generation framework, dubbed *MimicMotion*, which can generate high-quality videos of arbitrary length with any motion guidance. Comparing with previous methods, our approach has several highlights. Firstly, with confidence-aware pose guidance, temporal smoothness can be achieved so model robustness can be enhanced with large-scale training data. Secondly, regional loss amplification based on pose confidence significantly eases the distortion of image significantly. Lastly, for generating long smooth videos, a progressive latent fusion strategy is proposed. By this means, videos of arbitrary length can be generated with acceptable resource consumption. With extensive experiments and user studies, MimicMotion demonstrates significant improvements over previous approaches in multiple aspects. 31 | 32 | ## News 33 | 34 | * `[2024-07-08]`: 🔥 [A superior model checkpoint](https://huggingface.co/tencent/MimicMotion/blob/main/MimicMotion_1-1.pth) has been released as version 1.1. The maximum number of video frames has now been expanded from 16 to 72, significantly enhancing the video quality! 35 | * `[2024-07-01]`: Project page, code, technical report and [a basic model checkpoint](https://huggingface.co/tencent/MimicMotion/blob/main/MimicMotion_1.pth) are released. A better checkpoint supporting higher quality video generation will be released very soon. Stay tuned! 36 | 37 | ## Quickstart 38 | 39 | For the initial released version of the model checkpoint, it supports generating videos with a maximum of 72 frames at a 576x1024 resolution. If you encounter insufficient memory issues, you can appropriately reduce the number of frames. 40 | 41 | ### Environment setup 42 | 43 | Recommend python 3+ with torch 2.x are validated with an Nvidia V100 GPU. Follow the command below to install all the dependencies of python: 44 | 45 | ``` 46 | conda env create -f environment.yaml 47 | conda activate mimicmotion 48 | ``` 49 | 50 | ### Download weights 51 | If you experience connection issues with Hugging Face, you can utilize the mirror endpoint by setting the environment variable: `export HF_ENDPOINT=https://hf-mirror.com`. 52 | Please download weights manually as follows: 53 | ``` 54 | cd MimicMotions/ 55 | mkdir models 56 | ``` 57 | 1. Download DWPose pretrained model: [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main) 58 | ``` 59 | mkdir -p models/DWPose 60 | wget https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true -O models/DWPose/yolox_l.onnx 61 | wget https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true -O models/DWPose/dw-ll_ucoco_384.onnx 62 | ``` 63 | 2. Download the pre-trained checkpoint of MimicMotion from [Huggingface](https://huggingface.co/tencent/MimicMotion) 64 | ``` 65 | wget -P models/ https://huggingface.co/tencent/MimicMotion/resolve/main/MimicMotion_1-1.pth 66 | ``` 67 | 3. The SVD model [stabilityai/stable-video-diffusion-img2vid-xt-1-1](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1) will be automatically downloaded. 68 | 69 | Finally, all the weights should be organized in models as follows 70 | 71 | ``` 72 | models/ 73 | ├── DWPose 74 | │   ├── dw-ll_ucoco_384.onnx 75 | │   └── yolox_l.onnx 76 | └── MimicMotion_1-1.pth 77 | ``` 78 | 79 | ### Model inference 80 | 81 | A sample configuration for testing is provided as `test.yaml`. You can also easily modify the various configurations according to your needs. 82 | 83 | ``` 84 | python inference.py --inference_config configs/test.yaml 85 | ``` 86 | 87 | Tips: if your GPU memory is limited, try set env `PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:256`. 88 | 89 | ### VRAM requirement and Runtime 90 | 91 | For the 35s demo video, the 72-frame model requires 16GB VRAM (4060ti) and finishes in 20 minutes on a 4090 GPU. 92 | 93 | The minimum VRAM requirement for the 16-frame U-Net model is 8GB; however, the VAE decoder demands 16GB. You have the option to run the VAE decoder on CPU. 94 | 95 | ## Citation 96 | ```bib 97 | @article{mimicmotion2024, 98 | title={MimicMotion: High-Quality Human Motion Video Generation with Confidence-aware Pose Guidance}, 99 | author={Yuang Zhang and Jiaxi Gu and Li-Wen Wang and Han Wang and Junqi Cheng and Yuefeng Zhu and Fangyuan Zou}, 100 | journal={arXiv preprint arXiv:2406.19680}, 101 | year={2024} 102 | } 103 | ``` 104 | -------------------------------------------------------------------------------- /assets/example_data/images/demo1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/MimicMotion/62f91e1f4ab750e1ab0d96c08f2d199b6e84e8b1/assets/example_data/images/demo1.jpg -------------------------------------------------------------------------------- /assets/example_data/videos/pose1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/MimicMotion/62f91e1f4ab750e1ab0d96c08f2d199b6e84e8b1/assets/example_data/videos/pose1.mp4 -------------------------------------------------------------------------------- /assets/figures/latent_fusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/MimicMotion/62f91e1f4ab750e1ab0d96c08f2d199b6e84e8b1/assets/figures/latent_fusion.png -------------------------------------------------------------------------------- /assets/figures/model_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/MimicMotion/62f91e1f4ab750e1ab0d96c08f2d199b6e84e8b1/assets/figures/model_structure.png -------------------------------------------------------------------------------- /assets/figures/preview_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/MimicMotion/62f91e1f4ab750e1ab0d96c08f2d199b6e84e8b1/assets/figures/preview_1.gif -------------------------------------------------------------------------------- /assets/figures/preview_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/MimicMotion/62f91e1f4ab750e1ab0d96c08f2d199b6e84e8b1/assets/figures/preview_2.gif -------------------------------------------------------------------------------- /assets/figures/preview_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/MimicMotion/62f91e1f4ab750e1ab0d96c08f2d199b6e84e8b1/assets/figures/preview_3.gif -------------------------------------------------------------------------------- /assets/figures/preview_4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/MimicMotion/62f91e1f4ab750e1ab0d96c08f2d199b6e84e8b1/assets/figures/preview_4.gif -------------------------------------------------------------------------------- /assets/figures/preview_5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/MimicMotion/62f91e1f4ab750e1ab0d96c08f2d199b6e84e8b1/assets/figures/preview_5.gif -------------------------------------------------------------------------------- /assets/figures/preview_6.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/MimicMotion/62f91e1f4ab750e1ab0d96c08f2d199b6e84e8b1/assets/figures/preview_6.gif -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://cog.run/yaml 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: true 7 | # cuda: "11.7" 8 | 9 | # a list of ubuntu apt packages to install 10 | system_packages: 11 | - "libgl1-mesa-glx" 12 | - "libglib2.0-0" 13 | 14 | # python version in the form '3.11' or '3.11.4' 15 | python_version: "3.11" 16 | 17 | # a list of packages in the format == 18 | python_packages: 19 | - "torch>=2.3" # 2.3.1 20 | - "torchvision>=0.18" # 0.18.1 21 | - "diffusers>=0.29" # 0.29.2 22 | - "transformers>=4.42" # 4.42.3 23 | - "decord>=0.6" # 0.6.0 24 | - "einops>=0.8" # 0.8.0 25 | - "omegaconf>=2.3" # 2.3.0 26 | - "opencv-python>=4.10" # 4.10.0.84 27 | - "matplotlib>=3.9" # 3.9.1 28 | - "onnxruntime>=1.18" # 1.18.1 29 | - "accelerate>=0.32" # 0.32.0 30 | - "av>=12.2" # 12.2.0, https://github.com/continue-revolution/sd-webui-animatediff/issues/377 31 | 32 | # commands run after the environment is setup 33 | run: 34 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget 35 | 36 | # predict.py defines how predictions are run on your model 37 | predict: "predict.py:Predictor" 38 | -------------------------------------------------------------------------------- /configs/test.yaml: -------------------------------------------------------------------------------- 1 | # base svd model path 2 | base_model_path: stabilityai/stable-video-diffusion-img2vid-xt-1-1 3 | 4 | # checkpoint path 5 | ckpt_path: models/MimicMotion_1-1.pth 6 | 7 | test_case: 8 | - ref_video_path: assets/example_data/videos/pose1.mp4 9 | ref_image_path: assets/example_data/images/demo1.jpg 10 | num_frames: 72 11 | resolution: 576 12 | frames_overlap: 6 13 | num_inference_steps: 25 14 | noise_aug_strength: 0 15 | guidance_scale: 2.0 16 | sample_stride: 2 17 | fps: 15 18 | seed: 42 19 | 20 | 21 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | # w/h apsect ratio 2 | ASPECT_RATIO = 9 / 16 3 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: mimicmotion 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python=3.11 7 | - pytorch=2.0.1 8 | - torchvision=0.15.2 9 | - pytorch-cuda=11.7 10 | - pip 11 | - pip: 12 | - diffusers==0.27.0 13 | - transformers==4.32.1 14 | - decord==0.6.0 15 | - einops 16 | - omegaconf 17 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | import math 5 | from omegaconf import OmegaConf 6 | from datetime import datetime 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | import torch.jit 11 | from torchvision.datasets.folder import pil_loader 12 | from torchvision.transforms.functional import pil_to_tensor, resize, center_crop 13 | from torchvision.transforms.functional import to_pil_image 14 | 15 | 16 | from mimicmotion.utils.geglu_patch import patch_geglu_inplace 17 | patch_geglu_inplace() 18 | 19 | from constants import ASPECT_RATIO 20 | 21 | from mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline 22 | from mimicmotion.utils.loader import create_pipeline 23 | from mimicmotion.utils.utils import save_to_mp4 24 | from mimicmotion.dwpose.preprocess import get_video_pose, get_image_pose 25 | 26 | logging.basicConfig(level=logging.INFO, format="%(asctime)s: [%(levelname)s] %(message)s") 27 | logger = logging.getLogger(__name__) 28 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 29 | 30 | 31 | def preprocess(video_path, image_path, resolution=576, sample_stride=2): 32 | """preprocess ref image pose and video pose 33 | 34 | Args: 35 | video_path (str): input video pose path 36 | image_path (str): reference image path 37 | resolution (int, optional): Defaults to 576. 38 | sample_stride (int, optional): Defaults to 2. 39 | """ 40 | image_pixels = pil_loader(image_path) 41 | image_pixels = pil_to_tensor(image_pixels) # (c, h, w) 42 | h, w = image_pixels.shape[-2:] 43 | ############################ compute target h/w according to original aspect ratio ############################### 44 | if h>w: 45 | w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64 46 | else: 47 | w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolution 48 | h_w_ratio = float(h) / float(w) 49 | if h_w_ratio < h_target / w_target: 50 | h_resize, w_resize = h_target, math.ceil(h_target / h_w_ratio) 51 | else: 52 | h_resize, w_resize = math.ceil(w_target * h_w_ratio), w_target 53 | image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None) 54 | image_pixels = center_crop(image_pixels, [h_target, w_target]) 55 | image_pixels = image_pixels.permute((1, 2, 0)).numpy() 56 | ##################################### get image&video pose value ################################################# 57 | image_pose = get_image_pose(image_pixels) 58 | video_pose = get_video_pose(video_path, image_pixels, sample_stride=sample_stride) 59 | pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose]) 60 | image_pixels = np.transpose(np.expand_dims(image_pixels, 0), (0, 3, 1, 2)) 61 | return torch.from_numpy(pose_pixels.copy()) / 127.5 - 1, torch.from_numpy(image_pixels) / 127.5 - 1 62 | 63 | 64 | def run_pipeline(pipeline: MimicMotionPipeline, image_pixels, pose_pixels, device, task_config): 65 | image_pixels = [to_pil_image(img.to(torch.uint8)) for img in (image_pixels + 1.0) * 127.5] 66 | generator = torch.Generator(device=device) 67 | generator.manual_seed(task_config.seed) 68 | frames = pipeline( 69 | image_pixels, image_pose=pose_pixels, num_frames=pose_pixels.size(0), 70 | tile_size=task_config.num_frames, tile_overlap=task_config.frames_overlap, 71 | height=pose_pixels.shape[-2], width=pose_pixels.shape[-1], fps=7, 72 | noise_aug_strength=task_config.noise_aug_strength, num_inference_steps=task_config.num_inference_steps, 73 | generator=generator, min_guidance_scale=task_config.guidance_scale, 74 | max_guidance_scale=task_config.guidance_scale, decode_chunk_size=8, output_type="pt", device=device 75 | ).frames.cpu() 76 | video_frames = (frames * 255.0).to(torch.uint8) 77 | 78 | for vid_idx in range(video_frames.shape[0]): 79 | # deprecated first frame because of ref image 80 | _video_frames = video_frames[vid_idx, 1:] 81 | 82 | return _video_frames 83 | 84 | 85 | @torch.no_grad() 86 | def main(args): 87 | if not args.no_use_float16 : 88 | torch.set_default_dtype(torch.float16) 89 | 90 | infer_config = OmegaConf.load(args.inference_config) 91 | pipeline = create_pipeline(infer_config, device) 92 | 93 | for task in infer_config.test_case: 94 | ############################################## Pre-process data ############################################## 95 | pose_pixels, image_pixels = preprocess( 96 | task.ref_video_path, task.ref_image_path, 97 | resolution=task.resolution, sample_stride=task.sample_stride 98 | ) 99 | ########################################### Run MimicMotion pipeline ########################################### 100 | _video_frames = run_pipeline( 101 | pipeline, 102 | image_pixels, pose_pixels, 103 | device, task 104 | ) 105 | ################################### save results to output folder. ########################################### 106 | save_to_mp4( 107 | _video_frames, 108 | f"{args.output_dir}/{os.path.basename(task.ref_video_path).split('.')[0]}" \ 109 | f"_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4", 110 | fps=task.fps, 111 | ) 112 | 113 | def set_logger(log_file=None, log_level=logging.INFO): 114 | log_handler = logging.FileHandler(log_file, "w") 115 | log_handler.setFormatter( 116 | logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s]: %(message)s") 117 | ) 118 | log_handler.setLevel(log_level) 119 | logger.addHandler(log_handler) 120 | 121 | 122 | if __name__ == "__main__": 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument("--log_file", type=str, default=None) 125 | parser.add_argument("--inference_config", type=str, default="configs/test.yaml") #ToDo 126 | parser.add_argument("--output_dir", type=str, default="outputs/", help="path to output") 127 | parser.add_argument("--no_use_float16", 128 | action="store_true", 129 | help="Whether use float16 to speed up inference", 130 | ) 131 | args = parser.parse_args() 132 | 133 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 134 | set_logger(args.log_file \ 135 | if args.log_file is not None else f"{args.output_dir}/{datetime.now().strftime('%Y%m%d%H%M%S')}.log") 136 | main(args) 137 | logger.info(f"--- Finished ---") 138 | 139 | -------------------------------------------------------------------------------- /mimicmotion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/MimicMotion/62f91e1f4ab750e1ab0d96c08f2d199b6e84e8b1/mimicmotion/__init__.py -------------------------------------------------------------------------------- /mimicmotion/dwpose/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /mimicmotion/dwpose/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/MimicMotion/62f91e1f4ab750e1ab0d96c08f2d199b6e84e8b1/mimicmotion/dwpose/__init__.py -------------------------------------------------------------------------------- /mimicmotion/dwpose/dwpose_detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from .wholebody import Wholebody 7 | 8 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | class DWposeDetector: 12 | """ 13 | A pose detect method for image-like data. 14 | 15 | Parameters: 16 | model_det: (str) serialized ONNX format model path, 17 | such as https://huggingface.co/yzd-v/DWPose/blob/main/yolox_l.onnx 18 | model_pose: (str) serialized ONNX format model path, 19 | such as https://huggingface.co/yzd-v/DWPose/blob/main/dw-ll_ucoco_384.onnx 20 | device: (str) 'cpu' or 'cuda:{device_id}' 21 | """ 22 | def __init__(self, model_det, model_pose, device='cpu'): 23 | self.args = model_det, model_pose, device 24 | 25 | def release_memory(self): 26 | if hasattr(self, 'pose_estimation'): 27 | del self.pose_estimation 28 | import gc; gc.collect() 29 | 30 | def __call__(self, oriImg): 31 | if not hasattr(self, 'pose_estimation'): 32 | self.pose_estimation = Wholebody(*self.args) 33 | 34 | oriImg = oriImg.copy() 35 | H, W, C = oriImg.shape 36 | with torch.no_grad(): 37 | candidate, score = self.pose_estimation(oriImg) 38 | nums, _, locs = candidate.shape 39 | candidate[..., 0] /= float(W) 40 | candidate[..., 1] /= float(H) 41 | body = candidate[:, :18].copy() 42 | body = body.reshape(nums * 18, locs) 43 | subset = score[:, :18].copy() 44 | for i in range(len(subset)): 45 | for j in range(len(subset[i])): 46 | if subset[i][j] > 0.3: 47 | subset[i][j] = int(18 * i + j) 48 | else: 49 | subset[i][j] = -1 50 | 51 | # un_visible = subset < 0.3 52 | # candidate[un_visible] = -1 53 | 54 | # foot = candidate[:, 18:24] 55 | 56 | faces = candidate[:, 24:92] 57 | 58 | hands = candidate[:, 92:113] 59 | hands = np.vstack([hands, candidate[:, 113:]]) 60 | 61 | faces_score = score[:, 24:92] 62 | hands_score = np.vstack([score[:, 92:113], score[:, 113:]]) 63 | 64 | bodies = dict(candidate=body, subset=subset, score=score[:, :18]) 65 | pose = dict(bodies=bodies, hands=hands, hands_score=hands_score, faces=faces, faces_score=faces_score) 66 | 67 | return pose 68 | 69 | dwpose_detector = DWposeDetector( 70 | model_det="models/DWPose/yolox_l.onnx", 71 | model_pose="models/DWPose/dw-ll_ucoco_384.onnx", 72 | device=device) 73 | -------------------------------------------------------------------------------- /mimicmotion/dwpose/onnxdet.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def nms(boxes, scores, nms_thr): 6 | """Single class NMS implemented in Numpy. 7 | 8 | Args: 9 | boxes (np.ndarray): shape=(N,4); N is number of boxes 10 | scores (np.ndarray): the score of bboxes 11 | nms_thr (float): the threshold in NMS 12 | 13 | Returns: 14 | List[int]: output bbox ids 15 | """ 16 | x1 = boxes[:, 0] 17 | y1 = boxes[:, 1] 18 | x2 = boxes[:, 2] 19 | y2 = boxes[:, 3] 20 | 21 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 22 | order = scores.argsort()[::-1] 23 | 24 | keep = [] 25 | while order.size > 0: 26 | i = order[0] 27 | keep.append(i) 28 | xx1 = np.maximum(x1[i], x1[order[1:]]) 29 | yy1 = np.maximum(y1[i], y1[order[1:]]) 30 | xx2 = np.minimum(x2[i], x2[order[1:]]) 31 | yy2 = np.minimum(y2[i], y2[order[1:]]) 32 | 33 | w = np.maximum(0.0, xx2 - xx1 + 1) 34 | h = np.maximum(0.0, yy2 - yy1 + 1) 35 | inter = w * h 36 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 37 | 38 | inds = np.where(ovr <= nms_thr)[0] 39 | order = order[inds + 1] 40 | 41 | return keep 42 | 43 | def multiclass_nms(boxes, scores, nms_thr, score_thr): 44 | """Multiclass NMS implemented in Numpy. Class-aware version. 45 | 46 | Args: 47 | boxes (np.ndarray): shape=(N,4); N is number of boxes 48 | scores (np.ndarray): the score of bboxes 49 | nms_thr (float): the threshold in NMS 50 | score_thr (float): the threshold of cls score 51 | 52 | Returns: 53 | np.ndarray: outputs bboxes coordinate 54 | """ 55 | final_dets = [] 56 | num_classes = scores.shape[1] 57 | for cls_ind in range(num_classes): 58 | cls_scores = scores[:, cls_ind] 59 | valid_score_mask = cls_scores > score_thr 60 | if valid_score_mask.sum() == 0: 61 | continue 62 | else: 63 | valid_scores = cls_scores[valid_score_mask] 64 | valid_boxes = boxes[valid_score_mask] 65 | keep = nms(valid_boxes, valid_scores, nms_thr) 66 | if len(keep) > 0: 67 | cls_inds = np.ones((len(keep), 1)) * cls_ind 68 | dets = np.concatenate( 69 | [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 70 | ) 71 | final_dets.append(dets) 72 | if len(final_dets) == 0: 73 | return None 74 | return np.concatenate(final_dets, 0) 75 | 76 | def demo_postprocess(outputs, img_size, p6=False): 77 | grids = [] 78 | expanded_strides = [] 79 | strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] 80 | 81 | hsizes = [img_size[0] // stride for stride in strides] 82 | wsizes = [img_size[1] // stride for stride in strides] 83 | 84 | for hsize, wsize, stride in zip(hsizes, wsizes, strides): 85 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) 86 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2) 87 | grids.append(grid) 88 | shape = grid.shape[:2] 89 | expanded_strides.append(np.full((*shape, 1), stride)) 90 | 91 | grids = np.concatenate(grids, 1) 92 | expanded_strides = np.concatenate(expanded_strides, 1) 93 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides 94 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides 95 | 96 | return outputs 97 | 98 | def preprocess(img, input_size, swap=(2, 0, 1)): 99 | if len(img.shape) == 3: 100 | padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 101 | else: 102 | padded_img = np.ones(input_size, dtype=np.uint8) * 114 103 | 104 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) 105 | resized_img = cv2.resize( 106 | img, 107 | (int(img.shape[1] * r), int(img.shape[0] * r)), 108 | interpolation=cv2.INTER_LINEAR, 109 | ).astype(np.uint8) 110 | padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img 111 | 112 | padded_img = padded_img.transpose(swap) 113 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) 114 | return padded_img, r 115 | 116 | def inference_detector(session, oriImg): 117 | """run human detect 118 | """ 119 | input_shape = (640,640) 120 | img, ratio = preprocess(oriImg, input_shape) 121 | 122 | ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]} 123 | output = session.run(None, ort_inputs) 124 | predictions = demo_postprocess(output[0], input_shape)[0] 125 | 126 | boxes = predictions[:, :4] 127 | scores = predictions[:, 4:5] * predictions[:, 5:] 128 | 129 | boxes_xyxy = np.ones_like(boxes) 130 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2. 131 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2. 132 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2. 133 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. 134 | boxes_xyxy /= ratio 135 | dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) 136 | if dets is not None: 137 | final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] 138 | isscore = final_scores>0.3 139 | iscat = final_cls_inds == 0 140 | isbbox = [ i and j for (i, j) in zip(isscore, iscat)] 141 | final_boxes = final_boxes[isbbox] 142 | else: 143 | final_boxes = np.array([]) 144 | 145 | return final_boxes 146 | -------------------------------------------------------------------------------- /mimicmotion/dwpose/onnxpose.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import cv2 4 | import numpy as np 5 | import onnxruntime as ort 6 | 7 | def preprocess( 8 | img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256) 9 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 10 | """Do preprocessing for RTMPose model inference. 11 | 12 | Args: 13 | img (np.ndarray): Input image in shape. 14 | input_size (tuple): Input image size in shape (w, h). 15 | 16 | Returns: 17 | tuple: 18 | - resized_img (np.ndarray): Preprocessed image. 19 | - center (np.ndarray): Center of image. 20 | - scale (np.ndarray): Scale of image. 21 | """ 22 | # get shape of image 23 | img_shape = img.shape[:2] 24 | out_img, out_center, out_scale = [], [], [] 25 | if len(out_bbox) == 0: 26 | out_bbox = [[0, 0, img_shape[1], img_shape[0]]] 27 | for i in range(len(out_bbox)): 28 | x0 = out_bbox[i][0] 29 | y0 = out_bbox[i][1] 30 | x1 = out_bbox[i][2] 31 | y1 = out_bbox[i][3] 32 | bbox = np.array([x0, y0, x1, y1]) 33 | 34 | # get center and scale 35 | center, scale = bbox_xyxy2cs(bbox, padding=1.25) 36 | 37 | # do affine transformation 38 | resized_img, scale = top_down_affine(input_size, scale, center, img) 39 | 40 | # normalize image 41 | mean = np.array([123.675, 116.28, 103.53]) 42 | std = np.array([58.395, 57.12, 57.375]) 43 | resized_img = (resized_img - mean) / std 44 | 45 | out_img.append(resized_img) 46 | out_center.append(center) 47 | out_scale.append(scale) 48 | 49 | return out_img, out_center, out_scale 50 | 51 | 52 | def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray: 53 | """Inference RTMPose model. 54 | 55 | Args: 56 | sess (ort.InferenceSession): ONNXRuntime session. 57 | img (np.ndarray): Input image in shape. 58 | 59 | Returns: 60 | outputs (np.ndarray): Output of RTMPose model. 61 | """ 62 | all_out = [] 63 | # build input 64 | for i in range(len(img)): 65 | input = [img[i].transpose(2, 0, 1)] 66 | 67 | # build output 68 | sess_input = {sess.get_inputs()[0].name: input} 69 | sess_output = [] 70 | for out in sess.get_outputs(): 71 | sess_output.append(out.name) 72 | 73 | # run model 74 | outputs = sess.run(sess_output, sess_input) 75 | all_out.append(outputs) 76 | 77 | return all_out 78 | 79 | 80 | def postprocess(outputs: List[np.ndarray], 81 | model_input_size: Tuple[int, int], 82 | center: Tuple[int, int], 83 | scale: Tuple[int, int], 84 | simcc_split_ratio: float = 2.0 85 | ) -> Tuple[np.ndarray, np.ndarray]: 86 | """Postprocess for RTMPose model output. 87 | 88 | Args: 89 | outputs (np.ndarray): Output of RTMPose model. 90 | model_input_size (tuple): RTMPose model Input image size. 91 | center (tuple): Center of bbox in shape (x, y). 92 | scale (tuple): Scale of bbox in shape (w, h). 93 | simcc_split_ratio (float): Split ratio of simcc. 94 | 95 | Returns: 96 | tuple: 97 | - keypoints (np.ndarray): Rescaled keypoints. 98 | - scores (np.ndarray): Model predict scores. 99 | """ 100 | all_key = [] 101 | all_score = [] 102 | for i in range(len(outputs)): 103 | # use simcc to decode 104 | simcc_x, simcc_y = outputs[i] 105 | keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio) 106 | 107 | # rescale keypoints 108 | keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2 109 | all_key.append(keypoints[0]) 110 | all_score.append(scores[0]) 111 | 112 | return np.array(all_key), np.array(all_score) 113 | 114 | 115 | def bbox_xyxy2cs(bbox: np.ndarray, 116 | padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]: 117 | """Transform the bbox format from (x,y,w,h) into (center, scale) 118 | 119 | Args: 120 | bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted 121 | as (left, top, right, bottom) 122 | padding (float): BBox padding factor that will be multilied to scale. 123 | Default: 1.0 124 | 125 | Returns: 126 | tuple: A tuple containing center and scale. 127 | - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or 128 | (n, 2) 129 | - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or 130 | (n, 2) 131 | """ 132 | # convert single bbox from (4, ) to (1, 4) 133 | dim = bbox.ndim 134 | if dim == 1: 135 | bbox = bbox[None, :] 136 | 137 | # get bbox center and scale 138 | x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3]) 139 | center = np.hstack([x1 + x2, y1 + y2]) * 0.5 140 | scale = np.hstack([x2 - x1, y2 - y1]) * padding 141 | 142 | if dim == 1: 143 | center = center[0] 144 | scale = scale[0] 145 | 146 | return center, scale 147 | 148 | 149 | def _fix_aspect_ratio(bbox_scale: np.ndarray, 150 | aspect_ratio: float) -> np.ndarray: 151 | """Extend the scale to match the given aspect ratio. 152 | 153 | Args: 154 | scale (np.ndarray): The image scale (w, h) in shape (2, ) 155 | aspect_ratio (float): The ratio of ``w/h`` 156 | 157 | Returns: 158 | np.ndarray: The reshaped image scale in (2, ) 159 | """ 160 | w, h = np.hsplit(bbox_scale, [1]) 161 | bbox_scale = np.where(w > h * aspect_ratio, 162 | np.hstack([w, w / aspect_ratio]), 163 | np.hstack([h * aspect_ratio, h])) 164 | return bbox_scale 165 | 166 | 167 | def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray: 168 | """Rotate a point by an angle. 169 | 170 | Args: 171 | pt (np.ndarray): 2D point coordinates (x, y) in shape (2, ) 172 | angle_rad (float): rotation angle in radian 173 | 174 | Returns: 175 | np.ndarray: Rotated point in shape (2, ) 176 | """ 177 | sn, cs = np.sin(angle_rad), np.cos(angle_rad) 178 | rot_mat = np.array([[cs, -sn], [sn, cs]]) 179 | return rot_mat @ pt 180 | 181 | 182 | def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray: 183 | """To calculate the affine matrix, three pairs of points are required. This 184 | function is used to get the 3rd point, given 2D points a & b. 185 | 186 | The 3rd point is defined by rotating vector `a - b` by 90 degrees 187 | anticlockwise, using b as the rotation center. 188 | 189 | Args: 190 | a (np.ndarray): The 1st point (x,y) in shape (2, ) 191 | b (np.ndarray): The 2nd point (x,y) in shape (2, ) 192 | 193 | Returns: 194 | np.ndarray: The 3rd point. 195 | """ 196 | direction = a - b 197 | c = b + np.r_[-direction[1], direction[0]] 198 | return c 199 | 200 | 201 | def get_warp_matrix(center: np.ndarray, 202 | scale: np.ndarray, 203 | rot: float, 204 | output_size: Tuple[int, int], 205 | shift: Tuple[float, float] = (0., 0.), 206 | inv: bool = False) -> np.ndarray: 207 | """Calculate the affine transformation matrix that can warp the bbox area 208 | in the input image to the output size. 209 | 210 | Args: 211 | center (np.ndarray[2, ]): Center of the bounding box (x, y). 212 | scale (np.ndarray[2, ]): Scale of the bounding box 213 | wrt [width, height]. 214 | rot (float): Rotation angle (degree). 215 | output_size (np.ndarray[2, ] | list(2,)): Size of the 216 | destination heatmaps. 217 | shift (0-100%): Shift translation ratio wrt the width/height. 218 | Default (0., 0.). 219 | inv (bool): Option to inverse the affine transform direction. 220 | (inv=False: src->dst or inv=True: dst->src) 221 | 222 | Returns: 223 | np.ndarray: A 2x3 transformation matrix 224 | """ 225 | shift = np.array(shift) 226 | src_w = scale[0] 227 | dst_w = output_size[0] 228 | dst_h = output_size[1] 229 | 230 | # compute transformation matrix 231 | rot_rad = np.deg2rad(rot) 232 | src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad) 233 | dst_dir = np.array([0., dst_w * -0.5]) 234 | 235 | # get four corners of the src rectangle in the original image 236 | src = np.zeros((3, 2), dtype=np.float32) 237 | src[0, :] = center + scale * shift 238 | src[1, :] = center + src_dir + scale * shift 239 | src[2, :] = _get_3rd_point(src[0, :], src[1, :]) 240 | 241 | # get four corners of the dst rectangle in the input image 242 | dst = np.zeros((3, 2), dtype=np.float32) 243 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5] 244 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir 245 | dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) 246 | 247 | if inv: 248 | warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src)) 249 | else: 250 | warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst)) 251 | 252 | return warp_mat 253 | 254 | 255 | def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict, 256 | img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 257 | """Get the bbox image as the model input by affine transform. 258 | 259 | Args: 260 | input_size (dict): The input size of the model. 261 | bbox_scale (dict): The bbox scale of the img. 262 | bbox_center (dict): The bbox center of the img. 263 | img (np.ndarray): The original image. 264 | 265 | Returns: 266 | tuple: A tuple containing center and scale. 267 | - np.ndarray[float32]: img after affine transform. 268 | - np.ndarray[float32]: bbox scale after affine transform. 269 | """ 270 | w, h = input_size 271 | warp_size = (int(w), int(h)) 272 | 273 | # reshape bbox to fixed aspect ratio 274 | bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h) 275 | 276 | # get the affine matrix 277 | center = bbox_center 278 | scale = bbox_scale 279 | rot = 0 280 | warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h)) 281 | 282 | # do affine transform 283 | img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR) 284 | 285 | return img, bbox_scale 286 | 287 | 288 | def get_simcc_maximum(simcc_x: np.ndarray, 289 | simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 290 | """Get maximum response location and value from simcc representations. 291 | 292 | Note: 293 | instance number: N 294 | num_keypoints: K 295 | heatmap height: H 296 | heatmap width: W 297 | 298 | Args: 299 | simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx) 300 | simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy) 301 | 302 | Returns: 303 | tuple: 304 | - locs (np.ndarray): locations of maximum heatmap responses in shape 305 | (K, 2) or (N, K, 2) 306 | - vals (np.ndarray): values of maximum heatmap responses in shape 307 | (K,) or (N, K) 308 | """ 309 | N, K, Wx = simcc_x.shape 310 | simcc_x = simcc_x.reshape(N * K, -1) 311 | simcc_y = simcc_y.reshape(N * K, -1) 312 | 313 | # get maximum value locations 314 | x_locs = np.argmax(simcc_x, axis=1) 315 | y_locs = np.argmax(simcc_y, axis=1) 316 | locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) 317 | max_val_x = np.amax(simcc_x, axis=1) 318 | max_val_y = np.amax(simcc_y, axis=1) 319 | 320 | # get maximum value across x and y axis 321 | mask = max_val_x > max_val_y 322 | max_val_x[mask] = max_val_y[mask] 323 | vals = max_val_x 324 | locs[vals <= 0.] = -1 325 | 326 | # reshape 327 | locs = locs.reshape(N, K, 2) 328 | vals = vals.reshape(N, K) 329 | 330 | return locs, vals 331 | 332 | 333 | def decode(simcc_x: np.ndarray, simcc_y: np.ndarray, 334 | simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]: 335 | """Modulate simcc distribution with Gaussian. 336 | 337 | Args: 338 | simcc_x (np.ndarray[K, Wx]): model predicted simcc in x. 339 | simcc_y (np.ndarray[K, Wy]): model predicted simcc in y. 340 | simcc_split_ratio (int): The split ratio of simcc. 341 | 342 | Returns: 343 | tuple: A tuple containing center and scale. 344 | - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2) 345 | - np.ndarray[float32]: scores in shape (K,) or (n, K) 346 | """ 347 | keypoints, scores = get_simcc_maximum(simcc_x, simcc_y) 348 | keypoints /= simcc_split_ratio 349 | 350 | return keypoints, scores 351 | 352 | 353 | def inference_pose(session, out_bbox, oriImg): 354 | """run pose detect 355 | 356 | Args: 357 | session (ort.InferenceSession): ONNXRuntime session. 358 | out_bbox (np.ndarray): bbox list 359 | oriImg (np.ndarray): Input image in shape. 360 | 361 | Returns: 362 | tuple: 363 | - keypoints (np.ndarray): Rescaled keypoints. 364 | - scores (np.ndarray): Model predict scores. 365 | """ 366 | h, w = session.get_inputs()[0].shape[2:] 367 | model_input_size = (w, h) 368 | # preprocess for rtm-pose model inference. 369 | resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size) 370 | # run pose estimation for processed img 371 | outputs = inference(session, resized_img) 372 | # postprocess for rtm-pose model output. 373 | keypoints, scores = postprocess(outputs, model_input_size, center, scale) 374 | 375 | return keypoints, scores 376 | -------------------------------------------------------------------------------- /mimicmotion/dwpose/preprocess.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import decord 3 | import numpy as np 4 | 5 | from .util import draw_pose 6 | from .dwpose_detector import dwpose_detector as dwprocessor 7 | 8 | 9 | def get_video_pose( 10 | video_path: str, 11 | ref_image: np.ndarray, 12 | sample_stride: int=1): 13 | """preprocess ref image pose and video pose 14 | 15 | Args: 16 | video_path (str): video pose path 17 | ref_image (np.ndarray): reference image 18 | sample_stride (int, optional): Defaults to 1. 19 | 20 | Returns: 21 | np.ndarray: sequence of video pose 22 | """ 23 | # select ref-keypoint from reference pose for pose rescale 24 | ref_pose = dwprocessor(ref_image) 25 | ref_keypoint_id = [0, 1, 2, 5, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] 26 | ref_keypoint_id = [i for i in ref_keypoint_id \ 27 | if len(ref_pose['bodies']['subset']) > 0 and ref_pose['bodies']['subset'][0][i] >= .0] 28 | ref_body = ref_pose['bodies']['candidate'][ref_keypoint_id] 29 | 30 | height, width, _ = ref_image.shape 31 | 32 | # read input video 33 | vr = decord.VideoReader(video_path, ctx=decord.cpu(0)) 34 | sample_stride *= max(1, int(vr.get_avg_fps() / 24)) 35 | 36 | frames = vr.get_batch(list(range(0, len(vr), sample_stride))).asnumpy() 37 | detected_poses = [dwprocessor(frm) for frm in tqdm(frames, desc="DWPose")] 38 | dwprocessor.release_memory() 39 | 40 | detected_bodies = np.stack( 41 | [p['bodies']['candidate'] for p in detected_poses if p['bodies']['candidate'].shape[0] == 18])[:, 42 | ref_keypoint_id] 43 | # compute linear-rescale params 44 | ay, by = np.polyfit(detected_bodies[:, :, 1].flatten(), np.tile(ref_body[:, 1], len(detected_bodies)), 1) 45 | fh, fw, _ = vr[0].shape 46 | ax = ay / (fh / fw / height * width) 47 | bx = np.mean(np.tile(ref_body[:, 0], len(detected_bodies)) - detected_bodies[:, :, 0].flatten() * ax) 48 | a = np.array([ax, ay]) 49 | b = np.array([bx, by]) 50 | output_pose = [] 51 | # pose rescale 52 | for detected_pose in detected_poses: 53 | detected_pose['bodies']['candidate'] = detected_pose['bodies']['candidate'] * a + b 54 | detected_pose['faces'] = detected_pose['faces'] * a + b 55 | detected_pose['hands'] = detected_pose['hands'] * a + b 56 | im = draw_pose(detected_pose, height, width) 57 | output_pose.append(np.array(im)) 58 | return np.stack(output_pose) 59 | 60 | 61 | def get_image_pose(ref_image): 62 | """process image pose 63 | 64 | Args: 65 | ref_image (np.ndarray): reference image pixel value 66 | 67 | Returns: 68 | np.ndarray: pose visual image in RGB-mode 69 | """ 70 | height, width, _ = ref_image.shape 71 | ref_pose = dwprocessor(ref_image) 72 | pose_img = draw_pose(ref_pose, height, width) 73 | return np.array(pose_img) 74 | -------------------------------------------------------------------------------- /mimicmotion/dwpose/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import matplotlib 4 | import cv2 5 | 6 | 7 | eps = 0.01 8 | 9 | def alpha_blend_color(color, alpha): 10 | """blend color according to point conf 11 | """ 12 | return [int(c * alpha) for c in color] 13 | 14 | def draw_bodypose(canvas, candidate, subset, score): 15 | H, W, C = canvas.shape 16 | candidate = np.array(candidate) 17 | subset = np.array(subset) 18 | 19 | stickwidth = 4 20 | 21 | limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ 22 | [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ 23 | [1, 16], [16, 18], [3, 17], [6, 18]] 24 | 25 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ 26 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ 27 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 28 | 29 | for i in range(17): 30 | for n in range(len(subset)): 31 | index = subset[n][np.array(limbSeq[i]) - 1] 32 | conf = score[n][np.array(limbSeq[i]) - 1] 33 | if conf[0] < 0.3 or conf[1] < 0.3: 34 | continue 35 | Y = candidate[index.astype(int), 0] * float(W) 36 | X = candidate[index.astype(int), 1] * float(H) 37 | mX = np.mean(X) 38 | mY = np.mean(Y) 39 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 40 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 41 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) 42 | cv2.fillConvexPoly(canvas, polygon, alpha_blend_color(colors[i], conf[0] * conf[1])) 43 | 44 | canvas = (canvas * 0.6).astype(np.uint8) 45 | 46 | for i in range(18): 47 | for n in range(len(subset)): 48 | index = int(subset[n][i]) 49 | if index == -1: 50 | continue 51 | x, y = candidate[index][0:2] 52 | conf = score[n][i] 53 | x = int(x * W) 54 | y = int(y * H) 55 | cv2.circle(canvas, (int(x), int(y)), 4, alpha_blend_color(colors[i], conf), thickness=-1) 56 | 57 | return canvas 58 | 59 | def draw_handpose(canvas, all_hand_peaks, all_hand_scores): 60 | H, W, C = canvas.shape 61 | 62 | edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ 63 | [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] 64 | 65 | for peaks, scores in zip(all_hand_peaks, all_hand_scores): 66 | 67 | for ie, e in enumerate(edges): 68 | x1, y1 = peaks[e[0]] 69 | x2, y2 = peaks[e[1]] 70 | x1 = int(x1 * W) 71 | y1 = int(y1 * H) 72 | x2 = int(x2 * W) 73 | y2 = int(y2 * H) 74 | score = int(scores[e[0]] * scores[e[1]] * 255) 75 | if x1 > eps and y1 > eps and x2 > eps and y2 > eps: 76 | cv2.line(canvas, (x1, y1), (x2, y2), 77 | matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * score, thickness=2) 78 | 79 | for i, keyponit in enumerate(peaks): 80 | x, y = keyponit 81 | x = int(x * W) 82 | y = int(y * H) 83 | score = int(scores[i] * 255) 84 | if x > eps and y > eps: 85 | cv2.circle(canvas, (x, y), 4, (0, 0, score), thickness=-1) 86 | return canvas 87 | 88 | def draw_facepose(canvas, all_lmks, all_scores): 89 | H, W, C = canvas.shape 90 | for lmks, scores in zip(all_lmks, all_scores): 91 | for lmk, score in zip(lmks, scores): 92 | x, y = lmk 93 | x = int(x * W) 94 | y = int(y * H) 95 | conf = int(score * 255) 96 | if x > eps and y > eps: 97 | cv2.circle(canvas, (x, y), 3, (conf, conf, conf), thickness=-1) 98 | return canvas 99 | 100 | def draw_pose(pose, H, W, ref_w=2160): 101 | """vis dwpose outputs 102 | 103 | Args: 104 | pose (List): DWposeDetector outputs in dwpose_detector.py 105 | H (int): height 106 | W (int): width 107 | ref_w (int, optional) Defaults to 2160. 108 | 109 | Returns: 110 | np.ndarray: image pixel value in RGB mode 111 | """ 112 | bodies = pose['bodies'] 113 | faces = pose['faces'] 114 | hands = pose['hands'] 115 | candidate = bodies['candidate'] 116 | subset = bodies['subset'] 117 | 118 | sz = min(H, W) 119 | sr = (ref_w / sz) if sz != ref_w else 1 120 | 121 | ########################################## create zero canvas ################################################## 122 | canvas = np.zeros(shape=(int(H*sr), int(W*sr), 3), dtype=np.uint8) 123 | 124 | ########################################### draw body pose ##################################################### 125 | canvas = draw_bodypose(canvas, candidate, subset, score=bodies['score']) 126 | 127 | ########################################### draw hand pose ##################################################### 128 | canvas = draw_handpose(canvas, hands, pose['hands_score']) 129 | 130 | ########################################### draw face pose ##################################################### 131 | canvas = draw_facepose(canvas, faces, pose['faces_score']) 132 | 133 | return cv2.cvtColor(cv2.resize(canvas, (W, H)), cv2.COLOR_BGR2RGB).transpose(2, 0, 1) 134 | -------------------------------------------------------------------------------- /mimicmotion/dwpose/wholebody.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnxruntime as ort 3 | 4 | from .onnxdet import inference_detector 5 | from .onnxpose import inference_pose 6 | 7 | 8 | class Wholebody: 9 | """detect human pose by dwpose 10 | """ 11 | def __init__(self, model_det, model_pose, device="cpu"): 12 | providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider'] 13 | provider_options = None if device == 'cpu' else [{'device_id': 0}] 14 | 15 | self.session_det = ort.InferenceSession( 16 | path_or_bytes=model_det, providers=providers, provider_options=provider_options 17 | ) 18 | self.session_pose = ort.InferenceSession( 19 | path_or_bytes=model_pose, providers=providers, provider_options=provider_options 20 | ) 21 | 22 | def __call__(self, oriImg): 23 | """call to process dwpose-detect 24 | 25 | Args: 26 | oriImg (np.ndarray): detected image 27 | 28 | """ 29 | det_result = inference_detector(self.session_det, oriImg) 30 | keypoints, scores = inference_pose(self.session_pose, det_result, oriImg) 31 | 32 | keypoints_info = np.concatenate( 33 | (keypoints, scores[..., None]), axis=-1) 34 | # compute neck joint 35 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1) 36 | # neck score when visualizing pred 37 | neck[:, 2:4] = np.logical_and( 38 | keypoints_info[:, 5, 2:4] > 0.3, 39 | keypoints_info[:, 6, 2:4] > 0.3).astype(int) 40 | new_keypoints_info = np.insert( 41 | keypoints_info, 17, neck, axis=1) 42 | mmpose_idx = [ 43 | 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 44 | ] 45 | openpose_idx = [ 46 | 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 47 | ] 48 | new_keypoints_info[:, openpose_idx] = \ 49 | new_keypoints_info[:, mmpose_idx] 50 | keypoints_info = new_keypoints_info 51 | 52 | keypoints, scores = keypoints_info[ 53 | ..., :2], keypoints_info[..., 2] 54 | 55 | return keypoints, scores 56 | 57 | 58 | -------------------------------------------------------------------------------- /mimicmotion/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/MimicMotion/62f91e1f4ab750e1ab0d96c08f2d199b6e84e8b1/mimicmotion/modules/__init__.py -------------------------------------------------------------------------------- /mimicmotion/modules/attention.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, Optional 3 | 4 | import torch 5 | from diffusers.configuration_utils import ConfigMixin, register_to_config 6 | from diffusers.models.attention import BasicTransformerBlock, TemporalBasicTransformerBlock 7 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 8 | from diffusers.models.modeling_utils import ModelMixin 9 | from diffusers.models.resnet import AlphaBlender 10 | from diffusers.utils import BaseOutput 11 | from torch import nn 12 | 13 | 14 | @dataclass 15 | class TransformerTemporalModelOutput(BaseOutput): 16 | """ 17 | The output of [`TransformerTemporalModel`]. 18 | 19 | Args: 20 | sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`): 21 | The hidden states output conditioned on `encoder_hidden_states` input. 22 | """ 23 | 24 | sample: torch.FloatTensor 25 | 26 | 27 | class TransformerTemporalModel(ModelMixin, ConfigMixin): 28 | """ 29 | A Transformer model for video-like data. 30 | 31 | Parameters: 32 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 33 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 34 | in_channels (`int`, *optional*): 35 | The number of channels in the input and output (specify if the input is **continuous**). 36 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 37 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 38 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 39 | attention_bias (`bool`, *optional*): 40 | Configure if the `TransformerBlock` attention should contain a bias parameter. 41 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). 42 | This is fixed during training since it is used to learn a number of position embeddings. 43 | activation_fn (`str`, *optional*, defaults to `"geglu"`): 44 | Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported 45 | activation functions. 46 | norm_elementwise_affine (`bool`, *optional*): 47 | Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. 48 | double_self_attention (`bool`, *optional*): 49 | Configure if each `TransformerBlock` should contain two self-attention layers. 50 | positional_embeddings: (`str`, *optional*): 51 | The type of positional embeddings to apply to the sequence input before passing use. 52 | num_positional_embeddings: (`int`, *optional*): 53 | The maximum length of the sequence over which to apply positional embeddings. 54 | """ 55 | 56 | @register_to_config 57 | def __init__( 58 | self, 59 | num_attention_heads: int = 16, 60 | attention_head_dim: int = 88, 61 | in_channels: Optional[int] = None, 62 | out_channels: Optional[int] = None, 63 | num_layers: int = 1, 64 | dropout: float = 0.0, 65 | norm_num_groups: int = 32, 66 | cross_attention_dim: Optional[int] = None, 67 | attention_bias: bool = False, 68 | sample_size: Optional[int] = None, 69 | activation_fn: str = "geglu", 70 | norm_elementwise_affine: bool = True, 71 | double_self_attention: bool = True, 72 | positional_embeddings: Optional[str] = None, 73 | num_positional_embeddings: Optional[int] = None, 74 | ): 75 | super().__init__() 76 | self.num_attention_heads = num_attention_heads 77 | self.attention_head_dim = attention_head_dim 78 | inner_dim = num_attention_heads * attention_head_dim 79 | 80 | self.in_channels = in_channels 81 | 82 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 83 | self.proj_in = nn.Linear(in_channels, inner_dim) 84 | 85 | # 3. Define transformers blocks 86 | self.transformer_blocks = nn.ModuleList( 87 | [ 88 | BasicTransformerBlock( 89 | inner_dim, 90 | num_attention_heads, 91 | attention_head_dim, 92 | dropout=dropout, 93 | cross_attention_dim=cross_attention_dim, 94 | activation_fn=activation_fn, 95 | attention_bias=attention_bias, 96 | double_self_attention=double_self_attention, 97 | norm_elementwise_affine=norm_elementwise_affine, 98 | positional_embeddings=positional_embeddings, 99 | num_positional_embeddings=num_positional_embeddings, 100 | ) 101 | for d in range(num_layers) 102 | ] 103 | ) 104 | 105 | self.proj_out = nn.Linear(inner_dim, in_channels) 106 | 107 | def forward( 108 | self, 109 | hidden_states: torch.FloatTensor, 110 | encoder_hidden_states: Optional[torch.LongTensor] = None, 111 | timestep: Optional[torch.LongTensor] = None, 112 | class_labels: torch.LongTensor = None, 113 | num_frames: int = 1, 114 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 115 | return_dict: bool = True, 116 | ) -> TransformerTemporalModelOutput: 117 | """ 118 | The [`TransformerTemporal`] forward method. 119 | 120 | Args: 121 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, 122 | `torch.FloatTensor` of shape `(batch size, channel, height, width)`if continuous): Input hidden_states. 123 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): 124 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 125 | self-attention. 126 | timestep ( `torch.LongTensor`, *optional*): 127 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. 128 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): 129 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in 130 | `AdaLayerZeroNorm`. 131 | num_frames (`int`, *optional*, defaults to 1): 132 | The number of frames to be processed per batch. This is used to reshape the hidden states. 133 | cross_attention_kwargs (`dict`, *optional*): 134 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 135 | `self.processor` in [diffusers.models.attention_processor]( 136 | https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 137 | return_dict (`bool`, *optional*, defaults to `True`): 138 | Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain 139 | tuple. 140 | 141 | Returns: 142 | [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: 143 | If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is 144 | returned, otherwise a `tuple` where the first element is the sample tensor. 145 | """ 146 | # 1. Input 147 | batch_frames, channel, height, width = hidden_states.shape 148 | batch_size = batch_frames // num_frames 149 | 150 | residual = hidden_states 151 | 152 | hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) 153 | hidden_states = hidden_states.permute(0, 2, 1, 3, 4) 154 | 155 | hidden_states = self.norm(hidden_states) 156 | hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) 157 | 158 | hidden_states = self.proj_in(hidden_states) 159 | 160 | # 2. Blocks 161 | for block in self.transformer_blocks: 162 | hidden_states = block( 163 | hidden_states, 164 | encoder_hidden_states=encoder_hidden_states, 165 | timestep=timestep, 166 | cross_attention_kwargs=cross_attention_kwargs, 167 | class_labels=class_labels, 168 | ) 169 | 170 | # 3. Output 171 | hidden_states = self.proj_out(hidden_states) 172 | hidden_states = ( 173 | hidden_states[None, None, :] 174 | .reshape(batch_size, height, width, num_frames, channel) 175 | .permute(0, 3, 4, 1, 2) 176 | .contiguous() 177 | ) 178 | hidden_states = hidden_states.reshape(batch_frames, channel, height, width) 179 | 180 | output = hidden_states + residual 181 | 182 | if not return_dict: 183 | return (output,) 184 | 185 | return TransformerTemporalModelOutput(sample=output) 186 | 187 | 188 | class TransformerSpatioTemporalModel(nn.Module): 189 | """ 190 | A Transformer model for video-like data. 191 | 192 | Parameters: 193 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 194 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 195 | in_channels (`int`, *optional*): 196 | The number of channels in the input and output (specify if the input is **continuous**). 197 | out_channels (`int`, *optional*): 198 | The number of channels in the output (specify if the input is **continuous**). 199 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 200 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 201 | """ 202 | 203 | def __init__( 204 | self, 205 | num_attention_heads: int = 16, 206 | attention_head_dim: int = 88, 207 | in_channels: int = 320, 208 | out_channels: Optional[int] = None, 209 | num_layers: int = 1, 210 | cross_attention_dim: Optional[int] = None, 211 | ): 212 | super().__init__() 213 | self.num_attention_heads = num_attention_heads 214 | self.attention_head_dim = attention_head_dim 215 | 216 | inner_dim = num_attention_heads * attention_head_dim 217 | self.inner_dim = inner_dim 218 | 219 | # 2. Define input layers 220 | self.in_channels = in_channels 221 | self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6) 222 | self.proj_in = nn.Linear(in_channels, inner_dim) 223 | 224 | # 3. Define transformers blocks 225 | self.transformer_blocks = nn.ModuleList( 226 | [ 227 | BasicTransformerBlock( 228 | inner_dim, 229 | num_attention_heads, 230 | attention_head_dim, 231 | cross_attention_dim=cross_attention_dim, 232 | ) 233 | for d in range(num_layers) 234 | ] 235 | ) 236 | 237 | time_mix_inner_dim = inner_dim 238 | self.temporal_transformer_blocks = nn.ModuleList( 239 | [ 240 | TemporalBasicTransformerBlock( 241 | inner_dim, 242 | time_mix_inner_dim, 243 | num_attention_heads, 244 | attention_head_dim, 245 | cross_attention_dim=cross_attention_dim, 246 | ) 247 | for _ in range(num_layers) 248 | ] 249 | ) 250 | 251 | time_embed_dim = in_channels * 4 252 | self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels) 253 | self.time_proj = Timesteps(in_channels, True, 0) 254 | self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images") 255 | 256 | # 4. Define output layers 257 | self.out_channels = in_channels if out_channels is None else out_channels 258 | # TODO: should use out_channels for continuous projections 259 | self.proj_out = nn.Linear(inner_dim, in_channels) 260 | 261 | self.gradient_checkpointing = False 262 | 263 | def forward( 264 | self, 265 | hidden_states: torch.Tensor, 266 | encoder_hidden_states: Optional[torch.Tensor] = None, 267 | image_only_indicator: Optional[torch.Tensor] = None, 268 | return_dict: bool = True, 269 | ): 270 | """ 271 | Args: 272 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): 273 | Input hidden_states. 274 | num_frames (`int`): 275 | The number of frames to be processed per batch. This is used to reshape the hidden states. 276 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): 277 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 278 | self-attention. 279 | image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*): 280 | A tensor indicating whether the input contains only images. 1 indicates that the input contains only 281 | images, 0 indicates that the input contains video frames. 282 | return_dict (`bool`, *optional*, defaults to `True`): 283 | Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] 284 | instead of a plain tuple. 285 | 286 | Returns: 287 | [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: 288 | If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is 289 | returned, otherwise a `tuple` where the first element is the sample tensor. 290 | """ 291 | # 1. Input 292 | batch_frames, _, height, width = hidden_states.shape 293 | num_frames = image_only_indicator.shape[-1] 294 | batch_size = batch_frames // num_frames 295 | 296 | time_context = encoder_hidden_states 297 | time_context_first_timestep = time_context[None, :].reshape( 298 | batch_size, num_frames, -1, time_context.shape[-1] 299 | )[:, 0] 300 | time_context = time_context_first_timestep[None, :].broadcast_to( 301 | height * width, batch_size, 1, time_context.shape[-1] 302 | ) 303 | time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1]) 304 | 305 | residual = hidden_states 306 | 307 | hidden_states = self.norm(hidden_states) 308 | inner_dim = hidden_states.shape[1] 309 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim) 310 | hidden_states = torch.utils.checkpoint.checkpoint(self.proj_in, hidden_states) 311 | 312 | num_frames_emb = torch.arange(num_frames, device=hidden_states.device) 313 | num_frames_emb = num_frames_emb.repeat(batch_size, 1) 314 | num_frames_emb = num_frames_emb.reshape(-1) 315 | t_emb = self.time_proj(num_frames_emb) 316 | 317 | # `Timesteps` does not contain any weights and will always return f32 tensors 318 | # but time_embedding might actually be running in fp16. so we need to cast here. 319 | # there might be better ways to encapsulate this. 320 | t_emb = t_emb.to(dtype=hidden_states.dtype) 321 | 322 | emb = self.time_pos_embed(t_emb) 323 | emb = emb[:, None, :] 324 | 325 | # 2. Blocks 326 | for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks): 327 | if self.gradient_checkpointing: 328 | hidden_states = torch.utils.checkpoint.checkpoint( 329 | block, 330 | hidden_states, 331 | None, 332 | encoder_hidden_states, 333 | None, 334 | use_reentrant=False, 335 | ) 336 | else: 337 | hidden_states = block( 338 | hidden_states, 339 | encoder_hidden_states=encoder_hidden_states, 340 | ) 341 | 342 | hidden_states_mix = hidden_states 343 | hidden_states_mix = hidden_states_mix + emb 344 | 345 | if self.gradient_checkpointing: 346 | hidden_states_mix = torch.utils.checkpoint.checkpoint( 347 | temporal_block, 348 | hidden_states_mix, 349 | num_frames, 350 | time_context, 351 | ) 352 | hidden_states = self.time_mixer( 353 | x_spatial=hidden_states, 354 | x_temporal=hidden_states_mix, 355 | image_only_indicator=image_only_indicator, 356 | ) 357 | else: 358 | hidden_states_mix = temporal_block( 359 | hidden_states_mix, 360 | num_frames=num_frames, 361 | encoder_hidden_states=time_context, 362 | ) 363 | hidden_states = self.time_mixer( 364 | x_spatial=hidden_states, 365 | x_temporal=hidden_states_mix, 366 | image_only_indicator=image_only_indicator, 367 | ) 368 | 369 | # 3. Output 370 | hidden_states = torch.utils.checkpoint.checkpoint(self.proj_out, hidden_states) 371 | hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 372 | 373 | output = hidden_states + residual 374 | 375 | if not return_dict: 376 | return (output,) 377 | 378 | return TransformerTemporalModelOutput(sample=output) 379 | -------------------------------------------------------------------------------- /mimicmotion/modules/pose_net.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import einops 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | 9 | 10 | class PoseNet(nn.Module): 11 | """a tiny conv network for introducing pose sequence as the condition 12 | """ 13 | def __init__(self, noise_latent_channels=320, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | # multiple convolution layers 16 | self.conv_layers = nn.Sequential( 17 | nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1), 18 | nn.SiLU(), 19 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1), 20 | nn.SiLU(), 21 | 22 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1), 23 | nn.SiLU(), 24 | nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1), 25 | nn.SiLU(), 26 | 27 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1), 28 | nn.SiLU(), 29 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1), 30 | nn.SiLU(), 31 | 32 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1), 33 | nn.SiLU(), 34 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), 35 | nn.SiLU() 36 | ) 37 | 38 | # Final projection layer 39 | self.final_proj = nn.Conv2d(in_channels=128, out_channels=noise_latent_channels, kernel_size=1) 40 | 41 | # Initialize layers 42 | self._initialize_weights() 43 | 44 | self.scale = nn.Parameter(torch.ones(1) * 2) 45 | 46 | def _initialize_weights(self): 47 | """Initialize weights with He. initialization and zero out the biases 48 | """ 49 | for m in self.conv_layers: 50 | if isinstance(m, nn.Conv2d): 51 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 52 | init.normal_(m.weight, mean=0.0, std=np.sqrt(2. / n)) 53 | if m.bias is not None: 54 | init.zeros_(m.bias) 55 | init.zeros_(self.final_proj.weight) 56 | if self.final_proj.bias is not None: 57 | init.zeros_(self.final_proj.bias) 58 | 59 | def forward(self, x): 60 | if x.ndim == 5: 61 | x = einops.rearrange(x, "b f c h w -> (b f) c h w") 62 | x = self.conv_layers(x) 63 | x = self.final_proj(x) 64 | 65 | return x * self.scale 66 | 67 | @classmethod 68 | def from_pretrained(cls, pretrained_model_path): 69 | """load pretrained pose-net weights 70 | """ 71 | if not Path(pretrained_model_path).exists(): 72 | print(f"There is no model file in {pretrained_model_path}") 73 | print(f"loaded PoseNet's pretrained weights from {pretrained_model_path}.") 74 | 75 | state_dict = torch.load(pretrained_model_path, map_location="cpu") 76 | model = PoseNet(noise_latent_channels=320) 77 | 78 | model.load_state_dict(state_dict, strict=True) 79 | 80 | return model 81 | -------------------------------------------------------------------------------- /mimicmotion/modules/unet.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from diffusers.configuration_utils import ConfigMixin, register_to_config 7 | from diffusers.loaders import UNet2DConditionLoadersMixin 8 | from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor 9 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 10 | from diffusers.models.modeling_utils import ModelMixin 11 | from diffusers.utils import BaseOutput, logging 12 | 13 | from diffusers.models.unets.unet_3d_blocks import get_down_block, get_up_block, UNetMidBlockSpatioTemporal 14 | 15 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 16 | 17 | 18 | @dataclass 19 | class UNetSpatioTemporalConditionOutput(BaseOutput): 20 | """ 21 | The output of [`UNetSpatioTemporalConditionModel`]. 22 | 23 | Args: 24 | sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): 25 | The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. 26 | """ 27 | 28 | sample: torch.FloatTensor = None 29 | 30 | 31 | class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): 32 | r""" 33 | A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, 34 | and a timestep and returns a sample shaped output. 35 | 36 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented 37 | for all models (such as downloading or saving). 38 | 39 | Parameters: 40 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): 41 | Height and width of input/output sample. 42 | in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. 43 | out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. 44 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", 45 | "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): 46 | The tuple of downsample blocks to use. 47 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", 48 | "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): 49 | The tuple of upsample blocks to use. 50 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): 51 | The tuple of output channels for each block. 52 | addition_time_embed_dim: (`int`, defaults to 256): 53 | Dimension to to encode the additional time ids. 54 | projection_class_embeddings_input_dim (`int`, defaults to 768): 55 | The dimension of the projection of encoded `added_time_ids`. 56 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. 57 | cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): 58 | The dimension of the cross attention features. 59 | transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): 60 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for 61 | [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], 62 | [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], 63 | [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. 64 | num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): 65 | The number of attention heads. 66 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 67 | """ 68 | 69 | _supports_gradient_checkpointing = True 70 | 71 | @register_to_config 72 | def __init__( 73 | self, 74 | sample_size: Optional[int] = None, 75 | in_channels: int = 8, 76 | out_channels: int = 4, 77 | down_block_types: Tuple[str] = ( 78 | "CrossAttnDownBlockSpatioTemporal", 79 | "CrossAttnDownBlockSpatioTemporal", 80 | "CrossAttnDownBlockSpatioTemporal", 81 | "DownBlockSpatioTemporal", 82 | ), 83 | up_block_types: Tuple[str] = ( 84 | "UpBlockSpatioTemporal", 85 | "CrossAttnUpBlockSpatioTemporal", 86 | "CrossAttnUpBlockSpatioTemporal", 87 | "CrossAttnUpBlockSpatioTemporal", 88 | ), 89 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 90 | addition_time_embed_dim: int = 256, 91 | projection_class_embeddings_input_dim: int = 768, 92 | layers_per_block: Union[int, Tuple[int]] = 2, 93 | cross_attention_dim: Union[int, Tuple[int]] = 1024, 94 | transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, 95 | num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20), 96 | num_frames: int = 25, 97 | ): 98 | super().__init__() 99 | 100 | self.sample_size = sample_size 101 | 102 | # Check inputs 103 | if len(down_block_types) != len(up_block_types): 104 | raise ValueError( 105 | f"Must provide the same number of `down_block_types` as `up_block_types`. " \ 106 | f"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." 107 | ) 108 | 109 | if len(block_out_channels) != len(down_block_types): 110 | raise ValueError( 111 | f"Must provide the same number of `block_out_channels` as `down_block_types`. " \ 112 | f"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 113 | ) 114 | 115 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): 116 | raise ValueError( 117 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. " \ 118 | f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." 119 | ) 120 | 121 | if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): 122 | raise ValueError( 123 | f"Must provide the same number of `cross_attention_dim` as `down_block_types`. " \ 124 | f"`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." 125 | ) 126 | 127 | if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): 128 | raise ValueError( 129 | f"Must provide the same number of `layers_per_block` as `down_block_types`. " \ 130 | f"`layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." 131 | ) 132 | 133 | # input 134 | self.conv_in = nn.Conv2d( 135 | in_channels, 136 | block_out_channels[0], 137 | kernel_size=3, 138 | padding=1, 139 | ) 140 | 141 | # time 142 | time_embed_dim = block_out_channels[0] * 4 143 | 144 | self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) 145 | timestep_input_dim = block_out_channels[0] 146 | 147 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 148 | 149 | self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) 150 | self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 151 | 152 | self.down_blocks = nn.ModuleList([]) 153 | self.up_blocks = nn.ModuleList([]) 154 | 155 | if isinstance(num_attention_heads, int): 156 | num_attention_heads = (num_attention_heads,) * len(down_block_types) 157 | 158 | if isinstance(cross_attention_dim, int): 159 | cross_attention_dim = (cross_attention_dim,) * len(down_block_types) 160 | 161 | if isinstance(layers_per_block, int): 162 | layers_per_block = [layers_per_block] * len(down_block_types) 163 | 164 | if isinstance(transformer_layers_per_block, int): 165 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) 166 | 167 | blocks_time_embed_dim = time_embed_dim 168 | 169 | # down 170 | output_channel = block_out_channels[0] 171 | for i, down_block_type in enumerate(down_block_types): 172 | input_channel = output_channel 173 | output_channel = block_out_channels[i] 174 | is_final_block = i == len(block_out_channels) - 1 175 | 176 | down_block = get_down_block( 177 | down_block_type, 178 | num_layers=layers_per_block[i], 179 | transformer_layers_per_block=transformer_layers_per_block[i], 180 | in_channels=input_channel, 181 | out_channels=output_channel, 182 | temb_channels=blocks_time_embed_dim, 183 | add_downsample=not is_final_block, 184 | resnet_eps=1e-5, 185 | cross_attention_dim=cross_attention_dim[i], 186 | num_attention_heads=num_attention_heads[i], 187 | resnet_act_fn="silu", 188 | ) 189 | self.down_blocks.append(down_block) 190 | 191 | # mid 192 | self.mid_block = UNetMidBlockSpatioTemporal( 193 | block_out_channels[-1], 194 | temb_channels=blocks_time_embed_dim, 195 | transformer_layers_per_block=transformer_layers_per_block[-1], 196 | cross_attention_dim=cross_attention_dim[-1], 197 | num_attention_heads=num_attention_heads[-1], 198 | ) 199 | 200 | # count how many layers upsample the images 201 | self.num_upsamplers = 0 202 | 203 | # up 204 | reversed_block_out_channels = list(reversed(block_out_channels)) 205 | reversed_num_attention_heads = list(reversed(num_attention_heads)) 206 | reversed_layers_per_block = list(reversed(layers_per_block)) 207 | reversed_cross_attention_dim = list(reversed(cross_attention_dim)) 208 | reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) 209 | 210 | output_channel = reversed_block_out_channels[0] 211 | for i, up_block_type in enumerate(up_block_types): 212 | is_final_block = i == len(block_out_channels) - 1 213 | 214 | prev_output_channel = output_channel 215 | output_channel = reversed_block_out_channels[i] 216 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 217 | 218 | # add upsample block for all BUT final layer 219 | if not is_final_block: 220 | add_upsample = True 221 | self.num_upsamplers += 1 222 | else: 223 | add_upsample = False 224 | 225 | up_block = get_up_block( 226 | up_block_type, 227 | num_layers=reversed_layers_per_block[i] + 1, 228 | transformer_layers_per_block=reversed_transformer_layers_per_block[i], 229 | in_channels=input_channel, 230 | out_channels=output_channel, 231 | prev_output_channel=prev_output_channel, 232 | temb_channels=blocks_time_embed_dim, 233 | add_upsample=add_upsample, 234 | resnet_eps=1e-5, 235 | resolution_idx=i, 236 | cross_attention_dim=reversed_cross_attention_dim[i], 237 | num_attention_heads=reversed_num_attention_heads[i], 238 | resnet_act_fn="silu", 239 | ) 240 | self.up_blocks.append(up_block) 241 | prev_output_channel = output_channel 242 | 243 | # out 244 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5) 245 | self.conv_act = nn.SiLU() 246 | 247 | self.conv_out = nn.Conv2d( 248 | block_out_channels[0], 249 | out_channels, 250 | kernel_size=3, 251 | padding=1, 252 | ) 253 | 254 | @property 255 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 256 | r""" 257 | Returns: 258 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 259 | indexed by its weight name. 260 | """ 261 | # set recursively 262 | processors = {} 263 | 264 | def fn_recursive_add_processors( 265 | name: str, 266 | module: torch.nn.Module, 267 | processors: Dict[str, AttentionProcessor], 268 | ): 269 | if hasattr(module, "get_processor"): 270 | processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) 271 | 272 | for sub_name, child in module.named_children(): 273 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 274 | 275 | return processors 276 | 277 | for name, module in self.named_children(): 278 | fn_recursive_add_processors(name, module, processors) 279 | 280 | return processors 281 | 282 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 283 | r""" 284 | Sets the attention processor to use to compute attention. 285 | 286 | Parameters: 287 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 288 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 289 | for **all** `Attention` layers. 290 | 291 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 292 | processor. This is strongly recommended when setting trainable attention processors. 293 | 294 | """ 295 | count = len(self.attn_processors.keys()) 296 | 297 | if isinstance(processor, dict) and len(processor) != count: 298 | raise ValueError( 299 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 300 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 301 | ) 302 | 303 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 304 | if hasattr(module, "set_processor"): 305 | if not isinstance(processor, dict): 306 | module.set_processor(processor) 307 | else: 308 | module.set_processor(processor.pop(f"{name}.processor")) 309 | 310 | for sub_name, child in module.named_children(): 311 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 312 | 313 | for name, module in self.named_children(): 314 | fn_recursive_attn_processor(name, module, processor) 315 | 316 | def set_default_attn_processor(self): 317 | """ 318 | Disables custom attention processors and sets the default attention implementation. 319 | """ 320 | if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 321 | processor = AttnProcessor() 322 | else: 323 | raise ValueError( 324 | f"Cannot call `set_default_attn_processor` " \ 325 | f"when attention processors are of type {next(iter(self.attn_processors.values()))}" 326 | ) 327 | 328 | self.set_attn_processor(processor) 329 | 330 | def _set_gradient_checkpointing(self, module, value=False): 331 | if hasattr(module, "gradient_checkpointing"): 332 | module.gradient_checkpointing = value 333 | 334 | # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking 335 | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: 336 | """ 337 | Sets the attention processor to use [feed forward 338 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). 339 | 340 | Parameters: 341 | chunk_size (`int`, *optional*): 342 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually 343 | over each tensor of dim=`dim`. 344 | dim (`int`, *optional*, defaults to `0`): 345 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) 346 | or dim=1 (sequence length). 347 | """ 348 | if dim not in [0, 1]: 349 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") 350 | 351 | # By default chunk size is 1 352 | chunk_size = chunk_size or 1 353 | 354 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 355 | if hasattr(module, "set_chunk_feed_forward"): 356 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 357 | 358 | for child in module.children(): 359 | fn_recursive_feed_forward(child, chunk_size, dim) 360 | 361 | for module in self.children(): 362 | fn_recursive_feed_forward(module, chunk_size, dim) 363 | 364 | def forward( 365 | self, 366 | sample: torch.FloatTensor, 367 | timestep: Union[torch.Tensor, float, int], 368 | encoder_hidden_states: torch.Tensor, 369 | added_time_ids: torch.Tensor, 370 | pose_latents: torch.Tensor = None, 371 | image_only_indicator: bool = False, 372 | return_dict: bool = True, 373 | ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: 374 | r""" 375 | The [`UNetSpatioTemporalConditionModel`] forward method. 376 | 377 | Args: 378 | sample (`torch.FloatTensor`): 379 | The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. 380 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. 381 | encoder_hidden_states (`torch.FloatTensor`): 382 | The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. 383 | added_time_ids: (`torch.FloatTensor`): 384 | The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal 385 | embeddings and added to the time embeddings. 386 | pose_latents: (`torch.FloatTensor`): 387 | The additional latents for pose sequences. 388 | image_only_indicator (`bool`, *optional*, defaults to `False`): 389 | Whether or not training with all images. 390 | return_dict (`bool`, *optional*, defaults to `True`): 391 | Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] 392 | instead of a plain tuple. 393 | Returns: 394 | [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: 395 | If `return_dict` is True, 396 | an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, 397 | otherwise a `tuple` is returned where the first element is the sample tensor. 398 | """ 399 | # 1. time 400 | timesteps = timestep 401 | if not torch.is_tensor(timesteps): 402 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 403 | # This would be a good case for the `match` statement (Python 3.10+) 404 | is_mps = sample.device.type == "mps" 405 | if isinstance(timestep, float): 406 | dtype = torch.float32 if is_mps else torch.float64 407 | else: 408 | dtype = torch.int32 if is_mps else torch.int64 409 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 410 | elif len(timesteps.shape) == 0: 411 | timesteps = timesteps[None].to(sample.device) 412 | 413 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 414 | batch_size, num_frames = sample.shape[:2] 415 | timesteps = timesteps.expand(batch_size) 416 | 417 | t_emb = self.time_proj(timesteps) 418 | 419 | # `Timesteps` does not contain any weights and will always return f32 tensors 420 | # but time_embedding might actually be running in fp16. so we need to cast here. 421 | # there might be better ways to encapsulate this. 422 | t_emb = t_emb.to(dtype=sample.dtype) 423 | 424 | emb = self.time_embedding(t_emb) 425 | 426 | time_embeds = self.add_time_proj(added_time_ids.flatten()) 427 | time_embeds = time_embeds.reshape((batch_size, -1)) 428 | time_embeds = time_embeds.to(emb.dtype) 429 | aug_emb = self.add_embedding(time_embeds) 430 | emb = emb + aug_emb 431 | 432 | # Flatten the batch and frames dimensions 433 | # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] 434 | sample = sample.flatten(0, 1) 435 | # Repeat the embeddings num_video_frames times 436 | # emb: [batch, channels] -> [batch * frames, channels] 437 | emb = emb.repeat_interleave(num_frames, dim=0) 438 | # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] 439 | encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) 440 | 441 | # 2. pre-process 442 | sample = self.conv_in(sample) 443 | if pose_latents is not None: 444 | sample = sample + pose_latents 445 | 446 | image_only_indicator = torch.ones(batch_size, num_frames, dtype=sample.dtype, device=sample.device) \ 447 | if image_only_indicator else torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) 448 | 449 | down_block_res_samples = (sample,) 450 | for downsample_block in self.down_blocks: 451 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 452 | sample, res_samples = downsample_block( 453 | hidden_states=sample, 454 | temb=emb, 455 | encoder_hidden_states=encoder_hidden_states, 456 | image_only_indicator=image_only_indicator, 457 | ) 458 | else: 459 | sample, res_samples = downsample_block( 460 | hidden_states=sample, 461 | temb=emb, 462 | image_only_indicator=image_only_indicator, 463 | ) 464 | 465 | down_block_res_samples += res_samples 466 | 467 | # 4. mid 468 | sample = self.mid_block( 469 | hidden_states=sample, 470 | temb=emb, 471 | encoder_hidden_states=encoder_hidden_states, 472 | image_only_indicator=image_only_indicator, 473 | ) 474 | 475 | # 5. up 476 | for i, upsample_block in enumerate(self.up_blocks): 477 | res_samples = down_block_res_samples[-len(upsample_block.resnets):] 478 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 479 | 480 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 481 | sample = upsample_block( 482 | hidden_states=sample, 483 | temb=emb, 484 | res_hidden_states_tuple=res_samples, 485 | encoder_hidden_states=encoder_hidden_states, 486 | image_only_indicator=image_only_indicator, 487 | ) 488 | else: 489 | sample = upsample_block( 490 | hidden_states=sample, 491 | temb=emb, 492 | res_hidden_states_tuple=res_samples, 493 | image_only_indicator=image_only_indicator, 494 | ) 495 | 496 | # 6. post-process 497 | sample = self.conv_norm_out(sample) 498 | sample = self.conv_act(sample) 499 | sample = self.conv_out(sample) 500 | 501 | # 7. Reshape back to original shape 502 | sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) 503 | 504 | if not return_dict: 505 | return (sample,) 506 | 507 | return UNetSpatioTemporalConditionOutput(sample=sample) 508 | -------------------------------------------------------------------------------- /mimicmotion/pipelines/pipeline_mimicmotion.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Callable, Dict, List, Optional, Union 4 | 5 | import PIL.Image 6 | import einops 7 | import numpy as np 8 | import torch 9 | from diffusers.image_processor import VaeImageProcessor, PipelineImageInput 10 | from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel 11 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 12 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps 13 | from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion \ 14 | import _resize_with_antialiasing, _append_dims 15 | from diffusers.schedulers import EulerDiscreteScheduler 16 | from diffusers.utils import BaseOutput, logging 17 | from diffusers.utils.torch_utils import is_compiled_module, randn_tensor 18 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection 19 | 20 | from ..modules.pose_net import PoseNet 21 | 22 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 23 | 24 | 25 | def _append_dims(x, target_dims): 26 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 27 | dims_to_append = target_dims - x.ndim 28 | if dims_to_append < 0: 29 | raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") 30 | return x[(...,) + (None,) * dims_to_append] 31 | 32 | 33 | # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid 34 | def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): 35 | batch_size, channels, num_frames, height, width = video.shape 36 | outputs = [] 37 | for batch_idx in range(batch_size): 38 | batch_vid = video[batch_idx].permute(1, 0, 2, 3) 39 | batch_output = processor.postprocess(batch_vid, output_type) 40 | 41 | outputs.append(batch_output) 42 | 43 | if output_type == "np": 44 | outputs = np.stack(outputs) 45 | 46 | elif output_type == "pt": 47 | outputs = torch.stack(outputs) 48 | 49 | elif not output_type == "pil": 50 | raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") 51 | 52 | return outputs 53 | 54 | 55 | @dataclass 56 | class MimicMotionPipelineOutput(BaseOutput): 57 | r""" 58 | Output class for mimicmotion pipeline. 59 | 60 | Args: 61 | frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]): 62 | List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, 63 | num_frames, height, width, num_channels)`. 64 | """ 65 | 66 | frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.Tensor] 67 | 68 | 69 | class MimicMotionPipeline(DiffusionPipeline): 70 | r""" 71 | Pipeline to generate video from an input image using Stable Video Diffusion. 72 | 73 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 74 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 75 | 76 | Args: 77 | vae ([`AutoencoderKLTemporalDecoder`]): 78 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. 79 | image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): 80 | Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K] 81 | (https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)). 82 | unet ([`UNetSpatioTemporalConditionModel`]): 83 | A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents. 84 | scheduler ([`EulerDiscreteScheduler`]): 85 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. 86 | feature_extractor ([`~transformers.CLIPImageProcessor`]): 87 | A `CLIPImageProcessor` to extract features from generated images. 88 | pose_net ([`PoseNet`]): 89 | A `` to inject pose signals into unet. 90 | """ 91 | 92 | model_cpu_offload_seq = "image_encoder->unet->vae" 93 | _callback_tensor_inputs = ["latents"] 94 | 95 | def __init__( 96 | self, 97 | vae: AutoencoderKLTemporalDecoder, 98 | image_encoder: CLIPVisionModelWithProjection, 99 | unet: UNetSpatioTemporalConditionModel, 100 | scheduler: EulerDiscreteScheduler, 101 | feature_extractor: CLIPImageProcessor, 102 | pose_net: PoseNet, 103 | ): 104 | super().__init__() 105 | 106 | self.register_modules( 107 | vae=vae, 108 | image_encoder=image_encoder, 109 | unet=unet, 110 | scheduler=scheduler, 111 | feature_extractor=feature_extractor, 112 | pose_net=pose_net, 113 | ) 114 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 115 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) 116 | 117 | def _encode_image( 118 | self, 119 | image: PipelineImageInput, 120 | device: Union[str, torch.device], 121 | num_videos_per_prompt: int, 122 | do_classifier_free_guidance: bool): 123 | dtype = next(self.image_encoder.parameters()).dtype 124 | 125 | if not isinstance(image, torch.Tensor): 126 | image = self.image_processor.pil_to_numpy(image) 127 | image = self.image_processor.numpy_to_pt(image) 128 | 129 | # We normalize the image before resizing to match with the original implementation. 130 | # Then we unnormalize it after resizing. 131 | image = image * 2.0 - 1.0 132 | image = _resize_with_antialiasing(image, (224, 224)) 133 | image = (image + 1.0) / 2.0 134 | 135 | # Normalize the image with for CLIP input 136 | image = self.feature_extractor( 137 | images=image, 138 | do_normalize=True, 139 | do_center_crop=False, 140 | do_resize=False, 141 | do_rescale=False, 142 | return_tensors="pt", 143 | ).pixel_values 144 | 145 | image = image.to(device=device, dtype=dtype) 146 | image_embeddings = self.image_encoder(image).image_embeds 147 | image_embeddings = image_embeddings.unsqueeze(1) 148 | 149 | # duplicate image embeddings for each generation per prompt, using mps friendly method 150 | bs_embed, seq_len, _ = image_embeddings.shape 151 | image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) 152 | image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) 153 | 154 | if do_classifier_free_guidance: 155 | negative_image_embeddings = torch.zeros_like(image_embeddings) 156 | 157 | # For classifier free guidance, we need to do two forward passes. 158 | # Here we concatenate the unconditional and text embeddings into a single batch 159 | # to avoid doing two forward passes 160 | image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) 161 | 162 | return image_embeddings 163 | 164 | def _encode_vae_image( 165 | self, 166 | image: torch.Tensor, 167 | device: Union[str, torch.device], 168 | num_videos_per_prompt: int, 169 | do_classifier_free_guidance: bool, 170 | ): 171 | image = image.to(device=device, dtype=self.vae.dtype) 172 | image_latents = self.vae.encode(image).latent_dist.mode() 173 | 174 | if do_classifier_free_guidance: 175 | negative_image_latents = torch.zeros_like(image_latents) 176 | 177 | # For classifier free guidance, we need to do two forward passes. 178 | # Here we concatenate the unconditional and text embeddings into a single batch 179 | # to avoid doing two forward passes 180 | image_latents = torch.cat([negative_image_latents, image_latents]) 181 | 182 | # duplicate image_latents for each generation per prompt, using mps friendly method 183 | image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) 184 | 185 | return image_latents 186 | 187 | def _get_add_time_ids( 188 | self, 189 | fps: int, 190 | motion_bucket_id: int, 191 | noise_aug_strength: float, 192 | dtype: torch.dtype, 193 | batch_size: int, 194 | num_videos_per_prompt: int, 195 | do_classifier_free_guidance: bool, 196 | ): 197 | add_time_ids = [fps, motion_bucket_id, noise_aug_strength] 198 | 199 | passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) 200 | expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features 201 | 202 | if expected_add_embed_dim != passed_add_embed_dim: 203 | raise ValueError( 204 | f"Model expects an added time embedding vector of length {expected_add_embed_dim}, " \ 205 | f"but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. " \ 206 | f"Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." 207 | ) 208 | 209 | add_time_ids = torch.tensor([add_time_ids], dtype=dtype) 210 | add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) 211 | 212 | if do_classifier_free_guidance: 213 | add_time_ids = torch.cat([add_time_ids, add_time_ids]) 214 | 215 | return add_time_ids 216 | 217 | def decode_latents( 218 | self, 219 | latents: torch.Tensor, 220 | num_frames: int, 221 | decode_chunk_size: int = 8): 222 | # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] 223 | latents = latents.flatten(0, 1) 224 | 225 | latents = 1 / self.vae.config.scaling_factor * latents 226 | 227 | forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward 228 | accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys()) 229 | 230 | # decode decode_chunk_size frames at a time to avoid OOM 231 | frames = [] 232 | for i in range(0, latents.shape[0], decode_chunk_size): 233 | num_frames_in = latents[i: i + decode_chunk_size].shape[0] 234 | decode_kwargs = {} 235 | if accepts_num_frames: 236 | # we only pass num_frames_in if it's expected 237 | decode_kwargs["num_frames"] = num_frames_in 238 | 239 | frame = self.vae.decode(latents[i: i + decode_chunk_size], **decode_kwargs).sample 240 | frames.append(frame.cpu()) 241 | frames = torch.cat(frames, dim=0) 242 | 243 | # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] 244 | frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) 245 | 246 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 247 | frames = frames.float() 248 | return frames 249 | 250 | def check_inputs(self, image, height, width): 251 | if ( 252 | not isinstance(image, torch.Tensor) 253 | and not isinstance(image, PIL.Image.Image) 254 | and not isinstance(image, list) 255 | ): 256 | raise ValueError( 257 | "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" 258 | f" {type(image)}" 259 | ) 260 | 261 | if height % 8 != 0 or width % 8 != 0: 262 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 263 | 264 | def prepare_latents( 265 | self, 266 | batch_size: int, 267 | num_frames: int, 268 | num_channels_latents: int, 269 | height: int, 270 | width: int, 271 | dtype: torch.dtype, 272 | device: Union[str, torch.device], 273 | generator: torch.Generator, 274 | latents: Optional[torch.Tensor] = None, 275 | ): 276 | shape = ( 277 | batch_size, 278 | num_frames, 279 | num_channels_latents // 2, 280 | height // self.vae_scale_factor, 281 | width // self.vae_scale_factor, 282 | ) 283 | if isinstance(generator, list) and len(generator) != batch_size: 284 | raise ValueError( 285 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 286 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 287 | ) 288 | 289 | if latents is None: 290 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 291 | else: 292 | latents = latents.to(device) 293 | 294 | # scale the initial noise by the standard deviation required by the scheduler 295 | latents = latents * self.scheduler.init_noise_sigma 296 | return latents 297 | 298 | @property 299 | def guidance_scale(self): 300 | return self._guidance_scale 301 | 302 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 303 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 304 | # corresponds to doing no classifier free guidance. 305 | @property 306 | def do_classifier_free_guidance(self): 307 | if isinstance(self.guidance_scale, (int, float)): 308 | return self.guidance_scale > 1 309 | return self.guidance_scale.max() > 1 310 | 311 | @property 312 | def num_timesteps(self): 313 | return self._num_timesteps 314 | 315 | def prepare_extra_step_kwargs(self, generator, eta): 316 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 317 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 318 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 319 | # and should be between [0, 1] 320 | 321 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 322 | extra_step_kwargs = {} 323 | if accepts_eta: 324 | extra_step_kwargs["eta"] = eta 325 | 326 | # check if the scheduler accepts generator 327 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 328 | if accepts_generator: 329 | extra_step_kwargs["generator"] = generator 330 | return extra_step_kwargs 331 | 332 | @torch.no_grad() 333 | def __call__( 334 | self, 335 | image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], 336 | image_pose: Union[torch.FloatTensor], 337 | height: int = 576, 338 | width: int = 1024, 339 | num_frames: Optional[int] = None, 340 | tile_size: Optional[int] = 16, 341 | tile_overlap: Optional[int] = 4, 342 | num_inference_steps: int = 25, 343 | min_guidance_scale: float = 1.0, 344 | max_guidance_scale: float = 3.0, 345 | fps: int = 7, 346 | motion_bucket_id: int = 127, 347 | noise_aug_strength: float = 0.02, 348 | image_only_indicator: bool = False, 349 | decode_chunk_size: Optional[int] = None, 350 | num_videos_per_prompt: Optional[int] = 1, 351 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 352 | latents: Optional[torch.FloatTensor] = None, 353 | output_type: Optional[str] = "pil", 354 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 355 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 356 | return_dict: bool = True, 357 | device: Union[str, torch.device] =None, 358 | ): 359 | r""" 360 | The call function to the pipeline for generation. 361 | 362 | Args: 363 | image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): 364 | Image or images to guide image generation. If you provide a tensor, it needs to be compatible with 365 | [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/ 366 | feature_extractor/preprocessor_config.json). 367 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 368 | The height in pixels of the generated image. 369 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 370 | The width in pixels of the generated image. 371 | num_frames (`int`, *optional*): 372 | The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` 373 | and to 25 for `stable-video-diffusion-img2vid-xt` 374 | num_inference_steps (`int`, *optional*, defaults to 25): 375 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 376 | expense of slower inference. This parameter is modulated by `strength`. 377 | min_guidance_scale (`float`, *optional*, defaults to 1.0): 378 | The minimum guidance scale. Used for the classifier free guidance with first frame. 379 | max_guidance_scale (`float`, *optional*, defaults to 3.0): 380 | The maximum guidance scale. Used for the classifier free guidance with last frame. 381 | fps (`int`, *optional*, defaults to 7): 382 | Frames per second.The rate at which the generated images shall be exported to a video after generation. 383 | Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training. 384 | motion_bucket_id (`int`, *optional*, defaults to 127): 385 | The motion bucket ID. Used as conditioning for the generation. 386 | The higher the number the more motion will be in the video. 387 | noise_aug_strength (`float`, *optional*, defaults to 0.02): 388 | The amount of noise added to the init image, 389 | the higher it is the less the video will look like the init image. Increase it for more motion. 390 | image_only_indicator (`bool`, *optional*, defaults to False): 391 | Whether to treat the inputs as batch of images instead of videos. 392 | decode_chunk_size (`int`, *optional*): 393 | The number of frames to decode at a time.The higher the chunk size, the higher the temporal consistency 394 | between frames, but also the higher the memory consumption. 395 | By default, the decoder will decode all frames at once for maximal quality. 396 | Reduce `decode_chunk_size` to reduce memory usage. 397 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 398 | The number of images to generate per prompt. 399 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 400 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 401 | generation deterministic. 402 | latents (`torch.FloatTensor`, *optional*): 403 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 404 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 405 | tensor is generated by sampling using the supplied random `generator`. 406 | output_type (`str`, *optional*, defaults to `"pil"`): 407 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 408 | callback_on_step_end (`Callable`, *optional*): 409 | A function that calls at the end of each denoising steps during the inference. The function is called 410 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 411 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 412 | `callback_on_step_end_tensor_inputs`. 413 | callback_on_step_end_tensor_inputs (`List`, *optional*): 414 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 415 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 416 | `._callback_tensor_inputs` attribute of your pipeline class. 417 | return_dict (`bool`, *optional*, defaults to `True`): 418 | Whether to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 419 | plain tuple. 420 | device: 421 | On which device the pipeline runs on. 422 | 423 | Returns: 424 | [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`: 425 | If `return_dict` is `True`, 426 | [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned, 427 | otherwise a `tuple` is returned where the first element is a list of list with the generated frames. 428 | 429 | Examples: 430 | 431 | ```py 432 | from diffusers import StableVideoDiffusionPipeline 433 | from diffusers.utils import load_image, export_to_video 434 | 435 | pipe = StableVideoDiffusionPipeline.from_pretrained( 436 | "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16") 437 | pipe.to("cuda") 438 | 439 | image = load_image( 440 | "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200") 441 | image = image.resize((1024, 576)) 442 | 443 | frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0] 444 | export_to_video(frames, "generated.mp4", fps=7) 445 | ``` 446 | """ 447 | # 0. Default height and width to unet 448 | height = height or self.unet.config.sample_size * self.vae_scale_factor 449 | width = width or self.unet.config.sample_size * self.vae_scale_factor 450 | 451 | num_frames = num_frames if num_frames is not None else self.unet.config.num_frames 452 | decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames 453 | 454 | # 1. Check inputs. Raise error if not correct 455 | self.check_inputs(image, height, width) 456 | 457 | # 2. Define call parameters 458 | if isinstance(image, PIL.Image.Image): 459 | batch_size = 1 460 | elif isinstance(image, list): 461 | batch_size = len(image) 462 | else: 463 | batch_size = image.shape[0] 464 | device = device if device is not None else self._execution_device 465 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 466 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 467 | # corresponds to doing no classifier free guidance. 468 | self._guidance_scale = max_guidance_scale 469 | 470 | # 3. Encode input image 471 | self.image_encoder.to(device) 472 | image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance) 473 | self.image_encoder.cpu() 474 | 475 | # NOTE: Stable Diffusion Video was conditioned on fps - 1, which 476 | # is why it is reduced here. 477 | fps = fps - 1 478 | 479 | # 4. Encode input image using VAE 480 | image = self.image_processor.preprocess(image, height=height, width=width).to(device) 481 | noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype) 482 | image = image + noise_aug_strength * noise 483 | 484 | self.vae.to(device) 485 | image_latents = self._encode_vae_image( 486 | image, 487 | device=device, 488 | num_videos_per_prompt=num_videos_per_prompt, 489 | do_classifier_free_guidance=self.do_classifier_free_guidance, 490 | ) 491 | image_latents = image_latents.to(image_embeddings.dtype) 492 | self.vae.cpu() 493 | 494 | # Repeat the image latents for each frame so we can concatenate them with the noise 495 | # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width] 496 | image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1) 497 | 498 | # 5. Get Added Time IDs 499 | added_time_ids = self._get_add_time_ids( 500 | fps, 501 | motion_bucket_id, 502 | noise_aug_strength, 503 | image_embeddings.dtype, 504 | batch_size, 505 | num_videos_per_prompt, 506 | self.do_classifier_free_guidance, 507 | ) 508 | added_time_ids = added_time_ids.to(device) 509 | 510 | # 4. Prepare timesteps 511 | timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None) 512 | 513 | # 5. Prepare latent variables 514 | num_channels_latents = self.unet.config.in_channels 515 | latents = self.prepare_latents( 516 | batch_size * num_videos_per_prompt, 517 | tile_size, 518 | num_channels_latents, 519 | height, 520 | width, 521 | image_embeddings.dtype, 522 | device, 523 | generator, 524 | latents, 525 | ) 526 | latents = latents.repeat(1, num_frames // tile_size + 1, 1, 1, 1)[:, :num_frames] 527 | 528 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 529 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, 0.0) 530 | 531 | # 7. Prepare guidance scale 532 | guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0) 533 | guidance_scale = guidance_scale.to(device, latents.dtype) 534 | guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1) 535 | guidance_scale = _append_dims(guidance_scale, latents.ndim) 536 | 537 | self._guidance_scale = guidance_scale 538 | 539 | # 8. Denoising loop 540 | self._num_timesteps = len(timesteps) 541 | indices = [[0, *range(i + 1, min(i + tile_size, num_frames))] for i in 542 | range(0, num_frames - tile_size + 1, tile_size - tile_overlap)] 543 | if indices[-1][-1] < num_frames - 1: 544 | indices.append([0, *range(num_frames - tile_size + 1, num_frames)]) 545 | 546 | self.pose_net.to(device) 547 | self.unet.to(device) 548 | 549 | with torch.cuda.device(device): 550 | torch.cuda.empty_cache() 551 | 552 | with self.progress_bar(total=len(timesteps) * len(indices)) as progress_bar: 553 | for i, t in enumerate(timesteps): 554 | # expand the latents if we are doing classifier free guidance 555 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents 556 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 557 | 558 | # Concatenate image_latents over channels dimension 559 | latent_model_input = torch.cat([latent_model_input, image_latents], dim=2) 560 | 561 | # predict the noise residual 562 | noise_pred = torch.zeros_like(image_latents) 563 | noise_pred_cnt = image_latents.new_zeros((num_frames,)) 564 | weight = (torch.arange(tile_size, device=device) + 0.5) * 2. / tile_size 565 | weight = torch.minimum(weight, 2 - weight) 566 | for idx in indices: 567 | 568 | # classification-free inference 569 | pose_latents = self.pose_net(image_pose[idx].to(device)) 570 | _noise_pred = self.unet( 571 | latent_model_input[:1, idx], 572 | t, 573 | encoder_hidden_states=image_embeddings[:1], 574 | added_time_ids=added_time_ids[:1], 575 | pose_latents=None, 576 | image_only_indicator=image_only_indicator, 577 | return_dict=False, 578 | )[0] 579 | noise_pred[:1, idx] += _noise_pred * weight[:, None, None, None] 580 | 581 | # normal inference 582 | _noise_pred = self.unet( 583 | latent_model_input[1:, idx], 584 | t, 585 | encoder_hidden_states=image_embeddings[1:], 586 | added_time_ids=added_time_ids[1:], 587 | pose_latents=pose_latents, 588 | image_only_indicator=image_only_indicator, 589 | return_dict=False, 590 | )[0] 591 | noise_pred[1:, idx] += _noise_pred * weight[:, None, None, None] 592 | 593 | noise_pred_cnt[idx] += weight 594 | progress_bar.update() 595 | noise_pred.div_(noise_pred_cnt[:, None, None, None]) 596 | 597 | # perform guidance 598 | if self.do_classifier_free_guidance: 599 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) 600 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) 601 | 602 | # compute the previous noisy sample x_t -> x_t-1 603 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 604 | 605 | if callback_on_step_end is not None: 606 | callback_kwargs = {} 607 | for k in callback_on_step_end_tensor_inputs: 608 | callback_kwargs[k] = locals()[k] 609 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 610 | 611 | latents = callback_outputs.pop("latents", latents) 612 | 613 | self.pose_net.cpu() 614 | self.unet.cpu() 615 | 616 | if not output_type == "latent": 617 | self.vae.decoder.to(device) 618 | frames = self.decode_latents(latents, num_frames, decode_chunk_size) 619 | frames = tensor2vid(frames, self.image_processor, output_type=output_type) 620 | else: 621 | frames = latents 622 | 623 | self.maybe_free_model_hooks() 624 | 625 | if not return_dict: 626 | return frames 627 | 628 | return MimicMotionPipelineOutput(frames=frames) 629 | -------------------------------------------------------------------------------- /mimicmotion/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/MimicMotion/62f91e1f4ab750e1ab0d96c08f2d199b6e84e8b1/mimicmotion/utils/__init__.py -------------------------------------------------------------------------------- /mimicmotion/utils/geglu_patch.py: -------------------------------------------------------------------------------- 1 | import diffusers.models.activations 2 | 3 | 4 | def patch_geglu_inplace(): 5 | """Patch GEGLU with inplace multiplication to save GPU memory.""" 6 | def forward(self, hidden_states): 7 | hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) 8 | return hidden_states.mul_(self.gelu(gate)) 9 | diffusers.models.activations.GEGLU.forward = forward 10 | -------------------------------------------------------------------------------- /mimicmotion/utils/loader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.utils.checkpoint 5 | from diffusers.models import AutoencoderKLTemporalDecoder 6 | from diffusers.schedulers import EulerDiscreteScheduler 7 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection 8 | 9 | from ..modules.unet import UNetSpatioTemporalConditionModel 10 | from ..modules.pose_net import PoseNet 11 | from ..pipelines.pipeline_mimicmotion import MimicMotionPipeline 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | class MimicMotionModel(torch.nn.Module): 16 | def __init__(self, base_model_path): 17 | """construnct base model components and load pretrained svd model except pose-net 18 | Args: 19 | base_model_path (str): pretrained svd model path 20 | """ 21 | super().__init__() 22 | self.unet = UNetSpatioTemporalConditionModel.from_config( 23 | UNetSpatioTemporalConditionModel.load_config(base_model_path, subfolder="unet")) 24 | self.vae = AutoencoderKLTemporalDecoder.from_pretrained( 25 | base_model_path, subfolder="vae", torch_dtype=torch.float16, variant="fp16") 26 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( 27 | base_model_path, subfolder="image_encoder", torch_dtype=torch.float16, variant="fp16") 28 | self.noise_scheduler = EulerDiscreteScheduler.from_pretrained( 29 | base_model_path, subfolder="scheduler") 30 | self.feature_extractor = CLIPImageProcessor.from_pretrained( 31 | base_model_path, subfolder="feature_extractor") 32 | # pose_net 33 | self.pose_net = PoseNet(noise_latent_channels=self.unet.config.block_out_channels[0]) 34 | 35 | def create_pipeline(infer_config, device): 36 | """create mimicmotion pipeline and load pretrained weight 37 | 38 | Args: 39 | infer_config (str): 40 | device (str or torch.device): "cpu" or "cuda:{device_id}" 41 | """ 42 | mimicmotion_models = MimicMotionModel(infer_config.base_model_path) 43 | mimicmotion_models.load_state_dict(torch.load(infer_config.ckpt_path, map_location="cpu"), strict=False) 44 | pipeline = MimicMotionPipeline( 45 | vae=mimicmotion_models.vae, 46 | image_encoder=mimicmotion_models.image_encoder, 47 | unet=mimicmotion_models.unet, 48 | scheduler=mimicmotion_models.noise_scheduler, 49 | feature_extractor=mimicmotion_models.feature_extractor, 50 | pose_net=mimicmotion_models.pose_net 51 | ) 52 | return pipeline 53 | 54 | -------------------------------------------------------------------------------- /mimicmotion/utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | from torchvision.io import write_video 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | def save_to_mp4(frames, save_path, fps=7): 9 | frames = frames.permute((0, 2, 3, 1)) # (f, c, h, w) to (f, h, w, c) 10 | Path(save_path).parent.mkdir(parents=True, exist_ok=True) 11 | write_video(save_path, frames, fps=fps) 12 | 13 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # predict.py 2 | import subprocess 3 | import time 4 | from cog import BasePredictor, Input, Path 5 | import os 6 | import torch 7 | import numpy as np 8 | from PIL import Image 9 | from omegaconf import OmegaConf 10 | from datetime import datetime 11 | 12 | from torchvision.transforms.functional import pil_to_tensor, resize, center_crop 13 | from constants import ASPECT_RATIO 14 | 15 | MODEL_CACHE = "models" 16 | os.environ["HF_DATASETS_OFFLINE"] = "1" 17 | os.environ["TRANSFORMERS_OFFLINE"] = "1" 18 | os.environ["HF_HOME"] = MODEL_CACHE 19 | os.environ["TORCH_HOME"] = MODEL_CACHE 20 | os.environ["HF_DATASETS_CACHE"] = MODEL_CACHE 21 | os.environ["TRANSFORMERS_CACHE"] = MODEL_CACHE 22 | os.environ["HUGGINGFACE_HUB_CACHE"] = MODEL_CACHE 23 | 24 | BASE_URL = f"https://weights.replicate.delivery/default/MimicMotion/{MODEL_CACHE}/" 25 | 26 | 27 | def download_weights(url: str, dest: str) -> None: 28 | # NOTE WHEN YOU EXTRACT SPECIFY THE PARENT FOLDER 29 | start = time.time() 30 | print("[!] Initiating download from URL: ", url) 31 | print("[~] Destination path: ", dest) 32 | if ".tar" in dest: 33 | dest = os.path.dirname(dest) 34 | command = ["pget", "-vf" + ("x" if ".tar" in url else ""), url, dest] 35 | try: 36 | print(f"[~] Running command: {' '.join(command)}") 37 | subprocess.check_call(command, close_fds=False) 38 | except subprocess.CalledProcessError as e: 39 | print( 40 | f"[ERROR] Failed to download weights. Command '{' '.join(e.cmd)}' returned non-zero exit status {e.returncode}." 41 | ) 42 | raise 43 | print("[+] Download completed in: ", time.time() - start, "seconds") 44 | 45 | 46 | class Predictor(BasePredictor): 47 | def setup(self): 48 | """Load the model into memory to make running multiple predictions efficient""" 49 | 50 | if not os.path.exists(MODEL_CACHE): 51 | os.makedirs(MODEL_CACHE) 52 | model_files = [ 53 | "DWPose.tar", 54 | "MimicMotion.pth", 55 | "MimicMotion_1-1.pth", 56 | "SVD.tar", 57 | ] 58 | for model_file in model_files: 59 | url = BASE_URL + model_file 60 | filename = url.split("/")[-1] 61 | dest_path = os.path.join(MODEL_CACHE, filename) 62 | if not os.path.exists(dest_path.replace(".tar", "")): 63 | download_weights(url, dest_path) 64 | 65 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 66 | print(f"Using device: {self.device}") 67 | 68 | # Move imports here and make them global 69 | # This ensures model files are downloaded before importing mimicmotion modules 70 | global MimicMotionPipeline, create_pipeline, save_to_mp4, get_video_pose, get_image_pose 71 | from mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline 72 | from mimicmotion.utils.loader import create_pipeline 73 | from mimicmotion.utils.utils import save_to_mp4 74 | from mimicmotion.dwpose.preprocess import get_video_pose, get_image_pose 75 | 76 | # Load config with new checkpoint as default 77 | self.config = OmegaConf.create( 78 | { 79 | "base_model_path": "models/SVD/stable-video-diffusion-img2vid-xt-1-1", 80 | "ckpt_path": "models/MimicMotion_1-1.pth", 81 | } 82 | ) 83 | 84 | # Create the pipeline with the new checkpoint 85 | self.pipeline = create_pipeline(self.config, self.device) 86 | self.current_checkpoint = "v1-1" 87 | self.current_dtype = torch.get_default_dtype() 88 | 89 | def predict( 90 | self, 91 | motion_video: Path = Input( 92 | description="Reference video file containing the motion to be mimicked" 93 | ), 94 | appearance_image: Path = Input( 95 | description="Reference image file for the appearance of the generated video" 96 | ), 97 | resolution: int = Input( 98 | description="Height of the output video in pixels. Width is automatically calculated.", 99 | default=576, 100 | ge=64, 101 | le=1024, 102 | ), 103 | chunk_size: int = Input( 104 | description="Number of frames to generate in each processing chunk", 105 | default=16, 106 | ge=2, 107 | ), 108 | frames_overlap: int = Input( 109 | description="Number of overlapping frames between chunks for smoother transitions", 110 | default=6, 111 | ge=0, 112 | ), 113 | denoising_steps: int = Input( 114 | description="Number of denoising steps in the diffusion process. More steps can improve quality but increase processing time.", 115 | default=25, 116 | ge=1, 117 | le=100, 118 | ), 119 | noise_strength: float = Input( 120 | description="Strength of noise augmentation. Higher values add more variation but may reduce coherence with the reference.", 121 | default=0.0, 122 | ge=0.0, 123 | le=1.0, 124 | ), 125 | guidance_scale: float = Input( 126 | description="Strength of guidance towards the reference. Higher values adhere more closely to the reference but may reduce creativity.", 127 | default=2.0, 128 | ge=0.1, 129 | le=10.0, 130 | ), 131 | sample_stride: int = Input( 132 | description="Interval for sampling frames from the reference video. Higher values skip more frames.", 133 | default=2, 134 | ge=1, 135 | ), 136 | output_frames_per_second: int = Input( 137 | description="Frames per second of the output video. Affects playback speed.", 138 | default=15, 139 | ge=1, 140 | le=60, 141 | ), 142 | seed: int = Input( 143 | description="Random seed. Leave blank to randomize the seed", 144 | default=None, 145 | ), 146 | checkpoint_version: str = Input( 147 | description="Choose the checkpoint version to use", 148 | choices=["v1", "v1-1"], 149 | default="v1-1", 150 | ), 151 | ) -> Path: 152 | """Run a single prediction on the model""" 153 | 154 | ref_video = motion_video 155 | ref_image = appearance_image 156 | num_frames = chunk_size 157 | num_inference_steps = denoising_steps 158 | noise_aug_strength = noise_strength 159 | fps = output_frames_per_second 160 | use_fp16 = True 161 | 162 | if seed is None: 163 | seed = int.from_bytes(os.urandom(2), "big") 164 | print(f"Using seed: {seed}") 165 | 166 | need_pipeline_update = False 167 | 168 | # Check if we need to switch checkpoints 169 | if checkpoint_version != self.current_checkpoint: 170 | if checkpoint_version == "v1": 171 | self.config.ckpt_path = "models/MimicMotion.pth" 172 | else: # v1-1 173 | self.config.ckpt_path = "models/MimicMotion_1-1.pth" 174 | need_pipeline_update = True 175 | self.current_checkpoint = checkpoint_version 176 | 177 | # Check if we need to switch dtype 178 | target_dtype = torch.float16 if use_fp16 else torch.float32 179 | if target_dtype != self.current_dtype: 180 | torch.set_default_dtype(target_dtype) 181 | need_pipeline_update = True 182 | self.current_dtype = target_dtype 183 | 184 | # Update pipeline if needed 185 | if need_pipeline_update: 186 | print( 187 | f"Updating pipeline with checkpoint: {self.config.ckpt_path} and dtype: {torch.get_default_dtype()}" 188 | ) 189 | self.pipeline = create_pipeline(self.config, self.device) 190 | 191 | print(f"Using checkpoint: {self.config.ckpt_path}") 192 | print(f"Using dtype: {torch.get_default_dtype()}") 193 | 194 | print( 195 | f"[!] ({type(ref_video)}) ref_video={ref_video}, " 196 | f"[!] ({type(ref_image)}) ref_image={ref_image}, " 197 | f"[!] ({type(resolution)}) resolution={resolution}, " 198 | f"[!] ({type(num_frames)}) num_frames={num_frames}, " 199 | f"[!] ({type(frames_overlap)}) frames_overlap={frames_overlap}, " 200 | f"[!] ({type(num_inference_steps)}) num_inference_steps={num_inference_steps}, " 201 | f"[!] ({type(noise_aug_strength)}) noise_aug_strength={noise_aug_strength}, " 202 | f"[!] ({type(guidance_scale)}) guidance_scale={guidance_scale}, " 203 | f"[!] ({type(sample_stride)}) sample_stride={sample_stride}, " 204 | f"[!] ({type(fps)}) fps={fps}, " 205 | f"[!] ({type(seed)}) seed={seed}, " 206 | f"[!] ({type(use_fp16)}) use_fp16={use_fp16}" 207 | ) 208 | 209 | # Input validation 210 | if not ref_video.exists(): 211 | raise ValueError(f"Reference video file does not exist: {ref_video}") 212 | if not ref_image.exists(): 213 | raise ValueError(f"Reference image file does not exist: {ref_image}") 214 | 215 | if resolution % 8 != 0: 216 | raise ValueError(f"Resolution must be a multiple of 8, got {resolution}") 217 | 218 | if resolution < 64 or resolution > 1024: 219 | raise ValueError( 220 | f"Resolution must be between 64 and 1024, got {resolution}" 221 | ) 222 | 223 | if num_frames <= frames_overlap: 224 | raise ValueError( 225 | f"Number of frames ({num_frames}) must be greater than frames overlap ({frames_overlap})" 226 | ) 227 | 228 | if num_frames < 2: 229 | raise ValueError(f"Number of frames must be at least 2, got {num_frames}") 230 | 231 | if frames_overlap < 0: 232 | raise ValueError( 233 | f"Frames overlap must be non-negative, got {frames_overlap}" 234 | ) 235 | 236 | if num_inference_steps < 1 or num_inference_steps > 100: 237 | raise ValueError( 238 | f"Number of inference steps must be between 1 and 100, got {num_inference_steps}" 239 | ) 240 | 241 | if noise_aug_strength < 0.0 or noise_aug_strength > 1.0: 242 | raise ValueError( 243 | f"Noise augmentation strength must be between 0.0 and 1.0, got {noise_aug_strength}" 244 | ) 245 | 246 | if guidance_scale < 0.1 or guidance_scale > 10.0: 247 | raise ValueError( 248 | f"Guidance scale must be between 0.1 and 10.0, got {guidance_scale}" 249 | ) 250 | 251 | if sample_stride < 1: 252 | raise ValueError(f"Sample stride must be at least 1, got {sample_stride}") 253 | 254 | if fps < 1 or fps > 60: 255 | raise ValueError(f"FPS must be between 1 and 60, got {fps}") 256 | 257 | try: 258 | # Preprocess 259 | pose_pixels, image_pixels = self.preprocess( 260 | str(ref_video), 261 | str(ref_image), 262 | resolution=resolution, 263 | sample_stride=sample_stride, 264 | ) 265 | 266 | # Run pipeline 267 | video_frames = self.run_pipeline( 268 | image_pixels, 269 | pose_pixels, 270 | num_frames=num_frames, 271 | frames_overlap=frames_overlap, 272 | num_inference_steps=num_inference_steps, 273 | noise_aug_strength=noise_aug_strength, 274 | guidance_scale=guidance_scale, 275 | seed=seed, 276 | ) 277 | 278 | # Save output 279 | output_path = f"/tmp/output_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4" 280 | save_to_mp4(video_frames, output_path, fps=fps) 281 | 282 | return Path(output_path) 283 | 284 | except Exception as e: 285 | print(f"An error occurred during prediction: {str(e)}") 286 | raise 287 | 288 | def preprocess(self, video_path, image_path, resolution=576, sample_stride=2): 289 | image_pixels = Image.open(image_path).convert("RGB") 290 | image_pixels = pil_to_tensor(image_pixels) # (c, h, w) 291 | h, w = image_pixels.shape[-2:] 292 | 293 | if h > w: 294 | w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64 295 | else: 296 | w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolution 297 | 298 | h_w_ratio = float(h) / float(w) 299 | if h_w_ratio < h_target / w_target: 300 | h_resize, w_resize = h_target, int(h_target / h_w_ratio) 301 | else: 302 | h_resize, w_resize = int(w_target * h_w_ratio), w_target 303 | 304 | image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None) 305 | image_pixels = center_crop(image_pixels, [h_target, w_target]) 306 | image_pixels = image_pixels.permute((1, 2, 0)).numpy() 307 | 308 | image_pose = get_image_pose(image_pixels) 309 | video_pose = get_video_pose( 310 | video_path, image_pixels, sample_stride=sample_stride 311 | ) 312 | 313 | pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose]) 314 | image_pixels = np.transpose(np.expand_dims(image_pixels, 0), (0, 3, 1, 2)) 315 | 316 | return ( 317 | torch.from_numpy(pose_pixels.copy()) / 127.5 - 1, 318 | torch.from_numpy(image_pixels) / 127.5 - 1, 319 | ) 320 | 321 | def run_pipeline( 322 | self, 323 | image_pixels, 324 | pose_pixels, 325 | num_frames, 326 | frames_overlap, 327 | num_inference_steps, 328 | noise_aug_strength, 329 | guidance_scale, 330 | seed, 331 | ): 332 | image_pixels = [ 333 | Image.fromarray( 334 | (img.cpu().numpy().transpose(1, 2, 0) * 127.5 + 127.5).astype(np.uint8) 335 | ) 336 | for img in image_pixels 337 | ] 338 | pose_pixels = pose_pixels.unsqueeze(0).to(self.device) 339 | 340 | generator = torch.Generator(device=self.device) 341 | generator.manual_seed(seed) 342 | 343 | frames = self.pipeline( 344 | image_pixels, 345 | image_pose=pose_pixels, 346 | num_frames=pose_pixels.size(1), 347 | tile_size=num_frames, 348 | tile_overlap=frames_overlap, 349 | height=pose_pixels.shape[-2], 350 | width=pose_pixels.shape[-1], 351 | fps=7, 352 | noise_aug_strength=noise_aug_strength, 353 | num_inference_steps=num_inference_steps, 354 | generator=generator, 355 | min_guidance_scale=guidance_scale, 356 | max_guidance_scale=guidance_scale, 357 | decode_chunk_size=8, 358 | output_type="pt", 359 | device=self.device, 360 | ).frames.cpu() 361 | 362 | video_frames = (frames * 255.0).to(torch.uint8) 363 | return video_frames[0, 1:] # Remove the first frame (reference image) 364 | --------------------------------------------------------------------------------