├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── assets ├── attn-mask.png └── framework.jpg ├── configs ├── base_config.yaml ├── disneyPixar.yaml ├── kFelted.yaml ├── moxin.yaml ├── origami.yaml ├── pixart.yaml └── toonyou.yaml ├── demo ├── .gitattributes ├── .gitignore ├── README.md ├── app.py ├── config.py ├── connection_manager.py ├── demo_cfg.yaml ├── demo_cfg_arknight.yaml ├── frontend │ ├── .eslintignore │ ├── .eslintrc.cjs │ ├── .gitignore │ ├── .npmrc │ ├── .prettierignore │ ├── .prettierrc │ ├── README.md │ ├── package-lock.json │ ├── package.json │ ├── postcss.config.js │ ├── src │ │ ├── app.css │ │ ├── app.d.ts │ │ ├── app.html │ │ ├── lib │ │ │ ├── components │ │ │ │ ├── Button.svelte │ │ │ │ ├── Checkbox.svelte │ │ │ │ ├── ImagePlayer.svelte │ │ │ │ ├── InputRange.svelte │ │ │ │ ├── MediaListSwitcher.svelte │ │ │ │ ├── PipelineOptions.svelte │ │ │ │ ├── SeedInput.svelte │ │ │ │ ├── Selectlist.svelte │ │ │ │ ├── TextArea.svelte │ │ │ │ ├── VideoInput.svelte │ │ │ │ └── Warning.svelte │ │ │ ├── icons │ │ │ │ ├── floppy.svelte │ │ │ │ ├── screen.svelte │ │ │ │ └── spinner.svelte │ │ │ ├── index.ts │ │ │ ├── lcmLive.ts │ │ │ ├── mediaStream.ts │ │ │ ├── store.ts │ │ │ ├── types.ts │ │ │ └── utils.ts │ │ └── routes │ │ │ ├── +layout.svelte │ │ │ ├── +page.svelte │ │ │ └── +page.ts │ ├── svelte.config.js │ ├── tailwind.config.js │ ├── tsconfig.json │ └── vite.config.ts ├── main.py ├── requirements.txt ├── start.sh ├── util.py └── vid2vid.py ├── live2diff ├── __init__.py ├── acceleration │ ├── __init__.py │ └── tensorrt │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── engine.py │ │ ├── models.py │ │ └── utilities.py ├── animatediff │ ├── __init__.py │ ├── converter │ │ ├── __init__.py │ │ ├── convert.py │ │ ├── convert_from_ckpt.py │ │ └── convert_lora_safetensor_to_diffusers.py │ ├── models │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── depth_utils.py │ │ ├── motion_module.py │ │ ├── positional_encoding.py │ │ ├── resnet.py │ │ ├── stream_motion_module.py │ │ ├── unet_blocks_streaming.py │ │ ├── unet_blocks_warmup.py │ │ ├── unet_depth_streaming.py │ │ └── unet_depth_warmup.py │ └── pipeline │ │ ├── __init__.py │ │ ├── loader.py │ │ └── pipeline_animatediff_depth.py ├── image_filter.py ├── image_utils.py ├── pipeline_stream_animation_depth.py └── utils │ ├── __init__.py │ ├── config.py │ ├── io.py │ └── wrapper.py ├── pyproject.toml ├── scripts └── download.sh ├── setup.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # https://github.com/github/gitignore/blob/main/Python.gitignore 2 | 3 | .vscode/ 4 | engines/ 5 | output/ 6 | *.csv 7 | *.mp4 8 | *.png 9 | !assets/*.mp4 10 | !assets/*.png 11 | *.safetensors 12 | result_lcm.png 13 | model.ckpt 14 | !images/inputs/input.png 15 | 16 | # Byte-compiled / optimized / DLL files 17 | __pycache__/ 18 | *.py[cod] 19 | *$py.class 20 | 21 | # C extensions 22 | *.so 23 | 24 | # Distribution / packaging 25 | .Python 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .nox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | *.py,cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | cover/ 68 | 69 | # Translations 70 | *.mo 71 | *.pot 72 | 73 | # Django stuff: 74 | *.log 75 | local_settings.py 76 | db.sqlite3 77 | db.sqlite3-journal 78 | 79 | # Flask stuff: 80 | instance/ 81 | .webassets-cache 82 | 83 | # Scrapy stuff: 84 | .scrapy 85 | 86 | # Sphinx documentation 87 | docs/_build/ 88 | 89 | # PyBuilder 90 | .pybuilder/ 91 | target/ 92 | 93 | # Jupyter Notebook 94 | .ipynb_checkpoints 95 | 96 | # IPython 97 | profile_default/ 98 | ipython_config.py 99 | 100 | # pyenv 101 | # For a library or package, you might want to ignore these files since the code is 102 | # intended to run in multiple environments; otherwise, check them in: 103 | # .python-version 104 | 105 | # pipenv 106 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 107 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 108 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 109 | # install all needed dependencies. 110 | #Pipfile.lock 111 | 112 | # poetry 113 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 114 | # This is especially recommended for binary packages to ensure reproducibility, and is more 115 | # commonly ignored for libraries. 116 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 117 | #poetry.lock 118 | 119 | # pdm 120 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 121 | #pdm.lock 122 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 123 | # in version control. 124 | # https://pdm.fming.dev/#use-with-ide 125 | .pdm.toml 126 | 127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 128 | __pypackages__/ 129 | 130 | # Celery stuff 131 | celerybeat-schedule 132 | celerybeat.pid 133 | 134 | # SageMath parsed files 135 | *.sage.py 136 | 137 | # Environments 138 | .env 139 | .venv 140 | env/ 141 | venv/ 142 | ENV/ 143 | env.bak/ 144 | venv.bak/ 145 | 146 | # Spyder project settings 147 | .spyderproject 148 | .spyproject 149 | 150 | # Rope project settings 151 | .ropeproject 152 | 153 | # mkdocs documentation 154 | /site 155 | 156 | # mypy 157 | .mypy_cache/ 158 | .dmypy.json 159 | dmypy.json 160 | 161 | # Pyre type checker 162 | .pyre/ 163 | 164 | # pytype static type analyzer 165 | .pytype/ 166 | 167 | # Cython debug symbols 168 | cython_debug/ 169 | 170 | # PyCharm 171 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 172 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 173 | # and can be added to the global gitignore or merged into this file. For a more nuclear 174 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 175 | #.idea/ 176 | 177 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 178 | 179 | # dependencies 180 | *node_modules 181 | */.pnp 182 | .pnp.js 183 | 184 | # testing 185 | */coverage 186 | 187 | # production 188 | */build 189 | 190 | # misc 191 | .DS_Store 192 | .env.local 193 | .env.development.local 194 | .env.test.local 195 | .env.production.local 196 | 197 | npm-debug.log* 198 | yarn-debug.log* 199 | yarn-error.log* 200 | 201 | *.venv 202 | 203 | __pycache__/ 204 | *.py[cod] 205 | *$py.class 206 | 207 | models/RealESR* 208 | *.safetensors 209 | 210 | work_dirs/ 211 | tests/ 212 | data/ 213 | 214 | models/Model/ 215 | *.safetensors 216 | *.ckpt 217 | *.pt 218 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "live2diff/MiDaS"] 2 | path = live2diff/MiDaS 3 | url = git@github.com:lewiji/MiDaS.git 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.3.5 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | args: [ --fix ] 9 | # Run the formatter. 10 | - id: ruff-format 11 | - repo: https://github.com/codespell-project/codespell 12 | rev: v2.2.1 13 | hooks: 14 | - id: codespell 15 | args: ["-L", "warmup,mose,parms", "--skip", "*.json"] 16 | - repo: https://github.com/pre-commit/pre-commit-hooks 17 | rev: v4.3.0 18 | hooks: 19 | - id: trailing-whitespace 20 | - id: check-yaml 21 | - id: end-of-file-fixer 22 | - id: requirements-txt-fixer 23 | - id: fix-encoding-pragma 24 | args: ["--remove"] 25 | - id: mixed-line-ending 26 | args: ["--fix=lf"] 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Live2Diff: **Live** Stream Translation via Uni-directional Attention in Video **Diffusion** Models 2 | 3 |

4 | 5 |

