├── .gitignore ├── py ├── inspyrenet │ ├── __init__.py │ ├── config.yaml │ ├── modules │ │ ├── decoder_module.py │ │ ├── context_module.py │ │ ├── attention_module.py │ │ └── layers.py │ ├── InSPyReNet.py │ └── utils.py ├── md.py ├── combo_setter.py ├── inspynet.py ├── color_adjustment.py ├── switch.py ├── image_size_adjustment.py ├── image_cropper.py ├── image_selector.py └── bridge_preview.py ├── requirements.txt ├── assets ├── crop.jpg ├── size.jpg ├── noise.jpg ├── switch.jpg ├── refresh.png ├── FastCanvas.png ├── color_adjust.jpg └── CachePreviewBridge.png ├── pyproject.toml ├── .github └── workflows │ └── publish.yml ├── web ├── pip_manager.js ├── install_dependencies.js ├── lib │ └── fabric-slim.min.js ├── counter.js ├── Util.js ├── multi_button_widget.js ├── PB.js ├── image_loader_counter.js ├── combo_setter.js ├── queue_shortcut.js ├── bridge_preview.js ├── lg_group_muter.js ├── image_size_adjustment.js ├── color_adjustment.js ├── upload.js └── image_cropper.js ├── LICENSE.txt ├── __init__.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /py/inspyrenet/__init__.py: -------------------------------------------------------------------------------- 1 | from .Remover import Remover, console 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-contrib-python 2 | transparent-background 3 | -------------------------------------------------------------------------------- /assets/crop.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAOGOU-666/Comfyui_LG_Tools/HEAD/assets/crop.jpg -------------------------------------------------------------------------------- /assets/size.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAOGOU-666/Comfyui_LG_Tools/HEAD/assets/size.jpg -------------------------------------------------------------------------------- /assets/noise.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAOGOU-666/Comfyui_LG_Tools/HEAD/assets/noise.jpg -------------------------------------------------------------------------------- /assets/switch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAOGOU-666/Comfyui_LG_Tools/HEAD/assets/switch.jpg -------------------------------------------------------------------------------- /assets/refresh.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAOGOU-666/Comfyui_LG_Tools/HEAD/assets/refresh.png -------------------------------------------------------------------------------- /assets/FastCanvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAOGOU-666/Comfyui_LG_Tools/HEAD/assets/FastCanvas.png -------------------------------------------------------------------------------- /assets/color_adjust.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAOGOU-666/Comfyui_LG_Tools/HEAD/assets/color_adjust.jpg -------------------------------------------------------------------------------- /assets/CachePreviewBridge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAOGOU-666/Comfyui_LG_Tools/HEAD/assets/CachePreviewBridge.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui_lg_tools" 3 | description = "This is a toolset designed for ComfyUI by LAOGOU-666, providing a series of practical image processing and operation nodes, making our operation more intuitive and convenient" 4 | version = "1.3.7" 5 | license = { text = "MIT License" } 6 | dependencies = ["opencv-contrib-python"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/LAOGOU-666/Comfyui_LG_Tools" 10 | 11 | [tool.comfy] 12 | PublisherId = "laogou666" 13 | DisplayName = "Comfyui_LG_Tools" 14 | Icon = "" 15 | -------------------------------------------------------------------------------- /py/inspyrenet/config.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | url: "https://github.com/plemeri/transparent-background/releases/download/1.2.12/ckpt_base.pth" 3 | md5: "d692e3dd5fa1b9658949d452bebf1cda" 4 | ckpt_name: "ckpt_base.pth" 5 | http_proxy: NULL 6 | base_size: [1024, 1024] 7 | 8 | 9 | fast: 10 | url: "https://github.com/plemeri/transparent-background/releases/download/1.2.12/ckpt_fast.pth" 11 | md5: "9efdbfbcc49b79ef0f7891c83d2fd52f" 12 | ckpt_name: "ckpt_fast.pth" 13 | http_proxy: NULL 14 | base_size: [384, 384] 15 | 16 | base-nightly: 17 | url: "https://github.com/plemeri/transparent-background/releases/download/1.2.12/ckpt_base_nightly.pth" 18 | md5: NULL 19 | ckpt_name: "ckpt_base_nightly.pth" 20 | http_proxy: NULL 21 | base_size: [1024, 1024] 22 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | permissions: 11 | issues: write 12 | 13 | jobs: 14 | publish-node: 15 | name: Publish Custom Node to registry 16 | runs-on: ubuntu-latest 17 | if: ${{ github.repository_owner == 'LAOGOU-666' }} 18 | steps: 19 | - name: Check out code 20 | uses: actions/checkout@v4 21 | with: 22 | submodules: true 23 | - name: Publish Custom Node 24 | uses: Comfy-Org/publish-node-action@v1 25 | with: 26 | ## Add your own personal access token to your Github Repository secrets and reference it here. 27 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 28 | -------------------------------------------------------------------------------- /py/md.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import numpy as np 4 | import cv2 5 | import os 6 | import sys 7 | import folder_paths 8 | import base64 9 | import io 10 | import traceback 11 | import subprocess 12 | import hashlib 13 | try: 14 | import torchvision.transforms.v2 as T 15 | except ImportError: 16 | import torchvision.transforms as T 17 | from aiohttp import web 18 | from PIL import Image, ImageOps 19 | from io import BytesIO 20 | from threading import Event 21 | from nodes import LoadImage, PreviewImage 22 | from server import PromptServer 23 | routes = PromptServer.instance.routes 24 | 25 | class AlwaysEqualProxy(str): 26 | def __eq__(self, _): 27 | return True 28 | 29 | def __ne__(self, _): 30 | return False 31 | 32 | class AnyType(str): 33 | """用于表示任意类型的特殊类,在类型比较时总是返回相等""" 34 | def __eq__(self, _) -> bool: 35 | return True 36 | 37 | def __ne__(self, __value: object) -> bool: 38 | return False 39 | 40 | any = AnyType("*") -------------------------------------------------------------------------------- /web/pip_manager.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../scripts/app.js"; 2 | import { TerminalManager } from "./Util.js"; 3 | 4 | const terminal = new TerminalManager("/lg/pip_manager", "LG_PipManager"); 5 | 6 | app.registerExtension({ 7 | name: "LG.PipManager", 8 | 9 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 10 | if (nodeData.name === "LG_PipManager") { 11 | const onNodeCreated = nodeType.prototype.onNodeCreated; 12 | nodeType.prototype.onNodeCreated = async function () { 13 | const me = onNodeCreated?.apply(this); 14 | terminal.setupNode(this); 15 | return me; 16 | }; 17 | 18 | const onDrawForeground = nodeType.prototype.onDrawForeground; 19 | nodeType.prototype.onDrawForeground = function (ctx, graphcanvas) { 20 | return terminal.updateNode(this, onDrawForeground, ctx, graphcanvas); 21 | }; 22 | } 23 | }, 24 | }); -------------------------------------------------------------------------------- /web/install_dependencies.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../scripts/app.js"; 2 | import { TerminalManager } from "./Util.js"; 3 | 4 | const terminal = new TerminalManager("/lg/install_dependencies", "LG_InstallDependencies"); 5 | 6 | app.registerExtension({ 7 | name: "LG.InstallDependencies", 8 | 9 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 10 | if (nodeData.name === "LG_InstallDependencies") { 11 | const onNodeCreated = nodeType.prototype.onNodeCreated; 12 | nodeType.prototype.onNodeCreated = async function () { 13 | const me = onNodeCreated?.apply(this); 14 | terminal.setupNode(this); 15 | return me; 16 | }; 17 | 18 | const onDrawForeground = nodeType.prototype.onDrawForeground; 19 | nodeType.prototype.onDrawForeground = function (ctx, graphcanvas) { 20 | return terminal.updateNode(this, onDrawForeground, ctx, graphcanvas); 21 | }; 22 | } 23 | }, 24 | }); -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 LAOGOU-666 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /py/inspyrenet/modules/decoder_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .layers import * 6 | class PAA_d(nn.Module): 7 | def __init__(self, in_channel, out_channel=1, depth=64, base_size=None, stage=None): 8 | super(PAA_d, self).__init__() 9 | self.conv1 = Conv2d(in_channel ,depth, 3) 10 | self.conv2 = Conv2d(depth, depth, 3) 11 | self.conv3 = Conv2d(depth, depth, 3) 12 | self.conv4 = Conv2d(depth, depth, 3) 13 | self.conv5 = Conv2d(depth, out_channel, 3, bn=False) 14 | 15 | self.base_size = base_size 16 | self.stage = stage 17 | 18 | if base_size is not None and stage is not None: 19 | self.stage_size = (base_size[0] // (2 ** stage), base_size[1] // (2 ** stage)) 20 | else: 21 | self.stage_size = [None, None] 22 | 23 | self.Hattn = SelfAttention(depth, 'h', self.stage_size[0]) 24 | self.Wattn = SelfAttention(depth, 'w', self.stage_size[1]) 25 | 26 | self.upsample = lambda img, size: F.interpolate(img, size=size, mode='bilinear', align_corners=True) 27 | 28 | def forward(self, fs): #f3 f4 f5 -> f3 f2 f1 29 | fx = fs[0] 30 | for i in range(1, len(fs)): 31 | fs[i] = self.upsample(fs[i], fx.shape[-2:]) 32 | fx = torch.cat(fs[::-1], dim=1) 33 | 34 | fx = self.conv1(fx) 35 | 36 | Hfx = self.Hattn(fx) 37 | Wfx = self.Wattn(fx) 38 | 39 | fx = self.conv2(Hfx + Wfx) 40 | fx = self.conv3(fx) 41 | fx = self.conv4(fx) 42 | out = self.conv5(fx) 43 | 44 | return fx, out -------------------------------------------------------------------------------- /py/combo_setter.py: -------------------------------------------------------------------------------- 1 | CATEGORY_TYPE = "🎈LAOGOU/Utils" 2 | 3 | class ComboSetter: 4 | """ 5 | 动态Combo设置器 6 | """ 7 | @classmethod 8 | def INPUT_TYPES(cls): 9 | return { 10 | "required": { 11 | "labels": ("STRING", { 12 | "default": "", 13 | "multiline": True, 14 | "placeholder": "每行一个标签" 15 | }), 16 | "prompts": ("STRING", { 17 | "default": "", 18 | "multiline": True, 19 | "placeholder": "每行一个提示词" 20 | }), 21 | "selected": ("STRING", { 22 | "default": "" 23 | }), 24 | } 25 | } 26 | 27 | RETURN_TYPES = ("STRING", "STRING",) 28 | RETURN_NAMES = ("selected_label", "selected_prompt",) 29 | FUNCTION = "execute" 30 | CATEGORY = CATEGORY_TYPE 31 | 32 | def execute(self, labels, prompts, selected): 33 | # 按行分割 34 | label_lines = [line.strip() for line in labels.split('\n') if line.strip()] 35 | prompt_lines = [line.strip() for line in prompts.split('\n') if line.strip()] 36 | 37 | # 找到选中的索引 38 | selected_prompt = "" 39 | if selected in label_lines: 40 | index = label_lines.index(selected) 41 | if index < len(prompt_lines): 42 | selected_prompt = prompt_lines[index] 43 | 44 | return (selected, selected_prompt) 45 | 46 | NODE_CLASS_MAPPINGS = { 47 | "ComboSetter": ComboSetter, 48 | } 49 | 50 | NODE_DISPLAY_NAME_MAPPINGS = { 51 | "ComboSetter": "🎈ComboSetter", 52 | } 53 | 54 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import importlib.util 3 | import os 4 | import sys 5 | import json 6 | 7 | NODE_CLASS_MAPPINGS = {} 8 | NODE_DISPLAY_NAME_MAPPINGS = {} 9 | WEB_DIRECTORY = "web" 10 | python = sys.executable 11 | 12 | def get_ext_dir(subpath=None, mkdir=False): 13 | dir = os.path.dirname(__file__) 14 | if subpath is not None: 15 | dir = os.path.join(dir, subpath) 16 | 17 | dir = os.path.abspath(dir) 18 | 19 | if mkdir and not os.path.exists(dir): 20 | os.makedirs(dir) 21 | return dir 22 | 23 | def serialize(obj): 24 | if isinstance(obj, (str, int, float, bool, list, dict, type(None))): 25 | return obj 26 | return str(obj) 27 | 28 | 29 | py = get_ext_dir("py") 30 | files = os.listdir(py) 31 | all_nodes = {} 32 | for file in files: 33 | if not file.endswith(".py"): 34 | continue 35 | name = os.path.splitext(file)[0] 36 | imported_module = importlib.import_module(".py.{}".format(name), __name__) 37 | try: 38 | NODE_CLASS_MAPPINGS = {**NODE_CLASS_MAPPINGS, **imported_module.NODE_CLASS_MAPPINGS} 39 | NODE_DISPLAY_NAME_MAPPINGS = {**NODE_DISPLAY_NAME_MAPPINGS, **imported_module.NODE_DISPLAY_NAME_MAPPINGS} 40 | serialized_CLASS_MAPPINGS = {k: serialize(v) for k, v in imported_module.NODE_CLASS_MAPPINGS.items()} 41 | serialized_DISPLAY_NAME_MAPPINGS = {k: serialize(v) for k, v in imported_module.NODE_DISPLAY_NAME_MAPPINGS.items()} 42 | all_nodes[file]={"NODE_CLASS_MAPPINGS": serialized_CLASS_MAPPINGS, "NODE_DISPLAY_NAME_MAPPINGS": serialized_DISPLAY_NAME_MAPPINGS} 43 | except: 44 | pass 45 | 46 | 47 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"] 48 | -------------------------------------------------------------------------------- /web/lib/fabric-slim.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Fabric.js 精简版 (基于 fabric-with-all-modules@4.6.1) 3 | * 压缩版本 4 | */ 5 | import{fabric as t}from"./fabric.js";const fabric={};fabric.version=t.version,fabric.isTouchSupported=t.isTouchSupported,fabric.isLikelyNode=t.isLikelyNode,fabric.SHARED_ATTRIBUTES=t.SHARED_ATTRIBUTES,fabric.DPI=t.DPI,fabric.reNum=t.reNum,fabric.devicePixelRatio=t.devicePixelRatio,fabric.browserShadowBlurConstant=t.browserShadowBlurConstant,fabric.enableGLFiltering=t.enableGLFiltering,fabric.initFilterBackend=t.initFilterBackend,fabric.util={addClass:t.util.addClass,animate:t.util.animate,animateColor:t.util.animateColor,applyTransformToObject:t.util.applyTransformToObject,cos:t.util.cos,sin:t.util.sin,rotatePoint:t.util.rotatePoint,transformPoint:t.util.transformPoint,makeElement:t.util.makeElement,createCanvasElement:t.util.createCanvasElement,createImage:t.util.createImage,loadImage:t.util.loadImage,toFixed:t.util.toFixed,multiplyTransformMatrices:t.util.multiplyTransformMatrices,invertTransform:t.util.invertTransform,getRandomInt:t.util.getRandomInt,degreesToRadians:t.util.degreesToRadians,radiansToDegrees:t.util.radiansToDegrees,object:t.util.object,string:t.util.string,array:t.util.array,setImageSmoothing:t.util.setImageSmoothing,getPointer:t.util.getPointer,falseFunction:t.util.falseFunction,requestAnimFrame:t.util.requestAnimFrame,cancelAnimFrame:t.util.cancelAnimFrame},fabric.Observable=t.Observable,fabric.Collection=t.Collection,fabric.Point=t.Point,fabric.Intersection=t.Intersection,fabric.Color=t.Color,fabric.Object=t.Object,fabric.Canvas=t.Canvas,fabric.StaticCanvas=t.StaticCanvas,fabric.Rect=t.Rect,fabric.Image=t.Image,fabric.Text=t.Text,fabric.IText=t.IText,fabric.Textbox=t.Textbox,fabric.Shadow=t.Shadow;export{fabric}; -------------------------------------------------------------------------------- /py/inspyrenet/modules/context_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .layers import * 6 | 7 | class PAA_kernel(nn.Module): 8 | def __init__(self, in_channel, out_channel, receptive_size, stage_size=None): 9 | super(PAA_kernel, self).__init__() 10 | self.conv0 = Conv2d(in_channel, out_channel, 1) 11 | self.conv1 = Conv2d(out_channel, out_channel, kernel_size=(1, receptive_size)) 12 | self.conv2 = Conv2d(out_channel, out_channel, kernel_size=(receptive_size, 1)) 13 | self.conv3 = Conv2d(out_channel, out_channel, 3, dilation=receptive_size) 14 | self.Hattn = SelfAttention(out_channel, 'h', stage_size[0] if stage_size is not None else None) 15 | self.Wattn = SelfAttention(out_channel, 'w', stage_size[1] if stage_size is not None else None) 16 | 17 | def forward(self, x): 18 | x = self.conv0(x) 19 | x = self.conv1(x) 20 | x = self.conv2(x) 21 | 22 | Hx = self.Hattn(x) 23 | Wx = self.Wattn(x) 24 | 25 | x = self.conv3(Hx + Wx) 26 | return x 27 | 28 | class PAA_e(nn.Module): 29 | def __init__(self, in_channel, out_channel, base_size=None, stage=None): 30 | super(PAA_e, self).__init__() 31 | self.relu = nn.ReLU(True) 32 | if base_size is not None and stage is not None: 33 | self.stage_size = (base_size[0] // (2 ** stage), base_size[1] // (2 ** stage)) 34 | else: 35 | self.stage_size = None 36 | 37 | self.branch0 = Conv2d(in_channel, out_channel, 1) 38 | self.branch1 = PAA_kernel(in_channel, out_channel, 3, self.stage_size) 39 | self.branch2 = PAA_kernel(in_channel, out_channel, 5, self.stage_size) 40 | self.branch3 = PAA_kernel(in_channel, out_channel, 7, self.stage_size) 41 | 42 | self.conv_cat = Conv2d(4 * out_channel, out_channel, 3) 43 | self.conv_res = Conv2d(in_channel, out_channel, 1) 44 | 45 | def forward(self, x): 46 | x0 = self.branch0(x) 47 | x1 = self.branch1(x) 48 | x2 = self.branch2(x) 49 | x3 = self.branch3(x) 50 | 51 | x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1)) 52 | x = self.relu(x_cat + self.conv_res(x)) 53 | 54 | return x 55 | -------------------------------------------------------------------------------- /py/inspynet.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | import numpy as np 4 | from .inspyrenet import Remover 5 | from tqdm import tqdm 6 | 7 | 8 | def tensor2pil(image): 9 | return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) 10 | 11 | # Convert PIL to Tensor 12 | def pil2tensor(image): 13 | return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) 14 | class InspyrenetRembgLoader: 15 | def __init__(self): 16 | self.model = None 17 | 18 | @classmethod 19 | def INPUT_TYPES(s): 20 | return { 21 | "required": { 22 | "mode": (["base", "fast"],), 23 | "torchscript_jit": (["default", "on"],), 24 | }, 25 | } 26 | 27 | RETURN_TYPES = ("INSPYRENET_MODEL",) 28 | FUNCTION = "load_model" 29 | CATEGORY = "image" 30 | 31 | def load_model(self, mode, torchscript_jit): 32 | jit = torchscript_jit == "on" 33 | self.model = Remover(mode=mode, jit=jit) 34 | return (self.model,) 35 | 36 | class InspyrenetRembgProcess: 37 | @classmethod 38 | def INPUT_TYPES(s): 39 | return { 40 | "required": { 41 | "model": ("INSPYRENET_MODEL",), 42 | "image": ("IMAGE",), 43 | "threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), 44 | "background_color": ("STRING", {"default": "", "multiline": False}), 45 | }, 46 | } 47 | 48 | RETURN_TYPES = ("IMAGE", "MASK") 49 | FUNCTION = "process_image" 50 | CATEGORY = "image" 51 | 52 | def process_image(self, model, image, threshold, background_color): 53 | img_list = [] 54 | mask_list = [] # 新增mask列表 55 | for img in tqdm(image, "Inspyrenet Rembg"): 56 | mid = model.process(tensor2pil(img), type='rgba', threshold=threshold) 57 | 58 | # 保存mask(在处理背景之前) 59 | rgba_tensor = pil2tensor(mid) 60 | mask_list.append(rgba_tensor[:, :, :, 3]) # 保存alpha通道作为mask 61 | 62 | if background_color.strip(): 63 | try: 64 | rgba_img = mid 65 | background = Image.new('RGBA', rgba_img.size, background_color) 66 | mid = Image.alpha_composite(background, rgba_img) 67 | # 转换为RGB模式,移除alpha通道 68 | mid = mid.convert('RGB') 69 | except ValueError: 70 | print(f"无效的颜色值: {background_color},使用透明背景") 71 | 72 | out = pil2tensor(mid) 73 | img_list.append(out) 74 | 75 | img_stack = torch.cat(img_list, dim=0) 76 | mask_stack = torch.cat(mask_list, dim=0) # 合并所有mask 77 | 78 | # 如果有背景色,返回RGB图像和mask 79 | if background_color.strip(): 80 | return (img_stack[:, :, :, :3], mask_stack) 81 | # 如果没有背景色,保持原来的RGBA格式 82 | else: 83 | return (img_stack, mask_stack) 84 | 85 | NODE_CLASS_MAPPINGS = { 86 | "InspyrenetRembgLoader": InspyrenetRembgLoader, 87 | "InspyrenetRembgProcess": InspyrenetRembgProcess 88 | } 89 | NODE_DISPLAY_NAME_MAPPINGS = { 90 | "InspyrenetRembgLoader": "InSPyReNet Loader", 91 | "InspyrenetRembgProcess": "InSPyReNet Rembg" 92 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Comfyui_LG_Tools 2 | 3 | 这是***LG_老狗的学习笔记***为ComfyUI设计的工具集,提供了一系列实用的图像处理和操作节点,让我们的操作变得更加直观方便 4 | 5 | ## 安装说明 6 | 7 | 1. 确保已安装ComfyUI 8 | 2. 将此仓库克隆到ComfyUI的`custom_nodes`目录下: 9 | ```bash 10 | cd ComfyUI/custom_nodes 11 | git clone https://github.com/LAOGOU-666/Comfyui_LG_Tools.git 12 | ``` 13 | 14 | 3. 安装依赖: 15 | ```bash 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## 使用方法 20 | 21 | 1. 启动ComfyUI 22 | 2. 在右键添加节点中,您可以在"🎈LAOGOU"类别下找到所有工具节点 23 | 3. 将需要的节点拖入工作区并连接使用 24 | 25 | ## 节点说明 26 | 27 | ### 图像裁剪节点 28 | ![Image](./assets/crop.jpg) 29 | - 点击左键框选指定范围进行裁剪 30 | ### 图像尺寸调整节点 31 | ![Image](./assets/size.jpg) 32 | 33 | ### 颜色调整节点 34 | ![Image](./assets/color_adjust.jpg) 35 | 36 | ### FastCanvas画布节点 37 | ![Image](./assets/FastCanvas.png) 38 | > * 1.支持实时调整构图输出图像和选中图层遮罩 39 | > 40 | > * 2.支持批量构图,切换输入图层继承上个图层的位置和缩放 41 | > 42 | > * 3.支持限制画布窗口视图大小的功能,不用担心图片较大占地方了 43 | > 44 | > * 4.图层支持右键辅助功能,变换,设置背景等,点开有惊喜 45 | > 46 | > * 5.支持通过输入端口输入图像, 47 | 支持无输入端口独立使用, 48 | > 49 | > * 6.支持复制,拖拽,以及上传方式对图像处理 50 | > 51 | 注意!fastcanvas tool动态输入节点使用方法: 52 | * bg_img输入背景RGB图片,img输入图层图片,可以输入RGB/RGBA图片 53 | * **系统自带的加载图片节点默认输出的是RGB!不是RGBA(带遮罩通道的图片)!使用加载图像输入RGBA需要合并ALPHA图层!** 54 | ### 开关节点 55 | ![Image](./assets/switch.jpg) 56 | - 一键控制单组/多组的忽略或者禁用模式 57 | - 惰性求值开关lazyswitch(仅运行指定线路,非指定线路无需加载) 58 | - 注意!点击开关节点右键有设置不同模式(忽略和禁用)的功能 59 | 60 | ### 噪波节点 61 | ![Image](./assets/noise.jpg) 62 | - 添加自定义噪波以对图像进行预处理 63 | 64 | ### 桥接预览节点 65 | ![Image](./assets/CachePreviewBridge.png) 66 | 67 | > * 当你使用input模式将图片输入到节点后,点击Cache按钮即可缓存当前图片,然后进行编辑遮罩,并且不会出现遮罩被重置的问题, 68 | > 69 | > * 在点击Cache按钮后,无论输入端口是否连接,是否刷新,都不会影响当前缓存的图片和遮罩,你可以继续在当前节点编辑遮罩并且不会重置缓存 70 | > 71 | > * 现在支持复制功能,相当于加载图片节点和桥接预览节点的集合,对于需要重复操作以及大型工作流的缓存处理能提供很大便利 72 | 73 | ### 加载图像节点和桥接预览节点(新版本) 74 | ![Image](./assets/refresh.png) 75 | > * 新增刷新功能,一键将temp/output文件夹的最新图片刷新到当前节点,对于重复处理流程可以省去复制粘贴的操作 76 | 77 | ### 图片加载器(计数器)节点 🎈LG_ImageLoaderWithCounter 78 | 一个集成计数器功能的图片加载器节点,可从指定文件夹自动加载图片,**内置计数器无需外接**。 79 | 80 | #### 主要功能 81 | - **内置计数器**:自动获取文件夹图片总数,无需外接计数器节点 82 | - **三种加载模式**: 83 | - **increase**: 递增模式,每次加载下一张(0 → 1 → 2...) 84 | - **decrease**: 递减模式,每次加载上一张(9 → 8 → 7...) 85 | - **all**: 一次性加载所有图片为列表 86 | - **多种排序方式**: 87 | - Alphabetical (ASC/DESC) - 按字母表排序 88 | - Numerical (ASC/DESC) - 按数字排序 89 | - Datetime (ASC/DESC) - 按文件修改时间排序 90 | - **灵活的路径支持**:支持相对路径和绝对路径 91 | - **实时状态显示**:在节点标题栏实时显示当前索引(如:5/100) 92 | - **刷新功能**:节点内置刷新按钮,可重置计数器 93 | 94 | #### 输入参数 95 | - `folder_path` - 图片文件夹路径(支持相对路径和绝对路径) 96 | - `mode` - 加载模式 97 | - increase: 每次执行索引递增,自动循环 98 | - decrease: 每次执行索引递减,自动循环 99 | - all: 一次性加载文件夹内所有图片为列表 100 | - `sort_mode` - 排序模式,决定文件夹中图片的加载顺序 101 | - `keep_index` - 保持索引开关(Boolean) 102 | - True: 保持当前索引不变,暂停计数 103 | - False: 正常递增/递减(默认) 104 | 105 | #### 输出参数(列表格式) 106 | - `images` - 加载的图片列表 (IMAGE LIST) 107 | - `masks` - 图片遮罩列表 (MASK LIST) 108 | - `filenames` - 文件名列表 (STRING LIST) 109 | - `current_index` - 当前索引 (INT) 110 | - `total_images` - 图片总数 (INT) 111 | 112 | **注意**:所有输出都是列表格式,increase/decrease模式输出单元素列表,all模式输出所有图片列表。 113 | 114 | #### 使用示例 115 | 1. **批量顺序处理**:设置mode为increase,每次Queue执行自动加载下一张图片 116 | 2. **倒序处理**:设置mode为decrease,从最后一张图片开始往前处理 117 | 3. **一次性加载所有图片**:设置mode为all,将文件夹内所有图片加载为列表 118 | 4. **按时间顺序处理**:使用Datetime (ASC)排序,按时间顺序处理图片 119 | 5. **暂停在特定图片**:在递增/递减模式下,开启keep_index,可以暂停在当前图片进行测试 120 | 121 | #### 注意事项 122 | - 支持的图片格式:jpg, jpeg, png, bmp, gif, webp, tiff, tif 123 | - 相对路径基于ComfyUI根目录 124 | - 递增/递减模式会自动循环(到达末尾后回到开头) 125 | - 节点右键菜单或刷新按钮可重置计数器 126 | - all模式适合配合批处理节点使用 127 | 128 | ## 注意 129 | * 因为该库的节点是从LG_Node里面拆分出来的,之前购买过LG_Node的如果需要使用这个节点包,请联系我获取新的版本以避免出现节点冲突 130 | 131 | ## 合作/定制/0基础插件教程 132 | - **wechat:** wenrulaogou2033 133 | - **Bilibili:** 老狗_学习笔记 -------------------------------------------------------------------------------- /py/color_adjustment.py: -------------------------------------------------------------------------------- 1 | from .md import * 2 | node_data = {} 3 | class ColorAdjustment: 4 | """颜色调整节点""" 5 | 6 | @classmethod 7 | def INPUT_TYPES(cls): 8 | return { 9 | "required": { 10 | "image": ("IMAGE",), 11 | }, 12 | "hidden": { 13 | "unique_id": "UNIQUE_ID", 14 | } 15 | } 16 | 17 | RETURN_TYPES = ("IMAGE",) 18 | FUNCTION = "adjust" 19 | CATEGORY = "🎈LAOGOU/Image" 20 | OUTPUT_NODE = True 21 | 22 | def adjust(self, image, unique_id): 23 | try: 24 | node_id = unique_id 25 | event = Event() 26 | node_data[node_id] = { 27 | "event": event, 28 | "result": None, 29 | "shape": image.shape 30 | } 31 | 32 | preview_image = (torch.clamp(image.clone(), 0, 1) * 255).cpu().numpy().astype(np.uint8)[0] 33 | pil_image = Image.fromarray(preview_image) 34 | buffer = io.BytesIO() 35 | pil_image.save(buffer, format="PNG") 36 | base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8') 37 | 38 | try: 39 | PromptServer.instance.send_sync("color_adjustment_update", { 40 | "node_id": node_id, 41 | "image_data": f"data:image/png;base64,{base64_image}" 42 | }) 43 | 44 | if not event.wait(timeout=5): 45 | if node_id in node_data: 46 | del node_data[node_id] 47 | return (image,) 48 | 49 | result_image = node_data[node_id]["result"] 50 | del node_data[node_id] 51 | return (result_image if result_image is not None else image,) 52 | 53 | except Exception as e: 54 | if node_id in node_data: 55 | del node_data[node_id] 56 | return (image,) 57 | 58 | except Exception as e: 59 | if node_id in node_data: 60 | del node_data[node_id] 61 | return (image,) 62 | 63 | @PromptServer.instance.routes.post("/color_adjustment/apply") 64 | async def apply_color_adjustment(request): 65 | try: 66 | data = await request.json() 67 | node_id = data.get("node_id") 68 | adjusted_data = data.get("adjusted_data") 69 | 70 | if node_id not in node_data: 71 | return web.json_response({"success": False, "error": "节点数据不存在"}) 72 | 73 | try: 74 | node_info = node_data[node_id] 75 | 76 | if isinstance(adjusted_data, list): 77 | batch, height, width, channels = node_info["shape"] 78 | 79 | if len(adjusted_data) >= height * width * 4: 80 | rgba_array = np.array(adjusted_data, dtype=np.uint8).reshape(height, width, 4) 81 | rgb_array = rgba_array[:, :, :3] 82 | tensor_image = torch.from_numpy(rgb_array / 255.0).float().reshape(batch, height, width, channels) 83 | node_info["result"] = tensor_image 84 | 85 | node_info["event"].set() 86 | return web.json_response({"success": True}) 87 | 88 | except Exception as e: 89 | if node_id in node_data and "event" in node_data[node_id]: 90 | node_data[node_id]["event"].set() 91 | return web.json_response({"success": False, "error": str(e)}) 92 | 93 | except Exception as e: 94 | return web.json_response({"success": False, "error": str(e)}) 95 | 96 | NODE_CLASS_MAPPINGS = { 97 | "ColorAdjustment": ColorAdjustment, 98 | } 99 | 100 | NODE_DISPLAY_NAME_MAPPINGS = { 101 | "ColorAdjustment": "颜色调整", 102 | } 103 | -------------------------------------------------------------------------------- /py/switch.py: -------------------------------------------------------------------------------- 1 | from .md import * 2 | CATEGORY_TYPE = "🎈LAOGOU/Switch" 3 | class LazySwitch2way: 4 | @classmethod 5 | def INPUT_TYPES(cls): 6 | return { 7 | "required": { 8 | "boolean": ("BOOLEAN", {"default": True}), 9 | "ON_TRUE": (any, {"lazy": True}), 10 | "on_true": (any, {"lazy": True}), 11 | "ON_FALSE": (any, {"lazy": True}), 12 | "on_false": (any, {"lazy": True}), 13 | } 14 | } 15 | 16 | RETURN_TYPES = (any, any,) 17 | RETURN_NAMES = ("OUTPUT", "output",) 18 | FUNCTION = "switch" 19 | CATEGORY = CATEGORY_TYPE 20 | 21 | def check_lazy_status(self, boolean, ON_TRUE, on_true, ON_FALSE, on_false): 22 | result = [] 23 | if boolean: 24 | if ON_TRUE is None: 25 | result.append("ON_TRUE") 26 | if on_true is None: 27 | result.append("on_true") 28 | else: 29 | if ON_FALSE is None: 30 | result.append("ON_FALSE") 31 | if on_false is None: 32 | result.append("on_false") 33 | return result if result else None 34 | 35 | def switch(self, boolean, ON_TRUE, on_true, ON_FALSE, on_false): 36 | if boolean: 37 | return (ON_TRUE, on_true,) 38 | else: 39 | return (ON_FALSE, on_false,) 40 | 41 | class LazySwitch1way: 42 | @classmethod 43 | def INPUT_TYPES(cls): 44 | return { 45 | "required": { 46 | "boolean": ("BOOLEAN", {"default": True}), 47 | "ON_TRUE": (any, {"lazy": True}), 48 | "ON_FALSE": (any, {"lazy": True}), 49 | } 50 | } 51 | 52 | RETURN_TYPES = (any,) 53 | RETURN_NAMES = ("OUTPUT",) 54 | FUNCTION = "switch" 55 | CATEGORY = CATEGORY_TYPE 56 | 57 | def check_lazy_status(self, boolean, ON_TRUE, ON_FALSE): 58 | result = [] 59 | if boolean: 60 | if ON_TRUE is None: 61 | result.append("ON_TRUE") 62 | else: 63 | if ON_FALSE is None: 64 | result.append("ON_FALSE") 65 | return result if result else None 66 | 67 | def switch(self, boolean, ON_TRUE, ON_FALSE): 68 | if boolean: 69 | return (ON_TRUE,) 70 | else: 71 | return (ON_FALSE,) 72 | 73 | class GroupSwitcher: 74 | """ 75 | 组切换节点,带有布尔值控制 76 | """ 77 | @classmethod 78 | def INPUT_TYPES(cls): 79 | return { 80 | "required": { 81 | "boolean": ("BOOLEAN", {"default": True}), # 布尔值控制 82 | } 83 | } 84 | 85 | RETURN_TYPES = () # 无输出端口 86 | FUNCTION = "switch" 87 | CATEGORY = CATEGORY_TYPE 88 | 89 | def switch(self, boolean): 90 | 91 | return () 92 | 93 | class MuterSwitcher: 94 | @classmethod 95 | def INPUT_TYPES(cls): 96 | return { 97 | "required": { 98 | "boolean": ("BOOLEAN", {"default": True}), 99 | }, 100 | "optional": { 101 | "ON_TRUE": (any,), 102 | "on_true": (any,), 103 | "ON_FALSE": (any,), 104 | "on_false": (any,), 105 | } 106 | } 107 | 108 | RETURN_TYPES = (any, any,) 109 | RETURN_NAMES = ("OUTPUT", "output",) 110 | FUNCTION = "switch" 111 | CATEGORY = CATEGORY_TYPE 112 | 113 | def switch(self, ON_TRUE=None, on_true=None, ON_FALSE=None, on_false=None, boolean=True): 114 | # 根据布尔值选择输出 115 | if boolean: 116 | return (ON_TRUE, on_true,) 117 | else: 118 | return (ON_FALSE, on_false,) 119 | 120 | NODE_CLASS_MAPPINGS = { 121 | "LazySwitch2way": LazySwitch2way, 122 | "LazySwitch1way": LazySwitch1way, 123 | "GroupSwitcher": GroupSwitcher, 124 | "MuterSwitcher": MuterSwitcher, 125 | } 126 | 127 | NODE_DISPLAY_NAME_MAPPINGS = { 128 | "LazySwitch2way": "🎈LazySwitch2way", 129 | "LazySwitch1way": "🎈LazySwitch1way", 130 | "GroupSwitcher": "🎈GroupSwitcher", 131 | "MuterSwitcher": "🎈MuterSwitcher" 132 | } 133 | 134 | 135 | -------------------------------------------------------------------------------- /py/inspyrenet/modules/attention_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.nn.parameter import Parameter 6 | from operator import xor 7 | from typing import Optional 8 | 9 | from .layers import * 10 | 11 | class SICA(nn.Module): 12 | def __init__(self, in_channel, out_channel=1, depth=64, base_size=None, stage=None, lmap_in=False): 13 | super(SICA, self).__init__() 14 | self.in_channel = in_channel 15 | self.depth = depth 16 | self.lmap_in = lmap_in 17 | if base_size is not None and stage is not None: 18 | self.stage_size = (base_size[0] // (2 ** stage), base_size[1] // (2 ** stage)) 19 | else: 20 | self.stage_size = None 21 | 22 | self.conv_query = nn.Sequential(Conv2d(in_channel, depth, 3, relu=True), 23 | Conv2d(depth, depth, 3, relu=True)) 24 | self.conv_key = nn.Sequential(Conv2d(in_channel, depth, 1, relu=True), 25 | Conv2d(depth, depth, 1, relu=True)) 26 | self.conv_value = nn.Sequential(Conv2d(in_channel, depth, 1, relu=True), 27 | Conv2d(depth, depth, 1, relu=True)) 28 | 29 | if self.lmap_in is True: 30 | self.ctx = 5 31 | else: 32 | self.ctx = 3 33 | 34 | self.conv_out1 = Conv2d(depth, depth, 3, relu=True) 35 | self.conv_out2 = Conv2d(in_channel + depth, depth, 3, relu=True) 36 | self.conv_out3 = Conv2d(depth, depth, 3, relu=True) 37 | self.conv_out4 = Conv2d(depth, out_channel, 1) 38 | 39 | self.threshold = Parameter(torch.tensor([0.5])) 40 | 41 | if self.lmap_in is True: 42 | self.lthreshold = Parameter(torch.tensor([0.5])) 43 | 44 | def forward(self, x, smap, lmap: Optional[torch.Tensor]=None): 45 | assert not xor(self.lmap_in is True, lmap is not None) 46 | b, c, h, w = x.shape 47 | 48 | # compute class probability 49 | smap = F.interpolate(smap, size=x.shape[-2:], mode='bilinear', align_corners=False) 50 | smap = torch.sigmoid(smap) 51 | p = smap - self.threshold 52 | 53 | fg = torch.clip(p, 0, 1) # foreground 54 | bg = torch.clip(-p, 0, 1) # background 55 | cg = self.threshold - torch.abs(p) # confusion area 56 | 57 | if self.lmap_in is True and lmap is not None: 58 | lmap = F.interpolate(lmap, size=x.shape[-2:], mode='bilinear', align_corners=False) 59 | lmap = torch.sigmoid(lmap) 60 | lp = lmap - self.lthreshold 61 | fp = torch.clip(lp, 0, 1) # foreground 62 | bp = torch.clip(-lp, 0, 1) # background 63 | 64 | prob = [fg, bg, cg, fp, bp] 65 | else: 66 | prob = [fg, bg, cg] 67 | 68 | prob = torch.cat(prob, dim=1) 69 | 70 | # reshape feature & prob 71 | if self.stage_size is not None: 72 | shape = self.stage_size 73 | shape_mul = self.stage_size[0] * self.stage_size[1] 74 | else: 75 | shape = (h, w) 76 | shape_mul = h * w 77 | 78 | f = F.interpolate(x, size=shape, mode='bilinear', align_corners=False).view(b, shape_mul, -1) 79 | prob = F.interpolate(prob, size=shape, mode='bilinear', align_corners=False).view(b, self.ctx, shape_mul) 80 | 81 | # compute context vector 82 | context = torch.bmm(prob, f).permute(0, 2, 1).unsqueeze(3) # b, 3, c 83 | 84 | # k q v compute 85 | query = self.conv_query(x).view(b, self.depth, -1).permute(0, 2, 1) 86 | key = self.conv_key(context).view(b, self.depth, -1) 87 | value = self.conv_value(context).view(b, self.depth, -1).permute(0, 2, 1) 88 | 89 | # compute similarity map 90 | sim = torch.bmm(query, key) # b, hw, c x b, c, 2 91 | sim = (self.depth ** -.5) * sim 92 | sim = F.softmax(sim, dim=-1) 93 | 94 | # compute refined feature 95 | context = torch.bmm(sim, value).permute(0, 2, 1).contiguous().view(b, -1, h, w) 96 | context = self.conv_out1(context) 97 | 98 | x = torch.cat([x, context], dim=1) 99 | x = self.conv_out2(x) 100 | x = self.conv_out3(x) 101 | out = self.conv_out4(x) 102 | 103 | return x, out -------------------------------------------------------------------------------- /web/counter.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../scripts/app.js"; 2 | import { api } from "../../scripts/api.js"; 3 | 4 | app.registerExtension({ 5 | name: "LG.Counter", 6 | 7 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 8 | if (nodeData.name === "LG_Counter") { 9 | 10 | // 保存原始的 onNodeCreated 方法 11 | const onNodeCreated = nodeType.prototype.onNodeCreated; 12 | 13 | nodeType.prototype.onNodeCreated = function() { 14 | const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined; 15 | 16 | // 初始化显示文本 17 | this.currentCountText = ""; 18 | 19 | // 添加刷新按钮 20 | this.addWidget("button", "refresh", "刷新计数器", () => { 21 | this.resetCounter(); 22 | }); 23 | 24 | return r; 25 | }; 26 | 27 | // 重写 onDrawForeground 方法来显示计数 28 | const onDrawForeground = nodeType.prototype.onDrawForeground; 29 | nodeType.prototype.onDrawForeground = function(ctx) { 30 | const r = onDrawForeground?.apply?.(this, arguments); 31 | 32 | // 如果有计数文本要显示 33 | if (this.currentCountText) { 34 | ctx.save(); 35 | 36 | // 设置文本样式 - 在标题栏右上角显示 37 | ctx.font = "bold 20px sans-serif"; 38 | ctx.textAlign = "right"; 39 | ctx.textBaseline = "top"; 40 | ctx.fillStyle = "#00ff00"; // 绿色 41 | 42 | // 在标题栏右上角显示计数 43 | const rightX = this.size[0] - 20; // 距离右边10像素 44 | const topY = -25; // 距离顶部6像素,标题栏内 45 | 46 | // 添加文本阴影以提高可读性 47 | ctx.shadowColor = "rgba(0, 0, 0, 0.8)"; 48 | ctx.shadowBlur = 3; 49 | ctx.shadowOffsetX = 1; 50 | ctx.shadowOffsetY = 1; 51 | 52 | ctx.fillText(this.currentCountText, rightX, topY); 53 | 54 | ctx.restore(); 55 | } 56 | 57 | return r; 58 | }; 59 | 60 | // 添加重置计数器的方法 61 | nodeType.prototype.resetCounter = async function() { 62 | try { 63 | const response = await api.fetchApi("/counter/reset", { 64 | method: "POST", 65 | headers: { 66 | "Content-Type": "application/json", 67 | }, 68 | body: JSON.stringify({ 69 | node_id: this.id.toString() 70 | }) 71 | }); 72 | 73 | const result = await response.json(); 74 | 75 | if (result.status === "success") { 76 | console.log("计数器已重置:", result.message); 77 | } else { 78 | console.error("重置计数器失败:", result.message); 79 | } 80 | } catch (error) { 81 | console.error("重置计数器时发生错误:", error); 82 | } 83 | }; 84 | 85 | // 添加右键菜单选项 86 | const getExtraMenuOptions = nodeType.prototype.getExtraMenuOptions; 87 | nodeType.prototype.getExtraMenuOptions = function(_, options) { 88 | const r = getExtraMenuOptions ? getExtraMenuOptions.apply(this, arguments) : undefined; 89 | 90 | options.unshift({ 91 | content: "重置计数器", 92 | callback: () => { 93 | this.resetCounter(); 94 | } 95 | }); 96 | 97 | return r; 98 | }; 99 | } 100 | }, 101 | 102 | // 监听后端发送的计数更新事件 103 | async setup() { 104 | api.addEventListener("counter_update", ({ detail }) => { 105 | if (!detail || !detail.node_id) return; 106 | 107 | const node = app.graph._nodes_by_id[detail.node_id]; 108 | if (node && node.type === "LG_Counter") { 109 | // 更新显示文本 110 | node.currentCountText = detail.count.toString(); 111 | // 触发重绘 112 | node.setDirtyCanvas(true, true); 113 | 114 | console.log(`[Counter] 节点 ${detail.node_id} 计数更新: ${detail.count}`); 115 | } 116 | }); 117 | } 118 | }); 119 | 120 | -------------------------------------------------------------------------------- /web/Util.js: -------------------------------------------------------------------------------- 1 | import { api } from "../../scripts/api.js"; 2 | import { app } from "../../scripts/app.js"; 3 | 4 | export class Util { 5 | 6 | // Server 7 | static AddMessageListener(messagePath, handlerFunc) { 8 | api.addEventListener(messagePath, handlerFunc); 9 | } 10 | 11 | // Widget 12 | static SetTextAreaContent(widget, text) { 13 | widget.element.textContent = text 14 | } 15 | 16 | static SetTextAreaScrollPos(widget, pos01) { 17 | widget.element.scroll(0, widget.element.scrollHeight * pos01) 18 | } 19 | 20 | static AddReadOnlyTextArea(node, name, text, placeholder = "") { 21 | const inputEl = document.createElement("textarea"); 22 | inputEl.className = "comfy-multiline-input"; 23 | inputEl.placeholder = placeholder 24 | inputEl.spellcheck = false 25 | inputEl.readOnly = true 26 | inputEl.textContent = text 27 | return node.addDOMWidget(name, "", inputEl, { 28 | serialize: false, 29 | }); 30 | } 31 | 32 | static AddButtonWidget(node, label, callback, value = null) { 33 | return node.addWidget("button", label, value, callback); 34 | } 35 | 36 | } 37 | 38 | // 通用终端管理器 39 | export class TerminalManager { 40 | constructor(messagePath, nodeType) { 41 | this.messagePath = messagePath; 42 | this.nodeType = nodeType; 43 | this.textVersion = 0; 44 | this.lines = new Array(); 45 | 46 | // 监听消息 47 | Util.AddMessageListener(messagePath, (event) => { 48 | this.textVersion++; 49 | if (event.detail.clear) { 50 | this.lines.length = 0; 51 | } 52 | let totalText = String(event.detail.text || ""); 53 | this.lines.push(...(totalText.split("\n"))); 54 | if (this.lines.length > 1024) { 55 | this.lines = this.lines.slice(-1024); 56 | } 57 | // 刷新所有相关节点 58 | for (let i = 0; i < app.graph._nodes.length; i++) { 59 | var node = app.graph._nodes[i]; 60 | if (node.type == nodeType && node.setDirtyCanvas) { 61 | node.setDirtyCanvas(true); 62 | } 63 | } 64 | }); 65 | } 66 | 67 | // 清空终端 68 | clearTerminal() { 69 | this.lines.length = 0; 70 | this.textVersion++; 71 | } 72 | 73 | // 获取终端内容 74 | getContent() { 75 | return this.lines.join("\n"); 76 | } 77 | 78 | // 创建节点时的设置 79 | async setupNode(node) { 80 | var textArea = Util.AddReadOnlyTextArea(node, "terminal", ""); 81 | setTimeout(() => { 82 | if (textArea.element) { 83 | textArea.element.style.backgroundColor = "#000000ff"; 84 | textArea.element.style.color = "#ffffffff"; 85 | textArea.element.style.fontFamily = "monospace"; 86 | } 87 | }, 0); 88 | 89 | // 动态导入避免重复声明问题 90 | try { 91 | const { MultiButtonWidget } = await import("./multi_button_widget.js"); 92 | const { queueSelectedOutputNodes } = await import("./queue_shortcut.js"); 93 | 94 | // 添加多按钮组件:清理日志 + 执行 95 | const buttons = [ 96 | { 97 | text: "清理日志", 98 | callback: () => { 99 | this.clearTerminal(); 100 | } 101 | }, 102 | { 103 | text: "执行", 104 | callback: () => { 105 | queueSelectedOutputNodes(); 106 | } 107 | } 108 | ]; 109 | 110 | const multiButtonWidget = MultiButtonWidget(app, "", { 111 | labelWidth: 0, 112 | buttonSpacing: 4 113 | }, buttons); 114 | 115 | node.addCustomWidget(multiButtonWidget); 116 | } catch (error) { 117 | console.error("Failed to load button components:", error); 118 | // 如果动态导入失败,回退到原来的单按钮 119 | let clearBtn = Util.AddButtonWidget(node, "清空日志", () => { 120 | this.clearTerminal(); 121 | }); 122 | clearBtn.width = 128; 123 | } 124 | 125 | node.terminalVersion = -1; 126 | return node; 127 | } 128 | 129 | 130 | // 绘制时的更新 131 | updateNode(node, onDrawForeground, ctx, graphcanvas) { 132 | if (node.terminalVersion != this.textVersion) { 133 | node.terminalVersion = this.textVersion; 134 | for (var i = 0; i < node.widgets.length; i++) { 135 | var wid = node.widgets[i]; 136 | if (wid.name == "terminal") { 137 | Util.SetTextAreaContent(wid, this.getContent()); 138 | Util.SetTextAreaScrollPos(wid, 1.0); 139 | break; 140 | } 141 | } 142 | } 143 | return onDrawForeground?.apply(node, [ctx, graphcanvas]); 144 | } 145 | } -------------------------------------------------------------------------------- /web/multi_button_widget.js: -------------------------------------------------------------------------------- 1 | 2 | // 通用多按钮组件定义 3 | const MultiButtonWidget = (app, inputName, options, buttons) => { 4 | const widget = { 5 | name: inputName, 6 | type: "multi_button", 7 | y: 0, 8 | value: null, 9 | options: options || {}, 10 | clicked_button: null, 11 | click_time: 0 12 | }; 13 | 14 | // 使用 ComfyUI 原生样式常量 15 | const margin = 15; 16 | const button_height = LiteGraph.NODE_WIDGET_HEIGHT || 20; 17 | const label_width = options.labelWidth !== undefined ? options.labelWidth : 80; 18 | const button_spacing = options.buttonSpacing || 4; 19 | const button_count = buttons.length; 20 | 21 | // 原生 ComfyUI 按钮颜色 22 | const BUTTON_BGCOLOR = "#222"; 23 | const BUTTON_OUTLINE_COLOR = "#666"; 24 | const BUTTON_TEXT_COLOR = "#DDD"; 25 | const BUTTON_CLICKED_COLOR = "#AAA"; 26 | 27 | widget.draw = function(ctx, node, width, Y, height) { 28 | if (app.canvas.ds.scale < 0.50) return; 29 | 30 | ctx.save(); 31 | ctx.lineWidth = 1; 32 | 33 | // 绘制标签(仅当有标签时) 34 | if (label_width > 0 && inputName) { 35 | ctx.fillStyle = LiteGraph.WIDGET_SECONDARY_TEXT_COLOR; 36 | ctx.textAlign = "left"; 37 | ctx.fillText(inputName, margin, Y + height * 0.7); 38 | } 39 | 40 | // 计算按钮区域 41 | const label_space = label_width > 0 ? label_width : 0; 42 | const left_margin = label_width > 0 ? margin : 10; // 没有标签时使用小边距避免超出节点 43 | const right_margin = label_width > 0 ? margin : 10; // 右边也保持小边距 44 | const available_width = width - label_space - left_margin - right_margin; 45 | const total_spacing = button_spacing * (button_count - 1); 46 | const button_width = (available_width - total_spacing) / button_count; 47 | const start_x = label_space + left_margin; 48 | 49 | // 检查点击高亮是否需要清除 50 | const now = Date.now(); 51 | if (widget.click_time && now - widget.click_time > 150) { 52 | widget.clicked_button = null; 53 | widget.click_time = 0; 54 | } 55 | 56 | // 循环绘制所有按钮 57 | for (let i = 0; i < button_count; i++) { 58 | const button = buttons[i]; 59 | const button_x = start_x + i * (button_width + button_spacing); 60 | 61 | // 确定按钮颜色(点击高亮或默认) 62 | const is_clicked = (widget.clicked_button === i); 63 | const button_color = is_clicked ? BUTTON_CLICKED_COLOR : 64 | (button.color || BUTTON_BGCOLOR); 65 | 66 | // 绘制按钮背景 67 | ctx.fillStyle = button_color; 68 | ctx.fillRect(button_x, Y, button_width, button_height); 69 | 70 | // 绘制按钮边框 71 | ctx.strokeStyle = BUTTON_OUTLINE_COLOR; 72 | ctx.strokeRect(button_x, Y, button_width, button_height); 73 | 74 | // 绘制按钮文字 75 | ctx.fillStyle = BUTTON_TEXT_COLOR; 76 | ctx.textAlign = "center"; 77 | ctx.fillText(button.text, button_x + button_width / 2, Y + button_height * 0.7); 78 | } 79 | 80 | ctx.restore(); 81 | }; 82 | 83 | widget.onPointerDown = function(pointer, node) { 84 | const e = pointer.eDown; 85 | const label_space = label_width > 0 ? label_width : 0; 86 | const left_margin = label_width > 0 ? margin : 10; 87 | const right_margin = label_width > 0 ? margin : 10; 88 | const x = e.canvasX - node.pos[0] - label_space - left_margin; 89 | const available_width = node.size[0] - label_space - left_margin - right_margin; 90 | const total_spacing = button_spacing * (button_count - 1); 91 | const button_width = (available_width - total_spacing) / button_count; 92 | 93 | pointer.onClick = () => { 94 | // 计算点击了哪个按钮 95 | for (let i = 0; i < button_count; i++) { 96 | const button_start = i * (button_width + button_spacing); 97 | const button_end = button_start + button_width; 98 | 99 | if (x >= button_start && x <= button_end) { 100 | // 点击了第 i 个按钮 101 | widget.clicked_button = i; 102 | widget.click_time = Date.now(); 103 | 104 | // 执行回调函数 105 | if (buttons[i] && buttons[i].callback) { 106 | buttons[i].callback(); 107 | } 108 | 109 | // 触发重绘以显示点击效果 110 | app.graph.setDirtyCanvas(true); 111 | break; 112 | } 113 | } 114 | }; 115 | }; 116 | 117 | widget.computeSize = function() { 118 | return [0, button_height]; 119 | }; 120 | 121 | widget.serializeValue = async () => { 122 | return null; 123 | }; 124 | 125 | return widget; 126 | }; 127 | 128 | // 导出组件 129 | export { MultiButtonWidget }; 130 | -------------------------------------------------------------------------------- /web/PB.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../scripts/app.js"; 2 | import { api } from "../../scripts/api.js"; 3 | import { MultiButtonWidget } from "./multi_button_widget.js"; 4 | 5 | async function loadLatestImage(node, folder_type) { 6 | // 获取指定目录中的最新图片 7 | const res = await api.fetchApi(`/lg/get/latest_image?type=${folder_type}`, { cache: "no-store" }); 8 | if (res.status == 200) { 9 | const item = await res.json(); 10 | if (item && item.filename) { 11 | const imageWidget = node.widgets.find(w => w.name === 'image'); 12 | if (!imageWidget) return false; 13 | 14 | // 保存文件信息的JSON字符串到节点属性 15 | const fileInfo = JSON.stringify({ 16 | filename: item.filename, 17 | subfolder: item.subfolder || '', 18 | type: item.type || 'temp' 19 | }); 20 | node._latestFileInfo = fileInfo; 21 | 22 | // 设置 widget 值为 ComfyUI 期望的格式: "filename [type]" 23 | const displayValue = `${item.filename} [${item.type}]`; 24 | imageWidget.value = displayValue; 25 | 26 | // 加载并显示图像 27 | const image = new Image(); 28 | image.src = api.apiURL(`/view?filename=${item.filename}&type=${item.type}&subfolder=${item.subfolder || ''}`); 29 | node._imgs = [image]; 30 | return true; 31 | } 32 | } 33 | return false; 34 | } 35 | 36 | app.registerExtension({ 37 | name: "Comfy.LG.CachePreview", 38 | 39 | nodeCreated(node, app) { 40 | if (node.comfyClass !== "CachePreviewBridge") return; 41 | 42 | let imageWidget = node.widgets.find(w => w.name === 'image'); 43 | if (!imageWidget) return; 44 | 45 | // 存储当前文件信息 46 | node._latestFileInfo = null; 47 | 48 | // 重写序列化方法,确保执行时使用最新值 49 | imageWidget.serializeValue = function(nodeId, widgetIndex) { 50 | if (node._latestFileInfo) { 51 | return node._latestFileInfo; 52 | } 53 | return this.value || ""; 54 | }; 55 | 56 | node._imgs = [new Image()]; 57 | node.imageIndex = 0; 58 | 59 | // 使用多按钮组件创建刷新按钮 60 | const refreshWidget = node.addCustomWidget(MultiButtonWidget(app, "Refresh From", { 61 | labelWidth: 80, 62 | buttonSpacing: 4 63 | }, [ 64 | { 65 | text: "Temp", 66 | callback: () => { 67 | loadLatestImage(node, "temp").then(success => { 68 | if (success) { 69 | app.graph.setDirtyCanvas(true); 70 | } 71 | }); 72 | } 73 | }, 74 | { 75 | text: "Output", 76 | callback: () => { 77 | loadLatestImage(node, "output").then(success => { 78 | if (success) { 79 | app.graph.setDirtyCanvas(true); 80 | } 81 | }); 82 | } 83 | } 84 | ])); 85 | refreshWidget.serialize = false; 86 | 87 | // 重写 imgs 属性,处理来自 MaskEditor 的粘贴 88 | Object.defineProperty(node, 'imgs', { 89 | set(v) { 90 | if (!v || v.length === 0) return; 91 | 92 | const stackTrace = new Error().stack; 93 | 94 | // 来自 MaskEditor 的粘贴 95 | if (stackTrace.includes('pasteFromClipspace')) { 96 | if (v[0] && v[0].src) { 97 | const urlParts = v[0].src.split("?"); 98 | if (urlParts.length > 1) { 99 | const sp = new URLSearchParams(urlParts[1]); 100 | const filename = sp.get('filename'); 101 | const type = sp.get('type') || 'input'; 102 | const subfolder = sp.get('subfolder') || ''; 103 | 104 | if (filename) { 105 | // 保存文件信息的JSON字符串 106 | const fileInfo = JSON.stringify({ 107 | filename: filename, 108 | subfolder: subfolder, 109 | type: type 110 | }); 111 | 112 | // 保存到节点属性,序列化时会使用 113 | node._latestFileInfo = fileInfo; 114 | imageWidget.value = fileInfo; 115 | 116 | // 直接使用传入的图像 117 | node._imgs = v; 118 | app.graph.setDirtyCanvas(true); 119 | 120 | return; 121 | } 122 | } 123 | } 124 | } 125 | 126 | // 其他情况直接设置 127 | node._imgs = v; 128 | }, 129 | get() { 130 | return node._imgs; 131 | } 132 | }); 133 | } 134 | }); 135 | -------------------------------------------------------------------------------- /web/image_loader_counter.js: -------------------------------------------------------------------------------- 1 | /** 2 | * LG_ImageLoaderWithCounter 前端扩展 3 | * 为带计数器的图片加载器节点提供前端UI支持 4 | */ 5 | import { app } from "../../scripts/app.js"; 6 | import { api } from "../../scripts/api.js"; 7 | 8 | app.registerExtension({ 9 | name: "LG.ImageLoaderWithCounter", 10 | 11 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 12 | if (nodeData.name === "LG_ImageLoaderWithCounter") { 13 | 14 | // 保存原始的 onNodeCreated 方法 15 | const onNodeCreated = nodeType.prototype.onNodeCreated; 16 | 17 | nodeType.prototype.onNodeCreated = function() { 18 | const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined; 19 | 20 | // 初始化显示文本 21 | this.currentCountText = ""; 22 | this.totalImagesText = ""; 23 | 24 | // 添加刷新按钮 25 | this.addWidget("button", "refresh", "🔄 刷新计数器", () => { 26 | this.resetImageLoaderCounter(); 27 | }); 28 | 29 | return r; 30 | }; 31 | 32 | // 重写 onDrawForeground 方法来显示计数和总数 33 | const onDrawForeground = nodeType.prototype.onDrawForeground; 34 | nodeType.prototype.onDrawForeground = function(ctx) { 35 | const r = onDrawForeground?.apply?.(this, arguments); 36 | 37 | // 如果有计数文本要显示 38 | if (this.currentCountText) { 39 | ctx.save(); 40 | 41 | // 设置文本样式 - 在标题栏右上角显示 42 | ctx.font = "bold 20px sans-serif"; 43 | ctx.textAlign = "right"; 44 | ctx.textBaseline = "top"; 45 | ctx.fillStyle = "#00ff00"; // 绿色 46 | 47 | // 构建显示文本:当前索引/总数 48 | let displayText = this.currentCountText; 49 | if (this.totalImagesText) { 50 | displayText = `${this.currentCountText}/${this.totalImagesText}`; 51 | } 52 | 53 | // 在标题栏右上角显示计数 54 | const rightX = this.size[0] - 20; // 距离右边20像素 55 | const topY = -25; // 距离顶部25像素,标题栏内 56 | 57 | // 添加文本阴影以提高可读性 58 | ctx.shadowColor = "rgba(0, 0, 0, 0.8)"; 59 | ctx.shadowBlur = 3; 60 | ctx.shadowOffsetX = 1; 61 | ctx.shadowOffsetY = 1; 62 | 63 | ctx.fillText(displayText, rightX, topY); 64 | 65 | ctx.restore(); 66 | } 67 | 68 | return r; 69 | }; 70 | 71 | // 添加重置图片加载器计数器的方法 72 | nodeType.prototype.resetImageLoaderCounter = async function() { 73 | try { 74 | const response = await api.fetchApi("/image_loader_counter/reset", { 75 | method: "POST", 76 | headers: { 77 | "Content-Type": "application/json", 78 | }, 79 | body: JSON.stringify({ 80 | node_id: this.id.toString() 81 | }) 82 | }); 83 | 84 | const result = await response.json(); 85 | 86 | if (result.status === "success") { 87 | console.log("图片加载器计数器已重置:", result.message); 88 | // 更新显示 89 | this.currentCountText = result.current.toString(); 90 | this.setDirtyCanvas(true, true); 91 | } else { 92 | console.error("重置计数器失败:", result.message); 93 | } 94 | } catch (error) { 95 | console.error("重置计数器时发生错误:", error); 96 | } 97 | }; 98 | 99 | // 添加右键菜单选项 100 | const getExtraMenuOptions = nodeType.prototype.getExtraMenuOptions; 101 | nodeType.prototype.getExtraMenuOptions = function(_, options) { 102 | const r = getExtraMenuOptions ? getExtraMenuOptions.apply(this, arguments) : undefined; 103 | 104 | options.unshift({ 105 | content: "🔄 重置计数器", 106 | callback: () => { 107 | this.resetImageLoaderCounter(); 108 | } 109 | }); 110 | 111 | return r; 112 | }; 113 | } 114 | }, 115 | 116 | /** 117 | * 监听后端发送的计数更新事件 118 | */ 119 | async setup() { 120 | api.addEventListener("counter_update", ({ detail }) => { 121 | if (!detail || !detail.node_id) return; 122 | 123 | const node = app.graph._nodes_by_id[detail.node_id]; 124 | if (node && node.type === "LG_ImageLoaderWithCounter") { 125 | // 更新显示文本 126 | node.currentCountText = detail.count.toString(); 127 | // 如果有总数信息,也更新 128 | if (detail.total !== undefined) { 129 | node.totalImagesText = detail.total.toString(); 130 | } 131 | // 触发重绘 132 | node.setDirtyCanvas(true, true); 133 | 134 | console.log(`[ImageLoaderCounter] 节点 ${detail.node_id} 索引更新: ${detail.count}/${detail.total || '?'}`); 135 | } 136 | }); 137 | } 138 | }); 139 | 140 | -------------------------------------------------------------------------------- /py/inspyrenet/modules/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import cv2 6 | import numpy as np 7 | 8 | from kornia.morphology import dilation, erosion 9 | from torch.nn.parameter import Parameter 10 | 11 | class ImagePyramid: 12 | def __init__(self, ksize=7, sigma=1, channels=1): 13 | self.ksize = ksize 14 | self.sigma = sigma 15 | self.channels = channels 16 | 17 | k = cv2.getGaussianKernel(ksize, sigma) 18 | k = np.outer(k, k) 19 | k = torch.tensor(k).float() 20 | self.kernel = k.repeat(channels, 1, 1, 1) 21 | 22 | def to(self, device): 23 | self.kernel = self.kernel.to(device) 24 | return self 25 | 26 | def cuda(self, idx=None): 27 | if idx is None: 28 | idx = torch.cuda.current_device() 29 | 30 | self.to(device="cuda:{}".format(idx)) 31 | return self 32 | 33 | def expand(self, x): 34 | z = torch.zeros_like(x) 35 | x = torch.cat([x, z, z, z], dim=1) 36 | x = F.pixel_shuffle(x, 2) 37 | x = F.pad(x, (self.ksize // 2, ) * 4, mode='reflect') 38 | x = F.conv2d(x, self.kernel * 4, groups=self.channels) 39 | return x 40 | 41 | def reduce(self, x): 42 | x = F.pad(x, (self.ksize // 2, ) * 4, mode='reflect') 43 | x = F.conv2d(x, self.kernel, groups=self.channels) 44 | x = x[:, :, ::2, ::2] 45 | return x 46 | 47 | def deconstruct(self, x): 48 | reduced_x = self.reduce(x) 49 | expanded_reduced_x = self.expand(reduced_x) 50 | 51 | if x.shape != expanded_reduced_x.shape: 52 | expanded_reduced_x = F.interpolate(expanded_reduced_x, x.shape[-2:]) 53 | 54 | laplacian_x = x - expanded_reduced_x 55 | return reduced_x, laplacian_x 56 | 57 | def reconstruct(self, x, laplacian_x): 58 | expanded_x = self.expand(x) 59 | if laplacian_x.shape != expanded_x: 60 | laplacian_x = F.interpolate(laplacian_x, expanded_x.shape[-2:], mode='bilinear', align_corners=True) 61 | return expanded_x + laplacian_x 62 | 63 | class Transition: 64 | def __init__(self, k=3): 65 | self.kernel = torch.tensor(cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))).float() 66 | 67 | def to(self, device): 68 | self.kernel = self.kernel.to(device) 69 | return self 70 | 71 | def cuda(self, idx=0): 72 | self.to(device="cuda:{}".format(idx)) 73 | return self 74 | 75 | def __call__(self, x): 76 | x = torch.sigmoid(x) 77 | dx = dilation(x, self.kernel) 78 | ex = erosion(x, self.kernel) 79 | 80 | return ((dx - ex) > .5).float() 81 | 82 | class Conv2d(nn.Module): 83 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, padding='same', bias=False, bn=True, relu=False): 84 | super(Conv2d, self).__init__() 85 | if '__iter__' not in dir(kernel_size): 86 | kernel_size = (kernel_size, kernel_size) 87 | if '__iter__' not in dir(stride): 88 | stride = (stride, stride) 89 | if '__iter__' not in dir(dilation): 90 | dilation = (dilation, dilation) 91 | 92 | if padding == 'same': 93 | width_pad_size = kernel_size[0] + (kernel_size[0] - 1) * (dilation[0] - 1) 94 | height_pad_size = kernel_size[1] + (kernel_size[1] - 1) * (dilation[1] - 1) 95 | elif padding == 'valid': 96 | width_pad_size = 0 97 | height_pad_size = 0 98 | else: 99 | if '__iter__' in dir(padding): 100 | width_pad_size = padding[0] * 2 101 | height_pad_size = padding[1] * 2 102 | else: 103 | width_pad_size = padding * 2 104 | height_pad_size = padding * 2 105 | 106 | width_pad_size = width_pad_size // 2 + (width_pad_size % 2 - 1) 107 | height_pad_size = height_pad_size // 2 + (height_pad_size % 2 - 1) 108 | pad_size = (width_pad_size, height_pad_size) 109 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_size, dilation, groups, bias=bias) 110 | self.reset_parameters() 111 | 112 | if bn is True: 113 | self.bn = nn.BatchNorm2d(out_channels) 114 | else: 115 | self.bn = None 116 | 117 | if relu is True: 118 | self.relu = nn.ReLU(inplace=True) 119 | else: 120 | self.relu = None 121 | 122 | def forward(self, x): 123 | x = self.conv(x) 124 | if self.bn is not None: 125 | x = self.bn(x) 126 | if self.relu is not None: 127 | x = self.relu(x) 128 | return x 129 | 130 | def reset_parameters(self): 131 | nn.init.kaiming_normal_(self.conv.weight) 132 | 133 | 134 | class SelfAttention(nn.Module): 135 | def __init__(self, in_channels, mode='hw', stage_size=None): 136 | super(SelfAttention, self).__init__() 137 | 138 | self.mode = mode 139 | 140 | self.query_conv = Conv2d(in_channels, in_channels // 8, kernel_size=(1, 1)) 141 | self.key_conv = Conv2d(in_channels, in_channels // 8, kernel_size=(1, 1)) 142 | self.value_conv = Conv2d(in_channels, in_channels, kernel_size=(1, 1)) 143 | 144 | self.gamma = Parameter(torch.zeros(1)) 145 | self.softmax = nn.Softmax(dim=-1) 146 | 147 | self.stage_size = stage_size 148 | 149 | def forward(self, x): 150 | batch_size, channel, height, width = x.size() 151 | 152 | axis = 1 153 | if 'h' in self.mode: 154 | axis *= height 155 | if 'w' in self.mode: 156 | axis *= width 157 | 158 | view = (batch_size, -1, axis) 159 | 160 | projected_query = self.query_conv(x).view(*view).permute(0, 2, 1) 161 | projected_key = self.key_conv(x).view(*view) 162 | 163 | attention_map = torch.bmm(projected_query, projected_key) 164 | attention = self.softmax(attention_map) 165 | projected_value = self.value_conv(x).view(*view) 166 | 167 | out = torch.bmm(projected_value, attention.permute(0, 2, 1)) 168 | out = out.view(batch_size, channel, height, width) 169 | 170 | out = self.gamma * out + x 171 | return out 172 | -------------------------------------------------------------------------------- /py/inspyrenet/InSPyReNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | filepath = os.path.abspath(__file__) 9 | repopath = os.path.split(filepath)[0] 10 | sys.path.append(repopath) 11 | 12 | from .modules.layers import * 13 | from .modules.context_module import * 14 | from .modules.attention_module import * 15 | from .modules.decoder_module import * 16 | 17 | from .backbones.SwinTransformer import SwinB 18 | 19 | class InSPyReNet(nn.Module): 20 | def __init__(self, backbone, in_channels, depth=64, base_size=[384, 384], threshold=512, **kwargs): 21 | super(InSPyReNet, self).__init__() 22 | self.backbone = backbone 23 | self.in_channels = in_channels 24 | self.depth = depth 25 | self.base_size = base_size 26 | self.threshold = threshold 27 | 28 | self.context1 = PAA_e(self.in_channels[0], self.depth, base_size=self.base_size, stage=0) 29 | self.context2 = PAA_e(self.in_channels[1], self.depth, base_size=self.base_size, stage=1) 30 | self.context3 = PAA_e(self.in_channels[2], self.depth, base_size=self.base_size, stage=2) 31 | self.context4 = PAA_e(self.in_channels[3], self.depth, base_size=self.base_size, stage=3) 32 | self.context5 = PAA_e(self.in_channels[4], self.depth, base_size=self.base_size, stage=4) 33 | 34 | self.decoder = PAA_d(self.depth * 3, depth=self.depth, base_size=base_size, stage=2) 35 | 36 | self.attention0 = SICA(self.depth , depth=self.depth, base_size=self.base_size, stage=0, lmap_in=True) 37 | self.attention1 = SICA(self.depth * 2, depth=self.depth, base_size=self.base_size, stage=1, lmap_in=True) 38 | self.attention2 = SICA(self.depth * 2, depth=self.depth, base_size=self.base_size, stage=2 ) 39 | 40 | self.ret = lambda x, target: F.interpolate(x, size=target.shape[-2:], mode='bilinear', align_corners=False) 41 | self.res = lambda x, size: F.interpolate(x, size=size, mode='bilinear', align_corners=False) 42 | self.des = lambda x, size: F.interpolate(x, size=size, mode='nearest') 43 | 44 | self.image_pyramid = ImagePyramid(7, 1) 45 | 46 | self.transition0 = Transition(17) 47 | self.transition1 = Transition(9) 48 | self.transition2 = Transition(5) 49 | 50 | self.forward = self.forward_inference 51 | 52 | def to(self, device): 53 | self.image_pyramid.to(device) 54 | self.transition0.to(device) 55 | self.transition1.to(device) 56 | self.transition2.to(device) 57 | super(InSPyReNet, self).to(device) 58 | return self 59 | 60 | def cuda(self, idx=None): 61 | if idx is None: 62 | idx = torch.cuda.current_device() 63 | 64 | self.to(device="cuda:{}".format(idx)) 65 | return self 66 | 67 | def eval(self): 68 | super(InSPyReNet, self).train(False) 69 | self.forward = self.forward_inference 70 | return self 71 | 72 | def forward_inspyre(self, x): 73 | B, _, H, W = x.shape 74 | 75 | x1, x2, x3, x4, x5 = self.backbone(x) 76 | 77 | x1 = self.context1(x1) #4 78 | x2 = self.context2(x2) #4 79 | x3 = self.context3(x3) #8 80 | x4 = self.context4(x4) #16 81 | x5 = self.context5(x5) #32 82 | 83 | f3, d3 = self.decoder([x3, x4, x5]) #16 84 | 85 | f3 = self.res(f3, (H // 4, W // 4 )) 86 | f2, p2 = self.attention2(torch.cat([x2, f3], dim=1), d3.detach()) 87 | d2 = self.image_pyramid.reconstruct(d3.detach(), p2) #4 88 | 89 | x1 = self.res(x1, (H // 2, W // 2)) 90 | f2 = self.res(f2, (H // 2, W // 2)) 91 | f1, p1 = self.attention1(torch.cat([x1, f2], dim=1), d2.detach(), p2.detach()) #2 92 | d1 = self.image_pyramid.reconstruct(d2.detach(), p1) #2 93 | 94 | f1 = self.res(f1, (H, W)) 95 | _, p0 = self.attention0(f1, d1.detach(), p1.detach()) #2 96 | d0 = self.image_pyramid.reconstruct(d1.detach(), p0) #2 97 | 98 | out = dict() 99 | out['saliency'] = [d3, d2, d1, d0] 100 | out['laplacian'] = [p2, p1, p0] 101 | 102 | return out 103 | 104 | def forward_inference(self, img, img_lr=None): 105 | B, _, H, W = img.shape 106 | 107 | if self.threshold is None: 108 | out = self.forward_inspyre(img) 109 | d3, d2, d1, d0 = out['saliency'] 110 | p2, p1, p0 = out['laplacian'] 111 | 112 | elif (H <= self.threshold or W <= self.threshold): 113 | if img_lr is not None: 114 | out = self.forward_inspyre(img_lr) 115 | else: 116 | out = self.forward_inspyre(img) 117 | d3, d2, d1, d0 = out['saliency'] 118 | p2, p1, p0 = out['laplacian'] 119 | 120 | else: 121 | # LR Saliency Pyramid 122 | lr_out = self.forward_inspyre(img_lr) 123 | lr_d3, lr_d2, lr_d1, lr_d0 = lr_out['saliency'] 124 | lr_p2, lr_p1, lr_p0 = lr_out['laplacian'] 125 | 126 | # HR Saliency Pyramid 127 | hr_out = self.forward_inspyre(img) 128 | hr_d3, hr_d2, hr_d1, hr_d0 = hr_out['saliency'] 129 | hr_p2, hr_p1, hr_p0 = hr_out['laplacian'] 130 | 131 | # Pyramid Blending 132 | d3 = self.ret(lr_d0, hr_d3) 133 | 134 | t2 = self.ret(self.transition2(d3), hr_p2) 135 | p2 = t2 * hr_p2 136 | d2 = self.image_pyramid.reconstruct(d3, p2) 137 | 138 | t1 = self.ret(self.transition1(d2), hr_p1) 139 | p1 = t1 * hr_p1 140 | d1 = self.image_pyramid.reconstruct(d2, p1) 141 | 142 | t0 = self.ret(self.transition0(d1), hr_p0) 143 | p0 = t0 * hr_p0 144 | d0 = self.image_pyramid.reconstruct(d1, p0) 145 | 146 | pred = torch.sigmoid(d0) 147 | pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8) 148 | 149 | return pred 150 | 151 | def InSPyReNet_SwinB(depth, pretrained, base_size, **kwargs): 152 | return InSPyReNet(SwinB(pretrained=pretrained), [128, 128, 256, 512, 1024], depth, base_size, **kwargs) -------------------------------------------------------------------------------- /web/combo_setter.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../../scripts/app.js"; 2 | 3 | app.registerExtension({ 4 | name: "ComboSetter", 5 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 6 | if (nodeData.name === "ComboSetter") { 7 | const onNodeCreated = nodeType.prototype.onNodeCreated; 8 | 9 | nodeType.prototype.onNodeCreated = function() { 10 | const result = onNodeCreated?.apply(this, arguments); 11 | 12 | // 获取widgets 13 | const labelsWidget = this.widgets?.find(w => w.name === "labels"); 14 | const promptsWidget = this.widgets?.find(w => w.name === "prompts"); 15 | let selectedWidget = this.widgets?.find(w => w.name === "selected"); 16 | 17 | if (!labelsWidget || !promptsWidget || !selectedWidget) { 18 | console.error("ComboSetter: 无法找到所需的widgets"); 19 | return result; 20 | } 21 | 22 | // 将selected widget转换为combo类型 23 | const selectedIndex = this.widgets.indexOf(selectedWidget); 24 | if (selectedWidget.type !== "combo") { 25 | this.widgets.splice(selectedIndex, 1); 26 | selectedWidget = this.addWidget("combo", "selected", "", () => {}, { 27 | values: [""] 28 | }); 29 | // 将新的combo widget移动到正确的位置 30 | const newWidget = this.widgets.pop(); 31 | this.widgets.splice(selectedIndex, 0, newWidget); 32 | selectedWidget = newWidget; 33 | } 34 | 35 | // 添加"Set Combo"按钮 36 | const setComboBtn = this.addWidget("button", "Set Combo", null, () => { 37 | updateComboOptions.call(this); 38 | }); 39 | 40 | // 设置按钮样式 41 | setComboBtn.serialize = false; 42 | 43 | // 更新Combo选项的函数 44 | const updateComboOptions = function() { 45 | const labelsText = labelsWidget.value || ""; 46 | const promptsText = promptsWidget.value || ""; 47 | 48 | // 按行分割labels 49 | const labelLines = labelsText.split('\n') 50 | .map(line => line.trim()) 51 | .filter(line => line.length > 0); 52 | 53 | if (labelLines.length === 0) { 54 | console.warn("ComboSetter: labels为空"); 55 | selectedWidget.options.values = [""]; 56 | selectedWidget.value = ""; 57 | return; 58 | } 59 | 60 | // 更新combo的选项 61 | selectedWidget.options.values = labelLines; 62 | 63 | // 如果当前选中的值不在新的选项中,设置为第一个选项 64 | if (!labelLines.includes(selectedWidget.value)) { 65 | selectedWidget.value = labelLines[0]; 66 | } 67 | 68 | // 触发更新 69 | this.setDirtyCanvas(true, false); 70 | 71 | console.log("ComboSetter: 已更新Combo选项", labelLines); 72 | }; 73 | 74 | // 自动计算节点大小 75 | const originalComputeSize = this.computeSize; 76 | this.computeSize = function(out) { 77 | let size = originalComputeSize ? originalComputeSize.apply(this, arguments) : [200, 100]; 78 | 79 | // 根据widgets数量动态调整高度 80 | const widgetHeight = 40; 81 | const buttonHeight = 30; 82 | const padding = 20; 83 | 84 | let totalHeight = padding; 85 | 86 | // 计算多行文本框的高度 87 | if (labelsWidget) { 88 | const lines = (labelsWidget.value || "").split('\n').length; 89 | totalHeight += Math.max(lines * 20, 60); 90 | } 91 | 92 | if (promptsWidget) { 93 | const lines = (promptsWidget.value || "").split('\n').length; 94 | totalHeight += Math.max(lines * 20, 60); 95 | } 96 | 97 | // 添加combo和按钮的高度 98 | totalHeight += widgetHeight + buttonHeight + padding; 99 | 100 | size[1] = totalHeight; 101 | size[0] = Math.max(size[0], 300); 102 | 103 | return size; 104 | }; 105 | 106 | return result; 107 | }; 108 | 109 | // 序列化时保存combo的值 110 | const onSerialize = nodeType.prototype.onSerialize; 111 | nodeType.prototype.onSerialize = function(o) { 112 | const result = onSerialize?.apply(this, arguments); 113 | 114 | const selectedWidget = this.widgets?.find(w => w.name === "selected"); 115 | if (selectedWidget && selectedWidget.options && selectedWidget.options.values) { 116 | o.selected_options = selectedWidget.options.values; 117 | } 118 | 119 | return result; 120 | }; 121 | 122 | // 反序列化时恢复combo的选项 123 | const onConfigure = nodeType.prototype.onConfigure; 124 | nodeType.prototype.onConfigure = function(o) { 125 | const result = onConfigure?.apply(this, arguments); 126 | 127 | if (o.selected_options) { 128 | const selectedWidget = this.widgets?.find(w => w.name === "selected"); 129 | if (selectedWidget) { 130 | selectedWidget.options = selectedWidget.options || {}; 131 | selectedWidget.options.values = o.selected_options; 132 | } 133 | } 134 | 135 | return result; 136 | }; 137 | } 138 | } 139 | }); 140 | 141 | -------------------------------------------------------------------------------- /py/image_size_adjustment.py: -------------------------------------------------------------------------------- 1 | from .md import * 2 | size_data = {} 3 | class ImageSizeAdjustment: 4 | """图像预览和拉伸调整节点""" 5 | 6 | @classmethod 7 | def INPUT_TYPES(cls): 8 | return { 9 | "required": { 10 | "image": ("IMAGE",), 11 | }, 12 | "hidden": { 13 | "unique_id": "UNIQUE_ID", 14 | } 15 | } 16 | 17 | RETURN_TYPES = ("IMAGE",) 18 | FUNCTION = "adjust" 19 | CATEGORY = "🎈LAOGOU/Image" 20 | OUTPUT_NODE = True 21 | 22 | def adjust(self, image, unique_id): 23 | try: 24 | node_id = unique_id 25 | 26 | # 确保清理可能存在的旧数据 27 | if node_id in size_data: 28 | del size_data[node_id] 29 | 30 | event = Event() 31 | size_data[node_id] = { 32 | "event": event, 33 | "result": None 34 | } 35 | 36 | # 发送预览图像 37 | preview_image = (torch.clamp(image.clone(), 0, 1) * 255).cpu().numpy().astype(np.uint8)[0] 38 | pil_image = Image.fromarray(preview_image) 39 | buffer = io.BytesIO() 40 | pil_image.save(buffer, format="PNG") 41 | base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8') 42 | 43 | try: 44 | PromptServer.instance.send_sync("image_preview_update", { 45 | "node_id": node_id, 46 | "image_data": f"data:image/png;base64,{base64_image}" 47 | }) 48 | 49 | # 等待前端调整完成 50 | if not event.wait(timeout=15): 51 | if node_id in size_data: 52 | del size_data[node_id] 53 | return (image,) 54 | 55 | result_image = size_data[node_id]["result"] 56 | del size_data[node_id] 57 | return (result_image if result_image is not None else image,) 58 | 59 | except Exception as e: 60 | if node_id in size_data: 61 | del size_data[node_id] 62 | return (image,) 63 | 64 | except Exception as e: 65 | if node_id in size_data: 66 | del size_data[node_id] 67 | return (image,) 68 | 69 | @PromptServer.instance.routes.post("/image_preview/apply") 70 | async def apply_image_preview(request): 71 | try: 72 | # 检查内容类型 73 | content_type = request.headers.get('Content-Type', '') 74 | print(f"[ImagePreview] 请求内容类型: {content_type}") 75 | 76 | if 'multipart/form-data' in content_type: 77 | # 处理multipart/form-data请求 78 | reader = await request.multipart() 79 | 80 | # 读取表单字段 81 | node_id = None 82 | new_width = None 83 | new_height = None 84 | image_data = None 85 | 86 | # 逐个处理表单字段 87 | while True: 88 | part = await reader.next() 89 | if part is None: 90 | break 91 | 92 | if part.name == 'node_id': 93 | node_id = await part.text() 94 | elif part.name == 'width': 95 | new_width = int(await part.text()) 96 | elif part.name == 'height': 97 | new_height = int(await part.text()) 98 | elif part.name == 'image_data': 99 | # 读取二进制图像数据 100 | image_data = await part.read(decode=False) 101 | else: 102 | # 处理JSON请求 103 | data = await request.json() 104 | node_id = data.get("node_id") 105 | new_width = data.get("width") 106 | new_height = data.get("height") 107 | image_data = None 108 | 109 | # 检查是否有base64编码的图像数据 110 | adjusted_data_base64 = data.get("adjusted_data_base64") 111 | if adjusted_data_base64: 112 | if adjusted_data_base64.startswith('data:image'): 113 | base64_data = adjusted_data_base64.split(',')[1] 114 | else: 115 | base64_data = adjusted_data_base64 116 | image_data = base64.b64decode(base64_data) 117 | 118 | print(f"[ImagePreview] 接收到数据 - 节点ID: {node_id}") 119 | print(f"[ImagePreview] 接收到的尺寸: {new_width}x{new_height}") 120 | 121 | if node_id not in size_data: 122 | return web.json_response({"success": False, "error": "节点数据不存在"}) 123 | 124 | try: 125 | node_info = size_data[node_id] 126 | 127 | if image_data: 128 | try: 129 | # 从二进制数据创建PIL图像 130 | buffer = io.BytesIO(image_data) 131 | pil_image = Image.open(buffer) 132 | 133 | # 转换为RGB模式(如果是RGBA) 134 | if pil_image.mode == 'RGBA': 135 | pil_image = pil_image.convert('RGB') 136 | 137 | # 转换为numpy数组 138 | np_image = np.array(pil_image) 139 | 140 | # 转换为PyTorch张量 - 使用正确的维度顺序 [B, H, W, C] 141 | tensor_image = torch.from_numpy(np_image / 255.0).float().unsqueeze(0) 142 | print(f"[ImagePreview] 从二进制数据创建的张量形状: {tensor_image.shape}") 143 | node_info["result"] = tensor_image 144 | except Exception as e: 145 | print(f"[ImagePreview] 处理图像数据时出错: {str(e)}") 146 | traceback.print_exc() 147 | 148 | # 在成功处理后添加标记 149 | node_info["processed"] = True 150 | node_info["event"].set() 151 | return web.json_response({"success": True}) 152 | 153 | except Exception as e: 154 | print(f"[ImagePreview] 处理数据时出错: {str(e)}") 155 | traceback.print_exc() 156 | if node_id in size_data and "event" in size_data[node_id]: 157 | size_data[node_id]["event"].set() 158 | return web.json_response({"success": False, "error": str(e)}) 159 | 160 | except Exception as e: 161 | print(f"[ImagePreview] 请求处理出错: {str(e)}") 162 | traceback.print_exc() 163 | return web.json_response({"success": False, "error": str(e)}) 164 | 165 | NODE_CLASS_MAPPINGS = { 166 | "ImageSizeAdjustment": ImageSizeAdjustment, 167 | } 168 | 169 | NODE_DISPLAY_NAME_MAPPINGS = { 170 | "ImageSizeAdjustment": "图像尺寸调整", 171 | } -------------------------------------------------------------------------------- /py/image_cropper.py: -------------------------------------------------------------------------------- 1 | 2 | from .md import * 3 | crop_node_data = {} 4 | class ImageCropper: 5 | """图像裁剪专用节点""" 6 | 7 | @classmethod 8 | def INPUT_TYPES(cls): 9 | return { 10 | "required": { 11 | "image": ("IMAGE",), 12 | }, 13 | "hidden": { 14 | "unique_id": "UNIQUE_ID", 15 | } 16 | } 17 | 18 | RETURN_TYPES = ("IMAGE",) 19 | RETURN_NAMES = ("裁剪图像",) 20 | FUNCTION = "crop" 21 | CATEGORY = "🎈LAOGOU/Image" 22 | 23 | def crop(self, image, unique_id): 24 | try: 25 | node_id = unique_id 26 | event = Event() 27 | 28 | # 初始化节点数据 29 | crop_node_data[node_id] = { 30 | "event": event, 31 | "result": None, 32 | "processing_complete": False 33 | } 34 | 35 | # 发送预览图像 36 | preview_image = (torch.clamp(image.clone(), 0, 1) * 255).cpu().numpy().astype(np.uint8)[0] 37 | pil_image = Image.fromarray(preview_image) 38 | buffer = io.BytesIO() 39 | pil_image.save(buffer, format="PNG") 40 | base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8') 41 | 42 | try: 43 | PromptServer.instance.send_sync("image_cropper_update", { 44 | "node_id": node_id, 45 | "image_data": f"data:image/png;base64,{base64_image}" 46 | }) 47 | 48 | # 等待前端裁剪完成 49 | if not event.wait(timeout=30): 50 | print(f"[ImageCropper] 等待超时: 节点ID {node_id}") 51 | if node_id in crop_node_data: 52 | del crop_node_data[node_id] 53 | return (image,) 54 | 55 | # 获取结果 56 | result_image = None 57 | 58 | if node_id in crop_node_data: 59 | result_image = crop_node_data[node_id]["result"] 60 | del crop_node_data[node_id] 61 | 62 | return (result_image if result_image is not None else image,) 63 | 64 | except Exception as e: 65 | print(f"[ImageCropper] 处理过程中出错: {str(e)}") 66 | traceback.print_exc() 67 | if node_id in crop_node_data: 68 | del crop_node_data[node_id] 69 | return (image,) 70 | 71 | except Exception as e: 72 | print(f"[ImageCropper] 节点执行出错: {str(e)}") 73 | traceback.print_exc() 74 | return (image,) 75 | 76 | @PromptServer.instance.routes.post("/image_cropper/apply") 77 | async def apply_image_cropper(request): 78 | try: 79 | # 检查内容类型 80 | content_type = request.headers.get('Content-Type', '') 81 | print(f"[ImageCropper] 请求内容类型: {content_type}") 82 | 83 | node_id = None 84 | crop_width = None 85 | crop_height = None 86 | image_data = None 87 | 88 | if 'multipart/form-data' in content_type: 89 | # 处理multipart/form-data请求 90 | reader = await request.multipart() 91 | 92 | # 读取表单字段 93 | while True: 94 | part = await reader.next() 95 | if part is None: 96 | break 97 | 98 | if part.name == 'node_id': 99 | node_id = await part.text() 100 | elif part.name == 'width': 101 | crop_width = int(await part.text()) 102 | elif part.name == 'height': 103 | crop_height = int(await part.text()) 104 | elif part.name == 'image_data': 105 | image_data = await part.read(decode=False) 106 | else: 107 | # 处理JSON请求 108 | data = await request.json() 109 | node_id = data.get("node_id") 110 | crop_width = data.get("width") 111 | crop_height = data.get("height") 112 | 113 | cropped_data_base64 = data.get("cropped_data_base64") 114 | if cropped_data_base64: 115 | if cropped_data_base64.startswith('data:image'): 116 | base64_data = cropped_data_base64.split(',')[1] 117 | else: 118 | base64_data = cropped_data_base64 119 | image_data = base64.b64decode(base64_data) 120 | 121 | if node_id not in crop_node_data: 122 | crop_node_data[node_id] = { 123 | "event": Event(), 124 | "result": None, 125 | "processing_complete": False 126 | } 127 | 128 | try: 129 | node_info = crop_node_data[node_id] 130 | 131 | if image_data: 132 | try: 133 | buffer = io.BytesIO(image_data) 134 | pil_image = Image.open(buffer) 135 | 136 | if pil_image.mode == 'RGBA': 137 | pil_image = pil_image.convert('RGB') 138 | 139 | np_image = np.array(pil_image) 140 | 141 | if len(np_image.shape) == 3 and np_image.shape[2] == 3: 142 | tensor_image = torch.from_numpy(np_image / 255.0).float().unsqueeze(0) 143 | node_info["result"] = tensor_image 144 | node_info["event"].set() 145 | else: 146 | print(f"[ImageCropper] 警告: 图像数组形状不符合预期: {np_image.shape}") 147 | except Exception as e: 148 | print(f"[ImageCropper] 处理图像数据时出错: {str(e)}") 149 | traceback.print_exc() 150 | node_info["event"].set() 151 | 152 | return web.json_response({"success": True}) 153 | 154 | except Exception as e: 155 | print(f"[ImageCropper] 处理数据时出错: {str(e)}") 156 | traceback.print_exc() 157 | if node_id in crop_node_data and "event" in crop_node_data[node_id]: 158 | crop_node_data[node_id]["event"].set() 159 | return web.json_response({"success": False, "error": str(e)}) 160 | 161 | except Exception as e: 162 | print(f"[ImageCropper] 请求处理出错: {str(e)}") 163 | traceback.print_exc() 164 | return web.json_response({"success": False, "error": str(e)}) 165 | 166 | @PromptServer.instance.routes.post("/image_cropper/cancel") 167 | async def cancel_crop(request): 168 | try: 169 | data = await request.json() 170 | node_id = data.get("node_id") 171 | 172 | if node_id in crop_node_data: 173 | # 设置事件,让节点继续执行 174 | crop_node_data[node_id]["event"].set() 175 | print(f"[ImageCropper] 取消裁剪操作: 节点ID {node_id}") 176 | return web.json_response({"success": True}) 177 | 178 | return web.json_response({"success": False, "error": "节点未找到"}) 179 | 180 | except Exception as e: 181 | print(f"[ImageCropper] 取消请求处理出错: {str(e)}") 182 | traceback.print_exc() 183 | return web.json_response({"success": False, "error": str(e)}) 184 | 185 | NODE_CLASS_MAPPINGS = { 186 | "ImageCropper": ImageCropper, 187 | } 188 | 189 | NODE_DISPLAY_NAME_MAPPINGS = { 190 | "ImageCropper": "图像裁剪", 191 | } 192 | -------------------------------------------------------------------------------- /web/queue_shortcut.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../scripts/app.js"; 2 | import { api } from "../../scripts/api.js"; 3 | 4 | class EventManager { 5 | constructor() { 6 | this.listeners = new Map(); 7 | } 8 | 9 | addEventListener(event, callback) { 10 | if (!this.listeners.has(event)) { 11 | this.listeners.set(event, []); 12 | } 13 | this.listeners.get(event).push(callback); 14 | } 15 | 16 | removeEventListener(event, callback) { 17 | if (this.listeners.has(event)) { 18 | const callbacks = this.listeners.get(event); 19 | const index = callbacks.indexOf(callback); 20 | if (index > -1) { 21 | callbacks.splice(index, 1); 22 | } 23 | } 24 | } 25 | 26 | dispatchEvent(event, detail = {}) { 27 | if (this.listeners.has(event)) { 28 | const callbacks = this.listeners.get(event); 29 | callbacks.forEach(callback => { 30 | try { 31 | callback({ detail }); 32 | } catch (error) { 33 | console.error(`Error in event listener for ${event}:`, error); 34 | } 35 | }); 36 | } 37 | } 38 | } 39 | 40 | class QueueManager { 41 | constructor() { 42 | this.eventManager = new EventManager(); 43 | this.queueNodeIds = null; 44 | this.processingQueue = false; 45 | this.lastAdjustedMouseEvent = null; 46 | this.isLGTriggered = false; // 标记是否由 LG 扩展触发 47 | this.initializeHooks(); 48 | } 49 | 50 | initializeHooks() { 51 | const originalQueuePrompt = app.queuePrompt; 52 | const originalGraphToPrompt = app.graphToPrompt; 53 | const originalApiQueuePrompt = api.queuePrompt; 54 | 55 | app.queuePrompt = async function() { 56 | this.processingQueue = true; 57 | this.eventManager.dispatchEvent("queue"); 58 | try { 59 | await originalQueuePrompt.apply(app, [...arguments]); 60 | } finally { 61 | this.processingQueue = false; 62 | this.eventManager.dispatchEvent("queue-end"); 63 | } 64 | }.bind(this); 65 | 66 | app.graphToPrompt = async function() { 67 | this.eventManager.dispatchEvent("graph-to-prompt"); 68 | let promise = originalGraphToPrompt.apply(app, [...arguments]); 69 | await promise; 70 | this.eventManager.dispatchEvent("graph-to-prompt-end"); 71 | return promise; 72 | }.bind(this); 73 | 74 | api.queuePrompt = async function(index, prompt, ...args) { 75 | // 仅在 LG 扩展触发时才修改 prompt.output 76 | if (this.isLGTriggered && this.queueNodeIds && this.queueNodeIds.length && prompt.output) { 77 | const oldOutput = prompt.output; 78 | let newOutput = {}; 79 | for (const queueNodeId of this.queueNodeIds) { 80 | this.recursiveAddNodes(String(queueNodeId), oldOutput, newOutput); 81 | } 82 | prompt.output = newOutput; 83 | } 84 | 85 | this.eventManager.dispatchEvent("comfy-api-queue-prompt-before", { 86 | workflow: prompt.workflow, 87 | output: prompt.output, 88 | }); 89 | 90 | const response = originalApiQueuePrompt.apply(api, [index, prompt, ...args]); 91 | this.eventManager.dispatchEvent("comfy-api-queue-prompt-end"); 92 | return response; 93 | }.bind(this); 94 | 95 | const originalProcessMouseDown = LGraphCanvas.prototype.processMouseDown; 96 | const originalAdjustMouseEvent = LGraphCanvas.prototype.adjustMouseEvent; 97 | const originalProcessMouseMove = LGraphCanvas.prototype.processMouseMove; 98 | 99 | LGraphCanvas.prototype.processMouseDown = function(e) { 100 | const result = originalProcessMouseDown.apply(this, [...arguments]); 101 | queueManager.lastAdjustedMouseEvent = e; 102 | return result; 103 | }; 104 | 105 | LGraphCanvas.prototype.adjustMouseEvent = function(e) { 106 | originalAdjustMouseEvent.apply(this, [...arguments]); 107 | queueManager.lastAdjustedMouseEvent = e; 108 | }; 109 | 110 | LGraphCanvas.prototype.processMouseMove = function(e) { 111 | const result = originalProcessMouseMove.apply(this, [...arguments]); 112 | if (e && !e.canvasX && !e.canvasY) { 113 | const canvas = app.canvas; 114 | const offset = canvas.convertEventToCanvasOffset(e); 115 | e.canvasX = offset[0]; 116 | e.canvasY = offset[1]; 117 | } 118 | queueManager.lastAdjustedMouseEvent = e; 119 | return result; 120 | }; 121 | } 122 | recursiveAddNodes(nodeId, oldOutput, newOutput) { 123 | let currentId = nodeId; 124 | let currentNode = oldOutput[currentId]; 125 | if (newOutput[currentId] == null) { 126 | newOutput[currentId] = currentNode; 127 | for (const inputValue of Object.values(currentNode.inputs || [])) { 128 | if (Array.isArray(inputValue)) { 129 | this.recursiveAddNodes(inputValue[0], oldOutput, newOutput); 130 | } 131 | } 132 | } 133 | return newOutput; 134 | } 135 | async queueOutputNodes(nodeIds) { 136 | try { 137 | this.queueNodeIds = nodeIds; 138 | this.isLGTriggered = true; // 设置 LG 触发标记 139 | await app.queuePrompt(); 140 | } catch (e) { 141 | console.error("队列节点时出错:", e); 142 | } finally { 143 | this.queueNodeIds = null; 144 | this.isLGTriggered = false; // 清除 LG 触发标记 145 | } 146 | } 147 | getLastMouseEvent() { 148 | return this.lastAdjustedMouseEvent; 149 | } 150 | addEventListener(event, callback) { 151 | this.eventManager.addEventListener(event, callback); 152 | } 153 | removeEventListener(event, callback) { 154 | this.eventManager.removeEventListener(event, callback); 155 | } 156 | } 157 | 158 | function getOutputNodes(nodes) { 159 | return (nodes?.filter((n) => { 160 | return (n.mode != LiteGraph.NEVER && 161 | n.constructor.nodeData?.output_node); 162 | }) || []); 163 | } 164 | const queueManager = new QueueManager(); 165 | function queueSelectedOutputNodes() { 166 | const selectedNodes = app.canvas.selected_nodes; 167 | if (!selectedNodes || Object.keys(selectedNodes).length === 0) { 168 | console.log("[LG]队列: 没有选中的节点"); 169 | return; 170 | } 171 | 172 | const outputNodes = getOutputNodes(Object.values(selectedNodes)); 173 | if (!outputNodes || outputNodes.length === 0) { 174 | console.log("[LG]队列: 选中的节点中没有输出节点"); 175 | return; 176 | } 177 | 178 | console.log(`[LG]队列: 执行 ${outputNodes.length} 个输出节点`); 179 | queueManager.queueOutputNodes(outputNodes.map((n) => n.id)); 180 | } 181 | 182 | function queueGroupOutputNodes() { 183 | const lastMouseEvent = queueManager.getLastMouseEvent(); 184 | if (!lastMouseEvent) { 185 | return; 186 | } 187 | 188 | let canvasX = lastMouseEvent.canvasX; 189 | let canvasY = lastMouseEvent.canvasY; 190 | 191 | if (!canvasX || !canvasY) { 192 | const canvas = app.canvas; 193 | const mousePos = canvas.getMousePos(); 194 | canvasX = mousePos[0]; 195 | canvasY = mousePos[1]; 196 | } 197 | 198 | const group = app.graph.getGroupOnPos(canvasX, canvasY); 199 | 200 | if (!group) { 201 | return; 202 | } 203 | 204 | group.recomputeInsideNodes(); 205 | 206 | if (!group._nodes || group._nodes.length === 0) { 207 | return; 208 | } 209 | 210 | const outputNodes = getOutputNodes(group._nodes); 211 | if (!outputNodes || outputNodes.length === 0) { 212 | return; 213 | } 214 | 215 | queueManager.queueOutputNodes(outputNodes.map((n) => n.id)); 216 | } 217 | 218 | app.registerExtension({ 219 | name: "LG.QueueNodes", 220 | commands: [ 221 | { 222 | id: "LG.QueueSelectedOutputNodes", 223 | icon: "pi pi-play", 224 | label: "执行选中的输出节点", 225 | function: queueSelectedOutputNodes 226 | }, 227 | { 228 | id: "LG.QueueGroupOutputNodes", 229 | icon: "pi pi-sitemap", 230 | label: "执行组内输出节点", 231 | function: queueGroupOutputNodes 232 | } 233 | ] 234 | }); 235 | 236 | export { queueManager, getOutputNodes, queueSelectedOutputNodes, queueGroupOutputNodes }; 237 | 238 | -------------------------------------------------------------------------------- /web/bridge_preview.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../scripts/app.js"; 2 | import { api } from "../../scripts/api.js"; 3 | const waitingNodes = new Map(); 4 | const processedNodes = new Set(); 5 | function updateNodePreview(nodeId, fileInfo) { 6 | const node = app.graph.getNodeById(parseInt(nodeId)); 7 | if (!node || !fileInfo) return; 8 | const imageUrl = fileInfo.subfolder 9 | ? `/view?filename=${encodeURIComponent(fileInfo.filename)}&subfolder=${encodeURIComponent(fileInfo.subfolder)}&type=${fileInfo.type || 'input'}` 10 | : `/view?filename=${encodeURIComponent(fileInfo.filename)}&type=${fileInfo.type || 'input'}`; 11 | node.images = [{ 12 | filename: fileInfo.filename, 13 | subfolder: fileInfo.subfolder || "", 14 | type: fileInfo.type || "input", 15 | url: imageUrl 16 | }]; 17 | if (node.onDrawBackground) { 18 | const img = new Image(); 19 | img.onload = () => { 20 | node.imgs = [img]; 21 | app.graph.setDirtyCanvas(true); 22 | }; 23 | img.src = imageUrl; 24 | } 25 | updateStatusWidget(node, `${nodeId}_${fileInfo.filename}`); 26 | } 27 | function updateStatusWidget(node, statusText) { 28 | const statusWidget = node.widgets?.find(w => w.name === "file_info"); 29 | if (statusWidget) { 30 | statusWidget.value = statusText; 31 | app.graph.setDirtyCanvas(true); 32 | app.graph.change(); 33 | } 34 | } 35 | api.addEventListener("bridge_preview_update", (event) => { 36 | const { node_id, urls } = event.detail; 37 | const node = app.graph._nodes_by_id[node_id]; 38 | if (!node || !urls?.length) return; 39 | waitingNodes.set(node_id, { urls, timestamp: Date.now(), node }); 40 | const imageData = urls.map((url, index) => ({ 41 | index, 42 | filename: url.filename, 43 | subfolder: url.subfolder, 44 | type: url.type 45 | })); 46 | node.imageData = imageData; 47 | node.imgs = []; 48 | imageData.forEach((imgData, i) => { 49 | const img = new Image(); 50 | img.onload = () => { 51 | app.graph.setDirtyCanvas(true); 52 | if (i === imageData.length - 1) { 53 | setupClipspace(node_id, urls); 54 | } 55 | }; 56 | img.src = `/view?filename=${encodeURIComponent(imgData.filename)}&type=${imgData.type}&subfolder=${imgData.subfolder || ''}&${app.getPreviewFormatParam()}`; 57 | node.imgs.push(img); 58 | }); 59 | node.setSizeForImage?.(); 60 | node.update?.(); 61 | }); 62 | function setupClipspace(nodeId, urls) { 63 | const ComfyApp = app.constructor; 64 | if (!ComfyApp.clipspace) ComfyApp.clipspace = {}; 65 | if (!app.clipspace) app.clipspace = {}; 66 | const images = urls.map(url => ({ 67 | filename: url.filename, 68 | subfolder: url.subfolder || "", 69 | type: url.type || "output" 70 | })); 71 | const imgs = urls.map(url => ({ 72 | src: `${window.location.origin}/view?filename=${encodeURIComponent(url.filename)}&type=${url.type}&subfolder=${encodeURIComponent(url.subfolder || '')}`, 73 | filename: url.filename, 74 | subfolder: url.subfolder || "", 75 | type: url.type || "output" 76 | })); 77 | [ComfyApp.clipspace, app.clipspace].forEach(clipspace => { 78 | clipspace.images = images; 79 | clipspace.imgs = imgs; 80 | clipspace.selectedIndex = 0; 81 | }); 82 | setTimeout(() => { 83 | const node = app.graph.getNodeById(parseInt(nodeId)); 84 | if (!node) return; 85 | 86 | app.canvas.selectNode(node); 87 | 88 | let success = false; 89 | 90 | try { 91 | if (app.extensionManager?.command?.execute) { 92 | app.extensionManager.command.execute('Comfy.MaskEditor.OpenMaskEditor'); 93 | success = true; 94 | } 95 | } catch (error) { 96 | // Silently fail and try next method 97 | } 98 | 99 | if (!success) { 100 | try { 101 | const ComfyApp = app.constructor; 102 | const openMaskEditor = ComfyApp.open_maskeditor || app.open_maskeditor; 103 | if (openMaskEditor && typeof openMaskEditor === 'function') { 104 | openMaskEditor(); 105 | success = true; 106 | } 107 | } catch (error) { 108 | // Silently fail 109 | } 110 | } 111 | 112 | bindCancelButton(); 113 | }, 100); 114 | } 115 | function bindCancelButton() { 116 | const checkInterval = setInterval(() => { 117 | const maskEditor = findMaskEditor(); 118 | if (!maskEditor) return; 119 | 120 | const cancelButtons = Array.from(maskEditor.querySelectorAll('button')).filter(btn => { 121 | const text = btn.textContent.trim().toLowerCase(); 122 | return text === 'cancel' || text === '取消' || text.includes('cancel') || text.includes('取消'); 123 | }); 124 | 125 | if (cancelButtons.length > 0) { 126 | cancelButtons.forEach(button => { 127 | if (!button.hasAttribute('data-bridge-bound')) { 128 | button.setAttribute('data-bridge-bound', 'true'); 129 | button.addEventListener('click', () => { 130 | setTimeout(handleMaskEditorCancel, 50); 131 | }, { capture: true }); 132 | } 133 | }); 134 | clearInterval(checkInterval); 135 | } 136 | }, 300); 137 | setTimeout(() => clearInterval(checkInterval), 10000); 138 | } 139 | function findMaskEditor() { 140 | const newMaskEditor = document.querySelector('.mask-editor-dialog'); 141 | if (newMaskEditor) { 142 | return newMaskEditor; 143 | } 144 | 145 | const modals = document.querySelectorAll('div.comfy-modal, .comfy-modal, [class*="modal"]'); 146 | for (const modal of modals) { 147 | if (modal.querySelector('canvas') && modal.style.display !== 'none') { 148 | return modal; 149 | } 150 | } 151 | 152 | const elements = document.querySelectorAll('*'); 153 | for (const element of elements) { 154 | const buttons = element.querySelectorAll('button'); 155 | if (buttons.length >= 2 && element.querySelector('canvas')) { 156 | const buttonTexts = Array.from(buttons).map(btn => btn.textContent.trim().toLowerCase()); 157 | if (buttonTexts.some(text => text.includes('cancel') || text.includes('取消')) && 158 | buttonTexts.some(text => text.includes('save') || text.includes('保存'))) { 159 | return element; 160 | } 161 | } 162 | } 163 | 164 | return null; 165 | } 166 | function handleMaskEditorCancel() { 167 | waitingNodes.forEach((nodeInfo, nodeId) => { 168 | sendCancelSignal(nodeId); 169 | }); 170 | 171 | waitingNodes.clear(); 172 | processedNodes.clear(); 173 | } 174 | async function sendCancelSignal(nodeId) { 175 | try { 176 | await api.fetchApi("/bridge_preview/cancel", { 177 | method: "POST", 178 | headers: { "Content-Type": "application/json" }, 179 | body: JSON.stringify({ node_id: String(nodeId) }) 180 | }); 181 | } catch (error) { 182 | console.error(`[BridgePreview] Failed to send cancel signal:`, error); 183 | } 184 | } 185 | const originalFetch = api.fetchApi; 186 | api.fetchApi = async function(url, options) { 187 | const result = await originalFetch.call(this, url, options); 188 | if (url === "/upload/mask" && result.ok) { 189 | await handleMaskUpload(result); 190 | } 191 | return result; 192 | }; 193 | async function handleMaskUpload(result) { 194 | try { 195 | const responseData = await result.clone().json(); 196 | const fileInfo = responseData?.name ? { 197 | filename: responseData.name, 198 | subfolder: responseData.subfolder || "clipspace", 199 | type: responseData.type || "input" 200 | } : null; 201 | if (!fileInfo) return; 202 | let latestNodeId = null; 203 | let latestTimestamp = 0; 204 | for (const [nodeId, nodeInfo] of waitingNodes) { 205 | if (nodeInfo.timestamp > latestTimestamp) { 206 | latestTimestamp = nodeInfo.timestamp; 207 | latestNodeId = nodeId; 208 | } 209 | } 210 | if (!latestNodeId) return; 211 | try { 212 | updateNodePreview(latestNodeId, fileInfo); 213 | const confirmResponse = await originalFetch.call(api, "/bridge_preview/confirm", { 214 | method: "POST", 215 | headers: { "Content-Type": "application/json" }, 216 | body: JSON.stringify({ 217 | node_id: String(latestNodeId), 218 | file_info: fileInfo 219 | }) 220 | }); 221 | if (confirmResponse.ok) { 222 | processedNodes.add(String(latestNodeId)); 223 | setTimeout(() => { 224 | waitingNodes.delete(latestNodeId); 225 | }, 1000); 226 | } 227 | } catch (error) { 228 | console.error(`[BridgePreview] Failed to process node ${latestNodeId}:`, error); 229 | } 230 | } catch (error) { 231 | console.error(`[BridgePreview] Failed to handle mask result:`, error); 232 | } 233 | } 234 | setInterval(() => { 235 | const now = Date.now(); 236 | for (const [nodeId, nodeInfo] of waitingNodes) { 237 | if (now - nodeInfo.timestamp > 30000) { 238 | waitingNodes.delete(nodeId); 239 | processedNodes.delete(String(nodeId)); 240 | } 241 | } 242 | }, 5000); 243 | console.log("[BridgePreview] Bridge preview module loaded"); -------------------------------------------------------------------------------- /py/inspyrenet/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import cv2 4 | import yaml 5 | import torch 6 | import hashlib 7 | import argparse 8 | 9 | import albumentations as A 10 | from albumentations.core.transforms_interface import ImageOnlyTransform 11 | 12 | import numpy as np 13 | 14 | from PIL import Image 15 | from threading import Thread 16 | from easydict import EasyDict 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--source', '-s', type=str, help="Path to the source. Single image, video, directory of images, directory of videos is supported.") 21 | parser.add_argument('--dest', '-d', type=str, default=None, help="Path to destination. Results will be stored in current directory if not specified.") 22 | parser.add_argument('--type', '-t', type=str, default='rgba', help="Specify output type. If not specified, output results will make the background transparent. Please refer to the documentation for other types.") 23 | parser.add_argument('--fast', '-f', action='store_true', help="(Deprecated) Speed up inference speed by using small scale, but decreases output quality.") 24 | parser.add_argument('--jit', '-j', action='store_true', help="Speed up inference speed by using torchscript, but decreases output quality.") 25 | parser.add_argument('--device', '-D', type=str, default=None, help="Designate device. If not specified, it will find available device.") 26 | parser.add_argument('--mode', '-m', type=str, default='base', help="choose between base and fast mode. Also, use base-nightly for nightly release checkpoint.") 27 | parser.add_argument('--ckpt', '-c', type=str, default=None, help="Designate checkpoint. If not specified, it will download or load pre-downloaded default checkpoint.") 28 | parser.add_argument('--threshold', '-th', type=str, default=None, help="Designate threshold. If specified, it will output hard prediction above threshold. If not specified, it will output soft prediction.") 29 | return parser.parse_args() 30 | 31 | def get_backend(): 32 | if torch.cuda.is_available(): 33 | return "cuda:0" 34 | elif torch.backends.mps.is_available(): 35 | return "mps:0" 36 | else: 37 | return "cpu" 38 | 39 | def load_config(config_dir, easy=True): 40 | cfg = yaml.load(open(config_dir), yaml.FullLoader) 41 | if easy is True: 42 | cfg = EasyDict(cfg) 43 | return cfg 44 | 45 | def get_format(source): 46 | img_count = len([i for i in source if i.lower().endswith(('.jpg', '.png', '.jpeg'))]) 47 | vid_count = len([i for i in source if i.lower().endswith(('.mp4', '.avi', '.mov' ))]) 48 | 49 | if img_count * vid_count != 0: 50 | return '' 51 | elif img_count != 0: 52 | return 'Image' 53 | elif vid_count != 0: 54 | return 'Video' 55 | else: 56 | return '' 57 | 58 | def sort(x): 59 | convert = lambda text: int(text) if text.isdigit() else text.lower() 60 | alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] 61 | return sorted(x, key=alphanum_key) 62 | 63 | def download_and_unzip(filename, url, dest, unzip=True, **kwargs): 64 | if not os.path.isdir(dest): 65 | os.makedirs(dest, exist_ok=True) 66 | 67 | if os.path.isfile(os.path.join(dest, filename)) is False: 68 | os.system("wget -O {} {}".format(os.path.join(dest, filename), url)) 69 | elif 'md5' in kwargs.keys() and kwargs['md5'] != hashlib.md5(open(os.path.join(dest, filename), 'rb').read()).hexdigest(): 70 | os.system("wget -O {} {}".format(os.path.join(dest, filename), url)) 71 | 72 | if unzip: 73 | os.system("unzip -o {} -d {}".format(os.path.join(dest, filename), dest)) 74 | os.system("rm {}".format(os.path.join(dest, filename))) 75 | 76 | class dynamic_resize: 77 | def __init__(self, L=1280): 78 | self.L = L 79 | 80 | def __call__(self, img): 81 | size = list(img.size) 82 | if (size[0] >= size[1]) and size[1] > self.L: 83 | size[0] = size[0] / (size[1] / self.L) 84 | size[1] = self.L 85 | elif (size[1] > size[0]) and size[0] > self.L: 86 | size[1] = size[1] / (size[0] / self.L) 87 | size[0] = self.L 88 | size = (int(round(size[0] / 32)) * 32, int(round(size[1] / 32)) * 32) 89 | 90 | return img.resize(size, Image.BILINEAR) 91 | 92 | class dynamic_resize_a(ImageOnlyTransform): 93 | def __init__(self, L=1280, always_apply=False, p=1.0): 94 | super(dynamic_resize_a, self).__init__(always_apply, p) 95 | self.L = L 96 | 97 | def apply(self, img, **params): 98 | size = list(img.shape[:2]) 99 | if (size[0] >= size[1]) and size[1] > self.L: 100 | size[0] = size[0] / (size[1] / self.L) 101 | size[1] = self.L 102 | elif (size[1] > size[0]) and size[0] > self.L: 103 | size[1] = size[1] / (size[0] / self.L) 104 | size[0] = self.L 105 | size = (int(round(size[0] / 32)) * 32, int(round(size[1] / 32)) * 32) 106 | 107 | return A.resize(img, height=size[0], width=size[1]) 108 | 109 | def get_transform_init_args_names(self): 110 | return ("L",) 111 | 112 | class static_resize: 113 | def __init__(self, size=[1024, 1024]): 114 | self.size = size 115 | 116 | def __call__(self, img): 117 | return img.resize(self.size, Image.BILINEAR) 118 | 119 | class normalize: 120 | def __init__(self, mean=None, std=None, div=255): 121 | self.mean = mean if mean is not None else 0.0 122 | self.std = std if std is not None else 1.0 123 | self.div = div 124 | 125 | def __call__(self, img): 126 | img /= self.div 127 | img -= self.mean 128 | img /= self.std 129 | 130 | return img 131 | 132 | class tonumpy: 133 | def __init__(self): 134 | pass 135 | 136 | def __call__(self, img): 137 | img = np.array(img, dtype=np.float32) 138 | return img 139 | 140 | class totensor: 141 | def __init__(self): 142 | pass 143 | 144 | def __call__(self, img): 145 | img = img.transpose((2, 0, 1)) 146 | img = torch.from_numpy(img).float() 147 | 148 | return img 149 | 150 | class ImageLoader: 151 | def __init__(self, root): 152 | if os.path.isdir(root): 153 | self.images = [os.path.join(root, f) for f in os.listdir(root) if f.lower().endswith(('.jpg', '.png', '.jpeg'))] 154 | self.images = sort(self.images) 155 | elif os.path.isfile(root): 156 | self.images = [root] 157 | self.size = len(self.images) 158 | 159 | def __iter__(self): 160 | self.index = 0 161 | return self 162 | 163 | def __next__(self): 164 | if self.index == self.size: 165 | raise StopIteration 166 | 167 | img = Image.open(self.images[self.index]).convert('RGB') 168 | name = os.path.split(self.images[self.index])[-1] 169 | # name = os.path.splitext(name)[0] 170 | 171 | self.index += 1 172 | return img, name 173 | 174 | def __len__(self): 175 | return self.size 176 | 177 | class VideoLoader: 178 | def __init__(self, root): 179 | if os.path.isdir(root): 180 | self.videos = [os.path.join(root, f) for f in os.listdir(root) if f.lower().endswith(('.mp4', '.avi', 'mov'))] 181 | elif os.path.isfile(root): 182 | self.videos = [root] 183 | self.size = len(self.videos) 184 | 185 | def __iter__(self): 186 | self.index = 0 187 | self.cap = None 188 | self.fps = None 189 | return self 190 | 191 | def __next__(self): 192 | if self.index == self.size: 193 | raise StopIteration 194 | 195 | if self.cap is None: 196 | self.cap = cv2.VideoCapture(self.videos[self.index]) 197 | self.fps = self.cap.get(cv2.CAP_PROP_FPS) 198 | ret, frame = self.cap.read() 199 | name = os.path.split(self.videos[self.index])[-1] 200 | # name = os.path.splitext(name)[0] 201 | if ret is False: 202 | self.cap.release() 203 | self.cap = None 204 | img = None 205 | self.index += 1 206 | 207 | else: 208 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 209 | img = Image.fromarray(frame).convert('RGB') 210 | 211 | return img, name 212 | 213 | def __len__(self): 214 | return self.size 215 | 216 | class WebcamLoader: 217 | def __init__(self, ID): 218 | self.ID = int(ID) 219 | self.cap = cv2.VideoCapture(self.ID) 220 | self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) 221 | self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) 222 | self.imgs = [] 223 | self.imgs.append(self.cap.read()[1]) 224 | self.thread = Thread(target=self.update, daemon=True) 225 | self.thread.start() 226 | 227 | def update(self): 228 | while self.cap.isOpened(): 229 | ret, frame = self.cap.read() 230 | if ret is True: 231 | self.imgs.append(frame) 232 | else: 233 | break 234 | 235 | def __iter__(self): 236 | return self 237 | 238 | def __next__(self): 239 | if len(self.imgs) > 0: 240 | frame = self.imgs[-1] 241 | else: 242 | frame = Image.fromarray(np.zeros((480, 640, 3)).astype(np.uint8)) 243 | 244 | if self.thread.is_alive() is False or cv2.waitKey(1) == ord('q'): 245 | cv2.destroyAllWindows() 246 | raise StopIteration 247 | 248 | else: 249 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 250 | frame = Image.fromarray(frame).convert('RGB') 251 | 252 | del self.imgs[:-1] 253 | return frame, None 254 | 255 | def __len__(self): 256 | return 0 257 | -------------------------------------------------------------------------------- /py/image_selector.py: -------------------------------------------------------------------------------- 1 | from .md import * 2 | from threading import Event 3 | import comfy.model_management 4 | 5 | class ImageSelectorCancelled(Exception): 6 | pass 7 | 8 | def get_selector_storage(): 9 | """获取图像选择器的共享存储空间""" 10 | if not hasattr(PromptServer.instance, '_selector_node_data'): 11 | PromptServer.instance._selector_node_data = {} 12 | return PromptServer.instance._selector_node_data 13 | 14 | class ImageSelector(PreviewImage): 15 | @classmethod 16 | def INPUT_TYPES(cls): 17 | return { 18 | "required": { 19 | "images": ("IMAGE",), 20 | "mode": (["always_pause", "keep_last_selection", "passthrough"], {"default": "always_pause"}), 21 | }, 22 | "hidden": { 23 | "prompt": "PROMPT", 24 | "unique_id": "UNIQUE_ID", 25 | "extra_pnginfo": "EXTRA_PNGINFO" 26 | } 27 | } 28 | 29 | RETURN_TYPES = ("IMAGE", "STRING") 30 | RETURN_NAMES = ("selected_images", "selected_indices") 31 | FUNCTION = "select_image" 32 | CATEGORY = "🎈LAOGOU/Image" 33 | OUTPUT_NODE = True 34 | OUTPUT_IS_LIST = (True, False) 35 | INPUT_IS_LIST = True 36 | 37 | @classmethod 38 | def IS_CHANGED(cls, images, **kwargs): 39 | return float(time.time()) 40 | 41 | def select_image(self, images, mode, prompt=None, unique_id=None, extra_pnginfo=None): 42 | try: 43 | node_id = str(unique_id[0]) if isinstance(unique_id, list) else str(unique_id) 44 | actual_mode = mode[0] if isinstance(mode, list) else mode 45 | 46 | # 获取共享存储空间 47 | node_data = get_selector_storage() 48 | 49 | image_list = [] 50 | if isinstance(images, list): 51 | for img in images: 52 | if isinstance(img, torch.Tensor): 53 | if len(img.shape) == 4: 54 | for i in range(img.shape[0]): 55 | image_list.append(img[i:i+1]) 56 | elif len(img.shape) == 3: 57 | image_list.append(img.unsqueeze(0)) 58 | elif isinstance(images, torch.Tensor): 59 | if len(images.shape) == 4: 60 | for i in range(images.shape[0]): 61 | image_list.append(images[i:i+1]) 62 | elif len(images.shape) == 3: 63 | image_list.append(images.unsqueeze(0)) 64 | else: 65 | raise ValueError(f"不支持的图像维度: {images.shape}") 66 | else: 67 | raise ValueError(f"不支持的输入类型: {type(images)}") 68 | 69 | preview_images = [] 70 | for i, img in enumerate(image_list): 71 | try: 72 | result = self.save_images(images=img, prompt=prompt) 73 | if 'ui' in result and 'images' in result['ui']: 74 | preview_images.extend(result['ui']['images']) 75 | except Exception as e: 76 | continue 77 | 78 | try: 79 | PromptServer.instance.send_sync("image_selector_update", { 80 | "id": node_id, 81 | "urls": preview_images 82 | }) 83 | except Exception as e: 84 | pass 85 | 86 | if actual_mode == "passthrough": 87 | self.cleanup_session_data(node_id) 88 | all_indices = ','.join(str(i) for i in range(len(image_list))) 89 | return {"result": (image_list, all_indices)} 90 | 91 | if actual_mode == "keep_last_selection": 92 | if node_id in node_data and "last_selection" in node_data[node_id]: 93 | last_selection = node_data[node_id]["last_selection"] 94 | if last_selection and len(last_selection) > 0: 95 | valid_indices = [idx for idx in last_selection if 0 <= idx < len(image_list)] 96 | if valid_indices: 97 | try: 98 | PromptServer.instance.send_sync("image_selector_selection", { 99 | "id": node_id, 100 | "selected_indices": valid_indices 101 | }) 102 | except Exception as e: 103 | pass 104 | self.cleanup_session_data(node_id) 105 | indices_str = ','.join(str(i) for i in valid_indices) 106 | return {"result": ([image_list[idx] for idx in valid_indices], indices_str)} 107 | 108 | if node_id in node_data: 109 | del node_data[node_id] 110 | 111 | event = Event() 112 | node_data[node_id] = { 113 | "event": event, 114 | "selected_indices": None, 115 | "images": image_list, 116 | "total_count": len(image_list), 117 | "cancelled": False 118 | } 119 | 120 | while node_id in node_data: 121 | node_info = node_data[node_id] 122 | if node_info.get("cancelled", False): 123 | self.cleanup_session_data(node_id) 124 | raise ImageSelectorCancelled("用户取消选择") 125 | 126 | if "selected_indices" in node_info and node_info["selected_indices"] is not None: 127 | break 128 | 129 | time.sleep(0.1) 130 | 131 | if node_id in node_data: 132 | node_info = node_data[node_id] 133 | selected_indices = node_info.get("selected_indices") 134 | 135 | if selected_indices is not None and len(selected_indices) > 0: 136 | valid_indices = [idx for idx in selected_indices if 0 <= idx < len(image_list)] 137 | if valid_indices: 138 | selected_images = [image_list[idx] for idx in valid_indices] 139 | 140 | if node_id not in node_data: 141 | node_data[node_id] = {} 142 | node_data[node_id]["last_selection"] = valid_indices 143 | 144 | self.cleanup_session_data(node_id) 145 | indices_str = ','.join(str(i) for i in valid_indices) 146 | return {"result": (selected_images, indices_str)} 147 | else: 148 | self.cleanup_session_data(node_id) 149 | return {"result": ([image_list[0]] if len(image_list) > 0 else [], "0" if len(image_list) > 0 else "")} 150 | else: 151 | self.cleanup_session_data(node_id) 152 | return {"result": ([image_list[0]] if len(image_list) > 0 else [], "0" if len(image_list) > 0 else "")} 153 | else: 154 | return {"result": ([image_list[0]] if len(image_list) > 0 else [], "0" if len(image_list) > 0 else "")} 155 | 156 | except ImageSelectorCancelled: 157 | raise comfy.model_management.InterruptProcessingException() 158 | except Exception as e: 159 | node_data = get_selector_storage() 160 | if node_id in node_data: 161 | self.cleanup_session_data(node_id) 162 | if 'image_list' in locals() and len(image_list) > 0: 163 | return {"result": ([image_list[0]], "0")} 164 | else: 165 | return {"result": ([], "")} 166 | 167 | def cleanup_session_data(self, node_id): 168 | """清理会话数据""" 169 | node_data = get_selector_storage() 170 | if node_id in node_data: 171 | session_keys = ["event", "selected_indices", "images", "total_count", "cancelled"] 172 | for key in session_keys: 173 | if key in node_data[node_id]: 174 | del node_data[node_id][key] 175 | 176 | @PromptServer.instance.routes.post("/image_selector/select") 177 | async def select_image_handler(request): 178 | try: 179 | data = await request.json() 180 | node_id = data.get("node_id") 181 | selected_indices = data.get("selected_indices", []) 182 | action = data.get("action") 183 | 184 | # 获取共享存储空间 185 | node_data = get_selector_storage() 186 | 187 | if node_id not in node_data: 188 | return web.json_response({"success": False, "error": "节点数据不存在"}) 189 | 190 | try: 191 | node_info = node_data[node_id] 192 | 193 | if "total_count" not in node_info: 194 | return web.json_response({"success": False, "error": "节点已完成处理"}) 195 | 196 | if action == "cancel": 197 | node_info["cancelled"] = True 198 | node_info["selected_indices"] = [] 199 | elif action == "select" and isinstance(selected_indices, list): 200 | valid_indices = [idx for idx in selected_indices if isinstance(idx, int) and 0 <= idx < node_info["total_count"]] 201 | if valid_indices: 202 | node_info["selected_indices"] = valid_indices 203 | node_info["cancelled"] = False 204 | else: 205 | return web.json_response({"success": False, "error": "选择索引无效"}) 206 | else: 207 | return web.json_response({"success": False, "error": "无效操作"}) 208 | 209 | node_info["event"].set() 210 | return web.json_response({"success": True}) 211 | 212 | except Exception as e: 213 | if node_id in node_data and "event" in node_data[node_id]: 214 | node_data[node_id]["event"].set() 215 | return web.json_response({"success": False, "error": "处理失败"}) 216 | 217 | except Exception as e: 218 | return web.json_response({"success": False, "error": "请求失败"}) 219 | 220 | NODE_CLASS_MAPPINGS = { 221 | "ImageSelector": ImageSelector, 222 | } 223 | 224 | NODE_DISPLAY_NAME_MAPPINGS = { 225 | "ImageSelector": "🎈LG_图像选择器", 226 | } -------------------------------------------------------------------------------- /web/lg_group_muter.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../../scripts/app.js"; 2 | import { ComfyWidgets } from "../../../scripts/widgets.js"; 3 | 4 | const MODE_BYPASS = 4; 5 | const MODE_MUTE = 2; 6 | const MODE_ALWAYS = 0; 7 | 8 | function recomputeInsideNodesForGroup(group) { 9 | if (!group.graph) return; 10 | group._nodes = []; 11 | const nodeBoundings = {}; 12 | for (const node of group.graph._nodes) { 13 | nodeBoundings[node.id] = node.getBounding(); 14 | } 15 | for (const node of group.graph._nodes) { 16 | const bounding = nodeBoundings[node.id]; 17 | if (bounding && LiteGraph.overlapBounding(group._bounding, bounding)) { 18 | group._nodes.push(node); 19 | } 20 | } 21 | } 22 | 23 | class LG_GroupMuterNode extends LGraphNode { 24 | constructor(title) { 25 | super(title); 26 | this.isVirtualNode = true; 27 | this.serialize_widgets = true; 28 | this.properties = { mode: "mute", maxOne: false }; 29 | this.groupWidgetMap = new Map(); 30 | this.tempSize = null; 31 | this.debouncerTempWidth = 0; 32 | this._lastGroupSignature = ""; 33 | } 34 | 35 | onAdded() { 36 | setTimeout(() => this.setupGroups(), 10); 37 | } 38 | 39 | setupGroups() { 40 | if (!app.graph?._groups) return; 41 | 42 | const groups = app.graph._groups.sort((a, b) => a.title.localeCompare(b.title)); 43 | const existingGroups = new Set(); 44 | 45 | groups.forEach((group, index) => { 46 | const name = group.title; 47 | existingGroups.add(name); 48 | 49 | let widget = this.groupWidgetMap.get(name) || this.widgets?.find(w => w.name === name); 50 | 51 | if (!widget) { 52 | this.tempSize = this.size ? [...this.size] : [300, 60]; 53 | 54 | if (!this.inputs?.find(i => i.name === name)) { 55 | this.addInput(name, "BOOLEAN"); 56 | } 57 | const input = this.inputs[this.inputs.length - 1]; 58 | 59 | const widgetResult = ComfyWidgets?.BOOLEAN(this, name, ["BOOLEAN", { default: true }], app); 60 | widget = widgetResult?.widget; 61 | 62 | if (!widget) return; 63 | 64 | // Ensure value is explicitly boolean 65 | if (widget.value !== true && widget.value !== false) { 66 | widget.value = true; 67 | } 68 | 69 | input.widget = widget; 70 | 71 | // Ensure config symbols are set for PS_Config compatibility 72 | widget[Symbol.for("GET_CONFIG")] = () => ["BOOLEAN", { default: true }]; 73 | widget[Symbol.for("CONFIG")] = ["BOOLEAN", { default: true }]; 74 | 75 | const originalCallback = widget.callback; 76 | widget.callback = (value) => { 77 | // Max One模式:开启一个组时,关闭其他所有组 78 | if (this.properties.maxOne && value === true) { 79 | for (const [otherName, otherWidget] of this.groupWidgetMap) { 80 | if (otherName !== name && otherWidget.value === true) { 81 | otherWidget.value = false; 82 | this.applyGroupMode(otherName, false); 83 | const otherGroup = app.graph?._groups.find(g => g.title === otherName); 84 | if (otherGroup) otherGroup.lgtools_isActive = false; 85 | } 86 | } 87 | } 88 | 89 | this.applyGroupMode(name, value); 90 | group.lgtools_isActive = value; 91 | originalCallback?.call(widget, value); 92 | }; 93 | 94 | this.groupWidgetMap.set(name, widget); 95 | group.lgtools_isActive = widget.value; 96 | this.setSize(this.computeSize()); 97 | } else { 98 | this.groupWidgetMap.set(name, widget); 99 | const input = this.inputs?.find(i => i.name === name); 100 | if (input && !input.widget) input.widget = widget; 101 | 102 | // Ensure config symbols exist for existing widgets 103 | if (!widget[Symbol.for("GET_CONFIG")]) { 104 | widget[Symbol.for("GET_CONFIG")] = () => ["BOOLEAN", { default: true }]; 105 | widget[Symbol.for("CONFIG")] = ["BOOLEAN", { default: true }]; 106 | } 107 | 108 | if (group.lgtools_isActive != null && widget.value !== group.lgtools_isActive) { 109 | widget.value = group.lgtools_isActive; 110 | } 111 | } 112 | 113 | if (this.widgets?.[index] !== widget) { 114 | const oldIndex = this.widgets.indexOf(widget); 115 | if (oldIndex !== -1) { 116 | this.widgets.splice(index, 0, this.widgets.splice(oldIndex, 1)[0]); 117 | } 118 | } 119 | }); 120 | 121 | const toRemove = [...this.groupWidgetMap.keys()].filter(name => !existingGroups.has(name)); 122 | if (toRemove.length) { 123 | const widgetIndices = []; 124 | const inputIndices = []; 125 | 126 | toRemove.forEach(name => { 127 | const widget = this.groupWidgetMap.get(name); 128 | if (widget) widgetIndices.push(this.widgets.indexOf(widget)); 129 | const inputIdx = this.inputs?.findIndex(i => i.name === name); 130 | if (inputIdx !== -1) inputIndices.push(inputIdx); 131 | }); 132 | 133 | widgetIndices.sort((a, b) => b - a).forEach(i => this.widgets.splice(i, 1)); 134 | inputIndices.sort((a, b) => b - a).forEach(i => this.removeInput(i)); 135 | 136 | toRemove.forEach(name => { 137 | this.groupWidgetMap.delete(name); 138 | }); 139 | 140 | this.setSize(this.computeSize()); 141 | } 142 | 143 | this.setDirtyCanvas(true, false); 144 | } 145 | 146 | applyGroupMode(groupName, enabled) { 147 | const group = app.graph?._groups.find(g => g.title === groupName); 148 | if (!group) return; 149 | 150 | const targetMode = enabled ? MODE_ALWAYS : (this.properties.mode === "mute" ? MODE_MUTE : MODE_BYPASS); 151 | recomputeInsideNodesForGroup(group); 152 | 153 | (group._nodes || []).forEach(node => { 154 | if (node.mode !== targetMode) node.mode = targetMode; 155 | }); 156 | 157 | app.graph.setDirtyCanvas(true, false); 158 | } 159 | 160 | computeSize(out) { 161 | const widgetCount = this.widgets?.length || 0; 162 | let size = [200, (LiteGraph.NODE_TITLE_HEIGHT || 30) + widgetCount * (LiteGraph.NODE_WIDGET_HEIGHT || 20)]; 163 | 164 | if (this.tempSize) { 165 | size = [Math.max(this.tempSize[0], size[0]), Math.max(this.tempSize[1], size[1])]; 166 | clearTimeout(this.debouncerTempWidth); 167 | this.debouncerTempWidth = setTimeout(() => this.tempSize = null, 32); 168 | } 169 | 170 | if (out) { 171 | out[0] = size[0]; 172 | out[1] = size[1]; 173 | } 174 | return size; 175 | } 176 | 177 | getExtraMenuOptions(canvas, options) { 178 | const currentMode = this.properties.mode || "mute"; 179 | const nextMode = currentMode === "mute" ? "bypass" : "mute"; 180 | 181 | options.push({ 182 | content: `Switch to ${nextMode === "mute" ? "Mute" : "Bypass"} mode`, 183 | callback: () => { 184 | this.properties.mode = nextMode; 185 | for (const [name, widget] of this.groupWidgetMap) { 186 | this.applyGroupMode(name, widget.value); 187 | } 188 | } 189 | }); 190 | 191 | options.push({ 192 | content: this.properties.maxOne ? "✓ Max One Mode" : "Max One Mode", 193 | callback: () => { 194 | this.properties.maxOne = !this.properties.maxOne; 195 | // 如果启用Max One模式,确保最多只有一个组是开启的 196 | if (this.properties.maxOne) { 197 | let firstActiveFound = false; 198 | for (const [name, widget] of this.groupWidgetMap) { 199 | if (widget.value === true) { 200 | if (firstActiveFound) { 201 | widget.value = false; 202 | this.applyGroupMode(name, false); 203 | const group = app.graph?._groups.find(g => g.title === name); 204 | if (group) group.lgtools_isActive = false; 205 | } else { 206 | firstActiveFound = true; 207 | } 208 | } 209 | } 210 | } 211 | this.setDirtyCanvas(true, false); 212 | } 213 | }); 214 | 215 | options.push({ 216 | content: "Toggle All Groups", 217 | callback: () => { 218 | const newValue = !(this.widgets || []).every(w => w.value); 219 | (this.widgets || []).forEach(w => { 220 | w.value = newValue; 221 | w.callback?.(newValue); 222 | }); 223 | } 224 | }); 225 | } 226 | 227 | onDrawBackground() { 228 | if (!app.graph?._groups) return; 229 | 230 | const signature = app.graph._groups.map(g => g.title).sort().join(','); 231 | if (signature !== this._lastGroupSignature) { 232 | this._lastGroupSignature = signature; 233 | this.setupGroups(); 234 | } 235 | 236 | for (const [name, widget] of this.groupWidgetMap) { 237 | const group = app.graph._groups.find(g => g.title === name); 238 | if (!group) continue; 239 | 240 | recomputeInsideNodesForGroup(group); 241 | if (!group._nodes?.length) continue; 242 | 243 | const isActive = group._nodes.some(n => n.mode === MODE_ALWAYS); 244 | group.lgtools_isActive = isActive; 245 | 246 | if (widget.value !== isActive) { 247 | widget.value = isActive; 248 | this.setDirtyCanvas(true, false); 249 | } 250 | } 251 | } 252 | 253 | static setUp() { 254 | LiteGraph.registerNodeType(this.type, this); 255 | if (this._category) { 256 | this.category = this._category; 257 | } 258 | } 259 | } 260 | 261 | LG_GroupMuterNode.type = "🎈LG_GroupMuter"; 262 | LG_GroupMuterNode.title = "🎈LG Group Muter"; 263 | LG_GroupMuterNode._category = "🎈LAOGOU/Switch"; 264 | LG_GroupMuterNode["@mode"] = { type: "combo", values: ["mute", "bypass"] }; 265 | 266 | app.registerExtension({ 267 | name: "LG.GroupMuter", 268 | registerCustomNodes() { 269 | LG_GroupMuterNode.setUp(); 270 | }, 271 | loadedGraphNode(node) { 272 | if (node.type === LG_GroupMuterNode.type) { 273 | node.tempSize = node.size ? [...node.size] : null; 274 | } 275 | } 276 | }); 277 | -------------------------------------------------------------------------------- /web/image_size_adjustment.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../scripts/app.js"; 2 | import { api } from "../../scripts/api.js"; 3 | import { ComfyWidgets } from "../../scripts/widgets.js"; 4 | 5 | app.registerExtension({ 6 | name: "ImagePreview.Adjustment", 7 | async beforeRegisterNodeDef(nodeType, nodeData) { 8 | if (nodeData.name === "ImageSizeAdjustment") { 9 | const onNodeCreated = nodeType.prototype.onNodeCreated; 10 | nodeType.prototype.onNodeCreated = function() { 11 | const result = onNodeCreated?.apply(this, arguments); 12 | this.widgets_start_y = 30; 13 | this.setupWebSocket(); 14 | 15 | // 创建偏移量控制器 16 | const createOffsetWidget = (name, defaultValue = 0) => { 17 | const widget = ComfyWidgets.INT(this, name, ["INT", { 18 | default: defaultValue, 19 | min: -9999, 20 | max: 9999, 21 | step: 8 22 | }]); 23 | 24 | // 设置值变化回调 25 | widget.widget.callback = (value) => { 26 | // 将值调整为8的倍数 27 | value = Math.floor(value / 8) * 8; 28 | 29 | if (this.originalImageData) { 30 | const size = name === "x_offset" 31 | ? this.originalImageData[0]?.length 32 | : this.originalImageData.length; 33 | 34 | if (value < 0 && Math.abs(value) > size) { 35 | value = -size; 36 | this.widgets.find(w => w.name === name).value = value; 37 | } 38 | } 39 | this[name] = value; 40 | this.updatePreview(true); 41 | }; 42 | 43 | // 设置拖动事件 44 | widget.widget.onDragStart = function() { 45 | this.node.isAdjusting = true; 46 | }; 47 | 48 | widget.widget.onDragEnd = function() { 49 | this.node.isAdjusting = false; 50 | this.node.updatePreview(false); 51 | }; 52 | 53 | return widget; 54 | }; 55 | 56 | // 创建 x 和 y 偏移控制器 57 | createOffsetWidget("x_offset"); 58 | createOffsetWidget("y_offset"); 59 | 60 | return result; 61 | }; 62 | 63 | // 添加WebSocket设置方法 64 | nodeType.prototype.setupWebSocket = function() { 65 | api.addEventListener("image_preview_update", async (event) => { 66 | const data = event.detail; 67 | if (data && data.node_id === this.id.toString()) { 68 | console.log(`[ImagePreview] 节点 ${this.id} 接收到更新数据`); 69 | if (data.image_data) { 70 | this.loadImageFromBase64(data.image_data); 71 | } 72 | } 73 | }); 74 | }; 75 | 76 | // 更新预览方法 77 | nodeType.prototype.updatePreview = function(onlyPreview = false) { 78 | if (!this.originalImageData || !this.canvas) { 79 | return; 80 | } 81 | 82 | if (this.updateTimeout) { 83 | clearTimeout(this.updateTimeout); 84 | } 85 | 86 | this.updateTimeout = setTimeout(() => { 87 | const ctx = this.canvas.getContext("2d"); 88 | const originalWidth = this.originalImageData[0].length; 89 | const originalHeight = this.originalImageData.length; 90 | 91 | let x_offset = this.x_offset || 0; 92 | let y_offset = this.y_offset || 0; 93 | 94 | // 计算新的宽度和高度 95 | const newWidth = Math.max(1, originalWidth + x_offset); 96 | const newHeight = Math.max(1, originalHeight + y_offset); 97 | 98 | // 同步设置画布和容器的尺寸 99 | this.canvas.width = newWidth; 100 | this.canvas.height = newHeight; 101 | 102 | ctx.clearRect(0, 0, newWidth, newHeight); 103 | 104 | if (!this.tempCanvas) { 105 | this.tempCanvas = document.createElement('canvas'); 106 | } 107 | 108 | if (!this.originalImageRendered) { 109 | this.tempCanvas.width = originalWidth; 110 | this.tempCanvas.height = originalHeight; 111 | const tempCtx = this.tempCanvas.getContext('2d'); 112 | 113 | const imgData = new ImageData(originalWidth, originalHeight); 114 | for (let y = 0; y < originalHeight; y++) { 115 | for (let x = 0; x < originalWidth; x++) { 116 | const dstIdx = (y * originalWidth + x) * 4; 117 | const srcPixel = this.originalImageData[y][x]; 118 | imgData.data[dstIdx] = srcPixel[0]; // R 119 | imgData.data[dstIdx + 1] = srcPixel[1]; // G 120 | imgData.data[dstIdx + 2] = srcPixel[2]; // B 121 | imgData.data[dstIdx + 3] = 255; // A 122 | } 123 | } 124 | tempCtx.putImageData(imgData, 0, 0); 125 | this.originalImageRendered = true; 126 | } 127 | 128 | ctx.drawImage( 129 | this.tempCanvas, 130 | 0, 0, 131 | originalWidth, originalHeight, 132 | 0, 0, 133 | newWidth, newHeight 134 | ); 135 | 136 | if (!onlyPreview && !this.isAdjusting) { 137 | const adjustedData = ctx.getImageData(0, 0, newWidth, newHeight); 138 | this.sendAdjustedData(adjustedData); 139 | } 140 | }, this.isAdjusting ? 50 : 0); 141 | }; 142 | 143 | // 发送调整后的数据 144 | nodeType.prototype.sendAdjustedData = async function(adjustedData) { 145 | try { 146 | const endpoint = '/image_preview/apply'; 147 | const nodeId = String(this.id); 148 | 149 | console.log(`[ImagePreview] 发送调整后的数据 - 尺寸: ${adjustedData.width}x${adjustedData.height}`); 150 | 151 | const canvas = document.createElement('canvas'); 152 | canvas.width = adjustedData.width; 153 | canvas.height = adjustedData.height; 154 | const ctx = canvas.getContext('2d'); 155 | ctx.putImageData(adjustedData, 0, 0); 156 | 157 | const blob = await new Promise(resolve => { 158 | canvas.toBlob(resolve, 'image/jpeg', 0.9); 159 | }); 160 | 161 | const formData = new FormData(); 162 | formData.append('node_id', nodeId); 163 | formData.append('width', adjustedData.width); 164 | formData.append('height', adjustedData.height); 165 | formData.append('image_data', blob, 'adjusted_image.jpg'); 166 | 167 | api.fetchApi(endpoint, { 168 | method: 'POST', 169 | body: formData 170 | }); 171 | } catch (error) { 172 | console.error('发送数据时出错:', error); 173 | } 174 | }; 175 | 176 | // 添加从base64加载图像的方法 177 | nodeType.prototype.loadImageFromBase64 = function(base64Data) { 178 | const img = new Image(); 179 | 180 | img.onload = () => { 181 | this.originalImageRendered = false; 182 | 183 | const tempCanvas = document.createElement('canvas'); 184 | tempCanvas.width = img.width; 185 | tempCanvas.height = img.height; 186 | const tempCtx = tempCanvas.getContext('2d'); 187 | 188 | tempCtx.drawImage(img, 0, 0); 189 | const imageData = tempCtx.getImageData(0, 0, img.width, img.height); 190 | 191 | const pixelArray = []; 192 | for (let y = 0; y < img.height; y++) { 193 | const row = []; 194 | for (let x = 0; x < img.width; x++) { 195 | const idx = (y * img.width + x) * 4; 196 | row.push([ 197 | imageData.data[idx], // R 198 | imageData.data[idx + 1], // G 199 | imageData.data[idx + 2] // B 200 | ]); 201 | } 202 | pixelArray.push(row); 203 | } 204 | 205 | this.originalImageData = pixelArray; 206 | this.updatePreview(); 207 | }; 208 | 209 | img.src = base64Data; 210 | }; 211 | 212 | // 添加节点时的处理 213 | const onAdded = nodeType.prototype.onAdded; 214 | nodeType.prototype.onAdded = function() { 215 | const result = onAdded?.apply(this, arguments); 216 | 217 | if (!this.previewElement && this.id !== undefined && this.id !== -1) { 218 | const previewContainer = document.createElement("div"); 219 | previewContainer.style.position = "relative"; 220 | previewContainer.style.width = "100%"; 221 | previewContainer.style.height = "100%"; 222 | previewContainer.style.backgroundColor = "#333"; 223 | previewContainer.style.borderRadius = "8px"; 224 | previewContainer.style.overflow = "hidden"; 225 | 226 | const canvas = document.createElement("canvas"); 227 | canvas.style.width = "100%"; 228 | canvas.style.height = "100%"; 229 | canvas.style.objectFit = "contain"; 230 | 231 | previewContainer.appendChild(canvas); 232 | 233 | this.canvas = canvas; 234 | this.previewElement = previewContainer; 235 | 236 | this.widgets ||= []; 237 | this.widgets_up = true; 238 | 239 | requestAnimationFrame(() => { 240 | if (this.widgets) { 241 | this.previewWidget = this.addDOMWidget("preview", "preview", previewContainer); 242 | this.setDirtyCanvas(true, true); 243 | } 244 | }); 245 | } 246 | 247 | return result; 248 | }; 249 | } 250 | } 251 | }); -------------------------------------------------------------------------------- /web/color_adjustment.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../scripts/app.js"; 2 | import { api } from "../../scripts/api.js"; 3 | 4 | 5 | app.registerExtension({ 6 | name: "ColorAdjustment.Preview", 7 | async beforeRegisterNodeDef(nodeType, nodeData) { 8 | if (nodeData.name === "ColorAdjustment") { 9 | 10 | // 扩展节点的构造函数 11 | const onNodeCreated = nodeType.prototype.onNodeCreated; 12 | nodeType.prototype.onNodeCreated = function() { 13 | const result = onNodeCreated?.apply(this, arguments); 14 | 15 | // 设置组件起始位置,确保在端口下方 16 | this.widgets_start_y = 30; // 调整这个值以适应端口高度 17 | 18 | this.setupWebSocket(); 19 | 20 | const sliderConfig = { 21 | min: 0, 22 | max: 2, 23 | step: 0.01, 24 | drag_start: () => this.isAdjusting = true, 25 | drag_end: () => { 26 | this.isAdjusting = false; 27 | this.updatePreview(false); 28 | } 29 | }; 30 | 31 | const createSlider = (name) => { 32 | this.addWidget("slider", name, 1.0, (value) => { 33 | this[name] = value; 34 | this.updatePreview(true); 35 | }, sliderConfig); 36 | }; 37 | 38 | ["brightness", "contrast", "saturation"].forEach(createSlider); 39 | 40 | return result; 41 | }; 42 | 43 | // 添加WebSocket设置方法 44 | nodeType.prototype.setupWebSocket = function() { 45 | console.log(`[ColorAdjustment] 节点 ${this.id} 设置WebSocket监听`); 46 | api.addEventListener("color_adjustment_update", async (event) => { 47 | const data = event.detail; 48 | 49 | if (data && data.node_id && data.node_id === this.id.toString()) { 50 | console.log(`[ColorAdjustment] 节点 ${this.id} 接收到更新数据`); 51 | if (data.image_data) { 52 | // 处理base64图像数据 53 | console.log("[ColorAdjustment] 接收到base64数据:", { 54 | nodeId: this.id, 55 | dataLength: data.image_data.length, 56 | dataPreview: data.image_data.substring(0, 50) + "...", // 只显示前50个字符 57 | isBase64: data.image_data.startsWith("data:image"), 58 | timestamp: new Date().toISOString() 59 | }); 60 | 61 | this.loadImageFromBase64(data.image_data); 62 | } else { 63 | console.warn("[ColorAdjustment] 接收到空的图像数据"); 64 | } 65 | } 66 | }); 67 | }; 68 | 69 | // 添加从base64加载图像的方法 70 | nodeType.prototype.loadImageFromBase64 = function(base64Data) { 71 | console.log(`[ColorAdjustment] 节点 ${this.id} 开始加载base64图像数据`); 72 | // 创建一个新的图像对象 73 | const img = new Image(); 74 | 75 | // 当图像加载完成时 76 | img.onload = () => { 77 | console.log(`[ColorAdjustment] 节点 ${this.id} 图像加载完成: ${img.width}x${img.height}`); 78 | // 创建一个临时画布来获取像素数据 79 | const tempCanvas = document.createElement('canvas'); 80 | tempCanvas.width = img.width; 81 | tempCanvas.height = img.height; 82 | const tempCtx = tempCanvas.getContext('2d'); 83 | 84 | // 在临时画布上绘制图像 85 | tempCtx.drawImage(img, 0, 0); 86 | 87 | // 获取像素数据 88 | const imageData = tempCtx.getImageData(0, 0, img.width, img.height); 89 | 90 | // 创建二维数组存储像素数据 91 | const pixelArray = []; 92 | for (let y = 0; y < img.height; y++) { 93 | const row = []; 94 | for (let x = 0; x < img.width; x++) { 95 | const idx = (y * img.width + x) * 4; 96 | row.push([ 97 | imageData.data[idx], // R 98 | imageData.data[idx + 1], // G 99 | imageData.data[idx + 2] // B 100 | ]); 101 | } 102 | pixelArray.push(row); 103 | } 104 | 105 | // 存储像素数据并更新预览 106 | this.originalImageData = pixelArray; 107 | this.updatePreview(); 108 | }; 109 | 110 | // 设置图像源 111 | img.src = base64Data; 112 | }; 113 | 114 | // 添加节点时的处理 115 | const onAdded = nodeType.prototype.onAdded; 116 | nodeType.prototype.onAdded = function() { 117 | const result = onAdded?.apply(this, arguments); 118 | 119 | if (!this.previewElement && this.id !== undefined && this.id !== -1) { 120 | // 创建预览容器 121 | const previewContainer = document.createElement("div"); 122 | previewContainer.style.position = "relative"; 123 | previewContainer.style.width = "100%"; 124 | previewContainer.style.height = "100%"; 125 | previewContainer.style.backgroundColor = "#333"; 126 | previewContainer.style.borderRadius = "8px"; 127 | previewContainer.style.overflow = "hidden"; 128 | 129 | // 创建预览画布 130 | const canvas = document.createElement("canvas"); 131 | canvas.style.width = "100%"; 132 | canvas.style.height = "100%"; 133 | canvas.style.objectFit = "contain"; 134 | 135 | previewContainer.appendChild(canvas); 136 | this.canvas = canvas; 137 | this.previewElement = previewContainer; 138 | 139 | // 添加DOM部件 140 | this.widgets ||= []; 141 | this.widgets_up = true; 142 | 143 | requestAnimationFrame(() => { 144 | if (this.widgets) { 145 | this.previewWidget = this.addDOMWidget("preview", "preview", previewContainer); 146 | this.setDirtyCanvas(true, true); 147 | } 148 | }); 149 | } 150 | 151 | return result; 152 | }; 153 | 154 | // 更新预览方法 155 | nodeType.prototype.updatePreview = function(onlyPreview = false) { 156 | if (!this.originalImageData || !this.canvas) { 157 | return; 158 | } 159 | 160 | requestAnimationFrame(() => { 161 | const ctx = this.canvas.getContext("2d"); 162 | const width = this.originalImageData[0].length; 163 | const height = this.originalImageData.length; 164 | 165 | if (!onlyPreview && !this.isAdjusting) { 166 | console.log(`[ColorAdjustment] 节点 ${this.id} 更新预览并准备发送数据 (${width}x${height})`); 167 | } else { 168 | console.log(`[ColorAdjustment] 节点 ${this.id} 仅更新预览 (${width}x${height})`); 169 | } 170 | 171 | // 创建ImageData 172 | const imgData = new ImageData(width, height); 173 | 174 | // 填充原始数据 175 | for (let y = 0; y < height; y++) { 176 | for (let x = 0; x < width; x++) { 177 | const idx = (y * width + x) * 4; 178 | imgData.data[idx] = this.originalImageData[y][x][0]; // R 179 | imgData.data[idx + 1] = this.originalImageData[y][x][1]; // G 180 | imgData.data[idx + 2] = this.originalImageData[y][x][2]; // B 181 | imgData.data[idx + 3] = 255; // A 182 | } 183 | } 184 | 185 | // 应用颜色调整 186 | const adjustedData = this.adjustColors(imgData); 187 | 188 | // 调整画布大小并显示 189 | this.canvas.width = width; 190 | this.canvas.height = height; 191 | ctx.putImageData(adjustedData, 0, 0); 192 | 193 | // 只在拖动结束时发送数据 194 | if (!onlyPreview && !this.isAdjusting) { 195 | this.lastAdjustedData = adjustedData; 196 | this.sendAdjustedData(adjustedData); 197 | } 198 | }); 199 | }; 200 | 201 | // 优化颜色调整方法,提高性能 202 | nodeType.prototype.adjustColors = function(imageData) { 203 | const brightness = this.brightness || 1.0; 204 | const contrast = this.contrast || 1.0; 205 | const saturation = this.saturation || 1.0; 206 | 207 | const result = new Uint8ClampedArray(imageData.data); 208 | const len = result.length; 209 | 210 | // 使用查找表优化常用计算 211 | const contrastFactor = contrast; 212 | const contrastOffset = 128 * (1 - contrast); 213 | 214 | for (let i = 0; i < len; i += 4) { 215 | // 优化亮度和对比度调整 216 | let r = Math.min(255, result[i] * brightness); 217 | let g = Math.min(255, result[i + 1] * brightness); 218 | let b = Math.min(255, result[i + 2] * brightness); 219 | 220 | r = r * contrastFactor + contrastOffset; 221 | g = g * contrastFactor + contrastOffset; 222 | b = b * contrastFactor + contrastOffset; 223 | 224 | // 优化饱和度调整 - 使用更准确的亮度权重 225 | if (saturation !== 1.0) { 226 | const avg = r * 0.299 + g * 0.587 + b * 0.114; 227 | r = avg + (r - avg) * saturation; 228 | g = avg + (g - avg) * saturation; 229 | b = avg + (b - avg) * saturation; 230 | } 231 | 232 | // 确保值在正确范围内 233 | result[i] = Math.min(255, Math.max(0, r)); 234 | result[i + 1] = Math.min(255, Math.max(0, g)); 235 | result[i + 2] = Math.min(255, Math.max(0, b)); 236 | } 237 | 238 | return new ImageData(result, imageData.width, imageData.height); 239 | }; 240 | 241 | // 添加发送调整后数据的方法,优化为异步 242 | nodeType.prototype.sendAdjustedData = async function(adjustedData) { 243 | try { 244 | const endpoint = '/color_adjustment/apply'; 245 | const nodeId = String(this.id); 246 | 247 | api.fetchApi(endpoint, { 248 | method: 'POST', 249 | body: JSON.stringify({ 250 | node_id: nodeId, 251 | adjusted_data: Array.from(adjustedData.data), 252 | width: adjustedData.width, 253 | height: adjustedData.height 254 | }) 255 | }).then(response => { 256 | if (!response.ok) { 257 | throw new Error(`服务器返回错误: ${response.status}`); 258 | } 259 | return response.json(); 260 | }).catch(error => { 261 | console.error('数据发送失败:', error); 262 | }); 263 | } catch (error) { 264 | console.error('发送数据时出错:', error); 265 | } 266 | }; 267 | 268 | // 节点移除时的处理 269 | const onRemoved = nodeType.prototype.onRemoved; 270 | nodeType.prototype.onRemoved = function() { 271 | const result = onRemoved?.apply(this, arguments); 272 | 273 | if (this && this.canvas) { 274 | const ctx = this.canvas.getContext("2d"); 275 | if (ctx) { 276 | ctx.clearRect(0, 0, this.canvas.width, this.canvas.height); 277 | } 278 | this.canvas = null; 279 | } 280 | if (this) { 281 | this.previewElement = null; 282 | } 283 | 284 | return result; 285 | }; 286 | } 287 | } 288 | }); 289 | 290 | 291 | -------------------------------------------------------------------------------- /py/bridge_preview.py: -------------------------------------------------------------------------------- 1 | from .md import * 2 | import torch 3 | import json 4 | import os 5 | import random 6 | import numpy as np 7 | import hashlib 8 | from PIL import Image 9 | from PIL.PngImagePlugin import PngInfo 10 | import folder_paths 11 | import comfy.utils 12 | from threading import Event 13 | import threading 14 | from server import PromptServer 15 | from aiohttp import web 16 | def get_bridge_storage(): 17 | if not hasattr(PromptServer.instance, '_bridge_node_data'): 18 | PromptServer.instance._bridge_node_data = {} 19 | return PromptServer.instance._bridge_node_data 20 | def get_bridge_cache(): 21 | """获取桥接节点的缓存存储""" 22 | if not hasattr(PromptServer.instance, '_bridge_node_cache'): 23 | PromptServer.instance._bridge_node_cache = {} 24 | return PromptServer.instance._bridge_node_cache 25 | class BridgePreviewNode(PreviewImage): 26 | """桥接预览节点,等待前端遮罩操作完成后输出图片""" 27 | def __init__(self): 28 | super().__init__() 29 | self.prefix_append = "_bridge_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) 30 | @classmethod 31 | def INPUT_TYPES(cls): 32 | return { 33 | "required": { 34 | "images": ("IMAGE", ), 35 | "file_info": ("STRING", {"default": "", "readonly": True, "multiline": False}), 36 | "skip": ("BOOLEAN", {"default": False, "label_on": "Skip", "label_off": "Open dialog"}), 37 | }, 38 | "hidden": { 39 | "prompt": "PROMPT", 40 | "extra_pnginfo": "EXTRA_PNGINFO", 41 | "unique_id": "UNIQUE_ID" 42 | }, 43 | } 44 | RETURN_TYPES = ("IMAGE", "MASK") 45 | RETURN_NAMES = ("处理后图像", "遮罩") 46 | FUNCTION = "process_image" 47 | OUTPUT_NODE = True 48 | CATEGORY = "🎈LAOGOU/Image" 49 | def calculate_image_hash(self, images): 50 | """计算图片的哈希值用于检测是否改变""" 51 | try: 52 | np_images = images.cpu().numpy() 53 | image_bytes = np_images.tobytes() 54 | return hashlib.md5(image_bytes).hexdigest() 55 | except: 56 | return None 57 | def save_output_images(self, images, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): 58 | """保存输出图片(带遮罩的RGBA图片)用于前端显示""" 59 | try: 60 | if images.shape[0] != mask.shape[0]: 61 | return None 62 | batch_size = images.shape[0] 63 | saved_images = [] 64 | for i in range(batch_size): 65 | single_image = images[i:i+1] 66 | single_mask = mask[i] 67 | np_image = (single_image.squeeze(0).cpu().numpy() * 255).astype(np.uint8) 68 | np_mask = (single_mask.cpu().numpy() * 255).astype(np.uint8) 69 | h, w, c = np_image.shape 70 | rgba_image = np.zeros((h, w, 4), dtype=np.uint8) 71 | rgba_image[:, :, :3] = np_image 72 | rgba_image[:, :, 3] = np_mask 73 | pil_image = Image.fromarray(rgba_image, 'RGBA') 74 | counter = 1 75 | while True: 76 | filename = f"{filename_prefix}_{counter:05d}_.png" 77 | full_path = os.path.join(folder_paths.get_output_directory(), filename) 78 | if not os.path.exists(full_path): 79 | break 80 | counter += 1 81 | metadata = PngInfo() 82 | if prompt is not None: 83 | metadata.add_text("prompt", json.dumps(prompt)) 84 | if extra_pnginfo is not None: 85 | for key, value in extra_pnginfo.items(): 86 | metadata.add_text(key, json.dumps(value)) 87 | pil_image.save(full_path, pnginfo=metadata, compress_level=4) 88 | saved_images.append({ 89 | "filename": filename, 90 | "subfolder": "", 91 | "type": "output" 92 | }) 93 | return {"ui": {"images": saved_images}} 94 | except Exception as e: 95 | return None 96 | def process_image(self, images, file_info="", skip=False, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None, unique_id=None): 97 | """处理图像,等待前端遮罩操作完成""" 98 | try: 99 | node_id = str(unique_id) 100 | current_hash = self.calculate_image_hash(images) 101 | cache = get_bridge_cache() 102 | 103 | # 如果启用跳过预览,直接返回缓存结果或原图 104 | if skip: 105 | if node_id in cache and current_hash: 106 | cached_data = cache[node_id] 107 | if (cached_data.get("input_hash") == current_hash and 108 | cached_data.get("final_result")): 109 | # 返回缓存的最终结果,保持与弹窗模式一致的遮罩反转 110 | cached_images, cached_mask = cached_data["final_result"] 111 | return (cached_images, 1 - cached_mask) 112 | 113 | # 没有缓存或输入改变,返回原图和全白遮罩 114 | batch_size, height, width, channels = images.shape 115 | default_mask = torch.ones((batch_size, height, width), dtype=torch.float32) 116 | return (images, default_mask) 117 | 118 | # 正常预览遮罩编辑流程 119 | event = Event() 120 | preview_urls = [] 121 | should_send_to_frontend = True 122 | 123 | if node_id in cache and current_hash: 124 | cached_data = cache[node_id] 125 | if cached_data.get("input_hash") == current_hash and cached_data.get("output_urls"): 126 | preview_urls = cached_data["output_urls"] 127 | should_send_to_frontend = True 128 | else: 129 | preview_result = self.save_images(images, filename_prefix, prompt, extra_pnginfo) 130 | preview_urls = preview_result["ui"]["images"] if preview_result and "ui" in preview_result else [] 131 | else: 132 | preview_result = self.save_images(images, filename_prefix, prompt, extra_pnginfo) 133 | preview_urls = preview_result["ui"]["images"] if preview_result and "ui" in preview_result else [] 134 | get_bridge_storage()[node_id] = { 135 | "event": event, 136 | "result": None, 137 | "original_images": images, 138 | "processing_complete": False, 139 | "input_hash": current_hash 140 | } 141 | 142 | try: 143 | PromptServer.instance.send_sync("bridge_preview_update", { 144 | "node_id": node_id, 145 | "urls": preview_urls 146 | }) 147 | 148 | if not event.wait(timeout=30): 149 | if node_id in get_bridge_storage(): 150 | del get_bridge_storage()[node_id] 151 | batch_size, height, width, channels = images.shape 152 | default_mask = torch.ones((batch_size, height, width), dtype=torch.float32) 153 | return (images, default_mask) 154 | 155 | result_data = None 156 | if node_id in get_bridge_storage(): 157 | result_data = get_bridge_storage()[node_id]["result"] 158 | del get_bridge_storage()[node_id] 159 | 160 | batch_size, height, width, channels = images.shape 161 | default_mask = torch.ones((batch_size, height, width), dtype=torch.float32) 162 | 163 | if result_data is not None and isinstance(result_data, tuple) and len(result_data) == 2: 164 | final_images, final_mask = result_data 165 | if current_hash: 166 | # 缓存时使用原始遮罩 167 | output_result = self.save_output_images(final_images, final_mask, filename_prefix + "_output", prompt, extra_pnginfo) 168 | output_urls = output_result["ui"]["images"] if output_result and "ui" in output_result else [] 169 | cache[node_id] = { 170 | "input_hash": current_hash, 171 | "output_urls": output_urls, 172 | "final_result": (final_images, final_mask) # 缓存原始遮罩 173 | } 174 | # 返回时反转遮罩,但不影响缓存 175 | return (final_images, 1 - final_mask) 176 | else: 177 | return (images, default_mask) 178 | 179 | except Exception as e: 180 | if node_id in get_bridge_storage(): 181 | del get_bridge_storage()[node_id] 182 | batch_size, height, width, channels = images.shape 183 | default_mask = torch.ones((batch_size, height, width), dtype=torch.float32) 184 | return (images, default_mask) 185 | 186 | except Exception as e: 187 | batch_size, height, width, channels = images.shape 188 | default_mask = torch.ones((batch_size, height, width), dtype=torch.float32) 189 | return (images, default_mask) 190 | @PromptServer.instance.routes.post("/bridge_preview/confirm") 191 | async def confirm_bridge_preview(request): 192 | """处理前端确认遮罩操作完成""" 193 | try: 194 | data = await request.json() 195 | node_id = str(data.get("node_id")) 196 | file_info = data.get("file_info") 197 | if node_id not in get_bridge_storage(): 198 | return web.json_response({"success": False, "error": "节点未找到或已超时"}) 199 | try: 200 | node_info = get_bridge_storage()[node_id] 201 | if file_info: 202 | processed_image, mask_image = load_processed_image(file_info) 203 | if processed_image is not None and mask_image is not None: 204 | node_info["result"] = (processed_image, mask_image) 205 | node_info["processing_complete"] = True 206 | node_info["event"].set() 207 | return web.json_response({"success": True}) 208 | except Exception as e: 209 | if node_id in get_bridge_storage() and "event" in get_bridge_storage()[node_id]: 210 | get_bridge_storage()[node_id]["event"].set() 211 | return web.json_response({"success": False, "error": str(e)}) 212 | except Exception as e: 213 | return web.json_response({"success": False, "error": str(e)}) 214 | @PromptServer.instance.routes.post("/bridge_preview/cancel") 215 | async def cancel_bridge_preview(request): 216 | """取消桥接预览操作 - 恢复到上次保存的结果""" 217 | try: 218 | data = await request.json() 219 | node_id = str(data.get("node_id")) 220 | if node_id in get_bridge_storage(): 221 | node_info = get_bridge_storage()[node_id] 222 | cache = get_bridge_cache() 223 | 224 | # 如果有缓存的最终结果,使用它(保留上次编辑的遮罩) 225 | if node_id in cache and cache[node_id].get("final_result"): 226 | cached_images, cached_mask = cache[node_id]["final_result"] 227 | node_info["result"] = (cached_images, cached_mask) 228 | 229 | node_info["event"].set() 230 | return web.json_response({"success": True, "message": f"节点 {node_id} 已取消"}) 231 | else: 232 | return web.json_response({"success": False, "error": f"节点 {node_id} 未找到或已超时"}) 233 | except Exception as e: 234 | return web.json_response({"success": False, "error": str(e)}) 235 | def load_processed_image(file_info): 236 | """从文件信息加载处理后的图片,返回图像和遮罩""" 237 | try: 238 | filename = None 239 | subfolder = "" 240 | file_type = "output" 241 | 242 | if isinstance(file_info, dict): 243 | filename = file_info.get("filename") 244 | subfolder = file_info.get("subfolder", "") 245 | file_type = file_info.get("type", "output") 246 | elif isinstance(file_info, str): 247 | # 尝试解析新版本字符串格式: "path/filename.ext [type]" 248 | file_info_str = str(file_info).strip() 249 | 250 | # 查找最后的 [type] 部分 251 | if '[' in file_info_str and file_info_str.endswith(']'): 252 | # 分离路径和类型 253 | last_bracket = file_info_str.rfind('[') 254 | path_part = file_info_str[:last_bracket].strip() 255 | type_part = file_info_str[last_bracket+1:-1].strip() 256 | 257 | # 解析路径部分 258 | if '/' in path_part: 259 | # 包含子文件夹 260 | path_parts = path_part.split('/') 261 | filename = path_parts[-1] # 最后一部分是文件名 262 | subfolder = '/'.join(path_parts[:-1]) # 前面的部分是子文件夹 263 | else: 264 | # 只有文件名 265 | filename = path_part 266 | subfolder = "" 267 | 268 | # 解析类型 269 | file_type = type_part.lower() if type_part else "output" 270 | else: 271 | # 没有类型信息,直接作为路径处理 272 | if '/' in file_info_str: 273 | path_parts = file_info_str.split('/') 274 | filename = path_parts[-1] 275 | subfolder = '/'.join(path_parts[:-1]) 276 | else: 277 | filename = file_info_str 278 | subfolder = "" 279 | file_type = "output" # 默认类型 280 | else: 281 | return None, None 282 | 283 | if not filename: 284 | return None, None 285 | if file_type == "input": 286 | base_dir = folder_paths.get_input_directory() 287 | elif file_type == "output": 288 | base_dir = folder_paths.get_output_directory() 289 | elif file_type == "temp": 290 | base_dir = folder_paths.get_temp_directory() 291 | else: 292 | base_dir = folder_paths.get_output_directory() 293 | if subfolder: 294 | file_path = os.path.join(base_dir, subfolder, filename) 295 | else: 296 | file_path = os.path.join(base_dir, filename) 297 | if os.path.exists(file_path): 298 | image = Image.open(file_path) 299 | if image.mode != 'RGBA': 300 | image = image.convert('RGBA') 301 | np_image = np.array(image) 302 | rgb_image = np_image[:, :, :3] 303 | alpha_channel = np_image[:, :, 3] 304 | tensor_image = torch.from_numpy(rgb_image / 255.0).float().unsqueeze(0) 305 | mask_tensor = torch.from_numpy(alpha_channel / 255.0).float().unsqueeze(0) 306 | return tensor_image, mask_tensor 307 | else: 308 | return None, None 309 | except Exception as e: 310 | pass 311 | return None, None 312 | NODE_CLASS_MAPPINGS = { 313 | "BridgePreviewNode": BridgePreviewNode 314 | } 315 | NODE_DISPLAY_NAME_MAPPINGS = { 316 | "BridgePreviewNode": "🎈LG_PreviewBridge_V2" 317 | } -------------------------------------------------------------------------------- /web/upload.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../scripts/app.js"; 2 | import { api } from "../../scripts/api.js"; 3 | import { MultiButtonWidget } from "./multi_button_widget.js"; 4 | 5 | // 获取input目录的文件列表 6 | async function getInputFileList() { 7 | try { 8 | const response = await fetch('/object_info'); 9 | const data = await response.json(); 10 | // 从 LoadImage 节点类型获取可用文件列表 11 | const loadImageInfo = data.LoadImage; 12 | if (loadImageInfo && loadImageInfo.input && loadImageInfo.input.required && loadImageInfo.input.required.image) { 13 | return loadImageInfo.input.required.image[0]; // 返回文件列表数组 14 | } 15 | return []; 16 | } catch (error) { 17 | console.error("获取文件列表失败:", error); 18 | return []; 19 | } 20 | } 21 | 22 | // 删除图片文件(直接删除,无确认弹窗) 23 | async function deleteImageFile(filename) { 24 | try { 25 | const response = await api.fetchApi('/lg/delete_image', { 26 | method: 'DELETE', 27 | headers: { 'Content-Type': 'application/json' }, 28 | body: JSON.stringify({ filename: filename }) 29 | }); 30 | 31 | if (response.status === 200) { 32 | const result = await response.json(); 33 | if (result.success) { 34 | console.log(`文件 ${filename} 删除成功`); 35 | return true; 36 | } 37 | } else { 38 | const error = await response.json(); 39 | console.error(`删除失败: ${error.error || '未知错误'}`); 40 | return false; 41 | } 42 | } catch (error) { 43 | console.error(`删除文件失败: ${error}`); 44 | return false; 45 | } 46 | } 47 | 48 | // 加载最新图片 - 参考PB.js的实现方式 49 | async function loadLatestImage(node, folder_type) { 50 | try { 51 | // 获取指定目录中的最新图片 52 | const res = await api.fetchApi(`/lg/get/latest_image?type=${folder_type}`, { cache: "no-store" }); 53 | 54 | if (res.status === 200) { 55 | const item = await res.json(); 56 | 57 | if (item && item.filename) { 58 | // 找到图像小部件 59 | const imageWidget = node.widgets.find(w => w.name === 'image'); 60 | if (!imageWidget) return false; 61 | 62 | // 直接使用原始文件,不需要复制到input 63 | const displayValue = `${item.filename} [${folder_type}]`; 64 | imageWidget.value = displayValue; 65 | 66 | // 通过回调更新预览图像 67 | if (typeof imageWidget.callback === "function") { 68 | imageWidget.callback(displayValue); 69 | } 70 | 71 | // 更新画布 72 | app.graph.setDirtyCanvas(true); 73 | return true; 74 | } 75 | } 76 | } catch (error) { 77 | console.error(`加载图像失败: ${error}`); 78 | } 79 | return false; 80 | } 81 | 82 | // 扩展ContextMenu以支持图片缩略图和删除功能 83 | function extendContextMenuForThumbnails() { 84 | const originalContextMenu = LiteGraph.ContextMenu; 85 | 86 | LiteGraph.ContextMenu = function(values, options) { 87 | const ctx = originalContextMenu.call(this, values, options); 88 | 89 | // 检查是否是LG_LoadImage节点的image widget的下拉菜单 90 | if (options?.className === 'dark' && values?.length > 0) { 91 | // 等待DOM更新后处理 92 | requestAnimationFrame(() => { 93 | const currentNode = LGraphCanvas.active_canvas?.current_node; 94 | 95 | // 检查是否是LG_LoadImage节点的image widget 96 | if (currentNode?.comfyClass === "LG_LoadImage") { 97 | const imageWidget = currentNode.widgets?.find(w => w.name === 'image'); 98 | 99 | if (imageWidget && imageWidget.options?.values?.length === values.length) { 100 | // 限制菜单宽度 - 调整为更宽 101 | ctx.root.style.maxWidth = '400px'; 102 | ctx.root.style.minWidth = '350px'; 103 | 104 | // 为每个菜单项添加缩略图和删除按钮 105 | const menuItems = ctx.root.querySelectorAll('.litemenu-entry'); 106 | 107 | menuItems.forEach((item, index) => { 108 | if (index < values.length) { 109 | const filename = values[index]; 110 | addThumbnailAndDeleteToMenuItem(item, filename, currentNode, ctx); 111 | } 112 | }); 113 | } 114 | } 115 | }); 116 | } 117 | 118 | return ctx; 119 | }; 120 | 121 | // 保持原型链 122 | LiteGraph.ContextMenu.prototype = originalContextMenu.prototype; 123 | } 124 | 125 | // 为菜单项添加缩略图和删除按钮 126 | function addThumbnailAndDeleteToMenuItem(menuItem, filename, node, contextMenu) { 127 | // 避免重复添加 128 | if (menuItem.querySelector('.thumbnail-container')) { 129 | return; 130 | } 131 | 132 | // 保存原始文本内容 133 | const originalText = menuItem.textContent; 134 | 135 | // 清空菜单项内容 136 | menuItem.innerHTML = ''; 137 | 138 | // 设置菜单项样式为flex布局 139 | menuItem.style.cssText += ` 140 | display: flex; 141 | align-items: center; 142 | padding: 6px 12px; 143 | min-height: 48px; 144 | position: relative; 145 | `; 146 | 147 | // 创建缩略图容器 148 | const thumbnailContainer = document.createElement('div'); 149 | thumbnailContainer.className = 'thumbnail-container'; 150 | thumbnailContainer.style.cssText = ` 151 | width: 40px; 152 | height: 40px; 153 | margin-right: 10px; 154 | border-radius: 4px; 155 | overflow: hidden; 156 | background: #222; 157 | display: flex; 158 | align-items: center; 159 | justify-content: center; 160 | flex-shrink: 0; 161 | border: 1px solid #444; 162 | `; 163 | 164 | // 创建缩略图 165 | const thumbnail = document.createElement('img'); 166 | thumbnail.style.cssText = ` 167 | max-width: 100%; 168 | max-height: 100%; 169 | object-fit: cover; 170 | `; 171 | 172 | // 设置图片源 173 | thumbnail.src = `/view?filename=${encodeURIComponent(filename)}&type=input&subfolder=`; 174 | thumbnail.alt = filename; 175 | 176 | // 处理图片加载失败 177 | thumbnail.onerror = function() { 178 | thumbnailContainer.innerHTML = ` 179 |
预览
185 | `; 186 | }; 187 | 188 | thumbnailContainer.appendChild(thumbnail); 189 | 190 | // 创建文件名标签 191 | const textLabel = document.createElement('span'); 192 | 193 | // 截断长文件名 - 保留前10位和后10位文件名及扩展名 194 | let displayName = originalText; 195 | if (displayName.length > 35) { 196 | // 保留文件扩展名 197 | const lastDotIndex = displayName.lastIndexOf('.'); 198 | if (lastDotIndex > 0) { 199 | const name = displayName.substring(0, lastDotIndex); 200 | const extension = displayName.substring(lastDotIndex); 201 | if (name.length > 20) { 202 | // 保留前10位 + ... + 后10位 + 扩展名 203 | const firstPart = name.substring(0, 10); 204 | const lastPart = name.substring(name.length - 10); 205 | displayName = firstPart + '...' + lastPart + extension; 206 | } 207 | } else { 208 | // 没有扩展名的情况,保留前10位和后10位 209 | if (displayName.length > 20) { 210 | const firstPart = displayName.substring(0, 10); 211 | const lastPart = displayName.substring(displayName.length - 10); 212 | displayName = firstPart + '...' + lastPart; 213 | } 214 | } 215 | } 216 | 217 | textLabel.textContent = displayName; 218 | textLabel.title = originalText; // 悬停时显示完整文件名 219 | textLabel.style.cssText = ` 220 | color: inherit; 221 | font-size: inherit; 222 | overflow: hidden; 223 | text-overflow: ellipsis; 224 | white-space: nowrap; 225 | flex: 1; 226 | cursor: pointer; 227 | max-width: 280px; 228 | min-width: 0; 229 | `; 230 | 231 | // 创建删除按钮 - 扩大点击范围 232 | const deleteButton = document.createElement('button'); 233 | deleteButton.innerHTML = '✕'; 234 | deleteButton.title = `删除 ${filename}`; 235 | deleteButton.style.cssText = ` 236 | width: 28px; 237 | height: 28px; 238 | border: none; 239 | background: transparent; 240 | color: #888; 241 | border-radius: 4px; 242 | cursor: pointer; 243 | font-size: 14px; 244 | font-weight: bold; 245 | margin-left: 8px; 246 | flex-shrink: 0; 247 | display: flex; 248 | align-items: center; 249 | justify-content: center; 250 | opacity: 0.7; 251 | transition: all 0.15s; 252 | padding: 0; 253 | `; 254 | 255 | // 删除按钮悬停效果 - 更明显的反馈 256 | deleteButton.addEventListener('mouseenter', () => { 257 | deleteButton.style.opacity = '1'; 258 | deleteButton.style.color = '#fff'; 259 | deleteButton.style.background = 'rgba(255, 255, 255, 0.15)'; 260 | deleteButton.style.transform = 'scale(1.05)'; 261 | }); 262 | 263 | deleteButton.addEventListener('mouseleave', () => { 264 | deleteButton.style.opacity = '0.7'; 265 | deleteButton.style.color = '#888'; 266 | deleteButton.style.background = 'transparent'; 267 | deleteButton.style.transform = 'scale(1)'; 268 | }); 269 | 270 | // 删除按钮点击事件 - 快速删除,无动画延迟 271 | deleteButton.addEventListener('click', async (e) => { 272 | e.stopPropagation(); // 阻止触发菜单项选择 273 | e.preventDefault(); 274 | 275 | // 立即显示删除中状态 276 | deleteButton.innerHTML = '⋯'; 277 | deleteButton.style.pointerEvents = 'none'; 278 | deleteButton.style.opacity = '0.5'; 279 | 280 | // 直接执行删除操作,无确认弹窗 281 | const deleted = await deleteImageFile(filename); 282 | 283 | if (deleted) { 284 | // 立即移除菜单项,无动画延迟 285 | if (menuItem.parentNode) { 286 | menuItem.parentNode.removeChild(menuItem); 287 | } 288 | 289 | // 更新节点的文件列表 290 | const imageWidget = node.widgets.find(w => w.name === 'image'); 291 | if (imageWidget) { 292 | // 重新获取文件列表 293 | const fileList = await getInputFileList(); 294 | imageWidget.options.values = fileList; 295 | 296 | // 如果删除的是当前选中的文件,选择第一个可用文件 297 | if (imageWidget.value === filename) { 298 | imageWidget.value = fileList.length > 0 ? fileList[0] : ''; 299 | 300 | // 触发回调更新预览 301 | if (typeof imageWidget.callback === "function") { 302 | imageWidget.callback(imageWidget.value); 303 | } 304 | } 305 | 306 | // 更新画布 307 | app.graph.setDirtyCanvas(true); 308 | } 309 | 310 | // 检查是否还有剩余菜单项,如果没有则关闭菜单 311 | const remainingItems = contextMenu.root.querySelectorAll('.litemenu-entry'); 312 | if (remainingItems.length === 0) { 313 | contextMenu.close(); 314 | } 315 | } else { 316 | // 删除失败,恢复按钮状态 317 | deleteButton.innerHTML = '✕'; 318 | deleteButton.style.pointerEvents = 'auto'; 319 | deleteButton.style.opacity = '0.7'; 320 | } 321 | }); 322 | 323 | // 创建可点击区域(除了删除按钮) 324 | const clickableArea = document.createElement('div'); 325 | clickableArea.style.cssText = ` 326 | display: flex; 327 | align-items: center; 328 | flex: 1; 329 | cursor: pointer; 330 | `; 331 | 332 | clickableArea.appendChild(thumbnailContainer); 333 | clickableArea.appendChild(textLabel); 334 | 335 | // 为可点击区域添加选择事件 336 | clickableArea.addEventListener('click', () => { 337 | // 模拟原始菜单项点击 338 | const imageWidget = node.widgets.find(w => w.name === 'image'); 339 | if (imageWidget) { 340 | imageWidget.value = filename; 341 | 342 | // 触发回调 343 | if (typeof imageWidget.callback === "function") { 344 | imageWidget.callback(filename); 345 | } 346 | 347 | // 更新画布 348 | app.graph.setDirtyCanvas(true); 349 | } 350 | 351 | // 关闭菜单 352 | contextMenu.close(); 353 | }); 354 | 355 | // 组装菜单项 356 | menuItem.appendChild(clickableArea); 357 | menuItem.appendChild(deleteButton); 358 | 359 | // 移除原有的点击事件,因为我们现在有自定义的点击处理 360 | menuItem.onclick = null; 361 | } 362 | 363 | app.registerExtension({ 364 | name: "Comfy.LG.LoadImageButtons", 365 | 366 | init() { 367 | // 扩展ContextMenu以支持缩略图和删除功能 368 | extendContextMenuForThumbnails(); 369 | }, 370 | 371 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 372 | if (nodeData.name !== "LG_LoadImage") return; 373 | 374 | const onNodeCreated = nodeType.prototype.onNodeCreated; 375 | 376 | nodeType.prototype.onNodeCreated = function() { 377 | const result = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined; 378 | 379 | // 使用多按钮组件创建刷新按钮 380 | const refreshWidget = this.addCustomWidget(MultiButtonWidget(app, "Refresh From", { 381 | labelWidth: 80, 382 | buttonSpacing: 4 383 | }, [ 384 | { 385 | text: "Temp", 386 | callback: () => { 387 | loadLatestImage(this, "temp").then(success => { 388 | if (success) { 389 | app.graph.setDirtyCanvas(true); 390 | } 391 | }); 392 | } 393 | }, 394 | { 395 | text: "Output", 396 | callback: () => { 397 | loadLatestImage(this, "output").then(success => { 398 | if (success) { 399 | app.graph.setDirtyCanvas(true); 400 | } 401 | }); 402 | } 403 | } 404 | ])); 405 | refreshWidget.serialize = false; 406 | 407 | return result; 408 | }; 409 | } 410 | }); 411 | 412 | app.registerExtension({ 413 | name: "Comfy.LG.LoadImage_V2", 414 | 415 | init() { 416 | // 扩展ContextMenu以支持缩略图和删除功能 417 | extendContextMenuForThumbnails(); 418 | }, 419 | 420 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 421 | if (nodeData.name !== "LG_LoadImage_V2") return; 422 | 423 | // 保持节点的原始行为,不添加额外的前端功能 424 | // auto_refresh的逻辑现在完全在后端处理 425 | } 426 | }); -------------------------------------------------------------------------------- /web/image_cropper.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../scripts/app.js"; 2 | import { api } from "../../scripts/api.js"; 3 | 4 | // 创建裁剪模态窗口的HTML结构 5 | function createCropperModal() { 6 | const modal = document.createElement("dialog"); 7 | modal.id = "image-cropper-modal"; 8 | modal.innerHTML = ` 9 |
10 |
11 |

