├── .gitignore ├── LICENSE ├── README.md ├── StreamV2V ├── streamv2v │ ├── __init__.py │ ├── acceleration │ │ ├── __init__.py │ │ ├── sfast │ │ │ └── __init__.py │ │ └── tensorrt │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ ├── engine.py │ │ │ ├── models.py │ │ │ └── utilities.py │ ├── image_filter.py │ ├── image_utils.py │ ├── models │ │ ├── __init__.py │ │ ├── attention_processor.py │ │ └── utils.py │ ├── pip_utils.py │ ├── pipeline.py │ └── tools │ │ ├── __init__.py │ │ └── install-tensorrt.py ├── utils │ ├── __init__.py │ ├── viewer.py │ └── wrapper.py └── vid2vid │ ├── README.md │ ├── batch_eval.py │ └── main.py ├── __init__.py ├── donate.jpg ├── lora_weights ├── face │ └── put lora model here └── style │ └── put lora model here ├── nodes.py ├── requirements.txt ├── web.png ├── web └── js │ ├── previewVideo.js │ ├── uploadAudio.js │ └── uploadVideo.js └── wechat.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.safetensors -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Stream V2V 2 | 8417 MAR 3 | UT AUSTIN RESEARCH LICENSE 4 | (NONCONFIDENTIAL SOURCE CODE) 5 | 6 | The University of Texas at Austin has developed certain software and documentation that it desires to make available without charge to anyone for academic, research, experimental or personal use. If you wish to distribute or make other use of the software, you may purchase a license to do so from The University of Texas at Austin (licensing@otc.utexas.edu). 7 | The accompanying source code is made available to you under the terms of this UT Research License (this “UTRL”). By installing or using the code, you are consenting to be bound by this UTRL. If you do not agree to the terms and conditions of this license, do not install or use any part of the code. 8 | The terms and conditions in this UTRL not only apply to the source code made available by Licensor, but also to any improvements to, or derivative works of, that source code made by you and to any object code compiled from such source code, improvements or derivative works. 9 | 10 | 1. DEFINITIONS. 11 | 1.1 “Commercial Use” shall mean use of Software or Documentation by Licensee for direct or indirect financial, commercial or strategic gain or advantage, including without limitation: (a) bundling or integrating the Software with any hardware product or another software product for transfer, sale or license to a third party (even if distributing the Software on separate media and not charging for the Software); (b) providing customers with a link to the Software or a copy of the Software for use with hardware or another software product purchased by that customer; or (c) use in connection with the performance of services for which Licensee is compensated. 12 | 1.2 “Derivative Products” means any improvements to, or other derivative works of, the Software made by Licensee, and any computer software programs, and accompanying documentation, developed by Licensee which are a modification of, enhancement to, derived from or based upon the Licensed Software or documentation provided by Licensor for the Licensed Software, and any object code compiled from such computer software programs. 13 | 1.3 “Documentation” shall mean all manuals, user documentation, and other related materials pertaining to the Software that are made available to Licensee in connection with the Software. 14 | 1.4 “Licensor” shall mean The University of Texas at Austin, on behalf of the Board of Regents of the University of Texas System, an agency of the State of Texas, whose address is 3925 W. Braker Lane, Suite 1.9A (R3500), Austin, Texas 78759. 15 | 1.5 “Licensee” or “you” shall mean the person or entity that has agreed to the terms hereof and is exercising rights granted hereunder. 16 | 1.6 “Software” shall mean the computer program(s) referred to as: “Stream V2V” (UT Tech ID 8417 MAR), which is made available under this UTRL in source code form, including any error corrections, bug fixes, patches, updates or other modifications that Licensor may in its sole discretion make available to Licensee from time to time, and any object code compiled from such source code. 17 | 18 | 2. GRANT OF RIGHTS. 19 | Subject to the terms and conditions hereunder, Licensor hereby grants to Licensee a worldwide, non-transferable, non-exclusive license to (a) install, use and reproduce the Software for academic, research, experimental and personal use (but specifically excluding Commercial Use); (b) use and modify the Software to create Derivative Products, subject to Section 3.2; (c) use the Documentation, if any, solely in connection with Licensee’s authorized use of the Software; and (d) a non-exclusive, royalty-free license for academic, research, experimental and personal use (but specifically excluding Commercial Use) to those patents, of which Diana Marculescu or Jeff Liang is a named inventor, that are licensable by Licensee and that are necessarily infringed by such authorized use of the Software, and solely in connection with Licensee’s authorized use of the Software. 20 | 21 | 3. RESTRICTIONS; COVENANTS. 22 | 3.1 Licensee may not: (a) distribute, sub-license or otherwise transfer copies or rights to the Software (or any portion thereof) or the Documentation; (b) use the Software (or any portion thereof) or Documentation for Commercial Use, or for any other use except as described in Section 2; (c) copy the Software or Documentation other than for archival and backup purposes; or (d) remove any product identification, copyright, proprietary notices or labels from the Software and Documentation. This UTRL confers no rights upon Licensee except those expressly granted herein. 23 | 3.2 Derivative Products. Licensee hereby agrees that it will provide a copy of all Derivative Products to Licensor and that its use of the Derivative Products will be subject to all of the same terms, conditions, restrictions and limitations on use imposed on the Software under this UTRL. Licensee hereby grants Licensor a worldwide, non-exclusive, royalty-free license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute Derivative Products. Licensee also hereby grants Licensor a worldwide, non-exclusive, royalty-free patent license to make, have made, use, offer to sell, sell, import and otherwise transfer the Derivative Products under those patent claims, from patents of which Diana Marculescu or Jeff Liang is a named inventor, that licensable by Licensee that are necessarily infringed by the Derivative Products. 24 | 25 | 4. CONFIDENTIALITY; PROTECTION OF SOFTWARE. 26 | 4.1 Reserved. 27 | 4.2 Proprietary Notices. Licensee shall maintain and place on any copy of Software or Documentation that it reproduces for internal use all notices as are authorized and/or required hereunder. Licensee shall include a copy of this UTRL and the following notice, on each copy of the Software and Documentation. Such license and notice shall be embedded in each copy of the Software, in the video screen display, on the physical medium embodying the Software copy and on any Documentation: 28 | Copyright © 2021, The University of Texas at Austin. All rights reserved. 29 | UNIVERSITY EXPRESSLY DISCLAIMS ANY AND ALL WARRANTIES CONCERNING THIS SOFTWARE AND DOCUMENTATION, INCLUDING ANY WARRANTIES OF MERCHANTABILITY, FITNESS FOR ANY PARTICULAR PURPOSE, NON-INFRINGEMENT AND WARRANTIES OF PERFORMANCE, AND ANY WARRANTY THAT MIGHT OTHERWISE ARISE FROM COURSE OF DEALING OR USAGE OF TRADE. NO WARRANTY IS EITHER EXPRESS OR IMPLIED WITH RESPECT TO THE USE OF THE SOFTWARE OR DOCUMENTATION. Under no circumstances shall University be liable for incidental, special, indirect, direct or consequential damages or loss of profits, interruption of business, or related expenses which may arise from use of Software or Documentation, including but not limited to those resulting from defects in Software and/or Documentation, or loss or inaccuracy of data of any kind. 30 | 31 | 5. WARRANTIES. 32 | 5.1 Disclaimer of Warranties. TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE SOFTWARE AND DOCUMENTATION ARE BEING PROVIDED ON AN “AS IS” BASIS WITHOUT ANY WARRANTIES OF ANY KIND RESPECTING THE SOFTWARE OR DOCUMENTATION, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY WARRANTY OF DESIGN, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, OR NON-INFRINGEMENT. 33 | 5.2 Limitation of Liability. UNDER NO CIRCUMSTANCES UNLESS REQUIRED BY APPLICABLE LAW SHALL LICENSOR BE LIABLE FOR INCIDENTAL, SPECIAL, INDIRECT, DIRECT OR CONSEQUENTIAL DAMAGES OR LOSS OF PROFITS, INTERRUPTION OF BUSINESS, OR RELATED EXPENSES WHICH MAY ARISE AS A RESULT OF THIS LICENSE OR OUT OF THE USE OR ATTEMPT OF USE OF SOFTWARE OR DOCUMENTATION INCLUDING BUT NOT LIMITED TO THOSE RESULTING FROM DEFECTS IN SOFTWARE AND/OR DOCUMENTATION, OR LOSS OR INACCURACY OF DATA OF ANY KIND. THE FOREGOING EXCLUSIONS AND LIMITATIONS WILL APPLY TO ALL CLAIMS AND ACTIONS OF ANY KIND, WHETHER BASED ON CONTRACT, TORT (INCLUDING, WITHOUT LIMITATION, NEGLIGENCE), OR ANY OTHER GROUNDS. 34 | 35 | 6. INDEMNIFICATION. 36 | Licensee shall indemnify, defend and hold harmless Licensor, the University of Texas System, their Regents, and their officers, agents and employees from and against any claims, demands, or causes of action whatsoever caused by, or arising out of, or resulting from, the exercise or practice of the license granted hereunder by Licensee, its officers, employees, agents or representatives. 37 | 38 | 7. TERMINATION. 39 | If Licensee breaches this UTRL, Licensee’s right to use the Software and Documentation will terminate immediately without notice, but all provisions of this UTRL except Section 2 will survive termination and continue in effect. Upon termination, Licensee must destroy all copies of the Software and Documentation. 40 | 8. GOVERNING LAW; JURISDICTION AND VENUE. 41 | 42 | The validity, interpretation, construction and performance of this UTRL shall be governed by the laws of the State of Texas. The Texas state courts of Travis County, Texas (or, if there is exclusive federal jurisdiction, the United States District Court for the Western District of Texas) shall have exclusive jurisdiction and venue over any dispute arising out of this UTRL, and Licensee consents to the jurisdiction of such courts. Application of the United Nations Convention on Contracts for the International Sale of Goods is expressly excluded. 43 | 44 | 9. EXPORT CONTROLS. 45 | This license is subject to all applicable export restrictions. Licensee must comply with all export and import laws and restrictions and regulations of any United States or foreign agency or authority relating to the Software and its use. 46 | 47 | 10. U.S. GOVERNMENT END-USERS. 48 | The Software is a “commercial item,” as that term is defined in 48 C.F.R. 2.101, consisting of “commercial computer software” and “commercial computer software documentation,” as such terms are used in 48 C.F.R. 12.212 (Sept. 1995) and 48 C.F.R. 227.7202 (June 1995). Consistent with 48 C.F.R. 12.212, 48 C.F.R. 27.405(b)(2) (June 1998) and 48 C.F.R. 227.7202, all U.S. Government End Users acquire the Software with only those rights as set forth herein. 49 | 50 | 11. MISCELLANEOUS 51 | If any provision hereof shall be held illegal, invalid or unenforceable, in whole or in part, such provision shall be modified to the minimum extent necessary to make it legal, valid and enforceable, and the legality, validity and enforceability of all other provisions of this UTRL shall not be affected thereby. Licensee may not assign this UTRL in whole or in part, without Licensor’s prior written consent. Any attempt to assign this UTRL without such consent will be null and void. This UTRL is the complete and exclusive statement between Licensee and Licensor relating to the subject matter hereof and supersedes all prior oral and written and all contemporaneous oral negotiations, commitments and understandings of the parties, if any. Any waiver by either party of any default or breach hereunder shall not constitute a waiver of any provision of this UTRL or of any subsequent default or breach of the same or a different kind. 52 | 53 | END OF LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-StreamV2V 2 | the comfyui custom node of [StreamV2V](https://github.com/Jeff-LiangF/streamv2v) 3 |
4 |
5 | webpage 6 |
7 |
8 | 9 | ## How to use 10 | make sure `ffmpeg` is worked in your commandline 11 | for Linux 12 | ``` 13 | apt update 14 | apt install ffmpeg 15 | ``` 16 | for Windows,you can install `ffmpeg` by [WingetUI](https://github.com/marticliment/WingetUI) automatically 17 | 18 | then! 19 | ``` 20 | git clone https://github.com/AIFSH/ComfyUI-StreamV2V.git 21 | cd ComfyUI-StreamV2V 22 | pip install -r requirements.txt 23 | 24 | ## insatll xformers match your torch,for torch==2.1.0+cu121 25 | pip install xformers==0.0.22.post7 26 | pip install accelerate 27 | ``` 28 | weights will be downloaded from huggingface automaticly! 29 | 30 | ## Tutorial 31 | - [Demo](https://www.bilibili.com/video/BV12m42157Us) 32 | 33 | ## WeChat Group && Donate 34 |
35 |
36 | Wechat 37 | donate 38 |
39 |
40 | 41 | ## Thanks 42 | - [StreamV2V](https://github.com/Jeff-LiangF/streamv2v) 43 | -------------------------------------------------------------------------------- /StreamV2V/streamv2v/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline import StreamV2V 2 | -------------------------------------------------------------------------------- /StreamV2V/streamv2v/acceleration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-StreamV2V/ddc939992d873706ae54094d906ea328cb82613b/StreamV2V/streamv2v/acceleration/__init__.py -------------------------------------------------------------------------------- /StreamV2V/streamv2v/acceleration/sfast/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from sfast.compilers.stable_diffusion_pipeline_compiler import CompilationConfig, compile 4 | 5 | from ...pipeline import StreamV2V 6 | 7 | 8 | def accelerate_with_stable_fast( 9 | stream: StreamV2V, 10 | config: Optional[CompilationConfig] = None, 11 | ): 12 | if config is None: 13 | config = CompilationConfig.Default() 14 | # xformers and Triton are suggested for achieving best performance. 15 | try: 16 | import xformers 17 | 18 | config.enable_xformers = True 19 | except ImportError: 20 | print("xformers not installed, skip") 21 | try: 22 | import triton 23 | 24 | config.enable_triton = True 25 | except ImportError: 26 | print("Triton not installed, skip") 27 | # CUDA Graph is suggested for small batch sizes and small resolutions to reduce CPU overhead. 28 | config.enable_cuda_graph = True 29 | stream.pipe = compile(stream.pipe, config) 30 | stream.unet = stream.pipe.unet 31 | stream.vae = stream.pipe.vae 32 | stream.text_encoder = stream.pipe.text_encoder 33 | return stream 34 | -------------------------------------------------------------------------------- /StreamV2V/streamv2v/acceleration/tensorrt/__init__.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | 4 | import torch 5 | from diffusers import AutoencoderKL, UNet2DConditionModel 6 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( 7 | retrieve_latents, 8 | ) 9 | from polygraphy import cuda 10 | 11 | from ...pipeline import StreamV2V 12 | from .builder import EngineBuilder, create_onnx_path 13 | from .engine import AutoencoderKLEngine, UNet2DConditionModelEngine 14 | from .models import VAE, BaseModel, UNet, VAEEncoder 15 | 16 | 17 | class TorchVAEEncoder(torch.nn.Module): 18 | def __init__(self, vae: AutoencoderKL): 19 | super().__init__() 20 | self.vae = vae 21 | 22 | def forward(self, x: torch.Tensor): 23 | return retrieve_latents(self.vae.encode(x)) 24 | 25 | 26 | def compile_vae_encoder( 27 | vae: TorchVAEEncoder, 28 | model_data: BaseModel, 29 | onnx_path: str, 30 | onnx_opt_path: str, 31 | engine_path: str, 32 | opt_batch_size: int = 1, 33 | engine_build_options: dict = {}, 34 | ): 35 | builder = EngineBuilder(model_data, vae, device=torch.device("cuda")) 36 | builder.build( 37 | onnx_path, 38 | onnx_opt_path, 39 | engine_path, 40 | opt_batch_size=opt_batch_size, 41 | **engine_build_options, 42 | ) 43 | 44 | 45 | def compile_vae_decoder( 46 | vae: AutoencoderKL, 47 | model_data: BaseModel, 48 | onnx_path: str, 49 | onnx_opt_path: str, 50 | engine_path: str, 51 | opt_batch_size: int = 1, 52 | engine_build_options: dict = {}, 53 | ): 54 | vae = vae.to(torch.device("cuda")) 55 | builder = EngineBuilder(model_data, vae, device=torch.device("cuda")) 56 | builder.build( 57 | onnx_path, 58 | onnx_opt_path, 59 | engine_path, 60 | opt_batch_size=opt_batch_size, 61 | **engine_build_options, 62 | ) 63 | 64 | 65 | def compile_unet( 66 | unet: UNet2DConditionModel, 67 | model_data: BaseModel, 68 | onnx_path: str, 69 | onnx_opt_path: str, 70 | engine_path: str, 71 | opt_batch_size: int = 1, 72 | engine_build_options: dict = {}, 73 | ): 74 | unet = unet.to(torch.device("cuda"), dtype=torch.float16) 75 | builder = EngineBuilder(model_data, unet, device=torch.device("cuda")) 76 | builder.build( 77 | onnx_path, 78 | onnx_opt_path, 79 | engine_path, 80 | opt_batch_size=opt_batch_size, 81 | **engine_build_options, 82 | ) 83 | 84 | 85 | def accelerate_with_tensorrt( 86 | stream: StreamV2V, 87 | engine_dir: str, 88 | max_batch_size: int = 2, 89 | min_batch_size: int = 1, 90 | use_cuda_graph: bool = False, 91 | engine_build_options: dict = {}, 92 | ): 93 | if "opt_batch_size" not in engine_build_options or engine_build_options["opt_batch_size"] is None: 94 | engine_build_options["opt_batch_size"] = max_batch_size 95 | text_encoder = stream.text_encoder 96 | unet = stream.unet 97 | vae = stream.vae 98 | 99 | del stream.unet, stream.vae, stream.pipe.unet, stream.pipe.vae 100 | 101 | vae_config = vae.config 102 | vae_dtype = vae.dtype 103 | 104 | unet.to(torch.device("cpu")) 105 | vae.to(torch.device("cpu")) 106 | 107 | gc.collect() 108 | torch.cuda.empty_cache() 109 | 110 | onnx_dir = os.path.join(engine_dir, "onnx") 111 | os.makedirs(onnx_dir, exist_ok=True) 112 | 113 | unet_engine_path = f"{engine_dir}/unet.engine" 114 | vae_encoder_engine_path = f"{engine_dir}/vae_encoder.engine" 115 | vae_decoder_engine_path = f"{engine_dir}/vae_decoder.engine" 116 | 117 | unet_model = UNet( 118 | fp16=True, 119 | device=stream.device, 120 | max_batch_size=max_batch_size, 121 | min_batch_size=min_batch_size, 122 | embedding_dim=text_encoder.config.hidden_size, 123 | unet_dim=unet.config.in_channels, 124 | ) 125 | vae_decoder_model = VAE( 126 | device=stream.device, 127 | max_batch_size=max_batch_size, 128 | min_batch_size=min_batch_size, 129 | ) 130 | vae_encoder_model = VAEEncoder( 131 | device=stream.device, 132 | max_batch_size=max_batch_size, 133 | min_batch_size=min_batch_size, 134 | ) 135 | 136 | if not os.path.exists(unet_engine_path): 137 | compile_unet( 138 | unet, 139 | unet_model, 140 | create_onnx_path("unet", onnx_dir, opt=False), 141 | create_onnx_path("unet", onnx_dir, opt=True), 142 | unet_engine_path, 143 | **engine_build_options, 144 | ) 145 | else: 146 | del unet 147 | 148 | if not os.path.exists(vae_decoder_engine_path): 149 | vae.forward = vae.decode 150 | compile_vae_decoder( 151 | vae, 152 | vae_decoder_model, 153 | create_onnx_path("vae_decoder", onnx_dir, opt=False), 154 | create_onnx_path("vae_decoder", onnx_dir, opt=True), 155 | vae_decoder_engine_path, 156 | **engine_build_options, 157 | ) 158 | 159 | if not os.path.exists(vae_encoder_engine_path): 160 | vae_encoder = TorchVAEEncoder(vae).to(torch.device("cuda")) 161 | compile_vae_encoder( 162 | vae_encoder, 163 | vae_encoder_model, 164 | create_onnx_path("vae_encoder", onnx_dir, opt=False), 165 | create_onnx_path("vae_encoder", onnx_dir, opt=True), 166 | vae_encoder_engine_path, 167 | **engine_build_options, 168 | ) 169 | 170 | del vae 171 | 172 | cuda_steram = cuda.Stream() 173 | 174 | stream.unet = UNet2DConditionModelEngine(unet_engine_path, cuda_steram, use_cuda_graph=use_cuda_graph) 175 | stream.vae = AutoencoderKLEngine( 176 | vae_encoder_engine_path, 177 | vae_decoder_engine_path, 178 | cuda_steram, 179 | stream.pipe.vae_scale_factor, 180 | use_cuda_graph=use_cuda_graph, 181 | ) 182 | setattr(stream.vae, "config", vae_config) 183 | setattr(stream.vae, "dtype", vae_dtype) 184 | 185 | gc.collect() 186 | torch.cuda.empty_cache() 187 | 188 | return stream 189 | -------------------------------------------------------------------------------- /StreamV2V/streamv2v/acceleration/tensorrt/builder.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | from typing import * 4 | 5 | import torch 6 | 7 | from .models import BaseModel 8 | from .utilities import ( 9 | build_engine, 10 | export_onnx, 11 | optimize_onnx, 12 | ) 13 | 14 | 15 | def create_onnx_path(name, onnx_dir, opt=True): 16 | return os.path.join(onnx_dir, name + (".opt" if opt else "") + ".onnx") 17 | 18 | 19 | class EngineBuilder: 20 | def __init__( 21 | self, 22 | model: BaseModel, 23 | network: Any, 24 | device=torch.device("cuda"), 25 | ): 26 | self.device = device 27 | 28 | self.model = model 29 | self.network = network 30 | 31 | def build( 32 | self, 33 | onnx_path: str, 34 | onnx_opt_path: str, 35 | engine_path: str, 36 | opt_image_height: int = 512, 37 | opt_image_width: int = 512, 38 | opt_batch_size: int = 1, 39 | min_image_resolution: int = 256, 40 | max_image_resolution: int = 1024, 41 | build_enable_refit: bool = False, 42 | build_static_batch: bool = False, 43 | build_dynamic_shape: bool = False, 44 | build_all_tactics: bool = False, 45 | onnx_opset: int = 17, 46 | force_engine_build: bool = False, 47 | force_onnx_export: bool = False, 48 | force_onnx_optimize: bool = False, 49 | ): 50 | if not force_onnx_export and os.path.exists(onnx_path): 51 | print(f"Found cached model: {onnx_path}") 52 | else: 53 | print(f"Exporting model: {onnx_path}") 54 | export_onnx( 55 | self.network, 56 | onnx_path=onnx_path, 57 | model_data=self.model, 58 | opt_image_height=opt_image_height, 59 | opt_image_width=opt_image_width, 60 | opt_batch_size=opt_batch_size, 61 | onnx_opset=onnx_opset, 62 | ) 63 | del self.network 64 | gc.collect() 65 | torch.cuda.empty_cache() 66 | if not force_onnx_optimize and os.path.exists(onnx_opt_path): 67 | print(f"Found cached model: {onnx_opt_path}") 68 | else: 69 | print(f"Generating optimizing model: {onnx_opt_path}") 70 | optimize_onnx( 71 | onnx_path=onnx_path, 72 | onnx_opt_path=onnx_opt_path, 73 | model_data=self.model, 74 | ) 75 | self.model.min_latent_shape = min_image_resolution // 8 76 | self.model.max_latent_shape = max_image_resolution // 8 77 | if not force_engine_build and os.path.exists(engine_path): 78 | print(f"Found cached engine: {engine_path}") 79 | else: 80 | build_engine( 81 | engine_path=engine_path, 82 | onnx_opt_path=onnx_opt_path, 83 | model_data=self.model, 84 | opt_image_height=opt_image_height, 85 | opt_image_width=opt_image_width, 86 | opt_batch_size=opt_batch_size, 87 | build_static_batch=build_static_batch, 88 | build_dynamic_shape=build_dynamic_shape, 89 | build_all_tactics=build_all_tactics, 90 | build_enable_refit=build_enable_refit, 91 | ) 92 | 93 | gc.collect() 94 | torch.cuda.empty_cache() 95 | -------------------------------------------------------------------------------- /StreamV2V/streamv2v/acceleration/tensorrt/engine.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import torch 4 | from diffusers.models.autoencoder_tiny import AutoencoderTinyOutput 5 | from diffusers.models.unet_2d_condition import UNet2DConditionOutput 6 | from diffusers.models.vae import DecoderOutput 7 | from polygraphy import cuda 8 | 9 | from .utilities import Engine 10 | 11 | 12 | class UNet2DConditionModelEngine: 13 | def __init__(self, filepath: str, stream: cuda.Stream, use_cuda_graph: bool = False): 14 | self.engine = Engine(filepath) 15 | self.stream = stream 16 | self.use_cuda_graph = use_cuda_graph 17 | 18 | self.engine.load() 19 | self.engine.activate() 20 | 21 | def __call__( 22 | self, 23 | latent_model_input: torch.Tensor, 24 | timestep: torch.Tensor, 25 | encoder_hidden_states: torch.Tensor, 26 | **kwargs, 27 | ) -> Any: 28 | if timestep.dtype != torch.float32: 29 | timestep = timestep.float() 30 | 31 | self.engine.allocate_buffers( 32 | shape_dict={ 33 | "sample": latent_model_input.shape, 34 | "timestep": timestep.shape, 35 | "encoder_hidden_states": encoder_hidden_states.shape, 36 | "latent": latent_model_input.shape, 37 | }, 38 | device=latent_model_input.device, 39 | ) 40 | 41 | noise_pred = self.engine.infer( 42 | { 43 | "sample": latent_model_input, 44 | "timestep": timestep, 45 | "encoder_hidden_states": encoder_hidden_states, 46 | }, 47 | self.stream, 48 | use_cuda_graph=self.use_cuda_graph, 49 | )["latent"] 50 | return UNet2DConditionOutput(sample=noise_pred) 51 | 52 | def to(self, *args, **kwargs): 53 | pass 54 | 55 | def forward(self, *args, **kwargs): 56 | pass 57 | 58 | 59 | class AutoencoderKLEngine: 60 | def __init__( 61 | self, 62 | encoder_path: str, 63 | decoder_path: str, 64 | stream: cuda.Stream, 65 | scaling_factor: int, 66 | use_cuda_graph: bool = False, 67 | ): 68 | self.encoder = Engine(encoder_path) 69 | self.decoder = Engine(decoder_path) 70 | self.stream = stream 71 | self.vae_scale_factor = scaling_factor 72 | self.use_cuda_graph = use_cuda_graph 73 | 74 | self.encoder.load() 75 | self.decoder.load() 76 | self.encoder.activate() 77 | self.decoder.activate() 78 | 79 | def encode(self, images: torch.Tensor, **kwargs): 80 | self.encoder.allocate_buffers( 81 | shape_dict={ 82 | "images": images.shape, 83 | "latent": ( 84 | images.shape[0], 85 | 4, 86 | images.shape[2] // self.vae_scale_factor, 87 | images.shape[3] // self.vae_scale_factor, 88 | ), 89 | }, 90 | device=images.device, 91 | ) 92 | latents = self.encoder.infer( 93 | {"images": images}, 94 | self.stream, 95 | use_cuda_graph=self.use_cuda_graph, 96 | )["latent"] 97 | return AutoencoderTinyOutput(latents=latents) 98 | 99 | def decode(self, latent: torch.Tensor, **kwargs): 100 | self.decoder.allocate_buffers( 101 | shape_dict={ 102 | "latent": latent.shape, 103 | "images": ( 104 | latent.shape[0], 105 | 3, 106 | latent.shape[2] * self.vae_scale_factor, 107 | latent.shape[3] * self.vae_scale_factor, 108 | ), 109 | }, 110 | device=latent.device, 111 | ) 112 | images = self.decoder.infer( 113 | {"latent": latent}, 114 | self.stream, 115 | use_cuda_graph=self.use_cuda_graph, 116 | )["images"] 117 | return DecoderOutput(sample=images) 118 | 119 | def to(self, *args, **kwargs): 120 | pass 121 | 122 | def forward(self, *args, **kwargs): 123 | pass 124 | -------------------------------------------------------------------------------- /StreamV2V/streamv2v/acceleration/tensorrt/models.py: -------------------------------------------------------------------------------- 1 | #! fork: https://github.com/NVIDIA/TensorRT/blob/main/demo/Diffusion/models.py 2 | 3 | # 4 | # SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 5 | # SPDX-License-Identifier: Apache-2.0 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | 20 | import onnx_graphsurgeon as gs 21 | import torch 22 | from onnx import shape_inference 23 | from polygraphy.backend.onnx.loader import fold_constants 24 | 25 | 26 | class Optimizer: 27 | def __init__(self, onnx_graph, verbose=False): 28 | self.graph = gs.import_onnx(onnx_graph) 29 | self.verbose = verbose 30 | 31 | def info(self, prefix): 32 | if self.verbose: 33 | print( 34 | f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs" 35 | ) 36 | 37 | def cleanup(self, return_onnx=False): 38 | self.graph.cleanup().toposort() 39 | if return_onnx: 40 | return gs.export_onnx(self.graph) 41 | 42 | def select_outputs(self, keep, names=None): 43 | self.graph.outputs = [self.graph.outputs[o] for o in keep] 44 | if names: 45 | for i, name in enumerate(names): 46 | self.graph.outputs[i].name = name 47 | 48 | def fold_constants(self, return_onnx=False): 49 | onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) 50 | self.graph = gs.import_onnx(onnx_graph) 51 | if return_onnx: 52 | return onnx_graph 53 | 54 | def infer_shapes(self, return_onnx=False): 55 | onnx_graph = gs.export_onnx(self.graph) 56 | if onnx_graph.ByteSize() > 2147483648: 57 | raise TypeError("ERROR: model size exceeds supported 2GB limit") 58 | else: 59 | onnx_graph = shape_inference.infer_shapes(onnx_graph) 60 | 61 | self.graph = gs.import_onnx(onnx_graph) 62 | if return_onnx: 63 | return onnx_graph 64 | 65 | 66 | class BaseModel: 67 | def __init__( 68 | self, 69 | fp16=False, 70 | device="cuda", 71 | verbose=True, 72 | max_batch_size=16, 73 | min_batch_size=1, 74 | embedding_dim=768, 75 | text_maxlen=77, 76 | ): 77 | self.name = "SD Model" 78 | self.fp16 = fp16 79 | self.device = device 80 | self.verbose = verbose 81 | 82 | self.min_batch = min_batch_size 83 | self.max_batch = max_batch_size 84 | self.min_image_shape = 256 # min image resolution: 256x256 85 | self.max_image_shape = 1024 # max image resolution: 1024x1024 86 | self.min_latent_shape = self.min_image_shape // 8 87 | self.max_latent_shape = self.max_image_shape // 8 88 | 89 | self.embedding_dim = embedding_dim 90 | self.text_maxlen = text_maxlen 91 | 92 | def get_model(self): 93 | pass 94 | 95 | def get_input_names(self): 96 | pass 97 | 98 | def get_output_names(self): 99 | pass 100 | 101 | def get_dynamic_axes(self): 102 | return None 103 | 104 | def get_sample_input(self, batch_size, image_height, image_width): 105 | pass 106 | 107 | def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): 108 | return None 109 | 110 | def get_shape_dict(self, batch_size, image_height, image_width): 111 | return None 112 | 113 | def optimize(self, onnx_graph): 114 | opt = Optimizer(onnx_graph, verbose=self.verbose) 115 | opt.info(self.name + ": original") 116 | opt.cleanup() 117 | opt.info(self.name + ": cleanup") 118 | opt.fold_constants() 119 | opt.info(self.name + ": fold constants") 120 | opt.infer_shapes() 121 | opt.info(self.name + ": shape inference") 122 | onnx_opt_graph = opt.cleanup(return_onnx=True) 123 | opt.info(self.name + ": finished") 124 | return onnx_opt_graph 125 | 126 | def check_dims(self, batch_size, image_height, image_width): 127 | assert batch_size >= self.min_batch and batch_size <= self.max_batch 128 | assert image_height % 8 == 0 or image_width % 8 == 0 129 | latent_height = image_height // 8 130 | latent_width = image_width // 8 131 | assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape 132 | assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape 133 | return (latent_height, latent_width) 134 | 135 | def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape): 136 | min_batch = batch_size if static_batch else self.min_batch 137 | max_batch = batch_size if static_batch else self.max_batch 138 | latent_height = image_height // 8 139 | latent_width = image_width // 8 140 | min_image_height = image_height if static_shape else self.min_image_shape 141 | max_image_height = image_height if static_shape else self.max_image_shape 142 | min_image_width = image_width if static_shape else self.min_image_shape 143 | max_image_width = image_width if static_shape else self.max_image_shape 144 | min_latent_height = latent_height if static_shape else self.min_latent_shape 145 | max_latent_height = latent_height if static_shape else self.max_latent_shape 146 | min_latent_width = latent_width if static_shape else self.min_latent_shape 147 | max_latent_width = latent_width if static_shape else self.max_latent_shape 148 | return ( 149 | min_batch, 150 | max_batch, 151 | min_image_height, 152 | max_image_height, 153 | min_image_width, 154 | max_image_width, 155 | min_latent_height, 156 | max_latent_height, 157 | min_latent_width, 158 | max_latent_width, 159 | ) 160 | 161 | 162 | class CLIP(BaseModel): 163 | def __init__(self, device, max_batch_size, embedding_dim, min_batch_size=1): 164 | super(CLIP, self).__init__( 165 | device=device, 166 | max_batch_size=max_batch_size, 167 | min_batch_size=min_batch_size, 168 | embedding_dim=embedding_dim, 169 | ) 170 | self.name = "CLIP" 171 | 172 | def get_input_names(self): 173 | return ["input_ids"] 174 | 175 | def get_output_names(self): 176 | return ["text_embeddings", "pooler_output"] 177 | 178 | def get_dynamic_axes(self): 179 | return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} 180 | 181 | def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): 182 | self.check_dims(batch_size, image_height, image_width) 183 | min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( 184 | batch_size, image_height, image_width, static_batch, static_shape 185 | ) 186 | return { 187 | "input_ids": [ 188 | (min_batch, self.text_maxlen), 189 | (batch_size, self.text_maxlen), 190 | (max_batch, self.text_maxlen), 191 | ] 192 | } 193 | 194 | def get_shape_dict(self, batch_size, image_height, image_width): 195 | self.check_dims(batch_size, image_height, image_width) 196 | return { 197 | "input_ids": (batch_size, self.text_maxlen), 198 | "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), 199 | } 200 | 201 | def get_sample_input(self, batch_size, image_height, image_width): 202 | self.check_dims(batch_size, image_height, image_width) 203 | return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) 204 | 205 | def optimize(self, onnx_graph): 206 | opt = Optimizer(onnx_graph) 207 | opt.info(self.name + ": original") 208 | opt.select_outputs([0]) # delete graph output#1 209 | opt.cleanup() 210 | opt.info(self.name + ": remove output[1]") 211 | opt.fold_constants() 212 | opt.info(self.name + ": fold constants") 213 | opt.infer_shapes() 214 | opt.info(self.name + ": shape inference") 215 | opt.select_outputs([0], names=["text_embeddings"]) # rename network output 216 | opt.info(self.name + ": remove output[0]") 217 | opt_onnx_graph = opt.cleanup(return_onnx=True) 218 | opt.info(self.name + ": finished") 219 | return opt_onnx_graph 220 | 221 | 222 | class UNet(BaseModel): 223 | def __init__( 224 | self, 225 | fp16=False, 226 | device="cuda", 227 | max_batch_size=16, 228 | min_batch_size=1, 229 | embedding_dim=768, 230 | text_maxlen=77, 231 | unet_dim=4, 232 | ): 233 | super(UNet, self).__init__( 234 | fp16=fp16, 235 | device=device, 236 | max_batch_size=max_batch_size, 237 | min_batch_size=min_batch_size, 238 | embedding_dim=embedding_dim, 239 | text_maxlen=text_maxlen, 240 | ) 241 | self.unet_dim = unet_dim 242 | self.name = "UNet" 243 | 244 | def get_input_names(self): 245 | return ["sample", "timestep", "encoder_hidden_states"] 246 | 247 | def get_output_names(self): 248 | return ["latent"] 249 | 250 | def get_dynamic_axes(self): 251 | return { 252 | "sample": {0: "2B", 2: "H", 3: "W"}, 253 | "timestep": {0: "2B"}, 254 | "encoder_hidden_states": {0: "2B"}, 255 | "latent": {0: "2B", 2: "H", 3: "W"}, 256 | } 257 | 258 | def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): 259 | latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) 260 | ( 261 | min_batch, 262 | max_batch, 263 | _, 264 | _, 265 | _, 266 | _, 267 | min_latent_height, 268 | max_latent_height, 269 | min_latent_width, 270 | max_latent_width, 271 | ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) 272 | return { 273 | "sample": [ 274 | (min_batch, self.unet_dim, min_latent_height, min_latent_width), 275 | (batch_size, self.unet_dim, latent_height, latent_width), 276 | (max_batch, self.unet_dim, max_latent_height, max_latent_width), 277 | ], 278 | "timestep": [(min_batch,), (batch_size,), (max_batch,)], 279 | "encoder_hidden_states": [ 280 | (min_batch, self.text_maxlen, self.embedding_dim), 281 | (batch_size, self.text_maxlen, self.embedding_dim), 282 | (max_batch, self.text_maxlen, self.embedding_dim), 283 | ], 284 | } 285 | 286 | def get_shape_dict(self, batch_size, image_height, image_width): 287 | latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) 288 | return { 289 | "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), 290 | "timestep": (2 * batch_size,), 291 | "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), 292 | "latent": (2 * batch_size, 4, latent_height, latent_width), 293 | } 294 | 295 | def get_sample_input(self, batch_size, image_height, image_width): 296 | latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) 297 | dtype = torch.float16 if self.fp16 else torch.float32 298 | return ( 299 | torch.randn( 300 | 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device 301 | ), 302 | torch.ones((2 * batch_size,), dtype=torch.float32, device=self.device), 303 | torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), 304 | ) 305 | 306 | 307 | class VAE(BaseModel): 308 | def __init__(self, device, max_batch_size, min_batch_size=1): 309 | super(VAE, self).__init__( 310 | device=device, 311 | max_batch_size=max_batch_size, 312 | min_batch_size=min_batch_size, 313 | embedding_dim=None, 314 | ) 315 | self.name = "VAE decoder" 316 | 317 | def get_input_names(self): 318 | return ["latent"] 319 | 320 | def get_output_names(self): 321 | return ["images"] 322 | 323 | def get_dynamic_axes(self): 324 | return { 325 | "latent": {0: "B", 2: "H", 3: "W"}, 326 | "images": {0: "B", 2: "8H", 3: "8W"}, 327 | } 328 | 329 | def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): 330 | latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) 331 | ( 332 | min_batch, 333 | max_batch, 334 | _, 335 | _, 336 | _, 337 | _, 338 | min_latent_height, 339 | max_latent_height, 340 | min_latent_width, 341 | max_latent_width, 342 | ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) 343 | return { 344 | "latent": [ 345 | (min_batch, 4, min_latent_height, min_latent_width), 346 | (batch_size, 4, latent_height, latent_width), 347 | (max_batch, 4, max_latent_height, max_latent_width), 348 | ] 349 | } 350 | 351 | def get_shape_dict(self, batch_size, image_height, image_width): 352 | latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) 353 | return { 354 | "latent": (batch_size, 4, latent_height, latent_width), 355 | "images": (batch_size, 3, image_height, image_width), 356 | } 357 | 358 | def get_sample_input(self, batch_size, image_height, image_width): 359 | latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) 360 | return torch.randn( 361 | batch_size, 362 | 4, 363 | latent_height, 364 | latent_width, 365 | dtype=torch.float32, 366 | device=self.device, 367 | ) 368 | 369 | 370 | class VAEEncoder(BaseModel): 371 | def __init__(self, device, max_batch_size, min_batch_size=1): 372 | super(VAEEncoder, self).__init__( 373 | device=device, 374 | max_batch_size=max_batch_size, 375 | min_batch_size=min_batch_size, 376 | embedding_dim=None, 377 | ) 378 | self.name = "VAE encoder" 379 | 380 | def get_input_names(self): 381 | return ["images"] 382 | 383 | def get_output_names(self): 384 | return ["latent"] 385 | 386 | def get_dynamic_axes(self): 387 | return { 388 | "images": {0: "B", 2: "8H", 3: "8W"}, 389 | "latent": {0: "B", 2: "H", 3: "W"}, 390 | } 391 | 392 | def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): 393 | assert batch_size >= self.min_batch and batch_size <= self.max_batch 394 | min_batch = batch_size if static_batch else self.min_batch 395 | max_batch = batch_size if static_batch else self.max_batch 396 | self.check_dims(batch_size, image_height, image_width) 397 | ( 398 | min_batch, 399 | max_batch, 400 | min_image_height, 401 | max_image_height, 402 | min_image_width, 403 | max_image_width, 404 | _, 405 | _, 406 | _, 407 | _, 408 | ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) 409 | 410 | return { 411 | "images": [ 412 | (min_batch, 3, min_image_height, min_image_width), 413 | (batch_size, 3, image_height, image_width), 414 | (max_batch, 3, max_image_height, max_image_width), 415 | ], 416 | } 417 | 418 | def get_shape_dict(self, batch_size, image_height, image_width): 419 | latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) 420 | return { 421 | "images": (batch_size, 3, image_height, image_width), 422 | "latent": (batch_size, 4, latent_height, latent_width), 423 | } 424 | 425 | def get_sample_input(self, batch_size, image_height, image_width): 426 | self.check_dims(batch_size, image_height, image_width) 427 | return torch.randn( 428 | batch_size, 429 | 3, 430 | image_height, 431 | image_width, 432 | dtype=torch.float32, 433 | device=self.device, 434 | ) 435 | -------------------------------------------------------------------------------- /StreamV2V/streamv2v/acceleration/tensorrt/utilities.py: -------------------------------------------------------------------------------- 1 | #! fork: https://github.com/NVIDIA/TensorRT/blob/main/demo/Diffusion/utilities.py 2 | 3 | # 4 | # Copyright 2022 The HuggingFace Inc. team. 5 | # SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 6 | # SPDX-License-Identifier: Apache-2.0 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | # 20 | 21 | import gc 22 | from collections import OrderedDict 23 | from typing import * 24 | 25 | import numpy as np 26 | import onnx 27 | import onnx_graphsurgeon as gs 28 | import tensorrt as trt 29 | import torch 30 | from cuda import cudart 31 | from PIL import Image 32 | from polygraphy import cuda 33 | from polygraphy.backend.common import bytes_from_path 34 | from polygraphy.backend.trt import ( 35 | CreateConfig, 36 | Profile, 37 | engine_from_bytes, 38 | engine_from_network, 39 | network_from_onnx_path, 40 | save_engine, 41 | ) 42 | from polygraphy.backend.trt import util as trt_util 43 | 44 | from .models import CLIP, VAE, BaseModel, UNet, VAEEncoder 45 | 46 | 47 | TRT_LOGGER = trt.Logger(trt.Logger.ERROR) 48 | 49 | # Map of numpy dtype -> torch dtype 50 | numpy_to_torch_dtype_dict = { 51 | np.uint8: torch.uint8, 52 | np.int8: torch.int8, 53 | np.int16: torch.int16, 54 | np.int32: torch.int32, 55 | np.int64: torch.int64, 56 | np.float16: torch.float16, 57 | np.float32: torch.float32, 58 | np.float64: torch.float64, 59 | np.complex64: torch.complex64, 60 | np.complex128: torch.complex128, 61 | } 62 | if np.version.full_version >= "1.24.0": 63 | numpy_to_torch_dtype_dict[np.bool_] = torch.bool 64 | else: 65 | numpy_to_torch_dtype_dict[np.bool] = torch.bool 66 | 67 | # Map of torch dtype -> numpy dtype 68 | torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} 69 | 70 | 71 | def CUASSERT(cuda_ret): 72 | err = cuda_ret[0] 73 | if err != cudart.cudaError_t.cudaSuccess: 74 | raise RuntimeError( 75 | f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t" 76 | ) 77 | if len(cuda_ret) > 1: 78 | return cuda_ret[1] 79 | return None 80 | 81 | 82 | class Engine: 83 | def __init__( 84 | self, 85 | engine_path, 86 | ): 87 | self.engine_path = engine_path 88 | self.engine = None 89 | self.context = None 90 | self.buffers = OrderedDict() 91 | self.tensors = OrderedDict() 92 | self.cuda_graph_instance = None # cuda graph 93 | 94 | def __del__(self): 95 | [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)] 96 | del self.engine 97 | del self.context 98 | del self.buffers 99 | del self.tensors 100 | 101 | def refit(self, onnx_path, onnx_refit_path): 102 | def convert_int64(arr): 103 | # TODO: smarter conversion 104 | if len(arr.shape) == 0: 105 | return np.int32(arr) 106 | return arr 107 | 108 | def add_to_map(refit_dict, name, values): 109 | if name in refit_dict: 110 | assert refit_dict[name] is None 111 | if values.dtype == np.int64: 112 | values = convert_int64(values) 113 | refit_dict[name] = values 114 | 115 | print(f"Refitting TensorRT engine with {onnx_refit_path} weights") 116 | refit_nodes = gs.import_onnx(onnx.load(onnx_refit_path)).toposort().nodes 117 | 118 | # Construct mapping from weight names in refit model -> original model 119 | name_map = {} 120 | for n, node in enumerate(gs.import_onnx(onnx.load(onnx_path)).toposort().nodes): 121 | refit_node = refit_nodes[n] 122 | assert node.op == refit_node.op 123 | # Constant nodes in ONNX do not have inputs but have a constant output 124 | if node.op == "Constant": 125 | name_map[refit_node.outputs[0].name] = node.outputs[0].name 126 | # Handle scale and bias weights 127 | elif node.op == "Conv": 128 | if node.inputs[1].__class__ == gs.Constant: 129 | name_map[refit_node.name + "_TRTKERNEL"] = node.name + "_TRTKERNEL" 130 | if node.inputs[2].__class__ == gs.Constant: 131 | name_map[refit_node.name + "_TRTBIAS"] = node.name + "_TRTBIAS" 132 | # For all other nodes: find node inputs that are initializers (gs.Constant) 133 | else: 134 | for i, inp in enumerate(node.inputs): 135 | if inp.__class__ == gs.Constant: 136 | name_map[refit_node.inputs[i].name] = inp.name 137 | 138 | def map_name(name): 139 | if name in name_map: 140 | return name_map[name] 141 | return name 142 | 143 | # Construct refit dictionary 144 | refit_dict = {} 145 | refitter = trt.Refitter(self.engine, TRT_LOGGER) 146 | all_weights = refitter.get_all() 147 | for layer_name, role in zip(all_weights[0], all_weights[1]): 148 | # for speciailized roles, use a unique name in the map: 149 | if role == trt.WeightsRole.KERNEL: 150 | name = layer_name + "_TRTKERNEL" 151 | elif role == trt.WeightsRole.BIAS: 152 | name = layer_name + "_TRTBIAS" 153 | else: 154 | name = layer_name 155 | 156 | assert name not in refit_dict, "Found duplicate layer: " + name 157 | refit_dict[name] = None 158 | 159 | for n in refit_nodes: 160 | # Constant nodes in ONNX do not have inputs but have a constant output 161 | if n.op == "Constant": 162 | name = map_name(n.outputs[0].name) 163 | print(f"Add Constant {name}\n") 164 | add_to_map(refit_dict, name, n.outputs[0].values) 165 | 166 | # Handle scale and bias weights 167 | elif n.op == "Conv": 168 | if n.inputs[1].__class__ == gs.Constant: 169 | name = map_name(n.name + "_TRTKERNEL") 170 | add_to_map(refit_dict, name, n.inputs[1].values) 171 | 172 | if n.inputs[2].__class__ == gs.Constant: 173 | name = map_name(n.name + "_TRTBIAS") 174 | add_to_map(refit_dict, name, n.inputs[2].values) 175 | 176 | # For all other nodes: find node inputs that are initializers (AKA gs.Constant) 177 | else: 178 | for inp in n.inputs: 179 | name = map_name(inp.name) 180 | if inp.__class__ == gs.Constant: 181 | add_to_map(refit_dict, name, inp.values) 182 | 183 | for layer_name, weights_role in zip(all_weights[0], all_weights[1]): 184 | if weights_role == trt.WeightsRole.KERNEL: 185 | custom_name = layer_name + "_TRTKERNEL" 186 | elif weights_role == trt.WeightsRole.BIAS: 187 | custom_name = layer_name + "_TRTBIAS" 188 | else: 189 | custom_name = layer_name 190 | 191 | # Skip refitting Trilu for now; scalar weights of type int64 value 1 - for clip model 192 | if layer_name.startswith("onnx::Trilu"): 193 | continue 194 | 195 | if refit_dict[custom_name] is not None: 196 | refitter.set_weights(layer_name, weights_role, refit_dict[custom_name]) 197 | else: 198 | print(f"[W] No refit weights for layer: {layer_name}") 199 | 200 | if not refitter.refit_cuda_engine(): 201 | print("Failed to refit!") 202 | exit(0) 203 | 204 | def build( 205 | self, 206 | onnx_path, 207 | fp16, 208 | input_profile=None, 209 | enable_refit=False, 210 | enable_all_tactics=False, 211 | timing_cache=None, 212 | workspace_size=0, 213 | ): 214 | print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") 215 | p = Profile() 216 | if input_profile: 217 | for name, dims in input_profile.items(): 218 | assert len(dims) == 3 219 | p.add(name, min=dims[0], opt=dims[1], max=dims[2]) 220 | 221 | config_kwargs = {} 222 | 223 | if workspace_size > 0: 224 | config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size} 225 | if not enable_all_tactics: 226 | config_kwargs["tactic_sources"] = [] 227 | 228 | engine = engine_from_network( 229 | network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]), 230 | config=CreateConfig( 231 | fp16=fp16, refittable=enable_refit, profiles=[p], load_timing_cache=timing_cache, **config_kwargs 232 | ), 233 | save_timing_cache=timing_cache, 234 | ) 235 | save_engine(engine, path=self.engine_path) 236 | 237 | def load(self): 238 | print(f"Loading TensorRT engine: {self.engine_path}") 239 | self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) 240 | 241 | def activate(self, reuse_device_memory=None): 242 | if reuse_device_memory: 243 | self.context = self.engine.create_execution_context_without_device_memory() 244 | self.context.device_memory = reuse_device_memory 245 | else: 246 | self.context = self.engine.create_execution_context() 247 | 248 | def allocate_buffers(self, shape_dict=None, device="cuda"): 249 | for idx in range(trt_util.get_bindings_per_profile(self.engine)): 250 | binding = self.engine[idx] 251 | if shape_dict and binding in shape_dict: 252 | shape = shape_dict[binding] 253 | else: 254 | shape = self.engine.get_binding_shape(binding) 255 | dtype = trt.nptype(self.engine.get_binding_dtype(binding)) 256 | if self.engine.binding_is_input(binding): 257 | self.context.set_binding_shape(idx, shape) 258 | tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) 259 | self.tensors[binding] = tensor 260 | 261 | def infer(self, feed_dict, stream, use_cuda_graph=False): 262 | for name, buf in feed_dict.items(): 263 | self.tensors[name].copy_(buf) 264 | 265 | for name, tensor in self.tensors.items(): 266 | self.context.set_tensor_address(name, tensor.data_ptr()) 267 | 268 | if use_cuda_graph: 269 | if self.cuda_graph_instance is not None: 270 | CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream.ptr)) 271 | CUASSERT(cudart.cudaStreamSynchronize(stream.ptr)) 272 | else: 273 | # do inference before CUDA graph capture 274 | noerror = self.context.execute_async_v3(stream.ptr) 275 | if not noerror: 276 | raise ValueError("ERROR: inference failed.") 277 | # capture cuda graph 278 | CUASSERT( 279 | cudart.cudaStreamBeginCapture(stream.ptr, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal) 280 | ) 281 | self.context.execute_async_v3(stream.ptr) 282 | self.graph = CUASSERT(cudart.cudaStreamEndCapture(stream.ptr)) 283 | self.cuda_graph_instance = CUASSERT(cudart.cudaGraphInstantiate(self.graph, 0)) 284 | else: 285 | noerror = self.context.execute_async_v3(stream.ptr) 286 | if not noerror: 287 | raise ValueError("ERROR: inference failed.") 288 | 289 | return self.tensors 290 | 291 | 292 | def decode_images(images: torch.Tensor): 293 | images = ( 294 | ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy() 295 | ) 296 | return [Image.fromarray(x) for x in images] 297 | 298 | 299 | def preprocess_image(image: Image.Image): 300 | w, h = image.size 301 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 302 | image = image.resize((w, h)) 303 | init_image = np.array(image).astype(np.float32) / 255.0 304 | init_image = init_image[None].transpose(0, 3, 1, 2) 305 | init_image = torch.from_numpy(init_image).contiguous() 306 | return 2.0 * init_image - 1.0 307 | 308 | 309 | def prepare_mask_and_masked_image(image: Image.Image, mask: Image.Image): 310 | if isinstance(image, Image.Image): 311 | image = np.array(image.convert("RGB")) 312 | image = image[None].transpose(0, 3, 1, 2) 313 | image = torch.from_numpy(image).to(dtype=torch.float32).contiguous() / 127.5 - 1.0 314 | if isinstance(mask, Image.Image): 315 | mask = np.array(mask.convert("L")) 316 | mask = mask.astype(np.float32) / 255.0 317 | mask = mask[None, None] 318 | mask[mask < 0.5] = 0 319 | mask[mask >= 0.5] = 1 320 | mask = torch.from_numpy(mask).to(dtype=torch.float32).contiguous() 321 | 322 | masked_image = image * (mask < 0.5) 323 | 324 | return mask, masked_image 325 | 326 | 327 | def create_models( 328 | model_id: str, 329 | use_auth_token: Optional[str], 330 | device: Union[str, torch.device], 331 | max_batch_size: int, 332 | unet_in_channels: int = 4, 333 | embedding_dim: int = 768, 334 | ): 335 | models = { 336 | "clip": CLIP( 337 | hf_token=use_auth_token, 338 | device=device, 339 | max_batch_size=max_batch_size, 340 | embedding_dim=embedding_dim, 341 | ), 342 | "unet": UNet( 343 | hf_token=use_auth_token, 344 | fp16=True, 345 | device=device, 346 | max_batch_size=max_batch_size, 347 | embedding_dim=embedding_dim, 348 | unet_dim=unet_in_channels, 349 | ), 350 | "vae": VAE( 351 | hf_token=use_auth_token, 352 | device=device, 353 | max_batch_size=max_batch_size, 354 | embedding_dim=embedding_dim, 355 | ), 356 | "vae_encoder": VAEEncoder( 357 | hf_token=use_auth_token, 358 | device=device, 359 | max_batch_size=max_batch_size, 360 | embedding_dim=embedding_dim, 361 | ), 362 | } 363 | return models 364 | 365 | 366 | def build_engine( 367 | engine_path: str, 368 | onnx_opt_path: str, 369 | model_data: BaseModel, 370 | opt_image_height: int, 371 | opt_image_width: int, 372 | opt_batch_size: int, 373 | build_static_batch: bool = False, 374 | build_dynamic_shape: bool = False, 375 | build_all_tactics: bool = False, 376 | build_enable_refit: bool = False, 377 | ): 378 | _, free_mem, _ = cudart.cudaMemGetInfo() 379 | GiB = 2**30 380 | if free_mem > 6 * GiB: 381 | activation_carveout = 4 * GiB 382 | max_workspace_size = free_mem - activation_carveout 383 | else: 384 | max_workspace_size = 0 385 | engine = Engine(engine_path) 386 | input_profile = model_data.get_input_profile( 387 | opt_batch_size, 388 | opt_image_height, 389 | opt_image_width, 390 | static_batch=build_static_batch, 391 | static_shape=not build_dynamic_shape, 392 | ) 393 | engine.build( 394 | onnx_opt_path, 395 | fp16=True, 396 | input_profile=input_profile, 397 | enable_refit=build_enable_refit, 398 | enable_all_tactics=build_all_tactics, 399 | workspace_size=max_workspace_size, 400 | ) 401 | 402 | return engine 403 | 404 | 405 | def export_onnx( 406 | model, 407 | onnx_path: str, 408 | model_data: BaseModel, 409 | opt_image_height: int, 410 | opt_image_width: int, 411 | opt_batch_size: int, 412 | onnx_opset: int, 413 | ): 414 | with torch.inference_mode(), torch.autocast("cuda"): 415 | inputs = model_data.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) 416 | torch.onnx.export( 417 | model, 418 | inputs, 419 | onnx_path, 420 | export_params=True, 421 | opset_version=onnx_opset, 422 | do_constant_folding=True, 423 | input_names=model_data.get_input_names(), 424 | output_names=model_data.get_output_names(), 425 | dynamic_axes=model_data.get_dynamic_axes(), 426 | ) 427 | del model 428 | gc.collect() 429 | torch.cuda.empty_cache() 430 | 431 | 432 | def optimize_onnx( 433 | onnx_path: str, 434 | onnx_opt_path: str, 435 | model_data: BaseModel, 436 | ): 437 | onnx_opt_graph = model_data.optimize(onnx.load(onnx_path)) 438 | onnx.save(onnx_opt_graph, onnx_opt_path) 439 | del onnx_opt_graph 440 | gc.collect() 441 | torch.cuda.empty_cache() 442 | -------------------------------------------------------------------------------- /StreamV2V/streamv2v/image_filter.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import random 3 | 4 | import torch 5 | 6 | 7 | class SimilarImageFilter: 8 | def __init__(self, threshold: float = 0.98, max_skip_frame: float = 10) -> None: 9 | self.threshold = threshold 10 | self.prev_tensor = None 11 | self.cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6) 12 | self.max_skip_frame = max_skip_frame 13 | self.skip_count = 0 14 | 15 | def __call__(self, x: torch.Tensor) -> Optional[torch.Tensor]: 16 | if self.prev_tensor is None: 17 | self.prev_tensor = x.detach().clone() 18 | return x 19 | else: 20 | cos_sim = self.cos(self.prev_tensor.reshape(-1), x.reshape(-1)).item() 21 | sample = random.uniform(0, 1) 22 | if self.threshold >= 1: 23 | skip_prob = 0 24 | else: 25 | skip_prob = max(0, 1 - (1 - cos_sim) / (1 - self.threshold)) 26 | 27 | # not skip frame 28 | if skip_prob < sample: 29 | self.prev_tensor = x.detach().clone() 30 | return x 31 | # skip frame 32 | else: 33 | if self.skip_count > self.max_skip_frame: 34 | self.skip_count = 0 35 | self.prev_tensor = x.detach().clone() 36 | return x 37 | else: 38 | self.skip_count += 1 39 | return None 40 | 41 | def set_threshold(self, threshold: float) -> None: 42 | self.threshold = threshold 43 | 44 | def set_max_skip_frame(self, max_skip_frame: float) -> None: 45 | self.max_skip_frame = max_skip_frame 46 | -------------------------------------------------------------------------------- /StreamV2V/streamv2v/image_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | import PIL.Image 5 | import torch 6 | import torchvision 7 | 8 | 9 | def denormalize(images: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: 10 | """ 11 | Denormalize an image array to [0,1]. 12 | """ 13 | return (images / 2 + 0.5).clamp(0, 1) 14 | 15 | 16 | def pt_to_numpy(images: torch.Tensor) -> np.ndarray: 17 | """ 18 | Convert a PyTorch tensor to a NumPy image. 19 | """ 20 | images = images.cpu().permute(0, 2, 3, 1).float().numpy() 21 | return images 22 | 23 | 24 | def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image: 25 | """ 26 | Convert a NumPy image or a batch of images to a PIL image. 27 | """ 28 | if images.ndim == 3: 29 | images = images[None, ...] 30 | images = (images * 255).round().astype("uint8") 31 | if images.shape[-1] == 1: 32 | # special case for grayscale (single channel) images 33 | pil_images = [ 34 | PIL.Image.fromarray(image.squeeze(), mode="L") for image in images 35 | ] 36 | else: 37 | pil_images = [PIL.Image.fromarray(image) for image in images] 38 | 39 | return pil_images 40 | 41 | 42 | def postprocess_image( 43 | image: torch.Tensor, 44 | output_type: str = "pil", 45 | do_denormalize: Optional[List[bool]] = None, 46 | ) -> Union[torch.Tensor, np.ndarray, PIL.Image.Image]: 47 | if not isinstance(image, torch.Tensor): 48 | raise ValueError( 49 | f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" 50 | ) 51 | 52 | if output_type == "latent": 53 | return image 54 | 55 | do_normalize_flg = True 56 | if do_denormalize is None: 57 | do_denormalize = [do_normalize_flg] * image.shape[0] 58 | 59 | image = torch.stack( 60 | [ 61 | denormalize(image[i]) if do_denormalize[i] else image[i] 62 | for i in range(image.shape[0]) 63 | ] 64 | ) 65 | 66 | if output_type == "pt": 67 | return image 68 | 69 | image = pt_to_numpy(image) 70 | 71 | if output_type == "np": 72 | return image 73 | 74 | if output_type == "pil": 75 | return numpy_to_pil(image) 76 | 77 | 78 | def process_image( 79 | image_pil: PIL.Image.Image, range: Tuple[int, int] = (-1, 1) 80 | ) -> Tuple[torch.Tensor, PIL.Image.Image]: 81 | image = torchvision.transforms.ToTensor()(image_pil) 82 | r_min, r_max = range[0], range[1] 83 | image = image * (r_max - r_min) + r_min 84 | return image[None, ...], image_pil 85 | 86 | 87 | def pil2tensor(image_pil: PIL.Image.Image) -> torch.Tensor: 88 | height = image_pil.height 89 | width = image_pil.width 90 | imgs = [] 91 | img, _ = process_image(image_pil) 92 | imgs.append(img) 93 | imgs = torch.vstack(imgs) 94 | images = torch.nn.functional.interpolate( 95 | imgs, size=(height, width), mode="bilinear" 96 | ) 97 | image_tensors = images.to(torch.float16) 98 | return image_tensors 99 | 100 | ### Optical flow utils 101 | 102 | def coords_grid(b, h, w, homogeneous=False, device=None): 103 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] 104 | 105 | stacks = [x, y] 106 | 107 | if homogeneous: 108 | ones = torch.ones_like(x) # [H, W] 109 | stacks.append(ones) 110 | 111 | grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] 112 | 113 | grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] 114 | 115 | if device is not None: 116 | grid = grid.to(device) 117 | 118 | return grid 119 | 120 | def flow_warp(feature, flow, mask=False, padding_mode='zeros'): 121 | b, c, h, w = feature.size() 122 | assert flow.size(1) == 2 123 | 124 | grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] 125 | 126 | return bilinear_sample(feature, grid, padding_mode=padding_mode, 127 | return_mask=mask) 128 | 129 | def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): 130 | # img: [B, C, H, W] 131 | # sample_coords: [B, 2, H, W] in image scale 132 | if sample_coords.size(1) != 2: # [B, H, W, 2] 133 | sample_coords = sample_coords.permute(0, 3, 1, 2) 134 | 135 | b, _, h, w = sample_coords.shape 136 | 137 | # Normalize to [-1, 1] 138 | x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 139 | y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 140 | 141 | grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] 142 | 143 | img = torch.nn.functional.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) 144 | 145 | if return_mask: 146 | mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W] 147 | 148 | return img, mask 149 | 150 | return img 151 | 152 | def forward_backward_consistency_check(fwd_flow, bwd_flow, 153 | alpha=0.1, 154 | beta=0.5 155 | ): 156 | # fwd_flow, bwd_flow: [B, 2, H, W] 157 | # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) 158 | assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 159 | assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 160 | flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] 161 | 162 | warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] 163 | warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] 164 | 165 | diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] 166 | diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) 167 | 168 | threshold = alpha * flow_mag + beta 169 | 170 | fwd_occ = (diff_fwd > threshold).float() # [B, H, W] 171 | bwd_occ = (diff_bwd > threshold).float() 172 | 173 | return fwd_occ, bwd_occ -------------------------------------------------------------------------------- /StreamV2V/streamv2v/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-StreamV2V/ddc939992d873706ae54094d906ea328cb82613b/StreamV2V/streamv2v/models/__init__.py -------------------------------------------------------------------------------- /StreamV2V/streamv2v/models/attention_processor.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | from typing import Callable, Optional, Union 3 | from collections import deque 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | from diffusers.models.attention_processor import Attention 10 | from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging 11 | from diffusers.utils.import_utils import is_xformers_available 12 | from diffusers.utils.torch_utils import maybe_allow_in_graph 13 | from diffusers.models.lora import LoRACompatibleLinear, LoRALinearLayer 14 | 15 | from .utils import get_nn_feats, random_bipartite_soft_matching 16 | 17 | if is_xformers_available(): 18 | import xformers 19 | import xformers.ops 20 | else: 21 | xformers = None 22 | 23 | class CachedSTAttnProcessor2_0: 24 | r""" 25 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 26 | """ 27 | 28 | def __init__(self, name=None, use_feature_injection=False, 29 | feature_injection_strength=0.8, 30 | feature_similarity_threshold=0.98, 31 | interval=4, 32 | max_frames=1, 33 | use_tome_cache=False, 34 | tome_metric="keys", 35 | use_grid=False, 36 | tome_ratio=0.5): 37 | if not hasattr(F, "scaled_dot_product_attention"): 38 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 39 | self.name = name 40 | self.use_feature_injection = use_feature_injection 41 | self.fi_strength = feature_injection_strength 42 | self.threshold = feature_similarity_threshold 43 | self.zero_tensor = torch.tensor(0) 44 | self.frame_id = torch.tensor(0) 45 | self.interval = torch.tensor(interval) 46 | self.max_frames = max_frames 47 | self.cached_key = None 48 | self.cached_value = None 49 | self.cached_output = None 50 | self.use_tome_cache = use_tome_cache 51 | self.tome_metric = tome_metric 52 | self.use_grid = use_grid 53 | self.tome_ratio = tome_ratio 54 | 55 | def _tome_step_kvout(self, keys, values, outputs): 56 | keys = torch.cat([self.cached_key, keys], dim=1) 57 | values = torch.cat([self.cached_value, values], dim=1) 58 | outputs = torch.cat([self.cached_output, outputs], dim=1) 59 | m_kv_out, _, _= random_bipartite_soft_matching(metric=keys, use_grid=self.use_grid, ratio=self.tome_ratio) 60 | compact_keys, compact_values, compact_outputs = m_kv_out(keys, values, outputs) 61 | self.cached_key = compact_keys 62 | self.cached_value = compact_values 63 | self.cached_output = compact_outputs 64 | 65 | def __call__( 66 | self, 67 | attn: Attention, 68 | hidden_states: torch.FloatTensor, 69 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 70 | attention_mask: Optional[torch.FloatTensor] = None, 71 | temb: Optional[torch.FloatTensor] = None, 72 | scale: float = 1.0, 73 | ) -> torch.FloatTensor: 74 | residual = hidden_states 75 | if attn.spatial_norm is not None: 76 | hidden_states = attn.spatial_norm(hidden_states, temb) 77 | 78 | input_ndim = hidden_states.ndim 79 | 80 | if input_ndim == 4: 81 | batch_size, channel, height, width = hidden_states.shape 82 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 83 | 84 | batch_size, sequence_length, _ = ( 85 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 86 | ) 87 | 88 | if attention_mask is not None: 89 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 90 | # scaled_dot_product_attention expects attention_mask shape to be 91 | # (batch, heads, source_length, target_length) 92 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 93 | 94 | if attn.group_norm is not None: 95 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 96 | 97 | args = () if USE_PEFT_BACKEND else (scale,) 98 | query = attn.to_q(hidden_states, *args) 99 | 100 | is_selfattn = False 101 | if encoder_hidden_states is None: 102 | is_selfattn = True 103 | encoder_hidden_states = hidden_states 104 | elif attn.norm_cross: 105 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 106 | 107 | key = attn.to_k(encoder_hidden_states, *args) 108 | value = attn.to_v(encoder_hidden_states, *args) 109 | 110 | if is_selfattn: 111 | cached_key = key.clone() 112 | cached_value = value.clone() 113 | 114 | # Avoid if statement -> replace the dynamic graph to static graph 115 | if torch.equal(self.frame_id, self.zero_tensor): 116 | # ONNX 117 | self.cached_key = cached_key 118 | self.cached_value = cached_value 119 | 120 | key = torch.cat([key, self.cached_key], dim=1) 121 | value = torch.cat([value, self.cached_value], dim=1) 122 | 123 | inner_dim = key.shape[-1] 124 | head_dim = inner_dim // attn.heads 125 | 126 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 127 | 128 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 129 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 130 | 131 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 132 | # TODO: add support for attn.scale when we move to Torch 2.1 133 | hidden_states = F.scaled_dot_product_attention( 134 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 135 | ) 136 | 137 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 138 | hidden_states = hidden_states.to(query.dtype) 139 | 140 | # linear proj 141 | hidden_states = attn.to_out[0](hidden_states, *args) 142 | # dropout 143 | hidden_states = attn.to_out[1](hidden_states) 144 | 145 | if input_ndim == 4: 146 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 147 | 148 | if attn.residual_connection: 149 | hidden_states = hidden_states + residual 150 | 151 | hidden_states = hidden_states / attn.rescale_output_factor 152 | 153 | if is_selfattn: 154 | cached_output = hidden_states.clone() 155 | 156 | if torch.equal(self.frame_id, self.zero_tensor): 157 | self.cached_output = cached_output 158 | 159 | if self.use_feature_injection and ("up_blocks.0" in self.name or "up_blocks.1" in self.name or 'mid_block' in self.name): 160 | nn_hidden_states = get_nn_feats(hidden_states, self.cached_output, threshold=self.threshold) 161 | hidden_states = hidden_states * (1-self.fi_strength) + self.fi_strength * nn_hidden_states 162 | 163 | mod_result = torch.remainder(self.frame_id, self.interval) 164 | if torch.equal(mod_result, self.zero_tensor) and is_selfattn: 165 | self._tome_step_kvout(cached_key, cached_value, cached_output) 166 | 167 | self.frame_id = self.frame_id + 1 168 | 169 | return hidden_states 170 | 171 | 172 | 173 | class CachedSTXFormersAttnProcessor: 174 | r""" 175 | Processor for implementing memory efficient attention using xFormers. 176 | 177 | Args: 178 | attention_op (`Callable`, *optional*, defaults to `None`): 179 | The base 180 | [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to 181 | use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best 182 | operator. 183 | """ 184 | 185 | def __init__(self, attention_op: Optional[Callable] = None, name=None, 186 | use_feature_injection=False, feature_injection_strength=0.8, feature_similarity_threshold=0.98, 187 | interval=4, max_frames=4, use_tome_cache=False, tome_metric="keys", use_grid=False, tome_ratio=0.5): 188 | self.attention_op = attention_op 189 | self.name = name 190 | self.use_feature_injection = use_feature_injection 191 | self.fi_strength = feature_injection_strength 192 | self.threshold = feature_similarity_threshold 193 | self.frame_id = 0 194 | self.interval = interval 195 | self.cached_key = deque(maxlen=max_frames) 196 | self.cached_value = deque(maxlen=max_frames) 197 | self.cached_output = deque(maxlen=max_frames) 198 | self.use_tome_cache = use_tome_cache 199 | self.tome_metric = tome_metric 200 | self.use_grid = use_grid 201 | self.tome_ratio = tome_ratio 202 | 203 | def _tome_step_kvout(self, keys, values, outputs): 204 | if len(self.cached_value) == 1: 205 | keys = torch.cat(list(self.cached_key) + [keys], dim=1) 206 | values = torch.cat(list(self.cached_value) + [values], dim=1) 207 | outputs = torch.cat(list(self.cached_output) + [outputs], dim=1) 208 | m_kv_out, _, _= random_bipartite_soft_matching(metric=eval(self.tome_metric), use_grid=self.use_grid, ratio=self.tome_ratio) 209 | compact_keys, compact_values, compact_outputs = m_kv_out(keys, values, outputs) 210 | self.cached_key.append(compact_keys) 211 | self.cached_value.append(compact_values) 212 | self.cached_output.append(compact_outputs) 213 | else: 214 | self.cached_key.append(keys) 215 | self.cached_value.append(values) 216 | self.cached_output.append(outputs) 217 | 218 | def _tome_step_kv(self, keys, values): 219 | if len(self.cached_value) == 1: 220 | keys = torch.cat(list(self.cached_key) + [keys], dim=1) 221 | values = torch.cat(list(self.cached_value) + [values], dim=1) 222 | _, m_kv, _= random_bipartite_soft_matching(metric=eval(self.tome_metric), use_grid=self.use_grid, ratio=self.tome_ratio) 223 | compact_keys, compact_values = m_kv(keys, values) 224 | self.cached_key.append(compact_keys) 225 | self.cached_value.append(compact_values) 226 | else: 227 | self.cached_key.append(keys) 228 | self.cached_value.append(values) 229 | 230 | def _tome_step_out(self, outputs): 231 | if len(self.cached_value) == 1: 232 | outputs = torch.cat(list(self.cached_output) + [outputs], dim=1) 233 | _, _, m_out= random_bipartite_soft_matching(metric=outputs, use_grid=self.use_grid, ratio=self.tome_ratio) 234 | compact_outputs = m_out(outputs) 235 | self.cached_output.append(compact_outputs) 236 | else: 237 | self.cached_output.append(outputs) 238 | 239 | def __call__( 240 | self, 241 | attn: Attention, 242 | hidden_states: torch.FloatTensor, 243 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 244 | attention_mask: Optional[torch.FloatTensor] = None, 245 | temb: Optional[torch.FloatTensor] = None, 246 | scale: float = 1.0, 247 | ) -> torch.FloatTensor: 248 | residual = hidden_states 249 | 250 | args = () if USE_PEFT_BACKEND else (scale,) 251 | 252 | if attn.spatial_norm is not None: 253 | hidden_states = attn.spatial_norm(hidden_states, temb) 254 | 255 | input_ndim = hidden_states.ndim 256 | 257 | if input_ndim == 4: 258 | batch_size, channel, height, width = hidden_states.shape 259 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 260 | 261 | batch_size, key_tokens, _ = ( 262 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 263 | ) 264 | 265 | attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size) 266 | if attention_mask is not None: 267 | # expand our mask's singleton query_tokens dimension: 268 | # [batch*heads, 1, key_tokens] -> 269 | # [batch*heads, query_tokens, key_tokens] 270 | # so that it can be added as a bias onto the attention scores that xformers computes: 271 | # [batch*heads, query_tokens, key_tokens] 272 | # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. 273 | _, query_tokens, _ = hidden_states.shape 274 | attention_mask = attention_mask.expand(-1, query_tokens, -1) 275 | 276 | if attn.group_norm is not None: 277 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 278 | 279 | query = attn.to_q(hidden_states, *args) 280 | 281 | is_selfattn = False 282 | if encoder_hidden_states is None: 283 | is_selfattn = True 284 | encoder_hidden_states = hidden_states 285 | elif attn.norm_cross: 286 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 287 | 288 | key = attn.to_k(encoder_hidden_states, *args) 289 | value = attn.to_v(encoder_hidden_states, *args) 290 | 291 | if is_selfattn: 292 | cached_key = key.clone() 293 | cached_value = value.clone() 294 | 295 | if len(self.cached_key) > 0: 296 | key = torch.cat([key] + list(self.cached_key), dim=1) 297 | value = torch.cat([value] + list(self.cached_value), dim=1) 298 | 299 | ## Code for storing and visualizing features 300 | # if self.frame_id % self.interval == 0: 301 | # # if "down_blocks.0" in self.name or "up_blocks.3" in self.name: 302 | # # feats = { 303 | # # "hidden_states": hidden_states.clone().cpu(), 304 | # # "query": query.clone().cpu(), 305 | # # "key": cached_key.cpu(), 306 | # # "value": cached_value.cpu(), 307 | # # } 308 | # # torch.save(feats, f'./outputs/self_attn_feats_SD/{self.name}.frame{self.frame_id}.pt') 309 | # if self.use_tome_cache: 310 | # cached_key, cached_value = self._tome_step(cached_key, cached_value) 311 | 312 | query = attn.head_to_batch_dim(query).contiguous() 313 | key = attn.head_to_batch_dim(key).contiguous() 314 | value = attn.head_to_batch_dim(value).contiguous() 315 | 316 | hidden_states = xformers.ops.memory_efficient_attention( 317 | query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale 318 | ) 319 | hidden_states = hidden_states.to(query.dtype) 320 | hidden_states = attn.batch_to_head_dim(hidden_states) 321 | 322 | # linear proj 323 | hidden_states = attn.to_out[0](hidden_states, *args) 324 | # dropout 325 | hidden_states = attn.to_out[1](hidden_states) 326 | 327 | if input_ndim == 4: 328 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 329 | 330 | if attn.residual_connection: 331 | hidden_states = hidden_states + residual 332 | 333 | hidden_states = hidden_states / attn.rescale_output_factor 334 | if is_selfattn: 335 | cached_output = hidden_states.clone() 336 | if self.use_feature_injection and ("up_blocks.0" in self.name or "up_blocks.1" in self.name or 'mid_block' in self.name): 337 | if len(self.cached_output) > 0: 338 | nn_hidden_states = get_nn_feats(hidden_states, self.cached_output, threshold=self.threshold) 339 | hidden_states = hidden_states * (1-self.fi_strength) + self.fi_strength * nn_hidden_states 340 | 341 | if self.frame_id % self.interval == 0: 342 | if is_selfattn: 343 | if self.use_tome_cache: 344 | self._tome_step_kvout(cached_key, cached_value, cached_output) 345 | else: 346 | self.cached_key.append(cached_key) 347 | self.cached_value.append(cached_value) 348 | self.cached_output.append(cached_output) 349 | self.frame_id += 1 350 | 351 | return hidden_states 352 | -------------------------------------------------------------------------------- /StreamV2V/streamv2v/models/utils.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from typing import Tuple, Callable 3 | 4 | from einops import rearrange 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | def get_nn_feats(x, y, threshold=0.9): 9 | 10 | if type(x) is deque: 11 | x = torch.cat(list(x), dim=1) 12 | if type(y) is deque: 13 | y = torch.cat(list(y), dim=1) 14 | 15 | x_norm = F.normalize(x, p=2, dim=-1) 16 | y_norm = F.normalize(y, p=2, dim=-1) 17 | 18 | cosine_similarity = torch.matmul(x_norm, y_norm.transpose(1, 2)) 19 | 20 | max_cosine_values, nearest_neighbors_indices = torch.max(cosine_similarity, dim=-1) 21 | mask = max_cosine_values < threshold 22 | # print('mask ratio', torch.sum(mask)/x.shape[0]/x.shape[1]) 23 | indices_expanded = nearest_neighbors_indices.unsqueeze(-1).expand(-1, -1, x_norm.size(-1)) 24 | nearest_neighbor_tensor = torch.gather(y, 1, indices_expanded) 25 | selected_tensor = torch.where(mask.unsqueeze(-1), x, nearest_neighbor_tensor) 26 | 27 | return selected_tensor 28 | 29 | def get_nn_latent(x, y, threshold=0.9): 30 | 31 | assert len(x.shape) == 4 32 | _, c, h, w = x.shape 33 | x_ = rearrange(x, 'n c h w -> n (h w) c') 34 | y_ = [] 35 | for i in range(len(y)): 36 | y_.append(rearrange(y[i], 'n c h w -> n (h w) c')) 37 | y_ = torch.cat(y_, dim=1) 38 | x_norm = F.normalize(x_, p=2, dim=-1) 39 | y_norm = F.normalize(y_, p=2, dim=-1) 40 | 41 | cosine_similarity = torch.matmul(x_norm, y_norm.transpose(1, 2)) 42 | 43 | max_cosine_values, nearest_neighbors_indices = torch.max(cosine_similarity, dim=-1) 44 | mask = max_cosine_values < threshold 45 | indices_expanded = nearest_neighbors_indices.unsqueeze(-1).expand(-1, -1, x_norm.size(-1)) 46 | nearest_neighbor_tensor = torch.gather(y_, 1, indices_expanded) 47 | 48 | # Use values from x where the cosine similarity is below the threshold 49 | x_expanded = x_.expand_as(nearest_neighbor_tensor) 50 | selected_tensor = torch.where(mask.unsqueeze(-1), x_expanded, nearest_neighbor_tensor) 51 | 52 | selected_tensor = rearrange(selected_tensor, 'n (h w) c -> n c h w', h=h, w=w, c=c) 53 | 54 | return selected_tensor 55 | 56 | 57 | def random_bipartite_soft_matching( 58 | metric: torch.Tensor, use_grid: bool = False, ratio: float = 0.5 59 | ) -> Tuple[Callable, Callable]: 60 | """ 61 | Applies ToMe with the two sets as (r chosen randomly, the rest). 62 | Input size is [batch, tokens, channels]. 63 | 64 | This will reduce the number of tokens by a ratio of ratio/2. 65 | """ 66 | 67 | with torch.no_grad(): 68 | B, N, _ = metric.shape 69 | if use_grid: 70 | assert ratio == 0.5 71 | sample = torch.randint(2, size=(B, N//2, 1), device=metric.device) 72 | sample_alternate = 1 - sample 73 | grid = torch.arange(0, N, 2).view(1, N//2, 1).to(device=metric.device) 74 | grid = grid.repeat(4, 1, 1) 75 | rand_idx = torch.cat([sample + grid, sample_alternate + grid], dim = 1) 76 | else: 77 | rand_idx = torch.rand(B, N, 1, device=metric.device).argsort(dim=1) 78 | r = int(ratio * N) 79 | a_idx = rand_idx[:, :r, :] 80 | b_idx = rand_idx[:, r:, :] 81 | def split(x): 82 | C = x.shape[-1] 83 | a = x.gather(dim=1, index=a_idx.expand(B, r, C)) 84 | b = x.gather(dim=1, index=b_idx.expand(B, N - r, C)) 85 | return a, b 86 | 87 | metric = metric / metric.norm(dim=-1, keepdim=True) 88 | a, b = split(metric) 89 | scores = a @ b.transpose(-1, -2) 90 | 91 | _, dst_idx = scores.max(dim=-1) 92 | dst_idx = dst_idx[..., None] 93 | 94 | def merge_kv_out(keys: torch.Tensor, values: torch.Tensor, outputs: torch.Tensor, mode="mean") -> torch.Tensor: 95 | src_keys, dst_keys = split(keys) 96 | C_keys = src_keys.shape[-1] 97 | dst_keys = dst_keys.scatter_reduce(-2, dst_idx.expand(B, r, C_keys), src_keys, reduce=mode) 98 | 99 | src_values, dst_values = split(values) 100 | C_values = src_values.shape[-1] 101 | dst_values = dst_values.scatter_reduce(-2, dst_idx.expand(B, r, C_values), src_values, reduce=mode) 102 | 103 | src_outputs, dst_outputs = split(outputs) 104 | C_outputs = src_outputs.shape[-1] 105 | dst_outputs = dst_outputs.scatter_reduce(-2, dst_idx.expand(B, r, C_outputs), src_outputs, reduce=mode) 106 | 107 | return dst_keys, dst_values, dst_outputs 108 | 109 | def merge_kv(keys: torch.Tensor, values: torch.Tensor, mode="mean") -> torch.Tensor: 110 | src_keys, dst_keys = split(keys) 111 | C_keys = src_keys.shape[-1] 112 | dst_keys = dst_keys.scatter_reduce(-2, dst_idx.expand(B, r, C_keys), src_keys, reduce=mode) 113 | 114 | src_values, dst_values = split(values) 115 | C_values = src_values.shape[-1] 116 | dst_values = dst_values.scatter_reduce(-2, dst_idx.expand(B, r, C_values), src_values, reduce=mode) 117 | 118 | return dst_keys, dst_values 119 | 120 | def merge_out(outputs: torch.Tensor, mode="mean") -> torch.Tensor: 121 | src_outputs, dst_outputs = split(outputs) 122 | C_outputs = src_outputs.shape[-1] 123 | dst_outputs = dst_outputs.scatter_reduce(-2, dst_idx.expand(B, r, C_outputs), src_outputs, reduce=mode) 124 | 125 | return dst_outputs 126 | 127 | return merge_kv_out, merge_kv, merge_out -------------------------------------------------------------------------------- /StreamV2V/streamv2v/pip_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import importlib.util 3 | import os 4 | import subprocess 5 | import sys 6 | from typing import Dict, Optional 7 | 8 | from packaging.version import Version 9 | 10 | 11 | python = sys.executable 12 | index_url = os.environ.get("INDEX_URL", "") 13 | 14 | 15 | def version(package: str) -> Optional[Version]: 16 | try: 17 | return Version(importlib.import_module(package).__version__) 18 | except ModuleNotFoundError: 19 | return None 20 | 21 | 22 | def is_installed(package: str) -> bool: 23 | try: 24 | spec = importlib.util.find_spec(package) 25 | except ModuleNotFoundError: 26 | return False 27 | 28 | return spec is not None 29 | 30 | 31 | def run_python(command: str, env: Dict[str, str] = None) -> str: 32 | run_kwargs = { 33 | "args": f"\"{python}\" {command}", 34 | "shell": True, 35 | "env": os.environ if env is None else env, 36 | "encoding": "utf8", 37 | "errors": "ignore", 38 | } 39 | 40 | print(run_kwargs["args"]) 41 | 42 | result = subprocess.run(**run_kwargs) 43 | 44 | if result.returncode != 0: 45 | print(f"Error running command: {command}", file=sys.stderr) 46 | raise RuntimeError(f"Error running command: {command}") 47 | 48 | return result.stdout or "" 49 | 50 | 51 | def run_pip(command: str, env: Dict[str, str] = None) -> str: 52 | return run_python(f"-m pip {command}", env) 53 | -------------------------------------------------------------------------------- /StreamV2V/streamv2v/pipeline.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import time 4 | from typing import List, Optional, Union, Any, Dict, Tuple, Literal 5 | from collections import deque 6 | 7 | import numpy as np 8 | import PIL.Image 9 | import torch 10 | import torch.nn.functional as F 11 | from torchvision.models.optical_flow import raft_small 12 | 13 | from diffusers import LCMScheduler, StableDiffusionPipeline 14 | from diffusers.image_processor import VaeImageProcessor 15 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( 16 | retrieve_latents, 17 | ) 18 | from .image_utils import postprocess_image, forward_backward_consistency_check 19 | from .models.utils import get_nn_latent 20 | from .image_filter import SimilarImageFilter 21 | 22 | 23 | class StreamV2V: 24 | def __init__( 25 | self, 26 | pipe: StableDiffusionPipeline, 27 | t_index_list: List[int], 28 | torch_dtype: torch.dtype = torch.float16, 29 | width: int = 512, 30 | height: int = 512, 31 | do_add_noise: bool = True, 32 | use_denoising_batch: bool = True, 33 | frame_buffer_size: int = 1, 34 | cfg_type: Literal["none", "full", "self", "initialize"] = "self", 35 | ) -> None: 36 | self.device = pipe.device 37 | self.dtype = torch_dtype 38 | self.generator = None 39 | 40 | self.height = height 41 | self.width = width 42 | 43 | self.latent_height = int(height // pipe.vae_scale_factor) 44 | self.latent_width = int(width // pipe.vae_scale_factor) 45 | 46 | self.frame_bff_size = frame_buffer_size 47 | self.denoising_steps_num = len(t_index_list) 48 | 49 | self.cfg_type = cfg_type 50 | 51 | if use_denoising_batch: 52 | self.batch_size = self.denoising_steps_num * frame_buffer_size 53 | if self.cfg_type == "initialize": 54 | self.trt_unet_batch_size = ( 55 | self.denoising_steps_num + 1 56 | ) * self.frame_bff_size 57 | elif self.cfg_type == "full": 58 | self.trt_unet_batch_size = ( 59 | 2 * self.denoising_steps_num * self.frame_bff_size 60 | ) 61 | else: 62 | self.trt_unet_batch_size = self.denoising_steps_num * frame_buffer_size 63 | else: 64 | self.trt_unet_batch_size = self.frame_bff_size 65 | self.batch_size = frame_buffer_size 66 | 67 | self.t_list = t_index_list 68 | 69 | self.do_add_noise = do_add_noise 70 | self.use_denoising_batch = use_denoising_batch 71 | 72 | self.similar_image_filter = False 73 | self.similar_filter = SimilarImageFilter() 74 | self.prev_image_tensor = None 75 | self.prev_x_t_latent = None 76 | self.prev_image_result = None 77 | 78 | self.pipe = pipe 79 | self.image_processor = VaeImageProcessor(pipe.vae_scale_factor) 80 | 81 | self.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) 82 | self.text_encoder = pipe.text_encoder 83 | self.unet = pipe.unet 84 | self.vae = pipe.vae 85 | 86 | self.flow_model = raft_small(pretrained=True, progress=False).to(device=pipe.device).eval() 87 | 88 | self.cached_x_t_latent = deque(maxlen=4) 89 | 90 | self.inference_time_ema = 0 91 | 92 | def load_lcm_lora( 93 | self, 94 | pretrained_model_name_or_path_or_dict: Union[ 95 | str, Dict[str, torch.Tensor] 96 | ] = "latent-consistency/lcm-lora-sdv1-5", 97 | adapter_name: Optional[Any] = 'lcm', 98 | **kwargs, 99 | ) -> None: 100 | self.pipe.load_lora_weights( 101 | pretrained_model_name_or_path_or_dict, adapter_name, **kwargs 102 | ) 103 | 104 | def load_lora( 105 | self, 106 | pretrained_lora_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], 107 | adapter_name: Optional[Any] = None, 108 | **kwargs, 109 | ) -> None: 110 | self.pipe.load_lora_weights( 111 | pretrained_lora_model_name_or_path_or_dict, adapter_name, **kwargs 112 | ) 113 | 114 | def fuse_lora( 115 | self, 116 | fuse_unet: bool = True, 117 | fuse_text_encoder: bool = True, 118 | lora_scale: float = 1.0, 119 | safe_fusing: bool = False, 120 | ) -> None: 121 | self.pipe.fuse_lora( 122 | fuse_unet=fuse_unet, 123 | fuse_text_encoder=fuse_text_encoder, 124 | lora_scale=lora_scale, 125 | safe_fusing=safe_fusing, 126 | ) 127 | 128 | def enable_similar_image_filter(self, threshold: float = 0.98, max_skip_frame: float = 10) -> None: 129 | self.similar_image_filter = True 130 | self.similar_filter.set_threshold(threshold) 131 | self.similar_filter.set_max_skip_frame(max_skip_frame) 132 | 133 | def disable_similar_image_filter(self) -> None: 134 | self.similar_image_filter = False 135 | 136 | @torch.no_grad() 137 | def prepare( 138 | self, 139 | prompt: str, 140 | negative_prompt: str = "", 141 | num_inference_steps: int = 50, 142 | guidance_scale: float = 1.2, 143 | delta: float = 1.0, 144 | generator: Optional[torch.Generator] = torch.Generator(), 145 | seed: int = 2, 146 | ) -> None: 147 | self.generator = generator 148 | self.generator.manual_seed(seed) 149 | # initialize x_t_latent (it can be any random tensor) 150 | if self.denoising_steps_num > 1: 151 | self.x_t_latent_buffer = torch.zeros( 152 | ( 153 | (self.denoising_steps_num - 1) * self.frame_bff_size, 154 | 4, 155 | self.latent_height, 156 | self.latent_width, 157 | ), 158 | dtype=self.dtype, 159 | device=self.device, 160 | ) 161 | else: 162 | self.x_t_latent_buffer = None 163 | 164 | if self.cfg_type == "none": 165 | self.guidance_scale = 1.0 166 | else: 167 | self.guidance_scale = guidance_scale 168 | self.delta = delta 169 | 170 | do_classifier_free_guidance = False 171 | if self.guidance_scale > 1.0: 172 | do_classifier_free_guidance = True 173 | 174 | encoder_output = self.pipe.encode_prompt( 175 | prompt=prompt, 176 | device=self.device, 177 | num_images_per_prompt=1, 178 | do_classifier_free_guidance=True, 179 | negative_prompt=negative_prompt, 180 | ) 181 | 182 | self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1) 183 | self.null_prompt_embeds = encoder_output[1] 184 | 185 | if self.use_denoising_batch and self.cfg_type == "full": 186 | uncond_prompt_embeds = encoder_output[1].repeat(self.batch_size, 1, 1) 187 | elif self.cfg_type == "initialize": 188 | uncond_prompt_embeds = encoder_output[1].repeat(self.frame_bff_size, 1, 1) 189 | 190 | if self.guidance_scale > 1.0 and ( 191 | self.cfg_type == "initialize" or self.cfg_type == "full" 192 | ): 193 | self.prompt_embeds = torch.cat( 194 | [uncond_prompt_embeds, self.prompt_embeds], dim=0 195 | ) 196 | 197 | self.scheduler.set_timesteps(num_inference_steps, self.device) 198 | self.timesteps = self.scheduler.timesteps.to(self.device) 199 | 200 | # make sub timesteps list based on the indices in the t_list list and the values in the timesteps list 201 | self.sub_timesteps = [] 202 | for t in self.t_list: 203 | self.sub_timesteps.append(self.timesteps[t]) 204 | 205 | sub_timesteps_tensor = torch.tensor( 206 | self.sub_timesteps, dtype=torch.long, device=self.device 207 | ) 208 | self.sub_timesteps_tensor = torch.repeat_interleave( 209 | sub_timesteps_tensor, 210 | repeats=self.frame_bff_size if self.use_denoising_batch else 1, 211 | dim=0, 212 | ) 213 | 214 | self.init_noise = torch.randn( 215 | (self.batch_size, 4, self.latent_height, self.latent_width), 216 | generator=generator, 217 | ).to(device=self.device, dtype=self.dtype) 218 | 219 | self.randn_noise = self.init_noise[:1].clone() 220 | self.warp_noise = self.init_noise[:1].clone() 221 | 222 | self.stock_noise = torch.zeros_like(self.init_noise) 223 | 224 | c_skip_list = [] 225 | c_out_list = [] 226 | for timestep in self.sub_timesteps: 227 | c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete( 228 | timestep 229 | ) 230 | c_skip_list.append(c_skip) 231 | c_out_list.append(c_out) 232 | 233 | self.c_skip = ( 234 | torch.stack(c_skip_list) 235 | .view(len(self.t_list), 1, 1, 1) 236 | .to(dtype=self.dtype, device=self.device) 237 | ) 238 | self.c_out = ( 239 | torch.stack(c_out_list) 240 | .view(len(self.t_list), 1, 1, 1) 241 | .to(dtype=self.dtype, device=self.device) 242 | ) 243 | 244 | alpha_prod_t_sqrt_list = [] 245 | beta_prod_t_sqrt_list = [] 246 | for timestep in self.sub_timesteps: 247 | alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt() 248 | beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt() 249 | alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt) 250 | beta_prod_t_sqrt_list.append(beta_prod_t_sqrt) 251 | alpha_prod_t_sqrt = ( 252 | torch.stack(alpha_prod_t_sqrt_list) 253 | .view(len(self.t_list), 1, 1, 1) 254 | .to(dtype=self.dtype, device=self.device) 255 | ) 256 | beta_prod_t_sqrt = ( 257 | torch.stack(beta_prod_t_sqrt_list) 258 | .view(len(self.t_list), 1, 1, 1) 259 | .to(dtype=self.dtype, device=self.device) 260 | ) 261 | self.alpha_prod_t_sqrt = torch.repeat_interleave( 262 | alpha_prod_t_sqrt, 263 | repeats=self.frame_bff_size if self.use_denoising_batch else 1, 264 | dim=0, 265 | ) 266 | self.beta_prod_t_sqrt = torch.repeat_interleave( 267 | beta_prod_t_sqrt, 268 | repeats=self.frame_bff_size if self.use_denoising_batch else 1, 269 | dim=0, 270 | ) 271 | 272 | @torch.no_grad() 273 | def update_prompt(self, prompt: str) -> None: 274 | encoder_output = self.pipe.encode_prompt( 275 | prompt=prompt, 276 | device=self.device, 277 | num_images_per_prompt=1, 278 | do_classifier_free_guidance=False, 279 | ) 280 | self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1) 281 | 282 | def add_noise( 283 | self, 284 | original_samples: torch.Tensor, 285 | noise: torch.Tensor, 286 | t_index: int, 287 | ) -> torch.Tensor: 288 | noisy_samples = ( 289 | self.alpha_prod_t_sqrt[t_index] * original_samples 290 | + self.beta_prod_t_sqrt[t_index] * noise 291 | ) 292 | return noisy_samples 293 | 294 | def scheduler_step_batch( 295 | self, 296 | model_pred_batch: torch.Tensor, 297 | x_t_latent_batch: torch.Tensor, 298 | idx: Optional[int] = None, 299 | ) -> torch.Tensor: 300 | # TODO: use t_list to select beta_prod_t_sqrt 301 | if idx is None: 302 | F_theta = ( 303 | x_t_latent_batch - self.beta_prod_t_sqrt * model_pred_batch 304 | ) / self.alpha_prod_t_sqrt 305 | denoised_batch = self.c_out * F_theta + self.c_skip * x_t_latent_batch 306 | else: 307 | F_theta = ( 308 | x_t_latent_batch - self.beta_prod_t_sqrt[idx] * model_pred_batch 309 | ) / self.alpha_prod_t_sqrt[idx] 310 | denoised_batch = ( 311 | self.c_out[idx] * F_theta + self.c_skip[idx] * x_t_latent_batch 312 | ) 313 | 314 | return denoised_batch 315 | 316 | def unet_step( 317 | self, 318 | x_t_latent: torch.Tensor, 319 | t_list: Union[torch.Tensor, list[int]], 320 | idx: Optional[int] = None, 321 | ) -> Tuple[torch.Tensor, torch.Tensor]: 322 | if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"): 323 | x_t_latent_plus_uc = torch.concat([x_t_latent[0:1], x_t_latent], dim=0) 324 | t_list = torch.concat([t_list[0:1], t_list], dim=0) 325 | elif self.guidance_scale > 1.0 and (self.cfg_type == "full"): 326 | x_t_latent_plus_uc = torch.concat([x_t_latent, x_t_latent], dim=0) 327 | t_list = torch.concat([t_list, t_list], dim=0) 328 | else: 329 | x_t_latent_plus_uc = x_t_latent 330 | 331 | model_pred = self.unet( 332 | x_t_latent_plus_uc, 333 | t_list, 334 | encoder_hidden_states=self.prompt_embeds, 335 | return_dict=False, 336 | )[0] 337 | 338 | if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"): 339 | noise_pred_text = model_pred[1:] 340 | self.stock_noise = torch.concat( 341 | [model_pred[0:1], self.stock_noise[1:]], dim=0 342 | ) # ここコメントアウトでself out cfg 343 | elif self.guidance_scale > 1.0 and (self.cfg_type == "full"): 344 | noise_pred_uncond, noise_pred_text = model_pred.chunk(2) 345 | else: 346 | noise_pred_text = model_pred 347 | if self.guidance_scale > 1.0 and ( 348 | self.cfg_type == "self" or self.cfg_type == "initialize" 349 | ): 350 | noise_pred_uncond = self.stock_noise * self.delta 351 | if self.guidance_scale > 1.0 and self.cfg_type != "none": 352 | model_pred = noise_pred_uncond + self.guidance_scale * ( 353 | noise_pred_text - noise_pred_uncond 354 | ) 355 | else: 356 | model_pred = noise_pred_text 357 | 358 | # compute the previous noisy sample x_t -> x_t-1 359 | if self.use_denoising_batch: 360 | denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx) 361 | if self.cfg_type == "self" or self.cfg_type == "initialize": 362 | scaled_noise = self.beta_prod_t_sqrt * self.stock_noise 363 | delta_x = self.scheduler_step_batch(model_pred, scaled_noise, idx) 364 | alpha_next = torch.concat( 365 | [ 366 | self.alpha_prod_t_sqrt[1:], 367 | torch.ones_like(self.alpha_prod_t_sqrt[0:1]), 368 | ], 369 | dim=0, 370 | ) 371 | delta_x = alpha_next * delta_x 372 | beta_next = torch.concat( 373 | [ 374 | self.beta_prod_t_sqrt[1:], 375 | torch.ones_like(self.beta_prod_t_sqrt[0:1]), 376 | ], 377 | dim=0, 378 | ) 379 | delta_x = delta_x / beta_next 380 | init_noise = torch.concat( 381 | [self.init_noise[1:], self.init_noise[0:1]], dim=0 382 | ) 383 | self.stock_noise = init_noise + delta_x 384 | 385 | else: 386 | # denoised_batch = self.scheduler.step(model_pred, t_list[0], x_t_latent).denoised 387 | denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx) 388 | 389 | return denoised_batch, model_pred 390 | 391 | 392 | def norm_noise(self, noise): 393 | # Compute mean and std of blended_noise 394 | mean = noise.mean() 395 | std = noise.std() 396 | 397 | # Normalize blended_noise to have mean=0 and std=1 398 | normalized_noise = (noise - mean) / std 399 | return normalized_noise 400 | 401 | def encode_image(self, image_tensors: torch.Tensor) -> torch.Tensor: 402 | image_tensors = image_tensors.to( 403 | device=self.device, 404 | dtype=self.vae.dtype, 405 | ) 406 | img_latent = retrieve_latents(self.vae.encode(image_tensors), self.generator) 407 | img_latent = img_latent * self.vae.config.scaling_factor 408 | x_t_latent = self.add_noise(img_latent, self.init_noise[0], 0) 409 | return x_t_latent 410 | 411 | def decode_image(self, x_0_pred_out: torch.Tensor) -> torch.Tensor: 412 | output_latent = self.vae.decode( 413 | x_0_pred_out / self.vae.config.scaling_factor, return_dict=False 414 | )[0] 415 | return output_latent 416 | 417 | def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: 418 | prev_latent_batch = self.x_t_latent_buffer 419 | if self.use_denoising_batch: 420 | t_list = self.sub_timesteps_tensor 421 | if self.denoising_steps_num > 1: 422 | x_t_latent = torch.cat((x_t_latent, prev_latent_batch), dim=0) 423 | self.stock_noise = torch.cat( 424 | (self.init_noise[0:1], self.stock_noise[:-1]), dim=0 425 | ) 426 | x_0_pred_batch, model_pred = self.unet_step(x_t_latent, t_list) 427 | 428 | if self.denoising_steps_num > 1: 429 | x_0_pred_out = x_0_pred_batch[-1].unsqueeze(0) 430 | if self.do_add_noise: 431 | self.x_t_latent_buffer = ( 432 | self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] 433 | + self.beta_prod_t_sqrt[1:] * self.init_noise[1:] 434 | ) 435 | else: 436 | self.x_t_latent_buffer = ( 437 | self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] 438 | ) 439 | else: 440 | x_0_pred_out = x_0_pred_batch 441 | self.x_t_latent_buffer = None 442 | else: 443 | self.init_noise = x_t_latent 444 | for idx, t in enumerate(self.sub_timesteps_tensor): 445 | t = t.view( 446 | 1, 447 | ).repeat( 448 | self.frame_bff_size, 449 | ) 450 | x_0_pred, model_pred = self.unet_step(x_t_latent, t, idx) 451 | if idx < len(self.sub_timesteps_tensor) - 1: 452 | if self.do_add_noise: 453 | x_t_latent = self.alpha_prod_t_sqrt[ 454 | idx + 1 455 | ] * x_0_pred + self.beta_prod_t_sqrt[ 456 | idx + 1 457 | ] * torch.randn_like( 458 | x_0_pred, device=self.device, dtype=self.dtype 459 | ) 460 | else: 461 | x_t_latent = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred 462 | x_0_pred_out = x_0_pred 463 | return x_0_pred_out 464 | 465 | @torch.no_grad() 466 | def __call__( 467 | self, x: Union[torch.Tensor, PIL.Image.Image, np.ndarray] = None 468 | ) -> torch.Tensor: 469 | start = torch.cuda.Event(enable_timing=True) 470 | end = torch.cuda.Event(enable_timing=True) 471 | start.record() 472 | if x is not None: 473 | x = self.image_processor.preprocess(x, self.height, self.width).to( 474 | device=self.device, dtype=self.dtype 475 | ) 476 | if self.similar_image_filter: 477 | x = self.similar_filter(x) 478 | if x is None: 479 | time.sleep(self.inference_time_ema) 480 | return self.prev_image_result 481 | x_t_latent = self.encode_image(x) 482 | else: 483 | # TODO: check the dimension of x_t_latent 484 | x_t_latent = torch.randn((1, 4, self.latent_height, self.latent_width)).to( 485 | device=self.device, dtype=self.dtype 486 | ) 487 | x_0_pred_out = self.predict_x0_batch(x_t_latent) 488 | x_output = self.decode_image(x_0_pred_out).detach().clone() 489 | 490 | self.prev_image_result = x_output 491 | end.record() 492 | torch.cuda.synchronize() 493 | inference_time = start.elapsed_time(end) / 1000 494 | self.inference_time_ema = 0.9 * self.inference_time_ema + 0.1 * inference_time 495 | return x_output 496 | -------------------------------------------------------------------------------- /StreamV2V/streamv2v/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-StreamV2V/ddc939992d873706ae54094d906ea328cb82613b/StreamV2V/streamv2v/tools/__init__.py -------------------------------------------------------------------------------- /StreamV2V/streamv2v/tools/install-tensorrt.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | import fire 4 | from packaging.version import Version 5 | 6 | from ..pip_utils import is_installed, run_pip, version 7 | import platform 8 | 9 | 10 | def get_cuda_version_from_torch() -> Optional[Literal["11", "12"]]: 11 | try: 12 | import torch 13 | except ImportError: 14 | return None 15 | 16 | return torch.version.cuda.split(".")[0] 17 | 18 | 19 | def install(cu: Optional[Literal["11", "12"]] = get_cuda_version_from_torch()): 20 | if cu is None or cu not in ["11", "12"]: 21 | print("Could not detect CUDA version. Please specify manually.") 22 | return 23 | print("Installing TensorRT requirements...") 24 | 25 | if is_installed("tensorrt"): 26 | if version("tensorrt") < Version("9.0.0"): 27 | run_pip("uninstall -y tensorrt") 28 | 29 | cudnn_name = f"nvidia-cudnn-cu{cu}==8.9.4.25" 30 | 31 | if not is_installed("tensorrt"): 32 | run_pip(f"install {cudnn_name} --no-cache-dir") 33 | run_pip( 34 | "install --pre --extra-index-url https://pypi.nvidia.com tensorrt==9.0.1.post11.dev4 --no-cache-dir" 35 | ) 36 | 37 | if not is_installed("polygraphy"): 38 | run_pip( 39 | "install polygraphy==0.47.1 --extra-index-url https://pypi.ngc.nvidia.com" 40 | ) 41 | if not is_installed("onnx_graphsurgeon"): 42 | run_pip( 43 | "install onnx-graphsurgeon==0.3.26 --extra-index-url https://pypi.ngc.nvidia.com" 44 | ) 45 | # if platform.system() == 'Windows' and not is_installed("pywin32"): 46 | # run_pip( 47 | # "install pywin32" 48 | # ) 49 | 50 | pass 51 | 52 | 53 | if __name__ == "__main__": 54 | fire.Fire(install) 55 | -------------------------------------------------------------------------------- /StreamV2V/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-StreamV2V/ddc939992d873706ae54094d906ea328cb82613b/StreamV2V/utils/__init__.py -------------------------------------------------------------------------------- /StreamV2V/utils/viewer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import threading 4 | import time 5 | import tkinter as tk 6 | from multiprocessing import Queue 7 | from typing import List 8 | from PIL import Image, ImageTk 9 | from streamv2v.image_utils import postprocess_image 10 | 11 | sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) 12 | 13 | 14 | def update_image(image_data: Image.Image, label: tk.Label) -> None: 15 | """ 16 | Update the image displayed on a Tkinter label. 17 | 18 | Parameters 19 | ---------- 20 | image_data : Image.Image 21 | The image to be displayed. 22 | label : tk.Label 23 | The labels where the image will be updated. 24 | """ 25 | width = 512 26 | height = 512 27 | tk_image = ImageTk.PhotoImage(image_data, size=width) 28 | label.configure(image=tk_image, width=width, height=height) 29 | label.image = tk_image # keep a reference 30 | 31 | def _receive_images( 32 | queue: Queue, fps_queue: Queue, label: tk.Label, fps_label: tk.Label 33 | ) -> None: 34 | """ 35 | Continuously receive images from a queue and update the labels. 36 | 37 | Parameters 38 | ---------- 39 | queue : Queue 40 | The queue to receive images from. 41 | fps_queue : Queue 42 | The queue to put the calculated fps. 43 | label : tk.Label 44 | The label to update with images. 45 | fps_label : tk.Label 46 | The label to show fps. 47 | """ 48 | while True: 49 | try: 50 | if not queue.empty(): 51 | label.after( 52 | 0, 53 | update_image, 54 | postprocess_image(queue.get(block=False), output_type="pil")[0], 55 | label, 56 | ) 57 | if not fps_queue.empty(): 58 | fps_label.config(text=f"FPS: {fps_queue.get(block=False):.2f}") 59 | 60 | time.sleep(0.0005) 61 | except KeyboardInterrupt: 62 | return 63 | 64 | 65 | def receive_images(queue: Queue, fps_queue: Queue) -> None: 66 | """ 67 | Setup the Tkinter window and start the thread to receive images. 68 | 69 | Parameters 70 | ---------- 71 | queue : Queue 72 | The queue to receive images from. 73 | fps_queue : Queue 74 | The queue to put the calculated fps. 75 | """ 76 | root = tk.Tk() 77 | root.title("Image Viewer") 78 | label = tk.Label(root) 79 | fps_label = tk.Label(root, text="FPS: 0") 80 | label.grid(column=0) 81 | fps_label.grid(column=1) 82 | 83 | def on_closing(): 84 | print("window closed") 85 | root.quit() # stop event loop 86 | return 87 | 88 | thread = threading.Thread( 89 | target=_receive_images, args=(queue, fps_queue, label, fps_label), daemon=True 90 | ) 91 | thread.start() 92 | 93 | try: 94 | root.protocol("WM_DELETE_WINDOW", on_closing) 95 | root.mainloop() 96 | except KeyboardInterrupt: 97 | return 98 | 99 | -------------------------------------------------------------------------------- /StreamV2V/utils/wrapper.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | from pathlib import Path 4 | import traceback 5 | from typing import List, Literal, Optional, Union, Dict 6 | 7 | import numpy as np 8 | import torch 9 | from diffusers import AutoencoderTiny, StableDiffusionPipeline 10 | from diffusers.models.attention_processor import XFormersAttnProcessor, AttnProcessor2_0 11 | from PIL import Image 12 | 13 | from streamv2v import StreamV2V 14 | from streamv2v.image_utils import postprocess_image 15 | from streamv2v.models.attention_processor import CachedSTXFormersAttnProcessor, CachedSTAttnProcessor2_0 16 | 17 | 18 | torch.set_grad_enabled(False) 19 | torch.backends.cuda.matmul.allow_tf32 = True 20 | torch.backends.cudnn.allow_tf32 = True 21 | 22 | 23 | class StreamV2VWrapper: 24 | def __init__( 25 | self, 26 | model_id_or_path: str, 27 | t_index_list: List[int], 28 | lora_dict: Optional[Dict[str, float]] = None, 29 | output_type: Literal["pil", "pt", "np", "latent"] = "pil", 30 | lcm_lora_id: Optional[str] = None, 31 | vae_id: Optional[str] = None, 32 | device: Literal["cpu", "cuda"] = "cuda", 33 | dtype: torch.dtype = torch.float16, 34 | frame_buffer_size: int = 1, 35 | width: int = 512, 36 | height: int = 512, 37 | warmup: int = 10, 38 | acceleration: Literal["none", "xformers", "tensorrt"] = "xformers", 39 | do_add_noise: bool = True, 40 | device_ids: Optional[List[int]] = None, 41 | use_lcm_lora: bool = True, 42 | use_tiny_vae: bool = True, 43 | enable_similar_image_filter: bool = False, 44 | similar_image_filter_threshold: float = 0.98, 45 | similar_image_filter_max_skip_frame: int = 10, 46 | use_denoising_batch: bool = True, 47 | cfg_type: Literal["none", "full", "self", "initialize"] = "self", 48 | use_cached_attn: bool = True, 49 | use_feature_injection: bool = True, 50 | feature_injection_strength: float = 0.8, 51 | feature_similarity_threshold: float = 0.98, 52 | cache_interval: int = 4, 53 | cache_maxframes: int = 1, 54 | use_tome_cache: bool = True, 55 | tome_metric: str = "keys", 56 | tome_ratio: float = 0.5, 57 | use_grid: bool = False, 58 | seed: int = 2, 59 | use_safety_checker: bool = False, 60 | engine_dir: Optional[Union[str, Path]] = "engines", 61 | ): 62 | """ 63 | Initializes the StreamV2VWrapper. 64 | 65 | Parameters 66 | ---------- 67 | model_id_or_path : str 68 | The model identifier or path to load. 69 | t_index_list : List[int] 70 | The list of indices to use for inference. 71 | lora_dict : Optional[Dict[str, float]], optional 72 | Dictionary of LoRA names and their corresponding scales, 73 | by default None. Example: {'LoRA_1': 0.5, 'LoRA_2': 0.7, ...} 74 | output_type : Literal["pil", "pt", "np", "latent"], optional 75 | The type of output image, by default "pil". 76 | lcm_lora_id : Optional[str], optional 77 | The identifier for the LCM-LoRA to load, by default None. 78 | If None, the default LCM-LoRA ("latent-consistency/lcm-lora-sdv1-5") is used. 79 | vae_id : Optional[str], optional 80 | The identifier for the VAE to load, by default None. 81 | If None, the default TinyVAE ("madebyollin/taesd") is used. 82 | device : Literal["cpu", "cuda"], optional 83 | The device to use for inference, by default "cuda". 84 | dtype : torch.dtype, optional 85 | The data type for inference, by default torch.float16. 86 | frame_buffer_size : int, optional 87 | The size of the frame buffer for denoising batch, by default 1. 88 | width : int, optional 89 | The width of the image, by default 512. 90 | height : int, optional 91 | The height of the image, by default 512. 92 | warmup : int, optional 93 | The number of warmup steps to perform, by default 10. 94 | acceleration : Literal["none", "xformers", "tensorrt"], optional 95 | The acceleration method, by default "xformers". 96 | do_add_noise : bool, optional 97 | Whether to add noise during denoising steps, by default True. 98 | device_ids : Optional[List[int]], optional 99 | List of device IDs to use for DataParallel, by default None. 100 | use_lcm_lora : bool, optional 101 | Whether to use LCM-LoRA, by default True. 102 | use_tiny_vae : bool, optional 103 | Whether to use TinyVAE, by default True. 104 | enable_similar_image_filter : bool, optional 105 | Whether to enable similar image filtering, by default False. 106 | similar_image_filter_threshold : float, optional 107 | The threshold for the similar image filter, by default 0.98. 108 | similar_image_filter_max_skip_frame : int, optional 109 | The maximum number of frames to skip for similar image filter, by default 10. 110 | use_denoising_batch : bool, optional 111 | Whether to use denoising batch, by default True. 112 | cfg_type : Literal["none", "full", "self", "initialize"], optional 113 | The CFG type for img2img mode, by default "self". 114 | use_cached_attn : bool, optional 115 | Whether to cache self-attention maps from previous frames to improve temporal consistency, by default True. 116 | use_feature_injection : bool, optional 117 | Whether to use feature maps from previous frames to improve temporal consistency, by default True. 118 | feature_injection_strength : float, optional 119 | The strength of feature injection, by default 0.8. 120 | feature_similarity_threshold : float, optional 121 | The similarity threshold for feature injection, by default 0.98. 122 | cache_interval : int, optional 123 | The interval at which to cache attention maps, by default 4. 124 | cache_maxframes : int, optional 125 | The maximum number of frames to cache attention maps, by default 1. 126 | use_tome_cache : bool, optional 127 | Whether to use Tome caching, by default True. 128 | tome_metric : str, optional 129 | The metric to use for Tome, by default "keys". 130 | tome_ratio : float, optional 131 | The ratio for Tome, by default 0.5. 132 | use_grid : bool, optional 133 | Whether to use grid, by default False. 134 | seed : int, optional 135 | The seed for random number generation, by default 2. 136 | use_safety_checker : bool, optional 137 | Whether to use a safety checker, by default False. 138 | engine_dir : Optional[Union[str, Path]], optional 139 | The directory for the engine, by default "engines". 140 | """ 141 | # TODO: Test SD turbo 142 | self.sd_turbo = "turbo" in model_id_or_path 143 | 144 | assert use_denoising_batch, "vid2vid mode must use denoising batch for now." 145 | 146 | self.device = device 147 | self.dtype = dtype 148 | self.width = width 149 | self.height = height 150 | self.output_type = output_type 151 | self.frame_buffer_size = frame_buffer_size 152 | self.batch_size = ( 153 | len(t_index_list) * frame_buffer_size 154 | if use_denoising_batch 155 | else frame_buffer_size 156 | ) 157 | 158 | self.use_denoising_batch = use_denoising_batch 159 | self.use_cached_attn = use_cached_attn 160 | self.use_feature_injection = use_feature_injection 161 | self.feature_injection_strength = feature_injection_strength 162 | self.feature_similarity_threshold = feature_similarity_threshold 163 | self.cache_interval = cache_interval 164 | self.cache_maxframes = cache_maxframes 165 | self.use_tome_cache = use_tome_cache 166 | self.tome_metric = tome_metric 167 | self.tome_ratio = tome_ratio 168 | self.use_grid = use_grid 169 | self.use_safety_checker = use_safety_checker 170 | 171 | self.stream: StreamV2V = self._load_model( 172 | model_id_or_path=model_id_or_path, 173 | lora_dict=lora_dict, 174 | lcm_lora_id=lcm_lora_id, 175 | vae_id=vae_id, 176 | t_index_list=t_index_list, 177 | acceleration=acceleration, 178 | warmup=warmup, 179 | do_add_noise=do_add_noise, 180 | use_lcm_lora=use_lcm_lora, 181 | use_tiny_vae=use_tiny_vae, 182 | cfg_type=cfg_type, 183 | seed=seed, 184 | engine_dir=engine_dir, 185 | ) 186 | 187 | if device_ids is not None: 188 | self.stream.unet = torch.nn.DataParallel( 189 | self.stream.unet, device_ids=device_ids 190 | ) 191 | 192 | if enable_similar_image_filter: 193 | self.stream.enable_similar_image_filter(similar_image_filter_threshold, similar_image_filter_max_skip_frame) 194 | 195 | def prepare( 196 | self, 197 | prompt: str, 198 | negative_prompt: str = "", 199 | num_inference_steps: int = 50, 200 | guidance_scale: float = 1.2, 201 | delta: float = 1.0, 202 | ) -> None: 203 | """ 204 | Prepares the model for inference. 205 | 206 | Parameters 207 | ---------- 208 | prompt : str 209 | The prompt to generate images from. 210 | num_inference_steps : int, optional 211 | The number of inference steps to perform, by default 50. 212 | guidance_scale : float, optional 213 | The guidance scale to use, by default 1.2. 214 | delta : float, optional 215 | The delta multiplier of virtual residual noise, 216 | by default 1.0. 217 | """ 218 | self.stream.prepare( 219 | prompt, 220 | negative_prompt, 221 | num_inference_steps=num_inference_steps, 222 | guidance_scale=guidance_scale, 223 | delta=delta, 224 | ) 225 | 226 | def __call__( 227 | self, 228 | image: Union[str, Image.Image, torch.Tensor], 229 | prompt: Optional[str] = None, 230 | ) -> Union[Image.Image, List[Image.Image]]: 231 | """ 232 | Performs img2img 233 | 234 | Parameters 235 | ---------- 236 | image : Optional[Union[str, Image.Image, torch.Tensor]] 237 | The image to generate from. 238 | prompt : Optional[str] 239 | The prompt to generate images from. 240 | 241 | Returns 242 | ------- 243 | Union[Image.Image, List[Image.Image]] 244 | The generated image. 245 | """ 246 | return self.img2img(image, prompt) 247 | 248 | def img2img( 249 | self, image: Union[str, Image.Image, torch.Tensor], prompt: Optional[str] = None 250 | ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]: 251 | """ 252 | Performs img2img. 253 | 254 | Parameters 255 | ---------- 256 | image : Union[str, Image.Image, torch.Tensor] 257 | The image to generate from. 258 | 259 | Returns 260 | ------- 261 | Image.Image 262 | The generated image. 263 | """ 264 | if prompt is not None: 265 | self.stream.update_prompt(prompt) 266 | 267 | if isinstance(image, str) or isinstance(image, Image.Image): 268 | image = self.preprocess_image(image) 269 | 270 | image_tensor = self.stream(image) 271 | image = self.postprocess_image(image_tensor, output_type=self.output_type) 272 | 273 | if self.use_safety_checker: 274 | safety_checker_input = self.feature_extractor( 275 | image, return_tensors="pt" 276 | ).to(self.device) 277 | _, has_nsfw_concept = self.safety_checker( 278 | images=image_tensor.to(self.dtype), 279 | clip_input=safety_checker_input.pixel_values.to(self.dtype), 280 | ) 281 | image = self.nsfw_fallback_img if has_nsfw_concept[0] else image 282 | 283 | return image 284 | 285 | def preprocess_image(self, image: Union[str, Image.Image]) -> torch.Tensor: 286 | """ 287 | Preprocesses the image. 288 | 289 | Parameters 290 | ---------- 291 | image : Union[str, Image.Image, torch.Tensor] 292 | The image to preprocess. 293 | 294 | Returns 295 | ------- 296 | torch.Tensor 297 | The preprocessed image. 298 | """ 299 | if isinstance(image, str): 300 | image = Image.open(image).convert("RGB").resize((self.width, self.height)) 301 | if isinstance(image, Image.Image): 302 | image = image.convert("RGB").resize((self.width, self.height)) 303 | 304 | return self.stream.image_processor.preprocess( 305 | image, self.height, self.width 306 | ).to(device=self.device, dtype=self.dtype) 307 | 308 | def postprocess_image( 309 | self, image_tensor: torch.Tensor, output_type: str = "pil" 310 | ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]: 311 | """ 312 | Postprocesses the image. 313 | 314 | Parameters 315 | ---------- 316 | image_tensor : torch.Tensor 317 | The image tensor to postprocess. 318 | 319 | Returns 320 | ------- 321 | Union[Image.Image, List[Image.Image]] 322 | The postprocessed image. 323 | """ 324 | if self.frame_buffer_size > 1: 325 | return postprocess_image(image_tensor.cpu(), output_type=output_type) 326 | else: 327 | return postprocess_image(image_tensor.cpu(), output_type=output_type)[0] 328 | 329 | def _load_model( 330 | self, 331 | model_id_or_path: str, 332 | t_index_list: List[int], 333 | lora_dict: Optional[Dict[str, float]] = None, 334 | lcm_lora_id: Optional[str] = None, 335 | vae_id: Optional[str] = None, 336 | acceleration: Literal["none", "xformers", "tensorrt"] = "xformers", 337 | warmup: int = 10, 338 | do_add_noise: bool = True, 339 | use_lcm_lora: bool = True, 340 | use_tiny_vae: bool = True, 341 | cfg_type: Literal["none", "full", "self", "initialize"] = "self", 342 | seed: int = 2, 343 | engine_dir: Optional[Union[str, Path]] = "engines", 344 | ) -> StreamV2V: 345 | """ 346 | Loads the model. 347 | 348 | This method does the following: 349 | 350 | 1. Loads the model from the model_id_or_path. 351 | 2. Loads and fuses the LCM-LoRA model from the lcm_lora_id if needed. 352 | 3. Loads the VAE model from the vae_id if needed. 353 | 4. Enables acceleration if needed. 354 | 5. Prepares the model for inference. 355 | 6. Load the safety checker if needed. 356 | 357 | Parameters 358 | ---------- 359 | model_id_or_path : str 360 | The model id or path to load. 361 | t_index_list : List[int] 362 | The t_index_list to use for inference. 363 | lora_dict : Optional[Dict[str, float]], optional 364 | The lora_dict to load, by default None. 365 | Keys are the LoRA names and values are the LoRA scales. 366 | Example: {'LoRA_1' : 0.5 , 'LoRA_2' : 0.7 ,...} 367 | lcm_lora_id : Optional[str], optional 368 | The lcm_lora_id to load, by default None. 369 | vae_id : Optional[str], optional 370 | The vae_id to load, by default None. 371 | acceleration : Literal["none", "xfomers", "sfast", "tensorrt"], optional 372 | The acceleration method, by default "tensorrt". 373 | warmup : int, optional 374 | The number of warmup steps to perform, by default 10. 375 | do_add_noise : bool, optional 376 | Whether to add noise for following denoising steps or not, 377 | by default True. 378 | use_lcm_lora : bool, optional 379 | Whether to use LCM-LoRA or not, by default True. 380 | use_tiny_vae : bool, optional 381 | Whether to use TinyVAE or not, by default True. 382 | cfg_type : Literal["none", "full", "self", "initialize"], 383 | optional 384 | The cfg_type for img2img mode, by default " seed : int, optional 385 | ". 386 | seed : int, optional 387 | The seed, by default 2. 388 | 389 | Returns 390 | ------- 391 | StreamV2V 392 | The loaded model. 393 | """ 394 | 395 | try: # Load from local directory 396 | pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained( 397 | model_id_or_path, 398 | ).to(device=self.device, dtype=self.dtype) 399 | 400 | except ValueError: # Load from huggingface 401 | pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file( 402 | model_id_or_path, 403 | ).to(device=self.device, dtype=self.dtype) 404 | except Exception: # No model found 405 | traceback.print_exc() 406 | print("Model load has failed. Doesn't exist.") 407 | exit() 408 | 409 | stream = StreamV2V( 410 | pipe=pipe, 411 | t_index_list=t_index_list, 412 | torch_dtype=self.dtype, 413 | width=self.width, 414 | height=self.height, 415 | do_add_noise=do_add_noise, 416 | frame_buffer_size=self.frame_buffer_size, 417 | use_denoising_batch=self.use_denoising_batch, 418 | cfg_type=cfg_type, 419 | ) 420 | if not self.sd_turbo: 421 | if use_lcm_lora: 422 | if lcm_lora_id is not None: 423 | stream.load_lcm_lora( 424 | pretrained_model_name_or_path_or_dict=lcm_lora_id, 425 | adapter_name="lcm") 426 | else: 427 | stream.load_lcm_lora( 428 | pretrained_model_name_or_path_or_dict="latent-consistency/lcm-lora-sdv1-5", 429 | adapter_name="lcm" 430 | ) 431 | 432 | if lora_dict is not None: 433 | for lora_name, lora_scale in lora_dict.items(): 434 | stream.load_lora(lora_name) 435 | 436 | if use_tiny_vae: 437 | if vae_id is not None: 438 | stream.vae = AutoencoderTiny.from_pretrained(vae_id).to( 439 | device=pipe.device, dtype=pipe.dtype 440 | ) 441 | else: 442 | stream.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd").to( 443 | device=pipe.device, dtype=pipe.dtype 444 | ) 445 | 446 | try: 447 | if acceleration == "xformers": 448 | stream.pipe.enable_xformers_memory_efficient_attention() 449 | if self.use_cached_attn: 450 | attn_processors = stream.pipe.unet.attn_processors 451 | new_attn_processors = {} 452 | for key, attn_processor in attn_processors.items(): 453 | assert isinstance(attn_processor, XFormersAttnProcessor), \ 454 | "We only replace 'XFormersAttnProcessor' to 'CachedSTXFormersAttnProcessor'" 455 | new_attn_processors[key] = CachedSTXFormersAttnProcessor(name=key, 456 | use_feature_injection=self.use_feature_injection, 457 | feature_injection_strength=self.feature_injection_strength, 458 | feature_similarity_threshold=self.feature_similarity_threshold, 459 | interval=self.cache_interval, 460 | max_frames=self.cache_maxframes, 461 | use_tome_cache=self.use_tome_cache, 462 | tome_metric=self.tome_metric, 463 | tome_ratio=self.tome_ratio, 464 | use_grid=self.use_grid) 465 | stream.pipe.unet.set_attn_processor(new_attn_processors) 466 | 467 | if acceleration == "tensorrt": 468 | if self.use_cached_attn: 469 | raise NotImplementedError("TensorRT seems not support the costom attention_processor") 470 | else: 471 | stream.pipe.enable_xformers_memory_efficient_attention() 472 | if self.use_cached_attn: 473 | attn_processors = stream.pipe.unet.attn_processors 474 | new_attn_processors = {} 475 | for key, attn_processor in attn_processors.items(): 476 | assert isinstance(attn_processor, XFormersAttnProcessor), \ 477 | "We only replace 'XFormersAttnProcessor' to 'CachedSTXFormersAttnProcessor'" 478 | new_attn_processors[key] = CachedSTXFormersAttnProcessor(name=key, 479 | use_feature_injection=self.use_feature_injection, 480 | feature_injection_strength=self.feature_injection_strength, 481 | feature_similarity_threshold=self.feature_similarity_threshold, 482 | interval=self.cache_interval, 483 | max_frames=self.cache_maxframes, 484 | use_tome_cache=self.use_tome_cache, 485 | tome_metric=self.tome_metric, 486 | tome_ratio=self.tome_ratio, 487 | use_grid=self.use_grid) 488 | stream.pipe.unet.set_attn_processor(new_attn_processors) 489 | 490 | from polygraphy import cuda 491 | from streamv2v.acceleration.tensorrt import ( 492 | TorchVAEEncoder, 493 | compile_unet, 494 | compile_vae_decoder, 495 | compile_vae_encoder, 496 | ) 497 | from streamv2v.acceleration.tensorrt.engine import ( 498 | AutoencoderKLEngine, 499 | UNet2DConditionModelEngine, 500 | ) 501 | from streamv2v.acceleration.tensorrt.models import ( 502 | VAE, 503 | UNet, 504 | VAEEncoder, 505 | ) 506 | 507 | def create_prefix( 508 | model_id_or_path: str, 509 | max_batch_size: int, 510 | min_batch_size: int, 511 | ): 512 | maybe_path = Path(model_id_or_path) 513 | if maybe_path.exists(): 514 | return f"{maybe_path.stem}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--max_batch-{max_batch_size}--min_batch-{min_batch_size}--cache--{self.use_cached_attn}" 515 | else: 516 | return f"{model_id_or_path}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--max_batch-{max_batch_size}--min_batch-{min_batch_size}--cache--{self.use_cached_attn}" 517 | 518 | engine_dir = Path(engine_dir) 519 | unet_path = os.path.join( 520 | engine_dir, 521 | create_prefix( 522 | model_id_or_path=model_id_or_path, 523 | max_batch_size=stream.trt_unet_batch_size, 524 | min_batch_size=stream.trt_unet_batch_size, 525 | ), 526 | "unet.engine", 527 | ) 528 | vae_encoder_path = os.path.join( 529 | engine_dir, 530 | create_prefix( 531 | model_id_or_path=model_id_or_path, 532 | max_batch_size=stream.frame_bff_size, 533 | min_batch_size=stream.frame_bff_size, 534 | ), 535 | "vae_encoder.engine", 536 | ) 537 | vae_decoder_path = os.path.join( 538 | engine_dir, 539 | create_prefix( 540 | model_id_or_path=model_id_or_path, 541 | max_batch_size=stream.frame_bff_size, 542 | min_batch_size=stream.frame_bff_size, 543 | ), 544 | "vae_decoder.engine", 545 | ) 546 | 547 | if not os.path.exists(unet_path): 548 | os.makedirs(os.path.dirname(unet_path), exist_ok=True) 549 | unet_model = UNet( 550 | fp16=True, 551 | device=stream.device, 552 | max_batch_size=stream.trt_unet_batch_size, 553 | min_batch_size=stream.trt_unet_batch_size, 554 | embedding_dim=stream.text_encoder.config.hidden_size, 555 | unet_dim=stream.unet.config.in_channels, 556 | ) 557 | compile_unet( 558 | stream.unet, 559 | unet_model, 560 | unet_path + ".onnx", 561 | unet_path + ".opt.onnx", 562 | unet_path, 563 | opt_batch_size=stream.trt_unet_batch_size, 564 | ) 565 | 566 | if not os.path.exists(vae_decoder_path): 567 | os.makedirs(os.path.dirname(vae_decoder_path), exist_ok=True) 568 | stream.vae.forward = stream.vae.decode 569 | vae_decoder_model = VAE( 570 | device=stream.device, 571 | max_batch_size=stream.frame_bff_size, 572 | min_batch_size=stream.frame_bff_size, 573 | ) 574 | compile_vae_decoder( 575 | stream.vae, 576 | vae_decoder_model, 577 | vae_decoder_path + ".onnx", 578 | vae_decoder_path + ".opt.onnx", 579 | vae_decoder_path, 580 | opt_batch_size=stream.frame_bff_size, 581 | ) 582 | delattr(stream.vae, "forward") 583 | 584 | if not os.path.exists(vae_encoder_path): 585 | os.makedirs(os.path.dirname(vae_encoder_path), exist_ok=True) 586 | vae_encoder = TorchVAEEncoder(stream.vae).to(torch.device("cuda")) 587 | vae_encoder_model = VAEEncoder( 588 | device=stream.device, 589 | max_batch_size=stream.frame_bff_size, 590 | min_batch_size=stream.frame_bff_size, 591 | ) 592 | compile_vae_encoder( 593 | vae_encoder, 594 | vae_encoder_model, 595 | vae_encoder_path + ".onnx", 596 | vae_encoder_path + ".opt.onnx", 597 | vae_encoder_path, 598 | opt_batch_size=stream.frame_bff_size, 599 | ) 600 | 601 | cuda_steram = cuda.Stream() 602 | 603 | vae_config = stream.vae.config 604 | vae_dtype = stream.vae.dtype 605 | 606 | stream.unet = UNet2DConditionModelEngine( 607 | unet_path, cuda_steram, use_cuda_graph=False 608 | ) 609 | stream.vae = AutoencoderKLEngine( 610 | vae_encoder_path, 611 | vae_decoder_path, 612 | cuda_steram, 613 | stream.pipe.vae_scale_factor, 614 | use_cuda_graph=False, 615 | ) 616 | setattr(stream.vae, "config", vae_config) 617 | setattr(stream.vae, "dtype", vae_dtype) 618 | 619 | gc.collect() 620 | torch.cuda.empty_cache() 621 | 622 | print("TensorRT acceleration enabled.") 623 | if acceleration == "sfast": 624 | if self.use_cached_attn: 625 | raise NotImplementedError 626 | from streamv2v.acceleration.sfast import ( 627 | accelerate_with_stable_fast, 628 | ) 629 | 630 | stream = accelerate_with_stable_fast(stream) 631 | print("StableFast acceleration enabled.") 632 | except Exception: 633 | traceback.print_exc() 634 | print("Acceleration has failed. Falling back to normal mode.") 635 | 636 | if seed < 0: # Random seed 637 | seed = np.random.randint(0, 1000000) 638 | 639 | stream.prepare( 640 | "", 641 | "", 642 | num_inference_steps=50, 643 | guidance_scale=1.1 644 | if stream.cfg_type in ["full", "self", "initialize"] 645 | else 1.0, 646 | generator=torch.manual_seed(seed), 647 | seed=seed, 648 | ) 649 | 650 | if self.use_safety_checker: 651 | from transformers import CLIPFeatureExtractor 652 | from diffusers.pipelines.stable_diffusion.safety_checker import ( 653 | StableDiffusionSafetyChecker, 654 | ) 655 | 656 | self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( 657 | "CompVis/stable-diffusion-safety-checker" 658 | ).to(pipe.device) 659 | self.feature_extractor = CLIPFeatureExtractor.from_pretrained( 660 | "openai/clip-vit-base-patch32" 661 | ) 662 | self.nsfw_fallback_img = Image.new("RGB", (512, 512), (0, 0, 0)) 663 | 664 | return stream 665 | -------------------------------------------------------------------------------- /StreamV2V/vid2vid/README.md: -------------------------------------------------------------------------------- 1 | ## Get started with StreamV2V 2 | 3 | [English](./README.md) | [日本語](./README-ja.md) 4 | 5 | ### Prepartion 6 | 7 | We recommend to use [gdown](https://github.com/wkentaro/gdown) to prepare data and models. 8 | ```bash 9 | # Install gdown 10 | pip install gdown 11 | pip install --upgrade gdown 12 | ``` 13 | 14 | Download evaluation videos. 15 | 16 | ```bash 17 | cd vid2vid 18 | gdown https://drive.google.com/drive/folders/1q963FU9I4I8ml9_SeaW4jLb4kY3VkNak -O demo_selfie --folder 19 | ``` 20 | 21 | (Recommended) Download lora weights for better stylization. 22 | 23 | ```bash 24 | # Make sure you are under the directory of vid2vid 25 | gdown https://drive.google.com/drive/folders/1D7g-dnCQnjjogTPX-B3fttgdrp9nKeKw -O lora_weights --folder 26 | ``` 27 | 28 | | Trigger words | LORA weights | Source | 29 | |----------------------------------------------------------|------------------|-------------| 30 | | 'pixelart' , 'pixel art' , 'Pixel art' , 'PixArFK' | [Google drive](https://drive.google.com/file/d/1_-kEVFw_LnV1J2Nho6nZt4PUbymamypK/view?usp=drive_link) | [Civitai](https://civitai.com/models/185743/8bitdiffuser-64x-or-a-perfect-pixel-art-model) | 31 | | 'lowpoly', 'low poly', 'Low poly' | [Google drive](https://drive.google.com/file/d/1ZClfRljzKmxsU1Jj5OMwIuXQcnA1DwO9/view?usp=drive_link) | [Civitai](https://civitai.com/models/110435/y5-low-poly-style) | 32 | | 'Claymation', 'claymation' | [Google drive](https://drive.google.com/file/d/1GvPCbrPqJYj0_nRppSc2UD_1eRME-1tG/view?usp=drive_link) | [Civitai](https://civitai.com/models/25258/claymation-miniature) | 33 | | 'crayons', 'Crayons', 'crayons doodle', 'Crayons doodle' | [Google drive](https://drive.google.com/file/d/12ZMOy8CMzwB32RHSmff0h2TJC3lFDBmW/view?usp=drive_link) | [Civitai](https://civitai.com/models/90558/child-type-doodles) | 34 | | 'sketch', 'Sketch', 'pencil drawing', 'Pencil drawing' | [Google drive](https://drive.google.com/file/d/1NIBujegFMvFdjCW0vdrmD6fbNFKNROE4/view?usp=drive_link) | [Civitai](https://civitai.com/models/155490/pencil-sketch-or) | 35 | | 'oil painting', 'Oil painting' | [Google drive](https://drive.google.com/file/d/1fmS3fGeja0RM8YbZtbKw20fjXNzHrnxz/view?usp=drive_link) | [Civitai](https://civitai.com/models/84542/oil-paintingoil-brush-stroke) | 36 | 37 | ### Evaluation 38 | 39 | ```bash 40 | # Evaluate a single video 41 | python main.py --input ./demo_selfie/jeff_1.mp4 --prompt "Elon Musk is giving a talk." 42 | python main.py --input ./demo_selfie/jeff_1.mp4 --prompt "Claymation, a man is giving a talk." 43 | ``` 44 | 45 | ```bash 46 | # Evaluate a batch of videos 47 | python batch_eval.py --json_file ./demo_selfie/eval_jeff_celebrity.json # Face swap edits 48 | python batch_eval.py --json_file ./demo_selfie/eval_jeff_lorastyle.json # Stylization edits 49 | ``` 50 | 51 | CAUTION: The `--acceleration tensorrt` option is NOT SUPPORTED! I did try to accelerate the model with TensorRT, but due to the dynamic nature of the feature bank, I didn't succeed. If you are an expert on this, please contact me (jeffliang@utexas.edu) and we could discuss how to include you as a contributor. 52 | 53 | ### Ablation study using command 54 | 55 | ```bash 56 | # Do not use feature bank, the model would roll back into per-frame StreamDiffusion 57 | python main.py --input ./demo_selfie/jeff_1.mp4 --prompt "Claymation, a man is giving a talk." --use_cached_attn False --output_dir outputs_streamdiffusion 58 | ``` 59 | 60 | ```bash 61 | # Specify the noise strength. Higher the noise_strength means more noise is added to the starting frames. 62 | # Highter strength ususally leads to better edit effects but may sacrifice the consistency. By default, it is 0.4. 63 | python main.py --input ./demo_selfie/jeff_1.mp4 --prompt "Claymation, a man is giving a talk." --noise_strength 0.8 --output_dir outputs_strength 64 | ``` 65 | 66 | ```bash 67 | # Specify the diffusion steps. Higher steps ususally lead to higher quality but slower speed. 68 | # By default, it is 4. 69 | python main.py --input ./demo_selfie/jeff_1.mp4 --prompt "Claymation, a man is giving a talk." --diffusion_steps 1 --output_dir outputs_steps 70 | ``` 71 | 72 | ### Common Bugs 73 | 74 | #### ImportError Issue 75 | - **Error Message**: `ImportError: cannot import name 'packaging' from 'pkg_resources'`. 76 | - **Related GitHub Issue**: [setuptools issue #4961](https://github.com/vllm-project/vllm/issues/4961) 77 | 78 | **Potential Workaround**: 79 | Downgrade the setuptools package to resolve this issue. You can do this by running the following command in your terminal: 80 | 81 | ```bash 82 | pip install setuptools==69.5.1 83 | ``` -------------------------------------------------------------------------------- /StreamV2V/vid2vid/batch_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import subprocess 5 | 6 | def parse_arguments(): 7 | parser = argparse.ArgumentParser(description="Process a JSON file contains multiple edits.") 8 | parser.add_argument('--json_file', type=str, help='The path to the JSON file to process.') 9 | return parser.parse_args() 10 | 11 | data = [] 12 | args = parse_arguments() 13 | json_file = args.json_file 14 | 15 | # Load the JSON data 16 | with open(json_file, "r") as file: 17 | for line in file: 18 | data.append(json.loads(line)) 19 | 20 | for item in data: 21 | file_path = item["file_path"] 22 | src_vid_name = item["src_vid_name"] 23 | prompt = item["prompt"] 24 | diffusion_steps = item["diffusion_steps"] 25 | noise_strength = item["noise_strength"] 26 | try: 27 | model_id = item["model_id"] 28 | except: 29 | model_id = "runwayml/stable-diffusion-v1-5" 30 | command = [ 31 | 'python', "main.py", 32 | "--input", f"{file_path}/{src_vid_name}.mp4", 33 | "--prompt", prompt, 34 | "--model_id", model_id, 35 | "--diffusion_steps", diffusion_steps, 36 | "--noise_strength", noise_strength, 37 | "--acceleration", "xformers", 38 | "--use_cached_attn", 39 | "--use_feature_injection", 40 | "--cache_maxframes", "1", 41 | "--use_tome_cache", 42 | "--do_add_noise", 43 | "--guidance_scale", "1.0" 44 | ] 45 | subprocess.run(command) 46 | -------------------------------------------------------------------------------- /StreamV2V/vid2vid/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | from typing import Literal, Dict, Optional 5 | 6 | import fire 7 | import torch 8 | from torchvision.io import read_video, write_video 9 | from tqdm import tqdm 10 | 11 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 12 | 13 | from utils.wrapper import StreamV2VWrapper 14 | 15 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 16 | 17 | 18 | def main( 19 | input: str, 20 | prompt: str, 21 | face_lora: str, 22 | face_lora_weights:float, 23 | style_lora: str, 24 | style_lora_weights:float, 25 | output_file: str = os.path.join(CURRENT_DIR, "outputs"), 26 | model_id: str = "runwayml/stable-diffusion-v1-5", 27 | scale: float = 1.0, 28 | guidance_scale: float = 1.0, 29 | diffusion_steps: int = 4, 30 | noise_strength: float = 0.4, 31 | acceleration: Literal["none", "xformers", "tensorrt"] = "xformers", 32 | use_denoising_batch: bool = True, 33 | use_cached_attn: bool = True, 34 | use_feature_injection: bool = True, 35 | feature_injection_strength: float = 0.8, 36 | feature_similarity_threshold: float = 0.98, 37 | cache_interval: int = 4, 38 | cache_maxframes: int = 1, 39 | use_tome_cache: bool = True, 40 | do_add_noise: bool = True, 41 | enable_similar_image_filter: bool = False, 42 | seed: int = 2, 43 | ): 44 | 45 | """ 46 | Perform video-to-video translation with StreamV2V. 47 | 48 | Parameters 49 | ---------- 50 | input: str 51 | The input video name. 52 | prompt: str 53 | The editting prompt to perform video translation. 54 | output_file: str, optional 55 | The directory of the output video. 56 | model_id: str, optional 57 | The base image diffusion model. 58 | By default, it is SD 1.5 ("runwayml/stable-diffusion-v1-5"). 59 | scale: float, optional 60 | The scale of the resolution, by default 1.0. 61 | guidance_scale: float, optional 62 | Classifier-free guidance (CFG). 63 | By default, it is not enabled, 1.0. 64 | diffusion_steps: int, optional 65 | Diffusion steps to perform. Higher steps ususally lead to higher quality but slower speed. 66 | By default, it is 4. 67 | noise_strength: float, optional 68 | Our editing method is SDEdit. Higher the noise_strength means more noise is added to the starting frames. 69 | Highter strength ususally leads to better edit effects but may sacrifice the consistency. 70 | By default, it is 0.4. 71 | acceleration: Literal["none", "xformers", "tensorrt"] = "xformers" 72 | The type of acceleration to use for video translation. 73 | By default, it is xformers. 74 | use_denoising_batch: bool, optional 75 | Whether to use denoising batch or not. 76 | By default, it is True. 77 | use_cached_attn: bool, optional 78 | Whether to cache the self attention maps of the pervious frames to imporve temporal consistency. 79 | If it is set to False, it would roll back to per-frame StreamDiffusion. 80 | By default, it is True 81 | use_feature_injection: bool, optional 82 | Whether directly to inject the features of the pervious frames to imporve temporal consistency. 83 | By default, it is True 84 | feature_injection_strength: float, optional 85 | The strength to perform feature injection. Higher value means higher weights from previous frames. 86 | By default, it is 0.8 87 | feature_similarity_threshold: float, optional 88 | The threshold to identify the similar features. 89 | By default, it is 0.98 90 | cache_interval: int, optional 91 | The frame interval to update the feature bank. 92 | By default, it is 4 93 | cache_maxframes: int, optional 94 | The max frames to cache in the feature bank. Use FIFO (First-In-First-Out) strategy to update. 95 | Only effective when use_tome_cache = False, otherwise, cache_maxframes is set to 1. 96 | use_tome_cache : bool, optional 97 | Use Token Merging (ToMe) to update the bank. 98 | By default, it is True. 99 | enable_similar_image_filter: bool, optional 100 | Whether to enable similar image filter or not, 101 | By default, it is False. 102 | seed: int, optional 103 | The seed, by default 2. if -1, use random seed. 104 | """ 105 | 106 | video_info = read_video(input) 107 | video = video_info[0] / 255 108 | fps = video_info[2]["video_fps"] 109 | height = int(video.shape[1] * scale) 110 | width = int(video.shape[2] * scale) 111 | 112 | init_step = int(50 * (1 - noise_strength)) 113 | interval = int(50 * noise_strength) // diffusion_steps 114 | t_index_list = [init_step + i * interval for i in range(diffusion_steps)] 115 | 116 | 117 | stream = StreamV2VWrapper( 118 | model_id_or_path=model_id, 119 | t_index_list=t_index_list, 120 | frame_buffer_size=1, 121 | width=width, 122 | height=height, 123 | warmup=10, 124 | acceleration=acceleration, 125 | do_add_noise=do_add_noise, 126 | output_type="pt", 127 | enable_similar_image_filter=enable_similar_image_filter, 128 | similar_image_filter_threshold=0.98, 129 | use_denoising_batch=use_denoising_batch, 130 | use_cached_attn=use_cached_attn, 131 | use_feature_injection=use_feature_injection, 132 | feature_injection_strength=feature_injection_strength, 133 | feature_similarity_threshold=feature_similarity_threshold, 134 | cache_interval=cache_interval, 135 | cache_maxframes=cache_maxframes, 136 | use_tome_cache=use_tome_cache, 137 | seed=seed, 138 | ) 139 | stream.prepare( 140 | prompt=prompt, 141 | num_inference_steps=50, 142 | guidance_scale=guidance_scale, 143 | ) 144 | 145 | # Specify LORAs 146 | # ComfyUI/custom_nodes/ComfyUI-StreamV2V/lora_weights 147 | lora_path = os.path.join(os.path.dirname(os.path.dirname(output_file)),"custom_nodes","ComfyUI-StreamV2V","lora_weights") 148 | if face_lora != "none": 149 | face_lora_path = os.path.join(lora_path,"face",face_lora) 150 | print(f"Use face LORA: in {face_lora_path}") 151 | stream.stream.load_lora(face_lora_path, adapter_name='face') 152 | stream.stream.pipe.set_adapters(["lcm", "face"], adapter_weights=[1.0, face_lora_weights]) 153 | if style_lora != "none": 154 | style_lora_path = os.path.join(lora_path,"style",style_lora) 155 | print(f"Use style LORA: in {style_lora_path}") 156 | stream.stream.load_lora(style_lora_path, adapter_name='style') 157 | stream.stream.pipe.set_adapters(["lcm", "style"], adapter_weights=[1.0, style_lora_weights]) 158 | 159 | ''' 160 | # Specify LORAs 161 | if any(word in prompt for word in ['pixelart', 'pixel art', 'Pixel art', 'PixArFK']): 162 | stream.stream.load_lora("./lora_weights/PixelArtRedmond15V-PixelArt-PIXARFK.safetensors", adapter_name='pixelart') 163 | stream.stream.pipe.set_adapters(["lcm", "pixelart"], adapter_weights=[1.0, 1.0]) 164 | print("Use LORA: pixelart in ./lora_weights/PixelArtRedmond15V-PixelArt-PIXARFK.safetensors") 165 | elif any(word in prompt for word in ['lowpoly', 'low poly', 'Low poly']): 166 | stream.stream.load_lora("./lora_weights/low_poly.safetensors", adapter_name='lowpoly') 167 | stream.stream.pipe.set_adapters(["lcm", "lowpoly"], adapter_weights=[1.0, 1.0]) 168 | print("Use LORA: lowpoly in ./lora_weights/low_poly.safetensors") 169 | elif any(word in prompt for word in ['Claymation', 'claymation']): 170 | stream.stream.load_lora("./lora_weights/Claymation.safetensors", adapter_name='claymation') 171 | stream.stream.pipe.set_adapters(["lcm", "claymation"], adapter_weights=[1.0, 1.0]) 172 | print("Use LORA: claymation in ./lora_weights/Claymation.safetensors") 173 | elif any(word in prompt for word in ['crayons', 'Crayons', 'crayons doodle', 'Crayons doodle']): 174 | stream.stream.load_lora("./lora_weights/doodle.safetensors", adapter_name='crayons') 175 | stream.stream.pipe.set_adapters(["lcm", "crayons"], adapter_weights=[1.0, 1.0]) 176 | print("Use LORA: crayons in ./lora_weights/doodle.safetensors") 177 | elif any(word in prompt for word in ['sketch', 'Sketch', 'pencil drawing', 'Pencil drawing']): 178 | stream.stream.load_lora("./lora_weights/Sketch_offcolor.safetensors", adapter_name='sketch') 179 | stream.stream.pipe.set_adapters(["lcm", "sketch"], adapter_weights=[1.0, 1.0]) 180 | print("Use LORA: sketch in ./lora_weights/Sketch_offcolor.safetensors") 181 | elif any(word in prompt for word in ['oil painting', 'Oil painting']): 182 | stream.stream.load_lora("./lora_weights/bichu-v0612.safetensors", adapter_name='oilpainting') 183 | stream.stream.pipe.set_adapters(["lcm", "oilpainting"], adapter_weights=[1.0, 1.0]) 184 | print("Use LORA: oilpainting in ./lora_weights/bichu-v0612.safetensors") 185 | ''' 186 | 187 | video_result = torch.zeros(video.shape[0], height, width, 3) 188 | 189 | for _ in range(stream.batch_size): 190 | stream(image=video[0].permute(2, 0, 1)) 191 | 192 | inference_time = [] 193 | for i in tqdm(range(video.shape[0])): 194 | iteration_start_time = time.time() 195 | output_image = stream(video[i].permute(2, 0, 1)) 196 | video_result[i] = output_image.permute(1, 2, 0) 197 | iteration_end_time = time.time() 198 | inference_time.append(iteration_end_time -iteration_start_time ) 199 | print(f'Avg time: {sum(inference_time[20:])/len(inference_time[20:])}') 200 | 201 | video_result = video_result * 255 202 | ''' 203 | prompt_txt = prompt.replace(' ', '-') 204 | input_vid = input.split('/')[-1] 205 | output = os.path.join(output_dir, f"{input_vid.rsplit('.', 1)[0]}_{prompt_txt}.{input_vid.rsplit('.', 1)[1]}") 206 | ''' 207 | write_video(output_file, video_result, fps=fps) 208 | 209 | 210 | if __name__ == "__main__": 211 | fire.Fire(main) 212 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import site 3 | now_dir = os.path.dirname(os.path.abspath(__file__)) 4 | 5 | site_packages_roots = [] 6 | for path in site.getsitepackages(): 7 | if "packages" in path: 8 | site_packages_roots.append(path) 9 | if(site_packages_roots==[]):site_packages_roots=["%s/runtime/Lib/site-packages" % now_dir] 10 | #os.environ["OPENBLAS_NUM_THREADS"] = "4" 11 | for site_packages_root in site_packages_roots: 12 | if os.path.exists(site_packages_root): 13 | try: 14 | with open("%s/streamv2v.pth" % (site_packages_root), "w") as f: 15 | f.write( 16 | "%s\n%s/StreamV2V\n" 17 | % (now_dir,now_dir) 18 | ) 19 | break 20 | except PermissionError: 21 | raise PermissionError 22 | 23 | if os.path.isfile("%s/streamv2v.pth" % (site_packages_root)): 24 | print("!!!streamv2v path was added to " + "%s/streamv2v.pth" % (site_packages_root) 25 | + "\n if meet `No module` error,try `python main.py` again, don't be foolish to pip install modules") 26 | 27 | 28 | from .nodes import LoadVideo,PreViewVideo,CombineAudioVideo,StreamV2V,LoadImagePath, LoadAudio 29 | WEB_DIRECTORY = "./web" 30 | # A dictionary that contains all nodes you want to export with their names 31 | # NOTE: names should be globally unique 32 | NODE_CLASS_MAPPINGS = { 33 | "LoadAudio": LoadAudio, 34 | "LoadVideo": LoadVideo, 35 | "PreViewVideo": PreViewVideo, 36 | "CombineAudioVideo": CombineAudioVideo, 37 | "StreamV2V": StreamV2V, 38 | "LoadImagePath": LoadImagePath 39 | } 40 | 41 | # A dictionary that contains the friendly/humanly readable titles for the nodes 42 | NODE_DISPLAY_NAME_MAPPINGS = { 43 | "streamv2v": "StreamV2V Node", 44 | "LoadVideo": "Video Loader", 45 | "PreViewVideo": "PreView Video", 46 | "CombineAudioVideo": "Combine Audio Video", 47 | "LoadImagePath": "LoadImagePath", 48 | "LoadAudio": "AudioLoader" 49 | } 50 | -------------------------------------------------------------------------------- /donate.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-StreamV2V/ddc939992d873706ae54094d906ea328cb82613b/donate.jpg -------------------------------------------------------------------------------- /lora_weights/face/put lora model here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-StreamV2V/ddc939992d873706ae54094d906ea328cb82613b/lora_weights/face/put lora model here -------------------------------------------------------------------------------- /lora_weights/style/put lora model here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-StreamV2V/ddc939992d873706ae54094d906ea328cb82613b/lora_weights/style/put lora model here -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import folder_paths 5 | import streamv2v 6 | from pydub import AudioSegment 7 | from moviepy.editor import VideoFileClip,AudioFileClip 8 | 9 | input_path = folder_paths.get_input_directory() 10 | out_path = folder_paths.get_output_directory() 11 | now_dir = os.path.dirname(os.path.abspath(__file__)) 12 | 13 | lora_path = os.path.join(now_dir, "lora_weights") 14 | class StreamV2V: 15 | @classmethod 16 | def INPUT_TYPES(s): 17 | face_lora_list = os.listdir(os.path.join(lora_path,"face")) 18 | style_lora_list = os.listdir(os.path.join(lora_path,"style")) 19 | return { 20 | "required": { 21 | "input_video": ("VIDEO",), 22 | "prompt":("STRING",{ 23 | "default": "Elon Musk is giving a talk.", 24 | "multiline": True, 25 | }), 26 | "face_lora": (face_lora_list+["none"],{ 27 | "default": "none" 28 | }), 29 | "face_lora_weights": ("FLOAT",{ 30 | "default": 1., 31 | "min": 0., 32 | "max": 1., 33 | "step":0.1, 34 | "display": "slider" 35 | }), 36 | "style_lora":(style_lora_list+["none"],{ 37 | "default": "none" 38 | }), 39 | "style_lora_weights": ("FLOAT",{ 40 | "default": 1., 41 | "min": 0., 42 | "max": 1., 43 | "step":0.1, 44 | "display": "slider" 45 | }), 46 | "model_id":("STRING",{ 47 | "default": "runwayml/stable-diffusion-v1-5" 48 | }), 49 | "scale":("FLOAT",{ 50 | "default": 1. 51 | }), 52 | "guidance_scale":("FLOAT",{ 53 | "default": 1. 54 | }), 55 | "diffusion_steps":("INT",{ 56 | "default": 4 57 | }), 58 | "noise_strength":("FLOAT",{ 59 | "default": 0.4 60 | }), 61 | "acceleration":(["none", "xformers", "tensorrt"],{ 62 | "default": "xformers" 63 | }), 64 | "seed":("INT",{ 65 | "default": 42 66 | }), 67 | } 68 | } 69 | 70 | CATEGORY = "AIFSH_StreamV2V" 71 | 72 | RETURN_TYPES = ("VIDEO",) 73 | FUNCTION = "generate" 74 | 75 | def generate(self,input_video,prompt,face_lora,face_lora_weights, 76 | style_lora,style_lora_weights, 77 | model_id,scale,guidance_scale, 78 | diffusion_steps,noise_strength,acceleration,seed): 79 | python_exec = sys.executable or "python" 80 | parent_directory = os.path.join(now_dir,"StreamV2V","vid2vid") 81 | video_file = os.path.join(out_path, f"streamv2v_{time.time()}.mp4") 82 | streamv2v_cmd = f"{python_exec} {parent_directory}/main.py --input {input_video} \ 83 | --prompt '{prompt}' --output_file {video_file} --model_id {model_id} --scale {scale} \ 84 | --guidance_scale {guidance_scale} --diffusion_steps {diffusion_steps} --noise_strength {noise_strength} \ 85 | --acceleration {acceleration} --feature_similarity_threshold 0.98 --use_denoising_batch --use_cached_attn \ 86 | --use_feature_injection --feature_injection_strength 0.8 --cache_interval 4 --cache_maxframes 1 \ 87 | --use_tome_cache --do_add_noise --seed {seed} --face_lora '{face_lora}' --style_lora '{style_lora}'\ 88 | --style_lora_weights {style_lora_weights} --face_lora_weights {face_lora_weights}" 89 | 90 | print(streamv2v_cmd) 91 | os.system(streamv2v_cmd) 92 | return(video_file,) 93 | 94 | 95 | class LoadAudio: 96 | @classmethod 97 | def INPUT_TYPES(s): 98 | files = [f for f in os.listdir(input_path) if os.path.isfile(os.path.join(input_path, f)) and f.split('.')[-1] in ["wav", "mp3","WAV","flac","m4a"]] 99 | return {"required": 100 | {"audio": (sorted(files),)}, 101 | } 102 | 103 | CATEGORY = "AIFSH_StreamV2V" 104 | 105 | RETURN_TYPES = ("AUDIO",) 106 | FUNCTION = "load_audio" 107 | 108 | def load_audio(self, audio): 109 | audio_path = folder_paths.get_annotated_filepath(audio) 110 | return (audio_path,) 111 | 112 | class LoadImagePath: 113 | @classmethod 114 | def INPUT_TYPES(s): 115 | input_dir = folder_paths.get_input_directory() 116 | files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] 117 | return {"required": 118 | {"image": (sorted(files), {"image_upload": True})}, 119 | } 120 | 121 | CATEGORY = "AIFSH_StreamV2V" 122 | 123 | RETURN_TYPES = ("IMAGE",) 124 | FUNCTION = "load_image" 125 | def load_image(self, image): 126 | image_path = folder_paths.get_annotated_filepath(image) 127 | return (image_path,) 128 | 129 | class CombineAudioVideo: 130 | @classmethod 131 | def INPUT_TYPES(s): 132 | return {"required": 133 | {"vocal_AUDIO": ("AUDIO",), 134 | "bgm_AUDIO": ("AUDIO",), 135 | "video": ("VIDEO",) 136 | } 137 | } 138 | 139 | CATEGORY = "AIFSH_StreamV2V" 140 | DESCRIPTION = "hello world!" 141 | 142 | RETURN_TYPES = ("VIDEO",) 143 | 144 | OUTPUT_NODE = False 145 | 146 | FUNCTION = "combine" 147 | 148 | def combine(self, vocal_AUDIO,bgm_AUDIO,video): 149 | vocal = AudioSegment.from_file(vocal_AUDIO) 150 | bgm = AudioSegment.from_file(bgm_AUDIO) 151 | audio = vocal.overlay(bgm) 152 | audio_file = os.path.join(out_path,"ip_lap_voice.wav") 153 | audio.export(audio_file, format="wav") 154 | cm_video_file = os.path.join(out_path,"voice_"+os.path.basename(video)) 155 | video_clip = VideoFileClip(video) 156 | audio_clip = AudioFileClip(audio_file) 157 | new_video_clip = video_clip.set_audio(audio_clip) 158 | new_video_clip.write_videofile(cm_video_file) 159 | return (cm_video_file,) 160 | 161 | 162 | class PreViewVideo: 163 | @classmethod 164 | def INPUT_TYPES(s): 165 | return {"required":{ 166 | "video":("VIDEO",), 167 | }} 168 | 169 | CATEGORY = "AIFSH_StreamV2V" 170 | DESCRIPTION = "hello world!" 171 | 172 | RETURN_TYPES = () 173 | 174 | OUTPUT_NODE = True 175 | 176 | FUNCTION = "load_video" 177 | 178 | def load_video(self, video): 179 | video_name = os.path.basename(video) 180 | video_path_name = os.path.basename(os.path.dirname(video)) 181 | return {"ui":{"video":[video_name,video_path_name]}} 182 | 183 | class LoadVideo: 184 | @classmethod 185 | def INPUT_TYPES(s): 186 | files = [f for f in os.listdir(input_path) if os.path.isfile(os.path.join(input_path, f)) and f.split('.')[-1] in ["mp4", "webm","mkv","avi"]] 187 | return {"required":{ 188 | "video":(files,), 189 | }} 190 | 191 | CATEGORY = "AIFSH_StreamV2V" 192 | DESCRIPTION = "hello world!" 193 | 194 | RETURN_TYPES = ("VIDEO","AUDIO") 195 | 196 | OUTPUT_NODE = False 197 | 198 | FUNCTION = "load_video" 199 | 200 | def load_video(self, video): 201 | video_path = os.path.join(input_path,video) 202 | video_clip = VideoFileClip(video_path) 203 | audio_path = os.path.join(input_path,video+".wav") 204 | try: 205 | video_clip.audio.write_audiofile(audio_path) 206 | except: 207 | print("none audio") 208 | return (video_path,audio_path,) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | av 3 | peft 4 | pydub 5 | moviepy 6 | diffusers==0.27.0 7 | fire 8 | omegaconf 9 | cuda-python 10 | onnxruntime-gpu 11 | protobuf==3.20.2 12 | colored 13 | pywin32;sys_platform == 'win32' -------------------------------------------------------------------------------- /web.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-StreamV2V/ddc939992d873706ae54094d906ea328cb82613b/web.png -------------------------------------------------------------------------------- /web/js/previewVideo.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../../scripts/app.js"; 2 | import { api } from '../../../scripts/api.js' 3 | 4 | function fitHeight(node) { 5 | node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]]) 6 | node?.graph?.setDirtyCanvas(true); 7 | } 8 | function chainCallback(object, property, callback) { 9 | if (object == undefined) { 10 | //This should not happen. 11 | console.error("Tried to add callback to non-existant object") 12 | return; 13 | } 14 | if (property in object) { 15 | const callback_orig = object[property] 16 | object[property] = function () { 17 | const r = callback_orig.apply(this, arguments); 18 | callback.apply(this, arguments); 19 | return r 20 | }; 21 | } else { 22 | object[property] = callback; 23 | } 24 | } 25 | 26 | function addPreviewOptions(nodeType) { 27 | chainCallback(nodeType.prototype, "getExtraMenuOptions", function(_, options) { 28 | // The intended way of appending options is returning a list of extra options, 29 | // but this isn't used in widgetInputs.js and would require 30 | // less generalization of chainCallback 31 | let optNew = [] 32 | try { 33 | const previewWidget = this.widgets.find((w) => w.name === "videopreview"); 34 | 35 | let url = null 36 | if (previewWidget.videoEl?.hidden == false && previewWidget.videoEl.src) { 37 | //Use full quality video 38 | //url = api.apiURL('/view?' + new URLSearchParams(previewWidget.value.params)); 39 | url = previewWidget.videoEl.src 40 | } 41 | if (url) { 42 | optNew.push( 43 | { 44 | content: "Open preview", 45 | callback: () => { 46 | window.open(url, "_blank") 47 | }, 48 | }, 49 | { 50 | content: "Save preview", 51 | callback: () => { 52 | const a = document.createElement("a"); 53 | a.href = url; 54 | a.setAttribute("download", new URLSearchParams(previewWidget.value.params).get("filename")); 55 | document.body.append(a); 56 | a.click(); 57 | requestAnimationFrame(() => a.remove()); 58 | }, 59 | } 60 | ); 61 | } 62 | if(options.length > 0 && options[0] != null && optNew.length > 0) { 63 | optNew.push(null); 64 | } 65 | options.unshift(...optNew); 66 | 67 | } catch (error) { 68 | console.log(error); 69 | } 70 | 71 | }); 72 | } 73 | function previewVideo(node,file,type){ 74 | var element = document.createElement("div"); 75 | const previewNode = node; 76 | var previewWidget = node.addDOMWidget("videopreview", "preview", element, { 77 | serialize: false, 78 | hideOnZoom: false, 79 | getValue() { 80 | return element.value; 81 | }, 82 | setValue(v) { 83 | element.value = v; 84 | }, 85 | }); 86 | previewWidget.computeSize = function(width) { 87 | if (this.aspectRatio && !this.parentEl.hidden) { 88 | let height = (previewNode.size[0]-20)/ this.aspectRatio + 10; 89 | if (!(height > 0)) { 90 | height = 0; 91 | } 92 | this.computedHeight = height + 10; 93 | return [width, height]; 94 | } 95 | return [width, -4];//no loaded src, widget should not display 96 | } 97 | // element.style['pointer-events'] = "none" 98 | previewWidget.value = {hidden: false, paused: false, params: {}} 99 | previewWidget.parentEl = document.createElement("div"); 100 | previewWidget.parentEl.className = "video_preview"; 101 | previewWidget.parentEl.style['width'] = "100%" 102 | element.appendChild(previewWidget.parentEl); 103 | previewWidget.videoEl = document.createElement("video"); 104 | previewWidget.videoEl.controls = true; 105 | previewWidget.videoEl.loop = false; 106 | previewWidget.videoEl.muted = false; 107 | previewWidget.videoEl.style['width'] = "100%" 108 | previewWidget.videoEl.addEventListener("loadedmetadata", () => { 109 | 110 | previewWidget.aspectRatio = previewWidget.videoEl.videoWidth / previewWidget.videoEl.videoHeight; 111 | fitHeight(this); 112 | }); 113 | previewWidget.videoEl.addEventListener("error", () => { 114 | //TODO: consider a way to properly notify the user why a preview isn't shown. 115 | previewWidget.parentEl.hidden = true; 116 | fitHeight(this); 117 | }); 118 | 119 | let params = { 120 | "filename": file, 121 | "type": type, 122 | } 123 | 124 | previewWidget.parentEl.hidden = previewWidget.value.hidden; 125 | previewWidget.videoEl.autoplay = !previewWidget.value.paused && !previewWidget.value.hidden; 126 | let target_width = 256 127 | if (element.style?.width) { 128 | //overscale to allow scrolling. Endpoint won't return higher than native 129 | target_width = element.style.width.slice(0,-2)*2; 130 | } 131 | if (!params.force_size || params.force_size.includes("?") || params.force_size == "Disabled") { 132 | params.force_size = target_width+"x?" 133 | } else { 134 | let size = params.force_size.split("x") 135 | let ar = parseInt(size[0])/parseInt(size[1]) 136 | params.force_size = target_width+"x"+(target_width/ar) 137 | } 138 | 139 | previewWidget.videoEl.src = api.apiURL('/view?' + new URLSearchParams(params)); 140 | 141 | previewWidget.videoEl.hidden = false; 142 | previewWidget.parentEl.appendChild(previewWidget.videoEl) 143 | } 144 | 145 | app.registerExtension({ 146 | name: "StreamV2V.VideoPreviewer", 147 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 148 | if (nodeData?.name == "PreViewVideo") { 149 | nodeType.prototype.onExecuted = function (data) { 150 | previewVideo(this, data.video[0], data.video[1]); 151 | } 152 | addPreviewOptions(nodeType) 153 | } 154 | } 155 | }); 156 | -------------------------------------------------------------------------------- /web/js/uploadAudio.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../../scripts/app.js"; 2 | import { api } from '../../../scripts/api.js' 3 | import { ComfyWidgets } from "../../../scripts/widgets.js" 4 | 5 | function fitHeight(node) { 6 | node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]]) 7 | node?.graph?.setDirtyCanvas(true); 8 | } 9 | 10 | function previewAudio(node,file){ 11 | while (node.widgets.length > 2){ 12 | node.widgets.pop(); 13 | } 14 | try { 15 | var el = document.getElementById("uploadAudio"); 16 | el.remove(); 17 | } catch (error) { 18 | console.log(error); 19 | } 20 | var element = document.createElement("div"); 21 | element.id = "uploadAudio"; 22 | const previewNode = node; 23 | var previewWidget = node.addDOMWidget("audiopreview", "preview", element, { 24 | serialize: false, 25 | hideOnZoom: false, 26 | getValue() { 27 | return element.value; 28 | }, 29 | setValue(v) { 30 | element.value = v; 31 | }, 32 | }); 33 | previewWidget.computeSize = function(width) { 34 | if (this.aspectRatio && !this.parentEl.hidden) { 35 | let height = (previewNode.size[0]-20)/ this.aspectRatio + 10; 36 | if (!(height > 0)) { 37 | height = 0; 38 | } 39 | this.computedHeight = height + 10; 40 | return [width, height]; 41 | } 42 | return [width, -4];//no loaded src, widget should not display 43 | } 44 | // element.style['pointer-events'] = "none" 45 | previewWidget.value = {hidden: false, paused: false, params: {}} 46 | previewWidget.parentEl = document.createElement("div"); 47 | previewWidget.parentEl.className = "audio_preview"; 48 | previewWidget.parentEl.style['width'] = "100%" 49 | element.appendChild(previewWidget.parentEl); 50 | previewWidget.audioEl = document.createElement("audio"); 51 | previewWidget.audioEl.controls = true; 52 | previewWidget.audioEl.loop = false; 53 | previewWidget.audioEl.muted = false; 54 | previewWidget.audioEl.style['width'] = "100%" 55 | previewWidget.audioEl.addEventListener("loadedmetadata", () => { 56 | 57 | previewWidget.aspectRatio = previewWidget.audioEl.audioWidth / previewWidget.audioEl.audioHeight; 58 | fitHeight(this); 59 | }); 60 | previewWidget.audioEl.addEventListener("error", () => { 61 | //TODO: consider a way to properly notify the user why a preview isn't shown. 62 | previewWidget.parentEl.hidden = true; 63 | fitHeight(this); 64 | }); 65 | 66 | let params = { 67 | "filename": file, 68 | "type": "input", 69 | } 70 | 71 | previewWidget.parentEl.hidden = previewWidget.value.hidden; 72 | previewWidget.audioEl.autoplay = !previewWidget.value.paused && !previewWidget.value.hidden; 73 | let target_width = 256 74 | if (element.style?.width) { 75 | //overscale to allow scrolling. Endpoint won't return higher than native 76 | target_width = element.style.width.slice(0,-2)*2; 77 | } 78 | if (!params.force_size || params.force_size.includes("?") || params.force_size == "Disabled") { 79 | params.force_size = target_width+"x?" 80 | } else { 81 | let size = params.force_size.split("x") 82 | let ar = parseInt(size[0])/parseInt(size[1]) 83 | params.force_size = target_width+"x"+(target_width/ar) 84 | } 85 | 86 | previewWidget.audioEl.src = api.apiURL('/view?' + new URLSearchParams(params)); 87 | 88 | previewWidget.audioEl.hidden = false; 89 | previewWidget.parentEl.appendChild(previewWidget.audioEl) 90 | } 91 | 92 | function audioUpload(node, inputName, inputData, app) { 93 | const audioWidget = node.widgets.find((w) => w.name === "audio"); 94 | let uploadWidget; 95 | /* 96 | A method that returns the required style for the html 97 | */ 98 | var default_value = audioWidget.value; 99 | Object.defineProperty(audioWidget, "value", { 100 | set : function(value) { 101 | this._real_value = value; 102 | }, 103 | 104 | get : function() { 105 | let value = ""; 106 | if (this._real_value) { 107 | value = this._real_value; 108 | } else { 109 | return default_value; 110 | } 111 | 112 | if (value.filename) { 113 | let real_value = value; 114 | value = ""; 115 | if (real_value.subfolder) { 116 | value = real_value.subfolder + "/"; 117 | } 118 | 119 | value += real_value.filename; 120 | 121 | if(real_value.type && real_value.type !== "input") 122 | value += ` [${real_value.type}]`; 123 | } 124 | return value; 125 | } 126 | }); 127 | async function uploadFile(file, updateNode, pasted = false) { 128 | try { 129 | // Wrap file in formdata so it includes filename 130 | const body = new FormData(); 131 | body.append("image", file); 132 | if (pasted) body.append("subfolder", "pasted"); 133 | const resp = await api.fetchApi("/upload/image", { 134 | method: "POST", 135 | body, 136 | }); 137 | 138 | if (resp.status === 200) { 139 | const data = await resp.json(); 140 | // Add the file to the dropdown list and update the widget value 141 | let path = data.name; 142 | if (data.subfolder) path = data.subfolder + "/" + path; 143 | 144 | if (!audioWidget.options.values.includes(path)) { 145 | audioWidget.options.values.push(path); 146 | } 147 | 148 | if (updateNode) { 149 | audioWidget.value = path; 150 | previewAudio(node,path) 151 | 152 | } 153 | } else { 154 | alert(resp.status + " - " + resp.statusText); 155 | } 156 | } catch (error) { 157 | alert(error); 158 | } 159 | } 160 | 161 | const fileInput = document.createElement("input"); 162 | Object.assign(fileInput, { 163 | type: "file", 164 | accept: "audio/mp3,audio/wav,audio/flac,audio/m4a", 165 | style: "display: none", 166 | onchange: async () => { 167 | if (fileInput.files.length) { 168 | await uploadFile(fileInput.files[0], true); 169 | } 170 | }, 171 | }); 172 | document.body.append(fileInput); 173 | 174 | // Create the button widget for selecting the files 175 | uploadWidget = node.addWidget("button", "choose audio file to upload", "Audio", () => { 176 | fileInput.click(); 177 | }); 178 | 179 | uploadWidget.serialize = false; 180 | 181 | previewAudio(node, audioWidget.value); 182 | const cb = node.callback; 183 | audioWidget.callback = function () { 184 | previewAudio(node,audioWidget.value); 185 | if (cb) { 186 | return cb.apply(this, arguments); 187 | } 188 | }; 189 | 190 | return { widget: uploadWidget }; 191 | } 192 | 193 | ComfyWidgets.AUDIOPLOAD = audioUpload; 194 | 195 | app.registerExtension({ 196 | name: "StreamV2V.UploadAudio", 197 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 198 | if (nodeData?.name == "LoadAudio") { 199 | nodeData.input.required.upload = ["AUDIOPLOAD"]; 200 | } 201 | }, 202 | }); 203 | 204 | -------------------------------------------------------------------------------- /web/js/uploadVideo.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../../scripts/app.js"; 2 | import { api } from '../../../scripts/api.js' 3 | import { ComfyWidgets } from "../../../scripts/widgets.js" 4 | 5 | function fitHeight(node) { 6 | node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]]) 7 | node?.graph?.setDirtyCanvas(true); 8 | } 9 | 10 | function previewVideo(node,file){ 11 | while (node.widgets.length > 2){ 12 | node.widgets.pop() 13 | } 14 | try { 15 | var el = document.getElementById("uploadVideo"); 16 | el.remove(); 17 | } catch (error) { 18 | console.log(error); 19 | } 20 | var element = document.createElement("div"); 21 | element.id = "uploadVideo"; 22 | const previewNode = node; 23 | var previewWidget = node.addDOMWidget("videopreview", "preview", element, { 24 | serialize: false, 25 | hideOnZoom: false, 26 | getValue() { 27 | return element.value; 28 | }, 29 | setValue(v) { 30 | element.value = v; 31 | }, 32 | }); 33 | previewWidget.computeSize = function(width) { 34 | if (this.aspectRatio && !this.parentEl.hidden) { 35 | let height = (previewNode.size[0]-20)/ this.aspectRatio + 10; 36 | if (!(height > 0)) { 37 | height = 0; 38 | } 39 | this.computedHeight = height + 10; 40 | return [width, height]; 41 | } 42 | return [width, -4];//no loaded src, widget should not display 43 | } 44 | // element.style['pointer-events'] = "none" 45 | previewWidget.value = {hidden: false, paused: false, params: {}} 46 | previewWidget.parentEl = document.createElement("div"); 47 | previewWidget.parentEl.className = "video_preview"; 48 | previewWidget.parentEl.style['width'] = "100%" 49 | element.appendChild(previewWidget.parentEl); 50 | previewWidget.videoEl = document.createElement("video"); 51 | previewWidget.videoEl.controls = true; 52 | previewWidget.videoEl.loop = false; 53 | previewWidget.videoEl.muted = false; 54 | previewWidget.videoEl.style['width'] = "100%" 55 | previewWidget.videoEl.addEventListener("loadedmetadata", () => { 56 | 57 | previewWidget.aspectRatio = previewWidget.videoEl.videoWidth / previewWidget.videoEl.videoHeight; 58 | fitHeight(this); 59 | }); 60 | previewWidget.videoEl.addEventListener("error", () => { 61 | //TODO: consider a way to properly notify the user why a preview isn't shown. 62 | previewWidget.parentEl.hidden = true; 63 | fitHeight(this); 64 | }); 65 | 66 | let params = { 67 | "filename": file, 68 | "type": "input", 69 | } 70 | 71 | previewWidget.parentEl.hidden = previewWidget.value.hidden; 72 | previewWidget.videoEl.autoplay = !previewWidget.value.paused && !previewWidget.value.hidden; 73 | let target_width = 256 74 | if (element.style?.width) { 75 | //overscale to allow scrolling. Endpoint won't return higher than native 76 | target_width = element.style.width.slice(0,-2)*2; 77 | } 78 | if (!params.force_size || params.force_size.includes("?") || params.force_size == "Disabled") { 79 | params.force_size = target_width+"x?" 80 | } else { 81 | let size = params.force_size.split("x") 82 | let ar = parseInt(size[0])/parseInt(size[1]) 83 | params.force_size = target_width+"x"+(target_width/ar) 84 | } 85 | 86 | previewWidget.videoEl.src = api.apiURL('/view?' + new URLSearchParams(params)); 87 | 88 | previewWidget.videoEl.hidden = false; 89 | previewWidget.parentEl.appendChild(previewWidget.videoEl) 90 | } 91 | 92 | function videoUpload(node, inputName, inputData, app) { 93 | const videoWidget = node.widgets.find((w) => w.name === "video"); 94 | let uploadWidget; 95 | /* 96 | A method that returns the required style for the html 97 | */ 98 | var default_value = videoWidget.value; 99 | Object.defineProperty(videoWidget, "value", { 100 | set : function(value) { 101 | this._real_value = value; 102 | }, 103 | 104 | get : function() { 105 | let value = ""; 106 | if (this._real_value) { 107 | value = this._real_value; 108 | } else { 109 | return default_value; 110 | } 111 | 112 | if (value.filename) { 113 | let real_value = value; 114 | value = ""; 115 | if (real_value.subfolder) { 116 | value = real_value.subfolder + "/"; 117 | } 118 | 119 | value += real_value.filename; 120 | 121 | if(real_value.type && real_value.type !== "input") 122 | value += ` [${real_value.type}]`; 123 | } 124 | return value; 125 | } 126 | }); 127 | async function uploadFile(file, updateNode, pasted = false) { 128 | try { 129 | // Wrap file in formdata so it includes filename 130 | const body = new FormData(); 131 | body.append("image", file); 132 | if (pasted) body.append("subfolder", "pasted"); 133 | const resp = await api.fetchApi("/upload/image", { 134 | method: "POST", 135 | body, 136 | }); 137 | 138 | if (resp.status === 200) { 139 | const data = await resp.json(); 140 | // Add the file to the dropdown list and update the widget value 141 | let path = data.name; 142 | if (data.subfolder) path = data.subfolder + "/" + path; 143 | 144 | if (!videoWidget.options.values.includes(path)) { 145 | videoWidget.options.values.push(path); 146 | } 147 | 148 | if (updateNode) { 149 | videoWidget.value = path; 150 | previewVideo(node,path) 151 | 152 | } 153 | } else { 154 | alert(resp.status + " - " + resp.statusText); 155 | } 156 | } catch (error) { 157 | alert(error); 158 | } 159 | } 160 | 161 | const fileInput = document.createElement("input"); 162 | Object.assign(fileInput, { 163 | type: "file", 164 | accept: "video/webm,video/mp4,video/mkv,video/avi", 165 | style: "display: none", 166 | onchange: async () => { 167 | if (fileInput.files.length) { 168 | await uploadFile(fileInput.files[0], true); 169 | } 170 | }, 171 | }); 172 | document.body.append(fileInput); 173 | 174 | // Create the button widget for selecting the files 175 | uploadWidget = node.addWidget("button", "choose video file to upload", "Video", () => { 176 | fileInput.click(); 177 | }); 178 | 179 | uploadWidget.serialize = false; 180 | 181 | previewVideo(node, videoWidget.value); 182 | const cb = node.callback; 183 | videoWidget.callback = function () { 184 | previewVideo(node,videoWidget.value); 185 | if (cb) { 186 | return cb.apply(this, arguments); 187 | } 188 | }; 189 | 190 | return { widget: uploadWidget }; 191 | } 192 | 193 | ComfyWidgets.VIDEOPLOAD = videoUpload; 194 | 195 | app.registerExtension({ 196 | name: "StreamV2V.UploadVideo", 197 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 198 | if (nodeData?.name == "LoadVideo") { 199 | nodeData.input.required.upload = ["VIDEOPLOAD"]; 200 | } 201 | }, 202 | }); 203 | 204 | -------------------------------------------------------------------------------- /wechat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-StreamV2V/ddc939992d873706ae54094d906ea328cb82613b/wechat.jpg --------------------------------------------------------------------------------