6 | 7 | **Authors:** [Zhening Xing](https://github.com/LeoXing1996), [Gereon Fox](https://people.mpi-inf.mpg.de/~gfox/), [Yanhong Zeng](https://zengyh1900.github.io/), [Xingang Pan](https://xingangpan.github.io/), [Mohamed Elgharib](https://people.mpi-inf.mpg.de/~elgharib/), [Christian Theobalt](https://people.mpi-inf.mpg.de/~theobalt/), [Kai Chen †](https://chenkai.site/) (†: corresponding author) 8 | 9 | 10 | [![arXiv](https://img.shields.io/badge/arXiv-2407.08701-b31b1b.svg)](https://arxiv.org/abs/2407.08701) 11 | [![Project Page](https://img.shields.io/badge/Project-Page-blue)](https://live2diff.github.io/) 12 | 13 | Open in HugginFace 14 | 15 | [![HuggingFace Model](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue)](https://huggingface.co/Leoxing/Live2Diff) 16 | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Oo0KpOboBAO10ZG_nIB55zbDHpt-Pu68?usp=sharing) 17 | 18 | ## Introduction Video 19 | 20 | [![Youtube Introduction Video](https://github.com/user-attachments/assets/548e200e-90c3-4d51-a1d2-3f5ba78cb151)](https://youtu.be/4w2cLRW3RX0) 21 | 22 | ## Release 23 | 24 | * [2024/07/18] We release [HuggingFace space](https://huggingface.co/spaces/Leoxing/Live2Diff), code, and [checkpoints](https://huggingface.co/Leoxing/Live2Diff). 25 | * [2024/07/22] We release [Colab Demo](https://colab.research.google.com/drive/1Oo0KpOboBAO10ZG_nIB55zbDHpt-Pu68?usp=sharing) 26 | 27 | ## TODO List 28 | 29 | - [x] Support Colab 30 | 31 | ## Key Features 32 | 33 |

34 | 35 |

36 | 37 | * **Uni-directional** Temporal Attention with **Warmup** Mechanism 38 | * **Multitimestep KV-Cache** for Temporal Attention during Inference 39 | * **Depth Prior** for Better Structure Consistency 40 | * Compatible with **DreamBooth and LoRA** for Various Styles 41 | * **TensorRT** Supported 42 | 43 | The speed evaluation is conducted on **Ubuntu 20.04.6 LTS** and **Pytorch 2.2.2** with **RTX 4090 GPU** and **Intel(R) Xeon(R) Platinum 8352V CPU**. Denoising steps are set as 2. 44 | 45 | | Resolution | TensorRT | FPS | 46 | | :--------: | :------: | :-------: | 47 | | 512 x 512 | **On** | **16.43** | 48 | | 512 x 512 | Off | 6.91 | 49 | | 768 x 512 | **On** | **12.15** | 50 | | 768 x 512 | Off | 6.29 | 51 | 52 | ## Installation 53 | 54 | ### Step0: clone this repository and submodule 55 | 56 | ```bash 57 | git clone https://github.com/open-mmlab/Live2Diff.git 58 | # or vis ssh 59 | git clone git@github.com:open-mmlab/Live2Diff.git 60 | 61 | cd Live2Diff 62 | git submodule update --init --recursive 63 | ``` 64 | 65 | ### Step1: Make Environment 66 | 67 | Create virtual environment via conda: 68 | 69 | ```bash 70 | conda create -n live2diff python=3.10 71 | conda activate live2diff 72 | ``` 73 | 74 | ### Step2: Install PyTorch and xformers 75 | 76 | Select the appropriate version for your system. 77 | 78 | ```bash 79 | # CUDA 11.8 80 | pip install torch torchvision xformers --index-url https://download.pytorch.org/whl/cu118 81 | # CUDA 12.1 82 | pip install torch torchvision xformers --index-url https://download.pytorch.org/whl/cu121 83 | ``` 84 | 85 | Please may refers to https://pytorch.org/ for more detail. 86 | 87 | ### Step3: Install Project 88 | 89 | If you want to use TensorRT acceleration (we recommend it), you can install it by the following command. 90 | 91 | ```bash 92 | # for cuda 11.x 93 | pip install ."[tensorrt_cu11]" 94 | # for cuda 12.x 95 | pip install ."[tensorrt_cu12]" 96 | ``` 97 | 98 | Otherwise, you can install it via 99 | 100 | ```bash 101 | pip install . 102 | ``` 103 | 104 | If you want to install it with development mode (a.k.a. "Editable Installs"), you can add `-e` option. 105 | 106 | ```bash 107 | # for cuda 11.x 108 | pip install -e ."[tensorrt_cu11]" 109 | # for cuda 12.x 110 | pip install -e ."[tensorrt_cu12]" 111 | # or 112 | pip install -e . 113 | ``` 114 | 115 | ### Step4: Download Checkpoints and Demo Data 116 | 117 | 1. Download StableDiffusion-v1-5 118 | 119 | ```bash 120 | huggingface-cli download runwayml/stable-diffusion-v1-5 --local-dir ./models/Model/stable-diffusion-v1-5 121 | ``` 122 | 123 | 2. Download Checkpoint from [HuggingFace](https://huggingface.co/Leoxing/Live2Diff) and put it under `models` folder. 124 | 125 | 3. Download Depth Detector from MiDaS's official [release](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt) and put it under `models` folder. 126 | 127 | 4. Apply the download token from [civitAI](https://education.civitai.com/civitais-guide-to-downloading-via-api/) and then download Dreambooths and LoRAs via the script: 128 | 129 | ```bash 130 | # download all DreamBooth/Lora 131 | bash scripts/download.sh all YOUR_TOKEN 132 | # or download the one you want to use 133 | bash scripts/download.sh disney YOUR_TOKEN 134 | ``` 135 | 136 | 5. Download demo data from [OneDrive](https://pjlab-my.sharepoint.cn/:f:/g/personal/xingzhening_pjlab_org_cn/EpefezlxFXNBk93RDttYLMUBP2bofb6AZDfyRIkGapmIrQ?e=A6h2Eb). 137 | 138 | Then then data structure of `models` folder should be like this: 139 | 140 | ```bash 141 | ./ 142 | |-- models 143 | | |-- LoRA 144 | | | |-- MoXinV1.safetensors 145 | | | `-- ... 146 | | |-- Model 147 | | | |-- 3Guofeng3_v34.safetensors 148 | | | |-- ... 149 | | | `-- stable-diffusion-v1-5 150 | | |-- live2diff.ckpt 151 | | `-- dpt_hybrid_384.pt 152 | `--data 153 | |-- 1.mp4 154 | |-- 2.mp4 155 | |-- 3.mp4 156 | `-- 4.mp4 157 | ``` 158 | 159 | ### Notification 160 | 161 | The above installation steps (e.g. [download script](#step4-download-checkpoints-and-demo-data)) are for Linux users and not well tested on Windows. If you face any difficulties, please feel free to open an issue 🤗. 162 | 163 | ## Quick Start 164 | 165 | You can try examples under [`data`](./data) directory. For example, 166 | ```bash 167 | # with TensorRT acceleration, please pay patience for the first time, may take more than 20 minutes 168 | python test.py ./data/1.mp4 ./configs/disneyPixar.yaml --max-frames -1 --prompt "1man is talking" --output work_dirs/1-disneyPixar.mp4 --height 512 --width 512 --acceleration tensorrt 169 | 170 | # without TensorRT acceleration 171 | python test.py ./data/2.mp4 ./configs/disneyPixar.yaml --max-frames -1 --prompt "1man is talking" --output work_dirs/1-disneyPixar.mp4 --height 512 --width 512 --acceleration none 172 | ``` 173 | 174 | You can adjust denoising strength via `--num-inference-steps`, `--strength`, and `--t-index-list`. Please refers to `test.py` for more detail. 175 | 176 | ## Troubleshooting 177 | 178 | 1. If you face Cuda Out-of-memory error with TensorRT, please try to reduce `t-index-list` or `strength`. When inference with TensorRT, we maintian a group of buffer for kv-cache, which consumes more memory. Reduce `t-index-list` or `strength` can reduce the size of kv-cache and save more GPU memory. 179 | 180 | ## Real-Time Video2Video Demo 181 | 182 | There is an interactive txt2img demo in [`demo`](./demo) directory! 183 | 184 | Please refers to [`demo/README.md`](./demo/README.md) for more details. 185 | 186 |
187 | 188 | 189 | 190 | 193 | 196 | 197 | 198 | 201 | 204 | 205 | 206 |
191 |

Human Face (Web Camera Input)

192 |
194 |

Anime Character (Screen Video Input)

195 |
199 | 202 |
207 | 208 |
209 | 210 | ## Acknowledgements 211 | 212 | The video and image demos in this GitHub repository were generated using [LCM-LoRA](https://huggingface.co/latent-consistency/lcm-lora-sdv1-5). Stream batch in [StreamDiffusion](https://github.com/cumulo-autumn/StreamDiffusion) is used for model acceleration. The design of Video Diffusion Model is adopted from [AnimateDiff](https://github.com/guoyww/AnimateDiff). We use a third-party implementation of [MiDaS](https://github.com/lewiji/MiDaS) implementation which support onnx export. Our online demo is modified from [Real-Time-Latent-Consistency-Model](https://github.com/radames/Real-Time-Latent-Consistency-Model/). 213 | 214 | ## BibTex 215 | 216 | If you find it helpful, please consider citing our work: 217 | 218 | ```bibtex 219 | @article{xing2024live2diff, 220 | title={Live2Diff: Live Stream Translation via Uni-directional Attention in Video Diffusion Models}, 221 | author={Zhening Xing and Gereon Fox and Yanhong Zeng and Xingang Pan and Mohamed Elgharib and Christian Theobalt and Kai Chen}, 222 | booktitle={arXiv preprint arxiv:2407.08701}, 223 | year={2024} 224 | } 225 | ``` 226 | -------------------------------------------------------------------------------- /assets/attn-mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/Live2Diff/e40aa03a9a9d17f1232fd7a4566b3ee793e6893f/assets/attn-mask.png -------------------------------------------------------------------------------- /assets/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/Live2Diff/e40aa03a9a9d17f1232fd7a4566b3ee793e6893f/assets/framework.jpg -------------------------------------------------------------------------------- /configs/base_config.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: "./models/Model/stable-diffusion-v1-5" 2 | 3 | motion_module_path: './models/live2diff.ckpt' 4 | depth_model_path: './models/dpt_hybrid_384.pt' 5 | 6 | unet_additional_kwargs: 7 | cond_mapping: true 8 | use_inflated_groupnorm: true 9 | use_motion_module : true 10 | motion_module_resolutions : [ 1,2,4,8 ] 11 | unet_use_cross_frame_attention : false 12 | unet_use_temporal_attention : false 13 | 14 | motion_module_type: Streaming 15 | motion_module_kwargs: 16 | num_attention_heads : 8 17 | num_transformer_block : 1 18 | attention_block_types : [ "Temporal_Self", "Temporal_Self" ] 19 | temporal_position_encoding : true 20 | temporal_position_encoding_max_len : 24 21 | temporal_attention_dim_div : 1 22 | zero_initialize : true 23 | 24 | attention_class_name : 'stream' 25 | 26 | attention_kwargs: 27 | window_size: 16 28 | sink_size: 8 29 | 30 | noise_scheduler_kwargs: 31 | num_train_timesteps: 1000 32 | beta_start: 0.00085 33 | beta_end: 0.012 34 | beta_schedule: "linear" 35 | steps_offset: 1 36 | clip_sample: False 37 | -------------------------------------------------------------------------------- /configs/disneyPixar.yaml: -------------------------------------------------------------------------------- 1 | # good s0.4 2 | base: "./configs/base_config.yaml" 3 | 4 | prompt_template: "masterpiece, best quality, intricate, print, pattern, {}" 5 | 6 | third_party_dict: 7 | dreambooth: "./models/Model/disneyPixarCartoon_v10.safetensors" 8 | clip_skip: 2 9 | 10 | num_inference_steps: 50 11 | t_index_list: [30, 36, 42] 12 | -------------------------------------------------------------------------------- /configs/kFelted.yaml: -------------------------------------------------------------------------------- 1 | # good, s06-4 and s05-4 2 | base: "./configs/base_config.yaml" 3 | 4 | prompt_template: "masterpiece, best quality, felted, {}," 5 | 6 | third_party_dict: 7 | dreambooth: "./models/Model/revAnimated_v2RebirthVAE.safetensors" 8 | lora_list: 9 | - lora: './models/LoRA/kFeltedReV.safetensors' 10 | lora_alpha: 1 11 | clip_skip: 2 12 | 13 | num_inference_steps: 50 14 | t_index_list: [25, 34, 43] 15 | # or 16 | # num_inference_steps: 50 17 | # t_index_list: [20, 27, 34, 41] 18 | -------------------------------------------------------------------------------- /configs/moxin.yaml: -------------------------------------------------------------------------------- 1 | base: "./configs/base_config.yaml" 2 | 3 | prompt_template: 'shukezouma, negative space, shuimobysim, official art,extremely detailed CG,unity 8k wallpaper,chinese ink painting, {}' 4 | 5 | third_party_dict: 6 | lora_list: 7 | - lora: "./models/LoRA/MoXinV1.safetensors" 8 | lora_alpha: 0.7 9 | dreambooth: "./models/Model/3Guofeng3_v34.safetensors" 10 | clip_skip: 2 11 | 12 | num_inference_steps: 50 13 | t_index_list: [30, 36, 42] 14 | -------------------------------------------------------------------------------- /configs/origami.yaml: -------------------------------------------------------------------------------- 1 | # good, s04!!!! 2 | 3 | base: "./configs/base_config.yaml" 4 | 5 | prompt_template: "(masterpiece),best quality, a origami paper of {}" 6 | 7 | third_party_dict: 8 | dreambooth: "./models/Model/helloartdoor_V122p.safetensors" 9 | lora_list: 10 | - lora: "./models/LoRA/ral-origami-sd15.safetensors" 11 | lora_alpha: 1 12 | clip_skip: 2 13 | 14 | num_inference_steps: 50 15 | t_index_list: [30, 36, 42] 16 | -------------------------------------------------------------------------------- /configs/pixart.yaml: -------------------------------------------------------------------------------- 1 | # good, 0.4 & 0.3 2 | 3 | base: "./configs/base_config.yaml" 4 | 5 | prompt_template: "(masterpiece), best quality, {}" 6 | 7 | third_party_dict: 8 | dreambooth: "./models/Model/aziibpixelmix_v10.safetensors" 9 | clip_skip: 2 10 | 11 | num_inference_steps: 4 12 | strength: 0.6 13 | -------------------------------------------------------------------------------- /configs/toonyou.yaml: -------------------------------------------------------------------------------- 1 | base: "./configs/base_config.yaml" 2 | 3 | prompt: "masterpiece, best quality, intricate, print, pattern, {}" 4 | 5 | third_party_dict: 6 | dreambooth: "./models/Model/toonyou_beta6.safetensors" 7 | clip_skip: 2 8 | 9 | num_inference_steps: 50 10 | t_index_list: [25, 31, 37, 43] 11 | -------------------------------------------------------------------------------- /demo/.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tar filter=lfs diff=lfs merge=lfs -text 29 | *.tflite filter=lfs diff=lfs merge=lfs -text 30 | *.tgz filter=lfs diff=lfs merge=lfs -text 31 | *.wasm filter=lfs diff=lfs merge=lfs -text 32 | *.xz filter=lfs diff=lfs merge=lfs -text 33 | *.zip filter=lfs diff=lfs merge=lfs -text 34 | *.zst filter=lfs diff=lfs merge=lfs -text 35 | *tfevents* filter=lfs diff=lfs merge=lfs -text 36 | -------------------------------------------------------------------------------- /demo/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | venv/ 3 | public/ 4 | *.pem 5 | !lib/ 6 | !static/ 7 | -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | # Video2Video Example 2 | 3 |
4 | 5 | 6 | 7 | 10 | 13 | 14 | 15 | 18 | 21 | 22 | 23 |
8 |

Human Face (Web Camera Input)

9 |
11 |

Anime Character (Screen Video Input)

12 |
16 | 19 |
24 | 25 |
26 | 27 | This example, based on this [MJPEG server](https://github.com/radames/Real-Time-Latent-Consistency-Model/), runs image-to-image with a live webcam feed or screen capture on a web browser. 28 | 29 | ## Usage 30 | 31 | ### 1. Prepare Dependencies 32 | 33 | You need Node.js 18+ and Python 3.10 to run this example. Please make sure you've installed all dependencies according to the [installation instructions](../README.md#installation). 34 | 35 | ```bash 36 | cd frontend 37 | npm i 38 | npm run build 39 | cd .. 40 | pip install -r requirements.txt 41 | ``` 42 | 43 | If you face some difficulties in install `npm`, you can try to install it via `conda`: 44 | 45 | ```bash 46 | conda install -c conda-forge nodejs 47 | ``` 48 | 49 | ### 2. Run Demo 50 | 51 | If you run the demo with default [setting](./demo_cfg.yaml), you should download the model for style `felted`. 52 | 53 | ```bash 54 | bash ../scripts/download_model.sh felted 55 | ``` 56 | 57 | Then, you can run the demo with the following command, and open `http://127.0.0.1:7860` in your browser: 58 | 59 | ```bash 60 | # with TensorRT acceleration, please pay patience for the first time, may take more than 20 minutes 61 | python main.py --port 7860 --host 127.0.0.1 --acceleration tensorrt 62 | # if you don't have TensorRT, you can run it with `none` acceleration 63 | python main.py --port 7860 --host 127.0.0.1 --acceleration none 64 | ``` 65 | 66 | If you want to run this demo on a remote server, you can set host to `0.0.0.0`, e.g. 67 | 68 | ```bash 69 | python main.py --port 7860 --host 0.0.0.0 --acceleration tensorrt 70 | ``` 71 | -------------------------------------------------------------------------------- /demo/app.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import mimetypes 3 | import os 4 | import time 5 | import uuid 6 | from types import SimpleNamespace 7 | 8 | import markdown2 9 | import torch 10 | from config import Args, config 11 | from connection_manager import ConnectionManager, ServerFullException 12 | from fastapi import FastAPI, HTTPException, Request, WebSocket 13 | from fastapi.middleware.cors import CORSMiddleware 14 | from fastapi.responses import JSONResponse, StreamingResponse 15 | from fastapi.staticfiles import StaticFiles 16 | from util import bytes_to_pil, pil_to_frame 17 | from vid2vid import Pipeline 18 | 19 | 20 | # fix mime error on windows 21 | mimetypes.add_type("application/javascript", ".js") 22 | 23 | THROTTLE = 1.0 / 120 24 | # logging.basicConfig(level=logging.DEBUG) 25 | 26 | 27 | class App: 28 | def __init__(self, config: Args): 29 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 | torch_dtype = torch.float16 31 | pipeline = Pipeline(config, device, torch_dtype) 32 | self.args = config 33 | self.pipeline = pipeline 34 | self.app = FastAPI() 35 | self.conn_manager = ConnectionManager() 36 | self.init_app() 37 | 38 | def init_app(self): 39 | self.app.add_middleware( 40 | CORSMiddleware, 41 | allow_origins=["*"], 42 | allow_credentials=True, 43 | allow_methods=["*"], 44 | allow_headers=["*"], 45 | ) 46 | 47 | @self.app.websocket("/api/ws/{user_id}") 48 | async def websocket_endpoint(user_id: uuid.UUID, websocket: WebSocket): 49 | try: 50 | await self.conn_manager.connect(user_id, websocket, self.args.max_queue_size) 51 | await handle_websocket_data(user_id) 52 | except ServerFullException as e: 53 | logging.error(f"Server Full: {e}") 54 | finally: 55 | await self.conn_manager.disconnect(user_id) 56 | logging.info(f"User disconnected: {user_id}") 57 | 58 | async def handle_websocket_data(user_id: uuid.UUID): 59 | if not self.conn_manager.check_user(user_id): 60 | return HTTPException(status_code=404, detail="User not found") 61 | last_time = time.time() 62 | try: 63 | while True: 64 | if self.args.timeout > 0 and time.time() - last_time > self.args.timeout: 65 | await self.conn_manager.send_json( 66 | user_id, 67 | { 68 | "status": "timeout", 69 | "message": "Your session has ended", 70 | }, 71 | ) 72 | await self.conn_manager.disconnect(user_id) 73 | return 74 | data = await self.conn_manager.receive_json(user_id) 75 | if data["status"] == "next_frame": 76 | info = self.pipeline.Info() 77 | params = await self.conn_manager.receive_json(user_id) 78 | params = self.pipeline.InputParams(**params) 79 | params = SimpleNamespace(**params.model_dump()) 80 | if info.input_mode == "image": 81 | image_data = await self.conn_manager.receive_bytes(user_id) 82 | if len(image_data) == 0: 83 | await self.conn_manager.send_json(user_id, {"status": "send_frame"}) 84 | continue 85 | params.image = bytes_to_pil(image_data) 86 | await self.conn_manager.update_data(user_id, params) 87 | 88 | except Exception as e: 89 | logging.error(f"Websocket Error: {e}, {user_id} ") 90 | await self.conn_manager.disconnect(user_id) 91 | 92 | @self.app.get("/api/queue") 93 | async def get_queue_size(): 94 | queue_size = self.conn_manager.get_user_count() 95 | return JSONResponse({"queue_size": queue_size}) 96 | 97 | @self.app.get("/api/stream/{user_id}") 98 | async def stream(user_id: uuid.UUID, request: Request): 99 | try: 100 | 101 | async def generate(): 102 | while True: 103 | last_time = time.time() 104 | await self.conn_manager.send_json(user_id, {"status": "send_frame"}) 105 | params = await self.conn_manager.get_latest_data(user_id) 106 | if params is None: 107 | continue 108 | image = self.pipeline.predict(params) 109 | if image is None: 110 | continue 111 | frame = pil_to_frame(image) 112 | yield frame 113 | if self.args.debug: 114 | print(f"Time taken: {time.time() - last_time}") 115 | 116 | return StreamingResponse( 117 | generate(), 118 | media_type="multipart/x-mixed-replace;boundary=frame", 119 | headers={"Cache-Control": "no-cache"}, 120 | ) 121 | except Exception as e: 122 | logging.error(f"Streaming Error: {e}, {user_id} ") 123 | return HTTPException(status_code=404, detail="User not found") 124 | 125 | # route to setup frontend 126 | @self.app.get("/api/settings") 127 | async def settings(): 128 | info_schema = self.pipeline.Info.model_json_schema() 129 | info = self.pipeline.Info() 130 | if info.page_content: 131 | page_content = markdown2.markdown(info.page_content) 132 | 133 | input_params = self.pipeline.InputParams.model_json_schema() 134 | return JSONResponse( 135 | { 136 | "info": info_schema, 137 | "input_params": input_params, 138 | "max_queue_size": self.args.max_queue_size, 139 | "page_content": page_content if info.page_content else "", 140 | } 141 | ) 142 | 143 | if not os.path.exists("public"): 144 | os.makedirs("public") 145 | 146 | self.app.mount("/", StaticFiles(directory="./frontend/public", html=True), name="public") 147 | 148 | 149 | app = App(config).app 150 | -------------------------------------------------------------------------------- /demo/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import List, NamedTuple 4 | 5 | 6 | class Args(NamedTuple): 7 | host: str 8 | port: int 9 | reload: bool 10 | max_queue_size: int 11 | timeout: float 12 | safety_checker: bool 13 | taesd: bool 14 | ssl_certfile: str 15 | ssl_keyfile: str 16 | debug: bool 17 | acceleration: str 18 | engine_dir: str 19 | config: str 20 | seed: int 21 | num_inference_steps: int 22 | strength: float 23 | t_index_list: List[int] 24 | prompt: str 25 | 26 | def pretty_print(self): 27 | print("\n") 28 | for field, value in self._asdict().items(): 29 | print(f"{field}: {value}") 30 | print("\n") 31 | 32 | 33 | MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0)) 34 | TIMEOUT = float(os.environ.get("TIMEOUT", 0)) 35 | SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None) == "True" 36 | USE_TAESD = os.environ.get("USE_TAESD", "True") == "True" 37 | ENGINE_DIR = os.environ.get("ENGINE_DIR", "engines") 38 | ACCELERATION = os.environ.get("ACCELERATION", "tensorrt") 39 | 40 | default_host = os.getenv("HOST", "0.0.0.0") 41 | default_port = int(os.getenv("PORT", "7860")) 42 | default_mode = os.getenv("MODE", "default") 43 | 44 | parser = argparse.ArgumentParser(description="Run the app") 45 | parser.add_argument("--host", type=str, default=default_host, help="Host address") 46 | parser.add_argument("--port", type=int, default=default_port, help="Port number") 47 | parser.add_argument("--reload", action="store_true", help="Reload code on change") 48 | parser.add_argument( 49 | "--max-queue-size", 50 | dest="max_queue_size", 51 | type=int, 52 | default=MAX_QUEUE_SIZE, 53 | help="Max Queue Size", 54 | ) 55 | parser.add_argument("--timeout", type=float, default=TIMEOUT, help="Timeout") 56 | parser.add_argument( 57 | "--safety-checker", 58 | dest="safety_checker", 59 | action="store_true", 60 | default=SAFETY_CHECKER, 61 | help="Safety Checker", 62 | ) 63 | parser.add_argument( 64 | "--taesd", 65 | dest="taesd", 66 | action="store_true", 67 | help="Use Tiny Autoencoder", 68 | ) 69 | parser.add_argument( 70 | "--no-taesd", 71 | dest="taesd", 72 | action="store_false", 73 | help="Use Tiny Autoencoder", 74 | ) 75 | parser.add_argument( 76 | "--ssl-certfile", 77 | dest="ssl_certfile", 78 | type=str, 79 | default=None, 80 | help="SSL certfile", 81 | ) 82 | parser.add_argument( 83 | "--ssl-keyfile", 84 | dest="ssl_keyfile", 85 | type=str, 86 | default=None, 87 | help="SSL keyfile", 88 | ) 89 | parser.add_argument( 90 | "--debug", 91 | action="store_true", 92 | default=False, 93 | help="Debug", 94 | ) 95 | parser.add_argument( 96 | "--acceleration", 97 | type=str, 98 | default=ACCELERATION, 99 | choices=["none", "xformers", "tensorrt"], 100 | help="Acceleration", 101 | ) 102 | parser.add_argument( 103 | "--engine-dir", 104 | dest="engine_dir", 105 | type=str, 106 | default=ENGINE_DIR, 107 | help="Engine Dir", 108 | ) 109 | parser.add_argument( 110 | "--config", 111 | default="./demo_cfg.yaml", 112 | ) 113 | parser.add_argument("--num-inference-steps", type=int, default=None) 114 | parser.add_argument("--strength", type=float, default=None) 115 | parser.add_argument("--t-index-list", type=list) 116 | parser.add_argument("--seed", default=42) 117 | parser.add_argument("--prompt", type=str) 118 | 119 | parser.set_defaults(taesd=USE_TAESD) 120 | config = Args(**vars(parser.parse_args())) 121 | config.pretty_print() 122 | -------------------------------------------------------------------------------- /demo/connection_manager.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from types import SimpleNamespace 4 | from typing import Dict, Union 5 | from uuid import UUID 6 | 7 | from fastapi import WebSocket 8 | from starlette.websockets import WebSocketState 9 | 10 | 11 | Connections = Dict[UUID, Dict[str, Union[WebSocket, asyncio.Queue]]] 12 | 13 | 14 | class ServerFullException(Exception): 15 | """Exception raised when the server is full.""" 16 | 17 | pass 18 | 19 | 20 | class ConnectionManager: 21 | def __init__(self): 22 | self.active_connections: Connections = {} 23 | 24 | async def connect(self, user_id: UUID, websocket: WebSocket, max_queue_size: int = 0): 25 | await websocket.accept() 26 | user_count = self.get_user_count() 27 | print(f"User count: {user_count}") 28 | if max_queue_size > 0 and user_count >= max_queue_size: 29 | print("Server is full") 30 | await websocket.send_json({"status": "error", "message": "Server is full"}) 31 | await websocket.close() 32 | raise ServerFullException("Server is full") 33 | print(f"New user connected: {user_id}") 34 | self.active_connections[user_id] = { 35 | "websocket": websocket, 36 | "queue": asyncio.Queue(), 37 | } 38 | await websocket.send_json( 39 | {"status": "connected", "message": "Connected"}, 40 | ) 41 | await websocket.send_json({"status": "wait"}) 42 | await websocket.send_json({"status": "send_frame"}) 43 | 44 | def check_user(self, user_id: UUID) -> bool: 45 | return user_id in self.active_connections 46 | 47 | async def update_data(self, user_id: UUID, new_data: SimpleNamespace): 48 | user_session = self.active_connections.get(user_id) 49 | if user_session: 50 | queue = user_session["queue"] 51 | await queue.put(new_data) 52 | 53 | async def get_latest_data(self, user_id: UUID) -> SimpleNamespace: 54 | user_session = self.active_connections.get(user_id) 55 | if user_session: 56 | queue = user_session["queue"] 57 | try: 58 | return await queue.get() 59 | except asyncio.QueueEmpty: 60 | return None 61 | 62 | def delete_user(self, user_id: UUID): 63 | user_session = self.active_connections.pop(user_id, None) 64 | if user_session: 65 | queue = user_session["queue"] 66 | while not queue.empty(): 67 | try: 68 | queue.get_nowait() 69 | except asyncio.QueueEmpty: 70 | continue 71 | 72 | def get_user_count(self) -> int: 73 | return len(self.active_connections) 74 | 75 | def get_websocket(self, user_id: UUID) -> WebSocket: 76 | user_session = self.active_connections.get(user_id) 77 | if user_session: 78 | websocket = user_session["websocket"] 79 | if websocket.client_state == WebSocketState.CONNECTED: 80 | return user_session["websocket"] 81 | return None 82 | 83 | async def disconnect(self, user_id: UUID): 84 | websocket = self.get_websocket(user_id) 85 | if websocket: 86 | await websocket.close() 87 | self.delete_user(user_id) 88 | 89 | async def send_json(self, user_id: UUID, data: Dict): 90 | try: 91 | websocket = self.get_websocket(user_id) 92 | if websocket: 93 | await websocket.send_json(data) 94 | except Exception as e: 95 | logging.error(f"Error: Send json: {e}") 96 | 97 | async def receive_json(self, user_id: UUID) -> Dict: 98 | try: 99 | websocket = self.get_websocket(user_id) 100 | if websocket: 101 | return await websocket.receive_json() 102 | except Exception as e: 103 | logging.error(f"Error: Receive json: {e}") 104 | 105 | async def receive_bytes(self, user_id: UUID) -> bytes: 106 | try: 107 | websocket = self.get_websocket(user_id) 108 | if websocket: 109 | return await websocket.receive_bytes() 110 | except Exception as e: 111 | logging.error(f"Error: Receive bytes: {e}") 112 | -------------------------------------------------------------------------------- /demo/demo_cfg.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: "../models/Model/stable-diffusion-v1-5" 2 | 3 | motion_module_path: '../models/live2diff.ckpt' 4 | depth_model_path: '../models/dpt_hybrid_384.pt' 5 | 6 | unet_additional_kwargs: 7 | cond_mapping: true 8 | use_inflated_groupnorm: true 9 | use_motion_module : true 10 | motion_module_resolutions : [ 1,2,4,8 ] 11 | unet_use_cross_frame_attention : false 12 | unet_use_temporal_attention : false 13 | 14 | motion_module_type: Streaming 15 | motion_module_kwargs: 16 | num_attention_heads : 8 17 | num_transformer_block : 1 18 | attention_block_types : [ "Temporal_Self", "Temporal_Self" ] 19 | temporal_position_encoding : true 20 | temporal_position_encoding_max_len : 24 21 | temporal_attention_dim_div : 1 22 | zero_initialize : true 23 | 24 | attention_class_name : 'stream' 25 | 26 | attention_kwargs: 27 | window_size: 16 28 | sink_size: 8 29 | 30 | noise_scheduler_kwargs: 31 | num_train_timesteps: 1000 32 | beta_start: 0.00085 33 | beta_end: 0.012 34 | beta_schedule: "linear" 35 | steps_offset: 1 36 | clip_sample: False 37 | 38 | third_party_dict: 39 | dreambooth: "../models/Model/revAnimated_v2RebirthVAE.safetensors" 40 | lora_list: 41 | - lora: '../models/LoRA/kFeltedReV.safetensors' 42 | lora_alpha: 1 43 | clip_skip: 2 44 | 45 | num_inference_steps: 50 46 | t_index_list: [30, 40] 47 | prompt: "masterpiece, best quality, felted, 1man with glasses, glasses, play with his pen" 48 | -------------------------------------------------------------------------------- /demo/demo_cfg_arknight.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: "../models/Model/stable-diffusion-v1-5" 2 | 3 | motion_module_path: '../models/live2diff.ckpt' 4 | depth_model_path: '../models/dpt_hybrid_384.pt' 5 | 6 | unet_additional_kwargs: 7 | cond_mapping: true 8 | use_inflated_groupnorm: true 9 | use_motion_module : true 10 | motion_module_resolutions : [ 1,2,4,8 ] 11 | unet_use_cross_frame_attention : false 12 | unet_use_temporal_attention : false 13 | 14 | motion_module_type: Streaming 15 | motion_module_kwargs: 16 | num_attention_heads : 8 17 | num_transformer_block : 1 18 | attention_block_types : [ "Temporal_Self", "Temporal_Self" ] 19 | temporal_position_encoding : true 20 | temporal_position_encoding_max_len : 24 21 | temporal_attention_dim_div : 1 22 | zero_initialize : true 23 | 24 | attention_class_name : 'stream' 25 | 26 | attention_kwargs: 27 | window_size: 16 28 | sink_size: 8 29 | 30 | noise_scheduler_kwargs: 31 | num_train_timesteps: 1000 32 | beta_start: 0.00085 33 | beta_end: 0.012 34 | beta_schedule: "linear" 35 | steps_offset: 1 36 | clip_sample: False 37 | 38 | third_party_dict: 39 | dreambooth: "../models/Model/aziibpixelmix_v10.safetensors" 40 | lora_list: 41 | - lora: '../models/LoRA/kFeltedReV.safetensors' 42 | lora_alpha: 1 43 | clip_skip: 2 44 | 45 | num_inference_steps: 50 46 | t_index_list: [35, 40, 45] 47 | 48 | prompt: "masterpiece, best quality, 1gril" 49 | -------------------------------------------------------------------------------- /demo/frontend/.eslintignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | node_modules 3 | /build 4 | /.svelte-kit 5 | /package 6 | .env 7 | .env.* 8 | !.env.example 9 | 10 | # Ignore files for PNPM, NPM and YARN 11 | pnpm-lock.yaml 12 | package-lock.json 13 | yarn.lock 14 | -------------------------------------------------------------------------------- /demo/frontend/.eslintrc.cjs: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | root: true, 3 | extends: [ 4 | 'eslint:recommended', 5 | 'plugin:@typescript-eslint/recommended', 6 | 'plugin:svelte/recommended', 7 | 'prettier' 8 | ], 9 | parser: '@typescript-eslint/parser', 10 | plugins: ['@typescript-eslint'], 11 | parserOptions: { 12 | sourceType: 'module', 13 | ecmaVersion: 2020, 14 | extraFileExtensions: ['.svelte'] 15 | }, 16 | env: { 17 | browser: true, 18 | es2017: true, 19 | node: true 20 | }, 21 | overrides: [ 22 | { 23 | files: ['*.svelte'], 24 | parser: 'svelte-eslint-parser', 25 | parserOptions: { 26 | parser: '@typescript-eslint/parser' 27 | } 28 | } 29 | ] 30 | }; 31 | -------------------------------------------------------------------------------- /demo/frontend/.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | node_modules 3 | /build 4 | /.svelte-kit 5 | /package 6 | .env 7 | .env.* 8 | !.env.example 9 | vite.config.js.timestamp-* 10 | vite.config.ts.timestamp-* 11 | -------------------------------------------------------------------------------- /demo/frontend/.npmrc: -------------------------------------------------------------------------------- 1 | engine-strict=true 2 | -------------------------------------------------------------------------------- /demo/frontend/.prettierignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | node_modules 3 | /build 4 | /.svelte-kit 5 | /package 6 | .env 7 | .env.* 8 | !.env.example 9 | 10 | # Ignore files for PNPM, NPM and YARN 11 | pnpm-lock.yaml 12 | package-lock.json 13 | yarn.lock 14 | -------------------------------------------------------------------------------- /demo/frontend/.prettierrc: -------------------------------------------------------------------------------- 1 | { 2 | "useTabs": false, 3 | "singleQuote": true, 4 | "trailingComma": "none", 5 | "printWidth": 100, 6 | "plugins": [ 7 | "prettier-plugin-svelte", 8 | "prettier-plugin-organize-imports", 9 | "prettier-plugin-tailwindcss" 10 | ], 11 | "overrides": [ 12 | { 13 | "files": "*.svelte", 14 | "options": { 15 | "parser": "svelte" 16 | } 17 | } 18 | ] 19 | } 20 | -------------------------------------------------------------------------------- /demo/frontend/README.md: -------------------------------------------------------------------------------- 1 | # create-svelte 2 | 3 | Everything you need to build a Svelte project, powered by [`create-svelte`](https://github.com/sveltejs/kit/tree/master/packages/create-svelte). 4 | 5 | ## Creating a project 6 | 7 | If you're seeing this, you've probably already done this step. Congrats! 8 | 9 | ```bash 10 | # create a new project in the current directory 11 | npm create svelte@latest 12 | 13 | # create a new project in my-app 14 | npm create svelte@latest my-app 15 | ``` 16 | 17 | ## Developing 18 | 19 | Once you've created a project and installed dependencies with `npm install` (or `pnpm install` or `yarn`), start a development server: 20 | 21 | ```bash 22 | npm run dev 23 | 24 | # or start the server and open the app in a new browser tab 25 | npm run dev -- --open 26 | ``` 27 | 28 | ## Building 29 | 30 | To create a production version of your app: 31 | 32 | ```bash 33 | npm run build 34 | ``` 35 | 36 | You can preview the production build with `npm run preview`. 37 | 38 | > To deploy your app, you may need to install an [adapter](https://kit.svelte.dev/docs/adapters) for your target environment. 39 | -------------------------------------------------------------------------------- /demo/frontend/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "frontend", 3 | "version": "0.0.1", 4 | "private": true, 5 | "scripts": { 6 | "dev": "vite dev", 7 | "build": "vite build", 8 | "preview": "vite preview", 9 | "check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json", 10 | "check:watch": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json --watch", 11 | "lint": "prettier --check . && eslint .", 12 | "format": "prettier --write ." 13 | }, 14 | "devDependencies": { 15 | "@sveltejs/adapter-auto": "^2.0.0", 16 | "@sveltejs/adapter-static": "^2.0.3", 17 | "@sveltejs/kit": "^1.20.4", 18 | "@typescript-eslint/eslint-plugin": "^6.0.0", 19 | "@typescript-eslint/parser": "^6.0.0", 20 | "autoprefixer": "^10.4.16", 21 | "eslint": "^8.28.0", 22 | "eslint-config-prettier": "^9.0.0", 23 | "eslint-plugin-svelte": "^2.30.0", 24 | "postcss": "^8.4.31", 25 | "prettier": "^3.1.0", 26 | "prettier-plugin-organize-imports": "^3.2.4", 27 | "prettier-plugin-svelte": "^3.1.0", 28 | "prettier-plugin-tailwindcss": "^0.5.7", 29 | "svelte": "^4.0.5", 30 | "svelte-check": "^3.4.3", 31 | "tailwindcss": "^3.3.5", 32 | "tslib": "^2.4.1", 33 | "typescript": "^5.0.0", 34 | "vite": "^4.4.2" 35 | }, 36 | "type": "module", 37 | "dependencies": { 38 | "piexifjs": "^1.0.6", 39 | "rvfc-polyfill": "^1.0.7" 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /demo/frontend/postcss.config.js: -------------------------------------------------------------------------------- 1 | export default { 2 | plugins: { 3 | tailwindcss: {}, 4 | autoprefixer: {} 5 | } 6 | }; 7 | -------------------------------------------------------------------------------- /demo/frontend/src/app.css: -------------------------------------------------------------------------------- 1 | @tailwind base; 2 | @tailwind components; 3 | @tailwind utilities; 4 | -------------------------------------------------------------------------------- /demo/frontend/src/app.d.ts: -------------------------------------------------------------------------------- 1 | // See https://kit.svelte.dev/docs/types#app 2 | // for information about these interfaces 3 | declare global { 4 | namespace App { 5 | // interface Error {} 6 | // interface Locals {} 7 | // interface PageData {} 8 | // interface Platform {} 9 | } 10 | } 11 | 12 | export {}; 13 | -------------------------------------------------------------------------------- /demo/frontend/src/app.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | %sveltekit.head% 8 | 9 | 10 |
%sveltekit.body%
11 | 12 | 13 | -------------------------------------------------------------------------------- /demo/frontend/src/lib/components/Button.svelte: -------------------------------------------------------------------------------- 1 | 6 | 7 | 10 | 11 | 16 | -------------------------------------------------------------------------------- /demo/frontend/src/lib/components/Checkbox.svelte: -------------------------------------------------------------------------------- 1 | 11 | 12 |
13 | 14 | 15 |
16 | -------------------------------------------------------------------------------- /demo/frontend/src/lib/components/ImagePlayer.svelte: -------------------------------------------------------------------------------- 1 | 23 | 24 |
27 | 28 | {#if isLCMRunning && $streamId} 29 | 34 |
35 | 43 |
44 | {:else} 45 | 49 | {/if} 50 |
51 | -------------------------------------------------------------------------------- /demo/frontend/src/lib/components/InputRange.svelte: -------------------------------------------------------------------------------- 1 | 11 | 12 |
13 | 14 | 24 | 31 |
32 | 53 | -------------------------------------------------------------------------------- /demo/frontend/src/lib/components/MediaListSwitcher.svelte: -------------------------------------------------------------------------------- 1 | 17 | 18 |
19 | 28 | {#if $mediaDevices} 29 | 39 | {/if} 40 |
41 | -------------------------------------------------------------------------------- /demo/frontend/src/lib/components/PipelineOptions.svelte: -------------------------------------------------------------------------------- 1 | 19 | 20 |
21 |
22 | {#if featuredOptions} 23 | {#each featuredOptions as params} 24 | {#if params.field === FieldType.RANGE} 25 | 26 | {:else if params.field === FieldType.SEED} 27 | 28 | {:else if params.field === FieldType.TEXTAREA} 29 | 30 | {:else if params.field === FieldType.CHECKBOX} 31 | 32 | {:else if params.field === FieldType.SELECT} 33 | 34 | {/if} 35 | {/each} 36 | {/if} 37 |
38 | {#if advanceOptions && advanceOptions.length > 0} 39 |
40 | Advanced Options 41 |
46 | {#each advanceOptions as params} 47 | {#if params.field === FieldType.RANGE} 48 | 49 | {:else if params.field === FieldType.SEED} 50 | 51 | {:else if params.field === FieldType.TEXTAREA} 52 | 53 | {:else if params.field === FieldType.CHECKBOX} 54 | 55 | {:else if params.field === FieldType.SELECT} 56 | 57 | {/if} 58 | {/each} 59 |
60 |
61 | {/if} 62 |
63 | -------------------------------------------------------------------------------- /demo/frontend/src/lib/components/SeedInput.svelte: -------------------------------------------------------------------------------- 1 | 16 | 17 |
18 | 19 | 27 | 28 |
29 | -------------------------------------------------------------------------------- /demo/frontend/src/lib/components/Selectlist.svelte: -------------------------------------------------------------------------------- 1 | 11 | 12 |
13 | 14 | {#if params?.values} 15 | 25 | {/if} 26 |
27 | -------------------------------------------------------------------------------- /demo/frontend/src/lib/components/TextArea.svelte: -------------------------------------------------------------------------------- 1 | 11 | 12 |
13 | 16 |
17 | 24 |
25 |
26 | -------------------------------------------------------------------------------- /demo/frontend/src/lib/components/VideoInput.svelte: -------------------------------------------------------------------------------- 1 | 101 | 102 |
103 |
104 | {#if $mediaDevices.length > 0} 105 |
106 | 107 |
108 | {/if} 109 | 120 | 122 |
123 |
124 | 125 | 129 | 130 |
131 |
132 | 133 | 140 | -------------------------------------------------------------------------------- /demo/frontend/src/lib/components/Warning.svelte: -------------------------------------------------------------------------------- 1 | 13 | 14 | {#if message} 15 |
(message = '')}> 16 |
17 | {message} 18 |
19 |
20 |
21 | {/if} 22 | 23 | 28 | -------------------------------------------------------------------------------- /demo/frontend/src/lib/icons/floppy.svelte: -------------------------------------------------------------------------------- 1 | 4 | 5 | 6 | 10 | 11 | -------------------------------------------------------------------------------- /demo/frontend/src/lib/icons/screen.svelte: -------------------------------------------------------------------------------- 1 | 4 | 5 | 6 | 10 | 11 | -------------------------------------------------------------------------------- /demo/frontend/src/lib/icons/spinner.svelte: -------------------------------------------------------------------------------- 1 | 4 | 5 | 6 | 10 | 11 | -------------------------------------------------------------------------------- /demo/frontend/src/lib/index.ts: -------------------------------------------------------------------------------- 1 | // place files you want to import through the `$lib` alias in this folder. 2 | -------------------------------------------------------------------------------- /demo/frontend/src/lib/lcmLive.ts: -------------------------------------------------------------------------------- 1 | import { writable } from 'svelte/store'; 2 | 3 | 4 | export enum LCMLiveStatus { 5 | CONNECTED = "connected", 6 | DISCONNECTED = "disconnected", 7 | WAIT = "wait", 8 | SEND_FRAME = "send_frame", 9 | TIMEOUT = "timeout", 10 | } 11 | 12 | const initStatus: LCMLiveStatus = LCMLiveStatus.DISCONNECTED; 13 | 14 | export const lcmLiveStatus = writable(initStatus); 15 | export const streamId = writable(null); 16 | 17 | let websocket: WebSocket | null = null; 18 | export const lcmLiveActions = { 19 | async start(getSreamdata: () => any[]) { 20 | return new Promise((resolve, reject) => { 21 | 22 | try { 23 | const userId = crypto.randomUUID(); 24 | const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws" 25 | }:${window.location.host}/api/ws/${userId}`; 26 | 27 | websocket = new WebSocket(websocketURL); 28 | websocket.onopen = () => { 29 | console.log("Connected to websocket"); 30 | }; 31 | websocket.onclose = () => { 32 | lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED); 33 | console.log("Disconnected from websocket"); 34 | }; 35 | websocket.onerror = (err) => { 36 | console.error(err); 37 | }; 38 | websocket.onmessage = (event) => { 39 | const data = JSON.parse(event.data); 40 | switch (data.status) { 41 | case "connected": 42 | lcmLiveStatus.set(LCMLiveStatus.CONNECTED); 43 | streamId.set(userId); 44 | resolve({ status: "connected", userId }); 45 | break; 46 | case "send_frame": 47 | lcmLiveStatus.set(LCMLiveStatus.SEND_FRAME); 48 | const streamData = getSreamdata(); 49 | websocket?.send(JSON.stringify({ status: "next_frame" })); 50 | for (const d of streamData) { 51 | this.send(d); 52 | } 53 | break; 54 | case "wait": 55 | lcmLiveStatus.set(LCMLiveStatus.WAIT); 56 | break; 57 | case "timeout": 58 | console.log("timeout"); 59 | lcmLiveStatus.set(LCMLiveStatus.TIMEOUT); 60 | streamId.set(null); 61 | reject(new Error("timeout")); 62 | break; 63 | case "error": 64 | console.log(data.message); 65 | lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED); 66 | streamId.set(null); 67 | reject(new Error(data.message)); 68 | break; 69 | } 70 | }; 71 | 72 | } catch (err) { 73 | console.error(err); 74 | lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED); 75 | streamId.set(null); 76 | reject(err); 77 | } 78 | }); 79 | }, 80 | send(data: Blob | { [key: string]: any }) { 81 | if (websocket && websocket.readyState === WebSocket.OPEN) { 82 | if (data instanceof Blob) { 83 | websocket.send(data); 84 | } else { 85 | websocket.send(JSON.stringify(data)); 86 | } 87 | } else { 88 | console.log("WebSocket not connected"); 89 | } 90 | }, 91 | async stop() { 92 | lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED); 93 | if (websocket) { 94 | websocket.close(); 95 | } 96 | websocket = null; 97 | streamId.set(null); 98 | }, 99 | }; 100 | -------------------------------------------------------------------------------- /demo/frontend/src/lib/mediaStream.ts: -------------------------------------------------------------------------------- 1 | import { writable, type Writable, get } from 'svelte/store'; 2 | 3 | export enum MediaStreamStatusEnum { 4 | INIT = "init", 5 | CONNECTED = "connected", 6 | DISCONNECTED = "disconnected", 7 | } 8 | export const onFrameChangeStore: Writable<{ blob: Blob }> = writable({ blob: new Blob() }); 9 | 10 | export const mediaDevices = writable([]); 11 | export const mediaStreamStatus = writable(MediaStreamStatusEnum.INIT); 12 | export const mediaStream = writable(null); 13 | 14 | export const mediaStreamActions = { 15 | async enumerateDevices() { 16 | // console.log("Enumerating devices"); 17 | await navigator.mediaDevices.enumerateDevices() 18 | .then(devices => { 19 | const cameras = devices.filter(device => device.kind === 'videoinput'); 20 | mediaDevices.set(cameras); 21 | }) 22 | .catch(err => { 23 | console.error(err); 24 | }); 25 | }, 26 | async start(mediaDevicedID?: string) { 27 | const constraints = { 28 | audio: false, 29 | video: { 30 | width: 1024, height: 1024, deviceId: mediaDevicedID 31 | } 32 | }; 33 | 34 | await navigator.mediaDevices 35 | .getUserMedia(constraints) 36 | .then((stream) => { 37 | mediaStreamStatus.set(MediaStreamStatusEnum.CONNECTED); 38 | mediaStream.set(stream); 39 | }) 40 | .catch((err) => { 41 | console.error(`${err.name}: ${err.message}`); 42 | mediaStreamStatus.set(MediaStreamStatusEnum.DISCONNECTED); 43 | mediaStream.set(null); 44 | }); 45 | }, 46 | async startScreenCapture() { 47 | const displayMediaOptions = { 48 | video: { 49 | displaySurface: "window", 50 | }, 51 | audio: false, 52 | surfaceSwitching: "include" 53 | }; 54 | 55 | 56 | let captureStream = null; 57 | 58 | try { 59 | captureStream = await navigator.mediaDevices.getDisplayMedia(displayMediaOptions); 60 | const videoTrack = captureStream.getVideoTracks()[0]; 61 | 62 | console.log("Track settings:"); 63 | console.log(JSON.stringify(videoTrack.getSettings(), null, 2)); 64 | console.log("Track constraints:"); 65 | console.log(JSON.stringify(videoTrack.getConstraints(), null, 2)); 66 | mediaStreamStatus.set(MediaStreamStatusEnum.CONNECTED); 67 | mediaStream.set(captureStream) 68 | } catch (err) { 69 | console.error(err); 70 | } 71 | 72 | }, 73 | async switchCamera(mediaDevicedID: string) { 74 | if (get(mediaStreamStatus) !== MediaStreamStatusEnum.CONNECTED) { 75 | return; 76 | } 77 | const constraints = { 78 | audio: false, 79 | video: { width: 1024, height: 1024, deviceId: mediaDevicedID } 80 | }; 81 | await navigator.mediaDevices 82 | .getUserMedia(constraints) 83 | .then((stream) => { 84 | mediaStreamStatus.set(MediaStreamStatusEnum.CONNECTED); 85 | mediaStream.set(stream) 86 | }) 87 | .catch((err) => { 88 | console.error(`${err.name}: ${err.message}`); 89 | }); 90 | }, 91 | async stop() { 92 | navigator.mediaDevices.getUserMedia({ video: true }).then((stream) => { 93 | stream.getTracks().forEach((track) => track.stop()); 94 | }); 95 | mediaStreamStatus.set(MediaStreamStatusEnum.DISCONNECTED); 96 | mediaStream.set(null); 97 | }, 98 | }; 99 | -------------------------------------------------------------------------------- /demo/frontend/src/lib/store.ts: -------------------------------------------------------------------------------- 1 | 2 | import { derived, writable, get, type Writable, type Readable } from 'svelte/store'; 3 | 4 | export const pipelineValues: Writable> = writable({}); 5 | export const deboucedPipelineValues: Readable> 6 | = derived(pipelineValues, ($pipelineValues, set) => { 7 | const debounced = setTimeout(() => { 8 | set($pipelineValues); 9 | }, 100); 10 | return () => clearTimeout(debounced); 11 | }); 12 | 13 | 14 | 15 | export const getPipelineValues = () => get(pipelineValues); 16 | -------------------------------------------------------------------------------- /demo/frontend/src/lib/types.ts: -------------------------------------------------------------------------------- 1 | export const enum FieldType { 2 | RANGE = "range", 3 | SEED = "seed", 4 | TEXTAREA = "textarea", 5 | CHECKBOX = "checkbox", 6 | SELECT = "select", 7 | } 8 | export const enum PipelineMode { 9 | IMAGE = "image", 10 | VIDEO = "video", 11 | TEXT = "text", 12 | } 13 | 14 | 15 | export interface Fields { 16 | [key: string]: FieldProps; 17 | } 18 | 19 | export interface FieldProps { 20 | default: number | string; 21 | max?: number; 22 | min?: number; 23 | title: string; 24 | field: FieldType; 25 | step?: number; 26 | disabled?: boolean; 27 | hide?: boolean; 28 | id: string; 29 | values?: string[]; 30 | } 31 | export interface PipelineInfo { 32 | title: { 33 | default: string; 34 | } 35 | name: string; 36 | description: string; 37 | input_mode: { 38 | default: PipelineMode; 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /demo/frontend/src/lib/utils.ts: -------------------------------------------------------------------------------- 1 | import * as piexif from "piexifjs"; 2 | 3 | interface IImageInfo { 4 | prompt?: string; 5 | negative_prompt?: string; 6 | seed?: number; 7 | guidance_scale?: number; 8 | } 9 | 10 | export function snapImage(imageEl: HTMLImageElement, info: IImageInfo) { 11 | try { 12 | const zeroth: { [key: string]: any } = {}; 13 | const exif: { [key: string]: any } = {}; 14 | const gps: { [key: string]: any } = {}; 15 | zeroth[piexif.ImageIFD.Make] = "LCM Image-to-Image ControNet"; 16 | zeroth[piexif.ImageIFD.ImageDescription] = `prompt: ${info?.prompt} | negative_prompt: ${info?.negative_prompt} | seed: ${info?.seed} | guidance_scale: ${info?.guidance_scale}`; 17 | zeroth[piexif.ImageIFD.Software] = "https://github.com/radames/Real-Time-Latent-Consistency-Model"; 18 | exif[piexif.ExifIFD.DateTimeOriginal] = new Date().toISOString(); 19 | 20 | const exifObj = { "0th": zeroth, "Exif": exif, "GPS": gps }; 21 | const exifBytes = piexif.dump(exifObj); 22 | 23 | const canvas = document.createElement("canvas"); 24 | canvas.width = imageEl.naturalWidth; 25 | canvas.height = imageEl.naturalHeight; 26 | const ctx = canvas.getContext("2d") as CanvasRenderingContext2D; 27 | ctx.drawImage(imageEl, 0, 0); 28 | const dataURL = canvas.toDataURL("image/jpeg"); 29 | const withExif = piexif.insert(exifBytes, dataURL); 30 | 31 | const a = document.createElement("a"); 32 | a.href = withExif; 33 | a.download = `lcm_txt_2_img${Date.now()}.png`; 34 | a.click(); 35 | } catch (err) { 36 | console.log(err); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /demo/frontend/src/routes/+layout.svelte: -------------------------------------------------------------------------------- 1 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /demo/frontend/src/routes/+page.svelte: -------------------------------------------------------------------------------- 1 | 96 | 97 | 98 | 101 | 102 | 103 | {@debug isImageMode, pipelineParams} 104 | 105 |
106 | 107 |
108 | {#if pageContent} 109 | {@html pageContent} 110 | {/if} 111 | {#if maxQueueSize > 0} 112 |

113 | There are {currentQueueSize} 114 | user(s) sharing the same GPU, affecting real-time performance. Maximum queue size is {maxQueueSize}. 115 | Duplicate and run it on your own GPU. 120 |

121 | {/if} 122 |
123 | {#if pipelineParams} 124 |
125 | {#if isImageMode} 126 |
127 | 131 |
132 | {/if} 133 |
134 | 135 |
136 |
137 | 144 | 145 |
146 |
147 | {:else} 148 | 149 |
150 | 151 |

Loading...

152 |
153 | {/if} 154 |
155 | 156 | 161 | -------------------------------------------------------------------------------- /demo/frontend/src/routes/+page.ts: -------------------------------------------------------------------------------- 1 | export const prerender = true 2 | -------------------------------------------------------------------------------- /demo/frontend/svelte.config.js: -------------------------------------------------------------------------------- 1 | import adapter from '@sveltejs/adapter-static'; 2 | import { vitePreprocess } from '@sveltejs/kit/vite'; 3 | /** @type {import('@sveltejs/kit').Config} */ 4 | const config = { 5 | preprocess: vitePreprocess({ postcss: true }), 6 | kit: { 7 | adapter: adapter({ 8 | pages: 'public', 9 | assets: 'public', 10 | fallback: undefined, 11 | precompress: false, 12 | strict: true 13 | }) 14 | } 15 | }; 16 | 17 | export default config; 18 | -------------------------------------------------------------------------------- /demo/frontend/tailwind.config.js: -------------------------------------------------------------------------------- 1 | /** @type {import('tailwindcss').Config} */ 2 | export default { 3 | content: ['./src/**/*.{html,js,svelte,ts}', '../**/*.py'], 4 | theme: { 5 | extend: {} 6 | }, 7 | plugins: [] 8 | }; 9 | -------------------------------------------------------------------------------- /demo/frontend/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "./.svelte-kit/tsconfig.json", 3 | "compilerOptions": { 4 | "allowJs": true, 5 | "checkJs": true, 6 | "esModuleInterop": true, 7 | "forceConsistentCasingInFileNames": true, 8 | "resolveJsonModule": true, 9 | "skipLibCheck": true, 10 | "sourceMap": true, 11 | "strict": true 12 | } 13 | // Path aliases are handled by https://kit.svelte.dev/docs/configuration#alias 14 | // 15 | // If you want to overwrite includes/excludes, make sure to copy over the relevant includes/excludes 16 | // from the referenced tsconfig.json - TypeScript does not merge them in 17 | } 18 | -------------------------------------------------------------------------------- /demo/frontend/vite.config.ts: -------------------------------------------------------------------------------- 1 | import { sveltekit } from '@sveltejs/kit/vite'; 2 | import { defineConfig } from 'vite'; 3 | 4 | export default defineConfig({ 5 | plugins: [sveltekit()], 6 | server: { 7 | proxy: { 8 | '/api': 'http://localhost:7860', 9 | '/api/ws': { 10 | target: 'ws://localhost:7860', 11 | ws: true 12 | } 13 | }, 14 | } 15 | }); 16 | -------------------------------------------------------------------------------- /demo/main.py: -------------------------------------------------------------------------------- 1 | from config import config 2 | 3 | 4 | if __name__ == "__main__": 5 | import uvicorn 6 | 7 | uvicorn.run( 8 | "app:app", 9 | host=config.host, 10 | port=config.port, 11 | reload=config.reload, 12 | ssl_certfile=config.ssl_certfile, 13 | ssl_keyfile=config.ssl_keyfile, 14 | ) 15 | -------------------------------------------------------------------------------- /demo/requirements.txt: -------------------------------------------------------------------------------- 1 | compel==2.0.2 2 | fastapi==0.104.1 3 | markdown2 4 | uvicorn[standard]==0.24.0.post1 5 | -------------------------------------------------------------------------------- /demo/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd frontend 3 | npm install 4 | npm run build 5 | if [ $? -eq 0 ]; then 6 | echo -e "\033[1;32m\nfrontend build success \033[0m" 7 | else 8 | echo -e "\033[1;31m\nfrontend build failed\n\033[0m" >&2 exit 1 9 | fi 10 | cd ../ 11 | python main.py --port 7860 --host 0.0.0.0 --engine-dir ../engines --acceleration tensorrt 12 | -------------------------------------------------------------------------------- /demo/util.py: -------------------------------------------------------------------------------- 1 | import io 2 | from importlib import import_module 3 | from types import ModuleType 4 | 5 | from PIL import Image 6 | 7 | 8 | def get_pipeline_class(pipeline_name: str) -> ModuleType: 9 | try: 10 | module = import_module(f"pipelines.{pipeline_name}") 11 | except ModuleNotFoundError: 12 | raise ValueError(f"Pipeline {pipeline_name} module not found") 13 | 14 | pipeline_class = getattr(module, "Pipeline", None) 15 | 16 | if pipeline_class is None: 17 | raise ValueError(f"'Pipeline' class not found in module '{pipeline_name}'.") 18 | 19 | return pipeline_class 20 | 21 | 22 | def bytes_to_pil(image_bytes: bytes) -> Image.Image: 23 | image = Image.open(io.BytesIO(image_bytes)) 24 | return image 25 | 26 | 27 | def pil_to_frame(image: Image.Image) -> bytes: 28 | frame_data = io.BytesIO() 29 | image.save(frame_data, format="JPEG") 30 | frame_data = frame_data.getvalue() 31 | return ( 32 | b"--frame\r\n" 33 | + b"Content-Type: image/jpeg\r\n" 34 | + f"Content-Length: {len(frame_data)}\r\n\r\n".encode() 35 | + frame_data 36 | + b"\r\n" 37 | ) 38 | 39 | 40 | def is_firefox(user_agent: str) -> bool: 41 | return "Firefox" in user_agent 42 | -------------------------------------------------------------------------------- /demo/vid2vid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | sys.path.append( 6 | os.path.join( 7 | os.path.dirname(__file__), 8 | "..", 9 | ) 10 | ) 11 | 12 | import torch 13 | from config import Args 14 | from PIL import Image 15 | from pydantic import BaseModel, Field 16 | 17 | from live2diff.utils.config import load_config 18 | from live2diff.utils.wrapper import StreamAnimateDiffusionDepthWrapper 19 | 20 | 21 | default_prompt = "masterpiece, best quality, felted, 1man with glasses, glasses, play with his pen" 22 | 23 | page_content = """

Live2Diff:

24 |

Live Stream Translation via Uni-directional Attention in Video Diffusion Models

25 |

26 | This demo showcases 27 | Live2Diff 31 | 32 | pipeline using 33 | LCM-LoRA with a MJPEG stream server. 38 |

39 | """ 40 | 41 | 42 | WARMUP_FRAMES = 8 43 | WINDOW_SIZE = 16 44 | 45 | 46 | class Pipeline: 47 | class Info(BaseModel): 48 | name: str = "Live2Diff" 49 | input_mode: str = "image" 50 | page_content: str = page_content 51 | 52 | def build_input_params(self, default_prompt: str = default_prompt, width=512, height=512): 53 | class InputParams(BaseModel): 54 | prompt: str = Field( 55 | default_prompt, 56 | title="Prompt", 57 | field="textarea", 58 | id="prompt", 59 | ) 60 | width: int = Field( 61 | 512, 62 | min=2, 63 | max=15, 64 | title="Width", 65 | disabled=True, 66 | hide=True, 67 | id="width", 68 | ) 69 | height: int = Field( 70 | 512, 71 | min=2, 72 | max=15, 73 | title="Height", 74 | disabled=True, 75 | hide=True, 76 | id="height", 77 | ) 78 | 79 | return InputParams 80 | 81 | def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype): 82 | config_path = args.config 83 | 84 | cfg = load_config(config_path) 85 | prompt = args.prompt or cfg.prompt or default_prompt 86 | 87 | self.InputParams = self.build_input_params(default_prompt=prompt) 88 | params = self.InputParams() 89 | 90 | num_inference_steps = args.num_inference_steps or cfg.get("num_inference_steps", None) 91 | strength = args.strength or cfg.get("strength", None) 92 | t_index_list = args.t_index_list or cfg.get("t_index_list", None) 93 | 94 | self.stream = StreamAnimateDiffusionDepthWrapper( 95 | few_step_model_type="lcm", 96 | config_path=config_path, 97 | cfg_type="none", 98 | strength=strength, 99 | num_inference_steps=num_inference_steps, 100 | t_index_list=t_index_list, 101 | frame_buffer_size=1, 102 | width=params.width, 103 | height=params.height, 104 | acceleration=args.acceleration, 105 | do_add_noise=True, 106 | output_type="pil", 107 | enable_similar_image_filter=True, 108 | similar_image_filter_threshold=0.98, 109 | use_denoising_batch=True, 110 | use_tiny_vae=True, 111 | seed=args.seed, 112 | engine_dir=args.engine_dir, 113 | ) 114 | 115 | self.last_prompt = prompt 116 | 117 | self.warmup_frame_list = [] 118 | self.has_prepared = False 119 | 120 | def predict(self, params: "Pipeline.InputParams") -> Image.Image: 121 | prompt = params.prompt 122 | if prompt != self.last_prompt: 123 | self.last_prompt = prompt 124 | self.warmup_frame_list.clear() 125 | 126 | if len(self.warmup_frame_list) < WARMUP_FRAMES: 127 | # from PIL import Image 128 | self.warmup_frame_list.append(self.stream.preprocess_image(params.image)) 129 | 130 | elif len(self.warmup_frame_list) == WARMUP_FRAMES and not self.has_prepared: 131 | warmup_frames = torch.stack(self.warmup_frame_list) 132 | self.stream.prepare( 133 | warmup_frames=warmup_frames, 134 | prompt=prompt, 135 | guidance_scale=1, 136 | ) 137 | self.has_prepared = True 138 | 139 | if self.has_prepared: 140 | image_tensor = self.stream.preprocess_image(params.image) 141 | output_image = self.stream(image=image_tensor) 142 | return output_image 143 | else: 144 | return Image.new("RGB", (params.width, params.height)) 145 | -------------------------------------------------------------------------------- /live2diff/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_stream_animation_depth import StreamAnimateDiffusionDepth 2 | 3 | 4 | __all__ = ["StreamAnimateDiffusionDepth"] 5 | -------------------------------------------------------------------------------- /live2diff/acceleration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/Live2Diff/e40aa03a9a9d17f1232fd7a4566b3ee793e6893f/live2diff/acceleration/__init__.py -------------------------------------------------------------------------------- /live2diff/acceleration/tensorrt/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from diffusers import AutoencoderKL 4 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( 5 | retrieve_latents, 6 | ) 7 | 8 | from .builder import EngineBuilder 9 | from .models import BaseModel 10 | 11 | 12 | class TorchVAEEncoder(torch.nn.Module): 13 | def __init__(self, vae: AutoencoderKL): 14 | super().__init__() 15 | self.vae = vae 16 | 17 | def forward(self, x: torch.Tensor): 18 | return retrieve_latents(self.vae.encode(x)) 19 | 20 | 21 | def compile_engine( 22 | torch_model: nn.Module, 23 | model_data: BaseModel, 24 | onnx_path: str, 25 | onnx_opt_path: str, 26 | engine_path: str, 27 | opt_image_height: int = 512, 28 | opt_image_width: int = 512, 29 | opt_batch_size: int = 1, 30 | engine_build_options: dict = {}, 31 | ): 32 | builder = EngineBuilder( 33 | model_data, 34 | torch_model, 35 | device=torch.device("cuda"), 36 | ) 37 | builder.build( 38 | onnx_path, 39 | onnx_opt_path, 40 | engine_path, 41 | opt_image_height=opt_image_height, 42 | opt_image_width=opt_image_width, 43 | opt_batch_size=opt_batch_size, 44 | **engine_build_options, 45 | ) 46 | -------------------------------------------------------------------------------- /live2diff/acceleration/tensorrt/builder.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | from typing import * 4 | 5 | import torch 6 | 7 | from .models import BaseModel 8 | from .utilities import ( 9 | build_engine, 10 | export_onnx, 11 | handle_onnx_batch_norm, 12 | optimize_onnx, 13 | ) 14 | 15 | 16 | class EngineBuilder: 17 | def __init__( 18 | self, 19 | model: BaseModel, 20 | network: Any, 21 | device=torch.device("cuda"), 22 | ): 23 | self.device = device 24 | 25 | self.model = model 26 | self.network = network 27 | 28 | def build( 29 | self, 30 | onnx_path: str, 31 | onnx_opt_path: str, 32 | engine_path: str, 33 | opt_image_height: int = 512, 34 | opt_image_width: int = 512, 35 | opt_batch_size: int = 1, 36 | min_image_resolution: int = 256, 37 | max_image_resolution: int = 1024, 38 | build_enable_refit: bool = False, 39 | build_static_batch: bool = False, 40 | build_dynamic_shape: bool = False, 41 | build_all_tactics: bool = False, 42 | onnx_opset: int = 17, 43 | force_engine_build: bool = False, 44 | force_onnx_export: bool = False, 45 | force_onnx_optimize: bool = False, 46 | ignore_onnx_optimize: bool = False, 47 | auto_cast: bool = True, 48 | handle_batch_norm: bool = False, 49 | ): 50 | if not force_onnx_export and os.path.exists(onnx_path): 51 | print(f"Found cached model: {onnx_path}") 52 | else: 53 | print(f"Exporting model: {onnx_path}") 54 | export_onnx( 55 | self.network, 56 | onnx_path=onnx_path, 57 | model_data=self.model, 58 | opt_image_height=opt_image_height, 59 | opt_image_width=opt_image_width, 60 | opt_batch_size=opt_batch_size, 61 | onnx_opset=onnx_opset, 62 | auto_cast=auto_cast, 63 | ) 64 | del self.network 65 | gc.collect() 66 | torch.cuda.empty_cache() 67 | 68 | if handle_batch_norm: 69 | print(f"Handle Batch Norm for {onnx_path}") 70 | handle_onnx_batch_norm(onnx_path) 71 | 72 | if ignore_onnx_optimize: 73 | print(f"Ignore onnx optimize for {onnx_path}.") 74 | onnx_opt_path = onnx_path 75 | elif not force_onnx_optimize and os.path.exists(onnx_opt_path): 76 | print(f"Found cached model: {onnx_opt_path}") 77 | else: 78 | print(f"Generating optimizing model: {onnx_opt_path}") 79 | optimize_onnx( 80 | onnx_path=onnx_path, 81 | onnx_opt_path=onnx_opt_path, 82 | model_data=self.model, 83 | ) 84 | self.model.min_latent_shape = min_image_resolution // 8 85 | self.model.max_latent_shape = max_image_resolution // 8 86 | if not force_engine_build and os.path.exists(engine_path): 87 | print(f"Found cached engine: {engine_path}") 88 | else: 89 | build_engine( 90 | engine_path=engine_path, 91 | onnx_opt_path=onnx_opt_path, 92 | model_data=self.model, 93 | opt_image_height=opt_image_height, 94 | opt_image_width=opt_image_width, 95 | opt_batch_size=opt_batch_size, 96 | build_static_batch=build_static_batch, 97 | build_dynamic_shape=build_dynamic_shape, 98 | build_all_tactics=build_all_tactics, 99 | build_enable_refit=build_enable_refit, 100 | ) 101 | 102 | gc.collect() 103 | torch.cuda.empty_cache() 104 | -------------------------------------------------------------------------------- /live2diff/acceleration/tensorrt/engine.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import torch 4 | from polygraphy import cuda 5 | 6 | from live2diff.animatediff.models.unet_depth_streaming import UNet3DConditionStreamingOutput 7 | 8 | from .utilities import Engine 9 | 10 | 11 | try: 12 | from diffusers.models.autoencoder_tiny import AutoencoderTinyOutput 13 | except ImportError: 14 | from dataclasses import dataclass 15 | 16 | from diffusers.utils import BaseOutput 17 | 18 | @dataclass 19 | class AutoencoderTinyOutput(BaseOutput): 20 | """ 21 | Output of AutoencoderTiny encoding method. 22 | 23 | Args: 24 | latents (`torch.Tensor`): Encoded outputs of the `Encoder`. 25 | 26 | """ 27 | 28 | latents: torch.Tensor 29 | 30 | 31 | try: 32 | from diffusers.models.vae import DecoderOutput 33 | except ImportError: 34 | from dataclasses import dataclass 35 | 36 | from diffusers.utils import BaseOutput 37 | 38 | @dataclass 39 | class DecoderOutput(BaseOutput): 40 | r""" 41 | Output of decoding method. 42 | 43 | Args: 44 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 45 | The decoded output sample from the last layer of the model. 46 | """ 47 | 48 | sample: torch.FloatTensor 49 | 50 | 51 | class AutoencoderKLEngine: 52 | def __init__( 53 | self, 54 | encoder_path: str, 55 | decoder_path: str, 56 | stream: cuda.Stream, 57 | scaling_factor: int, 58 | use_cuda_graph: bool = False, 59 | ): 60 | self.encoder = Engine(encoder_path) 61 | self.decoder = Engine(decoder_path) 62 | self.stream = stream 63 | self.vae_scale_factor = scaling_factor 64 | self.use_cuda_graph = use_cuda_graph 65 | 66 | self.encoder.load() 67 | self.decoder.load() 68 | self.encoder.activate() 69 | self.decoder.activate() 70 | 71 | def encode(self, images: torch.Tensor, **kwargs): 72 | self.encoder.allocate_buffers( 73 | shape_dict={ 74 | "images": images.shape, 75 | "latent": ( 76 | images.shape[0], 77 | 4, 78 | images.shape[2] // self.vae_scale_factor, 79 | images.shape[3] // self.vae_scale_factor, 80 | ), 81 | }, 82 | device=images.device, 83 | ) 84 | latents = self.encoder.infer( 85 | {"images": images}, 86 | self.stream, 87 | use_cuda_graph=self.use_cuda_graph, 88 | )["latent"] 89 | return AutoencoderTinyOutput(latents=latents) 90 | 91 | def decode(self, latent: torch.Tensor, **kwargs): 92 | self.decoder.allocate_buffers( 93 | shape_dict={ 94 | "latent": latent.shape, 95 | "images": ( 96 | latent.shape[0], 97 | 3, 98 | latent.shape[2] * self.vae_scale_factor, 99 | latent.shape[3] * self.vae_scale_factor, 100 | ), 101 | }, 102 | device=latent.device, 103 | ) 104 | images = self.decoder.infer( 105 | {"latent": latent}, 106 | self.stream, 107 | use_cuda_graph=self.use_cuda_graph, 108 | )["images"] 109 | return DecoderOutput(sample=images) 110 | 111 | def to(self, *args, **kwargs): 112 | pass 113 | 114 | def forward(self, *args, **kwargs): 115 | pass 116 | 117 | 118 | class UNet2DConditionModelDepthEngine: 119 | def __init__(self, filepath: str, stream: cuda.Stream, use_cuda_graph: bool = False): 120 | self.engine = Engine(filepath) 121 | self.stream = stream 122 | self.use_cuda_graph = use_cuda_graph 123 | 124 | self.init_profiler() 125 | 126 | self.engine.load() 127 | self.engine.activate(profiler=self.profiler) 128 | self.has_allocated = False 129 | 130 | def init_profiler(self): 131 | import tensorrt 132 | 133 | class Profiler(tensorrt.IProfiler): 134 | def __init__(self): 135 | tensorrt.IProfiler.__init__(self) 136 | 137 | def report_layer_time(self, layer_name, ms): 138 | print(f"{layer_name}: {ms} ms") 139 | 140 | self.profiler = Profiler() 141 | 142 | def __call__( 143 | self, 144 | latent_model_input: torch.Tensor, 145 | timestep: torch.Tensor, 146 | encoder_hidden_states: torch.Tensor, 147 | temporal_attention_mask: torch.Tensor, 148 | depth_sample: torch.Tensor, 149 | kv_cache: List[torch.Tensor], 150 | pe_idx: torch.Tensor, 151 | update_idx: torch.Tensor, 152 | **kwargs, 153 | ) -> Any: 154 | if timestep.dtype != torch.float32: 155 | timestep = timestep.float() 156 | 157 | feed_dict = { 158 | "sample": latent_model_input, 159 | "timestep": timestep, 160 | "encoder_hidden_states": encoder_hidden_states, 161 | "temporal_attention_mask": temporal_attention_mask, 162 | "depth_sample": depth_sample, 163 | "pe_idx": pe_idx, 164 | "update_idx": update_idx, 165 | } 166 | for idx, cache in enumerate(kv_cache): 167 | feed_dict[f"kv_cache_{idx}"] = cache 168 | shape_dict = {k: v.shape for k, v in feed_dict.items()} 169 | 170 | if not self.has_allocated: 171 | self.engine.allocate_buffers( 172 | shape_dict=shape_dict, 173 | device=latent_model_input.device, 174 | ) 175 | self.has_allocated = True 176 | 177 | output = self.engine.infer( 178 | feed_dict, 179 | self.stream, 180 | use_cuda_graph=self.use_cuda_graph, 181 | ) 182 | 183 | noise_pred = output["latent"] 184 | kv_cache = [output[f"kv_cache_out_{idx}"] for idx in range(len(kv_cache))] 185 | return UNet3DConditionStreamingOutput(sample=noise_pred, kv_cache=kv_cache) 186 | 187 | def to(self, *args, **kwargs): 188 | pass 189 | 190 | def forward(self, *args, **kwargs): 191 | pass 192 | 193 | 194 | class MidasEngine: 195 | def __init__(self, filepath: str, stream: cuda.Stream, use_cuda_graph: bool = False): 196 | self.engine = Engine(filepath) 197 | self.stream = stream 198 | self.use_cuda_graph = use_cuda_graph 199 | 200 | self.engine.load() 201 | self.engine.activate() 202 | self.has_allocated = False 203 | self.default_batch_size = 1 204 | 205 | def __call__( 206 | self, 207 | images: torch.Tensor, 208 | **kwargs, 209 | ) -> Any: 210 | if not self.has_allocated or images.shape[0] != self.default_batch_size: 211 | bz = images.shape[0] 212 | self.engine.allocate_buffers( 213 | shape_dict={ 214 | "images": (bz, 3, 384, 384), 215 | "depth_map": (bz, 384, 384), 216 | }, 217 | device=images.device, 218 | ) 219 | self.has_allocated = True 220 | self.default_batch_size = bz 221 | 222 | depth_map = self.engine.infer( 223 | { 224 | "images": images, 225 | }, 226 | self.stream, 227 | use_cuda_graph=self.use_cuda_graph, 228 | )["depth_map"] # (1, 384, 384) 229 | 230 | return depth_map 231 | 232 | def norm(self, x): 233 | return (x - x.min()) / (x.max() - x.min()) 234 | 235 | def to(self, *args, **kwargs): 236 | pass 237 | 238 | def forward(self, *args, **kwargs): 239 | pass 240 | -------------------------------------------------------------------------------- /live2diff/acceleration/tensorrt/utilities.py: -------------------------------------------------------------------------------- 1 | #! fork: https://github.com/NVIDIA/TensorRT/blob/main/demo/Diffusion/utilities.py 2 | 3 | # 4 | # Copyright 2022 The HuggingFace Inc. team. 5 | # SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 6 | # SPDX-License-Identifier: Apache-2.0 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | # 20 | 21 | import gc 22 | from collections import OrderedDict 23 | from typing import * 24 | 25 | import numpy as np 26 | import onnx 27 | import onnx_graphsurgeon as gs 28 | import tensorrt as trt 29 | import torch 30 | from cuda import cudart 31 | from PIL import Image 32 | from polygraphy import cuda 33 | from polygraphy.backend.common import bytes_from_path 34 | from polygraphy.backend.trt import ( 35 | CreateConfig, 36 | Profile, 37 | engine_from_bytes, 38 | engine_from_network, 39 | network_from_onnx_path, 40 | save_engine, 41 | ) 42 | 43 | from .models import BaseModel 44 | 45 | 46 | TRT_LOGGER = trt.Logger(trt.Logger.ERROR) 47 | 48 | # Map of numpy dtype -> torch dtype 49 | numpy_to_torch_dtype_dict = { 50 | np.uint8: torch.uint8, 51 | np.int8: torch.int8, 52 | np.int16: torch.int16, 53 | np.int32: torch.int32, 54 | np.int64: torch.int64, 55 | np.float16: torch.float16, 56 | np.float32: torch.float32, 57 | np.float64: torch.float64, 58 | np.complex64: torch.complex64, 59 | np.complex128: torch.complex128, 60 | } 61 | if np.version.full_version >= "1.24.0": 62 | numpy_to_torch_dtype_dict[np.bool_] = torch.bool 63 | else: 64 | numpy_to_torch_dtype_dict[np.bool] = torch.bool 65 | 66 | # Map of torch dtype -> numpy dtype 67 | torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} 68 | 69 | 70 | def CUASSERT(cuda_ret): 71 | err = cuda_ret[0] 72 | if err != cudart.cudaError_t.cudaSuccess: 73 | raise RuntimeError( 74 | f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t" 75 | ) 76 | if len(cuda_ret) > 1: 77 | return cuda_ret[1] 78 | return None 79 | 80 | 81 | class Engine: 82 | def __init__( 83 | self, 84 | engine_path, 85 | ): 86 | self.engine_path = engine_path 87 | self.engine = None 88 | self.context = None 89 | self.buffers = OrderedDict() 90 | self.tensors = OrderedDict() 91 | self.cuda_graph_instance = None # cuda graph 92 | 93 | def __del__(self): 94 | [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)] 95 | del self.engine 96 | del self.context 97 | del self.buffers 98 | del self.tensors 99 | 100 | def refit(self, onnx_path, onnx_refit_path): 101 | def convert_int64(arr): 102 | # TODO: smarter conversion 103 | if len(arr.shape) == 0: 104 | return np.int32(arr) 105 | return arr 106 | 107 | def add_to_map(refit_dict, name, values): 108 | if name in refit_dict: 109 | assert refit_dict[name] is None 110 | if values.dtype == np.int64: 111 | values = convert_int64(values) 112 | refit_dict[name] = values 113 | 114 | print(f"Refitting TensorRT engine with {onnx_refit_path} weights") 115 | refit_nodes = gs.import_onnx(onnx.load(onnx_refit_path)).toposort().nodes 116 | 117 | # Construct mapping from weight names in refit model -> original model 118 | name_map = {} 119 | for n, node in enumerate(gs.import_onnx(onnx.load(onnx_path)).toposort().nodes): 120 | refit_node = refit_nodes[n] 121 | assert node.op == refit_node.op 122 | # Constant nodes in ONNX do not have inputs but have a constant output 123 | if node.op == "Constant": 124 | name_map[refit_node.outputs[0].name] = node.outputs[0].name 125 | # Handle scale and bias weights 126 | elif node.op == "Conv": 127 | if node.inputs[1].__class__ == gs.Constant: 128 | name_map[refit_node.name + "_TRTKERNEL"] = node.name + "_TRTKERNEL" 129 | if node.inputs[2].__class__ == gs.Constant: 130 | name_map[refit_node.name + "_TRTBIAS"] = node.name + "_TRTBIAS" 131 | # For all other nodes: find node inputs that are initializers (gs.Constant) 132 | else: 133 | for i, inp in enumerate(node.inputs): 134 | if inp.__class__ == gs.Constant: 135 | name_map[refit_node.inputs[i].name] = inp.name 136 | 137 | def map_name(name): 138 | if name in name_map: 139 | return name_map[name] 140 | return name 141 | 142 | # Construct refit dictionary 143 | refit_dict = {} 144 | refitter = trt.Refitter(self.engine, TRT_LOGGER) 145 | all_weights = refitter.get_all() 146 | for layer_name, role in zip(all_weights[0], all_weights[1]): 147 | # for speciailized roles, use a unique name in the map: 148 | if role == trt.WeightsRole.KERNEL: 149 | name = layer_name + "_TRTKERNEL" 150 | elif role == trt.WeightsRole.BIAS: 151 | name = layer_name + "_TRTBIAS" 152 | else: 153 | name = layer_name 154 | 155 | assert name not in refit_dict, "Found duplicate layer: " + name 156 | refit_dict[name] = None 157 | 158 | for n in refit_nodes: 159 | # Constant nodes in ONNX do not have inputs but have a constant output 160 | if n.op == "Constant": 161 | name = map_name(n.outputs[0].name) 162 | print(f"Add Constant {name}\n") 163 | add_to_map(refit_dict, name, n.outputs[0].values) 164 | 165 | # Handle scale and bias weights 166 | elif n.op == "Conv": 167 | if n.inputs[1].__class__ == gs.Constant: 168 | name = map_name(n.name + "_TRTKERNEL") 169 | add_to_map(refit_dict, name, n.inputs[1].values) 170 | 171 | if n.inputs[2].__class__ == gs.Constant: 172 | name = map_name(n.name + "_TRTBIAS") 173 | add_to_map(refit_dict, name, n.inputs[2].values) 174 | 175 | # For all other nodes: find node inputs that are initializers (AKA gs.Constant) 176 | else: 177 | for inp in n.inputs: 178 | name = map_name(inp.name) 179 | if inp.__class__ == gs.Constant: 180 | add_to_map(refit_dict, name, inp.values) 181 | 182 | for layer_name, weights_role in zip(all_weights[0], all_weights[1]): 183 | if weights_role == trt.WeightsRole.KERNEL: 184 | custom_name = layer_name + "_TRTKERNEL" 185 | elif weights_role == trt.WeightsRole.BIAS: 186 | custom_name = layer_name + "_TRTBIAS" 187 | else: 188 | custom_name = layer_name 189 | 190 | # Skip refitting Trilu for now; scalar weights of type int64 value 1 - for clip model 191 | if layer_name.startswith("onnx::Trilu"): 192 | continue 193 | 194 | if refit_dict[custom_name] is not None: 195 | refitter.set_weights(layer_name, weights_role, refit_dict[custom_name]) 196 | else: 197 | print(f"[W] No refit weights for layer: {layer_name}") 198 | 199 | if not refitter.refit_cuda_engine(): 200 | print("Failed to refit!") 201 | exit(0) 202 | 203 | def build( 204 | self, 205 | onnx_path, 206 | fp16, 207 | input_profile=None, 208 | enable_refit=False, 209 | enable_all_tactics=False, 210 | timing_cache=None, 211 | workspace_size=0, 212 | ): 213 | print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") 214 | p = Profile() 215 | if input_profile: 216 | for name, dims in input_profile.items(): 217 | assert len(dims) == 3 218 | p.add(name, min=dims[0], opt=dims[1], max=dims[2]) 219 | 220 | config_kwargs = {} 221 | 222 | if workspace_size > 0: 223 | config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size} 224 | if not enable_all_tactics: 225 | config_kwargs["tactic_sources"] = [] 226 | 227 | engine = engine_from_network( 228 | network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]), 229 | config=CreateConfig( 230 | fp16=fp16, refittable=enable_refit, profiles=[p], load_timing_cache=timing_cache, **config_kwargs 231 | ), 232 | save_timing_cache=timing_cache, 233 | ) 234 | save_engine(engine, path=self.engine_path) 235 | 236 | def load(self): 237 | print(f"Loading TensorRT engine: {self.engine_path}") 238 | self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) 239 | 240 | def activate(self, reuse_device_memory=None, profiler=None): 241 | if reuse_device_memory: 242 | self.context = self.engine.create_execution_context_without_device_memory() 243 | self.context.device_memory = reuse_device_memory 244 | else: 245 | self.context = self.engine.create_execution_context() 246 | 247 | def allocate_buffers(self, shape_dict=None, device="cuda"): 248 | # NOTE: API for tensorrt 10.01 249 | from tensorrt import TensorIOMode 250 | 251 | for idx in range(self.engine.num_io_tensors): 252 | binding = self.engine[idx] 253 | if shape_dict and binding in shape_dict: 254 | shape = shape_dict[binding] 255 | else: 256 | shape = self.engine.get_tensor_shape(binding) 257 | dtype = trt.nptype(self.engine.get_tensor_dtype(binding)) 258 | tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype], device=device) 259 | self.tensors[binding] = tensor 260 | 261 | binding_mode = self.engine.get_tensor_mode(binding) 262 | if binding_mode == TensorIOMode.INPUT: 263 | self.context.set_input_shape(binding, shape) 264 | self.has_allocated = True 265 | 266 | def infer(self, feed_dict, stream, use_cuda_graph=False): 267 | for name, buf in feed_dict.items(): 268 | self.tensors[name].copy_(buf) 269 | 270 | for name, tensor in self.tensors.items(): 271 | self.context.set_tensor_address(name, tensor.data_ptr()) 272 | 273 | if use_cuda_graph: 274 | if self.cuda_graph_instance is not None: 275 | CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream.ptr)) 276 | CUASSERT(cudart.cudaStreamSynchronize(stream.ptr)) 277 | else: 278 | # do inference before CUDA graph capture 279 | noerror = self.context.execute_async_v3(stream.ptr) 280 | if not noerror: 281 | raise ValueError("ERROR: inference failed.") 282 | # capture cuda graph 283 | CUASSERT( 284 | cudart.cudaStreamBeginCapture(stream.ptr, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal) 285 | ) 286 | self.context.execute_async_v3(stream.ptr) 287 | self.graph = CUASSERT(cudart.cudaStreamEndCapture(stream.ptr)) 288 | self.cuda_graph_instance = CUASSERT(cudart.cudaGraphInstantiate(self.graph, 0)) 289 | else: 290 | noerror = self.context.execute_async_v3(stream.ptr) 291 | if not noerror: 292 | raise ValueError("ERROR: inference failed.") 293 | 294 | return self.tensors 295 | 296 | 297 | def decode_images(images: torch.Tensor): 298 | images = ( 299 | ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy() 300 | ) 301 | return [Image.fromarray(x) for x in images] 302 | 303 | 304 | def preprocess_image(image: Image.Image): 305 | w, h = image.size 306 | w, h = [x - x % 32 for x in (w, h)] # resize to integer multiple of 32 307 | image = image.resize((w, h)) 308 | init_image = np.array(image).astype(np.float32) / 255.0 309 | init_image = init_image[None].transpose(0, 3, 1, 2) 310 | init_image = torch.from_numpy(init_image).contiguous() 311 | return 2.0 * init_image - 1.0 312 | 313 | 314 | def prepare_mask_and_masked_image(image: Image.Image, mask: Image.Image): 315 | if isinstance(image, Image.Image): 316 | image = np.array(image.convert("RGB")) 317 | image = image[None].transpose(0, 3, 1, 2) 318 | image = torch.from_numpy(image).to(dtype=torch.float32).contiguous() / 127.5 - 1.0 319 | if isinstance(mask, Image.Image): 320 | mask = np.array(mask.convert("L")) 321 | mask = mask.astype(np.float32) / 255.0 322 | mask = mask[None, None] 323 | mask[mask < 0.5] = 0 324 | mask[mask >= 0.5] = 1 325 | mask = torch.from_numpy(mask).to(dtype=torch.float32).contiguous() 326 | 327 | masked_image = image * (mask < 0.5) 328 | 329 | return mask, masked_image 330 | 331 | 332 | def build_engine( 333 | engine_path: str, 334 | onnx_opt_path: str, 335 | model_data: BaseModel, 336 | opt_image_height: int, 337 | opt_image_width: int, 338 | opt_batch_size: int, 339 | build_static_batch: bool = False, 340 | build_dynamic_shape: bool = False, 341 | build_all_tactics: bool = False, 342 | build_enable_refit: bool = False, 343 | ): 344 | _, free_mem, _ = cudart.cudaMemGetInfo() 345 | GiB = 2**30 346 | if free_mem > 6 * GiB: 347 | activation_carveout = 4 * GiB 348 | max_workspace_size = free_mem - activation_carveout 349 | else: 350 | max_workspace_size = 0 351 | engine = Engine(engine_path) 352 | input_profile = model_data.get_input_profile( 353 | opt_batch_size, 354 | opt_image_height, 355 | opt_image_width, 356 | static_batch=build_static_batch, 357 | static_shape=not build_dynamic_shape, 358 | ) 359 | engine.build( 360 | onnx_opt_path, 361 | fp16=True, 362 | input_profile=input_profile, 363 | enable_refit=build_enable_refit, 364 | enable_all_tactics=build_all_tactics, 365 | workspace_size=max_workspace_size, 366 | ) 367 | 368 | return engine 369 | 370 | 371 | def export_onnx( 372 | model, 373 | onnx_path: str, 374 | model_data: BaseModel, 375 | opt_image_height: int, 376 | opt_image_width: int, 377 | opt_batch_size: int, 378 | onnx_opset: int, 379 | auto_cast: bool = True, 380 | ): 381 | from contextlib import contextmanager 382 | 383 | @contextmanager 384 | def auto_cast_manager(enabled): 385 | if enabled: 386 | with torch.inference_mode(), torch.autocast("cuda"): 387 | yield 388 | else: 389 | yield 390 | 391 | with auto_cast_manager(auto_cast): 392 | inputs = model_data.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) 393 | torch.onnx.export( 394 | model, 395 | inputs, 396 | onnx_path, 397 | export_params=True, 398 | opset_version=onnx_opset, 399 | do_constant_folding=True, 400 | input_names=model_data.get_input_names(), 401 | output_names=model_data.get_output_names(), 402 | dynamic_axes=model_data.get_dynamic_axes(), 403 | ) 404 | del model 405 | gc.collect() 406 | torch.cuda.empty_cache() 407 | 408 | 409 | def optimize_onnx( 410 | onnx_path: str, 411 | onnx_opt_path: str, 412 | model_data: BaseModel, 413 | ): 414 | model_data.optimize(onnx_path, onnx_opt_path) 415 | # # onnx_opt_graph = model_data.optimize(onnx.load(onnx_path)) 416 | # onnx_opt_graph = model_data.optimize(onnx_path) 417 | # onnx.save(onnx_opt_graph, onnx_opt_path) 418 | # del onnx_opt_graph 419 | # gc.collect() 420 | # torch.cuda.empty_cache() 421 | 422 | 423 | def handle_onnx_batch_norm(onnx_path: str): 424 | onnx_model = onnx.load(onnx_path) 425 | for node in onnx_model.graph.node: 426 | if node.op_type == "BatchNormalization": 427 | for attribute in node.attribute: 428 | if attribute.name == "training_mode": 429 | if attribute.i == 1: 430 | node.output.remove(node.output[1]) 431 | node.output.remove(node.output[1]) 432 | attribute.i = 0 433 | 434 | onnx.save_model(onnx_model, onnx_path) 435 | -------------------------------------------------------------------------------- /live2diff/animatediff/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/Live2Diff/e40aa03a9a9d17f1232fd7a4566b3ee793e6893f/live2diff/animatediff/__init__.py -------------------------------------------------------------------------------- /live2diff/animatediff/converter/__init__.py: -------------------------------------------------------------------------------- 1 | from .convert import load_third_party_checkpoints, load_third_party_unet 2 | 3 | 4 | __all__ = ["load_third_party_checkpoints", "load_third_party_unet"] 5 | -------------------------------------------------------------------------------- /live2diff/animatediff/converter/convert.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from diffusers.pipelines import StableDiffusionPipeline 5 | from safetensors import safe_open 6 | 7 | from .convert_from_ckpt import convert_ldm_clip_checkpoint, convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint 8 | from .convert_lora_safetensor_to_diffusers import convert_lora_model_level 9 | 10 | 11 | def load_third_party_checkpoints( 12 | pipeline: StableDiffusionPipeline, 13 | third_party_dict: dict, 14 | dreambooth_path: Optional[str] = None, 15 | ): 16 | """ 17 | Modified from https://github.com/open-mmlab/PIA/blob/4b1ee136542e807a13c1adfe52f4e8e5fcc65cdb/animatediff/pipelines/i2v_pipeline.py#L165 18 | """ 19 | vae = third_party_dict.get("vae", None) 20 | lora_list = third_party_dict.get("lora_list", []) 21 | 22 | dreambooth = dreambooth_path or third_party_dict.get("dreambooth", None) 23 | 24 | text_embedding_dict = third_party_dict.get("text_embedding_dict", {}) 25 | 26 | if dreambooth is not None: 27 | dreambooth_state_dict = {} 28 | if dreambooth.endswith(".safetensors"): 29 | with safe_open(dreambooth, framework="pt", device="cpu") as f: 30 | for key in f.keys(): 31 | dreambooth_state_dict[key] = f.get_tensor(key) 32 | else: 33 | dreambooth_state_dict = torch.load(dreambooth, map_location="cpu") 34 | if "state_dict" in dreambooth_state_dict: 35 | dreambooth_state_dict = dreambooth_state_dict["state_dict"] 36 | # load unet 37 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, pipeline.unet.config) 38 | pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) 39 | 40 | # load vae from dreambooth (if need) 41 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, pipeline.vae.config) 42 | # add prefix for compiled model 43 | if "_orig_mod" in list(pipeline.vae.state_dict().keys())[0]: 44 | converted_vae_checkpoint = {f"_orig_mod.{k}": v for k, v in converted_vae_checkpoint.items()} 45 | pipeline.vae.load_state_dict(converted_vae_checkpoint, strict=True) 46 | 47 | # load text encoder (if need) 48 | text_encoder_checkpoint = convert_ldm_clip_checkpoint(dreambooth_state_dict) 49 | if text_encoder_checkpoint: 50 | pipeline.text_encoder.load_state_dict(text_encoder_checkpoint, strict=False) 51 | 52 | if vae is not None: 53 | vae_state_dict = {} 54 | if vae.endswith("safetensors"): 55 | with safe_open(vae, framework="pt", device="cpu") as f: 56 | for key in f.keys(): 57 | vae_state_dict[key] = f.get_tensor(key) 58 | elif vae.endswith("ckpt") or vae.endswith("pt"): 59 | vae_state_dict = torch.load(vae, map_location="cpu") 60 | if "state_dict" in vae_state_dict: 61 | vae_state_dict = vae_state_dict["state_dict"] 62 | 63 | vae_state_dict = {f"first_stage_model.{k}": v for k, v in vae_state_dict.items()} 64 | 65 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_state_dict, pipeline.vae.config) 66 | # add prefix for compiled model 67 | if "_orig_mod" in list(pipeline.vae.state_dict().keys())[0]: 68 | converted_vae_checkpoint = {f"_orig_mod.{k}": v for k, v in converted_vae_checkpoint.items()} 69 | pipeline.vae.load_state_dict(converted_vae_checkpoint, strict=True) 70 | 71 | if lora_list: 72 | for lora_dict in lora_list: 73 | lora, lora_alpha = lora_dict["lora"], lora_dict["lora_alpha"] 74 | lora_state_dict = {} 75 | with safe_open(lora, framework="pt", device="cpu") as file: 76 | for k in file.keys(): 77 | lora_state_dict[k] = file.get_tensor(k) 78 | pipeline.unet, pipeline.text_encoder = convert_lora_model_level( 79 | lora_state_dict, 80 | pipeline.unet, 81 | pipeline.text_encoder, 82 | alpha=lora_alpha, 83 | ) 84 | print(f'Add LoRA "{lora}":{lora_alpha} to pipeline.') 85 | 86 | if text_embedding_dict is not None: 87 | from diffusers.loaders import TextualInversionLoaderMixin 88 | 89 | assert isinstance( 90 | pipeline, TextualInversionLoaderMixin 91 | ), "Pipeline must inherit from TextualInversionLoaderMixin." 92 | 93 | for token, embedding_path in text_embedding_dict.items(): 94 | pipeline.load_textual_inversion(embedding_path, token) 95 | 96 | return pipeline 97 | 98 | 99 | def load_third_party_unet(unet, third_party_dict: dict, dreambooth_path: Optional[str] = None): 100 | lora_list = third_party_dict.get("lora_list", []) 101 | dreambooth = dreambooth_path or third_party_dict.get("dreambooth", None) 102 | 103 | if dreambooth is not None: 104 | dreambooth_state_dict = {} 105 | if dreambooth.endswith(".safetensors"): 106 | with safe_open(dreambooth, framework="pt", device="cpu") as f: 107 | for key in f.keys(): 108 | dreambooth_state_dict[key] = f.get_tensor(key) 109 | else: 110 | dreambooth_state_dict = torch.load(dreambooth, map_location="cpu") 111 | if "state_dict" in dreambooth_state_dict: 112 | dreambooth_state_dict = dreambooth_state_dict["state_dict"] 113 | # load unet 114 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, unet.config) 115 | unet.load_state_dict(converted_unet_checkpoint, strict=False) 116 | 117 | if lora_list: 118 | for lora_dict in lora_list: 119 | lora, lora_alpha = lora_dict["lora"], lora_dict["lora_alpha"] 120 | lora_state_dict = {} 121 | 122 | with safe_open(lora, framework="pt", device="cpu") as file: 123 | for k in file.keys(): 124 | if "text" not in k: 125 | lora_state_dict[k] = file.get_tensor(k) 126 | unet, _ = convert_lora_model_level( 127 | lora_state_dict, 128 | unet, 129 | None, 130 | alpha=lora_alpha, 131 | ) 132 | print(f'Add LoRA "{lora}":{lora_alpha} to Warmup UNet.') 133 | 134 | return unet 135 | -------------------------------------------------------------------------------- /live2diff/animatediff/converter/convert_lora_safetensor_to_diffusers.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/PIA/blob/main/animatediff/utils/convert_lora_safetensor_to_diffusers.py and 2 | # https://github.com/guoyww/AnimateDiff/blob/main/animatediff/utils/convert_lora_safetensor_to_diffusers.py 3 | # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | """Conversion script for the LoRA's safetensors checkpoints.""" 18 | 19 | import torch 20 | 21 | 22 | def convert_lora_model_level( 23 | state_dict, unet, text_encoder=None, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6 24 | ): 25 | """convert lora in model level instead of pipeline leval""" 26 | 27 | visited = [] 28 | 29 | # directly update weight in diffusers model 30 | for key in state_dict: 31 | # it is suggested to print out the key, it usually will be something like below 32 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" 33 | 34 | # as we have set the alpha beforehand, so just skip 35 | if ".alpha" in key or key in visited: 36 | continue 37 | 38 | if "text" in key: 39 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") 40 | assert text_encoder is not None, "text_encoder must be passed since lora contains text encoder layers" 41 | curr_layer = text_encoder 42 | else: 43 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") 44 | curr_layer = unet 45 | 46 | # find the target layer 47 | temp_name = layer_infos.pop(0) 48 | while len(layer_infos) > -1: 49 | try: 50 | curr_layer = curr_layer.__getattr__(temp_name) 51 | if len(layer_infos) > 0: 52 | temp_name = layer_infos.pop(0) 53 | elif len(layer_infos) == 0: 54 | break 55 | except Exception: 56 | if len(temp_name) > 0: 57 | temp_name += "_" + layer_infos.pop(0) 58 | else: 59 | temp_name = layer_infos.pop(0) 60 | 61 | pair_keys = [] 62 | if "lora_down" in key: 63 | pair_keys.append(key.replace("lora_down", "lora_up")) 64 | pair_keys.append(key) 65 | else: 66 | pair_keys.append(key) 67 | pair_keys.append(key.replace("lora_up", "lora_down")) 68 | 69 | # update weight 70 | # NOTE: load lycon, maybe have bugs :( 71 | if "conv_in" in pair_keys[0]: 72 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 73 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 74 | weight_up = weight_up.view(weight_up.size(0), -1) 75 | weight_down = weight_down.view(weight_down.size(0), -1) 76 | shape = list(curr_layer.weight.data.shape) 77 | shape[1] = 4 78 | curr_layer.weight.data[:, :4, ...] += alpha * (weight_up @ weight_down).view(*shape) 79 | elif "conv" in pair_keys[0]: 80 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 81 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 82 | weight_up = weight_up.view(weight_up.size(0), -1) 83 | weight_down = weight_down.view(weight_down.size(0), -1) 84 | shape = list(curr_layer.weight.data.shape) 85 | curr_layer.weight.data += alpha * (weight_up @ weight_down).view(*shape) 86 | elif len(state_dict[pair_keys[0]].shape) == 4: 87 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) 88 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) 89 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to( 90 | curr_layer.weight.data.device 91 | ) 92 | else: 93 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 94 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 95 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 96 | 97 | # update visited list 98 | for item in pair_keys: 99 | visited.append(item) 100 | 101 | return unet, text_encoder 102 | -------------------------------------------------------------------------------- /live2diff/animatediff/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/Live2Diff/e40aa03a9a9d17f1232fd7a4566b3ee793e6893f/live2diff/animatediff/models/__init__.py -------------------------------------------------------------------------------- /live2diff/animatediff/models/depth_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | try: 6 | from ...MiDaS.midas.dpt_depth import DPTDepthModel 7 | except ImportError: 8 | print('Please pull the MiDaS submodule via "git submodule update --init --recursive"!') 9 | 10 | 11 | class MidasDetector(nn.Module): 12 | def __init__(self, model_path="./models/dpt_hybrid_384"): 13 | super().__init__() 14 | 15 | self.model = DPTDepthModel(path=model_path, backbone="vitb_rn50_384", non_negative=True) 16 | self.model.requires_grad_(False) 17 | self.model.eval() 18 | 19 | @property 20 | def dtype(self): 21 | return next(self.parameters()).dtype 22 | 23 | @property 24 | def device(self): 25 | return next(self.parameters()).device 26 | 27 | @torch.no_grad() 28 | def forward(self, images: torch.Tensor): 29 | """ 30 | Input: [b, c, h, w] 31 | """ 32 | return self.model(images) 33 | -------------------------------------------------------------------------------- /live2diff/animatediff/models/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class PositionalEncoding(nn.Module): 9 | def __init__(self, d_model, dropout=0.0, max_len=32): 10 | super().__init__() 11 | self.dropout = nn.Dropout(p=dropout) 12 | position = torch.arange(max_len).unsqueeze(1) 13 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 14 | pe = torch.zeros(1, max_len, d_model) 15 | pe[0, :, 0::2] = torch.sin(position * div_term) 16 | pe[0, :, 1::2] = torch.cos(position * div_term) 17 | self.register_buffer("pe", pe) 18 | 19 | def forward(self, x, roll: Optional[int] = None, full_video_length: Optional[int] = None): 20 | """ 21 | Support roll for positional encoding. 22 | We select the first `full_video_length` elements and roll it by `roll`. 23 | And then select the first `x.size(1)` elements and add them to `x`. 24 | 25 | Take full_video_length = 4, roll = 2, and x.size(1) = 1 as example. 26 | 27 | If the original positional encoding is: 28 | [1, 2, 3, 4, 5, 6, 7, 8] 29 | The rolled encoding is: 30 | [3, 4, 1, 2] 31 | And the selected encoding added to input is: 32 | [3, 4] 33 | 34 | """ 35 | if roll is None: 36 | pe = self.pe[:, : x.size(1)] 37 | else: 38 | assert full_video_length is not None, "full_video_length must be passed when roll is not None." 39 | pe = self.pe[:, :full_video_length].roll(shifts=roll, dims=1)[:, : x.size(1)] 40 | x = x + pe 41 | return self.dropout(x) 42 | -------------------------------------------------------------------------------- /live2diff/animatediff/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | from typing import Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | 9 | 10 | def zero_module(module): 11 | # Zero out the parameters of a module and return it. 12 | for p in module.parameters(): 13 | p.detach().zero_() 14 | return module 15 | 16 | 17 | class MappingNetwork(nn.Module): 18 | """ 19 | Modified from https://github.com/huggingface/diffusers/blob/196835695ed6fa3ec53b888088d9d5581e8f8e94/src/diffusers/models/controlnet.py#L66-L108 # noqa 20 | """ 21 | 22 | def __init__( 23 | self, 24 | conditioning_embedding_channels: int, 25 | conditioning_channels: int = 3, 26 | block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), 27 | ): 28 | super().__init__() 29 | 30 | self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) 31 | 32 | self.blocks = nn.ModuleList([]) 33 | 34 | for i in range(len(block_out_channels) - 1): 35 | channel_in = block_out_channels[i] 36 | channel_out = block_out_channels[i + 1] 37 | self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)) 38 | self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1)) 39 | 40 | self.conv_out = zero_module( 41 | InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) 42 | ) 43 | 44 | def forward(self, conditioning): 45 | embedding = self.conv_in(conditioning) 46 | embedding = F.silu(embedding) 47 | 48 | for block in self.blocks: 49 | embedding = block(embedding) 50 | embedding = F.silu(embedding) 51 | 52 | embedding = self.conv_out(embedding) 53 | 54 | return embedding 55 | 56 | 57 | class InflatedConv3d(nn.Conv2d): 58 | def forward(self, x): 59 | video_length = x.shape[2] 60 | 61 | x = rearrange(x, "b c f h w -> (b f) c h w") 62 | x = super().forward(x) 63 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 64 | 65 | return x 66 | 67 | 68 | class InflatedGroupNorm(nn.GroupNorm): 69 | def forward(self, x): 70 | video_length = x.shape[2] 71 | 72 | x = rearrange(x, "b c f h w -> (b f) c h w") 73 | x = super().forward(x) 74 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 75 | 76 | return x 77 | 78 | 79 | class Upsample3D(nn.Module): 80 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 81 | super().__init__() 82 | self.channels = channels 83 | self.out_channels = out_channels or channels 84 | self.use_conv = use_conv 85 | self.use_conv_transpose = use_conv_transpose 86 | self.name = name 87 | 88 | # conv = None 89 | if use_conv_transpose: 90 | raise NotImplementedError 91 | elif use_conv: 92 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 93 | 94 | def forward(self, hidden_states, output_size=None): 95 | assert hidden_states.shape[1] == self.channels 96 | 97 | if self.use_conv_transpose: 98 | raise NotImplementedError 99 | 100 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 101 | dtype = hidden_states.dtype 102 | if dtype == torch.bfloat16: 103 | hidden_states = hidden_states.to(torch.float32) 104 | 105 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 106 | if hidden_states.shape[0] >= 64: 107 | hidden_states = hidden_states.contiguous() 108 | 109 | # if `output_size` is passed we force the interpolation output 110 | # size and do not make use of `scale_factor=2` 111 | if output_size is None: 112 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 113 | else: 114 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 115 | 116 | # If the input is bfloat16, we cast back to bfloat16 117 | if dtype == torch.bfloat16: 118 | hidden_states = hidden_states.to(dtype) 119 | 120 | # if self.use_conv: 121 | # if self.name == "conv": 122 | # hidden_states = self.conv(hidden_states) 123 | # else: 124 | # hidden_states = self.Conv2d_0(hidden_states) 125 | hidden_states = self.conv(hidden_states) 126 | 127 | return hidden_states 128 | 129 | 130 | class Downsample3D(nn.Module): 131 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 132 | super().__init__() 133 | self.channels = channels 134 | self.out_channels = out_channels or channels 135 | self.use_conv = use_conv 136 | self.padding = padding 137 | stride = 2 138 | self.name = name 139 | 140 | if use_conv: 141 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 142 | else: 143 | raise NotImplementedError 144 | 145 | def forward(self, hidden_states): 146 | assert hidden_states.shape[1] == self.channels 147 | if self.use_conv and self.padding == 0: 148 | raise NotImplementedError 149 | 150 | assert hidden_states.shape[1] == self.channels 151 | hidden_states = self.conv(hidden_states) 152 | 153 | return hidden_states 154 | 155 | 156 | class ResnetBlock3D(nn.Module): 157 | def __init__( 158 | self, 159 | *, 160 | in_channels, 161 | out_channels=None, 162 | conv_shortcut=False, 163 | dropout=0.0, 164 | temb_channels=512, 165 | groups=32, 166 | groups_out=None, 167 | pre_norm=True, 168 | eps=1e-6, 169 | non_linearity="swish", 170 | time_embedding_norm="default", 171 | output_scale_factor=1.0, 172 | use_in_shortcut=None, 173 | use_inflated_groupnorm=False, 174 | ): 175 | super().__init__() 176 | self.pre_norm = pre_norm 177 | self.pre_norm = True 178 | self.in_channels = in_channels 179 | out_channels = in_channels if out_channels is None else out_channels 180 | self.out_channels = out_channels 181 | self.use_conv_shortcut = conv_shortcut 182 | self.time_embedding_norm = time_embedding_norm 183 | self.output_scale_factor = output_scale_factor 184 | 185 | if groups_out is None: 186 | groups_out = groups 187 | 188 | assert use_inflated_groupnorm is not None 189 | if use_inflated_groupnorm: 190 | self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 191 | else: 192 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 193 | 194 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 195 | 196 | if temb_channels is not None: 197 | if self.time_embedding_norm == "default": 198 | time_emb_proj_out_channels = out_channels 199 | elif self.time_embedding_norm == "scale_shift": 200 | time_emb_proj_out_channels = out_channels * 2 201 | else: 202 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 203 | 204 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 205 | else: 206 | self.time_emb_proj = None 207 | 208 | if use_inflated_groupnorm: 209 | self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 210 | else: 211 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 212 | 213 | self.dropout = torch.nn.Dropout(dropout) 214 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 215 | 216 | if non_linearity == "swish": 217 | self.nonlinearity = lambda x: F.silu(x) 218 | elif non_linearity == "mish": 219 | self.nonlinearity = Mish() 220 | elif non_linearity == "silu": 221 | self.nonlinearity = nn.SiLU() 222 | 223 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 224 | 225 | self.conv_shortcut = None 226 | if self.use_in_shortcut: 227 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 228 | 229 | def forward(self, input_tensor, temb): 230 | hidden_states = input_tensor 231 | 232 | hidden_states = self.norm1(hidden_states) 233 | hidden_states = self.nonlinearity(hidden_states) 234 | 235 | hidden_states = self.conv1(hidden_states) 236 | 237 | if temb is not None: 238 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 239 | 240 | if temb is not None and self.time_embedding_norm == "default": 241 | hidden_states = hidden_states + temb 242 | 243 | hidden_states = self.norm2(hidden_states) 244 | 245 | if temb is not None and self.time_embedding_norm == "scale_shift": 246 | scale, shift = torch.chunk(temb, 2, dim=1) 247 | hidden_states = hidden_states * (1 + scale) + shift 248 | 249 | hidden_states = self.nonlinearity(hidden_states) 250 | 251 | hidden_states = self.dropout(hidden_states) 252 | hidden_states = self.conv2(hidden_states) 253 | 254 | if self.conv_shortcut is not None: 255 | input_tensor = self.conv_shortcut(input_tensor) 256 | 257 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 258 | 259 | return output_tensor 260 | 261 | 262 | class Mish(torch.nn.Module): 263 | def forward(self, hidden_states): 264 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 265 | -------------------------------------------------------------------------------- /live2diff/animatediff/models/stream_motion_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | 5 | from .attention import CrossAttention 6 | from .positional_encoding import PositionalEncoding 7 | 8 | 9 | class StreamTemporalAttention(CrossAttention): 10 | """ 11 | 12 | * window_size: The max length of attention window. 13 | * sink_size: The number sink token. 14 | * positional_rule: absolute, relative 15 | 16 | Therefore, the seq length of temporal self-attention will be: 17 | sink_length + cache_size 18 | 19 | """ 20 | 21 | def __init__( 22 | self, 23 | attention_mode=None, 24 | cross_frame_attention_mode=None, 25 | temporal_position_encoding=False, 26 | temporal_position_encoding_max_len=32, 27 | window_size=8, 28 | sink_size=0, 29 | *args, 30 | **kwargs, 31 | ): 32 | super().__init__(*args, **kwargs) 33 | 34 | self.attention_mode = self._orig_attention_mode = attention_mode 35 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None 36 | 37 | self.pos_encoder = PositionalEncoding( 38 | kwargs["query_dim"], 39 | dropout=0.0, 40 | max_len=temporal_position_encoding_max_len, 41 | ) 42 | 43 | self.window_size = window_size 44 | self.sink_size = sink_size 45 | self.cache_size = self.window_size - self.sink_size 46 | assert self.cache_size >= 0, ( 47 | "cache_size must be greater or equal to 0. Please check your configuration. " 48 | f"window_size: {window_size}, sink_size: {sink_size}, " 49 | f"cache_size: {self.cache_size}" 50 | ) 51 | 52 | self.motion_module_idx = None 53 | 54 | def set_index(self, idx): 55 | self.motion_module_idx = idx 56 | 57 | @torch.no_grad() 58 | def set_cache(self, denoising_steps_num: int): 59 | """ 60 | larger buffer index means cleaner latent 61 | """ 62 | device = next(self.parameters()).device 63 | dtype = next(self.parameters()).dtype 64 | 65 | # [t, 2, hw, L, c], 2 means k and v 66 | kv_cache = torch.zeros( 67 | denoising_steps_num, 68 | 2, 69 | self.h * self.w, 70 | self.window_size, 71 | self.kv_channels, 72 | device=device, 73 | dtype=dtype, 74 | ) 75 | self.denoising_steps_num = denoising_steps_num 76 | 77 | return kv_cache 78 | 79 | @torch.no_grad() 80 | def prepare_pe_buffer(self): 81 | """In AnimateDiff, Temporal Self-attention use absolute positional encoding: 82 | q = w_q * (x + pe) + bias 83 | k = w_k * (x + pe) + bias 84 | v = w_v * (x + pe) + bias 85 | 86 | If we want to conduct relative positional encoding with kv-cache, we should pre-calcute 87 | `w_q/k/v * pe` and then cache `w_q/k/v * x + bias` 88 | """ 89 | 90 | pe_list = self.pos_encoder.pe[:, : self.window_size] # [1, window_size, ch] 91 | q_pe = F.linear(pe_list, self.to_q.weight) 92 | k_pe = F.linear(pe_list, self.to_k.weight) 93 | v_pe = F.linear(pe_list, self.to_v.weight) 94 | 95 | self.register_buffer("q_pe", q_pe) 96 | self.register_buffer("k_pe", k_pe) 97 | self.register_buffer("v_pe", v_pe) 98 | 99 | def prepare_qkv_full_and_cache(self, hidden_states, kv_cache, pe_idx, update_idx): 100 | """ 101 | hidden_states: [(N * bhw), F, c], 102 | kv_cache: [2, N, hw, L, c] 103 | 104 | * for warmup case: `N` should be 1 and `F` should be warmup_size (`sink_size`) 105 | * for streaming case: `N` should be `denoising_steps_num` and `F` should be `chunk_size` 106 | 107 | """ 108 | q_layer = self.to_q(hidden_states) 109 | k_layer = self.to_k(hidden_states) 110 | v_layer = self.to_v(hidden_states) 111 | 112 | q_layer = rearrange(q_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num) 113 | k_layer = rearrange(k_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num) 114 | v_layer = rearrange(v_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num) 115 | 116 | # onnx & trt friendly indexing 117 | for idx in range(self.denoising_steps_num): 118 | kv_cache[idx, 0, :, update_idx[idx]] = k_layer[idx, :, 0] 119 | kv_cache[idx, 1, :, update_idx[idx]] = v_layer[idx, :, 0] 120 | 121 | k_full = kv_cache[:, 0] 122 | v_full = kv_cache[:, 1] 123 | 124 | kv_idx = pe_idx 125 | q_idx = torch.stack([kv_idx[idx, update_idx[idx]] for idx in range(self.denoising_steps_num)]).unsqueeze_( 126 | 1 127 | ) # [timesteps, 1] 128 | 129 | pe_k = torch.cat( 130 | [self.k_pe.index_select(1, kv_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0 131 | ) # [n, window_size, c] 132 | pe_v = torch.cat( 133 | [self.v_pe.index_select(1, kv_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0 134 | ) # [n, window_size, c] 135 | pe_q = torch.cat( 136 | [self.q_pe.index_select(1, q_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0 137 | ) # [n, window_size, c] 138 | 139 | q_layer = q_layer + pe_q.unsqueeze(1) 140 | k_full = k_full + pe_k.unsqueeze(1) 141 | v_full = v_full + pe_v.unsqueeze(1) 142 | 143 | q_layer = rearrange(q_layer, "n bhw f c -> (n bhw) f c") 144 | k_full = rearrange(k_full, "n bhw f c -> (n bhw) f c") 145 | v_full = rearrange(v_full, "n bhw f c -> (n bhw) f c") 146 | 147 | return q_layer, k_full, v_full 148 | 149 | def forward( 150 | self, 151 | hidden_states, 152 | encoder_hidden_states=None, 153 | attention_mask=None, 154 | video_length=None, 155 | temporal_attention_mask=None, 156 | kv_cache=None, 157 | pe_idx=None, 158 | update_idx=None, 159 | *args, 160 | **kwargs, 161 | ): 162 | """ 163 | temporal_attention_mask: attention mask specific for the temporal self-attention. 164 | """ 165 | 166 | d = hidden_states.shape[1] 167 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 168 | 169 | if self.group_norm is not None: 170 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 171 | 172 | query_layer, key_full, value_full = self.prepare_qkv_full_and_cache( 173 | hidden_states, kv_cache, pe_idx, update_idx 174 | ) 175 | 176 | # [(n * hw * b), f, c] -> [(n * hw * b * head), f, c // head] 177 | query_layer = self.reshape_heads_to_batch_dim(query_layer) 178 | key_full = self.reshape_heads_to_batch_dim(key_full) 179 | value_full = self.reshape_heads_to_batch_dim(value_full) 180 | 181 | if temporal_attention_mask is not None: 182 | q_size = query_layer.shape[1] 183 | # [n, self.window_size] -> [n, hw, q_size, window_size] 184 | temporal_attention_mask_ = temporal_attention_mask[:, None, None, :].repeat(1, self.h * self.w, q_size, 1) 185 | temporal_attention_mask_ = rearrange(temporal_attention_mask_, "n hw Q KV -> (n hw) Q KV") 186 | temporal_attention_mask_ = temporal_attention_mask_.repeat_interleave(self.heads, dim=0) 187 | else: 188 | temporal_attention_mask_ = None 189 | 190 | # attention, what we cannot get enough of 191 | if hasattr(F, "scaled_dot_product_attention"): 192 | hidden_states = self._memory_efficient_attention_pt20( 193 | query_layer, key_full, value_full, attention_mask=temporal_attention_mask_ 194 | ) 195 | 196 | elif self._use_memory_efficient_attention_xformers: 197 | hidden_states = self._memory_efficient_attention_xformers( 198 | query_layer, key_full, value_full, attention_mask=temporal_attention_mask_ 199 | ) 200 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input 201 | hidden_states = hidden_states.to(query_layer.dtype) 202 | else: 203 | hidden_states = self._attention(query_layer, key_full, value_full, temporal_attention_mask_) 204 | 205 | # linear proj 206 | hidden_states = self.to_out[0](hidden_states) 207 | 208 | # dropout 209 | hidden_states = self.to_out[1](hidden_states) 210 | 211 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 212 | 213 | return hidden_states 214 | -------------------------------------------------------------------------------- /live2diff/animatediff/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_animatediff_depth import AnimationDepthPipeline 2 | 3 | 4 | __all__ = ["AnimationDepthPipeline"] 5 | -------------------------------------------------------------------------------- /live2diff/animatediff/pipeline/loader.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Union 2 | 3 | import torch 4 | from diffusers.loaders.lora import LoraLoaderMixin 5 | from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT 6 | from diffusers.utils import USE_PEFT_BACKEND 7 | 8 | 9 | class LoraLoaderWithWarmup(LoraLoaderMixin): 10 | unet_warmup_name = "unet_warmup" 11 | 12 | def load_lora_weights( 13 | self, 14 | pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], 15 | adapter_name=None, 16 | **kwargs, 17 | ): 18 | # load lora for text encoder and unet-streaming 19 | super().load_lora_weights(pretrained_model_name_or_path_or_dict, adapter_name=adapter_name, **kwargs) 20 | 21 | # load lora for unet-warmup 22 | state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) 23 | low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) 24 | 25 | self.load_lora_into_unet( 26 | state_dict, 27 | network_alphas=network_alphas, 28 | unet=getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup, 29 | low_cpu_mem_usage=low_cpu_mem_usage, 30 | adapter_name=adapter_name, 31 | _pipeline=self, 32 | ) 33 | 34 | def fuse_lora( 35 | self, 36 | fuse_unet: bool = True, 37 | fuse_text_encoder: bool = True, 38 | lora_scale: float = 1.0, 39 | safe_fusing: bool = False, 40 | adapter_names: Optional[List[str]] = None, 41 | ): 42 | # fuse lora for text encoder and unet-streaming 43 | super().fuse_lora(fuse_unet, fuse_text_encoder, lora_scale, safe_fusing, adapter_names) 44 | 45 | # fuse lora for unet-warmup 46 | if fuse_unet: 47 | unet_warmup = ( 48 | getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup 49 | ) 50 | unet_warmup.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) 51 | 52 | def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True): 53 | # unfuse lora for text encoder and unet-streaming 54 | super().unfuse_lora(unfuse_unet, unfuse_text_encoder) 55 | 56 | # unfuse lora for unet-warmup 57 | if unfuse_unet: 58 | unet_warmup = ( 59 | getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup 60 | ) 61 | if not USE_PEFT_BACKEND: 62 | unet_warmup.unfuse_lora() 63 | else: 64 | from peft.tuners.tuners_utils import BaseTunerLayer 65 | 66 | for module in unet_warmup.modules(): 67 | if isinstance(module, BaseTunerLayer): 68 | module.unmerge() 69 | -------------------------------------------------------------------------------- /live2diff/animatediff/pipeline/pipeline_animatediff_depth.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/open-mmlab/PIA/blob/main/animatediff/pipelines/i2v_pipeline.py 2 | 3 | from dataclasses import dataclass 4 | from typing import List, Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | from diffusers.configuration_utils import FrozenDict 9 | from diffusers.loaders import TextualInversionLoaderMixin 10 | from diffusers.models import AutoencoderKL 11 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 12 | from diffusers.schedulers import ( 13 | DDIMScheduler, 14 | DPMSolverMultistepScheduler, 15 | EulerAncestralDiscreteScheduler, 16 | EulerDiscreteScheduler, 17 | LMSDiscreteScheduler, 18 | PNDMScheduler, 19 | ) 20 | from diffusers.utils import BaseOutput, deprecate, is_accelerate_available, logging 21 | from packaging import version 22 | from transformers import CLIPTextModel, CLIPTokenizer 23 | 24 | from ..models.depth_utils import MidasDetector 25 | from ..models.unet_depth_streaming import UNet3DConditionStreamingModel 26 | from .loader import LoraLoaderWithWarmup 27 | 28 | 29 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 30 | 31 | 32 | @dataclass 33 | class AnimationPipelineOutput(BaseOutput): 34 | videos: Union[torch.Tensor, np.ndarray] 35 | input_images: Optional[Union[torch.Tensor, np.ndarray]] = None 36 | 37 | 38 | class AnimationDepthPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderWithWarmup): 39 | _optional_components = [] 40 | 41 | def __init__( 42 | self, 43 | vae: AutoencoderKL, 44 | text_encoder: CLIPTextModel, 45 | tokenizer: CLIPTokenizer, 46 | unet: UNet3DConditionStreamingModel, 47 | depth_model: MidasDetector, 48 | scheduler: Union[ 49 | DDIMScheduler, 50 | PNDMScheduler, 51 | LMSDiscreteScheduler, 52 | EulerDiscreteScheduler, 53 | EulerAncestralDiscreteScheduler, 54 | DPMSolverMultistepScheduler, 55 | ], 56 | ): 57 | super().__init__() 58 | 59 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: 60 | deprecation_message = ( 61 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 62 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 63 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" 64 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," 65 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" 66 | " file" 67 | ) 68 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) 69 | new_config = dict(scheduler.config) 70 | new_config["steps_offset"] = 1 71 | scheduler._internal_dict = FrozenDict(new_config) 72 | 73 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: 74 | deprecation_message = ( 75 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." 76 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the" 77 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" 78 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" 79 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" 80 | ) 81 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) 82 | new_config = dict(scheduler.config) 83 | new_config["clip_sample"] = False 84 | scheduler._internal_dict = FrozenDict(new_config) 85 | 86 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( 87 | version.parse(unet.config._diffusers_version).base_version 88 | ) < version.parse("0.9.0.dev0") 89 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 90 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: 91 | deprecation_message = ( 92 | "The configuration file of the unet has set the default `sample_size` to smaller than" 93 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" 94 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" 95 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" 96 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" 97 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" 98 | " in the config might lead to incorrect results in future versions. If you have downloaded this" 99 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" 100 | " the `unet/config.json` file" 101 | ) 102 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) 103 | new_config = dict(unet.config) 104 | new_config["sample_size"] = 64 105 | unet._internal_dict = FrozenDict(new_config) 106 | 107 | self.register_modules( 108 | vae=vae, 109 | text_encoder=text_encoder, 110 | tokenizer=tokenizer, 111 | unet=unet, 112 | depth_model=depth_model, 113 | scheduler=scheduler, 114 | ) 115 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 116 | self.log_denoising_mean = False 117 | 118 | def enable_vae_slicing(self): 119 | self.vae.enable_slicing() 120 | 121 | def disable_vae_slicing(self): 122 | self.vae.disable_slicing() 123 | 124 | def enable_sequential_cpu_offload(self, gpu_id=0): 125 | if is_accelerate_available(): 126 | from accelerate import cpu_offload 127 | else: 128 | raise ImportError("Please install accelerate via `pip install accelerate`") 129 | 130 | device = torch.device(f"cuda:{gpu_id}") 131 | 132 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 133 | if cpu_offloaded_model is not None: 134 | cpu_offload(cpu_offloaded_model, device) 135 | 136 | @property 137 | def _execution_device(self): 138 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 139 | return self.device 140 | for module in self.unet.modules(): 141 | if ( 142 | hasattr(module, "_hf_hook") 143 | and hasattr(module._hf_hook, "execution_device") 144 | and module._hf_hook.execution_device is not None 145 | ): 146 | return torch.device(module._hf_hook.execution_device) 147 | return self.device 148 | 149 | def _encode_prompt( 150 | self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt, clip_skip=None 151 | ): 152 | batch_size = len(prompt) if isinstance(prompt, list) else 1 153 | 154 | text_inputs = self.tokenizer( 155 | prompt, 156 | padding="max_length", 157 | max_length=self.tokenizer.model_max_length, 158 | truncation=True, 159 | return_tensors="pt", 160 | ) 161 | text_input_ids = text_inputs.input_ids 162 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 163 | 164 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 165 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) 166 | logger.warning( 167 | "The following part of your input was truncated because CLIP can only handle sequences up to" 168 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 169 | ) 170 | 171 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 172 | attention_mask = text_inputs.attention_mask.to(device) 173 | else: 174 | attention_mask = None 175 | 176 | if clip_skip is None: 177 | text_embeddings = self.text_encoder( 178 | text_input_ids.to(device), 179 | attention_mask=attention_mask, 180 | ) 181 | text_embeddings = text_embeddings[0] 182 | else: 183 | # support ckip skip here, suitable for model based on NAI~ 184 | text_embeddings = self.text_encoder( 185 | text_input_ids.to(device), 186 | attention_mask=attention_mask, 187 | output_hidden_states=True, 188 | ) 189 | text_embeddings = text_embeddings[-1][-(clip_skip + 1)] 190 | text_embeddings = self.text_encoder.text_model.final_layer_norm(text_embeddings) 191 | 192 | # duplicate text embeddings for each generation per prompt, using mps friendly method 193 | bs_embed, seq_len, _ = text_embeddings.shape 194 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) 195 | text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) 196 | 197 | # get unconditional embeddings for classifier free guidance 198 | if do_classifier_free_guidance: 199 | uncond_tokens: List[str] 200 | if negative_prompt is None: 201 | uncond_tokens = [""] * batch_size 202 | elif type(prompt) is not type(negative_prompt): 203 | raise TypeError( 204 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 205 | f" {type(prompt)}." 206 | ) 207 | elif isinstance(negative_prompt, str): 208 | uncond_tokens = [negative_prompt] 209 | elif batch_size != len(negative_prompt): 210 | raise ValueError( 211 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 212 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 213 | " the batch size of `prompt`." 214 | ) 215 | else: 216 | uncond_tokens = negative_prompt 217 | 218 | max_length = text_input_ids.shape[-1] 219 | uncond_input = self.tokenizer( 220 | uncond_tokens, 221 | padding="max_length", 222 | max_length=max_length, 223 | truncation=True, 224 | return_tensors="pt", 225 | ) 226 | 227 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 228 | attention_mask = uncond_input.attention_mask.to(device) 229 | else: 230 | attention_mask = None 231 | 232 | uncond_embeddings = self.text_encoder( 233 | uncond_input.input_ids.to(device), 234 | attention_mask=attention_mask, 235 | ) 236 | uncond_embeddings = uncond_embeddings[0] 237 | 238 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 239 | seq_len = uncond_embeddings.shape[1] 240 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) 241 | uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) 242 | 243 | # For classifier free guidance, we need to do two forward passes. 244 | # Here we concatenate the unconditional and text embeddings into a single batch 245 | # to avoid doing two forward passes 246 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 247 | 248 | return text_embeddings 249 | 250 | @classmethod 251 | def build_pipeline(cls, config_path: str, dreambooth: Optional[str] = None): 252 | """We build pipeline from config path""" 253 | from omegaconf import OmegaConf 254 | 255 | from ...utils.config import load_config 256 | from ..converter import load_third_party_checkpoints 257 | from ..models.unet_depth_streaming import UNet3DConditionStreamingModel 258 | 259 | cfg = load_config(config_path) 260 | pretrained_model_path = cfg.pretrained_model_path 261 | unet_additional_kwargs = cfg.get("unet_additional_kwargs", {}) 262 | noise_scheduler_kwargs = cfg.noise_scheduler_kwargs 263 | third_party_dict = cfg.get("third_party_dict", {}) 264 | 265 | noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) 266 | 267 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") 268 | vae = vae.to(device="cuda", dtype=torch.bfloat16) 269 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") 270 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") 271 | text_encoder = text_encoder.to(device="cuda", dtype=torch.float16) 272 | 273 | unet = UNet3DConditionStreamingModel.from_pretrained_2d( 274 | pretrained_model_path, 275 | subfolder="unet", 276 | unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs) if unet_additional_kwargs else {}, 277 | ) 278 | 279 | motion_module_path = cfg.motion_module_path 280 | # load motion module to unet 281 | mm_checkpoint = torch.load(motion_module_path, map_location="cuda") 282 | if "global_step" in mm_checkpoint: 283 | print(f"global_step: {mm_checkpoint['global_step']}") 284 | state_dict = mm_checkpoint["state_dict"] if "state_dict" in mm_checkpoint else mm_checkpoint 285 | # NOTE: hard code here: remove `grid` from state_dict 286 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items() if "grid" not in k} 287 | 288 | m, u = unet.load_state_dict(state_dict, strict=False) 289 | assert len(u) == 0, f"Find unexpected keys ({len(u)}): {u}" 290 | print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") 291 | 292 | unet = unet.to(dtype=torch.float16) 293 | depth_model = MidasDetector(cfg.depth_model_path).to(device="cuda", dtype=torch.float16) 294 | 295 | pipeline = cls( 296 | unet=unet, 297 | vae=vae, 298 | tokenizer=tokenizer, 299 | text_encoder=text_encoder, 300 | depth_model=depth_model, 301 | scheduler=noise_scheduler, 302 | ) 303 | pipeline = load_third_party_checkpoints(pipeline, third_party_dict, dreambooth) 304 | 305 | return pipeline 306 | 307 | @classmethod 308 | def build_warmup_unet(cls, config_path: str, dreambooth: Optional[str] = None): 309 | from omegaconf import OmegaConf 310 | 311 | from ...utils.config import load_config 312 | from ..converter import load_third_party_unet 313 | from ..models.unet_depth_warmup import UNet3DConditionWarmupModel 314 | 315 | cfg = load_config(config_path) 316 | pretrained_model_path = cfg.pretrained_model_path 317 | unet_additional_kwargs = cfg.get("unet_additional_kwargs", {}) 318 | third_party_dict = cfg.get("third_party_dict", {}) 319 | 320 | unet = UNet3DConditionWarmupModel.from_pretrained_2d( 321 | pretrained_model_path, 322 | subfolder="unet", 323 | unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs) if unet_additional_kwargs else {}, 324 | ) 325 | motion_module_path = cfg.motion_module_path 326 | # load motion module to unet 327 | mm_checkpoint = torch.load(motion_module_path, map_location="cpu") 328 | if "global_step" in mm_checkpoint: 329 | print(f"global_step: {mm_checkpoint['global_step']}") 330 | state_dict = mm_checkpoint["state_dict"] if "state_dict" in mm_checkpoint else mm_checkpoint 331 | # NOTE: hard code here: remove `grid` from state_dict 332 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items() if "grid" not in k} 333 | 334 | m, u = unet.load_state_dict(state_dict, strict=False) 335 | assert len(u) == 0, f"Find unexpected keys ({len(u)}): {u}" 336 | print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") 337 | 338 | unet = load_third_party_unet(unet, third_party_dict, dreambooth) 339 | return unet 340 | 341 | def prepare_cache(self, height: int, width: int, denoising_steps_num: int): 342 | vae = self.vae 343 | scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) 344 | self.unet.set_info_for_attn(height // scale_factor, width // scale_factor) 345 | kv_cache_list = self.unet.prepare_cache(denoising_steps_num) 346 | return kv_cache_list 347 | 348 | def prepare_warmup_unet(self, height: int, width: int, unet): 349 | vae = self.vae 350 | scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) 351 | unet.set_info_for_attn(height // scale_factor, width // scale_factor) 352 | -------------------------------------------------------------------------------- /live2diff/image_filter.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional 3 | 4 | import torch 5 | 6 | 7 | class SimilarImageFilter: 8 | def __init__(self, threshold: float = 0.98, max_skip_frame: float = 10) -> None: 9 | self.threshold = threshold 10 | self.prev_tensor = None 11 | self.cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6) 12 | self.max_skip_frame = max_skip_frame 13 | self.skip_count = 0 14 | 15 | def __call__(self, x: torch.Tensor) -> Optional[torch.Tensor]: 16 | if self.prev_tensor is None: 17 | self.prev_tensor = x.detach().clone() 18 | return x 19 | else: 20 | cos_sim = self.cos(self.prev_tensor.reshape(-1), x.reshape(-1)).item() 21 | sample = random.uniform(0, 1) 22 | if self.threshold >= 1: 23 | skip_prob = 0 24 | else: 25 | skip_prob = max(0, 1 - (1 - cos_sim) / (1 - self.threshold)) 26 | 27 | # not skip frame 28 | if skip_prob < sample: 29 | self.prev_tensor = x.detach().clone() 30 | return x 31 | # skip frame 32 | else: 33 | if self.skip_count > self.max_skip_frame: 34 | self.skip_count = 0 35 | self.prev_tensor = x.detach().clone() 36 | return x 37 | else: 38 | self.skip_count += 1 39 | return None 40 | 41 | def set_threshold(self, threshold: float) -> None: 42 | self.threshold = threshold 43 | 44 | def set_max_skip_frame(self, max_skip_frame: float) -> None: 45 | self.max_skip_frame = max_skip_frame 46 | -------------------------------------------------------------------------------- /live2diff/image_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | import PIL.Image 5 | import torch 6 | import torchvision 7 | 8 | 9 | def denormalize(images: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: 10 | """ 11 | Denormalize an image array to [0,1]. 12 | """ 13 | return (images / 2 + 0.5).clamp(0, 1) 14 | 15 | 16 | def pt_to_numpy(images: torch.Tensor) -> np.ndarray: 17 | """ 18 | Convert a PyTorch tensor to a NumPy image. 19 | """ 20 | images = images.cpu().permute(0, 2, 3, 1).float().numpy() 21 | return images 22 | 23 | 24 | def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image: 25 | """ 26 | Convert a NumPy image or a batch of images to a PIL image. 27 | """ 28 | if images.ndim == 3: 29 | images = images[None, ...] 30 | images = (images * 255).round().astype("uint8") 31 | if images.shape[-1] == 1: 32 | # special case for grayscale (single channel) images 33 | pil_images = [PIL.Image.fromarray(image.squeeze(), mode="L") for image in images] 34 | else: 35 | pil_images = [PIL.Image.fromarray(image) for image in images] 36 | 37 | return pil_images 38 | 39 | 40 | def postprocess_image( 41 | image: torch.Tensor, 42 | output_type: str = "pil", 43 | do_denormalize: Optional[List[bool]] = None, 44 | ) -> Union[torch.Tensor, np.ndarray, PIL.Image.Image]: 45 | if not isinstance(image, torch.Tensor): 46 | raise ValueError( 47 | f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" 48 | ) 49 | 50 | if output_type == "latent": 51 | return image 52 | 53 | do_normalize_flg = True 54 | if do_denormalize is None: 55 | do_denormalize = [do_normalize_flg] * image.shape[0] 56 | 57 | image = torch.stack([denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]) 58 | 59 | if output_type == "pt": 60 | return image 61 | 62 | image = pt_to_numpy(image) 63 | 64 | if output_type == "np": 65 | return image 66 | 67 | if output_type == "pil": 68 | return numpy_to_pil(image) 69 | 70 | 71 | def process_image( 72 | image_pil: PIL.Image.Image, range: Tuple[int, int] = (-1, 1) 73 | ) -> Tuple[torch.Tensor, PIL.Image.Image]: 74 | image = torchvision.transforms.ToTensor()(image_pil) 75 | r_min, r_max = range[0], range[1] 76 | image = image * (r_max - r_min) + r_min 77 | return image[None, ...], image_pil 78 | 79 | 80 | def pil2tensor(image_pil: PIL.Image.Image) -> torch.Tensor: 81 | height = image_pil.height 82 | width = image_pil.width 83 | imgs = [] 84 | img, _ = process_image(image_pil) 85 | imgs.append(img) 86 | imgs = torch.vstack(imgs) 87 | images = torch.nn.functional.interpolate(imgs, size=(height, width), mode="bilinear") 88 | image_tensors = images.to(torch.float16) 89 | return image_tensors 90 | -------------------------------------------------------------------------------- /live2diff/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/Live2Diff/e40aa03a9a9d17f1232fd7a4566b3ee793e6893f/live2diff/utils/__init__.py -------------------------------------------------------------------------------- /live2diff/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | from omegaconf import OmegaConf 5 | 6 | 7 | config_suffix = [".yaml"] 8 | 9 | 10 | def load_config(config: str) -> OmegaConf: 11 | config = OmegaConf.load(config) 12 | base_config = config.pop("base", None) 13 | 14 | if base_config: 15 | config = OmegaConf.merge(OmegaConf.load(base_config), config) 16 | 17 | return config 18 | 19 | 20 | def dump_config(config: OmegaConf, save_path: str = None): 21 | from omegaconf import Container 22 | 23 | if isinstance(config, Container): 24 | if not save_path.endswith(".yaml"): 25 | save_dir = save_path 26 | save_path = osp.join(save_dir, "config.yaml") 27 | else: 28 | save_dir = osp.basename(config) 29 | os.makedirs(save_dir, exist_ok=True) 30 | OmegaConf.save(config, save_path) 31 | 32 | else: 33 | raise TypeError("Only support saving `Config` from `OmegaConf`.") 34 | 35 | print(f"Dump Config to {save_path}.") 36 | -------------------------------------------------------------------------------- /live2diff/utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import imageio 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | from einops import rearrange 9 | from PIL import Image 10 | 11 | 12 | def read_video_frames(folder: str, height=None, width=None): 13 | """ 14 | Read video frames from the given folder. 15 | 16 | Output: 17 | frames, in [0, 255], uint8, THWC 18 | """ 19 | _SUPPORTED_EXTENSIONS = [".png", ".jpg", ".jpeg"] 20 | 21 | frames = [f for f in os.listdir(folder) if osp.splitext(f)[1] in _SUPPORTED_EXTENSIONS] 22 | # sort frames 23 | sorted_frames = sorted(frames, key=lambda x: int(osp.splitext(x)[0])) 24 | sorted_frames = [osp.join(folder, f) for f in sorted_frames] 25 | 26 | if height is not None and width is not None: 27 | sorted_frames = [np.array(Image.open(f).resize((width, height))) for f in sorted_frames] 28 | else: 29 | sorted_frames = [np.array(Image.open(f)) for f in sorted_frames] 30 | sorted_frames = torch.stack([torch.from_numpy(f) for f in sorted_frames], dim=0) 31 | return sorted_frames 32 | 33 | 34 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): 35 | videos = rearrange(videos, "b c t h w -> t b c h w") 36 | outputs = [] 37 | for x in videos: 38 | x = torchvision.utils.make_grid(x, nrow=n_rows) 39 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 40 | if rescale: 41 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 42 | x = (x * 255).numpy().astype(np.uint8) 43 | outputs.append(x) 44 | 45 | parent_dir = os.path.dirname(path) 46 | if parent_dir != "": 47 | os.makedirs(parent_dir, exist_ok=True) 48 | imageio.mimsave(path, outputs, fps=fps, loop=0) 49 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | # Never enforce `E501` (line length violations). 3 | lint.ignore = ["C901", "E501", "E741", "F402", "F403", "F405", "F823"] 4 | lint.select = ["C", "E", "F", "I", "W"] 5 | line-length = 119 6 | 7 | # Ignore import violations in all `__init__.py` files. 8 | [tool.ruff.lint.per-file-ignores] 9 | "__init__.py" = ["E402", "F401", "F811"] 10 | 11 | [tool.ruff.lint.isort] 12 | lines-after-imports = 2 13 | known-first-party = ["live2diff"] 14 | 15 | [tool.ruff.format] 16 | # Like Black, use double quotes for strings. 17 | quote-style = "double" 18 | 19 | # Like Black, indent with spaces, rather than tabs. 20 | indent-style = "space" 21 | 22 | # Like Black, respect magic trailing commas. 23 | skip-magic-trailing-comma = false 24 | 25 | # Like Black, automatically detect the appropriate line ending. 26 | line-ending = "auto" 27 | 28 | [build-system] 29 | requires = ["setuptools"] 30 | build-backend = "setuptools.build_meta" 31 | -------------------------------------------------------------------------------- /scripts/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | TOKEN=$2 3 | 4 | download_disney() { 5 | echo "Download checkpoint for Disney..." 6 | wget https://civitai.com/api/download/models/69832\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate 7 | } 8 | 9 | download_moxin () { 10 | echo "Download checkpoints for MoXin..." 11 | wget https://civitai.com/api/download/models/106289\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate 12 | wget https://civitai.com/api/download/models/14856\?token\=${TOKEN} -P ./models/LoRA --content-disposition --no-check-certificate 13 | } 14 | 15 | download_pixart () { 16 | echo "Download checkpoint for PixArt..." 17 | wget https://civitai.com/api/download/models/220049\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate 18 | } 19 | 20 | download_origami () { 21 | echo "Download checkpoints for origami..." 22 | wget https://civitai.com/api/download/models/270085\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate 23 | wget https://civitai.com/api/download/models/266928\?token\=${TOKEN} -P ./models/LoRA --content-disposition --no-check-certificate 24 | } 25 | 26 | download_threeDelicacy () { 27 | echo "Download checkpoints for threeDelicacy..." 28 | wget https://civitai.com/api/download/models/36473\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate 29 | } 30 | 31 | download_toonyou () { 32 | echo "Download checkpoint for Toonyou..." 33 | wget https://civitai.com/api/download/models/125771\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate 34 | } 35 | 36 | download_zaum () { 37 | echo "Download checkpoints for Zaum..." 38 | wget https://civitai.com/api/download/models/428862\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate 39 | wget https://civitai.com/api/download/models/18989\?token\=${TOKEN} -P ./models/LoRA --content-disposition --no-check-certificate 40 | } 41 | 42 | download_felted () { 43 | echo "Download checkpoints for Felted..." 44 | wget https://civitai.com/api/download/models/428862\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate 45 | wget https://civitai.com/api/download/models/86739\?token\=${TOKEN} -P ./models/LoRA --content-disposition --no-check-certificate 46 | } 47 | 48 | if [ -z "$1" ]; then 49 | echo "Please input the model you want to download." 50 | echo "Supported model: all, disney, moxin, pixart, paperArt, threeDelicacy, toonyou, zaum." 51 | exit 1 52 | fi 53 | 54 | declare -A download_func=( 55 | ["disney"]="download_disney" 56 | ["moxin"]="download_moxin" 57 | ["pixart"]="download_pixart" 58 | ["origami"]="download_origami" 59 | ["threeDelicacy"]="download_threeDelicacy" 60 | ["toonyou"]="download_toonyou" 61 | ["zaum"]="download_zaum" 62 | ["felted"]="download_felted" 63 | ) 64 | 65 | execute_function() { 66 | local key="$1" 67 | if [[ -n "${download_func[$key]}" ]]; then 68 | ${download_func[$key]} 69 | else 70 | echo "Function not found for key: $key" 71 | fi 72 | } 73 | 74 | 75 | for arg in "$@"; do 76 | case "$arg" in 77 | disney|moxin|pixart|origami|threeDelicacy|toonyou|zaum|felted) 78 | model_name="$arg" 79 | execute_function "$model_name" 80 | ;; 81 | all) 82 | for model_name in "${!download_func[@]}"; do 83 | execute_function "$model_name" 84 | done 85 | ;; 86 | *) 87 | echo "Invalid argument: $arg." 88 | exit 1 89 | ;; 90 | esac 91 | done 92 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | 4 | deps = [ 5 | "diffusers==0.25.0", 6 | "transformers", 7 | "accelerate", 8 | "fire", 9 | "einops", 10 | "omegaconf", 11 | "imageio", 12 | "timm==0.6.7", 13 | "lightning", 14 | "peft", 15 | "av", 16 | "decord", 17 | "pillow", 18 | "pywin32;sys_platform == 'win32'", 19 | ] 20 | 21 | deps_tensorrt = [ 22 | "onnx==1.16.0", 23 | "onnxruntime==1.16.3", 24 | "protobuf==5.27.0", 25 | "polygraphy", 26 | "onnx-graphsurgeon", 27 | "cuda-python", 28 | "tensorrt==10.0.1", 29 | "colored", 30 | ] 31 | deps_tensorrt_cu11 = [ 32 | "tensorrt_cu11_libs==10.0.1", 33 | "tensorrt_cu11_bindings==10.0.1", 34 | ] 35 | deps_tensorrt_cu12 = [ 36 | "tensorrt_cu12_libs==10.0.1", 37 | "tensorrt_cu12_bindings==10.0.1", 38 | ] 39 | extras = { 40 | "tensorrt_cu11": deps_tensorrt + deps_tensorrt_cu11, 41 | "tensorrt_cu12": deps_tensorrt + deps_tensorrt_cu12, 42 | } 43 | 44 | 45 | if __name__ == "__main__": 46 | setup( 47 | name="Live2Diff", 48 | version="0.1", 49 | description="real-time interactive video translation pipeline", 50 | long_description=open("README.md", "r", encoding="utf-8").read(), 51 | long_description_content_type="text/markdown", 52 | keywords="deep learning diffusion pytorch stable diffusion streamdiffusion real-time next-frame prediction", 53 | license="Apache 2.0 License", 54 | author="leo", 55 | author_email="xingzhening@pjlab.org.cn", 56 | url="https://github.com/open-mmlab/Live2Diff", 57 | package_dir={"": "live2diff"}, 58 | packages=find_packages("live2diff"), 59 | python_requires=">=3.10.0", 60 | install_requires=deps, 61 | extras_require=extras, 62 | ) 63 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List, Literal, Optional 3 | 4 | import fire 5 | import numpy as np 6 | import torch 7 | from decord import VideoReader 8 | from PIL import Image 9 | from torchvision import transforms 10 | from torchvision.io import write_video 11 | from tqdm import tqdm 12 | 13 | from live2diff.utils.config import load_config 14 | from live2diff.utils.io import read_video_frames, save_videos_grid 15 | from live2diff.utils.wrapper import StreamAnimateDiffusionDepthWrapper 16 | 17 | 18 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 19 | 20 | 21 | def main( 22 | input: str, 23 | config_path: str, 24 | prompt: Optional[str] = None, 25 | prompt_template: Optional[str] = None, 26 | output: str = os.path.join("outputs", "output.mp4"), 27 | dreambooth_path: Optional[str] = None, 28 | lora_dict: Optional[Dict[str, float]] = None, 29 | height: int = 512, 30 | width: int = 512, 31 | max_frames: int = -1, 32 | num_inference_steps: Optional[int] = None, 33 | t_index_list: Optional[List[int]] = None, 34 | strength: Optional[float] = None, 35 | acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt", 36 | enable_similar_image_filter: bool = False, 37 | few_step_model_type: str = "lcm", 38 | enable_tiny_vae: bool = True, 39 | fps: int = 16, 40 | save_input: bool = True, 41 | seed: int = 42, 42 | ): 43 | """ 44 | Process for generating images based on a prompt using a specified model. 45 | 46 | Parameters 47 | ---------- 48 | input : str 49 | The input video name or name of video frames to load images from. 50 | config_path: str, optional 51 | The path to config file. 52 | prompt : str 53 | The prompt to generate images from. 54 | prompt_template: str, optional 55 | The template for specific dreambooth / LoRA. If not None, `{}` must be contained, 56 | and the prompt used for inference will be `prompt_template.format(prompt)`. 57 | output : str, optional 58 | The output video name to save images to. 59 | model_id_or_path : str 60 | The name of the model to use for image generation. 61 | lora_dict : Optional[Dict[str, float]], optional 62 | The lora_dict to load, by default None. 63 | Keys are the LoRA names and values are the LoRA scales. 64 | Example: `python main.py --lora_dict='{"LoRA_1" : 0.5 , "LoRA_2" : 0.7 ,...}'` 65 | height: int, optional 66 | The height of the image, by default 512. 67 | width: int, optional 68 | The width of the image, by default 512. 69 | max_frames : int, optional 70 | The maximum number of frames to process, by default -1. 71 | acceleration : Literal["none", "xformers", "tensorrt"] 72 | The type of acceleration to use for image generation. 73 | enable_similar_image_filter : bool, optional 74 | Whether to enable similar image filter or not, 75 | by default True. 76 | fps: int 77 | The fps of the output video, by default 16. 78 | save_input: bool, optional 79 | Whether to save the input video or not, by default True. 80 | If true, the input video will be saved as `output` + "_inp.mp4". 81 | seed : int, optional 82 | The seed, by default 42. if -1, use random seed. 83 | """ 84 | 85 | if os.path.isdir(input): 86 | video = read_video_frames(input) / 255 87 | elif input.endswith(".mp4"): 88 | reader = VideoReader(input) 89 | total_frames = len(reader) 90 | frame_indices = np.arange(total_frames) 91 | video = reader.get_batch(frame_indices).asnumpy() / 255 92 | video = torch.from_numpy(video) 93 | elif input.endswith(".gif"): 94 | video_frames = [] 95 | image = Image.open(input) 96 | for frames in range(image.n_frames): 97 | image.seek(frames) 98 | video_frames.append(np.array(image.convert("RGB"))) 99 | video = torch.from_numpy(np.array(video_frames)) / 255 100 | 101 | video = video[2:] 102 | 103 | height = int(height // 8 * 8) 104 | width = int(width // 8 * 8) 105 | 106 | trans = transforms.Compose( 107 | [ 108 | transforms.Resize(min(height, width), antialias=True), 109 | transforms.CenterCrop((height, width)), 110 | ] 111 | ) 112 | video = trans(video.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 113 | 114 | if max_frames > 0: 115 | video = video[: min(max_frames, len(video))] 116 | print(f"Clipping video to {len(video)} frames.") 117 | 118 | cfg = load_config(config_path) 119 | print("Inference Config:") 120 | print(cfg) 121 | 122 | # handle prompt 123 | cfg_prompt = cfg.get("prompt", None) 124 | prompt = prompt or cfg_prompt 125 | 126 | prompt_template = prompt_template or cfg.get("prompt_template", None) 127 | if prompt_template is not None: 128 | assert "{}" in prompt_template, '"{}" must be contained in "prompt_template".' 129 | prompt = prompt_template.format(prompt) 130 | 131 | print(f'Convert input prompt to "{prompt}".') 132 | 133 | # handle timesteps 134 | num_inference_steps = num_inference_steps or cfg.get("num_inference_steps", None) 135 | strength = strength or cfg.get("strength", None) 136 | t_index_list = t_index_list or cfg.get("t_index_list", None) 137 | 138 | stream = StreamAnimateDiffusionDepthWrapper( 139 | few_step_model_type=few_step_model_type, 140 | config_path=config_path, 141 | cfg_type="none", 142 | dreambooth_path=dreambooth_path, 143 | lora_dict=lora_dict, 144 | strength=strength, 145 | num_inference_steps=num_inference_steps, 146 | t_index_list=t_index_list, 147 | frame_buffer_size=1, 148 | width=width, 149 | height=height, 150 | acceleration=acceleration, 151 | do_add_noise=True, 152 | output_type="pt", 153 | enable_similar_image_filter=enable_similar_image_filter, 154 | similar_image_filter_threshold=0.98, 155 | use_denoising_batch=True, 156 | use_tiny_vae=enable_tiny_vae, 157 | seed=seed, 158 | ) 159 | warmup_frames = video[:8].permute(0, 3, 1, 2) 160 | warmup_results = stream.prepare( 161 | warmup_frames=warmup_frames, 162 | prompt=prompt, 163 | guidance_scale=1, 164 | ) 165 | video_result = torch.zeros(video.shape[0], height, width, 3) 166 | warmup_results = warmup_results.cpu().float() 167 | video_result[:8] = warmup_results 168 | 169 | skip_frames = stream.batch_size - 1 170 | for i in tqdm(range(8, video.shape[0])): 171 | output_image = stream(video[i].permute(2, 0, 1)) 172 | if i - 8 >= skip_frames: 173 | video_result[i - skip_frames] = output_image.permute(1, 2, 0) 174 | video_result = video_result[:-skip_frames] 175 | # video_result = video_result[:8] 176 | 177 | save_root = os.path.dirname(output) 178 | if save_root != "": 179 | os.makedirs(save_root, exist_ok=True) 180 | if output.endswith(".mp4"): 181 | video_result = video_result * 255 182 | write_video(output, video_result, fps=fps) 183 | if save_input: 184 | write_video(output.replace(".mp4", "_inp.mp4"), video * 255, fps=fps) 185 | elif output.endswith(".gif"): 186 | save_videos_grid( 187 | video_result.permute(3, 0, 1, 2)[None, ...], 188 | output, 189 | rescale=False, 190 | fps=fps, 191 | ) 192 | if save_input: 193 | save_videos_grid( 194 | video.permute(3, 0, 1, 2)[None, ...], 195 | output.replace(".gif", "_inp.gif"), 196 | rescale=False, 197 | fps=fps, 198 | ) 199 | else: 200 | raise TypeError(f"Unsupported output format: {output}") 201 | print("Inference time ema: ", stream.stream.inference_time_ema) 202 | inference_time_list = np.array(stream.stream.inference_time_list) 203 | print(f"Inference time mean & std: {inference_time_list.mean()} +/- {inference_time_list.std()}") 204 | if hasattr(stream.stream, "depth_time_ema"): 205 | print("Depth time ema: ", stream.stream.depth_time_ema) 206 | 207 | print(f'Video saved to "{output}".') 208 | 209 | 210 | if __name__ == "__main__": 211 | fire.Fire(main) 212 | --------------------------------------------------------------------------------