图像裁剪

12 | 13 |
14 |
15 |
16 | 17 |
18 |
19 |
20 | 21 | 22 |
23 |
24 |
25 | `; 26 | document.body.appendChild(modal); 27 | return modal; 28 | } 29 | 30 | // 修改样式 31 | const style = document.createElement("style"); 32 | style.textContent = ` 33 | #image-cropper-modal { 34 | border: none; 35 | border-radius: 8px; 36 | padding: 0; 37 | background: #2a2a2a; 38 | max-width: 90vw; /* 限制最大宽度 */ 39 | max-height: 90vh; /* 限制最大高度 */ 40 | } 41 | 42 | .cropper-container { 43 | width: fit-content; /* 根据内容自适应 */ 44 | height: fit-content; 45 | min-width: 400px; /* 减小最小尺寸 */ 46 | min-height: 300px; 47 | display: flex; 48 | flex-direction: column; 49 | } 50 | 51 | .cropper-header { 52 | display: flex; 53 | justify-content: space-between; 54 | align-items: center; 55 | padding: 10px 20px; 56 | background: #333; 57 | border-bottom: 1px solid #444; 58 | } 59 | 60 | .cropper-header h3 { 61 | margin: 0; 62 | color: #fff; 63 | } 64 | 65 | .close-button { 66 | background: none; 67 | border: none; 68 | color: #fff; 69 | font-size: 24px; 70 | cursor: pointer; 71 | } 72 | 73 | .cropper-content { 74 | padding: 20px; 75 | display: flex; 76 | flex-direction: column; 77 | gap: 20px; 78 | overflow: auto; /* 添加滚动条 */ 79 | } 80 | 81 | .cropper-wrapper { 82 | position: relative; 83 | overflow: hidden; 84 | background: #1a1a1a; 85 | display: flex; /* 使用flex布局 */ 86 | justify-content: center; 87 | align-items: center; 88 | } 89 | 90 | #crop-canvas { 91 | max-width: 100%; 92 | max-height: 70vh; 93 | object-fit: contain; /* 保持比例 */ 94 | } 95 | 96 | .crop-selection { 97 | position: absolute; 98 | border: 2px solid #00ff00; 99 | background: rgba(0, 255, 0, 0.1); 100 | pointer-events: none; 101 | transform-origin: 0 0; /* 添加变换原点 */ 102 | } 103 | 104 | .cropper-controls { 105 | display: flex; 106 | gap: 10px; 107 | justify-content: flex-end; 108 | } 109 | 110 | .cropper-controls button { 111 | padding: 8px 16px; 112 | border: none; 113 | border-radius: 4px; 114 | cursor: pointer; 115 | } 116 | 117 | #apply-crop { 118 | background: #2a8af6; 119 | color: white; 120 | } 121 | 122 | #cancel-crop { 123 | background: #666; 124 | color: white; 125 | } 126 | `; 127 | document.head.appendChild(style); 128 | 129 | // 裁剪功能类 130 | class ImageCropper { 131 | constructor() { 132 | this.modal = createCropperModal(); 133 | this.canvas = this.modal.querySelector("#crop-canvas"); 134 | this.ctx = this.canvas.getContext("2d"); 135 | this.selection = this.modal.querySelector(".crop-selection"); 136 | 137 | this.isDrawing = false; 138 | this.startX = 0; 139 | this.startY = 0; 140 | 141 | this.hasFixedSeed = false; 142 | 143 | this.setupEventListeners(); 144 | } 145 | 146 | setupEventListeners() { 147 | // 关闭按钮事件 148 | const closeButton = this.modal.querySelector(".close-button"); 149 | if (closeButton) { 150 | closeButton.addEventListener("click", () => { 151 | this.cleanupAndClose(true); // true 表示是取消操作 152 | }); 153 | } 154 | 155 | // 取消按钮事件 156 | const cancelButton = this.modal.querySelector("#cancel-crop"); 157 | if (cancelButton) { 158 | cancelButton.addEventListener("click", () => { 159 | this.cleanupAndClose(true); // true 表示是取消操作 160 | }); 161 | } 162 | 163 | // 应用裁剪按钮事件 164 | const applyButton = this.modal.querySelector("#apply-crop"); 165 | if (applyButton) { 166 | applyButton.addEventListener("click", () => this.applyCrop()); 167 | } 168 | 169 | // ESC键关闭 170 | this.modal.addEventListener("keydown", (e) => { 171 | if (e.key === "Escape") { 172 | this.cleanupAndClose(true); // true 表示是取消操作 173 | } 174 | }); 175 | 176 | // 画布鼠标事件 177 | this.canvas.addEventListener("mousedown", (e) => this.startDrawing(e)); 178 | this.canvas.addEventListener("mousemove", (e) => this.draw(e)); 179 | this.canvas.addEventListener("mouseup", () => this.endDrawing()); 180 | 181 | // 添加调试日志 182 | console.log("事件监听器已设置:", { 183 | closeButton: !!closeButton, 184 | cancelButton: !!cancelButton, 185 | applyButton: !!applyButton 186 | }); 187 | } 188 | 189 | async cleanupAndClose(cancelled = false) { 190 | // 如果是取消操作,通知后端 191 | if (cancelled && this.currentNodeId) { 192 | try { 193 | await api.fetchApi("/image_cropper/cancel", { 194 | method: "POST", 195 | headers: { 196 | "Content-Type": "application/json", 197 | }, 198 | body: JSON.stringify({ 199 | node_id: this.currentNodeId 200 | }) 201 | }); 202 | } catch (error) { 203 | console.error("发送取消信号失败:", error); 204 | } 205 | } 206 | 207 | // 清理选择框 208 | if (this.selection) { 209 | this.selection.style.display = 'none'; 210 | this.selection.style.width = '0'; 211 | this.selection.style.height = '0'; 212 | } 213 | 214 | // 清理画布 215 | if (this.ctx) { 216 | this.ctx.clearRect(0, 0, this.canvas.width, this.canvas.height); 217 | } 218 | 219 | // 重置状态 220 | this.isDrawing = false; 221 | this.startX = 0; 222 | this.startY = 0; 223 | 224 | // 关闭窗口 225 | this.modal.close(); 226 | } 227 | 228 | startDrawing(e) { 229 | const rect = this.canvas.getBoundingClientRect(); 230 | this.isDrawing = true; 231 | this.startX = e.clientX - rect.left; 232 | this.startY = e.clientY - rect.top; 233 | 234 | this.selection.style.left = `${this.startX}px`; 235 | this.selection.style.top = `${this.startY}px`; 236 | this.selection.style.width = "0px"; 237 | this.selection.style.height = "0px"; 238 | this.selection.style.display = "block"; 239 | } 240 | 241 | draw(e) { 242 | if (!this.isDrawing) return; 243 | 244 | const rect = this.canvas.getBoundingClientRect(); 245 | const currentX = e.clientX - rect.left; 246 | const currentY = e.clientY - rect.top; 247 | 248 | const width = currentX - this.startX; 249 | const height = currentY - this.startY; 250 | 251 | this.selection.style.width = `${Math.abs(width)}px`; 252 | this.selection.style.height = `${Math.abs(height)}px`; 253 | this.selection.style.left = `${width < 0 ? currentX : this.startX}px`; 254 | this.selection.style.top = `${height < 0 ? currentY : this.startY}px`; 255 | } 256 | 257 | endDrawing() { 258 | this.isDrawing = false; 259 | } 260 | 261 | calculateScale() { 262 | const canvas = this.canvas; 263 | const wrapper = canvas.parentElement; 264 | 265 | // 获取实际显示尺寸 266 | const displayRect = canvas.getBoundingClientRect(); 267 | 268 | // 计算缩放比例 269 | this.scaleX = canvas.width / displayRect.width; 270 | this.scaleY = canvas.height / displayRect.height; 271 | 272 | console.log("缩放比例:", { 273 | scaleX: this.scaleX, 274 | scaleY: this.scaleY, 275 | canvasWidth: canvas.width, 276 | canvasHeight: canvas.height, 277 | displayWidth: displayRect.width, 278 | displayHeight: displayRect.height 279 | }); 280 | } 281 | 282 | async applyCrop() { 283 | // 检查是否有选择区域 284 | if (!this.selection || 285 | !this.selection.style.width || 286 | !this.selection.style.height || 287 | parseInt(this.selection.style.width) <= 0 || 288 | parseInt(this.selection.style.height) <= 0) { 289 | console.warn("未选择有效的裁剪区域"); 290 | this.cleanupAndClose(); 291 | return; 292 | } 293 | 294 | const rect = this.selection.getBoundingClientRect(); 295 | const canvasRect = this.canvas.getBoundingClientRect(); 296 | 297 | // 计算实际坐标(考虑缩放) 298 | let x = (rect.left - canvasRect.left) * this.scaleX; 299 | let y = (rect.top - canvasRect.top) * this.scaleY; 300 | let width = rect.width * this.scaleX; 301 | let height = rect.height * this.scaleY; 302 | 303 | console.log("裁剪参数:", { 304 | x, y, width, height, 305 | originalRect: rect, 306 | canvasRect: canvasRect 307 | }); 308 | 309 | // 确保坐标和尺寸在有效范围内 310 | x = Math.max(0, Math.min(x, this.canvas.width)); 311 | y = Math.max(0, Math.min(y, this.canvas.height)); 312 | width = Math.min(width, this.canvas.width - x); 313 | height = Math.min(height, this.canvas.height - y); 314 | 315 | // 检查最终尺寸是否有效 316 | if (width <= 0 || height <= 0) { 317 | console.error("裁剪区域无效"); 318 | this.cleanupAndClose(); 319 | return; 320 | } 321 | 322 | try { 323 | // 创建临时画布进行裁剪 324 | const tempCanvas = document.createElement("canvas"); 325 | tempCanvas.width = width; 326 | tempCanvas.height = height; 327 | const tempCtx = tempCanvas.getContext("2d"); 328 | 329 | // 添加错误处理 330 | try { 331 | tempCtx.drawImage(this.canvas, 332 | x, y, width, height, 333 | 0, 0, width, height 334 | ); 335 | } catch (drawError) { 336 | console.error("裁剪绘制失败:", drawError); 337 | this.cleanupAndClose(); 338 | return; 339 | } 340 | 341 | // 转换为base64 342 | let croppedImage; 343 | try { 344 | croppedImage = tempCanvas.toDataURL("image/png"); 345 | } catch (dataUrlError) { 346 | console.error("图像转换失败:", dataUrlError); 347 | this.cleanupAndClose(); 348 | return; 349 | } 350 | 351 | console.log("准备发送请求,参数:", { 352 | node_id: this.currentNodeId, 353 | width: Math.round(width), 354 | height: Math.round(height), 355 | imageLength: croppedImage.length 356 | }); 357 | 358 | // 简化请求处理 359 | await api.fetchApi("/image_cropper/apply", { 360 | method: "POST", 361 | headers: { 362 | "Content-Type": "application/json", 363 | }, 364 | body: JSON.stringify({ 365 | node_id: this.currentNodeId, 366 | width: Math.round(width), 367 | height: Math.round(height), 368 | cropped_data_base64: croppedImage, 369 | }) 370 | }); 371 | 372 | // 简单关闭窗口 373 | this.cleanupAndClose(); 374 | 375 | } catch (error) { 376 | console.error("裁剪操作失败:", error); 377 | this.cleanupAndClose(); 378 | } 379 | } 380 | 381 | show(nodeId, imageData, node) { 382 | this.currentNodeId = nodeId; 383 | this.currentNode = node; 384 | 385 | const img = new Image(); 386 | img.onload = () => { 387 | this.canvas.width = img.width; 388 | this.canvas.height = img.height; 389 | this.ctx.drawImage(img, 0, 0); 390 | this.modal.showModal(); 391 | 392 | // 计算初始缩放比例 393 | this.calculateScale(); 394 | }; 395 | img.src = imageData; 396 | } 397 | } 398 | 399 | // 注册节点 400 | app.registerExtension({ 401 | name: "Comfy.ImageCropper", 402 | async setup() { 403 | const cropper = new ImageCropper(); 404 | 405 | // 监听裁剪更新事件 406 | api.addEventListener("image_cropper_update", ({ detail }) => { 407 | const { node_id, image_data } = detail; 408 | const node = app.graph.getNodeById(node_id); 409 | cropper.show(node_id, image_data, node); 410 | }); 411 | }, 412 | 413 | async beforeRegisterNodeDef(nodeType, nodeData) { 414 | // 只处理 ImageCropper 节点 415 | if (nodeData.name === "ImageCropper") { 416 | const onNodeCreated = nodeType.prototype.onNodeCreated; 417 | 418 | // 重写节点创建方法 419 | nodeType.prototype.onNodeCreated = function() { 420 | if (onNodeCreated) { 421 | onNodeCreated.apply(this, arguments); 422 | } 423 | 424 | // 创建种子值组件 425 | const seedWidget = this.addWidget( 426 | "number", 427 | "seed", 428 | 0, 429 | (value) => { 430 | this.seed = value; 431 | }, 432 | { 433 | min: 0, 434 | max: Number.MAX_SAFE_INTEGER, 435 | step: 1, 436 | precision: 0 437 | } 438 | ); 439 | 440 | // 创建种子模式控制组件 441 | const seed_modeWidget = this.addWidget( 442 | "combo", 443 | "seed_mode", 444 | "randomize", 445 | () => {}, 446 | { 447 | values: ["fixed", "increment", "decrement", "randomize"], 448 | serialize: false 449 | } 450 | ); 451 | 452 | // 添加控制逻辑 - 自动运行时的行为 453 | seed_modeWidget.beforeQueued = () => { 454 | const mode = seed_modeWidget.value; 455 | let newValue = seedWidget.value; 456 | 457 | if (mode === "randomize") { 458 | // 随机模式:每次执行都随机化 459 | newValue = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER); 460 | } else if (mode === "increment") { 461 | // 递增模式:每次+1 462 | newValue += 1; 463 | } else if (mode === "decrement") { 464 | // 递减模式:每次-1 465 | newValue -= 1; 466 | } else if (mode === "fixed") { 467 | // fixed模式:如果还没有固定种子,则生成一次,然后保持不变 468 | if (!this.hasFixedSeed) { 469 | newValue = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER); 470 | this.hasFixedSeed = true; 471 | } 472 | // 已经有固定种子时,不做任何改变 473 | } 474 | 475 | seedWidget.value = newValue; 476 | this.seed = newValue; 477 | }; 478 | 479 | // 模式变更时重置fixed标志 480 | seed_modeWidget.callback = (value) => { 481 | if (value !== "fixed") { 482 | // 如果切换到非fixed模式,重置标志 483 | this.hasFixedSeed = false; 484 | } 485 | }; 486 | 487 | // 创建更新按钮 488 | const updateButton = this.addWidget("button", "更新种子", null, () => { 489 | const mode = seed_modeWidget.value; 490 | let newValue = seedWidget.value; 491 | 492 | if (mode === "randomize") { 493 | newValue = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER); 494 | } else if (mode === "increment") { 495 | newValue += 1; 496 | } else if (mode === "decrement") { 497 | newValue -= 1; 498 | } else if (mode === "fixed") { 499 | // fixed模式下点击按钮也更新种子,并重置标志 500 | newValue = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER); 501 | this.hasFixedSeed = true; // 标记为已设置固定种子 502 | } 503 | 504 | seedWidget.value = newValue; 505 | seedWidget.callback(newValue); 506 | 507 | }); 508 | }; 509 | } 510 | } 511 | }); 512 | --------------------------------------------------------------------------------