113 | There are {currentQueueSize} 114 | user(s) sharing the same GPU, affecting real-time performance. Maximum queue size is {maxQueueSize}. 115 | Duplicate and run it on your own GPU. 120 |
121 | {/if} 122 |Loading...
152 |├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── assets ├── attn-mask.png └── framework.jpg ├── configs ├── base_config.yaml ├── disneyPixar.yaml ├── kFelted.yaml ├── moxin.yaml ├── origami.yaml ├── pixart.yaml └── toonyou.yaml ├── demo ├── .gitattributes ├── .gitignore ├── README.md ├── app.py ├── config.py ├── connection_manager.py ├── demo_cfg.yaml ├── demo_cfg_arknight.yaml ├── frontend │ ├── .eslintignore │ ├── .eslintrc.cjs │ ├── .gitignore │ ├── .npmrc │ ├── .prettierignore │ ├── .prettierrc │ ├── README.md │ ├── package-lock.json │ ├── package.json │ ├── postcss.config.js │ ├── src │ │ ├── app.css │ │ ├── app.d.ts │ │ ├── app.html │ │ ├── lib │ │ │ ├── components │ │ │ │ ├── Button.svelte │ │ │ │ ├── Checkbox.svelte │ │ │ │ ├── ImagePlayer.svelte │ │ │ │ ├── InputRange.svelte │ │ │ │ ├── MediaListSwitcher.svelte │ │ │ │ ├── PipelineOptions.svelte │ │ │ │ ├── SeedInput.svelte │ │ │ │ ├── Selectlist.svelte │ │ │ │ ├── TextArea.svelte │ │ │ │ ├── VideoInput.svelte │ │ │ │ └── Warning.svelte │ │ │ ├── icons │ │ │ │ ├── floppy.svelte │ │ │ │ ├── screen.svelte │ │ │ │ └── spinner.svelte │ │ │ ├── index.ts │ │ │ ├── lcmLive.ts │ │ │ ├── mediaStream.ts │ │ │ ├── store.ts │ │ │ ├── types.ts │ │ │ └── utils.ts │ │ └── routes │ │ │ ├── +layout.svelte │ │ │ ├── +page.svelte │ │ │ └── +page.ts │ ├── svelte.config.js │ ├── tailwind.config.js │ ├── tsconfig.json │ └── vite.config.ts ├── main.py ├── requirements.txt ├── start.sh ├── util.py └── vid2vid.py ├── live2diff ├── __init__.py ├── acceleration │ ├── __init__.py │ └── tensorrt │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── engine.py │ │ ├── models.py │ │ └── utilities.py ├── animatediff │ ├── __init__.py │ ├── converter │ │ ├── __init__.py │ │ ├── convert.py │ │ ├── convert_from_ckpt.py │ │ └── convert_lora_safetensor_to_diffusers.py │ ├── models │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── depth_utils.py │ │ ├── motion_module.py │ │ ├── positional_encoding.py │ │ ├── resnet.py │ │ ├── stream_motion_module.py │ │ ├── unet_blocks_streaming.py │ │ ├── unet_blocks_warmup.py │ │ ├── unet_depth_streaming.py │ │ └── unet_depth_warmup.py │ └── pipeline │ │ ├── __init__.py │ │ ├── loader.py │ │ └── pipeline_animatediff_depth.py ├── image_filter.py ├── image_utils.py ├── pipeline_stream_animation_depth.py └── utils │ ├── __init__.py │ ├── config.py │ ├── io.py │ └── wrapper.py ├── pyproject.toml ├── scripts └── download.sh ├── setup.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # https://github.com/github/gitignore/blob/main/Python.gitignore 2 | 3 | .vscode/ 4 | engines/ 5 | output/ 6 | *.csv 7 | *.mp4 8 | *.png 9 | !assets/*.mp4 10 | !assets/*.png 11 | *.safetensors 12 | result_lcm.png 13 | model.ckpt 14 | !images/inputs/input.png 15 | 16 | # Byte-compiled / optimized / DLL files 17 | __pycache__/ 18 | *.py[cod] 19 | *$py.class 20 | 21 | # C extensions 22 | *.so 23 | 24 | # Distribution / packaging 25 | .Python 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .nox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | *.py,cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | cover/ 68 | 69 | # Translations 70 | *.mo 71 | *.pot 72 | 73 | # Django stuff: 74 | *.log 75 | local_settings.py 76 | db.sqlite3 77 | db.sqlite3-journal 78 | 79 | # Flask stuff: 80 | instance/ 81 | .webassets-cache 82 | 83 | # Scrapy stuff: 84 | .scrapy 85 | 86 | # Sphinx documentation 87 | docs/_build/ 88 | 89 | # PyBuilder 90 | .pybuilder/ 91 | target/ 92 | 93 | # Jupyter Notebook 94 | .ipynb_checkpoints 95 | 96 | # IPython 97 | profile_default/ 98 | ipython_config.py 99 | 100 | # pyenv 101 | # For a library or package, you might want to ignore these files since the code is 102 | # intended to run in multiple environments; otherwise, check them in: 103 | # .python-version 104 | 105 | # pipenv 106 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 107 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 108 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 109 | # install all needed dependencies. 110 | #Pipfile.lock 111 | 112 | # poetry 113 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 114 | # This is especially recommended for binary packages to ensure reproducibility, and is more 115 | # commonly ignored for libraries. 116 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 117 | #poetry.lock 118 | 119 | # pdm 120 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 121 | #pdm.lock 122 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 123 | # in version control. 124 | # https://pdm.fming.dev/#use-with-ide 125 | .pdm.toml 126 | 127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 128 | __pypackages__/ 129 | 130 | # Celery stuff 131 | celerybeat-schedule 132 | celerybeat.pid 133 | 134 | # SageMath parsed files 135 | *.sage.py 136 | 137 | # Environments 138 | .env 139 | .venv 140 | env/ 141 | venv/ 142 | ENV/ 143 | env.bak/ 144 | venv.bak/ 145 | 146 | # Spyder project settings 147 | .spyderproject 148 | .spyproject 149 | 150 | # Rope project settings 151 | .ropeproject 152 | 153 | # mkdocs documentation 154 | /site 155 | 156 | # mypy 157 | .mypy_cache/ 158 | .dmypy.json 159 | dmypy.json 160 | 161 | # Pyre type checker 162 | .pyre/ 163 | 164 | # pytype static type analyzer 165 | .pytype/ 166 | 167 | # Cython debug symbols 168 | cython_debug/ 169 | 170 | # PyCharm 171 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 172 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 173 | # and can be added to the global gitignore or merged into this file. For a more nuclear 174 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 175 | #.idea/ 176 | 177 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 178 | 179 | # dependencies 180 | *node_modules 181 | */.pnp 182 | .pnp.js 183 | 184 | # testing 185 | */coverage 186 | 187 | # production 188 | */build 189 | 190 | # misc 191 | .DS_Store 192 | .env.local 193 | .env.development.local 194 | .env.test.local 195 | .env.production.local 196 | 197 | npm-debug.log* 198 | yarn-debug.log* 199 | yarn-error.log* 200 | 201 | *.venv 202 | 203 | __pycache__/ 204 | *.py[cod] 205 | *$py.class 206 | 207 | models/RealESR* 208 | *.safetensors 209 | 210 | work_dirs/ 211 | tests/ 212 | data/ 213 | 214 | models/Model/ 215 | *.safetensors 216 | *.ckpt 217 | *.pt 218 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "live2diff/MiDaS"] 2 | path = live2diff/MiDaS 3 | url = git@github.com:lewiji/MiDaS.git 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.3.5 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | args: [ --fix ] 9 | # Run the formatter. 10 | - id: ruff-format 11 | - repo: https://github.com/codespell-project/codespell 12 | rev: v2.2.1 13 | hooks: 14 | - id: codespell 15 | args: ["-L", "warmup,mose,parms", "--skip", "*.json"] 16 | - repo: https://github.com/pre-commit/pre-commit-hooks 17 | rev: v4.3.0 18 | hooks: 19 | - id: trailing-whitespace 20 | - id: check-yaml 21 | - id: end-of-file-fixer 22 | - id: requirements-txt-fixer 23 | - id: fix-encoding-pragma 24 | args: ["--remove"] 25 | - id: mixed-line-ending 26 | args: ["--fix=lf"] 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Live2Diff: **Live** Stream Translation via Uni-directional Attention in Video **Diffusion** Models 2 | 3 |
4 |
5 |
34 |
35 |
191 | Human Face (Web Camera Input) 192 | |
193 |
194 | Anime Character (Screen Video Input) 195 | |
196 |
199 | | 201 |202 | | 204 |
8 | Human Face (Web Camera Input) 9 | |
10 |
11 | Anime Character (Screen Video Input) 12 | |
13 |
16 | | 18 |19 | | 21 |
113 | There are {currentQueueSize} 114 | user(s) sharing the same GPU, affecting real-time performance. Maximum queue size is {maxQueueSize}. 115 | Duplicate and run it on your own GPU. 120 |
121 | {/if} 122 |Loading...
152 |26 | This demo showcases 27 | Live2Diff 31 | 32 | pipeline using 33 | LCM-LoRA with a MJPEG stream server. 38 |
39 | """ 40 | 41 | 42 | WARMUP_FRAMES = 8 43 | WINDOW_SIZE = 16 44 | 45 | 46 | class Pipeline: 47 | class Info(BaseModel): 48 | name: str = "Live2Diff" 49 | input_mode: str = "image" 50 | page_content: str = page_content 51 | 52 | def build_input_params(self, default_prompt: str = default_prompt, width=512, height=512): 53 | class InputParams(BaseModel): 54 | prompt: str = Field( 55 | default_prompt, 56 | title="Prompt", 57 | field="textarea", 58 | id="prompt", 59 | ) 60 | width: int = Field( 61 | 512, 62 | min=2, 63 | max=15, 64 | title="Width", 65 | disabled=True, 66 | hide=True, 67 | id="width", 68 | ) 69 | height: int = Field( 70 | 512, 71 | min=2, 72 | max=15, 73 | title="Height", 74 | disabled=True, 75 | hide=True, 76 | id="height", 77 | ) 78 | 79 | return InputParams 80 | 81 | def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype): 82 | config_path = args.config 83 | 84 | cfg = load_config(config_path) 85 | prompt = args.prompt or cfg.prompt or default_prompt 86 | 87 | self.InputParams = self.build_input_params(default_prompt=prompt) 88 | params = self.InputParams() 89 | 90 | num_inference_steps = args.num_inference_steps or cfg.get("num_inference_steps", None) 91 | strength = args.strength or cfg.get("strength", None) 92 | t_index_list = args.t_index_list or cfg.get("t_index_list", None) 93 | 94 | self.stream = StreamAnimateDiffusionDepthWrapper( 95 | few_step_model_type="lcm", 96 | config_path=config_path, 97 | cfg_type="none", 98 | strength=strength, 99 | num_inference_steps=num_inference_steps, 100 | t_index_list=t_index_list, 101 | frame_buffer_size=1, 102 | width=params.width, 103 | height=params.height, 104 | acceleration=args.acceleration, 105 | do_add_noise=True, 106 | output_type="pil", 107 | enable_similar_image_filter=True, 108 | similar_image_filter_threshold=0.98, 109 | use_denoising_batch=True, 110 | use_tiny_vae=True, 111 | seed=args.seed, 112 | engine_dir=args.engine_dir, 113 | ) 114 | 115 | self.last_prompt = prompt 116 | 117 | self.warmup_frame_list = [] 118 | self.has_prepared = False 119 | 120 | def predict(self, params: "Pipeline.InputParams") -> Image.Image: 121 | prompt = params.prompt 122 | if prompt != self.last_prompt: 123 | self.last_prompt = prompt 124 | self.warmup_frame_list.clear() 125 | 126 | if len(self.warmup_frame_list) < WARMUP_FRAMES: 127 | # from PIL import Image 128 | self.warmup_frame_list.append(self.stream.preprocess_image(params.image)) 129 | 130 | elif len(self.warmup_frame_list) == WARMUP_FRAMES and not self.has_prepared: 131 | warmup_frames = torch.stack(self.warmup_frame_list) 132 | self.stream.prepare( 133 | warmup_frames=warmup_frames, 134 | prompt=prompt, 135 | guidance_scale=1, 136 | ) 137 | self.has_prepared = True 138 | 139 | if self.has_prepared: 140 | image_tensor = self.stream.preprocess_image(params.image) 141 | output_image = self.stream(image=image_tensor) 142 | return output_image 143 | else: 144 | return Image.new("RGB", (params.width, params.height)) 145 | -------------------------------------------------------------------------------- /live2diff/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_stream_animation_depth import StreamAnimateDiffusionDepth 2 | 3 | 4 | __all__ = ["StreamAnimateDiffusionDepth"] 5 | -------------------------------------------------------------------------------- /live2diff/acceleration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/Live2Diff/e40aa03a9a9d17f1232fd7a4566b3ee793e6893f/live2diff/acceleration/__init__.py -------------------------------------------------------------------------------- /live2diff/acceleration/tensorrt/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from diffusers import AutoencoderKL 4 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( 5 | retrieve_latents, 6 | ) 7 | 8 | from .builder import EngineBuilder 9 | from .models import BaseModel 10 | 11 | 12 | class TorchVAEEncoder(torch.nn.Module): 13 | def __init__(self, vae: AutoencoderKL): 14 | super().__init__() 15 | self.vae = vae 16 | 17 | def forward(self, x: torch.Tensor): 18 | return retrieve_latents(self.vae.encode(x)) 19 | 20 | 21 | def compile_engine( 22 | torch_model: nn.Module, 23 | model_data: BaseModel, 24 | onnx_path: str, 25 | onnx_opt_path: str, 26 | engine_path: str, 27 | opt_image_height: int = 512, 28 | opt_image_width: int = 512, 29 | opt_batch_size: int = 1, 30 | engine_build_options: dict = {}, 31 | ): 32 | builder = EngineBuilder( 33 | model_data, 34 | torch_model, 35 | device=torch.device("cuda"), 36 | ) 37 | builder.build( 38 | onnx_path, 39 | onnx_opt_path, 40 | engine_path, 41 | opt_image_height=opt_image_height, 42 | opt_image_width=opt_image_width, 43 | opt_batch_size=opt_batch_size, 44 | **engine_build_options, 45 | ) 46 | -------------------------------------------------------------------------------- /live2diff/acceleration/tensorrt/builder.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | from typing import * 4 | 5 | import torch 6 | 7 | from .models import BaseModel 8 | from .utilities import ( 9 | build_engine, 10 | export_onnx, 11 | handle_onnx_batch_norm, 12 | optimize_onnx, 13 | ) 14 | 15 | 16 | class EngineBuilder: 17 | def __init__( 18 | self, 19 | model: BaseModel, 20 | network: Any, 21 | device=torch.device("cuda"), 22 | ): 23 | self.device = device 24 | 25 | self.model = model 26 | self.network = network 27 | 28 | def build( 29 | self, 30 | onnx_path: str, 31 | onnx_opt_path: str, 32 | engine_path: str, 33 | opt_image_height: int = 512, 34 | opt_image_width: int = 512, 35 | opt_batch_size: int = 1, 36 | min_image_resolution: int = 256, 37 | max_image_resolution: int = 1024, 38 | build_enable_refit: bool = False, 39 | build_static_batch: bool = False, 40 | build_dynamic_shape: bool = False, 41 | build_all_tactics: bool = False, 42 | onnx_opset: int = 17, 43 | force_engine_build: bool = False, 44 | force_onnx_export: bool = False, 45 | force_onnx_optimize: bool = False, 46 | ignore_onnx_optimize: bool = False, 47 | auto_cast: bool = True, 48 | handle_batch_norm: bool = False, 49 | ): 50 | if not force_onnx_export and os.path.exists(onnx_path): 51 | print(f"Found cached model: {onnx_path}") 52 | else: 53 | print(f"Exporting model: {onnx_path}") 54 | export_onnx( 55 | self.network, 56 | onnx_path=onnx_path, 57 | model_data=self.model, 58 | opt_image_height=opt_image_height, 59 | opt_image_width=opt_image_width, 60 | opt_batch_size=opt_batch_size, 61 | onnx_opset=onnx_opset, 62 | auto_cast=auto_cast, 63 | ) 64 | del self.network 65 | gc.collect() 66 | torch.cuda.empty_cache() 67 | 68 | if handle_batch_norm: 69 | print(f"Handle Batch Norm for {onnx_path}") 70 | handle_onnx_batch_norm(onnx_path) 71 | 72 | if ignore_onnx_optimize: 73 | print(f"Ignore onnx optimize for {onnx_path}.") 74 | onnx_opt_path = onnx_path 75 | elif not force_onnx_optimize and os.path.exists(onnx_opt_path): 76 | print(f"Found cached model: {onnx_opt_path}") 77 | else: 78 | print(f"Generating optimizing model: {onnx_opt_path}") 79 | optimize_onnx( 80 | onnx_path=onnx_path, 81 | onnx_opt_path=onnx_opt_path, 82 | model_data=self.model, 83 | ) 84 | self.model.min_latent_shape = min_image_resolution // 8 85 | self.model.max_latent_shape = max_image_resolution // 8 86 | if not force_engine_build and os.path.exists(engine_path): 87 | print(f"Found cached engine: {engine_path}") 88 | else: 89 | build_engine( 90 | engine_path=engine_path, 91 | onnx_opt_path=onnx_opt_path, 92 | model_data=self.model, 93 | opt_image_height=opt_image_height, 94 | opt_image_width=opt_image_width, 95 | opt_batch_size=opt_batch_size, 96 | build_static_batch=build_static_batch, 97 | build_dynamic_shape=build_dynamic_shape, 98 | build_all_tactics=build_all_tactics, 99 | build_enable_refit=build_enable_refit, 100 | ) 101 | 102 | gc.collect() 103 | torch.cuda.empty_cache() 104 | -------------------------------------------------------------------------------- /live2diff/acceleration/tensorrt/engine.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import torch 4 | from polygraphy import cuda 5 | 6 | from live2diff.animatediff.models.unet_depth_streaming import UNet3DConditionStreamingOutput 7 | 8 | from .utilities import Engine 9 | 10 | 11 | try: 12 | from diffusers.models.autoencoder_tiny import AutoencoderTinyOutput 13 | except ImportError: 14 | from dataclasses import dataclass 15 | 16 | from diffusers.utils import BaseOutput 17 | 18 | @dataclass 19 | class AutoencoderTinyOutput(BaseOutput): 20 | """ 21 | Output of AutoencoderTiny encoding method. 22 | 23 | Args: 24 | latents (`torch.Tensor`): Encoded outputs of the `Encoder`. 25 | 26 | """ 27 | 28 | latents: torch.Tensor 29 | 30 | 31 | try: 32 | from diffusers.models.vae import DecoderOutput 33 | except ImportError: 34 | from dataclasses import dataclass 35 | 36 | from diffusers.utils import BaseOutput 37 | 38 | @dataclass 39 | class DecoderOutput(BaseOutput): 40 | r""" 41 | Output of decoding method. 42 | 43 | Args: 44 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 45 | The decoded output sample from the last layer of the model. 46 | """ 47 | 48 | sample: torch.FloatTensor 49 | 50 | 51 | class AutoencoderKLEngine: 52 | def __init__( 53 | self, 54 | encoder_path: str, 55 | decoder_path: str, 56 | stream: cuda.Stream, 57 | scaling_factor: int, 58 | use_cuda_graph: bool = False, 59 | ): 60 | self.encoder = Engine(encoder_path) 61 | self.decoder = Engine(decoder_path) 62 | self.stream = stream 63 | self.vae_scale_factor = scaling_factor 64 | self.use_cuda_graph = use_cuda_graph 65 | 66 | self.encoder.load() 67 | self.decoder.load() 68 | self.encoder.activate() 69 | self.decoder.activate() 70 | 71 | def encode(self, images: torch.Tensor, **kwargs): 72 | self.encoder.allocate_buffers( 73 | shape_dict={ 74 | "images": images.shape, 75 | "latent": ( 76 | images.shape[0], 77 | 4, 78 | images.shape[2] // self.vae_scale_factor, 79 | images.shape[3] // self.vae_scale_factor, 80 | ), 81 | }, 82 | device=images.device, 83 | ) 84 | latents = self.encoder.infer( 85 | {"images": images}, 86 | self.stream, 87 | use_cuda_graph=self.use_cuda_graph, 88 | )["latent"] 89 | return AutoencoderTinyOutput(latents=latents) 90 | 91 | def decode(self, latent: torch.Tensor, **kwargs): 92 | self.decoder.allocate_buffers( 93 | shape_dict={ 94 | "latent": latent.shape, 95 | "images": ( 96 | latent.shape[0], 97 | 3, 98 | latent.shape[2] * self.vae_scale_factor, 99 | latent.shape[3] * self.vae_scale_factor, 100 | ), 101 | }, 102 | device=latent.device, 103 | ) 104 | images = self.decoder.infer( 105 | {"latent": latent}, 106 | self.stream, 107 | use_cuda_graph=self.use_cuda_graph, 108 | )["images"] 109 | return DecoderOutput(sample=images) 110 | 111 | def to(self, *args, **kwargs): 112 | pass 113 | 114 | def forward(self, *args, **kwargs): 115 | pass 116 | 117 | 118 | class UNet2DConditionModelDepthEngine: 119 | def __init__(self, filepath: str, stream: cuda.Stream, use_cuda_graph: bool = False): 120 | self.engine = Engine(filepath) 121 | self.stream = stream 122 | self.use_cuda_graph = use_cuda_graph 123 | 124 | self.init_profiler() 125 | 126 | self.engine.load() 127 | self.engine.activate(profiler=self.profiler) 128 | self.has_allocated = False 129 | 130 | def init_profiler(self): 131 | import tensorrt 132 | 133 | class Profiler(tensorrt.IProfiler): 134 | def __init__(self): 135 | tensorrt.IProfiler.__init__(self) 136 | 137 | def report_layer_time(self, layer_name, ms): 138 | print(f"{layer_name}: {ms} ms") 139 | 140 | self.profiler = Profiler() 141 | 142 | def __call__( 143 | self, 144 | latent_model_input: torch.Tensor, 145 | timestep: torch.Tensor, 146 | encoder_hidden_states: torch.Tensor, 147 | temporal_attention_mask: torch.Tensor, 148 | depth_sample: torch.Tensor, 149 | kv_cache: List[torch.Tensor], 150 | pe_idx: torch.Tensor, 151 | update_idx: torch.Tensor, 152 | **kwargs, 153 | ) -> Any: 154 | if timestep.dtype != torch.float32: 155 | timestep = timestep.float() 156 | 157 | feed_dict = { 158 | "sample": latent_model_input, 159 | "timestep": timestep, 160 | "encoder_hidden_states": encoder_hidden_states, 161 | "temporal_attention_mask": temporal_attention_mask, 162 | "depth_sample": depth_sample, 163 | "pe_idx": pe_idx, 164 | "update_idx": update_idx, 165 | } 166 | for idx, cache in enumerate(kv_cache): 167 | feed_dict[f"kv_cache_{idx}"] = cache 168 | shape_dict = {k: v.shape for k, v in feed_dict.items()} 169 | 170 | if not self.has_allocated: 171 | self.engine.allocate_buffers( 172 | shape_dict=shape_dict, 173 | device=latent_model_input.device, 174 | ) 175 | self.has_allocated = True 176 | 177 | output = self.engine.infer( 178 | feed_dict, 179 | self.stream, 180 | use_cuda_graph=self.use_cuda_graph, 181 | ) 182 | 183 | noise_pred = output["latent"] 184 | kv_cache = [output[f"kv_cache_out_{idx}"] for idx in range(len(kv_cache))] 185 | return UNet3DConditionStreamingOutput(sample=noise_pred, kv_cache=kv_cache) 186 | 187 | def to(self, *args, **kwargs): 188 | pass 189 | 190 | def forward(self, *args, **kwargs): 191 | pass 192 | 193 | 194 | class MidasEngine: 195 | def __init__(self, filepath: str, stream: cuda.Stream, use_cuda_graph: bool = False): 196 | self.engine = Engine(filepath) 197 | self.stream = stream 198 | self.use_cuda_graph = use_cuda_graph 199 | 200 | self.engine.load() 201 | self.engine.activate() 202 | self.has_allocated = False 203 | self.default_batch_size = 1 204 | 205 | def __call__( 206 | self, 207 | images: torch.Tensor, 208 | **kwargs, 209 | ) -> Any: 210 | if not self.has_allocated or images.shape[0] != self.default_batch_size: 211 | bz = images.shape[0] 212 | self.engine.allocate_buffers( 213 | shape_dict={ 214 | "images": (bz, 3, 384, 384), 215 | "depth_map": (bz, 384, 384), 216 | }, 217 | device=images.device, 218 | ) 219 | self.has_allocated = True 220 | self.default_batch_size = bz 221 | 222 | depth_map = self.engine.infer( 223 | { 224 | "images": images, 225 | }, 226 | self.stream, 227 | use_cuda_graph=self.use_cuda_graph, 228 | )["depth_map"] # (1, 384, 384) 229 | 230 | return depth_map 231 | 232 | def norm(self, x): 233 | return (x - x.min()) / (x.max() - x.min()) 234 | 235 | def to(self, *args, **kwargs): 236 | pass 237 | 238 | def forward(self, *args, **kwargs): 239 | pass 240 | -------------------------------------------------------------------------------- /live2diff/acceleration/tensorrt/utilities.py: -------------------------------------------------------------------------------- 1 | #! fork: https://github.com/NVIDIA/TensorRT/blob/main/demo/Diffusion/utilities.py 2 | 3 | # 4 | # Copyright 2022 The HuggingFace Inc. team. 5 | # SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 6 | # SPDX-License-Identifier: Apache-2.0 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | # 20 | 21 | import gc 22 | from collections import OrderedDict 23 | from typing import * 24 | 25 | import numpy as np 26 | import onnx 27 | import onnx_graphsurgeon as gs 28 | import tensorrt as trt 29 | import torch 30 | from cuda import cudart 31 | from PIL import Image 32 | from polygraphy import cuda 33 | from polygraphy.backend.common import bytes_from_path 34 | from polygraphy.backend.trt import ( 35 | CreateConfig, 36 | Profile, 37 | engine_from_bytes, 38 | engine_from_network, 39 | network_from_onnx_path, 40 | save_engine, 41 | ) 42 | 43 | from .models import BaseModel 44 | 45 | 46 | TRT_LOGGER = trt.Logger(trt.Logger.ERROR) 47 | 48 | # Map of numpy dtype -> torch dtype 49 | numpy_to_torch_dtype_dict = { 50 | np.uint8: torch.uint8, 51 | np.int8: torch.int8, 52 | np.int16: torch.int16, 53 | np.int32: torch.int32, 54 | np.int64: torch.int64, 55 | np.float16: torch.float16, 56 | np.float32: torch.float32, 57 | np.float64: torch.float64, 58 | np.complex64: torch.complex64, 59 | np.complex128: torch.complex128, 60 | } 61 | if np.version.full_version >= "1.24.0": 62 | numpy_to_torch_dtype_dict[np.bool_] = torch.bool 63 | else: 64 | numpy_to_torch_dtype_dict[np.bool] = torch.bool 65 | 66 | # Map of torch dtype -> numpy dtype 67 | torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} 68 | 69 | 70 | def CUASSERT(cuda_ret): 71 | err = cuda_ret[0] 72 | if err != cudart.cudaError_t.cudaSuccess: 73 | raise RuntimeError( 74 | f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t" 75 | ) 76 | if len(cuda_ret) > 1: 77 | return cuda_ret[1] 78 | return None 79 | 80 | 81 | class Engine: 82 | def __init__( 83 | self, 84 | engine_path, 85 | ): 86 | self.engine_path = engine_path 87 | self.engine = None 88 | self.context = None 89 | self.buffers = OrderedDict() 90 | self.tensors = OrderedDict() 91 | self.cuda_graph_instance = None # cuda graph 92 | 93 | def __del__(self): 94 | [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)] 95 | del self.engine 96 | del self.context 97 | del self.buffers 98 | del self.tensors 99 | 100 | def refit(self, onnx_path, onnx_refit_path): 101 | def convert_int64(arr): 102 | # TODO: smarter conversion 103 | if len(arr.shape) == 0: 104 | return np.int32(arr) 105 | return arr 106 | 107 | def add_to_map(refit_dict, name, values): 108 | if name in refit_dict: 109 | assert refit_dict[name] is None 110 | if values.dtype == np.int64: 111 | values = convert_int64(values) 112 | refit_dict[name] = values 113 | 114 | print(f"Refitting TensorRT engine with {onnx_refit_path} weights") 115 | refit_nodes = gs.import_onnx(onnx.load(onnx_refit_path)).toposort().nodes 116 | 117 | # Construct mapping from weight names in refit model -> original model 118 | name_map = {} 119 | for n, node in enumerate(gs.import_onnx(onnx.load(onnx_path)).toposort().nodes): 120 | refit_node = refit_nodes[n] 121 | assert node.op == refit_node.op 122 | # Constant nodes in ONNX do not have inputs but have a constant output 123 | if node.op == "Constant": 124 | name_map[refit_node.outputs[0].name] = node.outputs[0].name 125 | # Handle scale and bias weights 126 | elif node.op == "Conv": 127 | if node.inputs[1].__class__ == gs.Constant: 128 | name_map[refit_node.name + "_TRTKERNEL"] = node.name + "_TRTKERNEL" 129 | if node.inputs[2].__class__ == gs.Constant: 130 | name_map[refit_node.name + "_TRTBIAS"] = node.name + "_TRTBIAS" 131 | # For all other nodes: find node inputs that are initializers (gs.Constant) 132 | else: 133 | for i, inp in enumerate(node.inputs): 134 | if inp.__class__ == gs.Constant: 135 | name_map[refit_node.inputs[i].name] = inp.name 136 | 137 | def map_name(name): 138 | if name in name_map: 139 | return name_map[name] 140 | return name 141 | 142 | # Construct refit dictionary 143 | refit_dict = {} 144 | refitter = trt.Refitter(self.engine, TRT_LOGGER) 145 | all_weights = refitter.get_all() 146 | for layer_name, role in zip(all_weights[0], all_weights[1]): 147 | # for speciailized roles, use a unique name in the map: 148 | if role == trt.WeightsRole.KERNEL: 149 | name = layer_name + "_TRTKERNEL" 150 | elif role == trt.WeightsRole.BIAS: 151 | name = layer_name + "_TRTBIAS" 152 | else: 153 | name = layer_name 154 | 155 | assert name not in refit_dict, "Found duplicate layer: " + name 156 | refit_dict[name] = None 157 | 158 | for n in refit_nodes: 159 | # Constant nodes in ONNX do not have inputs but have a constant output 160 | if n.op == "Constant": 161 | name = map_name(n.outputs[0].name) 162 | print(f"Add Constant {name}\n") 163 | add_to_map(refit_dict, name, n.outputs[0].values) 164 | 165 | # Handle scale and bias weights 166 | elif n.op == "Conv": 167 | if n.inputs[1].__class__ == gs.Constant: 168 | name = map_name(n.name + "_TRTKERNEL") 169 | add_to_map(refit_dict, name, n.inputs[1].values) 170 | 171 | if n.inputs[2].__class__ == gs.Constant: 172 | name = map_name(n.name + "_TRTBIAS") 173 | add_to_map(refit_dict, name, n.inputs[2].values) 174 | 175 | # For all other nodes: find node inputs that are initializers (AKA gs.Constant) 176 | else: 177 | for inp in n.inputs: 178 | name = map_name(inp.name) 179 | if inp.__class__ == gs.Constant: 180 | add_to_map(refit_dict, name, inp.values) 181 | 182 | for layer_name, weights_role in zip(all_weights[0], all_weights[1]): 183 | if weights_role == trt.WeightsRole.KERNEL: 184 | custom_name = layer_name + "_TRTKERNEL" 185 | elif weights_role == trt.WeightsRole.BIAS: 186 | custom_name = layer_name + "_TRTBIAS" 187 | else: 188 | custom_name = layer_name 189 | 190 | # Skip refitting Trilu for now; scalar weights of type int64 value 1 - for clip model 191 | if layer_name.startswith("onnx::Trilu"): 192 | continue 193 | 194 | if refit_dict[custom_name] is not None: 195 | refitter.set_weights(layer_name, weights_role, refit_dict[custom_name]) 196 | else: 197 | print(f"[W] No refit weights for layer: {layer_name}") 198 | 199 | if not refitter.refit_cuda_engine(): 200 | print("Failed to refit!") 201 | exit(0) 202 | 203 | def build( 204 | self, 205 | onnx_path, 206 | fp16, 207 | input_profile=None, 208 | enable_refit=False, 209 | enable_all_tactics=False, 210 | timing_cache=None, 211 | workspace_size=0, 212 | ): 213 | print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") 214 | p = Profile() 215 | if input_profile: 216 | for name, dims in input_profile.items(): 217 | assert len(dims) == 3 218 | p.add(name, min=dims[0], opt=dims[1], max=dims[2]) 219 | 220 | config_kwargs = {} 221 | 222 | if workspace_size > 0: 223 | config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size} 224 | if not enable_all_tactics: 225 | config_kwargs["tactic_sources"] = [] 226 | 227 | engine = engine_from_network( 228 | network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]), 229 | config=CreateConfig( 230 | fp16=fp16, refittable=enable_refit, profiles=[p], load_timing_cache=timing_cache, **config_kwargs 231 | ), 232 | save_timing_cache=timing_cache, 233 | ) 234 | save_engine(engine, path=self.engine_path) 235 | 236 | def load(self): 237 | print(f"Loading TensorRT engine: {self.engine_path}") 238 | self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) 239 | 240 | def activate(self, reuse_device_memory=None, profiler=None): 241 | if reuse_device_memory: 242 | self.context = self.engine.create_execution_context_without_device_memory() 243 | self.context.device_memory = reuse_device_memory 244 | else: 245 | self.context = self.engine.create_execution_context() 246 | 247 | def allocate_buffers(self, shape_dict=None, device="cuda"): 248 | # NOTE: API for tensorrt 10.01 249 | from tensorrt import TensorIOMode 250 | 251 | for idx in range(self.engine.num_io_tensors): 252 | binding = self.engine[idx] 253 | if shape_dict and binding in shape_dict: 254 | shape = shape_dict[binding] 255 | else: 256 | shape = self.engine.get_tensor_shape(binding) 257 | dtype = trt.nptype(self.engine.get_tensor_dtype(binding)) 258 | tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype], device=device) 259 | self.tensors[binding] = tensor 260 | 261 | binding_mode = self.engine.get_tensor_mode(binding) 262 | if binding_mode == TensorIOMode.INPUT: 263 | self.context.set_input_shape(binding, shape) 264 | self.has_allocated = True 265 | 266 | def infer(self, feed_dict, stream, use_cuda_graph=False): 267 | for name, buf in feed_dict.items(): 268 | self.tensors[name].copy_(buf) 269 | 270 | for name, tensor in self.tensors.items(): 271 | self.context.set_tensor_address(name, tensor.data_ptr()) 272 | 273 | if use_cuda_graph: 274 | if self.cuda_graph_instance is not None: 275 | CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream.ptr)) 276 | CUASSERT(cudart.cudaStreamSynchronize(stream.ptr)) 277 | else: 278 | # do inference before CUDA graph capture 279 | noerror = self.context.execute_async_v3(stream.ptr) 280 | if not noerror: 281 | raise ValueError("ERROR: inference failed.") 282 | # capture cuda graph 283 | CUASSERT( 284 | cudart.cudaStreamBeginCapture(stream.ptr, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal) 285 | ) 286 | self.context.execute_async_v3(stream.ptr) 287 | self.graph = CUASSERT(cudart.cudaStreamEndCapture(stream.ptr)) 288 | self.cuda_graph_instance = CUASSERT(cudart.cudaGraphInstantiate(self.graph, 0)) 289 | else: 290 | noerror = self.context.execute_async_v3(stream.ptr) 291 | if not noerror: 292 | raise ValueError("ERROR: inference failed.") 293 | 294 | return self.tensors 295 | 296 | 297 | def decode_images(images: torch.Tensor): 298 | images = ( 299 | ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy() 300 | ) 301 | return [Image.fromarray(x) for x in images] 302 | 303 | 304 | def preprocess_image(image: Image.Image): 305 | w, h = image.size 306 | w, h = [x - x % 32 for x in (w, h)] # resize to integer multiple of 32 307 | image = image.resize((w, h)) 308 | init_image = np.array(image).astype(np.float32) / 255.0 309 | init_image = init_image[None].transpose(0, 3, 1, 2) 310 | init_image = torch.from_numpy(init_image).contiguous() 311 | return 2.0 * init_image - 1.0 312 | 313 | 314 | def prepare_mask_and_masked_image(image: Image.Image, mask: Image.Image): 315 | if isinstance(image, Image.Image): 316 | image = np.array(image.convert("RGB")) 317 | image = image[None].transpose(0, 3, 1, 2) 318 | image = torch.from_numpy(image).to(dtype=torch.float32).contiguous() / 127.5 - 1.0 319 | if isinstance(mask, Image.Image): 320 | mask = np.array(mask.convert("L")) 321 | mask = mask.astype(np.float32) / 255.0 322 | mask = mask[None, None] 323 | mask[mask < 0.5] = 0 324 | mask[mask >= 0.5] = 1 325 | mask = torch.from_numpy(mask).to(dtype=torch.float32).contiguous() 326 | 327 | masked_image = image * (mask < 0.5) 328 | 329 | return mask, masked_image 330 | 331 | 332 | def build_engine( 333 | engine_path: str, 334 | onnx_opt_path: str, 335 | model_data: BaseModel, 336 | opt_image_height: int, 337 | opt_image_width: int, 338 | opt_batch_size: int, 339 | build_static_batch: bool = False, 340 | build_dynamic_shape: bool = False, 341 | build_all_tactics: bool = False, 342 | build_enable_refit: bool = False, 343 | ): 344 | _, free_mem, _ = cudart.cudaMemGetInfo() 345 | GiB = 2**30 346 | if free_mem > 6 * GiB: 347 | activation_carveout = 4 * GiB 348 | max_workspace_size = free_mem - activation_carveout 349 | else: 350 | max_workspace_size = 0 351 | engine = Engine(engine_path) 352 | input_profile = model_data.get_input_profile( 353 | opt_batch_size, 354 | opt_image_height, 355 | opt_image_width, 356 | static_batch=build_static_batch, 357 | static_shape=not build_dynamic_shape, 358 | ) 359 | engine.build( 360 | onnx_opt_path, 361 | fp16=True, 362 | input_profile=input_profile, 363 | enable_refit=build_enable_refit, 364 | enable_all_tactics=build_all_tactics, 365 | workspace_size=max_workspace_size, 366 | ) 367 | 368 | return engine 369 | 370 | 371 | def export_onnx( 372 | model, 373 | onnx_path: str, 374 | model_data: BaseModel, 375 | opt_image_height: int, 376 | opt_image_width: int, 377 | opt_batch_size: int, 378 | onnx_opset: int, 379 | auto_cast: bool = True, 380 | ): 381 | from contextlib import contextmanager 382 | 383 | @contextmanager 384 | def auto_cast_manager(enabled): 385 | if enabled: 386 | with torch.inference_mode(), torch.autocast("cuda"): 387 | yield 388 | else: 389 | yield 390 | 391 | with auto_cast_manager(auto_cast): 392 | inputs = model_data.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) 393 | torch.onnx.export( 394 | model, 395 | inputs, 396 | onnx_path, 397 | export_params=True, 398 | opset_version=onnx_opset, 399 | do_constant_folding=True, 400 | input_names=model_data.get_input_names(), 401 | output_names=model_data.get_output_names(), 402 | dynamic_axes=model_data.get_dynamic_axes(), 403 | ) 404 | del model 405 | gc.collect() 406 | torch.cuda.empty_cache() 407 | 408 | 409 | def optimize_onnx( 410 | onnx_path: str, 411 | onnx_opt_path: str, 412 | model_data: BaseModel, 413 | ): 414 | model_data.optimize(onnx_path, onnx_opt_path) 415 | # # onnx_opt_graph = model_data.optimize(onnx.load(onnx_path)) 416 | # onnx_opt_graph = model_data.optimize(onnx_path) 417 | # onnx.save(onnx_opt_graph, onnx_opt_path) 418 | # del onnx_opt_graph 419 | # gc.collect() 420 | # torch.cuda.empty_cache() 421 | 422 | 423 | def handle_onnx_batch_norm(onnx_path: str): 424 | onnx_model = onnx.load(onnx_path) 425 | for node in onnx_model.graph.node: 426 | if node.op_type == "BatchNormalization": 427 | for attribute in node.attribute: 428 | if attribute.name == "training_mode": 429 | if attribute.i == 1: 430 | node.output.remove(node.output[1]) 431 | node.output.remove(node.output[1]) 432 | attribute.i = 0 433 | 434 | onnx.save_model(onnx_model, onnx_path) 435 | -------------------------------------------------------------------------------- /live2diff/animatediff/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/Live2Diff/e40aa03a9a9d17f1232fd7a4566b3ee793e6893f/live2diff/animatediff/__init__.py -------------------------------------------------------------------------------- /live2diff/animatediff/converter/__init__.py: -------------------------------------------------------------------------------- 1 | from .convert import load_third_party_checkpoints, load_third_party_unet 2 | 3 | 4 | __all__ = ["load_third_party_checkpoints", "load_third_party_unet"] 5 | -------------------------------------------------------------------------------- /live2diff/animatediff/converter/convert.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from diffusers.pipelines import StableDiffusionPipeline 5 | from safetensors import safe_open 6 | 7 | from .convert_from_ckpt import convert_ldm_clip_checkpoint, convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint 8 | from .convert_lora_safetensor_to_diffusers import convert_lora_model_level 9 | 10 | 11 | def load_third_party_checkpoints( 12 | pipeline: StableDiffusionPipeline, 13 | third_party_dict: dict, 14 | dreambooth_path: Optional[str] = None, 15 | ): 16 | """ 17 | Modified from https://github.com/open-mmlab/PIA/blob/4b1ee136542e807a13c1adfe52f4e8e5fcc65cdb/animatediff/pipelines/i2v_pipeline.py#L165 18 | """ 19 | vae = third_party_dict.get("vae", None) 20 | lora_list = third_party_dict.get("lora_list", []) 21 | 22 | dreambooth = dreambooth_path or third_party_dict.get("dreambooth", None) 23 | 24 | text_embedding_dict = third_party_dict.get("text_embedding_dict", {}) 25 | 26 | if dreambooth is not None: 27 | dreambooth_state_dict = {} 28 | if dreambooth.endswith(".safetensors"): 29 | with safe_open(dreambooth, framework="pt", device="cpu") as f: 30 | for key in f.keys(): 31 | dreambooth_state_dict[key] = f.get_tensor(key) 32 | else: 33 | dreambooth_state_dict = torch.load(dreambooth, map_location="cpu") 34 | if "state_dict" in dreambooth_state_dict: 35 | dreambooth_state_dict = dreambooth_state_dict["state_dict"] 36 | # load unet 37 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, pipeline.unet.config) 38 | pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) 39 | 40 | # load vae from dreambooth (if need) 41 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, pipeline.vae.config) 42 | # add prefix for compiled model 43 | if "_orig_mod" in list(pipeline.vae.state_dict().keys())[0]: 44 | converted_vae_checkpoint = {f"_orig_mod.{k}": v for k, v in converted_vae_checkpoint.items()} 45 | pipeline.vae.load_state_dict(converted_vae_checkpoint, strict=True) 46 | 47 | # load text encoder (if need) 48 | text_encoder_checkpoint = convert_ldm_clip_checkpoint(dreambooth_state_dict) 49 | if text_encoder_checkpoint: 50 | pipeline.text_encoder.load_state_dict(text_encoder_checkpoint, strict=False) 51 | 52 | if vae is not None: 53 | vae_state_dict = {} 54 | if vae.endswith("safetensors"): 55 | with safe_open(vae, framework="pt", device="cpu") as f: 56 | for key in f.keys(): 57 | vae_state_dict[key] = f.get_tensor(key) 58 | elif vae.endswith("ckpt") or vae.endswith("pt"): 59 | vae_state_dict = torch.load(vae, map_location="cpu") 60 | if "state_dict" in vae_state_dict: 61 | vae_state_dict = vae_state_dict["state_dict"] 62 | 63 | vae_state_dict = {f"first_stage_model.{k}": v for k, v in vae_state_dict.items()} 64 | 65 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_state_dict, pipeline.vae.config) 66 | # add prefix for compiled model 67 | if "_orig_mod" in list(pipeline.vae.state_dict().keys())[0]: 68 | converted_vae_checkpoint = {f"_orig_mod.{k}": v for k, v in converted_vae_checkpoint.items()} 69 | pipeline.vae.load_state_dict(converted_vae_checkpoint, strict=True) 70 | 71 | if lora_list: 72 | for lora_dict in lora_list: 73 | lora, lora_alpha = lora_dict["lora"], lora_dict["lora_alpha"] 74 | lora_state_dict = {} 75 | with safe_open(lora, framework="pt", device="cpu") as file: 76 | for k in file.keys(): 77 | lora_state_dict[k] = file.get_tensor(k) 78 | pipeline.unet, pipeline.text_encoder = convert_lora_model_level( 79 | lora_state_dict, 80 | pipeline.unet, 81 | pipeline.text_encoder, 82 | alpha=lora_alpha, 83 | ) 84 | print(f'Add LoRA "{lora}":{lora_alpha} to pipeline.') 85 | 86 | if text_embedding_dict is not None: 87 | from diffusers.loaders import TextualInversionLoaderMixin 88 | 89 | assert isinstance( 90 | pipeline, TextualInversionLoaderMixin 91 | ), "Pipeline must inherit from TextualInversionLoaderMixin." 92 | 93 | for token, embedding_path in text_embedding_dict.items(): 94 | pipeline.load_textual_inversion(embedding_path, token) 95 | 96 | return pipeline 97 | 98 | 99 | def load_third_party_unet(unet, third_party_dict: dict, dreambooth_path: Optional[str] = None): 100 | lora_list = third_party_dict.get("lora_list", []) 101 | dreambooth = dreambooth_path or third_party_dict.get("dreambooth", None) 102 | 103 | if dreambooth is not None: 104 | dreambooth_state_dict = {} 105 | if dreambooth.endswith(".safetensors"): 106 | with safe_open(dreambooth, framework="pt", device="cpu") as f: 107 | for key in f.keys(): 108 | dreambooth_state_dict[key] = f.get_tensor(key) 109 | else: 110 | dreambooth_state_dict = torch.load(dreambooth, map_location="cpu") 111 | if "state_dict" in dreambooth_state_dict: 112 | dreambooth_state_dict = dreambooth_state_dict["state_dict"] 113 | # load unet 114 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, unet.config) 115 | unet.load_state_dict(converted_unet_checkpoint, strict=False) 116 | 117 | if lora_list: 118 | for lora_dict in lora_list: 119 | lora, lora_alpha = lora_dict["lora"], lora_dict["lora_alpha"] 120 | lora_state_dict = {} 121 | 122 | with safe_open(lora, framework="pt", device="cpu") as file: 123 | for k in file.keys(): 124 | if "text" not in k: 125 | lora_state_dict[k] = file.get_tensor(k) 126 | unet, _ = convert_lora_model_level( 127 | lora_state_dict, 128 | unet, 129 | None, 130 | alpha=lora_alpha, 131 | ) 132 | print(f'Add LoRA "{lora}":{lora_alpha} to Warmup UNet.') 133 | 134 | return unet 135 | -------------------------------------------------------------------------------- /live2diff/animatediff/converter/convert_lora_safetensor_to_diffusers.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/PIA/blob/main/animatediff/utils/convert_lora_safetensor_to_diffusers.py and 2 | # https://github.com/guoyww/AnimateDiff/blob/main/animatediff/utils/convert_lora_safetensor_to_diffusers.py 3 | # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | """Conversion script for the LoRA's safetensors checkpoints.""" 18 | 19 | import torch 20 | 21 | 22 | def convert_lora_model_level( 23 | state_dict, unet, text_encoder=None, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6 24 | ): 25 | """convert lora in model level instead of pipeline leval""" 26 | 27 | visited = [] 28 | 29 | # directly update weight in diffusers model 30 | for key in state_dict: 31 | # it is suggested to print out the key, it usually will be something like below 32 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" 33 | 34 | # as we have set the alpha beforehand, so just skip 35 | if ".alpha" in key or key in visited: 36 | continue 37 | 38 | if "text" in key: 39 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") 40 | assert text_encoder is not None, "text_encoder must be passed since lora contains text encoder layers" 41 | curr_layer = text_encoder 42 | else: 43 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") 44 | curr_layer = unet 45 | 46 | # find the target layer 47 | temp_name = layer_infos.pop(0) 48 | while len(layer_infos) > -1: 49 | try: 50 | curr_layer = curr_layer.__getattr__(temp_name) 51 | if len(layer_infos) > 0: 52 | temp_name = layer_infos.pop(0) 53 | elif len(layer_infos) == 0: 54 | break 55 | except Exception: 56 | if len(temp_name) > 0: 57 | temp_name += "_" + layer_infos.pop(0) 58 | else: 59 | temp_name = layer_infos.pop(0) 60 | 61 | pair_keys = [] 62 | if "lora_down" in key: 63 | pair_keys.append(key.replace("lora_down", "lora_up")) 64 | pair_keys.append(key) 65 | else: 66 | pair_keys.append(key) 67 | pair_keys.append(key.replace("lora_up", "lora_down")) 68 | 69 | # update weight 70 | # NOTE: load lycon, maybe have bugs :( 71 | if "conv_in" in pair_keys[0]: 72 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 73 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 74 | weight_up = weight_up.view(weight_up.size(0), -1) 75 | weight_down = weight_down.view(weight_down.size(0), -1) 76 | shape = list(curr_layer.weight.data.shape) 77 | shape[1] = 4 78 | curr_layer.weight.data[:, :4, ...] += alpha * (weight_up @ weight_down).view(*shape) 79 | elif "conv" in pair_keys[0]: 80 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 81 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 82 | weight_up = weight_up.view(weight_up.size(0), -1) 83 | weight_down = weight_down.view(weight_down.size(0), -1) 84 | shape = list(curr_layer.weight.data.shape) 85 | curr_layer.weight.data += alpha * (weight_up @ weight_down).view(*shape) 86 | elif len(state_dict[pair_keys[0]].shape) == 4: 87 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) 88 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) 89 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to( 90 | curr_layer.weight.data.device 91 | ) 92 | else: 93 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 94 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 95 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 96 | 97 | # update visited list 98 | for item in pair_keys: 99 | visited.append(item) 100 | 101 | return unet, text_encoder 102 | -------------------------------------------------------------------------------- /live2diff/animatediff/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/Live2Diff/e40aa03a9a9d17f1232fd7a4566b3ee793e6893f/live2diff/animatediff/models/__init__.py -------------------------------------------------------------------------------- /live2diff/animatediff/models/depth_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | try: 6 | from ...MiDaS.midas.dpt_depth import DPTDepthModel 7 | except ImportError: 8 | print('Please pull the MiDaS submodule via "git submodule update --init --recursive"!') 9 | 10 | 11 | class MidasDetector(nn.Module): 12 | def __init__(self, model_path="./models/dpt_hybrid_384"): 13 | super().__init__() 14 | 15 | self.model = DPTDepthModel(path=model_path, backbone="vitb_rn50_384", non_negative=True) 16 | self.model.requires_grad_(False) 17 | self.model.eval() 18 | 19 | @property 20 | def dtype(self): 21 | return next(self.parameters()).dtype 22 | 23 | @property 24 | def device(self): 25 | return next(self.parameters()).device 26 | 27 | @torch.no_grad() 28 | def forward(self, images: torch.Tensor): 29 | """ 30 | Input: [b, c, h, w] 31 | """ 32 | return self.model(images) 33 | -------------------------------------------------------------------------------- /live2diff/animatediff/models/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class PositionalEncoding(nn.Module): 9 | def __init__(self, d_model, dropout=0.0, max_len=32): 10 | super().__init__() 11 | self.dropout = nn.Dropout(p=dropout) 12 | position = torch.arange(max_len).unsqueeze(1) 13 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 14 | pe = torch.zeros(1, max_len, d_model) 15 | pe[0, :, 0::2] = torch.sin(position * div_term) 16 | pe[0, :, 1::2] = torch.cos(position * div_term) 17 | self.register_buffer("pe", pe) 18 | 19 | def forward(self, x, roll: Optional[int] = None, full_video_length: Optional[int] = None): 20 | """ 21 | Support roll for positional encoding. 22 | We select the first `full_video_length` elements and roll it by `roll`. 23 | And then select the first `x.size(1)` elements and add them to `x`. 24 | 25 | Take full_video_length = 4, roll = 2, and x.size(1) = 1 as example. 26 | 27 | If the original positional encoding is: 28 | [1, 2, 3, 4, 5, 6, 7, 8] 29 | The rolled encoding is: 30 | [3, 4, 1, 2] 31 | And the selected encoding added to input is: 32 | [3, 4] 33 | 34 | """ 35 | if roll is None: 36 | pe = self.pe[:, : x.size(1)] 37 | else: 38 | assert full_video_length is not None, "full_video_length must be passed when roll is not None." 39 | pe = self.pe[:, :full_video_length].roll(shifts=roll, dims=1)[:, : x.size(1)] 40 | x = x + pe 41 | return self.dropout(x) 42 | -------------------------------------------------------------------------------- /live2diff/animatediff/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | from typing import Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | 9 | 10 | def zero_module(module): 11 | # Zero out the parameters of a module and return it. 12 | for p in module.parameters(): 13 | p.detach().zero_() 14 | return module 15 | 16 | 17 | class MappingNetwork(nn.Module): 18 | """ 19 | Modified from https://github.com/huggingface/diffusers/blob/196835695ed6fa3ec53b888088d9d5581e8f8e94/src/diffusers/models/controlnet.py#L66-L108 # noqa 20 | """ 21 | 22 | def __init__( 23 | self, 24 | conditioning_embedding_channels: int, 25 | conditioning_channels: int = 3, 26 | block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), 27 | ): 28 | super().__init__() 29 | 30 | self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) 31 | 32 | self.blocks = nn.ModuleList([]) 33 | 34 | for i in range(len(block_out_channels) - 1): 35 | channel_in = block_out_channels[i] 36 | channel_out = block_out_channels[i + 1] 37 | self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)) 38 | self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1)) 39 | 40 | self.conv_out = zero_module( 41 | InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) 42 | ) 43 | 44 | def forward(self, conditioning): 45 | embedding = self.conv_in(conditioning) 46 | embedding = F.silu(embedding) 47 | 48 | for block in self.blocks: 49 | embedding = block(embedding) 50 | embedding = F.silu(embedding) 51 | 52 | embedding = self.conv_out(embedding) 53 | 54 | return embedding 55 | 56 | 57 | class InflatedConv3d(nn.Conv2d): 58 | def forward(self, x): 59 | video_length = x.shape[2] 60 | 61 | x = rearrange(x, "b c f h w -> (b f) c h w") 62 | x = super().forward(x) 63 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 64 | 65 | return x 66 | 67 | 68 | class InflatedGroupNorm(nn.GroupNorm): 69 | def forward(self, x): 70 | video_length = x.shape[2] 71 | 72 | x = rearrange(x, "b c f h w -> (b f) c h w") 73 | x = super().forward(x) 74 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 75 | 76 | return x 77 | 78 | 79 | class Upsample3D(nn.Module): 80 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 81 | super().__init__() 82 | self.channels = channels 83 | self.out_channels = out_channels or channels 84 | self.use_conv = use_conv 85 | self.use_conv_transpose = use_conv_transpose 86 | self.name = name 87 | 88 | # conv = None 89 | if use_conv_transpose: 90 | raise NotImplementedError 91 | elif use_conv: 92 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 93 | 94 | def forward(self, hidden_states, output_size=None): 95 | assert hidden_states.shape[1] == self.channels 96 | 97 | if self.use_conv_transpose: 98 | raise NotImplementedError 99 | 100 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 101 | dtype = hidden_states.dtype 102 | if dtype == torch.bfloat16: 103 | hidden_states = hidden_states.to(torch.float32) 104 | 105 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 106 | if hidden_states.shape[0] >= 64: 107 | hidden_states = hidden_states.contiguous() 108 | 109 | # if `output_size` is passed we force the interpolation output 110 | # size and do not make use of `scale_factor=2` 111 | if output_size is None: 112 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 113 | else: 114 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 115 | 116 | # If the input is bfloat16, we cast back to bfloat16 117 | if dtype == torch.bfloat16: 118 | hidden_states = hidden_states.to(dtype) 119 | 120 | # if self.use_conv: 121 | # if self.name == "conv": 122 | # hidden_states = self.conv(hidden_states) 123 | # else: 124 | # hidden_states = self.Conv2d_0(hidden_states) 125 | hidden_states = self.conv(hidden_states) 126 | 127 | return hidden_states 128 | 129 | 130 | class Downsample3D(nn.Module): 131 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 132 | super().__init__() 133 | self.channels = channels 134 | self.out_channels = out_channels or channels 135 | self.use_conv = use_conv 136 | self.padding = padding 137 | stride = 2 138 | self.name = name 139 | 140 | if use_conv: 141 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 142 | else: 143 | raise NotImplementedError 144 | 145 | def forward(self, hidden_states): 146 | assert hidden_states.shape[1] == self.channels 147 | if self.use_conv and self.padding == 0: 148 | raise NotImplementedError 149 | 150 | assert hidden_states.shape[1] == self.channels 151 | hidden_states = self.conv(hidden_states) 152 | 153 | return hidden_states 154 | 155 | 156 | class ResnetBlock3D(nn.Module): 157 | def __init__( 158 | self, 159 | *, 160 | in_channels, 161 | out_channels=None, 162 | conv_shortcut=False, 163 | dropout=0.0, 164 | temb_channels=512, 165 | groups=32, 166 | groups_out=None, 167 | pre_norm=True, 168 | eps=1e-6, 169 | non_linearity="swish", 170 | time_embedding_norm="default", 171 | output_scale_factor=1.0, 172 | use_in_shortcut=None, 173 | use_inflated_groupnorm=False, 174 | ): 175 | super().__init__() 176 | self.pre_norm = pre_norm 177 | self.pre_norm = True 178 | self.in_channels = in_channels 179 | out_channels = in_channels if out_channels is None else out_channels 180 | self.out_channels = out_channels 181 | self.use_conv_shortcut = conv_shortcut 182 | self.time_embedding_norm = time_embedding_norm 183 | self.output_scale_factor = output_scale_factor 184 | 185 | if groups_out is None: 186 | groups_out = groups 187 | 188 | assert use_inflated_groupnorm is not None 189 | if use_inflated_groupnorm: 190 | self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 191 | else: 192 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 193 | 194 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 195 | 196 | if temb_channels is not None: 197 | if self.time_embedding_norm == "default": 198 | time_emb_proj_out_channels = out_channels 199 | elif self.time_embedding_norm == "scale_shift": 200 | time_emb_proj_out_channels = out_channels * 2 201 | else: 202 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 203 | 204 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 205 | else: 206 | self.time_emb_proj = None 207 | 208 | if use_inflated_groupnorm: 209 | self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 210 | else: 211 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 212 | 213 | self.dropout = torch.nn.Dropout(dropout) 214 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 215 | 216 | if non_linearity == "swish": 217 | self.nonlinearity = lambda x: F.silu(x) 218 | elif non_linearity == "mish": 219 | self.nonlinearity = Mish() 220 | elif non_linearity == "silu": 221 | self.nonlinearity = nn.SiLU() 222 | 223 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 224 | 225 | self.conv_shortcut = None 226 | if self.use_in_shortcut: 227 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 228 | 229 | def forward(self, input_tensor, temb): 230 | hidden_states = input_tensor 231 | 232 | hidden_states = self.norm1(hidden_states) 233 | hidden_states = self.nonlinearity(hidden_states) 234 | 235 | hidden_states = self.conv1(hidden_states) 236 | 237 | if temb is not None: 238 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 239 | 240 | if temb is not None and self.time_embedding_norm == "default": 241 | hidden_states = hidden_states + temb 242 | 243 | hidden_states = self.norm2(hidden_states) 244 | 245 | if temb is not None and self.time_embedding_norm == "scale_shift": 246 | scale, shift = torch.chunk(temb, 2, dim=1) 247 | hidden_states = hidden_states * (1 + scale) + shift 248 | 249 | hidden_states = self.nonlinearity(hidden_states) 250 | 251 | hidden_states = self.dropout(hidden_states) 252 | hidden_states = self.conv2(hidden_states) 253 | 254 | if self.conv_shortcut is not None: 255 | input_tensor = self.conv_shortcut(input_tensor) 256 | 257 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 258 | 259 | return output_tensor 260 | 261 | 262 | class Mish(torch.nn.Module): 263 | def forward(self, hidden_states): 264 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 265 | -------------------------------------------------------------------------------- /live2diff/animatediff/models/stream_motion_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | 5 | from .attention import CrossAttention 6 | from .positional_encoding import PositionalEncoding 7 | 8 | 9 | class StreamTemporalAttention(CrossAttention): 10 | """ 11 | 12 | * window_size: The max length of attention window. 13 | * sink_size: The number sink token. 14 | * positional_rule: absolute, relative 15 | 16 | Therefore, the seq length of temporal self-attention will be: 17 | sink_length + cache_size 18 | 19 | """ 20 | 21 | def __init__( 22 | self, 23 | attention_mode=None, 24 | cross_frame_attention_mode=None, 25 | temporal_position_encoding=False, 26 | temporal_position_encoding_max_len=32, 27 | window_size=8, 28 | sink_size=0, 29 | *args, 30 | **kwargs, 31 | ): 32 | super().__init__(*args, **kwargs) 33 | 34 | self.attention_mode = self._orig_attention_mode = attention_mode 35 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None 36 | 37 | self.pos_encoder = PositionalEncoding( 38 | kwargs["query_dim"], 39 | dropout=0.0, 40 | max_len=temporal_position_encoding_max_len, 41 | ) 42 | 43 | self.window_size = window_size 44 | self.sink_size = sink_size 45 | self.cache_size = self.window_size - self.sink_size 46 | assert self.cache_size >= 0, ( 47 | "cache_size must be greater or equal to 0. Please check your configuration. " 48 | f"window_size: {window_size}, sink_size: {sink_size}, " 49 | f"cache_size: {self.cache_size}" 50 | ) 51 | 52 | self.motion_module_idx = None 53 | 54 | def set_index(self, idx): 55 | self.motion_module_idx = idx 56 | 57 | @torch.no_grad() 58 | def set_cache(self, denoising_steps_num: int): 59 | """ 60 | larger buffer index means cleaner latent 61 | """ 62 | device = next(self.parameters()).device 63 | dtype = next(self.parameters()).dtype 64 | 65 | # [t, 2, hw, L, c], 2 means k and v 66 | kv_cache = torch.zeros( 67 | denoising_steps_num, 68 | 2, 69 | self.h * self.w, 70 | self.window_size, 71 | self.kv_channels, 72 | device=device, 73 | dtype=dtype, 74 | ) 75 | self.denoising_steps_num = denoising_steps_num 76 | 77 | return kv_cache 78 | 79 | @torch.no_grad() 80 | def prepare_pe_buffer(self): 81 | """In AnimateDiff, Temporal Self-attention use absolute positional encoding: 82 | q = w_q * (x + pe) + bias 83 | k = w_k * (x + pe) + bias 84 | v = w_v * (x + pe) + bias 85 | 86 | If we want to conduct relative positional encoding with kv-cache, we should pre-calcute 87 | `w_q/k/v * pe` and then cache `w_q/k/v * x + bias` 88 | """ 89 | 90 | pe_list = self.pos_encoder.pe[:, : self.window_size] # [1, window_size, ch] 91 | q_pe = F.linear(pe_list, self.to_q.weight) 92 | k_pe = F.linear(pe_list, self.to_k.weight) 93 | v_pe = F.linear(pe_list, self.to_v.weight) 94 | 95 | self.register_buffer("q_pe", q_pe) 96 | self.register_buffer("k_pe", k_pe) 97 | self.register_buffer("v_pe", v_pe) 98 | 99 | def prepare_qkv_full_and_cache(self, hidden_states, kv_cache, pe_idx, update_idx): 100 | """ 101 | hidden_states: [(N * bhw), F, c], 102 | kv_cache: [2, N, hw, L, c] 103 | 104 | * for warmup case: `N` should be 1 and `F` should be warmup_size (`sink_size`) 105 | * for streaming case: `N` should be `denoising_steps_num` and `F` should be `chunk_size` 106 | 107 | """ 108 | q_layer = self.to_q(hidden_states) 109 | k_layer = self.to_k(hidden_states) 110 | v_layer = self.to_v(hidden_states) 111 | 112 | q_layer = rearrange(q_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num) 113 | k_layer = rearrange(k_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num) 114 | v_layer = rearrange(v_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num) 115 | 116 | # onnx & trt friendly indexing 117 | for idx in range(self.denoising_steps_num): 118 | kv_cache[idx, 0, :, update_idx[idx]] = k_layer[idx, :, 0] 119 | kv_cache[idx, 1, :, update_idx[idx]] = v_layer[idx, :, 0] 120 | 121 | k_full = kv_cache[:, 0] 122 | v_full = kv_cache[:, 1] 123 | 124 | kv_idx = pe_idx 125 | q_idx = torch.stack([kv_idx[idx, update_idx[idx]] for idx in range(self.denoising_steps_num)]).unsqueeze_( 126 | 1 127 | ) # [timesteps, 1] 128 | 129 | pe_k = torch.cat( 130 | [self.k_pe.index_select(1, kv_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0 131 | ) # [n, window_size, c] 132 | pe_v = torch.cat( 133 | [self.v_pe.index_select(1, kv_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0 134 | ) # [n, window_size, c] 135 | pe_q = torch.cat( 136 | [self.q_pe.index_select(1, q_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0 137 | ) # [n, window_size, c] 138 | 139 | q_layer = q_layer + pe_q.unsqueeze(1) 140 | k_full = k_full + pe_k.unsqueeze(1) 141 | v_full = v_full + pe_v.unsqueeze(1) 142 | 143 | q_layer = rearrange(q_layer, "n bhw f c -> (n bhw) f c") 144 | k_full = rearrange(k_full, "n bhw f c -> (n bhw) f c") 145 | v_full = rearrange(v_full, "n bhw f c -> (n bhw) f c") 146 | 147 | return q_layer, k_full, v_full 148 | 149 | def forward( 150 | self, 151 | hidden_states, 152 | encoder_hidden_states=None, 153 | attention_mask=None, 154 | video_length=None, 155 | temporal_attention_mask=None, 156 | kv_cache=None, 157 | pe_idx=None, 158 | update_idx=None, 159 | *args, 160 | **kwargs, 161 | ): 162 | """ 163 | temporal_attention_mask: attention mask specific for the temporal self-attention. 164 | """ 165 | 166 | d = hidden_states.shape[1] 167 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 168 | 169 | if self.group_norm is not None: 170 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 171 | 172 | query_layer, key_full, value_full = self.prepare_qkv_full_and_cache( 173 | hidden_states, kv_cache, pe_idx, update_idx 174 | ) 175 | 176 | # [(n * hw * b), f, c] -> [(n * hw * b * head), f, c // head] 177 | query_layer = self.reshape_heads_to_batch_dim(query_layer) 178 | key_full = self.reshape_heads_to_batch_dim(key_full) 179 | value_full = self.reshape_heads_to_batch_dim(value_full) 180 | 181 | if temporal_attention_mask is not None: 182 | q_size = query_layer.shape[1] 183 | # [n, self.window_size] -> [n, hw, q_size, window_size] 184 | temporal_attention_mask_ = temporal_attention_mask[:, None, None, :].repeat(1, self.h * self.w, q_size, 1) 185 | temporal_attention_mask_ = rearrange(temporal_attention_mask_, "n hw Q KV -> (n hw) Q KV") 186 | temporal_attention_mask_ = temporal_attention_mask_.repeat_interleave(self.heads, dim=0) 187 | else: 188 | temporal_attention_mask_ = None 189 | 190 | # attention, what we cannot get enough of 191 | if hasattr(F, "scaled_dot_product_attention"): 192 | hidden_states = self._memory_efficient_attention_pt20( 193 | query_layer, key_full, value_full, attention_mask=temporal_attention_mask_ 194 | ) 195 | 196 | elif self._use_memory_efficient_attention_xformers: 197 | hidden_states = self._memory_efficient_attention_xformers( 198 | query_layer, key_full, value_full, attention_mask=temporal_attention_mask_ 199 | ) 200 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input 201 | hidden_states = hidden_states.to(query_layer.dtype) 202 | else: 203 | hidden_states = self._attention(query_layer, key_full, value_full, temporal_attention_mask_) 204 | 205 | # linear proj 206 | hidden_states = self.to_out[0](hidden_states) 207 | 208 | # dropout 209 | hidden_states = self.to_out[1](hidden_states) 210 | 211 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 212 | 213 | return hidden_states 214 | -------------------------------------------------------------------------------- /live2diff/animatediff/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_animatediff_depth import AnimationDepthPipeline 2 | 3 | 4 | __all__ = ["AnimationDepthPipeline"] 5 | -------------------------------------------------------------------------------- /live2diff/animatediff/pipeline/loader.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Union 2 | 3 | import torch 4 | from diffusers.loaders.lora import LoraLoaderMixin 5 | from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT 6 | from diffusers.utils import USE_PEFT_BACKEND 7 | 8 | 9 | class LoraLoaderWithWarmup(LoraLoaderMixin): 10 | unet_warmup_name = "unet_warmup" 11 | 12 | def load_lora_weights( 13 | self, 14 | pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], 15 | adapter_name=None, 16 | **kwargs, 17 | ): 18 | # load lora for text encoder and unet-streaming 19 | super().load_lora_weights(pretrained_model_name_or_path_or_dict, adapter_name=adapter_name, **kwargs) 20 | 21 | # load lora for unet-warmup 22 | state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) 23 | low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) 24 | 25 | self.load_lora_into_unet( 26 | state_dict, 27 | network_alphas=network_alphas, 28 | unet=getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup, 29 | low_cpu_mem_usage=low_cpu_mem_usage, 30 | adapter_name=adapter_name, 31 | _pipeline=self, 32 | ) 33 | 34 | def fuse_lora( 35 | self, 36 | fuse_unet: bool = True, 37 | fuse_text_encoder: bool = True, 38 | lora_scale: float = 1.0, 39 | safe_fusing: bool = False, 40 | adapter_names: Optional[List[str]] = None, 41 | ): 42 | # fuse lora for text encoder and unet-streaming 43 | super().fuse_lora(fuse_unet, fuse_text_encoder, lora_scale, safe_fusing, adapter_names) 44 | 45 | # fuse lora for unet-warmup 46 | if fuse_unet: 47 | unet_warmup = ( 48 | getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup 49 | ) 50 | unet_warmup.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) 51 | 52 | def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True): 53 | # unfuse lora for text encoder and unet-streaming 54 | super().unfuse_lora(unfuse_unet, unfuse_text_encoder) 55 | 56 | # unfuse lora for unet-warmup 57 | if unfuse_unet: 58 | unet_warmup = ( 59 | getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup 60 | ) 61 | if not USE_PEFT_BACKEND: 62 | unet_warmup.unfuse_lora() 63 | else: 64 | from peft.tuners.tuners_utils import BaseTunerLayer 65 | 66 | for module in unet_warmup.modules(): 67 | if isinstance(module, BaseTunerLayer): 68 | module.unmerge() 69 | -------------------------------------------------------------------------------- /live2diff/animatediff/pipeline/pipeline_animatediff_depth.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/open-mmlab/PIA/blob/main/animatediff/pipelines/i2v_pipeline.py 2 | 3 | from dataclasses import dataclass 4 | from typing import List, Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | from diffusers.configuration_utils import FrozenDict 9 | from diffusers.loaders import TextualInversionLoaderMixin 10 | from diffusers.models import AutoencoderKL 11 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 12 | from diffusers.schedulers import ( 13 | DDIMScheduler, 14 | DPMSolverMultistepScheduler, 15 | EulerAncestralDiscreteScheduler, 16 | EulerDiscreteScheduler, 17 | LMSDiscreteScheduler, 18 | PNDMScheduler, 19 | ) 20 | from diffusers.utils import BaseOutput, deprecate, is_accelerate_available, logging 21 | from packaging import version 22 | from transformers import CLIPTextModel, CLIPTokenizer 23 | 24 | from ..models.depth_utils import MidasDetector 25 | from ..models.unet_depth_streaming import UNet3DConditionStreamingModel 26 | from .loader import LoraLoaderWithWarmup 27 | 28 | 29 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 30 | 31 | 32 | @dataclass 33 | class AnimationPipelineOutput(BaseOutput): 34 | videos: Union[torch.Tensor, np.ndarray] 35 | input_images: Optional[Union[torch.Tensor, np.ndarray]] = None 36 | 37 | 38 | class AnimationDepthPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderWithWarmup): 39 | _optional_components = [] 40 | 41 | def __init__( 42 | self, 43 | vae: AutoencoderKL, 44 | text_encoder: CLIPTextModel, 45 | tokenizer: CLIPTokenizer, 46 | unet: UNet3DConditionStreamingModel, 47 | depth_model: MidasDetector, 48 | scheduler: Union[ 49 | DDIMScheduler, 50 | PNDMScheduler, 51 | LMSDiscreteScheduler, 52 | EulerDiscreteScheduler, 53 | EulerAncestralDiscreteScheduler, 54 | DPMSolverMultistepScheduler, 55 | ], 56 | ): 57 | super().__init__() 58 | 59 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: 60 | deprecation_message = ( 61 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 62 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 63 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" 64 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," 65 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" 66 | " file" 67 | ) 68 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) 69 | new_config = dict(scheduler.config) 70 | new_config["steps_offset"] = 1 71 | scheduler._internal_dict = FrozenDict(new_config) 72 | 73 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: 74 | deprecation_message = ( 75 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." 76 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the" 77 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" 78 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" 79 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" 80 | ) 81 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) 82 | new_config = dict(scheduler.config) 83 | new_config["clip_sample"] = False 84 | scheduler._internal_dict = FrozenDict(new_config) 85 | 86 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( 87 | version.parse(unet.config._diffusers_version).base_version 88 | ) < version.parse("0.9.0.dev0") 89 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 90 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: 91 | deprecation_message = ( 92 | "The configuration file of the unet has set the default `sample_size` to smaller than" 93 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" 94 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" 95 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" 96 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" 97 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" 98 | " in the config might lead to incorrect results in future versions. If you have downloaded this" 99 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" 100 | " the `unet/config.json` file" 101 | ) 102 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) 103 | new_config = dict(unet.config) 104 | new_config["sample_size"] = 64 105 | unet._internal_dict = FrozenDict(new_config) 106 | 107 | self.register_modules( 108 | vae=vae, 109 | text_encoder=text_encoder, 110 | tokenizer=tokenizer, 111 | unet=unet, 112 | depth_model=depth_model, 113 | scheduler=scheduler, 114 | ) 115 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 116 | self.log_denoising_mean = False 117 | 118 | def enable_vae_slicing(self): 119 | self.vae.enable_slicing() 120 | 121 | def disable_vae_slicing(self): 122 | self.vae.disable_slicing() 123 | 124 | def enable_sequential_cpu_offload(self, gpu_id=0): 125 | if is_accelerate_available(): 126 | from accelerate import cpu_offload 127 | else: 128 | raise ImportError("Please install accelerate via `pip install accelerate`") 129 | 130 | device = torch.device(f"cuda:{gpu_id}") 131 | 132 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 133 | if cpu_offloaded_model is not None: 134 | cpu_offload(cpu_offloaded_model, device) 135 | 136 | @property 137 | def _execution_device(self): 138 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 139 | return self.device 140 | for module in self.unet.modules(): 141 | if ( 142 | hasattr(module, "_hf_hook") 143 | and hasattr(module._hf_hook, "execution_device") 144 | and module._hf_hook.execution_device is not None 145 | ): 146 | return torch.device(module._hf_hook.execution_device) 147 | return self.device 148 | 149 | def _encode_prompt( 150 | self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt, clip_skip=None 151 | ): 152 | batch_size = len(prompt) if isinstance(prompt, list) else 1 153 | 154 | text_inputs = self.tokenizer( 155 | prompt, 156 | padding="max_length", 157 | max_length=self.tokenizer.model_max_length, 158 | truncation=True, 159 | return_tensors="pt", 160 | ) 161 | text_input_ids = text_inputs.input_ids 162 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 163 | 164 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 165 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) 166 | logger.warning( 167 | "The following part of your input was truncated because CLIP can only handle sequences up to" 168 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 169 | ) 170 | 171 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 172 | attention_mask = text_inputs.attention_mask.to(device) 173 | else: 174 | attention_mask = None 175 | 176 | if clip_skip is None: 177 | text_embeddings = self.text_encoder( 178 | text_input_ids.to(device), 179 | attention_mask=attention_mask, 180 | ) 181 | text_embeddings = text_embeddings[0] 182 | else: 183 | # support ckip skip here, suitable for model based on NAI~ 184 | text_embeddings = self.text_encoder( 185 | text_input_ids.to(device), 186 | attention_mask=attention_mask, 187 | output_hidden_states=True, 188 | ) 189 | text_embeddings = text_embeddings[-1][-(clip_skip + 1)] 190 | text_embeddings = self.text_encoder.text_model.final_layer_norm(text_embeddings) 191 | 192 | # duplicate text embeddings for each generation per prompt, using mps friendly method 193 | bs_embed, seq_len, _ = text_embeddings.shape 194 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) 195 | text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) 196 | 197 | # get unconditional embeddings for classifier free guidance 198 | if do_classifier_free_guidance: 199 | uncond_tokens: List[str] 200 | if negative_prompt is None: 201 | uncond_tokens = [""] * batch_size 202 | elif type(prompt) is not type(negative_prompt): 203 | raise TypeError( 204 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 205 | f" {type(prompt)}." 206 | ) 207 | elif isinstance(negative_prompt, str): 208 | uncond_tokens = [negative_prompt] 209 | elif batch_size != len(negative_prompt): 210 | raise ValueError( 211 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 212 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 213 | " the batch size of `prompt`." 214 | ) 215 | else: 216 | uncond_tokens = negative_prompt 217 | 218 | max_length = text_input_ids.shape[-1] 219 | uncond_input = self.tokenizer( 220 | uncond_tokens, 221 | padding="max_length", 222 | max_length=max_length, 223 | truncation=True, 224 | return_tensors="pt", 225 | ) 226 | 227 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 228 | attention_mask = uncond_input.attention_mask.to(device) 229 | else: 230 | attention_mask = None 231 | 232 | uncond_embeddings = self.text_encoder( 233 | uncond_input.input_ids.to(device), 234 | attention_mask=attention_mask, 235 | ) 236 | uncond_embeddings = uncond_embeddings[0] 237 | 238 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 239 | seq_len = uncond_embeddings.shape[1] 240 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) 241 | uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) 242 | 243 | # For classifier free guidance, we need to do two forward passes. 244 | # Here we concatenate the unconditional and text embeddings into a single batch 245 | # to avoid doing two forward passes 246 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 247 | 248 | return text_embeddings 249 | 250 | @classmethod 251 | def build_pipeline(cls, config_path: str, dreambooth: Optional[str] = None): 252 | """We build pipeline from config path""" 253 | from omegaconf import OmegaConf 254 | 255 | from ...utils.config import load_config 256 | from ..converter import load_third_party_checkpoints 257 | from ..models.unet_depth_streaming import UNet3DConditionStreamingModel 258 | 259 | cfg = load_config(config_path) 260 | pretrained_model_path = cfg.pretrained_model_path 261 | unet_additional_kwargs = cfg.get("unet_additional_kwargs", {}) 262 | noise_scheduler_kwargs = cfg.noise_scheduler_kwargs 263 | third_party_dict = cfg.get("third_party_dict", {}) 264 | 265 | noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) 266 | 267 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") 268 | vae = vae.to(device="cuda", dtype=torch.bfloat16) 269 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") 270 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") 271 | text_encoder = text_encoder.to(device="cuda", dtype=torch.float16) 272 | 273 | unet = UNet3DConditionStreamingModel.from_pretrained_2d( 274 | pretrained_model_path, 275 | subfolder="unet", 276 | unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs) if unet_additional_kwargs else {}, 277 | ) 278 | 279 | motion_module_path = cfg.motion_module_path 280 | # load motion module to unet 281 | mm_checkpoint = torch.load(motion_module_path, map_location="cuda") 282 | if "global_step" in mm_checkpoint: 283 | print(f"global_step: {mm_checkpoint['global_step']}") 284 | state_dict = mm_checkpoint["state_dict"] if "state_dict" in mm_checkpoint else mm_checkpoint 285 | # NOTE: hard code here: remove `grid` from state_dict 286 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items() if "grid" not in k} 287 | 288 | m, u = unet.load_state_dict(state_dict, strict=False) 289 | assert len(u) == 0, f"Find unexpected keys ({len(u)}): {u}" 290 | print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") 291 | 292 | unet = unet.to(dtype=torch.float16) 293 | depth_model = MidasDetector(cfg.depth_model_path).to(device="cuda", dtype=torch.float16) 294 | 295 | pipeline = cls( 296 | unet=unet, 297 | vae=vae, 298 | tokenizer=tokenizer, 299 | text_encoder=text_encoder, 300 | depth_model=depth_model, 301 | scheduler=noise_scheduler, 302 | ) 303 | pipeline = load_third_party_checkpoints(pipeline, third_party_dict, dreambooth) 304 | 305 | return pipeline 306 | 307 | @classmethod 308 | def build_warmup_unet(cls, config_path: str, dreambooth: Optional[str] = None): 309 | from omegaconf import OmegaConf 310 | 311 | from ...utils.config import load_config 312 | from ..converter import load_third_party_unet 313 | from ..models.unet_depth_warmup import UNet3DConditionWarmupModel 314 | 315 | cfg = load_config(config_path) 316 | pretrained_model_path = cfg.pretrained_model_path 317 | unet_additional_kwargs = cfg.get("unet_additional_kwargs", {}) 318 | third_party_dict = cfg.get("third_party_dict", {}) 319 | 320 | unet = UNet3DConditionWarmupModel.from_pretrained_2d( 321 | pretrained_model_path, 322 | subfolder="unet", 323 | unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs) if unet_additional_kwargs else {}, 324 | ) 325 | motion_module_path = cfg.motion_module_path 326 | # load motion module to unet 327 | mm_checkpoint = torch.load(motion_module_path, map_location="cpu") 328 | if "global_step" in mm_checkpoint: 329 | print(f"global_step: {mm_checkpoint['global_step']}") 330 | state_dict = mm_checkpoint["state_dict"] if "state_dict" in mm_checkpoint else mm_checkpoint 331 | # NOTE: hard code here: remove `grid` from state_dict 332 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items() if "grid" not in k} 333 | 334 | m, u = unet.load_state_dict(state_dict, strict=False) 335 | assert len(u) == 0, f"Find unexpected keys ({len(u)}): {u}" 336 | print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") 337 | 338 | unet = load_third_party_unet(unet, third_party_dict, dreambooth) 339 | return unet 340 | 341 | def prepare_cache(self, height: int, width: int, denoising_steps_num: int): 342 | vae = self.vae 343 | scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) 344 | self.unet.set_info_for_attn(height // scale_factor, width // scale_factor) 345 | kv_cache_list = self.unet.prepare_cache(denoising_steps_num) 346 | return kv_cache_list 347 | 348 | def prepare_warmup_unet(self, height: int, width: int, unet): 349 | vae = self.vae 350 | scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) 351 | unet.set_info_for_attn(height // scale_factor, width // scale_factor) 352 | -------------------------------------------------------------------------------- /live2diff/image_filter.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional 3 | 4 | import torch 5 | 6 | 7 | class SimilarImageFilter: 8 | def __init__(self, threshold: float = 0.98, max_skip_frame: float = 10) -> None: 9 | self.threshold = threshold 10 | self.prev_tensor = None 11 | self.cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6) 12 | self.max_skip_frame = max_skip_frame 13 | self.skip_count = 0 14 | 15 | def __call__(self, x: torch.Tensor) -> Optional[torch.Tensor]: 16 | if self.prev_tensor is None: 17 | self.prev_tensor = x.detach().clone() 18 | return x 19 | else: 20 | cos_sim = self.cos(self.prev_tensor.reshape(-1), x.reshape(-1)).item() 21 | sample = random.uniform(0, 1) 22 | if self.threshold >= 1: 23 | skip_prob = 0 24 | else: 25 | skip_prob = max(0, 1 - (1 - cos_sim) / (1 - self.threshold)) 26 | 27 | # not skip frame 28 | if skip_prob < sample: 29 | self.prev_tensor = x.detach().clone() 30 | return x 31 | # skip frame 32 | else: 33 | if self.skip_count > self.max_skip_frame: 34 | self.skip_count = 0 35 | self.prev_tensor = x.detach().clone() 36 | return x 37 | else: 38 | self.skip_count += 1 39 | return None 40 | 41 | def set_threshold(self, threshold: float) -> None: 42 | self.threshold = threshold 43 | 44 | def set_max_skip_frame(self, max_skip_frame: float) -> None: 45 | self.max_skip_frame = max_skip_frame 46 | -------------------------------------------------------------------------------- /live2diff/image_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | import PIL.Image 5 | import torch 6 | import torchvision 7 | 8 | 9 | def denormalize(images: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: 10 | """ 11 | Denormalize an image array to [0,1]. 12 | """ 13 | return (images / 2 + 0.5).clamp(0, 1) 14 | 15 | 16 | def pt_to_numpy(images: torch.Tensor) -> np.ndarray: 17 | """ 18 | Convert a PyTorch tensor to a NumPy image. 19 | """ 20 | images = images.cpu().permute(0, 2, 3, 1).float().numpy() 21 | return images 22 | 23 | 24 | def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image: 25 | """ 26 | Convert a NumPy image or a batch of images to a PIL image. 27 | """ 28 | if images.ndim == 3: 29 | images = images[None, ...] 30 | images = (images * 255).round().astype("uint8") 31 | if images.shape[-1] == 1: 32 | # special case for grayscale (single channel) images 33 | pil_images = [PIL.Image.fromarray(image.squeeze(), mode="L") for image in images] 34 | else: 35 | pil_images = [PIL.Image.fromarray(image) for image in images] 36 | 37 | return pil_images 38 | 39 | 40 | def postprocess_image( 41 | image: torch.Tensor, 42 | output_type: str = "pil", 43 | do_denormalize: Optional[List[bool]] = None, 44 | ) -> Union[torch.Tensor, np.ndarray, PIL.Image.Image]: 45 | if not isinstance(image, torch.Tensor): 46 | raise ValueError( 47 | f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" 48 | ) 49 | 50 | if output_type == "latent": 51 | return image 52 | 53 | do_normalize_flg = True 54 | if do_denormalize is None: 55 | do_denormalize = [do_normalize_flg] * image.shape[0] 56 | 57 | image = torch.stack([denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]) 58 | 59 | if output_type == "pt": 60 | return image 61 | 62 | image = pt_to_numpy(image) 63 | 64 | if output_type == "np": 65 | return image 66 | 67 | if output_type == "pil": 68 | return numpy_to_pil(image) 69 | 70 | 71 | def process_image( 72 | image_pil: PIL.Image.Image, range: Tuple[int, int] = (-1, 1) 73 | ) -> Tuple[torch.Tensor, PIL.Image.Image]: 74 | image = torchvision.transforms.ToTensor()(image_pil) 75 | r_min, r_max = range[0], range[1] 76 | image = image * (r_max - r_min) + r_min 77 | return image[None, ...], image_pil 78 | 79 | 80 | def pil2tensor(image_pil: PIL.Image.Image) -> torch.Tensor: 81 | height = image_pil.height 82 | width = image_pil.width 83 | imgs = [] 84 | img, _ = process_image(image_pil) 85 | imgs.append(img) 86 | imgs = torch.vstack(imgs) 87 | images = torch.nn.functional.interpolate(imgs, size=(height, width), mode="bilinear") 88 | image_tensors = images.to(torch.float16) 89 | return image_tensors 90 | -------------------------------------------------------------------------------- /live2diff/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/Live2Diff/e40aa03a9a9d17f1232fd7a4566b3ee793e6893f/live2diff/utils/__init__.py -------------------------------------------------------------------------------- /live2diff/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | from omegaconf import OmegaConf 5 | 6 | 7 | config_suffix = [".yaml"] 8 | 9 | 10 | def load_config(config: str) -> OmegaConf: 11 | config = OmegaConf.load(config) 12 | base_config = config.pop("base", None) 13 | 14 | if base_config: 15 | config = OmegaConf.merge(OmegaConf.load(base_config), config) 16 | 17 | return config 18 | 19 | 20 | def dump_config(config: OmegaConf, save_path: str = None): 21 | from omegaconf import Container 22 | 23 | if isinstance(config, Container): 24 | if not save_path.endswith(".yaml"): 25 | save_dir = save_path 26 | save_path = osp.join(save_dir, "config.yaml") 27 | else: 28 | save_dir = osp.basename(config) 29 | os.makedirs(save_dir, exist_ok=True) 30 | OmegaConf.save(config, save_path) 31 | 32 | else: 33 | raise TypeError("Only support saving `Config` from `OmegaConf`.") 34 | 35 | print(f"Dump Config to {save_path}.") 36 | -------------------------------------------------------------------------------- /live2diff/utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import imageio 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | from einops import rearrange 9 | from PIL import Image 10 | 11 | 12 | def read_video_frames(folder: str, height=None, width=None): 13 | """ 14 | Read video frames from the given folder. 15 | 16 | Output: 17 | frames, in [0, 255], uint8, THWC 18 | """ 19 | _SUPPORTED_EXTENSIONS = [".png", ".jpg", ".jpeg"] 20 | 21 | frames = [f for f in os.listdir(folder) if osp.splitext(f)[1] in _SUPPORTED_EXTENSIONS] 22 | # sort frames 23 | sorted_frames = sorted(frames, key=lambda x: int(osp.splitext(x)[0])) 24 | sorted_frames = [osp.join(folder, f) for f in sorted_frames] 25 | 26 | if height is not None and width is not None: 27 | sorted_frames = [np.array(Image.open(f).resize((width, height))) for f in sorted_frames] 28 | else: 29 | sorted_frames = [np.array(Image.open(f)) for f in sorted_frames] 30 | sorted_frames = torch.stack([torch.from_numpy(f) for f in sorted_frames], dim=0) 31 | return sorted_frames 32 | 33 | 34 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): 35 | videos = rearrange(videos, "b c t h w -> t b c h w") 36 | outputs = [] 37 | for x in videos: 38 | x = torchvision.utils.make_grid(x, nrow=n_rows) 39 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 40 | if rescale: 41 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 42 | x = (x * 255).numpy().astype(np.uint8) 43 | outputs.append(x) 44 | 45 | parent_dir = os.path.dirname(path) 46 | if parent_dir != "": 47 | os.makedirs(parent_dir, exist_ok=True) 48 | imageio.mimsave(path, outputs, fps=fps, loop=0) 49 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | # Never enforce `E501` (line length violations). 3 | lint.ignore = ["C901", "E501", "E741", "F402", "F403", "F405", "F823"] 4 | lint.select = ["C", "E", "F", "I", "W"] 5 | line-length = 119 6 | 7 | # Ignore import violations in all `__init__.py` files. 8 | [tool.ruff.lint.per-file-ignores] 9 | "__init__.py" = ["E402", "F401", "F811"] 10 | 11 | [tool.ruff.lint.isort] 12 | lines-after-imports = 2 13 | known-first-party = ["live2diff"] 14 | 15 | [tool.ruff.format] 16 | # Like Black, use double quotes for strings. 17 | quote-style = "double" 18 | 19 | # Like Black, indent with spaces, rather than tabs. 20 | indent-style = "space" 21 | 22 | # Like Black, respect magic trailing commas. 23 | skip-magic-trailing-comma = false 24 | 25 | # Like Black, automatically detect the appropriate line ending. 26 | line-ending = "auto" 27 | 28 | [build-system] 29 | requires = ["setuptools"] 30 | build-backend = "setuptools.build_meta" 31 | -------------------------------------------------------------------------------- /scripts/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | TOKEN=$2 3 | 4 | download_disney() { 5 | echo "Download checkpoint for Disney..." 6 | wget https://civitai.com/api/download/models/69832\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate 7 | } 8 | 9 | download_moxin () { 10 | echo "Download checkpoints for MoXin..." 11 | wget https://civitai.com/api/download/models/106289\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate 12 | wget https://civitai.com/api/download/models/14856\?token\=${TOKEN} -P ./models/LoRA --content-disposition --no-check-certificate 13 | } 14 | 15 | download_pixart () { 16 | echo "Download checkpoint for PixArt..." 17 | wget https://civitai.com/api/download/models/220049\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate 18 | } 19 | 20 | download_origami () { 21 | echo "Download checkpoints for origami..." 22 | wget https://civitai.com/api/download/models/270085\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate 23 | wget https://civitai.com/api/download/models/266928\?token\=${TOKEN} -P ./models/LoRA --content-disposition --no-check-certificate 24 | } 25 | 26 | download_threeDelicacy () { 27 | echo "Download checkpoints for threeDelicacy..." 28 | wget https://civitai.com/api/download/models/36473\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate 29 | } 30 | 31 | download_toonyou () { 32 | echo "Download checkpoint for Toonyou..." 33 | wget https://civitai.com/api/download/models/125771\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate 34 | } 35 | 36 | download_zaum () { 37 | echo "Download checkpoints for Zaum..." 38 | wget https://civitai.com/api/download/models/428862\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate 39 | wget https://civitai.com/api/download/models/18989\?token\=${TOKEN} -P ./models/LoRA --content-disposition --no-check-certificate 40 | } 41 | 42 | download_felted () { 43 | echo "Download checkpoints for Felted..." 44 | wget https://civitai.com/api/download/models/428862\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate 45 | wget https://civitai.com/api/download/models/86739\?token\=${TOKEN} -P ./models/LoRA --content-disposition --no-check-certificate 46 | } 47 | 48 | if [ -z "$1" ]; then 49 | echo "Please input the model you want to download." 50 | echo "Supported model: all, disney, moxin, pixart, paperArt, threeDelicacy, toonyou, zaum." 51 | exit 1 52 | fi 53 | 54 | declare -A download_func=( 55 | ["disney"]="download_disney" 56 | ["moxin"]="download_moxin" 57 | ["pixart"]="download_pixart" 58 | ["origami"]="download_origami" 59 | ["threeDelicacy"]="download_threeDelicacy" 60 | ["toonyou"]="download_toonyou" 61 | ["zaum"]="download_zaum" 62 | ["felted"]="download_felted" 63 | ) 64 | 65 | execute_function() { 66 | local key="$1" 67 | if [[ -n "${download_func[$key]}" ]]; then 68 | ${download_func[$key]} 69 | else 70 | echo "Function not found for key: $key" 71 | fi 72 | } 73 | 74 | 75 | for arg in "$@"; do 76 | case "$arg" in 77 | disney|moxin|pixart|origami|threeDelicacy|toonyou|zaum|felted) 78 | model_name="$arg" 79 | execute_function "$model_name" 80 | ;; 81 | all) 82 | for model_name in "${!download_func[@]}"; do 83 | execute_function "$model_name" 84 | done 85 | ;; 86 | *) 87 | echo "Invalid argument: $arg." 88 | exit 1 89 | ;; 90 | esac 91 | done 92 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | 4 | deps = [ 5 | "diffusers==0.25.0", 6 | "transformers", 7 | "accelerate", 8 | "fire", 9 | "einops", 10 | "omegaconf", 11 | "imageio", 12 | "timm==0.6.7", 13 | "lightning", 14 | "peft", 15 | "av", 16 | "decord", 17 | "pillow", 18 | "pywin32;sys_platform == 'win32'", 19 | ] 20 | 21 | deps_tensorrt = [ 22 | "onnx==1.16.0", 23 | "onnxruntime==1.16.3", 24 | "protobuf==5.27.0", 25 | "polygraphy", 26 | "onnx-graphsurgeon", 27 | "cuda-python", 28 | "tensorrt==10.0.1", 29 | "colored", 30 | ] 31 | deps_tensorrt_cu11 = [ 32 | "tensorrt_cu11_libs==10.0.1", 33 | "tensorrt_cu11_bindings==10.0.1", 34 | ] 35 | deps_tensorrt_cu12 = [ 36 | "tensorrt_cu12_libs==10.0.1", 37 | "tensorrt_cu12_bindings==10.0.1", 38 | ] 39 | extras = { 40 | "tensorrt_cu11": deps_tensorrt + deps_tensorrt_cu11, 41 | "tensorrt_cu12": deps_tensorrt + deps_tensorrt_cu12, 42 | } 43 | 44 | 45 | if __name__ == "__main__": 46 | setup( 47 | name="Live2Diff", 48 | version="0.1", 49 | description="real-time interactive video translation pipeline", 50 | long_description=open("README.md", "r", encoding="utf-8").read(), 51 | long_description_content_type="text/markdown", 52 | keywords="deep learning diffusion pytorch stable diffusion streamdiffusion real-time next-frame prediction", 53 | license="Apache 2.0 License", 54 | author="leo", 55 | author_email="xingzhening@pjlab.org.cn", 56 | url="https://github.com/open-mmlab/Live2Diff", 57 | package_dir={"": "live2diff"}, 58 | packages=find_packages("live2diff"), 59 | python_requires=">=3.10.0", 60 | install_requires=deps, 61 | extras_require=extras, 62 | ) 63 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List, Literal, Optional 3 | 4 | import fire 5 | import numpy as np 6 | import torch 7 | from decord import VideoReader 8 | from PIL import Image 9 | from torchvision import transforms 10 | from torchvision.io import write_video 11 | from tqdm import tqdm 12 | 13 | from live2diff.utils.config import load_config 14 | from live2diff.utils.io import read_video_frames, save_videos_grid 15 | from live2diff.utils.wrapper import StreamAnimateDiffusionDepthWrapper 16 | 17 | 18 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 19 | 20 | 21 | def main( 22 | input: str, 23 | config_path: str, 24 | prompt: Optional[str] = None, 25 | prompt_template: Optional[str] = None, 26 | output: str = os.path.join("outputs", "output.mp4"), 27 | dreambooth_path: Optional[str] = None, 28 | lora_dict: Optional[Dict[str, float]] = None, 29 | height: int = 512, 30 | width: int = 512, 31 | max_frames: int = -1, 32 | num_inference_steps: Optional[int] = None, 33 | t_index_list: Optional[List[int]] = None, 34 | strength: Optional[float] = None, 35 | acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt", 36 | enable_similar_image_filter: bool = False, 37 | few_step_model_type: str = "lcm", 38 | enable_tiny_vae: bool = True, 39 | fps: int = 16, 40 | save_input: bool = True, 41 | seed: int = 42, 42 | ): 43 | """ 44 | Process for generating images based on a prompt using a specified model. 45 | 46 | Parameters 47 | ---------- 48 | input : str 49 | The input video name or name of video frames to load images from. 50 | config_path: str, optional 51 | The path to config file. 52 | prompt : str 53 | The prompt to generate images from. 54 | prompt_template: str, optional 55 | The template for specific dreambooth / LoRA. If not None, `{}` must be contained, 56 | and the prompt used for inference will be `prompt_template.format(prompt)`. 57 | output : str, optional 58 | The output video name to save images to. 59 | model_id_or_path : str 60 | The name of the model to use for image generation. 61 | lora_dict : Optional[Dict[str, float]], optional 62 | The lora_dict to load, by default None. 63 | Keys are the LoRA names and values are the LoRA scales. 64 | Example: `python main.py --lora_dict='{"LoRA_1" : 0.5 , "LoRA_2" : 0.7 ,...}'` 65 | height: int, optional 66 | The height of the image, by default 512. 67 | width: int, optional 68 | The width of the image, by default 512. 69 | max_frames : int, optional 70 | The maximum number of frames to process, by default -1. 71 | acceleration : Literal["none", "xformers", "tensorrt"] 72 | The type of acceleration to use for image generation. 73 | enable_similar_image_filter : bool, optional 74 | Whether to enable similar image filter or not, 75 | by default True. 76 | fps: int 77 | The fps of the output video, by default 16. 78 | save_input: bool, optional 79 | Whether to save the input video or not, by default True. 80 | If true, the input video will be saved as `output` + "_inp.mp4". 81 | seed : int, optional 82 | The seed, by default 42. if -1, use random seed. 83 | """ 84 | 85 | if os.path.isdir(input): 86 | video = read_video_frames(input) / 255 87 | elif input.endswith(".mp4"): 88 | reader = VideoReader(input) 89 | total_frames = len(reader) 90 | frame_indices = np.arange(total_frames) 91 | video = reader.get_batch(frame_indices).asnumpy() / 255 92 | video = torch.from_numpy(video) 93 | elif input.endswith(".gif"): 94 | video_frames = [] 95 | image = Image.open(input) 96 | for frames in range(image.n_frames): 97 | image.seek(frames) 98 | video_frames.append(np.array(image.convert("RGB"))) 99 | video = torch.from_numpy(np.array(video_frames)) / 255 100 | 101 | video = video[2:] 102 | 103 | height = int(height // 8 * 8) 104 | width = int(width // 8 * 8) 105 | 106 | trans = transforms.Compose( 107 | [ 108 | transforms.Resize(min(height, width), antialias=True), 109 | transforms.CenterCrop((height, width)), 110 | ] 111 | ) 112 | video = trans(video.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 113 | 114 | if max_frames > 0: 115 | video = video[: min(max_frames, len(video))] 116 | print(f"Clipping video to {len(video)} frames.") 117 | 118 | cfg = load_config(config_path) 119 | print("Inference Config:") 120 | print(cfg) 121 | 122 | # handle prompt 123 | cfg_prompt = cfg.get("prompt", None) 124 | prompt = prompt or cfg_prompt 125 | 126 | prompt_template = prompt_template or cfg.get("prompt_template", None) 127 | if prompt_template is not None: 128 | assert "{}" in prompt_template, '"{}" must be contained in "prompt_template".' 129 | prompt = prompt_template.format(prompt) 130 | 131 | print(f'Convert input prompt to "{prompt}".') 132 | 133 | # handle timesteps 134 | num_inference_steps = num_inference_steps or cfg.get("num_inference_steps", None) 135 | strength = strength or cfg.get("strength", None) 136 | t_index_list = t_index_list or cfg.get("t_index_list", None) 137 | 138 | stream = StreamAnimateDiffusionDepthWrapper( 139 | few_step_model_type=few_step_model_type, 140 | config_path=config_path, 141 | cfg_type="none", 142 | dreambooth_path=dreambooth_path, 143 | lora_dict=lora_dict, 144 | strength=strength, 145 | num_inference_steps=num_inference_steps, 146 | t_index_list=t_index_list, 147 | frame_buffer_size=1, 148 | width=width, 149 | height=height, 150 | acceleration=acceleration, 151 | do_add_noise=True, 152 | output_type="pt", 153 | enable_similar_image_filter=enable_similar_image_filter, 154 | similar_image_filter_threshold=0.98, 155 | use_denoising_batch=True, 156 | use_tiny_vae=enable_tiny_vae, 157 | seed=seed, 158 | ) 159 | warmup_frames = video[:8].permute(0, 3, 1, 2) 160 | warmup_results = stream.prepare( 161 | warmup_frames=warmup_frames, 162 | prompt=prompt, 163 | guidance_scale=1, 164 | ) 165 | video_result = torch.zeros(video.shape[0], height, width, 3) 166 | warmup_results = warmup_results.cpu().float() 167 | video_result[:8] = warmup_results 168 | 169 | skip_frames = stream.batch_size - 1 170 | for i in tqdm(range(8, video.shape[0])): 171 | output_image = stream(video[i].permute(2, 0, 1)) 172 | if i - 8 >= skip_frames: 173 | video_result[i - skip_frames] = output_image.permute(1, 2, 0) 174 | video_result = video_result[:-skip_frames] 175 | # video_result = video_result[:8] 176 | 177 | save_root = os.path.dirname(output) 178 | if save_root != "": 179 | os.makedirs(save_root, exist_ok=True) 180 | if output.endswith(".mp4"): 181 | video_result = video_result * 255 182 | write_video(output, video_result, fps=fps) 183 | if save_input: 184 | write_video(output.replace(".mp4", "_inp.mp4"), video * 255, fps=fps) 185 | elif output.endswith(".gif"): 186 | save_videos_grid( 187 | video_result.permute(3, 0, 1, 2)[None, ...], 188 | output, 189 | rescale=False, 190 | fps=fps, 191 | ) 192 | if save_input: 193 | save_videos_grid( 194 | video.permute(3, 0, 1, 2)[None, ...], 195 | output.replace(".gif", "_inp.gif"), 196 | rescale=False, 197 | fps=fps, 198 | ) 199 | else: 200 | raise TypeError(f"Unsupported output format: {output}") 201 | print("Inference time ema: ", stream.stream.inference_time_ema) 202 | inference_time_list = np.array(stream.stream.inference_time_list) 203 | print(f"Inference time mean & std: {inference_time_list.mean()} +/- {inference_time_list.std()}") 204 | if hasattr(stream.stream, "depth_time_ema"): 205 | print("Depth time ema: ", stream.stream.depth_time_ema) 206 | 207 | print(f'Video saved to "{output}".') 208 | 209 | 210 | if __name__ == "__main__": 211 | fire.Fire(main) 212 | --------------------------------------------------------------------------------