├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── chroma ├── layers.py └── model.py ├── flux ├── layers.py └── model.py ├── hidream └── model.py ├── hunyuan_video └── model.py ├── node.py ├── sample.py ├── samplers.py ├── sd ├── attention.py └── openaimodel.py ├── sd3 └── mmdit.py ├── utils.py ├── wan └── model.py ├── workflow.png └── workflows ├── NAG-Chroma-ComfyUI-Workflow.json ├── NAG-DMD2-ComfyUI-Workflow.json ├── NAG-Flux-Dev-ComfyUI-Workflow.json ├── NAG-Flux-Kontext-Dev-ComfyUI-Workflow.json ├── NAG-Flux-Schnell-ComfyUI-Workflow.json ├── NAG-Hunyuan-ComfyUI-Workflow.json ├── NAG-SD15-ComfyUI-Workflow.json ├── NAG-SD3.5-Turbo-ComfyUI-Workflow.json ├── NAG-Wan-ComfyUI-Workflow.json └── NAG-Wan-Fast-ComfyUI-Workflow.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | .idea/ 169 | 170 | # Abstra 171 | # Abstra is an AI-powered process automation framework. 172 | # Ignore directories containing user credentials, local state, and settings. 173 | # Learn more at https://abstra.io/docs 174 | .abstra/ 175 | 176 | # Visual Studio Code 177 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 178 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 179 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 180 | # you could uncomment the following to ignore the enitre vscode folder 181 | # .vscode/ 182 | 183 | # Ruff stuff: 184 | .ruff_cache/ 185 | 186 | # PyPI configuration file 187 | .pypirc 188 | 189 | # Cursor 190 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 191 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 192 | # refer to https://docs.cursor.com/context/ignore-files 193 | .cursorignore 194 | .cursorindexingignore 195 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Dar-Yen Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-NAG 2 | 3 | Implementation of [Normalized Attention Guidance: Universal Negative Guidance for Diffusion Models](https://chendaryen.github.io/NAG.github.io/) for [ComfyUI](https://github.com/comfyanonymous/ComfyUI). 4 | 5 | NAG restores effective negative prompting in few-step diffusion models, and complements CFG in multi-step sampling for improved quality and control. 6 | 7 | Paper: https://arxiv.org/abs/2505.21179 8 | 9 | Code: https://github.com/ChenDarYen/Normalized-Attention-Guidance 10 | 11 | Wan2.1 Demo: https://huggingface.co/spaces/ChenDY/NAG_wan2-1-fast 12 | 13 | LTX Video Demo: https://huggingface.co/spaces/ChenDY/NAG_ltx-video-distilled 14 | 15 | Flux-Dev Demo: https://huggingface.co/spaces/ChenDY/NAG_FLUX.1-dev 16 | 17 | ![comfyui-nag](workflow.png?cache=20250628) 18 | 19 | ## News 20 | 21 | 2025-07-06: Add three new nodes: 22 | - `KSamplerWithNAG (Advanced)` as a drop-in replacement for `KSampler (Advanced)`. 23 | - `SamplerCustomWithNAG` for `SamplerCustom`. 24 | - `NAGGuider` for `BasicGuider`. 25 | 26 | 2025-07-02: `HiDream` is now supported! 27 | 28 | 2025-07-02: Add support for `TeaCache` and `WaveSpeed` to accelerate NAG sampling! 29 | 30 | 2025-06-30: Fix a major bug affecting `Flux`, `Flux Kontext` and `Chroma`, resulting in degraded guidance. Please update your NAG node! 31 | 32 | 2025-06-29: Add compile model support. You can now use compile model nodes like `TorchCompileModel` to speed up NAG sampling! 33 | 34 | 2025-06-28: `Flux Kontext` is now supported. Check out the [workflow](https://github.com/ChenDarYen/ComfyUI-NAG/blob/main/workflows/NAG-Flux-Kontext-Dev-ComfyUI-Workflow.json)! 35 | 36 | 2025-06-26: `Hunyuan video` is now supported! 37 | 38 | 2025-06-25: `Wan` video generation is now supported (GGUF compatible)! Try it out with the new [workflow](https://github.com/ChenDarYen/ComfyUI-NAG/blob/main/workflows/NAG-Wan-Fast-ComfyUI-Workflow.json)! 39 | 40 | ## Nodes 41 | 42 | - `KSamplerWithNAG`, `KSamplerWithNAG (Advanced)`, `SamplerCustomWithNAG` 43 | - `BasicGuider`, `NAGCFGGuider` 44 | 45 | ## Usage 46 | 47 | To use NAG, simply replace 48 | - `KSampler` with `KSamplerWithNAG`. 49 | - `KSamplerWithNAG (Advanced)` with `KSampler (Advanced)`. 50 | - `SamplerCustomWithNAG` with `SamplerCustom`. 51 | - `NAGGuider` with `BasicGuider`. 52 | - `CFGGuider` with `NAGCFGGuider`. 53 | 54 | We currently support `Flux`, `Flux Kontext`, `Wan`, `Vace Wan`, `Hunyuan Video`, `Choroma`, `SD3.5`, `SDXL` and `SD`. 55 | 56 | Example workflows are available in the `./workflows` directory! 57 | 58 | ## Key Inputs 59 | 60 | When working with a new model, it's recommended to first find a good combination of `nag_tau` and `nag_alpha`, which ensures that the negative guidance is effective without introducing artifacts. 61 | 62 | Once you're satisfied, keep `nag_tau` and `nag_alpha` fixed and tune only `nag_scale` in most cases to control the strength of guidance. 63 | 64 | Using `nag_sigma_end` to reduce computation without much quality drop. 65 | 66 | For flow-based models like `Flux`, `nag_sigma_end = 0.75` achieves near-identical results with significantly improved speed. For diffusion-based `SDXL`, a good default is `nag_sigma_end = 4`. 67 | 68 | - `nag_scale`: The scale for attention feature extrapolation. Higher values result in stronger negative guidance. 69 | - `nag_tau`: The normalisation threshold. Higher values result in stronger negative guidance. 70 | - `nag_alpha`: Blending factor between original and extrapolated attention. Higher values result in stronger negative guidance. 71 | - `nag_sigma_end`: NAG will be active only until `nag_sigma_end`. 72 | 73 | ### Rule of Thumb 74 | 75 | - For image-reference tasks (e.g., Image2Video), use lower `nag_tau` and `nag_alpha` to preserve the reference content more faithfully. 76 | - For models that require more sampling steps and higher CFG, also prefer lower `nag_tau` and `nag_alpha`. 77 | - For few-step models, you can use higher `nag_tau` and `nag_alpha` to have stronger negative guidance. 78 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .node import NODE_CLASS_MAPPINGS 2 | __all__ = ['NODE_CLASS_MAPPINGS'] 3 | -------------------------------------------------------------------------------- /chroma/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from comfy.ldm.flux.math import attention 5 | from comfy.ldm.chroma.layers import DoubleStreamBlock, SingleStreamBlock 6 | 7 | from ..utils import nag 8 | 9 | 10 | class NAGDoubleStreamBlock(DoubleStreamBlock): 11 | def __init__( 12 | self, 13 | *args, 14 | nag_scale: float = 1, 15 | nag_tau: float = 2.5, 16 | nag_alpha: float = 0.25, 17 | **kwargs, 18 | ): 19 | super().__init__(*args, **kwargs) 20 | self.nag_scale = nag_scale 21 | self.nag_tau = nag_tau 22 | self.nag_alpha = nag_alpha 23 | 24 | def forward( 25 | self, 26 | img: Tensor, 27 | txt: Tensor, 28 | pe: Tensor, 29 | pe_negative: Tensor, 30 | vec: Tensor, 31 | attn_mask=None, 32 | context_pad_len: int = 0, 33 | nag_pad_len: int = 0, 34 | ): 35 | origin_bsz = len(txt) - len(img) 36 | assert origin_bsz != 0 37 | 38 | (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec 39 | 40 | # prepare image for attention 41 | img_modulated = torch.addcmul(img_mod1.shift[:-origin_bsz], 1 + img_mod1.scale[:-origin_bsz], self.img_norm1(img)) 42 | img_qkv = self.img_attn.qkv(img_modulated) 43 | img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 44 | 1, 4) 45 | img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) 46 | 47 | # prepare txt for attention 48 | txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt)) 49 | txt_qkv = self.txt_attn.qkv(txt_modulated) 50 | txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 51 | 1, 4) 52 | txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) 53 | 54 | txt_q_negative, txt_q = txt_q[-origin_bsz:, :, nag_pad_len:], txt_q[:-origin_bsz, :, context_pad_len:] 55 | txt_k_negative, txt_k = txt_k[-origin_bsz:, :, nag_pad_len:], txt_k[:-origin_bsz, :, context_pad_len:] 56 | txt_v_negative, txt_v = txt_v[-origin_bsz:, :, nag_pad_len:], txt_v[:-origin_bsz, :, context_pad_len:] 57 | 58 | img_q_negative = img_q[-origin_bsz:] 59 | img_k_negative = img_k[-origin_bsz:] 60 | img_v_negative = img_v[-origin_bsz:] 61 | 62 | # run actual attention 63 | attn_negative = attention( 64 | torch.cat((txt_q_negative, img_q_negative), dim=2), 65 | torch.cat((txt_k_negative, img_k_negative), dim=2), 66 | torch.cat((txt_v_negative, img_v_negative), dim=2), 67 | pe=pe_negative, mask=attn_mask, 68 | ) 69 | attn = attention( 70 | torch.cat((txt_q, img_q), dim=2), 71 | torch.cat((txt_k, img_k), dim=2), 72 | torch.cat((txt_v, img_v), dim=2), 73 | pe=pe, mask=attn_mask, 74 | ) 75 | 76 | txt_attn_negative, img_attn_negative = attn_negative[:, : txt.shape[1] - nag_pad_len], attn_negative[:, txt.shape[1] - nag_pad_len:] 77 | txt_attn, img_attn = attn[:, : txt.shape[1] - context_pad_len], attn[:, txt.shape[1] - context_pad_len:] 78 | 79 | # NAG 80 | img_attn_positive = img_attn[-origin_bsz:] 81 | img_attn_guidance = nag(img_attn_positive, img_attn_negative, self.nag_scale, self.nag_tau, self.nag_alpha) 82 | 83 | img_attn = torch.cat([img_attn[:-origin_bsz], img_attn_guidance], dim=0) 84 | 85 | # calculate the img bloks 86 | img.addcmul_(img_mod1.gate[:-origin_bsz], self.img_attn.proj(img_attn)) 87 | img.addcmul_( 88 | img_mod2.gate[:-origin_bsz], 89 | self.img_mlp(torch.addcmul(img_mod2.shift[:-origin_bsz], 1 + img_mod2.scale[:-origin_bsz], self.img_norm2(img))), 90 | ) 91 | 92 | # calculate the txt bloks 93 | txt[:-origin_bsz, context_pad_len:].addcmul_(txt_mod1.gate[:-origin_bsz], self.txt_attn.proj(txt_attn)) 94 | txt[-origin_bsz:, nag_pad_len:].addcmul_(txt_mod1.gate[-origin_bsz:], self.txt_attn.proj(txt_attn_negative)) 95 | txt.addcmul_(txt_mod2.gate, 96 | self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt)))) 97 | 98 | if txt.dtype == torch.float16: 99 | txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) 100 | 101 | return img, txt 102 | 103 | 104 | class NAGSingleStreamBlock(SingleStreamBlock): 105 | def __init__( 106 | self, 107 | *args, 108 | nag_scale: float = 1, 109 | nag_tau: float = 2.5, 110 | nag_alpha: float = 0.25, 111 | **kwargs, 112 | ): 113 | super().__init__(*args, **kwargs) 114 | self.nag_scale = nag_scale 115 | self.nag_tau = nag_tau 116 | self.nag_alpha = nag_alpha 117 | 118 | def forward( 119 | self, 120 | x: Tensor, 121 | pe: Tensor, 122 | pe_negative: Tensor, 123 | vec: Tensor, 124 | attn_mask=None, 125 | txt_length:int = None, 126 | origin_bsz: int = None, 127 | context_pad_len: int = 0, 128 | nag_pad_len: int = 0, 129 | ) -> Tensor: 130 | mod = vec 131 | x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x)) 132 | qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) 133 | 134 | q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 135 | q, k = self.norm(q, k, v) 136 | 137 | # NAG 138 | q, q_negative = q[:-origin_bsz, :, context_pad_len:], q[-origin_bsz:, :, nag_pad_len:] 139 | k, k_negative = k[:-origin_bsz, :, context_pad_len:], k[-origin_bsz:, :, nag_pad_len:] 140 | v, v_negative = v[:-origin_bsz, :, context_pad_len:], v[-origin_bsz:, :, nag_pad_len:] 141 | 142 | attn_negative = attention(q_negative, k_negative, v_negative, pe=pe_negative, mask=attn_mask) 143 | attn = attention(q, k, v, pe=pe, mask=attn_mask) 144 | 145 | img_attn_negative = attn_negative[:, txt_length - nag_pad_len:] 146 | img_attn = attn[:, txt_length - context_pad_len:] 147 | 148 | img_attn_positive = img_attn[-origin_bsz:] 149 | img_attn_guidance = nag(img_attn_positive, img_attn_negative, self.nag_scale, self.nag_tau, self.nag_alpha) 150 | 151 | attn_negative[:, txt_length - nag_pad_len:] = img_attn_guidance 152 | attn[-origin_bsz:, txt_length - context_pad_len:] = img_attn_guidance 153 | 154 | # compute activation in mlp stream, cat again and run second linear layer 155 | output_negative = self.linear2(torch.cat((attn_negative, self.mlp_act(mlp[-origin_bsz:, nag_pad_len:])), 2)) 156 | output = self.linear2(torch.cat((attn, self.mlp_act(mlp[:-origin_bsz, context_pad_len:])), 2)) 157 | 158 | x[:-origin_bsz, context_pad_len:].addcmul_(mod.gate[:-origin_bsz], output) 159 | x[-origin_bsz:, nag_pad_len:].addcmul_(mod.gate[-origin_bsz:], output_negative) 160 | 161 | if x.dtype == torch.float16: 162 | x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) 163 | return x 164 | -------------------------------------------------------------------------------- /chroma/model.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from types import MethodType 3 | 4 | import torch 5 | from torch import Tensor 6 | from einops import rearrange, repeat 7 | import comfy.ldm.common_dit 8 | 9 | from comfy.ldm.flux.layers import timestep_embedding 10 | from comfy.ldm.chroma.layers import ( 11 | DoubleStreamBlock, 12 | SingleStreamBlock, 13 | ) 14 | from comfy.ldm.chroma.model import Chroma 15 | 16 | from .layers import NAGDoubleStreamBlock, NAGSingleStreamBlock 17 | from ..utils import cat_context, check_nag_activation, NAGSwitch 18 | 19 | 20 | class NAGChroma(Chroma): 21 | def forward_orig( 22 | self, 23 | img: Tensor, 24 | img_ids: Tensor, 25 | txt: Tensor, 26 | txt_ids: Tensor, 27 | txt_ids_negative: Tensor, 28 | timesteps: Tensor, 29 | guidance: Tensor = None, 30 | control = None, 31 | transformer_options={}, 32 | attn_mask: Tensor = None, 33 | ) -> Tensor: 34 | patches_replace = transformer_options.get("patches_replace", {}) 35 | if img.ndim != 3 or txt.ndim != 3: 36 | raise ValueError("Input img and txt tensors must have 3 dimensions.") 37 | 38 | # running on sequences img 39 | img = self.img_in(img) 40 | 41 | # distilled vector guidance 42 | mod_index_length = 344 43 | distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(img.device, img.dtype) 44 | # guidance = guidance * 45 | distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype) 46 | 47 | # get all modulation index 48 | modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype) 49 | # we need to broadcast the modulation index here so each batch has all of the index 50 | modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype) 51 | # and we need to broadcast timestep and guidance along too 52 | timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype) 53 | # then and only then we could concatenate it together 54 | input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype) 55 | 56 | mod_vectors = self.distilled_guidance_layer(input_vec) 57 | 58 | origin_bsz = len(txt) - len(img) 59 | mod_vectors = torch.cat((mod_vectors, mod_vectors[-origin_bsz:]), dim=0) 60 | 61 | txt = self.txt_in(txt) 62 | 63 | ids = torch.cat((txt_ids, img_ids), dim=1) 64 | ids_negative = torch.cat((txt_ids_negative, img_ids[-origin_bsz:]), dim=1) 65 | pe = self.pe_embedder(ids) 66 | pe_negative = self.pe_embedder(ids_negative) 67 | 68 | blocks_replace = patches_replace.get("dit", {}) 69 | for i, block in enumerate(self.double_blocks): 70 | if i not in self.skip_mmdit: 71 | double_mod = ( 72 | self.get_modulations(mod_vectors, "double_img", idx=i), 73 | self.get_modulations(mod_vectors, "double_txt", idx=i), 74 | ) 75 | if ("double_block", i) in blocks_replace: 76 | def block_wrap(args): 77 | out = {} 78 | out["img"], out["txt"] = block(img=args["img"], 79 | txt=args["txt"], 80 | vec=args["vec"], 81 | pe=args["pe"], 82 | pe_negative=args["pe_negative"], 83 | attn_mask=args.get("attn_mask")) 84 | return out 85 | 86 | out = blocks_replace[("double_block", i)]({"img": img, 87 | "txt": txt, 88 | "vec": double_mod, 89 | "pe": pe, 90 | "pe_negative": pe_negative, 91 | "attn_mask": attn_mask}, 92 | {"original_block": block_wrap}) 93 | txt = out["txt"] 94 | img = out["img"] 95 | else: 96 | img, txt = block(img=img, 97 | txt=txt, 98 | vec=double_mod, 99 | pe=pe, 100 | pe_negative=pe_negative, 101 | attn_mask=attn_mask) 102 | 103 | if control is not None: # Controlnet 104 | control_i = control.get("input") 105 | if i < len(control_i): 106 | add = control_i[i] 107 | if add is not None: 108 | img += add 109 | 110 | img = torch.cat((img, img[-origin_bsz:]), dim=0) 111 | img = torch.cat((txt, img), 1) 112 | 113 | for i, block in enumerate(self.single_blocks): 114 | if i not in self.skip_dit: 115 | single_mod = self.get_modulations(mod_vectors, "single", idx=i) 116 | if ("single_block", i) in blocks_replace: 117 | def block_wrap(args): 118 | out = {} 119 | out["img"] = block(args["img"], 120 | vec=args["vec"], 121 | pe=args["pe"], 122 | pe_negative=args["pe_negative"], 123 | attn_mask=args.get("attn_mask")) 124 | return out 125 | 126 | out = blocks_replace[("single_block", i)]({"img": img, 127 | "vec": single_mod, 128 | "pe": pe, 129 | "pe_negative": pe_negative, 130 | "attn_mask": attn_mask}, 131 | {"original_block": block_wrap}) 132 | img = out["img"] 133 | else: 134 | img = block(img, vec=single_mod, pe=pe, pe_negative=pe_negative, attn_mask=attn_mask) 135 | 136 | if control is not None: # Controlnet 137 | control_o = control.get("output") 138 | if i < len(control_o): 139 | add = control_o[i] 140 | if add is not None: 141 | img[:, txt.shape[1] :, ...] += add 142 | 143 | img = img[:-origin_bsz] 144 | img = img[:, txt.shape[1] :, ...] 145 | final_mod = self.get_modulations(mod_vectors[:-origin_bsz], "final") 146 | img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) 147 | return img 148 | 149 | def forward( 150 | self, 151 | x, 152 | timestep, 153 | context, 154 | guidance, 155 | control=None, 156 | transformer_options={}, 157 | 158 | nag_negative_context=None, 159 | nag_sigma_end=0., 160 | 161 | **kwargs, 162 | ): 163 | bs, c, h, w = x.shape 164 | patch_size = 2 165 | x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) 166 | 167 | img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) 168 | 169 | h_len = ((h + (patch_size // 2)) // patch_size) 170 | w_len = ((w + (patch_size // 2)) // patch_size) 171 | img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) 172 | img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) 173 | img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) 174 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) 175 | 176 | apply_nag = check_nag_activation(transformer_options, nag_sigma_end) 177 | if apply_nag: 178 | origin_context_len = context.shape[1] 179 | nag_bsz, nag_negative_context_len = nag_negative_context.shape[:2] 180 | context = cat_context(context, nag_negative_context, trim_context=True) 181 | context_pad_len = context.shape[1] - origin_context_len 182 | nag_pad_len = context.shape[1] - nag_negative_context_len 183 | 184 | forward_orig_ = self.forward_orig 185 | double_blocks_forward = list() 186 | single_blocks_forward = list() 187 | 188 | self.forward_orig = MethodType(NAGChroma.forward_orig, self) 189 | for block in self.double_blocks: 190 | double_blocks_forward.append(block.forward) 191 | block.forward = MethodType( 192 | partial( 193 | NAGDoubleStreamBlock.forward, 194 | context_pad_len=context_pad_len, 195 | nag_pad_len=nag_pad_len, 196 | ), 197 | block, 198 | ) 199 | for block in self.single_blocks: 200 | single_blocks_forward.append(block.forward) 201 | block.forward = MethodType( 202 | partial( 203 | NAGSingleStreamBlock.forward, 204 | txt_length=context.shape[1], 205 | origin_bsz=nag_bsz, 206 | context_pad_len=context_pad_len, 207 | nag_pad_len=nag_pad_len, 208 | ), 209 | block, 210 | ) 211 | 212 | txt_ids = torch.zeros((bs, origin_context_len, 3), device=x.device, dtype=x.dtype) 213 | txt_ids_negative = torch.zeros((nag_bsz, nag_negative_context_len, 3), device=x.device, dtype=x.dtype) 214 | out = self.forward_orig( 215 | img, img_ids, context, txt_ids, txt_ids_negative, timestep, guidance, control, transformer_options, 216 | attn_mask=kwargs.get("attention_mask", None), 217 | ) 218 | 219 | self.forward_orig = forward_orig_ 220 | for block in self.double_blocks: 221 | block.forward = double_blocks_forward.pop(0) 222 | for block in self.single_blocks: 223 | block.forward = single_blocks_forward.pop(0) 224 | 225 | else: 226 | txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) 227 | out = self.forward_orig( 228 | img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, 229 | attn_mask=kwargs.get("attention_mask", None), 230 | ) 231 | 232 | return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w] 233 | 234 | 235 | class NAGChromaSwitch(NAGSwitch): 236 | def set_nag(self): 237 | self.model.forward = MethodType( 238 | partial( 239 | NAGChroma.forward, 240 | nag_negative_context=self.nag_negative_cond[0][0], 241 | nag_negative_y=self.nag_negative_cond[0][1]["pooled_output"], 242 | nag_sigma_end=self.nag_sigma_end, 243 | ), 244 | self.model, 245 | ) 246 | for block in self.model.double_blocks: 247 | block.nag_scale = self.nag_scale 248 | block.nag_tau = self.nag_tau 249 | block.nag_alpha = self.nag_alpha 250 | for block in self.model.single_blocks: 251 | block.nag_scale = self.nag_scale 252 | block.nag_tau = self.nag_tau 253 | block.nag_alpha = self.nag_alpha 254 | -------------------------------------------------------------------------------- /flux/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from comfy.ldm.flux.math import attention 5 | from comfy.ldm.flux.layers import DoubleStreamBlock, SingleStreamBlock, apply_mod 6 | 7 | from ..utils import nag 8 | 9 | 10 | class NAGDoubleStreamBlock(DoubleStreamBlock): 11 | def __init__( 12 | self, 13 | *args, 14 | nag_scale: float = 1, 15 | nag_tau: float = 2.5, 16 | nag_alpha: float = 0.25, 17 | **kwargs, 18 | ): 19 | super().__init__(*args, **kwargs) 20 | self.nag_scale = nag_scale 21 | self.nag_tau = nag_tau 22 | self.nag_alpha = nag_alpha 23 | 24 | def forward( 25 | self, 26 | img: Tensor, 27 | txt: Tensor, 28 | vec: Tensor, 29 | pe: Tensor, 30 | pe_negative: Tensor, 31 | attn_mask=None, 32 | modulation_dims_img=None, 33 | modulation_dims_txt=None, 34 | context_pad_len: int = 0, 35 | nag_pad_len: int = 0, 36 | ): 37 | origin_bsz = len(txt) - len(img) 38 | assert origin_bsz != 0 39 | 40 | img_mod1, img_mod2 = self.img_mod(vec[:-origin_bsz]) 41 | txt_mod1, txt_mod2 = self.txt_mod(vec) 42 | 43 | # prepare image for attention 44 | img_modulated = self.img_norm1(img) 45 | img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img) 46 | img_qkv = self.img_attn.qkv(img_modulated) 47 | img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 48 | 1, 4) 49 | img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) 50 | 51 | # prepare txt for attention 52 | txt_modulated = self.txt_norm1(txt) 53 | txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt) 54 | txt_qkv = self.txt_attn.qkv(txt_modulated) 55 | txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 56 | 1, 4) 57 | txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) 58 | 59 | txt_q_negative, txt_q = txt_q[-origin_bsz:, :, nag_pad_len:], txt_q[:-origin_bsz, :, context_pad_len:] 60 | txt_k_negative, txt_k = txt_k[-origin_bsz:, :, nag_pad_len:], txt_k[:-origin_bsz, :, context_pad_len:] 61 | txt_v_negative, txt_v = txt_v[-origin_bsz:, :, nag_pad_len:], txt_v[:-origin_bsz, :, context_pad_len:] 62 | 63 | img_q_negative = img_q[-origin_bsz:] 64 | img_k_negative = img_k[-origin_bsz:] 65 | img_v_negative = img_v[-origin_bsz:] 66 | 67 | if self.flipped_img_txt: 68 | # run actual attention 69 | attn_negative = attention( 70 | torch.cat((img_q_negative, txt_q_negative), dim=2), 71 | torch.cat((img_k_negative, txt_k_negative), dim=2), 72 | torch.cat((img_v_negative, txt_v_negative), dim=2), 73 | pe=pe_negative, mask=attn_mask, 74 | ) 75 | attn = attention( 76 | torch.cat((img_q, txt_q), dim=2), 77 | torch.cat((img_k, txt_k), dim=2), 78 | torch.cat((img_v, txt_v), dim=2), 79 | pe=pe, mask=attn_mask, 80 | ) 81 | 82 | img_attn_negative, txt_attn_negative = attn_negative[:, :img.shape[1]], attn_negative[:, img.shape[1]:] 83 | img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:] 84 | else: 85 | # run actual attention 86 | attn_negative = attention( 87 | torch.cat((txt_q_negative, img_q_negative), dim=2), 88 | torch.cat((txt_k_negative, img_k_negative), dim=2), 89 | torch.cat((txt_v_negative, img_v_negative), dim=2), 90 | pe=pe_negative, mask=attn_mask, 91 | ) 92 | attn = attention( 93 | torch.cat((txt_q, img_q), dim=2), 94 | torch.cat((txt_k, img_k), dim=2), 95 | torch.cat((txt_v, img_v), dim=2), 96 | pe=pe, mask=attn_mask, 97 | ) 98 | 99 | txt_attn_negative, img_attn_negative = attn_negative[:, : txt.shape[1] - nag_pad_len], attn_negative[:, txt.shape[1] - nag_pad_len:] 100 | txt_attn, img_attn = attn[:, : txt.shape[1] - context_pad_len], attn[:, txt.shape[1] - context_pad_len:] 101 | 102 | # NAG 103 | img_attn_positive = img_attn[-origin_bsz:] 104 | img_attn_guidance = nag(img_attn_positive, img_attn_negative, self.nag_scale, self.nag_tau, self.nag_alpha) 105 | 106 | img_attn = torch.cat([img_attn[:-origin_bsz], img_attn_guidance], dim=0) 107 | 108 | # calculate the img bloks 109 | img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) 110 | img = img + apply_mod( 111 | self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), 112 | img_mod2.gate, None, modulation_dims_img) 113 | 114 | # calculate the txt bloks 115 | txt[:-origin_bsz, context_pad_len:].addcmul_(txt_mod1.gate[:-origin_bsz], self.txt_attn.proj(txt_attn)) 116 | txt[-origin_bsz:, nag_pad_len:].addcmul_(txt_mod1.gate[-origin_bsz:], self.txt_attn.proj(txt_attn_negative)) 117 | txt += apply_mod( 118 | self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), 119 | txt_mod2.gate, None, modulation_dims_txt) 120 | 121 | if txt.dtype == torch.float16: 122 | txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) 123 | 124 | return img, txt 125 | 126 | 127 | class NAGSingleStreamBlock(SingleStreamBlock): 128 | def __init__( 129 | self, 130 | *args, 131 | nag_scale: float = 1, 132 | nag_tau: float = 2.5, 133 | nag_alpha: float = 0.25, 134 | **kwargs, 135 | ): 136 | super().__init__(*args, **kwargs) 137 | self.nag_scale = nag_scale 138 | self.nag_tau = nag_tau 139 | self.nag_alpha = nag_alpha 140 | 141 | def forward( 142 | self, 143 | x: Tensor, 144 | vec: Tensor, 145 | pe: Tensor, 146 | pe_negative: Tensor, 147 | attn_mask=None, 148 | modulation_dims=None, 149 | 150 | txt_length: int = None, 151 | img_length: int = None, 152 | origin_bsz: int = None, 153 | context_pad_len: int = 0, 154 | nag_pad_len: int = 0, 155 | ) -> Tensor: 156 | mod= self.modulation(vec)[0] 157 | qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) 158 | 159 | q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 160 | q, k = self.norm(q, k, v) 161 | 162 | # NAG 163 | if txt_length is not None: 164 | def remove_pad_and_get_neg(feature, pad_dim=2): 165 | assert pad_dim in [1, 2] 166 | if pad_dim == 2: 167 | feature_negative = feature[-origin_bsz:, :, nag_pad_len:] 168 | feature = feature[:-origin_bsz, :, context_pad_len:] 169 | else: 170 | feature_negative = feature[-origin_bsz:, nag_pad_len:] 171 | feature = feature[:-origin_bsz, context_pad_len:] 172 | 173 | return feature_negative, feature 174 | 175 | else: 176 | def remove_pad_and_get_neg(feature, pad_dim=2): 177 | assert pad_dim in [1, 2] 178 | if pad_dim == 2: 179 | feature_negative = torch.cat([feature[-origin_bsz:, :, :img_length], feature[-origin_bsz:, :, img_length + nag_pad_len:]], dim=2) 180 | feature = torch.cat([feature[:-origin_bsz, :, :img_length], feature[:-origin_bsz, :, img_length + context_pad_len:]], dim=2) 181 | else: 182 | feature_negative = torch.cat([feature[-origin_bsz:, :img_length], feature[-origin_bsz:, img_length + nag_pad_len:]], dim=1) 183 | feature = torch.cat([feature[:-origin_bsz, :img_length], feature[:-origin_bsz, img_length + context_pad_len:]], dim=1) 184 | return feature_negative, feature 185 | 186 | q_negative, q = remove_pad_and_get_neg(q) 187 | k_negative, k = remove_pad_and_get_neg(k) 188 | v_negative, v = remove_pad_and_get_neg(v) 189 | 190 | # compute attention 191 | attn_negative = attention(q_negative, k_negative, v_negative, pe=pe_negative, mask=attn_mask) 192 | attn = attention(q, k, v, pe=pe, mask=attn_mask) 193 | 194 | if txt_length is not None: 195 | img_attn_negative = attn_negative[:, txt_length - nag_pad_len:] 196 | img_attn = attn[:, txt_length - context_pad_len:] 197 | else: 198 | img_attn_negative = attn_negative[:, :img_length] 199 | img_attn = attn[:, :img_length] 200 | 201 | img_attn_positive = img_attn[-origin_bsz:] 202 | img_attn_guidance = nag(img_attn_positive, img_attn_negative, self.nag_scale, self.nag_tau, self.nag_alpha) 203 | 204 | if txt_length is not None: 205 | attn_negative[:, txt_length - nag_pad_len:] = img_attn_guidance 206 | attn[-origin_bsz:, txt_length - context_pad_len:] = img_attn_guidance 207 | else: 208 | attn_negative[:, :img_length] = img_attn_guidance 209 | attn[-origin_bsz:, :img_length] = img_attn_guidance 210 | 211 | # compute activation in mlp stream, cat again and run second linear layer 212 | mlp_negative, mlp = remove_pad_and_get_neg(mlp, pad_dim=1) 213 | output_negative = self.linear2(torch.cat((attn_negative, self.mlp_act(mlp_negative)), 2)) 214 | output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) 215 | 216 | if txt_length is not None: 217 | x[:-origin_bsz, context_pad_len:] += apply_mod(output, mod.gate[:-origin_bsz], None, modulation_dims) 218 | x[-origin_bsz:, nag_pad_len:] += apply_mod(output_negative, mod.gate[-origin_bsz:], None, modulation_dims) 219 | else: 220 | x[:-origin_bsz, :img_length] += apply_mod(output[:, :img_length], mod.gate[:-origin_bsz], None, modulation_dims) 221 | x[:-origin_bsz, img_length + context_pad_len:] += apply_mod(output[:, img_length:], mod.gate[:-origin_bsz], None, modulation_dims) 222 | x[-origin_bsz:, :img_length] += apply_mod(output_negative[:, :img_length], mod.gate[-origin_bsz:], None, modulation_dims) 223 | x[-origin_bsz:, img_length + nag_pad_len:] += apply_mod(output_negative[:, img_length:], mod.gate[-origin_bsz:], None, modulation_dims) 224 | 225 | if x.dtype == torch.float16: 226 | x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) 227 | return x 228 | -------------------------------------------------------------------------------- /hidream/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from types import MethodType 3 | from functools import partial 4 | 5 | import torch 6 | from einops import repeat 7 | 8 | from comfy.ldm.flux.math import apply_rope 9 | import comfy.model_management 10 | import comfy.ldm.common_dit 11 | from comfy.ldm.hidream.model import ( 12 | HiDreamImageTransformer2DModel, 13 | HiDreamAttention, 14 | HiDreamImageTransformerBlock, 15 | attention, 16 | ) 17 | 18 | from ..utils import nag, cat_context, check_nag_activation, NAGSwitch 19 | 20 | 21 | class NAGHiDreamAttnProcessor_flashattn: 22 | """Attention processor used typically in processing the SD3-like self-attention projections.""" 23 | def __init__( 24 | self, 25 | nag_scale: float = 1.0, 26 | nag_tau=2.5, 27 | nag_alpha=0.25, 28 | encoder_hidden_states_length: int = None, 29 | origin_batch_size: int = None, 30 | ): 31 | self.nag_scale = nag_scale 32 | self.nag_tau = nag_tau 33 | self.nag_alpha = nag_alpha 34 | self.encoder_hidden_states_length = encoder_hidden_states_length 35 | self.origin_batch_size = origin_batch_size 36 | 37 | def __call__( 38 | self, 39 | attn, 40 | image_tokens: torch.FloatTensor, 41 | image_tokens_masks: Optional[torch.FloatTensor] = None, 42 | text_tokens: Optional[torch.FloatTensor] = None, 43 | rope: torch.FloatTensor = None, 44 | *args, 45 | **kwargs, 46 | ) -> torch.FloatTensor: 47 | dtype = image_tokens.dtype 48 | batch_size = image_tokens.shape[0] 49 | origin_batch_size = self.origin_batch_size 50 | txt_batch_size = text_tokens.shape[0] if text_tokens is not None else batch_size 51 | if text_tokens is not None: 52 | assert txt_batch_size == batch_size + origin_batch_size 53 | 54 | query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype) 55 | key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype) 56 | value_i = attn.to_v(image_tokens) 57 | 58 | inner_dim = key_i.shape[-1] 59 | head_dim = inner_dim // attn.heads 60 | 61 | query_i = query_i.view(batch_size, -1, attn.heads, head_dim) 62 | key_i = key_i.view(batch_size, -1, attn.heads, head_dim) 63 | value_i = value_i.view(batch_size, -1, attn.heads, head_dim) 64 | if image_tokens_masks is not None: 65 | key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1) 66 | 67 | if not attn.single: 68 | query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype) 69 | key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype) 70 | value_t = attn.to_v_t(text_tokens) 71 | 72 | query_t = query_t.view(txt_batch_size, -1, attn.heads, head_dim) 73 | key_t = key_t.view(txt_batch_size, -1, attn.heads, head_dim) 74 | value_t = value_t.view(txt_batch_size, -1, attn.heads, head_dim) 75 | 76 | query_i = torch.cat([query_i, query_i[-origin_batch_size:]], dim=0) 77 | key_i = torch.cat([key_i, key_i[-origin_batch_size:]], dim=0) 78 | value_i = torch.cat([value_i, value_i[-origin_batch_size:]], dim=0) 79 | 80 | num_image_tokens = query_i.shape[1] 81 | num_text_tokens = query_t.shape[1] 82 | query = torch.cat([query_i, query_t], dim=1) 83 | key = torch.cat([key_i, key_t], dim=1) 84 | value = torch.cat([value_i, value_t], dim=1) 85 | 86 | else: 87 | num_text_tokens = self.encoder_hidden_states_length 88 | num_image_tokens = query_i.shape[1] - num_text_tokens 89 | query = query_i 90 | key = key_i 91 | value = value_i 92 | 93 | if query.shape[-1] == rope.shape[-3] * 2: 94 | query, key = apply_rope(query, key, rope) 95 | else: 96 | query_1, query_2 = query.chunk(2, dim=-1) 97 | key_1, key_2 = key.chunk(2, dim=-1) 98 | query_1, key_1 = apply_rope(query_1, key_1, rope) 99 | query = torch.cat([query_1, query_2], dim=-1) 100 | key = torch.cat([key_1, key_2], dim=-1) 101 | 102 | query_negative, query = query[-origin_batch_size:], query[:-origin_batch_size] 103 | key_negative, key = key[-origin_batch_size:], key[:-origin_batch_size] 104 | value_negative, value = value[-origin_batch_size:], value[:-origin_batch_size] 105 | 106 | hidden_states = attention(query, key, value) 107 | hidden_states_negative = attention(query_negative, key_negative, value_negative) 108 | del query_negative, key_negative, value_negative, query, key, value 109 | 110 | hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) 111 | hidden_states_i_negative, hidden_states_t_negative = torch.split(hidden_states_negative, [num_image_tokens, num_text_tokens], dim=1) 112 | 113 | # NAG 114 | hidden_states_i_positive = hidden_states_i[-origin_batch_size:] 115 | hidden_states_i_guidance = nag(hidden_states_i_positive, hidden_states_i_negative, self.nag_scale, self.nag_tau, self.nag_alpha) 116 | 117 | hidden_states_i[-origin_batch_size:] = hidden_states_i_guidance 118 | 119 | if not attn.single: 120 | hidden_states_i = attn.to_out(hidden_states_i) 121 | hidden_states_t = attn.to_out_t(hidden_states_t) 122 | hidden_states_t_negative = attn.to_out_t(hidden_states_t_negative) 123 | hidden_states_t = torch.cat([hidden_states_t, hidden_states_t_negative], dim=0) 124 | return hidden_states_i, hidden_states_t 125 | 126 | else: 127 | hidden_states[-origin_batch_size:, :num_image_tokens] = hidden_states_i_guidance 128 | hidden_states_negative[:, :num_image_tokens] = hidden_states_i_guidance 129 | hidden_states = attn.to_out(hidden_states) 130 | hidden_states_negative = attn.to_out(hidden_states_negative) 131 | hidden_states = torch.cat([hidden_states, hidden_states_negative], dim=0) 132 | return hidden_states 133 | 134 | 135 | class NAGHiDreamImageTransformerBlock(HiDreamImageTransformerBlock): 136 | def forward( 137 | self, 138 | image_tokens: torch.FloatTensor, 139 | image_tokens_masks: Optional[torch.FloatTensor] = None, 140 | text_tokens: Optional[torch.FloatTensor] = None, 141 | adaln_input: Optional[torch.FloatTensor] = None, 142 | rope: torch.FloatTensor = None, 143 | ) -> torch.FloatTensor: 144 | wtype = image_tokens.dtype 145 | shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \ 146 | shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \ 147 | self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1) 148 | 149 | # 1. MM-Attention 150 | image_batch_size = image_tokens.shape[0] 151 | norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype) 152 | norm_image_tokens = norm_image_tokens * (1 + scale_msa_i[:image_batch_size]) + shift_msa_i[:image_batch_size] 153 | norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype) 154 | norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t 155 | 156 | attn_output_i, attn_output_t = self.attn1( 157 | norm_image_tokens, 158 | image_tokens_masks, 159 | norm_text_tokens, 160 | rope = rope, 161 | ) 162 | 163 | image_tokens = gate_msa_i[:image_batch_size] * attn_output_i + image_tokens 164 | text_tokens = gate_msa_t * attn_output_t + text_tokens 165 | 166 | # 2. Feed-forward 167 | norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype) 168 | norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i[:image_batch_size]) + shift_mlp_i[:image_batch_size] 169 | norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype) 170 | norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t 171 | 172 | ff_output_i = gate_mlp_i[:image_batch_size] * self.ff_i(norm_image_tokens) 173 | ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens) 174 | image_tokens = ff_output_i + image_tokens 175 | text_tokens = ff_output_t + text_tokens 176 | return image_tokens, text_tokens 177 | 178 | 179 | class NAGHiDreamImageTransformer2DModel(HiDreamImageTransformer2DModel): 180 | def __init__( 181 | self, 182 | *args, 183 | nag_scale: float = 1, 184 | nag_tau: float = 2.5, 185 | nag_alpha: float = 0.25, 186 | **kwargs, 187 | ): 188 | super().__init__(*args, **kwargs) 189 | self.nag_scale = nag_scale 190 | self.nag_tau = nag_tau 191 | self.nag_alpha = nag_alpha 192 | 193 | def forward_nag( 194 | self, 195 | x: torch.Tensor, 196 | t: torch.Tensor, 197 | y: Optional[torch.Tensor] = None, 198 | context: Optional[torch.Tensor] = None, 199 | encoder_hidden_states_llama3=None, 200 | image_cond=None, 201 | control = None, 202 | transformer_options = {}, 203 | ) -> torch.Tensor: 204 | bs, c, h, w = x.shape 205 | if image_cond is not None: 206 | x = torch.cat([x, image_cond], dim=-1) 207 | hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) 208 | timesteps = t 209 | pooled_embeds = y 210 | T5_encoder_hidden_states = context 211 | 212 | img_sizes = None 213 | 214 | # spatial forward 215 | batch_size = hidden_states.shape[0] 216 | txt_batch_size = T5_encoder_hidden_states.shape[0] 217 | origin_batch_size = txt_batch_size - batch_size 218 | hidden_states_type = hidden_states.dtype 219 | 220 | # 0. time 221 | timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device) 222 | timesteps = self.t_embedder(timesteps, hidden_states_type) 223 | timesteps = torch.cat([timesteps, timesteps[-origin_batch_size:]], dim=0) 224 | p_embedder = self.p_embedder(pooled_embeds) 225 | adaln_input = timesteps + p_embedder 226 | 227 | hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes) 228 | if image_tokens_masks is None: 229 | pH, pW = img_sizes[0] 230 | img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device) 231 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None] 232 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :] 233 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) 234 | hidden_states = self.x_embedder(hidden_states) 235 | 236 | # T5_encoder_hidden_states = encoder_hidden_states[0] 237 | encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0) 238 | encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers] 239 | 240 | if self.caption_projection is not None: 241 | new_encoder_hidden_states = [] 242 | for i, enc_hidden_state in enumerate(encoder_hidden_states): 243 | enc_hidden_state = self.caption_projection[i](enc_hidden_state) 244 | enc_hidden_state = enc_hidden_state.view(txt_batch_size, -1, hidden_states.shape[-1]) 245 | new_encoder_hidden_states.append(enc_hidden_state) 246 | encoder_hidden_states = new_encoder_hidden_states 247 | T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states) 248 | T5_encoder_hidden_states = T5_encoder_hidden_states.view(txt_batch_size, -1, hidden_states.shape[-1]) 249 | encoder_hidden_states.append(T5_encoder_hidden_states) 250 | 251 | txt_ids = torch.zeros( 252 | batch_size, 253 | encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1], 254 | 3, 255 | device=img_ids.device, dtype=img_ids.dtype 256 | ) 257 | ids = torch.cat((img_ids, txt_ids), dim=1) 258 | ids = torch.cat([ids, ids[-origin_batch_size:]], dim=0) 259 | rope = self.pe_embedder(ids) 260 | 261 | # 2. Blocks 262 | block_id = 0 263 | initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1) 264 | initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1] 265 | for bid, block in enumerate(self.double_stream_blocks): 266 | cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] 267 | cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1) 268 | hidden_states, initial_encoder_hidden_states = block( 269 | image_tokens = hidden_states, 270 | image_tokens_masks = image_tokens_masks, 271 | text_tokens = cur_encoder_hidden_states, 272 | adaln_input = adaln_input, 273 | rope = rope, 274 | ) 275 | initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] 276 | block_id += 1 277 | 278 | image_tokens_seq_len = hidden_states.shape[1] 279 | hidden_states = torch.cat([hidden_states, hidden_states[-origin_batch_size:]], dim=0) 280 | hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1) 281 | hidden_states_seq_len = hidden_states.shape[1] 282 | if image_tokens_masks is not None: 283 | encoder_attention_mask_ones = torch.ones( 284 | (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]), 285 | device=image_tokens_masks.device, dtype=image_tokens_masks.dtype 286 | ) 287 | image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1) 288 | 289 | for bid, block in enumerate(self.single_stream_blocks): 290 | cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] 291 | hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1) 292 | hidden_states = block( 293 | image_tokens=hidden_states, 294 | image_tokens_masks=image_tokens_masks, 295 | text_tokens=None, 296 | adaln_input=adaln_input, 297 | rope=rope, 298 | ) 299 | hidden_states = hidden_states[:, :hidden_states_seq_len] 300 | block_id += 1 301 | 302 | hidden_states = hidden_states[:-origin_batch_size] 303 | hidden_states = hidden_states[:, :image_tokens_seq_len, ...] 304 | output = self.final_layer(hidden_states, adaln_input[:-origin_batch_size]) 305 | output = self.unpatchify(output, img_sizes) 306 | return -output[:, :, :h, :w] 307 | 308 | def forward( 309 | self, 310 | x: torch.Tensor, 311 | t: torch.Tensor, 312 | y: Optional[torch.Tensor] = None, 313 | context: Optional[torch.Tensor] = None, 314 | encoder_hidden_states_llama3=None, 315 | image_cond=None, 316 | control=None, 317 | transformer_options={}, 318 | 319 | nag_negative_y=None, 320 | nag_negative_context=None, 321 | nag_negative_encoder_hidden_states_llama=None, 322 | nag_sigma_end=0., 323 | ): 324 | apply_nag = check_nag_activation(transformer_options, nag_sigma_end) 325 | if apply_nag: 326 | y = torch.cat((y, nag_negative_y.to(y)), dim=0) 327 | context = cat_context(context, nag_negative_context) 328 | encoder_hidden_states_llama3 = cat_context( 329 | encoder_hidden_states_llama3, nag_negative_encoder_hidden_states_llama, 330 | trim_context=True, 331 | dim=2, 332 | ) 333 | 334 | forward_nag_ = self.forward_nag 335 | blocks_forward = list() 336 | attn_processors = list() 337 | 338 | for module in self.modules(): 339 | if isinstance(module, HiDreamImageTransformerBlock): 340 | blocks_forward.append((module, module.forward)) 341 | module.forward = MethodType(NAGHiDreamImageTransformerBlock.forward, module) 342 | elif isinstance(module, HiDreamAttention): 343 | attn_processors.append((module, module.processor)) 344 | module.processor = NAGHiDreamAttnProcessor_flashattn( 345 | nag_scale=self.nag_scale, 346 | nag_tau=self.nag_tau, 347 | nag_alpha=self.nag_alpha, 348 | encoder_hidden_states_length=context.shape[1], 349 | origin_batch_size=nag_negative_context.shape[0], 350 | ) 351 | 352 | self.forward_nag = MethodType(NAGHiDreamImageTransformer2DModel.forward_nag, self) 353 | 354 | output = self.forward_nag(x, t, y, context, encoder_hidden_states_llama3, image_cond, control, transformer_options) 355 | 356 | if apply_nag: 357 | self.forward_nag = forward_nag_ 358 | for block, forward_fn in blocks_forward: 359 | block.forward = forward_fn 360 | for module, processor in attn_processors: 361 | module.processor = processor 362 | 363 | return output 364 | 365 | 366 | class NAGHiDreamImageTransformer2DModelSwitch(NAGSwitch): 367 | def set_nag(self): 368 | self.model.nag_scale = self.nag_scale 369 | self.model.nag_tau = self.nag_tau 370 | self.model.nag_alpha = self.nag_alpha 371 | self.model.forward_nag = self.model.forward 372 | self.model.forward = MethodType( 373 | partial( 374 | NAGHiDreamImageTransformer2DModel.forward, 375 | nag_negative_context=self.nag_negative_cond[0][0], 376 | nag_negative_y=self.nag_negative_cond[0][1]["pooled_output"], 377 | nag_negative_encoder_hidden_states_llama=self.nag_negative_cond[0][1]["conditioning_llama3"], 378 | nag_sigma_end=self.nag_sigma_end, 379 | ), 380 | self.model 381 | ) 382 | 383 | def set_origin(self): 384 | super().set_origin() 385 | if hasattr(self.model, "forward_nag"): 386 | del self.model.forward_nag 387 | -------------------------------------------------------------------------------- /node.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import comfy 3 | from comfy_extras.nodes_custom_sampler import Noise_EmptyNoise, Noise_RandomNoise 4 | import latent_preview 5 | 6 | from .samplers import NAGCFGGuider as samplers_NAGCFGGuider 7 | from .sample import sample_with_nag, sample_custom_with_nag 8 | 9 | 10 | def common_ksampler_with_nag(model, seed, steps, cfg, nag_scale, nag_tau, nag_alpha, nag_sigma_end, sampler_name, 11 | scheduler, positive, negative, nag_negative, latent, denoise=1.0, disable_noise=False, 12 | start_step=None, last_step=None, force_full_denoise=False): 13 | latent_image = latent["samples"] 14 | latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) 15 | 16 | if disable_noise: 17 | noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") 18 | else: 19 | batch_inds = latent["batch_index"] if "batch_index" in latent else None 20 | noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds) 21 | 22 | noise_mask = None 23 | if "noise_mask" in latent: 24 | noise_mask = latent["noise_mask"] 25 | 26 | callback = latent_preview.prepare_callback(model, steps) 27 | disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED 28 | samples = sample_with_nag( 29 | model, noise, steps, cfg, nag_scale, nag_tau, nag_alpha, nag_sigma_end, sampler_name, scheduler, positive, 30 | negative, nag_negative, latent_image, 31 | denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, 32 | force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, 33 | seed=seed, 34 | ) 35 | out = latent.copy() 36 | out["samples"] = samples 37 | return (out,) 38 | 39 | 40 | class NAGGuider: 41 | @classmethod 42 | def INPUT_TYPES(s): 43 | return {"required": 44 | { 45 | "model": ("MODEL",), 46 | "conditioning": ("CONDITIONING",), 47 | "nag_negative": ("CONDITIONING",), 48 | "nag_scale": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}), 49 | "nag_tau": ("FLOAT", {"default": 2.5, "min": 1.0, "max": 10.0, "step": 0.1, "round": 0.01}), 50 | "nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01}), 51 | "nag_sigma_end": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 20.0, "step": 0.01, "round": 0.01}), 52 | "latent_image": ("LATENT",), 53 | } 54 | } 55 | 56 | RETURN_TYPES = ("GUIDER",) 57 | 58 | FUNCTION = "get_guider" 59 | CATEGORY = "sampling/custom_sampling/guiders" 60 | 61 | def get_guider( 62 | self, 63 | model, 64 | conditioning, 65 | nag_negative, 66 | nag_scale, 67 | nag_tau, 68 | nag_alpha, 69 | nag_sigma_end, 70 | latent_image, 71 | ): 72 | batch_size = latent_image["samples"].shape[0] 73 | guider = samplers_NAGCFGGuider(model) 74 | guider.set_conds(conditioning) 75 | guider.set_batch_size(batch_size) 76 | guider.set_nag(nag_negative, nag_scale, nag_tau, nag_alpha, nag_sigma_end) 77 | return (guider,) 78 | 79 | 80 | class NAGCFGGuider: 81 | @classmethod 82 | def INPUT_TYPES(s): 83 | return {"required": 84 | { 85 | "model": ("MODEL",), 86 | "positive": ("CONDITIONING",), 87 | "negative": ("CONDITIONING",), 88 | "nag_negative": ("CONDITIONING",), 89 | "cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}), 90 | "nag_scale": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}), 91 | "nag_tau": ("FLOAT", {"default": 2.5, "min": 1.0, "max": 10.0, "step": 0.1, "round": 0.01}), 92 | "nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01}), 93 | "nag_sigma_end": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 20.0, "step": 0.01, "round": 0.01}), 94 | "latent_image": ("LATENT",), 95 | } 96 | } 97 | 98 | RETURN_TYPES = ("GUIDER",) 99 | 100 | FUNCTION = "get_guider" 101 | CATEGORY = "sampling/custom_sampling/guiders" 102 | 103 | def get_guider( 104 | self, 105 | model, 106 | positive, 107 | negative, 108 | nag_negative, 109 | cfg, 110 | nag_scale, 111 | nag_tau, 112 | nag_alpha, 113 | nag_sigma_end, 114 | latent_image, 115 | ): 116 | batch_size = latent_image["samples"].shape[0] 117 | guider = samplers_NAGCFGGuider(model) 118 | guider.set_conds(positive, negative) 119 | guider.set_cfg(cfg) 120 | guider.set_batch_size(batch_size) 121 | guider.set_nag(nag_negative, nag_scale, nag_tau, nag_alpha, nag_sigma_end) 122 | return (guider,) 123 | 124 | 125 | class KSamplerWithNAG: 126 | @classmethod 127 | def INPUT_TYPES(s): 128 | return { 129 | "required": { 130 | "model": ("MODEL", {"tooltip": "The model used for denoising the input latent."}), 131 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, 132 | "tooltip": "The random seed used for creating the noise."}), 133 | "steps": ("INT", {"default": 20, "min": 1, "max": 10000, 134 | "tooltip": "The number of steps used in the denoising process."}), 135 | "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01, 136 | "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."}), 137 | "nag_scale": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}), 138 | "nag_tau": ("FLOAT", {"default": 2.5, "min": 1.0, "max": 10.0, "step": 0.1, "round": 0.01}), 139 | "nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01}), 140 | "nag_sigma_end": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 20.0, "step": 0.01, "round": 0.01}), 141 | "sampler_name": (comfy.samplers.KSampler.SAMPLERS, { 142 | "tooltip": "The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."}), 143 | "scheduler": (comfy.samplers.KSampler.SCHEDULERS, 144 | {"tooltip": "The scheduler controls how noise is gradually removed to form the image."}), 145 | "positive": ("CONDITIONING", { 146 | "tooltip": "The conditioning describing the attributes you want to include in the image."}), 147 | "negative": ("CONDITIONING", { 148 | "tooltip": "The conditioning describing the attributes you want to exclude from the image."}), 149 | "nag_negative": ("CONDITIONING", { 150 | "tooltip": "The conditioning describing the attributes you want to exclude from the image for NAG."}), 151 | "latent_image": ("LATENT", {"tooltip": "The latent image to denoise."}), 152 | "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, 153 | "tooltip": "The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."}), 154 | } 155 | } 156 | 157 | RETURN_TYPES = ("LATENT",) 158 | OUTPUT_TOOLTIPS = ("The denoised latent.",) 159 | FUNCTION = "sample" 160 | 161 | CATEGORY = "sampling" 162 | DESCRIPTION = "Uses the provided model, positive and negative conditioning to denoise the latent image." 163 | 164 | def sample(self, model, seed, steps, cfg, nag_scale, nag_tau, nag_alpha, nag_sigma_end, sampler_name, scheduler, 165 | positive, negative, nag_negative, latent_image, denoise=1.0): 166 | return common_ksampler_with_nag(model, seed, steps, cfg, nag_scale, nag_tau, nag_alpha, nag_sigma_end, 167 | sampler_name, scheduler, positive, negative, nag_negative, latent_image, 168 | denoise=denoise) 169 | 170 | 171 | class KSamplerAdvancedWithNAG: 172 | @classmethod 173 | def INPUT_TYPES(s): 174 | return { 175 | "required": { 176 | "model": ("MODEL",), 177 | "add_noise": (["enable", "disable"],), 178 | "noise_seed": ( 179 | "INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}), 180 | "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), 181 | "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}), 182 | "nag_scale": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}), 183 | "nag_tau": ("FLOAT", {"default": 2.5, "min": 1.0, "max": 10.0, "step": 0.1, "round": 0.01}), 184 | "nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01}), 185 | "nag_sigma_end": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 20.0, "step": 0.01, "round": 0.01}), 186 | "sampler_name": (comfy.samplers.KSampler.SAMPLERS,), 187 | "scheduler": (comfy.samplers.KSampler.SCHEDULERS,), 188 | "positive": ("CONDITIONING",), 189 | "negative": ("CONDITIONING",), 190 | "nag_negative": ("CONDITIONING", { 191 | "tooltip": "The conditioning describing the attributes you want to exclude from the image for NAG."}), 192 | "latent_image": ("LATENT",), 193 | "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), 194 | "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), 195 | "return_with_leftover_noise": (["disable", "enable"],), 196 | } 197 | } 198 | 199 | RETURN_TYPES = ("LATENT",) 200 | FUNCTION = "sample" 201 | 202 | CATEGORY = "sampling" 203 | 204 | def sample( 205 | self, model, add_noise, noise_seed, steps, cfg, nag_scale, nag_tau, nag_alpha, nag_sigma_end, 206 | sampler_name, scheduler, positive, negative, nag_negative, 207 | latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, 208 | ): 209 | force_full_denoise = True 210 | if return_with_leftover_noise == "enable": 211 | force_full_denoise = False 212 | disable_noise = False 213 | if add_noise == "disable": 214 | disable_noise = True 215 | return common_ksampler_with_nag( 216 | model, noise_seed, steps, cfg, nag_scale, nag_tau, nag_alpha, nag_sigma_end, 217 | sampler_name, scheduler, positive, negative, nag_negative, 218 | latent_image, 219 | denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, 220 | last_step=end_at_step, force_full_denoise=force_full_denoise, 221 | ) 222 | 223 | 224 | class SamplerCustomWithNAG: 225 | @classmethod 226 | def INPUT_TYPES(s): 227 | return {"required": { 228 | "model": ("MODEL",), 229 | "add_noise": ("BOOLEAN", {"default": True}), 230 | "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}), 231 | "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}), 232 | "nag_scale": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}), 233 | "nag_tau": ("FLOAT", {"default": 2.5, "min": 1.0, "max": 10.0, "step": 0.1, "round": 0.01}), 234 | "nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01}), 235 | "nag_sigma_end": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 20.0, "step": 0.01, "round": 0.01}), 236 | "positive": ("CONDITIONING",), 237 | "negative": ("CONDITIONING",), 238 | "nag_negative": ("CONDITIONING", { 239 | "tooltip": "The conditioning describing the attributes you want to exclude from the image for NAG."}), 240 | "sampler": ("SAMPLER",), 241 | "sigmas": ("SIGMAS",), 242 | "latent_image": ("LATENT",), 243 | }} 244 | 245 | RETURN_TYPES = ("LATENT", "LATENT") 246 | RETURN_NAMES = ("output", "denoised_output") 247 | 248 | FUNCTION = "sample" 249 | 250 | CATEGORY = "sampling/custom_sampling" 251 | 252 | def sample( 253 | self, 254 | model, add_noise, noise_seed, cfg, nag_scale, nag_tau, nag_alpha, nag_sigma_end, 255 | positive, negative, nag_negative, 256 | sampler, sigmas, latent_image, 257 | ): 258 | latent = latent_image 259 | latent_image = latent["samples"] 260 | latent = latent.copy() 261 | latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) 262 | latent["samples"] = latent_image 263 | 264 | if not add_noise: 265 | noise = Noise_EmptyNoise().generate_noise(latent) 266 | else: 267 | noise = Noise_RandomNoise(noise_seed).generate_noise(latent) 268 | 269 | noise_mask = None 270 | if "noise_mask" in latent: 271 | noise_mask = latent["noise_mask"] 272 | 273 | x0_output = {} 274 | callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output) 275 | 276 | disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED 277 | samples = sample_custom_with_nag( 278 | model, noise, cfg, nag_scale, nag_tau, nag_alpha, nag_sigma_end, 279 | sampler, sigmas, positive, negative, nag_negative, 280 | latent_image, 281 | noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed, 282 | ) 283 | 284 | out = latent.copy() 285 | out["samples"] = samples 286 | if "x0" in x0_output: 287 | out_denoised = latent.copy() 288 | out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu()) 289 | else: 290 | out_denoised = out 291 | return (out, out_denoised) 292 | 293 | 294 | NODE_CLASS_MAPPINGS = { 295 | "NAGGuider": NAGGuider, 296 | "NAGCFGGuider": NAGCFGGuider, 297 | "KSamplerWithNAG": KSamplerWithNAG, 298 | "KSamplerWithNAG (Advanced)": KSamplerAdvancedWithNAG, 299 | "SamplerCustomWithNAG": SamplerCustomWithNAG, 300 | } 301 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import comfy 2 | from .samplers import KSamplerWithNAG 3 | from .samplers import sample_with_nag as samplers_sample_with_nag 4 | 5 | 6 | def sample_with_nag( 7 | model, noise, steps, cfg, nag_scale, nag_tau, nag_alpha, nag_sigma_end, sampler_name, scheduler, positive, negative, nag_negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): 8 | sampler = KSamplerWithNAG(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) 9 | 10 | samples = sampler.sample( 11 | noise, positive, negative, nag_negative, 12 | cfg=cfg, nag_scale=nag_scale, nag_tau=nag_tau, nag_alpha=nag_alpha, nag_sigma_end=nag_sigma_end, 13 | latent_image=latent_image, 14 | start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, 15 | denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed, 16 | ) 17 | samples = samples.to(comfy.model_management.intermediate_device()) 18 | return samples 19 | 20 | 21 | def sample_custom_with_nag( 22 | model, noise, cfg, nag_scale, nag_tau, nag_alpha, nag_sigma_end, sampler, sigmas, positive, negative, nag_negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): 23 | samples = samplers_sample_with_nag( 24 | model, noise, positive, negative, nag_negative, 25 | cfg, nag_scale, nag_tau, nag_alpha, nag_sigma_end, 26 | model.load_device, sampler, sigmas, 27 | model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, 28 | callback=callback, disable_pbar=disable_pbar, seed=seed, 29 | ) 30 | samples = samples.to(comfy.model_management.intermediate_device()) 31 | return samples 32 | -------------------------------------------------------------------------------- /samplers.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import copy 4 | from typing import TYPE_CHECKING 5 | import math 6 | 7 | if TYPE_CHECKING: 8 | from comfy.model_patcher import ModelPatcher 9 | import torch 10 | from torch._dynamo.eval_frame import OptimizedModule 11 | import torch._dynamo 12 | 13 | torch._dynamo.config.suppress_errors = True 14 | 15 | from comfy.samplers import ( 16 | process_conds, 17 | preprocess_conds_hooks, 18 | cast_to_load_options, 19 | filter_registered_hooks_on_conds, 20 | get_total_hook_groups_in_conds, 21 | CFGGuider, 22 | sampler_object, 23 | KSampler, 24 | ) 25 | import comfy.sampler_helpers 26 | import comfy.model_patcher 27 | import comfy.patcher_extension 28 | import comfy.hooks 29 | from comfy.ldm.flux.model import Flux 30 | from comfy.ldm.chroma.model import Chroma 31 | from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel 32 | from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper 33 | from comfy.ldm.wan.model import WanModel, VaceWanModel 34 | from comfy.ldm.hunyuan_video.model import HunyuanVideo 35 | from comfy.ldm.hidream.model import HiDreamImageTransformer2DModel 36 | 37 | from .flux.model import NAGFluxSwitch 38 | from .chroma.model import NAGChromaSwitch 39 | from .sd.openaimodel import NAGUNetModelSwitch 40 | from .sd3.mmdit import NAGOpenAISignatureMMDITWrapperSwitch 41 | from .wan.model import NAGWanModelSwitch 42 | from .hunyuan_video.model import NAGHunyuanVideoSwitch 43 | from .hidream.model import NAGHiDreamImageTransformer2DModelSwitch 44 | 45 | 46 | def sample_with_nag( 47 | model, 48 | noise, 49 | positive, negative, nag_negative, 50 | cfg, 51 | nag_scale, nag_tau, nag_alpha, nag_sigma_end, 52 | device, 53 | sampler, 54 | sigmas, 55 | model_options={}, 56 | latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None, 57 | ): 58 | guider = NAGCFGGuider(model) 59 | guider.set_conds(positive, negative) 60 | guider.set_cfg(cfg) 61 | guider.set_batch_size(latent_image.shape[0]) 62 | guider.set_nag(nag_negative, nag_scale, nag_tau, nag_alpha, nag_sigma_end) 63 | return guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) 64 | 65 | 66 | class NAGCFGGuider(CFGGuider): 67 | def __init__(self, model_patcher: ModelPatcher): 68 | super().__init__(model_patcher=model_patcher) 69 | self.origin_nag_negative_cond = None 70 | self.nag_scale = 5.0 71 | self.nag_tau = 3.5 72 | self.nag_alpha = 0.25 73 | self.nag_sigma_end = 0. 74 | self.batch_size = 1 75 | 76 | def set_conds(self, positive, negative=None): 77 | self.inner_set_conds( 78 | {"positive": positive, "negative": negative} if negative is not None else {"positive": positive}) 79 | 80 | def set_batch_size(self, batch_size): 81 | self.batch_size = batch_size 82 | 83 | def set_nag(self, nag_negative_cond, nag_scale, nag_tau, nag_alpha, nag_sigma_end): 84 | self.origin_nag_negative_cond = nag_negative_cond 85 | self.nag_scale = nag_scale 86 | self.nag_tau = nag_tau 87 | self.nag_alpha = nag_alpha 88 | self.nag_sigma_end = nag_sigma_end 89 | 90 | def __call__(self, *args, **kwargs): 91 | return self.predict_noise(*args, **kwargs) 92 | 93 | def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed): 94 | if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. 95 | latent_image = self.inner_model.process_latent_in(latent_image) 96 | 97 | self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) 98 | 99 | extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options) 100 | extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas 101 | extra_args = {"model_options": extra_model_options, "seed": seed} 102 | 103 | executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( 104 | sampler.sample, 105 | sampler, 106 | comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, extra_args["model_options"], is_model_options=True) 107 | ) 108 | samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) 109 | return self.inner_model.process_latent_out(samples.to(torch.float32)) 110 | 111 | def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): 112 | if sigmas.shape[-1] == 0: 113 | return latent_image 114 | 115 | self.conds = {} 116 | for k in self.original_conds: 117 | self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) 118 | preprocess_conds_hooks(self.conds) 119 | 120 | apply_guidance = self.nag_scale > 1. 121 | 122 | self.nag_negative_cond = None 123 | if apply_guidance: 124 | self.nag_negative_cond = copy.deepcopy(self.origin_nag_negative_cond) 125 | 126 | model = self.model_patcher.model.diffusion_model 127 | if isinstance(model, OptimizedModule): 128 | model = model._orig_mod 129 | model_type = type(model) 130 | if model_type == Flux: 131 | switcher_cls = NAGFluxSwitch 132 | elif model_type == Chroma: 133 | switcher_cls = NAGChromaSwitch 134 | elif model_type == UNetModel: 135 | switcher_cls = NAGUNetModelSwitch 136 | elif model_type == OpenAISignatureMMDITWrapper: 137 | switcher_cls = NAGOpenAISignatureMMDITWrapperSwitch 138 | elif model_type in [WanModel, VaceWanModel]: 139 | switcher_cls = NAGWanModelSwitch 140 | elif model_type == HunyuanVideo: 141 | switcher_cls = NAGHunyuanVideoSwitch 142 | elif model_type == HiDreamImageTransformer2DModel: 143 | switcher_cls = NAGHiDreamImageTransformer2DModelSwitch 144 | else: 145 | raise ValueError( 146 | f"Model type {model_type} is not support for NAGCFGGuider" 147 | ) 148 | self.nag_negative_cond[0][0] = self.nag_negative_cond[0][0].expand(self.batch_size, -1, -1) 149 | if self.nag_negative_cond[0][1].get("pooled_output", None) is not None: 150 | self.nag_negative_cond[0][1]["pooled_output"] = self.nag_negative_cond[0][1]["pooled_output"].expand(self.batch_size, -1) 151 | switcher = switcher_cls( 152 | model, 153 | self.nag_negative_cond, 154 | self.nag_scale, self.nag_tau, self.nag_alpha, self.nag_sigma_end, 155 | ) 156 | switcher.set_nag() 157 | 158 | try: 159 | orig_model_options = self.model_options 160 | self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options) 161 | # if one hook type (or just None), then don't bother caching weights for hooks (will never change after first step) 162 | orig_hook_mode = self.model_patcher.hook_mode 163 | if get_total_hook_groups_in_conds(self.conds) <= 1: 164 | self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram 165 | comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options) 166 | filter_registered_hooks_on_conds(self.conds, self.model_options) 167 | executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( 168 | self.outer_sample, 169 | self, 170 | comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True) 171 | ) 172 | output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) 173 | finally: 174 | cast_to_load_options(self.model_options, device=self.model_patcher.offload_device) 175 | self.model_options = orig_model_options 176 | self.model_patcher.hook_mode = orig_hook_mode 177 | self.model_patcher.restore_hook_patches() 178 | 179 | if apply_guidance: 180 | switcher.set_origin() 181 | 182 | del self.conds 183 | del self.nag_negative_cond 184 | return output 185 | 186 | 187 | class KSamplerWithNAG(KSampler): 188 | def sample( 189 | self, 190 | noise, 191 | positive, negative, nag_negative, 192 | cfg, 193 | nag_scale, nag_tau, nag_alpha, nag_sigma_end, 194 | latent_image=None, 195 | start_step=None, last_step=None, force_full_denoise=False, 196 | denoise_mask=None, 197 | sigmas=None, callback=None, disable_pbar=False, seed=None, 198 | ): 199 | if sigmas is None: 200 | sigmas = self.sigmas 201 | 202 | if last_step is not None and last_step < (len(sigmas) - 1): 203 | sigmas = sigmas[:last_step + 1] 204 | if force_full_denoise: 205 | sigmas[-1] = 0 206 | 207 | if start_step is not None: 208 | if start_step < (len(sigmas) - 1): 209 | sigmas = sigmas[start_step:] 210 | else: 211 | if latent_image is not None: 212 | return latent_image 213 | else: 214 | return torch.zeros_like(noise) 215 | 216 | sampler = sampler_object(self.sampler) 217 | 218 | return sample_with_nag( 219 | self.model, 220 | noise, 221 | positive, negative, nag_negative, 222 | cfg, 223 | nag_scale, nag_tau, nag_alpha, nag_sigma_end, 224 | self.device, 225 | sampler, 226 | sigmas, 227 | self.model_options, 228 | latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed, 229 | ) 230 | 231 | -------------------------------------------------------------------------------- /sd/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from comfy.ldm.modules.attention import CrossAttention, default, optimized_attention, optimized_attention_masked 3 | from ..utils import nag 4 | 5 | 6 | class NAGCrossAttention(CrossAttention): 7 | def __init__( 8 | self, 9 | *args, 10 | nag_scale: float = 1, 11 | nag_tau: float = 2.5, 12 | nag_alpha: float = 0.25, 13 | **kwargs, 14 | ): 15 | super().__init__(*args, **kwargs) 16 | self.nag_scale = nag_scale 17 | self.nag_tau = nag_tau 18 | self.nag_alpha = nag_alpha 19 | 20 | def forward( 21 | self, 22 | x, 23 | context=None, 24 | value=None, 25 | mask=None, 26 | ): 27 | origin_bsz = len(context) - len(x) 28 | assert origin_bsz != 0 29 | 30 | q = self.to_q(x) 31 | q = torch.cat([q, q[-origin_bsz:]], dim=0) 32 | 33 | context = default(context, x) 34 | k = self.to_k(context) 35 | if value is not None: 36 | v = self.to_v(value) 37 | del value 38 | else: 39 | v = self.to_v(context) 40 | 41 | if mask is None: 42 | out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) 43 | else: 44 | out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) 45 | 46 | # NAG 47 | out_negative, out_positive = out[-origin_bsz:], out[-origin_bsz * 2:-origin_bsz] 48 | out_guidance = nag(out_positive, out_negative, self.nag_scale, self.nag_tau, self.nag_alpha) 49 | out = torch.cat([out[:-origin_bsz * 2], out_guidance], dim=0) 50 | 51 | return self.to_out(out) 52 | -------------------------------------------------------------------------------- /sd/openaimodel.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from types import MethodType 3 | 4 | import torch 5 | import comfy 6 | from comfy.ldm.modules.attention import CrossAttention 7 | from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel 8 | 9 | from .attention import NAGCrossAttention 10 | from ..utils import cat_context, check_nag_activation, NAGSwitch 11 | 12 | 13 | class NAGUNetModel(UNetModel): 14 | def forward( 15 | self, 16 | x, 17 | timesteps=None, 18 | context=None, 19 | y=None, 20 | control=None, 21 | transformer_options={}, 22 | 23 | nag_negative_context=None, 24 | nag_sigma_end=0., 25 | 26 | **kwargs, 27 | ): 28 | apply_nag = check_nag_activation(transformer_options, nag_sigma_end) 29 | if apply_nag: 30 | context = cat_context(context, nag_negative_context) 31 | cross_attns_forward = list() 32 | for name, module in self.named_modules(): 33 | if "attn2" in name and isinstance(module, CrossAttention): 34 | cross_attns_forward.append((module, module.forward)) 35 | module.forward = MethodType(NAGCrossAttention.forward, module) 36 | 37 | output = comfy.patcher_extension.WrapperExecutor.new_class_executor( 38 | self._forward, 39 | self, 40 | comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, 41 | transformer_options) 42 | ).execute(x, timesteps, context, y, control, transformer_options, **kwargs) 43 | 44 | if apply_nag: 45 | for mod, forward_fn in cross_attns_forward: 46 | mod.forward = forward_fn 47 | 48 | return output 49 | 50 | 51 | class NAGUNetModelSwitch(NAGSwitch): 52 | def set_nag(self): 53 | self.model.forward = MethodType( 54 | partial( 55 | NAGUNetModel.forward, 56 | nag_negative_context=self.nag_negative_cond[0][0], 57 | nag_sigma_end=self.nag_sigma_end, 58 | ), 59 | self.model 60 | ) 61 | for name, module in self.model.named_modules(): 62 | if "attn2" in name and isinstance(module, CrossAttention): 63 | module.nag_scale = self.nag_scale 64 | module.nag_tau = self.nag_tau 65 | module.nag_alpha = self.nag_alpha 66 | -------------------------------------------------------------------------------- /sd3/mmdit.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from types import MethodType 3 | from functools import partial 4 | from typing import Callable 5 | 6 | import torch 7 | from einops import repeat 8 | import comfy 9 | from comfy.ldm.modules.diffusionmodules.mmdit import ( 10 | OpenAISignatureMMDITWrapper, 11 | JointBlock, 12 | optimized_attention, 13 | default, 14 | ) 15 | 16 | from ..utils import nag, cat_context, check_nag_activation, NAGSwitch 17 | 18 | 19 | def _nag_block_mixing( 20 | context, 21 | x, 22 | context_block, 23 | x_block, 24 | c, 25 | nag_scale: float = 1.0, 26 | nag_tau: float = 2.5, 27 | nag_alpha: float = 0.5, 28 | ): 29 | origin_bsz = len(context) - len(x) 30 | assert origin_bsz != 0 31 | 32 | context_qkv, context_intermediates = context_block.pre_attention(context, c) 33 | 34 | if x_block.x_block_self_attn: 35 | x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c[:-origin_bsz]) 36 | else: 37 | x_qkv, x_intermediates = x_block.pre_attention(x, c[:-origin_bsz]) 38 | 39 | o = [] 40 | for t in range(3): 41 | o.append(torch.cat(( 42 | context_qkv[t], 43 | torch.cat([x_qkv[t], x_qkv[t][-origin_bsz:]], dim=0), 44 | ),dim=1)) 45 | qkv = tuple(o) 46 | 47 | attn = optimized_attention( 48 | qkv[0], qkv[1], qkv[2], 49 | heads=x_block.attn.num_heads, 50 | ) 51 | context_attn, x_attn = ( 52 | attn[:, : context_qkv[0].shape[1]], 53 | attn[:, context_qkv[0].shape[1] :], 54 | ) 55 | 56 | # NAG 57 | x_attn_negative, x_attn_positive = x_attn[-origin_bsz:], x_attn[-origin_bsz * 2:-origin_bsz] 58 | x_attn_guidance = nag(x_attn_positive, x_attn_negative, nag_scale, nag_tau, nag_alpha) 59 | 60 | x_attn = torch.cat([x_attn[:-origin_bsz * 2], x_attn_guidance], dim=0) 61 | 62 | if not context_block.pre_only: 63 | context = context_block.post_attention(context_attn, *context_intermediates) 64 | 65 | else: 66 | context = None 67 | if x_block.x_block_self_attn: 68 | attn2 = optimized_attention( 69 | x_qkv2[0], x_qkv2[1], x_qkv2[2], 70 | heads=x_block.attn2.num_heads, 71 | ) 72 | x = x_block.post_attention_x(x_attn, attn2, *x_intermediates) 73 | else: 74 | x = x_block.post_attention(x_attn, *x_intermediates) 75 | return context, x 76 | 77 | 78 | def nag_block_mixing(*args, use_checkpoint=True, **kwargs): 79 | if use_checkpoint: 80 | return torch.utils.checkpoint.checkpoint( 81 | _nag_block_mixing, *args, use_reentrant=False, **kwargs 82 | ) 83 | else: 84 | return _nag_block_mixing(*args, **kwargs) 85 | 86 | 87 | class NAGJointBlock(JointBlock): 88 | def forward(self, *args, **kwargs): 89 | return nag_block_mixing( 90 | *args, context_block=self.context_block, x_block=self.x_block, **kwargs 91 | ) 92 | 93 | 94 | class NAGOpenAISignatureMMDITWrapper(OpenAISignatureMMDITWrapper): 95 | def __init__( 96 | self, 97 | *args, 98 | nag_scale: float = 1, 99 | nag_tau: float = 2.5, 100 | nag_alpha: float = 0.25, 101 | **kwargs, 102 | ): 103 | super().__init__(*args, **kwargs) 104 | self.nag_scale = nag_scale 105 | self.nag_tau = nag_tau 106 | self.nag_alpha = nag_alpha 107 | 108 | def forward_core_with_concat( 109 | self, 110 | x: torch.Tensor, 111 | c_mod: torch.Tensor, 112 | context: Optional[torch.Tensor] = None, 113 | control = None, 114 | transformer_options = {}, 115 | ) -> torch.Tensor: 116 | patches_replace = transformer_options.get("patches_replace", {}) 117 | if self.register_length > 0: 118 | context = torch.cat( 119 | ( 120 | repeat(self.register, "1 ... -> b ...", b=x.shape[0]), 121 | default(context, torch.Tensor([]).type_as(x)), 122 | ), 123 | 1, 124 | ) 125 | 126 | # context is B, L', D 127 | # x is B, L, D 128 | blocks_replace = patches_replace.get("dit", {}) 129 | blocks = len(self.joint_blocks) 130 | for i in range(blocks): 131 | if ("double_block", i) in blocks_replace: 132 | def block_wrap(args): 133 | out = {} 134 | out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"]) 135 | return out 136 | 137 | out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap}) 138 | context = out["txt"] 139 | x = out["img"] 140 | else: 141 | context, x = self.joint_blocks[i]( 142 | context, 143 | x, 144 | c=c_mod, 145 | use_checkpoint=self.use_checkpoint, 146 | ) 147 | if control is not None: 148 | control_o = control.get("output") 149 | if i < len(control_o): 150 | add = control_o[i] 151 | if add is not None: 152 | x += add 153 | 154 | x = self.final_layer(x, c_mod[:len(x)]) # (N, T, patch_size ** 2 * out_channels) 155 | return x 156 | 157 | def forward_core_with_concat_with_wavespeed( 158 | self, 159 | x: torch.Tensor, 160 | c_mod: torch.Tensor, 161 | context: Optional[torch.Tensor] = None, 162 | control = None, 163 | transformer_options = {}, 164 | use_cache: Callable = None, 165 | apply_prev_hidden_states_residual: Callable = None, 166 | set_buffer: Callable = None, 167 | ) -> torch.Tensor: 168 | patches_replace = transformer_options.get("patches_replace", {}) 169 | if self.register_length > 0: 170 | context = torch.cat( 171 | ( 172 | repeat(self.register, "1 ... -> b ...", b=x.shape[0]), 173 | default(context, torch.Tensor([]).type_as(x)), 174 | ), 175 | 1, 176 | ) 177 | 178 | # context is B, L', D 179 | # x is B, L, D 180 | blocks_replace = patches_replace.get("dit", {}) 181 | joint_blocks = self.joint_blocks[0].transformer_blocks 182 | blocks = len(joint_blocks) 183 | 184 | original_x = x 185 | can_use_cache = False 186 | 187 | for i in range(blocks): 188 | if i == 1: 189 | torch._dynamo.graph_break() 190 | if can_use_cache: 191 | del first_x_residual 192 | x = apply_prev_hidden_states_residual(x) 193 | break 194 | else: 195 | set_buffer("first_hidden_states_residual", first_x_residual) 196 | del first_x_residual 197 | 198 | original_x = x 199 | 200 | if ("double_block", i) in blocks_replace: 201 | def block_wrap(args): 202 | out = {} 203 | out["txt"], out["img"] = joint_blocks[i](args["txt"], args["img"], c=args["vec"]) 204 | return out 205 | 206 | out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap}) 207 | context = out["txt"] 208 | x = out["img"] 209 | else: 210 | context, x = joint_blocks[i]( 211 | context, 212 | x, 213 | c=c_mod, 214 | use_checkpoint=self.use_checkpoint, 215 | ) 216 | if control is not None: 217 | control_o = control.get("output") 218 | if i < len(control_o): 219 | add = control_o[i] 220 | if add is not None: 221 | x += add 222 | 223 | if i == 0: 224 | first_x_residual = x - original_x 225 | can_use_cache = use_cache(first_x_residual) 226 | del original_x 227 | 228 | if not can_use_cache: 229 | x = x.contiguous() 230 | x_residual = x - original_x 231 | set_buffer("hidden_states_residual", x_residual) 232 | torch._dynamo.graph_break() 233 | 234 | x = self.final_layer(x, c_mod[:len(x)]) # (N, T, patch_size ** 2 * out_channels) 235 | return x 236 | 237 | def forward( 238 | self, 239 | x: torch.Tensor, 240 | timesteps: torch.Tensor, 241 | context: Optional[torch.Tensor] = None, 242 | y: Optional[torch.Tensor] = None, 243 | control=None, 244 | transformer_options={}, 245 | 246 | nag_negative_context=None, 247 | nag_negative_y=None, 248 | nag_sigma_end=0., 249 | 250 | **kwargs, 251 | ) -> torch.Tensor: 252 | apply_nag = check_nag_activation(transformer_options, nag_sigma_end) 253 | if apply_nag: 254 | context = cat_context(context, nag_negative_context) 255 | y = torch.cat((y, nag_negative_y.to(y)), dim=0) 256 | 257 | forward_core_with_concat_ = self.forward_core_with_concat 258 | joint_blocks_forward = list() 259 | 260 | joint_blocks = self.joint_blocks 261 | is_wavespeed = "CachedTransformerBlocks" in type(joint_blocks[0]).__name__ 262 | if is_wavespeed: # chengzeyi/Comfy-WaveSpeed 263 | cached_blocks = self.joint_blocks[0] 264 | joint_blocks = cached_blocks.transformer_blocks 265 | 266 | if is_wavespeed: 267 | get_can_use_cache = cached_blocks.forward.__globals__["get_can_use_cache"] 268 | set_buffer = cached_blocks.forward.__globals__["set_buffer"] 269 | apply_prev_hidden_states_residual = cached_blocks.forward.__globals__["apply_prev_hidden_states_residual"] 270 | 271 | def use_cache(first_hidden_states_residual): 272 | return get_can_use_cache( 273 | first_hidden_states_residual, 274 | threshold=cached_blocks.residual_diff_threshold, 275 | validation_function=cached_blocks.validate_can_use_cache_function, 276 | ) 277 | 278 | self.forward_core_with_concat = MethodType( 279 | partial( 280 | NAGOpenAISignatureMMDITWrapper.forward_core_with_concat_with_wavespeed, 281 | use_cache=use_cache, 282 | apply_prev_hidden_states_residual=apply_prev_hidden_states_residual, 283 | set_buffer=set_buffer, 284 | ), 285 | self, 286 | ) 287 | 288 | else: 289 | self.forward_core_with_concat = MethodType(NAGOpenAISignatureMMDITWrapper.forward_core_with_concat, self) 290 | 291 | for block in joint_blocks: 292 | joint_blocks_forward.append(block.forward) 293 | block.forward = MethodType( 294 | partial( 295 | NAGJointBlock.forward, 296 | nag_scale=self.nag_scale, 297 | nag_tau=self.nag_tau, 298 | nag_alpha=self.nag_alpha, 299 | ), 300 | block, 301 | ) 302 | 303 | if self.context_processor is not None: 304 | context = self.context_processor(context) 305 | 306 | hw = x.shape[-2:] 307 | x = self.x_embedder(x) + comfy.ops.cast_to_input(self.cropped_pos_embed(hw, device=x.device), x) 308 | c = self.t_embedder(timesteps, dtype=x.dtype) # (N, D) 309 | 310 | if apply_nag: 311 | origin_bsz = len(context) - len(x) 312 | c = torch.cat((c, c[-origin_bsz:]), dim=0) 313 | 314 | if y is not None and self.y_embedder is not None: 315 | y = self.y_embedder(y) # (N, D) 316 | c = c + y # (N, D) 317 | 318 | if context is not None: 319 | context = self.context_embedder(context) 320 | 321 | x = self.forward_core_with_concat(x, c, context, control, transformer_options) 322 | 323 | if apply_nag: 324 | self.forward_core_with_concat = forward_core_with_concat_ 325 | for block in joint_blocks: 326 | block.forward = joint_blocks_forward.pop(0) 327 | 328 | x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W) 329 | return x[:, :, :hw[-2], :hw[-1]] 330 | 331 | 332 | class NAGOpenAISignatureMMDITWrapperSwitch(NAGSwitch): 333 | def set_nag(self): 334 | self.model.nag_scale = self.nag_scale 335 | self.model.nag_tau = self.nag_tau 336 | self.model.nag_alpha = self.nag_alpha 337 | self.model.forward = MethodType( 338 | partial( 339 | NAGOpenAISignatureMMDITWrapper.forward, 340 | nag_negative_context=self.nag_negative_cond[0][0], 341 | nag_negative_y=self.nag_negative_cond[0][1]["pooled_output"], 342 | nag_sigma_end=self.nag_sigma_end, 343 | ), 344 | self.model 345 | ) 346 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | def nag(z_positive, z_negative, scale, tau, alpha): 6 | z_guidance = z_positive * scale - z_negative * (scale - 1) 7 | norm_positive = torch.norm(z_positive, p=1, dim=-1, keepdim=True).expand(*z_positive.shape) 8 | norm_guidance = torch.norm(z_guidance, p=1, dim=-1, keepdim=True).expand(*z_guidance.shape) 9 | 10 | scale = norm_guidance / norm_positive 11 | z_guidance = z_guidance * torch.minimum(scale, scale.new_ones(1) * tau) / scale 12 | 13 | z_guidance = z_guidance * alpha + z_positive * (1 - alpha) 14 | 15 | return z_guidance 16 | 17 | 18 | def cat_context(context, nag_negative_context, trim_context=False, dim=1): 19 | assert dim in [1, 2] 20 | nag_negative_context = nag_negative_context.to(context) 21 | 22 | context_len = context.shape[dim] 23 | nag_neg_context_len = nag_negative_context.shape[dim] 24 | 25 | if context_len < nag_neg_context_len: 26 | if dim == 1: 27 | context = context.repeat(1, math.ceil(nag_neg_context_len / context_len), 1) 28 | if trim_context: 29 | context = context[:, -nag_neg_context_len:] 30 | else: 31 | context = context.repeat(1, 1, math.ceil(nag_neg_context_len / context_len), 1) 32 | if trim_context: 33 | context = context[:, :, -nag_neg_context_len:] 34 | 35 | context_len = context.shape[dim] 36 | 37 | if dim == 1: 38 | nag_negative_context = nag_negative_context.repeat(1, math.ceil(context_len / nag_neg_context_len), 1) 39 | nag_negative_context = nag_negative_context[:, -context_len:] 40 | else: 41 | nag_negative_context = nag_negative_context.repeat(1, 1, math.ceil(context_len / nag_neg_context_len), 1) 42 | nag_negative_context = nag_negative_context[:, :, -context_len:] 43 | 44 | 45 | return torch.cat([context, nag_negative_context], dim=0) 46 | 47 | 48 | def check_nag_activation(transformer_options, nag_sigma_end): 49 | apply_nag = torch.all(transformer_options["sigmas"] >= nag_sigma_end) 50 | positive_batch = 0 in transformer_options["cond_or_uncond"] 51 | return apply_nag and positive_batch 52 | 53 | 54 | def get_closure_vars(func): 55 | if func.__closure__ is None: 56 | return {} 57 | return { 58 | var: cell.cell_contents 59 | for var, cell in zip(func.__code__.co_freevars, func.__closure__) 60 | } 61 | 62 | 63 | def is_from_wavespeed(func): 64 | closure = get_closure_vars(func) 65 | return "residual_diff_threshold" in closure \ 66 | and "validate_can_use_cache_function" in closure 67 | 68 | 69 | class NAGSwitch: 70 | def __init__( 71 | self, 72 | model: torch.nn.Module, 73 | nag_negative_cond, 74 | nag_scale, nag_tau, nag_alpha, nag_sigma_end, 75 | ): 76 | self.model = model 77 | self.nag_negative_cond = nag_negative_cond 78 | self.nag_scale = nag_scale 79 | self.nag_tau = nag_tau 80 | self.nag_alpha = nag_alpha 81 | self.nag_sigma_end = nag_sigma_end 82 | self.origin_forward = model.forward 83 | 84 | def set_nag(self): 85 | pass 86 | 87 | def set_origin(self): 88 | self.model.forward = self.origin_forward 89 | 90 | 91 | # https://github.com/welltop-cn/ComfyUI-TeaCache/blob/4bca908bf53b029ea5739cb69ef2a9e6c06e6752/nodes.py 92 | def poly1d(coefficients, x): 93 | result = torch.zeros_like(x) 94 | for i, coeff in enumerate(coefficients): 95 | result += coeff * (x ** (len(coefficients) - 1 - i)) 96 | return result 97 | -------------------------------------------------------------------------------- /workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenDarYen/ComfyUI-NAG/25bd940444732e21726944d03893a2f6055f5392/workflow.png -------------------------------------------------------------------------------- /workflows/NAG-Chroma-ComfyUI-Workflow.json: -------------------------------------------------------------------------------- 1 | {"last_node_id":45,"last_link_id":85,"nodes":[{"id":33,"type":"SamplerCustomAdvanced","pos":[1290,40],"size":[355.20001220703125,106],"flags":{},"order":13,"mode":0,"inputs":[{"name":"noise","localized_name":"noise","label":"noise","type":"NOISE","link":52,"slot_index":0},{"name":"guider","localized_name":"guider","label":"guider","type":"GUIDER","link":53,"slot_index":1},{"name":"sampler","localized_name":"sampler","label":"sampler","type":"SAMPLER","link":56,"slot_index":2},{"name":"sigmas","localized_name":"sigmas","label":"sigmas","type":"SIGMAS","link":55,"slot_index":3},{"name":"latent_image","localized_name":"latent_image","label":"latent_image","type":"LATENT","link":54,"slot_index":4}],"outputs":[{"name":"output","localized_name":"output","label":"output","type":"LATENT","shape":3,"links":[51],"slot_index":0},{"name":"denoised_output","localized_name":"denoised_output","label":"denoised_output","type":"LATENT","shape":3,"links":null}],"properties":{"Node name for S&R":"SamplerCustomAdvanced"},"widgets_values":[]},{"id":34,"type":"VAEDecode","pos":[1290,200],"size":[210,46],"flags":{},"order":15,"mode":0,"inputs":[{"name":"samples","localized_name":"samples","label":"samples","type":"LATENT","link":51},{"name":"vae","localized_name":"vae","label":"vae","type":"VAE","link":57}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","label":"IMAGE","type":"IMAGE","links":[50],"slot_index":0}],"properties":{"Node name for S&R":"VAEDecode"},"widgets_values":[]},{"id":43,"type":"SaveImage","pos":[1700,430],"size":[450,490],"flags":{},"order":18,"mode":0,"inputs":[{"name":"images","localized_name":"images","label":"images","type":"IMAGE","link":71}],"outputs":[],"properties":{},"widgets_values":["ComfyUI"]},{"id":36,"type":"SaveImage","pos":[1700,-120],"size":[450,490],"flags":{},"order":17,"mode":0,"inputs":[{"name":"images","localized_name":"images","label":"images","type":"IMAGE","link":50}],"outputs":[],"properties":{},"widgets_values":["ComfyUI"]},{"id":41,"type":"SamplerCustomAdvanced","pos":[1290,650],"size":[355.20001220703125,106],"flags":{},"order":14,"mode":0,"inputs":[{"name":"noise","localized_name":"noise","label":"noise","type":"NOISE","link":78,"slot_index":0},{"name":"guider","localized_name":"guider","label":"guider","type":"GUIDER","link":68,"slot_index":1},{"name":"sampler","localized_name":"sampler","label":"sampler","type":"SAMPLER","link":77,"slot_index":2},{"name":"sigmas","localized_name":"sigmas","label":"sigmas","type":"SIGMAS","link":76,"slot_index":3},{"name":"latent_image","localized_name":"latent_image","label":"latent_image","type":"LATENT","link":79,"slot_index":4}],"outputs":[{"name":"output","localized_name":"output","label":"output","type":"LATENT","shape":3,"links":[69],"slot_index":0},{"name":"denoised_output","localized_name":"denoised_output","label":"denoised_output","type":"LATENT","shape":3,"links":null}],"properties":{"Node name for S&R":"SamplerCustomAdvanced"},"widgets_values":[]},{"id":10,"type":"VAELoader","pos":[48,384],"size":[315,58],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"VAE","localized_name":"VAE","label":"VAE","type":"VAE","shape":3,"links":[57,70],"slot_index":0}],"properties":{"Node name for S&R":"VAELoader"},"widgets_values":["ae.safetensors"]},{"id":42,"type":"VAEDecode","pos":[1290,540],"size":[210,46],"flags":{},"order":16,"mode":0,"inputs":[{"name":"samples","localized_name":"samples","label":"samples","type":"LATENT","link":69},{"name":"vae","localized_name":"vae","label":"vae","type":"VAE","link":70}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","label":"IMAGE","type":"IMAGE","links":[71],"slot_index":0}],"properties":{"Node name for S&R":"VAEDecode"},"widgets_values":[]},{"id":16,"type":"KSamplerSelect","pos":[480,720],"size":[315,58],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[{"name":"SAMPLER","localized_name":"SAMPLER","label":"SAMPLER","type":"SAMPLER","shape":3,"links":[56,77],"slot_index":0}],"properties":{"Node name for S&R":"KSamplerSelect"},"widgets_values":["euler"]},{"id":44,"type":"CLIPLoader","pos":[-300,240],"size":[315,98.00001525878906],"flags":{},"order":2,"mode":0,"inputs":[],"outputs":[{"name":"CLIP","localized_name":"CLIP","type":"CLIP","links":[80],"slot_index":0}],"properties":{"Node name for S&R":"CLIPLoader"},"widgets_values":["t5xxl_fp16.safetensors","chroma","default"]},{"id":45,"type":"T5TokenizerOptions","pos":[50,240],"size":[315,82],"flags":{},"order":6,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","type":"CLIP","link":80}],"outputs":[{"name":"CLIP","localized_name":"CLIP","type":"CLIP","links":[81,82,83],"slot_index":0}],"properties":{"Node name for S&R":"T5TokenizerOptions"},"widgets_values":[1,0]},{"id":12,"type":"UNETLoader","pos":[48,96],"size":[315,82],"flags":{},"order":3,"mode":0,"inputs":[],"outputs":[{"name":"MODEL","localized_name":"MODEL","label":"MODEL","type":"MODEL","shape":3,"links":[38,46,75],"slot_index":0}],"properties":{"Node name for S&R":"UNETLoader"},"widgets_values":["chroma-unlocked-v38-detail-calibrated.safetensors","default"],"color":"#223","bgcolor":"#335"},{"id":17,"type":"BasicScheduler","pos":[480,816],"size":[315,106],"flags":{},"order":7,"mode":0,"inputs":[{"name":"model","localized_name":"model","label":"model","type":"MODEL","link":38,"slot_index":0}],"outputs":[{"name":"SIGMAS","localized_name":"SIGMAS","label":"SIGMAS","type":"SIGMAS","shape":3,"links":[55,76],"slot_index":0}],"properties":{"Node name for S&R":"BasicScheduler"},"widgets_values":["beta",30,1]},{"id":6,"type":"CLIPTextEncode","pos":[375,221],"size":[422.84503173828125,164.31304931640625],"flags":{},"order":8,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","label":"clip","type":"CLIP","link":81}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","label":"CONDITIONING","type":"CONDITIONING","links":[47,72],"slot_index":0}],"properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["A beautiful cyborg."],"color":"#232","bgcolor":"#353"},{"id":31,"type":"CLIPTextEncode","pos":[380,-230],"size":[422.84503173828125,164.31304931640625],"flags":{},"order":10,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","label":"clip","type":"CLIP","link":83}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","label":"CONDITIONING","type":"CONDITIONING","links":[48,74],"slot_index":0}],"title":"CLIP Text Encode (Negative Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["Low resolution, blurry, lack of details, illustration, cartoon, painting."],"color":"#322","bgcolor":"#533"},{"id":32,"type":"CLIPTextEncode","pos":[380,0],"size":[422.84503173828125,164.31304931640625],"flags":{},"order":9,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","label":"clip","type":"CLIP","link":82}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","label":"CONDITIONING","type":"CONDITIONING","links":[49,73],"slot_index":0}],"title":"CLIP Text Encode (NAG Negative Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["warm lighting."],"color":"#322","bgcolor":"#533"},{"id":25,"type":"RandomNoise","pos":[480,576],"size":[315,82],"flags":{},"order":4,"mode":0,"inputs":[],"outputs":[{"name":"NOISE","localized_name":"NOISE","label":"NOISE","type":"NOISE","shape":3,"links":[52,78],"slot_index":0}],"properties":{"Node name for S&R":"RandomNoise"},"widgets_values":[124740536074849,"randomize"],"color":"#2a363b","bgcolor":"#3f5159"},{"id":40,"type":"NAGCFGGuider","pos":[930,460],"size":[315,234],"flags":{},"order":12,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":75},{"name":"positive","localized_name":"positive","type":"CONDITIONING","link":72},{"name":"negative","localized_name":"negative","type":"CONDITIONING","link":74},{"name":"nag_negative","localized_name":"nag_negative","type":"CONDITIONING","link":73},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":85}],"outputs":[{"name":"GUIDER","localized_name":"GUIDER","type":"GUIDER","links":[68],"slot_index":0}],"properties":{"Node name for S&R":"NAGCFGGuider"},"widgets_values":[4,1,2.5,0.25,0]},{"id":5,"type":"EmptyLatentImage","pos":[480,432],"size":[315,106],"flags":{},"order":5,"mode":0,"inputs":[],"outputs":[{"name":"LATENT","localized_name":"LATENT","label":"LATENT","type":"LATENT","links":[54,79,84,85],"slot_index":0}],"properties":{"Node name for S&R":"EmptyLatentImage"},"widgets_values":[1024,1024,1],"color":"#323","bgcolor":"#535"},{"id":30,"type":"NAGCFGGuider","pos":[930,180],"size":[315,234],"flags":{},"order":11,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":46},{"name":"positive","localized_name":"positive","type":"CONDITIONING","link":47},{"name":"negative","localized_name":"negative","type":"CONDITIONING","link":48},{"name":"nag_negative","localized_name":"nag_negative","type":"CONDITIONING","link":49},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":84}],"outputs":[{"name":"GUIDER","localized_name":"GUIDER","type":"GUIDER","links":[53],"slot_index":0}],"properties":{"Node name for S&R":"NAGCFGGuider"},"widgets_values":[4,5,2.5,0.25,0.75]}],"links":[[38,12,0,17,0,"MODEL"],[46,12,0,30,0,"MODEL"],[47,6,0,30,1,"CONDITIONING"],[48,31,0,30,2,"CONDITIONING"],[49,32,0,30,3,"CONDITIONING"],[50,34,0,36,0,"IMAGE"],[51,33,0,34,0,"LATENT"],[52,25,0,33,0,"NOISE"],[53,30,0,33,1,"GUIDER"],[54,5,0,33,4,"LATENT"],[55,17,0,33,3,"SIGMAS"],[56,16,0,33,2,"SAMPLER"],[57,10,0,34,1,"VAE"],[68,40,0,41,1,"GUIDER"],[69,41,0,42,0,"LATENT"],[70,10,0,42,1,"VAE"],[71,42,0,43,0,"IMAGE"],[72,6,0,40,1,"CONDITIONING"],[73,32,0,40,3,"CONDITIONING"],[74,31,0,40,2,"CONDITIONING"],[75,12,0,40,0,"MODEL"],[76,17,0,41,3,"SIGMAS"],[77,16,0,41,2,"SAMPLER"],[78,25,0,41,0,"NOISE"],[79,5,0,41,4,"LATENT"],[80,44,0,45,0,"CLIP"],[81,45,0,6,0,"CLIP"],[82,45,0,32,0,"CLIP"],[83,45,0,31,0,"CLIP"],[84,5,0,30,4,"LATENT"],[85,5,0,40,4,"LATENT"]],"groups":[],"config":{},"extra":{"ds":{"scale":0.8535456747772172,"offset":[245.35353773913687,316.27093101706475]}},"version":0.4} -------------------------------------------------------------------------------- /workflows/NAG-DMD2-ComfyUI-Workflow.json: -------------------------------------------------------------------------------- 1 | {"last_node_id":44,"last_link_id":81,"nodes":[{"id":33,"type":"SamplerCustomAdvanced","pos":[1290,40],"size":[355.20001220703125,106],"flags":{},"order":12,"mode":0,"inputs":[{"name":"noise","localized_name":"noise","label":"noise","type":"NOISE","link":52,"slot_index":0},{"name":"guider","localized_name":"guider","label":"guider","type":"GUIDER","link":53,"slot_index":1},{"name":"sampler","localized_name":"sampler","label":"sampler","type":"SAMPLER","link":56,"slot_index":2},{"name":"sigmas","localized_name":"sigmas","label":"sigmas","type":"SIGMAS","link":55,"slot_index":3},{"name":"latent_image","localized_name":"latent_image","label":"latent_image","type":"LATENT","link":54,"slot_index":4}],"outputs":[{"name":"output","localized_name":"output","label":"output","type":"LATENT","shape":3,"links":[51],"slot_index":0},{"name":"denoised_output","localized_name":"denoised_output","label":"denoised_output","type":"LATENT","shape":3,"links":null}],"properties":{"Node name for S&R":"SamplerCustomAdvanced"},"widgets_values":[]},{"id":34,"type":"VAEDecode","pos":[1290,200],"size":[210,46],"flags":{},"order":14,"mode":0,"inputs":[{"name":"samples","localized_name":"samples","label":"samples","type":"LATENT","link":51},{"name":"vae","localized_name":"vae","label":"vae","type":"VAE","link":57}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","label":"IMAGE","type":"IMAGE","links":[50],"slot_index":0}],"properties":{"Node name for S&R":"VAEDecode"},"widgets_values":[]},{"id":43,"type":"SaveImage","pos":[1700,430],"size":[450,490],"flags":{},"order":17,"mode":0,"inputs":[{"name":"images","localized_name":"images","label":"images","type":"IMAGE","link":71}],"outputs":[],"properties":{},"widgets_values":["ComfyUI"]},{"id":36,"type":"SaveImage","pos":[1700,-120],"size":[450,490],"flags":{},"order":16,"mode":0,"inputs":[{"name":"images","localized_name":"images","label":"images","type":"IMAGE","link":50}],"outputs":[],"properties":{},"widgets_values":["ComfyUI"]},{"id":41,"type":"SamplerCustomAdvanced","pos":[1290,650],"size":[355.20001220703125,106],"flags":{},"order":13,"mode":0,"inputs":[{"name":"noise","localized_name":"noise","label":"noise","type":"NOISE","link":78,"slot_index":0},{"name":"guider","localized_name":"guider","label":"guider","type":"GUIDER","link":68,"slot_index":1},{"name":"sampler","localized_name":"sampler","label":"sampler","type":"SAMPLER","link":77,"slot_index":2},{"name":"sigmas","localized_name":"sigmas","label":"sigmas","type":"SIGMAS","link":76,"slot_index":3},{"name":"latent_image","localized_name":"latent_image","label":"latent_image","type":"LATENT","link":79,"slot_index":4}],"outputs":[{"name":"output","localized_name":"output","label":"output","type":"LATENT","shape":3,"links":[69],"slot_index":0},{"name":"denoised_output","localized_name":"denoised_output","label":"denoised_output","type":"LATENT","shape":3,"links":null}],"properties":{"Node name for S&R":"SamplerCustomAdvanced"},"widgets_values":[]},{"id":42,"type":"VAEDecode","pos":[1290,540],"size":[210,46],"flags":{},"order":15,"mode":0,"inputs":[{"name":"samples","localized_name":"samples","label":"samples","type":"LATENT","link":69},{"name":"vae","localized_name":"vae","label":"vae","type":"VAE","link":70}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","label":"IMAGE","type":"IMAGE","links":[71],"slot_index":0}],"properties":{"Node name for S&R":"VAEDecode"},"widgets_values":[]},{"id":6,"type":"CLIPTextEncode","pos":[375,221],"size":[422.84503173828125,164.31304931640625],"flags":{},"order":7,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","label":"clip","type":"CLIP","link":10}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","label":"CONDITIONING","type":"CONDITIONING","links":[47,72],"slot_index":0}],"properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["A beautiful cyborg."],"color":"#232","bgcolor":"#353"},{"id":32,"type":"CLIPTextEncode","pos":[380,0],"size":[422.84503173828125,164.31304931640625],"flags":{},"order":8,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","label":"clip","type":"CLIP","link":44}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","label":"CONDITIONING","type":"CONDITIONING","links":[49,73],"slot_index":0}],"title":"CLIP Text Encode (NAG Negative Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["Robot."],"color":"#322","bgcolor":"#533"},{"id":31,"type":"CLIPTextEncode","pos":[380,-230],"size":[422.84503173828125,164.31304931640625],"flags":{},"order":9,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","label":"clip","type":"CLIP","link":45}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","label":"CONDITIONING","type":"CONDITIONING","links":[48,74],"slot_index":0}],"title":"CLIP Text Encode (Negative Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":[""],"color":"#322","bgcolor":"#533"},{"id":12,"type":"UNETLoader","pos":[48,96],"size":[315,82],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"MODEL","localized_name":"MODEL","label":"MODEL","type":"MODEL","shape":3,"links":[38,46,75],"slot_index":0}],"properties":{"Node name for S&R":"UNETLoader"},"widgets_values":["dmd2_sdxl_4step_unet_fp16.safetensors","default"],"color":"#223","bgcolor":"#335"},{"id":11,"type":"DualCLIPLoader","pos":[48,240],"size":[315,122],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[{"name":"CLIP","localized_name":"CLIP","label":"CLIP","type":"CLIP","shape":3,"links":[10,44,45],"slot_index":0}],"properties":{"Node name for S&R":"DualCLIPLoader"},"widgets_values":["clip_g.safetensors","clip_l.safetensors","sdxl","default"]},{"id":17,"type":"BasicScheduler","pos":[480,816],"size":[315,106],"flags":{},"order":6,"mode":0,"inputs":[{"name":"model","localized_name":"model","label":"model","type":"MODEL","link":38,"slot_index":0}],"outputs":[{"name":"SIGMAS","localized_name":"SIGMAS","label":"SIGMAS","type":"SIGMAS","shape":3,"links":[55,76],"slot_index":0}],"properties":{"Node name for S&R":"BasicScheduler"},"widgets_values":["simple",4,1]},{"id":16,"type":"KSamplerSelect","pos":[480,720],"size":[315,58],"flags":{},"order":2,"mode":0,"inputs":[],"outputs":[{"name":"SAMPLER","localized_name":"SAMPLER","label":"SAMPLER","type":"SAMPLER","shape":3,"links":[56,77],"slot_index":0}],"properties":{"Node name for S&R":"KSamplerSelect"},"widgets_values":["lcm"]},{"id":10,"type":"VAELoader","pos":[48,384],"size":[315,58],"flags":{},"order":3,"mode":0,"inputs":[],"outputs":[{"name":"VAE","localized_name":"VAE","label":"VAE","type":"VAE","shape":3,"links":[57,70],"slot_index":0}],"properties":{"Node name for S&R":"VAELoader"},"widgets_values":["sdxl_vae.safetensors"]},{"id":25,"type":"RandomNoise","pos":[480,576],"size":[315,82],"flags":{},"order":4,"mode":0,"inputs":[],"outputs":[{"name":"NOISE","localized_name":"NOISE","label":"NOISE","type":"NOISE","shape":3,"links":[52,78],"slot_index":0}],"properties":{"Node name for S&R":"RandomNoise"},"widgets_values":[382920077628287,"randomize"],"color":"#2a363b","bgcolor":"#3f5159"},{"id":5,"type":"EmptyLatentImage","pos":[480,432],"size":[315,106],"flags":{},"order":5,"mode":0,"inputs":[],"outputs":[{"name":"LATENT","localized_name":"LATENT","label":"LATENT","type":"LATENT","links":[54,79,80,81],"slot_index":0}],"properties":{"Node name for S&R":"EmptyLatentImage"},"widgets_values":[1024,1024,1],"color":"#323","bgcolor":"#535"},{"id":40,"type":"NAGCFGGuider","pos":[930,460],"size":[315,234],"flags":{},"order":11,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":75},{"name":"positive","localized_name":"positive","type":"CONDITIONING","link":72},{"name":"negative","localized_name":"negative","type":"CONDITIONING","link":74},{"name":"nag_negative","localized_name":"nag_negative","type":"CONDITIONING","link":73},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":80}],"outputs":[{"name":"GUIDER","localized_name":"GUIDER","type":"GUIDER","links":[68],"slot_index":0}],"properties":{"Node name for S&R":"NAGCFGGuider"},"widgets_values":[1,1,2.5,0.5,0]},{"id":30,"type":"NAGCFGGuider","pos":[930,180],"size":[315,234],"flags":{},"order":10,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":46},{"name":"positive","localized_name":"positive","type":"CONDITIONING","link":47},{"name":"negative","localized_name":"negative","type":"CONDITIONING","link":48},{"name":"nag_negative","localized_name":"nag_negative","type":"CONDITIONING","link":49},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":81}],"outputs":[{"name":"GUIDER","localized_name":"GUIDER","type":"GUIDER","links":[53],"slot_index":0}],"properties":{"Node name for S&R":"NAGCFGGuider"},"widgets_values":[1,3,2.5,0.5,4]}],"links":[[10,11,0,6,0,"CLIP"],[38,12,0,17,0,"MODEL"],[44,11,0,32,0,"CLIP"],[45,11,0,31,0,"CLIP"],[46,12,0,30,0,"MODEL"],[47,6,0,30,1,"CONDITIONING"],[48,31,0,30,2,"CONDITIONING"],[49,32,0,30,3,"CONDITIONING"],[50,34,0,36,0,"IMAGE"],[51,33,0,34,0,"LATENT"],[52,25,0,33,0,"NOISE"],[53,30,0,33,1,"GUIDER"],[54,5,0,33,4,"LATENT"],[55,17,0,33,3,"SIGMAS"],[56,16,0,33,2,"SAMPLER"],[57,10,0,34,1,"VAE"],[68,40,0,41,1,"GUIDER"],[69,41,0,42,0,"LATENT"],[70,10,0,42,1,"VAE"],[71,42,0,43,0,"IMAGE"],[72,6,0,40,1,"CONDITIONING"],[73,32,0,40,3,"CONDITIONING"],[74,31,0,40,2,"CONDITIONING"],[75,12,0,40,0,"MODEL"],[76,17,0,41,3,"SIGMAS"],[77,16,0,41,2,"SAMPLER"],[78,25,0,41,0,"NOISE"],[79,5,0,41,4,"LATENT"],[80,5,0,40,4,"LATENT"],[81,5,0,30,4,"LATENT"]],"groups":[],"config":{},"extra":{"ds":{"scale":1.1360692931284764,"offset":[-161.2854704088722,-20.60295583757925]}},"version":0.4} -------------------------------------------------------------------------------- /workflows/NAG-Flux-Dev-ComfyUI-Workflow.json: -------------------------------------------------------------------------------- 1 | {"last_node_id":44,"last_link_id":84,"nodes":[{"id":33,"type":"SamplerCustomAdvanced","pos":[1290,40],"size":[355.20001220703125,106],"flags":{},"order":14,"mode":0,"inputs":[{"name":"noise","localized_name":"noise","label":"noise","type":"NOISE","link":52,"slot_index":0},{"name":"guider","localized_name":"guider","label":"guider","type":"GUIDER","link":53,"slot_index":1},{"name":"sampler","localized_name":"sampler","label":"sampler","type":"SAMPLER","link":56,"slot_index":2},{"name":"sigmas","localized_name":"sigmas","label":"sigmas","type":"SIGMAS","link":55,"slot_index":3},{"name":"latent_image","localized_name":"latent_image","label":"latent_image","type":"LATENT","link":54,"slot_index":4}],"outputs":[{"name":"output","localized_name":"output","label":"output","type":"LATENT","shape":3,"links":[51],"slot_index":0},{"name":"denoised_output","localized_name":"denoised_output","label":"denoised_output","type":"LATENT","shape":3,"links":null}],"properties":{"Node name for S&R":"SamplerCustomAdvanced"},"widgets_values":[]},{"id":34,"type":"VAEDecode","pos":[1290,200],"size":[210,46],"flags":{},"order":16,"mode":0,"inputs":[{"name":"samples","localized_name":"samples","label":"samples","type":"LATENT","link":51},{"name":"vae","localized_name":"vae","label":"vae","type":"VAE","link":57}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","label":"IMAGE","type":"IMAGE","links":[50],"slot_index":0}],"properties":{"Node name for S&R":"VAEDecode"},"widgets_values":[]},{"id":43,"type":"SaveImage","pos":[1700,430],"size":[450,490],"flags":{},"order":17,"mode":0,"inputs":[{"name":"images","localized_name":"images","label":"images","type":"IMAGE","link":71}],"outputs":[],"properties":{},"widgets_values":["ComfyUI"]},{"id":36,"type":"SaveImage","pos":[1700,-120],"size":[450,490],"flags":{},"order":18,"mode":0,"inputs":[{"name":"images","localized_name":"images","label":"images","type":"IMAGE","link":50}],"outputs":[],"properties":{},"widgets_values":["ComfyUI"]},{"id":41,"type":"SamplerCustomAdvanced","pos":[1290,650],"size":[355.20001220703125,106],"flags":{},"order":13,"mode":0,"inputs":[{"name":"noise","localized_name":"noise","label":"noise","type":"NOISE","link":78,"slot_index":0},{"name":"guider","localized_name":"guider","label":"guider","type":"GUIDER","link":68,"slot_index":1},{"name":"sampler","localized_name":"sampler","label":"sampler","type":"SAMPLER","link":77,"slot_index":2},{"name":"sigmas","localized_name":"sigmas","label":"sigmas","type":"SIGMAS","link":76,"slot_index":3},{"name":"latent_image","localized_name":"latent_image","label":"latent_image","type":"LATENT","link":79,"slot_index":4}],"outputs":[{"name":"output","localized_name":"output","label":"output","type":"LATENT","shape":3,"links":[69],"slot_index":0},{"name":"denoised_output","localized_name":"denoised_output","label":"denoised_output","type":"LATENT","shape":3,"links":null}],"properties":{"Node name for S&R":"SamplerCustomAdvanced"},"widgets_values":[]},{"id":10,"type":"VAELoader","pos":[48,384],"size":[315,58],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"VAE","localized_name":"VAE","label":"VAE","type":"VAE","shape":3,"links":[57,70],"slot_index":0}],"properties":{"Node name for S&R":"VAELoader"},"widgets_values":["ae.safetensors"]},{"id":42,"type":"VAEDecode","pos":[1290,540],"size":[210,46],"flags":{},"order":15,"mode":0,"inputs":[{"name":"samples","localized_name":"samples","label":"samples","type":"LATENT","link":69},{"name":"vae","localized_name":"vae","label":"vae","type":"VAE","link":70}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","label":"IMAGE","type":"IMAGE","links":[71],"slot_index":0}],"properties":{"Node name for S&R":"VAEDecode"},"widgets_values":[]},{"id":6,"type":"CLIPTextEncode","pos":[375,221],"size":[422.84503173828125,164.31304931640625],"flags":{},"order":7,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","label":"clip","type":"CLIP","link":10}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","label":"CONDITIONING","type":"CONDITIONING","links":[80],"slot_index":0}],"properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["A beautiful cyborg."],"color":"#232","bgcolor":"#353"},{"id":32,"type":"CLIPTextEncode","pos":[380,0],"size":[422.84503173828125,164.31304931640625],"flags":{},"order":8,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","label":"clip","type":"CLIP","link":44}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","label":"CONDITIONING","type":"CONDITIONING","links":[49,73],"slot_index":0}],"title":"CLIP Text Encode (NAG Negative Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["Robot."],"color":"#322","bgcolor":"#533"},{"id":31,"type":"CLIPTextEncode","pos":[380,-230],"size":[422.84503173828125,164.31304931640625],"flags":{},"order":9,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","label":"clip","type":"CLIP","link":45}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","label":"CONDITIONING","type":"CONDITIONING","links":[48,74],"slot_index":0}],"title":"CLIP Text Encode (Negative Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":[""],"color":"#322","bgcolor":"#533"},{"id":16,"type":"KSamplerSelect","pos":[480,720],"size":[315,58],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[{"name":"SAMPLER","localized_name":"SAMPLER","label":"SAMPLER","type":"SAMPLER","shape":3,"links":[56,77],"slot_index":0}],"properties":{"Node name for S&R":"KSamplerSelect"},"widgets_values":["euler"]},{"id":44,"type":"FluxGuidance","pos":[930,380],"size":[317.4000244140625,58],"flags":{},"order":10,"mode":0,"inputs":[{"name":"conditioning","localized_name":"conditioning","type":"CONDITIONING","link":80}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","type":"CONDITIONING","links":[81,82],"slot_index":0}],"properties":{"Node name for S&R":"FluxGuidance"},"widgets_values":[3.5]},{"id":12,"type":"UNETLoader","pos":[48,96],"size":[315,82],"flags":{},"order":2,"mode":0,"inputs":[],"outputs":[{"name":"MODEL","localized_name":"MODEL","label":"MODEL","type":"MODEL","shape":3,"links":[38,46,75],"slot_index":0}],"properties":{"Node name for S&R":"UNETLoader"},"widgets_values":["flux1-dev.safetensors","default"],"color":"#223","bgcolor":"#335"},{"id":17,"type":"BasicScheduler","pos":[480,816],"size":[315,106],"flags":{},"order":6,"mode":0,"inputs":[{"name":"model","localized_name":"model","label":"model","type":"MODEL","link":38,"slot_index":0}],"outputs":[{"name":"SIGMAS","localized_name":"SIGMAS","label":"SIGMAS","type":"SIGMAS","shape":3,"links":[55,76],"slot_index":0}],"properties":{"Node name for S&R":"BasicScheduler"},"widgets_values":["simple",25,1]},{"id":11,"type":"DualCLIPLoader","pos":[48,240],"size":[315,122],"flags":{},"order":3,"mode":0,"inputs":[],"outputs":[{"name":"CLIP","localized_name":"CLIP","label":"CLIP","type":"CLIP","shape":3,"links":[10,44,45],"slot_index":0}],"properties":{"Node name for S&R":"DualCLIPLoader"},"widgets_values":["t5xxl_fp16.safetensors","clip_l.safetensors","flux","default"]},{"id":25,"type":"RandomNoise","pos":[480,576],"size":[315,82],"flags":{},"order":4,"mode":0,"inputs":[],"outputs":[{"name":"NOISE","localized_name":"NOISE","label":"NOISE","type":"NOISE","shape":3,"links":[52,78],"slot_index":0}],"properties":{"Node name for S&R":"RandomNoise"},"widgets_values":[569167680612941,"randomize"],"color":"#2a363b","bgcolor":"#3f5159"},{"id":5,"type":"EmptyLatentImage","pos":[480,432],"size":[315,106],"flags":{},"order":5,"mode":0,"inputs":[],"outputs":[{"name":"LATENT","localized_name":"LATENT","label":"LATENT","type":"LATENT","links":[54,79,83,84],"slot_index":0}],"properties":{"Node name for S&R":"EmptyLatentImage"},"widgets_values":[1024,1024,1],"color":"#323","bgcolor":"#535"},{"id":40,"type":"NAGCFGGuider","pos":[930,510],"size":[315,234],"flags":{},"order":11,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":75},{"name":"positive","localized_name":"positive","type":"CONDITIONING","link":81},{"name":"negative","localized_name":"negative","type":"CONDITIONING","link":74},{"name":"nag_negative","localized_name":"nag_negative","type":"CONDITIONING","link":73},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":84}],"outputs":[{"name":"GUIDER","localized_name":"GUIDER","type":"GUIDER","links":[68],"slot_index":0}],"properties":{"Node name for S&R":"NAGCFGGuider"},"widgets_values":[1,1,2.5,0.25,0]},{"id":30,"type":"NAGCFGGuider","pos":[930,110],"size":[315,234],"flags":{},"order":12,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":46},{"name":"positive","localized_name":"positive","type":"CONDITIONING","link":82},{"name":"negative","localized_name":"negative","type":"CONDITIONING","link":48},{"name":"nag_negative","localized_name":"nag_negative","type":"CONDITIONING","link":49},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":83}],"outputs":[{"name":"GUIDER","localized_name":"GUIDER","type":"GUIDER","links":[53],"slot_index":0}],"properties":{"Node name for S&R":"NAGCFGGuider"},"widgets_values":[1,5,2.5,0.25,0.75]}],"links":[[10,11,0,6,0,"CLIP"],[38,12,0,17,0,"MODEL"],[44,11,0,32,0,"CLIP"],[45,11,0,31,0,"CLIP"],[46,12,0,30,0,"MODEL"],[48,31,0,30,2,"CONDITIONING"],[49,32,0,30,3,"CONDITIONING"],[50,34,0,36,0,"IMAGE"],[51,33,0,34,0,"LATENT"],[52,25,0,33,0,"NOISE"],[53,30,0,33,1,"GUIDER"],[54,5,0,33,4,"LATENT"],[55,17,0,33,3,"SIGMAS"],[56,16,0,33,2,"SAMPLER"],[57,10,0,34,1,"VAE"],[68,40,0,41,1,"GUIDER"],[69,41,0,42,0,"LATENT"],[70,10,0,42,1,"VAE"],[71,42,0,43,0,"IMAGE"],[73,32,0,40,3,"CONDITIONING"],[74,31,0,40,2,"CONDITIONING"],[75,12,0,40,0,"MODEL"],[76,17,0,41,3,"SIGMAS"],[77,16,0,41,2,"SAMPLER"],[78,25,0,41,0,"NOISE"],[79,5,0,41,4,"LATENT"],[80,6,0,44,0,"CONDITIONING"],[81,44,0,40,1,"CONDITIONING"],[82,44,0,30,1,"CONDITIONING"],[83,5,0,30,4,"LATENT"],[84,5,0,40,4,"LATENT"]],"groups":[],"config":{},"extra":{"ds":{"scale":0.938900242254939,"offset":[117.51570117103262,249.50629750502912]}},"version":0.4} -------------------------------------------------------------------------------- /workflows/NAG-Flux-Kontext-Dev-ComfyUI-Workflow.json: -------------------------------------------------------------------------------- 1 | {"last_node_id":193,"last_link_id":306,"nodes":[{"id":173,"type":"PreviewImage","pos":[320,860],"size":[420,310],"flags":{},"order":19,"mode":0,"inputs":[{"name":"images","localized_name":"images","type":"IMAGE","link":289}],"outputs":[],"properties":{"Node name for S&R":"PreviewImage","cnr_id":"comfy-core","ver":"0.3.40"},"widgets_values":[]},{"id":136,"type":"SaveImage","pos":[760,510],"size":[650,660],"flags":{},"order":26,"mode":0,"inputs":[{"name":"images","localized_name":"images","type":"IMAGE","link":240}],"outputs":[],"properties":{"cnr_id":"comfy-core","ver":"0.3.39"},"widgets_values":["ComfyUI"]},{"id":147,"type":"LoadImageOutput","pos":[-50,770],"size":[320,374],"flags":{},"order":0,"mode":4,"inputs":[],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","type":"IMAGE","links":[250]},{"name":"MASK","localized_name":"MASK","type":"MASK","links":null}],"properties":{"Node name for S&R":"LoadImageOutput","cnr_id":"comfy-core","ver":"0.3.40"},"widgets_values":["rabbit.jpg [output]",false,"refresh"],"color":"#322","bgcolor":"#533"},{"id":8,"type":"VAEDecode","pos":[530,350],"size":[190,46],"flags":{"collapsed":false},"order":24,"mode":0,"inputs":[{"name":"samples","localized_name":"samples","type":"LATENT","link":52},{"name":"vae","localized_name":"vae","type":"VAE","link":61}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","type":"IMAGE","links":[240],"slot_index":0}],"properties":{"Node name for S&R":"VAEDecode","cnr_id":"comfy-core","ver":"0.3.38"},"widgets_values":[]},{"id":42,"type":"FluxKontextImageScale","pos":[-50,570],"size":[270,30],"flags":{"collapsed":false},"order":16,"mode":0,"inputs":[{"name":"image","localized_name":"image","type":"IMAGE","link":251}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","type":"IMAGE","links":[222,289]}],"properties":{"Node name for S&R":"FluxKontextImageScale","cnr_id":"comfy-core","ver":"0.3.38"},"widgets_values":[]},{"id":175,"type":"MarkdownNote","pos":[-50,640],"size":[320,88],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[],"title":"About Flux Kontext Edit","properties":{},"widgets_values":["[English] Use Ctrl + B to enable multipule image input.\n\n[中文] 使用 **Ctrl + B** 来启用多张图片输入"],"color":"#432","bgcolor":"#653"},{"id":146,"type":"ImageStitch","pos":[-390,570],"size":[270,150],"flags":{},"order":13,"mode":0,"inputs":[{"name":"image1","localized_name":"image1","type":"IMAGE","link":249},{"name":"image2","localized_name":"image2","type":"IMAGE","shape":7,"link":250}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","type":"IMAGE","links":[251]}],"properties":{"Node name for S&R":"ImageStitch","cnr_id":"comfy-core","ver":"0.3.40"},"widgets_values":["right",true,0,"white"]},{"id":185,"type":"MarkdownNote","pos":[-960,490],"size":[510,170],"flags":{},"order":2,"mode":0,"inputs":[],"outputs":[],"title":"About VRAM","properties":{},"widgets_values":["For reference:\n- **fp8_scaled**: Requires about 20GB of VRAM.\n- **Original**: Requires about 32GB of VRAM.\n\n---\n\n供参考:\n- **fp8_scaled** : 大概需要 20GB 左右 VRAM \n- **原始权重**: 原始权重,大概需要 32GB 左右 VRAM \n"],"color":"#432","bgcolor":"#653"},{"id":186,"type":"MarkdownNote","pos":[-960,710],"size":[510,170],"flags":{},"order":3,"mode":0,"inputs":[],"outputs":[],"title":"Flux Kontext Prompt Techniques","properties":{},"widgets_values":["\n## Flux Kontext Prompt Techniques\n\n### 1. Basic Modifications\n- Simple and direct: `\"Change the car color to red\"`\n- Maintain style: `\"Change to daytime while maintaining the same style of the painting\"`\n\n### 2. Style Transfer\n**Principles:**\n- Clearly name style: `\"Transform to Bauhaus art style\"`\n- Describe characteristics: `\"Transform to oil painting with visible brushstrokes, thick paint texture\"`\n- Preserve composition: `\"Change to Bauhaus style while maintaining the original composition\"`\n\n### 3. Character Consistency\n**Framework:**\n- Specific description: `\"The woman with short black hair\"` instead of \"she\"\n- Preserve features: `\"while maintaining the same facial features, hairstyle, and expression\"`\n- Step-by-step modifications: Change background first, then actions\n\n### 4. Text Editing\n- Use quotes: `\"Replace 'joy' with 'BFL'\"`\n- Maintain format: `\"Replace text while maintaining the same font style\"`\n\n## Common Problem Solutions\n\n### Character Changes Too Much\n❌ Wrong: `\"Transform the person into a Viking\"`\n✅ Correct: `\"Change the clothes to be a viking warrior while preserving facial features\"`\n\n### Composition Position Changes\n❌ Wrong: `\"Put him on a beach\"`\n✅ Correct: `\"Change the background to a beach while keeping the person in the exact same position, scale, and pose\"`\n\n### Style Application Inaccuracy\n❌ Wrong: `\"Make it a sketch\"`\n✅ Correct: `\"Convert to pencil sketch with natural graphite lines, cross-hatching, and visible paper texture\"`\n\n## Core Principles\n\n1. **Be Specific and Clear** - Use precise descriptions, avoid vague terms\n2. **Step-by-step Editing** - Break complex modifications into multiple simple steps\n3. **Explicit Preservation** - State what should remain unchanged\n4. **Verb Selection** - Use \"change\", \"replace\" rather than \"transform\"\n\n## Best Practice Templates\n\n**Object Modification:**\n`\"Change [object] to [new state], keep [content to preserve] unchanged\"`\n\n**Style Transfer:**\n`\"Transform to [specific style], while maintaining [composition/character/other] unchanged\"`\n\n**Background Replacement:**\n`\"Change the background to [new background], keep the subject in the exact same position and pose\"`\n\n**Text Editing:**\n`\"Replace '[original text]' with '[new text]', maintain the same font style\"`\n\n> **Remember:** The more specific, the better. Kontext excels at understanding detailed instructions and maintaining consistency. "],"color":"#432","bgcolor":"#653"},{"id":187,"type":"MarkdownNote","pos":[-960,930],"size":[510,180],"flags":{},"order":4,"mode":0,"inputs":[],"outputs":[],"title":"Flux Kontext 提示词技巧","properties":{},"widgets_values":["\n## Flux Kontext 提示词技巧\n\n使用英文\n\n### 1. 基础修改\n- 简单直接:`\"Change the car color to red\"`\n- 保持风格:`\"Change to daytime while maintaining the same style of the painting\"`\n\n### 2. 风格转换\n**原则:**\n- 明确命名风格:`\"Transform to Bauhaus art style\"`\n- 描述特征:`\"Transform to oil painting with visible brushstrokes, thick paint texture\"`\n- 保留构图:`\"Change to Bauhaus style while maintaining the original composition\"`\n\n### 3. 角色一致性\n**框架:**\n- 具体描述:`\"The woman with short black hair\"`而非`\"她\"`\n- 保留特征:`\"while maintaining the same facial features, hairstyle, and expression\"`\n- 分步修改:先改背景,再改动作\n\n### 4. 文本编辑\n- 使用引号:`\"Replace 'joy' with 'BFL'\"`\n- 保持格式:`\"Replace text while maintaining the same font style\"`\n\n## 常见问题解决\n\n### 角色变化过大\n❌ 错误:`\"Transform the person into a Viking\"`\n✅ 正确:`\"Change the clothes to be a viking warrior while preserving facial features\"`\n\n### 构图位置改变\n❌ 错误:`\"Put him on a beach\"`\n✅ 正确:`\"Change the background to a beach while keeping the person in the exact same position, scale, and pose\"`\n\n### 风格应用不准确\n❌ 错误:`\"Make it a sketch\"`\n✅ 正确:`\"Convert to pencil sketch with natural graphite lines, cross-hatching, and visible paper texture\"`\n\n## 核心原则\n\n1. **具体明确** - 使用精确描述,避免模糊词汇\n2. **分步编辑** - 复杂修改分为多个简单步骤\n3. **明确保留** - 说明哪些要保持不变\n4. **动词选择** - 用\"更改\"、\"替换\"而非\"转换\"\n\n## 最佳实践模板\n\n**对象修改:**\n`\"Change [object] to [new state], keep [content to preserve] unchanged\"`\n\n**风格转换:**\n`\"Transform to [specific style], while maintaining [composition/character/other] unchanged\"`\n\n**背景替换:**\n`\"Change the background to [new background], keep the subject in the exact same position and pose\"`\n\n**文本编辑:**\n`\"Replace '[original text]' with '[new text]', maintain the same font style\"`\n\n> **记住:** 越具体越好,Kontext 擅长理解详细指令并保持一致性。"],"color":"#432","bgcolor":"#653"},{"id":184,"type":"MarkdownNote","pos":[-960,40],"size":[510,400],"flags":{},"order":5,"mode":0,"inputs":[],"outputs":[],"title":"Model links","properties":{},"widgets_values":["[tutorial](http://docs.comfy.org/tutorials/flux/flux-1-kontext-dev) | [教程](http://docs.comfy.org/zh-CN/tutorials/flux/flux-1-kontext-dev)\n\n**diffusion model**\n\n- [flux1-dev-kontext_fp8_scaled.safetensors](https://huggingface.co/Comfy-Org/flux1-kontext-dev_ComfyUI/resolve/main/split_files/diffusion_models/flux1-dev-kontext_fp8_scaled.safetensors)\n\n**vae**\n\n- [ae.safetensors](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/vae/ae.safetensors)\n\n**text encoder**\n\n- [clip_l.safetensors](https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/clip_l.safetensors)\n- [t5xxl_fp16.safetensors](https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/t5xxl_fp16.safetensors) or [t5xxl_fp8_e4m3fn_scaled.safetensors](https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/t5xxl_fp8_e4m3fn_scaled.safetensors)\n\nModel Storage Location\n\n```\n📂 ComfyUI/\n├── 📂 models/\n│ ├── 📂 diffusion_models/\n│ │ └── flux1-dev-kontext_fp8_scaled.safetensors\n│ ├── 📂 vae/\n│ │ └── ae.safetensor\n│ └── 📂 text_encoders/\n│ ├── clip_l.safetensors\n│ └── t5xxl_fp16.safetensors 或者 t5xxl_fp8_e4m3fn_scaled.safetensors\n```\n"],"color":"#432","bgcolor":"#653"},{"id":180,"type":"MarkdownNote","pos":[-1430,40],"size":[450,450],"flags":{},"order":6,"mode":0,"inputs":[],"outputs":[],"title":"✨ New ComfyUI feature for Flux.1 Kontext Dev","properties":{},"widgets_values":["[English]\nWe have added an **Edit** button to the **Selection Toolbox** of the node for **FLUX.1 Kontext Image Edit** support. When clicked, it quickly adds a **FLUX.1 Kontext Image Edit** group node to the Latent output of your current workflow. This enables an interactive editing experience where you can:\n\n- Create multiple editing iterations, each preserved as a separate node\n- Easily branch off from any previous edit point to explore different creative directions\n- Return to any earlier version and start a new editing branch\n- Modify parameters in earlier nodes and automatically update all downstream edits\n- Execute or re-execute any branch of edits at any time\n\nThis workflow mirrors the iterative nature of LLM conversations, but with the added advantage of visual editing and the ability to maintain multiple parallel editing paths.\n\n---\n\n[中文]\n我们为 **FLUX.1 Kontext Image Edit** 的相关支持在节点的**选择工具箱**上新增了一个**编辑**按钮。点击后,系统会在当前工作流的 Latent 输出上快速添加一个 **FLUX.1 Kontext Image Edit** 的组节点。这种设计带来了灵活的交互式编辑体验:\n\n- 创建多个编辑迭代,每次编辑都会保存为独立节点\n- 可以从任何之前的编辑点分支出新的创作方向\n- 随时返回到早期版本并开始新的编辑分支\n- 修改早期节点的参数,自动更新所有下游编辑\n- 可以随时执行或重新执行任何编辑分支\n\n这种工作流程类似于 LLM 对话的迭代特性,但增加了视觉编辑的优势,并能够维护多个并行的编辑路径。"],"color":"#322","bgcolor":"#533"},{"id":177,"type":"ReferenceLatent","pos":[10,140],"size":[211.60000610351562,46],"flags":{},"order":20,"mode":0,"inputs":[{"name":"conditioning","localized_name":"conditioning","type":"CONDITIONING","link":294},{"name":"latent","localized_name":"latent","type":"LATENT","shape":7,"link":293}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","type":"CONDITIONING","links":[292]}],"properties":{"Node name for S&R":"ReferenceLatent","cnr_id":"comfy-core","ver":"0.3.41"},"widgets_values":[]},{"id":178,"type":"MarkdownNote","pos":[-30,-150],"size":[540,150],"flags":{},"order":7,"mode":0,"inputs":[],"outputs":[],"title":"About multiple images reference","properties":{},"widgets_values":["[English] In addition to using **Image Stitch** to combine two images at a time, you can also encode individual images, then concatenate multiple latent conditions using the **ReferenceLatent** node, thus achieving the purpose of referencing multiple images. You can use the **EmptySD3LatentImage** node on the right to connect to **KSamper** and customize the size of the **latent_image**.\n\n[中文] 除了使用 **Image Stitch** 将两个两个图像拼合之外,你同样可以将单独的图像 encode 之后,将多个 latent 条件使用 **ReferenceLatent** 节点串联,从而实现多张图像参考的目的。可以使用右边的 **EmptySD3LatentImage** 节点连接到 **KSamper**来自定义 **latent_image** 的尺寸"],"color":"#432","bgcolor":"#653"},{"id":188,"type":"EmptySD3LatentImage","pos":[530,-140],"size":[310,106],"flags":{},"order":8,"mode":4,"inputs":[],"outputs":[{"name":"LATENT","localized_name":"LATENT","type":"LATENT","links":null}],"properties":{"Node name for S&R":"EmptySD3LatentImage","cnr_id":"comfy-core","ver":"0.3.41"},"widgets_values":[1024,1024,1]},{"id":142,"type":"LoadImageOutput","pos":[-390,770],"size":[320,374],"flags":{},"order":9,"mode":0,"inputs":[],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","type":"IMAGE","links":[249]},{"name":"MASK","localized_name":"MASK","type":"MASK","links":null}],"properties":{"Node name for S&R":"LoadImageOutput","cnr_id":"comfy-core","ver":"0.3.40"},"widgets_values":["rabbit.jpg [output]",false,"refresh"],"color":"#322","bgcolor":"#533"},{"id":190,"type":"VAEDecode","pos":[1230,340],"size":[190,46],"flags":{"collapsed":false},"order":25,"mode":0,"inputs":[{"name":"samples","localized_name":"samples","type":"LATENT","link":299},{"name":"vae","localized_name":"vae","type":"VAE","link":298}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","type":"IMAGE","links":[296],"slot_index":0}],"properties":{"Node name for S&R":"VAEDecode","cnr_id":"comfy-core","ver":"0.3.38"},"widgets_values":[]},{"id":191,"type":"SaveImage","pos":[1460,500],"size":[650,660],"flags":{},"order":27,"mode":0,"inputs":[{"name":"images","localized_name":"images","type":"IMAGE","link":296}],"outputs":[],"properties":{"cnr_id":"comfy-core","ver":"0.3.39"},"widgets_values":["ComfyUI"]},{"id":39,"type":"VAELoader","pos":[-400,390],"size":[337.76861572265625,58],"flags":{},"order":10,"mode":0,"inputs":[],"outputs":[{"name":"VAE","localized_name":"VAE","type":"VAE","links":[61,223,298],"slot_index":0}],"properties":{"Node name for S&R":"VAELoader","cnr_id":"comfy-core","ver":"0.3.38","models":[{"name":"ae.safetensors","url":"https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/resolve/main/split_files/vae/ae.safetensors","directory":"vae"}]},"widgets_values":["ae.safetensors"],"color":"#322","bgcolor":"#533"},{"id":35,"type":"FluxGuidance","pos":[250,90],"size":[240,58],"flags":{"collapsed":false},"order":21,"mode":0,"inputs":[{"name":"conditioning","localized_name":"conditioning","type":"CONDITIONING","link":292}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","type":"CONDITIONING","links":[57,300],"slot_index":0}],"properties":{"Node name for S&R":"FluxGuidance","cnr_id":"comfy-core","ver":"0.3.38"},"widgets_values":[2.5]},{"id":135,"type":"ConditioningZeroOut","pos":[250,200],"size":[240,26],"flags":{"collapsed":false},"order":17,"mode":0,"inputs":[{"name":"conditioning","localized_name":"conditioning","type":"CONDITIONING","link":237}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","type":"CONDITIONING","links":[238,303],"slot_index":0}],"properties":{"Node name for S&R":"ConditioningZeroOut","cnr_id":"comfy-core","ver":"0.3.39"},"widgets_values":[]},{"id":124,"type":"VAEEncode","pos":[-20,400],"size":[240,50],"flags":{"collapsed":false},"order":18,"mode":0,"inputs":[{"name":"pixels","localized_name":"pixels","type":"IMAGE","link":222},{"name":"vae","localized_name":"vae","type":"VAE","link":223}],"outputs":[{"name":"LATENT","localized_name":"LATENT","type":"LATENT","links":[291,293,302],"slot_index":0}],"properties":{"Node name for S&R":"VAEEncode","cnr_id":"comfy-core","ver":"0.3.39"},"widgets_values":[]},{"id":6,"type":"CLIPTextEncode","pos":[330,560],"size":[400,220],"flags":{},"order":14,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","type":"CLIP","link":59}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","type":"CONDITIONING","links":[237,294],"slot_index":0}],"title":"CLIP Text Encode (Positive Prompt)","properties":{"Node name for S&R":"CLIPTextEncode","cnr_id":"comfy-core","ver":"0.3.38"},"widgets_values":["Using this elegant style, create a portrait of a swan wearing a pearl tiara and lace collar, maintaining the same refined quality and soft color tones."],"color":"#232","bgcolor":"#353"},{"id":38,"type":"DualCLIPLoader","pos":[-400,210],"size":[337.76861572265625,130],"flags":{},"order":11,"mode":0,"inputs":[],"outputs":[{"name":"CLIP","localized_name":"CLIP","type":"CLIP","links":[59,305],"slot_index":0}],"properties":{"Node name for S&R":"DualCLIPLoader","cnr_id":"comfy-core","ver":"0.3.38","models":[{"name":"clip_l.safetensors","url":"https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/clip_l.safetensors","directory":"text_encoders"},{"name":"t5xxl_fp8_e4m3fn_scaled.safetensors","url":"https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/t5xxl_fp8_e4m3fn_scaled.safetensors","directory":"text_encoders"}]},"widgets_values":["clip_l.safetensors","t5xxl_fp8_e4m3fn_scaled.safetensors","flux","default"],"color":"#322","bgcolor":"#533"},{"id":31,"type":"KSampler","pos":[530,40],"size":[320,262],"flags":{},"order":22,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":58},{"name":"positive","localized_name":"positive","type":"CONDITIONING","link":57},{"name":"negative","localized_name":"negative","type":"CONDITIONING","link":238},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":291}],"outputs":[{"name":"LATENT","localized_name":"LATENT","type":"LATENT","links":[52],"slot_index":0}],"properties":{"Node name for S&R":"KSampler","cnr_id":"comfy-core","ver":"0.3.38"},"widgets_values":[2025,"fixed",20,1,"euler","simple",1]},{"id":37,"type":"UNETLoader","pos":[-400,80],"size":[337.76861572265625,82],"flags":{},"order":12,"mode":0,"inputs":[],"outputs":[{"name":"MODEL","localized_name":"MODEL","type":"MODEL","links":[58,306],"slot_index":0}],"properties":{"Node name for S&R":"UNETLoader","cnr_id":"comfy-core","ver":"0.3.38","models":[{"name":"flux1-dev-kontext_fp8_scaled.safetensors","url":"https://huggingface.co/Comfy-Org/flux1-kontext-dev_ComfyUI/resolve/main/split_files/diffusion_models/flux1-dev-kontext_fp8_scaled.safetensors","directory":"diffusion_models"}]},"widgets_values":["flux1-dev-kontext_fp8_scaled.safetensors","default"],"color":"#322","bgcolor":"#533"},{"id":193,"type":"CLIPTextEncode","pos":[880,60],"size":[400,220],"flags":{},"order":15,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","type":"CLIP","link":305}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","type":"CONDITIONING","links":[304],"slot_index":0}],"title":"CLIP Text Encode (NAG negative Prompt)","properties":{"Node name for S&R":"CLIPTextEncode","cnr_id":"comfy-core","ver":"0.3.38"},"widgets_values":["duck."],"color":"#322","bgcolor":"#533"},{"id":192,"type":"KSamplerWithNAG","pos":[1320,-100],"size":[315,378],"flags":{},"order":23,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":306},{"name":"positive","localized_name":"positive","type":"CONDITIONING","link":300},{"name":"negative","localized_name":"negative","type":"CONDITIONING","link":303},{"name":"nag_negative","localized_name":"nag_negative","type":"CONDITIONING","link":304},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":302}],"outputs":[{"name":"LATENT","localized_name":"LATENT","type":"LATENT","links":[299],"slot_index":0}],"properties":{"Node name for S&R":"KSamplerWithNAG"},"widgets_values":[2025,"fixed",20,1,3,2.5,0.25,0.75,"euler","simple",1]}],"links":[[52,31,0,8,0,"LATENT"],[57,35,0,31,1,"CONDITIONING"],[58,37,0,31,0,"MODEL"],[59,38,0,6,0,"CLIP"],[61,39,0,8,1,"VAE"],[222,42,0,124,0,"IMAGE"],[223,39,0,124,1,"VAE"],[237,6,0,135,0,"CONDITIONING"],[238,135,0,31,2,"CONDITIONING"],[240,8,0,136,0,"IMAGE"],[249,142,0,146,0,"IMAGE"],[250,147,0,146,1,"IMAGE"],[251,146,0,42,0,"IMAGE"],[289,42,0,173,0,"IMAGE"],[291,124,0,31,3,"LATENT"],[292,177,0,35,0,"CONDITIONING"],[293,124,0,177,1,"LATENT"],[294,6,0,177,0,"CONDITIONING"],[296,190,0,191,0,"IMAGE"],[298,39,0,190,1,"VAE"],[299,192,0,190,0,"LATENT"],[300,35,0,192,1,"CONDITIONING"],[302,124,0,192,4,"LATENT"],[303,135,0,192,2,"CONDITIONING"],[304,193,0,192,3,"CONDITIONING"],[305,38,0,193,0,"CLIP"],[306,37,0,192,0,"MODEL"]],"groups":[{"id":1,"title":"Step 1- Load models","bounding":[-410,10,360,450],"color":"#3f789e","font_size":24,"flags":{}},{"id":3,"title":"Step 2 - Upload images","bounding":[-410,480,700,680],"color":"#3f789e","font_size":24,"flags":{}},{"id":5,"title":"Step 3 - Prompt","bounding":[310,480,430,330],"color":"#3f789e","font_size":24,"flags":{}},{"id":6,"title":"Conditioning","bounding":[-30,10,540,250],"color":"#3f789e","font_size":24,"flags":{}}],"config":{},"extra":{"ds":{"scale":0.7400249944258633,"offset":[904.7778913363532,217.72649672461966]},"frontendVersion":"1.23.2","groupNodes":{},"VHS_latentpreview":false,"VHS_latentpreviewrate":0,"VHS_MetadataImage":true,"VHS_KeepIntermediate":true},"version":0.4} -------------------------------------------------------------------------------- /workflows/NAG-Flux-Schnell-ComfyUI-Workflow.json: -------------------------------------------------------------------------------- 1 | {"last_node_id":43,"last_link_id":81,"nodes":[{"id":33,"type":"SamplerCustomAdvanced","pos":[1290,40],"size":[355.20001220703125,106],"flags":{},"order":12,"mode":0,"inputs":[{"name":"noise","localized_name":"noise","label":"noise","type":"NOISE","link":52,"slot_index":0},{"name":"guider","localized_name":"guider","label":"guider","type":"GUIDER","link":53,"slot_index":1},{"name":"sampler","localized_name":"sampler","label":"sampler","type":"SAMPLER","link":56,"slot_index":2},{"name":"sigmas","localized_name":"sigmas","label":"sigmas","type":"SIGMAS","link":55,"slot_index":3},{"name":"latent_image","localized_name":"latent_image","label":"latent_image","type":"LATENT","link":54,"slot_index":4}],"outputs":[{"name":"output","localized_name":"output","label":"output","type":"LATENT","shape":3,"links":[51],"slot_index":0},{"name":"denoised_output","localized_name":"denoised_output","label":"denoised_output","type":"LATENT","shape":3,"links":null}],"properties":{"Node name for S&R":"SamplerCustomAdvanced"},"widgets_values":[]},{"id":34,"type":"VAEDecode","pos":[1290,200],"size":[210,46],"flags":{},"order":14,"mode":0,"inputs":[{"name":"samples","localized_name":"samples","label":"samples","type":"LATENT","link":51},{"name":"vae","localized_name":"vae","label":"vae","type":"VAE","link":57}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","label":"IMAGE","type":"IMAGE","links":[50],"slot_index":0}],"properties":{"Node name for S&R":"VAEDecode"},"widgets_values":[]},{"id":43,"type":"SaveImage","pos":[1700,430],"size":[450,490],"flags":{},"order":17,"mode":0,"inputs":[{"name":"images","localized_name":"images","label":"images","type":"IMAGE","link":71}],"outputs":[],"properties":{},"widgets_values":["ComfyUI"]},{"id":36,"type":"SaveImage","pos":[1700,-120],"size":[450,490],"flags":{},"order":16,"mode":0,"inputs":[{"name":"images","localized_name":"images","label":"images","type":"IMAGE","link":50}],"outputs":[],"properties":{},"widgets_values":["ComfyUI"]},{"id":41,"type":"SamplerCustomAdvanced","pos":[1290,650],"size":[355.20001220703125,106],"flags":{},"order":13,"mode":0,"inputs":[{"name":"noise","localized_name":"noise","label":"noise","type":"NOISE","link":78,"slot_index":0},{"name":"guider","localized_name":"guider","label":"guider","type":"GUIDER","link":68,"slot_index":1},{"name":"sampler","localized_name":"sampler","label":"sampler","type":"SAMPLER","link":77,"slot_index":2},{"name":"sigmas","localized_name":"sigmas","label":"sigmas","type":"SIGMAS","link":76,"slot_index":3},{"name":"latent_image","localized_name":"latent_image","label":"latent_image","type":"LATENT","link":79,"slot_index":4}],"outputs":[{"name":"output","localized_name":"output","label":"output","type":"LATENT","shape":3,"links":[69],"slot_index":0},{"name":"denoised_output","localized_name":"denoised_output","label":"denoised_output","type":"LATENT","shape":3,"links":null}],"properties":{"Node name for S&R":"SamplerCustomAdvanced"},"widgets_values":[]},{"id":10,"type":"VAELoader","pos":[48,384],"size":[315,58],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"VAE","localized_name":"VAE","label":"VAE","type":"VAE","shape":3,"links":[57,70],"slot_index":0}],"properties":{"Node name for S&R":"VAELoader"},"widgets_values":["ae.safetensors"]},{"id":42,"type":"VAEDecode","pos":[1290,540],"size":[210,46],"flags":{},"order":15,"mode":0,"inputs":[{"name":"samples","localized_name":"samples","label":"samples","type":"LATENT","link":69},{"name":"vae","localized_name":"vae","label":"vae","type":"VAE","link":70}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","label":"IMAGE","type":"IMAGE","links":[71],"slot_index":0}],"properties":{"Node name for S&R":"VAEDecode"},"widgets_values":[]},{"id":6,"type":"CLIPTextEncode","pos":[375,221],"size":[422.84503173828125,164.31304931640625],"flags":{},"order":7,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","label":"clip","type":"CLIP","link":10}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","label":"CONDITIONING","type":"CONDITIONING","links":[47,72],"slot_index":0}],"properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["A beautiful cyborg."],"color":"#232","bgcolor":"#353"},{"id":32,"type":"CLIPTextEncode","pos":[380,0],"size":[422.84503173828125,164.31304931640625],"flags":{},"order":8,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","label":"clip","type":"CLIP","link":44}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","label":"CONDITIONING","type":"CONDITIONING","links":[49,73],"slot_index":0}],"title":"CLIP Text Encode (NAG Negative Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["Robot."],"color":"#322","bgcolor":"#533"},{"id":31,"type":"CLIPTextEncode","pos":[380,-230],"size":[422.84503173828125,164.31304931640625],"flags":{},"order":9,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","label":"clip","type":"CLIP","link":45}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","label":"CONDITIONING","type":"CONDITIONING","links":[48,74],"slot_index":0}],"title":"CLIP Text Encode (Negative Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":[""],"color":"#322","bgcolor":"#533"},{"id":12,"type":"UNETLoader","pos":[48,96],"size":[315,82],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[{"name":"MODEL","localized_name":"MODEL","label":"MODEL","type":"MODEL","shape":3,"links":[38,46,75],"slot_index":0}],"properties":{"Node name for S&R":"UNETLoader"},"widgets_values":["flux1-schnell.safetensors","default"],"color":"#223","bgcolor":"#335"},{"id":17,"type":"BasicScheduler","pos":[480,816],"size":[315,106],"flags":{},"order":6,"mode":0,"inputs":[{"name":"model","localized_name":"model","label":"model","type":"MODEL","link":38,"slot_index":0}],"outputs":[{"name":"SIGMAS","localized_name":"SIGMAS","label":"SIGMAS","type":"SIGMAS","shape":3,"links":[55,76],"slot_index":0}],"properties":{"Node name for S&R":"BasicScheduler"},"widgets_values":["simple",4,1]},{"id":16,"type":"KSamplerSelect","pos":[480,720],"size":[315,58],"flags":{},"order":2,"mode":0,"inputs":[],"outputs":[{"name":"SAMPLER","localized_name":"SAMPLER","label":"SAMPLER","type":"SAMPLER","shape":3,"links":[56,77],"slot_index":0}],"properties":{"Node name for S&R":"KSamplerSelect"},"widgets_values":["euler"]},{"id":11,"type":"DualCLIPLoader","pos":[48,240],"size":[315,122],"flags":{},"order":3,"mode":0,"inputs":[],"outputs":[{"name":"CLIP","localized_name":"CLIP","label":"CLIP","type":"CLIP","shape":3,"links":[10,44,45],"slot_index":0}],"properties":{"Node name for S&R":"DualCLIPLoader"},"widgets_values":["t5xxl_fp16.safetensors","clip_l.safetensors","flux","default"]},{"id":25,"type":"RandomNoise","pos":[480,576],"size":[315,82],"flags":{},"order":4,"mode":0,"inputs":[],"outputs":[{"name":"NOISE","localized_name":"NOISE","label":"NOISE","type":"NOISE","shape":3,"links":[52,78],"slot_index":0}],"properties":{"Node name for S&R":"RandomNoise"},"widgets_values":[901354717069938,"randomize"],"color":"#2a363b","bgcolor":"#3f5159"},{"id":5,"type":"EmptyLatentImage","pos":[480,432],"size":[315,106],"flags":{},"order":5,"mode":0,"inputs":[],"outputs":[{"name":"LATENT","localized_name":"LATENT","label":"LATENT","type":"LATENT","links":[54,79,80,81],"slot_index":0}],"properties":{"Node name for S&R":"EmptyLatentImage"},"widgets_values":[1024,1024,1],"color":"#323","bgcolor":"#535"},{"id":40,"type":"NAGCFGGuider","pos":[930,460],"size":[315,234],"flags":{},"order":11,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":75},{"name":"positive","localized_name":"positive","type":"CONDITIONING","link":72},{"name":"negative","localized_name":"negative","type":"CONDITIONING","link":74},{"name":"nag_negative","localized_name":"nag_negative","type":"CONDITIONING","link":73},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":80}],"outputs":[{"name":"GUIDER","localized_name":"GUIDER","type":"GUIDER","links":[68],"slot_index":0}],"properties":{"Node name for S&R":"NAGCFGGuider"},"widgets_values":[1,1,2.5,0.25,0]},{"id":30,"type":"NAGCFGGuider","pos":[930,180],"size":[315,234],"flags":{},"order":10,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":46},{"name":"positive","localized_name":"positive","type":"CONDITIONING","link":47},{"name":"negative","localized_name":"negative","type":"CONDITIONING","link":48},{"name":"nag_negative","localized_name":"nag_negative","type":"CONDITIONING","link":49},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":81}],"outputs":[{"name":"GUIDER","localized_name":"GUIDER","type":"GUIDER","links":[53],"slot_index":0}],"properties":{"Node name for S&R":"NAGCFGGuider"},"widgets_values":[1,5,2.5,0.25,0.75]}],"links":[[10,11,0,6,0,"CLIP"],[38,12,0,17,0,"MODEL"],[44,11,0,32,0,"CLIP"],[45,11,0,31,0,"CLIP"],[46,12,0,30,0,"MODEL"],[47,6,0,30,1,"CONDITIONING"],[48,31,0,30,2,"CONDITIONING"],[49,32,0,30,3,"CONDITIONING"],[50,34,0,36,0,"IMAGE"],[51,33,0,34,0,"LATENT"],[52,25,0,33,0,"NOISE"],[53,30,0,33,1,"GUIDER"],[54,5,0,33,4,"LATENT"],[55,17,0,33,3,"SIGMAS"],[56,16,0,33,2,"SAMPLER"],[57,10,0,34,1,"VAE"],[68,40,0,41,1,"GUIDER"],[69,41,0,42,0,"LATENT"],[70,10,0,42,1,"VAE"],[71,42,0,43,0,"IMAGE"],[72,6,0,40,1,"CONDITIONING"],[73,32,0,40,3,"CONDITIONING"],[74,31,0,40,2,"CONDITIONING"],[75,12,0,40,0,"MODEL"],[76,17,0,41,3,"SIGMAS"],[77,16,0,41,2,"SAMPLER"],[78,25,0,41,0,"NOISE"],[79,5,0,41,4,"LATENT"],[80,5,0,40,4,"LATENT"],[81,5,0,30,4,"LATENT"]],"groups":[],"config":{},"extra":{"ds":{"scale":1.1360692931284764,"offset":[-314.4502750403351,13.050650279273668]}},"version":0.4} -------------------------------------------------------------------------------- /workflows/NAG-Hunyuan-ComfyUI-Workflow.json: -------------------------------------------------------------------------------- 1 | {"last_node_id":110,"last_link_id":304,"nodes":[{"id":11,"type":"DualCLIPLoader","pos":[0,270],"size":[350,122],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"CLIP","localized_name":"CLIP","type":"CLIP","shape":3,"links":[205,232,233],"slot_index":0}],"properties":{"Node name for S&R":"DualCLIPLoader"},"widgets_values":["clip_l.safetensors","llava_llama3_fp8_scaled.safetensors","hunyuan_video","default"]},{"id":81,"type":"SaveAnimatedWEBP","pos":[1430,620],"size":[315,366],"flags":{},"order":19,"mode":0,"inputs":[{"name":"images","localized_name":"images","type":"IMAGE","link":218}],"outputs":[],"properties":{},"widgets_values":["ComfyUI",24,false,80,"default",""]},{"id":83,"type":"SamplerCustomAdvanced","pos":[890,890],"size":[272.3617858886719,124.53733825683594],"flags":{},"order":15,"mode":0,"inputs":[{"name":"noise","localized_name":"noise","type":"NOISE","link":224,"slot_index":0},{"name":"guider","localized_name":"guider","type":"GUIDER","link":221,"slot_index":1},{"name":"sampler","localized_name":"sampler","type":"SAMPLER","link":223,"slot_index":2},{"name":"sigmas","localized_name":"sigmas","type":"SIGMAS","link":222,"slot_index":3},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":225,"slot_index":4}],"outputs":[{"name":"output","localized_name":"output","type":"LATENT","shape":3,"links":[220],"slot_index":0},{"name":"denoised_output","localized_name":"denoised_output","type":"LATENT","shape":3,"links":null}],"properties":{"Node name for S&R":"SamplerCustomAdvanced"},"widgets_values":[]},{"id":84,"type":"VAEDecodeTiled","pos":[1190,900],"size":[210,150],"flags":{},"order":17,"mode":0,"inputs":[{"name":"samples","localized_name":"samples","type":"LATENT","link":220},{"name":"vae","localized_name":"vae","type":"VAE","link":231}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","type":"IMAGE","links":[218],"slot_index":0}],"properties":{"Node name for S&R":"VAEDecodeTiled"},"widgets_values":[128,32,32,4]},{"id":26,"type":"FluxGuidance","pos":[500,90],"size":[317.4000244140625,58],"flags":{},"order":11,"mode":0,"inputs":[{"name":"conditioning","localized_name":"conditioning","type":"CONDITIONING","link":175}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","type":"CONDITIONING","shape":3,"links":[259],"slot_index":0}],"properties":{"Node name for S&R":"FluxGuidance"},"widgets_values":[6],"color":"#233","bgcolor":"#355"},{"id":105,"type":"SamplerCustomAdvanced","pos":[890,440],"size":[272.3617858886719,124.53733825683594],"flags":{},"order":14,"mode":0,"inputs":[{"name":"noise","localized_name":"noise","type":"NOISE","link":289,"slot_index":0},{"name":"guider","localized_name":"guider","type":"GUIDER","link":279,"slot_index":1},{"name":"sampler","localized_name":"sampler","type":"SAMPLER","link":288,"slot_index":2},{"name":"sigmas","localized_name":"sigmas","type":"SIGMAS","link":291,"slot_index":3},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":290,"slot_index":4}],"outputs":[{"name":"output","localized_name":"output","type":"LATENT","shape":3,"links":[280],"slot_index":0},{"name":"denoised_output","localized_name":"denoised_output","type":"LATENT","shape":3,"links":null}],"properties":{"Node name for S&R":"SamplerCustomAdvanced"},"widgets_values":[]},{"id":106,"type":"VAEDecodeTiled","pos":[1190,450],"size":[210,150],"flags":{},"order":16,"mode":0,"inputs":[{"name":"samples","localized_name":"samples","type":"LATENT","link":280},{"name":"vae","localized_name":"vae","type":"VAE","link":287}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","type":"IMAGE","links":[281],"slot_index":0}],"properties":{"Node name for S&R":"VAEDecodeTiled"},"widgets_values":[128,32,32,4]},{"id":107,"type":"SaveAnimatedWEBP","pos":[1430,170],"size":[315,366],"flags":{},"order":18,"mode":0,"inputs":[{"name":"images","localized_name":"images","type":"IMAGE","link":281}],"outputs":[],"properties":{},"widgets_values":["ComfyUI",24,false,80,"default",""]},{"id":67,"type":"ModelSamplingSD3","pos":[100,0],"size":[210,58],"flags":{},"order":9,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":303}],"outputs":[{"name":"MODEL","localized_name":"MODEL","type":"MODEL","links":[260,282],"slot_index":0}],"properties":{"Node name for S&R":"ModelSamplingSD3"},"widgets_values":[7]},{"id":85,"type":"CLIPTextEncode","pos":[410,870],"size":[422.84503173828125,164.31304931640625],"flags":{"collapsed":true},"order":7,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","type":"CLIP","link":232}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","type":"CONDITIONING","links":[227,284],"slot_index":0}],"title":"CLIP Text Encode (Negative Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":[""],"color":"#322","bgcolor":"#533"},{"id":86,"type":"CLIPTextEncode","pos":[410,940],"size":[422.84503173828125,164.31304931640625],"flags":{},"order":8,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","type":"CLIP","link":233}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","type":"CONDITIONING","links":[234,285],"slot_index":0}],"title":"CLIP Text Encode (NAG Negative Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"],"color":"#322","bgcolor":"#533"},{"id":45,"type":"EmptyHunyuanLatentVideo","pos":[475.540771484375,432.673583984375],"size":[315,130],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[{"name":"LATENT","localized_name":"LATENT","type":"LATENT","links":[225,226,286,290],"slot_index":0}],"properties":{"Node name for S&R":"EmptyHunyuanLatentVideo"},"widgets_values":[832,480,73,1]},{"id":16,"type":"KSamplerSelect","pos":[484,751],"size":[315,58],"flags":{},"order":2,"mode":0,"inputs":[],"outputs":[{"name":"SAMPLER","localized_name":"SAMPLER","type":"SAMPLER","shape":3,"links":[223,288],"slot_index":0}],"properties":{"Node name for S&R":"KSamplerSelect"},"widgets_values":["euler"]},{"id":17,"type":"BasicScheduler","pos":[870,0],"size":[315,106],"flags":{},"order":10,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":304,"slot_index":0}],"outputs":[{"name":"SIGMAS","localized_name":"SIGMAS","type":"SIGMAS","shape":3,"links":[222,291],"slot_index":0}],"properties":{"Node name for S&R":"BasicScheduler"},"widgets_values":["normal",25,1]},{"id":10,"type":"VAELoader","pos":[0,420],"size":[350,60],"flags":{},"order":3,"mode":0,"inputs":[],"outputs":[{"name":"VAE","localized_name":"VAE","type":"VAE","shape":3,"links":[231,287],"slot_index":0}],"properties":{"Node name for S&R":"VAELoader"},"widgets_values":["hunyuan_video_vae_bf16.safetensors"]},{"id":25,"type":"RandomNoise","pos":[479,618],"size":[315,82],"flags":{},"order":4,"mode":0,"inputs":[],"outputs":[{"name":"NOISE","localized_name":"NOISE","type":"NOISE","shape":3,"links":[224,289],"slot_index":0}],"properties":{"Node name for S&R":"RandomNoise"},"widgets_values":[306212149788532,"randomize"],"color":"#2a363b","bgcolor":"#3f5159"},{"id":79,"type":"NAGCFGGuider","pos":[870,610],"size":[315,234],"flags":{},"order":13,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":260},{"name":"positive","localized_name":"positive","type":"CONDITIONING","link":259},{"name":"negative","localized_name":"negative","type":"CONDITIONING","link":227},{"name":"nag_negative","localized_name":"nag_negative","type":"CONDITIONING","link":234},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":226}],"outputs":[{"name":"GUIDER","localized_name":"GUIDER","type":"GUIDER","links":[221],"slot_index":0}],"properties":{"Node name for S&R":"NAGCFGGuider"},"widgets_values":[1,1,2.5,0.12,0]},{"id":104,"type":"NAGCFGGuider","pos":[870,160],"size":[315,234],"flags":{},"order":12,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":282},{"name":"positive","localized_name":"positive","type":"CONDITIONING","link":283},{"name":"negative","localized_name":"negative","type":"CONDITIONING","link":284},{"name":"nag_negative","localized_name":"nag_negative","type":"CONDITIONING","link":285},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":286}],"outputs":[{"name":"GUIDER","localized_name":"GUIDER","type":"GUIDER","links":[279],"slot_index":0}],"properties":{"Node name for S&R":"NAGCFGGuider"},"widgets_values":[1,9,2.5,0.12,0.9]},{"id":44,"type":"CLIPTextEncode","pos":[420,200],"size":[422.84503173828125,164.31304931640625],"flags":{},"order":6,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","type":"CLIP","link":205}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","type":"CONDITIONING","links":[175,283],"slot_index":0}],"title":"CLIP Text Encode (Positive Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["A ginger cat passionately plays eletric guitar with intensity and emotion on a stage. The background is shrouded in deep darkness. Spotlights casts dramatic shadows."],"color":"#232","bgcolor":"#353"},{"id":110,"type":"UNETLoader","pos":[0,130],"size":[350,82],"flags":{},"order":5,"mode":0,"inputs":[],"outputs":[{"name":"MODEL","localized_name":"MODEL","type":"MODEL","shape":3,"links":[303,304],"slot_index":0}],"properties":{"Node name for S&R":"UNETLoader"},"widgets_values":["hunyuan_video_t2v_720p_bf16.safetensors","default"],"color":"#223","bgcolor":"#335"}],"links":[[175,44,0,26,0,"CONDITIONING"],[205,11,0,44,0,"CLIP"],[218,84,0,81,0,"IMAGE"],[220,83,0,84,0,"LATENT"],[221,79,0,83,1,"GUIDER"],[222,17,0,83,3,"SIGMAS"],[223,16,0,83,2,"SAMPLER"],[224,25,0,83,0,"NOISE"],[225,45,0,83,4,"LATENT"],[226,45,0,79,4,"LATENT"],[227,85,0,79,2,"CONDITIONING"],[231,10,0,84,1,"VAE"],[232,11,0,85,0,"CLIP"],[233,11,0,86,0,"CLIP"],[234,86,0,79,3,"CONDITIONING"],[259,26,0,79,1,"CONDITIONING"],[260,67,0,79,0,"MODEL"],[279,104,0,105,1,"GUIDER"],[280,105,0,106,0,"LATENT"],[281,106,0,107,0,"IMAGE"],[282,67,0,104,0,"MODEL"],[283,44,0,104,1,"CONDITIONING"],[284,85,0,104,2,"CONDITIONING"],[285,86,0,104,3,"CONDITIONING"],[286,45,0,104,4,"LATENT"],[287,10,0,106,1,"VAE"],[288,16,0,105,2,"SAMPLER"],[289,25,0,105,0,"NOISE"],[290,45,0,105,4,"LATENT"],[291,17,0,105,3,"SIGMAS"],[303,110,0,67,0,"MODEL"],[304,110,0,17,0,"MODEL"]],"groups":[],"config":{},"extra":{"ds":{"scale":1.1360692931284764,"offset":[180.73724101394097,209.16119814389862]},"groupNodes":{},"VHS_latentpreview":false,"VHS_latentpreviewrate":0,"VHS_MetadataImage":true,"VHS_KeepIntermediate":true},"version":0.4} -------------------------------------------------------------------------------- /workflows/NAG-SD15-ComfyUI-Workflow.json: -------------------------------------------------------------------------------- 1 | {"last_node_id":46,"last_link_id":92,"nodes":[{"id":33,"type":"SamplerCustomAdvanced","pos":[1290,40],"size":[355.20001220703125,106],"flags":{},"order":10,"mode":0,"inputs":[{"name":"noise","localized_name":"noise","label":"noise","type":"NOISE","link":52,"slot_index":0},{"name":"guider","localized_name":"guider","label":"guider","type":"GUIDER","link":53,"slot_index":1},{"name":"sampler","localized_name":"sampler","label":"sampler","type":"SAMPLER","link":56,"slot_index":2},{"name":"sigmas","localized_name":"sigmas","label":"sigmas","type":"SIGMAS","link":55,"slot_index":3},{"name":"latent_image","localized_name":"latent_image","label":"latent_image","type":"LATENT","link":54,"slot_index":4}],"outputs":[{"name":"output","localized_name":"output","label":"output","type":"LATENT","shape":3,"links":[51],"slot_index":0},{"name":"denoised_output","localized_name":"denoised_output","label":"denoised_output","type":"LATENT","shape":3,"links":null}],"properties":{"Node name for S&R":"SamplerCustomAdvanced"},"widgets_values":[]},{"id":34,"type":"VAEDecode","pos":[1290,200],"size":[210,46],"flags":{},"order":12,"mode":0,"inputs":[{"name":"samples","localized_name":"samples","label":"samples","type":"LATENT","link":51},{"name":"vae","localized_name":"vae","label":"vae","type":"VAE","link":92}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","label":"IMAGE","type":"IMAGE","links":[50],"slot_index":0}],"properties":{"Node name for S&R":"VAEDecode"},"widgets_values":[]},{"id":43,"type":"SaveImage","pos":[1700,430],"size":[450,490],"flags":{},"order":15,"mode":0,"inputs":[{"name":"images","localized_name":"images","label":"images","type":"IMAGE","link":71}],"outputs":[],"properties":{},"widgets_values":["ComfyUI"]},{"id":36,"type":"SaveImage","pos":[1700,-120],"size":[450,490],"flags":{},"order":14,"mode":0,"inputs":[{"name":"images","localized_name":"images","label":"images","type":"IMAGE","link":50}],"outputs":[],"properties":{},"widgets_values":["ComfyUI"]},{"id":41,"type":"SamplerCustomAdvanced","pos":[1290,650],"size":[355.20001220703125,106],"flags":{},"order":11,"mode":0,"inputs":[{"name":"noise","localized_name":"noise","label":"noise","type":"NOISE","link":78,"slot_index":0},{"name":"guider","localized_name":"guider","label":"guider","type":"GUIDER","link":68,"slot_index":1},{"name":"sampler","localized_name":"sampler","label":"sampler","type":"SAMPLER","link":77,"slot_index":2},{"name":"sigmas","localized_name":"sigmas","label":"sigmas","type":"SIGMAS","link":76,"slot_index":3},{"name":"latent_image","localized_name":"latent_image","label":"latent_image","type":"LATENT","link":79,"slot_index":4}],"outputs":[{"name":"output","localized_name":"output","label":"output","type":"LATENT","shape":3,"links":[69],"slot_index":0},{"name":"denoised_output","localized_name":"denoised_output","label":"denoised_output","type":"LATENT","shape":3,"links":null}],"properties":{"Node name for S&R":"SamplerCustomAdvanced"},"widgets_values":[]},{"id":42,"type":"VAEDecode","pos":[1290,540],"size":[210,46],"flags":{},"order":13,"mode":0,"inputs":[{"name":"samples","localized_name":"samples","label":"samples","type":"LATENT","link":69},{"name":"vae","localized_name":"vae","label":"vae","type":"VAE","link":91}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","label":"IMAGE","type":"IMAGE","links":[71],"slot_index":0}],"properties":{"Node name for S&R":"VAEDecode"},"widgets_values":[]},{"id":6,"type":"CLIPTextEncode","pos":[375,221],"size":[422.84503173828125,164.31304931640625],"flags":{},"order":7,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","label":"clip","type":"CLIP","link":90}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","label":"CONDITIONING","type":"CONDITIONING","links":[47,72],"slot_index":0}],"properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["A beautiful cyborg."],"color":"#232","bgcolor":"#353"},{"id":32,"type":"CLIPTextEncode","pos":[380,0],"size":[422.84503173828125,164.31304931640625],"flags":{},"order":6,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","label":"clip","type":"CLIP","link":89}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","label":"CONDITIONING","type":"CONDITIONING","links":[49,73],"slot_index":0}],"title":"CLIP Text Encode (NAG Negative Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["Robot."],"color":"#322","bgcolor":"#533"},{"id":31,"type":"CLIPTextEncode","pos":[380,-230],"size":[422.84503173828125,164.31304931640625],"flags":{},"order":5,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","label":"clip","type":"CLIP","link":88}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","label":"CONDITIONING","type":"CONDITIONING","links":[48,74],"slot_index":0}],"title":"CLIP Text Encode (Negative Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":[""],"color":"#322","bgcolor":"#533"},{"id":40,"type":"NAGCFGGuider","pos":[930,460],"size":[315,234],"flags":{},"order":9,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":83},{"name":"positive","localized_name":"positive","type":"CONDITIONING","link":72},{"name":"negative","localized_name":"negative","type":"CONDITIONING","link":74},{"name":"nag_negative","localized_name":"nag_negative","type":"CONDITIONING","link":73},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":80}],"outputs":[{"name":"GUIDER","localized_name":"GUIDER","type":"GUIDER","links":[68],"slot_index":0}],"properties":{"Node name for S&R":"NAGCFGGuider"},"widgets_values":[7,1,2.5,0.5,0]},{"id":5,"type":"EmptyLatentImage","pos":[480,432],"size":[315,106],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"LATENT","localized_name":"LATENT","label":"LATENT","type":"LATENT","links":[54,79,80,81],"slot_index":0}],"properties":{"Node name for S&R":"EmptyLatentImage"},"widgets_values":[512,512,1],"color":"#323","bgcolor":"#535"},{"id":45,"type":"CheckpointLoaderSimple","pos":[-30,-70],"size":[315,98],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[{"name":"MODEL","localized_name":"MODEL","type":"MODEL","links":[82,83,84],"slot_index":0},{"name":"CLIP","localized_name":"CLIP","type":"CLIP","links":[88,89,90],"slot_index":1},{"name":"VAE","localized_name":"VAE","type":"VAE","links":[91,92],"slot_index":2}],"properties":{"Node name for S&R":"CheckpointLoaderSimple"},"widgets_values":["v1-5-pruned-emaonly.safetensors"]},{"id":16,"type":"KSamplerSelect","pos":[480,720],"size":[315,58],"flags":{},"order":2,"mode":0,"inputs":[],"outputs":[{"name":"SAMPLER","localized_name":"SAMPLER","label":"SAMPLER","type":"SAMPLER","shape":3,"links":[56,77],"slot_index":0}],"properties":{"Node name for S&R":"KSamplerSelect"},"widgets_values":["euler"]},{"id":17,"type":"BasicScheduler","pos":[480,816],"size":[315,106],"flags":{},"order":4,"mode":0,"inputs":[{"name":"model","localized_name":"model","label":"model","type":"MODEL","link":84,"slot_index":0}],"outputs":[{"name":"SIGMAS","localized_name":"SIGMAS","label":"SIGMAS","type":"SIGMAS","shape":3,"links":[55,76],"slot_index":0}],"properties":{"Node name for S&R":"BasicScheduler"},"widgets_values":["simple",20,1]},{"id":30,"type":"NAGCFGGuider","pos":[930,180],"size":[315,234],"flags":{},"order":8,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":82},{"name":"positive","localized_name":"positive","type":"CONDITIONING","link":47},{"name":"negative","localized_name":"negative","type":"CONDITIONING","link":48},{"name":"nag_negative","localized_name":"nag_negative","type":"CONDITIONING","link":49},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":81}],"outputs":[{"name":"GUIDER","localized_name":"GUIDER","type":"GUIDER","links":[53],"slot_index":0}],"properties":{"Node name for S&R":"NAGCFGGuider"},"widgets_values":[7,5,2.5,0.38,4]},{"id":25,"type":"RandomNoise","pos":[480,576],"size":[315,82],"flags":{},"order":3,"mode":0,"inputs":[],"outputs":[{"name":"NOISE","localized_name":"NOISE","label":"NOISE","type":"NOISE","shape":3,"links":[52,78],"slot_index":0}],"properties":{"Node name for S&R":"RandomNoise"},"widgets_values":[245181149383731,"randomize"],"color":"#2a363b","bgcolor":"#3f5159"}],"links":[[47,6,0,30,1,"CONDITIONING"],[48,31,0,30,2,"CONDITIONING"],[49,32,0,30,3,"CONDITIONING"],[50,34,0,36,0,"IMAGE"],[51,33,0,34,0,"LATENT"],[52,25,0,33,0,"NOISE"],[53,30,0,33,1,"GUIDER"],[54,5,0,33,4,"LATENT"],[55,17,0,33,3,"SIGMAS"],[56,16,0,33,2,"SAMPLER"],[68,40,0,41,1,"GUIDER"],[69,41,0,42,0,"LATENT"],[71,42,0,43,0,"IMAGE"],[72,6,0,40,1,"CONDITIONING"],[73,32,0,40,3,"CONDITIONING"],[74,31,0,40,2,"CONDITIONING"],[76,17,0,41,3,"SIGMAS"],[77,16,0,41,2,"SAMPLER"],[78,25,0,41,0,"NOISE"],[79,5,0,41,4,"LATENT"],[80,5,0,40,4,"LATENT"],[81,5,0,30,4,"LATENT"],[82,45,0,30,0,"MODEL"],[83,45,0,40,0,"MODEL"],[84,45,0,17,0,"MODEL"],[88,45,1,31,0,"CLIP"],[89,45,1,32,0,"CLIP"],[90,45,1,6,0,"CLIP"],[91,45,2,42,1,"VAE"],[92,45,2,34,1,"VAE"]],"groups":[],"config":{},"extra":{"ds":{"scale":0.8264462809917354,"offset":[188.41109074559512,310.2033848763747]},"VHS_latentpreview":false,"VHS_latentpreviewrate":0,"VHS_MetadataImage":true,"VHS_KeepIntermediate":true},"version":0.4} -------------------------------------------------------------------------------- /workflows/NAG-SD3.5-Turbo-ComfyUI-Workflow.json: -------------------------------------------------------------------------------- 1 | {"last_node_id":46,"last_link_id":89,"nodes":[{"id":33,"type":"SamplerCustomAdvanced","pos":[1290,40],"size":[355.20001220703125,106],"flags":{},"order":11,"mode":0,"inputs":[{"name":"noise","localized_name":"noise","label":"noise","type":"NOISE","link":52,"slot_index":0},{"name":"guider","localized_name":"guider","label":"guider","type":"GUIDER","link":53,"slot_index":1},{"name":"sampler","localized_name":"sampler","label":"sampler","type":"SAMPLER","link":56,"slot_index":2},{"name":"sigmas","localized_name":"sigmas","label":"sigmas","type":"SIGMAS","link":55,"slot_index":3},{"name":"latent_image","localized_name":"latent_image","label":"latent_image","type":"LATENT","link":54,"slot_index":4}],"outputs":[{"name":"output","localized_name":"output","label":"output","type":"LATENT","shape":3,"links":[51],"slot_index":0},{"name":"denoised_output","localized_name":"denoised_output","label":"denoised_output","type":"LATENT","shape":3,"links":null}],"properties":{"Node name for S&R":"SamplerCustomAdvanced"},"widgets_values":[]},{"id":34,"type":"VAEDecode","pos":[1290,200],"size":[210,46],"flags":{},"order":13,"mode":0,"inputs":[{"name":"samples","localized_name":"samples","label":"samples","type":"LATENT","link":51},{"name":"vae","localized_name":"vae","label":"vae","type":"VAE","link":83}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","label":"IMAGE","type":"IMAGE","links":[50],"slot_index":0}],"properties":{"Node name for S&R":"VAEDecode"},"widgets_values":[]},{"id":43,"type":"SaveImage","pos":[1700,430],"size":[450,490],"flags":{},"order":16,"mode":0,"inputs":[{"name":"images","localized_name":"images","label":"images","type":"IMAGE","link":71}],"outputs":[],"properties":{},"widgets_values":["ComfyUI"]},{"id":36,"type":"SaveImage","pos":[1700,-120],"size":[450,490],"flags":{},"order":15,"mode":0,"inputs":[{"name":"images","localized_name":"images","label":"images","type":"IMAGE","link":50}],"outputs":[],"properties":{},"widgets_values":["ComfyUI"]},{"id":41,"type":"SamplerCustomAdvanced","pos":[1290,650],"size":[355.20001220703125,106],"flags":{},"order":12,"mode":0,"inputs":[{"name":"noise","localized_name":"noise","label":"noise","type":"NOISE","link":78,"slot_index":0},{"name":"guider","localized_name":"guider","label":"guider","type":"GUIDER","link":68,"slot_index":1},{"name":"sampler","localized_name":"sampler","label":"sampler","type":"SAMPLER","link":77,"slot_index":2},{"name":"sigmas","localized_name":"sigmas","label":"sigmas","type":"SIGMAS","link":76,"slot_index":3},{"name":"latent_image","localized_name":"latent_image","label":"latent_image","type":"LATENT","link":79,"slot_index":4}],"outputs":[{"name":"output","localized_name":"output","label":"output","type":"LATENT","shape":3,"links":[69],"slot_index":0},{"name":"denoised_output","localized_name":"denoised_output","label":"denoised_output","type":"LATENT","shape":3,"links":null}],"properties":{"Node name for S&R":"SamplerCustomAdvanced"},"widgets_values":[]},{"id":42,"type":"VAEDecode","pos":[1290,540],"size":[210,46],"flags":{},"order":14,"mode":0,"inputs":[{"name":"samples","localized_name":"samples","label":"samples","type":"LATENT","link":69},{"name":"vae","localized_name":"vae","label":"vae","type":"VAE","link":84}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","label":"IMAGE","type":"IMAGE","links":[71],"slot_index":0}],"properties":{"Node name for S&R":"VAEDecode"},"widgets_values":[]},{"id":6,"type":"CLIPTextEncode","pos":[375,221],"size":[422.84503173828125,164.31304931640625],"flags":{},"order":5,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","label":"clip","type":"CLIP","link":85}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","label":"CONDITIONING","type":"CONDITIONING","links":[47,72],"slot_index":0}],"properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["A beautiful cyborg."],"color":"#232","bgcolor":"#353"},{"id":32,"type":"CLIPTextEncode","pos":[380,0],"size":[422.84503173828125,164.31304931640625],"flags":{},"order":6,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","label":"clip","type":"CLIP","link":86}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","label":"CONDITIONING","type":"CONDITIONING","links":[49,73],"slot_index":0}],"title":"CLIP Text Encode (NAG Negative Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["Robot."],"color":"#322","bgcolor":"#533"},{"id":31,"type":"CLIPTextEncode","pos":[380,-230],"size":[422.84503173828125,164.31304931640625],"flags":{},"order":7,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","label":"clip","type":"CLIP","link":87}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","label":"CONDITIONING","type":"CONDITIONING","links":[48,74],"slot_index":0}],"title":"CLIP Text Encode (Negative Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":[""],"color":"#322","bgcolor":"#533"},{"id":16,"type":"KSamplerSelect","pos":[480,720],"size":[315,58],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"SAMPLER","localized_name":"SAMPLER","label":"SAMPLER","type":"SAMPLER","shape":3,"links":[56,77],"slot_index":0}],"properties":{"Node name for S&R":"KSamplerSelect"},"widgets_values":["euler"]},{"id":46,"type":"TripleCLIPLoader","pos":[30,230],"size":[315,106],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[{"name":"CLIP","localized_name":"CLIP","type":"CLIP","shape":3,"links":[85,86,87],"slot_index":0}],"properties":{"Node name for S&R":"TripleCLIPLoader"},"widgets_values":["clip_l.safetensors","clip_g.safetensors","t5xxl_fp16.safetensors"]},{"id":44,"type":"CheckpointLoaderSimple","pos":[-40,80],"size":[384.75592041015625,98],"flags":{},"order":2,"mode":0,"inputs":[],"outputs":[{"name":"MODEL","localized_name":"MODEL","type":"MODEL","links":[80,81,82],"slot_index":0},{"name":"CLIP","localized_name":"CLIP","type":"CLIP","links":[],"slot_index":1},{"name":"VAE","localized_name":"VAE","type":"VAE","links":[83,84],"slot_index":2}],"properties":{"Node name for S&R":"CheckpointLoaderSimple"},"widgets_values":["sd3.5_large_turbo.safetensors"]},{"id":17,"type":"BasicScheduler","pos":[480,816],"size":[315,106],"flags":{},"order":8,"mode":0,"inputs":[{"name":"model","localized_name":"model","label":"model","type":"MODEL","link":82,"slot_index":0}],"outputs":[{"name":"SIGMAS","localized_name":"SIGMAS","label":"SIGMAS","type":"SIGMAS","shape":3,"links":[55,76],"slot_index":0}],"properties":{"Node name for S&R":"BasicScheduler"},"widgets_values":["simple",8,1]},{"id":25,"type":"RandomNoise","pos":[480,576],"size":[315,82],"flags":{},"order":3,"mode":0,"inputs":[],"outputs":[{"name":"NOISE","localized_name":"NOISE","label":"NOISE","type":"NOISE","shape":3,"links":[52,78],"slot_index":0}],"properties":{"Node name for S&R":"RandomNoise"},"widgets_values":[638191395580604,"randomize"],"color":"#2a363b","bgcolor":"#3f5159"},{"id":5,"type":"EmptyLatentImage","pos":[480,432],"size":[315,106],"flags":{},"order":4,"mode":0,"inputs":[],"outputs":[{"name":"LATENT","localized_name":"LATENT","label":"LATENT","type":"LATENT","links":[54,79,88,89],"slot_index":0}],"properties":{"Node name for S&R":"EmptyLatentImage"},"widgets_values":[1024,1024,1],"color":"#323","bgcolor":"#535"},{"id":30,"type":"NAGCFGGuider","pos":[930,180],"size":[315,234],"flags":{},"order":9,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":80},{"name":"positive","localized_name":"positive","type":"CONDITIONING","link":47},{"name":"negative","localized_name":"negative","type":"CONDITIONING","link":48},{"name":"nag_negative","localized_name":"nag_negative","type":"CONDITIONING","link":49},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":89}],"outputs":[{"name":"GUIDER","localized_name":"GUIDER","type":"GUIDER","links":[53],"slot_index":0}],"properties":{"Node name for S&R":"NAGCFGGuider"},"widgets_values":[1,5,2.5,0.25,0.75]},{"id":40,"type":"NAGCFGGuider","pos":[930,460],"size":[315,234],"flags":{},"order":10,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":81},{"name":"positive","localized_name":"positive","type":"CONDITIONING","link":72},{"name":"negative","localized_name":"negative","type":"CONDITIONING","link":74},{"name":"nag_negative","localized_name":"nag_negative","type":"CONDITIONING","link":73},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":88}],"outputs":[{"name":"GUIDER","localized_name":"GUIDER","type":"GUIDER","links":[68],"slot_index":0}],"properties":{"Node name for S&R":"NAGCFGGuider"},"widgets_values":[1,1,2.5,0.25,0]}],"links":[[47,6,0,30,1,"CONDITIONING"],[48,31,0,30,2,"CONDITIONING"],[49,32,0,30,3,"CONDITIONING"],[50,34,0,36,0,"IMAGE"],[51,33,0,34,0,"LATENT"],[52,25,0,33,0,"NOISE"],[53,30,0,33,1,"GUIDER"],[54,5,0,33,4,"LATENT"],[55,17,0,33,3,"SIGMAS"],[56,16,0,33,2,"SAMPLER"],[68,40,0,41,1,"GUIDER"],[69,41,0,42,0,"LATENT"],[71,42,0,43,0,"IMAGE"],[72,6,0,40,1,"CONDITIONING"],[73,32,0,40,3,"CONDITIONING"],[74,31,0,40,2,"CONDITIONING"],[76,17,0,41,3,"SIGMAS"],[77,16,0,41,2,"SAMPLER"],[78,25,0,41,0,"NOISE"],[79,5,0,41,4,"LATENT"],[80,44,0,30,0,"MODEL"],[81,44,0,40,0,"MODEL"],[82,44,0,17,0,"MODEL"],[83,44,2,34,1,"VAE"],[84,44,2,42,1,"VAE"],[85,46,0,6,0,"CLIP"],[86,46,0,32,0,"CLIP"],[87,46,0,31,0,"CLIP"],[88,5,0,40,4,"LATENT"],[89,5,0,30,4,"LATENT"]],"groups":[],"config":{},"extra":{"ds":{"scale":0.8535456747772172,"offset":[78.4028938416164,194.71914642501017]}},"version":0.4} -------------------------------------------------------------------------------- /workflows/NAG-Wan-ComfyUI-Workflow.json: -------------------------------------------------------------------------------- 1 | {"last_node_id":64,"last_link_id":142,"nodes":[{"id":50,"type":"SaveAnimatedWEBP","pos":[1280,-490],"size":[600,460],"flags":{},"order":13,"mode":0,"inputs":[{"name":"images","localized_name":"images","type":"IMAGE","link":100}],"outputs":[],"properties":{},"widgets_values":["ComfyUI",16,false,90,"default",""]},{"id":51,"type":"VAEDecode","pos":[970,-500],"size":[210,46],"flags":{},"order":11,"mode":0,"inputs":[{"name":"samples","localized_name":"samples","type":"LATENT","link":131},{"name":"vae","localized_name":"vae","type":"VAE","link":102}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","type":"IMAGE","links":[100],"slot_index":0}],"properties":{"Node name for S&R":"VAEDecode"},"widgets_values":[]},{"id":56,"type":"VAEDecode","pos":[970,70],"size":[210,46],"flags":{},"order":10,"mode":0,"inputs":[{"name":"samples","localized_name":"samples","type":"LATENT","link":132},{"name":"vae","localized_name":"vae","type":"VAE","link":117}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","type":"IMAGE","links":[116],"slot_index":0}],"properties":{"Node name for S&R":"VAEDecode"},"widgets_values":[]},{"id":58,"type":"SaveAnimatedWEBP","pos":[1280,80],"size":[600,460],"flags":{},"order":12,"mode":0,"inputs":[{"name":"images","localized_name":"images","type":"IMAGE","link":116}],"outputs":[],"properties":{},"widgets_values":["ComfyUI",16,false,90,"default",""]},{"id":39,"type":"VAELoader","pos":[20,250],"size":[330,60],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"VAE","localized_name":"VAE","type":"VAE","links":[102,117],"slot_index":0}],"properties":{"Node name for S&R":"VAELoader","models":[{"name":"wan_2.1_vae.safetensors","url":"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/vae/wan_2.1_vae.safetensors?download=true","directory":"vae"}]},"widgets_values":["wan_2.1_vae.safetensors"],"color":"#322","bgcolor":"#533"},{"id":38,"type":"CLIPLoader","pos":[20,100],"size":[330,100],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[{"name":"CLIP","localized_name":"CLIP","type":"CLIP","links":[74,75,104],"slot_index":0}],"properties":{"Node name for S&R":"CLIPLoader","models":[{"name":"umt5_xxl_fp8_e4m3fn_scaled.safetensors","url":"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors?download=true","directory":"text_encoders"}]},"widgets_values":["umt5_xxl_fp8_e4m3fn_scaled.safetensors","wan","default"],"color":"#322","bgcolor":"#533"},{"id":61,"type":"UNETLoader","pos":[30,-60],"size":[315,82],"flags":{},"order":2,"mode":0,"inputs":[],"outputs":[{"name":"MODEL","localized_name":"MODEL","type":"MODEL","links":[128],"slot_index":0}],"properties":{"Node name for S&R":"UNETLoader"},"widgets_values":["wan2.1_t2v_1.3B_fp16.safetensors","default"]},{"id":64,"type":"KSamplerWithNAG","pos":[870,-370],"size":[315,378],"flags":{},"order":9,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":142},{"name":"positive","localized_name":"positive","type":"CONDITIONING","link":140},{"name":"negative","localized_name":"negative","type":"CONDITIONING","link":141},{"name":"nag_negative","localized_name":"nag_negative","type":"CONDITIONING","link":139},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":138}],"outputs":[{"name":"LATENT","localized_name":"LATENT","type":"LATENT","links":[131],"slot_index":0}],"properties":{"Node name for S&R":"KSamplerWithNAG"},"widgets_values":[2025,"fixed",20,8,6,2.5,0.25,0,"uni_pc","simple",1]},{"id":40,"type":"EmptyHunyuanLatentVideo","pos":[30,390],"size":[340,130],"flags":{},"order":3,"mode":0,"inputs":[],"outputs":[{"name":"LATENT","localized_name":"LATENT","type":"LATENT","links":[137,138],"slot_index":0}],"properties":{"Node name for S&R":"EmptyHunyuanLatentVideo"},"widgets_values":[832,480,33,1],"color":"#322","bgcolor":"#533"},{"id":52,"type":"CLIPTextEncode","pos":[450,-280],"size":[340,100],"flags":{},"order":6,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","type":"CLIP","link":104}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","type":"CONDITIONING","links":[135,139],"slot_index":0}],"title":"CLIP Text Encode (NAG Negative Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"],"color":"#323","bgcolor":"#535"},{"id":6,"type":"CLIPTextEncode","pos":[450,90],"size":[340,120],"flags":{},"order":4,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","type":"CLIP","link":74}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","type":"CONDITIONING","links":[133,140],"slot_index":0}],"title":"CLIP Text Encode (Positive Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["a majestic old white-robed wizard casting a spell under a starlit sky, standing on an ancient stone altar in a ruined medieval forest temple, glowing magic symbols, celestial energy swirling around, long silver beard, ornate staff with glowing crystal, cinematic lighting, volumetric fog, fantasy atmosphere, ultra detailed, 4K, highly realistic, by greg rutkowski, artgerm, cinematic fantasy, animation of swirling energy, slow motion magical aura forming, glowing runes pulsing, cloak flowing in the wind"],"color":"#232","bgcolor":"#353"},{"id":7,"type":"CLIPTextEncode","pos":[460,250],"size":[340,100],"flags":{},"order":5,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","type":"CLIP","link":75}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","type":"CONDITIONING","links":[134,141],"slot_index":0}],"title":"CLIP Text Encode (Negative Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["low quality, blurry, ugly, poorly drawn hands, deformed face, extra limbs, bad anatomy, low resolution, disfigured, unrealistic, cartoonish, watermark, text, signature, distorted proportions, creepy, glitch, jpeg artifacts\n"],"color":"#323","bgcolor":"#535"},{"id":48,"type":"ModelSamplingSD3","pos":[440,-30],"size":[210,58],"flags":{},"order":7,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":128}],"outputs":[{"name":"MODEL","localized_name":"MODEL","type":"MODEL","links":[136,142],"slot_index":0}],"properties":{"Node name for S&R":"ModelSamplingSD3"},"widgets_values":[5]},{"id":63,"type":"KSamplerWithNAG","pos":[880,180],"size":[315,378],"flags":{},"order":8,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":136},{"name":"positive","localized_name":"positive","type":"CONDITIONING","link":133},{"name":"negative","localized_name":"negative","type":"CONDITIONING","link":134},{"name":"nag_negative","localized_name":"nag_negative","type":"CONDITIONING","link":135},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":137}],"outputs":[{"name":"LATENT","localized_name":"LATENT","type":"LATENT","links":[132],"slot_index":0}],"properties":{"Node name for S&R":"KSamplerWithNAG"},"widgets_values":[2025,"fixed",20,8,1,2.5,0.25,0,"uni_pc","simple",1]}],"links":[[74,38,0,6,0,"CLIP"],[75,38,0,7,0,"CLIP"],[100,51,0,50,0,"IMAGE"],[102,39,0,51,1,"VAE"],[104,38,0,52,0,"CLIP"],[116,56,0,58,0,"IMAGE"],[117,39,0,56,1,"VAE"],[128,61,0,48,0,"MODEL"],[131,64,0,51,0,"LATENT"],[132,63,0,56,0,"LATENT"],[133,6,0,63,1,"CONDITIONING"],[134,7,0,63,2,"CONDITIONING"],[135,52,0,63,3,"CONDITIONING"],[136,48,0,63,0,"MODEL"],[137,40,0,63,4,"LATENT"],[138,40,0,64,4,"LATENT"],[139,52,0,64,3,"CONDITIONING"],[140,6,0,64,1,"CONDITIONING"],[141,7,0,64,2,"CONDITIONING"],[142,48,0,64,0,"MODEL"]],"groups":[],"config":{},"extra":{"ds":{"scale":0.8535456747772172,"offset":[347.8670910095441,619.4181528309838]},"node_versions":{"comfy-core":"0.3.27"},"VHS_latentpreview":false,"VHS_latentpreviewrate":0,"VHS_MetadataImage":true,"VHS_KeepIntermediate":true},"version":0.4} -------------------------------------------------------------------------------- /workflows/NAG-Wan-Fast-ComfyUI-Workflow.json: -------------------------------------------------------------------------------- 1 | {"last_node_id":62,"last_link_id":140,"nodes":[{"id":50,"type":"SaveAnimatedWEBP","pos":[1280,-490],"size":[600,460],"flags":{},"order":14,"mode":0,"inputs":[{"name":"images","localized_name":"images","type":"IMAGE","link":100}],"outputs":[],"properties":{},"widgets_values":["ComfyUI",16,false,90,"default",""]},{"id":51,"type":"VAEDecode","pos":[970,-500],"size":[210,46],"flags":{},"order":12,"mode":0,"inputs":[{"name":"samples","localized_name":"samples","type":"LATENT","link":135},{"name":"vae","localized_name":"vae","type":"VAE","link":102}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","type":"IMAGE","links":[100],"slot_index":0}],"properties":{"Node name for S&R":"VAEDecode"},"widgets_values":[]},{"id":56,"type":"VAEDecode","pos":[970,70],"size":[210,46],"flags":{},"order":11,"mode":0,"inputs":[{"name":"samples","localized_name":"samples","type":"LATENT","link":133},{"name":"vae","localized_name":"vae","type":"VAE","link":117}],"outputs":[{"name":"IMAGE","localized_name":"IMAGE","type":"IMAGE","links":[116],"slot_index":0}],"properties":{"Node name for S&R":"VAEDecode"},"widgets_values":[]},{"id":58,"type":"SaveAnimatedWEBP","pos":[1280,80],"size":[600,460],"flags":{},"order":13,"mode":0,"inputs":[{"name":"images","localized_name":"images","type":"IMAGE","link":116}],"outputs":[],"properties":{},"widgets_values":["ComfyUI",16,false,90,"default",""]},{"id":39,"type":"VAELoader","pos":[20,250],"size":[330,60],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"VAE","localized_name":"VAE","type":"VAE","links":[102,117],"slot_index":0}],"properties":{"Node name for S&R":"VAELoader","models":[{"name":"wan_2.1_vae.safetensors","url":"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/vae/wan_2.1_vae.safetensors?download=true","directory":"vae"}]},"widgets_values":["wan_2.1_vae.safetensors"],"color":"#322","bgcolor":"#533"},{"id":38,"type":"CLIPLoader","pos":[20,100],"size":[330,100],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[{"name":"CLIP","localized_name":"CLIP","type":"CLIP","links":[74,75,104,126],"slot_index":0}],"properties":{"Node name for S&R":"CLIPLoader","models":[{"name":"umt5_xxl_fp8_e4m3fn_scaled.safetensors","url":"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors?download=true","directory":"text_encoders"}]},"widgets_values":["umt5_xxl_fp8_e4m3fn_scaled.safetensors","wan","default"],"color":"#322","bgcolor":"#533"},{"id":60,"type":"LoraLoader","pos":[30,-80],"size":[315,126],"flags":{},"order":7,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":124},{"name":"clip","localized_name":"clip","type":"CLIP","link":126}],"outputs":[{"name":"MODEL","localized_name":"MODEL","type":"MODEL","links":[125],"slot_index":0},{"name":"CLIP","localized_name":"CLIP","type":"CLIP","links":null}],"properties":{"Node name for S&R":"LoraLoader"},"widgets_values":["wan/Wan21_CausVid_14B_T2V_lora_rank32_v1_5_no_first_block.safetensors",1,1]},{"id":59,"type":"UnetLoaderGGUF","pos":[30,-200],"size":[315,58],"flags":{},"order":2,"mode":0,"inputs":[],"outputs":[{"name":"MODEL","localized_name":"MODEL","type":"MODEL","links":[124],"slot_index":0}],"properties":{"Node name for S&R":"UnetLoaderGGUF"},"widgets_values":["wan2.1-t2v-14b-Q4_0.gguf"]},{"id":61,"type":"KSamplerWithNAG","pos":[870,200],"size":[315,378],"flags":{},"order":9,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":129},{"name":"positive","localized_name":"positive","type":"CONDITIONING","link":132},{"name":"negative","localized_name":"negative","type":"CONDITIONING","link":131},{"name":"nag_negative","localized_name":"nag_negative","type":"CONDITIONING","link":130},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":134}],"outputs":[{"name":"LATENT","localized_name":"LATENT","type":"LATENT","links":[133],"slot_index":0}],"properties":{"Node name for S&R":"KSamplerWithNAG"},"widgets_values":[2025,"fixed",4,1,1,3.5,0.5,0,"uni_pc","simple",1]},{"id":52,"type":"CLIPTextEncode","pos":[450,-280],"size":[340,100],"flags":{},"order":6,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","type":"CLIP","link":104}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","type":"CONDITIONING","links":[130,136],"slot_index":0}],"title":"CLIP Text Encode (NAG Negative Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"],"color":"#323","bgcolor":"#535"},{"id":7,"type":"CLIPTextEncode","pos":[460,250],"size":[340,100],"flags":{"collapsed":true},"order":5,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","type":"CLIP","link":75}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","type":"CONDITIONING","links":[131,137],"slot_index":0}],"title":"CLIP Text Encode (Negative Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":[""],"color":"#323","bgcolor":"#535"},{"id":6,"type":"CLIPTextEncode","pos":[450,90],"size":[340,120],"flags":{},"order":4,"mode":0,"inputs":[{"name":"clip","localized_name":"clip","type":"CLIP","link":74}],"outputs":[{"name":"CONDITIONING","localized_name":"CONDITIONING","type":"CONDITIONING","links":[132,138],"slot_index":0}],"title":"CLIP Text Encode (Positive Prompt)","properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["Enormous glowing jellyfish float slowly across a sky filled with soft clouds. Their tentacles shimmer with iridescent light as they drift above a peaceful mountain landscape. Magical and dreamlike, captured in a wide shot. Surreal realism style with detailed textures."],"color":"#232","bgcolor":"#353"},{"id":48,"type":"ModelSamplingSD3","pos":[440,-30],"size":[210,58],"flags":{},"order":8,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":125}],"outputs":[{"name":"MODEL","localized_name":"MODEL","type":"MODEL","links":[129,139],"slot_index":0}],"properties":{"Node name for S&R":"ModelSamplingSD3"},"widgets_values":[5]},{"id":40,"type":"EmptyHunyuanLatentVideo","pos":[30,390],"size":[340,130],"flags":{},"order":3,"mode":0,"inputs":[],"outputs":[{"name":"LATENT","localized_name":"LATENT","type":"LATENT","links":[134,140],"slot_index":0}],"properties":{"Node name for S&R":"EmptyHunyuanLatentVideo"},"widgets_values":[832,480,81,1],"color":"#322","bgcolor":"#533"},{"id":62,"type":"KSamplerWithNAG","pos":[870,-380],"size":[315,378],"flags":{},"order":10,"mode":0,"inputs":[{"name":"model","localized_name":"model","type":"MODEL","link":139},{"name":"positive","localized_name":"positive","type":"CONDITIONING","link":138},{"name":"negative","localized_name":"negative","type":"CONDITIONING","link":137},{"name":"nag_negative","localized_name":"nag_negative","type":"CONDITIONING","link":136},{"name":"latent_image","localized_name":"latent_image","type":"LATENT","link":140}],"outputs":[{"name":"LATENT","localized_name":"LATENT","type":"LATENT","links":[135],"slot_index":0}],"properties":{"Node name for S&R":"KSamplerWithNAG"},"widgets_values":[2025,"fixed",4,1,11,3.5,0.5,0.75,"uni_pc","simple",1]}],"links":[[74,38,0,6,0,"CLIP"],[75,38,0,7,0,"CLIP"],[100,51,0,50,0,"IMAGE"],[102,39,0,51,1,"VAE"],[104,38,0,52,0,"CLIP"],[116,56,0,58,0,"IMAGE"],[117,39,0,56,1,"VAE"],[124,59,0,60,0,"MODEL"],[125,60,0,48,0,"MODEL"],[126,38,0,60,1,"CLIP"],[129,48,0,61,0,"MODEL"],[130,52,0,61,3,"CONDITIONING"],[131,7,0,61,2,"CONDITIONING"],[132,6,0,61,1,"CONDITIONING"],[133,61,0,56,0,"LATENT"],[134,40,0,61,4,"LATENT"],[135,62,0,51,0,"LATENT"],[136,52,0,62,3,"CONDITIONING"],[137,7,0,62,2,"CONDITIONING"],[138,6,0,62,1,"CONDITIONING"],[139,48,0,62,0,"MODEL"],[140,40,0,62,4,"LATENT"]],"groups":[],"config":{},"extra":{"ds":{"scale":0.8535456747772172,"offset":[456.03038768805357,619.3842640530432]},"node_versions":{"comfy-core":"0.3.27"},"VHS_latentpreview":false,"VHS_latentpreviewrate":0,"VHS_MetadataImage":true,"VHS_KeepIntermediate":true},"version":0.4} --------------------------------------------------------------------------------