├── luts └── put_luts_files_here.txt ├── fonts ├── put_font_files_here.txt └── ShareTechMono-Regular.ttf ├── requirements.txt ├── .gitignore ├── README.md ├── pyproject.toml ├── .github └── workflows │ └── publish.yml ├── LICENSE ├── js ├── DisplayAny.js └── FluxAttentionSeeker.js ├── __init__.py ├── utils.py ├── segmentation.py ├── histogram_matching.py ├── text.py ├── conditioning.py ├── carve.py ├── workflow_all_nodes.json ├── misc.py ├── mask.py └── sampling.py /luts/put_luts_files_here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fonts/put_font_files_here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numba 2 | colour-science 3 | rembg 4 | pixeloe 5 | transparent-background -------------------------------------------------------------------------------- /fonts/ShareTechMono-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cubiq/ComfyUI_essentials/HEAD/fonts/ShareTechMono-Regular.ttf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /__pycache__/ 2 | /luts/*.cube 3 | /luts/*.CUBE 4 | /fonts/*.ttf 5 | /fonts/*.otf 6 | !/fonts/ShareTechMono-Regular.ttf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # :wrench: ComfyUI Essentials 2 | 3 | Essential nodes that are weirdly missing from ComfyUI core. With few exceptions they are new features and not commodities. I hope this will be just a temporary repository until the nodes get included into ComfyUI. 4 | 5 | > [!IMPORTANT] 6 | > **2025.04.14** - I do not use ComfyUI as my main way to interact with Gen AI anymore as a result I'm setting the repository in "maintenance only" mode. If there are crucial updates or PRs I might still consider merging them but I do not plan any consistent work on this repo. 7 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui_essentials" 3 | description = "Essential nodes that are weirdly missing from ComfyUI core. With few exceptions they are new features and not commodities." 4 | version = "1.1.0" 5 | license = { file = "LICENSE" } 6 | dependencies = ["numba", "colour-science", "rembg", "pixeloe"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/cubiq/ComfyUI_essentials" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "matteo" 14 | DisplayName = "ComfyUI_essentials" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | jobs: 12 | publish-node: 13 | name: Publish Custom Node to registry 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Check out code 17 | uses: actions/checkout@v4 18 | - name: Publish Custom Node 19 | uses: Comfy-Org/publish-node-action@main 20 | with: 21 | ## Add your own personal access token to your Github Repository secrets and reference it here. 22 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Matteo Spinelli 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /js/DisplayAny.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../scripts/app.js"; 2 | import { ComfyWidgets } from "../../scripts/widgets.js"; 3 | 4 | app.registerExtension({ 5 | name: "essentials.DisplayAny", 6 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 7 | if (!nodeData?.category?.startsWith("essentials")) { 8 | return; 9 | } 10 | 11 | if (nodeData.name === "DisplayAny") { 12 | const onExecuted = nodeType.prototype.onExecuted; 13 | 14 | nodeType.prototype.onExecuted = function (message) { 15 | onExecuted?.apply(this, arguments); 16 | 17 | if (this.widgets) { 18 | for (let i = 1; i < this.widgets.length; i++) { 19 | this.widgets[i].onRemove?.(); 20 | } 21 | this.widgets.length = 1; 22 | } 23 | 24 | // Check if the "text" widget already exists. 25 | let textWidget = this.widgets && this.widgets.find(w => w.name === "displaytext"); 26 | if (!textWidget) { 27 | textWidget = ComfyWidgets["STRING"](this, "displaytext", ["STRING", { multiline: true }], app).widget; 28 | textWidget.inputEl.readOnly = true; 29 | textWidget.inputEl.style.border = "none"; 30 | textWidget.inputEl.style.backgroundColor = "transparent"; 31 | } 32 | textWidget.value = message["text"].join(""); 33 | }; 34 | } 35 | }, 36 | }); -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | #from .essentials import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | from .image import IMAGE_CLASS_MAPPINGS, IMAGE_NAME_MAPPINGS 3 | from .mask import MASK_CLASS_MAPPINGS, MASK_NAME_MAPPINGS 4 | from .sampling import SAMPLING_CLASS_MAPPINGS, SAMPLING_NAME_MAPPINGS 5 | from .segmentation import SEG_CLASS_MAPPINGS, SEG_NAME_MAPPINGS 6 | from .misc import MISC_CLASS_MAPPINGS, MISC_NAME_MAPPINGS 7 | from .conditioning import COND_CLASS_MAPPINGS, COND_NAME_MAPPINGS 8 | from .text import TEXT_CLASS_MAPPINGS, TEXT_NAME_MAPPINGS 9 | 10 | WEB_DIRECTORY = "./js" 11 | 12 | NODE_CLASS_MAPPINGS = {} 13 | NODE_DISPLAY_NAME_MAPPINGS = {} 14 | 15 | NODE_CLASS_MAPPINGS.update(COND_CLASS_MAPPINGS) 16 | NODE_DISPLAY_NAME_MAPPINGS.update(COND_NAME_MAPPINGS) 17 | 18 | NODE_CLASS_MAPPINGS.update(IMAGE_CLASS_MAPPINGS) 19 | NODE_DISPLAY_NAME_MAPPINGS.update(IMAGE_NAME_MAPPINGS) 20 | 21 | NODE_CLASS_MAPPINGS.update(MASK_CLASS_MAPPINGS) 22 | NODE_DISPLAY_NAME_MAPPINGS.update(MASK_NAME_MAPPINGS) 23 | 24 | NODE_CLASS_MAPPINGS.update(SAMPLING_CLASS_MAPPINGS) 25 | NODE_DISPLAY_NAME_MAPPINGS.update(SAMPLING_NAME_MAPPINGS) 26 | 27 | NODE_CLASS_MAPPINGS.update(SEG_CLASS_MAPPINGS) 28 | NODE_DISPLAY_NAME_MAPPINGS.update(SEG_NAME_MAPPINGS) 29 | 30 | NODE_CLASS_MAPPINGS.update(TEXT_CLASS_MAPPINGS) 31 | NODE_DISPLAY_NAME_MAPPINGS.update(TEXT_NAME_MAPPINGS) 32 | 33 | NODE_CLASS_MAPPINGS.update(MISC_CLASS_MAPPINGS) 34 | NODE_DISPLAY_NAME_MAPPINGS.update(MISC_NAME_MAPPINGS) 35 | 36 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS', "WEB_DIRECTORY"] 37 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import scipy 4 | import os 5 | #import re 6 | from pathlib import Path 7 | import folder_paths 8 | 9 | FONTS_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "fonts") 10 | 11 | SCRIPT_DIR = Path(__file__).parent 12 | folder_paths.add_model_folder_path("luts", (SCRIPT_DIR / "luts").as_posix()) 13 | folder_paths.add_model_folder_path( 14 | "luts", (Path(folder_paths.models_dir) / "luts").as_posix() 15 | ) 16 | 17 | # from https://github.com/pythongosssss/ComfyUI-Custom-Scripts 18 | class AnyType(str): 19 | def __ne__(self, __value: object) -> bool: 20 | return False 21 | 22 | def min_(tensor_list): 23 | # return the element-wise min of the tensor list. 24 | x = torch.stack(tensor_list) 25 | mn = x.min(axis=0)[0] 26 | return torch.clamp(mn, min=0) 27 | 28 | def max_(tensor_list): 29 | # return the element-wise max of the tensor list. 30 | x = torch.stack(tensor_list) 31 | mx = x.max(axis=0)[0] 32 | return torch.clamp(mx, max=1) 33 | 34 | def expand_mask(mask, expand, tapered_corners): 35 | c = 0 if tapered_corners else 1 36 | kernel = np.array([[c, 1, c], 37 | [1, 1, 1], 38 | [c, 1, c]]) 39 | mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) 40 | out = [] 41 | for m in mask: 42 | output = m.numpy() 43 | for _ in range(abs(expand)): 44 | if expand < 0: 45 | output = scipy.ndimage.grey_erosion(output, footprint=kernel) 46 | else: 47 | output = scipy.ndimage.grey_dilation(output, footprint=kernel) 48 | output = torch.from_numpy(output) 49 | out.append(output) 50 | 51 | return torch.stack(out, dim=0) 52 | 53 | def parse_string_to_list(s): 54 | elements = s.split(',') 55 | result = [] 56 | 57 | def parse_number(s): 58 | try: 59 | if '.' in s: 60 | return float(s) 61 | else: 62 | return int(s) 63 | except ValueError: 64 | return 0 65 | 66 | def decimal_places(s): 67 | if '.' in s: 68 | return len(s.split('.')[1]) 69 | return 0 70 | 71 | for element in elements: 72 | element = element.strip() 73 | if '...' in element: 74 | start, rest = element.split('...') 75 | end, step = rest.split('+') 76 | decimals = decimal_places(step) 77 | start = parse_number(start) 78 | end = parse_number(end) 79 | step = parse_number(step) 80 | current = start 81 | if (start > end and step > 0) or (start < end and step < 0): 82 | step = -step 83 | while current <= end: 84 | result.append(round(current, decimals)) 85 | current += step 86 | else: 87 | result.append(round(parse_number(element), decimal_places(element))) 88 | 89 | return result -------------------------------------------------------------------------------- /segmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.v2 as T 3 | import torch.nn.functional as F 4 | from .utils import expand_mask 5 | 6 | class LoadCLIPSegModels: 7 | @classmethod 8 | def INPUT_TYPES(s): 9 | return { 10 | "required": {}, 11 | } 12 | 13 | RETURN_TYPES = ("CLIP_SEG",) 14 | FUNCTION = "execute" 15 | CATEGORY = "essentials/segmentation" 16 | 17 | def execute(self): 18 | from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation 19 | processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") 20 | model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") 21 | 22 | return ((processor, model),) 23 | 24 | class ApplyCLIPSeg: 25 | @classmethod 26 | def INPUT_TYPES(s): 27 | return { 28 | "required": { 29 | "clip_seg": ("CLIP_SEG",), 30 | "image": ("IMAGE",), 31 | "prompt": ("STRING", { "multiline": False, "default": "" }), 32 | "threshold": ("FLOAT", { "default": 0.4, "min": 0.0, "max": 1.0, "step": 0.05 }), 33 | "smooth": ("INT", { "default": 9, "min": 0, "max": 32, "step": 1 }), 34 | "dilate": ("INT", { "default": 0, "min": -32, "max": 32, "step": 1 }), 35 | "blur": ("INT", { "default": 0, "min": 0, "max": 64, "step": 1 }), 36 | }, 37 | } 38 | 39 | RETURN_TYPES = ("MASK",) 40 | FUNCTION = "execute" 41 | CATEGORY = "essentials/segmentation" 42 | 43 | def execute(self, image, clip_seg, prompt, threshold, smooth, dilate, blur): 44 | processor, model = clip_seg 45 | 46 | imagenp = image.mul(255).clamp(0, 255).byte().cpu().numpy() 47 | 48 | outputs = [] 49 | for i in imagenp: 50 | inputs = processor(text=prompt, images=[i], return_tensors="pt") 51 | out = model(**inputs) 52 | out = out.logits.unsqueeze(1) 53 | out = torch.sigmoid(out[0][0]) 54 | out = (out > threshold) 55 | outputs.append(out) 56 | 57 | del imagenp 58 | 59 | outputs = torch.stack(outputs, dim=0) 60 | 61 | if smooth > 0: 62 | if smooth % 2 == 0: 63 | smooth += 1 64 | outputs = T.functional.gaussian_blur(outputs, smooth) 65 | 66 | outputs = outputs.float() 67 | 68 | if dilate != 0: 69 | outputs = expand_mask(outputs, dilate, True) 70 | 71 | if blur > 0: 72 | if blur % 2 == 0: 73 | blur += 1 74 | outputs = T.functional.gaussian_blur(outputs, blur) 75 | 76 | # resize to original size 77 | outputs = F.interpolate(outputs.unsqueeze(1), size=(image.shape[1], image.shape[2]), mode='bicubic').squeeze(1) 78 | 79 | return (outputs,) 80 | 81 | SEG_CLASS_MAPPINGS = { 82 | "ApplyCLIPSeg+": ApplyCLIPSeg, 83 | "LoadCLIPSegModels+": LoadCLIPSegModels, 84 | } 85 | 86 | SEG_NAME_MAPPINGS = { 87 | "ApplyCLIPSeg+": "🔧 Apply CLIPSeg", 88 | "LoadCLIPSegModels+": "🔧 Load CLIPSeg Models", 89 | } -------------------------------------------------------------------------------- /histogram_matching.py: -------------------------------------------------------------------------------- 1 | # from MIT licensed https://github.com/nemodleo/pytorch-histogram-matching 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class Histogram_Matching(nn.Module): 7 | def __init__(self, differentiable=False): 8 | super(Histogram_Matching, self).__init__() 9 | self.differentiable = differentiable 10 | 11 | def forward(self, dst, ref): 12 | # B C 13 | B, C, H, W = dst.size() 14 | # assertion 15 | assert dst.device == ref.device 16 | # [B*C 256] 17 | hist_dst = self.cal_hist(dst) 18 | hist_ref = self.cal_hist(ref) 19 | # [B*C 256] 20 | tables = self.cal_trans_batch(hist_dst, hist_ref) 21 | # [B C H W] 22 | rst = dst.clone() 23 | for b in range(B): 24 | for c in range(C): 25 | rst[b,c] = tables[b*c, (dst[b,c] * 255).long()] 26 | # [B C H W] 27 | rst /= 255. 28 | return rst 29 | 30 | def cal_hist(self, img): 31 | B, C, H, W = img.size() 32 | # [B*C 256] 33 | if self.differentiable: 34 | hists = self.soft_histc_batch(img * 255, bins=256, min=0, max=256, sigma=3*25) 35 | else: 36 | hists = torch.stack([torch.histc(img[b,c] * 255, bins=256, min=0, max=255) for b in range(B) for c in range(C)]) 37 | hists = hists.float() 38 | hists = F.normalize(hists, p=1) 39 | # BC 256 40 | bc, n = hists.size() 41 | # [B*C 256 256] 42 | triu = torch.ones(bc, n, n, device=hists.device).triu() 43 | # [B*C 256] 44 | hists = torch.bmm(hists[:,None,:], triu)[:,0,:] 45 | return hists 46 | 47 | def soft_histc_batch(self, x, bins=256, min=0, max=256, sigma=3*25): 48 | # B C H W 49 | B, C, H, W = x.size() 50 | # [B*C H*W] 51 | x = x.view(B*C, -1) 52 | # 1 53 | delta = float(max - min) / float(bins) 54 | # [256] 55 | centers = float(min) + delta * (torch.arange(bins, device=x.device, dtype=torch.bfloat16) + 0.5) 56 | # [B*C 1 H*W] 57 | x = torch.unsqueeze(x, 1) 58 | # [1 256 1] 59 | centers = centers[None,:,None] 60 | # [B*C 256 H*W] 61 | x = x - centers 62 | # [B*C 256 H*W] 63 | x = x.type(torch.bfloat16) 64 | # [B*C 256 H*W] 65 | x = torch.sigmoid(sigma * (x + delta/2)) - torch.sigmoid(sigma * (x - delta/2)) 66 | # [B*C 256] 67 | x = x.sum(dim=2) 68 | # [B*C 256] 69 | x = x.type(torch.float32) 70 | # prevent oom 71 | # torch.cuda.empty_cache() 72 | return x 73 | 74 | def cal_trans_batch(self, hist_dst, hist_ref): 75 | # [B*C 256 256] 76 | hist_dst = hist_dst[:,None,:].repeat(1,256,1) 77 | # [B*C 256 256] 78 | hist_ref = hist_ref[:,:,None].repeat(1,1,256) 79 | # [B*C 256 256] 80 | table = hist_dst - hist_ref 81 | # [B*C 256 256] 82 | table = torch.where(table>=0, 1., 0.) 83 | # [B*C 256] 84 | table = torch.sum(table, dim=1) - 1 85 | # [B*C 256] 86 | table = torch.clamp(table, min=0, max=255) 87 | return table 88 | -------------------------------------------------------------------------------- /text.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from nodes import MAX_RESOLUTION 4 | import torchvision.transforms.v2 as T 5 | from .utils import FONTS_DIR 6 | 7 | class DrawText: 8 | @classmethod 9 | def INPUT_TYPES(s): 10 | return { 11 | "required": { 12 | "text": ("STRING", { "multiline": True, "dynamicPrompts": True, "default": "Hello, World!" }), 13 | "font": (sorted([f for f in os.listdir(FONTS_DIR) if f.endswith('.ttf') or f.endswith('.otf')]), ), 14 | "size": ("INT", { "default": 56, "min": 1, "max": 9999, "step": 1 }), 15 | "color": ("STRING", { "multiline": False, "default": "#FFFFFF" }), 16 | "background_color": ("STRING", { "multiline": False, "default": "#00000000" }), 17 | "shadow_distance": ("INT", { "default": 0, "min": 0, "max": 100, "step": 1 }), 18 | "shadow_blur": ("INT", { "default": 0, "min": 0, "max": 100, "step": 1 }), 19 | "shadow_color": ("STRING", { "multiline": False, "default": "#000000" }), 20 | "horizontal_align": (["left", "center", "right"],), 21 | "vertical_align": (["top", "center", "bottom"],), 22 | "offset_x": ("INT", { "default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1 }), 23 | "offset_y": ("INT", { "default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1 }), 24 | "direction": (["ltr", "rtl"],), 25 | }, 26 | "optional": { 27 | "img_composite": ("IMAGE",), 28 | }, 29 | } 30 | 31 | RETURN_TYPES = ("IMAGE", "MASK",) 32 | FUNCTION = "execute" 33 | CATEGORY = "essentials/text" 34 | 35 | def execute(self, text, font, size, color, background_color, shadow_distance, shadow_blur, shadow_color, horizontal_align, vertical_align, offset_x, offset_y, direction, img_composite=None): 36 | from PIL import Image, ImageDraw, ImageFont, ImageColor, ImageFilter 37 | 38 | font = ImageFont.truetype(os.path.join(FONTS_DIR, font), size) 39 | 40 | lines = text.split("\n") 41 | if direction == "rtl": 42 | lines = [line[::-1] for line in lines] 43 | 44 | # Calculate the width and height of the text 45 | text_width = max(font.getbbox(line)[2] for line in lines) 46 | line_height = font.getmask(text).getbbox()[3] + font.getmetrics()[1] # add descent to height 47 | text_height = line_height * len(lines) 48 | 49 | if img_composite is not None: 50 | img_composite = T.ToPILImage()(img_composite.permute([0,3,1,2])[0]).convert('RGBA') 51 | width = img_composite.width 52 | height = img_composite.height 53 | image = Image.new('RGBA', (width, height), color=background_color) 54 | else: 55 | width = text_width 56 | height = text_height 57 | background_color = ImageColor.getrgb(background_color) 58 | image = Image.new('RGBA', (width + shadow_distance, height + shadow_distance), color=background_color) 59 | 60 | image_shadow = None 61 | if shadow_distance > 0: 62 | image_shadow = image.copy() 63 | #image_shadow = Image.new('RGBA', (width + shadow_distance, height + shadow_distance), color=background_color) 64 | 65 | for i, line in enumerate(lines): 66 | line_width = font.getbbox(line)[2] 67 | #text_height =font.getbbox(line)[3] 68 | if horizontal_align == "left": 69 | x = 0 70 | elif horizontal_align == "center": 71 | x = (width - line_width) / 2 72 | elif horizontal_align == "right": 73 | x = width - line_width 74 | 75 | if vertical_align == "top": 76 | y = 0 77 | elif vertical_align == "center": 78 | y = (height - text_height) / 2 79 | elif vertical_align == "bottom": 80 | y = height - text_height 81 | 82 | x += offset_x 83 | y += i * line_height + offset_y 84 | 85 | draw = ImageDraw.Draw(image) 86 | draw.text((x, y), line, font=font, fill=color) 87 | 88 | if image_shadow is not None: 89 | draw = ImageDraw.Draw(image_shadow) 90 | draw.text((x + shadow_distance, y + shadow_distance), line, font=font, fill=shadow_color) 91 | 92 | if image_shadow is not None: 93 | image_shadow = image_shadow.filter(ImageFilter.GaussianBlur(shadow_blur)) 94 | image = Image.alpha_composite(image_shadow, image) 95 | 96 | #image = T.ToTensor()(image).unsqueeze(0).permute([0,2,3,1]) 97 | mask = T.ToTensor()(image).unsqueeze(0).permute([0,2,3,1]) 98 | mask = mask[:, :, :, 3] if mask.shape[3] == 4 else torch.ones_like(mask[:, :, :, 0]) 99 | 100 | if img_composite is not None: 101 | image = Image.alpha_composite(img_composite, image) 102 | 103 | image = T.ToTensor()(image).unsqueeze(0).permute([0,2,3,1]) 104 | 105 | return (image[:, :, :, :3], mask,) 106 | 107 | TEXT_CLASS_MAPPINGS = { 108 | "DrawText+": DrawText, 109 | } 110 | 111 | TEXT_NAME_MAPPINGS = { 112 | "DrawText+": "🔧 Draw Text", 113 | } -------------------------------------------------------------------------------- /js/FluxAttentionSeeker.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../scripts/app.js"; 2 | 3 | app.registerExtension({ 4 | name: "essentials.FluxAttentionSeeker", 5 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 6 | if (!nodeData?.category?.startsWith("essentials")) { 7 | return; 8 | } 9 | 10 | if (nodeData.name === "FluxAttentionSeeker+") { 11 | const onCreated = nodeType.prototype.onNodeCreated; 12 | 13 | nodeType.prototype.onNodeCreated = function () { 14 | this.addWidget("button", "RESET ALL", null, () => { 15 | this.widgets.forEach(w => { 16 | if (w.type === "slider") { 17 | w.value = 1.0; 18 | } 19 | }); 20 | }); 21 | 22 | this.addWidget("button", "ZERO ALL", null, () => { 23 | this.widgets.forEach(w => { 24 | if (w.type === "slider") { 25 | w.value = 0.0; 26 | } 27 | }); 28 | }); 29 | 30 | this.addWidget("button", "REPEAT FIRST", null, () => { 31 | var clip_value = undefined; 32 | var t5_value = undefined; 33 | this.widgets.forEach(w => { 34 | if (w.name.startsWith('clip_l')) { 35 | if (clip_value === undefined) { 36 | clip_value = w.value; 37 | } 38 | w.value = clip_value; 39 | } else if (w.name.startsWith('t5')) { 40 | if (t5_value === undefined) { 41 | t5_value = w.value; 42 | } 43 | w.value = t5_value; 44 | } 45 | }); 46 | }); 47 | }; 48 | } 49 | }, 50 | }); 51 | 52 | app.registerExtension({ 53 | name: "essentials.SD3AttentionSeekerLG", 54 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 55 | if (!nodeData?.category?.startsWith("essentials")) { 56 | return; 57 | } 58 | 59 | if (nodeData.name === "SD3AttentionSeekerLG+") { 60 | const onCreated = nodeType.prototype.onNodeCreated; 61 | 62 | nodeType.prototype.onNodeCreated = function () { 63 | this.addWidget("button", "RESET L", null, () => { 64 | this.widgets.forEach(w => { 65 | if (w.type === "slider" && w.name.startsWith('clip_l')) { 66 | w.value = 1.0; 67 | } 68 | }); 69 | }); 70 | this.addWidget("button", "RESET G", null, () => { 71 | this.widgets.forEach(w => { 72 | if (w.type === "slider" && w.name.startsWith('clip_g')) { 73 | w.value = 1.0; 74 | } 75 | }); 76 | }); 77 | 78 | this.addWidget("button", "REPEAT FIRST", null, () => { 79 | var clip_l_value = undefined; 80 | var clip_g_value = undefined; 81 | this.widgets.forEach(w => { 82 | if (w.name.startsWith('clip_l')) { 83 | if (clip_l_value === undefined) { 84 | clip_l_value = w.value; 85 | } 86 | w.value = clip_l_value; 87 | } else if (w.name.startsWith('clip_g')) { 88 | if (clip_g_value === undefined) { 89 | clip_g_value = w.value; 90 | } 91 | w.value = clip_g_value; 92 | } 93 | }); 94 | }); 95 | }; 96 | } 97 | }, 98 | }); 99 | 100 | app.registerExtension({ 101 | name: "essentials.SD3AttentionSeekerT5", 102 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 103 | if (!nodeData?.category?.startsWith("essentials")) { 104 | return; 105 | } 106 | 107 | if (nodeData.name === "SD3AttentionSeekerT5+") { 108 | const onCreated = nodeType.prototype.onNodeCreated; 109 | 110 | nodeType.prototype.onNodeCreated = function () { 111 | this.addWidget("button", "RESET ALL", null, () => { 112 | this.widgets.forEach(w => { 113 | if (w.type === "slider") { 114 | w.value = 1.0; 115 | } 116 | }); 117 | }); 118 | 119 | this.addWidget("button", "REPEAT FIRST", null, () => { 120 | var t5_value = undefined; 121 | this.widgets.forEach(w => { 122 | if (w.name.startsWith('t5')) { 123 | if (t5_value === undefined) { 124 | t5_value = w.value; 125 | } 126 | w.value = t5_value; 127 | } 128 | }); 129 | }); 130 | }; 131 | } 132 | }, 133 | }); -------------------------------------------------------------------------------- /conditioning.py: -------------------------------------------------------------------------------- 1 | from nodes import MAX_RESOLUTION, ConditioningZeroOut, ConditioningSetTimestepRange, ConditioningCombine 2 | import re 3 | 4 | class CLIPTextEncodeSDXLSimplified: 5 | @classmethod 6 | def INPUT_TYPES(s): 7 | return {"required": { 8 | "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), 9 | "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), 10 | "size_cond_factor": ("INT", {"default": 4, "min": 1, "max": 16 }), 11 | "text": ("STRING", {"multiline": True, "dynamicPrompts": True, "default": ""}), 12 | "clip": ("CLIP", ), 13 | }} 14 | RETURN_TYPES = ("CONDITIONING",) 15 | FUNCTION = "execute" 16 | CATEGORY = "essentials/conditioning" 17 | 18 | def execute(self, clip, width, height, size_cond_factor, text): 19 | crop_w = 0 20 | crop_h = 0 21 | width = width*size_cond_factor 22 | height = height*size_cond_factor 23 | target_width = width 24 | target_height = height 25 | text_g = text_l = text 26 | 27 | tokens = clip.tokenize(text_g) 28 | tokens["l"] = clip.tokenize(text_l)["l"] 29 | if len(tokens["l"]) != len(tokens["g"]): 30 | empty = clip.tokenize("") 31 | while len(tokens["l"]) < len(tokens["g"]): 32 | tokens["l"] += empty["l"] 33 | while len(tokens["l"]) > len(tokens["g"]): 34 | tokens["g"] += empty["g"] 35 | cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) 36 | return ([[cond, {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]], ) 37 | 38 | class ConditioningCombineMultiple: 39 | @classmethod 40 | def INPUT_TYPES(s): 41 | return { 42 | "required": { 43 | "conditioning_1": ("CONDITIONING",), 44 | "conditioning_2": ("CONDITIONING",), 45 | }, "optional": { 46 | "conditioning_3": ("CONDITIONING",), 47 | "conditioning_4": ("CONDITIONING",), 48 | "conditioning_5": ("CONDITIONING",), 49 | }, 50 | } 51 | RETURN_TYPES = ("CONDITIONING",) 52 | FUNCTION = "execute" 53 | CATEGORY = "essentials/conditioning" 54 | 55 | def execute(self, conditioning_1, conditioning_2, conditioning_3=None, conditioning_4=None, conditioning_5=None): 56 | c = conditioning_1 + conditioning_2 57 | 58 | if conditioning_3 is not None: 59 | c += conditioning_3 60 | if conditioning_4 is not None: 61 | c += conditioning_4 62 | if conditioning_5 is not None: 63 | c += conditioning_5 64 | 65 | return (c,) 66 | 67 | class SD3NegativeConditioning: 68 | @classmethod 69 | def INPUT_TYPES(s): 70 | return {"required": { 71 | "conditioning": ("CONDITIONING",), 72 | "end": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.001 }), 73 | }} 74 | RETURN_TYPES = ("CONDITIONING",) 75 | FUNCTION = "execute" 76 | CATEGORY = "essentials/conditioning" 77 | 78 | def execute(self, conditioning, end): 79 | zero_c = ConditioningZeroOut().zero_out(conditioning)[0] 80 | 81 | if end == 0: 82 | return (zero_c, ) 83 | 84 | c = ConditioningSetTimestepRange().set_range(conditioning, 0, end)[0] 85 | zero_c = ConditioningSetTimestepRange().set_range(zero_c, end, 1.0)[0] 86 | c = ConditioningCombine().combine(zero_c, c)[0] 87 | 88 | return (c, ) 89 | 90 | class FluxAttentionSeeker: 91 | @classmethod 92 | def INPUT_TYPES(s): 93 | return {"required": { 94 | "clip": ("CLIP",), 95 | "apply_to_query": ("BOOLEAN", { "default": True }), 96 | "apply_to_key": ("BOOLEAN", { "default": True }), 97 | "apply_to_value": ("BOOLEAN", { "default": True }), 98 | "apply_to_out": ("BOOLEAN", { "default": True }), 99 | **{f"clip_l_{s}": ("FLOAT", { "display": "slider", "default": 1.0, "min": 0, "max": 5, "step": 0.05 }) for s in range(12)}, 100 | **{f"t5xxl_{s}": ("FLOAT", { "display": "slider", "default": 1.0, "min": 0, "max": 5, "step": 0.05 }) for s in range(24)}, 101 | }} 102 | 103 | RETURN_TYPES = ("CLIP",) 104 | FUNCTION = "execute" 105 | 106 | CATEGORY = "essentials/conditioning" 107 | 108 | def execute(self, clip, apply_to_query, apply_to_key, apply_to_value, apply_to_out, **values): 109 | if not apply_to_key and not apply_to_query and not apply_to_value and not apply_to_out: 110 | return (clip, ) 111 | 112 | m = clip.clone() 113 | sd = m.patcher.model_state_dict() 114 | 115 | for k in sd: 116 | if "self_attn" in k: 117 | layer = re.search(r"\.layers\.(\d+)\.", k) 118 | layer = int(layer.group(1)) if layer else None 119 | 120 | if layer is not None and values[f"clip_l_{layer}"] != 1.0: 121 | if (apply_to_query and "q_proj" in k) or (apply_to_key and "k_proj" in k) or (apply_to_value and "v_proj" in k) or (apply_to_out and "out_proj" in k): 122 | m.add_patches({k: (None,)}, 0.0, values[f"clip_l_{layer}"]) 123 | elif "SelfAttention" in k: 124 | block = re.search(r"\.block\.(\d+)\.", k) 125 | block = int(block.group(1)) if block else None 126 | 127 | if block is not None and values[f"t5xxl_{block}"] != 1.0: 128 | if (apply_to_query and ".q." in k) or (apply_to_key and ".k." in k) or (apply_to_value and ".v." in k) or (apply_to_out and ".o." in k): 129 | m.add_patches({k: (None,)}, 0.0, values[f"t5xxl_{block}"]) 130 | 131 | return (m, ) 132 | 133 | class SD3AttentionSeekerLG: 134 | @classmethod 135 | def INPUT_TYPES(s): 136 | return {"required": { 137 | "clip": ("CLIP",), 138 | "apply_to_query": ("BOOLEAN", { "default": True }), 139 | "apply_to_key": ("BOOLEAN", { "default": True }), 140 | "apply_to_value": ("BOOLEAN", { "default": True }), 141 | "apply_to_out": ("BOOLEAN", { "default": True }), 142 | **{f"clip_l_{s}": ("FLOAT", { "display": "slider", "default": 1.0, "min": 0, "max": 5, "step": 0.05 }) for s in range(12)}, 143 | **{f"clip_g_{s}": ("FLOAT", { "display": "slider", "default": 1.0, "min": 0, "max": 5, "step": 0.05 }) for s in range(32)}, 144 | }} 145 | 146 | RETURN_TYPES = ("CLIP",) 147 | FUNCTION = "execute" 148 | 149 | CATEGORY = "essentials/conditioning" 150 | 151 | def execute(self, clip, apply_to_query, apply_to_key, apply_to_value, apply_to_out, **values): 152 | if not apply_to_key and not apply_to_query and not apply_to_value and not apply_to_out: 153 | return (clip, ) 154 | 155 | m = clip.clone() 156 | sd = m.patcher.model_state_dict() 157 | 158 | for k in sd: 159 | if "self_attn" in k: 160 | layer = re.search(r"\.layers\.(\d+)\.", k) 161 | layer = int(layer.group(1)) if layer else None 162 | 163 | if layer is not None: 164 | if "clip_l" in k and values[f"clip_l_{layer}"] != 1.0: 165 | if (apply_to_query and "q_proj" in k) or (apply_to_key and "k_proj" in k) or (apply_to_value and "v_proj" in k) or (apply_to_out and "out_proj" in k): 166 | m.add_patches({k: (None,)}, 0.0, values[f"clip_l_{layer}"]) 167 | elif "clip_g" in k and values[f"clip_g_{layer}"] != 1.0: 168 | if (apply_to_query and "q_proj" in k) or (apply_to_key and "k_proj" in k) or (apply_to_value and "v_proj" in k) or (apply_to_out and "out_proj" in k): 169 | m.add_patches({k: (None,)}, 0.0, values[f"clip_g_{layer}"]) 170 | 171 | return (m, ) 172 | 173 | class SD3AttentionSeekerT5: 174 | @classmethod 175 | def INPUT_TYPES(s): 176 | return {"required": { 177 | "clip": ("CLIP",), 178 | "apply_to_query": ("BOOLEAN", { "default": True }), 179 | "apply_to_key": ("BOOLEAN", { "default": True }), 180 | "apply_to_value": ("BOOLEAN", { "default": True }), 181 | "apply_to_out": ("BOOLEAN", { "default": True }), 182 | **{f"t5xxl_{s}": ("FLOAT", { "display": "slider", "default": 1.0, "min": 0, "max": 5, "step": 0.05 }) for s in range(24)}, 183 | }} 184 | 185 | RETURN_TYPES = ("CLIP",) 186 | FUNCTION = "execute" 187 | 188 | CATEGORY = "essentials/conditioning" 189 | 190 | def execute(self, clip, apply_to_query, apply_to_key, apply_to_value, apply_to_out, **values): 191 | if not apply_to_key and not apply_to_query and not apply_to_value and not apply_to_out: 192 | return (clip, ) 193 | 194 | m = clip.clone() 195 | sd = m.patcher.model_state_dict() 196 | 197 | for k in sd: 198 | if "SelfAttention" in k: 199 | block = re.search(r"\.block\.(\d+)\.", k) 200 | block = int(block.group(1)) if block else None 201 | 202 | if block is not None and values[f"t5xxl_{block}"] != 1.0: 203 | if (apply_to_query and ".q." in k) or (apply_to_key and ".k." in k) or (apply_to_value and ".v." in k) or (apply_to_out and ".o." in k): 204 | m.add_patches({k: (None,)}, 0.0, values[f"t5xxl_{block}"]) 205 | 206 | return (m, ) 207 | 208 | class FluxBlocksBuster: 209 | @classmethod 210 | def INPUT_TYPES(s): 211 | return {"required": { 212 | "model": ("MODEL",), 213 | "blocks": ("STRING", {"default": "## 0 = 1.0\n## 1 = 1.0\n## 2 = 1.0\n## 3 = 1.0\n## 4 = 1.0\n## 5 = 1.0\n## 6 = 1.0\n## 7 = 1.0\n## 8 = 1.0\n## 9 = 1.0\n## 10 = 1.0\n## 11 = 1.0\n## 12 = 1.0\n## 13 = 1.0\n## 14 = 1.0\n## 15 = 1.0\n## 16 = 1.0\n## 17 = 1.0\n## 18 = 1.0\n# 0 = 1.0\n# 1 = 1.0\n# 2 = 1.0\n# 3 = 1.0\n# 4 = 1.0\n# 5 = 1.0\n# 6 = 1.0\n# 7 = 1.0\n# 8 = 1.0\n# 9 = 1.0\n# 10 = 1.0\n# 11 = 1.0\n# 12 = 1.0\n# 13 = 1.0\n# 14 = 1.0\n# 15 = 1.0\n# 16 = 1.0\n# 17 = 1.0\n# 18 = 1.0\n# 19 = 1.0\n# 20 = 1.0\n# 21 = 1.0\n# 22 = 1.0\n# 23 = 1.0\n# 24 = 1.0\n# 25 = 1.0\n# 26 = 1.0\n# 27 = 1.0\n# 28 = 1.0\n# 29 = 1.0\n# 30 = 1.0\n# 31 = 1.0\n# 32 = 1.0\n# 33 = 1.0\n# 34 = 1.0\n# 35 = 1.0\n# 36 = 1.0\n# 37 = 1.0", "multiline": True, "dynamicPrompts": True}), 214 | #**{f"double_block_{s}": ("FLOAT", { "display": "slider", "default": 1.0, "min": 0, "max": 5, "step": 0.05 }) for s in range(19)}, 215 | #**{f"single_block_{s}": ("FLOAT", { "display": "slider", "default": 1.0, "min": 0, "max": 5, "step": 0.05 }) for s in range(38)}, 216 | }} 217 | RETURN_TYPES = ("MODEL", "STRING") 218 | RETURN_NAMES = ("MODEL", "patched_blocks") 219 | FUNCTION = "patch" 220 | 221 | CATEGORY = "essentials/conditioning" 222 | 223 | def patch(self, model, blocks): 224 | if blocks == "": 225 | return (model, ) 226 | 227 | m = model.clone() 228 | sd = model.model_state_dict() 229 | patched_blocks = [] 230 | 231 | """ 232 | Also compatible with the following format: 233 | 234 | double_blocks\.0\.(img|txt)_(mod|attn|mlp)\.(lin|qkv|proj|0|2)\.(weight|bias)=1.1 235 | single_blocks\.0\.(linear[12]|modulation\.lin)\.(weight|bias)=1.1 236 | 237 | The regex is used to match the block names 238 | """ 239 | 240 | blocks = blocks.split("\n") 241 | blocks = [b.strip() for b in blocks if b.strip()] 242 | 243 | for k in sd: 244 | for block in blocks: 245 | block = block.split("=") 246 | value = float(block[1].strip()) if len(block) > 1 else 1.0 247 | block = block[0].strip() 248 | if block.startswith("##"): 249 | block = r"double_blocks\." + block[2:].strip() + r"\.(img|txt)_(mod|attn|mlp)\.(lin|qkv|proj|0|2)\.(weight|bias)" 250 | elif block.startswith("#"): 251 | block = r"single_blocks\." + block[1:].strip() + r"\.(linear[12]|modulation\.lin)\.(weight|bias)" 252 | 253 | if value != 1.0 and re.search(block, k): 254 | m.add_patches({k: (None,)}, 0.0, value) 255 | patched_blocks.append(f"{k}: {value}") 256 | 257 | patched_blocks = "\n".join(patched_blocks) 258 | 259 | return (m, patched_blocks,) 260 | 261 | 262 | COND_CLASS_MAPPINGS = { 263 | "CLIPTextEncodeSDXL+": CLIPTextEncodeSDXLSimplified, 264 | "ConditioningCombineMultiple+": ConditioningCombineMultiple, 265 | "SD3NegativeConditioning+": SD3NegativeConditioning, 266 | "FluxAttentionSeeker+": FluxAttentionSeeker, 267 | "SD3AttentionSeekerLG+": SD3AttentionSeekerLG, 268 | "SD3AttentionSeekerT5+": SD3AttentionSeekerT5, 269 | "FluxBlocksBuster+": FluxBlocksBuster, 270 | } 271 | 272 | COND_NAME_MAPPINGS = { 273 | "CLIPTextEncodeSDXL+": "🔧 SDXL CLIPTextEncode", 274 | "ConditioningCombineMultiple+": "🔧 Cond Combine Multiple", 275 | "SD3NegativeConditioning+": "🔧 SD3 Negative Conditioning", 276 | "FluxAttentionSeeker+": "🔧 Flux Attention Seeker", 277 | "SD3AttentionSeekerLG+": "🔧 SD3 Attention Seeker L/G", 278 | "SD3AttentionSeekerT5+": "🔧 SD3 Attention Seeker T5", 279 | "FluxBlocksBuster+": "🔧 Flux Model Blocks Buster", 280 | } -------------------------------------------------------------------------------- /carve.py: -------------------------------------------------------------------------------- 1 | # MIT licensed code from https://github.com/li-plus/seam-carving/ 2 | 3 | from enum import Enum 4 | from typing import Optional, Tuple 5 | 6 | import numba as nb 7 | import numpy as np 8 | from scipy.ndimage import sobel 9 | 10 | DROP_MASK_ENERGY = 1e5 11 | KEEP_MASK_ENERGY = 1e3 12 | 13 | 14 | class OrderMode(str, Enum): 15 | WIDTH_FIRST = "width-first" 16 | HEIGHT_FIRST = "height-first" 17 | 18 | 19 | class EnergyMode(str, Enum): 20 | FORWARD = "forward" 21 | BACKWARD = "backward" 22 | 23 | 24 | def _list_enum(enum_class) -> Tuple: 25 | return tuple(x.value for x in enum_class) 26 | 27 | 28 | def _rgb2gray(rgb: np.ndarray) -> np.ndarray: 29 | """Convert an RGB image to a grayscale image""" 30 | coeffs = np.array([0.2125, 0.7154, 0.0721], dtype=np.float32) 31 | return (rgb @ coeffs).astype(rgb.dtype) 32 | 33 | 34 | def _get_seam_mask(src: np.ndarray, seam: np.ndarray) -> np.ndarray: 35 | """Convert a list of seam column indices to a mask""" 36 | return np.eye(src.shape[1], dtype=bool)[seam] 37 | 38 | 39 | def _remove_seam_mask(src: np.ndarray, seam_mask: np.ndarray) -> np.ndarray: 40 | """Remove a seam from the source image according to the given seam_mask""" 41 | if src.ndim == 3: 42 | h, w, c = src.shape 43 | seam_mask = np.broadcast_to(seam_mask[:, :, None], src.shape) 44 | dst = src[~seam_mask].reshape((h, w - 1, c)) 45 | else: 46 | h, w = src.shape 47 | dst = src[~seam_mask].reshape((h, w - 1)) 48 | return dst 49 | 50 | 51 | def _get_energy(gray: np.ndarray) -> np.ndarray: 52 | """Get backward energy map from the source image""" 53 | assert gray.ndim == 2 54 | 55 | gray = gray.astype(np.float32) 56 | grad_x = sobel(gray, axis=1) 57 | grad_y = sobel(gray, axis=0) 58 | energy = np.abs(grad_x) + np.abs(grad_y) 59 | return energy 60 | 61 | 62 | @nb.njit(nb.int32[:](nb.float32[:, :]), cache=True) 63 | def _get_backward_seam(energy: np.ndarray) -> np.ndarray: 64 | """Compute the minimum vertical seam from the backward energy map""" 65 | h, w = energy.shape 66 | inf = np.array([np.inf], dtype=np.float32) 67 | cost = np.concatenate((inf, energy[0], inf)) 68 | parent = np.empty((h, w), dtype=np.int32) 69 | base_idx = np.arange(-1, w - 1, dtype=np.int32) 70 | 71 | for r in range(1, h): 72 | choices = np.vstack((cost[:-2], cost[1:-1], cost[2:])) 73 | min_idx = np.argmin(choices, axis=0) + base_idx 74 | parent[r] = min_idx 75 | cost[1:-1] = cost[1:-1][min_idx] + energy[r] 76 | 77 | c = np.argmin(cost[1:-1]) 78 | seam = np.empty(h, dtype=np.int32) 79 | for r in range(h - 1, -1, -1): 80 | seam[r] = c 81 | c = parent[r, c] 82 | 83 | return seam 84 | 85 | 86 | def _get_backward_seams( 87 | gray: np.ndarray, num_seams: int, aux_energy: Optional[np.ndarray] 88 | ) -> np.ndarray: 89 | """Compute the minimum N vertical seams using backward energy""" 90 | h, w = gray.shape 91 | seams = np.zeros((h, w), dtype=bool) 92 | rows = np.arange(h, dtype=np.int32) 93 | idx_map = np.broadcast_to(np.arange(w, dtype=np.int32), (h, w)) 94 | energy = _get_energy(gray) 95 | if aux_energy is not None: 96 | energy += aux_energy 97 | for _ in range(num_seams): 98 | seam = _get_backward_seam(energy) 99 | seams[rows, idx_map[rows, seam]] = True 100 | 101 | seam_mask = _get_seam_mask(gray, seam) 102 | gray = _remove_seam_mask(gray, seam_mask) 103 | idx_map = _remove_seam_mask(idx_map, seam_mask) 104 | if aux_energy is not None: 105 | aux_energy = _remove_seam_mask(aux_energy, seam_mask) 106 | 107 | # Only need to re-compute the energy in the bounding box of the seam 108 | _, cur_w = energy.shape 109 | lo = max(0, np.min(seam) - 1) 110 | hi = min(cur_w, np.max(seam) + 1) 111 | pad_lo = 1 if lo > 0 else 0 112 | pad_hi = 1 if hi < cur_w - 1 else 0 113 | mid_block = gray[:, lo - pad_lo : hi + pad_hi] 114 | _, mid_w = mid_block.shape 115 | mid_energy = _get_energy(mid_block)[:, pad_lo : mid_w - pad_hi] 116 | if aux_energy is not None: 117 | mid_energy += aux_energy[:, lo:hi] 118 | energy = np.hstack((energy[:, :lo], mid_energy, energy[:, hi + 1 :])) 119 | 120 | return seams 121 | 122 | 123 | @nb.njit( 124 | [ 125 | nb.int32[:](nb.float32[:, :], nb.none), 126 | nb.int32[:](nb.float32[:, :], nb.float32[:, :]), 127 | ], 128 | cache=True, 129 | ) 130 | def _get_forward_seam(gray: np.ndarray, aux_energy: Optional[np.ndarray]) -> np.ndarray: 131 | """Compute the minimum vertical seam using forward energy""" 132 | h, w = gray.shape 133 | 134 | gray = np.hstack((gray[:, :1], gray, gray[:, -1:])) 135 | 136 | inf = np.array([np.inf], dtype=np.float32) 137 | dp = np.concatenate((inf, np.abs(gray[0, 2:] - gray[0, :-2]), inf)) 138 | 139 | parent = np.empty((h, w), dtype=np.int32) 140 | base_idx = np.arange(-1, w - 1, dtype=np.int32) 141 | 142 | inf = np.array([np.inf], dtype=np.float32) 143 | for r in range(1, h): 144 | curr_shl = gray[r, 2:] 145 | curr_shr = gray[r, :-2] 146 | cost_mid = np.abs(curr_shl - curr_shr) 147 | if aux_energy is not None: 148 | cost_mid += aux_energy[r] 149 | 150 | prev_mid = gray[r - 1, 1:-1] 151 | cost_left = cost_mid + np.abs(prev_mid - curr_shr) 152 | cost_right = cost_mid + np.abs(prev_mid - curr_shl) 153 | 154 | dp_mid = dp[1:-1] 155 | dp_left = dp[:-2] 156 | dp_right = dp[2:] 157 | 158 | choices = np.vstack( 159 | (cost_left + dp_left, cost_mid + dp_mid, cost_right + dp_right) 160 | ) 161 | min_idx = np.argmin(choices, axis=0) 162 | parent[r] = min_idx + base_idx 163 | # numba does not support specifying axis in np.min, below loop is equivalent to: 164 | # `dp_mid[:] = np.min(choices, axis=0)` or `dp_mid[:] = choices[min_idx, np.arange(w)]` 165 | for j, i in enumerate(min_idx): 166 | dp_mid[j] = choices[i, j] 167 | 168 | c = np.argmin(dp[1:-1]) 169 | seam = np.empty(h, dtype=np.int32) 170 | for r in range(h - 1, -1, -1): 171 | seam[r] = c 172 | c = parent[r, c] 173 | 174 | return seam 175 | 176 | 177 | def _get_forward_seams( 178 | gray: np.ndarray, num_seams: int, aux_energy: Optional[np.ndarray] 179 | ) -> np.ndarray: 180 | """Compute minimum N vertical seams using forward energy""" 181 | h, w = gray.shape 182 | seams = np.zeros((h, w), dtype=bool) 183 | rows = np.arange(h, dtype=np.int32) 184 | idx_map = np.broadcast_to(np.arange(w, dtype=np.int32), (h, w)) 185 | for _ in range(num_seams): 186 | seam = _get_forward_seam(gray, aux_energy) 187 | seams[rows, idx_map[rows, seam]] = True 188 | seam_mask = _get_seam_mask(gray, seam) 189 | gray = _remove_seam_mask(gray, seam_mask) 190 | idx_map = _remove_seam_mask(idx_map, seam_mask) 191 | if aux_energy is not None: 192 | aux_energy = _remove_seam_mask(aux_energy, seam_mask) 193 | 194 | return seams 195 | 196 | 197 | def _get_seams( 198 | gray: np.ndarray, num_seams: int, energy_mode: str, aux_energy: Optional[np.ndarray] 199 | ) -> np.ndarray: 200 | """Get the minimum N seams from the grayscale image""" 201 | gray = np.asarray(gray, dtype=np.float32) 202 | if energy_mode == EnergyMode.BACKWARD: 203 | return _get_backward_seams(gray, num_seams, aux_energy) 204 | elif energy_mode == EnergyMode.FORWARD: 205 | return _get_forward_seams(gray, num_seams, aux_energy) 206 | else: 207 | raise ValueError( 208 | f"expect energy_mode to be one of {_list_enum(EnergyMode)}, got {energy_mode}" 209 | ) 210 | 211 | 212 | def _reduce_width( 213 | src: np.ndarray, 214 | delta_width: int, 215 | energy_mode: str, 216 | aux_energy: Optional[np.ndarray], 217 | ) -> Tuple[np.ndarray, Optional[np.ndarray]]: 218 | """Reduce the width of image by delta_width pixels""" 219 | assert src.ndim in (2, 3) and delta_width >= 0 220 | if src.ndim == 2: 221 | gray = src 222 | src_h, src_w = src.shape 223 | dst_shape: Tuple[int, ...] = (src_h, src_w - delta_width) 224 | else: 225 | gray = _rgb2gray(src) 226 | src_h, src_w, src_c = src.shape 227 | dst_shape = (src_h, src_w - delta_width, src_c) 228 | 229 | to_keep = ~_get_seams(gray, delta_width, energy_mode, aux_energy) 230 | dst = src[to_keep].reshape(dst_shape) 231 | if aux_energy is not None: 232 | aux_energy = aux_energy[to_keep].reshape(dst_shape[:2]) 233 | return dst, aux_energy 234 | 235 | 236 | @nb.njit( 237 | nb.float32[:, :, :](nb.float32[:, :, :], nb.boolean[:, :], nb.int32), cache=True 238 | ) 239 | def _insert_seams_kernel( 240 | src: np.ndarray, seams: np.ndarray, delta_width: int 241 | ) -> np.ndarray: 242 | """The numba kernel for inserting seams""" 243 | src_h, src_w, src_c = src.shape 244 | dst = np.empty((src_h, src_w + delta_width, src_c), dtype=src.dtype) 245 | for row in range(src_h): 246 | dst_col = 0 247 | for src_col in range(src_w): 248 | if seams[row, src_col]: 249 | left = src[row, max(src_col - 1, 0)] 250 | right = src[row, src_col] 251 | dst[row, dst_col] = (left + right) / 2 252 | dst_col += 1 253 | dst[row, dst_col] = src[row, src_col] 254 | dst_col += 1 255 | return dst 256 | 257 | 258 | def _insert_seams(src: np.ndarray, seams: np.ndarray, delta_width: int) -> np.ndarray: 259 | """Insert multiple seams into the source image""" 260 | dst = src.astype(np.float32) 261 | if dst.ndim == 2: 262 | dst = dst[:, :, None] 263 | dst = _insert_seams_kernel(dst, seams, delta_width).astype(src.dtype) 264 | if src.ndim == 2: 265 | dst = dst.squeeze(-1) 266 | return dst 267 | 268 | 269 | def _expand_width( 270 | src: np.ndarray, 271 | delta_width: int, 272 | energy_mode: str, 273 | aux_energy: Optional[np.ndarray], 274 | step_ratio: float, 275 | ) -> Tuple[np.ndarray, Optional[np.ndarray]]: 276 | """Expand the width of image by delta_width pixels""" 277 | assert src.ndim in (2, 3) and delta_width >= 0 278 | if not 0 < step_ratio <= 1: 279 | raise ValueError(f"expect `step_ratio` to be between (0,1], got {step_ratio}") 280 | 281 | dst = src 282 | while delta_width > 0: 283 | max_step_size = max(1, round(step_ratio * dst.shape[1])) 284 | step_size = min(max_step_size, delta_width) 285 | gray = dst if dst.ndim == 2 else _rgb2gray(dst) 286 | seams = _get_seams(gray, step_size, energy_mode, aux_energy) 287 | dst = _insert_seams(dst, seams, step_size) 288 | if aux_energy is not None: 289 | aux_energy = _insert_seams(aux_energy, seams, step_size) 290 | delta_width -= step_size 291 | 292 | return dst, aux_energy 293 | 294 | 295 | def _resize_width( 296 | src: np.ndarray, 297 | width: int, 298 | energy_mode: str, 299 | aux_energy: Optional[np.ndarray], 300 | step_ratio: float, 301 | ) -> Tuple[np.ndarray, Optional[np.ndarray]]: 302 | """Resize the width of image by removing vertical seams""" 303 | assert src.size > 0 and src.ndim in (2, 3) 304 | assert width > 0 305 | 306 | src_w = src.shape[1] 307 | if src_w < width: 308 | dst, aux_energy = _expand_width( 309 | src, width - src_w, energy_mode, aux_energy, step_ratio 310 | ) 311 | else: 312 | dst, aux_energy = _reduce_width(src, src_w - width, energy_mode, aux_energy) 313 | return dst, aux_energy 314 | 315 | 316 | def _transpose_image(src: np.ndarray) -> np.ndarray: 317 | """Transpose a source image in rgb or grayscale format""" 318 | if src.ndim == 3: 319 | dst = src.transpose((1, 0, 2)) 320 | else: 321 | dst = src.T 322 | return dst 323 | 324 | 325 | def _resize_height( 326 | src: np.ndarray, 327 | height: int, 328 | energy_mode: str, 329 | aux_energy: Optional[np.ndarray], 330 | step_ratio: float, 331 | ) -> Tuple[np.ndarray, Optional[np.ndarray]]: 332 | """Resize the height of image by removing horizontal seams""" 333 | assert src.ndim in (2, 3) and height > 0 334 | if aux_energy is not None: 335 | aux_energy = aux_energy.T 336 | src = _transpose_image(src) 337 | src, aux_energy = _resize_width(src, height, energy_mode, aux_energy, step_ratio) 338 | src = _transpose_image(src) 339 | if aux_energy is not None: 340 | aux_energy = aux_energy.T 341 | return src, aux_energy 342 | 343 | 344 | def _check_mask(mask: np.ndarray, shape: Tuple[int, ...]) -> np.ndarray: 345 | """Ensure the mask to be a 2D grayscale map of specific shape""" 346 | mask = np.asarray(mask, dtype=bool) 347 | if mask.ndim != 2: 348 | raise ValueError(f"expect mask to be a 2d binary map, got shape {mask.shape}") 349 | if mask.shape != shape: 350 | raise ValueError( 351 | f"expect the shape of mask to match the image, got {mask.shape} vs {shape}" 352 | ) 353 | return mask 354 | 355 | 356 | def _check_src(src: np.ndarray) -> np.ndarray: 357 | """Ensure the source to be RGB or grayscale""" 358 | src = np.asarray(src) 359 | if src.size == 0 or src.ndim not in (2, 3): 360 | raise ValueError( 361 | f"expect a 3d rgb image or a 2d grayscale image, got image in shape {src.shape}" 362 | ) 363 | return src 364 | 365 | 366 | def seam_carving( 367 | src: np.ndarray, 368 | size: Optional[Tuple[int, int]] = None, 369 | energy_mode: str = "backward", 370 | order: str = "width-first", 371 | keep_mask: Optional[np.ndarray] = None, 372 | drop_mask: Optional[np.ndarray] = None, 373 | step_ratio: float = 0.5, 374 | ) -> np.ndarray: 375 | """Resize the image using the content-aware seam-carving algorithm. 376 | 377 | :param src: A source image in RGB or grayscale format. 378 | :param size: The target size in pixels, as a 2-tuple (width, height). 379 | :param energy_mode: Policy to compute energy for the source image. Could be 380 | one of ``backward`` or ``forward``. If ``backward``, compute the energy 381 | as the gradient at each pixel. If ``forward``, compute the energy as the 382 | distances between adjacent pixels after each pixel is removed. 383 | :param order: The order to remove horizontal and vertical seams. Could be 384 | one of ``width-first`` or ``height-first``. In ``width-first`` mode, we 385 | remove or insert all vertical seams first, then the horizontal ones, 386 | while ``height-first`` is the opposite. 387 | :param keep_mask: An optional mask where the foreground is protected from 388 | seam removal. If not specified, no area will be protected. 389 | :param drop_mask: An optional binary object mask to remove. If given, the 390 | object will be removed before resizing the image to the target size. 391 | :param step_ratio: The maximum size expansion ratio in one seam carving step. 392 | The image will be expanded in multiple steps if target size is too large. 393 | :return: A resized copy of the source image. 394 | """ 395 | src = _check_src(src) 396 | 397 | if order not in _list_enum(OrderMode): 398 | raise ValueError( 399 | f"expect order to be one of {_list_enum(OrderMode)}, got {order}" 400 | ) 401 | 402 | aux_energy = None 403 | 404 | if keep_mask is not None: 405 | keep_mask = _check_mask(keep_mask, src.shape[:2]) 406 | 407 | aux_energy = np.zeros(src.shape[:2], dtype=np.float32) 408 | aux_energy[keep_mask] += KEEP_MASK_ENERGY 409 | 410 | # remove object if `drop_mask` is given 411 | if drop_mask is not None: 412 | drop_mask = _check_mask(drop_mask, src.shape[:2]) 413 | 414 | if aux_energy is None: 415 | aux_energy = np.zeros(src.shape[:2], dtype=np.float32) 416 | aux_energy[drop_mask] -= DROP_MASK_ENERGY 417 | 418 | if order == OrderMode.HEIGHT_FIRST: 419 | src = _transpose_image(src) 420 | aux_energy = aux_energy.T 421 | 422 | num_seams = (aux_energy < 0).sum(1).max() 423 | while num_seams > 0: 424 | src, aux_energy = _reduce_width(src, num_seams, energy_mode, aux_energy) 425 | num_seams = (aux_energy < 0).sum(1).max() 426 | 427 | if order == OrderMode.HEIGHT_FIRST: 428 | src = _transpose_image(src) 429 | aux_energy = aux_energy.T 430 | 431 | # resize image if `size` is given 432 | if size is not None: 433 | width, height = size 434 | width = round(width) 435 | height = round(height) 436 | if width <= 0 or height <= 0: 437 | raise ValueError(f"expect target size to be positive, got {size}") 438 | 439 | if order == OrderMode.WIDTH_FIRST: 440 | src, aux_energy = _resize_width( 441 | src, width, energy_mode, aux_energy, step_ratio 442 | ) 443 | src, aux_energy = _resize_height( 444 | src, height, energy_mode, aux_energy, step_ratio 445 | ) 446 | else: 447 | src, aux_energy = _resize_height( 448 | src, height, energy_mode, aux_energy, step_ratio 449 | ) 450 | src, aux_energy = _resize_width( 451 | src, width, energy_mode, aux_energy, step_ratio 452 | ) 453 | 454 | return src 455 | -------------------------------------------------------------------------------- /workflow_all_nodes.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 42, 3 | "last_link_id": 61, 4 | "nodes": [ 5 | { 6 | "id": 9, 7 | "type": "ConsoleDebug+", 8 | "pos": [ 9 | 720, 10 | 140 11 | ], 12 | "size": { 13 | "0": 210, 14 | "1": 60 15 | }, 16 | "flags": {}, 17 | "order": 12, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "value", 22 | "type": "*", 23 | "link": 3 24 | } 25 | ], 26 | "properties": { 27 | "Node name for S&R": "ConsoleDebug+" 28 | }, 29 | "widgets_values": [ 30 | "Height:" 31 | ] 32 | }, 33 | { 34 | "id": 28, 35 | "type": "PreviewImage", 36 | "pos": [ 37 | 860, 38 | 1180 39 | ], 40 | "size": { 41 | "0": 210, 42 | "1": 246 43 | }, 44 | "flags": {}, 45 | "order": 17, 46 | "mode": 0, 47 | "inputs": [ 48 | { 49 | "name": "images", 50 | "type": "IMAGE", 51 | "link": 23 52 | } 53 | ], 54 | "properties": { 55 | "Node name for S&R": "PreviewImage" 56 | } 57 | }, 58 | { 59 | "id": 12, 60 | "type": "PreviewImage", 61 | "pos": [ 62 | 860, 63 | 580 64 | ], 65 | "size": { 66 | "0": 210, 67 | "1": 246 68 | }, 69 | "flags": {}, 70 | "order": 15, 71 | "mode": 0, 72 | "inputs": [ 73 | { 74 | "name": "images", 75 | "type": "IMAGE", 76 | "link": 11 77 | } 78 | ], 79 | "properties": { 80 | "Node name for S&R": "PreviewImage" 81 | } 82 | }, 83 | { 84 | "id": 14, 85 | "type": "PreviewImage", 86 | "pos": [ 87 | 860, 88 | 880 89 | ], 90 | "size": { 91 | "0": 210, 92 | "1": 246 93 | }, 94 | "flags": {}, 95 | "order": 16, 96 | "mode": 0, 97 | "inputs": [ 98 | { 99 | "name": "images", 100 | "type": "IMAGE", 101 | "link": 13 102 | } 103 | ], 104 | "properties": { 105 | "Node name for S&R": "PreviewImage" 106 | } 107 | }, 108 | { 109 | "id": 18, 110 | "type": "MaskPreview+", 111 | "pos": [ 112 | 2100, 113 | 90 114 | ], 115 | "size": { 116 | "0": 210, 117 | "1": 246 118 | }, 119 | "flags": {}, 120 | "order": 20, 121 | "mode": 0, 122 | "inputs": [ 123 | { 124 | "name": "mask", 125 | "type": "MASK", 126 | "link": 19 127 | } 128 | ], 129 | "properties": { 130 | "Node name for S&R": "MaskPreview+" 131 | } 132 | }, 133 | { 134 | "id": 1, 135 | "type": "GetImageSize+", 136 | "pos": [ 137 | 450, 138 | 80 139 | ], 140 | "size": { 141 | "0": 210, 142 | "1": 46 143 | }, 144 | "flags": {}, 145 | "order": 2, 146 | "mode": 0, 147 | "inputs": [ 148 | { 149 | "name": "image", 150 | "type": "IMAGE", 151 | "link": 1 152 | } 153 | ], 154 | "outputs": [ 155 | { 156 | "name": "width", 157 | "type": "INT", 158 | "links": [ 159 | 2 160 | ], 161 | "shape": 3, 162 | "slot_index": 0 163 | }, 164 | { 165 | "name": "height", 166 | "type": "INT", 167 | "links": [ 168 | 3 169 | ], 170 | "shape": 3, 171 | "slot_index": 1 172 | } 173 | ], 174 | "properties": { 175 | "Node name for S&R": "GetImageSize+" 176 | } 177 | }, 178 | { 179 | "id": 8, 180 | "type": "ConsoleDebug+", 181 | "pos": [ 182 | 720, 183 | 40 184 | ], 185 | "size": { 186 | "0": 210, 187 | "1": 60 188 | }, 189 | "flags": {}, 190 | "order": 11, 191 | "mode": 0, 192 | "inputs": [ 193 | { 194 | "name": "value", 195 | "type": "*", 196 | "link": 2 197 | } 198 | ], 199 | "properties": { 200 | "Node name for S&R": "ConsoleDebug+" 201 | }, 202 | "widgets_values": [ 203 | "Width:" 204 | ] 205 | }, 206 | { 207 | "id": 10, 208 | "type": "PreviewImage", 209 | "pos": [ 210 | 860, 211 | 280 212 | ], 213 | "size": { 214 | "0": 210, 215 | "1": 246 216 | }, 217 | "flags": {}, 218 | "order": 13, 219 | "mode": 0, 220 | "inputs": [ 221 | { 222 | "name": "images", 223 | "type": "IMAGE", 224 | "link": 9 225 | } 226 | ], 227 | "properties": { 228 | "Node name for S&R": "PreviewImage" 229 | } 230 | }, 231 | { 232 | "id": 36, 233 | "type": "SimpleMath+", 234 | "pos": [ 235 | 1650, 236 | 780 237 | ], 238 | "size": { 239 | "0": 210, 240 | "1": 80 241 | }, 242 | "flags": {}, 243 | "order": 14, 244 | "mode": 0, 245 | "inputs": [ 246 | { 247 | "name": "a", 248 | "type": "INT,FLOAT", 249 | "link": 44 250 | }, 251 | { 252 | "name": "b", 253 | "type": "INT,FLOAT", 254 | "link": 45 255 | } 256 | ], 257 | "outputs": [ 258 | { 259 | "name": "INT", 260 | "type": "INT", 261 | "links": [ 262 | 46 263 | ], 264 | "shape": 3, 265 | "slot_index": 0 266 | }, 267 | { 268 | "name": "FLOAT", 269 | "type": "FLOAT", 270 | "links": null, 271 | "shape": 3 272 | } 273 | ], 274 | "properties": { 275 | "Node name for S&R": "SimpleMath+" 276 | }, 277 | "widgets_values": [ 278 | "a*b" 279 | ] 280 | }, 281 | { 282 | "id": 23, 283 | "type": "ConsoleDebug+", 284 | "pos": [ 285 | 1920, 286 | 780 287 | ], 288 | "size": { 289 | "0": 210, 290 | "1": 60 291 | }, 292 | "flags": {}, 293 | "order": 22, 294 | "mode": 0, 295 | "inputs": [ 296 | { 297 | "name": "value", 298 | "type": "*", 299 | "link": 46 300 | } 301 | ], 302 | "properties": { 303 | "Node name for S&R": "ConsoleDebug+" 304 | }, 305 | "widgets_values": [ 306 | "Value:" 307 | ] 308 | }, 309 | { 310 | "id": 2, 311 | "type": "ImageResize+", 312 | "pos": [ 313 | 430, 314 | 340 315 | ], 316 | "size": { 317 | "0": 310, 318 | "1": 170 319 | }, 320 | "flags": {}, 321 | "order": 3, 322 | "mode": 0, 323 | "inputs": [ 324 | { 325 | "name": "image", 326 | "type": "IMAGE", 327 | "link": 4 328 | } 329 | ], 330 | "outputs": [ 331 | { 332 | "name": "IMAGE", 333 | "type": "IMAGE", 334 | "links": [ 335 | 9 336 | ], 337 | "shape": 3, 338 | "slot_index": 0 339 | }, 340 | { 341 | "name": "width", 342 | "type": "INT", 343 | "links": [ 344 | 44 345 | ], 346 | "shape": 3, 347 | "slot_index": 1 348 | }, 349 | { 350 | "name": "height", 351 | "type": "INT", 352 | "links": [ 353 | 45 354 | ], 355 | "shape": 3, 356 | "slot_index": 2 357 | } 358 | ], 359 | "properties": { 360 | "Node name for S&R": "ImageResize+" 361 | }, 362 | "widgets_values": [ 363 | 256, 364 | 64, 365 | "lanczos", 366 | true 367 | ] 368 | }, 369 | { 370 | "id": 4, 371 | "type": "ImageFlip+", 372 | "pos": [ 373 | 430, 374 | 800 375 | ], 376 | "size": { 377 | "0": 310, 378 | "1": 60 379 | }, 380 | "flags": {}, 381 | "order": 4, 382 | "mode": 0, 383 | "inputs": [ 384 | { 385 | "name": "image", 386 | "type": "IMAGE", 387 | "link": 6 388 | } 389 | ], 390 | "outputs": [ 391 | { 392 | "name": "IMAGE", 393 | "type": "IMAGE", 394 | "links": [ 395 | 11 396 | ], 397 | "shape": 3, 398 | "slot_index": 0 399 | } 400 | ], 401 | "properties": { 402 | "Node name for S&R": "ImageFlip+" 403 | }, 404 | "widgets_values": [ 405 | "xy" 406 | ] 407 | }, 408 | { 409 | "id": 6, 410 | "type": "ImagePosterize+", 411 | "pos": [ 412 | 430, 413 | 1000 414 | ], 415 | "size": { 416 | "0": 310, 417 | "1": 60 418 | }, 419 | "flags": {}, 420 | "order": 5, 421 | "mode": 0, 422 | "inputs": [ 423 | { 424 | "name": "image", 425 | "type": "IMAGE", 426 | "link": 8 427 | } 428 | ], 429 | "outputs": [ 430 | { 431 | "name": "IMAGE", 432 | "type": "IMAGE", 433 | "links": [ 434 | 13 435 | ], 436 | "shape": 3, 437 | "slot_index": 0 438 | } 439 | ], 440 | "properties": { 441 | "Node name for S&R": "ImagePosterize+" 442 | }, 443 | "widgets_values": [ 444 | 0.5 445 | ] 446 | }, 447 | { 448 | "id": 27, 449 | "type": "ImageCASharpening+", 450 | "pos": [ 451 | 430, 452 | 1110 453 | ], 454 | "size": { 455 | "0": 310.79998779296875, 456 | "1": 60 457 | }, 458 | "flags": {}, 459 | "order": 6, 460 | "mode": 0, 461 | "inputs": [ 462 | { 463 | "name": "image", 464 | "type": "IMAGE", 465 | "link": 22 466 | } 467 | ], 468 | "outputs": [ 469 | { 470 | "name": "IMAGE", 471 | "type": "IMAGE", 472 | "links": [ 473 | 23 474 | ], 475 | "shape": 3, 476 | "slot_index": 0 477 | } 478 | ], 479 | "properties": { 480 | "Node name for S&R": "ImageCASharpening+" 481 | }, 482 | "widgets_values": [ 483 | 0.8 484 | ] 485 | }, 486 | { 487 | "id": 15, 488 | "type": "MaskBlur+", 489 | "pos": [ 490 | 1690, 491 | 130 492 | ], 493 | "size": { 494 | "0": 310, 495 | "1": 82 496 | }, 497 | "flags": {}, 498 | "order": 9, 499 | "mode": 0, 500 | "inputs": [ 501 | { 502 | "name": "mask", 503 | "type": "MASK", 504 | "link": 14 505 | } 506 | ], 507 | "outputs": [ 508 | { 509 | "name": "MASK", 510 | "type": "MASK", 511 | "links": [ 512 | 19 513 | ], 514 | "shape": 3, 515 | "slot_index": 0 516 | } 517 | ], 518 | "properties": { 519 | "Node name for S&R": "MaskBlur+" 520 | }, 521 | "widgets_values": [ 522 | 45, 523 | 28.5 524 | ] 525 | }, 526 | { 527 | "id": 16, 528 | "type": "MaskFlip+", 529 | "pos": [ 530 | 1690, 531 | 270 532 | ], 533 | "size": { 534 | "0": 310, 535 | "1": 60 536 | }, 537 | "flags": {}, 538 | "order": 10, 539 | "mode": 0, 540 | "inputs": [ 541 | { 542 | "name": "mask", 543 | "type": "MASK", 544 | "link": 15 545 | } 546 | ], 547 | "outputs": [ 548 | { 549 | "name": "MASK", 550 | "type": "MASK", 551 | "links": [ 552 | 18 553 | ], 554 | "shape": 3, 555 | "slot_index": 0 556 | } 557 | ], 558 | "properties": { 559 | "Node name for S&R": "MaskFlip+" 560 | }, 561 | "widgets_values": [ 562 | "xy" 563 | ] 564 | }, 565 | { 566 | "id": 13, 567 | "type": "PreviewImage", 568 | "pos": [ 569 | 1100, 570 | 760 571 | ], 572 | "size": { 573 | "0": 210, 574 | "1": 246 575 | }, 576 | "flags": {}, 577 | "order": 18, 578 | "mode": 0, 579 | "inputs": [ 580 | { 581 | "name": "images", 582 | "type": "IMAGE", 583 | "link": 49 584 | } 585 | ], 586 | "properties": { 587 | "Node name for S&R": "PreviewImage" 588 | } 589 | }, 590 | { 591 | "id": 37, 592 | "type": "ImageDesaturate+", 593 | "pos": [ 594 | 500, 595 | 920 596 | ], 597 | "size": { 598 | "0": 190, 599 | "1": 30 600 | }, 601 | "flags": {}, 602 | "order": 7, 603 | "mode": 0, 604 | "inputs": [ 605 | { 606 | "name": "image", 607 | "type": "IMAGE", 608 | "link": 48 609 | } 610 | ], 611 | "outputs": [ 612 | { 613 | "name": "IMAGE", 614 | "type": "IMAGE", 615 | "links": [ 616 | 49 617 | ], 618 | "shape": 3, 619 | "slot_index": 0 620 | } 621 | ], 622 | "properties": { 623 | "Node name for S&R": "ImageDesaturate+" 624 | } 625 | }, 626 | { 627 | "id": 7, 628 | "type": "LoadImage", 629 | "pos": [ 630 | -90, 631 | 650 632 | ], 633 | "size": { 634 | "0": 315, 635 | "1": 314 636 | }, 637 | "flags": {}, 638 | "order": 0, 639 | "mode": 0, 640 | "outputs": [ 641 | { 642 | "name": "IMAGE", 643 | "type": "IMAGE", 644 | "links": [ 645 | 1, 646 | 4, 647 | 6, 648 | 8, 649 | 22, 650 | 48, 651 | 57 652 | ], 653 | "shape": 3, 654 | "slot_index": 0 655 | }, 656 | { 657 | "name": "MASK", 658 | "type": "MASK", 659 | "links": null, 660 | "shape": 3 661 | } 662 | ], 663 | "properties": { 664 | "Node name for S&R": "LoadImage" 665 | }, 666 | "widgets_values": [ 667 | "venere.jpg", 668 | "image" 669 | ] 670 | }, 671 | { 672 | "id": 11, 673 | "type": "PreviewImage", 674 | "pos": [ 675 | 1100, 676 | 450 677 | ], 678 | "size": { 679 | "0": 210, 680 | "1": 246 681 | }, 682 | "flags": {}, 683 | "order": 19, 684 | "mode": 0, 685 | "inputs": [ 686 | { 687 | "name": "images", 688 | "type": "IMAGE", 689 | "link": 58 690 | } 691 | ], 692 | "properties": { 693 | "Node name for S&R": "PreviewImage" 694 | } 695 | }, 696 | { 697 | "id": 40, 698 | "type": "ImageCrop+", 699 | "pos": [ 700 | 430, 701 | 560 702 | ], 703 | "size": { 704 | "0": 310, 705 | "1": 194 706 | }, 707 | "flags": {}, 708 | "order": 8, 709 | "mode": 0, 710 | "inputs": [ 711 | { 712 | "name": "image", 713 | "type": "IMAGE", 714 | "link": 57 715 | } 716 | ], 717 | "outputs": [ 718 | { 719 | "name": "IMAGE", 720 | "type": "IMAGE", 721 | "links": [ 722 | 58 723 | ], 724 | "shape": 3, 725 | "slot_index": 0 726 | }, 727 | { 728 | "name": "x", 729 | "type": "INT", 730 | "links": null, 731 | "shape": 3 732 | }, 733 | { 734 | "name": "y", 735 | "type": "INT", 736 | "links": null, 737 | "shape": 3 738 | } 739 | ], 740 | "properties": { 741 | "Node name for S&R": "ImageCrop+" 742 | }, 743 | "widgets_values": [ 744 | 256, 745 | 256, 746 | "center", 747 | 0, 748 | 0 749 | ] 750 | }, 751 | { 752 | "id": 20, 753 | "type": "LoadImageMask", 754 | "pos": [ 755 | 1400, 756 | 260 757 | ], 758 | "size": { 759 | "0": 220.70516967773438, 760 | "1": 318 761 | }, 762 | "flags": {}, 763 | "order": 1, 764 | "mode": 0, 765 | "outputs": [ 766 | { 767 | "name": "MASK", 768 | "type": "MASK", 769 | "links": [ 770 | 14, 771 | 15 772 | ], 773 | "shape": 3, 774 | "slot_index": 0 775 | } 776 | ], 777 | "properties": { 778 | "Node name for S&R": "LoadImageMask" 779 | }, 780 | "widgets_values": [ 781 | "cwf_inpaint_example_mask.png", 782 | "alpha", 783 | "image" 784 | ] 785 | }, 786 | { 787 | "id": 21, 788 | "type": "MaskPreview+", 789 | "pos": [ 790 | 2100, 791 | 380 792 | ], 793 | "size": { 794 | "0": 210, 795 | "1": 246 796 | }, 797 | "flags": {}, 798 | "order": 21, 799 | "mode": 0, 800 | "inputs": [ 801 | { 802 | "name": "mask", 803 | "type": "MASK", 804 | "link": 18 805 | } 806 | ], 807 | "properties": { 808 | "Node name for S&R": "MaskPreview+" 809 | } 810 | } 811 | ], 812 | "links": [ 813 | [ 814 | 1, 815 | 7, 816 | 0, 817 | 1, 818 | 0, 819 | "IMAGE" 820 | ], 821 | [ 822 | 2, 823 | 1, 824 | 0, 825 | 8, 826 | 0, 827 | "*" 828 | ], 829 | [ 830 | 3, 831 | 1, 832 | 1, 833 | 9, 834 | 0, 835 | "*" 836 | ], 837 | [ 838 | 4, 839 | 7, 840 | 0, 841 | 2, 842 | 0, 843 | "IMAGE" 844 | ], 845 | [ 846 | 6, 847 | 7, 848 | 0, 849 | 4, 850 | 0, 851 | "IMAGE" 852 | ], 853 | [ 854 | 8, 855 | 7, 856 | 0, 857 | 6, 858 | 0, 859 | "IMAGE" 860 | ], 861 | [ 862 | 9, 863 | 2, 864 | 0, 865 | 10, 866 | 0, 867 | "IMAGE" 868 | ], 869 | [ 870 | 11, 871 | 4, 872 | 0, 873 | 12, 874 | 0, 875 | "IMAGE" 876 | ], 877 | [ 878 | 13, 879 | 6, 880 | 0, 881 | 14, 882 | 0, 883 | "IMAGE" 884 | ], 885 | [ 886 | 14, 887 | 20, 888 | 0, 889 | 15, 890 | 0, 891 | "MASK" 892 | ], 893 | [ 894 | 15, 895 | 20, 896 | 0, 897 | 16, 898 | 0, 899 | "MASK" 900 | ], 901 | [ 902 | 18, 903 | 16, 904 | 0, 905 | 21, 906 | 0, 907 | "MASK" 908 | ], 909 | [ 910 | 19, 911 | 15, 912 | 0, 913 | 18, 914 | 0, 915 | "MASK" 916 | ], 917 | [ 918 | 22, 919 | 7, 920 | 0, 921 | 27, 922 | 0, 923 | "IMAGE" 924 | ], 925 | [ 926 | 23, 927 | 27, 928 | 0, 929 | 28, 930 | 0, 931 | "IMAGE" 932 | ], 933 | [ 934 | 44, 935 | 2, 936 | 1, 937 | 36, 938 | 0, 939 | "INT,FLOAT" 940 | ], 941 | [ 942 | 45, 943 | 2, 944 | 2, 945 | 36, 946 | 1, 947 | "INT,FLOAT" 948 | ], 949 | [ 950 | 46, 951 | 36, 952 | 0, 953 | 23, 954 | 0, 955 | "*" 956 | ], 957 | [ 958 | 48, 959 | 7, 960 | 0, 961 | 37, 962 | 0, 963 | "IMAGE" 964 | ], 965 | [ 966 | 49, 967 | 37, 968 | 0, 969 | 13, 970 | 0, 971 | "IMAGE" 972 | ], 973 | [ 974 | 57, 975 | 7, 976 | 0, 977 | 40, 978 | 0, 979 | "IMAGE" 980 | ], 981 | [ 982 | 58, 983 | 40, 984 | 0, 985 | 11, 986 | 0, 987 | "IMAGE" 988 | ] 989 | ], 990 | "groups": [], 991 | "config": {}, 992 | "extra": {}, 993 | "version": 0.4 994 | } -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from .utils import AnyType 4 | import comfy.model_management 5 | from nodes import MAX_RESOLUTION 6 | import time 7 | 8 | any = AnyType("*") 9 | 10 | class SimpleMathFloat: 11 | @classmethod 12 | def INPUT_TYPES(s): 13 | return { 14 | "required": { 15 | "value": ("FLOAT", { "default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.05 }), 16 | }, 17 | } 18 | 19 | RETURN_TYPES = ("FLOAT", ) 20 | FUNCTION = "execute" 21 | CATEGORY = "essentials/utilities" 22 | 23 | def execute(self, value): 24 | return (float(value), ) 25 | 26 | class SimpleMathPercent: 27 | @classmethod 28 | def INPUT_TYPES(s): 29 | return { 30 | "required": { 31 | "value": ("FLOAT", { "default": 0.0, "min": 0, "max": 1, "step": 0.05 }), 32 | }, 33 | } 34 | 35 | RETURN_TYPES = ("FLOAT", ) 36 | FUNCTION = "execute" 37 | CATEGORY = "essentials/utilities" 38 | 39 | def execute(self, value): 40 | return (float(value), ) 41 | 42 | class SimpleMathInt: 43 | @classmethod 44 | def INPUT_TYPES(s): 45 | return { 46 | "required": { 47 | "value": ("INT", { "default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1 }), 48 | }, 49 | } 50 | 51 | RETURN_TYPES = ("INT",) 52 | FUNCTION = "execute" 53 | CATEGORY = "essentials/utilities" 54 | 55 | def execute(self, value): 56 | return (int(value), ) 57 | 58 | class SimpleMathSlider: 59 | @classmethod 60 | def INPUT_TYPES(s): 61 | return { 62 | "required": { 63 | "value": ("FLOAT", { "display": "slider", "default": 0.5, "min": 0.0, "max": 1.0, "step": 0.001 }), 64 | "min": ("FLOAT", { "default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.001 }), 65 | "max": ("FLOAT", { "default": 1.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.001 }), 66 | "rounding": ("INT", { "default": 0, "min": 0, "max": 10, "step": 1 }), 67 | }, 68 | } 69 | 70 | RETURN_TYPES = ("FLOAT", "INT",) 71 | FUNCTION = "execute" 72 | CATEGORY = "essentials/utilities" 73 | 74 | def execute(self, value, min, max, rounding): 75 | value = min + value * (max - min) 76 | 77 | if rounding > 0: 78 | value = round(value, rounding) 79 | 80 | return (value, int(value), ) 81 | 82 | class SimpleMathSliderLowRes: 83 | @classmethod 84 | def INPUT_TYPES(s): 85 | return { 86 | "required": { 87 | "value": ("INT", { "display": "slider", "default": 5, "min": 0, "max": 10, "step": 1 }), 88 | "min": ("FLOAT", { "default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.001 }), 89 | "max": ("FLOAT", { "default": 1.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.001 }), 90 | "rounding": ("INT", { "default": 0, "min": 0, "max": 10, "step": 1 }), 91 | }, 92 | } 93 | 94 | RETURN_TYPES = ("FLOAT", "INT",) 95 | FUNCTION = "execute" 96 | CATEGORY = "essentials/utilities" 97 | 98 | def execute(self, value, min, max, rounding): 99 | value = 0.1 * value 100 | value = min + value * (max - min) 101 | if rounding > 0: 102 | value = round(value, rounding) 103 | 104 | return (value, ) 105 | 106 | class SimpleMathBoolean: 107 | @classmethod 108 | def INPUT_TYPES(s): 109 | return { 110 | "required": { 111 | "value": ("BOOLEAN", { "default": False }), 112 | }, 113 | } 114 | 115 | RETURN_TYPES = ("BOOLEAN",) 116 | FUNCTION = "execute" 117 | CATEGORY = "essentials/utilities" 118 | 119 | def execute(self, value): 120 | return (value, int(value), ) 121 | 122 | class SimpleMath: 123 | @classmethod 124 | def INPUT_TYPES(s): 125 | return { 126 | "optional": { 127 | "a": (any, { "default": 0.0 }), 128 | "b": (any, { "default": 0.0 }), 129 | "c": (any, { "default": 0.0 }), 130 | }, 131 | "required": { 132 | "value": ("STRING", { "multiline": False, "default": "" }), 133 | }, 134 | } 135 | 136 | RETURN_TYPES = ("INT", "FLOAT", ) 137 | FUNCTION = "execute" 138 | CATEGORY = "essentials/utilities" 139 | 140 | def execute(self, value, a = 0.0, b = 0.0, c = 0.0, d = 0.0): 141 | import ast 142 | import operator as op 143 | 144 | h, w = 0.0, 0.0 145 | if hasattr(a, 'shape'): 146 | a = list(a.shape) 147 | if hasattr(b, 'shape'): 148 | b = list(b.shape) 149 | if hasattr(c, 'shape'): 150 | c = list(c.shape) 151 | if hasattr(d, 'shape'): 152 | d = list(d.shape) 153 | 154 | if isinstance(a, str): 155 | a = float(a) 156 | if isinstance(b, str): 157 | b = float(b) 158 | if isinstance(c, str): 159 | c = float(c) 160 | if isinstance(d, str): 161 | d = float(d) 162 | 163 | operators = { 164 | ast.Add: op.add, 165 | ast.Sub: op.sub, 166 | ast.Mult: op.mul, 167 | ast.Div: op.truediv, 168 | ast.FloorDiv: op.floordiv, 169 | ast.Pow: op.pow, 170 | #ast.BitXor: op.xor, 171 | #ast.BitOr: op.or_, 172 | #ast.BitAnd: op.and_, 173 | ast.USub: op.neg, 174 | ast.Mod: op.mod, 175 | ast.Eq: op.eq, 176 | ast.NotEq: op.ne, 177 | ast.Lt: op.lt, 178 | ast.LtE: op.le, 179 | ast.Gt: op.gt, 180 | ast.GtE: op.ge, 181 | ast.And: lambda x, y: x and y, 182 | ast.Or: lambda x, y: x or y, 183 | ast.Not: op.not_ 184 | } 185 | 186 | op_functions = { 187 | 'min': min, 188 | 'max': max, 189 | 'round': round, 190 | 'sum': sum, 191 | 'len': len, 192 | } 193 | 194 | def eval_(node): 195 | if isinstance(node, ast.Num): # number 196 | return node.n 197 | elif isinstance(node, ast.Name): # variable 198 | if node.id == "a": 199 | return a 200 | if node.id == "b": 201 | return b 202 | if node.id == "c": 203 | return c 204 | if node.id == "d": 205 | return d 206 | elif isinstance(node, ast.BinOp): # 207 | return operators[type(node.op)](eval_(node.left), eval_(node.right)) 208 | elif isinstance(node, ast.UnaryOp): # e.g., -1 209 | return operators[type(node.op)](eval_(node.operand)) 210 | elif isinstance(node, ast.Compare): # comparison operators 211 | left = eval_(node.left) 212 | for op, comparator in zip(node.ops, node.comparators): 213 | if not operators[type(op)](left, eval_(comparator)): 214 | return 0 215 | return 1 216 | elif isinstance(node, ast.BoolOp): # boolean operators (And, Or) 217 | values = [eval_(value) for value in node.values] 218 | return operators[type(node.op)](*values) 219 | elif isinstance(node, ast.Call): # custom function 220 | if node.func.id in op_functions: 221 | args =[eval_(arg) for arg in node.args] 222 | return op_functions[node.func.id](*args) 223 | elif isinstance(node, ast.Subscript): # indexing or slicing 224 | value = eval_(node.value) 225 | if isinstance(node.slice, ast.Constant): 226 | return value[node.slice.value] 227 | else: 228 | return 0 229 | else: 230 | return 0 231 | 232 | result = eval_(ast.parse(value, mode='eval').body) 233 | 234 | if math.isnan(result): 235 | result = 0.0 236 | 237 | return (round(result), result, ) 238 | 239 | class SimpleMathDual: 240 | @classmethod 241 | def INPUT_TYPES(s): 242 | return { 243 | "optional": { 244 | "a": (any, { "default": 0.0 }), 245 | "b": (any, { "default": 0.0 }), 246 | "c": (any, { "default": 0.0 }), 247 | "d": (any, { "default": 0.0 }), 248 | }, 249 | "required": { 250 | "value_1": ("STRING", { "multiline": False, "default": "" }), 251 | "value_2": ("STRING", { "multiline": False, "default": "" }), 252 | }, 253 | } 254 | 255 | RETURN_TYPES = ("INT", "FLOAT", "INT", "FLOAT", ) 256 | RETURN_NAMES = ("int_1", "float_1", "int_2", "float_2" ) 257 | FUNCTION = "execute" 258 | CATEGORY = "essentials/utilities" 259 | 260 | def execute(self, value_1, value_2, a = 0.0, b = 0.0, c = 0.0, d = 0.0): 261 | return SimpleMath().execute(value_1, a, b, c, d) + SimpleMath().execute(value_2, a, b, c, d) 262 | 263 | class SimpleMathCondition: 264 | @classmethod 265 | def INPUT_TYPES(s): 266 | return { 267 | "optional": { 268 | "a": (any, { "default": 0.0 }), 269 | "b": (any, { "default": 0.0 }), 270 | "c": (any, { "default": 0.0 }), 271 | }, 272 | "required": { 273 | "evaluate": (any, {"default": 0}), 274 | "on_true": ("STRING", { "multiline": False, "default": "" }), 275 | "on_false": ("STRING", { "multiline": False, "default": "" }), 276 | }, 277 | } 278 | 279 | RETURN_TYPES = ("INT", "FLOAT", ) 280 | FUNCTION = "execute" 281 | CATEGORY = "essentials/utilities" 282 | 283 | def execute(self, evaluate, on_true, on_false, a = 0.0, b = 0.0, c = 0.0): 284 | return SimpleMath().execute(on_true if evaluate else on_false, a, b, c) 285 | 286 | class SimpleCondition: 287 | def __init__(self): 288 | pass 289 | 290 | @classmethod 291 | def INPUT_TYPES(cls): 292 | return { 293 | "required": { 294 | "evaluate": (any, {"default": 0}), 295 | "on_true": (any, {"default": 0}), 296 | }, 297 | "optional": { 298 | "on_false": (any, {"default": None}), 299 | }, 300 | } 301 | 302 | RETURN_TYPES = (any,) 303 | RETURN_NAMES = ("result",) 304 | FUNCTION = "execute" 305 | 306 | CATEGORY = "essentials/utilities" 307 | 308 | def execute(self, evaluate, on_true, on_false=None): 309 | from comfy_execution.graph import ExecutionBlocker 310 | if not evaluate: 311 | return (on_false if on_false is not None else ExecutionBlocker(None),) 312 | 313 | return (on_true,) 314 | 315 | class SimpleComparison: 316 | def __init__(self): 317 | pass 318 | 319 | @classmethod 320 | def INPUT_TYPES(cls): 321 | return { 322 | "required": { 323 | "a": (any, {"default": 0}), 324 | "b": (any, {"default": 0}), 325 | "comparison": (["==", "!=", "<", "<=", ">", ">="],), 326 | }, 327 | } 328 | 329 | RETURN_TYPES = ("BOOLEAN",) 330 | FUNCTION = "execute" 331 | 332 | CATEGORY = "essentials/utilities" 333 | 334 | def execute(self, a, b, comparison): 335 | if comparison == "==": 336 | return (a == b,) 337 | elif comparison == "!=": 338 | return (a != b,) 339 | elif comparison == "<": 340 | return (a < b,) 341 | elif comparison == "<=": 342 | return (a <= b,) 343 | elif comparison == ">": 344 | return (a > b,) 345 | elif comparison == ">=": 346 | return (a >= b,) 347 | 348 | class ConsoleDebug: 349 | @classmethod 350 | def INPUT_TYPES(s): 351 | return { 352 | "required": { 353 | "value": (any, {}), 354 | }, 355 | "optional": { 356 | "prefix": ("STRING", { "multiline": False, "default": "Value:" }) 357 | } 358 | } 359 | 360 | RETURN_TYPES = () 361 | FUNCTION = "execute" 362 | CATEGORY = "essentials/utilities" 363 | OUTPUT_NODE = True 364 | 365 | def execute(self, value, prefix): 366 | print(f"\033[96m{prefix} {value}\033[0m") 367 | 368 | return (None,) 369 | 370 | class DebugTensorShape: 371 | @classmethod 372 | def INPUT_TYPES(s): 373 | return { 374 | "required": { 375 | "tensor": (any, {}), 376 | }, 377 | } 378 | 379 | RETURN_TYPES = () 380 | FUNCTION = "execute" 381 | CATEGORY = "essentials/utilities" 382 | OUTPUT_NODE = True 383 | 384 | def execute(self, tensor): 385 | shapes = [] 386 | def tensorShape(tensor): 387 | if isinstance(tensor, dict): 388 | for k in tensor: 389 | tensorShape(tensor[k]) 390 | elif isinstance(tensor, list): 391 | for i in range(len(tensor)): 392 | tensorShape(tensor[i]) 393 | elif hasattr(tensor, 'shape'): 394 | shapes.append(list(tensor.shape)) 395 | 396 | tensorShape(tensor) 397 | 398 | print(f"\033[96mShapes found: {shapes}\033[0m") 399 | 400 | return (None,) 401 | 402 | class BatchCount: 403 | @classmethod 404 | def INPUT_TYPES(s): 405 | return { 406 | "required": { 407 | "batch": (any, {}), 408 | }, 409 | } 410 | 411 | RETURN_TYPES = ("INT",) 412 | FUNCTION = "execute" 413 | CATEGORY = "essentials/utilities" 414 | 415 | def execute(self, batch): 416 | count = 0 417 | if hasattr(batch, 'shape'): 418 | count = batch.shape[0] 419 | elif isinstance(batch, dict) and 'samples' in batch: 420 | count = batch['samples'].shape[0] 421 | elif isinstance(batch, list) or isinstance(batch, dict): 422 | count = len(batch) 423 | 424 | return (count, ) 425 | 426 | class ModelCompile(): 427 | @classmethod 428 | def INPUT_TYPES(s): 429 | return { 430 | "required": { 431 | "model": ("MODEL",), 432 | "fullgraph": ("BOOLEAN", { "default": False }), 433 | "dynamic": ("BOOLEAN", { "default": False }), 434 | "mode": (["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"],), 435 | }, 436 | } 437 | 438 | RETURN_TYPES = ("MODEL", ) 439 | FUNCTION = "execute" 440 | CATEGORY = "essentials/utilities" 441 | 442 | def execute(self, model, fullgraph, dynamic, mode): 443 | work_model = model.clone() 444 | torch._dynamo.config.suppress_errors = True 445 | work_model.add_object_patch("diffusion_model", torch.compile(model=work_model.get_model_object("diffusion_model"), dynamic=dynamic, fullgraph=fullgraph, mode=mode)) 446 | return (work_model, ) 447 | 448 | class RemoveLatentMask: 449 | @classmethod 450 | def INPUT_TYPES(s): 451 | return {"required": { "samples": ("LATENT",),}} 452 | RETURN_TYPES = ("LATENT",) 453 | FUNCTION = "execute" 454 | 455 | CATEGORY = "essentials/utilities" 456 | 457 | def execute(self, samples): 458 | s = samples.copy() 459 | if "noise_mask" in s: 460 | del s["noise_mask"] 461 | 462 | return (s,) 463 | 464 | class SDXLEmptyLatentSizePicker: 465 | def __init__(self): 466 | self.device = comfy.model_management.intermediate_device() 467 | 468 | @classmethod 469 | def INPUT_TYPES(s): 470 | return {"required": { 471 | "resolution": (["704x1408 (0.5)","704x1344 (0.52)","768x1344 (0.57)","768x1280 (0.6)","832x1216 (0.68)","832x1152 (0.72)","896x1152 (0.78)","896x1088 (0.82)","960x1088 (0.88)","960x1024 (0.94)","1024x1024 (1.0)","1024x960 (1.07)","1088x960 (1.13)","1088x896 (1.21)","1152x896 (1.29)","1152x832 (1.38)","1216x832 (1.46)","1280x768 (1.67)","1344x768 (1.75)","1344x704 (1.91)","1408x704 (2.0)","1472x704 (2.09)","1536x640 (2.4)","1600x640 (2.5)","1664x576 (2.89)","1728x576 (3.0)",], {"default": "1024x1024 (1.0)"}), 472 | "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), 473 | "width_override": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), 474 | "height_override": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), 475 | }} 476 | 477 | RETURN_TYPES = ("LATENT","INT","INT",) 478 | RETURN_NAMES = ("LATENT","width","height",) 479 | FUNCTION = "execute" 480 | CATEGORY = "essentials/utilities" 481 | 482 | def execute(self, resolution, batch_size, width_override=0, height_override=0): 483 | width, height = resolution.split(" ")[0].split("x") 484 | width = width_override if width_override > 0 else int(width) 485 | height = height_override if height_override > 0 else int(height) 486 | 487 | latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device) 488 | 489 | return ({"samples":latent}, width, height,) 490 | 491 | class DisplayAny: 492 | def __init__(self): 493 | pass 494 | 495 | @classmethod 496 | def INPUT_TYPES(s): 497 | return { 498 | "required": { 499 | "input": (("*",{})), 500 | "mode": (["raw value", "tensor shape"],), 501 | }, 502 | } 503 | 504 | @classmethod 505 | def VALIDATE_INPUTS(s, input_types): 506 | return True 507 | 508 | RETURN_TYPES = ("STRING",) 509 | FUNCTION = "execute" 510 | OUTPUT_NODE = True 511 | 512 | CATEGORY = "essentials/utilities" 513 | 514 | def execute(self, input, mode): 515 | if mode == "tensor shape": 516 | text = [] 517 | def tensorShape(tensor): 518 | if isinstance(tensor, dict): 519 | for k in tensor: 520 | tensorShape(tensor[k]) 521 | elif isinstance(tensor, list): 522 | for i in range(len(tensor)): 523 | tensorShape(tensor[i]) 524 | elif hasattr(tensor, 'shape'): 525 | text.append(list(tensor.shape)) 526 | 527 | tensorShape(input) 528 | input = text 529 | 530 | text = str(input) 531 | 532 | return {"ui": {"text": text}, "result": (text,)} 533 | 534 | MISC_CLASS_MAPPINGS = { 535 | "BatchCount+": BatchCount, 536 | "ConsoleDebug+": ConsoleDebug, 537 | "DebugTensorShape+": DebugTensorShape, 538 | "DisplayAny": DisplayAny, 539 | "ModelCompile+": ModelCompile, 540 | "RemoveLatentMask+": RemoveLatentMask, 541 | "SDXLEmptyLatentSizePicker+": SDXLEmptyLatentSizePicker, 542 | "SimpleComparison+": SimpleComparison, 543 | "SimpleCondition+": SimpleCondition, 544 | "SimpleMath+": SimpleMath, 545 | "SimpleMathDual+": SimpleMathDual, 546 | "SimpleMathCondition+": SimpleMathCondition, 547 | "SimpleMathBoolean+": SimpleMathBoolean, 548 | "SimpleMathFloat+": SimpleMathFloat, 549 | "SimpleMathInt+": SimpleMathInt, 550 | "SimpleMathPercent+": SimpleMathPercent, 551 | "SimpleMathSlider+": SimpleMathSlider, 552 | "SimpleMathSliderLowRes+": SimpleMathSliderLowRes, 553 | } 554 | 555 | MISC_NAME_MAPPINGS = { 556 | "BatchCount+": "🔧 Batch Count", 557 | "ConsoleDebug+": "🔧 Console Debug", 558 | "DebugTensorShape+": "🔧 Debug Tensor Shape", 559 | "DisplayAny": "🔧 Display Any", 560 | "ModelCompile+": "🔧 Model Compile", 561 | "RemoveLatentMask+": "🔧 Remove Latent Mask", 562 | "SDXLEmptyLatentSizePicker+": "🔧 Empty Latent Size Picker", 563 | "SimpleComparison+": "🔧 Simple Comparison", 564 | "SimpleCondition+": "🔧 Simple Condition", 565 | "SimpleMath+": "🔧 Simple Math", 566 | "SimpleMathDual+": "🔧 Simple Math Dual", 567 | "SimpleMathCondition+": "🔧 Simple Math Condition", 568 | "SimpleMathBoolean+": "🔧 Simple Math Boolean", 569 | "SimpleMathFloat+": "🔧 Simple Math Float", 570 | "SimpleMathInt+": "🔧 Simple Math Int", 571 | "SimpleMathPercent+": "🔧 Simple Math Percent", 572 | "SimpleMathSlider+": "🔧 Simple Math Slider", 573 | "SimpleMathSliderLowRes+": "🔧 Simple Math Slider low-res", 574 | } -------------------------------------------------------------------------------- /mask.py: -------------------------------------------------------------------------------- 1 | from nodes import SaveImage 2 | import torch 3 | import torchvision.transforms.v2 as T 4 | import random 5 | import folder_paths 6 | import comfy.utils 7 | from .image import ImageExpandBatch 8 | from .utils import AnyType 9 | import numpy as np 10 | import scipy 11 | from PIL import Image 12 | from nodes import MAX_RESOLUTION 13 | import math 14 | 15 | any = AnyType("*") 16 | 17 | class MaskBlur: 18 | @classmethod 19 | def INPUT_TYPES(s): 20 | return { 21 | "required": { 22 | "mask": ("MASK",), 23 | "amount": ("INT", { "default": 6, "min": 0, "max": 256, "step": 1, }), 24 | "device": (["auto", "cpu", "gpu"],), 25 | } 26 | } 27 | 28 | RETURN_TYPES = ("MASK",) 29 | FUNCTION = "execute" 30 | CATEGORY = "essentials/mask" 31 | 32 | def execute(self, mask, amount, device): 33 | if amount == 0: 34 | return (mask,) 35 | 36 | if "gpu" == device: 37 | mask = mask.to(comfy.model_management.get_torch_device()) 38 | elif "cpu" == device: 39 | mask = mask.to('cpu') 40 | 41 | if amount % 2 == 0: 42 | amount+= 1 43 | 44 | if mask.dim() == 2: 45 | mask = mask.unsqueeze(0) 46 | 47 | mask = T.functional.gaussian_blur(mask.unsqueeze(1), amount).squeeze(1) 48 | 49 | if "gpu" == device or "cpu" == device: 50 | mask = mask.to(comfy.model_management.intermediate_device()) 51 | 52 | return(mask,) 53 | 54 | class MaskFlip: 55 | @classmethod 56 | def INPUT_TYPES(s): 57 | return { 58 | "required": { 59 | "mask": ("MASK",), 60 | "axis": (["x", "y", "xy"],), 61 | } 62 | } 63 | 64 | RETURN_TYPES = ("MASK",) 65 | FUNCTION = "execute" 66 | CATEGORY = "essentials/mask" 67 | 68 | def execute(self, mask, axis): 69 | if mask.dim() == 2: 70 | mask = mask.unsqueeze(0) 71 | 72 | dim = () 73 | if "y" in axis: 74 | dim += (1,) 75 | if "x" in axis: 76 | dim += (2,) 77 | mask = torch.flip(mask, dims=dim) 78 | 79 | return(mask,) 80 | 81 | class MaskPreview(SaveImage): 82 | def __init__(self): 83 | self.output_dir = folder_paths.get_temp_directory() 84 | self.type = "temp" 85 | self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) 86 | self.compress_level = 4 87 | 88 | @classmethod 89 | def INPUT_TYPES(s): 90 | return { 91 | "required": {"mask": ("MASK",), }, 92 | "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, 93 | } 94 | 95 | FUNCTION = "execute" 96 | CATEGORY = "essentials/mask" 97 | 98 | def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): 99 | preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) 100 | return self.save_images(preview, filename_prefix, prompt, extra_pnginfo) 101 | 102 | class MaskBatch: 103 | @classmethod 104 | def INPUT_TYPES(s): 105 | return { 106 | "required": { 107 | "mask1": ("MASK",), 108 | "mask2": ("MASK",), 109 | } 110 | } 111 | 112 | RETURN_TYPES = ("MASK",) 113 | FUNCTION = "execute" 114 | CATEGORY = "essentials/mask batch" 115 | 116 | def execute(self, mask1, mask2): 117 | if mask1.shape[1:] != mask2.shape[1:]: 118 | mask2 = comfy.utils.common_upscale(mask2.unsqueeze(1).expand(-1,3,-1,-1), mask1.shape[2], mask1.shape[1], upscale_method='bicubic', crop='center')[:,0,:,:] 119 | 120 | return (torch.cat((mask1, mask2), dim=0),) 121 | 122 | class MaskExpandBatch: 123 | @classmethod 124 | def INPUT_TYPES(s): 125 | return { 126 | "required": { 127 | "mask": ("MASK",), 128 | "size": ("INT", { "default": 16, "min": 1, "step": 1, }), 129 | "method": (["expand", "repeat all", "repeat first", "repeat last"],) 130 | } 131 | } 132 | 133 | RETURN_TYPES = ("MASK",) 134 | FUNCTION = "execute" 135 | CATEGORY = "essentials/mask batch" 136 | 137 | def execute(self, mask, size, method): 138 | return (ImageExpandBatch().execute(mask.unsqueeze(1).expand(-1,3,-1,-1), size, method)[0][:,0,:,:],) 139 | 140 | 141 | class MaskBoundingBox: 142 | @classmethod 143 | def INPUT_TYPES(s): 144 | return { 145 | "required": { 146 | "mask": ("MASK",), 147 | "padding": ("INT", { "default": 0, "min": 0, "max": 4096, "step": 1, }), 148 | "blur": ("INT", { "default": 0, "min": 0, "max": 256, "step": 1, }), 149 | }, 150 | "optional": { 151 | "image_optional": ("IMAGE",), 152 | } 153 | } 154 | 155 | RETURN_TYPES = ("MASK", "IMAGE", "INT", "INT", "INT", "INT") 156 | RETURN_NAMES = ("MASK", "IMAGE", "x", "y", "width", "height") 157 | FUNCTION = "execute" 158 | CATEGORY = "essentials/mask" 159 | 160 | def execute(self, mask, padding, blur, image_optional=None): 161 | if mask.dim() == 2: 162 | mask = mask.unsqueeze(0) 163 | 164 | if image_optional is None: 165 | image_optional = mask.unsqueeze(3).repeat(1, 1, 1, 3) 166 | 167 | # resize the image if it's not the same size as the mask 168 | if image_optional.shape[1:] != mask.shape[1:]: 169 | image_optional = comfy.utils.common_upscale(image_optional.permute([0,3,1,2]), mask.shape[2], mask.shape[1], upscale_method='bicubic', crop='center').permute([0,2,3,1]) 170 | 171 | # match batch size 172 | if image_optional.shape[0] < mask.shape[0]: 173 | image_optional = torch.cat((image_optional, image_optional[-1].unsqueeze(0).repeat(mask.shape[0]-image_optional.shape[0], 1, 1, 1)), dim=0) 174 | elif image_optional.shape[0] > mask.shape[0]: 175 | image_optional = image_optional[:mask.shape[0]] 176 | 177 | # blur the mask 178 | if blur > 0: 179 | if blur % 2 == 0: 180 | blur += 1 181 | mask = T.functional.gaussian_blur(mask.unsqueeze(1), blur).squeeze(1) 182 | 183 | _, y, x = torch.where(mask) 184 | x1 = max(0, x.min().item() - padding) 185 | x2 = min(mask.shape[2], x.max().item() + 1 + padding) 186 | y1 = max(0, y.min().item() - padding) 187 | y2 = min(mask.shape[1], y.max().item() + 1 + padding) 188 | 189 | # crop the mask 190 | mask = mask[:, y1:y2, x1:x2] 191 | image_optional = image_optional[:, y1:y2, x1:x2, :] 192 | 193 | return (mask, image_optional, x1, y1, x2 - x1, y2 - y1) 194 | 195 | 196 | class MaskFromColor: 197 | @classmethod 198 | def INPUT_TYPES(s): 199 | return { 200 | "required": { 201 | "image": ("IMAGE", ), 202 | "red": ("INT", { "default": 255, "min": 0, "max": 255, "step": 1, }), 203 | "green": ("INT", { "default": 255, "min": 0, "max": 255, "step": 1, }), 204 | "blue": ("INT", { "default": 255, "min": 0, "max": 255, "step": 1, }), 205 | "threshold": ("INT", { "default": 0, "min": 0, "max": 127, "step": 1, }), 206 | } 207 | } 208 | 209 | RETURN_TYPES = ("MASK",) 210 | FUNCTION = "execute" 211 | CATEGORY = "essentials/mask" 212 | 213 | def execute(self, image, red, green, blue, threshold): 214 | temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int) 215 | color = torch.tensor([red, green, blue]) 216 | lower_bound = (color - threshold).clamp(min=0) 217 | upper_bound = (color + threshold).clamp(max=255) 218 | lower_bound = lower_bound.view(1, 1, 1, 3) 219 | upper_bound = upper_bound.view(1, 1, 1, 3) 220 | mask = (temp >= lower_bound) & (temp <= upper_bound) 221 | mask = mask.all(dim=-1) 222 | mask = mask.float() 223 | 224 | return (mask, ) 225 | 226 | 227 | class MaskFromSegmentation: 228 | @classmethod 229 | def INPUT_TYPES(s): 230 | return { 231 | "required": { 232 | "image": ("IMAGE", ), 233 | "segments": ("INT", { "default": 6, "min": 1, "max": 16, "step": 1, }), 234 | "remove_isolated_pixels": ("INT", { "default": 0, "min": 0, "max": 32, "step": 1, }), 235 | "remove_small_masks": ("FLOAT", { "default": 0.0, "min": 0., "max": 1., "step": 0.01, }), 236 | "fill_holes": ("BOOLEAN", { "default": False }), 237 | } 238 | } 239 | 240 | RETURN_TYPES = ("MASK",) 241 | FUNCTION = "execute" 242 | CATEGORY = "essentials/mask" 243 | 244 | def execute(self, image, segments, remove_isolated_pixels, fill_holes, remove_small_masks): 245 | im = image[0] # we only work on the first image in the batch 246 | im = Image.fromarray((im * 255).to(torch.uint8).cpu().numpy(), mode="RGB") 247 | im = im.quantize(palette=im.quantize(colors=segments), dither=Image.Dither.NONE) 248 | im = torch.tensor(np.array(im.convert("RGB"))).float() / 255.0 249 | 250 | colors = im.reshape(-1, im.shape[-1]) 251 | colors = torch.unique(colors, dim=0) 252 | 253 | masks = [] 254 | for color in colors: 255 | mask = (im == color).all(dim=-1).float() 256 | # remove isolated pixels 257 | if remove_isolated_pixels > 0: 258 | mask = torch.from_numpy(scipy.ndimage.binary_opening(mask.cpu().numpy(), structure=np.ones((remove_isolated_pixels, remove_isolated_pixels)))) 259 | 260 | # fill holes 261 | if fill_holes: 262 | mask = torch.from_numpy(scipy.ndimage.binary_fill_holes(mask.cpu().numpy())) 263 | 264 | # if the mask is too small, it's probably noise 265 | if mask.sum() / (mask.shape[0]*mask.shape[1]) > remove_small_masks: 266 | masks.append(mask) 267 | 268 | if masks == []: 269 | masks.append(torch.zeros_like(im)[:,:,0]) # return an empty mask if no masks were found, prevents errors 270 | 271 | mask = torch.stack(masks, dim=0).float() 272 | 273 | return (mask, ) 274 | 275 | 276 | class MaskFix: 277 | @classmethod 278 | def INPUT_TYPES(s): 279 | return { 280 | "required": { 281 | "mask": ("MASK",), 282 | "erode_dilate": ("INT", { "default": 0, "min": -256, "max": 256, "step": 1, }), 283 | "fill_holes": ("INT", { "default": 0, "min": 0, "max": 128, "step": 1, }), 284 | "remove_isolated_pixels": ("INT", { "default": 0, "min": 0, "max": 32, "step": 1, }), 285 | "smooth": ("INT", { "default": 0, "min": 0, "max": 256, "step": 1, }), 286 | "blur": ("INT", { "default": 0, "min": 0, "max": 256, "step": 1, }), 287 | } 288 | } 289 | 290 | RETURN_TYPES = ("MASK",) 291 | FUNCTION = "execute" 292 | CATEGORY = "essentials/mask" 293 | 294 | def execute(self, mask, erode_dilate, smooth, remove_isolated_pixels, blur, fill_holes): 295 | masks = [] 296 | for m in mask: 297 | # erode and dilate 298 | if erode_dilate != 0: 299 | if erode_dilate < 0: 300 | m = torch.from_numpy(scipy.ndimage.grey_erosion(m.cpu().numpy(), size=(-erode_dilate, -erode_dilate))) 301 | else: 302 | m = torch.from_numpy(scipy.ndimage.grey_dilation(m.cpu().numpy(), size=(erode_dilate, erode_dilate))) 303 | 304 | # fill holes 305 | if fill_holes > 0: 306 | #m = torch.from_numpy(scipy.ndimage.binary_fill_holes(m.cpu().numpy(), structure=np.ones((fill_holes,fill_holes)))).float() 307 | m = torch.from_numpy(scipy.ndimage.grey_closing(m.cpu().numpy(), size=(fill_holes, fill_holes))) 308 | 309 | # remove isolated pixels 310 | if remove_isolated_pixels > 0: 311 | m = torch.from_numpy(scipy.ndimage.grey_opening(m.cpu().numpy(), size=(remove_isolated_pixels, remove_isolated_pixels))) 312 | 313 | # smooth the mask 314 | if smooth > 0: 315 | if smooth % 2 == 0: 316 | smooth += 1 317 | m = T.functional.gaussian_blur((m > 0.5).unsqueeze(0), smooth).squeeze(0) 318 | 319 | # blur the mask 320 | if blur > 0: 321 | if blur % 2 == 0: 322 | blur += 1 323 | m = T.functional.gaussian_blur(m.float().unsqueeze(0), blur).squeeze(0) 324 | 325 | masks.append(m.float()) 326 | 327 | masks = torch.stack(masks, dim=0).float() 328 | 329 | return (masks, ) 330 | 331 | class MaskSmooth: 332 | @classmethod 333 | def INPUT_TYPES(s): 334 | return { 335 | "required": { 336 | "mask": ("MASK",), 337 | "amount": ("INT", { "default": 0, "min": 0, "max": 127, "step": 1, }), 338 | } 339 | } 340 | 341 | RETURN_TYPES = ("MASK",) 342 | FUNCTION = "execute" 343 | CATEGORY = "essentials/mask" 344 | 345 | def execute(self, mask, amount): 346 | if amount == 0: 347 | return (mask,) 348 | 349 | if amount % 2 == 0: 350 | amount += 1 351 | 352 | mask = mask > 0.5 353 | mask = T.functional.gaussian_blur(mask.unsqueeze(1), amount).squeeze(1).float() 354 | 355 | return (mask,) 356 | 357 | class MaskFromBatch: 358 | @classmethod 359 | def INPUT_TYPES(s): 360 | return { 361 | "required": { 362 | "mask": ("MASK", ), 363 | "start": ("INT", { "default": 0, "min": 0, "step": 1, }), 364 | "length": ("INT", { "default": 1, "min": 1, "step": 1, }), 365 | } 366 | } 367 | 368 | RETURN_TYPES = ("MASK",) 369 | FUNCTION = "execute" 370 | CATEGORY = "essentials/mask batch" 371 | 372 | def execute(self, mask, start, length): 373 | if length > mask.shape[0]: 374 | length = mask.shape[0] 375 | 376 | start = min(start, mask.shape[0]-1) 377 | length = min(mask.shape[0]-start, length) 378 | return (mask[start:start + length], ) 379 | 380 | class MaskFromList: 381 | @classmethod 382 | def INPUT_TYPES(s): 383 | return { 384 | "required": { 385 | "width": ("INT", { "default": 32, "min": 0, "max": MAX_RESOLUTION, "step": 8, }), 386 | "height": ("INT", { "default": 32, "min": 0, "max": MAX_RESOLUTION, "step": 8, }), 387 | }, "optional": { 388 | "values": (any, { "default": 0.0, "min": 0.0, "max": 1.0, }), 389 | "str_values": ("STRING", { "default": "", "multiline": True, "placeholder": "0.0, 0.5, 1.0",}), 390 | } 391 | } 392 | 393 | RETURN_TYPES = ("MASK",) 394 | FUNCTION = "execute" 395 | CATEGORY = "essentials/mask" 396 | 397 | def execute(self, width, height, values=None, str_values=""): 398 | out = [] 399 | 400 | if values is not None: 401 | if not isinstance(values, list): 402 | out = [values] 403 | else: 404 | out.extend([float(v) for v in values]) 405 | 406 | if str_values != "": 407 | str_values = [float(v) for v in str_values.split(",")] 408 | out.extend(str_values) 409 | 410 | if out == []: 411 | raise ValueError("No values provided") 412 | 413 | out = torch.tensor(out).float().clamp(0.0, 1.0) 414 | out = out.view(-1, 1, 1).expand(-1, height, width) 415 | 416 | values = None 417 | str_values = "" 418 | 419 | return (out, ) 420 | 421 | class MaskFromRGBCMYBW: 422 | @classmethod 423 | def INPUT_TYPES(s): 424 | return { 425 | "required": { 426 | "image": ("IMAGE", ), 427 | "threshold_r": ("FLOAT", { "default": 0.15, "min": 0.0, "max": 1, "step": 0.01, }), 428 | "threshold_g": ("FLOAT", { "default": 0.15, "min": 0.0, "max": 1, "step": 0.01, }), 429 | "threshold_b": ("FLOAT", { "default": 0.15, "min": 0.0, "max": 1, "step": 0.01, }), 430 | } 431 | } 432 | 433 | RETURN_TYPES = ("MASK","MASK","MASK","MASK","MASK","MASK","MASK","MASK",) 434 | RETURN_NAMES = ("red","green","blue","cyan","magenta","yellow","black","white",) 435 | FUNCTION = "execute" 436 | CATEGORY = "essentials/mask" 437 | 438 | def execute(self, image, threshold_r, threshold_g, threshold_b): 439 | red = ((image[..., 0] >= 1-threshold_r) & (image[..., 1] < threshold_g) & (image[..., 2] < threshold_b)).float() 440 | green = ((image[..., 0] < threshold_r) & (image[..., 1] >= 1-threshold_g) & (image[..., 2] < threshold_b)).float() 441 | blue = ((image[..., 0] < threshold_r) & (image[..., 1] < threshold_g) & (image[..., 2] >= 1-threshold_b)).float() 442 | 443 | cyan = ((image[..., 0] < threshold_r) & (image[..., 1] >= 1-threshold_g) & (image[..., 2] >= 1-threshold_b)).float() 444 | magenta = ((image[..., 0] >= 1-threshold_r) & (image[..., 1] < threshold_g) & (image[..., 2] > 1-threshold_b)).float() 445 | yellow = ((image[..., 0] >= 1-threshold_r) & (image[..., 1] >= 1-threshold_g) & (image[..., 2] < threshold_b)).float() 446 | 447 | black = ((image[..., 0] <= threshold_r) & (image[..., 1] <= threshold_g) & (image[..., 2] <= threshold_b)).float() 448 | white = ((image[..., 0] >= 1-threshold_r) & (image[..., 1] >= 1-threshold_g) & (image[..., 2] >= 1-threshold_b)).float() 449 | 450 | return (red, green, blue, cyan, magenta, yellow, black, white,) 451 | 452 | class TransitionMask: 453 | @classmethod 454 | def INPUT_TYPES(s): 455 | return { 456 | "required": { 457 | "width": ("INT", { "default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1, }), 458 | "height": ("INT", { "default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1, }), 459 | "frames": ("INT", { "default": 16, "min": 1, "max": 9999, "step": 1, }), 460 | "start_frame": ("INT", { "default": 0, "min": 0, "step": 1, }), 461 | "end_frame": ("INT", { "default": 9999, "min": 0, "step": 1, }), 462 | "transition_type": (["horizontal slide", "vertical slide", "horizontal bar", "vertical bar", "center box", "horizontal door", "vertical door", "circle", "fade"],), 463 | "timing_function": (["linear", "in", "out", "in-out"],) 464 | } 465 | } 466 | 467 | RETURN_TYPES = ("MASK",) 468 | FUNCTION = "execute" 469 | CATEGORY = "essentials/mask" 470 | 471 | def linear(self, i, t): 472 | return i/t 473 | def ease_in(self, i, t): 474 | return pow(i/t, 2) 475 | def ease_out(self, i, t): 476 | return 1 - pow(1 - i/t, 2) 477 | def ease_in_out(self, i, t): 478 | if i < t/2: 479 | return pow(i/(t/2), 2) / 2 480 | else: 481 | return 1 - pow(1 - (i - t/2)/(t/2), 2) / 2 482 | 483 | def execute(self, width, height, frames, start_frame, end_frame, transition_type, timing_function): 484 | if timing_function == 'in': 485 | timing_function = self.ease_in 486 | elif timing_function == 'out': 487 | timing_function = self.ease_out 488 | elif timing_function == 'in-out': 489 | timing_function = self.ease_in_out 490 | else: 491 | timing_function = self.linear 492 | 493 | out = [] 494 | 495 | end_frame = min(frames, end_frame) 496 | transition = end_frame - start_frame 497 | 498 | if start_frame > 0: 499 | out = out + [torch.full((height, width), 0.0, dtype=torch.float32, device="cpu")] * start_frame 500 | 501 | for i in range(transition): 502 | frame = torch.full((height, width), 0.0, dtype=torch.float32, device="cpu") 503 | progress = timing_function(i, transition-1) 504 | 505 | if "horizontal slide" in transition_type: 506 | pos = round(width*progress) 507 | frame[:, :pos] = 1.0 508 | elif "vertical slide" in transition_type: 509 | pos = round(height*progress) 510 | frame[:pos, :] = 1.0 511 | elif "box" in transition_type: 512 | box_w = round(width*progress) 513 | box_h = round(height*progress) 514 | x1 = (width - box_w) // 2 515 | y1 = (height - box_h) // 2 516 | x2 = x1 + box_w 517 | y2 = y1 + box_h 518 | frame[y1:y2, x1:x2] = 1.0 519 | elif "circle" in transition_type: 520 | radius = math.ceil(math.sqrt(pow(width,2)+pow(height,2))*progress/2) 521 | c_x = width // 2 522 | c_y = height // 2 523 | # is this real life? Am I hallucinating? 524 | x = torch.arange(0, width, dtype=torch.float32, device="cpu") 525 | y = torch.arange(0, height, dtype=torch.float32, device="cpu") 526 | y, x = torch.meshgrid((y, x), indexing="ij") 527 | circle = ((x - c_x) ** 2 + (y - c_y) ** 2) <= (radius ** 2) 528 | frame[circle] = 1.0 529 | elif "horizontal bar" in transition_type: 530 | bar = round(height*progress) 531 | y1 = (height - bar) // 2 532 | y2 = y1 + bar 533 | frame[y1:y2, :] = 1.0 534 | elif "vertical bar" in transition_type: 535 | bar = round(width*progress) 536 | x1 = (width - bar) // 2 537 | x2 = x1 + bar 538 | frame[:, x1:x2] = 1.0 539 | elif "horizontal door" in transition_type: 540 | bar = math.ceil(height*progress/2) 541 | if bar > 0: 542 | frame[:bar, :] = 1.0 543 | frame[-bar:, :] = 1.0 544 | elif "vertical door" in transition_type: 545 | bar = math.ceil(width*progress/2) 546 | if bar > 0: 547 | frame[:, :bar] = 1.0 548 | frame[:, -bar:] = 1.0 549 | elif "fade" in transition_type: 550 | frame[:,:] = progress 551 | 552 | out.append(frame) 553 | 554 | if end_frame < frames: 555 | out = out + [torch.full((height, width), 1.0, dtype=torch.float32, device="cpu")] * (frames - end_frame) 556 | 557 | out = torch.stack(out, dim=0) 558 | 559 | return (out, ) 560 | 561 | MASK_CLASS_MAPPINGS = { 562 | "MaskBlur+": MaskBlur, 563 | "MaskBoundingBox+": MaskBoundingBox, 564 | "MaskFix+": MaskFix, 565 | "MaskFlip+": MaskFlip, 566 | "MaskFromColor+": MaskFromColor, 567 | "MaskFromList+": MaskFromList, 568 | "MaskFromRGBCMYBW+": MaskFromRGBCMYBW, 569 | "MaskFromSegmentation+": MaskFromSegmentation, 570 | "MaskPreview+": MaskPreview, 571 | "MaskSmooth+": MaskSmooth, 572 | "TransitionMask+": TransitionMask, 573 | 574 | # Batch 575 | "MaskBatch+": MaskBatch, 576 | "MaskExpandBatch+": MaskExpandBatch, 577 | "MaskFromBatch+": MaskFromBatch, 578 | } 579 | 580 | MASK_NAME_MAPPINGS = { 581 | "MaskBlur+": "🔧 Mask Blur", 582 | "MaskFix+": "🔧 Mask Fix", 583 | "MaskFlip+": "🔧 Mask Flip", 584 | "MaskFromColor+": "🔧 Mask From Color", 585 | "MaskFromList+": "🔧 Mask From List", 586 | "MaskFromRGBCMYBW+": "🔧 Mask From RGB/CMY/BW", 587 | "MaskFromSegmentation+": "🔧 Mask From Segmentation", 588 | "MaskPreview+": "🔧 Mask Preview", 589 | "MaskBoundingBox+": "🔧 Mask Bounding Box", 590 | "MaskSmooth+": "🔧 Mask Smooth", 591 | "TransitionMask+": "🔧 Transition Mask", 592 | 593 | "MaskBatch+": "🔧 Mask Batch", 594 | "MaskExpandBatch+": "🔧 Mask Expand Batch", 595 | "MaskFromBatch+": "🔧 Mask From Batch", 596 | } 597 | -------------------------------------------------------------------------------- /sampling.py: -------------------------------------------------------------------------------- 1 | import os 2 | import comfy.samplers 3 | import comfy.sample 4 | import torch 5 | from nodes import common_ksampler, CLIPTextEncode 6 | from comfy.utils import ProgressBar 7 | from .utils import expand_mask, FONTS_DIR, parse_string_to_list 8 | import torchvision.transforms.v2 as T 9 | import torch.nn.functional as F 10 | import logging 11 | import folder_paths 12 | 13 | # From https://github.com/BlenderNeko/ComfyUI_Noise/ 14 | def slerp(val, low, high): 15 | dims = low.shape 16 | 17 | low = low.reshape(dims[0], -1) 18 | high = high.reshape(dims[0], -1) 19 | 20 | low_norm = low/torch.norm(low, dim=1, keepdim=True) 21 | high_norm = high/torch.norm(high, dim=1, keepdim=True) 22 | 23 | low_norm[low_norm != low_norm] = 0.0 24 | high_norm[high_norm != high_norm] = 0.0 25 | 26 | omega = torch.acos((low_norm*high_norm).sum(1)) 27 | so = torch.sin(omega) 28 | res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high 29 | 30 | return res.reshape(dims) 31 | 32 | class KSamplerVariationsWithNoise: 33 | @classmethod 34 | def INPUT_TYPES(s): 35 | return {"required": { 36 | "model": ("MODEL", ), 37 | "latent_image": ("LATENT", ), 38 | "main_seed": ("INT:seed", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), 39 | "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), 40 | "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), 41 | "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), 42 | "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), 43 | "positive": ("CONDITIONING", ), 44 | "negative": ("CONDITIONING", ), 45 | "variation_strength": ("FLOAT", {"default": 0.17, "min": 0.0, "max": 1.0, "step":0.01, "round": 0.01}), 46 | #"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), 47 | #"end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), 48 | #"return_with_leftover_noise": (["disable", "enable"], ), 49 | "variation_seed": ("INT:seed", {"default": 12345, "min": 0, "max": 0xffffffffffffffff}), 50 | "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step":0.01, "round": 0.01}), 51 | }} 52 | 53 | RETURN_TYPES = ("LATENT",) 54 | FUNCTION = "execute" 55 | CATEGORY = "essentials/sampling" 56 | 57 | def prepare_mask(self, mask, shape): 58 | mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") 59 | mask = mask.expand((-1,shape[1],-1,-1)) 60 | if mask.shape[0] < shape[0]: 61 | mask = mask.repeat((shape[0] -1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]] 62 | return mask 63 | 64 | def execute(self, model, latent_image, main_seed, steps, cfg, sampler_name, scheduler, positive, negative, variation_strength, variation_seed, denoise): 65 | if main_seed == variation_seed: 66 | variation_seed += 1 67 | 68 | end_at_step = steps #min(steps, end_at_step) 69 | start_at_step = round(end_at_step - end_at_step * denoise) 70 | 71 | force_full_denoise = True 72 | disable_noise = True 73 | 74 | device = comfy.model_management.get_torch_device() 75 | 76 | # Generate base noise 77 | batch_size, _, height, width = latent_image["samples"].shape 78 | generator = torch.manual_seed(main_seed) 79 | base_noise = torch.randn((1, 4, height, width), dtype=torch.float32, device="cpu", generator=generator).repeat(batch_size, 1, 1, 1).cpu() 80 | 81 | # Generate variation noise 82 | generator = torch.manual_seed(variation_seed) 83 | variation_noise = torch.randn((batch_size, 4, height, width), dtype=torch.float32, device="cpu", generator=generator).cpu() 84 | 85 | slerp_noise = slerp(variation_strength, base_noise, variation_noise) 86 | 87 | # Calculate sigma 88 | comfy.model_management.load_model_gpu(model) 89 | sampler = comfy.samplers.KSampler(model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=1.0, model_options=model.model_options) 90 | sigmas = sampler.sigmas 91 | sigma = sigmas[start_at_step] - sigmas[end_at_step] 92 | sigma /= model.model.latent_format.scale_factor 93 | sigma = sigma.detach().cpu().item() 94 | 95 | work_latent = latent_image.copy() 96 | work_latent["samples"] = latent_image["samples"].clone() + slerp_noise * sigma 97 | 98 | # if there's a mask we need to expand it to avoid artifacts, 5 pixels should be enough 99 | if "noise_mask" in latent_image: 100 | noise_mask = self.prepare_mask(latent_image["noise_mask"], latent_image['samples'].shape) 101 | work_latent["samples"] = noise_mask * work_latent["samples"] + (1-noise_mask) * latent_image["samples"] 102 | work_latent['noise_mask'] = expand_mask(latent_image["noise_mask"].clone(), 5, True) 103 | 104 | return common_ksampler(model, main_seed, steps, cfg, sampler_name, scheduler, positive, negative, work_latent, denoise=1.0, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise) 105 | 106 | 107 | class KSamplerVariationsStochastic: 108 | @classmethod 109 | def INPUT_TYPES(s): 110 | return {"required":{ 111 | "model": ("MODEL",), 112 | "latent_image": ("LATENT", ), 113 | "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), 114 | "steps": ("INT", {"default": 25, "min": 1, "max": 10000}), 115 | "cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), 116 | "sampler": (comfy.samplers.KSampler.SAMPLERS, ), 117 | "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), 118 | "positive": ("CONDITIONING", ), 119 | "negative": ("CONDITIONING", ), 120 | "variation_seed": ("INT:seed", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), 121 | "variation_strength": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step":0.05, "round": 0.01}), 122 | #"variation_sampler": (comfy.samplers.KSampler.SAMPLERS, ), 123 | "cfg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step":0.05, "round": 0.01}), 124 | }} 125 | 126 | RETURN_TYPES = ("LATENT", ) 127 | FUNCTION = "execute" 128 | CATEGORY = "essentials/sampling" 129 | 130 | def execute(self, model, latent_image, noise_seed, steps, cfg, sampler, scheduler, positive, negative, variation_seed, variation_strength, cfg_scale, variation_sampler="dpmpp_2m_sde"): 131 | # Stage 1: composition sampler 132 | force_full_denoise = False # return with leftover noise = "enable" 133 | disable_noise = False # add noise = "enable" 134 | 135 | end_at_step = max(int(steps * (1-variation_strength)), 1) 136 | start_at_step = 0 137 | 138 | work_latent = latent_image.copy() 139 | batch_size = work_latent["samples"].shape[0] 140 | work_latent["samples"] = work_latent["samples"][0].unsqueeze(0) 141 | 142 | stage1 = common_ksampler(model, noise_seed, steps, cfg, sampler, scheduler, positive, negative, work_latent, denoise=1.0, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)[0] 143 | 144 | if batch_size > 1: 145 | stage1["samples"] = stage1["samples"].clone().repeat(batch_size, 1, 1, 1) 146 | 147 | # Stage 2: variation sampler 148 | force_full_denoise = True 149 | disable_noise = True 150 | cfg = max(cfg * cfg_scale, 1.0) 151 | start_at_step = end_at_step 152 | end_at_step = steps 153 | 154 | return common_ksampler(model, variation_seed, steps, cfg, variation_sampler, scheduler, positive, negative, stage1, denoise=1.0, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise) 155 | 156 | class InjectLatentNoise: 157 | @classmethod 158 | def INPUT_TYPES(s): 159 | return {"required": { 160 | "latent": ("LATENT", ), 161 | "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), 162 | "noise_strength": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step":0.01, "round": 0.01}), 163 | "normalize": (["false", "true"], {"default": "false"}), 164 | }, 165 | "optional": { 166 | "mask": ("MASK", ), 167 | }} 168 | 169 | RETURN_TYPES = ("LATENT",) 170 | FUNCTION = "execute" 171 | CATEGORY = "essentials/sampling" 172 | 173 | def execute(self, latent, noise_seed, noise_strength, normalize="false", mask=None): 174 | torch.manual_seed(noise_seed) 175 | noise_latent = latent.copy() 176 | original_samples = noise_latent["samples"].clone() 177 | random_noise = torch.randn_like(original_samples) 178 | 179 | if normalize == "true": 180 | mean = original_samples.mean() 181 | std = original_samples.std() 182 | random_noise = random_noise * std + mean 183 | 184 | random_noise = original_samples + random_noise * noise_strength 185 | 186 | if mask is not None: 187 | mask = F.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(random_noise.shape[2], random_noise.shape[3]), mode="bilinear") 188 | mask = mask.expand((-1,random_noise.shape[1],-1,-1)).clamp(0.0, 1.0) 189 | if mask.shape[0] < random_noise.shape[0]: 190 | mask = mask.repeat((random_noise.shape[0] -1) // mask.shape[0] + 1, 1, 1, 1)[:random_noise.shape[0]] 191 | elif mask.shape[0] > random_noise.shape[0]: 192 | mask = mask[:random_noise.shape[0]] 193 | random_noise = mask * random_noise + (1-mask) * original_samples 194 | 195 | noise_latent["samples"] = random_noise 196 | 197 | return (noise_latent, ) 198 | 199 | class TextEncodeForSamplerParams: 200 | @classmethod 201 | def INPUT_TYPES(s): 202 | return { 203 | "required": { 204 | "text": ("STRING", {"multiline": True, "dynamicPrompts": True, "default": "Separate prompts with at least three dashes\n---\nLike so"}), 205 | "clip": ("CLIP", ) 206 | }} 207 | 208 | RETURN_TYPES = ("CONDITIONING", ) 209 | FUNCTION = "execute" 210 | CATEGORY = "essentials/sampling" 211 | 212 | def execute(self, text, clip): 213 | import re 214 | output_text = [] 215 | output_encoded = [] 216 | text = re.sub(r'[-*=~]{4,}\n', '---\n', text) 217 | text = text.split("---\n") 218 | 219 | for t in text: 220 | t = t.strip() 221 | if t: 222 | output_text.append(t) 223 | output_encoded.append(CLIPTextEncode().encode(clip, t)[0]) 224 | 225 | #if len(output_encoded) == 1: 226 | # output = output_encoded[0] 227 | #else: 228 | output = {"text": output_text, "encoded": output_encoded} 229 | 230 | return (output, ) 231 | 232 | class SamplerSelectHelper: 233 | @classmethod 234 | def INPUT_TYPES(s): 235 | return {"required": { 236 | **{s: ("BOOLEAN", { "default": False }) for s in comfy.samplers.KSampler.SAMPLERS}, 237 | }} 238 | 239 | RETURN_TYPES = ("STRING", ) 240 | FUNCTION = "execute" 241 | CATEGORY = "essentials/sampling" 242 | 243 | def execute(self, **values): 244 | values = [v for v in values if values[v]] 245 | values = ", ".join(values) 246 | 247 | return (values, ) 248 | 249 | class SchedulerSelectHelper: 250 | @classmethod 251 | def INPUT_TYPES(s): 252 | return {"required": { 253 | **{s: ("BOOLEAN", { "default": False }) for s in comfy.samplers.KSampler.SCHEDULERS}, 254 | }} 255 | 256 | RETURN_TYPES = ("STRING", ) 257 | FUNCTION = "execute" 258 | CATEGORY = "essentials/sampling" 259 | 260 | def execute(self, **values): 261 | values = [v for v in values if values[v]] 262 | values = ", ".join(values) 263 | 264 | return (values, ) 265 | 266 | class LorasForFluxParams: 267 | @classmethod 268 | def INPUT_TYPES(s): 269 | optional_loras = ['none'] + folder_paths.get_filename_list("loras") 270 | return { 271 | "required": { 272 | "lora_1": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), 273 | "strength_model_1": ("STRING", { "multiline": False, "dynamicPrompts": False, "default": "1.0" }), 274 | }, 275 | #"optional": { 276 | # "lora_2": (optional_loras, ), 277 | # "strength_lora_2": ("STRING", { "multiline": False, "dynamicPrompts": False }), 278 | # "lora_3": (optional_loras, ), 279 | # "strength_lora_3": ("STRING", { "multiline": False, "dynamicPrompts": False }), 280 | # "lora_4": (optional_loras, ), 281 | # "strength_lora_4": ("STRING", { "multiline": False, "dynamicPrompts": False }), 282 | #} 283 | } 284 | 285 | RETURN_TYPES = ("LORA_PARAMS", ) 286 | FUNCTION = "execute" 287 | CATEGORY = "essentials/sampling" 288 | 289 | def execute(self, lora_1, strength_model_1, lora_2="none", strength_lora_2="", lora_3="none", strength_lora_3="", lora_4="none", strength_lora_4=""): 290 | output = { "loras": [], "strengths": [] } 291 | output["loras"].append(lora_1) 292 | output["strengths"].append(parse_string_to_list(strength_model_1)) 293 | 294 | if lora_2 != "none": 295 | output["loras"].append(lora_2) 296 | if strength_lora_2 == "": 297 | strength_lora_2 = "1.0" 298 | output["strengths"].append(parse_string_to_list(strength_lora_2)) 299 | if lora_3 != "none": 300 | output["loras"].append(lora_3) 301 | if strength_lora_3 == "": 302 | strength_lora_3 = "1.0" 303 | output["strengths"].append(parse_string_to_list(strength_lora_3)) 304 | if lora_4 != "none": 305 | output["loras"].append(lora_4) 306 | if strength_lora_4 == "": 307 | strength_lora_4 = "1.0" 308 | output["strengths"].append(parse_string_to_list(strength_lora_4)) 309 | 310 | return (output,) 311 | 312 | 313 | class FluxSamplerParams: 314 | def __init__(self): 315 | self.loraloader = None 316 | self.lora = (None, None) 317 | 318 | @classmethod 319 | def INPUT_TYPES(s): 320 | return {"required": { 321 | "model": ("MODEL", ), 322 | "conditioning": ("CONDITIONING", ), 323 | "latent_image": ("LATENT", ), 324 | 325 | "seed": ("STRING", { "multiline": False, "dynamicPrompts": False, "default": "?" }), 326 | "sampler": ("STRING", { "multiline": False, "dynamicPrompts": False, "default": "euler" }), 327 | "scheduler": ("STRING", { "multiline": False, "dynamicPrompts": False, "default": "simple" }), 328 | "steps": ("STRING", { "multiline": False, "dynamicPrompts": False, "default": "20" }), 329 | "guidance": ("STRING", { "multiline": False, "dynamicPrompts": False, "default": "3.5" }), 330 | "max_shift": ("STRING", { "multiline": False, "dynamicPrompts": False, "default": "" }), 331 | "base_shift": ("STRING", { "multiline": False, "dynamicPrompts": False, "default": "" }), 332 | "denoise": ("STRING", { "multiline": False, "dynamicPrompts": False, "default": "1.0" }), 333 | }, 334 | "optional": { 335 | "loras": ("LORA_PARAMS",), 336 | }} 337 | 338 | RETURN_TYPES = ("LATENT","SAMPLER_PARAMS") 339 | RETURN_NAMES = ("latent", "params") 340 | FUNCTION = "execute" 341 | CATEGORY = "essentials/sampling" 342 | 343 | def execute(self, model, conditioning, latent_image, seed, sampler, scheduler, steps, guidance, max_shift, base_shift, denoise, loras=None): 344 | import random 345 | import time 346 | from comfy_extras.nodes_custom_sampler import Noise_RandomNoise, BasicScheduler, BasicGuider, SamplerCustomAdvanced 347 | from comfy_extras.nodes_latent import LatentBatch 348 | from comfy_extras.nodes_model_advanced import ModelSamplingFlux, ModelSamplingAuraFlow 349 | from node_helpers import conditioning_set_values 350 | from nodes import LoraLoader 351 | 352 | is_schnell = model.model.model_type == comfy.model_base.ModelType.FLOW 353 | 354 | noise = seed.replace("\n", ",").split(",") 355 | noise = [random.randint(0, 999999) if "?" in n else int(n) for n in noise] 356 | if not noise: 357 | noise = [random.randint(0, 999999)] 358 | 359 | if sampler == '*': 360 | sampler = comfy.samplers.KSampler.SAMPLERS 361 | elif sampler.startswith("!"): 362 | sampler = sampler.replace("\n", ",").split(",") 363 | sampler = [s.strip("! ") for s in sampler] 364 | sampler = [s for s in comfy.samplers.KSampler.SAMPLERS if s not in sampler] 365 | else: 366 | sampler = sampler.replace("\n", ",").split(",") 367 | sampler = [s.strip() for s in sampler if s.strip() in comfy.samplers.KSampler.SAMPLERS] 368 | if not sampler: 369 | sampler = ['ipndm'] 370 | 371 | if scheduler == '*': 372 | scheduler = comfy.samplers.KSampler.SCHEDULERS 373 | elif scheduler.startswith("!"): 374 | scheduler = scheduler.replace("\n", ",").split(",") 375 | scheduler = [s.strip("! ") for s in scheduler] 376 | scheduler = [s for s in comfy.samplers.KSampler.SCHEDULERS if s not in scheduler] 377 | else: 378 | scheduler = scheduler.replace("\n", ",").split(",") 379 | scheduler = [s.strip() for s in scheduler] 380 | scheduler = [s for s in scheduler if s in comfy.samplers.KSampler.SCHEDULERS] 381 | if not scheduler: 382 | scheduler = ['simple'] 383 | 384 | if steps == "": 385 | if is_schnell: 386 | steps = "4" 387 | else: 388 | steps = "20" 389 | steps = parse_string_to_list(steps) 390 | 391 | denoise = "1.0" if denoise == "" else denoise 392 | denoise = parse_string_to_list(denoise) 393 | 394 | guidance = "3.5" if guidance == "" else guidance 395 | guidance = parse_string_to_list(guidance) 396 | 397 | if not is_schnell: 398 | max_shift = "1.15" if max_shift == "" else max_shift 399 | base_shift = "0.5" if base_shift == "" else base_shift 400 | else: 401 | max_shift = "0" 402 | base_shift = "1.0" if base_shift == "" else base_shift 403 | 404 | max_shift = parse_string_to_list(max_shift) 405 | base_shift = parse_string_to_list(base_shift) 406 | 407 | cond_text = None 408 | if isinstance(conditioning, dict) and "encoded" in conditioning: 409 | cond_text = conditioning["text"] 410 | cond_encoded = conditioning["encoded"] 411 | else: 412 | cond_encoded = [conditioning] 413 | 414 | out_latent = None 415 | out_params = [] 416 | 417 | basicschedueler = BasicScheduler() 418 | basicguider = BasicGuider() 419 | samplercustomadvanced = SamplerCustomAdvanced() 420 | latentbatch = LatentBatch() 421 | modelsamplingflux = ModelSamplingFlux() if not is_schnell else ModelSamplingAuraFlow() 422 | width = latent_image["samples"].shape[3]*8 423 | height = latent_image["samples"].shape[2]*8 424 | 425 | lora_strength_len = 1 426 | if loras: 427 | lora_model = loras["loras"] 428 | lora_strength = loras["strengths"] 429 | lora_strength_len = sum(len(i) for i in lora_strength) 430 | 431 | if self.loraloader is None: 432 | self.loraloader = LoraLoader() 433 | 434 | # count total number of samples 435 | total_samples = len(cond_encoded) * len(noise) * len(max_shift) * len(base_shift) * len(guidance) * len(sampler) * len(scheduler) * len(steps) * len(denoise) * lora_strength_len 436 | current_sample = 0 437 | if total_samples > 1: 438 | pbar = ProgressBar(total_samples) 439 | 440 | lora_strength_len = 1 441 | if loras: 442 | lora_strength_len = len(lora_strength[0]) 443 | 444 | for los in range(lora_strength_len): 445 | if loras: 446 | patched_model = self.loraloader.load_lora(model, None, lora_model[0], lora_strength[0][los], 0)[0] 447 | else: 448 | patched_model = model 449 | 450 | for i in range(len(cond_encoded)): 451 | conditioning = cond_encoded[i] 452 | ct = cond_text[i] if cond_text else None 453 | for n in noise: 454 | randnoise = Noise_RandomNoise(n) 455 | for ms in max_shift: 456 | for bs in base_shift: 457 | if is_schnell: 458 | work_model = modelsamplingflux.patch_aura(patched_model, bs)[0] 459 | else: 460 | work_model = modelsamplingflux.patch(patched_model, ms, bs, width, height)[0] 461 | for g in guidance: 462 | cond = conditioning_set_values(conditioning, {"guidance": g}) 463 | guider = basicguider.get_guider(work_model, cond)[0] 464 | for s in sampler: 465 | samplerobj = comfy.samplers.sampler_object(s) 466 | for sc in scheduler: 467 | for st in steps: 468 | for d in denoise: 469 | sigmas = basicschedueler.get_sigmas(work_model, sc, st, d)[0] 470 | current_sample += 1 471 | log = f"Sampling {current_sample}/{total_samples} with seed {n}, sampler {s}, scheduler {sc}, steps {st}, guidance {g}, max_shift {ms}, base_shift {bs}, denoise {d}" 472 | lora_name = None 473 | lora_str = 0 474 | if loras: 475 | lora_name = lora_model[0] 476 | lora_str = lora_strength[0][los] 477 | log += f", lora {lora_name}, lora_strength {lora_str}" 478 | logging.info(log) 479 | start_time = time.time() 480 | latent = samplercustomadvanced.sample(randnoise, guider, samplerobj, sigmas, latent_image)[1] 481 | elapsed_time = time.time() - start_time 482 | out_params.append({"time": elapsed_time, 483 | "seed": n, 484 | "width": width, 485 | "height": height, 486 | "sampler": s, 487 | "scheduler": sc, 488 | "steps": st, 489 | "guidance": g, 490 | "max_shift": ms, 491 | "base_shift": bs, 492 | "denoise": d, 493 | "prompt": ct, 494 | "lora": lora_name, 495 | "lora_strength": lora_str}) 496 | 497 | if out_latent is None: 498 | out_latent = latent 499 | else: 500 | out_latent = latentbatch.batch(out_latent, latent)[0] 501 | if total_samples > 1: 502 | pbar.update(1) 503 | 504 | return (out_latent, out_params) 505 | 506 | class PlotParameters: 507 | @classmethod 508 | def INPUT_TYPES(s): 509 | return {"required": { 510 | "images": ("IMAGE", ), 511 | "params": ("SAMPLER_PARAMS", ), 512 | "order_by": (["none", "time", "seed", "steps", "denoise", "sampler", "scheduler", "guidance", "max_shift", "base_shift", "lora_strength"], ), 513 | "cols_value": (["none", "time", "seed", "steps", "denoise", "sampler", "scheduler", "guidance", "max_shift", "base_shift", "lora_strength"], ), 514 | "cols_num": ("INT", {"default": -1, "min": -1, "max": 1024 }), 515 | "add_prompt": (["false", "true", "excerpt"], ), 516 | "add_params": (["false", "true", "changes only"], {"default": "true"}), 517 | }} 518 | 519 | RETURN_TYPES = ("IMAGE", ) 520 | FUNCTION = "execute" 521 | CATEGORY = "essentials/sampling" 522 | 523 | def execute(self, images, params, order_by, cols_value, cols_num, add_prompt, add_params): 524 | from PIL import Image, ImageDraw, ImageFont 525 | import math 526 | import textwrap 527 | 528 | if images.shape[0] != len(params): 529 | raise ValueError("Number of images and number of parameters do not match.") 530 | 531 | _params = params.copy() 532 | 533 | if order_by != "none": 534 | sorted_params = sorted(_params, key=lambda x: x[order_by]) 535 | indices = [_params.index(item) for item in sorted_params] 536 | images = images[torch.tensor(indices)] 537 | _params = sorted_params 538 | 539 | if cols_value != "none" and cols_num > -1: 540 | groups = {} 541 | for p in _params: 542 | value = p[cols_value] 543 | if value not in groups: 544 | groups[value] = [] 545 | groups[value].append(p) 546 | cols_num = len(groups) 547 | 548 | sorted_params = [] 549 | groups = list(groups.values()) 550 | for g in zip(*groups): 551 | sorted_params.extend(g) 552 | 553 | indices = [_params.index(item) for item in sorted_params] 554 | images = images[torch.tensor(indices)] 555 | _params = sorted_params 556 | elif cols_num == 0: 557 | cols_num = int(math.sqrt(images.shape[0])) 558 | cols_num = max(1, min(cols_num, 1024)) 559 | 560 | width = images.shape[2] 561 | out_image = [] 562 | 563 | font = ImageFont.truetype(os.path.join(FONTS_DIR, 'ShareTechMono-Regular.ttf'), min(48, int(32*(width/1024)))) 564 | text_padding = 3 565 | line_height = font.getmask('Q').getbbox()[3] + font.getmetrics()[1] + text_padding*2 566 | char_width = font.getbbox('M')[2]+1 # using monospace font 567 | 568 | if add_params == "changes only": 569 | value_tracker = {} 570 | for p in _params: 571 | for key, value in p.items(): 572 | if key != "time": 573 | if key not in value_tracker: 574 | value_tracker[key] = set() 575 | value_tracker[key].add(value) 576 | changing_keys = {key for key, values in value_tracker.items() if len(values) > 1 or key == "prompt"} 577 | 578 | result = [] 579 | for p in _params: 580 | changing_params = {key: value for key, value in p.items() if key in changing_keys} 581 | result.append(changing_params) 582 | 583 | _params = result 584 | 585 | for (image, param) in zip(images, _params): 586 | image = image.permute(2, 0, 1) 587 | 588 | if add_params != "false": 589 | if add_params == "changes only": 590 | text = "\n".join([f"{key}: {value}" for key, value in param.items() if key != "prompt"]) 591 | else: 592 | text = f"time: {param['time']:.2f}s, seed: {param['seed']}, steps: {param['steps']}, size: {param['width']}×{param['height']}\ndenoise: {param['denoise']}, sampler: {param['sampler']}, sched: {param['scheduler']}\nguidance: {param['guidance']}, max/base shift: {param['max_shift']}/{param['base_shift']}" 593 | if 'lora' in param and param['lora']: 594 | text += f"\nLoRA: {param['lora'][:32]}, str: {param['lora_strength']}" 595 | 596 | lines = text.split("\n") 597 | text_height = line_height * len(lines) 598 | text_image = Image.new('RGB', (width, text_height), color=(0, 0, 0)) 599 | 600 | for i, line in enumerate(lines): 601 | draw = ImageDraw.Draw(text_image) 602 | draw.text((text_padding, i * line_height + text_padding), line, font=font, fill=(255, 255, 255)) 603 | 604 | text_image = T.ToTensor()(text_image).to(image.device) 605 | image = torch.cat([image, text_image], 1) 606 | 607 | if 'prompt' in param and param['prompt'] and add_prompt != "false": 608 | prompt = param['prompt'] 609 | if add_prompt == "excerpt": 610 | prompt = " ".join(param['prompt'].split()[:64]) 611 | prompt += "..." 612 | 613 | cols = math.ceil(width / char_width) 614 | prompt_lines = textwrap.wrap(prompt, width=cols) 615 | prompt_height = line_height * len(prompt_lines) 616 | prompt_image = Image.new('RGB', (width, prompt_height), color=(0, 0, 0)) 617 | 618 | for i, line in enumerate(prompt_lines): 619 | draw = ImageDraw.Draw(prompt_image) 620 | draw.text((text_padding, i * line_height + text_padding), line, font=font, fill=(255, 255, 255)) 621 | 622 | prompt_image = T.ToTensor()(prompt_image).to(image.device) 623 | image = torch.cat([image, prompt_image], 1) 624 | 625 | # a little cleanup 626 | image = torch.nan_to_num(image, nan=0.0).clamp(0.0, 1.0) 627 | out_image.append(image) 628 | 629 | # ensure all images have the same height 630 | if add_prompt != "false" or add_params == "changes only": 631 | max_height = max([image.shape[1] for image in out_image]) 632 | out_image = [F.pad(image, (0, 0, 0, max_height - image.shape[1])) for image in out_image] 633 | 634 | out_image = torch.stack(out_image, 0).permute(0, 2, 3, 1) 635 | 636 | # merge images 637 | if cols_num > -1: 638 | cols = min(cols_num, out_image.shape[0]) 639 | b, h, w, c = out_image.shape 640 | rows = math.ceil(b / cols) 641 | 642 | # Pad the tensor if necessary 643 | if b % cols != 0: 644 | padding = cols - (b % cols) 645 | out_image = F.pad(out_image, (0, 0, 0, 0, 0, 0, 0, padding)) 646 | b = out_image.shape[0] 647 | 648 | # Reshape and transpose 649 | out_image = out_image.reshape(rows, cols, h, w, c) 650 | out_image = out_image.permute(0, 2, 1, 3, 4) 651 | out_image = out_image.reshape(rows * h, cols * w, c).unsqueeze(0) 652 | 653 | """ 654 | width = out_image.shape[2] 655 | # add the title and notes on top 656 | if title and export_labels: 657 | title_font = ImageFont.truetype(os.path.join(FONTS_DIR, 'ShareTechMono-Regular.ttf'), 48) 658 | title_width = title_font.getbbox(title)[2] 659 | title_padding = 6 660 | title_line_height = title_font.getmask(title).getbbox()[3] + title_font.getmetrics()[1] + title_padding*2 661 | title_text_height = title_line_height 662 | title_text_image = Image.new('RGB', (width, title_text_height), color=(0, 0, 0, 0)) 663 | 664 | draw = ImageDraw.Draw(title_text_image) 665 | draw.text((width//2 - title_width//2, title_padding), title, font=title_font, fill=(255, 255, 255)) 666 | 667 | title_text_image = T.ToTensor()(title_text_image).unsqueeze(0).permute([0,2,3,1]).to(out_image.device) 668 | out_image = torch.cat([title_text_image, out_image], 1) 669 | """ 670 | 671 | return (out_image, ) 672 | 673 | class GuidanceTimestepping: 674 | @classmethod 675 | def INPUT_TYPES(s): 676 | return { 677 | "required": { 678 | "model": ("MODEL",), 679 | "value": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step": 0.05}), 680 | "start_at": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01}), 681 | "end_at": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.01}), 682 | } 683 | } 684 | 685 | RETURN_TYPES = ("MODEL",) 686 | FUNCTION = "execute" 687 | CATEGORY = "essentials/sampling" 688 | 689 | def execute(self, model, value, start_at, end_at): 690 | sigma_start = model.get_model_object("model_sampling").percent_to_sigma(start_at) 691 | sigma_end = model.get_model_object("model_sampling").percent_to_sigma(end_at) 692 | 693 | def apply_apg(args): 694 | cond = args["cond"] 695 | uncond = args["uncond"] 696 | cond_scale = args["cond_scale"] 697 | sigma = args["sigma"] 698 | 699 | sigma = sigma.detach().cpu()[0].item() 700 | 701 | if sigma <= sigma_start and sigma > sigma_end: 702 | cond_scale = value 703 | 704 | return uncond + (cond - uncond) * cond_scale 705 | 706 | m = model.clone() 707 | m.set_model_sampler_cfg_function(apply_apg) 708 | return (m,) 709 | 710 | class ModelSamplingDiscreteFlowCustom(torch.nn.Module): 711 | def __init__(self, model_config=None): 712 | super().__init__() 713 | if model_config is not None: 714 | sampling_settings = model_config.sampling_settings 715 | else: 716 | sampling_settings = {} 717 | 718 | self.set_parameters(shift=sampling_settings.get("shift", 1.0), multiplier=sampling_settings.get("multiplier", 1000)) 719 | 720 | def set_parameters(self, shift=1.0, timesteps=1000, multiplier=1000, cut_off=1.0, shift_multiplier=0): 721 | self.shift = shift 722 | self.multiplier = multiplier 723 | self.cut_off = cut_off 724 | self.shift_multiplier = shift_multiplier 725 | ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps) * multiplier) 726 | self.register_buffer('sigmas', ts) 727 | 728 | @property 729 | def sigma_min(self): 730 | return self.sigmas[0] 731 | 732 | @property 733 | def sigma_max(self): 734 | return self.sigmas[-1] 735 | 736 | def timestep(self, sigma): 737 | return sigma * self.multiplier 738 | 739 | def sigma(self, timestep): 740 | shift = self.shift 741 | if timestep.dim() == 0: 742 | t = timestep.cpu().item() / self.multiplier 743 | if t <= self.cut_off: 744 | shift = shift * self.shift_multiplier 745 | 746 | return comfy.model_sampling.time_snr_shift(shift, timestep / self.multiplier) 747 | 748 | def percent_to_sigma(self, percent): 749 | if percent <= 0.0: 750 | return 1.0 751 | if percent >= 1.0: 752 | return 0.0 753 | return 1.0 - percent 754 | 755 | class ModelSamplingSD3Advanced: 756 | @classmethod 757 | def INPUT_TYPES(s): 758 | return {"required": { "model": ("MODEL",), 759 | "shift": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step":0.01}), 760 | "cut_off": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step":0.05}), 761 | "shift_multiplier": ("FLOAT", {"default": 2, "min": 0, "max": 10, "step":0.05}), 762 | }} 763 | 764 | RETURN_TYPES = ("MODEL",) 765 | FUNCTION = "execute" 766 | 767 | CATEGORY = "essentials/sampling" 768 | 769 | def execute(self, model, shift, multiplier=1000, cut_off=1.0, shift_multiplier=0): 770 | m = model.clone() 771 | 772 | 773 | sampling_base = ModelSamplingDiscreteFlowCustom 774 | sampling_type = comfy.model_sampling.CONST 775 | 776 | class ModelSamplingAdvanced(sampling_base, sampling_type): 777 | pass 778 | 779 | model_sampling = ModelSamplingAdvanced(model.model.model_config) 780 | model_sampling.set_parameters(shift=shift, multiplier=multiplier, cut_off=cut_off, shift_multiplier=shift_multiplier) 781 | m.add_object_patch("model_sampling", model_sampling) 782 | 783 | return (m, ) 784 | 785 | SAMPLING_CLASS_MAPPINGS = { 786 | "KSamplerVariationsStochastic+": KSamplerVariationsStochastic, 787 | "KSamplerVariationsWithNoise+": KSamplerVariationsWithNoise, 788 | "InjectLatentNoise+": InjectLatentNoise, 789 | "FluxSamplerParams+": FluxSamplerParams, 790 | "GuidanceTimestepping+": GuidanceTimestepping, 791 | "PlotParameters+": PlotParameters, 792 | "TextEncodeForSamplerParams+": TextEncodeForSamplerParams, 793 | "SamplerSelectHelper+": SamplerSelectHelper, 794 | "SchedulerSelectHelper+": SchedulerSelectHelper, 795 | "LorasForFluxParams+": LorasForFluxParams, 796 | "ModelSamplingSD3Advanced+": ModelSamplingSD3Advanced, 797 | } 798 | 799 | SAMPLING_NAME_MAPPINGS = { 800 | "KSamplerVariationsStochastic+": "🔧 KSampler Stochastic Variations", 801 | "KSamplerVariationsWithNoise+": "🔧 KSampler Variations with Noise Injection", 802 | "InjectLatentNoise+": "🔧 Inject Latent Noise", 803 | "FluxSamplerParams+": "🔧 Flux Sampler Parameters", 804 | "GuidanceTimestepping+": "🔧 Guidance Timestep (experimental)", 805 | "PlotParameters+": "🔧 Plot Sampler Parameters", 806 | "TextEncodeForSamplerParams+": "🔧Text Encode for Sampler Params", 807 | "SamplerSelectHelper+": "🔧 Sampler Select Helper", 808 | "SchedulerSelectHelper+": "🔧 Scheduler Select Helper", 809 | "LorasForFluxParams+": "🔧 LoRA for Flux Parameters", 810 | "ModelSamplingSD3Advanced+": "🔧 Model Sampling SD3 Advanced", 811 | } --------------------------------------------------------------------------------