├── .DS_Store ├── .github └── workflows │ └── publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── hi_diffusers ├── .DS_Store ├── models │ ├── attention.py │ ├── attention_processor.py │ ├── embeddings.py │ ├── moe.py │ └── transformers │ │ └── transformer_hidream_image.py ├── pipelines │ ├── .DS_Store │ └── hidream_image │ │ ├── pipeline_hidream_image.py │ │ ├── pipeline_hidream_image_to_image.py │ │ └── pipeline_output.py └── schedulers │ ├── flash_flow_match.py │ └── fm_solvers_unipc.py ├── hidreamsampler.py ├── pyproject.toml ├── requirements.txt └── sample_workflow ├── ComfyUI HiDream GGUF Simple.json ├── Sample HiDream Sampler Workflow.json └── workflow.png /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lum3on/comfyui_HiDream-Sampler/98ad017cac93b782e2af95411e4c10d493ecb841/.DS_Store -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | permissions: 11 | issues: write 12 | 13 | jobs: 14 | publish-node: 15 | name: Publish Custom Node to registry 16 | runs-on: ubuntu-latest 17 | if: ${{ github.repository_owner == 'lum3on' }} 18 | steps: 19 | - name: Check out code 20 | uses: actions/checkout@v4 21 | with: 22 | submodules: true 23 | - name: Publish Custom Node 24 | uses: Comfy-Org/publish-node-action@v1 25 | with: 26 | ## Add your own personal access token to your Github Repository secrets and reference it here. 27 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | 3 | .vscode/* 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 lum3on 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## [comfyanonymous](https://github.com/comfyanonymous/ComfyUI/commit/9ad792f92706e2179c58b2e5348164acafa69288) is now supporting HiDream natively, so I will focus on improving the HiDream Advanced Node and abandon the dependency of diffusers soon(TM). 2 | 3 | In the meanwhile update your comfy and you can test the model with less struggle you might have installing this node ;) 4 | You need their official [HiDream Model](https://huggingface.co/city96/HiDream-I1-Dev-gguf/tree/main) and the four [text encoders](https://huggingface.co/Comfy-Org/HiDream-I1_ComfyUI/tree/main/split_files/text_encoders). 5 | Sample Workflow is in the Workflow folder 6 | 7 | ## Announcement 8 | 9 | We have now a bunch of forks for this, and with [SanDiegoDude/ComfyUI-HiDream-Sampler](https://github.com/SanDiegoDude/ComfyUI-HiDream-Sampler/) actively maintaining, feel free to check out his awsome work, which might solve some issues. 10 | 11 | I will have a bit of time over easter to work on the node, but it will be updated at my pace and capabilities. 12 | 13 | If anyone would like to contribute, I'd LOVE you to reach out or do PRs, as I am not able to solve all the issues alone (Used to be an [illustrator](https://benjaminbertram.com/) not a dev by trade :D ) 14 | 15 | What is on my list and where you could support me: 16 | 17 | - [ ] **Edit capabilites:** Integrate [editing capabilities](https://github.com/HiDream-ai/HiDream-E1) 18 | - [ ] **GGUF Support:** Integrate [Calcuis Model](https://huggingface.co/calcuis/hidream-gguf/tree/main) 19 | - [ ] **Local Checkpoints:** Get rid of hugging face download logic 20 | - [ ] **Bat installer:** For the standalone/ windows/ portable users of ComfyUI 21 | - [ ] **Make installation overall easier:** currently many users have problems with the installation process. Meanwhile try to follow this [video](https://www.youtube.com/watch?v=KRnJCLdgdRE) to get you going. 22 | - [ ] **Multi Image for Img2img:** multi in, multi out 23 | - [ ] **Better uncensored LLM:** LLM which does not OOM, so if you have suggestions 24 | - [ ] **Manual attention checker:** At least for the advance mode, choose between sage, sdpa or flash manually 25 | - [ ] **Beautifiy codebase:** Currently a lot of repetition and thus prone to errors 26 | - [ ] **Cancel generation:** Currently generation can't be cancelled with the stop button 27 | - [ ] **System Prompt presets:** Explore good working system prompts for various use cases 28 | - [ ] **Clean up UI:** Clearer naming, tooltips for inputs and rearrange the fields for faster work 29 | - [ ] **Explore Lora:** While the first Loras for HiDream are being explored it would be good to know how to implement them without big performance losses 30 | - [ ] **Explore HiResFix Capabilites via Img2img:** Dig into HiResFix to have a propper Upscale via native HiDream 31 | 32 | ## Updates 14.04. 33 | 34 | - fixed uncensored llm support 35 | - fixed pipelines 36 | - fixed image output 37 | - fixed cache overload 38 | - Added multi image generation (up to 8) (Img2img not yet supported!) 39 | - Fixed Sageattention fetch as first attention method. 40 | - Fixed according to [burgstall](https://github.com/Burgstall-labs): Not require auto-gptq anymore! Make sure to git pull and pip install -r requirements.txt! 41 | - Added Image2image functionality 42 | - Flash Attention is no longer needed thanks to [power88](https://github.com/power88) 43 | - added working uncensored Llama Support (Available via HiDream Sampler Advanced) thanks to [sdujack2012](https://github.com/sdujack2012) but beware its not quantified so you could get an OOM 44 | 45 | ![image](sample_workflow/workflow.png) 46 | 47 | # HiDreamSampler for ComfyUI 48 | 49 | A custom ComfyUI node for generating images using the HiDream AI model. 50 | 51 | ## Features 52 | - Supports `full`, `dev`, and `fast` model types. 53 | - Configurable resolution and inference steps. 54 | - Uses 4-bit quantization for lower memory usage. 55 | 56 | ## Installation 57 | ### Basic installation. 58 | 1. Clone this repository into your `ComfyUI/custom_nodes/` directory: 59 | ```bash 60 | git clone https://github.com/lum3on/comfyui_HiDream-Sampler ComfyUI/custom_nodes/comfyui_HiDream-Sampler 61 | ``` 62 | 63 | 2. Install requirements 64 | ```bash 65 | pip install -r requirements.txt 66 | ``` 67 | or for the portable version: 68 | ```bash 69 | .\python_embeded\python.exe -m pip install -r .\ComfyUI\custom_nodes\comfyui_HiDream-Sampler\requirements.txt 70 | ``` 71 | 72 | 4. Restart ComfyUI. 73 | 74 | Steps to install SageAttention 1: 75 | - Install triton. 76 | Windows built wheel, [download here](https://huggingface.co/madbuda/triton-windows-builds): 77 | ```bash 78 | .\python_embeded\python.exe -s -m pip install (Your downloaded whl package) 79 | ``` 80 | linux: 81 | ```bash 82 | python3 -m pip install triton 83 | ``` 84 | 85 | - Install sageattention package 86 | ```bash 87 | .\python_embeded\python.exe -s -m pip install sageattention==1.0.6 88 | ``` 89 | PyTorch SDPA is automantically installed when you install PyTorch 2 (ComfyUI Requirement). However, if your torch version is lower than 2. Use this command to update to the latest version. 90 | - linux 91 | ``` 92 | python3 -m pip install torch torchvision torchaudio 93 | ``` 94 | - windows 95 | ``` 96 | .\python_embeded\python.exe -s -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 97 | ``` 98 | 99 | ## Download the weights 100 | Here's some weight that you need to download (Which will be automantically downloaded when running workflow). Please use huggingface-cli to download. 101 | - Llama Text Encoder 102 | 103 | | Model | Huggingface repo | 104 | |------------------------|---------------------------| 105 | | 4-bit Llama text encoder | hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4 | 106 | | Uncensored 4-bit Llama text encoder | John6666/Llama-3.1-8B-Lexi-Uncensored-V2-nf4 | 107 | | Original Llama text encoder | nvidia/Llama-3.1-Nemotron-Nano-8B-v1 | 108 | 109 | - Quantized Diffusion models (Thanks to `azaneko` for the quantized model!) 110 | 111 | | Model | Huggingface repo | 112 | |------------------------|---------------------------| 113 | | 4-bit HiDream Full | azaneko/HiDream-I1-Full-nf4 | 114 | | 4-bit HiDream Dev | azaneko/HiDream-I1-Dev-nf4 | 115 | | 4-bit HiDream Fast | azaneko/HiDream-I1-Fast-nf4 | 116 | 117 | - Full weight diffusion model (optional, not recommend unless you have high VRAM) 118 | 119 | | Model | Huggingface repo | 120 | |------------------------|---------------------------| 121 | | HiDream Full | HiDream-ai/HiDream-I1-Full | 122 | | HiDream Dev | HiDream-ai/HiDream-I1-Dev | 123 | | HiDream Fast | HiDream-ai/HiDream-I1-Fast | 124 | 125 | You can download these weights by this command. 126 | ```shell 127 | huggingface-cli download (Huggingface repo) 128 | ``` 129 | For some region that cannot connect to huggingface. Use this command for mirror. 130 | 131 | Windows CMD 132 | ```shell 133 | set HF_ENDPOINT=https://hf-mirror.com 134 | ``` 135 | Windows Powershell 136 | ```shell 137 | $env:HF_ENDPOINT = "https://hf-mirror.com" 138 | ``` 139 | Linux 140 | ```shell 141 | export HF_ENDPOINT=https://hf-mirror.com 142 | ``` 143 | 144 | ## Usage 145 | - Add the HiDreamSampler node to your workflow. 146 | - Configure inputs: 147 | model_type: Choose full, dev, or fast. 148 | prompt: Enter your text prompt (e.g., "A photo of an astronaut riding a horse on the moon"). 149 | resolution: Select from available options (e.g., "1024 × 1024 (Square)"). 150 | seed: Set a random seed. 151 | override_steps and override_cfg: Optionally override default steps and guidance scale. 152 | - Connect the output to a PreviewImage or SaveImage node. 153 | 154 | ## Requirements 155 | - ComfyUI 156 | - GPU (for model inference) 157 | - Models are cached after the first load to improve performance and use 4-bit quantization models from https://github.com/hykilpikonna/HiDream-I1-nf4. 158 | - Ensure you have sufficient VRAM (e.g., 16GB+ recommended for full mode). 159 | 160 | ## Credits 161 | 162 | Merged with [SanDiegoDude/ComfyUI-HiDream-Sampler](https://github.com/SanDiegoDude/ComfyUI-HiDream-Sampler/) who implemented a cleaner version for my originial NF4 / fp8 support. 163 | 164 | - Added NF4 (Full/Dev/Fast) download and load support 165 | - Added better memory handling 166 | - Added more informative CLI output for TQDM 167 | - Full/Dev/Fast requires roughly 27GB VRAM 168 | - NF4 requires roughly 15GB VRAM 169 | 170 | Build upon the original [HiDream-I1]https://github.com/HiDream-ai/HiDream-I1 171 | 172 | ## Star History 173 | 174 | [![Star History Chart](https://api.star-history.com/svg?repos=lum3on/comfyui_HiDream-Sampler&type=Date)](https://www.star-history.com/#lum3on/comfyui_HiDream-Sampler&Date) 175 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .hidreamsampler import HiDreamSampler, HiDreamSamplerAdvanced, HiDreamImg2Img 2 | 3 | NODE_CLASS_MAPPINGS = { 4 | "HiDreamSampler": HiDreamSampler, 5 | "HiDreamSamplerAdvanced": HiDreamSamplerAdvanced, 6 | "HiDreamImg2Img": HiDreamImg2Img 7 | } 8 | 9 | NODE_DISPLAY_NAME_MAPPINGS = { 10 | "HiDreamSampler": "HiDream Sampler", 11 | "HiDreamSamplerAdvanced": "HiDream Sampler (Advanced)", 12 | "HiDreamImg2Img": "HiDream Image to Image" 13 | } 14 | 15 | WEB_DIRECTORY = "./web" 16 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"] -------------------------------------------------------------------------------- /hi_diffusers/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lum3on/comfyui_HiDream-Sampler/98ad017cac93b782e2af95411e4c10d493ecb841/hi_diffusers/.DS_Store -------------------------------------------------------------------------------- /hi_diffusers/models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import Optional 5 | from diffusers.models.attention_processor import Attention 6 | from diffusers.utils.torch_utils import maybe_allow_in_graph 7 | 8 | # Add RMSNorm if missing from PyTorch (needed for <2.1.0) 9 | if not hasattr(nn, 'RMSNorm'): 10 | print("Adding RMSNorm implementation for compatibility with older PyTorch") 11 | 12 | class RMSNorm(nn.Module): 13 | def __init__(self, dim, eps=1e-6): 14 | super().__init__() 15 | self.eps = eps 16 | self.weight = nn.Parameter(torch.ones(dim)) 17 | 18 | def forward(self, x): 19 | # Calculate RMS 20 | norm_x = torch.mean(x * x, dim=-1, keepdim=True) 21 | x_normed = x * torch.rsqrt(norm_x + self.eps) 22 | return self.weight * x_normed 23 | 24 | # Add to torch.nn for compatibility 25 | torch.nn.RMSNorm = RMSNorm 26 | 27 | @maybe_allow_in_graph 28 | class HiDreamAttention(Attention): 29 | def __init__( 30 | self, 31 | query_dim: int, 32 | heads: int = 8, 33 | dim_head: int = 64, 34 | upcast_attention: bool = False, 35 | upcast_softmax: bool = False, 36 | scale_qk: bool = True, 37 | eps: float = 1e-5, 38 | processor = None, 39 | out_dim: int = None, 40 | single: bool = False 41 | ): 42 | super(Attention, self).__init__() 43 | self.inner_dim = out_dim if out_dim is not None else dim_head * heads 44 | self.query_dim = query_dim 45 | self.upcast_attention = upcast_attention 46 | self.upcast_softmax = upcast_softmax 47 | self.out_dim = out_dim if out_dim is not None else query_dim 48 | 49 | self.scale_qk = scale_qk 50 | self.scale = dim_head**-0.5 if self.scale_qk else 1.0 51 | 52 | self.heads = out_dim // dim_head if out_dim is not None else heads 53 | self.sliceable_head_dim = heads 54 | self.single = single 55 | 56 | linear_cls = nn.Linear 57 | self.linear_cls = linear_cls 58 | self.to_q = linear_cls(query_dim, self.inner_dim) 59 | self.to_k = linear_cls(self.inner_dim, self.inner_dim) 60 | self.to_v = linear_cls(self.inner_dim, self.inner_dim) 61 | self.to_out = linear_cls(self.inner_dim, self.out_dim) 62 | self.q_rms_norm = nn.RMSNorm(self.inner_dim, eps) 63 | self.k_rms_norm = nn.RMSNorm(self.inner_dim, eps) 64 | 65 | if not single: 66 | self.to_q_t = linear_cls(query_dim, self.inner_dim) 67 | self.to_k_t = linear_cls(self.inner_dim, self.inner_dim) 68 | self.to_v_t = linear_cls(self.inner_dim, self.inner_dim) 69 | self.to_out_t = linear_cls(self.inner_dim, self.out_dim) 70 | self.q_rms_norm_t = nn.RMSNorm(self.inner_dim, eps) 71 | self.k_rms_norm_t = nn.RMSNorm(self.inner_dim, eps) 72 | 73 | self.set_processor(processor) 74 | self.apply(self._init_weights) 75 | 76 | def _init_weights(self, m): 77 | if isinstance(m, nn.Linear): 78 | nn.init.xavier_uniform_(m.weight) 79 | if m.bias is not None: 80 | nn.init.constant_(m.bias, 0) 81 | 82 | def forward( 83 | self, 84 | norm_image_tokens: torch.FloatTensor, 85 | image_tokens_masks: torch.FloatTensor = None, 86 | norm_text_tokens: torch.FloatTensor = None, 87 | rope: torch.FloatTensor = None, 88 | ) -> torch.Tensor: 89 | return self.processor( 90 | self, 91 | image_tokens = norm_image_tokens, 92 | image_tokens_masks = image_tokens_masks, 93 | text_tokens = norm_text_tokens, 94 | rope = rope, 95 | ) 96 | 97 | class FeedForwardSwiGLU(nn.Module): 98 | def __init__( 99 | self, 100 | dim: int, 101 | hidden_dim: int, 102 | multiple_of: int = 256, 103 | ffn_dim_multiplier: Optional[float] = None, 104 | ): 105 | super().__init__() 106 | hidden_dim = int(2 * hidden_dim / 3) 107 | # custom dim factor multiplier 108 | if ffn_dim_multiplier is not None: 109 | hidden_dim = int(ffn_dim_multiplier * hidden_dim) 110 | hidden_dim = multiple_of * ( 111 | (hidden_dim + multiple_of - 1) // multiple_of 112 | ) 113 | 114 | self.w1 = nn.Linear(dim, hidden_dim, bias=False) 115 | self.w2 = nn.Linear(hidden_dim, dim, bias=False) 116 | self.w3 = nn.Linear(dim, hidden_dim, bias=False) 117 | self.apply(self._init_weights) 118 | 119 | def _init_weights(self, m): 120 | if isinstance(m, nn.Linear): 121 | nn.init.xavier_uniform_(m.weight) 122 | if m.bias is not None: 123 | nn.init.constant_(m.bias, 0) 124 | 125 | def forward(self, x): 126 | return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) -------------------------------------------------------------------------------- /hi_diffusers/models/attention_processor.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | from .attention import HiDreamAttention 4 | 5 | USE_FLASH_ATTN = False 6 | USE_FLASH_ATTN3 = False 7 | 8 | 9 | # Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py 10 | def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 11 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) 12 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) 13 | xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] 14 | xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] 15 | return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) 16 | 17 | def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): 18 | # Check if we've already determined the attention type 19 | if not hasattr(attention, "_attention_type"): 20 | # Determine which attention implementation to use 21 | attention_type = "sdpa" # default 22 | 23 | try: 24 | import importlib.util 25 | if importlib.util.find_spec("sageattention"): 26 | attention_type = "SageAttention" 27 | elif importlib.util.find_spec("flash_attn"): 28 | attention_type = "FlashAttention2" 29 | elif importlib.util.find_spec("flash_attn_interface"): 30 | attention_type = "FlashAttention3" 31 | except: 32 | pass 33 | 34 | # Cache the result 35 | attention._attention_type = attention_type 36 | print(f"using {attention_type}") 37 | 38 | # Get the cached attention type 39 | attention_type = attention._attention_type 40 | 41 | 42 | # Execute the appropriate attention implementation 43 | if attention_type == "SageAttention": 44 | from sageattention import sageattn 45 | hidden_states = sageattn(query, key, value, tensor_layout="NHD", is_causal=False) 46 | elif attention_type == "FlashAttention2": 47 | from flash_attn import flash_attn_func 48 | hidden_states = flash_attn_func(query, key, value, dropout_p=0., causal=False) 49 | elif attention_type == "FlashAttention3": 50 | from flash_attn_interface import flash_attn_func 51 | hidden_states = flash_attn_func(query, key, value, causal=False, deterministic=False)[0] 52 | elif attention_type == "sdpa": 53 | q = query.transpose(1, 2) 54 | k = key.transpose(1, 2) 55 | v = value.transpose(1, 2) 56 | 57 | hidden_states = torch.nn.functional.scaled_dot_product_attention( 58 | q, k, v, 59 | attn_mask=None, 60 | dropout_p=0.0, 61 | is_causal=False 62 | ) 63 | hidden_states = hidden_states.transpose(1, 2) 64 | else: 65 | raise ValueError("Invalid attention implementation") 66 | 67 | hidden_states = hidden_states.flatten(-2) 68 | hidden_states = hidden_states.to(query.dtype) 69 | return hidden_states 70 | 71 | class HiDreamAttnProcessor_flashattn: 72 | """Attention processor used typically in processing the SD3-like self-attention projections.""" 73 | 74 | def __call__( 75 | self, 76 | attn: HiDreamAttention, 77 | image_tokens: torch.FloatTensor, 78 | image_tokens_masks: Optional[torch.FloatTensor] = None, 79 | text_tokens: Optional[torch.FloatTensor] = None, 80 | rope: torch.FloatTensor = None, 81 | *args, 82 | **kwargs, 83 | ) -> torch.FloatTensor: 84 | dtype = image_tokens.dtype 85 | batch_size = image_tokens.shape[0] 86 | 87 | query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype) 88 | key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype) 89 | value_i = attn.to_v(image_tokens) 90 | 91 | inner_dim = key_i.shape[-1] 92 | head_dim = inner_dim // attn.heads 93 | 94 | query_i = query_i.view(batch_size, -1, attn.heads, head_dim) 95 | key_i = key_i.view(batch_size, -1, attn.heads, head_dim) 96 | value_i = value_i.view(batch_size, -1, attn.heads, head_dim) 97 | if image_tokens_masks is not None: 98 | key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1) 99 | 100 | if not attn.single: 101 | query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype) 102 | key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype) 103 | value_t = attn.to_v_t(text_tokens) 104 | 105 | query_t = query_t.view(batch_size, -1, attn.heads, head_dim) 106 | key_t = key_t.view(batch_size, -1, attn.heads, head_dim) 107 | value_t = value_t.view(batch_size, -1, attn.heads, head_dim) 108 | 109 | num_image_tokens = query_i.shape[1] 110 | num_text_tokens = query_t.shape[1] 111 | query = torch.cat([query_i, query_t], dim=1) 112 | key = torch.cat([key_i, key_t], dim=1) 113 | value = torch.cat([value_i, value_t], dim=1) 114 | else: 115 | query = query_i 116 | key = key_i 117 | value = value_i 118 | 119 | if query.shape[-1] == rope.shape[-3] * 2: 120 | query, key = apply_rope(query, key, rope) 121 | else: 122 | query_1, query_2 = query.chunk(2, dim=-1) 123 | key_1, key_2 = key.chunk(2, dim=-1) 124 | query_1, key_1 = apply_rope(query_1, key_1, rope) 125 | query = torch.cat([query_1, query_2], dim=-1) 126 | key = torch.cat([key_1, key_2], dim=-1) 127 | 128 | hidden_states = attention(query, key, value) 129 | 130 | if not attn.single: 131 | hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) 132 | hidden_states_i = attn.to_out(hidden_states_i) 133 | hidden_states_t = attn.to_out_t(hidden_states_t) 134 | return hidden_states_i, hidden_states_t 135 | else: 136 | hidden_states = attn.to_out(hidden_states) 137 | return hidden_states -------------------------------------------------------------------------------- /hi_diffusers/models/embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import List 4 | from diffusers.models.embeddings import Timesteps, TimestepEmbedding 5 | 6 | # Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py 7 | def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: 8 | assert dim % 2 == 0, "The dimension must be even." 9 | 10 | scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim 11 | omega = 1.0 / (theta**scale) 12 | 13 | batch_size, seq_length = pos.shape 14 | out = torch.einsum("...n,d->...nd", pos, omega) 15 | cos_out = torch.cos(out) 16 | sin_out = torch.sin(out) 17 | 18 | stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) 19 | out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) 20 | return out.float() 21 | 22 | # Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py 23 | class EmbedND(nn.Module): 24 | def __init__(self, theta: int, axes_dim: List[int]): 25 | super().__init__() 26 | self.theta = theta 27 | self.axes_dim = axes_dim 28 | 29 | def forward(self, ids: torch.Tensor) -> torch.Tensor: 30 | n_axes = ids.shape[-1] 31 | emb = torch.cat( 32 | [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], 33 | dim=-3, 34 | ) 35 | return emb.unsqueeze(2) 36 | 37 | class PatchEmbed(nn.Module): 38 | def __init__( 39 | self, 40 | patch_size=2, 41 | in_channels=4, 42 | out_channels=1024, 43 | ): 44 | super().__init__() 45 | self.patch_size = patch_size 46 | self.out_channels = out_channels 47 | self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True) 48 | self.apply(self._init_weights) 49 | 50 | def _init_weights(self, m): 51 | if isinstance(m, nn.Linear): 52 | nn.init.xavier_uniform_(m.weight) 53 | if m.bias is not None: 54 | nn.init.constant_(m.bias, 0) 55 | 56 | def forward(self, latent): 57 | latent = self.proj(latent) 58 | return latent 59 | 60 | class PooledEmbed(nn.Module): 61 | def __init__(self, text_emb_dim, hidden_size): 62 | super().__init__() 63 | self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size) 64 | self.apply(self._init_weights) 65 | 66 | def _init_weights(self, m): 67 | if isinstance(m, nn.Linear): 68 | nn.init.normal_(m.weight, std=0.02) 69 | if m.bias is not None: 70 | nn.init.constant_(m.bias, 0) 71 | 72 | def forward(self, pooled_embed): 73 | return self.pooled_embedder(pooled_embed) 74 | 75 | class TimestepEmbed(nn.Module): 76 | def __init__(self, hidden_size, frequency_embedding_size=256): 77 | super().__init__() 78 | self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0) 79 | self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size) 80 | self.apply(self._init_weights) 81 | 82 | def _init_weights(self, m): 83 | if isinstance(m, nn.Linear): 84 | nn.init.normal_(m.weight, std=0.02) 85 | if m.bias is not None: 86 | nn.init.constant_(m.bias, 0) 87 | 88 | def forward(self, timesteps, wdtype): 89 | t_emb = self.time_proj(timesteps).to(dtype=wdtype) 90 | t_emb = self.timestep_embedder(t_emb) 91 | return t_emb 92 | 93 | class OutEmbed(nn.Module): 94 | def __init__(self, hidden_size, patch_size, out_channels): 95 | super().__init__() 96 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 97 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 98 | self.adaLN_modulation = nn.Sequential( 99 | nn.SiLU(), 100 | nn.Linear(hidden_size, 2 * hidden_size, bias=True) 101 | ) 102 | self.apply(self._init_weights) 103 | 104 | def _init_weights(self, m): 105 | if isinstance(m, nn.Linear): 106 | nn.init.zeros_(m.weight) 107 | if m.bias is not None: 108 | nn.init.constant_(m.bias, 0) 109 | 110 | def forward(self, x, adaln_input): 111 | shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1) 112 | x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 113 | x = self.linear(x) 114 | return x -------------------------------------------------------------------------------- /hi_diffusers/models/moe.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from .attention import FeedForwardSwiGLU 6 | from torch.distributed.nn.functional import all_gather 7 | 8 | _LOAD_BALANCING_LOSS = [] 9 | def save_load_balancing_loss(loss): 10 | global _LOAD_BALANCING_LOSS 11 | _LOAD_BALANCING_LOSS.append(loss) 12 | 13 | def clear_load_balancing_loss(): 14 | global _LOAD_BALANCING_LOSS 15 | _LOAD_BALANCING_LOSS.clear() 16 | 17 | def get_load_balancing_loss(): 18 | global _LOAD_BALANCING_LOSS 19 | return _LOAD_BALANCING_LOSS 20 | 21 | def batched_load_balancing_loss(): 22 | aux_losses_arr = get_load_balancing_loss() 23 | alpha = aux_losses_arr[0][-1] 24 | Pi = torch.stack([ent[1] for ent in aux_losses_arr], dim=0) 25 | fi = torch.stack([ent[2] for ent in aux_losses_arr], dim=0) 26 | 27 | fi_list = all_gather(fi) 28 | fi = torch.stack(fi_list, 0).mean(0) 29 | 30 | aux_loss = (Pi * fi).sum(-1).mean() * alpha 31 | return aux_loss 32 | 33 | # Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py 34 | class MoEGate(nn.Module): 35 | def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01): 36 | super().__init__() 37 | self.top_k = num_activated_experts 38 | self.n_routed_experts = num_routed_experts 39 | 40 | self.scoring_func = 'softmax' 41 | self.alpha = aux_loss_alpha 42 | self.seq_aux = False 43 | 44 | # topk selection algorithm 45 | self.norm_topk_prob = False 46 | self.gating_dim = embed_dim 47 | self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) 48 | self.reset_parameters() 49 | 50 | def reset_parameters(self) -> None: 51 | import torch.nn.init as init 52 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 53 | 54 | def forward(self, hidden_states): 55 | bsz, seq_len, h = hidden_states.shape 56 | # print(bsz, seq_len, h) 57 | ### compute gating score 58 | hidden_states = hidden_states.view(-1, h) 59 | logits = F.linear(hidden_states, self.weight, None) 60 | if self.scoring_func == 'softmax': 61 | scores = logits.softmax(dim=-1) 62 | else: 63 | raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}') 64 | 65 | ### select top-k experts 66 | topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) 67 | 68 | ### norm gate to sum 1 69 | if self.top_k > 1 and self.norm_topk_prob: 70 | denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 71 | topk_weight = topk_weight / denominator 72 | 73 | ### expert-level computation auxiliary loss 74 | if self.training and self.alpha > 0.0: 75 | scores_for_aux = scores 76 | aux_topk = self.top_k 77 | # always compute aux loss based on the naive greedy topk method 78 | topk_idx_for_aux_loss = topk_idx.view(bsz, -1) 79 | if self.seq_aux: 80 | scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) 81 | ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device) 82 | ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts) 83 | aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * self.alpha 84 | else: 85 | mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) 86 | ce = mask_ce.float().mean(0) 87 | 88 | Pi = scores_for_aux.mean(0) 89 | fi = ce * self.n_routed_experts 90 | aux_loss = (Pi * fi).sum() * self.alpha 91 | save_load_balancing_loss((aux_loss, Pi, fi, self.alpha)) 92 | else: 93 | aux_loss = None 94 | return topk_idx, topk_weight, aux_loss 95 | 96 | # Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py 97 | class MOEFeedForwardSwiGLU(nn.Module): 98 | def __init__( 99 | self, 100 | dim: int, 101 | hidden_dim: int, 102 | num_routed_experts: int, 103 | num_activated_experts: int, 104 | ): 105 | super().__init__() 106 | self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2) 107 | self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim) for i in range(num_routed_experts)]) 108 | self.gate = MoEGate( 109 | embed_dim = dim, 110 | num_routed_experts = num_routed_experts, 111 | num_activated_experts = num_activated_experts 112 | ) 113 | self.num_activated_experts = num_activated_experts 114 | 115 | def forward(self, x): 116 | wtype = x.dtype 117 | identity = x 118 | orig_shape = x.shape 119 | topk_idx, topk_weight, aux_loss = self.gate(x) 120 | x = x.view(-1, x.shape[-1]) 121 | flat_topk_idx = topk_idx.view(-1) 122 | if self.training: 123 | x = x.repeat_interleave(self.num_activated_experts, dim=0) 124 | y = torch.empty_like(x, dtype=wtype) 125 | for i, expert in enumerate(self.experts): 126 | y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype) 127 | y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) 128 | y = y.view(*orig_shape).to(dtype=wtype) 129 | #y = AddAuxiliaryLoss.apply(y, aux_loss) 130 | else: 131 | y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) 132 | y = y + self.shared_experts(identity) 133 | return y 134 | 135 | @torch.no_grad() 136 | def moe_infer(self, x, flat_expert_indices, flat_expert_weights): 137 | expert_cache = torch.zeros_like(x) 138 | idxs = flat_expert_indices.argsort() 139 | tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) 140 | token_idxs = idxs // self.num_activated_experts 141 | for i, end_idx in enumerate(tokens_per_expert): 142 | start_idx = 0 if i == 0 else tokens_per_expert[i-1] 143 | if start_idx == end_idx: 144 | continue 145 | expert = self.experts[i] 146 | exp_token_idx = token_idxs[start_idx:end_idx] 147 | expert_tokens = x[exp_token_idx] 148 | expert_out = expert(expert_tokens) 149 | expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) 150 | 151 | # for fp16 and other dtype 152 | expert_cache = expert_cache.to(expert_out.dtype) 153 | expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum') 154 | return expert_cache 155 | -------------------------------------------------------------------------------- /hi_diffusers/models/transformers/transformer_hidream_image.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Tuple, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | import einops 6 | from einops import repeat 7 | 8 | from diffusers.configuration_utils import ConfigMixin, register_to_config 9 | from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin 10 | from diffusers.models.modeling_utils import ModelMixin 11 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers 12 | from diffusers.utils.torch_utils import maybe_allow_in_graph 13 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 14 | from ..embeddings import PatchEmbed, PooledEmbed, TimestepEmbed, EmbedND, OutEmbed 15 | from ..attention import HiDreamAttention, FeedForwardSwiGLU 16 | from ..attention_processor import HiDreamAttnProcessor_flashattn 17 | from ..moe import MOEFeedForwardSwiGLU 18 | 19 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 20 | 21 | class TextProjection(nn.Module): 22 | def __init__(self, in_features, hidden_size): 23 | super().__init__() 24 | self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False) 25 | 26 | def forward(self, caption): 27 | hidden_states = self.linear(caption) 28 | return hidden_states 29 | 30 | class BlockType: 31 | TransformerBlock = 1 32 | SingleTransformerBlock = 2 33 | 34 | @maybe_allow_in_graph 35 | class HiDreamImageSingleTransformerBlock(nn.Module): 36 | def __init__( 37 | self, 38 | dim: int, 39 | num_attention_heads: int, 40 | attention_head_dim: int, 41 | num_routed_experts: int = 4, 42 | num_activated_experts: int = 2, 43 | ): 44 | super().__init__() 45 | self.num_attention_heads = num_attention_heads 46 | self.adaLN_modulation = nn.Sequential( 47 | nn.SiLU(), 48 | nn.Linear(dim, 6 * dim, bias=True) 49 | ) 50 | nn.init.zeros_(self.adaLN_modulation[1].weight) 51 | nn.init.zeros_(self.adaLN_modulation[1].bias) 52 | 53 | # 1. Attention 54 | self.norm1_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) 55 | self.attn1 = HiDreamAttention( 56 | query_dim=dim, 57 | heads=num_attention_heads, 58 | dim_head=attention_head_dim, 59 | processor = HiDreamAttnProcessor_flashattn(), 60 | single = True 61 | ) 62 | 63 | # 3. Feed-forward 64 | self.norm3_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) 65 | if num_routed_experts > 0: 66 | self.ff_i = MOEFeedForwardSwiGLU( 67 | dim = dim, 68 | hidden_dim = 4 * dim, 69 | num_routed_experts = num_routed_experts, 70 | num_activated_experts = num_activated_experts, 71 | ) 72 | else: 73 | self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim) 74 | 75 | def forward( 76 | self, 77 | image_tokens: torch.FloatTensor, 78 | image_tokens_masks: Optional[torch.FloatTensor] = None, 79 | text_tokens: Optional[torch.FloatTensor] = None, 80 | adaln_input: Optional[torch.FloatTensor] = None, 81 | rope: torch.FloatTensor = None, 82 | 83 | ) -> torch.FloatTensor: 84 | wtype = image_tokens.dtype 85 | shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \ 86 | self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1) 87 | 88 | # 1. MM-Attention 89 | norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype) 90 | norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i 91 | attn_output_i = self.attn1( 92 | norm_image_tokens, 93 | image_tokens_masks, 94 | rope = rope, 95 | ) 96 | image_tokens = gate_msa_i * attn_output_i + image_tokens 97 | 98 | # 2. Feed-forward 99 | norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype) 100 | norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i 101 | ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype)) 102 | image_tokens = ff_output_i + image_tokens 103 | return image_tokens 104 | 105 | @maybe_allow_in_graph 106 | class HiDreamImageTransformerBlock(nn.Module): 107 | def __init__( 108 | self, 109 | dim: int, 110 | num_attention_heads: int, 111 | attention_head_dim: int, 112 | num_routed_experts: int = 4, 113 | num_activated_experts: int = 2, 114 | ): 115 | super().__init__() 116 | self.num_attention_heads = num_attention_heads 117 | self.adaLN_modulation = nn.Sequential( 118 | nn.SiLU(), 119 | nn.Linear(dim, 12 * dim, bias=True) 120 | ) 121 | nn.init.zeros_(self.adaLN_modulation[1].weight) 122 | nn.init.zeros_(self.adaLN_modulation[1].bias) 123 | 124 | # 1. Attention 125 | self.norm1_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) 126 | self.norm1_t = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) 127 | self.attn1 = HiDreamAttention( 128 | query_dim=dim, 129 | heads=num_attention_heads, 130 | dim_head=attention_head_dim, 131 | processor = HiDreamAttnProcessor_flashattn(), 132 | single = False 133 | ) 134 | 135 | # 3. Feed-forward 136 | self.norm3_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) 137 | if num_routed_experts > 0: 138 | self.ff_i = MOEFeedForwardSwiGLU( 139 | dim = dim, 140 | hidden_dim = 4 * dim, 141 | num_routed_experts = num_routed_experts, 142 | num_activated_experts = num_activated_experts, 143 | ) 144 | else: 145 | self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim) 146 | self.norm3_t = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) 147 | self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim) 148 | 149 | def forward( 150 | self, 151 | image_tokens: torch.FloatTensor, 152 | image_tokens_masks: Optional[torch.FloatTensor] = None, 153 | text_tokens: Optional[torch.FloatTensor] = None, 154 | adaln_input: Optional[torch.FloatTensor] = None, 155 | rope: torch.FloatTensor = None, 156 | ) -> torch.FloatTensor: 157 | wtype = image_tokens.dtype 158 | shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \ 159 | shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \ 160 | self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1) 161 | 162 | # 1. MM-Attention 163 | norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype) 164 | norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i 165 | norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype) 166 | norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t 167 | 168 | attn_output_i, attn_output_t = self.attn1( 169 | norm_image_tokens, 170 | image_tokens_masks, 171 | norm_text_tokens, 172 | rope = rope, 173 | ) 174 | 175 | image_tokens = gate_msa_i * attn_output_i + image_tokens 176 | text_tokens = gate_msa_t * attn_output_t + text_tokens 177 | 178 | # 2. Feed-forward 179 | norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype) 180 | norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i 181 | norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype) 182 | norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t 183 | 184 | ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens) 185 | ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens) 186 | image_tokens = ff_output_i + image_tokens 187 | text_tokens = ff_output_t + text_tokens 188 | return image_tokens, text_tokens 189 | 190 | @maybe_allow_in_graph 191 | class HiDreamImageBlock(nn.Module): 192 | def __init__( 193 | self, 194 | dim: int, 195 | num_attention_heads: int, 196 | attention_head_dim: int, 197 | num_routed_experts: int = 4, 198 | num_activated_experts: int = 2, 199 | block_type: BlockType = BlockType.TransformerBlock, 200 | ): 201 | super().__init__() 202 | block_classes = { 203 | BlockType.TransformerBlock: HiDreamImageTransformerBlock, 204 | BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock, 205 | } 206 | self.block = block_classes[block_type]( 207 | dim, 208 | num_attention_heads, 209 | attention_head_dim, 210 | num_routed_experts, 211 | num_activated_experts, 212 | ) 213 | 214 | def forward( 215 | self, 216 | image_tokens: torch.FloatTensor, 217 | image_tokens_masks: Optional[torch.FloatTensor] = None, 218 | text_tokens: Optional[torch.FloatTensor] = None, 219 | adaln_input: torch.FloatTensor = None, 220 | rope: torch.FloatTensor = None, 221 | ) -> torch.FloatTensor: 222 | return self.block( 223 | image_tokens, 224 | image_tokens_masks, 225 | text_tokens, 226 | adaln_input, 227 | rope, 228 | ) 229 | 230 | class HiDreamImageTransformer2DModel( 231 | ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin 232 | ): 233 | _supports_gradient_checkpointing = True 234 | _no_split_modules = ["HiDreamImageBlock"] 235 | 236 | @register_to_config 237 | def __init__( 238 | self, 239 | patch_size: Optional[int] = None, 240 | in_channels: int = 64, 241 | out_channels: Optional[int] = None, 242 | num_layers: int = 16, 243 | num_single_layers: int = 32, 244 | attention_head_dim: int = 128, 245 | num_attention_heads: int = 20, 246 | caption_channels: List[int] = None, 247 | text_emb_dim: int = 2048, 248 | num_routed_experts: int = 4, 249 | num_activated_experts: int = 2, 250 | axes_dims_rope: Tuple[int, int] = (32, 32), 251 | max_resolution: Tuple[int, int] = (128, 128), 252 | llama_layers: List[int] = None, 253 | ): 254 | super().__init__() 255 | self.out_channels = out_channels or in_channels 256 | self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim 257 | self.llama_layers = llama_layers 258 | 259 | self.t_embedder = TimestepEmbed(self.inner_dim) 260 | self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim) 261 | self.x_embedder = PatchEmbed( 262 | patch_size = patch_size, 263 | in_channels = in_channels, 264 | out_channels = self.inner_dim, 265 | ) 266 | self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope) 267 | 268 | self.double_stream_blocks = nn.ModuleList( 269 | [ 270 | HiDreamImageBlock( 271 | dim = self.inner_dim, 272 | num_attention_heads = self.config.num_attention_heads, 273 | attention_head_dim = self.config.attention_head_dim, 274 | num_routed_experts = num_routed_experts, 275 | num_activated_experts = num_activated_experts, 276 | block_type = BlockType.TransformerBlock, 277 | ) 278 | for i in range(self.config.num_layers) 279 | ] 280 | ) 281 | 282 | self.single_stream_blocks = nn.ModuleList( 283 | [ 284 | HiDreamImageBlock( 285 | dim = self.inner_dim, 286 | num_attention_heads = self.config.num_attention_heads, 287 | attention_head_dim = self.config.attention_head_dim, 288 | num_routed_experts = num_routed_experts, 289 | num_activated_experts = num_activated_experts, 290 | block_type = BlockType.SingleTransformerBlock 291 | ) 292 | for i in range(self.config.num_single_layers) 293 | ] 294 | ) 295 | 296 | self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels) 297 | 298 | caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ] 299 | caption_projection = [] 300 | for caption_channel in caption_channels: 301 | caption_projection.append(TextProjection(in_features = caption_channel, hidden_size = self.inner_dim)) 302 | self.caption_projection = nn.ModuleList(caption_projection) 303 | self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size) 304 | 305 | self.gradient_checkpointing = False 306 | 307 | def _set_gradient_checkpointing(self, module, value=False): 308 | if hasattr(module, "gradient_checkpointing"): 309 | module.gradient_checkpointing = value 310 | 311 | def expand_timesteps(self, timesteps, batch_size, device): 312 | if not torch.is_tensor(timesteps): 313 | is_mps = device.type == "mps" 314 | if isinstance(timesteps, float): 315 | dtype = torch.float32 if is_mps else torch.float64 316 | else: 317 | dtype = torch.int32 if is_mps else torch.int64 318 | timesteps = torch.tensor([timesteps], dtype=dtype, device=device) 319 | elif len(timesteps.shape) == 0: 320 | timesteps = timesteps[None].to(device) 321 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 322 | timesteps = timesteps.expand(batch_size) 323 | return timesteps 324 | 325 | def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]: 326 | if is_training: 327 | x = einops.rearrange(x, 'B S (p1 p2 C) -> B C S (p1 p2)', p1=self.config.patch_size, p2=self.config.patch_size) 328 | else: 329 | x_arr = [] 330 | for i, img_size in enumerate(img_sizes): 331 | pH, pW = img_size 332 | x_arr.append( 333 | einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)', 334 | p1=self.config.patch_size, p2=self.config.patch_size) 335 | ) 336 | x = torch.cat(x_arr, dim=0) 337 | return x 338 | 339 | def patchify(self, x, max_seq, img_sizes=None): 340 | pz2 = self.config.patch_size * self.config.patch_size 341 | if isinstance(x, torch.Tensor): 342 | B, C = x.shape[0], x.shape[1] 343 | device = x.device 344 | dtype = x.dtype 345 | else: 346 | B, C = len(x), x[0].shape[0] 347 | device = x[0].device 348 | dtype = x[0].dtype 349 | x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device) 350 | 351 | if img_sizes is not None: 352 | for i, img_size in enumerate(img_sizes): 353 | x_masks[i, 0:img_size[0] * img_size[1]] = 1 354 | x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2) 355 | elif isinstance(x, torch.Tensor): 356 | pH, pW = x.shape[-2] // self.config.patch_size, x.shape[-1] // self.config.patch_size 357 | x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.config.patch_size, p2=self.config.patch_size) 358 | img_sizes = [[pH, pW]] * B 359 | x_masks = None 360 | else: 361 | raise NotImplementedError 362 | return x, x_masks, img_sizes 363 | 364 | def forward( 365 | self, 366 | hidden_states: torch.Tensor, 367 | timesteps: torch.LongTensor = None, 368 | encoder_hidden_states: torch.Tensor = None, 369 | pooled_embeds: torch.Tensor = None, 370 | img_sizes: Optional[List[Tuple[int, int]]] = None, 371 | img_ids: Optional[torch.Tensor] = None, 372 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 373 | return_dict: bool = True, 374 | ): 375 | if joint_attention_kwargs is not None: 376 | joint_attention_kwargs = joint_attention_kwargs.copy() 377 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 378 | else: 379 | lora_scale = 1.0 380 | 381 | if USE_PEFT_BACKEND: 382 | # weight the lora layers by setting `lora_scale` for each PEFT layer 383 | scale_lora_layers(self, lora_scale) 384 | else: 385 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: 386 | logger.warning( 387 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 388 | ) 389 | 390 | # spatial forward 391 | batch_size = hidden_states.shape[0] 392 | hidden_states_type = hidden_states.dtype 393 | 394 | # 0. time 395 | timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device) 396 | timesteps = self.t_embedder(timesteps, hidden_states_type) 397 | p_embedder = self.p_embedder(pooled_embeds) 398 | adaln_input = timesteps + p_embedder 399 | 400 | hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes) 401 | if image_tokens_masks is None: 402 | pH, pW = img_sizes[0] 403 | img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device) 404 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None] 405 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :] 406 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) 407 | hidden_states = self.x_embedder(hidden_states) 408 | 409 | T5_encoder_hidden_states = encoder_hidden_states[0] 410 | encoder_hidden_states = encoder_hidden_states[-1] 411 | encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers] 412 | 413 | if self.caption_projection is not None: 414 | new_encoder_hidden_states = [] 415 | for i, enc_hidden_state in enumerate(encoder_hidden_states): 416 | enc_hidden_state = self.caption_projection[i](enc_hidden_state) 417 | enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1]) 418 | new_encoder_hidden_states.append(enc_hidden_state) 419 | encoder_hidden_states = new_encoder_hidden_states 420 | T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states) 421 | T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) 422 | encoder_hidden_states.append(T5_encoder_hidden_states) 423 | 424 | txt_ids = torch.zeros( 425 | batch_size, 426 | encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1], 427 | 3, 428 | device=img_ids.device, dtype=img_ids.dtype 429 | ) 430 | ids = torch.cat((img_ids, txt_ids), dim=1) 431 | rope = self.pe_embedder(ids) 432 | 433 | # 2. Blocks 434 | block_id = 0 435 | initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1) 436 | initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1] 437 | for bid, block in enumerate(self.double_stream_blocks): 438 | cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] 439 | cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1) 440 | if self.training and self.gradient_checkpointing: 441 | def create_custom_forward(module, return_dict=None): 442 | def custom_forward(*inputs): 443 | if return_dict is not None: 444 | return module(*inputs, return_dict=return_dict) 445 | else: 446 | return module(*inputs) 447 | return custom_forward 448 | 449 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 450 | hidden_states, initial_encoder_hidden_states = torch.utils.checkpoint.checkpoint( 451 | create_custom_forward(block), 452 | hidden_states, 453 | image_tokens_masks, 454 | cur_encoder_hidden_states, 455 | adaln_input, 456 | rope, 457 | **ckpt_kwargs, 458 | ) 459 | else: 460 | hidden_states, initial_encoder_hidden_states = block( 461 | image_tokens = hidden_states, 462 | image_tokens_masks = image_tokens_masks, 463 | text_tokens = cur_encoder_hidden_states, 464 | adaln_input = adaln_input, 465 | rope = rope, 466 | ) 467 | initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] 468 | block_id += 1 469 | 470 | image_tokens_seq_len = hidden_states.shape[1] 471 | hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1) 472 | hidden_states_seq_len = hidden_states.shape[1] 473 | if image_tokens_masks is not None: 474 | encoder_attention_mask_ones = torch.ones( 475 | (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]), 476 | device=image_tokens_masks.device, dtype=image_tokens_masks.dtype 477 | ) 478 | image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1) 479 | 480 | for bid, block in enumerate(self.single_stream_blocks): 481 | cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] 482 | hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1) 483 | if self.training and self.gradient_checkpointing: 484 | def create_custom_forward(module, return_dict=None): 485 | def custom_forward(*inputs): 486 | if return_dict is not None: 487 | return module(*inputs, return_dict=return_dict) 488 | else: 489 | return module(*inputs) 490 | return custom_forward 491 | 492 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 493 | hidden_states = torch.utils.checkpoint.checkpoint( 494 | create_custom_forward(block), 495 | hidden_states, 496 | image_tokens_masks, 497 | None, 498 | adaln_input, 499 | rope, 500 | **ckpt_kwargs, 501 | ) 502 | else: 503 | hidden_states = block( 504 | image_tokens = hidden_states, 505 | image_tokens_masks = image_tokens_masks, 506 | text_tokens = None, 507 | adaln_input = adaln_input, 508 | rope = rope, 509 | ) 510 | hidden_states = hidden_states[:, :hidden_states_seq_len] 511 | block_id += 1 512 | 513 | hidden_states = hidden_states[:, :image_tokens_seq_len, ...] 514 | output = self.final_layer(hidden_states, adaln_input) 515 | output = self.unpatchify(output, img_sizes, self.training) 516 | if image_tokens_masks is not None: 517 | image_tokens_masks = image_tokens_masks[:, :image_tokens_seq_len] 518 | 519 | if USE_PEFT_BACKEND: 520 | # remove `lora_scale` from each PEFT layer 521 | unscale_lora_layers(self, lora_scale) 522 | 523 | if not return_dict: 524 | return (output, image_tokens_masks) 525 | return Transformer2DModelOutput(sample=output, mask=image_tokens_masks) 526 | 527 | -------------------------------------------------------------------------------- /hi_diffusers/pipelines/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lum3on/comfyui_HiDream-Sampler/98ad017cac93b782e2af95411e4c10d493ecb841/hi_diffusers/pipelines/.DS_Store -------------------------------------------------------------------------------- /hi_diffusers/pipelines/hidream_image/pipeline_hidream_image.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Any, Callable, Dict, List, Optional, Union 3 | import math 4 | import einops 5 | import torch 6 | from transformers import ( 7 | CLIPTextModelWithProjection, 8 | CLIPTokenizer, 9 | T5EncoderModel, 10 | T5Tokenizer, 11 | LlamaForCausalLM, 12 | PreTrainedTokenizerFast 13 | ) 14 | 15 | from diffusers.image_processor import VaeImageProcessor 16 | from diffusers.loaders import FromSingleFileMixin 17 | from diffusers.models.autoencoders import AutoencoderKL 18 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 19 | from diffusers.utils import ( 20 | USE_PEFT_BACKEND, 21 | is_torch_xla_available, 22 | logging, 23 | ) 24 | from diffusers.utils.torch_utils import randn_tensor 25 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 26 | from .pipeline_output import HiDreamImagePipelineOutput 27 | from ...models.transformers.transformer_hidream_image import HiDreamImageTransformer2DModel 28 | from ...schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler 29 | 30 | if is_torch_xla_available(): 31 | import torch_xla.core.xla_model as xm 32 | 33 | XLA_AVAILABLE = True 34 | else: 35 | XLA_AVAILABLE = False 36 | 37 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 38 | 39 | # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift 40 | def calculate_shift( 41 | image_seq_len, 42 | base_seq_len: int = 256, 43 | max_seq_len: int = 4096, 44 | base_shift: float = 0.5, 45 | max_shift: float = 1.15, 46 | ): 47 | m = (max_shift - base_shift) / (max_seq_len - base_seq_len) 48 | b = base_shift - m * base_seq_len 49 | mu = image_seq_len * m + b 50 | return mu 51 | 52 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 53 | def retrieve_timesteps( 54 | scheduler, 55 | num_inference_steps: Optional[int] = None, 56 | device: Optional[Union[str, torch.device]] = None, 57 | timesteps: Optional[List[int]] = None, 58 | sigmas: Optional[List[float]] = None, 59 | **kwargs, 60 | ): 61 | r""" 62 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 63 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 64 | 65 | Args: 66 | scheduler (`SchedulerMixin`): 67 | The scheduler to get timesteps from. 68 | num_inference_steps (`int`): 69 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 70 | must be `None`. 71 | device (`str` or `torch.device`, *optional*): 72 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 73 | timesteps (`List[int]`, *optional*): 74 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 75 | `num_inference_steps` and `sigmas` must be `None`. 76 | sigmas (`List[float]`, *optional*): 77 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 78 | `num_inference_steps` and `timesteps` must be `None`. 79 | 80 | Returns: 81 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 82 | second element is the number of inference steps. 83 | """ 84 | if timesteps is not None and sigmas is not None: 85 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 86 | if timesteps is not None: 87 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 88 | if not accepts_timesteps: 89 | raise ValueError( 90 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 91 | f" timestep schedules. Please check whether you are using the correct scheduler." 92 | ) 93 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 94 | timesteps = scheduler.timesteps 95 | num_inference_steps = len(timesteps) 96 | elif sigmas is not None: 97 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 98 | if not accept_sigmas: 99 | raise ValueError( 100 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 101 | f" sigmas schedules. Please check whether you are using the correct scheduler." 102 | ) 103 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 104 | timesteps = scheduler.timesteps 105 | num_inference_steps = len(timesteps) 106 | else: 107 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 108 | timesteps = scheduler.timesteps 109 | return timesteps, num_inference_steps 110 | 111 | class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin): 112 | model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->image_encoder->transformer->vae" 113 | _optional_components = ["image_encoder", "feature_extractor"] 114 | _callback_tensor_inputs = ["latents", "prompt_embeds"] 115 | 116 | def __init__( 117 | self, 118 | scheduler: FlowMatchEulerDiscreteScheduler, 119 | vae: AutoencoderKL, 120 | text_encoder: CLIPTextModelWithProjection, 121 | tokenizer: CLIPTokenizer, 122 | text_encoder_2: CLIPTextModelWithProjection, 123 | tokenizer_2: CLIPTokenizer, 124 | text_encoder_3: T5EncoderModel, 125 | tokenizer_3: T5Tokenizer, 126 | text_encoder_4: LlamaForCausalLM, 127 | tokenizer_4: PreTrainedTokenizerFast, 128 | ): 129 | super().__init__() 130 | 131 | self.register_modules( 132 | vae=vae, 133 | text_encoder=text_encoder, 134 | text_encoder_2=text_encoder_2, 135 | text_encoder_3=text_encoder_3, 136 | text_encoder_4=text_encoder_4, 137 | tokenizer=tokenizer, 138 | tokenizer_2=tokenizer_2, 139 | tokenizer_3=tokenizer_3, 140 | tokenizer_4=tokenizer_4, 141 | scheduler=scheduler, 142 | ) 143 | self.vae_scale_factor = ( 144 | 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 145 | ) 146 | # HiDreamImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible 147 | # by the patch size. So the vae scale factor is multiplied by the patch size to account for this 148 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) 149 | self.default_sample_size = 128 150 | self.tokenizer_4.pad_token = self.tokenizer_4.eos_token 151 | 152 | def _get_t5_prompt_embeds( 153 | self, 154 | prompt: Union[str, List[str]] = None, 155 | num_images_per_prompt: int = 1, 156 | max_sequence_length: int = 128, 157 | device: Optional[torch.device] = None, 158 | dtype: Optional[torch.dtype] = None, 159 | ): 160 | device = device or self._execution_device 161 | dtype = dtype or self.text_encoder_3.dtype 162 | 163 | prompt = [prompt] if isinstance(prompt, str) else prompt 164 | batch_size = len(prompt) 165 | 166 | text_inputs = self.tokenizer_3( 167 | prompt, 168 | padding="max_length", 169 | max_length=min(max_sequence_length, self.tokenizer_3.model_max_length), 170 | truncation=True, 171 | add_special_tokens=True, 172 | return_tensors="pt", 173 | ) 174 | text_input_ids = text_inputs.input_ids 175 | attention_mask = text_inputs.attention_mask 176 | untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids 177 | 178 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 179 | removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_3.model_max_length) - 1 : -1]) 180 | logger.warning( 181 | "The following part of your input was truncated because `max_sequence_length` is set to " 182 | f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}" 183 | ) 184 | 185 | prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0] 186 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 187 | _, seq_len, _ = prompt_embeds.shape 188 | 189 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 190 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 191 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 192 | return prompt_embeds 193 | 194 | def _get_clip_prompt_embeds( 195 | self, 196 | tokenizer, 197 | text_encoder, 198 | prompt: Union[str, List[str]], 199 | num_images_per_prompt: int = 1, 200 | max_sequence_length: int = 128, 201 | device: Optional[torch.device] = None, 202 | dtype: Optional[torch.dtype] = None, 203 | ): 204 | device = device or self._execution_device 205 | dtype = dtype or text_encoder.dtype 206 | 207 | prompt = [prompt] if isinstance(prompt, str) else prompt 208 | batch_size = len(prompt) 209 | 210 | text_inputs = tokenizer( 211 | prompt, 212 | padding="max_length", 213 | max_length=min(max_sequence_length, 218), 214 | truncation=True, 215 | return_tensors="pt", 216 | ) 217 | 218 | text_input_ids = text_inputs.input_ids 219 | prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) 220 | 221 | # Use pooled output of CLIPTextModel 222 | prompt_embeds = prompt_embeds[0] 223 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 224 | 225 | # duplicate text embeddings for each generation per prompt, using mps friendly method 226 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) 227 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) 228 | 229 | return prompt_embeds 230 | 231 | def _get_llama3_prompt_embeds( 232 | self, 233 | prompt: Union[str, List[str]] = None, 234 | num_images_per_prompt: int = 1, 235 | max_sequence_length: int = 128, 236 | system_prompt: Optional[str] = "", 237 | device: Optional[torch.device] = None, 238 | dtype: Optional[torch.dtype] = None, 239 | ): 240 | device = device or self._execution_device 241 | dtype = dtype or self.text_encoder_4.dtype 242 | prompt = [prompt] if isinstance(prompt, str) else prompt 243 | batch_size = len(prompt) 244 | 245 | # Format prompts with system message 246 | formatted_prompts = [] 247 | for p in prompt: 248 | if system_prompt: 249 | formatted_prompt = f"<|system|>\n{system_prompt}\n<|user|>\n{p}\n<|assistant|>" 250 | else: 251 | formatted_prompt = f"<|user|>\n{p}\n<|assistant|>" 252 | formatted_prompts.append(formatted_prompt) 253 | 254 | # Calculate the actual maximum length being used 255 | actual_max_length = min(max_sequence_length, self.tokenizer_4.model_max_length) 256 | 257 | text_inputs = self.tokenizer_4( 258 | formatted_prompts, 259 | padding="max_length", 260 | max_length=actual_max_length, # Use the calculated length 261 | truncation=True, 262 | add_special_tokens=True, 263 | return_tensors="pt", 264 | ) 265 | 266 | text_input_ids = text_inputs.input_ids 267 | attention_mask = text_inputs.attention_mask 268 | untruncated_ids = self.tokenizer_4(formatted_prompts, padding="longest", return_tensors="pt").input_ids 269 | 270 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 271 | removed_text = self.tokenizer_4.batch_decode(untruncated_ids[:, actual_max_length - 1 : -1]) 272 | logger.warning( 273 | "The following part of your input was truncated because the LLaMA max sequence length " 274 | f"is set to {actual_max_length} tokens: {removed_text}" 275 | ) 276 | 277 | # Rest of your function remains unchanged 278 | outputs = self.text_encoder_4( 279 | text_input_ids.to(device), 280 | attention_mask=attention_mask.to(device), 281 | output_hidden_states=True, 282 | output_attentions=True 283 | ) 284 | prompt_embeds = outputs.hidden_states[1:] 285 | prompt_embeds = torch.stack(prompt_embeds, dim=0) 286 | _,_ , seq_len, dim = prompt_embeds.shape 287 | 288 | # duplicate text embeddings and attention mask for each generation per prompt 289 | prompt_embeds = prompt_embeds.repeat(1, 1, num_images_per_prompt, 1) 290 | prompt_embeds = prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim) 291 | 292 | return prompt_embeds 293 | 294 | def encode_prompt( 295 | self, 296 | prompt: Union[str, List[str]], 297 | prompt_2: Union[str, List[str]], 298 | prompt_3: Union[str, List[str]], 299 | prompt_4: Union[str, List[str]], 300 | device: Optional[torch.device] = None, 301 | dtype: Optional[torch.dtype] = None, 302 | num_images_per_prompt: int = 1, 303 | do_classifier_free_guidance: bool = True, 304 | negative_prompt: Optional[Union[str, List[str]]] = None, 305 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 306 | negative_prompt_3: Optional[Union[str, List[str]]] = None, 307 | negative_prompt_4: Optional[Union[str, List[str]]] = None, 308 | prompt_embeds: Optional[List[torch.FloatTensor]] = None, 309 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 310 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 311 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 312 | max_sequence_length: int = 128, 313 | max_sequence_length_clip_l: Optional[int] = None, 314 | max_sequence_length_openclip: Optional[int] = None, 315 | max_sequence_length_t5: Optional[int] = None, 316 | max_sequence_length_llama: Optional[int] = None, 317 | lora_scale: Optional[float] = None, 318 | llm_system_prompt: str = "", 319 | clip_l_scale: float = 1.0, 320 | openclip_scale: float = 1.0, 321 | t5_scale: float = 1.0, 322 | llama_scale: float = 1.0, 323 | ): 324 | prompt = [prompt] if isinstance(prompt, str) else prompt 325 | if prompt is not None: 326 | batch_size = len(prompt) 327 | else: 328 | batch_size = prompt_embeds.shape[0] 329 | 330 | # Pass all sequence length parameters to _encode_prompt 331 | prompt_embeds, pooled_prompt_embeds = self._encode_prompt( 332 | prompt = prompt, 333 | prompt_2 = prompt_2, 334 | prompt_3 = prompt_3, 335 | prompt_4 = prompt_4, 336 | device = device, 337 | dtype = dtype, 338 | num_images_per_prompt = num_images_per_prompt, 339 | prompt_embeds = prompt_embeds, 340 | pooled_prompt_embeds = pooled_prompt_embeds, 341 | max_sequence_length = max_sequence_length, 342 | max_sequence_length_clip_l = max_sequence_length_clip_l, 343 | max_sequence_length_openclip = max_sequence_length_openclip, 344 | max_sequence_length_t5 = max_sequence_length_t5, 345 | max_sequence_length_llama = max_sequence_length_llama, 346 | llm_system_prompt=llm_system_prompt, 347 | clip_l_scale=clip_l_scale, 348 | openclip_scale=openclip_scale, 349 | t5_scale=t5_scale, 350 | llama_scale=llama_scale, 351 | ) 352 | 353 | if do_classifier_free_guidance and negative_prompt_embeds is None: 354 | negative_prompt = negative_prompt or "" 355 | negative_prompt_2 = negative_prompt_2 or negative_prompt 356 | negative_prompt_3 = negative_prompt_3 or negative_prompt 357 | negative_prompt_4 = negative_prompt_4 or negative_prompt 358 | 359 | # normalize str to list 360 | negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt 361 | negative_prompt_2 = ( 362 | batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 363 | ) 364 | negative_prompt_3 = ( 365 | batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 366 | ) 367 | negative_prompt_4 = ( 368 | batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4 369 | ) 370 | 371 | if prompt is not None and type(prompt) is not type(negative_prompt): 372 | raise TypeError( 373 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 374 | f" {type(prompt)}." 375 | ) 376 | elif batch_size != len(negative_prompt): 377 | raise ValueError( 378 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 379 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 380 | " the batch size of `prompt`." 381 | ) 382 | 383 | negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt( 384 | prompt = negative_prompt, 385 | prompt_2 = negative_prompt_2, 386 | prompt_3 = negative_prompt_3, 387 | prompt_4 = negative_prompt_4, 388 | device = device, 389 | dtype = dtype, 390 | num_images_per_prompt = num_images_per_prompt, 391 | prompt_embeds = negative_prompt_embeds, 392 | pooled_prompt_embeds = negative_pooled_prompt_embeds, 393 | max_sequence_length = max_sequence_length, 394 | max_sequence_length_clip_l = max_sequence_length_clip_l, 395 | max_sequence_length_openclip = max_sequence_length_openclip, 396 | max_sequence_length_t5 = max_sequence_length_t5, 397 | max_sequence_length_llama = max_sequence_length_llama, 398 | llm_system_prompt=llm_system_prompt, 399 | clip_l_scale=clip_l_scale, 400 | openclip_scale=openclip_scale, 401 | t5_scale=t5_scale, 402 | llama_scale=llama_scale, 403 | ) 404 | return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds 405 | 406 | def _encode_prompt( 407 | self, 408 | prompt: Union[str, List[str]], 409 | prompt_2: Union[str, List[str]], 410 | prompt_3: Union[str, List[str]], 411 | prompt_4: Union[str, List[str]], 412 | device: Optional[torch.device] = None, 413 | dtype: Optional[torch.dtype] = None, 414 | num_images_per_prompt: int = 1, 415 | prompt_embeds: Optional[List[torch.FloatTensor]] = None, 416 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 417 | max_sequence_length: int = 128, 418 | max_sequence_length_clip_l: Optional[int] = None, 419 | max_sequence_length_openclip: Optional[int] = None, 420 | max_sequence_length_t5: Optional[int] = None, 421 | max_sequence_length_llama: Optional[int] = None, 422 | llm_system_prompt: str = "", 423 | clip_l_scale: float = 1.0, 424 | openclip_scale: float = 1.0, 425 | t5_scale: float = 1.0, 426 | llama_scale: float = 1.0, 427 | ): 428 | device = device or self._execution_device 429 | # Set defaults for individual encoders if not specified 430 | clip_l_length = max_sequence_length_clip_l if max_sequence_length_clip_l is not None else max_sequence_length 431 | openclip_length = max_sequence_length_openclip if max_sequence_length_openclip is not None else max_sequence_length 432 | t5_length = max_sequence_length_t5 if max_sequence_length_t5 is not None else max_sequence_length 433 | llama_length = max_sequence_length_llama if max_sequence_length_llama is not None else max_sequence_length 434 | 435 | if prompt_embeds is None: 436 | prompt_2 = prompt_2 or prompt 437 | prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 438 | prompt_3 = prompt_3 or prompt 439 | prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 440 | prompt_4 = prompt_4 or prompt 441 | prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4 442 | 443 | # Get CLIP-L embeddings 444 | pooled_prompt_embeds_1 = self._get_clip_prompt_embeds( 445 | self.tokenizer, 446 | self.text_encoder, 447 | prompt = prompt, 448 | num_images_per_prompt = num_images_per_prompt, 449 | max_sequence_length = clip_l_length, # CLIP-L specific length 450 | device = device, 451 | dtype = dtype, 452 | ) 453 | 454 | # Get OpenCLIP embeddings 455 | pooled_prompt_embeds_2 = self._get_clip_prompt_embeds( 456 | self.tokenizer_2, 457 | self.text_encoder_2, 458 | prompt = prompt_2, 459 | num_images_per_prompt = num_images_per_prompt, 460 | max_sequence_length = openclip_length, # OpenCLIP specific length 461 | device = device, 462 | dtype = dtype, 463 | ) 464 | 465 | # Apply clip scaling factors 466 | pooled_prompt_embeds_1 = pooled_prompt_embeds_1 * clip_l_scale 467 | pooled_prompt_embeds_2 = pooled_prompt_embeds_2 * openclip_scale 468 | 469 | pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1) 470 | 471 | # Get T5 embeddings 472 | t5_prompt_embeds = self._get_t5_prompt_embeds( 473 | prompt = prompt_3, 474 | num_images_per_prompt = num_images_per_prompt, 475 | max_sequence_length = t5_length, # T5 specific length 476 | device = device, 477 | dtype = dtype 478 | ) 479 | # Get LLM embeddings 480 | llama3_prompt_embeds = self._get_llama3_prompt_embeds( 481 | prompt = prompt_4, 482 | num_images_per_prompt = num_images_per_prompt, 483 | max_sequence_length = llama_length, 484 | system_prompt = llm_system_prompt, # Add this parameter 485 | device = device, 486 | dtype = dtype 487 | ) 488 | 489 | # Apply T5 and LLM scaling factors 490 | t5_prompt_embeds = t5_prompt_embeds * t5_scale 491 | llama3_prompt_embeds = llama3_prompt_embeds * llama_scale 492 | 493 | prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds] 494 | 495 | return prompt_embeds, pooled_prompt_embeds 496 | 497 | def enable_vae_slicing(self): 498 | r""" 499 | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to 500 | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. 501 | """ 502 | self.vae.enable_slicing() 503 | 504 | def disable_vae_slicing(self): 505 | r""" 506 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to 507 | computing decoding in one step. 508 | """ 509 | self.vae.disable_slicing() 510 | 511 | def enable_vae_tiling(self): 512 | r""" 513 | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to 514 | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow 515 | processing larger images. 516 | """ 517 | self.vae.enable_tiling() 518 | 519 | def disable_vae_tiling(self): 520 | r""" 521 | Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to 522 | computing decoding in one step. 523 | """ 524 | self.vae.disable_tiling() 525 | 526 | def prepare_latents( 527 | self, 528 | batch_size, 529 | num_channels_latents, 530 | height, 531 | width, 532 | dtype, 533 | device, 534 | generator, 535 | latents=None, 536 | ): 537 | # VAE applies 8x compression on images but we must also account for packing which requires 538 | # latent height and width to be divisible by 2. 539 | height = 2 * (int(height) // (self.vae_scale_factor * 2)) 540 | width = 2 * (int(width) // (self.vae_scale_factor * 2)) 541 | 542 | shape = (batch_size, num_channels_latents, height, width) 543 | 544 | if latents is None: 545 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 546 | else: 547 | if latents.shape != shape: 548 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") 549 | latents = latents.to(device) 550 | return latents 551 | 552 | @property 553 | def guidance_scale(self): 554 | return self._guidance_scale 555 | 556 | @property 557 | def do_classifier_free_guidance(self): 558 | return self._guidance_scale > 1 559 | 560 | @property 561 | def joint_attention_kwargs(self): 562 | return self._joint_attention_kwargs 563 | 564 | @property 565 | def num_timesteps(self): 566 | return self._num_timesteps 567 | 568 | @property 569 | def interrupt(self): 570 | return self._interrupt 571 | 572 | @torch.no_grad() 573 | def __call__( 574 | self, 575 | prompt: Union[str, List[str]] = None, 576 | prompt_2: Optional[Union[str, List[str]]] = None, 577 | prompt_3: Optional[Union[str, List[str]]] = None, 578 | prompt_4: Optional[Union[str, List[str]]] = None, 579 | height: Optional[int] = None, 580 | width: Optional[int] = None, 581 | num_inference_steps: int = 50, 582 | sigmas: Optional[List[float]] = None, 583 | guidance_scale: float = 5.0, 584 | negative_prompt: Optional[Union[str, List[str]]] = None, 585 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 586 | negative_prompt_3: Optional[Union[str, List[str]]] = None, 587 | negative_prompt_4: Optional[Union[str, List[str]]] = None, 588 | num_images_per_prompt: Optional[int] = 1, 589 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 590 | latents: Optional[torch.FloatTensor] = None, 591 | prompt_embeds: Optional[torch.FloatTensor] = None, 592 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 593 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 594 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 595 | output_type: Optional[str] = "pil", 596 | return_dict: bool = True, 597 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 598 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 599 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 600 | max_sequence_length: int = 128, 601 | max_sequence_length_clip_l: Optional[int] = None, 602 | max_sequence_length_openclip: Optional[int] = None, 603 | max_sequence_length_t5: Optional[int] = None, 604 | max_sequence_length_llama: Optional[int] = None, 605 | llm_system_prompt: str = "", 606 | clip_l_scale: float = 1.0, 607 | openclip_scale: float = 1.0, 608 | t5_scale: float = 1.0, 609 | llama_scale: float = 1.0, 610 | ): 611 | # disable scaling entirely 612 | height = height or self.default_sample_size * self.vae_scale_factor 613 | width = width or self.default_sample_size * self.vae_scale_factor 614 | division = self.vae_scale_factor * 2 615 | 616 | # Force dimensions to be divisible by division without any area scaling 617 | width = int(width // division * division) 618 | height = int(height // division * division) 619 | 620 | # Ensure minimum dimensions 621 | width = max(width, division) 622 | height = max(height, division) 623 | 624 | self._guidance_scale = guidance_scale 625 | self._joint_attention_kwargs = joint_attention_kwargs 626 | self._interrupt = False 627 | 628 | # 2. Define call parameters 629 | if prompt is not None and isinstance(prompt, str): 630 | batch_size = 1 631 | elif prompt is not None and isinstance(prompt, list): 632 | batch_size = len(prompt) 633 | else: 634 | batch_size = prompt_embeds.shape[0] 635 | 636 | device = self._execution_device 637 | 638 | lora_scale = ( 639 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None 640 | ) 641 | ( 642 | prompt_embeds, 643 | negative_prompt_embeds, 644 | pooled_prompt_embeds, 645 | negative_pooled_prompt_embeds, 646 | ) = self.encode_prompt( 647 | prompt=prompt, 648 | prompt_2=prompt_2, 649 | prompt_3=prompt_3, 650 | prompt_4=prompt_4, 651 | negative_prompt=negative_prompt, 652 | negative_prompt_2=negative_prompt_2, 653 | negative_prompt_3=negative_prompt_3, 654 | negative_prompt_4=negative_prompt_4, 655 | do_classifier_free_guidance=self.do_classifier_free_guidance, 656 | prompt_embeds=prompt_embeds, 657 | negative_prompt_embeds=negative_prompt_embeds, 658 | pooled_prompt_embeds=pooled_prompt_embeds, 659 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 660 | device=device, 661 | num_images_per_prompt=num_images_per_prompt, 662 | max_sequence_length=max_sequence_length, 663 | lora_scale=lora_scale, 664 | llm_system_prompt=llm_system_prompt, 665 | clip_l_scale=clip_l_scale, 666 | openclip_scale=openclip_scale, 667 | t5_scale=t5_scale, 668 | llama_scale=llama_scale, 669 | ) 670 | 671 | if self.do_classifier_free_guidance: 672 | prompt_embeds_arr = [] 673 | for n, p in zip(negative_prompt_embeds, prompt_embeds): 674 | if len(n.shape) == 3: 675 | prompt_embeds_arr.append(torch.cat([n, p], dim=0)) 676 | else: 677 | prompt_embeds_arr.append(torch.cat([n, p], dim=1)) 678 | prompt_embeds = prompt_embeds_arr 679 | pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) 680 | 681 | # 4. Prepare latent variables 682 | num_channels_latents = self.transformer.config.in_channels 683 | latents = self.prepare_latents( 684 | batch_size * num_images_per_prompt, 685 | num_channels_latents, 686 | height, 687 | width, 688 | pooled_prompt_embeds.dtype, 689 | device, 690 | generator, 691 | latents, 692 | ) 693 | 694 | if latents.shape[-2] != latents.shape[-1]: 695 | B, C, H, W = latents.shape 696 | pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size 697 | 698 | img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1) 699 | img_ids = torch.zeros(pH, pW, 3) 700 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None] 701 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :] 702 | img_ids = img_ids.reshape(pH * pW, -1) 703 | img_ids_pad = torch.zeros(self.transformer.max_seq, 3) 704 | img_ids_pad[:pH*pW, :] = img_ids 705 | 706 | img_sizes = img_sizes.unsqueeze(0).to(latents.device) 707 | img_ids = img_ids_pad.unsqueeze(0).to(latents.device) 708 | if self.do_classifier_free_guidance: 709 | img_sizes = img_sizes.repeat(2 * B, 1) 710 | img_ids = img_ids.repeat(2 * B, 1, 1) 711 | else: 712 | img_sizes = img_ids = None 713 | 714 | # 5. Prepare timesteps 715 | mu = calculate_shift(self.transformer.max_seq) 716 | scheduler_kwargs = {"mu": mu} 717 | if isinstance(self.scheduler, FlowUniPCMultistepScheduler): 718 | self.scheduler.set_timesteps(num_inference_steps, device=device, shift=math.exp(mu)) 719 | timesteps = self.scheduler.timesteps 720 | else: 721 | timesteps, num_inference_steps = retrieve_timesteps( 722 | self.scheduler, 723 | num_inference_steps, 724 | device, 725 | sigmas=sigmas, 726 | **scheduler_kwargs, 727 | ) 728 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 729 | self._num_timesteps = len(timesteps) 730 | 731 | # 6. Denoising loop 732 | with self.progress_bar(total=num_inference_steps) as progress_bar: 733 | for i, t in enumerate(timesteps): 734 | if self.interrupt: 735 | continue 736 | 737 | # expand the latents if we are doing classifier free guidance 738 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents 739 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 740 | timestep = t.expand(latent_model_input.shape[0]) 741 | 742 | if latent_model_input.shape[-2] != latent_model_input.shape[-1]: 743 | B, C, H, W = latent_model_input.shape 744 | patch_size = self.transformer.config.patch_size 745 | pH, pW = H // patch_size, W // patch_size 746 | out = torch.zeros( 747 | (B, C, self.transformer.max_seq, patch_size * patch_size), 748 | dtype=latent_model_input.dtype, 749 | device=latent_model_input.device 750 | ) 751 | latent_model_input = einops.rearrange(latent_model_input, 'B C (H p1) (W p2) -> B C (H W) (p1 p2)', p1=patch_size, p2=patch_size) 752 | out[:, :, 0:pH*pW] = latent_model_input 753 | latent_model_input = out 754 | 755 | noise_pred = self.transformer( 756 | hidden_states = latent_model_input, 757 | timesteps = timestep, 758 | encoder_hidden_states = prompt_embeds, 759 | pooled_embeds = pooled_prompt_embeds, 760 | img_sizes = img_sizes, 761 | img_ids = img_ids, 762 | return_dict = False, 763 | )[0] 764 | noise_pred = -noise_pred 765 | 766 | # perform guidance 767 | if self.do_classifier_free_guidance: 768 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 769 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) 770 | 771 | # compute the previous noisy sample x_t -> x_t-1 772 | latents_dtype = latents.dtype 773 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 774 | 775 | if latents.dtype != latents_dtype: 776 | if torch.backends.mps.is_available(): 777 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 778 | latents = latents.to(latents_dtype) 779 | 780 | if callback_on_step_end is not None: 781 | callback_kwargs = {} 782 | for k in callback_on_step_end_tensor_inputs: 783 | callback_kwargs[k] = locals()[k] 784 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 785 | 786 | latents = callback_outputs.pop("latents", latents) 787 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 788 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 789 | 790 | # call the callback, if provided 791 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 792 | progress_bar.update() 793 | 794 | if XLA_AVAILABLE: 795 | xm.mark_step() 796 | 797 | if output_type == "latent": 798 | image = latents 799 | 800 | else: 801 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor 802 | 803 | image = self.vae.decode(latents, return_dict=False)[0] 804 | image = self.image_processor.postprocess(image, output_type=output_type) 805 | 806 | # Offload all models 807 | self.maybe_free_model_hooks() 808 | 809 | if not return_dict: 810 | return (image,) 811 | 812 | return HiDreamImagePipelineOutput(images=image) -------------------------------------------------------------------------------- /hi_diffusers/pipelines/hidream_image/pipeline_hidream_image_to_image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import einops 4 | from typing import Any, Callable, Dict, List, Optional, Union 5 | from .pipeline_hidream_image import HiDreamImagePipeline, calculate_shift, retrieve_timesteps 6 | from .pipeline_output import HiDreamImagePipelineOutput 7 | from ...schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler 8 | from ...schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler 9 | from diffusers.utils import is_torch_xla_available 10 | 11 | if is_torch_xla_available(): 12 | import torch_xla.core.xla_model as xm 13 | XLA_AVAILABLE = True 14 | else: 15 | XLA_AVAILABLE = False 16 | 17 | class HiDreamImageToImagePipeline(HiDreamImagePipeline): 18 | @torch.no_grad() 19 | def __call__( 20 | self, 21 | prompt: Union[str, List[str]] = None, 22 | prompt_2: Optional[Union[str, List[str]]] = None, 23 | prompt_3: Optional[Union[str, List[str]]] = None, 24 | prompt_4: Optional[Union[str, List[str]]] = None, 25 | height: Optional[int] = None, 26 | width: Optional[int] = None, 27 | num_inference_steps: int = 50, 28 | sigmas: Optional[List[float]] = None, 29 | guidance_scale: float = 5.0, 30 | negative_prompt: Optional[Union[str, List[str]]] = None, 31 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 32 | negative_prompt_3: Optional[Union[str, List[str]]] = None, 33 | negative_prompt_4: Optional[Union[str, List[str]]] = None, 34 | num_images_per_prompt: Optional[int] = 1, 35 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 36 | latents: Optional[torch.FloatTensor] = None, 37 | prompt_embeds: Optional[torch.FloatTensor] = None, 38 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 39 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 40 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 41 | output_type: Optional[str] = "pil", 42 | return_dict: bool = True, 43 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 44 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 45 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 46 | max_sequence_length: int = 128, 47 | max_sequence_length_clip_l: Optional[int] = None, 48 | max_sequence_length_openclip: Optional[int] = None, 49 | max_sequence_length_t5: Optional[int] = None, 50 | max_sequence_length_llama: Optional[int] = None, 51 | llm_system_prompt: str = "You are a creative AI assistant that helps create detailed, vivid images based on user descriptions.", 52 | clip_l_scale: float = 1.0, 53 | openclip_scale: float = 1.0, 54 | t5_scale: float = 1.0, 55 | llama_scale: float = 1.0, 56 | # Add img2img specific parameters 57 | init_image: Optional[torch.FloatTensor] = None, 58 | denoising_strength: float = 0.75, 59 | ): 60 | # Handle dimensions 61 | height = height or self.default_sample_size * self.vae_scale_factor 62 | width = width or self.default_sample_size * self.vae_scale_factor 63 | division = self.vae_scale_factor * 2 64 | 65 | # Force dimensions to be divisible by division without any area scaling 66 | width = int(width // division * division) 67 | height = int(height // division * division) 68 | 69 | # Ensure minimum dimensions 70 | width = max(width, division) 71 | height = max(height, division) 72 | 73 | self._guidance_scale = guidance_scale 74 | self._joint_attention_kwargs = joint_attention_kwargs 75 | self._interrupt = False 76 | 77 | # Define call parameters 78 | if prompt is not None and isinstance(prompt, str): 79 | batch_size = 1 80 | elif prompt is not None and isinstance(prompt, list): 81 | batch_size = len(prompt) 82 | else: 83 | batch_size = prompt_embeds.shape[0] 84 | 85 | device = self._execution_device 86 | lora_scale = ( 87 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None 88 | ) 89 | 90 | # Encode prompt 91 | ( 92 | prompt_embeds, 93 | negative_prompt_embeds, 94 | pooled_prompt_embeds, 95 | negative_pooled_prompt_embeds, 96 | ) = self.encode_prompt( 97 | prompt=prompt, 98 | prompt_2=prompt_2, 99 | prompt_3=prompt_3, 100 | prompt_4=prompt_4, 101 | negative_prompt=negative_prompt, 102 | negative_prompt_2=negative_prompt_2, 103 | negative_prompt_3=negative_prompt_3, 104 | negative_prompt_4=negative_prompt_4, 105 | do_classifier_free_guidance=self.do_classifier_free_guidance, 106 | prompt_embeds=prompt_embeds, 107 | negative_prompt_embeds=negative_prompt_embeds, 108 | pooled_prompt_embeds=pooled_prompt_embeds, 109 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 110 | device=device, 111 | num_images_per_prompt=num_images_per_prompt, 112 | max_sequence_length=max_sequence_length, 113 | max_sequence_length_clip_l=max_sequence_length_clip_l, 114 | max_sequence_length_openclip=max_sequence_length_openclip, 115 | max_sequence_length_t5=max_sequence_length_t5, 116 | max_sequence_length_llama=max_sequence_length_llama, 117 | llm_system_prompt=llm_system_prompt, 118 | clip_l_scale=clip_l_scale, 119 | openclip_scale=openclip_scale, 120 | t5_scale=t5_scale, 121 | llama_scale=llama_scale, 122 | lora_scale=lora_scale, 123 | ) 124 | 125 | if self.do_classifier_free_guidance: 126 | prompt_embeds_arr = [] 127 | for n, p in zip(negative_prompt_embeds, prompt_embeds): 128 | if len(n.shape) == 3: 129 | prompt_embeds_arr.append(torch.cat([n, p], dim=0)) 130 | else: 131 | prompt_embeds_arr.append(torch.cat([n, p], dim=1)) 132 | prompt_embeds = prompt_embeds_arr 133 | pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) 134 | 135 | # Prepare latent variables - this is where we handle img2img 136 | num_channels_latents = self.transformer.config.in_channels 137 | 138 | # If we have an init_image, we want to encode it to latents 139 | if init_image is not None: 140 | # Preprocess the input image to latent representation 141 | init_image = init_image.to(device=device, dtype=self.vae.dtype) 142 | 143 | # Ensure correct shape [B, C, H, W] 144 | # ComfyUI typically provides [B, H, W, C] 145 | if init_image.shape[3] == 3: # [B, H, W, C] 146 | init_image = init_image.permute(0, 3, 1, 2) 147 | 148 | # Scale to [-1, 1] 149 | init_image = 2 * init_image - 1.0 150 | 151 | # Encode the image to latent space 152 | latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) 153 | latents = latents * self.vae.config.scaling_factor 154 | 155 | # If we're working with a batch of 1, repeat for each image_per_prompt 156 | if latents.shape[0] == 1 and batch_size * num_images_per_prompt > 1: 157 | latents = latents.repeat(batch_size * num_images_per_prompt, 1, 1, 1) 158 | else: 159 | # For regular txt2img, prepare random latents 160 | latents = self.prepare_latents( 161 | batch_size * num_images_per_prompt, 162 | num_channels_latents, 163 | height, 164 | width, 165 | pooled_prompt_embeds.dtype, 166 | device, 167 | generator, 168 | latents, 169 | ) 170 | 171 | # Prepare for different aspect ratios 172 | if latents.shape[-2] != latents.shape[-1]: 173 | B, C, H, W = latents.shape 174 | pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size 175 | img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1) 176 | img_ids = torch.zeros(pH, pW, 3) 177 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None] 178 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :] 179 | img_ids = img_ids.reshape(pH * pW, -1) 180 | img_ids_pad = torch.zeros(self.transformer.max_seq, 3) 181 | img_ids_pad[:pH*pW, :] = img_ids 182 | img_sizes = img_sizes.unsqueeze(0).to(latents.device) 183 | img_ids = img_ids_pad.unsqueeze(0).to(latents.device) 184 | if self.do_classifier_free_guidance: 185 | img_sizes = img_sizes.repeat(2 * B, 1) 186 | img_ids = img_ids.repeat(2 * B, 1, 1) 187 | else: 188 | img_sizes = img_ids = None 189 | 190 | # Prepare timesteps 191 | mu = calculate_shift(self.transformer.max_seq) 192 | scheduler_kwargs = {"mu": mu} 193 | 194 | if isinstance(self.scheduler, FlowUniPCMultistepScheduler): 195 | self.scheduler.set_timesteps(num_inference_steps, device=device, shift=math.exp(mu)) 196 | timesteps = self.scheduler.timesteps 197 | else: 198 | timesteps, num_inference_steps = retrieve_timesteps( 199 | self.scheduler, 200 | num_inference_steps, 201 | device, 202 | sigmas=sigmas, 203 | **scheduler_kwargs, 204 | ) 205 | 206 | # For img2img, we need to modify the timesteps based on denoising_strength 207 | if init_image is not None and denoising_strength > 0.0: 208 | # Calculate the starting timestep based on denoising strength 209 | start_step = int(num_inference_steps * (1.0 - denoising_strength)) 210 | 211 | # Skip steps based on denoising strength 212 | if start_step > 0: 213 | timesteps = timesteps[start_step:] 214 | print(f"Starting denoising from step {start_step}/{num_inference_steps} (strength: {denoising_strength})") 215 | 216 | # Create noise 217 | noise = torch.randn(latents.shape, dtype=latents.dtype, device=device, generator=generator) 218 | 219 | # Get starting timestep 220 | t_start = timesteps[0].unsqueeze(0) 221 | 222 | # Set the scheduler's step index for proper noise scaling 223 | self.scheduler._step_index = start_step 224 | 225 | # Apply noise using the appropriate scheduler method 226 | if isinstance(self.scheduler, FlowUniPCMultistepScheduler): 227 | print(f"Using UniPC add_noise with timestep {t_start}") 228 | latents = self.scheduler.add_noise( 229 | original_samples=latents, 230 | noise=noise, 231 | timesteps=t_start 232 | ) 233 | else: # FlashFlowMatchEulerDiscreteScheduler or variants 234 | print(f"Using FlashFlow scale_noise with timestep {t_start}") 235 | latents = self.scheduler.scale_noise( 236 | sample=latents, 237 | timestep=t_start, 238 | noise=noise 239 | ) 240 | 241 | # Denoising loop 242 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 243 | self._num_timesteps = len(timesteps) 244 | 245 | with self.progress_bar(total=len(timesteps)) as progress_bar: 246 | for i, t in enumerate(timesteps): 247 | if self.interrupt: 248 | continue 249 | 250 | # expand the latents if we are doing classifier free guidance 251 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents 252 | 253 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 254 | timestep = t.expand(latent_model_input.shape[0]) 255 | 256 | if latent_model_input.shape[-2] != latent_model_input.shape[-1]: 257 | B, C, H, W = latent_model_input.shape 258 | patch_size = self.transformer.config.patch_size 259 | pH, pW = H // patch_size, W // patch_size 260 | out = torch.zeros( 261 | (B, C, self.transformer.max_seq, patch_size * patch_size), 262 | dtype=latent_model_input.dtype, 263 | device=latent_model_input.device 264 | ) 265 | latent_model_input = einops.rearrange(latent_model_input, 'B C (H p1) (W p2) -> B C (H W) (p1 p2)', p1=patch_size, p2=patch_size) 266 | out[:, :, 0:pH*pW] = latent_model_input 267 | latent_model_input = out 268 | 269 | noise_pred = self.transformer( 270 | hidden_states = latent_model_input, 271 | timesteps = timestep, 272 | encoder_hidden_states = prompt_embeds, 273 | pooled_embeds = pooled_prompt_embeds, 274 | img_sizes = img_sizes, 275 | img_ids = img_ids, 276 | return_dict = False, 277 | )[0] 278 | 279 | noise_pred = -noise_pred 280 | 281 | # perform guidance 282 | if self.do_classifier_free_guidance: 283 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 284 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) 285 | 286 | # compute the previous noisy sample x_t -> x_t-1 287 | latents_dtype = latents.dtype 288 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 289 | 290 | if latents.dtype != latents_dtype: 291 | if torch.backends.mps.is_available(): 292 | # some platforms (eg. apple mps) misbehave due to a pytorch bug 293 | latents = latents.to(latents_dtype) 294 | 295 | if callback_on_step_end is not None: 296 | callback_kwargs = {} 297 | for k in callback_on_step_end_tensor_inputs: 298 | callback_kwargs[k] = locals()[k] 299 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 300 | latents = callback_outputs.pop("latents", latents) 301 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 302 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 303 | 304 | # call the callback, if provided 305 | progress_bar.update() 306 | 307 | if XLA_AVAILABLE: 308 | xm.mark_step() 309 | 310 | # Post-processing 311 | if output_type == "latent": 312 | image = latents 313 | else: 314 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor 315 | image = self.vae.decode(latents, return_dict=False)[0] 316 | image = self.image_processor.postprocess(image, output_type=output_type) 317 | 318 | # Offload all models 319 | self.maybe_free_model_hooks() 320 | 321 | if not return_dict: 322 | return (image,) 323 | 324 | return HiDreamImagePipelineOutput(images=image) -------------------------------------------------------------------------------- /hi_diffusers/pipelines/hidream_image/pipeline_output.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Union 3 | 4 | import numpy as np 5 | import PIL.Image 6 | 7 | from diffusers.utils import BaseOutput 8 | 9 | 10 | @dataclass 11 | class HiDreamImagePipelineOutput(BaseOutput): 12 | """ 13 | Output class for HiDreamImage pipelines. 14 | 15 | Args: 16 | images (`List[PIL.Image.Image]` or `np.ndarray`) 17 | List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, 18 | num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. 19 | """ 20 | 21 | images: Union[List[PIL.Image.Image], np.ndarray] -------------------------------------------------------------------------------- /hi_diffusers/schedulers/flash_flow_match.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | from dataclasses import dataclass 17 | from typing import List, Optional, Tuple, Union 18 | 19 | import numpy as np 20 | import torch 21 | from diffusers.configuration_utils import ConfigMixin, register_to_config 22 | from diffusers.schedulers.scheduling_utils import SchedulerMixin 23 | from diffusers.utils import BaseOutput, is_scipy_available, logging 24 | from diffusers.utils.torch_utils import randn_tensor 25 | 26 | if is_scipy_available(): 27 | import scipy.stats 28 | 29 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 30 | 31 | 32 | @dataclass 33 | class FlashFlowMatchEulerDiscreteSchedulerOutput(BaseOutput): 34 | """ 35 | Output class for the scheduler's `step` function output. 36 | 37 | Args: 38 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 39 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 40 | denoising loop. 41 | """ 42 | 43 | prev_sample: torch.FloatTensor 44 | 45 | 46 | class FlashFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): 47 | """ 48 | Euler scheduler. 49 | 50 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic 51 | methods the library implements for all schedulers such as loading and saving. 52 | 53 | Args: 54 | num_train_timesteps (`int`, defaults to 1000): 55 | The number of diffusion steps to train the model. 56 | timestep_spacing (`str`, defaults to `"linspace"`): 57 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 58 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 59 | shift (`float`, defaults to 1.0): 60 | The shift value for the timestep schedule. 61 | """ 62 | 63 | _compatibles = [] 64 | order = 1 65 | 66 | @register_to_config 67 | def __init__( 68 | self, 69 | num_train_timesteps: int = 1000, 70 | shift: float = 1.0, 71 | use_dynamic_shifting=False, 72 | base_shift: Optional[float] = 0.5, 73 | max_shift: Optional[float] = 1.15, 74 | base_image_seq_len: Optional[int] = 256, 75 | max_image_seq_len: Optional[int] = 4096, 76 | invert_sigmas: bool = False, 77 | use_karras_sigmas: Optional[bool] = False, 78 | use_exponential_sigmas: Optional[bool] = False, 79 | use_beta_sigmas: Optional[bool] = False, 80 | ): 81 | if self.config.use_beta_sigmas and not is_scipy_available(): 82 | raise ImportError("Make sure to install scipy if you want to use beta sigmas.") 83 | if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: 84 | raise ValueError( 85 | "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." 86 | ) 87 | timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() 88 | timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) 89 | 90 | sigmas = timesteps / num_train_timesteps 91 | if not use_dynamic_shifting: 92 | # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution 93 | sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) 94 | 95 | self.timesteps = sigmas * num_train_timesteps 96 | 97 | self._step_index = None 98 | self._begin_index = None 99 | 100 | self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication 101 | self.sigma_min = self.sigmas[-1].item() 102 | self.sigma_max = self.sigmas[0].item() 103 | 104 | @property 105 | def step_index(self): 106 | """ 107 | The index counter for current timestep. It will increase 1 after each scheduler step. 108 | """ 109 | return self._step_index 110 | 111 | @property 112 | def begin_index(self): 113 | """ 114 | The index for the first timestep. It should be set from pipeline with `set_begin_index` method. 115 | """ 116 | return self._begin_index 117 | 118 | # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index 119 | def set_begin_index(self, begin_index: int = 0): 120 | """ 121 | Sets the begin index for the scheduler. This function should be run from pipeline before the inference. 122 | 123 | Args: 124 | begin_index (`int`): 125 | The begin index for the scheduler. 126 | """ 127 | self._begin_index = begin_index 128 | 129 | def scale_noise( 130 | self, 131 | sample: torch.FloatTensor, 132 | timestep: Union[float, torch.FloatTensor], 133 | noise: Optional[torch.FloatTensor] = None, 134 | ) -> torch.FloatTensor: 135 | """ 136 | Forward process in flow-matching 137 | 138 | Args: 139 | sample (`torch.FloatTensor`): 140 | The input sample. 141 | timestep (`int`, *optional*): 142 | The current timestep in the diffusion chain. 143 | 144 | Returns: 145 | `torch.FloatTensor`: 146 | A scaled input sample. 147 | """ 148 | # Make sure sigmas and timesteps have the same device and dtype as original_samples 149 | sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) 150 | 151 | if sample.device.type == "mps" and torch.is_floating_point(timestep): 152 | # mps does not support float64 153 | schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) 154 | timestep = timestep.to(sample.device, dtype=torch.float32) 155 | else: 156 | schedule_timesteps = self.timesteps.to(sample.device) 157 | timestep = timestep.to(sample.device) 158 | 159 | # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index 160 | if self.begin_index is None: 161 | step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] 162 | elif self.step_index is not None: 163 | # add_noise is called after first denoising step (for inpainting) 164 | step_indices = [self.step_index] * timestep.shape[0] 165 | else: 166 | # add noise is called before first denoising step to create initial latent(img2img) 167 | step_indices = [self.begin_index] * timestep.shape[0] 168 | 169 | sigma = sigmas[step_indices].flatten() 170 | while len(sigma.shape) < len(sample.shape): 171 | sigma = sigma.unsqueeze(-1) 172 | 173 | sample = sigma * noise + (1.0 - sigma) * sample 174 | 175 | return sample 176 | 177 | def _sigma_to_t(self, sigma): 178 | return sigma * self.config.num_train_timesteps 179 | 180 | def time_shift(self, mu: float, sigma: float, t: torch.Tensor): 181 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) 182 | 183 | def set_timesteps( 184 | self, 185 | num_inference_steps: int = None, 186 | device: Union[str, torch.device] = None, 187 | sigmas: Optional[List[float]] = None, 188 | mu: Optional[float] = None, 189 | ): 190 | """ 191 | Sets the discrete timesteps used for the diffusion chain (to be run before inference). 192 | 193 | Args: 194 | num_inference_steps (`int`): 195 | The number of diffusion steps used when generating samples with a pre-trained model. 196 | device (`str` or `torch.device`, *optional*): 197 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 198 | """ 199 | if self.config.use_dynamic_shifting and mu is None: 200 | raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") 201 | 202 | if sigmas is None: 203 | timesteps = np.linspace( 204 | self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps 205 | ) 206 | 207 | sigmas = timesteps / self.config.num_train_timesteps 208 | else: 209 | sigmas = np.array(sigmas).astype(np.float32) 210 | num_inference_steps = len(sigmas) 211 | self.num_inference_steps = num_inference_steps 212 | 213 | if self.config.use_dynamic_shifting: 214 | sigmas = self.time_shift(mu, 1.0, sigmas) 215 | else: 216 | sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) 217 | 218 | if self.config.use_karras_sigmas: 219 | sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) 220 | 221 | elif self.config.use_exponential_sigmas: 222 | sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) 223 | 224 | elif self.config.use_beta_sigmas: 225 | sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) 226 | 227 | sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) 228 | timesteps = sigmas * self.config.num_train_timesteps 229 | 230 | if self.config.invert_sigmas: 231 | sigmas = 1.0 - sigmas 232 | timesteps = sigmas * self.config.num_train_timesteps 233 | sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) 234 | else: 235 | sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) 236 | 237 | self.timesteps = timesteps.to(device=device) 238 | self.sigmas = sigmas 239 | self._step_index = None 240 | self._begin_index = None 241 | 242 | def index_for_timestep(self, timestep, schedule_timesteps=None): 243 | if schedule_timesteps is None: 244 | schedule_timesteps = self.timesteps 245 | 246 | indices = (schedule_timesteps == timestep).nonzero() 247 | 248 | # The sigma index that is taken for the **very** first `step` 249 | # is always the second index (or the last index if there is only 1) 250 | # This way we can ensure we don't accidentally skip a sigma in 251 | # case we start in the middle of the denoising schedule (e.g. for image-to-image) 252 | pos = 1 if len(indices) > 1 else 0 253 | 254 | return indices[pos].item() 255 | 256 | def _init_step_index(self, timestep): 257 | if self.begin_index is None: 258 | if isinstance(timestep, torch.Tensor): 259 | timestep = timestep.to(self.timesteps.device) 260 | self._step_index = self.index_for_timestep(timestep) 261 | else: 262 | self._step_index = self._begin_index 263 | 264 | def step( 265 | self, 266 | model_output: torch.FloatTensor, 267 | timestep: Union[float, torch.FloatTensor], 268 | sample: torch.FloatTensor, 269 | s_churn: float = 0.0, 270 | s_tmin: float = 0.0, 271 | s_tmax: float = float("inf"), 272 | s_noise: float = 1.0, 273 | generator: Optional[torch.Generator] = None, 274 | return_dict: bool = True, 275 | ) -> Union[FlashFlowMatchEulerDiscreteSchedulerOutput, Tuple]: 276 | """ 277 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 278 | process from the learned model outputs (most often the predicted noise). 279 | 280 | Args: 281 | model_output (`torch.FloatTensor`): 282 | The direct output from learned diffusion model. 283 | timestep (`float`): 284 | The current discrete timestep in the diffusion chain. 285 | sample (`torch.FloatTensor`): 286 | A current instance of a sample created by the diffusion process. 287 | s_churn (`float`): 288 | s_tmin (`float`): 289 | s_tmax (`float`): 290 | s_noise (`float`, defaults to 1.0): 291 | Scaling factor for noise added to the sample. 292 | generator (`torch.Generator`, *optional*): 293 | A random number generator. 294 | return_dict (`bool`): 295 | Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or 296 | tuple. 297 | 298 | Returns: 299 | [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: 300 | If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is 301 | returned, otherwise a tuple is returned where the first element is the sample tensor. 302 | """ 303 | 304 | if ( 305 | isinstance(timestep, int) 306 | or isinstance(timestep, torch.IntTensor) 307 | or isinstance(timestep, torch.LongTensor) 308 | ): 309 | raise ValueError( 310 | ( 311 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 312 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 313 | " one of the `scheduler.timesteps` as a timestep." 314 | ), 315 | ) 316 | 317 | if self.step_index is None: 318 | self._init_step_index(timestep) 319 | 320 | # Upcast to avoid precision issues when computing prev_sample 321 | 322 | sigma = self.sigmas[self.step_index] 323 | 324 | # Upcast to avoid precision issues when computing prev_sample 325 | sample = sample.to(torch.float32) 326 | 327 | denoised = sample - model_output * sigma 328 | 329 | if self.step_index < self.num_inference_steps - 1: 330 | sigma_next = self.sigmas[self.step_index + 1] 331 | noise = randn_tensor( 332 | model_output.shape, 333 | generator=generator, 334 | device=model_output.device, 335 | dtype=denoised.dtype, 336 | ) 337 | sample = sigma_next * noise + (1.0 - sigma_next) * denoised 338 | 339 | self._step_index += 1 340 | sample = sample.to(model_output.dtype) 341 | 342 | if not return_dict: 343 | return (sample,) 344 | 345 | return FlashFlowMatchEulerDiscreteSchedulerOutput(prev_sample=sample) 346 | 347 | # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras 348 | def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: 349 | """Constructs the noise schedule of Karras et al. (2022).""" 350 | 351 | # Hack to make sure that other schedulers which copy this function don't break 352 | # TODO: Add this logic to the other schedulers 353 | if hasattr(self.config, "sigma_min"): 354 | sigma_min = self.config.sigma_min 355 | else: 356 | sigma_min = None 357 | 358 | if hasattr(self.config, "sigma_max"): 359 | sigma_max = self.config.sigma_max 360 | else: 361 | sigma_max = None 362 | 363 | sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() 364 | sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() 365 | 366 | rho = 7.0 # 7.0 is the value used in the paper 367 | ramp = np.linspace(0, 1, num_inference_steps) 368 | min_inv_rho = sigma_min ** (1 / rho) 369 | max_inv_rho = sigma_max ** (1 / rho) 370 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho 371 | return sigmas 372 | 373 | # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential 374 | def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: 375 | """Constructs an exponential noise schedule.""" 376 | 377 | # Hack to make sure that other schedulers which copy this function don't break 378 | # TODO: Add this logic to the other schedulers 379 | if hasattr(self.config, "sigma_min"): 380 | sigma_min = self.config.sigma_min 381 | else: 382 | sigma_min = None 383 | 384 | if hasattr(self.config, "sigma_max"): 385 | sigma_max = self.config.sigma_max 386 | else: 387 | sigma_max = None 388 | 389 | sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() 390 | sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() 391 | 392 | sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) 393 | return sigmas 394 | 395 | # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta 396 | def _convert_to_beta( 397 | self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 398 | ) -> torch.Tensor: 399 | """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" 400 | 401 | # Hack to make sure that other schedulers which copy this function don't break 402 | # TODO: Add this logic to the other schedulers 403 | if hasattr(self.config, "sigma_min"): 404 | sigma_min = self.config.sigma_min 405 | else: 406 | sigma_min = None 407 | 408 | if hasattr(self.config, "sigma_max"): 409 | sigma_max = self.config.sigma_max 410 | else: 411 | sigma_max = None 412 | 413 | sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() 414 | sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() 415 | 416 | sigmas = np.array( 417 | [ 418 | sigma_min + (ppf * (sigma_max - sigma_min)) 419 | for ppf in [ 420 | scipy.stats.beta.ppf(timestep, alpha, beta) 421 | for timestep in 1 - np.linspace(0, 1, num_inference_steps) 422 | ] 423 | ] 424 | ) 425 | return sigmas 426 | 427 | def __len__(self): 428 | return self.config.num_train_timesteps 429 | -------------------------------------------------------------------------------- /hi_diffusers/schedulers/fm_solvers_unipc.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py 2 | # Convert unipc for flow matching 3 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 4 | 5 | import math 6 | from typing import List, Optional, Tuple, Union 7 | 8 | import numpy as np 9 | import torch 10 | from diffusers.configuration_utils import ConfigMixin, register_to_config 11 | from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, 12 | SchedulerMixin, 13 | SchedulerOutput) 14 | from diffusers.utils import deprecate, is_scipy_available 15 | 16 | if is_scipy_available(): 17 | import scipy.stats 18 | 19 | 20 | class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): 21 | """ 22 | `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. 23 | 24 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic 25 | methods the library implements for all schedulers such as loading and saving. 26 | 27 | Args: 28 | num_train_timesteps (`int`, defaults to 1000): 29 | The number of diffusion steps to train the model. 30 | solver_order (`int`, default `2`): 31 | The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` 32 | due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for 33 | unconditional sampling. 34 | prediction_type (`str`, defaults to "flow_prediction"): 35 | Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts 36 | the flow of the diffusion process. 37 | thresholding (`bool`, defaults to `False`): 38 | Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such 39 | as Stable Diffusion. 40 | dynamic_thresholding_ratio (`float`, defaults to 0.995): 41 | The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. 42 | sample_max_value (`float`, defaults to 1.0): 43 | The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. 44 | predict_x0 (`bool`, defaults to `True`): 45 | Whether to use the updating algorithm on the predicted x0. 46 | solver_type (`str`, default `bh2`): 47 | Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` 48 | otherwise. 49 | lower_order_final (`bool`, default `True`): 50 | Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can 51 | stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. 52 | disable_corrector (`list`, default `[]`): 53 | Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` 54 | and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is 55 | usually disabled during the first few steps. 56 | solver_p (`SchedulerMixin`, default `None`): 57 | Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. 58 | use_karras_sigmas (`bool`, *optional*, defaults to `False`): 59 | Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, 60 | the sigmas are determined according to a sequence of noise levels {σi}. 61 | use_exponential_sigmas (`bool`, *optional*, defaults to `False`): 62 | Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. 63 | timestep_spacing (`str`, defaults to `"linspace"`): 64 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 65 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 66 | steps_offset (`int`, defaults to 0): 67 | An offset added to the inference steps, as required by some model families. 68 | final_sigmas_type (`str`, defaults to `"zero"`): 69 | The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final 70 | sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. 71 | """ 72 | 73 | _compatibles = [e.name for e in KarrasDiffusionSchedulers] 74 | order = 1 75 | 76 | @register_to_config 77 | def __init__( 78 | self, 79 | num_train_timesteps: int = 1000, 80 | solver_order: int = 2, 81 | prediction_type: str = "flow_prediction", 82 | shift: Optional[float] = 1.0, 83 | use_dynamic_shifting=False, 84 | thresholding: bool = False, 85 | dynamic_thresholding_ratio: float = 0.995, 86 | sample_max_value: float = 1.0, 87 | predict_x0: bool = True, 88 | solver_type: str = "bh2", 89 | lower_order_final: bool = True, 90 | disable_corrector: List[int] = [], 91 | solver_p: SchedulerMixin = None, 92 | timestep_spacing: str = "linspace", 93 | steps_offset: int = 0, 94 | final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" 95 | ): 96 | 97 | if solver_type not in ["bh1", "bh2"]: 98 | if solver_type in ["midpoint", "heun", "logrho"]: 99 | self.register_to_config(solver_type="bh2") 100 | else: 101 | raise NotImplementedError( 102 | f"{solver_type} is not implemented for {self.__class__}") 103 | 104 | self.predict_x0 = predict_x0 105 | # setable values 106 | self.num_inference_steps = None 107 | alphas = np.linspace(1, 1 / num_train_timesteps, 108 | num_train_timesteps)[::-1].copy() 109 | sigmas = 1.0 - alphas 110 | sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) 111 | 112 | if not use_dynamic_shifting: 113 | # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution 114 | sigmas = shift * sigmas / (1 + 115 | (shift - 1) * sigmas) # pyright: ignore 116 | 117 | self.sigmas = sigmas 118 | self.timesteps = sigmas * num_train_timesteps 119 | 120 | self.model_outputs = [None] * solver_order 121 | self.timestep_list = [None] * solver_order 122 | self.lower_order_nums = 0 123 | self.disable_corrector = disable_corrector 124 | self.solver_p = solver_p 125 | self.last_sample = None 126 | self._step_index = None 127 | self._begin_index = None 128 | 129 | self.sigmas = self.sigmas.to( 130 | "cpu") # to avoid too much CPU/GPU communication 131 | self.sigma_min = self.sigmas[-1].item() 132 | self.sigma_max = self.sigmas[0].item() 133 | 134 | @property 135 | def step_index(self): 136 | """ 137 | The index counter for current timestep. It will increase 1 after each scheduler step. 138 | """ 139 | return self._step_index 140 | 141 | @property 142 | def begin_index(self): 143 | """ 144 | The index for the first timestep. It should be set from pipeline with `set_begin_index` method. 145 | """ 146 | return self._begin_index 147 | 148 | # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index 149 | def set_begin_index(self, begin_index: int = 0): 150 | """ 151 | Sets the begin index for the scheduler. This function should be run from pipeline before the inference. 152 | 153 | Args: 154 | begin_index (`int`): 155 | The begin index for the scheduler. 156 | """ 157 | self._begin_index = begin_index 158 | 159 | # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps 160 | def set_timesteps( 161 | self, 162 | num_inference_steps: Union[int, None] = None, 163 | device: Union[str, torch.device] = None, 164 | sigmas: Optional[List[float]] = None, 165 | mu: Optional[Union[float, None]] = None, 166 | shift: Optional[Union[float, None]] = None, 167 | ): 168 | """ 169 | Sets the discrete timesteps used for the diffusion chain (to be run before inference). 170 | Args: 171 | num_inference_steps (`int`): 172 | Total number of the spacing of the time steps. 173 | device (`str` or `torch.device`, *optional*): 174 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 175 | """ 176 | 177 | if self.config.use_dynamic_shifting and mu is None: 178 | raise ValueError( 179 | " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" 180 | ) 181 | 182 | if sigmas is None: 183 | sigmas = np.linspace(self.sigma_max, self.sigma_min, 184 | num_inference_steps + 185 | 1).copy()[:-1] # pyright: ignore 186 | 187 | if self.config.use_dynamic_shifting: 188 | sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore 189 | else: 190 | if shift is None: 191 | shift = self.config.shift 192 | sigmas = shift * sigmas / (1 + 193 | (shift - 1) * sigmas) # pyright: ignore 194 | 195 | if self.config.final_sigmas_type == "sigma_min": 196 | sigma_last = ((1 - self.alphas_cumprod[0]) / 197 | self.alphas_cumprod[0])**0.5 198 | elif self.config.final_sigmas_type == "zero": 199 | sigma_last = 0 200 | else: 201 | raise ValueError( 202 | f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" 203 | ) 204 | 205 | timesteps = sigmas * self.config.num_train_timesteps 206 | sigmas = np.concatenate([sigmas, [sigma_last] 207 | ]).astype(np.float32) # pyright: ignore 208 | 209 | self.sigmas = torch.from_numpy(sigmas) 210 | self.timesteps = torch.from_numpy(timesteps).to( 211 | device=device, dtype=torch.int64) 212 | 213 | self.num_inference_steps = len(timesteps) 214 | 215 | self.model_outputs = [ 216 | None, 217 | ] * self.config.solver_order 218 | self.lower_order_nums = 0 219 | self.last_sample = None 220 | if self.solver_p: 221 | self.solver_p.set_timesteps(self.num_inference_steps, device=device) 222 | 223 | # add an index counter for schedulers that allow duplicated timesteps 224 | self._step_index = None 225 | self._begin_index = None 226 | self.sigmas = self.sigmas.to( 227 | "cpu") # to avoid too much CPU/GPU communication 228 | 229 | # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample 230 | def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: 231 | """ 232 | "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the 233 | prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by 234 | s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing 235 | pixels from saturation at each step. We find that dynamic thresholding results in significantly better 236 | photorealism as well as better image-text alignment, especially when using very large guidance weights." 237 | 238 | https://arxiv.org/abs/2205.11487 239 | """ 240 | dtype = sample.dtype 241 | batch_size, channels, *remaining_dims = sample.shape 242 | 243 | if dtype not in (torch.float32, torch.float64): 244 | sample = sample.float( 245 | ) # upcast for quantile calculation, and clamp not implemented for cpu half 246 | 247 | # Flatten sample for doing quantile calculation along each image 248 | sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) 249 | 250 | abs_sample = sample.abs() # "a certain percentile absolute pixel value" 251 | 252 | s = torch.quantile( 253 | abs_sample, self.config.dynamic_thresholding_ratio, dim=1) 254 | s = torch.clamp( 255 | s, min=1, max=self.config.sample_max_value 256 | ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] 257 | s = s.unsqueeze( 258 | 1) # (batch_size, 1) because clamp will broadcast along dim=0 259 | sample = torch.clamp( 260 | sample, -s, s 261 | ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" 262 | 263 | sample = sample.reshape(batch_size, channels, *remaining_dims) 264 | sample = sample.to(dtype) 265 | 266 | return sample 267 | 268 | # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t 269 | def _sigma_to_t(self, sigma): 270 | return sigma * self.config.num_train_timesteps 271 | 272 | def _sigma_to_alpha_sigma_t(self, sigma): 273 | return 1 - sigma, sigma 274 | 275 | # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps 276 | def time_shift(self, mu: float, sigma: float, t: torch.Tensor): 277 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) 278 | 279 | def convert_model_output( 280 | self, 281 | model_output: torch.Tensor, 282 | *args, 283 | sample: torch.Tensor = None, 284 | **kwargs, 285 | ) -> torch.Tensor: 286 | r""" 287 | Convert the model output to the corresponding type the UniPC algorithm needs. 288 | 289 | Args: 290 | model_output (`torch.Tensor`): 291 | The direct output from the learned diffusion model. 292 | timestep (`int`): 293 | The current discrete timestep in the diffusion chain. 294 | sample (`torch.Tensor`): 295 | A current instance of a sample created by the diffusion process. 296 | 297 | Returns: 298 | `torch.Tensor`: 299 | The converted model output. 300 | """ 301 | timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) 302 | if sample is None: 303 | if len(args) > 1: 304 | sample = args[1] 305 | else: 306 | raise ValueError( 307 | "missing `sample` as a required keyward argument") 308 | if timestep is not None: 309 | deprecate( 310 | "timesteps", 311 | "1.0.0", 312 | "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", 313 | ) 314 | 315 | sigma = self.sigmas[self.step_index] 316 | alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) 317 | 318 | if self.predict_x0: 319 | if self.config.prediction_type == "flow_prediction": 320 | sigma_t = self.sigmas[self.step_index] 321 | x0_pred = sample - sigma_t * model_output 322 | else: 323 | raise ValueError( 324 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," 325 | " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." 326 | ) 327 | 328 | if self.config.thresholding: 329 | x0_pred = self._threshold_sample(x0_pred) 330 | 331 | return x0_pred 332 | else: 333 | if self.config.prediction_type == "flow_prediction": 334 | sigma_t = self.sigmas[self.step_index] 335 | epsilon = sample - (1 - sigma_t) * model_output 336 | else: 337 | raise ValueError( 338 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," 339 | " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." 340 | ) 341 | 342 | if self.config.thresholding: 343 | sigma_t = self.sigmas[self.step_index] 344 | x0_pred = sample - sigma_t * model_output 345 | x0_pred = self._threshold_sample(x0_pred) 346 | epsilon = model_output + x0_pred 347 | 348 | return epsilon 349 | 350 | def multistep_uni_p_bh_update( 351 | self, 352 | model_output: torch.Tensor, 353 | *args, 354 | sample: torch.Tensor = None, 355 | order: int = None, # pyright: ignore 356 | **kwargs, 357 | ) -> torch.Tensor: 358 | """ 359 | One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. 360 | 361 | Args: 362 | model_output (`torch.Tensor`): 363 | The direct output from the learned diffusion model at the current timestep. 364 | prev_timestep (`int`): 365 | The previous discrete timestep in the diffusion chain. 366 | sample (`torch.Tensor`): 367 | A current instance of a sample created by the diffusion process. 368 | order (`int`): 369 | The order of UniP at this timestep (corresponds to the *p* in UniPC-p). 370 | 371 | Returns: 372 | `torch.Tensor`: 373 | The sample tensor at the previous timestep. 374 | """ 375 | prev_timestep = args[0] if len(args) > 0 else kwargs.pop( 376 | "prev_timestep", None) 377 | if sample is None: 378 | if len(args) > 1: 379 | sample = args[1] 380 | else: 381 | raise ValueError( 382 | " missing `sample` as a required keyward argument") 383 | if order is None: 384 | if len(args) > 2: 385 | order = args[2] 386 | else: 387 | raise ValueError( 388 | " missing `order` as a required keyward argument") 389 | if prev_timestep is not None: 390 | deprecate( 391 | "prev_timestep", 392 | "1.0.0", 393 | "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", 394 | ) 395 | model_output_list = self.model_outputs 396 | 397 | s0 = self.timestep_list[-1] 398 | m0 = model_output_list[-1] 399 | x = sample 400 | 401 | if self.solver_p: 402 | x_t = self.solver_p.step(model_output, s0, x).prev_sample 403 | return x_t 404 | 405 | sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[ 406 | self.step_index] # pyright: ignore 407 | alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) 408 | alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) 409 | 410 | lambda_t = torch.log(alpha_t) - torch.log(sigma_t) 411 | lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) 412 | 413 | h = lambda_t - lambda_s0 414 | device = sample.device 415 | 416 | rks = [] 417 | D1s = [] 418 | for i in range(1, order): 419 | si = self.step_index - i # pyright: ignore 420 | mi = model_output_list[-(i + 1)] 421 | alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) 422 | lambda_si = torch.log(alpha_si) - torch.log(sigma_si) 423 | rk = (lambda_si - lambda_s0) / h 424 | rks.append(rk) 425 | D1s.append((mi - m0) / rk) # pyright: ignore 426 | 427 | rks.append(1.0) 428 | rks = torch.tensor(rks, device=device) 429 | 430 | R = [] 431 | b = [] 432 | 433 | hh = -h if self.predict_x0 else h 434 | h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 435 | h_phi_k = h_phi_1 / hh - 1 436 | 437 | factorial_i = 1 438 | 439 | if self.config.solver_type == "bh1": 440 | B_h = hh 441 | elif self.config.solver_type == "bh2": 442 | B_h = torch.expm1(hh) 443 | else: 444 | raise NotImplementedError() 445 | 446 | for i in range(1, order + 1): 447 | R.append(torch.pow(rks, i - 1)) 448 | b.append(h_phi_k * factorial_i / B_h) 449 | factorial_i *= i + 1 450 | h_phi_k = h_phi_k / hh - 1 / factorial_i 451 | 452 | R = torch.stack(R) 453 | b = torch.tensor(b, device=device) 454 | 455 | if len(D1s) > 0: 456 | D1s = torch.stack(D1s, dim=1) # (B, K) 457 | # for order 2, we use a simplified version 458 | if order == 2: 459 | rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) 460 | else: 461 | rhos_p = torch.linalg.solve(R[:-1, :-1], 462 | b[:-1]).to(device).to(x.dtype) 463 | else: 464 | D1s = None 465 | 466 | if self.predict_x0: 467 | x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 468 | if D1s is not None: 469 | pred_res = torch.einsum("k,bkc...->bc...", rhos_p, 470 | D1s) # pyright: ignore 471 | else: 472 | pred_res = 0 473 | x_t = x_t_ - alpha_t * B_h * pred_res 474 | else: 475 | x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 476 | if D1s is not None: 477 | pred_res = torch.einsum("k,bkc...->bc...", rhos_p, 478 | D1s) # pyright: ignore 479 | else: 480 | pred_res = 0 481 | x_t = x_t_ - sigma_t * B_h * pred_res 482 | 483 | x_t = x_t.to(x.dtype) 484 | return x_t 485 | 486 | def multistep_uni_c_bh_update( 487 | self, 488 | this_model_output: torch.Tensor, 489 | *args, 490 | last_sample: torch.Tensor = None, 491 | this_sample: torch.Tensor = None, 492 | order: int = None, # pyright: ignore 493 | **kwargs, 494 | ) -> torch.Tensor: 495 | """ 496 | One step for the UniC (B(h) version). 497 | 498 | Args: 499 | this_model_output (`torch.Tensor`): 500 | The model outputs at `x_t`. 501 | this_timestep (`int`): 502 | The current timestep `t`. 503 | last_sample (`torch.Tensor`): 504 | The generated sample before the last predictor `x_{t-1}`. 505 | this_sample (`torch.Tensor`): 506 | The generated sample after the last predictor `x_{t}`. 507 | order (`int`): 508 | The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. 509 | 510 | Returns: 511 | `torch.Tensor`: 512 | The corrected sample tensor at the current timestep. 513 | """ 514 | this_timestep = args[0] if len(args) > 0 else kwargs.pop( 515 | "this_timestep", None) 516 | if last_sample is None: 517 | if len(args) > 1: 518 | last_sample = args[1] 519 | else: 520 | raise ValueError( 521 | " missing`last_sample` as a required keyward argument") 522 | if this_sample is None: 523 | if len(args) > 2: 524 | this_sample = args[2] 525 | else: 526 | raise ValueError( 527 | " missing`this_sample` as a required keyward argument") 528 | if order is None: 529 | if len(args) > 3: 530 | order = args[3] 531 | else: 532 | raise ValueError( 533 | " missing`order` as a required keyward argument") 534 | if this_timestep is not None: 535 | deprecate( 536 | "this_timestep", 537 | "1.0.0", 538 | "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", 539 | ) 540 | 541 | model_output_list = self.model_outputs 542 | 543 | m0 = model_output_list[-1] 544 | x = last_sample 545 | x_t = this_sample 546 | model_t = this_model_output 547 | 548 | sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[ 549 | self.step_index - 1] # pyright: ignore 550 | alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) 551 | alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) 552 | 553 | lambda_t = torch.log(alpha_t) - torch.log(sigma_t) 554 | lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) 555 | 556 | h = lambda_t - lambda_s0 557 | device = this_sample.device 558 | 559 | rks = [] 560 | D1s = [] 561 | for i in range(1, order): 562 | si = self.step_index - (i + 1) # pyright: ignore 563 | mi = model_output_list[-(i + 1)] 564 | alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) 565 | lambda_si = torch.log(alpha_si) - torch.log(sigma_si) 566 | rk = (lambda_si - lambda_s0) / h 567 | rks.append(rk) 568 | D1s.append((mi - m0) / rk) # pyright: ignore 569 | 570 | rks.append(1.0) 571 | rks = torch.tensor(rks, device=device) 572 | 573 | R = [] 574 | b = [] 575 | 576 | hh = -h if self.predict_x0 else h 577 | h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 578 | h_phi_k = h_phi_1 / hh - 1 579 | 580 | factorial_i = 1 581 | 582 | if self.config.solver_type == "bh1": 583 | B_h = hh 584 | elif self.config.solver_type == "bh2": 585 | B_h = torch.expm1(hh) 586 | else: 587 | raise NotImplementedError() 588 | 589 | for i in range(1, order + 1): 590 | R.append(torch.pow(rks, i - 1)) 591 | b.append(h_phi_k * factorial_i / B_h) 592 | factorial_i *= i + 1 593 | h_phi_k = h_phi_k / hh - 1 / factorial_i 594 | 595 | R = torch.stack(R) 596 | b = torch.tensor(b, device=device) 597 | 598 | if len(D1s) > 0: 599 | D1s = torch.stack(D1s, dim=1) 600 | else: 601 | D1s = None 602 | 603 | # for order 1, we use a simplified version 604 | if order == 1: 605 | rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) 606 | else: 607 | rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) 608 | 609 | if self.predict_x0: 610 | x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 611 | if D1s is not None: 612 | corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) 613 | else: 614 | corr_res = 0 615 | D1_t = model_t - m0 616 | x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) 617 | else: 618 | x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 619 | if D1s is not None: 620 | corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) 621 | else: 622 | corr_res = 0 623 | D1_t = model_t - m0 624 | x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) 625 | x_t = x_t.to(x.dtype) 626 | return x_t 627 | 628 | def index_for_timestep(self, timestep, schedule_timesteps=None): 629 | if schedule_timesteps is None: 630 | schedule_timesteps = self.timesteps 631 | 632 | indices = (schedule_timesteps == timestep).nonzero() 633 | 634 | # The sigma index that is taken for the **very** first `step` 635 | # is always the second index (or the last index if there is only 1) 636 | # This way we can ensure we don't accidentally skip a sigma in 637 | # case we start in the middle of the denoising schedule (e.g. for image-to-image) 638 | pos = 1 if len(indices) > 1 else 0 639 | 640 | return indices[pos].item() 641 | 642 | # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index 643 | def _init_step_index(self, timestep): 644 | """ 645 | Initialize the step_index counter for the scheduler. 646 | """ 647 | 648 | if self.begin_index is None: 649 | if isinstance(timestep, torch.Tensor): 650 | timestep = timestep.to(self.timesteps.device) 651 | self._step_index = self.index_for_timestep(timestep) 652 | else: 653 | self._step_index = self._begin_index 654 | 655 | def step(self, 656 | model_output: torch.Tensor, 657 | timestep: Union[int, torch.Tensor], 658 | sample: torch.Tensor, 659 | return_dict: bool = True, 660 | generator=None) -> Union[SchedulerOutput, Tuple]: 661 | """ 662 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with 663 | the multistep UniPC. 664 | 665 | Args: 666 | model_output (`torch.Tensor`): 667 | The direct output from learned diffusion model. 668 | timestep (`int`): 669 | The current discrete timestep in the diffusion chain. 670 | sample (`torch.Tensor`): 671 | A current instance of a sample created by the diffusion process. 672 | return_dict (`bool`): 673 | Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. 674 | 675 | Returns: 676 | [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: 677 | If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a 678 | tuple is returned where the first element is the sample tensor. 679 | 680 | """ 681 | if self.num_inference_steps is None: 682 | raise ValueError( 683 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" 684 | ) 685 | 686 | if self.step_index is None: 687 | self._init_step_index(timestep) 688 | 689 | use_corrector = ( 690 | self.step_index > 0 and 691 | self.step_index - 1 not in self.disable_corrector and 692 | self.last_sample is not None # pyright: ignore 693 | ) 694 | 695 | model_output_convert = self.convert_model_output( 696 | model_output, sample=sample) 697 | if use_corrector: 698 | sample = self.multistep_uni_c_bh_update( 699 | this_model_output=model_output_convert, 700 | last_sample=self.last_sample, 701 | this_sample=sample, 702 | order=self.this_order, 703 | ) 704 | 705 | for i in range(self.config.solver_order - 1): 706 | self.model_outputs[i] = self.model_outputs[i + 1] 707 | self.timestep_list[i] = self.timestep_list[i + 1] 708 | 709 | self.model_outputs[-1] = model_output_convert 710 | self.timestep_list[-1] = timestep # pyright: ignore 711 | 712 | if self.config.lower_order_final: 713 | this_order = min(self.config.solver_order, 714 | len(self.timesteps) - 715 | self.step_index) # pyright: ignore 716 | else: 717 | this_order = self.config.solver_order 718 | 719 | self.this_order = min(this_order, 720 | self.lower_order_nums + 1) # warmup for multistep 721 | assert self.this_order > 0 722 | 723 | self.last_sample = sample 724 | prev_sample = self.multistep_uni_p_bh_update( 725 | model_output=model_output, # pass the original non-converted model output, in case solver-p is used 726 | sample=sample, 727 | order=self.this_order, 728 | ) 729 | 730 | if self.lower_order_nums < self.config.solver_order: 731 | self.lower_order_nums += 1 732 | 733 | # upon completion increase step index by one 734 | self._step_index += 1 # pyright: ignore 735 | 736 | if not return_dict: 737 | return (prev_sample,) 738 | 739 | return SchedulerOutput(prev_sample=prev_sample) 740 | 741 | def scale_model_input(self, sample: torch.Tensor, *args, 742 | **kwargs) -> torch.Tensor: 743 | """ 744 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the 745 | current timestep. 746 | 747 | Args: 748 | sample (`torch.Tensor`): 749 | The input sample. 750 | 751 | Returns: 752 | `torch.Tensor`: 753 | A scaled input sample. 754 | """ 755 | return sample 756 | 757 | # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise 758 | def add_noise( 759 | self, 760 | original_samples: torch.Tensor, 761 | noise: torch.Tensor, 762 | timesteps: torch.IntTensor, 763 | ) -> torch.Tensor: 764 | # Make sure sigmas and timesteps have the same device and dtype as original_samples 765 | sigmas = self.sigmas.to( 766 | device=original_samples.device, dtype=original_samples.dtype) 767 | if original_samples.device.type == "mps" and torch.is_floating_point( 768 | timesteps): 769 | # mps does not support float64 770 | schedule_timesteps = self.timesteps.to( 771 | original_samples.device, dtype=torch.float32) 772 | timesteps = timesteps.to( 773 | original_samples.device, dtype=torch.float32) 774 | else: 775 | schedule_timesteps = self.timesteps.to(original_samples.device) 776 | timesteps = timesteps.to(original_samples.device) 777 | 778 | # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index 779 | if self.begin_index is None: 780 | step_indices = [ 781 | self.index_for_timestep(t, schedule_timesteps) 782 | for t in timesteps 783 | ] 784 | elif self.step_index is not None: 785 | # add_noise is called after first denoising step (for inpainting) 786 | step_indices = [self.step_index] * timesteps.shape[0] 787 | else: 788 | # add noise is called before first denoising step to create initial latent(img2img) 789 | step_indices = [self.begin_index] * timesteps.shape[0] 790 | 791 | sigma = sigmas[step_indices].flatten() 792 | while len(sigma.shape) < len(original_samples.shape): 793 | sigma = sigma.unsqueeze(-1) 794 | 795 | alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) 796 | noisy_samples = alpha_t * original_samples + sigma_t * noise 797 | return noisy_samples 798 | 799 | def __len__(self): 800 | return self.config.num_train_timesteps 801 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | 3 | name = "hidream-sampler" 4 | description = "A custom ComfyUI node for generating images using the HiDream AI model. Uses quantization for lower memory usage. Simple, Advanced and Img2img mode." 5 | version = "1.2.0" 6 | license = {file = "LICENSE"} 7 | dependencies = ["# Core dependencies", "transformers>=4.36.0", "diffusers>=0.26.0", "torch>=2.0.0", "numpy>=1.24.0", "Pillow>=10.0.0", "# For standard (BNB) models", "bitsandbytes>=0.41.0", "# For NF4 models", "optimum>=1.12.0", "accelerate>=0.25.0", "gptqmodel>=2.2.0", "# gptqmodel might need some more dependencies", "device-smi", "tokenicer", "threadpoolctl", "logbar", "datasets"] 8 | 9 | [project.urls] 10 | Repository = "https://github.com/lum3on/comfyui_HiDream-Sampler" 11 | # Used by Comfy Registry https://comfyregistry.org 12 | 13 | [tool.comfy] 14 | 15 | PublisherId = "lum3on" 16 | DisplayName = "comfyui_HiDream-Sampler" 17 | Icon = "hhttps://avatars.githubusercontent.com/u/197819028?v=4" 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | # Core ComfyUI/HiDream dependencies 3 | transformers>=4.36.0 4 | diffusers>=0.26.0 5 | torch>=2.0.0 6 | numpy>=1.24.0 7 | Pillow>=10.0.0 8 | safetensors>=0.4.0 9 | huggingface_hub>=0.17.0 10 | 11 | # For standard 4-bit (BNB) models (e.g., 'full', 'dev', 'fast') 12 | bitsandbytes>=0.41.0 13 | 14 | # For NF4/GPTQ models (e.g., 'full-nf4', 'dev-nf4', 'fast-nf4') 15 | optimum>=1.12.0 16 | accelerate>=0.25.0 17 | sageattention 18 | 19 | # Add this based on user feedback 20 | gptqmodel>=2.0.0 21 | 22 | # 3.3.0 required for NVIDIA RTX 50xx (Blackwell) 23 | triton-windows==3.3.0a0.post17; sys_platform == "win32" 24 | # Keep Triton 3 according to earlier requirements for non-windows platforms. 25 | triton>=3.0.0,<3.1.0; sys_platform != "win32" -------------------------------------------------------------------------------- /sample_workflow/ComfyUI HiDream GGUF Simple.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "56d4aafd-359f-4d6d-b993-7290fa2fb2c5", 3 | "revision": 0, 4 | "last_node_id": 173, 5 | "last_link_id": 439, 6 | "nodes": [ 7 | { 8 | "id": 167, 9 | "type": "SaveImage", 10 | "pos": [ 11 | 524.5581665039062, 12 | 694.3413696289062 13 | ], 14 | "size": [ 15 | 1008.216552734375, 16 | 611.4951782226562 17 | ], 18 | "flags": {}, 19 | "order": 10, 20 | "mode": 0, 21 | "inputs": [ 22 | { 23 | "name": "images", 24 | "type": "IMAGE", 25 | "link": 433 26 | } 27 | ], 28 | "outputs": [], 29 | "properties": { 30 | "cnr_id": "comfy-core", 31 | "ver": "0.3.28", 32 | "Node name for S&R": "SaveImage" 33 | }, 34 | "widgets_values": [ 35 | "HiDream", 36 | "" 37 | ] 38 | }, 39 | { 40 | "id": 44, 41 | "type": "CLIPTextEncode", 42 | "pos": [ 43 | -273.22076416015625, 44 | 1090.7685546875 45 | ], 46 | "size": [ 47 | 421.59600830078125, 48 | 209.1119842529297 49 | ], 50 | "flags": {}, 51 | "order": 5, 52 | "mode": 0, 53 | "inputs": [ 54 | { 55 | "name": "clip", 56 | "type": "CLIP", 57 | "link": 416 58 | } 59 | ], 60 | "outputs": [ 61 | { 62 | "name": "CONDITIONING", 63 | "type": "CONDITIONING", 64 | "slot_index": 0, 65 | "links": [ 66 | 435 67 | ] 68 | } 69 | ], 70 | "title": "CLIP Text Encode (Negative)", 71 | "properties": { 72 | "cnr_id": "comfy-core", 73 | "ver": "0.3.26", 74 | "Node name for S&R": "CLIPTextEncode" 75 | }, 76 | "widgets_values": [ 77 | "bad quality, artifacts" 78 | ], 79 | "color": "#322", 80 | "bgcolor": "#533" 81 | }, 82 | { 83 | "id": 164, 84 | "type": "CLIPTextEncode", 85 | "pos": [ 86 | -273.2846374511719, 87 | 827.217529296875 88 | ], 89 | "size": [ 90 | 420.07513427734375, 91 | 217.4784393310547 92 | ], 93 | "flags": {}, 94 | "order": 6, 95 | "mode": 0, 96 | "inputs": [ 97 | { 98 | "name": "clip", 99 | "type": "CLIP", 100 | "link": 426 101 | } 102 | ], 103 | "outputs": [ 104 | { 105 | "name": "CONDITIONING", 106 | "type": "CONDITIONING", 107 | "slot_index": 0, 108 | "links": [ 109 | 422, 110 | 436 111 | ] 112 | } 113 | ], 114 | "title": "CLIP Text Encode (Positive Prompt)", 115 | "properties": { 116 | "cnr_id": "comfy-core", 117 | "ver": "0.3.26", 118 | "Node name for S&R": "CLIPTextEncode" 119 | }, 120 | "widgets_values": [ 121 | "This artwork showcases a mesmerizing blend of decay and technology. The scene depicts an old car parked inside an abandoned, dilapidated room. The room features a large, circular hole in the ceiling, allowing soft, natural light to filter through, illuminating the interior. A rectangular opening in the wall reveals a misty, tree-filled landscape, creating a surreal, otherworldly atmosphere. The space is filled with rubble and debris, adding to the sense of ruin. On the left side of the room, a futuristic-looking screen emits a soft, blue glow, contrasting with the old, decaying environment. The lighting is dramatic, with strong contrasts between light and shadow, enhancing the overall sense of desolation and mystery." 122 | ], 123 | "color": "#232", 124 | "bgcolor": "#353" 125 | }, 126 | { 127 | "id": 161, 128 | "type": "QuadrupleCLIPLoader", 129 | "pos": [ 130 | -271.17388916015625, 131 | 644.1995239257812 132 | ], 133 | "size": [ 134 | 419.6007385253906, 135 | 130 136 | ], 137 | "flags": {}, 138 | "order": 0, 139 | "mode": 0, 140 | "inputs": [], 141 | "outputs": [ 142 | { 143 | "name": "CLIP", 144 | "type": "CLIP", 145 | "links": [ 146 | 416, 147 | 426 148 | ] 149 | } 150 | ], 151 | "properties": { 152 | "cnr_id": "comfy-core", 153 | "ver": "0.3.28", 154 | "Node name for S&R": "QuadrupleCLIPLoader" 155 | }, 156 | "widgets_values": [ 157 | "clip_g_hidream.safetensors", 158 | "clip_l_hidream.safetensors", 159 | "t5xxl_fp8_e4m3fn_scaled_HiDream.safetensors", 160 | "llama_3.1_8b_instruct_fp8_scaled.safetensors" 161 | ] 162 | }, 163 | { 164 | "id": 162, 165 | "type": "UnetLoaderGGUF", 166 | "pos": [ 167 | -281.3764343261719, 168 | 535.6951904296875 169 | ], 170 | "size": [ 171 | 419.80206298828125, 172 | 58 173 | ], 174 | "flags": {}, 175 | "order": 1, 176 | "mode": 0, 177 | "inputs": [], 178 | "outputs": [ 179 | { 180 | "name": "MODEL", 181 | "type": "MODEL", 182 | "links": [ 183 | 418, 184 | 437, 185 | 438 186 | ] 187 | } 188 | ], 189 | "properties": { 190 | "cnr_id": "comfyui-gguf", 191 | "ver": "298192ed60f8ca821c6fe5f8030cae23424cada5", 192 | "Node name for S&R": "UnetLoaderGGUF" 193 | }, 194 | "widgets_values": [ 195 | "Hunyuan-FastDrive\\hidream-i1-dev-Q8_0.gguf" 196 | ] 197 | }, 198 | { 199 | "id": 105, 200 | "type": "VAELoader", 201 | "pos": [ 202 | 175.5850372314453, 203 | 696.0665283203125 204 | ], 205 | "size": [ 206 | 332.7481689453125, 207 | 64.84506225585938 208 | ], 209 | "flags": {}, 210 | "order": 2, 211 | "mode": 0, 212 | "inputs": [], 213 | "outputs": [ 214 | { 215 | "name": "VAE", 216 | "type": "VAE", 217 | "slot_index": 0, 218 | "links": [ 219 | 432 220 | ] 221 | } 222 | ], 223 | "properties": { 224 | "cnr_id": "comfy-core", 225 | "ver": "0.3.26", 226 | "Node name for S&R": "VAELoader" 227 | }, 228 | "widgets_values": [ 229 | "diffusion_pytorch_model.safetensors" 230 | ] 231 | }, 232 | { 233 | "id": 169, 234 | "type": "VAEDecode", 235 | "pos": [ 236 | 386.3603515625, 237 | 593.2472534179688 238 | ], 239 | "size": [ 240 | 146.87254333496094, 241 | 46 242 | ], 243 | "flags": {}, 244 | "order": 9, 245 | "mode": 0, 246 | "inputs": [ 247 | { 248 | "name": "samples", 249 | "type": "LATENT", 250 | "link": 431 251 | }, 252 | { 253 | "name": "vae", 254 | "type": "VAE", 255 | "link": 432 256 | } 257 | ], 258 | "outputs": [ 259 | { 260 | "name": "IMAGE", 261 | "type": "IMAGE", 262 | "links": [ 263 | 433 264 | ] 265 | } 266 | ], 267 | "properties": { 268 | "cnr_id": "comfy-core", 269 | "ver": "0.3.28", 270 | "Node name for S&R": "VAEDecode" 271 | }, 272 | "widgets_values": [] 273 | }, 274 | { 275 | "id": 173, 276 | "type": "Note", 277 | "pos": [ 278 | 571.3648681640625, 279 | 541.4556274414062 280 | ], 281 | "size": [ 282 | 529.4127197265625, 283 | 111.57780456542969 284 | ], 285 | "flags": {}, 286 | "order": 3, 287 | "mode": 0, 288 | "inputs": [], 289 | "outputs": [], 290 | "properties": {}, 291 | "widgets_values": [ 292 | "Use the GGUF conversions from City96:\n\nhttps://huggingface.co/city96/HiDream-I1-Dev-gguf/tree/main\n\nThe 4 Text Encoder models can be downloaded here:\n\nhttps://huggingface.co/Comfy-Org/HiDream-I1_ComfyUI/tree/main/split_files/text_encoders" 293 | ], 294 | "color": "#432", 295 | "bgcolor": "#653" 296 | }, 297 | { 298 | "id": 171, 299 | "type": "ModelSamplingSD3", 300 | "pos": [ 301 | 158.48165893554688, 302 | 586.4207763671875 303 | ], 304 | "size": [ 305 | 210, 306 | 58 307 | ], 308 | "flags": {}, 309 | "order": 7, 310 | "mode": 0, 311 | "inputs": [ 312 | { 313 | "name": "model", 314 | "type": "MODEL", 315 | "link": 438 316 | } 317 | ], 318 | "outputs": [ 319 | { 320 | "name": "MODEL", 321 | "type": "MODEL", 322 | "slot_index": 0, 323 | "links": [ 324 | 439 325 | ] 326 | } 327 | ], 328 | "properties": { 329 | "cnr_id": "comfy-core", 330 | "ver": "0.3.26", 331 | "Node name for S&R": "ModelSamplingSD3" 332 | }, 333 | "widgets_values": [ 334 | 9 335 | ] 336 | }, 337 | { 338 | "id": 170, 339 | "type": "EmptyLatentImage", 340 | "pos": [ 341 | 176.3371124267578, 342 | 815.2146606445312 343 | ], 344 | "size": [ 345 | 315, 346 | 106 347 | ], 348 | "flags": {}, 349 | "order": 4, 350 | "mode": 0, 351 | "inputs": [], 352 | "outputs": [ 353 | { 354 | "name": "LATENT", 355 | "type": "LATENT", 356 | "links": [ 357 | 434 358 | ] 359 | } 360 | ], 361 | "properties": { 362 | "cnr_id": "comfy-core", 363 | "ver": "0.3.28", 364 | "Node name for S&R": "EmptyLatentImage" 365 | }, 366 | "widgets_values": [ 367 | 1344, 368 | 768, 369 | 1 370 | ] 371 | }, 372 | { 373 | "id": 168, 374 | "type": "KSamplerAdvanced", 375 | "pos": [ 376 | 179.55612182617188, 377 | 971.3984985351562 378 | ], 379 | "size": [ 380 | 315, 381 | 334 382 | ], 383 | "flags": {}, 384 | "order": 8, 385 | "mode": 0, 386 | "inputs": [ 387 | { 388 | "name": "model", 389 | "type": "MODEL", 390 | "link": 439 391 | }, 392 | { 393 | "name": "positive", 394 | "type": "CONDITIONING", 395 | "link": 436 396 | }, 397 | { 398 | "name": "negative", 399 | "type": "CONDITIONING", 400 | "link": 435 401 | }, 402 | { 403 | "name": "latent_image", 404 | "type": "LATENT", 405 | "link": 434 406 | } 407 | ], 408 | "outputs": [ 409 | { 410 | "name": "LATENT", 411 | "type": "LATENT", 412 | "links": [ 413 | 431 414 | ] 415 | } 416 | ], 417 | "properties": { 418 | "cnr_id": "comfy-core", 419 | "ver": "0.3.28", 420 | "Node name for S&R": "KSamplerAdvanced" 421 | }, 422 | "widgets_values": [ 423 | "enable", 424 | 234021387576769, 425 | "randomize", 426 | 20, 427 | 1, 428 | "lcm_custom_noise", 429 | "normal", 430 | 0, 431 | 10000, 432 | "disable" 433 | ] 434 | } 435 | ], 436 | "links": [ 437 | [ 438 | 416, 439 | 161, 440 | 0, 441 | 44, 442 | 0, 443 | "CLIP" 444 | ], 445 | [ 446 | 426, 447 | 161, 448 | 0, 449 | 164, 450 | 0, 451 | "CLIP" 452 | ], 453 | [ 454 | 431, 455 | 168, 456 | 0, 457 | 169, 458 | 0, 459 | "LATENT" 460 | ], 461 | [ 462 | 432, 463 | 105, 464 | 0, 465 | 169, 466 | 1, 467 | "VAE" 468 | ], 469 | [ 470 | 433, 471 | 169, 472 | 0, 473 | 167, 474 | 0, 475 | "IMAGE" 476 | ], 477 | [ 478 | 434, 479 | 170, 480 | 0, 481 | 168, 482 | 3, 483 | "LATENT" 484 | ], 485 | [ 486 | 435, 487 | 44, 488 | 0, 489 | 168, 490 | 2, 491 | "CONDITIONING" 492 | ], 493 | [ 494 | 436, 495 | 164, 496 | 0, 497 | 168, 498 | 1, 499 | "CONDITIONING" 500 | ], 501 | [ 502 | 438, 503 | 162, 504 | 0, 505 | 171, 506 | 0, 507 | "MODEL" 508 | ], 509 | [ 510 | 439, 511 | 171, 512 | 0, 513 | 168, 514 | 0, 515 | "MODEL" 516 | ] 517 | ], 518 | "groups": [], 519 | "config": {}, 520 | "extra": { 521 | "ds": { 522 | "scale": 1, 523 | "offset": [ 524 | 354.642445219116, 525 | -436.043685518403 526 | ] 527 | }, 528 | "groupNodes": {}, 529 | "node_versions": { 530 | "ComfyUI_Fill-Nodes": "5f646932e6ecf92fa6d6dc5de149867451065353", 531 | "comfyui_dagthomas": "4f901fbf8d05bd1f120a30eac709cf9edcf37ebe", 532 | "comfy-core": "0.3.15", 533 | "comfyui_patches_ll": "1acfbbd58848f7f7a40b1e5c7f88d1be876f4938", 534 | "ComfyUI-HunyuanVideoMultiLora": "d0a2d5fe1fc9e8b4756567e3e0b9751bd570c859", 535 | "ComfyUI-VideoHelperSuite": "c36626c6028faca912eafcedbc71f1d342fb4d2a", 536 | "ComfyUI-Easy-Use": "037080ac3935b5a398a12a6510dd600c762d8983", 537 | "darkprompts": "65154e1975671f343468d488d20ed9ca6230d45f", 538 | "ComfyUI-Custom-Scripts": "a53ef9b617ed1331640d7a2cd97644995908dc00", 539 | "ComfyUI-Impact-Pack": "c6056b132d7e155c3ece42b77e08ea45bde1bfef" 540 | }, 541 | "ue_links": [], 542 | "VHS_latentpreview": false, 543 | "VHS_latentpreviewrate": 0, 544 | "VHS_MetadataImage": true, 545 | "VHS_KeepIntermediate": true 546 | }, 547 | "version": 0.4 548 | } -------------------------------------------------------------------------------- /sample_workflow/Sample HiDream Sampler Workflow.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "64c118a4-1506-48f5-a531-a0620e2d46fe", 3 | "revision": 0, 4 | "last_node_id": 24, 5 | "last_link_id": 40, 6 | "nodes": [ 7 | { 8 | "id": 13, 9 | "type": "Note", 10 | "pos": [ 11 | -280, 12 | -180 13 | ], 14 | "size": [ 15 | 240, 16 | 160 17 | ], 18 | "flags": {}, 19 | "order": 0, 20 | "mode": 0, 21 | "inputs": [], 22 | "outputs": [], 23 | "properties": {}, 24 | "widgets_values": [ 25 | "Negative Prompt only works for Full & full-nf4 models.\n\nnf4 models require ~15GB of VRAM\nnon-nf4 models require ~27GB of VRAM" 26 | ], 27 | "color": "#432", 28 | "bgcolor": "#653" 29 | }, 30 | { 31 | "id": 21, 32 | "type": "PrimitiveNode", 33 | "pos": [ 34 | -280, 35 | -380 36 | ], 37 | "size": [ 38 | 240, 39 | 140 40 | ], 41 | "flags": {}, 42 | "order": 1, 43 | "mode": 0, 44 | "inputs": [], 45 | "outputs": [ 46 | { 47 | "name": "STRING", 48 | "type": "STRING", 49 | "widget": { 50 | "name": "primary_prompt" 51 | }, 52 | "links": [ 53 | 37, 54 | 38 55 | ] 56 | } 57 | ], 58 | "title": "Input Primary Prompt", 59 | "properties": { 60 | "Run widget replace on values": false 61 | }, 62 | "widgets_values": [ 63 | "A collectable rare postage stamp with the text \"US Postal Service\" and the year \"1918\" and the price of 15 cents, featuring a US WW1 Soldier peeking over the top of a trench, stamp collection" 64 | ], 65 | "color": "#232", 66 | "bgcolor": "#353" 67 | }, 68 | { 69 | "id": 17, 70 | "type": "PreviewImage", 71 | "pos": [ 72 | 1220, 73 | -180 74 | ], 75 | "size": [ 76 | 460, 77 | 460 78 | ], 79 | "flags": {}, 80 | "order": 5, 81 | "mode": 0, 82 | "inputs": [ 83 | { 84 | "name": "images", 85 | "type": "IMAGE", 86 | "link": 30 87 | } 88 | ], 89 | "outputs": [], 90 | "properties": { 91 | "Node name for S&R": "PreviewImage", 92 | "cnr_id": "comfy-core", 93 | "ver": "0.3.27" 94 | }, 95 | "widgets_values": [ 96 | "" 97 | ], 98 | "color": "#223", 99 | "bgcolor": "#335" 100 | }, 101 | { 102 | "id": 24, 103 | "type": "PreviewImage", 104 | "pos": [ 105 | 2120, 106 | -180 107 | ], 108 | "size": [ 109 | 460, 110 | 460 111 | ], 112 | "flags": {}, 113 | "order": 7, 114 | "mode": 0, 115 | "inputs": [ 116 | { 117 | "name": "images", 118 | "type": "IMAGE", 119 | "link": 40 120 | } 121 | ], 122 | "outputs": [], 123 | "properties": { 124 | "Node name for S&R": "PreviewImage", 125 | "cnr_id": "comfy-core", 126 | "ver": "0.3.27" 127 | }, 128 | "widgets_values": [ 129 | "" 130 | ], 131 | "color": "#432", 132 | "bgcolor": "#653" 133 | }, 134 | { 135 | "id": 23, 136 | "type": "HiDreamImg2Img", 137 | "pos": [ 138 | 1700, 139 | -180 140 | ], 141 | "size": [ 142 | 400, 143 | 580 144 | ], 145 | "flags": {}, 146 | "order": 6, 147 | "mode": 0, 148 | "inputs": [ 149 | { 150 | "name": "image", 151 | "type": "IMAGE", 152 | "link": 39 153 | } 154 | ], 155 | "outputs": [ 156 | { 157 | "name": "image", 158 | "type": "IMAGE", 159 | "links": [ 160 | 40 161 | ] 162 | } 163 | ], 164 | "properties": { 165 | "Node name for S&R": "HiDreamImg2Img" 166 | }, 167 | "widgets_values": [ 168 | "dev-nf4", 169 | 0.8000000000000002, 170 | "A collectable rare postage stamp with the text \"US Postal Service\" and the year \"1918\" and the price of 15 cents, featuring Micky Mouse peeking over the top of a trench, stamp collection", 171 | "", 172 | 532986874756016, 173 | "randomize", 174 | "Default for model", 175 | -1, 176 | -1, 177 | false, 178 | "You are a creative AI assistant that helps create detailed, vivid images based on user descriptions.", 179 | 1, 180 | 1, 181 | 1, 182 | 1 183 | ], 184 | "color": "#432", 185 | "bgcolor": "#653" 186 | }, 187 | { 188 | "id": 7, 189 | "type": "HiDreamSampler", 190 | "pos": [ 191 | 840, 192 | -180 193 | ], 194 | "size": [ 195 | 340, 196 | 352 197 | ], 198 | "flags": {}, 199 | "order": 3, 200 | "mode": 0, 201 | "inputs": [ 202 | { 203 | "name": "prompt", 204 | "type": "STRING", 205 | "widget": { 206 | "name": "prompt" 207 | }, 208 | "link": 38 209 | } 210 | ], 211 | "outputs": [ 212 | { 213 | "name": "image", 214 | "type": "IMAGE", 215 | "links": [ 216 | 30, 217 | 39 218 | ] 219 | } 220 | ], 221 | "properties": { 222 | "Node name for S&R": "HiDreamSampler", 223 | "aux_id": "SanDiegoDude/ComfyUI-HiDream-Sampler", 224 | "ver": "ae6c2ac8896c26b1fc0e58b8f70d478881663dd2" 225 | }, 226 | "widgets_values": [ 227 | "dev-nf4", 228 | "A collectable rare postage stamp with the text \"US Postal Service\" and the year \"1918\" and the price of 15 cents, featuring a US WW1 Soldier peeking over the top of a trench, stamp collection", 229 | "", 230 | "1024 × 1024 (Square)", 231 | 1, 232 | 42, 233 | "fixed", 234 | "Default for model", 235 | -1, 236 | -1, 237 | 0, 238 | 0 239 | ], 240 | "color": "#223", 241 | "bgcolor": "#335" 242 | }, 243 | { 244 | "id": 19, 245 | "type": "HiDreamSamplerAdvanced", 246 | "pos": [ 247 | 0, 248 | -180 249 | ], 250 | "size": [ 251 | 360, 252 | 838 253 | ], 254 | "flags": {}, 255 | "order": 2, 256 | "mode": 0, 257 | "inputs": [ 258 | { 259 | "name": "primary_prompt", 260 | "type": "STRING", 261 | "widget": { 262 | "name": "primary_prompt" 263 | }, 264 | "link": 37 265 | } 266 | ], 267 | "outputs": [ 268 | { 269 | "name": "image", 270 | "type": "IMAGE", 271 | "links": [ 272 | 28 273 | ] 274 | } 275 | ], 276 | "properties": { 277 | "Node name for S&R": "HiDreamSamplerAdvanced", 278 | "aux_id": "SanDiegoDude/ComfyUI-HiDream-Sampler", 279 | "ver": "7f3d4bbddbfca35aa84e9c386a94a1ce846bcc57" 280 | }, 281 | "widgets_values": [ 282 | "dev-nf4", 283 | "A collectable rare postage stamp with the text \"US Postal Service\" and the year \"1918\" and the price of 15 cents, featuring a US WW1 Soldier peeking over the top of a trench, stamp collection", 284 | "deformed, ugly, watermark, text, scribbles, noise, static, low quality, bad quality, jpeg artifacts, low resolution", 285 | "1024 × 1024 (Square)", 286 | 4, 287 | 42, 288 | "fixed", 289 | "Default for model", 290 | -1, 291 | -1, 292 | false, 293 | 0, 294 | false, 295 | "painted by Georgia O'Keefe", 296 | "painted by Georgia O'Keefe", 297 | "painted by Georgia O'Keefe", 298 | "", 299 | "You are a creative AI assistant that helps create detailed, vivid images based on user descriptions.", 300 | 1, 301 | 1, 302 | 1, 303 | 1, 304 | 77, 305 | 150, 306 | 256, 307 | 256 308 | ], 309 | "color": "#2a363b", 310 | "bgcolor": "#3f5159" 311 | }, 312 | { 313 | "id": 2, 314 | "type": "PreviewImage", 315 | "pos": [ 316 | 380, 317 | -180 318 | ], 319 | "size": [ 320 | 440, 321 | 420 322 | ], 323 | "flags": {}, 324 | "order": 4, 325 | "mode": 0, 326 | "inputs": [ 327 | { 328 | "name": "images", 329 | "type": "IMAGE", 330 | "link": 28 331 | } 332 | ], 333 | "outputs": [], 334 | "properties": { 335 | "Node name for S&R": "PreviewImage", 336 | "cnr_id": "comfy-core", 337 | "ver": "0.3.27" 338 | }, 339 | "widgets_values": [ 340 | "" 341 | ], 342 | "color": "#2a363b", 343 | "bgcolor": "#3f5159" 344 | } 345 | ], 346 | "links": [ 347 | [ 348 | 28, 349 | 19, 350 | 0, 351 | 2, 352 | 0, 353 | "IMAGE" 354 | ], 355 | [ 356 | 30, 357 | 7, 358 | 0, 359 | 17, 360 | 0, 361 | "IMAGE" 362 | ], 363 | [ 364 | 37, 365 | 21, 366 | 0, 367 | 19, 368 | 0, 369 | "STRING" 370 | ], 371 | [ 372 | 38, 373 | 21, 374 | 0, 375 | 7, 376 | 0, 377 | "STRING" 378 | ], 379 | [ 380 | 39, 381 | 7, 382 | 0, 383 | 23, 384 | 0, 385 | "IMAGE" 386 | ], 387 | [ 388 | 40, 389 | 23, 390 | 0, 391 | 24, 392 | 0, 393 | "IMAGE" 394 | ] 395 | ], 396 | "groups": [], 397 | "config": {}, 398 | "extra": { 399 | "ds": { 400 | "scale": 0.6303940863128511, 401 | "offset": [ 402 | 308.380239914095, 403 | 483.3423162860777 404 | ] 405 | }, 406 | "ue_links": [], 407 | "VHS_latentpreview": false, 408 | "VHS_latentpreviewrate": 0, 409 | "VHS_MetadataImage": true, 410 | "VHS_KeepIntermediate": true 411 | }, 412 | "version": 0.4 413 | } -------------------------------------------------------------------------------- /sample_workflow/workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lum3on/comfyui_HiDream-Sampler/98ad017cac93b782e2af95411e4c10d493ecb841/sample_workflow/workflow.png --------------------------------------------------------------------------------