├── .github └── workflows │ └── publish.yml ├── .gitignore ├── README.md ├── __init__.py ├── controlnet ├── controlnet_instantx.py └── controlnet_instantx_format2.py ├── nodes.py ├── pyproject.toml └── requirements.txt /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | jobs: 12 | publish-node: 13 | name: Publish Custom Node to registry 14 | runs-on: ubuntu-latest 15 | # if this is a forked repository. Skipping the workflow. 16 | if: github.event.repository.fork == false 17 | steps: 18 | - name: Check out code 19 | uses: actions/checkout@v4 20 | - name: Publish Custom Node 21 | uses: Comfy-Org/publish-node-action@main 22 | with: 23 | ## Add your own personal access token to your Github Repository secrets and reference it here. 24 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Note - This loader node is deprecated 2 | 3 | Important update regarding InstantX Union Controlnet: The latest version of ComfyUI now includes native support for the InstantX/Shakkar Labs Union Controlnet Pro, which produces higher quality outputs than the alpha version this loader supports. 4 | 5 | You can find an updated workflow in here: https://civitai.com/models/709352 6 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: eesahe 3 | @title: eesahe's Nodes 4 | @nickname: eesahesNodes 5 | @description: InstantX's Flux union ControlNet loader and implementation 6 | """ 7 | 8 | from .nodes import InstantXFluxUnionControlNetLoader 9 | 10 | NODE_CLASS_MAPPINGS = { 11 | "InstantX Flux Union ControlNet Loader": InstantXFluxUnionControlNetLoader 12 | } 13 | -------------------------------------------------------------------------------- /controlnet/controlnet_instantx.py: -------------------------------------------------------------------------------- 1 | #Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | from einops import rearrange, repeat 6 | 7 | from comfy.ldm.flux.layers import (timestep_embedding) 8 | 9 | from comfy.ldm.flux.model import Flux 10 | import comfy.ldm.common_dit 11 | import operator as op 12 | import sys 13 | import torch.nn.functional as F 14 | import numbers 15 | from diffusers.models.normalization import AdaLayerNormContinuous 16 | from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock 17 | from diffusers.utils.import_utils import is_torch_version 18 | import numpy as np 19 | 20 | if is_torch_version(">=", "2.1.0"): 21 | LayerNorm = nn.LayerNorm 22 | else: 23 | # Has optional bias parameter compared to torch layer norm 24 | # TODO: replace with torch layernorm once min required torch version >= 2.1 25 | class LayerNorm(nn.Module): 26 | def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True): 27 | super().__init__() 28 | 29 | self.eps = eps 30 | 31 | if isinstance(dim, numbers.Integral): 32 | dim = (dim,) 33 | 34 | self.dim = torch.Size(dim) 35 | 36 | if elementwise_affine: 37 | self.weight = nn.Parameter(torch.ones(dim)) 38 | self.bias = nn.Parameter(torch.zeros(dim)) if bias else None 39 | else: 40 | self.weight = None 41 | self.bias = None 42 | 43 | def forward(self, input): 44 | return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps) 45 | 46 | class FluxUnionControlNetModeEmbedder(nn.Module): 47 | def __init__(self, num_mode, out_channels): 48 | super().__init__() 49 | self.mode_embber = nn.Embedding(num_mode, out_channels) 50 | self.norm = nn.LayerNorm(out_channels, eps=1e-6) 51 | self.fc = nn.Linear(out_channels, out_channels) 52 | 53 | def forward(self, x): 54 | x_emb = self.mode_embber(x) 55 | x_emb = self.norm(x_emb) 56 | x_emb = self.fc(x_emb) 57 | x_emb = x_emb[:, 0] 58 | return x_emb 59 | 60 | def zero_module(module): 61 | for p in module.parameters(): 62 | nn.init.zeros_(p) 63 | return module 64 | 65 | # YiYi to-do: refactor rope related functions/classes 66 | def apply_rope(xq, xk, freqs_cis): 67 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) 68 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) 69 | xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] 70 | xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] 71 | return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) 72 | 73 | 74 | 75 | class FluxUnionControlNetInputEmbedder(nn.Module): 76 | def __init__(self, in_channels, out_channels, num_attention_heads=24, mlp_ratio=4.0, attention_head_dim=128, dtype=None, device=None, operations=None, depth=2): 77 | super().__init__() 78 | self.x_embedder = nn.Sequential(nn.LayerNorm(in_channels), nn.Linear(in_channels, out_channels)) 79 | self.norm = AdaLayerNormContinuous(out_channels, out_channels, elementwise_affine=False, eps=1e-6) 80 | self.fc = nn.Linear(out_channels, out_channels) 81 | self.emb_embedder = nn.Sequential(nn.LayerNorm(out_channels), nn.Linear(out_channels, out_channels)) 82 | 83 | """ self.single_blocks = nn.ModuleList( 84 | [ 85 | SingleStreamBlock( 86 | out_channels, num_attention_heads, dtype=dtype, device=device, operations=operations 87 | ) 88 | for i in range(2) 89 | ] 90 | ) """ 91 | self.single_transformer_blocks = nn.ModuleList( 92 | [ 93 | FluxSingleTransformerBlock( 94 | dim=out_channels, 95 | num_attention_heads=num_attention_heads, 96 | attention_head_dim=attention_head_dim, 97 | ) 98 | for i in range(depth) 99 | ] 100 | ) 101 | 102 | self.out = zero_module(nn.Linear(out_channels, out_channels)) 103 | 104 | def forward(self, x, mode_emb): 105 | mode_token = self.emb_embedder(mode_emb)[:, None] 106 | x_emb = self.fc(self.norm(self.x_embedder(x), mode_emb)) 107 | hidden_states = torch.cat([mode_token, x_emb], dim=1) 108 | for index_block, block in enumerate(self.single_transformer_blocks): 109 | hidden_states = block( 110 | hidden_states=hidden_states, 111 | temb=mode_emb, 112 | ) 113 | hidden_states = self.out(hidden_states) 114 | res = hidden_states[:, 1:] 115 | return res 116 | 117 | class InstantXControlNetFlux(Flux): 118 | def __init__(self, image_model=None, dtype=None, device=None, operations=None, joint_attention_dim=4096, **kwargs): 119 | kwargs["depth"] = 0 120 | kwargs["depth_single_blocks"] = 0 121 | depth_single_blocks_controlnet = kwargs.pop("depth_single_blocks_controlnet", 2) 122 | super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) 123 | 124 | self.transformer_blocks = nn.ModuleList( 125 | [ 126 | FluxTransformerBlock( 127 | dim=self.hidden_size, 128 | num_attention_heads=24, 129 | attention_head_dim=128, 130 | ).to(dtype=dtype) 131 | for i in range(5) 132 | ] 133 | ) 134 | 135 | self.single_transformer_blocks = nn.ModuleList( 136 | [ 137 | FluxSingleTransformerBlock( 138 | dim=self.hidden_size, 139 | num_attention_heads=24, 140 | attention_head_dim=128, 141 | ).to(dtype=dtype) 142 | for i in range(10) 143 | ] 144 | ) 145 | 146 | self.require_vae = True 147 | # add ControlNet blocks 148 | self.controlnet_blocks = nn.ModuleList([]) 149 | for _ in range(len(self.transformer_blocks)): 150 | controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) 151 | controlnet_block = zero_module(controlnet_block) 152 | self.controlnet_blocks.append(controlnet_block) 153 | 154 | self.controlnet_single_blocks = nn.ModuleList([]) 155 | for _ in range(len(self.single_transformer_blocks)): 156 | controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) 157 | controlnet_block = zero_module(controlnet_block) 158 | self.controlnet_single_blocks.append(controlnet_block) 159 | 160 | # TODO support both union and unimodal 161 | self.union = True #num_mode is not None 162 | num_mode = 10 163 | if self.union: 164 | self.controlnet_mode_embedder = zero_module(FluxUnionControlNetModeEmbedder(num_mode, self.hidden_size)).to(device=device, dtype=dtype) 165 | self.controlnet_x_embedder = FluxUnionControlNetInputEmbedder(self.in_channels, self.hidden_size, operations=operations, depth=depth_single_blocks_controlnet).to(device=device, dtype=dtype) 166 | self.controlnet_mode_token_embedder = nn.Sequential(nn.LayerNorm(self.hidden_size, eps=1e-6), nn.Linear(self.hidden_size, self.hidden_size)).to(device=device, dtype=dtype) 167 | else: 168 | self.controlnet_x_embedder = zero_module(torch.nn.Linear(self.in_channels, self.hidden_size)).to(device=device, dtype=dtype) 169 | self.gradient_checkpointing = False 170 | 171 | @staticmethod 172 | # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents 173 | def _pack_latents(latents, batch_size, num_channels_latents, height, width): 174 | latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) 175 | latents = latents.permute(0, 2, 4, 1, 3, 5) 176 | latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) 177 | 178 | return latents 179 | 180 | def set_hint_latents(self, hint_latents): 181 | vae_shift_factor = 0.1159 182 | vae_scaling_factor = 0.3611 183 | num_channels_latents = self.in_channels // 4 184 | hint_latents = (hint_latents - vae_shift_factor) * vae_scaling_factor 185 | 186 | height, width = hint_latents.shape[2:] 187 | hint_latents = self._pack_latents( 188 | hint_latents, 189 | hint_latents.shape[0], 190 | num_channels_latents, 191 | height, 192 | width, 193 | ) 194 | self.hint_latents = hint_latents.to(device=self.device, dtype=self.dtype) 195 | 196 | def forward_orig( 197 | self, 198 | img: Tensor, 199 | img_ids: Tensor, 200 | controlnet_cond: Tensor, 201 | txt: Tensor, 202 | txt_ids: Tensor, 203 | timesteps: Tensor, 204 | y: Tensor, 205 | guidance: Tensor = None, 206 | controlnet_mode = None 207 | ) -> Tensor: 208 | if img.ndim != 3 or txt.ndim != 3: 209 | raise ValueError("Input img and txt tensors must have 3 dimensions.") 210 | 211 | batch_size = img.shape[0] 212 | 213 | img = self.img_in(img) 214 | vec = self.time_in(timestep_embedding(timesteps, 256).to(self.dtype)) 215 | if self.params.guidance_embed: 216 | vec.add_(self.guidance_in(timestep_embedding(guidance, 256).to(self.dtype))) 217 | vec.add_(self.vector_in(y)) 218 | 219 | if self.union: 220 | if controlnet_mode is None: 221 | raise ValueError('using union-controlnet, but controlnet_mode is not a list or is empty') 222 | controlnet_mode = torch.tensor([[controlnet_mode]], device=self.device) 223 | emb_controlnet_mode = self.controlnet_mode_embedder(controlnet_mode).to(self.dtype) 224 | vec = vec + emb_controlnet_mode 225 | img = img + self.controlnet_x_embedder(controlnet_cond, emb_controlnet_mode) 226 | else: 227 | img = img + self.controlnet_x_embedder(controlnet_cond) 228 | 229 | txt = self.txt_in(txt) 230 | 231 | if self.union: 232 | token_controlnet_mode = self.controlnet_mode_token_embedder(emb_controlnet_mode)[:, None] 233 | token_controlnet_mode = token_controlnet_mode.expand(txt.size(0), -1, -1) 234 | txt = torch.cat([token_controlnet_mode, txt], dim=1) 235 | txt_ids = torch.cat([txt_ids[:, :1], txt_ids], dim=1) 236 | 237 | ids = torch.cat((txt_ids, img_ids), dim=1) 238 | pe = self.pe_embedder(ids).to(dtype=self.dtype, device=self.device) 239 | 240 | block_res_samples = () 241 | for block in self.transformer_blocks: 242 | txt, img = block(hidden_states=img, encoder_hidden_states=txt, temb=vec, image_rotary_emb=pe) 243 | block_res_samples = block_res_samples + (img,) 244 | 245 | img = torch.cat([txt, img], dim=1) 246 | 247 | single_block_res_samples = () 248 | for block in self.single_transformer_blocks: 249 | img = block(hidden_states=img, temb=vec, image_rotary_emb=pe) 250 | single_block_res_samples = single_block_res_samples + (img[:, txt.shape[1] :],) 251 | 252 | controlnet_block_res_samples = () 253 | for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): 254 | block_res_sample = controlnet_block(block_res_sample) 255 | controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) 256 | 257 | controlnet_single_block_res_samples = () 258 | for single_block_res_sample, single_controlnet_block in zip(single_block_res_samples, self.controlnet_single_blocks): 259 | single_block_res_sample = single_controlnet_block(single_block_res_sample) 260 | controlnet_single_block_res_samples = controlnet_single_block_res_samples + (single_block_res_sample,) 261 | 262 | n_single_blocks = 38 263 | n_double_blocks = 19 264 | 265 | # Expand controlnet_block_res_samples to match n_double_blocks 266 | expanded_controlnet_block_res_samples = [] 267 | interval_control_double = int(np.ceil(n_double_blocks / len(controlnet_block_res_samples))) 268 | for i in range(n_double_blocks): 269 | index = i // interval_control_double 270 | expanded_controlnet_block_res_samples.append(controlnet_block_res_samples[index]) 271 | 272 | # Expand controlnet_single_block_res_samples to match n_single_blocks 273 | expanded_controlnet_single_block_res_samples = [] 274 | interval_control_single = int(np.ceil(n_single_blocks / len(controlnet_single_block_res_samples))) 275 | for i in range(n_single_blocks): 276 | index = i // interval_control_single 277 | expanded_controlnet_single_block_res_samples.append(controlnet_single_block_res_samples[index]) 278 | 279 | return { 280 | "input": expanded_controlnet_block_res_samples, 281 | "output": expanded_controlnet_single_block_res_samples 282 | } 283 | 284 | def forward(self, x, timesteps, context, y, guidance=None, hint=None, control_type=None, **kwargs): 285 | bs, c, h, w = x.shape 286 | patch_size = 2 287 | x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) 288 | 289 | height_control_image, width_control_image = hint.shape[2:] 290 | num_channels_latents = self.in_channels // 4 291 | hint = self._pack_latents( 292 | hint, 293 | hint.shape[0], 294 | num_channels_latents, 295 | height_control_image, 296 | width_control_image, 297 | ) 298 | img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) 299 | 300 | h_len = ((h + (patch_size // 2)) // patch_size) 301 | w_len = ((w + (patch_size // 2)) // patch_size) 302 | img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) 303 | img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None] 304 | img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :] 305 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) 306 | 307 | txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) 308 | return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type) 309 | -------------------------------------------------------------------------------- /controlnet/controlnet_instantx_format2.py: -------------------------------------------------------------------------------- 1 | #Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | from einops import rearrange, repeat 6 | 7 | from comfy.ldm.flux.layers import (timestep_embedding) 8 | 9 | from comfy.ldm.flux.model import Flux 10 | import comfy.ldm.common_dit 11 | import operator as op 12 | import sys 13 | import torch.nn.functional as F 14 | import numbers 15 | from diffusers.models.normalization import AdaLayerNormContinuous 16 | from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock 17 | from diffusers.utils.import_utils import is_torch_version 18 | import numpy as np 19 | 20 | if is_torch_version(">=", "2.1.0"): 21 | LayerNorm = nn.LayerNorm 22 | else: 23 | # Has optional bias parameter compared to torch layer norm 24 | # TODO: replace with torch layernorm once min required torch version >= 2.1 25 | class LayerNorm(nn.Module): 26 | def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True): 27 | super().__init__() 28 | 29 | self.eps = eps 30 | 31 | if isinstance(dim, numbers.Integral): 32 | dim = (dim,) 33 | 34 | self.dim = torch.Size(dim) 35 | 36 | if elementwise_affine: 37 | self.weight = nn.Parameter(torch.ones(dim)) 38 | self.bias = nn.Parameter(torch.zeros(dim)) if bias else None 39 | else: 40 | self.weight = None 41 | self.bias = None 42 | 43 | def forward(self, input): 44 | return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps) 45 | 46 | def zero_module(module): 47 | for p in module.parameters(): 48 | nn.init.zeros_(p) 49 | return module 50 | 51 | # YiYi to-do: refactor rope related functions/classes 52 | def apply_rope(xq, xk, freqs_cis): 53 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) 54 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) 55 | xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] 56 | xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] 57 | return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) 58 | 59 | class InstantXControlNetFluxFormat2(Flux): 60 | def __init__(self, image_model=None, dtype=None, device=None, operations=None, joint_attention_dim=4096, **kwargs): 61 | kwargs["depth"] = 0 62 | kwargs["depth_single_blocks"] = 0 63 | depth_single_blocks_controlnet = kwargs.pop("depth_single_blocks_controlnet", 2) 64 | super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) 65 | 66 | self.transformer_blocks = nn.ModuleList( 67 | [ 68 | FluxTransformerBlock( 69 | dim=self.hidden_size, 70 | num_attention_heads=24, 71 | attention_head_dim=128, 72 | ).to(dtype=dtype) 73 | for i in range(5) 74 | ] 75 | ) 76 | 77 | self.single_transformer_blocks = nn.ModuleList( 78 | [ 79 | FluxSingleTransformerBlock( 80 | dim=self.hidden_size, 81 | num_attention_heads=24, 82 | attention_head_dim=128, 83 | ).to(dtype=dtype) 84 | for i in range(10) 85 | ] 86 | ) 87 | 88 | self.require_vae = True 89 | # add ControlNet blocks 90 | self.controlnet_blocks = nn.ModuleList([]) 91 | for _ in range(len(self.transformer_blocks)): 92 | controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) 93 | controlnet_block = zero_module(controlnet_block) 94 | self.controlnet_blocks.append(controlnet_block) 95 | 96 | self.controlnet_single_blocks = nn.ModuleList([]) 97 | for _ in range(len(self.single_transformer_blocks)): 98 | controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) 99 | controlnet_block = zero_module(controlnet_block) 100 | self.controlnet_single_blocks.append(controlnet_block) 101 | 102 | # TODO support both union and unimodal 103 | self.union = True #num_mode is not None 104 | num_mode = 10 105 | if self.union: 106 | self.controlnet_mode_embedder = nn.Embedding(num_mode, self.hidden_size) 107 | self.controlnet_x_embedder = zero_module(operations.Linear(self.in_channels, self.hidden_size).to(device=device, dtype=dtype)) 108 | self.gradient_checkpointing = False 109 | 110 | @staticmethod 111 | # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents 112 | def _pack_latents(latents, batch_size, num_channels_latents, height, width): 113 | latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) 114 | latents = latents.permute(0, 2, 4, 1, 3, 5) 115 | latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) 116 | 117 | return latents 118 | 119 | def forward_orig( 120 | self, 121 | img: Tensor, 122 | img_ids: Tensor, 123 | controlnet_cond: Tensor, 124 | txt: Tensor, 125 | txt_ids: Tensor, 126 | timesteps: Tensor, 127 | y: Tensor, 128 | guidance: Tensor = None, 129 | controlnet_mode = None 130 | ) -> Tensor: 131 | if img.ndim != 3 or txt.ndim != 3: 132 | raise ValueError("Input img and txt tensors must have 3 dimensions.") 133 | 134 | batch_size = img.shape[0] 135 | 136 | img = self.img_in(img) 137 | vec = self.time_in(timestep_embedding(timesteps, 256).to(self.dtype)) 138 | if self.params.guidance_embed: 139 | vec.add_(self.guidance_in(timestep_embedding(guidance, 256).to(self.dtype))) 140 | vec.add_(self.vector_in(y)) 141 | 142 | txt = self.txt_in(txt) 143 | 144 | if self.union: 145 | if controlnet_mode is None: 146 | raise ValueError('using union-controlnet, but controlnet_mode is not a list or is empty') 147 | controlnet_mode = torch.tensor(controlnet_mode).to(self.device, dtype=torch.long) 148 | controlnet_mode = controlnet_mode.reshape([-1, 1]) 149 | emb_controlnet_mode = self.controlnet_mode_embedder(controlnet_mode).to(self.dtype) 150 | txt = torch.cat([emb_controlnet_mode, txt], dim=1) 151 | txt_ids = torch.cat([txt_ids[:, :1], txt_ids], dim=1) 152 | 153 | img = img + self.controlnet_x_embedder(controlnet_cond) 154 | 155 | txt_ids = txt_ids.expand(img_ids.size(0), -1, -1) 156 | ids = torch.cat((txt_ids, img_ids), dim=1) 157 | pe = self.pe_embedder(ids) 158 | 159 | block_res_samples = () 160 | for block in self.transformer_blocks: 161 | txt, img = block(hidden_states=img, encoder_hidden_states=txt, temb=vec, image_rotary_emb=pe) 162 | block_res_samples = block_res_samples + (img,) 163 | 164 | img = torch.cat([txt, img], dim=1) 165 | 166 | single_block_res_samples = () 167 | for block in self.single_transformer_blocks: 168 | img = block(hidden_states=img, temb=vec, image_rotary_emb=pe) 169 | single_block_res_samples = single_block_res_samples + (img[:, txt.shape[1] :],) 170 | 171 | controlnet_block_res_samples = () 172 | for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): 173 | block_res_sample = controlnet_block(block_res_sample) 174 | controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) 175 | 176 | controlnet_single_block_res_samples = () 177 | for single_block_res_sample, single_controlnet_block in zip(single_block_res_samples, self.controlnet_single_blocks): 178 | single_block_res_sample = single_controlnet_block(single_block_res_sample) 179 | controlnet_single_block_res_samples = controlnet_single_block_res_samples + (single_block_res_sample,) 180 | 181 | n_single_blocks = 38 182 | n_double_blocks = 19 183 | 184 | # Expand controlnet_block_res_samples to match n_double_blocks 185 | expanded_controlnet_block_res_samples = [] 186 | interval_control_double = int(np.ceil(n_double_blocks / len(controlnet_block_res_samples))) 187 | for i in range(n_double_blocks): 188 | index = i // interval_control_double 189 | expanded_controlnet_block_res_samples.append(controlnet_block_res_samples[index]) 190 | 191 | # Expand controlnet_single_block_res_samples to match n_single_blocks 192 | expanded_controlnet_single_block_res_samples = [] 193 | interval_control_single = int(np.ceil(n_single_blocks / len(controlnet_single_block_res_samples))) 194 | for i in range(n_single_blocks): 195 | index = i // interval_control_single 196 | expanded_controlnet_single_block_res_samples.append(controlnet_single_block_res_samples[index]) 197 | 198 | return { 199 | "input": expanded_controlnet_block_res_samples, 200 | "output": expanded_controlnet_single_block_res_samples 201 | } 202 | 203 | def forward(self, x, timesteps, context, y, guidance=None, hint=None, control_type=None, **kwargs): 204 | bs, c, h, w = x.shape 205 | patch_size = 2 206 | x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) 207 | 208 | height_control_image, width_control_image = hint.shape[2:] 209 | num_channels_latents = self.in_channels // 4 210 | hint = self._pack_latents( 211 | hint, 212 | hint.shape[0], 213 | num_channels_latents, 214 | height_control_image, 215 | width_control_image, 216 | ) 217 | img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) 218 | 219 | h_len = ((h + (patch_size // 2)) // patch_size) 220 | w_len = ((w + (patch_size // 2)) // patch_size) 221 | img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) 222 | img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None] 223 | img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :] 224 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) 225 | 226 | txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) 227 | return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type) 228 | -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import comfy.model_management as mm 2 | import folder_paths 3 | import logging 4 | import comfy 5 | import torch 6 | from .controlnet.controlnet_instantx import InstantXControlNetFlux 7 | from .controlnet.controlnet_instantx_format2 import InstantXControlNetFluxFormat2 8 | from comfy.controlnet import ControlNet, controlnet_load_state_dict 9 | from nodes import ControlNetApplyAdvanced 10 | 11 | def load_controlnet_flux_instantx(sd, controlnet_class, weight_dtype): 12 | keys_to_keep = [ 13 | "controlnet_", 14 | "single_transformer_blocks", 15 | "transformer_blocks" 16 | ] 17 | preserved_keys = {k: v.cpu() for k, v in sd.items() if any(k.startswith(key) for key in keys_to_keep)} 18 | 19 | new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "") 20 | 21 | keys_to_discard = [ 22 | "double_blocks", 23 | "single_blocks" 24 | ] 25 | new_sd = {k: v for k, v in new_sd.items() if not any(k.startswith(discard_key) for discard_key in keys_to_discard)} 26 | new_sd.update(preserved_keys) 27 | 28 | config = { 29 | "image_model": "flux", 30 | "axes_dim": [16, 56, 56], 31 | "in_channels": 16, 32 | "depth": 5, 33 | "depth_single_blocks": 10, 34 | "context_in_dim": 4096, 35 | "num_heads": 24, 36 | "guidance_embed": True, 37 | "hidden_size": 3072, 38 | "mlp_ratio": 4.0, 39 | "theta": 10000, 40 | "qkv_bias": True, 41 | "vec_in_dim": 768 42 | } 43 | 44 | device=mm.get_torch_device() 45 | 46 | if weight_dtype == "fp8_e4m3fn": 47 | dtype=torch.float8_e4m3fn 48 | operations = comfy.ops.manual_cast 49 | elif weight_dtype == "fp8_e5m2": 50 | dtype=torch.float8_e5m2 51 | operations = comfy.ops.manual_cast 52 | else: 53 | dtype=torch.bfloat16 54 | operations = comfy.ops.disable_weight_init 55 | 56 | control_model = controlnet_class(operations=operations, device=device, dtype=dtype, **config) 57 | control_model = controlnet_load_state_dict(control_model, new_sd) 58 | extra_conds = ['y', 'guidance', 'control_type'] 59 | latent_format = comfy.latent_formats.SD3() 60 | # TODO check manual cast dtype 61 | control = ControlNet(control_model, compression_ratio=1, load_device=device, manual_cast_dtype=torch.bfloat16, extra_conds=extra_conds, latent_format=latent_format) 62 | return control 63 | 64 | def load_controlnet(full_path, weight_dtype): 65 | controlnet_data = comfy.utils.load_torch_file(full_path, safe_load=True) 66 | if "controlnet_mode_embedder.fc.weight" in controlnet_data: 67 | return load_controlnet_flux_instantx(controlnet_data, InstantXControlNetFlux, weight_dtype) 68 | if "controlnet_mode_embedder.weight" in controlnet_data: 69 | return load_controlnet_flux_instantx(controlnet_data, InstantXControlNetFluxFormat2, weight_dtype) 70 | assert False, f"Only InstantX union controlnet supported. Could not find key 'controlnet_mode_embedder.fc.weight' in {full_path}" 71 | 72 | INSTANTX_UNION_CONTROLNET_TYPES = { 73 | "canny": 0, 74 | "tile": 1, 75 | "depth": 2, 76 | "blur": 3, 77 | "pose": 4, 78 | "gray": 5, 79 | "lq": 6 80 | } 81 | 82 | class InstantXFluxUnionControlNetLoader: 83 | @classmethod 84 | def INPUT_TYPES(s): 85 | return { 86 | "required": { 87 | "control_net_name": (folder_paths.get_filename_list("controlnet"),), 88 | "type": (list(INSTANTX_UNION_CONTROLNET_TYPES.keys()),), 89 | #"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],) 90 | } 91 | } 92 | 93 | RETURN_TYPES = ("CONTROL_NET",) 94 | FUNCTION = "load_controlnet" 95 | CATEGORY = "loaders" 96 | 97 | def load_controlnet(self, control_net_name, type, weight_dtype="default"): 98 | controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) 99 | controlnet = load_controlnet(controlnet_path, weight_dtype) 100 | 101 | type_number = INSTANTX_UNION_CONTROLNET_TYPES.get(type, -1) 102 | controlnet.set_extra_arg("control_type", type_number) 103 | 104 | return (controlnet,) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-eesahesnodes" 3 | description = "InstantX's Flux union ControlNet loader and implementation" 4 | version = "1.0.0" 5 | license = {file = "LICENSE"} 6 | dependencies = ["diffusers>=0.30.0", "einops>=0.7.0"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/EeroHeikkinen/ComfyUI-eesahesNodes" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "eesahe" 14 | DisplayName = "ComfyUI-eesahesNodes" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers>=0.30.0 2 | einops>=0.7.0 --------------------------------------------------------------------------------