├── streamdiffusion ├── tools │ ├── __init__.py │ └── install-tensorrt.py ├── acceleration │ ├── __init__.py │ ├── sfast │ │ └── __init__.py │ └── tensorrt │ │ ├── builder.py │ │ ├── engine.py │ │ ├── __init__.py │ │ ├── utilities.py │ │ └── models.py ├── __init__.py ├── pip_utils.py ├── image_filter.py ├── image_utils.py ├── pipeline.py └── wrapper.py ├── workflow.png ├── __init__.py ├── .github └── workflows │ └── publish.yml ├── README.md ├── nodes.py └── LICENSE /streamdiffusion/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /streamdiffusion/acceleration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /streamdiffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline import StreamDiffusion 2 | -------------------------------------------------------------------------------- /workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jesenzhang/ComfyUI_StreamDiffusion/HEAD/workflow.png -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 3 | 4 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | permissions: 11 | issues: write 12 | 13 | jobs: 14 | publish-node: 15 | name: Publish Custom Node to registry 16 | runs-on: ubuntu-latest 17 | if: ${{ github.repository_owner == 'jesenzhang' }} 18 | steps: 19 | - name: Check out code 20 | uses: actions/checkout@v4 21 | - name: Publish Custom Node 22 | uses: Comfy-Org/publish-node-action@v1 23 | with: 24 | ## Add your own personal access token to your Github Repository secrets and reference it here. 25 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 26 | -------------------------------------------------------------------------------- /streamdiffusion/acceleration/sfast/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from sfast.compilers.stable_diffusion_pipeline_compiler import CompilationConfig, compile 4 | 5 | from ...pipeline import StreamDiffusion 6 | 7 | 8 | def accelerate_with_stable_fast( 9 | stream: StreamDiffusion, 10 | config: Optional[CompilationConfig] = None, 11 | ): 12 | if config is None: 13 | config = CompilationConfig.Default() 14 | # xformers and Triton are suggested for achieving best performance. 15 | try: 16 | import xformers 17 | 18 | config.enable_xformers = True 19 | except ImportError: 20 | print("xformers not installed, skip") 21 | try: 22 | import triton 23 | 24 | config.enable_triton = True 25 | except ImportError: 26 | print("Triton not installed, skip") 27 | # CUDA Graph is suggested for small batch sizes and small resolutions to reduce CPU overhead. 28 | config.enable_cuda_graph = True 29 | stream.pipe = compile(stream.pipe, config) 30 | stream.unet = stream.pipe.unet 31 | stream.vae = stream.pipe.vae 32 | stream.text_encoder = stream.pipe.text_encoder 33 | return stream 34 | -------------------------------------------------------------------------------- /streamdiffusion/pip_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import importlib.util 3 | import os 4 | import subprocess 5 | import sys 6 | from typing import Dict, Optional 7 | 8 | from packaging.version import Version 9 | 10 | 11 | python = sys.executable 12 | index_url = os.environ.get("INDEX_URL", "") 13 | 14 | 15 | def version(package: str) -> Optional[Version]: 16 | try: 17 | return Version(importlib.import_module(package).__version__) 18 | except ModuleNotFoundError: 19 | return None 20 | 21 | 22 | def is_installed(package: str) -> bool: 23 | try: 24 | spec = importlib.util.find_spec(package) 25 | except ModuleNotFoundError: 26 | return False 27 | 28 | return spec is not None 29 | 30 | 31 | def run_python(command: str, env: Dict[str, str] = None) -> str: 32 | run_kwargs = { 33 | "args": f"{python} {command}", 34 | "shell": True, 35 | "env": os.environ if env is None else env, 36 | "encoding": "utf8", 37 | "errors": "ignore", 38 | } 39 | 40 | print(run_kwargs["args"]) 41 | 42 | result = subprocess.run(**run_kwargs) 43 | 44 | if result.returncode != 0: 45 | print(f"Error running command: {command}", file=sys.stderr) 46 | raise RuntimeError(f"Error running command: {command}") 47 | 48 | return result.stdout or "" 49 | 50 | 51 | def run_pip(command: str, env: Dict[str, str] = None) -> str: 52 | return run_python(f"-m pip {command}", env) 53 | -------------------------------------------------------------------------------- /streamdiffusion/tools/install-tensorrt.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | import fire 4 | from packaging.version import Version 5 | 6 | from ..pip_utils import is_installed, run_pip, version 7 | 8 | 9 | def get_cuda_version_from_torch() -> Optional[Literal["11", "12"]]: 10 | try: 11 | import torch 12 | except ImportError: 13 | return None 14 | 15 | return torch.version.cuda.split(".")[0] 16 | 17 | 18 | def install(cu: Optional[Literal["11", "12"]] = get_cuda_version_from_torch()): 19 | if cu is None or cu not in ["11", "12"]: 20 | print("Could not detect CUDA version. Please specify manually.") 21 | return 22 | print("Installing TensorRT requirements...") 23 | 24 | if is_installed("tensorrt"): 25 | if version("tensorrt") < Version("9.0.0"): 26 | run_pip("uninstall -y tensorrt") 27 | 28 | cudnn_name = f"nvidia-cudnn-cu{cu}==8.9.4.25" 29 | 30 | if not is_installed("tensorrt"): 31 | run_pip(f"install {cudnn_name} --no-cache-dir") 32 | run_pip( 33 | "install --pre --extra-index-url https://pypi.nvidia.com tensorrt==9.0.1.post11.dev4 --no-cache-dir" 34 | ) 35 | 36 | if not is_installed("polygraphy"): 37 | run_pip( 38 | "install polygraphy==0.47.1 --extra-index-url https://pypi.ngc.nvidia.com" 39 | ) 40 | if not is_installed("onnx_graphsurgeon"): 41 | run_pip( 42 | "install onnx-graphsurgeon==0.3.26 --extra-index-url https://pypi.ngc.nvidia.com" 43 | ) 44 | 45 | pass 46 | 47 | 48 | if __name__ == "__main__": 49 | fire.Fire(install) 50 | -------------------------------------------------------------------------------- /streamdiffusion/image_filter.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import random 3 | 4 | import torch 5 | 6 | 7 | class SimilarImageFilter: 8 | def __init__(self, threshold: float = 0.98, max_skip_frame: float = 10) -> None: 9 | self.threshold = threshold 10 | self.prev_tensor = None 11 | self.cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6) 12 | self.max_skip_frame = max_skip_frame 13 | self.skip_count = 0 14 | 15 | def __call__(self, x: torch.Tensor) -> Optional[torch.Tensor]: 16 | if self.prev_tensor is None: 17 | self.prev_tensor = x.detach().clone() 18 | return x 19 | else: 20 | cos_sim = self.cos(self.prev_tensor.reshape(-1), x.reshape(-1)).item() 21 | sample = random.uniform(0, 1) 22 | if self.threshold >= 1: 23 | skip_prob = 0 24 | else: 25 | skip_prob = max(0, 1 - (1 - cos_sim) / (1 - self.threshold)) 26 | 27 | # not skip frame 28 | if skip_prob < sample: 29 | self.prev_tensor = x.detach().clone() 30 | return x 31 | # skip frame 32 | else: 33 | if self.skip_count > self.max_skip_frame: 34 | self.skip_count = 0 35 | self.prev_tensor = x.detach().clone() 36 | return x 37 | else: 38 | self.skip_count += 1 39 | return None 40 | 41 | def set_threshold(self, threshold: float) -> None: 42 | self.threshold = threshold 43 | 44 | def set_max_skip_frame(self, max_skip_frame: float) -> None: 45 | self.max_skip_frame = max_skip_frame 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI_StreamDiffusion 2 | 3 | # This is a simple implementation StreamDiffusion for ComfyUI 4 | 5 | 6 | # StreamDiffusion: A Pipeline-Level Solution for Real-Time Interactive Generation 7 | 8 | **Authors:** [Akio Kodaira](https://www.linkedin.com/in/akio-kodaira-1a7b98252/), [Chenfeng Xu](https://www.chenfengx.com/), Toshiki Hazama, [Takanori Yoshimoto](https://twitter.com/__ramu0e__), [Kohei Ohno](https://www.linkedin.com/in/kohei--ohno/), [Shogo Mitsuhori](https://me.ddpn.world/), [Soichi Sugano](https://twitter.com/toni_nimono), [Hanying Cho](https://twitter.com/hanyingcl), [Zhijian Liu](https://zhijianliu.com/), [Kurt Keutzer](https://scholar.google.com/citations?hl=en&user=ID9QePIAAAAJ) 9 | 10 | StreamDiffusion is an innovative diffusion pipeline designed for real-time interactive generation. It introduces significant performance enhancements to current diffusion-based image generation techniques. 11 | 12 | [![arXiv](https://img.shields.io/badge/arXiv-2307.04725-b31b1b.svg)](https://arxiv.org/abs/2312.12491) 13 | [![Hugging Face Papers](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-papers-yellow)](https://huggingface.co/papers/2312.12491) 14 | 15 | # Simple Use 16 | you can download the workflow image below , and import into ComfyUI 17 |

18 | 19 |

20 | 21 | # img2img 22 | img2img can be done by send a image to the image imput in the sampler node,but the batch_size must be 1. 23 | 24 | # StreamDiffusion_Sampler 25 | Input Latent is not implemented for now. 26 | 27 | ### Lora stack 28 | You can set Lora stack by using LoRA Stacker from Efficiency Nodes. 29 | 30 | ## Support 31 | Thank you for being awesome! -------------------------------------------------------------------------------- /streamdiffusion/image_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | import PIL.Image 5 | import torch 6 | import torchvision 7 | 8 | 9 | def denormalize(images: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: 10 | """ 11 | Denormalize an image array to [0,1]. 12 | """ 13 | return (images.float() / 2 + 0.5).clamp(0, 1) 14 | 15 | 16 | def pt_to_numpy(images: torch.Tensor) -> np.ndarray: 17 | """ 18 | Convert a PyTorch tensor to a NumPy image. 19 | """ 20 | images = images.cpu().permute(0, 2, 3, 1).float().numpy() 21 | return images 22 | 23 | 24 | def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image: 25 | """ 26 | Convert a NumPy image or a batch of images to a PIL image. 27 | """ 28 | if images.ndim == 3: 29 | images = images[None, ...] 30 | images = (images * 255).round().astype("uint8") 31 | if images.shape[-1] == 1: 32 | # special case for grayscale (single channel) images 33 | pil_images = [ 34 | PIL.Image.fromarray(image.squeeze(), mode="L") for image in images 35 | ] 36 | else: 37 | pil_images = [PIL.Image.fromarray(image) for image in images] 38 | 39 | return pil_images 40 | 41 | 42 | def postprocess_image( 43 | image: torch.Tensor, 44 | output_type: str = "pil", 45 | do_denormalize: Optional[List[bool]] = None, 46 | ) -> Union[torch.Tensor, np.ndarray, PIL.Image.Image]: 47 | if not isinstance(image, torch.Tensor): 48 | raise ValueError( 49 | f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" 50 | ) 51 | 52 | if output_type == "latent": 53 | return image 54 | 55 | do_normalize_flg = True 56 | if do_denormalize is None: 57 | do_denormalize = [do_normalize_flg] * image.shape[0] 58 | 59 | image = torch.stack( 60 | [ 61 | denormalize(image[i]) if do_denormalize[i] else image[i] 62 | for i in range(image.shape[0]) 63 | ] 64 | ) 65 | 66 | if output_type == "pt": 67 | return image 68 | 69 | image = pt_to_numpy(image) 70 | 71 | if output_type == "np": 72 | return image 73 | 74 | if output_type == "pil": 75 | return numpy_to_pil(image) 76 | 77 | 78 | def process_image( 79 | image_pil: PIL.Image.Image, range: Tuple[int, int] = (-1, 1) 80 | ) -> Tuple[torch.Tensor, PIL.Image.Image]: 81 | image = torchvision.transforms.ToTensor()(image_pil) 82 | r_min, r_max = range[0], range[1] 83 | image = image * (r_max - r_min) + r_min 84 | return image[None, ...], image_pil 85 | 86 | 87 | def pil2tensor(image_pil: PIL.Image.Image) -> torch.Tensor: 88 | height = image_pil.height 89 | width = image_pil.width 90 | imgs = [] 91 | img, _ = process_image(image_pil) 92 | imgs.append(img) 93 | imgs = torch.vstack(imgs) 94 | images = torch.nn.functional.interpolate( 95 | imgs, size=(height, width), mode="bilinear" 96 | ) 97 | image_tensors = images.to(torch.float16) 98 | return image_tensors 99 | -------------------------------------------------------------------------------- /streamdiffusion/acceleration/tensorrt/builder.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | from typing import * 4 | 5 | import torch 6 | 7 | from .models import BaseModel 8 | from .utilities import ( 9 | build_engine, 10 | export_onnx, 11 | optimize_onnx, 12 | ) 13 | 14 | 15 | def create_onnx_path(name, onnx_dir, opt=True): 16 | return os.path.join(onnx_dir, name + (".opt" if opt else "") + ".onnx") 17 | 18 | 19 | class EngineBuilder: 20 | def __init__( 21 | self, 22 | model: BaseModel, 23 | network: Any, 24 | device=torch.device("cuda"), 25 | ): 26 | self.device = device 27 | 28 | self.model = model 29 | self.network = network 30 | 31 | def build( 32 | self, 33 | onnx_path: str, 34 | onnx_opt_path: str, 35 | engine_path: str, 36 | opt_image_height: int = 512, 37 | opt_image_width: int = 512, 38 | opt_batch_size: int = 1, 39 | min_image_resolution: int = 256, 40 | max_image_resolution: int = 1024, 41 | build_enable_refit: bool = False, 42 | build_static_batch: bool = False, 43 | build_dynamic_shape: bool = False, 44 | build_all_tactics: bool = False, 45 | onnx_opset: int = 17, 46 | force_engine_build: bool = False, 47 | force_onnx_export: bool = False, 48 | force_onnx_optimize: bool = False, 49 | ): 50 | if not force_onnx_export and os.path.exists(onnx_path): 51 | print(f"Found cached model: {onnx_path}") 52 | else: 53 | print(f"Exporting model: {onnx_path}") 54 | export_onnx( 55 | self.network, 56 | onnx_path=onnx_path, 57 | model_data=self.model, 58 | opt_image_height=opt_image_height, 59 | opt_image_width=opt_image_width, 60 | opt_batch_size=opt_batch_size, 61 | onnx_opset=onnx_opset, 62 | ) 63 | del self.network 64 | gc.collect() 65 | torch.cuda.empty_cache() 66 | if not force_onnx_optimize and os.path.exists(onnx_opt_path): 67 | print(f"Found cached model: {onnx_opt_path}") 68 | else: 69 | print(f"Generating optimizing model: {onnx_opt_path}") 70 | optimize_onnx( 71 | onnx_path=onnx_path, 72 | onnx_opt_path=onnx_opt_path, 73 | model_data=self.model, 74 | ) 75 | self.model.min_latent_shape = min_image_resolution // 8 76 | self.model.max_latent_shape = max_image_resolution // 8 77 | if not force_engine_build and os.path.exists(engine_path): 78 | print(f"Found cached engine: {engine_path}") 79 | else: 80 | build_engine( 81 | engine_path=engine_path, 82 | onnx_opt_path=onnx_opt_path, 83 | model_data=self.model, 84 | opt_image_height=opt_image_height, 85 | opt_image_width=opt_image_width, 86 | opt_batch_size=opt_batch_size, 87 | build_static_batch=build_static_batch, 88 | build_dynamic_shape=build_dynamic_shape, 89 | build_all_tactics=build_all_tactics, 90 | build_enable_refit=build_enable_refit, 91 | ) 92 | 93 | gc.collect() 94 | torch.cuda.empty_cache() 95 | -------------------------------------------------------------------------------- /streamdiffusion/acceleration/tensorrt/engine.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import torch 4 | from diffusers.models.autoencoder_tiny import AutoencoderTinyOutput 5 | from diffusers.models.unet_2d_condition import UNet2DConditionOutput 6 | from diffusers.models.vae import DecoderOutput 7 | from polygraphy import cuda 8 | 9 | from .utilities import Engine 10 | 11 | 12 | class UNet2DConditionModelEngine: 13 | def __init__(self, filepath: str, stream: cuda.Stream, use_cuda_graph: bool = False): 14 | self.engine = Engine(filepath) 15 | self.stream = stream 16 | self.use_cuda_graph = use_cuda_graph 17 | 18 | self.engine.load() 19 | self.engine.activate() 20 | 21 | def __call__( 22 | self, 23 | latent_model_input: torch.Tensor, 24 | timestep: torch.Tensor, 25 | encoder_hidden_states: torch.Tensor, 26 | **kwargs, 27 | ) -> Any: 28 | if timestep.dtype != torch.float32: 29 | timestep = timestep.float() 30 | 31 | self.engine.allocate_buffers( 32 | shape_dict={ 33 | "sample": latent_model_input.shape, 34 | "timestep": timestep.shape, 35 | "encoder_hidden_states": encoder_hidden_states.shape, 36 | "latent": latent_model_input.shape, 37 | }, 38 | device=latent_model_input.device, 39 | ) 40 | 41 | noise_pred = self.engine.infer( 42 | { 43 | "sample": latent_model_input, 44 | "timestep": timestep, 45 | "encoder_hidden_states": encoder_hidden_states, 46 | }, 47 | self.stream, 48 | use_cuda_graph=self.use_cuda_graph, 49 | )["latent"] 50 | return UNet2DConditionOutput(sample=noise_pred) 51 | 52 | def to(self, *args, **kwargs): 53 | pass 54 | 55 | def forward(self, *args, **kwargs): 56 | pass 57 | 58 | 59 | class AutoencoderKLEngine: 60 | def __init__( 61 | self, 62 | encoder_path: str, 63 | decoder_path: str, 64 | stream: cuda.Stream, 65 | scaling_factor: int, 66 | use_cuda_graph: bool = False, 67 | ): 68 | self.encoder = Engine(encoder_path) 69 | self.decoder = Engine(decoder_path) 70 | self.stream = stream 71 | self.vae_scale_factor = scaling_factor 72 | self.use_cuda_graph = use_cuda_graph 73 | 74 | self.encoder.load() 75 | self.decoder.load() 76 | self.encoder.activate() 77 | self.decoder.activate() 78 | 79 | def encode(self, images: torch.Tensor, **kwargs): 80 | self.encoder.allocate_buffers( 81 | shape_dict={ 82 | "images": images.shape, 83 | "latent": ( 84 | images.shape[0], 85 | 4, 86 | images.shape[2] // self.vae_scale_factor, 87 | images.shape[3] // self.vae_scale_factor, 88 | ), 89 | }, 90 | device=images.device, 91 | ) 92 | latents = self.encoder.infer( 93 | {"images": images}, 94 | self.stream, 95 | use_cuda_graph=self.use_cuda_graph, 96 | )["latent"] 97 | return AutoencoderTinyOutput(latents=latents) 98 | 99 | def decode(self, latent: torch.Tensor, **kwargs): 100 | self.decoder.allocate_buffers( 101 | shape_dict={ 102 | "latent": latent.shape, 103 | "images": ( 104 | latent.shape[0], 105 | 3, 106 | latent.shape[2] * self.vae_scale_factor, 107 | latent.shape[3] * self.vae_scale_factor, 108 | ), 109 | }, 110 | device=latent.device, 111 | ) 112 | images = self.decoder.infer( 113 | {"latent": latent}, 114 | self.stream, 115 | use_cuda_graph=self.use_cuda_graph, 116 | )["images"] 117 | return DecoderOutput(sample=images) 118 | 119 | def to(self, *args, **kwargs): 120 | pass 121 | 122 | def forward(self, *args, **kwargs): 123 | pass 124 | -------------------------------------------------------------------------------- /streamdiffusion/acceleration/tensorrt/__init__.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | 4 | import torch 5 | from diffusers import AutoencoderKL, UNet2DConditionModel 6 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( 7 | retrieve_latents, 8 | ) 9 | from polygraphy import cuda 10 | 11 | from ...pipeline import StreamDiffusion 12 | from .builder import EngineBuilder, create_onnx_path 13 | from .engine import AutoencoderKLEngine, UNet2DConditionModelEngine 14 | from .models import VAE, BaseModel, UNet, VAEEncoder 15 | 16 | 17 | class TorchVAEEncoder(torch.nn.Module): 18 | def __init__(self, vae: AutoencoderKL): 19 | super().__init__() 20 | self.vae = vae 21 | 22 | def forward(self, x: torch.Tensor): 23 | return retrieve_latents(self.vae.encode(x)) 24 | 25 | 26 | def compile_vae_encoder( 27 | vae: TorchVAEEncoder, 28 | model_data: BaseModel, 29 | onnx_path: str, 30 | onnx_opt_path: str, 31 | engine_path: str, 32 | opt_batch_size: int = 1, 33 | engine_build_options: dict = {}, 34 | ): 35 | builder = EngineBuilder(model_data, vae, device=torch.device("cuda")) 36 | builder.build( 37 | onnx_path, 38 | onnx_opt_path, 39 | engine_path, 40 | opt_batch_size=opt_batch_size, 41 | **engine_build_options, 42 | ) 43 | 44 | 45 | def compile_vae_decoder( 46 | vae: AutoencoderKL, 47 | model_data: BaseModel, 48 | onnx_path: str, 49 | onnx_opt_path: str, 50 | engine_path: str, 51 | opt_batch_size: int = 1, 52 | engine_build_options: dict = {}, 53 | ): 54 | vae = vae.to(torch.device("cuda")) 55 | builder = EngineBuilder(model_data, vae, device=torch.device("cuda")) 56 | builder.build( 57 | onnx_path, 58 | onnx_opt_path, 59 | engine_path, 60 | opt_batch_size=opt_batch_size, 61 | **engine_build_options, 62 | ) 63 | 64 | 65 | def compile_unet( 66 | unet: UNet2DConditionModel, 67 | model_data: BaseModel, 68 | onnx_path: str, 69 | onnx_opt_path: str, 70 | engine_path: str, 71 | opt_batch_size: int = 1, 72 | engine_build_options: dict = {}, 73 | ): 74 | unet = unet.to(torch.device("cuda"), dtype=torch.float16) 75 | builder = EngineBuilder(model_data, unet, device=torch.device("cuda")) 76 | builder.build( 77 | onnx_path, 78 | onnx_opt_path, 79 | engine_path, 80 | opt_batch_size=opt_batch_size, 81 | **engine_build_options, 82 | ) 83 | 84 | 85 | def accelerate_with_tensorrt( 86 | stream: StreamDiffusion, 87 | engine_dir: str, 88 | max_batch_size: int = 2, 89 | min_batch_size: int = 1, 90 | use_cuda_graph: bool = False, 91 | engine_build_options: dict = {}, 92 | ): 93 | if "opt_batch_size" not in engine_build_options or engine_build_options["opt_batch_size"] is None: 94 | engine_build_options["opt_batch_size"] = max_batch_size 95 | text_encoder = stream.text_encoder 96 | unet = stream.unet 97 | vae = stream.vae 98 | 99 | del stream.unet, stream.vae, stream.pipe.unet, stream.pipe.vae 100 | 101 | vae_config = vae.config 102 | vae_dtype = vae.dtype 103 | 104 | unet.to(torch.device("cpu")) 105 | vae.to(torch.device("cpu")) 106 | 107 | gc.collect() 108 | torch.cuda.empty_cache() 109 | 110 | onnx_dir = os.path.join(engine_dir, "onnx") 111 | os.makedirs(onnx_dir, exist_ok=True) 112 | 113 | unet_engine_path = f"{engine_dir}/unet.engine" 114 | vae_encoder_engine_path = f"{engine_dir}/vae_encoder.engine" 115 | vae_decoder_engine_path = f"{engine_dir}/vae_decoder.engine" 116 | 117 | unet_model = UNet( 118 | fp16=True, 119 | device=stream.device, 120 | max_batch_size=max_batch_size, 121 | min_batch_size=min_batch_size, 122 | embedding_dim=text_encoder.config.hidden_size, 123 | unet_dim=unet.config.in_channels, 124 | ) 125 | vae_decoder_model = VAE( 126 | device=stream.device, 127 | max_batch_size=max_batch_size, 128 | min_batch_size=min_batch_size, 129 | ) 130 | vae_encoder_model = VAEEncoder( 131 | device=stream.device, 132 | max_batch_size=max_batch_size, 133 | min_batch_size=min_batch_size, 134 | ) 135 | 136 | if not os.path.exists(unet_engine_path): 137 | compile_unet( 138 | unet, 139 | unet_model, 140 | create_onnx_path("unet", onnx_dir, opt=False), 141 | create_onnx_path("unet", onnx_dir, opt=True), 142 | opt_batch_size=max_batch_size, 143 | **engine_build_options, 144 | ) 145 | else: 146 | del unet 147 | 148 | if not os.path.exists(vae_decoder_engine_path): 149 | compile_vae_decoder( 150 | vae, 151 | vae_decoder_model, 152 | create_onnx_path("vae_decoder", onnx_dir, opt=False), 153 | create_onnx_path("vae_decoder", onnx_dir, opt=True), 154 | opt_batch_size=max_batch_size, 155 | **engine_build_options, 156 | ) 157 | 158 | if not os.path.exists(vae_encoder_engine_path): 159 | vae_encoder = TorchVAEEncoder(vae).to(torch.device("cuda")) 160 | compile_vae_encoder( 161 | vae_encoder, 162 | vae_encoder_model, 163 | create_onnx_path("vae_encoder", onnx_dir, opt=False), 164 | create_onnx_path("vae_encoder", onnx_dir, opt=True), 165 | opt_batch_size=max_batch_size, 166 | **engine_build_options, 167 | ) 168 | 169 | del vae 170 | 171 | cuda_steram = cuda.Stream() 172 | 173 | stream.unet = UNet2DConditionModelEngine(unet_engine_path, cuda_steram, use_cuda_graph=use_cuda_graph) 174 | stream.vae = AutoencoderKLEngine( 175 | vae_encoder_engine_path, 176 | vae_decoder_engine_path, 177 | cuda_steram, 178 | stream.pipe.vae_scale_factor, 179 | use_cuda_graph=use_cuda_graph, 180 | ) 181 | setattr(stream.vae, "config", vae_config) 182 | setattr(stream.vae, "dtype", vae_dtype) 183 | 184 | gc.collect() 185 | torch.cuda.empty_cache() 186 | 187 | return stream 188 | -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from .streamdiffusion.pipeline import StreamDiffusion 4 | from .streamdiffusion.wrapper import StreamDiffusionWrapper 5 | from diffusers import AutoencoderTiny, StableDiffusionPipeline 6 | from pathlib import Path 7 | import traceback 8 | from typing import List, Literal, Optional, Union, Dict 9 | import torch 10 | import gc 11 | # Get the absolute path of various directories 12 | my_dir = os.path.dirname(os.path.abspath(__file__)) 13 | custom_nodes_dir = os.path.abspath(os.path.join(my_dir, '..')) 14 | comfy_dir = os.path.abspath(os.path.join(my_dir, '..', '..')) 15 | 16 | # Construct the path to the font file 17 | font_path = os.path.join(my_dir, 'arial.ttf') 18 | 19 | # Append comfy_dir to sys.path & import files 20 | sys.path.append(comfy_dir) 21 | import folder_paths 22 | 23 | 24 | import comfy.sample 25 | import comfy.samplers 26 | import comfy.sd 27 | import comfy.utils 28 | import comfy.latent_formats 29 | import comfy.model_management 30 | 31 | # Append my_dir to sys.path & import files 32 | sys.path.append(my_dir) 33 | 34 | from PIL import Image, ImageOps, ImageSequence 35 | from PIL.PngImagePlugin import PngInfo 36 | import numpy as np 37 | 38 | MAX_RESOLUTION=8192 39 | 40 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 41 | 42 | def get_async_loop(): 43 | loop = None 44 | try: 45 | loop = asyncio.get_event_loop() 46 | except: 47 | loop = asyncio.new_event_loop() 48 | asyncio.set_event_loop(loop) 49 | return loop 50 | 51 | 52 | class StreamDiffusion_Loader: 53 | @classmethod 54 | def INPUT_TYPES(s): 55 | loras = ["None"] + folder_paths.get_filename_list("loras") 56 | return { 57 | "required": { 58 | "ckpt_name": (["Baked ckpt"]+folder_paths.get_filename_list("checkpoints"), ), 59 | "vae_name": (["Baked VAE"] + folder_paths.get_filename_list("vae"),), 60 | "lcm_lora": (loras,), 61 | "acceleration": (["none", "xfomers", "sfast", "tensorrt"],), 62 | 63 | "use_tiny_vae": ("BOOLEAN", { "default": True }), 64 | "use_lcm_lora": ("BOOLEAN", { "default": True }), 65 | }, 66 | "optional": { 67 | "lora_stack": ("LORA_STACK", ), 68 | }, 69 | } 70 | RETURN_TYPES = ("MODEL",) 71 | RETURN_NAMES = ("MODEL",) 72 | 73 | FUNCTION = "efficientloader" 74 | CATEGORY = "StreamDiffusion/Loader" 75 | 76 | def efficientloader(self,ckpt_name,vae_name,lcm_lora,acceleration,use_tiny_vae,use_lcm_lora,lora_stack=None): 77 | 78 | device = comfy.model_management.get_torch_device() 79 | device_name = comfy.model_management.get_torch_device_name(device) 80 | vae_dtype=comfy.model_management.vae_dtype() 81 | 82 | if ckpt_name =='Baked ckpt': 83 | ckpt_path="KBlueLeaf/kohaku-v2.1" 84 | else: 85 | ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) 86 | 87 | if vae_name=='Baked VAE': 88 | vae_id = None 89 | else: 90 | vae_id=vae_name 91 | 92 | if lcm_lora=='None': 93 | lcm_lora_id=None 94 | else: 95 | lcm_lora_id= folder_paths.get_full_path("loras", lcm_lora) 96 | 97 | lora_dict =None 98 | if lora_stack is not None: 99 | lora_dict={} 100 | for lora_name, lora_scale, strength_clip in lora_stack: 101 | full_lora_name=folder_paths.get_full_path("loras", lora_name) 102 | lora_dict[full_lora_name]=lora_scale 103 | 104 | t_index_list=[32,40,45] 105 | 106 | stream = StreamDiffusionWrapper( 107 | model_id_or_path=ckpt_path, 108 | lora_dict=lora_dict, 109 | t_index_list=t_index_list, 110 | frame_buffer_size=1, 111 | width=512, 112 | height=512, 113 | warmup=10, 114 | acceleration=acceleration, 115 | use_tiny_vae =use_tiny_vae, 116 | device=device, 117 | use_lcm_lora = use_lcm_lora, 118 | output_type = 'pt', 119 | dtype = torch.float16, 120 | lcm_lora_id=lcm_lora_id, 121 | vae_id =vae_id, 122 | ) 123 | 124 | return (stream,) 125 | 126 | 127 | class StreamDiffusion_Sampler: 128 | @classmethod 129 | def INPUT_TYPES(s): 130 | return { 131 | "required":{ 132 | "model": ("MODEL",), 133 | "positive": ("STRING", {"default": "CLIP_POSITIVE","multiline": True}), 134 | "negative": ("STRING", {"default": "CLIP_NEGATIVE", "multiline": True}), 135 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), 136 | "steps": ("INT", {"default": 50, "min": 1, "max": 10000}), 137 | "cfg": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 100.0}), 138 | "delta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}), 139 | 140 | "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}), 141 | "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}), 142 | "batch_size": ("INT", {"default": 1, "min": 1, "max": 10000, "step": 1}),#The frame buffer size for denoising batch, by default 1. 143 | 144 | "index_list": ("STRING", {"default": "32,40,45","multiline": False}), 145 | 146 | "cfg_type": (["none", "full", "self", "initialize"],), 147 | 148 | "add_noise": ("BOOLEAN", { "default": True }), 149 | "use_denoising_batch": ("BOOLEAN", { "default": True }), 150 | "enable_similar_image_filter": ("BOOLEAN", { "default": False }), 151 | "use_safety_checker": ("BOOLEAN", { "default": False }), 152 | 153 | }, 154 | "optional": { 155 | "similar_image_filter_threshold": ("FLOAT", {"default": 0.98, "min": 0.0, "max": 100.0,"step": 0.01}), 156 | "similar_image_filter_max_skip_frame": ("INT", {"default": 10, "min": 0, "max": 100}), 157 | "latent": ("LATENT",), 158 | "image": ("IMAGE",), 159 | "lora_stack": ("LORA_STACK", ), 160 | } 161 | } 162 | 163 | RETURN_TYPES = ("IMAGE",) 164 | RETURN_NAMES = ("IMAGE",) 165 | 166 | FUNCTION = "sample" 167 | CATEGORY = "StreamDiffusion/Sampler" 168 | 169 | @torch.no_grad() 170 | def sample(self,model,positive,negative,seed,steps,cfg,delta,width,height,batch_size,index_list,cfg_type,add_noise,use_denoising_batch,enable_similar_image_filter=False,use_safety_checker=False,similar_image_filter_threshold= 0.98,similar_image_filter_max_skip_frame=10,latent=None,image=None,lora_stack=None): 171 | device = comfy.model_management.get_torch_device() 172 | device_name = comfy.model_management.get_torch_device_name(device) 173 | vae_dtype=comfy.model_management.vae_dtype() 174 | 175 | # latent_image = latent["samples"] 176 | # latent_image = latent_image.to(model.device) 177 | # batch_size,channel,latent_height,latent_width,=latent_image.shape 178 | # width=latent_width*8 179 | # height =latent_height*8 180 | 181 | t_index_list=[32,40,45] 182 | t_index_list =[int(i) for i in index_list.split(',')] 183 | 184 | if image ==None: 185 | mode = "txt2img" 186 | else: 187 | mode = "img2img" 188 | image = image.movedim(-1,1) 189 | 190 | if cfg <= 1.0: 191 | cfg_type = "none" 192 | 193 | if batch_size>1 and mode=="txt2img": 194 | use_denoising_batch=False 195 | 196 | 197 | # stream.set_sampler_param(t_index_list=t_index_list, 198 | # width=width, 199 | # height=height, 200 | # do_add_noise= add_noise=='enable', 201 | # frame_buffer_size=frame_buffer_size, 202 | # use_denoising_batch=use_denoising_batch, 203 | # cfg_type=cfg_type,) 204 | 205 | model.prepare( 206 | positive, 207 | negative, 208 | steps, 209 | cfg, 210 | delta, 211 | t_index_list, 212 | add_noise, 213 | enable_similar_image_filter, 214 | similar_image_filter_threshold, 215 | similar_image_filter_max_skip_frame, 216 | use_denoising_batch, 217 | cfg_type, 218 | seed, 219 | batch_size, 220 | use_safety_checker, 221 | ) 222 | # latent_image = self.predict_x0_batch( 223 | # torch.randn((stream.batch_size, 4, stream.latent_height, stream.latent_width)).to( 224 | # device=stream.device, dtype=stream.dtype 225 | # ) 226 | # ) 227 | 228 | # if batch_size==1: 229 | # for _ in range(stream.batch_size - 1): 230 | # stream() 231 | 232 | output = model.sample(image).permute(0, 2, 3, 1) 233 | 234 | return (output,) 235 | 236 | 237 | NODE_CLASS_MAPPINGS = { 238 | "StreamDiffusion_Loader": StreamDiffusion_Loader, 239 | "StreamDiffusion_Sampler":StreamDiffusion_Sampler, 240 | 241 | } 242 | NODE_DISPLAY_NAME_MAPPINGS = { 243 | "StreamDiffusion_Loader": "StreamDiffusion_Loader", 244 | "StreamDiffusion_Sampler":"StreamDiffusion_Sampler", 245 | 246 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /streamdiffusion/acceleration/tensorrt/utilities.py: -------------------------------------------------------------------------------- 1 | #! fork: https://github.com/NVIDIA/TensorRT/blob/main/demo/Diffusion/utilities.py 2 | 3 | # 4 | # Copyright 2022 The HuggingFace Inc. team. 5 | # SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 6 | # SPDX-License-Identifier: Apache-2.0 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | # 20 | 21 | import gc 22 | from collections import OrderedDict 23 | from typing import * 24 | 25 | import numpy as np 26 | import onnx 27 | import onnx_graphsurgeon as gs 28 | import tensorrt as trt 29 | import torch 30 | from cuda import cudart 31 | from PIL import Image 32 | from polygraphy import cuda 33 | from polygraphy.backend.common import bytes_from_path 34 | from polygraphy.backend.trt import ( 35 | CreateConfig, 36 | Profile, 37 | engine_from_bytes, 38 | engine_from_network, 39 | network_from_onnx_path, 40 | save_engine, 41 | ) 42 | from polygraphy.backend.trt import util as trt_util 43 | 44 | from .models import CLIP, VAE, BaseModel, UNet, VAEEncoder 45 | 46 | 47 | TRT_LOGGER = trt.Logger(trt.Logger.ERROR) 48 | 49 | # Map of numpy dtype -> torch dtype 50 | numpy_to_torch_dtype_dict = { 51 | np.uint8: torch.uint8, 52 | np.int8: torch.int8, 53 | np.int16: torch.int16, 54 | np.int32: torch.int32, 55 | np.int64: torch.int64, 56 | np.float16: torch.float16, 57 | np.float32: torch.float32, 58 | np.float64: torch.float64, 59 | np.complex64: torch.complex64, 60 | np.complex128: torch.complex128, 61 | } 62 | if np.version.full_version >= "1.24.0": 63 | numpy_to_torch_dtype_dict[np.bool_] = torch.bool 64 | else: 65 | numpy_to_torch_dtype_dict[np.bool] = torch.bool 66 | 67 | # Map of torch dtype -> numpy dtype 68 | torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} 69 | 70 | 71 | def CUASSERT(cuda_ret): 72 | err = cuda_ret[0] 73 | if err != cudart.cudaError_t.cudaSuccess: 74 | raise RuntimeError( 75 | f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t" 76 | ) 77 | if len(cuda_ret) > 1: 78 | return cuda_ret[1] 79 | return None 80 | 81 | 82 | class Engine: 83 | def __init__( 84 | self, 85 | engine_path, 86 | ): 87 | self.engine_path = engine_path 88 | self.engine = None 89 | self.context = None 90 | self.buffers = OrderedDict() 91 | self.tensors = OrderedDict() 92 | self.cuda_graph_instance = None # cuda graph 93 | 94 | def __del__(self): 95 | [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)] 96 | del self.engine 97 | del self.context 98 | del self.buffers 99 | del self.tensors 100 | 101 | def refit(self, onnx_path, onnx_refit_path): 102 | def convert_int64(arr): 103 | # TODO: smarter conversion 104 | if len(arr.shape) == 0: 105 | return np.int32(arr) 106 | return arr 107 | 108 | def add_to_map(refit_dict, name, values): 109 | if name in refit_dict: 110 | assert refit_dict[name] is None 111 | if values.dtype == np.int64: 112 | values = convert_int64(values) 113 | refit_dict[name] = values 114 | 115 | print(f"Refitting TensorRT engine with {onnx_refit_path} weights") 116 | refit_nodes = gs.import_onnx(onnx.load(onnx_refit_path)).toposort().nodes 117 | 118 | # Construct mapping from weight names in refit model -> original model 119 | name_map = {} 120 | for n, node in enumerate(gs.import_onnx(onnx.load(onnx_path)).toposort().nodes): 121 | refit_node = refit_nodes[n] 122 | assert node.op == refit_node.op 123 | # Constant nodes in ONNX do not have inputs but have a constant output 124 | if node.op == "Constant": 125 | name_map[refit_node.outputs[0].name] = node.outputs[0].name 126 | # Handle scale and bias weights 127 | elif node.op == "Conv": 128 | if node.inputs[1].__class__ == gs.Constant: 129 | name_map[refit_node.name + "_TRTKERNEL"] = node.name + "_TRTKERNEL" 130 | if node.inputs[2].__class__ == gs.Constant: 131 | name_map[refit_node.name + "_TRTBIAS"] = node.name + "_TRTBIAS" 132 | # For all other nodes: find node inputs that are initializers (gs.Constant) 133 | else: 134 | for i, inp in enumerate(node.inputs): 135 | if inp.__class__ == gs.Constant: 136 | name_map[refit_node.inputs[i].name] = inp.name 137 | 138 | def map_name(name): 139 | if name in name_map: 140 | return name_map[name] 141 | return name 142 | 143 | # Construct refit dictionary 144 | refit_dict = {} 145 | refitter = trt.Refitter(self.engine, TRT_LOGGER) 146 | all_weights = refitter.get_all() 147 | for layer_name, role in zip(all_weights[0], all_weights[1]): 148 | # for speciailized roles, use a unique name in the map: 149 | if role == trt.WeightsRole.KERNEL: 150 | name = layer_name + "_TRTKERNEL" 151 | elif role == trt.WeightsRole.BIAS: 152 | name = layer_name + "_TRTBIAS" 153 | else: 154 | name = layer_name 155 | 156 | assert name not in refit_dict, "Found duplicate layer: " + name 157 | refit_dict[name] = None 158 | 159 | for n in refit_nodes: 160 | # Constant nodes in ONNX do not have inputs but have a constant output 161 | if n.op == "Constant": 162 | name = map_name(n.outputs[0].name) 163 | print(f"Add Constant {name}\n") 164 | add_to_map(refit_dict, name, n.outputs[0].values) 165 | 166 | # Handle scale and bias weights 167 | elif n.op == "Conv": 168 | if n.inputs[1].__class__ == gs.Constant: 169 | name = map_name(n.name + "_TRTKERNEL") 170 | add_to_map(refit_dict, name, n.inputs[1].values) 171 | 172 | if n.inputs[2].__class__ == gs.Constant: 173 | name = map_name(n.name + "_TRTBIAS") 174 | add_to_map(refit_dict, name, n.inputs[2].values) 175 | 176 | # For all other nodes: find node inputs that are initializers (AKA gs.Constant) 177 | else: 178 | for inp in n.inputs: 179 | name = map_name(inp.name) 180 | if inp.__class__ == gs.Constant: 181 | add_to_map(refit_dict, name, inp.values) 182 | 183 | for layer_name, weights_role in zip(all_weights[0], all_weights[1]): 184 | if weights_role == trt.WeightsRole.KERNEL: 185 | custom_name = layer_name + "_TRTKERNEL" 186 | elif weights_role == trt.WeightsRole.BIAS: 187 | custom_name = layer_name + "_TRTBIAS" 188 | else: 189 | custom_name = layer_name 190 | 191 | # Skip refitting Trilu for now; scalar weights of type int64 value 1 - for clip model 192 | if layer_name.startswith("onnx::Trilu"): 193 | continue 194 | 195 | if refit_dict[custom_name] is not None: 196 | refitter.set_weights(layer_name, weights_role, refit_dict[custom_name]) 197 | else: 198 | print(f"[W] No refit weights for layer: {layer_name}") 199 | 200 | if not refitter.refit_cuda_engine(): 201 | print("Failed to refit!") 202 | exit(0) 203 | 204 | def build( 205 | self, 206 | onnx_path, 207 | fp16, 208 | input_profile=None, 209 | enable_refit=False, 210 | enable_all_tactics=False, 211 | timing_cache=None, 212 | workspace_size=0, 213 | ): 214 | print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") 215 | p = Profile() 216 | if input_profile: 217 | for name, dims in input_profile.items(): 218 | assert len(dims) == 3 219 | p.add(name, min=dims[0], opt=dims[1], max=dims[2]) 220 | 221 | config_kwargs = {} 222 | 223 | if workspace_size > 0: 224 | config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size} 225 | if not enable_all_tactics: 226 | config_kwargs["tactic_sources"] = [] 227 | 228 | engine = engine_from_network( 229 | network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]), 230 | config=CreateConfig( 231 | fp16=fp16, refittable=enable_refit, profiles=[p], load_timing_cache=timing_cache, **config_kwargs 232 | ), 233 | save_timing_cache=timing_cache, 234 | ) 235 | save_engine(engine, path=self.engine_path) 236 | 237 | def load(self): 238 | print(f"Loading TensorRT engine: {self.engine_path}") 239 | self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) 240 | 241 | def activate(self, reuse_device_memory=None): 242 | if reuse_device_memory: 243 | self.context = self.engine.create_execution_context_without_device_memory() 244 | self.context.device_memory = reuse_device_memory 245 | else: 246 | self.context = self.engine.create_execution_context() 247 | 248 | def allocate_buffers(self, shape_dict=None, device="cuda"): 249 | for idx in range(trt_util.get_bindings_per_profile(self.engine)): 250 | binding = self.engine[idx] 251 | if shape_dict and binding in shape_dict: 252 | shape = shape_dict[binding] 253 | else: 254 | shape = self.engine.get_binding_shape(binding) 255 | dtype = trt.nptype(self.engine.get_binding_dtype(binding)) 256 | if self.engine.binding_is_input(binding): 257 | self.context.set_binding_shape(idx, shape) 258 | tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) 259 | self.tensors[binding] = tensor 260 | 261 | def infer(self, feed_dict, stream, use_cuda_graph=False): 262 | for name, buf in feed_dict.items(): 263 | self.tensors[name].copy_(buf) 264 | 265 | for name, tensor in self.tensors.items(): 266 | self.context.set_tensor_address(name, tensor.data_ptr()) 267 | 268 | if use_cuda_graph: 269 | if self.cuda_graph_instance is not None: 270 | CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream.ptr)) 271 | CUASSERT(cudart.cudaStreamSynchronize(stream.ptr)) 272 | else: 273 | # do inference before CUDA graph capture 274 | noerror = self.context.execute_async_v3(stream.ptr) 275 | if not noerror: 276 | raise ValueError("ERROR: inference failed.") 277 | # capture cuda graph 278 | CUASSERT( 279 | cudart.cudaStreamBeginCapture(stream.ptr, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal) 280 | ) 281 | self.context.execute_async_v3(stream.ptr) 282 | self.graph = CUASSERT(cudart.cudaStreamEndCapture(stream.ptr)) 283 | self.cuda_graph_instance = CUASSERT(cudart.cudaGraphInstantiate(self.graph, 0)) 284 | else: 285 | noerror = self.context.execute_async_v3(stream.ptr) 286 | if not noerror: 287 | raise ValueError("ERROR: inference failed.") 288 | 289 | return self.tensors 290 | 291 | 292 | def decode_images(images: torch.Tensor): 293 | images = ( 294 | ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy() 295 | ) 296 | return [Image.fromarray(x) for x in images] 297 | 298 | 299 | def preprocess_image(image: Image.Image): 300 | w, h = image.size 301 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 302 | image = image.resize((w, h)) 303 | init_image = np.array(image).astype(np.float32) / 255.0 304 | init_image = init_image[None].transpose(0, 3, 1, 2) 305 | init_image = torch.from_numpy(init_image).contiguous() 306 | return 2.0 * init_image - 1.0 307 | 308 | 309 | def prepare_mask_and_masked_image(image: Image.Image, mask: Image.Image): 310 | if isinstance(image, Image.Image): 311 | image = np.array(image.convert("RGB")) 312 | image = image[None].transpose(0, 3, 1, 2) 313 | image = torch.from_numpy(image).to(dtype=torch.float32).contiguous() / 127.5 - 1.0 314 | if isinstance(mask, Image.Image): 315 | mask = np.array(mask.convert("L")) 316 | mask = mask.astype(np.float32) / 255.0 317 | mask = mask[None, None] 318 | mask[mask < 0.5] = 0 319 | mask[mask >= 0.5] = 1 320 | mask = torch.from_numpy(mask).to(dtype=torch.float32).contiguous() 321 | 322 | masked_image = image * (mask < 0.5) 323 | 324 | return mask, masked_image 325 | 326 | 327 | def create_models( 328 | model_id: str, 329 | use_auth_token: Optional[str], 330 | device: Union[str, torch.device], 331 | max_batch_size: int, 332 | unet_in_channels: int = 4, 333 | embedding_dim: int = 768, 334 | ): 335 | models = { 336 | "clip": CLIP( 337 | hf_token=use_auth_token, 338 | device=device, 339 | max_batch_size=max_batch_size, 340 | embedding_dim=embedding_dim, 341 | ), 342 | "unet": UNet( 343 | hf_token=use_auth_token, 344 | fp16=True, 345 | device=device, 346 | max_batch_size=max_batch_size, 347 | embedding_dim=embedding_dim, 348 | unet_dim=unet_in_channels, 349 | ), 350 | "vae": VAE( 351 | hf_token=use_auth_token, 352 | device=device, 353 | max_batch_size=max_batch_size, 354 | embedding_dim=embedding_dim, 355 | ), 356 | "vae_encoder": VAEEncoder( 357 | hf_token=use_auth_token, 358 | device=device, 359 | max_batch_size=max_batch_size, 360 | embedding_dim=embedding_dim, 361 | ), 362 | } 363 | return models 364 | 365 | 366 | def build_engine( 367 | engine_path: str, 368 | onnx_opt_path: str, 369 | model_data: BaseModel, 370 | opt_image_height: int, 371 | opt_image_width: int, 372 | opt_batch_size: int, 373 | build_static_batch: bool = False, 374 | build_dynamic_shape: bool = False, 375 | build_all_tactics: bool = False, 376 | build_enable_refit: bool = False, 377 | ): 378 | _, free_mem, _ = cudart.cudaMemGetInfo() 379 | GiB = 2**30 380 | if free_mem > 6 * GiB: 381 | activation_carveout = 4 * GiB 382 | max_workspace_size = free_mem - activation_carveout 383 | else: 384 | max_workspace_size = 0 385 | engine = Engine(engine_path) 386 | input_profile = model_data.get_input_profile( 387 | opt_batch_size, 388 | opt_image_height, 389 | opt_image_width, 390 | static_batch=build_static_batch, 391 | static_shape=not build_dynamic_shape, 392 | ) 393 | engine.build( 394 | onnx_opt_path, 395 | fp16=True, 396 | input_profile=input_profile, 397 | enable_refit=build_enable_refit, 398 | enable_all_tactics=build_all_tactics, 399 | workspace_size=max_workspace_size, 400 | ) 401 | 402 | return engine 403 | 404 | 405 | def export_onnx( 406 | model, 407 | onnx_path: str, 408 | model_data: BaseModel, 409 | opt_image_height: int, 410 | opt_image_width: int, 411 | opt_batch_size: int, 412 | onnx_opset: int, 413 | ): 414 | with torch.inference_mode(), torch.autocast("cuda"): 415 | inputs = model_data.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) 416 | torch.onnx.export( 417 | model, 418 | inputs, 419 | onnx_path, 420 | export_params=True, 421 | opset_version=onnx_opset, 422 | do_constant_folding=True, 423 | input_names=model_data.get_input_names(), 424 | output_names=model_data.get_output_names(), 425 | dynamic_axes=model_data.get_dynamic_axes(), 426 | ) 427 | del model 428 | gc.collect() 429 | torch.cuda.empty_cache() 430 | 431 | 432 | def optimize_onnx( 433 | onnx_path: str, 434 | onnx_opt_path: str, 435 | model_data: BaseModel, 436 | ): 437 | onnx_opt_graph = model_data.optimize(onnx.load(onnx_path)) 438 | onnx.save(onnx_opt_graph, onnx_opt_path) 439 | del onnx_opt_graph 440 | gc.collect() 441 | torch.cuda.empty_cache() 442 | -------------------------------------------------------------------------------- /streamdiffusion/acceleration/tensorrt/models.py: -------------------------------------------------------------------------------- 1 | #! fork: https://github.com/NVIDIA/TensorRT/blob/main/demo/Diffusion/models.py 2 | 3 | # 4 | # SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 5 | # SPDX-License-Identifier: Apache-2.0 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | 20 | import onnx_graphsurgeon as gs 21 | import torch 22 | from onnx import shape_inference 23 | from polygraphy.backend.onnx.loader import fold_constants 24 | 25 | 26 | class Optimizer: 27 | def __init__(self, onnx_graph, verbose=False): 28 | self.graph = gs.import_onnx(onnx_graph) 29 | self.verbose = verbose 30 | 31 | def info(self, prefix): 32 | if self.verbose: 33 | print( 34 | f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs" 35 | ) 36 | 37 | def cleanup(self, return_onnx=False): 38 | self.graph.cleanup().toposort() 39 | if return_onnx: 40 | return gs.export_onnx(self.graph) 41 | 42 | def select_outputs(self, keep, names=None): 43 | self.graph.outputs = [self.graph.outputs[o] for o in keep] 44 | if names: 45 | for i, name in enumerate(names): 46 | self.graph.outputs[i].name = name 47 | 48 | def fold_constants(self, return_onnx=False): 49 | onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) 50 | self.graph = gs.import_onnx(onnx_graph) 51 | if return_onnx: 52 | return onnx_graph 53 | 54 | def infer_shapes(self, return_onnx=False): 55 | onnx_graph = gs.export_onnx(self.graph) 56 | if onnx_graph.ByteSize() > 2147483648: 57 | raise TypeError("ERROR: model size exceeds supported 2GB limit") 58 | else: 59 | onnx_graph = shape_inference.infer_shapes(onnx_graph) 60 | 61 | self.graph = gs.import_onnx(onnx_graph) 62 | if return_onnx: 63 | return onnx_graph 64 | 65 | 66 | class BaseModel: 67 | def __init__( 68 | self, 69 | fp16=False, 70 | device="cuda", 71 | verbose=True, 72 | max_batch_size=16, 73 | min_batch_size=1, 74 | embedding_dim=768, 75 | text_maxlen=77, 76 | ): 77 | self.name = "SD Model" 78 | self.fp16 = fp16 79 | self.device = device 80 | self.verbose = verbose 81 | 82 | self.min_batch = min_batch_size 83 | self.max_batch = max_batch_size 84 | self.min_image_shape = 256 # min image resolution: 256x256 85 | self.max_image_shape = 1024 # max image resolution: 1024x1024 86 | self.min_latent_shape = self.min_image_shape // 8 87 | self.max_latent_shape = self.max_image_shape // 8 88 | 89 | self.embedding_dim = embedding_dim 90 | self.text_maxlen = text_maxlen 91 | 92 | def get_model(self): 93 | pass 94 | 95 | def get_input_names(self): 96 | pass 97 | 98 | def get_output_names(self): 99 | pass 100 | 101 | def get_dynamic_axes(self): 102 | return None 103 | 104 | def get_sample_input(self, batch_size, image_height, image_width): 105 | pass 106 | 107 | def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): 108 | return None 109 | 110 | def get_shape_dict(self, batch_size, image_height, image_width): 111 | return None 112 | 113 | def optimize(self, onnx_graph): 114 | opt = Optimizer(onnx_graph, verbose=self.verbose) 115 | opt.info(self.name + ": original") 116 | opt.cleanup() 117 | opt.info(self.name + ": cleanup") 118 | opt.fold_constants() 119 | opt.info(self.name + ": fold constants") 120 | opt.infer_shapes() 121 | opt.info(self.name + ": shape inference") 122 | onnx_opt_graph = opt.cleanup(return_onnx=True) 123 | opt.info(self.name + ": finished") 124 | return onnx_opt_graph 125 | 126 | def check_dims(self, batch_size, image_height, image_width): 127 | assert batch_size >= self.min_batch and batch_size <= self.max_batch 128 | assert image_height % 8 == 0 or image_width % 8 == 0 129 | latent_height = image_height // 8 130 | latent_width = image_width // 8 131 | assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape 132 | assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape 133 | return (latent_height, latent_width) 134 | 135 | def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape): 136 | min_batch = batch_size if static_batch else self.min_batch 137 | max_batch = batch_size if static_batch else self.max_batch 138 | latent_height = image_height // 8 139 | latent_width = image_width // 8 140 | min_image_height = image_height if static_shape else self.min_image_shape 141 | max_image_height = image_height if static_shape else self.max_image_shape 142 | min_image_width = image_width if static_shape else self.min_image_shape 143 | max_image_width = image_width if static_shape else self.max_image_shape 144 | min_latent_height = latent_height if static_shape else self.min_latent_shape 145 | max_latent_height = latent_height if static_shape else self.max_latent_shape 146 | min_latent_width = latent_width if static_shape else self.min_latent_shape 147 | max_latent_width = latent_width if static_shape else self.max_latent_shape 148 | return ( 149 | min_batch, 150 | max_batch, 151 | min_image_height, 152 | max_image_height, 153 | min_image_width, 154 | max_image_width, 155 | min_latent_height, 156 | max_latent_height, 157 | min_latent_width, 158 | max_latent_width, 159 | ) 160 | 161 | 162 | class CLIP(BaseModel): 163 | def __init__(self, device, max_batch_size, embedding_dim, min_batch_size=1): 164 | super(CLIP, self).__init__( 165 | device=device, 166 | max_batch_size=max_batch_size, 167 | min_batch_size=min_batch_size, 168 | embedding_dim=embedding_dim, 169 | ) 170 | self.name = "CLIP" 171 | 172 | def get_input_names(self): 173 | return ["input_ids"] 174 | 175 | def get_output_names(self): 176 | return ["text_embeddings", "pooler_output"] 177 | 178 | def get_dynamic_axes(self): 179 | return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} 180 | 181 | def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): 182 | self.check_dims(batch_size, image_height, image_width) 183 | min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( 184 | batch_size, image_height, image_width, static_batch, static_shape 185 | ) 186 | return { 187 | "input_ids": [ 188 | (min_batch, self.text_maxlen), 189 | (batch_size, self.text_maxlen), 190 | (max_batch, self.text_maxlen), 191 | ] 192 | } 193 | 194 | def get_shape_dict(self, batch_size, image_height, image_width): 195 | self.check_dims(batch_size, image_height, image_width) 196 | return { 197 | "input_ids": (batch_size, self.text_maxlen), 198 | "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), 199 | } 200 | 201 | def get_sample_input(self, batch_size, image_height, image_width): 202 | self.check_dims(batch_size, image_height, image_width) 203 | return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) 204 | 205 | def optimize(self, onnx_graph): 206 | opt = Optimizer(onnx_graph) 207 | opt.info(self.name + ": original") 208 | opt.select_outputs([0]) # delete graph output#1 209 | opt.cleanup() 210 | opt.info(self.name + ": remove output[1]") 211 | opt.fold_constants() 212 | opt.info(self.name + ": fold constants") 213 | opt.infer_shapes() 214 | opt.info(self.name + ": shape inference") 215 | opt.select_outputs([0], names=["text_embeddings"]) # rename network output 216 | opt.info(self.name + ": remove output[0]") 217 | opt_onnx_graph = opt.cleanup(return_onnx=True) 218 | opt.info(self.name + ": finished") 219 | return opt_onnx_graph 220 | 221 | 222 | class UNet(BaseModel): 223 | def __init__( 224 | self, 225 | fp16=False, 226 | device="cuda", 227 | max_batch_size=16, 228 | min_batch_size=1, 229 | embedding_dim=768, 230 | text_maxlen=77, 231 | unet_dim=4, 232 | ): 233 | super(UNet, self).__init__( 234 | fp16=fp16, 235 | device=device, 236 | max_batch_size=max_batch_size, 237 | min_batch_size=min_batch_size, 238 | embedding_dim=embedding_dim, 239 | text_maxlen=text_maxlen, 240 | ) 241 | self.unet_dim = unet_dim 242 | self.name = "UNet" 243 | 244 | def get_input_names(self): 245 | return ["sample", "timestep", "encoder_hidden_states"] 246 | 247 | def get_output_names(self): 248 | return ["latent"] 249 | 250 | def get_dynamic_axes(self): 251 | return { 252 | "sample": {0: "2B", 2: "H", 3: "W"}, 253 | "timestep": {0: "2B"}, 254 | "encoder_hidden_states": {0: "2B"}, 255 | "latent": {0: "2B", 2: "H", 3: "W"}, 256 | } 257 | 258 | def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): 259 | latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) 260 | ( 261 | min_batch, 262 | max_batch, 263 | _, 264 | _, 265 | _, 266 | _, 267 | min_latent_height, 268 | max_latent_height, 269 | min_latent_width, 270 | max_latent_width, 271 | ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) 272 | return { 273 | "sample": [ 274 | (min_batch, self.unet_dim, min_latent_height, min_latent_width), 275 | (batch_size, self.unet_dim, latent_height, latent_width), 276 | (max_batch, self.unet_dim, max_latent_height, max_latent_width), 277 | ], 278 | "timestep": [(min_batch,), (batch_size,), (max_batch,)], 279 | "encoder_hidden_states": [ 280 | (min_batch, self.text_maxlen, self.embedding_dim), 281 | (batch_size, self.text_maxlen, self.embedding_dim), 282 | (max_batch, self.text_maxlen, self.embedding_dim), 283 | ], 284 | } 285 | 286 | def get_shape_dict(self, batch_size, image_height, image_width): 287 | latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) 288 | return { 289 | "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), 290 | "timestep": (2 * batch_size,), 291 | "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), 292 | "latent": (2 * batch_size, 4, latent_height, latent_width), 293 | } 294 | 295 | def get_sample_input(self, batch_size, image_height, image_width): 296 | latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) 297 | dtype = torch.float16 if self.fp16 else torch.float32 298 | return ( 299 | torch.randn( 300 | 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device 301 | ), 302 | torch.ones((2 * batch_size,), dtype=torch.float32, device=self.device), 303 | torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), 304 | ) 305 | 306 | 307 | class VAE(BaseModel): 308 | def __init__(self, device, max_batch_size, min_batch_size=1): 309 | super(VAE, self).__init__( 310 | device=device, 311 | max_batch_size=max_batch_size, 312 | min_batch_size=min_batch_size, 313 | embedding_dim=None, 314 | ) 315 | self.name = "VAE decoder" 316 | 317 | def get_input_names(self): 318 | return ["latent"] 319 | 320 | def get_output_names(self): 321 | return ["images"] 322 | 323 | def get_dynamic_axes(self): 324 | return { 325 | "latent": {0: "B", 2: "H", 3: "W"}, 326 | "images": {0: "B", 2: "8H", 3: "8W"}, 327 | } 328 | 329 | def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): 330 | latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) 331 | ( 332 | min_batch, 333 | max_batch, 334 | _, 335 | _, 336 | _, 337 | _, 338 | min_latent_height, 339 | max_latent_height, 340 | min_latent_width, 341 | max_latent_width, 342 | ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) 343 | return { 344 | "latent": [ 345 | (min_batch, 4, min_latent_height, min_latent_width), 346 | (batch_size, 4, latent_height, latent_width), 347 | (max_batch, 4, max_latent_height, max_latent_width), 348 | ] 349 | } 350 | 351 | def get_shape_dict(self, batch_size, image_height, image_width): 352 | latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) 353 | return { 354 | "latent": (batch_size, 4, latent_height, latent_width), 355 | "images": (batch_size, 3, image_height, image_width), 356 | } 357 | 358 | def get_sample_input(self, batch_size, image_height, image_width): 359 | latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) 360 | return torch.randn( 361 | batch_size, 362 | 4, 363 | latent_height, 364 | latent_width, 365 | dtype=torch.float32, 366 | device=self.device, 367 | ) 368 | 369 | 370 | class VAEEncoder(BaseModel): 371 | def __init__(self, device, max_batch_size, min_batch_size=1): 372 | super(VAEEncoder, self).__init__( 373 | device=device, 374 | max_batch_size=max_batch_size, 375 | min_batch_size=min_batch_size, 376 | embedding_dim=None, 377 | ) 378 | self.name = "VAE encoder" 379 | 380 | def get_input_names(self): 381 | return ["images"] 382 | 383 | def get_output_names(self): 384 | return ["latent"] 385 | 386 | def get_dynamic_axes(self): 387 | return { 388 | "images": {0: "B", 2: "8H", 3: "8W"}, 389 | "latent": {0: "B", 2: "H", 3: "W"}, 390 | } 391 | 392 | def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): 393 | assert batch_size >= self.min_batch and batch_size <= self.max_batch 394 | min_batch = batch_size if static_batch else self.min_batch 395 | max_batch = batch_size if static_batch else self.max_batch 396 | self.check_dims(batch_size, image_height, image_width) 397 | ( 398 | min_batch, 399 | max_batch, 400 | min_image_height, 401 | max_image_height, 402 | min_image_width, 403 | max_image_width, 404 | _, 405 | _, 406 | _, 407 | _, 408 | ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) 409 | 410 | return { 411 | "images": [ 412 | (min_batch, 3, min_image_height, min_image_width), 413 | (batch_size, 3, image_height, image_width), 414 | (max_batch, 3, max_image_height, max_image_width), 415 | ], 416 | } 417 | 418 | def get_shape_dict(self, batch_size, image_height, image_width): 419 | latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) 420 | return { 421 | "images": (batch_size, 3, image_height, image_width), 422 | "latent": (batch_size, 4, latent_height, latent_width), 423 | } 424 | 425 | def get_sample_input(self, batch_size, image_height, image_width): 426 | self.check_dims(batch_size, image_height, image_width) 427 | return torch.randn( 428 | batch_size, 429 | 3, 430 | image_height, 431 | image_width, 432 | dtype=torch.float32, 433 | device=self.device, 434 | ) 435 | -------------------------------------------------------------------------------- /streamdiffusion/pipeline.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import List, Optional, Union, Any, Dict, Tuple, Literal 3 | 4 | import numpy as np 5 | import PIL.Image 6 | import torch 7 | from diffusers import LCMScheduler, StableDiffusionPipeline 8 | from diffusers.image_processor import VaeImageProcessor 9 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( 10 | retrieve_latents, 11 | ) 12 | 13 | from .image_filter import SimilarImageFilter 14 | from .image_utils import postprocess_image 15 | 16 | 17 | class StreamDiffusion: 18 | def __init__( 19 | self, 20 | pipe: StableDiffusionPipeline, 21 | t_index_list: List[int], 22 | torch_dtype: torch.dtype = torch.float16, 23 | width: int = 512, 24 | height: int = 512, 25 | do_add_noise: bool = True, 26 | use_denoising_batch: bool = True, 27 | frame_buffer_size: int = 1, 28 | cfg_type: Literal["none", "full", "self", "initialize"] = "self", 29 | ) -> None: 30 | self.device = pipe.device 31 | self.dtype = torch_dtype 32 | self.generator = None 33 | 34 | self.height = height 35 | self.width = width 36 | 37 | self.latent_height = int(height // pipe.vae_scale_factor) 38 | self.latent_width = int(width // pipe.vae_scale_factor) 39 | 40 | self.frame_bff_size = frame_buffer_size 41 | self.denoising_steps_num = len(t_index_list) 42 | 43 | self.cfg_type = cfg_type 44 | 45 | if use_denoising_batch: 46 | self.batch_size = self.denoising_steps_num * frame_buffer_size 47 | if self.cfg_type == "initialize": 48 | self.trt_unet_batch_size = ( 49 | self.denoising_steps_num + 1 50 | ) * self.frame_bff_size 51 | elif self.cfg_type == "full": 52 | self.trt_unet_batch_size = ( 53 | 2 * self.denoising_steps_num * self.frame_bff_size 54 | ) 55 | else: 56 | self.trt_unet_batch_size = self.denoising_steps_num * frame_buffer_size 57 | else: 58 | self.trt_unet_batch_size = self.frame_bff_size 59 | self.batch_size = frame_buffer_size 60 | 61 | self.t_list = t_index_list 62 | 63 | self.do_add_noise = do_add_noise 64 | self.use_denoising_batch = use_denoising_batch 65 | 66 | self.similar_image_filter = False 67 | self.similar_filter = SimilarImageFilter() 68 | self.prev_image_result = None 69 | 70 | self.pipe = pipe 71 | self.image_processor = VaeImageProcessor(pipe.vae_scale_factor) 72 | 73 | self.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) 74 | self.text_encoder = pipe.text_encoder 75 | self.unet = pipe.unet 76 | self.vae = pipe.vae 77 | 78 | self.inference_time_ema = 0 79 | 80 | def set_sampler_param(self, 81 | t_index_list: List[int], 82 | width: int = 512, 83 | height: int = 512, 84 | do_add_noise: bool = True, 85 | use_denoising_batch: bool = True, 86 | frame_buffer_size: int = 1, 87 | cfg_type: Literal["none", "full", "self", "initialize"] = "self",): 88 | self.height = height 89 | self.width = width 90 | 91 | self.latent_height = int(height // self.pipe.vae_scale_factor) 92 | self.latent_width = int(width // self.pipe.vae_scale_factor) 93 | 94 | self.frame_bff_size = frame_buffer_size 95 | 96 | self.cfg_type = cfg_type 97 | self.t_list = t_index_list 98 | 99 | self.do_add_noise = do_add_noise 100 | self.use_denoising_batch = use_denoising_batch 101 | 102 | self.inference_time_ema = 0 103 | 104 | self.denoising_steps_num = len(self.t_list) 105 | if self.use_denoising_batch: 106 | self.batch_size = self.denoising_steps_num * self.frame_bff_size 107 | if self.cfg_type == "initialize": 108 | self.trt_unet_batch_size = ( 109 | self.denoising_steps_num + 1 110 | ) * self.frame_bff_size 111 | elif self.cfg_type == "full": 112 | self.trt_unet_batch_size = ( 113 | 2 * self.denoising_steps_num * self.frame_bff_size 114 | ) 115 | else: 116 | self.trt_unet_batch_size = self.denoising_steps_num * self.frame_bff_size 117 | else: 118 | self.trt_unet_batch_size = self.frame_bff_size 119 | self.batch_size = self.frame_bff_size 120 | 121 | def load_lcm_lora( 122 | self, 123 | pretrained_model_name_or_path_or_dict: Union[ 124 | str, Dict[str, torch.Tensor] 125 | ] = "latent-consistency/lcm-lora-sdv1-5", 126 | adapter_name: Optional[Any] = None, 127 | **kwargs, 128 | ) -> None: 129 | self.pipe.load_lora_weights( 130 | pretrained_model_name_or_path_or_dict, adapter_name, **kwargs 131 | ) 132 | 133 | def load_lora( 134 | self, 135 | pretrained_lora_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], 136 | adapter_name: Optional[Any] = None, 137 | **kwargs, 138 | ) -> None: 139 | self.pipe.load_lora_weights( 140 | pretrained_lora_model_name_or_path_or_dict, adapter_name, **kwargs 141 | ) 142 | 143 | def fuse_lora( 144 | self, 145 | fuse_unet: bool = True, 146 | fuse_text_encoder: bool = True, 147 | lora_scale: float = 1.0, 148 | safe_fusing: bool = False, 149 | ) -> None: 150 | self.pipe.fuse_lora( 151 | fuse_unet=fuse_unet, 152 | fuse_text_encoder=fuse_text_encoder, 153 | lora_scale=lora_scale, 154 | safe_fusing=safe_fusing, 155 | ) 156 | 157 | def enable_similar_image_filter(self, threshold: float = 0.98, max_skip_frame: float = 10) -> None: 158 | self.similar_image_filter = True 159 | self.similar_filter.set_threshold(threshold) 160 | self.similar_filter.set_max_skip_frame(max_skip_frame) 161 | 162 | def disable_similar_image_filter(self) -> None: 163 | self.similar_image_filter = False 164 | 165 | @torch.no_grad() 166 | def prepare( 167 | self, 168 | prompt: str, 169 | negative_prompt: str = "", 170 | num_inference_steps: int = 50, 171 | guidance_scale: float = 1.2, 172 | delta: float = 1.0, 173 | generator: Optional[torch.Generator] = torch.Generator(), 174 | seed: int = 2, 175 | ) -> None: 176 | self.generator = generator 177 | self.generator.manual_seed(seed) 178 | # initialize x_t_latent (it can be any random tensor) 179 | if self.denoising_steps_num > 1: 180 | self.x_t_latent_buffer = torch.zeros( 181 | ( 182 | (self.denoising_steps_num - 1) * self.frame_bff_size, 183 | 4, 184 | self.latent_height, 185 | self.latent_width, 186 | ), 187 | dtype=self.dtype, 188 | device=self.device, 189 | ) 190 | else: 191 | self.x_t_latent_buffer = None 192 | 193 | if self.cfg_type == "none": 194 | self.guidance_scale = 1.0 195 | else: 196 | self.guidance_scale = guidance_scale 197 | self.delta = delta 198 | 199 | do_classifier_free_guidance = False 200 | if self.guidance_scale > 1.0: 201 | do_classifier_free_guidance = True 202 | 203 | encoder_output = self.pipe.encode_prompt( 204 | prompt=prompt, 205 | device=self.device, 206 | num_images_per_prompt=1, 207 | do_classifier_free_guidance=do_classifier_free_guidance, 208 | negative_prompt=negative_prompt, 209 | ) 210 | self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1) 211 | 212 | if self.use_denoising_batch and self.cfg_type == "full": 213 | uncond_prompt_embeds = encoder_output[1].repeat(self.batch_size, 1, 1) 214 | elif self.cfg_type == "initialize": 215 | uncond_prompt_embeds = encoder_output[1].repeat(self.frame_bff_size, 1, 1) 216 | 217 | if self.guidance_scale > 1.0 and ( 218 | self.cfg_type == "initialize" or self.cfg_type == "full" 219 | ): 220 | self.prompt_embeds = torch.cat( 221 | [uncond_prompt_embeds, self.prompt_embeds], dim=0 222 | ) 223 | 224 | self.scheduler.set_timesteps(num_inference_steps, self.device) 225 | self.timesteps = self.scheduler.timesteps.to(self.device) 226 | 227 | # make sub timesteps list based on the indices in the t_list list and the values in the timesteps list 228 | self.sub_timesteps = [] 229 | for t in self.t_list: 230 | self.sub_timesteps.append(self.timesteps[t]) 231 | 232 | sub_timesteps_tensor = torch.tensor( 233 | self.sub_timesteps, dtype=torch.long, device=self.device 234 | ) 235 | self.sub_timesteps_tensor = torch.repeat_interleave( 236 | sub_timesteps_tensor, 237 | repeats=self.frame_bff_size if self.use_denoising_batch else 1, 238 | dim=0, 239 | ) 240 | 241 | self.init_noise = torch.randn( 242 | (self.batch_size, 4, self.latent_height, self.latent_width), 243 | generator=generator, 244 | ).to(device=self.device, dtype=self.dtype) 245 | 246 | self.stock_noise = torch.zeros_like(self.init_noise) 247 | 248 | c_skip_list = [] 249 | c_out_list = [] 250 | for timestep in self.sub_timesteps: 251 | c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete( 252 | timestep 253 | ) 254 | c_skip_list.append(c_skip) 255 | c_out_list.append(c_out) 256 | 257 | self.c_skip = ( 258 | torch.stack(c_skip_list) 259 | .view(len(self.t_list), 1, 1, 1) 260 | .to(dtype=self.dtype, device=self.device) 261 | ) 262 | self.c_out = ( 263 | torch.stack(c_out_list) 264 | .view(len(self.t_list), 1, 1, 1) 265 | .to(dtype=self.dtype, device=self.device) 266 | ) 267 | 268 | alpha_prod_t_sqrt_list = [] 269 | beta_prod_t_sqrt_list = [] 270 | for timestep in self.sub_timesteps: 271 | alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt() 272 | beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt() 273 | alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt) 274 | beta_prod_t_sqrt_list.append(beta_prod_t_sqrt) 275 | alpha_prod_t_sqrt = ( 276 | torch.stack(alpha_prod_t_sqrt_list) 277 | .view(len(self.t_list), 1, 1, 1) 278 | .to(dtype=self.dtype, device=self.device) 279 | ) 280 | beta_prod_t_sqrt = ( 281 | torch.stack(beta_prod_t_sqrt_list) 282 | .view(len(self.t_list), 1, 1, 1) 283 | .to(dtype=self.dtype, device=self.device) 284 | ) 285 | self.alpha_prod_t_sqrt = torch.repeat_interleave( 286 | alpha_prod_t_sqrt, 287 | repeats=self.frame_bff_size if self.use_denoising_batch else 1, 288 | dim=0, 289 | ) 290 | self.beta_prod_t_sqrt = torch.repeat_interleave( 291 | beta_prod_t_sqrt, 292 | repeats=self.frame_bff_size if self.use_denoising_batch else 1, 293 | dim=0, 294 | ) 295 | 296 | @torch.no_grad() 297 | def update_prompt(self, prompt: str,negative_prompt: Optional[str] = None) -> None: 298 | encoder_output = self.pipe.encode_prompt( 299 | prompt=prompt, 300 | negative_prompt=negative_prompt, 301 | device=self.device, 302 | num_images_per_prompt=1, 303 | do_classifier_free_guidance=False, 304 | ) 305 | self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1) 306 | 307 | def add_noise( 308 | self, 309 | original_samples: torch.Tensor, 310 | noise: torch.Tensor, 311 | t_index: int, 312 | ) -> torch.Tensor: 313 | noisy_samples = ( 314 | self.alpha_prod_t_sqrt[t_index] * original_samples 315 | + self.beta_prod_t_sqrt[t_index] * noise 316 | ) 317 | return noisy_samples 318 | 319 | def scheduler_step_batch( 320 | self, 321 | model_pred_batch: torch.Tensor, 322 | x_t_latent_batch: torch.Tensor, 323 | idx: Optional[int] = None, 324 | ) -> torch.Tensor: 325 | # TODO: use t_list to select beta_prod_t_sqrt 326 | if idx is None: 327 | F_theta = ( 328 | x_t_latent_batch - self.beta_prod_t_sqrt * model_pred_batch 329 | ) / self.alpha_prod_t_sqrt 330 | denoised_batch = self.c_out * F_theta + self.c_skip * x_t_latent_batch 331 | else: 332 | F_theta = ( 333 | x_t_latent_batch - self.beta_prod_t_sqrt[idx] * model_pred_batch 334 | ) / self.alpha_prod_t_sqrt[idx] 335 | denoised_batch = ( 336 | self.c_out[idx] * F_theta + self.c_skip[idx] * x_t_latent_batch 337 | ) 338 | 339 | return denoised_batch 340 | 341 | def unet_step( 342 | self, 343 | x_t_latent: torch.Tensor, 344 | t_list: Union[torch.Tensor, list[int]], 345 | idx: Optional[int] = None, 346 | ) -> Tuple[torch.Tensor, torch.Tensor]: 347 | if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"): 348 | x_t_latent_plus_uc = torch.concat([x_t_latent[0:1], x_t_latent], dim=0) 349 | t_list = torch.concat([t_list[0:1], t_list], dim=0) 350 | elif self.guidance_scale > 1.0 and (self.cfg_type == "full"): 351 | x_t_latent_plus_uc = torch.concat([x_t_latent, x_t_latent], dim=0) 352 | t_list = torch.concat([t_list, t_list], dim=0) 353 | else: 354 | x_t_latent_plus_uc = x_t_latent 355 | 356 | model_pred = self.unet( 357 | x_t_latent_plus_uc, 358 | t_list, 359 | encoder_hidden_states=self.prompt_embeds, 360 | return_dict=False, 361 | )[0] 362 | 363 | if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"): 364 | noise_pred_text = model_pred[1:] 365 | self.stock_noise = torch.concat( 366 | [model_pred[0:1], self.stock_noise[1:]], dim=0 367 | ) # ここコメントアウトでself out cfg 368 | elif self.guidance_scale > 1.0 and (self.cfg_type == "full"): 369 | noise_pred_uncond, noise_pred_text = model_pred.chunk(2) 370 | else: 371 | noise_pred_text = model_pred 372 | if self.guidance_scale > 1.0 and ( 373 | self.cfg_type == "self" or self.cfg_type == "initialize" 374 | ): 375 | noise_pred_uncond = self.stock_noise * self.delta 376 | if self.guidance_scale > 1.0 and self.cfg_type != "none": 377 | model_pred = noise_pred_uncond + self.guidance_scale * ( 378 | noise_pred_text - noise_pred_uncond 379 | ) 380 | else: 381 | model_pred = noise_pred_text 382 | 383 | # compute the previous noisy sample x_t -> x_t-1 384 | if self.use_denoising_batch: 385 | denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx) 386 | if self.cfg_type == "self" or self.cfg_type == "initialize": 387 | scaled_noise = self.beta_prod_t_sqrt * self.stock_noise 388 | delta_x = self.scheduler_step_batch(model_pred, scaled_noise, idx) 389 | alpha_next = torch.concat( 390 | [ 391 | self.alpha_prod_t_sqrt[1:], 392 | torch.ones_like(self.alpha_prod_t_sqrt[0:1]), 393 | ], 394 | dim=0, 395 | ) 396 | delta_x = alpha_next * delta_x 397 | beta_next = torch.concat( 398 | [ 399 | self.beta_prod_t_sqrt[1:], 400 | torch.ones_like(self.beta_prod_t_sqrt[0:1]), 401 | ], 402 | dim=0, 403 | ) 404 | delta_x = delta_x / beta_next 405 | init_noise = torch.concat( 406 | [self.init_noise[1:], self.init_noise[0:1]], dim=0 407 | ) 408 | self.stock_noise = init_noise + delta_x 409 | 410 | else: 411 | # denoised_batch = self.scheduler.step(model_pred, t_list[0], x_t_latent).denoised 412 | denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx) 413 | 414 | return denoised_batch, model_pred 415 | 416 | def encode_image(self, image_tensors: torch.Tensor) -> torch.Tensor: 417 | image_tensors = image_tensors.to( 418 | device=self.device, 419 | dtype=self.vae.dtype, 420 | ) 421 | img_latent = retrieve_latents(self.vae.encode(image_tensors), self.generator) 422 | img_latent = img_latent * self.vae.config.scaling_factor 423 | x_t_latent = self.add_noise(img_latent, self.init_noise[0], 0) 424 | return x_t_latent 425 | 426 | def decode_image(self, x_0_pred_out: torch.Tensor) -> torch.Tensor: 427 | output_latents = self.vae.decode( 428 | x_0_pred_out / self.vae.config.scaling_factor, return_dict=False 429 | ) 430 | output_latent =output_latents[0] 431 | return output_latent 432 | 433 | def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: 434 | prev_latent_batch = self.x_t_latent_buffer 435 | 436 | if self.use_denoising_batch: 437 | t_list = self.sub_timesteps_tensor 438 | if self.denoising_steps_num > 1: 439 | x_t_latent = torch.cat((x_t_latent, prev_latent_batch), dim=0) 440 | self.stock_noise = torch.cat( 441 | (self.init_noise[0:1], self.stock_noise[:-1]), dim=0 442 | ) 443 | x_0_pred_batch, model_pred = self.unet_step(x_t_latent, t_list) 444 | 445 | if self.denoising_steps_num > 1: 446 | x_0_pred_out = x_0_pred_batch[-1].unsqueeze(0) 447 | if self.do_add_noise: 448 | self.x_t_latent_buffer = ( 449 | self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] 450 | + self.beta_prod_t_sqrt[1:] * self.init_noise[1:] 451 | ) 452 | else: 453 | self.x_t_latent_buffer = ( 454 | self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] 455 | ) 456 | else: 457 | x_0_pred_out = x_0_pred_batch 458 | self.x_t_latent_buffer = None 459 | else: 460 | self.init_noise = x_t_latent 461 | for idx, t in enumerate(self.sub_timesteps_tensor): 462 | t = t.view( 463 | 1, 464 | ).repeat( 465 | self.frame_bff_size, 466 | ) 467 | x_0_pred, model_pred = self.unet_step(x_t_latent, t, idx) 468 | if idx < len(self.sub_timesteps_tensor) - 1: 469 | if self.do_add_noise: 470 | x_t_latent = self.alpha_prod_t_sqrt[ 471 | idx + 1 472 | ] * x_0_pred + self.beta_prod_t_sqrt[ 473 | idx + 1 474 | ] * torch.randn_like( 475 | x_0_pred, device=self.device, dtype=self.dtype 476 | ) 477 | else: 478 | x_t_latent = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred 479 | x_0_pred_out = x_0_pred 480 | 481 | return x_0_pred_out 482 | 483 | @torch.no_grad() 484 | def __call__( 485 | self, x: Union[torch.Tensor, PIL.Image.Image, np.ndarray] = None 486 | ) -> torch.Tensor: 487 | start = torch.cuda.Event(enable_timing=True) 488 | end = torch.cuda.Event(enable_timing=True) 489 | start.record() 490 | if x is not None: 491 | x = self.image_processor.preprocess(x, self.height, self.width).to( 492 | device=self.device, dtype=self.dtype 493 | ) 494 | if self.similar_image_filter: 495 | x = self.similar_filter(x) 496 | if x is None: 497 | time.sleep(self.inference_time_ema) 498 | return self.prev_image_result 499 | x_t_latent = self.encode_image(x) 500 | else: 501 | # TODO: check the dimension of x_t_latent 502 | x_t_latent = torch.randn((1, 4, self.latent_height, self.latent_width)).to( 503 | device=self.device, dtype=self.dtype 504 | ) 505 | x_0_pred_out = self.predict_x0_batch(x_t_latent) 506 | x_output = self.decode_image(x_0_pred_out).detach().clone() 507 | 508 | self.prev_image_result = x_output 509 | end.record() 510 | torch.cuda.synchronize() 511 | inference_time = start.elapsed_time(end) / 1000 512 | self.inference_time_ema = 0.9 * self.inference_time_ema + 0.1 * inference_time 513 | return x_output 514 | 515 | @torch.no_grad() 516 | def sample( 517 | self, x: Union[torch.Tensor, PIL.Image.Image, np.ndarray] = None 518 | ) -> torch.Tensor: 519 | start = torch.cuda.Event(enable_timing=True) 520 | end = torch.cuda.Event(enable_timing=True) 521 | start.record() 522 | if x is not None: 523 | x = self.image_processor.preprocess(x, self.height, self.width).to( 524 | device=self.device, dtype=self.dtype 525 | ) 526 | if self.similar_image_filter: 527 | x = self.similar_filter(x) 528 | if x is None: 529 | time.sleep(self.inference_time_ema) 530 | return self.prev_image_result 531 | x_t_latent = self.encode_image(x) 532 | b,c,h,w=x_t_latent.shape 533 | 534 | # x_t_latent=x_t_latent.repeat((2, 1,1,1)) 535 | else: 536 | # TODO: check the dimension of x_t_latent 537 | x_t_latent = torch.randn((self.frame_bff_size, 4, self.latent_height, self.latent_width)).to( 538 | device=self.device, dtype=self.dtype 539 | ) 540 | x_0_pred_out = self.predict_x0_batch(x_t_latent) 541 | x_output = self.decode_image(x_0_pred_out).detach().clone() 542 | self.prev_image_result = x_output 543 | end.record() 544 | torch.cuda.synchronize() 545 | inference_time = start.elapsed_time(end) / 1000 546 | self.inference_time_ema = 0.9 * self.inference_time_ema + 0.1 * inference_time 547 | return x_output 548 | 549 | 550 | @torch.no_grad() 551 | def txt2img(self, batch_size: int = 1) -> torch.Tensor: 552 | x_0_pred_out = self.predict_x0_batch( 553 | torch.randn((batch_size, 4, self.latent_height, self.latent_width)).to( 554 | device=self.device, dtype=self.dtype 555 | ) 556 | ) 557 | x_output = self.decode_image(x_0_pred_out).detach().clone() 558 | return x_output 559 | 560 | def txt2img_sd_turbo(self, batch_size: int = 1) -> torch.Tensor: 561 | x_t_latent = torch.randn( 562 | (batch_size, 4, self.latent_height, self.latent_width), 563 | device=self.device, 564 | dtype=self.dtype, 565 | ) 566 | model_pred = self.unet( 567 | x_t_latent, 568 | self.sub_timesteps_tensor, 569 | encoder_hidden_states=self.prompt_embeds, 570 | return_dict=False, 571 | )[0] 572 | x_0_pred_out = ( 573 | x_t_latent - self.beta_prod_t_sqrt * model_pred 574 | ) / self.alpha_prod_t_sqrt 575 | return self.decode_image(x_0_pred_out) 576 | -------------------------------------------------------------------------------- /streamdiffusion/wrapper.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | from pathlib import Path 4 | import traceback 5 | from typing import List, Literal, Optional, Union, Dict 6 | 7 | import numpy as np 8 | import torch 9 | from diffusers import AutoencoderTiny, StableDiffusionPipeline 10 | from PIL import Image 11 | 12 | from .pipeline import StreamDiffusion 13 | from .image_utils import postprocess_image 14 | 15 | 16 | torch.set_grad_enabled(False) 17 | torch.backends.cuda.matmul.allow_tf32 = True 18 | torch.backends.cudnn.allow_tf32 = True 19 | 20 | 21 | class StreamDiffusionWrapper: 22 | def __init__( 23 | self, 24 | model_id_or_path: str, 25 | t_index_list: List[int], 26 | lora_dict: Optional[Dict[str, float]] = None, 27 | mode: Literal["img2img", "txt2img"] = "img2img", 28 | output_type: Literal["pil", "pt", "np", "latent"] = "pil", 29 | lcm_lora_id: Optional[str] = None, 30 | vae_id: Optional[str] = None, 31 | device: Literal["cpu", "cuda"] = "cuda", 32 | dtype: torch.dtype = torch.float16, 33 | frame_buffer_size: int = 1, 34 | width: int = 512, 35 | height: int = 512, 36 | warmup: int = 10, 37 | acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt", 38 | do_add_noise: bool = True, 39 | device_ids: Optional[List[int]] = None, 40 | use_lcm_lora: bool = True, 41 | use_tiny_vae: bool = True, 42 | enable_similar_image_filter: bool = False, 43 | similar_image_filter_threshold: float = 0.98, 44 | similar_image_filter_max_skip_frame: int = 10, 45 | use_denoising_batch: bool = True, 46 | cfg_type: Literal["none", "full", "self", "initialize"] = "self", 47 | seed: int = 2, 48 | use_safety_checker: bool = False, 49 | ): 50 | """ 51 | Initializes the StreamDiffusionWrapper. 52 | 53 | Parameters 54 | ---------- 55 | model_id_or_path : str 56 | The model id or path to load. 57 | t_index_list : List[int] 58 | The t_index_list to use for inference. 59 | lora_dict : Optional[Dict[str, float]], optional 60 | The lora_dict to load, by default None. 61 | Keys are the LoRA names and values are the LoRA scales. 62 | Example: {'LoRA_1' : 0.5 , 'LoRA_2' : 0.7 ,...} 63 | mode : Literal["img2img", "txt2img"], optional 64 | txt2img or img2img, by default "img2img". 65 | output_type : Literal["pil", "pt", "np", "latent"], optional 66 | The output type of image, by default "pil". 67 | lcm_lora_id : Optional[str], optional 68 | The lcm_lora_id to load, by default None. 69 | If None, the default LCM-LoRA 70 | ("latent-consistency/lcm-lora-sdv1-5") will be used. 71 | vae_id : Optional[str], optional 72 | The vae_id to load, by default None. 73 | If None, the default TinyVAE 74 | ("madebyollin/taesd") will be used. 75 | device : Literal["cpu", "cuda"], optional 76 | The device to use for inference, by default "cuda". 77 | dtype : torch.dtype, optional 78 | The dtype for inference, by default torch.float16. 79 | frame_buffer_size : int, optional 80 | The frame buffer size for denoising batch, by default 1. 81 | width : int, optional 82 | The width of the image, by default 512. 83 | height : int, optional 84 | The height of the image, by default 512. 85 | warmup : int, optional 86 | The number of warmup steps to perform, by default 10. 87 | acceleration : Literal["none", "xformers", "tensorrt"], optional 88 | The acceleration method, by default "tensorrt". 89 | do_add_noise : bool, optional 90 | Whether to add noise for following denoising steps or not, 91 | by default True. 92 | device_ids : Optional[List[int]], optional 93 | The device ids to use for DataParallel, by default None. 94 | use_lcm_lora : bool, optional 95 | Whether to use LCM-LoRA or not, by default True. 96 | use_tiny_vae : bool, optional 97 | Whether to use TinyVAE or not, by default True. 98 | enable_similar_image_filter : bool, optional 99 | Whether to enable similar image filter or not, 100 | by default False. 101 | similar_image_filter_threshold : float, optional 102 | The threshold for similar image filter, by default 0.98. 103 | similar_image_filter_max_skip_frame : int, optional 104 | The max skip frame for similar image filter, by default 10. 105 | use_denoising_batch : bool, optional 106 | Whether to use denoising batch or not, by default True. 107 | cfg_type : Literal["none", "full", "self", "initialize"], 108 | optional 109 | The cfg_type for img2img mode, by default "self". 110 | You cannot use anything other than "none" for txt2img mode. 111 | seed : int, optional 112 | The seed, by default 2. 113 | use_safety_checker : bool, optional 114 | Whether to use safety checker or not, by default False. 115 | """ 116 | self.sd_turbo = "turbo" in model_id_or_path 117 | 118 | if mode == "txt2img": 119 | if cfg_type != "none": 120 | raise ValueError( 121 | f"txt2img mode accepts only cfg_type = 'none', but got {cfg_type}" 122 | ) 123 | if use_denoising_batch and frame_buffer_size > 1: 124 | if not self.sd_turbo: 125 | raise ValueError( 126 | "txt2img mode cannot use denoising batch with frame_buffer_size > 1." 127 | ) 128 | 129 | if mode == "img2img": 130 | if not use_denoising_batch: 131 | raise NotImplementedError( 132 | "img2img mode must use denoising batch for now." 133 | ) 134 | 135 | self.device = device 136 | self.dtype = dtype 137 | self.width = width 138 | self.height = height 139 | self.mode = mode 140 | self.output_type = output_type 141 | self.frame_buffer_size = frame_buffer_size 142 | self.batch_size = ( 143 | len(t_index_list) * frame_buffer_size 144 | if use_denoising_batch 145 | else frame_buffer_size 146 | ) 147 | self.t_index_list =t_index_list 148 | self.cfg_type =cfg_type 149 | self.use_denoising_batch = use_denoising_batch 150 | self.use_safety_checker = use_safety_checker 151 | self.do_add_noise =do_add_noise 152 | self.seed=seed 153 | 154 | self.stream: StreamDiffusion = self._load_model( 155 | model_id_or_path=model_id_or_path, 156 | lora_dict=lora_dict, 157 | lcm_lora_id=lcm_lora_id, 158 | vae_id=vae_id, 159 | t_index_list=t_index_list, 160 | acceleration=acceleration, 161 | warmup=warmup, 162 | do_add_noise=do_add_noise, 163 | use_lcm_lora=use_lcm_lora, 164 | use_tiny_vae=use_tiny_vae, 165 | cfg_type=cfg_type, 166 | seed=seed, 167 | ) 168 | 169 | if device_ids is not None: 170 | self.stream.unet = torch.nn.DataParallel( 171 | self.stream.unet, device_ids=device_ids 172 | ) 173 | 174 | if enable_similar_image_filter: 175 | self.stream.enable_similar_image_filter(similar_image_filter_threshold, similar_image_filter_max_skip_frame) 176 | 177 | def prepare( 178 | self, 179 | prompt: str, 180 | negative_prompt: str = "", 181 | num_inference_steps: int = 50, 182 | guidance_scale: float = 1.2, 183 | delta: float = 1.0, 184 | t_index_list: List[int]=[16,32,45], 185 | do_add_noise: bool = True, 186 | enable_similar_image_filter: bool = False, 187 | similar_image_filter_threshold: float = 0.98, 188 | similar_image_filter_max_skip_frame: int = 10, 189 | use_denoising_batch: bool = True, 190 | cfg_type: Literal["none", "full", "self", "initialize"] = "self", 191 | seed: int = 2, 192 | frame_buffer_size:int=1, 193 | use_safety_checker: bool = False, 194 | ) -> None: 195 | """ 196 | Prepares the model for inference. 197 | 198 | Parameters 199 | ---------- 200 | prompt : str 201 | The prompt to generate images from. 202 | num_inference_steps : int, optional 203 | The number of inference steps to perform, by default 50. 204 | guidance_scale : float, optional 205 | The guidance scale to use, by default 1.2. 206 | delta : float, optional 207 | The delta multiplier of virtual residual noise, 208 | by default 1.0. 209 | """ 210 | # self.stream.prepare( 211 | # prompt, 212 | # negative_prompt, 213 | # num_inference_steps=num_inference_steps, 214 | # guidance_scale=guidance_scale, 215 | # delta=delta, 216 | # ) 217 | self.prompt =prompt 218 | self.negative_prompt=negative_prompt 219 | self.num_inference_steps=num_inference_steps 220 | self.guidance_scale=guidance_scale 221 | self.delta=delta 222 | 223 | self.frame_buffer_size = frame_buffer_size 224 | self.batch_size = ( 225 | len(t_index_list) * frame_buffer_size 226 | if use_denoising_batch 227 | else frame_buffer_size 228 | ) 229 | self.t_index_list =t_index_list 230 | self.cfg_type =cfg_type 231 | self.use_denoising_batch = use_denoising_batch 232 | self.use_safety_checker = use_safety_checker 233 | self.do_add_noise =do_add_noise 234 | self.seed=seed 235 | 236 | if enable_similar_image_filter: 237 | self.stream.enable_similar_image_filter(similar_image_filter_threshold, similar_image_filter_max_skip_frame) 238 | else: 239 | self.stream.disable_similar_image_filter() 240 | 241 | if self.use_safety_checker: 242 | if self.safety_checker==None or self.feature_extractor==None: 243 | from transformers import CLIPFeatureExtractor 244 | from diffusers.pipelines.stable_diffusion.safety_checker import ( 245 | StableDiffusionSafetyChecker, 246 | ) 247 | 248 | self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( 249 | "CompVis/stable-diffusion-safety-checker" 250 | ).to(self.stream.device) 251 | self.feature_extractor = CLIPFeatureExtractor.from_pretrained( 252 | "openai/clip-vit-base-patch32" 253 | ) 254 | self.nsfw_fallback_img = Image.new("RGB", (512, 512), (0, 0, 0)) 255 | 256 | def __call__( 257 | self, 258 | image: Optional[Union[str, Image.Image, torch.Tensor]] = None, 259 | prompt: Optional[str] = None, 260 | ) -> Union[Image.Image, List[Image.Image]]: 261 | """ 262 | Performs img2img or txt2img based on the mode. 263 | 264 | Parameters 265 | ---------- 266 | image : Optional[Union[str, Image.Image, torch.Tensor]] 267 | The image to generate from. 268 | prompt : Optional[str] 269 | The prompt to generate images from. 270 | 271 | Returns 272 | ------- 273 | Union[Image.Image, List[Image.Image]] 274 | The generated image. 275 | """ 276 | if self.mode == "img2img": 277 | return self.img2img(image) 278 | else: 279 | return self.txt2img(prompt) 280 | 281 | def sample(self, image: Optional[Union[str, Image.Image, torch.Tensor]] = None, 282 | prompt: Optional[str] = None,negative_prompt: Optional[str] = None)-> List[Image.Image]: 283 | 284 | use_denoising_batch=self.use_denoising_batch 285 | if not image == None: 286 | #图生图 287 | if isinstance(image, str) or isinstance(image, Image.Image): 288 | image = self.preprocess_image(image) 289 | 290 | use_denoising_batch = True 291 | self.stream.set_sampler_param(t_index_list=self.t_index_list, 292 | width=self.width, 293 | height=self.height, 294 | do_add_noise=self.do_add_noise, 295 | use_denoising_batch=use_denoising_batch, 296 | frame_buffer_size=self.frame_buffer_size, 297 | cfg_type=self.cfg_type) 298 | else: 299 | #文生图 300 | if self.frame_buffer_size >1 and self.use_denoising_batch: 301 | use_denoising_batch = False 302 | self.stream.set_sampler_param(t_index_list=self.t_index_list, 303 | width=self.width, 304 | height=self.height, 305 | do_add_noise=self.do_add_noise, 306 | use_denoising_batch=use_denoising_batch, 307 | frame_buffer_size=self.frame_buffer_size, 308 | cfg_type='none') 309 | 310 | self.stream.prepare( 311 | prompt=self.prompt, 312 | negative_prompt=self.negative_prompt, 313 | num_inference_steps=self.num_inference_steps, 314 | guidance_scale=self.guidance_scale, 315 | delta=self.delta, 316 | seed=self.seed, 317 | ) 318 | 319 | if prompt is not None: 320 | self.stream.update_prompt(prompt,negative_prompt) 321 | 322 | self.batch_size = ( 323 | len(self.t_index_list) * self.frame_buffer_size 324 | if use_denoising_batch 325 | else self.frame_buffer_size 326 | ) 327 | 328 | if self.frame_buffer_size==1: 329 | for _ in range(self.batch_size): 330 | self.stream.sample(image) 331 | 332 | image_tensor = self.stream.sample(image) 333 | image = postprocess_image(image_tensor.cpu(), output_type=self.output_type) 334 | 335 | if self.use_safety_checker: 336 | safety_checker_input = self.feature_extractor( 337 | image, return_tensors="pt" 338 | ).to(self.device) 339 | _, has_nsfw_concept = self.safety_checker( 340 | images=image_tensor.to(self.dtype), 341 | clip_input=safety_checker_input.pixel_values.to(self.dtype), 342 | ) 343 | image = self.nsfw_fallback_img if has_nsfw_concept[0] else image 344 | return image 345 | 346 | def txt2img( 347 | self, prompt: Optional[str] = None 348 | ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]: 349 | """ 350 | Performs txt2img. 351 | 352 | Parameters 353 | ---------- 354 | prompt : Optional[str] 355 | The prompt to generate images from. 356 | 357 | Returns 358 | ------- 359 | Union[Image.Image, List[Image.Image]] 360 | The generated image. 361 | """ 362 | if prompt is not None: 363 | self.stream.update_prompt(prompt) 364 | 365 | if self.sd_turbo: 366 | image_tensor = self.stream.txt2img_sd_turbo(self.batch_size) 367 | else: 368 | image_tensor = self.stream.txt2img(self.frame_buffer_size) 369 | image = self.postprocess_image(image_tensor, output_type=self.output_type) 370 | 371 | if self.use_safety_checker: 372 | safety_checker_input = self.feature_extractor( 373 | image, return_tensors="pt" 374 | ).to(self.device) 375 | _, has_nsfw_concept = self.safety_checker( 376 | images=image_tensor.to(self.dtype), 377 | clip_input=safety_checker_input.pixel_values.to(self.dtype), 378 | ) 379 | image = self.nsfw_fallback_img if has_nsfw_concept[0] else image 380 | 381 | return image 382 | 383 | def img2img( 384 | self, image: Union[str, Image.Image, torch.Tensor] 385 | ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]: 386 | """ 387 | Performs img2img. 388 | 389 | Parameters 390 | ---------- 391 | image : Union[str, Image.Image, torch.Tensor] 392 | The image to generate from. 393 | 394 | Returns 395 | ------- 396 | Image.Image 397 | The generated image. 398 | """ 399 | if isinstance(image, str) or isinstance(image, Image.Image): 400 | image = self.preprocess_image(image) 401 | 402 | image_tensor = self.stream(image) 403 | image = self.postprocess_image(image_tensor, output_type=self.output_type) 404 | 405 | if self.use_safety_checker: 406 | safety_checker_input = self.feature_extractor( 407 | image, return_tensors="pt" 408 | ).to(self.device) 409 | _, has_nsfw_concept = self.safety_checker( 410 | images=image_tensor.to(self.dtype), 411 | clip_input=safety_checker_input.pixel_values.to(self.dtype), 412 | ) 413 | image = self.nsfw_fallback_img if has_nsfw_concept[0] else image 414 | 415 | return image 416 | 417 | def preprocess_image(self, image: Union[str, Image.Image]) -> torch.Tensor: 418 | """ 419 | Preprocesses the image. 420 | 421 | Parameters 422 | ---------- 423 | image : Union[str, Image.Image, torch.Tensor] 424 | The image to preprocess. 425 | 426 | Returns 427 | ------- 428 | torch.Tensor 429 | The preprocessed image. 430 | """ 431 | if isinstance(image, str): 432 | image = Image.open(image).convert("RGB").resize((self.width, self.height)) 433 | if isinstance(image, Image.Image): 434 | image = image.convert("RGB").resize((self.width, self.height)) 435 | 436 | return self.stream.image_processor.preprocess( 437 | image, self.height, self.width 438 | ).to(device=self.device, dtype=self.dtype) 439 | 440 | def postprocess_image( 441 | self, image_tensor: torch.Tensor, output_type: str = "pil" 442 | ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]: 443 | """ 444 | Postprocesses the image. 445 | 446 | Parameters 447 | ---------- 448 | image_tensor : torch.Tensor 449 | The image tensor to postprocess. 450 | 451 | Returns 452 | ------- 453 | Union[Image.Image, List[Image.Image]] 454 | The postprocessed image. 455 | """ 456 | if self.frame_buffer_size > 1: 457 | return postprocess_image(image_tensor.cpu(), output_type=output_type) 458 | else: 459 | return postprocess_image(image_tensor.cpu(), output_type=output_type)[0] 460 | 461 | def _load_model( 462 | self, 463 | model_id_or_path: str, 464 | t_index_list: List[int], 465 | lora_dict: Optional[Dict[str, float]] = None, 466 | lcm_lora_id: Optional[str] = None, 467 | vae_id: Optional[str] = None, 468 | acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt", 469 | warmup: int = 10, 470 | do_add_noise: bool = True, 471 | use_lcm_lora: bool = True, 472 | use_tiny_vae: bool = True, 473 | cfg_type: Literal["none", "full", "self", "initialize"] = "self", 474 | seed: int = 2, 475 | ) -> StreamDiffusion: 476 | """ 477 | Loads the model. 478 | 479 | This method does the following: 480 | 481 | 1. Loads the model from the model_id_or_path. 482 | 2. Loads and fuses the LCM-LoRA model from the lcm_lora_id if needed. 483 | 3. Loads the VAE model from the vae_id if needed. 484 | 4. Enables acceleration if needed. 485 | 5. Prepares the model for inference. 486 | 6. Load the safety checker if needed. 487 | 488 | Parameters 489 | ---------- 490 | model_id_or_path : str 491 | The model id or path to load. 492 | t_index_list : List[int] 493 | The t_index_list to use for inference. 494 | lora_dict : Optional[Dict[str, float]], optional 495 | The lora_dict to load, by default None. 496 | Keys are the LoRA names and values are the LoRA scales. 497 | Example: {'LoRA_1' : 0.5 , 'LoRA_2' : 0.7 ,...} 498 | lcm_lora_id : Optional[str], optional 499 | The lcm_lora_id to load, by default None. 500 | vae_id : Optional[str], optional 501 | The vae_id to load, by default None. 502 | acceleration : Literal["none", "xfomers", "sfast", "tensorrt"], optional 503 | The acceleration method, by default "tensorrt". 504 | warmup : int, optional 505 | The number of warmup steps to perform, by default 10. 506 | do_add_noise : bool, optional 507 | Whether to add noise for following denoising steps or not, 508 | by default True. 509 | use_lcm_lora : bool, optional 510 | Whether to use LCM-LoRA or not, by default True. 511 | use_tiny_vae : bool, optional 512 | Whether to use TinyVAE or not, by default True. 513 | cfg_type : Literal["none", "full", "self", "initialize"], 514 | optional 515 | The cfg_type for img2img mode, by default "self". 516 | You cannot use anything other than "none" for txt2img mode. 517 | seed : int, optional 518 | The seed, by default 2. 519 | 520 | Returns 521 | ------- 522 | StreamDiffusion 523 | The loaded model. 524 | """ 525 | 526 | try: # Load from local directory 527 | pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained( 528 | model_id_or_path, 529 | ).to(device=self.device, dtype=self.dtype) 530 | 531 | except ValueError: # Load from huggingface 532 | pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file( 533 | model_id_or_path, 534 | ).to(device=self.device, dtype=self.dtype) 535 | except Exception: # No model found 536 | traceback.print_exc() 537 | print("Model load has failed. Doesn't exist.") 538 | exit() 539 | 540 | stream = StreamDiffusion( 541 | pipe=pipe, 542 | t_index_list=t_index_list, 543 | torch_dtype=self.dtype, 544 | width=self.width, 545 | height=self.height, 546 | do_add_noise=do_add_noise, 547 | frame_buffer_size=self.frame_buffer_size, 548 | use_denoising_batch=self.use_denoising_batch, 549 | cfg_type=cfg_type, 550 | ) 551 | if not self.sd_turbo: 552 | if use_lcm_lora: 553 | if lcm_lora_id is not None: 554 | stream.load_lcm_lora( 555 | pretrained_model_name_or_path_or_dict=lcm_lora_id 556 | ) 557 | else: 558 | stream.load_lcm_lora() 559 | stream.fuse_lora() 560 | 561 | if lora_dict is not None: 562 | for lora_name, lora_scale in lora_dict.items(): 563 | stream.load_lora(lora_name) 564 | stream.fuse_lora(lora_scale=lora_scale) 565 | print(f"Use LoRA: {lora_name} in weights {lora_scale}") 566 | 567 | if use_tiny_vae: 568 | if vae_id is not None: 569 | stream.vae = AutoencoderTiny.from_pretrained(vae_id).to( 570 | device=pipe.device, dtype=pipe.dtype 571 | ) 572 | else: 573 | stream.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd").to( 574 | device=pipe.device, dtype=pipe.dtype 575 | ) 576 | 577 | try: 578 | if acceleration == "xformers": 579 | stream.pipe.enable_xformers_memory_efficient_attention() 580 | if acceleration == "tensorrt": 581 | from polygraphy import cuda 582 | from streamdiffusion.acceleration.tensorrt import ( 583 | TorchVAEEncoder, 584 | compile_unet, 585 | compile_vae_decoder, 586 | compile_vae_encoder, 587 | ) 588 | from streamdiffusion.acceleration.tensorrt.engine import ( 589 | AutoencoderKLEngine, 590 | UNet2DConditionModelEngine, 591 | ) 592 | from streamdiffusion.acceleration.tensorrt.models import ( 593 | VAE, 594 | UNet, 595 | VAEEncoder, 596 | ) 597 | 598 | def create_prefix( 599 | model_id_or_path: str, 600 | max_batch_size: int, 601 | min_batch_size: int, 602 | ): 603 | maybe_path = Path(model_id_or_path) 604 | if maybe_path.exists(): 605 | return f"{maybe_path.stem}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--max_batch-{max_batch_size}--min_batch-{min_batch_size}--mode-{self.mode}" 606 | else: 607 | return f"{model_id_or_path}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--max_batch-{max_batch_size}--min_batch-{min_batch_size}--mode-{self.mode}" 608 | 609 | engine_dir = os.path.join("engines") 610 | unet_path = os.path.join( 611 | engine_dir, 612 | create_prefix( 613 | model_id_or_path=model_id_or_path, 614 | max_batch_size=stream.trt_unet_batch_size, 615 | min_batch_size=stream.trt_unet_batch_size, 616 | ), 617 | "unet.engine", 618 | ) 619 | vae_encoder_path = os.path.join( 620 | engine_dir, 621 | create_prefix( 622 | model_id_or_path=model_id_or_path, 623 | max_batch_size=self.batch_size 624 | if self.mode == "txt2img" 625 | else stream.frame_bff_size, 626 | min_batch_size=self.batch_size 627 | if self.mode == "txt2img" 628 | else stream.frame_bff_size, 629 | ), 630 | "vae_encoder.engine", 631 | ) 632 | vae_decoder_path = os.path.join( 633 | engine_dir, 634 | create_prefix( 635 | model_id_or_path=model_id_or_path, 636 | max_batch_size=self.batch_size 637 | if self.mode == "txt2img" 638 | else stream.frame_bff_size, 639 | min_batch_size=self.batch_size 640 | if self.mode == "txt2img" 641 | else stream.frame_bff_size, 642 | ), 643 | "vae_decoder.engine", 644 | ) 645 | 646 | if not os.path.exists(unet_path): 647 | os.makedirs(os.path.dirname(unet_path), exist_ok=True) 648 | unet_model = UNet( 649 | fp16=True, 650 | device=stream.device, 651 | max_batch_size=stream.trt_unet_batch_size, 652 | min_batch_size=stream.trt_unet_batch_size, 653 | embedding_dim=stream.text_encoder.config.hidden_size, 654 | unet_dim=stream.unet.config.in_channels, 655 | ) 656 | compile_unet( 657 | stream.unet, 658 | unet_model, 659 | unet_path + ".onnx", 660 | unet_path + ".opt.onnx", 661 | unet_path, 662 | opt_batch_size=stream.trt_unet_batch_size, 663 | ) 664 | 665 | if not os.path.exists(vae_decoder_path): 666 | os.makedirs(os.path.dirname(vae_decoder_path), exist_ok=True) 667 | stream.vae.forward = stream.vae.decode 668 | vae_decoder_model = VAE( 669 | device=stream.device, 670 | max_batch_size=self.batch_size 671 | if self.mode == "txt2img" 672 | else stream.frame_bff_size, 673 | min_batch_size=self.batch_size 674 | if self.mode == "txt2img" 675 | else stream.frame_bff_size, 676 | ) 677 | compile_vae_decoder( 678 | stream.vae, 679 | vae_decoder_model, 680 | vae_decoder_path + ".onnx", 681 | vae_decoder_path + ".opt.onnx", 682 | vae_decoder_path, 683 | opt_batch_size=self.batch_size 684 | if self.mode == "txt2img" 685 | else stream.frame_bff_size, 686 | ) 687 | delattr(stream.vae, "forward") 688 | 689 | if not os.path.exists(vae_encoder_path): 690 | os.makedirs(os.path.dirname(vae_encoder_path), exist_ok=True) 691 | vae_encoder = TorchVAEEncoder(stream.vae).to(torch.device("cuda")) 692 | vae_encoder_model = VAEEncoder( 693 | device=stream.device, 694 | max_batch_size=self.batch_size 695 | if self.mode == "txt2img" 696 | else stream.frame_bff_size, 697 | min_batch_size=self.batch_size 698 | if self.mode == "txt2img" 699 | else stream.frame_bff_size, 700 | ) 701 | compile_vae_encoder( 702 | vae_encoder, 703 | vae_encoder_model, 704 | vae_encoder_path + ".onnx", 705 | vae_encoder_path + ".opt.onnx", 706 | vae_encoder_path, 707 | opt_batch_size=self.batch_size 708 | if self.mode == "txt2img" 709 | else stream.frame_bff_size, 710 | ) 711 | 712 | cuda_steram = cuda.Stream() 713 | 714 | vae_config = stream.vae.config 715 | vae_dtype = stream.vae.dtype 716 | 717 | stream.unet = UNet2DConditionModelEngine( 718 | unet_path, cuda_steram, use_cuda_graph=False 719 | ) 720 | stream.vae = AutoencoderKLEngine( 721 | vae_encoder_path, 722 | vae_decoder_path, 723 | cuda_steram, 724 | stream.pipe.vae_scale_factor, 725 | use_cuda_graph=False, 726 | ) 727 | setattr(stream.vae, "config", vae_config) 728 | setattr(stream.vae, "dtype", vae_dtype) 729 | 730 | gc.collect() 731 | torch.cuda.empty_cache() 732 | 733 | print("TensorRT acceleration enabled.") 734 | if acceleration == "sfast": 735 | from streamdiffusion.acceleration.sfast import ( 736 | accelerate_with_stable_fast, 737 | ) 738 | 739 | stream = accelerate_with_stable_fast(stream) 740 | print("StableFast acceleration enabled.") 741 | except Exception: 742 | traceback.print_exc() 743 | print("Acceleration has failed. Falling back to normal mode.") 744 | 745 | if seed < 0: # Random seed 746 | seed = np.random.randint(0, 1000000) 747 | 748 | stream.prepare( 749 | "", 750 | "", 751 | num_inference_steps=50, 752 | guidance_scale=1.1 753 | if stream.cfg_type in ["full", "self", "initialize"] 754 | else 1.0, 755 | generator=torch.manual_seed(seed), 756 | seed=seed, 757 | ) 758 | 759 | if self.use_safety_checker: 760 | from transformers import CLIPFeatureExtractor 761 | from diffusers.pipelines.stable_diffusion.safety_checker import ( 762 | StableDiffusionSafetyChecker, 763 | ) 764 | 765 | self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( 766 | "CompVis/stable-diffusion-safety-checker" 767 | ).to(pipe.device) 768 | self.feature_extractor = CLIPFeatureExtractor.from_pretrained( 769 | "openai/clip-vit-base-patch32" 770 | ) 771 | self.nsfw_fallback_img = Image.new("RGB", (512, 512), (0, 0, 0)) 772 | 773 | return stream 774 | --------------------------------------------------------------------------------