├── .gitignore
├── py
├── inspyrenet
│ ├── __init__.py
│ ├── config.yaml
│ ├── modules
│ │ ├── decoder_module.py
│ │ ├── context_module.py
│ │ ├── attention_module.py
│ │ └── layers.py
│ ├── InSPyReNet.py
│ └── utils.py
├── md.py
├── combo_setter.py
├── inspynet.py
├── color_adjustment.py
├── switch.py
├── image_size_adjustment.py
├── image_cropper.py
├── image_selector.py
└── bridge_preview.py
├── requirements.txt
├── assets
├── crop.jpg
├── size.jpg
├── noise.jpg
├── switch.jpg
├── refresh.png
├── FastCanvas.png
├── color_adjust.jpg
└── CachePreviewBridge.png
├── pyproject.toml
├── .github
└── workflows
│ └── publish.yml
├── web
├── pip_manager.js
├── install_dependencies.js
├── lib
│ └── fabric-slim.min.js
├── counter.js
├── Util.js
├── multi_button_widget.js
├── PB.js
├── image_loader_counter.js
├── combo_setter.js
├── queue_shortcut.js
├── bridge_preview.js
├── lg_group_muter.js
├── image_size_adjustment.js
├── color_adjustment.js
├── upload.js
└── image_cropper.js
├── LICENSE.txt
├── __init__.py
└── README.md
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
--------------------------------------------------------------------------------
/py/inspyrenet/__init__.py:
--------------------------------------------------------------------------------
1 | from .Remover import Remover, console
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-contrib-python
2 | transparent-background
3 |
--------------------------------------------------------------------------------
/assets/crop.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LAOGOU-666/Comfyui_LG_Tools/HEAD/assets/crop.jpg
--------------------------------------------------------------------------------
/assets/size.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LAOGOU-666/Comfyui_LG_Tools/HEAD/assets/size.jpg
--------------------------------------------------------------------------------
/assets/noise.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LAOGOU-666/Comfyui_LG_Tools/HEAD/assets/noise.jpg
--------------------------------------------------------------------------------
/assets/switch.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LAOGOU-666/Comfyui_LG_Tools/HEAD/assets/switch.jpg
--------------------------------------------------------------------------------
/assets/refresh.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LAOGOU-666/Comfyui_LG_Tools/HEAD/assets/refresh.png
--------------------------------------------------------------------------------
/assets/FastCanvas.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LAOGOU-666/Comfyui_LG_Tools/HEAD/assets/FastCanvas.png
--------------------------------------------------------------------------------
/assets/color_adjust.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LAOGOU-666/Comfyui_LG_Tools/HEAD/assets/color_adjust.jpg
--------------------------------------------------------------------------------
/assets/CachePreviewBridge.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LAOGOU-666/Comfyui_LG_Tools/HEAD/assets/CachePreviewBridge.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "comfyui_lg_tools"
3 | description = "This is a toolset designed for ComfyUI by LAOGOU-666, providing a series of practical image processing and operation nodes, making our operation more intuitive and convenient"
4 | version = "1.3.7"
5 | license = { text = "MIT License" }
6 | dependencies = ["opencv-contrib-python"]
7 |
8 | [project.urls]
9 | Repository = "https://github.com/LAOGOU-666/Comfyui_LG_Tools"
10 |
11 | [tool.comfy]
12 | PublisherId = "laogou666"
13 | DisplayName = "Comfyui_LG_Tools"
14 | Icon = ""
15 |
--------------------------------------------------------------------------------
/py/inspyrenet/config.yaml:
--------------------------------------------------------------------------------
1 | base:
2 | url: "https://github.com/plemeri/transparent-background/releases/download/1.2.12/ckpt_base.pth"
3 | md5: "d692e3dd5fa1b9658949d452bebf1cda"
4 | ckpt_name: "ckpt_base.pth"
5 | http_proxy: NULL
6 | base_size: [1024, 1024]
7 |
8 |
9 | fast:
10 | url: "https://github.com/plemeri/transparent-background/releases/download/1.2.12/ckpt_fast.pth"
11 | md5: "9efdbfbcc49b79ef0f7891c83d2fd52f"
12 | ckpt_name: "ckpt_fast.pth"
13 | http_proxy: NULL
14 | base_size: [384, 384]
15 |
16 | base-nightly:
17 | url: "https://github.com/plemeri/transparent-background/releases/download/1.2.12/ckpt_base_nightly.pth"
18 | md5: NULL
19 | ckpt_name: "ckpt_base_nightly.pth"
20 | http_proxy: NULL
21 | base_size: [1024, 1024]
22 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish to Comfy registry
2 | on:
3 | workflow_dispatch:
4 | push:
5 | branches:
6 | - main
7 | paths:
8 | - "pyproject.toml"
9 |
10 | permissions:
11 | issues: write
12 |
13 | jobs:
14 | publish-node:
15 | name: Publish Custom Node to registry
16 | runs-on: ubuntu-latest
17 | if: ${{ github.repository_owner == 'LAOGOU-666' }}
18 | steps:
19 | - name: Check out code
20 | uses: actions/checkout@v4
21 | with:
22 | submodules: true
23 | - name: Publish Custom Node
24 | uses: Comfy-Org/publish-node-action@v1
25 | with:
26 | ## Add your own personal access token to your Github Repository secrets and reference it here.
27 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
28 |
--------------------------------------------------------------------------------
/py/md.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import numpy as np
4 | import cv2
5 | import os
6 | import sys
7 | import folder_paths
8 | import base64
9 | import io
10 | import traceback
11 | import subprocess
12 | import hashlib
13 | try:
14 | import torchvision.transforms.v2 as T
15 | except ImportError:
16 | import torchvision.transforms as T
17 | from aiohttp import web
18 | from PIL import Image, ImageOps
19 | from io import BytesIO
20 | from threading import Event
21 | from nodes import LoadImage, PreviewImage
22 | from server import PromptServer
23 | routes = PromptServer.instance.routes
24 |
25 | class AlwaysEqualProxy(str):
26 | def __eq__(self, _):
27 | return True
28 |
29 | def __ne__(self, _):
30 | return False
31 |
32 | class AnyType(str):
33 | """用于表示任意类型的特殊类,在类型比较时总是返回相等"""
34 | def __eq__(self, _) -> bool:
35 | return True
36 |
37 | def __ne__(self, __value: object) -> bool:
38 | return False
39 |
40 | any = AnyType("*")
--------------------------------------------------------------------------------
/web/pip_manager.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../scripts/app.js";
2 | import { TerminalManager } from "./Util.js";
3 |
4 | const terminal = new TerminalManager("/lg/pip_manager", "LG_PipManager");
5 |
6 | app.registerExtension({
7 | name: "LG.PipManager",
8 |
9 | async beforeRegisterNodeDef(nodeType, nodeData, app) {
10 | if (nodeData.name === "LG_PipManager") {
11 | const onNodeCreated = nodeType.prototype.onNodeCreated;
12 | nodeType.prototype.onNodeCreated = async function () {
13 | const me = onNodeCreated?.apply(this);
14 | terminal.setupNode(this);
15 | return me;
16 | };
17 |
18 | const onDrawForeground = nodeType.prototype.onDrawForeground;
19 | nodeType.prototype.onDrawForeground = function (ctx, graphcanvas) {
20 | return terminal.updateNode(this, onDrawForeground, ctx, graphcanvas);
21 | };
22 | }
23 | },
24 | });
--------------------------------------------------------------------------------
/web/install_dependencies.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../scripts/app.js";
2 | import { TerminalManager } from "./Util.js";
3 |
4 | const terminal = new TerminalManager("/lg/install_dependencies", "LG_InstallDependencies");
5 |
6 | app.registerExtension({
7 | name: "LG.InstallDependencies",
8 |
9 | async beforeRegisterNodeDef(nodeType, nodeData, app) {
10 | if (nodeData.name === "LG_InstallDependencies") {
11 | const onNodeCreated = nodeType.prototype.onNodeCreated;
12 | nodeType.prototype.onNodeCreated = async function () {
13 | const me = onNodeCreated?.apply(this);
14 | terminal.setupNode(this);
15 | return me;
16 | };
17 |
18 | const onDrawForeground = nodeType.prototype.onDrawForeground;
19 | nodeType.prototype.onDrawForeground = function (ctx, graphcanvas) {
20 | return terminal.updateNode(this, onDrawForeground, ctx, graphcanvas);
21 | };
22 | }
23 | },
24 | });
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 LAOGOU-666
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/py/inspyrenet/modules/decoder_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .layers import *
6 | class PAA_d(nn.Module):
7 | def __init__(self, in_channel, out_channel=1, depth=64, base_size=None, stage=None):
8 | super(PAA_d, self).__init__()
9 | self.conv1 = Conv2d(in_channel ,depth, 3)
10 | self.conv2 = Conv2d(depth, depth, 3)
11 | self.conv3 = Conv2d(depth, depth, 3)
12 | self.conv4 = Conv2d(depth, depth, 3)
13 | self.conv5 = Conv2d(depth, out_channel, 3, bn=False)
14 |
15 | self.base_size = base_size
16 | self.stage = stage
17 |
18 | if base_size is not None and stage is not None:
19 | self.stage_size = (base_size[0] // (2 ** stage), base_size[1] // (2 ** stage))
20 | else:
21 | self.stage_size = [None, None]
22 |
23 | self.Hattn = SelfAttention(depth, 'h', self.stage_size[0])
24 | self.Wattn = SelfAttention(depth, 'w', self.stage_size[1])
25 |
26 | self.upsample = lambda img, size: F.interpolate(img, size=size, mode='bilinear', align_corners=True)
27 |
28 | def forward(self, fs): #f3 f4 f5 -> f3 f2 f1
29 | fx = fs[0]
30 | for i in range(1, len(fs)):
31 | fs[i] = self.upsample(fs[i], fx.shape[-2:])
32 | fx = torch.cat(fs[::-1], dim=1)
33 |
34 | fx = self.conv1(fx)
35 |
36 | Hfx = self.Hattn(fx)
37 | Wfx = self.Wattn(fx)
38 |
39 | fx = self.conv2(Hfx + Wfx)
40 | fx = self.conv3(fx)
41 | fx = self.conv4(fx)
42 | out = self.conv5(fx)
43 |
44 | return fx, out
--------------------------------------------------------------------------------
/py/combo_setter.py:
--------------------------------------------------------------------------------
1 | CATEGORY_TYPE = "🎈LAOGOU/Utils"
2 |
3 | class ComboSetter:
4 | """
5 | 动态Combo设置器
6 | """
7 | @classmethod
8 | def INPUT_TYPES(cls):
9 | return {
10 | "required": {
11 | "labels": ("STRING", {
12 | "default": "",
13 | "multiline": True,
14 | "placeholder": "每行一个标签"
15 | }),
16 | "prompts": ("STRING", {
17 | "default": "",
18 | "multiline": True,
19 | "placeholder": "每行一个提示词"
20 | }),
21 | "selected": ("STRING", {
22 | "default": ""
23 | }),
24 | }
25 | }
26 |
27 | RETURN_TYPES = ("STRING", "STRING",)
28 | RETURN_NAMES = ("selected_label", "selected_prompt",)
29 | FUNCTION = "execute"
30 | CATEGORY = CATEGORY_TYPE
31 |
32 | def execute(self, labels, prompts, selected):
33 | # 按行分割
34 | label_lines = [line.strip() for line in labels.split('\n') if line.strip()]
35 | prompt_lines = [line.strip() for line in prompts.split('\n') if line.strip()]
36 |
37 | # 找到选中的索引
38 | selected_prompt = ""
39 | if selected in label_lines:
40 | index = label_lines.index(selected)
41 | if index < len(prompt_lines):
42 | selected_prompt = prompt_lines[index]
43 |
44 | return (selected, selected_prompt)
45 |
46 | NODE_CLASS_MAPPINGS = {
47 | "ComboSetter": ComboSetter,
48 | }
49 |
50 | NODE_DISPLAY_NAME_MAPPINGS = {
51 | "ComboSetter": "🎈ComboSetter",
52 | }
53 |
54 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | import importlib.util
3 | import os
4 | import sys
5 | import json
6 |
7 | NODE_CLASS_MAPPINGS = {}
8 | NODE_DISPLAY_NAME_MAPPINGS = {}
9 | WEB_DIRECTORY = "web"
10 | python = sys.executable
11 |
12 | def get_ext_dir(subpath=None, mkdir=False):
13 | dir = os.path.dirname(__file__)
14 | if subpath is not None:
15 | dir = os.path.join(dir, subpath)
16 |
17 | dir = os.path.abspath(dir)
18 |
19 | if mkdir and not os.path.exists(dir):
20 | os.makedirs(dir)
21 | return dir
22 |
23 | def serialize(obj):
24 | if isinstance(obj, (str, int, float, bool, list, dict, type(None))):
25 | return obj
26 | return str(obj)
27 |
28 |
29 | py = get_ext_dir("py")
30 | files = os.listdir(py)
31 | all_nodes = {}
32 | for file in files:
33 | if not file.endswith(".py"):
34 | continue
35 | name = os.path.splitext(file)[0]
36 | imported_module = importlib.import_module(".py.{}".format(name), __name__)
37 | try:
38 | NODE_CLASS_MAPPINGS = {**NODE_CLASS_MAPPINGS, **imported_module.NODE_CLASS_MAPPINGS}
39 | NODE_DISPLAY_NAME_MAPPINGS = {**NODE_DISPLAY_NAME_MAPPINGS, **imported_module.NODE_DISPLAY_NAME_MAPPINGS}
40 | serialized_CLASS_MAPPINGS = {k: serialize(v) for k, v in imported_module.NODE_CLASS_MAPPINGS.items()}
41 | serialized_DISPLAY_NAME_MAPPINGS = {k: serialize(v) for k, v in imported_module.NODE_DISPLAY_NAME_MAPPINGS.items()}
42 | all_nodes[file]={"NODE_CLASS_MAPPINGS": serialized_CLASS_MAPPINGS, "NODE_DISPLAY_NAME_MAPPINGS": serialized_DISPLAY_NAME_MAPPINGS}
43 | except:
44 | pass
45 |
46 |
47 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
48 |
--------------------------------------------------------------------------------
/web/lib/fabric-slim.min.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Fabric.js 精简版 (基于 fabric-with-all-modules@4.6.1)
3 | * 压缩版本
4 | */
5 | import{fabric as t}from"./fabric.js";const fabric={};fabric.version=t.version,fabric.isTouchSupported=t.isTouchSupported,fabric.isLikelyNode=t.isLikelyNode,fabric.SHARED_ATTRIBUTES=t.SHARED_ATTRIBUTES,fabric.DPI=t.DPI,fabric.reNum=t.reNum,fabric.devicePixelRatio=t.devicePixelRatio,fabric.browserShadowBlurConstant=t.browserShadowBlurConstant,fabric.enableGLFiltering=t.enableGLFiltering,fabric.initFilterBackend=t.initFilterBackend,fabric.util={addClass:t.util.addClass,animate:t.util.animate,animateColor:t.util.animateColor,applyTransformToObject:t.util.applyTransformToObject,cos:t.util.cos,sin:t.util.sin,rotatePoint:t.util.rotatePoint,transformPoint:t.util.transformPoint,makeElement:t.util.makeElement,createCanvasElement:t.util.createCanvasElement,createImage:t.util.createImage,loadImage:t.util.loadImage,toFixed:t.util.toFixed,multiplyTransformMatrices:t.util.multiplyTransformMatrices,invertTransform:t.util.invertTransform,getRandomInt:t.util.getRandomInt,degreesToRadians:t.util.degreesToRadians,radiansToDegrees:t.util.radiansToDegrees,object:t.util.object,string:t.util.string,array:t.util.array,setImageSmoothing:t.util.setImageSmoothing,getPointer:t.util.getPointer,falseFunction:t.util.falseFunction,requestAnimFrame:t.util.requestAnimFrame,cancelAnimFrame:t.util.cancelAnimFrame},fabric.Observable=t.Observable,fabric.Collection=t.Collection,fabric.Point=t.Point,fabric.Intersection=t.Intersection,fabric.Color=t.Color,fabric.Object=t.Object,fabric.Canvas=t.Canvas,fabric.StaticCanvas=t.StaticCanvas,fabric.Rect=t.Rect,fabric.Image=t.Image,fabric.Text=t.Text,fabric.IText=t.IText,fabric.Textbox=t.Textbox,fabric.Shadow=t.Shadow;export{fabric};
--------------------------------------------------------------------------------
/py/inspyrenet/modules/context_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .layers import *
6 |
7 | class PAA_kernel(nn.Module):
8 | def __init__(self, in_channel, out_channel, receptive_size, stage_size=None):
9 | super(PAA_kernel, self).__init__()
10 | self.conv0 = Conv2d(in_channel, out_channel, 1)
11 | self.conv1 = Conv2d(out_channel, out_channel, kernel_size=(1, receptive_size))
12 | self.conv2 = Conv2d(out_channel, out_channel, kernel_size=(receptive_size, 1))
13 | self.conv3 = Conv2d(out_channel, out_channel, 3, dilation=receptive_size)
14 | self.Hattn = SelfAttention(out_channel, 'h', stage_size[0] if stage_size is not None else None)
15 | self.Wattn = SelfAttention(out_channel, 'w', stage_size[1] if stage_size is not None else None)
16 |
17 | def forward(self, x):
18 | x = self.conv0(x)
19 | x = self.conv1(x)
20 | x = self.conv2(x)
21 |
22 | Hx = self.Hattn(x)
23 | Wx = self.Wattn(x)
24 |
25 | x = self.conv3(Hx + Wx)
26 | return x
27 |
28 | class PAA_e(nn.Module):
29 | def __init__(self, in_channel, out_channel, base_size=None, stage=None):
30 | super(PAA_e, self).__init__()
31 | self.relu = nn.ReLU(True)
32 | if base_size is not None and stage is not None:
33 | self.stage_size = (base_size[0] // (2 ** stage), base_size[1] // (2 ** stage))
34 | else:
35 | self.stage_size = None
36 |
37 | self.branch0 = Conv2d(in_channel, out_channel, 1)
38 | self.branch1 = PAA_kernel(in_channel, out_channel, 3, self.stage_size)
39 | self.branch2 = PAA_kernel(in_channel, out_channel, 5, self.stage_size)
40 | self.branch3 = PAA_kernel(in_channel, out_channel, 7, self.stage_size)
41 |
42 | self.conv_cat = Conv2d(4 * out_channel, out_channel, 3)
43 | self.conv_res = Conv2d(in_channel, out_channel, 1)
44 |
45 | def forward(self, x):
46 | x0 = self.branch0(x)
47 | x1 = self.branch1(x)
48 | x2 = self.branch2(x)
49 | x3 = self.branch3(x)
50 |
51 | x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))
52 | x = self.relu(x_cat + self.conv_res(x))
53 |
54 | return x
55 |
--------------------------------------------------------------------------------
/py/inspynet.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import torch
3 | import numpy as np
4 | from .inspyrenet import Remover
5 | from tqdm import tqdm
6 |
7 |
8 | def tensor2pil(image):
9 | return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
10 |
11 | # Convert PIL to Tensor
12 | def pil2tensor(image):
13 | return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
14 | class InspyrenetRembgLoader:
15 | def __init__(self):
16 | self.model = None
17 |
18 | @classmethod
19 | def INPUT_TYPES(s):
20 | return {
21 | "required": {
22 | "mode": (["base", "fast"],),
23 | "torchscript_jit": (["default", "on"],),
24 | },
25 | }
26 |
27 | RETURN_TYPES = ("INSPYRENET_MODEL",)
28 | FUNCTION = "load_model"
29 | CATEGORY = "image"
30 |
31 | def load_model(self, mode, torchscript_jit):
32 | jit = torchscript_jit == "on"
33 | self.model = Remover(mode=mode, jit=jit)
34 | return (self.model,)
35 |
36 | class InspyrenetRembgProcess:
37 | @classmethod
38 | def INPUT_TYPES(s):
39 | return {
40 | "required": {
41 | "model": ("INSPYRENET_MODEL",),
42 | "image": ("IMAGE",),
43 | "threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
44 | "background_color": ("STRING", {"default": "", "multiline": False}),
45 | },
46 | }
47 |
48 | RETURN_TYPES = ("IMAGE", "MASK")
49 | FUNCTION = "process_image"
50 | CATEGORY = "image"
51 |
52 | def process_image(self, model, image, threshold, background_color):
53 | img_list = []
54 | mask_list = [] # 新增mask列表
55 | for img in tqdm(image, "Inspyrenet Rembg"):
56 | mid = model.process(tensor2pil(img), type='rgba', threshold=threshold)
57 |
58 | # 保存mask(在处理背景之前)
59 | rgba_tensor = pil2tensor(mid)
60 | mask_list.append(rgba_tensor[:, :, :, 3]) # 保存alpha通道作为mask
61 |
62 | if background_color.strip():
63 | try:
64 | rgba_img = mid
65 | background = Image.new('RGBA', rgba_img.size, background_color)
66 | mid = Image.alpha_composite(background, rgba_img)
67 | # 转换为RGB模式,移除alpha通道
68 | mid = mid.convert('RGB')
69 | except ValueError:
70 | print(f"无效的颜色值: {background_color},使用透明背景")
71 |
72 | out = pil2tensor(mid)
73 | img_list.append(out)
74 |
75 | img_stack = torch.cat(img_list, dim=0)
76 | mask_stack = torch.cat(mask_list, dim=0) # 合并所有mask
77 |
78 | # 如果有背景色,返回RGB图像和mask
79 | if background_color.strip():
80 | return (img_stack[:, :, :, :3], mask_stack)
81 | # 如果没有背景色,保持原来的RGBA格式
82 | else:
83 | return (img_stack, mask_stack)
84 |
85 | NODE_CLASS_MAPPINGS = {
86 | "InspyrenetRembgLoader": InspyrenetRembgLoader,
87 | "InspyrenetRembgProcess": InspyrenetRembgProcess
88 | }
89 | NODE_DISPLAY_NAME_MAPPINGS = {
90 | "InspyrenetRembgLoader": "InSPyReNet Loader",
91 | "InspyrenetRembgProcess": "InSPyReNet Rembg"
92 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Comfyui_LG_Tools
2 |
3 | 这是***LG_老狗的学习笔记***为ComfyUI设计的工具集,提供了一系列实用的图像处理和操作节点,让我们的操作变得更加直观方便
4 |
5 | ## 安装说明
6 |
7 | 1. 确保已安装ComfyUI
8 | 2. 将此仓库克隆到ComfyUI的`custom_nodes`目录下:
9 | ```bash
10 | cd ComfyUI/custom_nodes
11 | git clone https://github.com/LAOGOU-666/Comfyui_LG_Tools.git
12 | ```
13 |
14 | 3. 安装依赖:
15 | ```bash
16 | pip install -r requirements.txt
17 | ```
18 |
19 | ## 使用方法
20 |
21 | 1. 启动ComfyUI
22 | 2. 在右键添加节点中,您可以在"🎈LAOGOU"类别下找到所有工具节点
23 | 3. 将需要的节点拖入工作区并连接使用
24 |
25 | ## 节点说明
26 |
27 | ### 图像裁剪节点
28 | 
29 | - 点击左键框选指定范围进行裁剪
30 | ### 图像尺寸调整节点
31 | 
32 |
33 | ### 颜色调整节点
34 | 
35 |
36 | ### FastCanvas画布节点
37 | 
38 | > * 1.支持实时调整构图输出图像和选中图层遮罩
39 | >
40 | > * 2.支持批量构图,切换输入图层继承上个图层的位置和缩放
41 | >
42 | > * 3.支持限制画布窗口视图大小的功能,不用担心图片较大占地方了
43 | >
44 | > * 4.图层支持右键辅助功能,变换,设置背景等,点开有惊喜
45 | >
46 | > * 5.支持通过输入端口输入图像,
47 | 支持无输入端口独立使用,
48 | >
49 | > * 6.支持复制,拖拽,以及上传方式对图像处理
50 | >
51 | 注意!fastcanvas tool动态输入节点使用方法:
52 | * bg_img输入背景RGB图片,img输入图层图片,可以输入RGB/RGBA图片
53 | * **系统自带的加载图片节点默认输出的是RGB!不是RGBA(带遮罩通道的图片)!使用加载图像输入RGBA需要合并ALPHA图层!**
54 | ### 开关节点
55 | 
56 | - 一键控制单组/多组的忽略或者禁用模式
57 | - 惰性求值开关lazyswitch(仅运行指定线路,非指定线路无需加载)
58 | - 注意!点击开关节点右键有设置不同模式(忽略和禁用)的功能
59 |
60 | ### 噪波节点
61 | 
62 | - 添加自定义噪波以对图像进行预处理
63 |
64 | ### 桥接预览节点
65 | 
66 |
67 | > * 当你使用input模式将图片输入到节点后,点击Cache按钮即可缓存当前图片,然后进行编辑遮罩,并且不会出现遮罩被重置的问题,
68 | >
69 | > * 在点击Cache按钮后,无论输入端口是否连接,是否刷新,都不会影响当前缓存的图片和遮罩,你可以继续在当前节点编辑遮罩并且不会重置缓存
70 | >
71 | > * 现在支持复制功能,相当于加载图片节点和桥接预览节点的集合,对于需要重复操作以及大型工作流的缓存处理能提供很大便利
72 |
73 | ### 加载图像节点和桥接预览节点(新版本)
74 | 
75 | > * 新增刷新功能,一键将temp/output文件夹的最新图片刷新到当前节点,对于重复处理流程可以省去复制粘贴的操作
76 |
77 | ### 图片加载器(计数器)节点 🎈LG_ImageLoaderWithCounter
78 | 一个集成计数器功能的图片加载器节点,可从指定文件夹自动加载图片,**内置计数器无需外接**。
79 |
80 | #### 主要功能
81 | - **内置计数器**:自动获取文件夹图片总数,无需外接计数器节点
82 | - **三种加载模式**:
83 | - **increase**: 递增模式,每次加载下一张(0 → 1 → 2...)
84 | - **decrease**: 递减模式,每次加载上一张(9 → 8 → 7...)
85 | - **all**: 一次性加载所有图片为列表
86 | - **多种排序方式**:
87 | - Alphabetical (ASC/DESC) - 按字母表排序
88 | - Numerical (ASC/DESC) - 按数字排序
89 | - Datetime (ASC/DESC) - 按文件修改时间排序
90 | - **灵活的路径支持**:支持相对路径和绝对路径
91 | - **实时状态显示**:在节点标题栏实时显示当前索引(如:5/100)
92 | - **刷新功能**:节点内置刷新按钮,可重置计数器
93 |
94 | #### 输入参数
95 | - `folder_path` - 图片文件夹路径(支持相对路径和绝对路径)
96 | - `mode` - 加载模式
97 | - increase: 每次执行索引递增,自动循环
98 | - decrease: 每次执行索引递减,自动循环
99 | - all: 一次性加载文件夹内所有图片为列表
100 | - `sort_mode` - 排序模式,决定文件夹中图片的加载顺序
101 | - `keep_index` - 保持索引开关(Boolean)
102 | - True: 保持当前索引不变,暂停计数
103 | - False: 正常递增/递减(默认)
104 |
105 | #### 输出参数(列表格式)
106 | - `images` - 加载的图片列表 (IMAGE LIST)
107 | - `masks` - 图片遮罩列表 (MASK LIST)
108 | - `filenames` - 文件名列表 (STRING LIST)
109 | - `current_index` - 当前索引 (INT)
110 | - `total_images` - 图片总数 (INT)
111 |
112 | **注意**:所有输出都是列表格式,increase/decrease模式输出单元素列表,all模式输出所有图片列表。
113 |
114 | #### 使用示例
115 | 1. **批量顺序处理**:设置mode为increase,每次Queue执行自动加载下一张图片
116 | 2. **倒序处理**:设置mode为decrease,从最后一张图片开始往前处理
117 | 3. **一次性加载所有图片**:设置mode为all,将文件夹内所有图片加载为列表
118 | 4. **按时间顺序处理**:使用Datetime (ASC)排序,按时间顺序处理图片
119 | 5. **暂停在特定图片**:在递增/递减模式下,开启keep_index,可以暂停在当前图片进行测试
120 |
121 | #### 注意事项
122 | - 支持的图片格式:jpg, jpeg, png, bmp, gif, webp, tiff, tif
123 | - 相对路径基于ComfyUI根目录
124 | - 递增/递减模式会自动循环(到达末尾后回到开头)
125 | - 节点右键菜单或刷新按钮可重置计数器
126 | - all模式适合配合批处理节点使用
127 |
128 | ## 注意
129 | * 因为该库的节点是从LG_Node里面拆分出来的,之前购买过LG_Node的如果需要使用这个节点包,请联系我获取新的版本以避免出现节点冲突
130 |
131 | ## 合作/定制/0基础插件教程
132 | - **wechat:** wenrulaogou2033
133 | - **Bilibili:** 老狗_学习笔记
--------------------------------------------------------------------------------
/py/color_adjustment.py:
--------------------------------------------------------------------------------
1 | from .md import *
2 | node_data = {}
3 | class ColorAdjustment:
4 | """颜色调整节点"""
5 |
6 | @classmethod
7 | def INPUT_TYPES(cls):
8 | return {
9 | "required": {
10 | "image": ("IMAGE",),
11 | },
12 | "hidden": {
13 | "unique_id": "UNIQUE_ID",
14 | }
15 | }
16 |
17 | RETURN_TYPES = ("IMAGE",)
18 | FUNCTION = "adjust"
19 | CATEGORY = "🎈LAOGOU/Image"
20 | OUTPUT_NODE = True
21 |
22 | def adjust(self, image, unique_id):
23 | try:
24 | node_id = unique_id
25 | event = Event()
26 | node_data[node_id] = {
27 | "event": event,
28 | "result": None,
29 | "shape": image.shape
30 | }
31 |
32 | preview_image = (torch.clamp(image.clone(), 0, 1) * 255).cpu().numpy().astype(np.uint8)[0]
33 | pil_image = Image.fromarray(preview_image)
34 | buffer = io.BytesIO()
35 | pil_image.save(buffer, format="PNG")
36 | base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
37 |
38 | try:
39 | PromptServer.instance.send_sync("color_adjustment_update", {
40 | "node_id": node_id,
41 | "image_data": f"data:image/png;base64,{base64_image}"
42 | })
43 |
44 | if not event.wait(timeout=5):
45 | if node_id in node_data:
46 | del node_data[node_id]
47 | return (image,)
48 |
49 | result_image = node_data[node_id]["result"]
50 | del node_data[node_id]
51 | return (result_image if result_image is not None else image,)
52 |
53 | except Exception as e:
54 | if node_id in node_data:
55 | del node_data[node_id]
56 | return (image,)
57 |
58 | except Exception as e:
59 | if node_id in node_data:
60 | del node_data[node_id]
61 | return (image,)
62 |
63 | @PromptServer.instance.routes.post("/color_adjustment/apply")
64 | async def apply_color_adjustment(request):
65 | try:
66 | data = await request.json()
67 | node_id = data.get("node_id")
68 | adjusted_data = data.get("adjusted_data")
69 |
70 | if node_id not in node_data:
71 | return web.json_response({"success": False, "error": "节点数据不存在"})
72 |
73 | try:
74 | node_info = node_data[node_id]
75 |
76 | if isinstance(adjusted_data, list):
77 | batch, height, width, channels = node_info["shape"]
78 |
79 | if len(adjusted_data) >= height * width * 4:
80 | rgba_array = np.array(adjusted_data, dtype=np.uint8).reshape(height, width, 4)
81 | rgb_array = rgba_array[:, :, :3]
82 | tensor_image = torch.from_numpy(rgb_array / 255.0).float().reshape(batch, height, width, channels)
83 | node_info["result"] = tensor_image
84 |
85 | node_info["event"].set()
86 | return web.json_response({"success": True})
87 |
88 | except Exception as e:
89 | if node_id in node_data and "event" in node_data[node_id]:
90 | node_data[node_id]["event"].set()
91 | return web.json_response({"success": False, "error": str(e)})
92 |
93 | except Exception as e:
94 | return web.json_response({"success": False, "error": str(e)})
95 |
96 | NODE_CLASS_MAPPINGS = {
97 | "ColorAdjustment": ColorAdjustment,
98 | }
99 |
100 | NODE_DISPLAY_NAME_MAPPINGS = {
101 | "ColorAdjustment": "颜色调整",
102 | }
103 |
--------------------------------------------------------------------------------
/py/switch.py:
--------------------------------------------------------------------------------
1 | from .md import *
2 | CATEGORY_TYPE = "🎈LAOGOU/Switch"
3 | class LazySwitch2way:
4 | @classmethod
5 | def INPUT_TYPES(cls):
6 | return {
7 | "required": {
8 | "boolean": ("BOOLEAN", {"default": True}),
9 | "ON_TRUE": (any, {"lazy": True}),
10 | "on_true": (any, {"lazy": True}),
11 | "ON_FALSE": (any, {"lazy": True}),
12 | "on_false": (any, {"lazy": True}),
13 | }
14 | }
15 |
16 | RETURN_TYPES = (any, any,)
17 | RETURN_NAMES = ("OUTPUT", "output",)
18 | FUNCTION = "switch"
19 | CATEGORY = CATEGORY_TYPE
20 |
21 | def check_lazy_status(self, boolean, ON_TRUE, on_true, ON_FALSE, on_false):
22 | result = []
23 | if boolean:
24 | if ON_TRUE is None:
25 | result.append("ON_TRUE")
26 | if on_true is None:
27 | result.append("on_true")
28 | else:
29 | if ON_FALSE is None:
30 | result.append("ON_FALSE")
31 | if on_false is None:
32 | result.append("on_false")
33 | return result if result else None
34 |
35 | def switch(self, boolean, ON_TRUE, on_true, ON_FALSE, on_false):
36 | if boolean:
37 | return (ON_TRUE, on_true,)
38 | else:
39 | return (ON_FALSE, on_false,)
40 |
41 | class LazySwitch1way:
42 | @classmethod
43 | def INPUT_TYPES(cls):
44 | return {
45 | "required": {
46 | "boolean": ("BOOLEAN", {"default": True}),
47 | "ON_TRUE": (any, {"lazy": True}),
48 | "ON_FALSE": (any, {"lazy": True}),
49 | }
50 | }
51 |
52 | RETURN_TYPES = (any,)
53 | RETURN_NAMES = ("OUTPUT",)
54 | FUNCTION = "switch"
55 | CATEGORY = CATEGORY_TYPE
56 |
57 | def check_lazy_status(self, boolean, ON_TRUE, ON_FALSE):
58 | result = []
59 | if boolean:
60 | if ON_TRUE is None:
61 | result.append("ON_TRUE")
62 | else:
63 | if ON_FALSE is None:
64 | result.append("ON_FALSE")
65 | return result if result else None
66 |
67 | def switch(self, boolean, ON_TRUE, ON_FALSE):
68 | if boolean:
69 | return (ON_TRUE,)
70 | else:
71 | return (ON_FALSE,)
72 |
73 | class GroupSwitcher:
74 | """
75 | 组切换节点,带有布尔值控制
76 | """
77 | @classmethod
78 | def INPUT_TYPES(cls):
79 | return {
80 | "required": {
81 | "boolean": ("BOOLEAN", {"default": True}), # 布尔值控制
82 | }
83 | }
84 |
85 | RETURN_TYPES = () # 无输出端口
86 | FUNCTION = "switch"
87 | CATEGORY = CATEGORY_TYPE
88 |
89 | def switch(self, boolean):
90 |
91 | return ()
92 |
93 | class MuterSwitcher:
94 | @classmethod
95 | def INPUT_TYPES(cls):
96 | return {
97 | "required": {
98 | "boolean": ("BOOLEAN", {"default": True}),
99 | },
100 | "optional": {
101 | "ON_TRUE": (any,),
102 | "on_true": (any,),
103 | "ON_FALSE": (any,),
104 | "on_false": (any,),
105 | }
106 | }
107 |
108 | RETURN_TYPES = (any, any,)
109 | RETURN_NAMES = ("OUTPUT", "output",)
110 | FUNCTION = "switch"
111 | CATEGORY = CATEGORY_TYPE
112 |
113 | def switch(self, ON_TRUE=None, on_true=None, ON_FALSE=None, on_false=None, boolean=True):
114 | # 根据布尔值选择输出
115 | if boolean:
116 | return (ON_TRUE, on_true,)
117 | else:
118 | return (ON_FALSE, on_false,)
119 |
120 | NODE_CLASS_MAPPINGS = {
121 | "LazySwitch2way": LazySwitch2way,
122 | "LazySwitch1way": LazySwitch1way,
123 | "GroupSwitcher": GroupSwitcher,
124 | "MuterSwitcher": MuterSwitcher,
125 | }
126 |
127 | NODE_DISPLAY_NAME_MAPPINGS = {
128 | "LazySwitch2way": "🎈LazySwitch2way",
129 | "LazySwitch1way": "🎈LazySwitch1way",
130 | "GroupSwitcher": "🎈GroupSwitcher",
131 | "MuterSwitcher": "🎈MuterSwitcher"
132 | }
133 |
134 |
135 |
--------------------------------------------------------------------------------
/py/inspyrenet/modules/attention_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from torch.nn.parameter import Parameter
6 | from operator import xor
7 | from typing import Optional
8 |
9 | from .layers import *
10 |
11 | class SICA(nn.Module):
12 | def __init__(self, in_channel, out_channel=1, depth=64, base_size=None, stage=None, lmap_in=False):
13 | super(SICA, self).__init__()
14 | self.in_channel = in_channel
15 | self.depth = depth
16 | self.lmap_in = lmap_in
17 | if base_size is not None and stage is not None:
18 | self.stage_size = (base_size[0] // (2 ** stage), base_size[1] // (2 ** stage))
19 | else:
20 | self.stage_size = None
21 |
22 | self.conv_query = nn.Sequential(Conv2d(in_channel, depth, 3, relu=True),
23 | Conv2d(depth, depth, 3, relu=True))
24 | self.conv_key = nn.Sequential(Conv2d(in_channel, depth, 1, relu=True),
25 | Conv2d(depth, depth, 1, relu=True))
26 | self.conv_value = nn.Sequential(Conv2d(in_channel, depth, 1, relu=True),
27 | Conv2d(depth, depth, 1, relu=True))
28 |
29 | if self.lmap_in is True:
30 | self.ctx = 5
31 | else:
32 | self.ctx = 3
33 |
34 | self.conv_out1 = Conv2d(depth, depth, 3, relu=True)
35 | self.conv_out2 = Conv2d(in_channel + depth, depth, 3, relu=True)
36 | self.conv_out3 = Conv2d(depth, depth, 3, relu=True)
37 | self.conv_out4 = Conv2d(depth, out_channel, 1)
38 |
39 | self.threshold = Parameter(torch.tensor([0.5]))
40 |
41 | if self.lmap_in is True:
42 | self.lthreshold = Parameter(torch.tensor([0.5]))
43 |
44 | def forward(self, x, smap, lmap: Optional[torch.Tensor]=None):
45 | assert not xor(self.lmap_in is True, lmap is not None)
46 | b, c, h, w = x.shape
47 |
48 | # compute class probability
49 | smap = F.interpolate(smap, size=x.shape[-2:], mode='bilinear', align_corners=False)
50 | smap = torch.sigmoid(smap)
51 | p = smap - self.threshold
52 |
53 | fg = torch.clip(p, 0, 1) # foreground
54 | bg = torch.clip(-p, 0, 1) # background
55 | cg = self.threshold - torch.abs(p) # confusion area
56 |
57 | if self.lmap_in is True and lmap is not None:
58 | lmap = F.interpolate(lmap, size=x.shape[-2:], mode='bilinear', align_corners=False)
59 | lmap = torch.sigmoid(lmap)
60 | lp = lmap - self.lthreshold
61 | fp = torch.clip(lp, 0, 1) # foreground
62 | bp = torch.clip(-lp, 0, 1) # background
63 |
64 | prob = [fg, bg, cg, fp, bp]
65 | else:
66 | prob = [fg, bg, cg]
67 |
68 | prob = torch.cat(prob, dim=1)
69 |
70 | # reshape feature & prob
71 | if self.stage_size is not None:
72 | shape = self.stage_size
73 | shape_mul = self.stage_size[0] * self.stage_size[1]
74 | else:
75 | shape = (h, w)
76 | shape_mul = h * w
77 |
78 | f = F.interpolate(x, size=shape, mode='bilinear', align_corners=False).view(b, shape_mul, -1)
79 | prob = F.interpolate(prob, size=shape, mode='bilinear', align_corners=False).view(b, self.ctx, shape_mul)
80 |
81 | # compute context vector
82 | context = torch.bmm(prob, f).permute(0, 2, 1).unsqueeze(3) # b, 3, c
83 |
84 | # k q v compute
85 | query = self.conv_query(x).view(b, self.depth, -1).permute(0, 2, 1)
86 | key = self.conv_key(context).view(b, self.depth, -1)
87 | value = self.conv_value(context).view(b, self.depth, -1).permute(0, 2, 1)
88 |
89 | # compute similarity map
90 | sim = torch.bmm(query, key) # b, hw, c x b, c, 2
91 | sim = (self.depth ** -.5) * sim
92 | sim = F.softmax(sim, dim=-1)
93 |
94 | # compute refined feature
95 | context = torch.bmm(sim, value).permute(0, 2, 1).contiguous().view(b, -1, h, w)
96 | context = self.conv_out1(context)
97 |
98 | x = torch.cat([x, context], dim=1)
99 | x = self.conv_out2(x)
100 | x = self.conv_out3(x)
101 | out = self.conv_out4(x)
102 |
103 | return x, out
--------------------------------------------------------------------------------
/web/counter.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../scripts/app.js";
2 | import { api } from "../../scripts/api.js";
3 |
4 | app.registerExtension({
5 | name: "LG.Counter",
6 |
7 | async beforeRegisterNodeDef(nodeType, nodeData, app) {
8 | if (nodeData.name === "LG_Counter") {
9 |
10 | // 保存原始的 onNodeCreated 方法
11 | const onNodeCreated = nodeType.prototype.onNodeCreated;
12 |
13 | nodeType.prototype.onNodeCreated = function() {
14 | const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;
15 |
16 | // 初始化显示文本
17 | this.currentCountText = "";
18 |
19 | // 添加刷新按钮
20 | this.addWidget("button", "refresh", "刷新计数器", () => {
21 | this.resetCounter();
22 | });
23 |
24 | return r;
25 | };
26 |
27 | // 重写 onDrawForeground 方法来显示计数
28 | const onDrawForeground = nodeType.prototype.onDrawForeground;
29 | nodeType.prototype.onDrawForeground = function(ctx) {
30 | const r = onDrawForeground?.apply?.(this, arguments);
31 |
32 | // 如果有计数文本要显示
33 | if (this.currentCountText) {
34 | ctx.save();
35 |
36 | // 设置文本样式 - 在标题栏右上角显示
37 | ctx.font = "bold 20px sans-serif";
38 | ctx.textAlign = "right";
39 | ctx.textBaseline = "top";
40 | ctx.fillStyle = "#00ff00"; // 绿色
41 |
42 | // 在标题栏右上角显示计数
43 | const rightX = this.size[0] - 20; // 距离右边10像素
44 | const topY = -25; // 距离顶部6像素,标题栏内
45 |
46 | // 添加文本阴影以提高可读性
47 | ctx.shadowColor = "rgba(0, 0, 0, 0.8)";
48 | ctx.shadowBlur = 3;
49 | ctx.shadowOffsetX = 1;
50 | ctx.shadowOffsetY = 1;
51 |
52 | ctx.fillText(this.currentCountText, rightX, topY);
53 |
54 | ctx.restore();
55 | }
56 |
57 | return r;
58 | };
59 |
60 | // 添加重置计数器的方法
61 | nodeType.prototype.resetCounter = async function() {
62 | try {
63 | const response = await api.fetchApi("/counter/reset", {
64 | method: "POST",
65 | headers: {
66 | "Content-Type": "application/json",
67 | },
68 | body: JSON.stringify({
69 | node_id: this.id.toString()
70 | })
71 | });
72 |
73 | const result = await response.json();
74 |
75 | if (result.status === "success") {
76 | console.log("计数器已重置:", result.message);
77 | } else {
78 | console.error("重置计数器失败:", result.message);
79 | }
80 | } catch (error) {
81 | console.error("重置计数器时发生错误:", error);
82 | }
83 | };
84 |
85 | // 添加右键菜单选项
86 | const getExtraMenuOptions = nodeType.prototype.getExtraMenuOptions;
87 | nodeType.prototype.getExtraMenuOptions = function(_, options) {
88 | const r = getExtraMenuOptions ? getExtraMenuOptions.apply(this, arguments) : undefined;
89 |
90 | options.unshift({
91 | content: "重置计数器",
92 | callback: () => {
93 | this.resetCounter();
94 | }
95 | });
96 |
97 | return r;
98 | };
99 | }
100 | },
101 |
102 | // 监听后端发送的计数更新事件
103 | async setup() {
104 | api.addEventListener("counter_update", ({ detail }) => {
105 | if (!detail || !detail.node_id) return;
106 |
107 | const node = app.graph._nodes_by_id[detail.node_id];
108 | if (node && node.type === "LG_Counter") {
109 | // 更新显示文本
110 | node.currentCountText = detail.count.toString();
111 | // 触发重绘
112 | node.setDirtyCanvas(true, true);
113 |
114 | console.log(`[Counter] 节点 ${detail.node_id} 计数更新: ${detail.count}`);
115 | }
116 | });
117 | }
118 | });
119 |
120 |
--------------------------------------------------------------------------------
/web/Util.js:
--------------------------------------------------------------------------------
1 | import { api } from "../../scripts/api.js";
2 | import { app } from "../../scripts/app.js";
3 |
4 | export class Util {
5 |
6 | // Server
7 | static AddMessageListener(messagePath, handlerFunc) {
8 | api.addEventListener(messagePath, handlerFunc);
9 | }
10 |
11 | // Widget
12 | static SetTextAreaContent(widget, text) {
13 | widget.element.textContent = text
14 | }
15 |
16 | static SetTextAreaScrollPos(widget, pos01) {
17 | widget.element.scroll(0, widget.element.scrollHeight * pos01)
18 | }
19 |
20 | static AddReadOnlyTextArea(node, name, text, placeholder = "") {
21 | const inputEl = document.createElement("textarea");
22 | inputEl.className = "comfy-multiline-input";
23 | inputEl.placeholder = placeholder
24 | inputEl.spellcheck = false
25 | inputEl.readOnly = true
26 | inputEl.textContent = text
27 | return node.addDOMWidget(name, "", inputEl, {
28 | serialize: false,
29 | });
30 | }
31 |
32 | static AddButtonWidget(node, label, callback, value = null) {
33 | return node.addWidget("button", label, value, callback);
34 | }
35 |
36 | }
37 |
38 | // 通用终端管理器
39 | export class TerminalManager {
40 | constructor(messagePath, nodeType) {
41 | this.messagePath = messagePath;
42 | this.nodeType = nodeType;
43 | this.textVersion = 0;
44 | this.lines = new Array();
45 |
46 | // 监听消息
47 | Util.AddMessageListener(messagePath, (event) => {
48 | this.textVersion++;
49 | if (event.detail.clear) {
50 | this.lines.length = 0;
51 | }
52 | let totalText = String(event.detail.text || "");
53 | this.lines.push(...(totalText.split("\n")));
54 | if (this.lines.length > 1024) {
55 | this.lines = this.lines.slice(-1024);
56 | }
57 | // 刷新所有相关节点
58 | for (let i = 0; i < app.graph._nodes.length; i++) {
59 | var node = app.graph._nodes[i];
60 | if (node.type == nodeType && node.setDirtyCanvas) {
61 | node.setDirtyCanvas(true);
62 | }
63 | }
64 | });
65 | }
66 |
67 | // 清空终端
68 | clearTerminal() {
69 | this.lines.length = 0;
70 | this.textVersion++;
71 | }
72 |
73 | // 获取终端内容
74 | getContent() {
75 | return this.lines.join("\n");
76 | }
77 |
78 | // 创建节点时的设置
79 | async setupNode(node) {
80 | var textArea = Util.AddReadOnlyTextArea(node, "terminal", "");
81 | setTimeout(() => {
82 | if (textArea.element) {
83 | textArea.element.style.backgroundColor = "#000000ff";
84 | textArea.element.style.color = "#ffffffff";
85 | textArea.element.style.fontFamily = "monospace";
86 | }
87 | }, 0);
88 |
89 | // 动态导入避免重复声明问题
90 | try {
91 | const { MultiButtonWidget } = await import("./multi_button_widget.js");
92 | const { queueSelectedOutputNodes } = await import("./queue_shortcut.js");
93 |
94 | // 添加多按钮组件:清理日志 + 执行
95 | const buttons = [
96 | {
97 | text: "清理日志",
98 | callback: () => {
99 | this.clearTerminal();
100 | }
101 | },
102 | {
103 | text: "执行",
104 | callback: () => {
105 | queueSelectedOutputNodes();
106 | }
107 | }
108 | ];
109 |
110 | const multiButtonWidget = MultiButtonWidget(app, "", {
111 | labelWidth: 0,
112 | buttonSpacing: 4
113 | }, buttons);
114 |
115 | node.addCustomWidget(multiButtonWidget);
116 | } catch (error) {
117 | console.error("Failed to load button components:", error);
118 | // 如果动态导入失败,回退到原来的单按钮
119 | let clearBtn = Util.AddButtonWidget(node, "清空日志", () => {
120 | this.clearTerminal();
121 | });
122 | clearBtn.width = 128;
123 | }
124 |
125 | node.terminalVersion = -1;
126 | return node;
127 | }
128 |
129 |
130 | // 绘制时的更新
131 | updateNode(node, onDrawForeground, ctx, graphcanvas) {
132 | if (node.terminalVersion != this.textVersion) {
133 | node.terminalVersion = this.textVersion;
134 | for (var i = 0; i < node.widgets.length; i++) {
135 | var wid = node.widgets[i];
136 | if (wid.name == "terminal") {
137 | Util.SetTextAreaContent(wid, this.getContent());
138 | Util.SetTextAreaScrollPos(wid, 1.0);
139 | break;
140 | }
141 | }
142 | }
143 | return onDrawForeground?.apply(node, [ctx, graphcanvas]);
144 | }
145 | }
--------------------------------------------------------------------------------
/web/multi_button_widget.js:
--------------------------------------------------------------------------------
1 |
2 | // 通用多按钮组件定义
3 | const MultiButtonWidget = (app, inputName, options, buttons) => {
4 | const widget = {
5 | name: inputName,
6 | type: "multi_button",
7 | y: 0,
8 | value: null,
9 | options: options || {},
10 | clicked_button: null,
11 | click_time: 0
12 | };
13 |
14 | // 使用 ComfyUI 原生样式常量
15 | const margin = 15;
16 | const button_height = LiteGraph.NODE_WIDGET_HEIGHT || 20;
17 | const label_width = options.labelWidth !== undefined ? options.labelWidth : 80;
18 | const button_spacing = options.buttonSpacing || 4;
19 | const button_count = buttons.length;
20 |
21 | // 原生 ComfyUI 按钮颜色
22 | const BUTTON_BGCOLOR = "#222";
23 | const BUTTON_OUTLINE_COLOR = "#666";
24 | const BUTTON_TEXT_COLOR = "#DDD";
25 | const BUTTON_CLICKED_COLOR = "#AAA";
26 |
27 | widget.draw = function(ctx, node, width, Y, height) {
28 | if (app.canvas.ds.scale < 0.50) return;
29 |
30 | ctx.save();
31 | ctx.lineWidth = 1;
32 |
33 | // 绘制标签(仅当有标签时)
34 | if (label_width > 0 && inputName) {
35 | ctx.fillStyle = LiteGraph.WIDGET_SECONDARY_TEXT_COLOR;
36 | ctx.textAlign = "left";
37 | ctx.fillText(inputName, margin, Y + height * 0.7);
38 | }
39 |
40 | // 计算按钮区域
41 | const label_space = label_width > 0 ? label_width : 0;
42 | const left_margin = label_width > 0 ? margin : 10; // 没有标签时使用小边距避免超出节点
43 | const right_margin = label_width > 0 ? margin : 10; // 右边也保持小边距
44 | const available_width = width - label_space - left_margin - right_margin;
45 | const total_spacing = button_spacing * (button_count - 1);
46 | const button_width = (available_width - total_spacing) / button_count;
47 | const start_x = label_space + left_margin;
48 |
49 | // 检查点击高亮是否需要清除
50 | const now = Date.now();
51 | if (widget.click_time && now - widget.click_time > 150) {
52 | widget.clicked_button = null;
53 | widget.click_time = 0;
54 | }
55 |
56 | // 循环绘制所有按钮
57 | for (let i = 0; i < button_count; i++) {
58 | const button = buttons[i];
59 | const button_x = start_x + i * (button_width + button_spacing);
60 |
61 | // 确定按钮颜色(点击高亮或默认)
62 | const is_clicked = (widget.clicked_button === i);
63 | const button_color = is_clicked ? BUTTON_CLICKED_COLOR :
64 | (button.color || BUTTON_BGCOLOR);
65 |
66 | // 绘制按钮背景
67 | ctx.fillStyle = button_color;
68 | ctx.fillRect(button_x, Y, button_width, button_height);
69 |
70 | // 绘制按钮边框
71 | ctx.strokeStyle = BUTTON_OUTLINE_COLOR;
72 | ctx.strokeRect(button_x, Y, button_width, button_height);
73 |
74 | // 绘制按钮文字
75 | ctx.fillStyle = BUTTON_TEXT_COLOR;
76 | ctx.textAlign = "center";
77 | ctx.fillText(button.text, button_x + button_width / 2, Y + button_height * 0.7);
78 | }
79 |
80 | ctx.restore();
81 | };
82 |
83 | widget.onPointerDown = function(pointer, node) {
84 | const e = pointer.eDown;
85 | const label_space = label_width > 0 ? label_width : 0;
86 | const left_margin = label_width > 0 ? margin : 10;
87 | const right_margin = label_width > 0 ? margin : 10;
88 | const x = e.canvasX - node.pos[0] - label_space - left_margin;
89 | const available_width = node.size[0] - label_space - left_margin - right_margin;
90 | const total_spacing = button_spacing * (button_count - 1);
91 | const button_width = (available_width - total_spacing) / button_count;
92 |
93 | pointer.onClick = () => {
94 | // 计算点击了哪个按钮
95 | for (let i = 0; i < button_count; i++) {
96 | const button_start = i * (button_width + button_spacing);
97 | const button_end = button_start + button_width;
98 |
99 | if (x >= button_start && x <= button_end) {
100 | // 点击了第 i 个按钮
101 | widget.clicked_button = i;
102 | widget.click_time = Date.now();
103 |
104 | // 执行回调函数
105 | if (buttons[i] && buttons[i].callback) {
106 | buttons[i].callback();
107 | }
108 |
109 | // 触发重绘以显示点击效果
110 | app.graph.setDirtyCanvas(true);
111 | break;
112 | }
113 | }
114 | };
115 | };
116 |
117 | widget.computeSize = function() {
118 | return [0, button_height];
119 | };
120 |
121 | widget.serializeValue = async () => {
122 | return null;
123 | };
124 |
125 | return widget;
126 | };
127 |
128 | // 导出组件
129 | export { MultiButtonWidget };
130 |
--------------------------------------------------------------------------------
/web/PB.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../scripts/app.js";
2 | import { api } from "../../scripts/api.js";
3 | import { MultiButtonWidget } from "./multi_button_widget.js";
4 |
5 | async function loadLatestImage(node, folder_type) {
6 | // 获取指定目录中的最新图片
7 | const res = await api.fetchApi(`/lg/get/latest_image?type=${folder_type}`, { cache: "no-store" });
8 | if (res.status == 200) {
9 | const item = await res.json();
10 | if (item && item.filename) {
11 | const imageWidget = node.widgets.find(w => w.name === 'image');
12 | if (!imageWidget) return false;
13 |
14 | // 保存文件信息的JSON字符串到节点属性
15 | const fileInfo = JSON.stringify({
16 | filename: item.filename,
17 | subfolder: item.subfolder || '',
18 | type: item.type || 'temp'
19 | });
20 | node._latestFileInfo = fileInfo;
21 |
22 | // 设置 widget 值为 ComfyUI 期望的格式: "filename [type]"
23 | const displayValue = `${item.filename} [${item.type}]`;
24 | imageWidget.value = displayValue;
25 |
26 | // 加载并显示图像
27 | const image = new Image();
28 | image.src = api.apiURL(`/view?filename=${item.filename}&type=${item.type}&subfolder=${item.subfolder || ''}`);
29 | node._imgs = [image];
30 | return true;
31 | }
32 | }
33 | return false;
34 | }
35 |
36 | app.registerExtension({
37 | name: "Comfy.LG.CachePreview",
38 |
39 | nodeCreated(node, app) {
40 | if (node.comfyClass !== "CachePreviewBridge") return;
41 |
42 | let imageWidget = node.widgets.find(w => w.name === 'image');
43 | if (!imageWidget) return;
44 |
45 | // 存储当前文件信息
46 | node._latestFileInfo = null;
47 |
48 | // 重写序列化方法,确保执行时使用最新值
49 | imageWidget.serializeValue = function(nodeId, widgetIndex) {
50 | if (node._latestFileInfo) {
51 | return node._latestFileInfo;
52 | }
53 | return this.value || "";
54 | };
55 |
56 | node._imgs = [new Image()];
57 | node.imageIndex = 0;
58 |
59 | // 使用多按钮组件创建刷新按钮
60 | const refreshWidget = node.addCustomWidget(MultiButtonWidget(app, "Refresh From", {
61 | labelWidth: 80,
62 | buttonSpacing: 4
63 | }, [
64 | {
65 | text: "Temp",
66 | callback: () => {
67 | loadLatestImage(node, "temp").then(success => {
68 | if (success) {
69 | app.graph.setDirtyCanvas(true);
70 | }
71 | });
72 | }
73 | },
74 | {
75 | text: "Output",
76 | callback: () => {
77 | loadLatestImage(node, "output").then(success => {
78 | if (success) {
79 | app.graph.setDirtyCanvas(true);
80 | }
81 | });
82 | }
83 | }
84 | ]));
85 | refreshWidget.serialize = false;
86 |
87 | // 重写 imgs 属性,处理来自 MaskEditor 的粘贴
88 | Object.defineProperty(node, 'imgs', {
89 | set(v) {
90 | if (!v || v.length === 0) return;
91 |
92 | const stackTrace = new Error().stack;
93 |
94 | // 来自 MaskEditor 的粘贴
95 | if (stackTrace.includes('pasteFromClipspace')) {
96 | if (v[0] && v[0].src) {
97 | const urlParts = v[0].src.split("?");
98 | if (urlParts.length > 1) {
99 | const sp = new URLSearchParams(urlParts[1]);
100 | const filename = sp.get('filename');
101 | const type = sp.get('type') || 'input';
102 | const subfolder = sp.get('subfolder') || '';
103 |
104 | if (filename) {
105 | // 保存文件信息的JSON字符串
106 | const fileInfo = JSON.stringify({
107 | filename: filename,
108 | subfolder: subfolder,
109 | type: type
110 | });
111 |
112 | // 保存到节点属性,序列化时会使用
113 | node._latestFileInfo = fileInfo;
114 | imageWidget.value = fileInfo;
115 |
116 | // 直接使用传入的图像
117 | node._imgs = v;
118 | app.graph.setDirtyCanvas(true);
119 |
120 | return;
121 | }
122 | }
123 | }
124 | }
125 |
126 | // 其他情况直接设置
127 | node._imgs = v;
128 | },
129 | get() {
130 | return node._imgs;
131 | }
132 | });
133 | }
134 | });
135 |
--------------------------------------------------------------------------------
/web/image_loader_counter.js:
--------------------------------------------------------------------------------
1 | /**
2 | * LG_ImageLoaderWithCounter 前端扩展
3 | * 为带计数器的图片加载器节点提供前端UI支持
4 | */
5 | import { app } from "../../scripts/app.js";
6 | import { api } from "../../scripts/api.js";
7 |
8 | app.registerExtension({
9 | name: "LG.ImageLoaderWithCounter",
10 |
11 | async beforeRegisterNodeDef(nodeType, nodeData, app) {
12 | if (nodeData.name === "LG_ImageLoaderWithCounter") {
13 |
14 | // 保存原始的 onNodeCreated 方法
15 | const onNodeCreated = nodeType.prototype.onNodeCreated;
16 |
17 | nodeType.prototype.onNodeCreated = function() {
18 | const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;
19 |
20 | // 初始化显示文本
21 | this.currentCountText = "";
22 | this.totalImagesText = "";
23 |
24 | // 添加刷新按钮
25 | this.addWidget("button", "refresh", "🔄 刷新计数器", () => {
26 | this.resetImageLoaderCounter();
27 | });
28 |
29 | return r;
30 | };
31 |
32 | // 重写 onDrawForeground 方法来显示计数和总数
33 | const onDrawForeground = nodeType.prototype.onDrawForeground;
34 | nodeType.prototype.onDrawForeground = function(ctx) {
35 | const r = onDrawForeground?.apply?.(this, arguments);
36 |
37 | // 如果有计数文本要显示
38 | if (this.currentCountText) {
39 | ctx.save();
40 |
41 | // 设置文本样式 - 在标题栏右上角显示
42 | ctx.font = "bold 20px sans-serif";
43 | ctx.textAlign = "right";
44 | ctx.textBaseline = "top";
45 | ctx.fillStyle = "#00ff00"; // 绿色
46 |
47 | // 构建显示文本:当前索引/总数
48 | let displayText = this.currentCountText;
49 | if (this.totalImagesText) {
50 | displayText = `${this.currentCountText}/${this.totalImagesText}`;
51 | }
52 |
53 | // 在标题栏右上角显示计数
54 | const rightX = this.size[0] - 20; // 距离右边20像素
55 | const topY = -25; // 距离顶部25像素,标题栏内
56 |
57 | // 添加文本阴影以提高可读性
58 | ctx.shadowColor = "rgba(0, 0, 0, 0.8)";
59 | ctx.shadowBlur = 3;
60 | ctx.shadowOffsetX = 1;
61 | ctx.shadowOffsetY = 1;
62 |
63 | ctx.fillText(displayText, rightX, topY);
64 |
65 | ctx.restore();
66 | }
67 |
68 | return r;
69 | };
70 |
71 | // 添加重置图片加载器计数器的方法
72 | nodeType.prototype.resetImageLoaderCounter = async function() {
73 | try {
74 | const response = await api.fetchApi("/image_loader_counter/reset", {
75 | method: "POST",
76 | headers: {
77 | "Content-Type": "application/json",
78 | },
79 | body: JSON.stringify({
80 | node_id: this.id.toString()
81 | })
82 | });
83 |
84 | const result = await response.json();
85 |
86 | if (result.status === "success") {
87 | console.log("图片加载器计数器已重置:", result.message);
88 | // 更新显示
89 | this.currentCountText = result.current.toString();
90 | this.setDirtyCanvas(true, true);
91 | } else {
92 | console.error("重置计数器失败:", result.message);
93 | }
94 | } catch (error) {
95 | console.error("重置计数器时发生错误:", error);
96 | }
97 | };
98 |
99 | // 添加右键菜单选项
100 | const getExtraMenuOptions = nodeType.prototype.getExtraMenuOptions;
101 | nodeType.prototype.getExtraMenuOptions = function(_, options) {
102 | const r = getExtraMenuOptions ? getExtraMenuOptions.apply(this, arguments) : undefined;
103 |
104 | options.unshift({
105 | content: "🔄 重置计数器",
106 | callback: () => {
107 | this.resetImageLoaderCounter();
108 | }
109 | });
110 |
111 | return r;
112 | };
113 | }
114 | },
115 |
116 | /**
117 | * 监听后端发送的计数更新事件
118 | */
119 | async setup() {
120 | api.addEventListener("counter_update", ({ detail }) => {
121 | if (!detail || !detail.node_id) return;
122 |
123 | const node = app.graph._nodes_by_id[detail.node_id];
124 | if (node && node.type === "LG_ImageLoaderWithCounter") {
125 | // 更新显示文本
126 | node.currentCountText = detail.count.toString();
127 | // 如果有总数信息,也更新
128 | if (detail.total !== undefined) {
129 | node.totalImagesText = detail.total.toString();
130 | }
131 | // 触发重绘
132 | node.setDirtyCanvas(true, true);
133 |
134 | console.log(`[ImageLoaderCounter] 节点 ${detail.node_id} 索引更新: ${detail.count}/${detail.total || '?'}`);
135 | }
136 | });
137 | }
138 | });
139 |
140 |
--------------------------------------------------------------------------------
/py/inspyrenet/modules/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import cv2
6 | import numpy as np
7 |
8 | from kornia.morphology import dilation, erosion
9 | from torch.nn.parameter import Parameter
10 |
11 | class ImagePyramid:
12 | def __init__(self, ksize=7, sigma=1, channels=1):
13 | self.ksize = ksize
14 | self.sigma = sigma
15 | self.channels = channels
16 |
17 | k = cv2.getGaussianKernel(ksize, sigma)
18 | k = np.outer(k, k)
19 | k = torch.tensor(k).float()
20 | self.kernel = k.repeat(channels, 1, 1, 1)
21 |
22 | def to(self, device):
23 | self.kernel = self.kernel.to(device)
24 | return self
25 |
26 | def cuda(self, idx=None):
27 | if idx is None:
28 | idx = torch.cuda.current_device()
29 |
30 | self.to(device="cuda:{}".format(idx))
31 | return self
32 |
33 | def expand(self, x):
34 | z = torch.zeros_like(x)
35 | x = torch.cat([x, z, z, z], dim=1)
36 | x = F.pixel_shuffle(x, 2)
37 | x = F.pad(x, (self.ksize // 2, ) * 4, mode='reflect')
38 | x = F.conv2d(x, self.kernel * 4, groups=self.channels)
39 | return x
40 |
41 | def reduce(self, x):
42 | x = F.pad(x, (self.ksize // 2, ) * 4, mode='reflect')
43 | x = F.conv2d(x, self.kernel, groups=self.channels)
44 | x = x[:, :, ::2, ::2]
45 | return x
46 |
47 | def deconstruct(self, x):
48 | reduced_x = self.reduce(x)
49 | expanded_reduced_x = self.expand(reduced_x)
50 |
51 | if x.shape != expanded_reduced_x.shape:
52 | expanded_reduced_x = F.interpolate(expanded_reduced_x, x.shape[-2:])
53 |
54 | laplacian_x = x - expanded_reduced_x
55 | return reduced_x, laplacian_x
56 |
57 | def reconstruct(self, x, laplacian_x):
58 | expanded_x = self.expand(x)
59 | if laplacian_x.shape != expanded_x:
60 | laplacian_x = F.interpolate(laplacian_x, expanded_x.shape[-2:], mode='bilinear', align_corners=True)
61 | return expanded_x + laplacian_x
62 |
63 | class Transition:
64 | def __init__(self, k=3):
65 | self.kernel = torch.tensor(cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))).float()
66 |
67 | def to(self, device):
68 | self.kernel = self.kernel.to(device)
69 | return self
70 |
71 | def cuda(self, idx=0):
72 | self.to(device="cuda:{}".format(idx))
73 | return self
74 |
75 | def __call__(self, x):
76 | x = torch.sigmoid(x)
77 | dx = dilation(x, self.kernel)
78 | ex = erosion(x, self.kernel)
79 |
80 | return ((dx - ex) > .5).float()
81 |
82 | class Conv2d(nn.Module):
83 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, padding='same', bias=False, bn=True, relu=False):
84 | super(Conv2d, self).__init__()
85 | if '__iter__' not in dir(kernel_size):
86 | kernel_size = (kernel_size, kernel_size)
87 | if '__iter__' not in dir(stride):
88 | stride = (stride, stride)
89 | if '__iter__' not in dir(dilation):
90 | dilation = (dilation, dilation)
91 |
92 | if padding == 'same':
93 | width_pad_size = kernel_size[0] + (kernel_size[0] - 1) * (dilation[0] - 1)
94 | height_pad_size = kernel_size[1] + (kernel_size[1] - 1) * (dilation[1] - 1)
95 | elif padding == 'valid':
96 | width_pad_size = 0
97 | height_pad_size = 0
98 | else:
99 | if '__iter__' in dir(padding):
100 | width_pad_size = padding[0] * 2
101 | height_pad_size = padding[1] * 2
102 | else:
103 | width_pad_size = padding * 2
104 | height_pad_size = padding * 2
105 |
106 | width_pad_size = width_pad_size // 2 + (width_pad_size % 2 - 1)
107 | height_pad_size = height_pad_size // 2 + (height_pad_size % 2 - 1)
108 | pad_size = (width_pad_size, height_pad_size)
109 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_size, dilation, groups, bias=bias)
110 | self.reset_parameters()
111 |
112 | if bn is True:
113 | self.bn = nn.BatchNorm2d(out_channels)
114 | else:
115 | self.bn = None
116 |
117 | if relu is True:
118 | self.relu = nn.ReLU(inplace=True)
119 | else:
120 | self.relu = None
121 |
122 | def forward(self, x):
123 | x = self.conv(x)
124 | if self.bn is not None:
125 | x = self.bn(x)
126 | if self.relu is not None:
127 | x = self.relu(x)
128 | return x
129 |
130 | def reset_parameters(self):
131 | nn.init.kaiming_normal_(self.conv.weight)
132 |
133 |
134 | class SelfAttention(nn.Module):
135 | def __init__(self, in_channels, mode='hw', stage_size=None):
136 | super(SelfAttention, self).__init__()
137 |
138 | self.mode = mode
139 |
140 | self.query_conv = Conv2d(in_channels, in_channels // 8, kernel_size=(1, 1))
141 | self.key_conv = Conv2d(in_channels, in_channels // 8, kernel_size=(1, 1))
142 | self.value_conv = Conv2d(in_channels, in_channels, kernel_size=(1, 1))
143 |
144 | self.gamma = Parameter(torch.zeros(1))
145 | self.softmax = nn.Softmax(dim=-1)
146 |
147 | self.stage_size = stage_size
148 |
149 | def forward(self, x):
150 | batch_size, channel, height, width = x.size()
151 |
152 | axis = 1
153 | if 'h' in self.mode:
154 | axis *= height
155 | if 'w' in self.mode:
156 | axis *= width
157 |
158 | view = (batch_size, -1, axis)
159 |
160 | projected_query = self.query_conv(x).view(*view).permute(0, 2, 1)
161 | projected_key = self.key_conv(x).view(*view)
162 |
163 | attention_map = torch.bmm(projected_query, projected_key)
164 | attention = self.softmax(attention_map)
165 | projected_value = self.value_conv(x).view(*view)
166 |
167 | out = torch.bmm(projected_value, attention.permute(0, 2, 1))
168 | out = out.view(batch_size, channel, height, width)
169 |
170 | out = self.gamma * out + x
171 | return out
172 |
--------------------------------------------------------------------------------
/py/inspyrenet/InSPyReNet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import numpy as np
7 |
8 | filepath = os.path.abspath(__file__)
9 | repopath = os.path.split(filepath)[0]
10 | sys.path.append(repopath)
11 |
12 | from .modules.layers import *
13 | from .modules.context_module import *
14 | from .modules.attention_module import *
15 | from .modules.decoder_module import *
16 |
17 | from .backbones.SwinTransformer import SwinB
18 |
19 | class InSPyReNet(nn.Module):
20 | def __init__(self, backbone, in_channels, depth=64, base_size=[384, 384], threshold=512, **kwargs):
21 | super(InSPyReNet, self).__init__()
22 | self.backbone = backbone
23 | self.in_channels = in_channels
24 | self.depth = depth
25 | self.base_size = base_size
26 | self.threshold = threshold
27 |
28 | self.context1 = PAA_e(self.in_channels[0], self.depth, base_size=self.base_size, stage=0)
29 | self.context2 = PAA_e(self.in_channels[1], self.depth, base_size=self.base_size, stage=1)
30 | self.context3 = PAA_e(self.in_channels[2], self.depth, base_size=self.base_size, stage=2)
31 | self.context4 = PAA_e(self.in_channels[3], self.depth, base_size=self.base_size, stage=3)
32 | self.context5 = PAA_e(self.in_channels[4], self.depth, base_size=self.base_size, stage=4)
33 |
34 | self.decoder = PAA_d(self.depth * 3, depth=self.depth, base_size=base_size, stage=2)
35 |
36 | self.attention0 = SICA(self.depth , depth=self.depth, base_size=self.base_size, stage=0, lmap_in=True)
37 | self.attention1 = SICA(self.depth * 2, depth=self.depth, base_size=self.base_size, stage=1, lmap_in=True)
38 | self.attention2 = SICA(self.depth * 2, depth=self.depth, base_size=self.base_size, stage=2 )
39 |
40 | self.ret = lambda x, target: F.interpolate(x, size=target.shape[-2:], mode='bilinear', align_corners=False)
41 | self.res = lambda x, size: F.interpolate(x, size=size, mode='bilinear', align_corners=False)
42 | self.des = lambda x, size: F.interpolate(x, size=size, mode='nearest')
43 |
44 | self.image_pyramid = ImagePyramid(7, 1)
45 |
46 | self.transition0 = Transition(17)
47 | self.transition1 = Transition(9)
48 | self.transition2 = Transition(5)
49 |
50 | self.forward = self.forward_inference
51 |
52 | def to(self, device):
53 | self.image_pyramid.to(device)
54 | self.transition0.to(device)
55 | self.transition1.to(device)
56 | self.transition2.to(device)
57 | super(InSPyReNet, self).to(device)
58 | return self
59 |
60 | def cuda(self, idx=None):
61 | if idx is None:
62 | idx = torch.cuda.current_device()
63 |
64 | self.to(device="cuda:{}".format(idx))
65 | return self
66 |
67 | def eval(self):
68 | super(InSPyReNet, self).train(False)
69 | self.forward = self.forward_inference
70 | return self
71 |
72 | def forward_inspyre(self, x):
73 | B, _, H, W = x.shape
74 |
75 | x1, x2, x3, x4, x5 = self.backbone(x)
76 |
77 | x1 = self.context1(x1) #4
78 | x2 = self.context2(x2) #4
79 | x3 = self.context3(x3) #8
80 | x4 = self.context4(x4) #16
81 | x5 = self.context5(x5) #32
82 |
83 | f3, d3 = self.decoder([x3, x4, x5]) #16
84 |
85 | f3 = self.res(f3, (H // 4, W // 4 ))
86 | f2, p2 = self.attention2(torch.cat([x2, f3], dim=1), d3.detach())
87 | d2 = self.image_pyramid.reconstruct(d3.detach(), p2) #4
88 |
89 | x1 = self.res(x1, (H // 2, W // 2))
90 | f2 = self.res(f2, (H // 2, W // 2))
91 | f1, p1 = self.attention1(torch.cat([x1, f2], dim=1), d2.detach(), p2.detach()) #2
92 | d1 = self.image_pyramid.reconstruct(d2.detach(), p1) #2
93 |
94 | f1 = self.res(f1, (H, W))
95 | _, p0 = self.attention0(f1, d1.detach(), p1.detach()) #2
96 | d0 = self.image_pyramid.reconstruct(d1.detach(), p0) #2
97 |
98 | out = dict()
99 | out['saliency'] = [d3, d2, d1, d0]
100 | out['laplacian'] = [p2, p1, p0]
101 |
102 | return out
103 |
104 | def forward_inference(self, img, img_lr=None):
105 | B, _, H, W = img.shape
106 |
107 | if self.threshold is None:
108 | out = self.forward_inspyre(img)
109 | d3, d2, d1, d0 = out['saliency']
110 | p2, p1, p0 = out['laplacian']
111 |
112 | elif (H <= self.threshold or W <= self.threshold):
113 | if img_lr is not None:
114 | out = self.forward_inspyre(img_lr)
115 | else:
116 | out = self.forward_inspyre(img)
117 | d3, d2, d1, d0 = out['saliency']
118 | p2, p1, p0 = out['laplacian']
119 |
120 | else:
121 | # LR Saliency Pyramid
122 | lr_out = self.forward_inspyre(img_lr)
123 | lr_d3, lr_d2, lr_d1, lr_d0 = lr_out['saliency']
124 | lr_p2, lr_p1, lr_p0 = lr_out['laplacian']
125 |
126 | # HR Saliency Pyramid
127 | hr_out = self.forward_inspyre(img)
128 | hr_d3, hr_d2, hr_d1, hr_d0 = hr_out['saliency']
129 | hr_p2, hr_p1, hr_p0 = hr_out['laplacian']
130 |
131 | # Pyramid Blending
132 | d3 = self.ret(lr_d0, hr_d3)
133 |
134 | t2 = self.ret(self.transition2(d3), hr_p2)
135 | p2 = t2 * hr_p2
136 | d2 = self.image_pyramid.reconstruct(d3, p2)
137 |
138 | t1 = self.ret(self.transition1(d2), hr_p1)
139 | p1 = t1 * hr_p1
140 | d1 = self.image_pyramid.reconstruct(d2, p1)
141 |
142 | t0 = self.ret(self.transition0(d1), hr_p0)
143 | p0 = t0 * hr_p0
144 | d0 = self.image_pyramid.reconstruct(d1, p0)
145 |
146 | pred = torch.sigmoid(d0)
147 | pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
148 |
149 | return pred
150 |
151 | def InSPyReNet_SwinB(depth, pretrained, base_size, **kwargs):
152 | return InSPyReNet(SwinB(pretrained=pretrained), [128, 128, 256, 512, 1024], depth, base_size, **kwargs)
--------------------------------------------------------------------------------
/web/combo_setter.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../../scripts/app.js";
2 |
3 | app.registerExtension({
4 | name: "ComboSetter",
5 | async beforeRegisterNodeDef(nodeType, nodeData, app) {
6 | if (nodeData.name === "ComboSetter") {
7 | const onNodeCreated = nodeType.prototype.onNodeCreated;
8 |
9 | nodeType.prototype.onNodeCreated = function() {
10 | const result = onNodeCreated?.apply(this, arguments);
11 |
12 | // 获取widgets
13 | const labelsWidget = this.widgets?.find(w => w.name === "labels");
14 | const promptsWidget = this.widgets?.find(w => w.name === "prompts");
15 | let selectedWidget = this.widgets?.find(w => w.name === "selected");
16 |
17 | if (!labelsWidget || !promptsWidget || !selectedWidget) {
18 | console.error("ComboSetter: 无法找到所需的widgets");
19 | return result;
20 | }
21 |
22 | // 将selected widget转换为combo类型
23 | const selectedIndex = this.widgets.indexOf(selectedWidget);
24 | if (selectedWidget.type !== "combo") {
25 | this.widgets.splice(selectedIndex, 1);
26 | selectedWidget = this.addWidget("combo", "selected", "", () => {}, {
27 | values: [""]
28 | });
29 | // 将新的combo widget移动到正确的位置
30 | const newWidget = this.widgets.pop();
31 | this.widgets.splice(selectedIndex, 0, newWidget);
32 | selectedWidget = newWidget;
33 | }
34 |
35 | // 添加"Set Combo"按钮
36 | const setComboBtn = this.addWidget("button", "Set Combo", null, () => {
37 | updateComboOptions.call(this);
38 | });
39 |
40 | // 设置按钮样式
41 | setComboBtn.serialize = false;
42 |
43 | // 更新Combo选项的函数
44 | const updateComboOptions = function() {
45 | const labelsText = labelsWidget.value || "";
46 | const promptsText = promptsWidget.value || "";
47 |
48 | // 按行分割labels
49 | const labelLines = labelsText.split('\n')
50 | .map(line => line.trim())
51 | .filter(line => line.length > 0);
52 |
53 | if (labelLines.length === 0) {
54 | console.warn("ComboSetter: labels为空");
55 | selectedWidget.options.values = [""];
56 | selectedWidget.value = "";
57 | return;
58 | }
59 |
60 | // 更新combo的选项
61 | selectedWidget.options.values = labelLines;
62 |
63 | // 如果当前选中的值不在新的选项中,设置为第一个选项
64 | if (!labelLines.includes(selectedWidget.value)) {
65 | selectedWidget.value = labelLines[0];
66 | }
67 |
68 | // 触发更新
69 | this.setDirtyCanvas(true, false);
70 |
71 | console.log("ComboSetter: 已更新Combo选项", labelLines);
72 | };
73 |
74 | // 自动计算节点大小
75 | const originalComputeSize = this.computeSize;
76 | this.computeSize = function(out) {
77 | let size = originalComputeSize ? originalComputeSize.apply(this, arguments) : [200, 100];
78 |
79 | // 根据widgets数量动态调整高度
80 | const widgetHeight = 40;
81 | const buttonHeight = 30;
82 | const padding = 20;
83 |
84 | let totalHeight = padding;
85 |
86 | // 计算多行文本框的高度
87 | if (labelsWidget) {
88 | const lines = (labelsWidget.value || "").split('\n').length;
89 | totalHeight += Math.max(lines * 20, 60);
90 | }
91 |
92 | if (promptsWidget) {
93 | const lines = (promptsWidget.value || "").split('\n').length;
94 | totalHeight += Math.max(lines * 20, 60);
95 | }
96 |
97 | // 添加combo和按钮的高度
98 | totalHeight += widgetHeight + buttonHeight + padding;
99 |
100 | size[1] = totalHeight;
101 | size[0] = Math.max(size[0], 300);
102 |
103 | return size;
104 | };
105 |
106 | return result;
107 | };
108 |
109 | // 序列化时保存combo的值
110 | const onSerialize = nodeType.prototype.onSerialize;
111 | nodeType.prototype.onSerialize = function(o) {
112 | const result = onSerialize?.apply(this, arguments);
113 |
114 | const selectedWidget = this.widgets?.find(w => w.name === "selected");
115 | if (selectedWidget && selectedWidget.options && selectedWidget.options.values) {
116 | o.selected_options = selectedWidget.options.values;
117 | }
118 |
119 | return result;
120 | };
121 |
122 | // 反序列化时恢复combo的选项
123 | const onConfigure = nodeType.prototype.onConfigure;
124 | nodeType.prototype.onConfigure = function(o) {
125 | const result = onConfigure?.apply(this, arguments);
126 |
127 | if (o.selected_options) {
128 | const selectedWidget = this.widgets?.find(w => w.name === "selected");
129 | if (selectedWidget) {
130 | selectedWidget.options = selectedWidget.options || {};
131 | selectedWidget.options.values = o.selected_options;
132 | }
133 | }
134 |
135 | return result;
136 | };
137 | }
138 | }
139 | });
140 |
141 |
--------------------------------------------------------------------------------
/py/image_size_adjustment.py:
--------------------------------------------------------------------------------
1 | from .md import *
2 | size_data = {}
3 | class ImageSizeAdjustment:
4 | """图像预览和拉伸调整节点"""
5 |
6 | @classmethod
7 | def INPUT_TYPES(cls):
8 | return {
9 | "required": {
10 | "image": ("IMAGE",),
11 | },
12 | "hidden": {
13 | "unique_id": "UNIQUE_ID",
14 | }
15 | }
16 |
17 | RETURN_TYPES = ("IMAGE",)
18 | FUNCTION = "adjust"
19 | CATEGORY = "🎈LAOGOU/Image"
20 | OUTPUT_NODE = True
21 |
22 | def adjust(self, image, unique_id):
23 | try:
24 | node_id = unique_id
25 |
26 | # 确保清理可能存在的旧数据
27 | if node_id in size_data:
28 | del size_data[node_id]
29 |
30 | event = Event()
31 | size_data[node_id] = {
32 | "event": event,
33 | "result": None
34 | }
35 |
36 | # 发送预览图像
37 | preview_image = (torch.clamp(image.clone(), 0, 1) * 255).cpu().numpy().astype(np.uint8)[0]
38 | pil_image = Image.fromarray(preview_image)
39 | buffer = io.BytesIO()
40 | pil_image.save(buffer, format="PNG")
41 | base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
42 |
43 | try:
44 | PromptServer.instance.send_sync("image_preview_update", {
45 | "node_id": node_id,
46 | "image_data": f"data:image/png;base64,{base64_image}"
47 | })
48 |
49 | # 等待前端调整完成
50 | if not event.wait(timeout=15):
51 | if node_id in size_data:
52 | del size_data[node_id]
53 | return (image,)
54 |
55 | result_image = size_data[node_id]["result"]
56 | del size_data[node_id]
57 | return (result_image if result_image is not None else image,)
58 |
59 | except Exception as e:
60 | if node_id in size_data:
61 | del size_data[node_id]
62 | return (image,)
63 |
64 | except Exception as e:
65 | if node_id in size_data:
66 | del size_data[node_id]
67 | return (image,)
68 |
69 | @PromptServer.instance.routes.post("/image_preview/apply")
70 | async def apply_image_preview(request):
71 | try:
72 | # 检查内容类型
73 | content_type = request.headers.get('Content-Type', '')
74 | print(f"[ImagePreview] 请求内容类型: {content_type}")
75 |
76 | if 'multipart/form-data' in content_type:
77 | # 处理multipart/form-data请求
78 | reader = await request.multipart()
79 |
80 | # 读取表单字段
81 | node_id = None
82 | new_width = None
83 | new_height = None
84 | image_data = None
85 |
86 | # 逐个处理表单字段
87 | while True:
88 | part = await reader.next()
89 | if part is None:
90 | break
91 |
92 | if part.name == 'node_id':
93 | node_id = await part.text()
94 | elif part.name == 'width':
95 | new_width = int(await part.text())
96 | elif part.name == 'height':
97 | new_height = int(await part.text())
98 | elif part.name == 'image_data':
99 | # 读取二进制图像数据
100 | image_data = await part.read(decode=False)
101 | else:
102 | # 处理JSON请求
103 | data = await request.json()
104 | node_id = data.get("node_id")
105 | new_width = data.get("width")
106 | new_height = data.get("height")
107 | image_data = None
108 |
109 | # 检查是否有base64编码的图像数据
110 | adjusted_data_base64 = data.get("adjusted_data_base64")
111 | if adjusted_data_base64:
112 | if adjusted_data_base64.startswith('data:image'):
113 | base64_data = adjusted_data_base64.split(',')[1]
114 | else:
115 | base64_data = adjusted_data_base64
116 | image_data = base64.b64decode(base64_data)
117 |
118 | print(f"[ImagePreview] 接收到数据 - 节点ID: {node_id}")
119 | print(f"[ImagePreview] 接收到的尺寸: {new_width}x{new_height}")
120 |
121 | if node_id not in size_data:
122 | return web.json_response({"success": False, "error": "节点数据不存在"})
123 |
124 | try:
125 | node_info = size_data[node_id]
126 |
127 | if image_data:
128 | try:
129 | # 从二进制数据创建PIL图像
130 | buffer = io.BytesIO(image_data)
131 | pil_image = Image.open(buffer)
132 |
133 | # 转换为RGB模式(如果是RGBA)
134 | if pil_image.mode == 'RGBA':
135 | pil_image = pil_image.convert('RGB')
136 |
137 | # 转换为numpy数组
138 | np_image = np.array(pil_image)
139 |
140 | # 转换为PyTorch张量 - 使用正确的维度顺序 [B, H, W, C]
141 | tensor_image = torch.from_numpy(np_image / 255.0).float().unsqueeze(0)
142 | print(f"[ImagePreview] 从二进制数据创建的张量形状: {tensor_image.shape}")
143 | node_info["result"] = tensor_image
144 | except Exception as e:
145 | print(f"[ImagePreview] 处理图像数据时出错: {str(e)}")
146 | traceback.print_exc()
147 |
148 | # 在成功处理后添加标记
149 | node_info["processed"] = True
150 | node_info["event"].set()
151 | return web.json_response({"success": True})
152 |
153 | except Exception as e:
154 | print(f"[ImagePreview] 处理数据时出错: {str(e)}")
155 | traceback.print_exc()
156 | if node_id in size_data and "event" in size_data[node_id]:
157 | size_data[node_id]["event"].set()
158 | return web.json_response({"success": False, "error": str(e)})
159 |
160 | except Exception as e:
161 | print(f"[ImagePreview] 请求处理出错: {str(e)}")
162 | traceback.print_exc()
163 | return web.json_response({"success": False, "error": str(e)})
164 |
165 | NODE_CLASS_MAPPINGS = {
166 | "ImageSizeAdjustment": ImageSizeAdjustment,
167 | }
168 |
169 | NODE_DISPLAY_NAME_MAPPINGS = {
170 | "ImageSizeAdjustment": "图像尺寸调整",
171 | }
--------------------------------------------------------------------------------
/py/image_cropper.py:
--------------------------------------------------------------------------------
1 |
2 | from .md import *
3 | crop_node_data = {}
4 | class ImageCropper:
5 | """图像裁剪专用节点"""
6 |
7 | @classmethod
8 | def INPUT_TYPES(cls):
9 | return {
10 | "required": {
11 | "image": ("IMAGE",),
12 | },
13 | "hidden": {
14 | "unique_id": "UNIQUE_ID",
15 | }
16 | }
17 |
18 | RETURN_TYPES = ("IMAGE",)
19 | RETURN_NAMES = ("裁剪图像",)
20 | FUNCTION = "crop"
21 | CATEGORY = "🎈LAOGOU/Image"
22 |
23 | def crop(self, image, unique_id):
24 | try:
25 | node_id = unique_id
26 | event = Event()
27 |
28 | # 初始化节点数据
29 | crop_node_data[node_id] = {
30 | "event": event,
31 | "result": None,
32 | "processing_complete": False
33 | }
34 |
35 | # 发送预览图像
36 | preview_image = (torch.clamp(image.clone(), 0, 1) * 255).cpu().numpy().astype(np.uint8)[0]
37 | pil_image = Image.fromarray(preview_image)
38 | buffer = io.BytesIO()
39 | pil_image.save(buffer, format="PNG")
40 | base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
41 |
42 | try:
43 | PromptServer.instance.send_sync("image_cropper_update", {
44 | "node_id": node_id,
45 | "image_data": f"data:image/png;base64,{base64_image}"
46 | })
47 |
48 | # 等待前端裁剪完成
49 | if not event.wait(timeout=30):
50 | print(f"[ImageCropper] 等待超时: 节点ID {node_id}")
51 | if node_id in crop_node_data:
52 | del crop_node_data[node_id]
53 | return (image,)
54 |
55 | # 获取结果
56 | result_image = None
57 |
58 | if node_id in crop_node_data:
59 | result_image = crop_node_data[node_id]["result"]
60 | del crop_node_data[node_id]
61 |
62 | return (result_image if result_image is not None else image,)
63 |
64 | except Exception as e:
65 | print(f"[ImageCropper] 处理过程中出错: {str(e)}")
66 | traceback.print_exc()
67 | if node_id in crop_node_data:
68 | del crop_node_data[node_id]
69 | return (image,)
70 |
71 | except Exception as e:
72 | print(f"[ImageCropper] 节点执行出错: {str(e)}")
73 | traceback.print_exc()
74 | return (image,)
75 |
76 | @PromptServer.instance.routes.post("/image_cropper/apply")
77 | async def apply_image_cropper(request):
78 | try:
79 | # 检查内容类型
80 | content_type = request.headers.get('Content-Type', '')
81 | print(f"[ImageCropper] 请求内容类型: {content_type}")
82 |
83 | node_id = None
84 | crop_width = None
85 | crop_height = None
86 | image_data = None
87 |
88 | if 'multipart/form-data' in content_type:
89 | # 处理multipart/form-data请求
90 | reader = await request.multipart()
91 |
92 | # 读取表单字段
93 | while True:
94 | part = await reader.next()
95 | if part is None:
96 | break
97 |
98 | if part.name == 'node_id':
99 | node_id = await part.text()
100 | elif part.name == 'width':
101 | crop_width = int(await part.text())
102 | elif part.name == 'height':
103 | crop_height = int(await part.text())
104 | elif part.name == 'image_data':
105 | image_data = await part.read(decode=False)
106 | else:
107 | # 处理JSON请求
108 | data = await request.json()
109 | node_id = data.get("node_id")
110 | crop_width = data.get("width")
111 | crop_height = data.get("height")
112 |
113 | cropped_data_base64 = data.get("cropped_data_base64")
114 | if cropped_data_base64:
115 | if cropped_data_base64.startswith('data:image'):
116 | base64_data = cropped_data_base64.split(',')[1]
117 | else:
118 | base64_data = cropped_data_base64
119 | image_data = base64.b64decode(base64_data)
120 |
121 | if node_id not in crop_node_data:
122 | crop_node_data[node_id] = {
123 | "event": Event(),
124 | "result": None,
125 | "processing_complete": False
126 | }
127 |
128 | try:
129 | node_info = crop_node_data[node_id]
130 |
131 | if image_data:
132 | try:
133 | buffer = io.BytesIO(image_data)
134 | pil_image = Image.open(buffer)
135 |
136 | if pil_image.mode == 'RGBA':
137 | pil_image = pil_image.convert('RGB')
138 |
139 | np_image = np.array(pil_image)
140 |
141 | if len(np_image.shape) == 3 and np_image.shape[2] == 3:
142 | tensor_image = torch.from_numpy(np_image / 255.0).float().unsqueeze(0)
143 | node_info["result"] = tensor_image
144 | node_info["event"].set()
145 | else:
146 | print(f"[ImageCropper] 警告: 图像数组形状不符合预期: {np_image.shape}")
147 | except Exception as e:
148 | print(f"[ImageCropper] 处理图像数据时出错: {str(e)}")
149 | traceback.print_exc()
150 | node_info["event"].set()
151 |
152 | return web.json_response({"success": True})
153 |
154 | except Exception as e:
155 | print(f"[ImageCropper] 处理数据时出错: {str(e)}")
156 | traceback.print_exc()
157 | if node_id in crop_node_data and "event" in crop_node_data[node_id]:
158 | crop_node_data[node_id]["event"].set()
159 | return web.json_response({"success": False, "error": str(e)})
160 |
161 | except Exception as e:
162 | print(f"[ImageCropper] 请求处理出错: {str(e)}")
163 | traceback.print_exc()
164 | return web.json_response({"success": False, "error": str(e)})
165 |
166 | @PromptServer.instance.routes.post("/image_cropper/cancel")
167 | async def cancel_crop(request):
168 | try:
169 | data = await request.json()
170 | node_id = data.get("node_id")
171 |
172 | if node_id in crop_node_data:
173 | # 设置事件,让节点继续执行
174 | crop_node_data[node_id]["event"].set()
175 | print(f"[ImageCropper] 取消裁剪操作: 节点ID {node_id}")
176 | return web.json_response({"success": True})
177 |
178 | return web.json_response({"success": False, "error": "节点未找到"})
179 |
180 | except Exception as e:
181 | print(f"[ImageCropper] 取消请求处理出错: {str(e)}")
182 | traceback.print_exc()
183 | return web.json_response({"success": False, "error": str(e)})
184 |
185 | NODE_CLASS_MAPPINGS = {
186 | "ImageCropper": ImageCropper,
187 | }
188 |
189 | NODE_DISPLAY_NAME_MAPPINGS = {
190 | "ImageCropper": "图像裁剪",
191 | }
192 |
--------------------------------------------------------------------------------
/web/queue_shortcut.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../scripts/app.js";
2 | import { api } from "../../scripts/api.js";
3 |
4 | class EventManager {
5 | constructor() {
6 | this.listeners = new Map();
7 | }
8 |
9 | addEventListener(event, callback) {
10 | if (!this.listeners.has(event)) {
11 | this.listeners.set(event, []);
12 | }
13 | this.listeners.get(event).push(callback);
14 | }
15 |
16 | removeEventListener(event, callback) {
17 | if (this.listeners.has(event)) {
18 | const callbacks = this.listeners.get(event);
19 | const index = callbacks.indexOf(callback);
20 | if (index > -1) {
21 | callbacks.splice(index, 1);
22 | }
23 | }
24 | }
25 |
26 | dispatchEvent(event, detail = {}) {
27 | if (this.listeners.has(event)) {
28 | const callbacks = this.listeners.get(event);
29 | callbacks.forEach(callback => {
30 | try {
31 | callback({ detail });
32 | } catch (error) {
33 | console.error(`Error in event listener for ${event}:`, error);
34 | }
35 | });
36 | }
37 | }
38 | }
39 |
40 | class QueueManager {
41 | constructor() {
42 | this.eventManager = new EventManager();
43 | this.queueNodeIds = null;
44 | this.processingQueue = false;
45 | this.lastAdjustedMouseEvent = null;
46 | this.isLGTriggered = false; // 标记是否由 LG 扩展触发
47 | this.initializeHooks();
48 | }
49 |
50 | initializeHooks() {
51 | const originalQueuePrompt = app.queuePrompt;
52 | const originalGraphToPrompt = app.graphToPrompt;
53 | const originalApiQueuePrompt = api.queuePrompt;
54 |
55 | app.queuePrompt = async function() {
56 | this.processingQueue = true;
57 | this.eventManager.dispatchEvent("queue");
58 | try {
59 | await originalQueuePrompt.apply(app, [...arguments]);
60 | } finally {
61 | this.processingQueue = false;
62 | this.eventManager.dispatchEvent("queue-end");
63 | }
64 | }.bind(this);
65 |
66 | app.graphToPrompt = async function() {
67 | this.eventManager.dispatchEvent("graph-to-prompt");
68 | let promise = originalGraphToPrompt.apply(app, [...arguments]);
69 | await promise;
70 | this.eventManager.dispatchEvent("graph-to-prompt-end");
71 | return promise;
72 | }.bind(this);
73 |
74 | api.queuePrompt = async function(index, prompt, ...args) {
75 | // 仅在 LG 扩展触发时才修改 prompt.output
76 | if (this.isLGTriggered && this.queueNodeIds && this.queueNodeIds.length && prompt.output) {
77 | const oldOutput = prompt.output;
78 | let newOutput = {};
79 | for (const queueNodeId of this.queueNodeIds) {
80 | this.recursiveAddNodes(String(queueNodeId), oldOutput, newOutput);
81 | }
82 | prompt.output = newOutput;
83 | }
84 |
85 | this.eventManager.dispatchEvent("comfy-api-queue-prompt-before", {
86 | workflow: prompt.workflow,
87 | output: prompt.output,
88 | });
89 |
90 | const response = originalApiQueuePrompt.apply(api, [index, prompt, ...args]);
91 | this.eventManager.dispatchEvent("comfy-api-queue-prompt-end");
92 | return response;
93 | }.bind(this);
94 |
95 | const originalProcessMouseDown = LGraphCanvas.prototype.processMouseDown;
96 | const originalAdjustMouseEvent = LGraphCanvas.prototype.adjustMouseEvent;
97 | const originalProcessMouseMove = LGraphCanvas.prototype.processMouseMove;
98 |
99 | LGraphCanvas.prototype.processMouseDown = function(e) {
100 | const result = originalProcessMouseDown.apply(this, [...arguments]);
101 | queueManager.lastAdjustedMouseEvent = e;
102 | return result;
103 | };
104 |
105 | LGraphCanvas.prototype.adjustMouseEvent = function(e) {
106 | originalAdjustMouseEvent.apply(this, [...arguments]);
107 | queueManager.lastAdjustedMouseEvent = e;
108 | };
109 |
110 | LGraphCanvas.prototype.processMouseMove = function(e) {
111 | const result = originalProcessMouseMove.apply(this, [...arguments]);
112 | if (e && !e.canvasX && !e.canvasY) {
113 | const canvas = app.canvas;
114 | const offset = canvas.convertEventToCanvasOffset(e);
115 | e.canvasX = offset[0];
116 | e.canvasY = offset[1];
117 | }
118 | queueManager.lastAdjustedMouseEvent = e;
119 | return result;
120 | };
121 | }
122 | recursiveAddNodes(nodeId, oldOutput, newOutput) {
123 | let currentId = nodeId;
124 | let currentNode = oldOutput[currentId];
125 | if (newOutput[currentId] == null) {
126 | newOutput[currentId] = currentNode;
127 | for (const inputValue of Object.values(currentNode.inputs || [])) {
128 | if (Array.isArray(inputValue)) {
129 | this.recursiveAddNodes(inputValue[0], oldOutput, newOutput);
130 | }
131 | }
132 | }
133 | return newOutput;
134 | }
135 | async queueOutputNodes(nodeIds) {
136 | try {
137 | this.queueNodeIds = nodeIds;
138 | this.isLGTriggered = true; // 设置 LG 触发标记
139 | await app.queuePrompt();
140 | } catch (e) {
141 | console.error("队列节点时出错:", e);
142 | } finally {
143 | this.queueNodeIds = null;
144 | this.isLGTriggered = false; // 清除 LG 触发标记
145 | }
146 | }
147 | getLastMouseEvent() {
148 | return this.lastAdjustedMouseEvent;
149 | }
150 | addEventListener(event, callback) {
151 | this.eventManager.addEventListener(event, callback);
152 | }
153 | removeEventListener(event, callback) {
154 | this.eventManager.removeEventListener(event, callback);
155 | }
156 | }
157 |
158 | function getOutputNodes(nodes) {
159 | return (nodes?.filter((n) => {
160 | return (n.mode != LiteGraph.NEVER &&
161 | n.constructor.nodeData?.output_node);
162 | }) || []);
163 | }
164 | const queueManager = new QueueManager();
165 | function queueSelectedOutputNodes() {
166 | const selectedNodes = app.canvas.selected_nodes;
167 | if (!selectedNodes || Object.keys(selectedNodes).length === 0) {
168 | console.log("[LG]队列: 没有选中的节点");
169 | return;
170 | }
171 |
172 | const outputNodes = getOutputNodes(Object.values(selectedNodes));
173 | if (!outputNodes || outputNodes.length === 0) {
174 | console.log("[LG]队列: 选中的节点中没有输出节点");
175 | return;
176 | }
177 |
178 | console.log(`[LG]队列: 执行 ${outputNodes.length} 个输出节点`);
179 | queueManager.queueOutputNodes(outputNodes.map((n) => n.id));
180 | }
181 |
182 | function queueGroupOutputNodes() {
183 | const lastMouseEvent = queueManager.getLastMouseEvent();
184 | if (!lastMouseEvent) {
185 | return;
186 | }
187 |
188 | let canvasX = lastMouseEvent.canvasX;
189 | let canvasY = lastMouseEvent.canvasY;
190 |
191 | if (!canvasX || !canvasY) {
192 | const canvas = app.canvas;
193 | const mousePos = canvas.getMousePos();
194 | canvasX = mousePos[0];
195 | canvasY = mousePos[1];
196 | }
197 |
198 | const group = app.graph.getGroupOnPos(canvasX, canvasY);
199 |
200 | if (!group) {
201 | return;
202 | }
203 |
204 | group.recomputeInsideNodes();
205 |
206 | if (!group._nodes || group._nodes.length === 0) {
207 | return;
208 | }
209 |
210 | const outputNodes = getOutputNodes(group._nodes);
211 | if (!outputNodes || outputNodes.length === 0) {
212 | return;
213 | }
214 |
215 | queueManager.queueOutputNodes(outputNodes.map((n) => n.id));
216 | }
217 |
218 | app.registerExtension({
219 | name: "LG.QueueNodes",
220 | commands: [
221 | {
222 | id: "LG.QueueSelectedOutputNodes",
223 | icon: "pi pi-play",
224 | label: "执行选中的输出节点",
225 | function: queueSelectedOutputNodes
226 | },
227 | {
228 | id: "LG.QueueGroupOutputNodes",
229 | icon: "pi pi-sitemap",
230 | label: "执行组内输出节点",
231 | function: queueGroupOutputNodes
232 | }
233 | ]
234 | });
235 |
236 | export { queueManager, getOutputNodes, queueSelectedOutputNodes, queueGroupOutputNodes };
237 |
238 |
--------------------------------------------------------------------------------
/web/bridge_preview.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../scripts/app.js";
2 | import { api } from "../../scripts/api.js";
3 | const waitingNodes = new Map();
4 | const processedNodes = new Set();
5 | function updateNodePreview(nodeId, fileInfo) {
6 | const node = app.graph.getNodeById(parseInt(nodeId));
7 | if (!node || !fileInfo) return;
8 | const imageUrl = fileInfo.subfolder
9 | ? `/view?filename=${encodeURIComponent(fileInfo.filename)}&subfolder=${encodeURIComponent(fileInfo.subfolder)}&type=${fileInfo.type || 'input'}`
10 | : `/view?filename=${encodeURIComponent(fileInfo.filename)}&type=${fileInfo.type || 'input'}`;
11 | node.images = [{
12 | filename: fileInfo.filename,
13 | subfolder: fileInfo.subfolder || "",
14 | type: fileInfo.type || "input",
15 | url: imageUrl
16 | }];
17 | if (node.onDrawBackground) {
18 | const img = new Image();
19 | img.onload = () => {
20 | node.imgs = [img];
21 | app.graph.setDirtyCanvas(true);
22 | };
23 | img.src = imageUrl;
24 | }
25 | updateStatusWidget(node, `${nodeId}_${fileInfo.filename}`);
26 | }
27 | function updateStatusWidget(node, statusText) {
28 | const statusWidget = node.widgets?.find(w => w.name === "file_info");
29 | if (statusWidget) {
30 | statusWidget.value = statusText;
31 | app.graph.setDirtyCanvas(true);
32 | app.graph.change();
33 | }
34 | }
35 | api.addEventListener("bridge_preview_update", (event) => {
36 | const { node_id, urls } = event.detail;
37 | const node = app.graph._nodes_by_id[node_id];
38 | if (!node || !urls?.length) return;
39 | waitingNodes.set(node_id, { urls, timestamp: Date.now(), node });
40 | const imageData = urls.map((url, index) => ({
41 | index,
42 | filename: url.filename,
43 | subfolder: url.subfolder,
44 | type: url.type
45 | }));
46 | node.imageData = imageData;
47 | node.imgs = [];
48 | imageData.forEach((imgData, i) => {
49 | const img = new Image();
50 | img.onload = () => {
51 | app.graph.setDirtyCanvas(true);
52 | if (i === imageData.length - 1) {
53 | setupClipspace(node_id, urls);
54 | }
55 | };
56 | img.src = `/view?filename=${encodeURIComponent(imgData.filename)}&type=${imgData.type}&subfolder=${imgData.subfolder || ''}&${app.getPreviewFormatParam()}`;
57 | node.imgs.push(img);
58 | });
59 | node.setSizeForImage?.();
60 | node.update?.();
61 | });
62 | function setupClipspace(nodeId, urls) {
63 | const ComfyApp = app.constructor;
64 | if (!ComfyApp.clipspace) ComfyApp.clipspace = {};
65 | if (!app.clipspace) app.clipspace = {};
66 | const images = urls.map(url => ({
67 | filename: url.filename,
68 | subfolder: url.subfolder || "",
69 | type: url.type || "output"
70 | }));
71 | const imgs = urls.map(url => ({
72 | src: `${window.location.origin}/view?filename=${encodeURIComponent(url.filename)}&type=${url.type}&subfolder=${encodeURIComponent(url.subfolder || '')}`,
73 | filename: url.filename,
74 | subfolder: url.subfolder || "",
75 | type: url.type || "output"
76 | }));
77 | [ComfyApp.clipspace, app.clipspace].forEach(clipspace => {
78 | clipspace.images = images;
79 | clipspace.imgs = imgs;
80 | clipspace.selectedIndex = 0;
81 | });
82 | setTimeout(() => {
83 | const node = app.graph.getNodeById(parseInt(nodeId));
84 | if (!node) return;
85 |
86 | app.canvas.selectNode(node);
87 |
88 | let success = false;
89 |
90 | try {
91 | if (app.extensionManager?.command?.execute) {
92 | app.extensionManager.command.execute('Comfy.MaskEditor.OpenMaskEditor');
93 | success = true;
94 | }
95 | } catch (error) {
96 | // Silently fail and try next method
97 | }
98 |
99 | if (!success) {
100 | try {
101 | const ComfyApp = app.constructor;
102 | const openMaskEditor = ComfyApp.open_maskeditor || app.open_maskeditor;
103 | if (openMaskEditor && typeof openMaskEditor === 'function') {
104 | openMaskEditor();
105 | success = true;
106 | }
107 | } catch (error) {
108 | // Silently fail
109 | }
110 | }
111 |
112 | bindCancelButton();
113 | }, 100);
114 | }
115 | function bindCancelButton() {
116 | const checkInterval = setInterval(() => {
117 | const maskEditor = findMaskEditor();
118 | if (!maskEditor) return;
119 |
120 | const cancelButtons = Array.from(maskEditor.querySelectorAll('button')).filter(btn => {
121 | const text = btn.textContent.trim().toLowerCase();
122 | return text === 'cancel' || text === '取消' || text.includes('cancel') || text.includes('取消');
123 | });
124 |
125 | if (cancelButtons.length > 0) {
126 | cancelButtons.forEach(button => {
127 | if (!button.hasAttribute('data-bridge-bound')) {
128 | button.setAttribute('data-bridge-bound', 'true');
129 | button.addEventListener('click', () => {
130 | setTimeout(handleMaskEditorCancel, 50);
131 | }, { capture: true });
132 | }
133 | });
134 | clearInterval(checkInterval);
135 | }
136 | }, 300);
137 | setTimeout(() => clearInterval(checkInterval), 10000);
138 | }
139 | function findMaskEditor() {
140 | const newMaskEditor = document.querySelector('.mask-editor-dialog');
141 | if (newMaskEditor) {
142 | return newMaskEditor;
143 | }
144 |
145 | const modals = document.querySelectorAll('div.comfy-modal, .comfy-modal, [class*="modal"]');
146 | for (const modal of modals) {
147 | if (modal.querySelector('canvas') && modal.style.display !== 'none') {
148 | return modal;
149 | }
150 | }
151 |
152 | const elements = document.querySelectorAll('*');
153 | for (const element of elements) {
154 | const buttons = element.querySelectorAll('button');
155 | if (buttons.length >= 2 && element.querySelector('canvas')) {
156 | const buttonTexts = Array.from(buttons).map(btn => btn.textContent.trim().toLowerCase());
157 | if (buttonTexts.some(text => text.includes('cancel') || text.includes('取消')) &&
158 | buttonTexts.some(text => text.includes('save') || text.includes('保存'))) {
159 | return element;
160 | }
161 | }
162 | }
163 |
164 | return null;
165 | }
166 | function handleMaskEditorCancel() {
167 | waitingNodes.forEach((nodeInfo, nodeId) => {
168 | sendCancelSignal(nodeId);
169 | });
170 |
171 | waitingNodes.clear();
172 | processedNodes.clear();
173 | }
174 | async function sendCancelSignal(nodeId) {
175 | try {
176 | await api.fetchApi("/bridge_preview/cancel", {
177 | method: "POST",
178 | headers: { "Content-Type": "application/json" },
179 | body: JSON.stringify({ node_id: String(nodeId) })
180 | });
181 | } catch (error) {
182 | console.error(`[BridgePreview] Failed to send cancel signal:`, error);
183 | }
184 | }
185 | const originalFetch = api.fetchApi;
186 | api.fetchApi = async function(url, options) {
187 | const result = await originalFetch.call(this, url, options);
188 | if (url === "/upload/mask" && result.ok) {
189 | await handleMaskUpload(result);
190 | }
191 | return result;
192 | };
193 | async function handleMaskUpload(result) {
194 | try {
195 | const responseData = await result.clone().json();
196 | const fileInfo = responseData?.name ? {
197 | filename: responseData.name,
198 | subfolder: responseData.subfolder || "clipspace",
199 | type: responseData.type || "input"
200 | } : null;
201 | if (!fileInfo) return;
202 | let latestNodeId = null;
203 | let latestTimestamp = 0;
204 | for (const [nodeId, nodeInfo] of waitingNodes) {
205 | if (nodeInfo.timestamp > latestTimestamp) {
206 | latestTimestamp = nodeInfo.timestamp;
207 | latestNodeId = nodeId;
208 | }
209 | }
210 | if (!latestNodeId) return;
211 | try {
212 | updateNodePreview(latestNodeId, fileInfo);
213 | const confirmResponse = await originalFetch.call(api, "/bridge_preview/confirm", {
214 | method: "POST",
215 | headers: { "Content-Type": "application/json" },
216 | body: JSON.stringify({
217 | node_id: String(latestNodeId),
218 | file_info: fileInfo
219 | })
220 | });
221 | if (confirmResponse.ok) {
222 | processedNodes.add(String(latestNodeId));
223 | setTimeout(() => {
224 | waitingNodes.delete(latestNodeId);
225 | }, 1000);
226 | }
227 | } catch (error) {
228 | console.error(`[BridgePreview] Failed to process node ${latestNodeId}:`, error);
229 | }
230 | } catch (error) {
231 | console.error(`[BridgePreview] Failed to handle mask result:`, error);
232 | }
233 | }
234 | setInterval(() => {
235 | const now = Date.now();
236 | for (const [nodeId, nodeInfo] of waitingNodes) {
237 | if (now - nodeInfo.timestamp > 30000) {
238 | waitingNodes.delete(nodeId);
239 | processedNodes.delete(String(nodeId));
240 | }
241 | }
242 | }, 5000);
243 | console.log("[BridgePreview] Bridge preview module loaded");
--------------------------------------------------------------------------------
/py/inspyrenet/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import cv2
4 | import yaml
5 | import torch
6 | import hashlib
7 | import argparse
8 |
9 | import albumentations as A
10 | from albumentations.core.transforms_interface import ImageOnlyTransform
11 |
12 | import numpy as np
13 |
14 | from PIL import Image
15 | from threading import Thread
16 | from easydict import EasyDict
17 |
18 | def parse_args():
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('--source', '-s', type=str, help="Path to the source. Single image, video, directory of images, directory of videos is supported.")
21 | parser.add_argument('--dest', '-d', type=str, default=None, help="Path to destination. Results will be stored in current directory if not specified.")
22 | parser.add_argument('--type', '-t', type=str, default='rgba', help="Specify output type. If not specified, output results will make the background transparent. Please refer to the documentation for other types.")
23 | parser.add_argument('--fast', '-f', action='store_true', help="(Deprecated) Speed up inference speed by using small scale, but decreases output quality.")
24 | parser.add_argument('--jit', '-j', action='store_true', help="Speed up inference speed by using torchscript, but decreases output quality.")
25 | parser.add_argument('--device', '-D', type=str, default=None, help="Designate device. If not specified, it will find available device.")
26 | parser.add_argument('--mode', '-m', type=str, default='base', help="choose between base and fast mode. Also, use base-nightly for nightly release checkpoint.")
27 | parser.add_argument('--ckpt', '-c', type=str, default=None, help="Designate checkpoint. If not specified, it will download or load pre-downloaded default checkpoint.")
28 | parser.add_argument('--threshold', '-th', type=str, default=None, help="Designate threshold. If specified, it will output hard prediction above threshold. If not specified, it will output soft prediction.")
29 | return parser.parse_args()
30 |
31 | def get_backend():
32 | if torch.cuda.is_available():
33 | return "cuda:0"
34 | elif torch.backends.mps.is_available():
35 | return "mps:0"
36 | else:
37 | return "cpu"
38 |
39 | def load_config(config_dir, easy=True):
40 | cfg = yaml.load(open(config_dir), yaml.FullLoader)
41 | if easy is True:
42 | cfg = EasyDict(cfg)
43 | return cfg
44 |
45 | def get_format(source):
46 | img_count = len([i for i in source if i.lower().endswith(('.jpg', '.png', '.jpeg'))])
47 | vid_count = len([i for i in source if i.lower().endswith(('.mp4', '.avi', '.mov' ))])
48 |
49 | if img_count * vid_count != 0:
50 | return ''
51 | elif img_count != 0:
52 | return 'Image'
53 | elif vid_count != 0:
54 | return 'Video'
55 | else:
56 | return ''
57 |
58 | def sort(x):
59 | convert = lambda text: int(text) if text.isdigit() else text.lower()
60 | alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
61 | return sorted(x, key=alphanum_key)
62 |
63 | def download_and_unzip(filename, url, dest, unzip=True, **kwargs):
64 | if not os.path.isdir(dest):
65 | os.makedirs(dest, exist_ok=True)
66 |
67 | if os.path.isfile(os.path.join(dest, filename)) is False:
68 | os.system("wget -O {} {}".format(os.path.join(dest, filename), url))
69 | elif 'md5' in kwargs.keys() and kwargs['md5'] != hashlib.md5(open(os.path.join(dest, filename), 'rb').read()).hexdigest():
70 | os.system("wget -O {} {}".format(os.path.join(dest, filename), url))
71 |
72 | if unzip:
73 | os.system("unzip -o {} -d {}".format(os.path.join(dest, filename), dest))
74 | os.system("rm {}".format(os.path.join(dest, filename)))
75 |
76 | class dynamic_resize:
77 | def __init__(self, L=1280):
78 | self.L = L
79 |
80 | def __call__(self, img):
81 | size = list(img.size)
82 | if (size[0] >= size[1]) and size[1] > self.L:
83 | size[0] = size[0] / (size[1] / self.L)
84 | size[1] = self.L
85 | elif (size[1] > size[0]) and size[0] > self.L:
86 | size[1] = size[1] / (size[0] / self.L)
87 | size[0] = self.L
88 | size = (int(round(size[0] / 32)) * 32, int(round(size[1] / 32)) * 32)
89 |
90 | return img.resize(size, Image.BILINEAR)
91 |
92 | class dynamic_resize_a(ImageOnlyTransform):
93 | def __init__(self, L=1280, always_apply=False, p=1.0):
94 | super(dynamic_resize_a, self).__init__(always_apply, p)
95 | self.L = L
96 |
97 | def apply(self, img, **params):
98 | size = list(img.shape[:2])
99 | if (size[0] >= size[1]) and size[1] > self.L:
100 | size[0] = size[0] / (size[1] / self.L)
101 | size[1] = self.L
102 | elif (size[1] > size[0]) and size[0] > self.L:
103 | size[1] = size[1] / (size[0] / self.L)
104 | size[0] = self.L
105 | size = (int(round(size[0] / 32)) * 32, int(round(size[1] / 32)) * 32)
106 |
107 | return A.resize(img, height=size[0], width=size[1])
108 |
109 | def get_transform_init_args_names(self):
110 | return ("L",)
111 |
112 | class static_resize:
113 | def __init__(self, size=[1024, 1024]):
114 | self.size = size
115 |
116 | def __call__(self, img):
117 | return img.resize(self.size, Image.BILINEAR)
118 |
119 | class normalize:
120 | def __init__(self, mean=None, std=None, div=255):
121 | self.mean = mean if mean is not None else 0.0
122 | self.std = std if std is not None else 1.0
123 | self.div = div
124 |
125 | def __call__(self, img):
126 | img /= self.div
127 | img -= self.mean
128 | img /= self.std
129 |
130 | return img
131 |
132 | class tonumpy:
133 | def __init__(self):
134 | pass
135 |
136 | def __call__(self, img):
137 | img = np.array(img, dtype=np.float32)
138 | return img
139 |
140 | class totensor:
141 | def __init__(self):
142 | pass
143 |
144 | def __call__(self, img):
145 | img = img.transpose((2, 0, 1))
146 | img = torch.from_numpy(img).float()
147 |
148 | return img
149 |
150 | class ImageLoader:
151 | def __init__(self, root):
152 | if os.path.isdir(root):
153 | self.images = [os.path.join(root, f) for f in os.listdir(root) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
154 | self.images = sort(self.images)
155 | elif os.path.isfile(root):
156 | self.images = [root]
157 | self.size = len(self.images)
158 |
159 | def __iter__(self):
160 | self.index = 0
161 | return self
162 |
163 | def __next__(self):
164 | if self.index == self.size:
165 | raise StopIteration
166 |
167 | img = Image.open(self.images[self.index]).convert('RGB')
168 | name = os.path.split(self.images[self.index])[-1]
169 | # name = os.path.splitext(name)[0]
170 |
171 | self.index += 1
172 | return img, name
173 |
174 | def __len__(self):
175 | return self.size
176 |
177 | class VideoLoader:
178 | def __init__(self, root):
179 | if os.path.isdir(root):
180 | self.videos = [os.path.join(root, f) for f in os.listdir(root) if f.lower().endswith(('.mp4', '.avi', 'mov'))]
181 | elif os.path.isfile(root):
182 | self.videos = [root]
183 | self.size = len(self.videos)
184 |
185 | def __iter__(self):
186 | self.index = 0
187 | self.cap = None
188 | self.fps = None
189 | return self
190 |
191 | def __next__(self):
192 | if self.index == self.size:
193 | raise StopIteration
194 |
195 | if self.cap is None:
196 | self.cap = cv2.VideoCapture(self.videos[self.index])
197 | self.fps = self.cap.get(cv2.CAP_PROP_FPS)
198 | ret, frame = self.cap.read()
199 | name = os.path.split(self.videos[self.index])[-1]
200 | # name = os.path.splitext(name)[0]
201 | if ret is False:
202 | self.cap.release()
203 | self.cap = None
204 | img = None
205 | self.index += 1
206 |
207 | else:
208 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
209 | img = Image.fromarray(frame).convert('RGB')
210 |
211 | return img, name
212 |
213 | def __len__(self):
214 | return self.size
215 |
216 | class WebcamLoader:
217 | def __init__(self, ID):
218 | self.ID = int(ID)
219 | self.cap = cv2.VideoCapture(self.ID)
220 | self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
221 | self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
222 | self.imgs = []
223 | self.imgs.append(self.cap.read()[1])
224 | self.thread = Thread(target=self.update, daemon=True)
225 | self.thread.start()
226 |
227 | def update(self):
228 | while self.cap.isOpened():
229 | ret, frame = self.cap.read()
230 | if ret is True:
231 | self.imgs.append(frame)
232 | else:
233 | break
234 |
235 | def __iter__(self):
236 | return self
237 |
238 | def __next__(self):
239 | if len(self.imgs) > 0:
240 | frame = self.imgs[-1]
241 | else:
242 | frame = Image.fromarray(np.zeros((480, 640, 3)).astype(np.uint8))
243 |
244 | if self.thread.is_alive() is False or cv2.waitKey(1) == ord('q'):
245 | cv2.destroyAllWindows()
246 | raise StopIteration
247 |
248 | else:
249 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
250 | frame = Image.fromarray(frame).convert('RGB')
251 |
252 | del self.imgs[:-1]
253 | return frame, None
254 |
255 | def __len__(self):
256 | return 0
257 |
--------------------------------------------------------------------------------
/py/image_selector.py:
--------------------------------------------------------------------------------
1 | from .md import *
2 | from threading import Event
3 | import comfy.model_management
4 |
5 | class ImageSelectorCancelled(Exception):
6 | pass
7 |
8 | def get_selector_storage():
9 | """获取图像选择器的共享存储空间"""
10 | if not hasattr(PromptServer.instance, '_selector_node_data'):
11 | PromptServer.instance._selector_node_data = {}
12 | return PromptServer.instance._selector_node_data
13 |
14 | class ImageSelector(PreviewImage):
15 | @classmethod
16 | def INPUT_TYPES(cls):
17 | return {
18 | "required": {
19 | "images": ("IMAGE",),
20 | "mode": (["always_pause", "keep_last_selection", "passthrough"], {"default": "always_pause"}),
21 | },
22 | "hidden": {
23 | "prompt": "PROMPT",
24 | "unique_id": "UNIQUE_ID",
25 | "extra_pnginfo": "EXTRA_PNGINFO"
26 | }
27 | }
28 |
29 | RETURN_TYPES = ("IMAGE", "STRING")
30 | RETURN_NAMES = ("selected_images", "selected_indices")
31 | FUNCTION = "select_image"
32 | CATEGORY = "🎈LAOGOU/Image"
33 | OUTPUT_NODE = True
34 | OUTPUT_IS_LIST = (True, False)
35 | INPUT_IS_LIST = True
36 |
37 | @classmethod
38 | def IS_CHANGED(cls, images, **kwargs):
39 | return float(time.time())
40 |
41 | def select_image(self, images, mode, prompt=None, unique_id=None, extra_pnginfo=None):
42 | try:
43 | node_id = str(unique_id[0]) if isinstance(unique_id, list) else str(unique_id)
44 | actual_mode = mode[0] if isinstance(mode, list) else mode
45 |
46 | # 获取共享存储空间
47 | node_data = get_selector_storage()
48 |
49 | image_list = []
50 | if isinstance(images, list):
51 | for img in images:
52 | if isinstance(img, torch.Tensor):
53 | if len(img.shape) == 4:
54 | for i in range(img.shape[0]):
55 | image_list.append(img[i:i+1])
56 | elif len(img.shape) == 3:
57 | image_list.append(img.unsqueeze(0))
58 | elif isinstance(images, torch.Tensor):
59 | if len(images.shape) == 4:
60 | for i in range(images.shape[0]):
61 | image_list.append(images[i:i+1])
62 | elif len(images.shape) == 3:
63 | image_list.append(images.unsqueeze(0))
64 | else:
65 | raise ValueError(f"不支持的图像维度: {images.shape}")
66 | else:
67 | raise ValueError(f"不支持的输入类型: {type(images)}")
68 |
69 | preview_images = []
70 | for i, img in enumerate(image_list):
71 | try:
72 | result = self.save_images(images=img, prompt=prompt)
73 | if 'ui' in result and 'images' in result['ui']:
74 | preview_images.extend(result['ui']['images'])
75 | except Exception as e:
76 | continue
77 |
78 | try:
79 | PromptServer.instance.send_sync("image_selector_update", {
80 | "id": node_id,
81 | "urls": preview_images
82 | })
83 | except Exception as e:
84 | pass
85 |
86 | if actual_mode == "passthrough":
87 | self.cleanup_session_data(node_id)
88 | all_indices = ','.join(str(i) for i in range(len(image_list)))
89 | return {"result": (image_list, all_indices)}
90 |
91 | if actual_mode == "keep_last_selection":
92 | if node_id in node_data and "last_selection" in node_data[node_id]:
93 | last_selection = node_data[node_id]["last_selection"]
94 | if last_selection and len(last_selection) > 0:
95 | valid_indices = [idx for idx in last_selection if 0 <= idx < len(image_list)]
96 | if valid_indices:
97 | try:
98 | PromptServer.instance.send_sync("image_selector_selection", {
99 | "id": node_id,
100 | "selected_indices": valid_indices
101 | })
102 | except Exception as e:
103 | pass
104 | self.cleanup_session_data(node_id)
105 | indices_str = ','.join(str(i) for i in valid_indices)
106 | return {"result": ([image_list[idx] for idx in valid_indices], indices_str)}
107 |
108 | if node_id in node_data:
109 | del node_data[node_id]
110 |
111 | event = Event()
112 | node_data[node_id] = {
113 | "event": event,
114 | "selected_indices": None,
115 | "images": image_list,
116 | "total_count": len(image_list),
117 | "cancelled": False
118 | }
119 |
120 | while node_id in node_data:
121 | node_info = node_data[node_id]
122 | if node_info.get("cancelled", False):
123 | self.cleanup_session_data(node_id)
124 | raise ImageSelectorCancelled("用户取消选择")
125 |
126 | if "selected_indices" in node_info and node_info["selected_indices"] is not None:
127 | break
128 |
129 | time.sleep(0.1)
130 |
131 | if node_id in node_data:
132 | node_info = node_data[node_id]
133 | selected_indices = node_info.get("selected_indices")
134 |
135 | if selected_indices is not None and len(selected_indices) > 0:
136 | valid_indices = [idx for idx in selected_indices if 0 <= idx < len(image_list)]
137 | if valid_indices:
138 | selected_images = [image_list[idx] for idx in valid_indices]
139 |
140 | if node_id not in node_data:
141 | node_data[node_id] = {}
142 | node_data[node_id]["last_selection"] = valid_indices
143 |
144 | self.cleanup_session_data(node_id)
145 | indices_str = ','.join(str(i) for i in valid_indices)
146 | return {"result": (selected_images, indices_str)}
147 | else:
148 | self.cleanup_session_data(node_id)
149 | return {"result": ([image_list[0]] if len(image_list) > 0 else [], "0" if len(image_list) > 0 else "")}
150 | else:
151 | self.cleanup_session_data(node_id)
152 | return {"result": ([image_list[0]] if len(image_list) > 0 else [], "0" if len(image_list) > 0 else "")}
153 | else:
154 | return {"result": ([image_list[0]] if len(image_list) > 0 else [], "0" if len(image_list) > 0 else "")}
155 |
156 | except ImageSelectorCancelled:
157 | raise comfy.model_management.InterruptProcessingException()
158 | except Exception as e:
159 | node_data = get_selector_storage()
160 | if node_id in node_data:
161 | self.cleanup_session_data(node_id)
162 | if 'image_list' in locals() and len(image_list) > 0:
163 | return {"result": ([image_list[0]], "0")}
164 | else:
165 | return {"result": ([], "")}
166 |
167 | def cleanup_session_data(self, node_id):
168 | """清理会话数据"""
169 | node_data = get_selector_storage()
170 | if node_id in node_data:
171 | session_keys = ["event", "selected_indices", "images", "total_count", "cancelled"]
172 | for key in session_keys:
173 | if key in node_data[node_id]:
174 | del node_data[node_id][key]
175 |
176 | @PromptServer.instance.routes.post("/image_selector/select")
177 | async def select_image_handler(request):
178 | try:
179 | data = await request.json()
180 | node_id = data.get("node_id")
181 | selected_indices = data.get("selected_indices", [])
182 | action = data.get("action")
183 |
184 | # 获取共享存储空间
185 | node_data = get_selector_storage()
186 |
187 | if node_id not in node_data:
188 | return web.json_response({"success": False, "error": "节点数据不存在"})
189 |
190 | try:
191 | node_info = node_data[node_id]
192 |
193 | if "total_count" not in node_info:
194 | return web.json_response({"success": False, "error": "节点已完成处理"})
195 |
196 | if action == "cancel":
197 | node_info["cancelled"] = True
198 | node_info["selected_indices"] = []
199 | elif action == "select" and isinstance(selected_indices, list):
200 | valid_indices = [idx for idx in selected_indices if isinstance(idx, int) and 0 <= idx < node_info["total_count"]]
201 | if valid_indices:
202 | node_info["selected_indices"] = valid_indices
203 | node_info["cancelled"] = False
204 | else:
205 | return web.json_response({"success": False, "error": "选择索引无效"})
206 | else:
207 | return web.json_response({"success": False, "error": "无效操作"})
208 |
209 | node_info["event"].set()
210 | return web.json_response({"success": True})
211 |
212 | except Exception as e:
213 | if node_id in node_data and "event" in node_data[node_id]:
214 | node_data[node_id]["event"].set()
215 | return web.json_response({"success": False, "error": "处理失败"})
216 |
217 | except Exception as e:
218 | return web.json_response({"success": False, "error": "请求失败"})
219 |
220 | NODE_CLASS_MAPPINGS = {
221 | "ImageSelector": ImageSelector,
222 | }
223 |
224 | NODE_DISPLAY_NAME_MAPPINGS = {
225 | "ImageSelector": "🎈LG_图像选择器",
226 | }
--------------------------------------------------------------------------------
/web/lg_group_muter.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../../scripts/app.js";
2 | import { ComfyWidgets } from "../../../scripts/widgets.js";
3 |
4 | const MODE_BYPASS = 4;
5 | const MODE_MUTE = 2;
6 | const MODE_ALWAYS = 0;
7 |
8 | function recomputeInsideNodesForGroup(group) {
9 | if (!group.graph) return;
10 | group._nodes = [];
11 | const nodeBoundings = {};
12 | for (const node of group.graph._nodes) {
13 | nodeBoundings[node.id] = node.getBounding();
14 | }
15 | for (const node of group.graph._nodes) {
16 | const bounding = nodeBoundings[node.id];
17 | if (bounding && LiteGraph.overlapBounding(group._bounding, bounding)) {
18 | group._nodes.push(node);
19 | }
20 | }
21 | }
22 |
23 | class LG_GroupMuterNode extends LGraphNode {
24 | constructor(title) {
25 | super(title);
26 | this.isVirtualNode = true;
27 | this.serialize_widgets = true;
28 | this.properties = { mode: "mute", maxOne: false };
29 | this.groupWidgetMap = new Map();
30 | this.tempSize = null;
31 | this.debouncerTempWidth = 0;
32 | this._lastGroupSignature = "";
33 | }
34 |
35 | onAdded() {
36 | setTimeout(() => this.setupGroups(), 10);
37 | }
38 |
39 | setupGroups() {
40 | if (!app.graph?._groups) return;
41 |
42 | const groups = app.graph._groups.sort((a, b) => a.title.localeCompare(b.title));
43 | const existingGroups = new Set();
44 |
45 | groups.forEach((group, index) => {
46 | const name = group.title;
47 | existingGroups.add(name);
48 |
49 | let widget = this.groupWidgetMap.get(name) || this.widgets?.find(w => w.name === name);
50 |
51 | if (!widget) {
52 | this.tempSize = this.size ? [...this.size] : [300, 60];
53 |
54 | if (!this.inputs?.find(i => i.name === name)) {
55 | this.addInput(name, "BOOLEAN");
56 | }
57 | const input = this.inputs[this.inputs.length - 1];
58 |
59 | const widgetResult = ComfyWidgets?.BOOLEAN(this, name, ["BOOLEAN", { default: true }], app);
60 | widget = widgetResult?.widget;
61 |
62 | if (!widget) return;
63 |
64 | // Ensure value is explicitly boolean
65 | if (widget.value !== true && widget.value !== false) {
66 | widget.value = true;
67 | }
68 |
69 | input.widget = widget;
70 |
71 | // Ensure config symbols are set for PS_Config compatibility
72 | widget[Symbol.for("GET_CONFIG")] = () => ["BOOLEAN", { default: true }];
73 | widget[Symbol.for("CONFIG")] = ["BOOLEAN", { default: true }];
74 |
75 | const originalCallback = widget.callback;
76 | widget.callback = (value) => {
77 | // Max One模式:开启一个组时,关闭其他所有组
78 | if (this.properties.maxOne && value === true) {
79 | for (const [otherName, otherWidget] of this.groupWidgetMap) {
80 | if (otherName !== name && otherWidget.value === true) {
81 | otherWidget.value = false;
82 | this.applyGroupMode(otherName, false);
83 | const otherGroup = app.graph?._groups.find(g => g.title === otherName);
84 | if (otherGroup) otherGroup.lgtools_isActive = false;
85 | }
86 | }
87 | }
88 |
89 | this.applyGroupMode(name, value);
90 | group.lgtools_isActive = value;
91 | originalCallback?.call(widget, value);
92 | };
93 |
94 | this.groupWidgetMap.set(name, widget);
95 | group.lgtools_isActive = widget.value;
96 | this.setSize(this.computeSize());
97 | } else {
98 | this.groupWidgetMap.set(name, widget);
99 | const input = this.inputs?.find(i => i.name === name);
100 | if (input && !input.widget) input.widget = widget;
101 |
102 | // Ensure config symbols exist for existing widgets
103 | if (!widget[Symbol.for("GET_CONFIG")]) {
104 | widget[Symbol.for("GET_CONFIG")] = () => ["BOOLEAN", { default: true }];
105 | widget[Symbol.for("CONFIG")] = ["BOOLEAN", { default: true }];
106 | }
107 |
108 | if (group.lgtools_isActive != null && widget.value !== group.lgtools_isActive) {
109 | widget.value = group.lgtools_isActive;
110 | }
111 | }
112 |
113 | if (this.widgets?.[index] !== widget) {
114 | const oldIndex = this.widgets.indexOf(widget);
115 | if (oldIndex !== -1) {
116 | this.widgets.splice(index, 0, this.widgets.splice(oldIndex, 1)[0]);
117 | }
118 | }
119 | });
120 |
121 | const toRemove = [...this.groupWidgetMap.keys()].filter(name => !existingGroups.has(name));
122 | if (toRemove.length) {
123 | const widgetIndices = [];
124 | const inputIndices = [];
125 |
126 | toRemove.forEach(name => {
127 | const widget = this.groupWidgetMap.get(name);
128 | if (widget) widgetIndices.push(this.widgets.indexOf(widget));
129 | const inputIdx = this.inputs?.findIndex(i => i.name === name);
130 | if (inputIdx !== -1) inputIndices.push(inputIdx);
131 | });
132 |
133 | widgetIndices.sort((a, b) => b - a).forEach(i => this.widgets.splice(i, 1));
134 | inputIndices.sort((a, b) => b - a).forEach(i => this.removeInput(i));
135 |
136 | toRemove.forEach(name => {
137 | this.groupWidgetMap.delete(name);
138 | });
139 |
140 | this.setSize(this.computeSize());
141 | }
142 |
143 | this.setDirtyCanvas(true, false);
144 | }
145 |
146 | applyGroupMode(groupName, enabled) {
147 | const group = app.graph?._groups.find(g => g.title === groupName);
148 | if (!group) return;
149 |
150 | const targetMode = enabled ? MODE_ALWAYS : (this.properties.mode === "mute" ? MODE_MUTE : MODE_BYPASS);
151 | recomputeInsideNodesForGroup(group);
152 |
153 | (group._nodes || []).forEach(node => {
154 | if (node.mode !== targetMode) node.mode = targetMode;
155 | });
156 |
157 | app.graph.setDirtyCanvas(true, false);
158 | }
159 |
160 | computeSize(out) {
161 | const widgetCount = this.widgets?.length || 0;
162 | let size = [200, (LiteGraph.NODE_TITLE_HEIGHT || 30) + widgetCount * (LiteGraph.NODE_WIDGET_HEIGHT || 20)];
163 |
164 | if (this.tempSize) {
165 | size = [Math.max(this.tempSize[0], size[0]), Math.max(this.tempSize[1], size[1])];
166 | clearTimeout(this.debouncerTempWidth);
167 | this.debouncerTempWidth = setTimeout(() => this.tempSize = null, 32);
168 | }
169 |
170 | if (out) {
171 | out[0] = size[0];
172 | out[1] = size[1];
173 | }
174 | return size;
175 | }
176 |
177 | getExtraMenuOptions(canvas, options) {
178 | const currentMode = this.properties.mode || "mute";
179 | const nextMode = currentMode === "mute" ? "bypass" : "mute";
180 |
181 | options.push({
182 | content: `Switch to ${nextMode === "mute" ? "Mute" : "Bypass"} mode`,
183 | callback: () => {
184 | this.properties.mode = nextMode;
185 | for (const [name, widget] of this.groupWidgetMap) {
186 | this.applyGroupMode(name, widget.value);
187 | }
188 | }
189 | });
190 |
191 | options.push({
192 | content: this.properties.maxOne ? "✓ Max One Mode" : "Max One Mode",
193 | callback: () => {
194 | this.properties.maxOne = !this.properties.maxOne;
195 | // 如果启用Max One模式,确保最多只有一个组是开启的
196 | if (this.properties.maxOne) {
197 | let firstActiveFound = false;
198 | for (const [name, widget] of this.groupWidgetMap) {
199 | if (widget.value === true) {
200 | if (firstActiveFound) {
201 | widget.value = false;
202 | this.applyGroupMode(name, false);
203 | const group = app.graph?._groups.find(g => g.title === name);
204 | if (group) group.lgtools_isActive = false;
205 | } else {
206 | firstActiveFound = true;
207 | }
208 | }
209 | }
210 | }
211 | this.setDirtyCanvas(true, false);
212 | }
213 | });
214 |
215 | options.push({
216 | content: "Toggle All Groups",
217 | callback: () => {
218 | const newValue = !(this.widgets || []).every(w => w.value);
219 | (this.widgets || []).forEach(w => {
220 | w.value = newValue;
221 | w.callback?.(newValue);
222 | });
223 | }
224 | });
225 | }
226 |
227 | onDrawBackground() {
228 | if (!app.graph?._groups) return;
229 |
230 | const signature = app.graph._groups.map(g => g.title).sort().join(',');
231 | if (signature !== this._lastGroupSignature) {
232 | this._lastGroupSignature = signature;
233 | this.setupGroups();
234 | }
235 |
236 | for (const [name, widget] of this.groupWidgetMap) {
237 | const group = app.graph._groups.find(g => g.title === name);
238 | if (!group) continue;
239 |
240 | recomputeInsideNodesForGroup(group);
241 | if (!group._nodes?.length) continue;
242 |
243 | const isActive = group._nodes.some(n => n.mode === MODE_ALWAYS);
244 | group.lgtools_isActive = isActive;
245 |
246 | if (widget.value !== isActive) {
247 | widget.value = isActive;
248 | this.setDirtyCanvas(true, false);
249 | }
250 | }
251 | }
252 |
253 | static setUp() {
254 | LiteGraph.registerNodeType(this.type, this);
255 | if (this._category) {
256 | this.category = this._category;
257 | }
258 | }
259 | }
260 |
261 | LG_GroupMuterNode.type = "🎈LG_GroupMuter";
262 | LG_GroupMuterNode.title = "🎈LG Group Muter";
263 | LG_GroupMuterNode._category = "🎈LAOGOU/Switch";
264 | LG_GroupMuterNode["@mode"] = { type: "combo", values: ["mute", "bypass"] };
265 |
266 | app.registerExtension({
267 | name: "LG.GroupMuter",
268 | registerCustomNodes() {
269 | LG_GroupMuterNode.setUp();
270 | },
271 | loadedGraphNode(node) {
272 | if (node.type === LG_GroupMuterNode.type) {
273 | node.tempSize = node.size ? [...node.size] : null;
274 | }
275 | }
276 | });
277 |
--------------------------------------------------------------------------------
/web/image_size_adjustment.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../scripts/app.js";
2 | import { api } from "../../scripts/api.js";
3 | import { ComfyWidgets } from "../../scripts/widgets.js";
4 |
5 | app.registerExtension({
6 | name: "ImagePreview.Adjustment",
7 | async beforeRegisterNodeDef(nodeType, nodeData) {
8 | if (nodeData.name === "ImageSizeAdjustment") {
9 | const onNodeCreated = nodeType.prototype.onNodeCreated;
10 | nodeType.prototype.onNodeCreated = function() {
11 | const result = onNodeCreated?.apply(this, arguments);
12 | this.widgets_start_y = 30;
13 | this.setupWebSocket();
14 |
15 | // 创建偏移量控制器
16 | const createOffsetWidget = (name, defaultValue = 0) => {
17 | const widget = ComfyWidgets.INT(this, name, ["INT", {
18 | default: defaultValue,
19 | min: -9999,
20 | max: 9999,
21 | step: 8
22 | }]);
23 |
24 | // 设置值变化回调
25 | widget.widget.callback = (value) => {
26 | // 将值调整为8的倍数
27 | value = Math.floor(value / 8) * 8;
28 |
29 | if (this.originalImageData) {
30 | const size = name === "x_offset"
31 | ? this.originalImageData[0]?.length
32 | : this.originalImageData.length;
33 |
34 | if (value < 0 && Math.abs(value) > size) {
35 | value = -size;
36 | this.widgets.find(w => w.name === name).value = value;
37 | }
38 | }
39 | this[name] = value;
40 | this.updatePreview(true);
41 | };
42 |
43 | // 设置拖动事件
44 | widget.widget.onDragStart = function() {
45 | this.node.isAdjusting = true;
46 | };
47 |
48 | widget.widget.onDragEnd = function() {
49 | this.node.isAdjusting = false;
50 | this.node.updatePreview(false);
51 | };
52 |
53 | return widget;
54 | };
55 |
56 | // 创建 x 和 y 偏移控制器
57 | createOffsetWidget("x_offset");
58 | createOffsetWidget("y_offset");
59 |
60 | return result;
61 | };
62 |
63 | // 添加WebSocket设置方法
64 | nodeType.prototype.setupWebSocket = function() {
65 | api.addEventListener("image_preview_update", async (event) => {
66 | const data = event.detail;
67 | if (data && data.node_id === this.id.toString()) {
68 | console.log(`[ImagePreview] 节点 ${this.id} 接收到更新数据`);
69 | if (data.image_data) {
70 | this.loadImageFromBase64(data.image_data);
71 | }
72 | }
73 | });
74 | };
75 |
76 | // 更新预览方法
77 | nodeType.prototype.updatePreview = function(onlyPreview = false) {
78 | if (!this.originalImageData || !this.canvas) {
79 | return;
80 | }
81 |
82 | if (this.updateTimeout) {
83 | clearTimeout(this.updateTimeout);
84 | }
85 |
86 | this.updateTimeout = setTimeout(() => {
87 | const ctx = this.canvas.getContext("2d");
88 | const originalWidth = this.originalImageData[0].length;
89 | const originalHeight = this.originalImageData.length;
90 |
91 | let x_offset = this.x_offset || 0;
92 | let y_offset = this.y_offset || 0;
93 |
94 | // 计算新的宽度和高度
95 | const newWidth = Math.max(1, originalWidth + x_offset);
96 | const newHeight = Math.max(1, originalHeight + y_offset);
97 |
98 | // 同步设置画布和容器的尺寸
99 | this.canvas.width = newWidth;
100 | this.canvas.height = newHeight;
101 |
102 | ctx.clearRect(0, 0, newWidth, newHeight);
103 |
104 | if (!this.tempCanvas) {
105 | this.tempCanvas = document.createElement('canvas');
106 | }
107 |
108 | if (!this.originalImageRendered) {
109 | this.tempCanvas.width = originalWidth;
110 | this.tempCanvas.height = originalHeight;
111 | const tempCtx = this.tempCanvas.getContext('2d');
112 |
113 | const imgData = new ImageData(originalWidth, originalHeight);
114 | for (let y = 0; y < originalHeight; y++) {
115 | for (let x = 0; x < originalWidth; x++) {
116 | const dstIdx = (y * originalWidth + x) * 4;
117 | const srcPixel = this.originalImageData[y][x];
118 | imgData.data[dstIdx] = srcPixel[0]; // R
119 | imgData.data[dstIdx + 1] = srcPixel[1]; // G
120 | imgData.data[dstIdx + 2] = srcPixel[2]; // B
121 | imgData.data[dstIdx + 3] = 255; // A
122 | }
123 | }
124 | tempCtx.putImageData(imgData, 0, 0);
125 | this.originalImageRendered = true;
126 | }
127 |
128 | ctx.drawImage(
129 | this.tempCanvas,
130 | 0, 0,
131 | originalWidth, originalHeight,
132 | 0, 0,
133 | newWidth, newHeight
134 | );
135 |
136 | if (!onlyPreview && !this.isAdjusting) {
137 | const adjustedData = ctx.getImageData(0, 0, newWidth, newHeight);
138 | this.sendAdjustedData(adjustedData);
139 | }
140 | }, this.isAdjusting ? 50 : 0);
141 | };
142 |
143 | // 发送调整后的数据
144 | nodeType.prototype.sendAdjustedData = async function(adjustedData) {
145 | try {
146 | const endpoint = '/image_preview/apply';
147 | const nodeId = String(this.id);
148 |
149 | console.log(`[ImagePreview] 发送调整后的数据 - 尺寸: ${adjustedData.width}x${adjustedData.height}`);
150 |
151 | const canvas = document.createElement('canvas');
152 | canvas.width = adjustedData.width;
153 | canvas.height = adjustedData.height;
154 | const ctx = canvas.getContext('2d');
155 | ctx.putImageData(adjustedData, 0, 0);
156 |
157 | const blob = await new Promise(resolve => {
158 | canvas.toBlob(resolve, 'image/jpeg', 0.9);
159 | });
160 |
161 | const formData = new FormData();
162 | formData.append('node_id', nodeId);
163 | formData.append('width', adjustedData.width);
164 | formData.append('height', adjustedData.height);
165 | formData.append('image_data', blob, 'adjusted_image.jpg');
166 |
167 | api.fetchApi(endpoint, {
168 | method: 'POST',
169 | body: formData
170 | });
171 | } catch (error) {
172 | console.error('发送数据时出错:', error);
173 | }
174 | };
175 |
176 | // 添加从base64加载图像的方法
177 | nodeType.prototype.loadImageFromBase64 = function(base64Data) {
178 | const img = new Image();
179 |
180 | img.onload = () => {
181 | this.originalImageRendered = false;
182 |
183 | const tempCanvas = document.createElement('canvas');
184 | tempCanvas.width = img.width;
185 | tempCanvas.height = img.height;
186 | const tempCtx = tempCanvas.getContext('2d');
187 |
188 | tempCtx.drawImage(img, 0, 0);
189 | const imageData = tempCtx.getImageData(0, 0, img.width, img.height);
190 |
191 | const pixelArray = [];
192 | for (let y = 0; y < img.height; y++) {
193 | const row = [];
194 | for (let x = 0; x < img.width; x++) {
195 | const idx = (y * img.width + x) * 4;
196 | row.push([
197 | imageData.data[idx], // R
198 | imageData.data[idx + 1], // G
199 | imageData.data[idx + 2] // B
200 | ]);
201 | }
202 | pixelArray.push(row);
203 | }
204 |
205 | this.originalImageData = pixelArray;
206 | this.updatePreview();
207 | };
208 |
209 | img.src = base64Data;
210 | };
211 |
212 | // 添加节点时的处理
213 | const onAdded = nodeType.prototype.onAdded;
214 | nodeType.prototype.onAdded = function() {
215 | const result = onAdded?.apply(this, arguments);
216 |
217 | if (!this.previewElement && this.id !== undefined && this.id !== -1) {
218 | const previewContainer = document.createElement("div");
219 | previewContainer.style.position = "relative";
220 | previewContainer.style.width = "100%";
221 | previewContainer.style.height = "100%";
222 | previewContainer.style.backgroundColor = "#333";
223 | previewContainer.style.borderRadius = "8px";
224 | previewContainer.style.overflow = "hidden";
225 |
226 | const canvas = document.createElement("canvas");
227 | canvas.style.width = "100%";
228 | canvas.style.height = "100%";
229 | canvas.style.objectFit = "contain";
230 |
231 | previewContainer.appendChild(canvas);
232 |
233 | this.canvas = canvas;
234 | this.previewElement = previewContainer;
235 |
236 | this.widgets ||= [];
237 | this.widgets_up = true;
238 |
239 | requestAnimationFrame(() => {
240 | if (this.widgets) {
241 | this.previewWidget = this.addDOMWidget("preview", "preview", previewContainer);
242 | this.setDirtyCanvas(true, true);
243 | }
244 | });
245 | }
246 |
247 | return result;
248 | };
249 | }
250 | }
251 | });
--------------------------------------------------------------------------------
/web/color_adjustment.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../scripts/app.js";
2 | import { api } from "../../scripts/api.js";
3 |
4 |
5 | app.registerExtension({
6 | name: "ColorAdjustment.Preview",
7 | async beforeRegisterNodeDef(nodeType, nodeData) {
8 | if (nodeData.name === "ColorAdjustment") {
9 |
10 | // 扩展节点的构造函数
11 | const onNodeCreated = nodeType.prototype.onNodeCreated;
12 | nodeType.prototype.onNodeCreated = function() {
13 | const result = onNodeCreated?.apply(this, arguments);
14 |
15 | // 设置组件起始位置,确保在端口下方
16 | this.widgets_start_y = 30; // 调整这个值以适应端口高度
17 |
18 | this.setupWebSocket();
19 |
20 | const sliderConfig = {
21 | min: 0,
22 | max: 2,
23 | step: 0.01,
24 | drag_start: () => this.isAdjusting = true,
25 | drag_end: () => {
26 | this.isAdjusting = false;
27 | this.updatePreview(false);
28 | }
29 | };
30 |
31 | const createSlider = (name) => {
32 | this.addWidget("slider", name, 1.0, (value) => {
33 | this[name] = value;
34 | this.updatePreview(true);
35 | }, sliderConfig);
36 | };
37 |
38 | ["brightness", "contrast", "saturation"].forEach(createSlider);
39 |
40 | return result;
41 | };
42 |
43 | // 添加WebSocket设置方法
44 | nodeType.prototype.setupWebSocket = function() {
45 | console.log(`[ColorAdjustment] 节点 ${this.id} 设置WebSocket监听`);
46 | api.addEventListener("color_adjustment_update", async (event) => {
47 | const data = event.detail;
48 |
49 | if (data && data.node_id && data.node_id === this.id.toString()) {
50 | console.log(`[ColorAdjustment] 节点 ${this.id} 接收到更新数据`);
51 | if (data.image_data) {
52 | // 处理base64图像数据
53 | console.log("[ColorAdjustment] 接收到base64数据:", {
54 | nodeId: this.id,
55 | dataLength: data.image_data.length,
56 | dataPreview: data.image_data.substring(0, 50) + "...", // 只显示前50个字符
57 | isBase64: data.image_data.startsWith("data:image"),
58 | timestamp: new Date().toISOString()
59 | });
60 |
61 | this.loadImageFromBase64(data.image_data);
62 | } else {
63 | console.warn("[ColorAdjustment] 接收到空的图像数据");
64 | }
65 | }
66 | });
67 | };
68 |
69 | // 添加从base64加载图像的方法
70 | nodeType.prototype.loadImageFromBase64 = function(base64Data) {
71 | console.log(`[ColorAdjustment] 节点 ${this.id} 开始加载base64图像数据`);
72 | // 创建一个新的图像对象
73 | const img = new Image();
74 |
75 | // 当图像加载完成时
76 | img.onload = () => {
77 | console.log(`[ColorAdjustment] 节点 ${this.id} 图像加载完成: ${img.width}x${img.height}`);
78 | // 创建一个临时画布来获取像素数据
79 | const tempCanvas = document.createElement('canvas');
80 | tempCanvas.width = img.width;
81 | tempCanvas.height = img.height;
82 | const tempCtx = tempCanvas.getContext('2d');
83 |
84 | // 在临时画布上绘制图像
85 | tempCtx.drawImage(img, 0, 0);
86 |
87 | // 获取像素数据
88 | const imageData = tempCtx.getImageData(0, 0, img.width, img.height);
89 |
90 | // 创建二维数组存储像素数据
91 | const pixelArray = [];
92 | for (let y = 0; y < img.height; y++) {
93 | const row = [];
94 | for (let x = 0; x < img.width; x++) {
95 | const idx = (y * img.width + x) * 4;
96 | row.push([
97 | imageData.data[idx], // R
98 | imageData.data[idx + 1], // G
99 | imageData.data[idx + 2] // B
100 | ]);
101 | }
102 | pixelArray.push(row);
103 | }
104 |
105 | // 存储像素数据并更新预览
106 | this.originalImageData = pixelArray;
107 | this.updatePreview();
108 | };
109 |
110 | // 设置图像源
111 | img.src = base64Data;
112 | };
113 |
114 | // 添加节点时的处理
115 | const onAdded = nodeType.prototype.onAdded;
116 | nodeType.prototype.onAdded = function() {
117 | const result = onAdded?.apply(this, arguments);
118 |
119 | if (!this.previewElement && this.id !== undefined && this.id !== -1) {
120 | // 创建预览容器
121 | const previewContainer = document.createElement("div");
122 | previewContainer.style.position = "relative";
123 | previewContainer.style.width = "100%";
124 | previewContainer.style.height = "100%";
125 | previewContainer.style.backgroundColor = "#333";
126 | previewContainer.style.borderRadius = "8px";
127 | previewContainer.style.overflow = "hidden";
128 |
129 | // 创建预览画布
130 | const canvas = document.createElement("canvas");
131 | canvas.style.width = "100%";
132 | canvas.style.height = "100%";
133 | canvas.style.objectFit = "contain";
134 |
135 | previewContainer.appendChild(canvas);
136 | this.canvas = canvas;
137 | this.previewElement = previewContainer;
138 |
139 | // 添加DOM部件
140 | this.widgets ||= [];
141 | this.widgets_up = true;
142 |
143 | requestAnimationFrame(() => {
144 | if (this.widgets) {
145 | this.previewWidget = this.addDOMWidget("preview", "preview", previewContainer);
146 | this.setDirtyCanvas(true, true);
147 | }
148 | });
149 | }
150 |
151 | return result;
152 | };
153 |
154 | // 更新预览方法
155 | nodeType.prototype.updatePreview = function(onlyPreview = false) {
156 | if (!this.originalImageData || !this.canvas) {
157 | return;
158 | }
159 |
160 | requestAnimationFrame(() => {
161 | const ctx = this.canvas.getContext("2d");
162 | const width = this.originalImageData[0].length;
163 | const height = this.originalImageData.length;
164 |
165 | if (!onlyPreview && !this.isAdjusting) {
166 | console.log(`[ColorAdjustment] 节点 ${this.id} 更新预览并准备发送数据 (${width}x${height})`);
167 | } else {
168 | console.log(`[ColorAdjustment] 节点 ${this.id} 仅更新预览 (${width}x${height})`);
169 | }
170 |
171 | // 创建ImageData
172 | const imgData = new ImageData(width, height);
173 |
174 | // 填充原始数据
175 | for (let y = 0; y < height; y++) {
176 | for (let x = 0; x < width; x++) {
177 | const idx = (y * width + x) * 4;
178 | imgData.data[idx] = this.originalImageData[y][x][0]; // R
179 | imgData.data[idx + 1] = this.originalImageData[y][x][1]; // G
180 | imgData.data[idx + 2] = this.originalImageData[y][x][2]; // B
181 | imgData.data[idx + 3] = 255; // A
182 | }
183 | }
184 |
185 | // 应用颜色调整
186 | const adjustedData = this.adjustColors(imgData);
187 |
188 | // 调整画布大小并显示
189 | this.canvas.width = width;
190 | this.canvas.height = height;
191 | ctx.putImageData(adjustedData, 0, 0);
192 |
193 | // 只在拖动结束时发送数据
194 | if (!onlyPreview && !this.isAdjusting) {
195 | this.lastAdjustedData = adjustedData;
196 | this.sendAdjustedData(adjustedData);
197 | }
198 | });
199 | };
200 |
201 | // 优化颜色调整方法,提高性能
202 | nodeType.prototype.adjustColors = function(imageData) {
203 | const brightness = this.brightness || 1.0;
204 | const contrast = this.contrast || 1.0;
205 | const saturation = this.saturation || 1.0;
206 |
207 | const result = new Uint8ClampedArray(imageData.data);
208 | const len = result.length;
209 |
210 | // 使用查找表优化常用计算
211 | const contrastFactor = contrast;
212 | const contrastOffset = 128 * (1 - contrast);
213 |
214 | for (let i = 0; i < len; i += 4) {
215 | // 优化亮度和对比度调整
216 | let r = Math.min(255, result[i] * brightness);
217 | let g = Math.min(255, result[i + 1] * brightness);
218 | let b = Math.min(255, result[i + 2] * brightness);
219 |
220 | r = r * contrastFactor + contrastOffset;
221 | g = g * contrastFactor + contrastOffset;
222 | b = b * contrastFactor + contrastOffset;
223 |
224 | // 优化饱和度调整 - 使用更准确的亮度权重
225 | if (saturation !== 1.0) {
226 | const avg = r * 0.299 + g * 0.587 + b * 0.114;
227 | r = avg + (r - avg) * saturation;
228 | g = avg + (g - avg) * saturation;
229 | b = avg + (b - avg) * saturation;
230 | }
231 |
232 | // 确保值在正确范围内
233 | result[i] = Math.min(255, Math.max(0, r));
234 | result[i + 1] = Math.min(255, Math.max(0, g));
235 | result[i + 2] = Math.min(255, Math.max(0, b));
236 | }
237 |
238 | return new ImageData(result, imageData.width, imageData.height);
239 | };
240 |
241 | // 添加发送调整后数据的方法,优化为异步
242 | nodeType.prototype.sendAdjustedData = async function(adjustedData) {
243 | try {
244 | const endpoint = '/color_adjustment/apply';
245 | const nodeId = String(this.id);
246 |
247 | api.fetchApi(endpoint, {
248 | method: 'POST',
249 | body: JSON.stringify({
250 | node_id: nodeId,
251 | adjusted_data: Array.from(adjustedData.data),
252 | width: adjustedData.width,
253 | height: adjustedData.height
254 | })
255 | }).then(response => {
256 | if (!response.ok) {
257 | throw new Error(`服务器返回错误: ${response.status}`);
258 | }
259 | return response.json();
260 | }).catch(error => {
261 | console.error('数据发送失败:', error);
262 | });
263 | } catch (error) {
264 | console.error('发送数据时出错:', error);
265 | }
266 | };
267 |
268 | // 节点移除时的处理
269 | const onRemoved = nodeType.prototype.onRemoved;
270 | nodeType.prototype.onRemoved = function() {
271 | const result = onRemoved?.apply(this, arguments);
272 |
273 | if (this && this.canvas) {
274 | const ctx = this.canvas.getContext("2d");
275 | if (ctx) {
276 | ctx.clearRect(0, 0, this.canvas.width, this.canvas.height);
277 | }
278 | this.canvas = null;
279 | }
280 | if (this) {
281 | this.previewElement = null;
282 | }
283 |
284 | return result;
285 | };
286 | }
287 | }
288 | });
289 |
290 |
291 |
--------------------------------------------------------------------------------
/py/bridge_preview.py:
--------------------------------------------------------------------------------
1 | from .md import *
2 | import torch
3 | import json
4 | import os
5 | import random
6 | import numpy as np
7 | import hashlib
8 | from PIL import Image
9 | from PIL.PngImagePlugin import PngInfo
10 | import folder_paths
11 | import comfy.utils
12 | from threading import Event
13 | import threading
14 | from server import PromptServer
15 | from aiohttp import web
16 | def get_bridge_storage():
17 | if not hasattr(PromptServer.instance, '_bridge_node_data'):
18 | PromptServer.instance._bridge_node_data = {}
19 | return PromptServer.instance._bridge_node_data
20 | def get_bridge_cache():
21 | """获取桥接节点的缓存存储"""
22 | if not hasattr(PromptServer.instance, '_bridge_node_cache'):
23 | PromptServer.instance._bridge_node_cache = {}
24 | return PromptServer.instance._bridge_node_cache
25 | class BridgePreviewNode(PreviewImage):
26 | """桥接预览节点,等待前端遮罩操作完成后输出图片"""
27 | def __init__(self):
28 | super().__init__()
29 | self.prefix_append = "_bridge_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
30 | @classmethod
31 | def INPUT_TYPES(cls):
32 | return {
33 | "required": {
34 | "images": ("IMAGE", ),
35 | "file_info": ("STRING", {"default": "", "readonly": True, "multiline": False}),
36 | "skip": ("BOOLEAN", {"default": False, "label_on": "Skip", "label_off": "Open dialog"}),
37 | },
38 | "hidden": {
39 | "prompt": "PROMPT",
40 | "extra_pnginfo": "EXTRA_PNGINFO",
41 | "unique_id": "UNIQUE_ID"
42 | },
43 | }
44 | RETURN_TYPES = ("IMAGE", "MASK")
45 | RETURN_NAMES = ("处理后图像", "遮罩")
46 | FUNCTION = "process_image"
47 | OUTPUT_NODE = True
48 | CATEGORY = "🎈LAOGOU/Image"
49 | def calculate_image_hash(self, images):
50 | """计算图片的哈希值用于检测是否改变"""
51 | try:
52 | np_images = images.cpu().numpy()
53 | image_bytes = np_images.tobytes()
54 | return hashlib.md5(image_bytes).hexdigest()
55 | except:
56 | return None
57 | def save_output_images(self, images, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
58 | """保存输出图片(带遮罩的RGBA图片)用于前端显示"""
59 | try:
60 | if images.shape[0] != mask.shape[0]:
61 | return None
62 | batch_size = images.shape[0]
63 | saved_images = []
64 | for i in range(batch_size):
65 | single_image = images[i:i+1]
66 | single_mask = mask[i]
67 | np_image = (single_image.squeeze(0).cpu().numpy() * 255).astype(np.uint8)
68 | np_mask = (single_mask.cpu().numpy() * 255).astype(np.uint8)
69 | h, w, c = np_image.shape
70 | rgba_image = np.zeros((h, w, 4), dtype=np.uint8)
71 | rgba_image[:, :, :3] = np_image
72 | rgba_image[:, :, 3] = np_mask
73 | pil_image = Image.fromarray(rgba_image, 'RGBA')
74 | counter = 1
75 | while True:
76 | filename = f"{filename_prefix}_{counter:05d}_.png"
77 | full_path = os.path.join(folder_paths.get_output_directory(), filename)
78 | if not os.path.exists(full_path):
79 | break
80 | counter += 1
81 | metadata = PngInfo()
82 | if prompt is not None:
83 | metadata.add_text("prompt", json.dumps(prompt))
84 | if extra_pnginfo is not None:
85 | for key, value in extra_pnginfo.items():
86 | metadata.add_text(key, json.dumps(value))
87 | pil_image.save(full_path, pnginfo=metadata, compress_level=4)
88 | saved_images.append({
89 | "filename": filename,
90 | "subfolder": "",
91 | "type": "output"
92 | })
93 | return {"ui": {"images": saved_images}}
94 | except Exception as e:
95 | return None
96 | def process_image(self, images, file_info="", skip=False, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None, unique_id=None):
97 | """处理图像,等待前端遮罩操作完成"""
98 | try:
99 | node_id = str(unique_id)
100 | current_hash = self.calculate_image_hash(images)
101 | cache = get_bridge_cache()
102 |
103 | # 如果启用跳过预览,直接返回缓存结果或原图
104 | if skip:
105 | if node_id in cache and current_hash:
106 | cached_data = cache[node_id]
107 | if (cached_data.get("input_hash") == current_hash and
108 | cached_data.get("final_result")):
109 | # 返回缓存的最终结果,保持与弹窗模式一致的遮罩反转
110 | cached_images, cached_mask = cached_data["final_result"]
111 | return (cached_images, 1 - cached_mask)
112 |
113 | # 没有缓存或输入改变,返回原图和全白遮罩
114 | batch_size, height, width, channels = images.shape
115 | default_mask = torch.ones((batch_size, height, width), dtype=torch.float32)
116 | return (images, default_mask)
117 |
118 | # 正常预览遮罩编辑流程
119 | event = Event()
120 | preview_urls = []
121 | should_send_to_frontend = True
122 |
123 | if node_id in cache and current_hash:
124 | cached_data = cache[node_id]
125 | if cached_data.get("input_hash") == current_hash and cached_data.get("output_urls"):
126 | preview_urls = cached_data["output_urls"]
127 | should_send_to_frontend = True
128 | else:
129 | preview_result = self.save_images(images, filename_prefix, prompt, extra_pnginfo)
130 | preview_urls = preview_result["ui"]["images"] if preview_result and "ui" in preview_result else []
131 | else:
132 | preview_result = self.save_images(images, filename_prefix, prompt, extra_pnginfo)
133 | preview_urls = preview_result["ui"]["images"] if preview_result and "ui" in preview_result else []
134 | get_bridge_storage()[node_id] = {
135 | "event": event,
136 | "result": None,
137 | "original_images": images,
138 | "processing_complete": False,
139 | "input_hash": current_hash
140 | }
141 |
142 | try:
143 | PromptServer.instance.send_sync("bridge_preview_update", {
144 | "node_id": node_id,
145 | "urls": preview_urls
146 | })
147 |
148 | if not event.wait(timeout=30):
149 | if node_id in get_bridge_storage():
150 | del get_bridge_storage()[node_id]
151 | batch_size, height, width, channels = images.shape
152 | default_mask = torch.ones((batch_size, height, width), dtype=torch.float32)
153 | return (images, default_mask)
154 |
155 | result_data = None
156 | if node_id in get_bridge_storage():
157 | result_data = get_bridge_storage()[node_id]["result"]
158 | del get_bridge_storage()[node_id]
159 |
160 | batch_size, height, width, channels = images.shape
161 | default_mask = torch.ones((batch_size, height, width), dtype=torch.float32)
162 |
163 | if result_data is not None and isinstance(result_data, tuple) and len(result_data) == 2:
164 | final_images, final_mask = result_data
165 | if current_hash:
166 | # 缓存时使用原始遮罩
167 | output_result = self.save_output_images(final_images, final_mask, filename_prefix + "_output", prompt, extra_pnginfo)
168 | output_urls = output_result["ui"]["images"] if output_result and "ui" in output_result else []
169 | cache[node_id] = {
170 | "input_hash": current_hash,
171 | "output_urls": output_urls,
172 | "final_result": (final_images, final_mask) # 缓存原始遮罩
173 | }
174 | # 返回时反转遮罩,但不影响缓存
175 | return (final_images, 1 - final_mask)
176 | else:
177 | return (images, default_mask)
178 |
179 | except Exception as e:
180 | if node_id in get_bridge_storage():
181 | del get_bridge_storage()[node_id]
182 | batch_size, height, width, channels = images.shape
183 | default_mask = torch.ones((batch_size, height, width), dtype=torch.float32)
184 | return (images, default_mask)
185 |
186 | except Exception as e:
187 | batch_size, height, width, channels = images.shape
188 | default_mask = torch.ones((batch_size, height, width), dtype=torch.float32)
189 | return (images, default_mask)
190 | @PromptServer.instance.routes.post("/bridge_preview/confirm")
191 | async def confirm_bridge_preview(request):
192 | """处理前端确认遮罩操作完成"""
193 | try:
194 | data = await request.json()
195 | node_id = str(data.get("node_id"))
196 | file_info = data.get("file_info")
197 | if node_id not in get_bridge_storage():
198 | return web.json_response({"success": False, "error": "节点未找到或已超时"})
199 | try:
200 | node_info = get_bridge_storage()[node_id]
201 | if file_info:
202 | processed_image, mask_image = load_processed_image(file_info)
203 | if processed_image is not None and mask_image is not None:
204 | node_info["result"] = (processed_image, mask_image)
205 | node_info["processing_complete"] = True
206 | node_info["event"].set()
207 | return web.json_response({"success": True})
208 | except Exception as e:
209 | if node_id in get_bridge_storage() and "event" in get_bridge_storage()[node_id]:
210 | get_bridge_storage()[node_id]["event"].set()
211 | return web.json_response({"success": False, "error": str(e)})
212 | except Exception as e:
213 | return web.json_response({"success": False, "error": str(e)})
214 | @PromptServer.instance.routes.post("/bridge_preview/cancel")
215 | async def cancel_bridge_preview(request):
216 | """取消桥接预览操作 - 恢复到上次保存的结果"""
217 | try:
218 | data = await request.json()
219 | node_id = str(data.get("node_id"))
220 | if node_id in get_bridge_storage():
221 | node_info = get_bridge_storage()[node_id]
222 | cache = get_bridge_cache()
223 |
224 | # 如果有缓存的最终结果,使用它(保留上次编辑的遮罩)
225 | if node_id in cache and cache[node_id].get("final_result"):
226 | cached_images, cached_mask = cache[node_id]["final_result"]
227 | node_info["result"] = (cached_images, cached_mask)
228 |
229 | node_info["event"].set()
230 | return web.json_response({"success": True, "message": f"节点 {node_id} 已取消"})
231 | else:
232 | return web.json_response({"success": False, "error": f"节点 {node_id} 未找到或已超时"})
233 | except Exception as e:
234 | return web.json_response({"success": False, "error": str(e)})
235 | def load_processed_image(file_info):
236 | """从文件信息加载处理后的图片,返回图像和遮罩"""
237 | try:
238 | filename = None
239 | subfolder = ""
240 | file_type = "output"
241 |
242 | if isinstance(file_info, dict):
243 | filename = file_info.get("filename")
244 | subfolder = file_info.get("subfolder", "")
245 | file_type = file_info.get("type", "output")
246 | elif isinstance(file_info, str):
247 | # 尝试解析新版本字符串格式: "path/filename.ext [type]"
248 | file_info_str = str(file_info).strip()
249 |
250 | # 查找最后的 [type] 部分
251 | if '[' in file_info_str and file_info_str.endswith(']'):
252 | # 分离路径和类型
253 | last_bracket = file_info_str.rfind('[')
254 | path_part = file_info_str[:last_bracket].strip()
255 | type_part = file_info_str[last_bracket+1:-1].strip()
256 |
257 | # 解析路径部分
258 | if '/' in path_part:
259 | # 包含子文件夹
260 | path_parts = path_part.split('/')
261 | filename = path_parts[-1] # 最后一部分是文件名
262 | subfolder = '/'.join(path_parts[:-1]) # 前面的部分是子文件夹
263 | else:
264 | # 只有文件名
265 | filename = path_part
266 | subfolder = ""
267 |
268 | # 解析类型
269 | file_type = type_part.lower() if type_part else "output"
270 | else:
271 | # 没有类型信息,直接作为路径处理
272 | if '/' in file_info_str:
273 | path_parts = file_info_str.split('/')
274 | filename = path_parts[-1]
275 | subfolder = '/'.join(path_parts[:-1])
276 | else:
277 | filename = file_info_str
278 | subfolder = ""
279 | file_type = "output" # 默认类型
280 | else:
281 | return None, None
282 |
283 | if not filename:
284 | return None, None
285 | if file_type == "input":
286 | base_dir = folder_paths.get_input_directory()
287 | elif file_type == "output":
288 | base_dir = folder_paths.get_output_directory()
289 | elif file_type == "temp":
290 | base_dir = folder_paths.get_temp_directory()
291 | else:
292 | base_dir = folder_paths.get_output_directory()
293 | if subfolder:
294 | file_path = os.path.join(base_dir, subfolder, filename)
295 | else:
296 | file_path = os.path.join(base_dir, filename)
297 | if os.path.exists(file_path):
298 | image = Image.open(file_path)
299 | if image.mode != 'RGBA':
300 | image = image.convert('RGBA')
301 | np_image = np.array(image)
302 | rgb_image = np_image[:, :, :3]
303 | alpha_channel = np_image[:, :, 3]
304 | tensor_image = torch.from_numpy(rgb_image / 255.0).float().unsqueeze(0)
305 | mask_tensor = torch.from_numpy(alpha_channel / 255.0).float().unsqueeze(0)
306 | return tensor_image, mask_tensor
307 | else:
308 | return None, None
309 | except Exception as e:
310 | pass
311 | return None, None
312 | NODE_CLASS_MAPPINGS = {
313 | "BridgePreviewNode": BridgePreviewNode
314 | }
315 | NODE_DISPLAY_NAME_MAPPINGS = {
316 | "BridgePreviewNode": "🎈LG_PreviewBridge_V2"
317 | }
--------------------------------------------------------------------------------
/web/upload.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../scripts/app.js";
2 | import { api } from "../../scripts/api.js";
3 | import { MultiButtonWidget } from "./multi_button_widget.js";
4 |
5 | // 获取input目录的文件列表
6 | async function getInputFileList() {
7 | try {
8 | const response = await fetch('/object_info');
9 | const data = await response.json();
10 | // 从 LoadImage 节点类型获取可用文件列表
11 | const loadImageInfo = data.LoadImage;
12 | if (loadImageInfo && loadImageInfo.input && loadImageInfo.input.required && loadImageInfo.input.required.image) {
13 | return loadImageInfo.input.required.image[0]; // 返回文件列表数组
14 | }
15 | return [];
16 | } catch (error) {
17 | console.error("获取文件列表失败:", error);
18 | return [];
19 | }
20 | }
21 |
22 | // 删除图片文件(直接删除,无确认弹窗)
23 | async function deleteImageFile(filename) {
24 | try {
25 | const response = await api.fetchApi('/lg/delete_image', {
26 | method: 'DELETE',
27 | headers: { 'Content-Type': 'application/json' },
28 | body: JSON.stringify({ filename: filename })
29 | });
30 |
31 | if (response.status === 200) {
32 | const result = await response.json();
33 | if (result.success) {
34 | console.log(`文件 ${filename} 删除成功`);
35 | return true;
36 | }
37 | } else {
38 | const error = await response.json();
39 | console.error(`删除失败: ${error.error || '未知错误'}`);
40 | return false;
41 | }
42 | } catch (error) {
43 | console.error(`删除文件失败: ${error}`);
44 | return false;
45 | }
46 | }
47 |
48 | // 加载最新图片 - 参考PB.js的实现方式
49 | async function loadLatestImage(node, folder_type) {
50 | try {
51 | // 获取指定目录中的最新图片
52 | const res = await api.fetchApi(`/lg/get/latest_image?type=${folder_type}`, { cache: "no-store" });
53 |
54 | if (res.status === 200) {
55 | const item = await res.json();
56 |
57 | if (item && item.filename) {
58 | // 找到图像小部件
59 | const imageWidget = node.widgets.find(w => w.name === 'image');
60 | if (!imageWidget) return false;
61 |
62 | // 直接使用原始文件,不需要复制到input
63 | const displayValue = `${item.filename} [${folder_type}]`;
64 | imageWidget.value = displayValue;
65 |
66 | // 通过回调更新预览图像
67 | if (typeof imageWidget.callback === "function") {
68 | imageWidget.callback(displayValue);
69 | }
70 |
71 | // 更新画布
72 | app.graph.setDirtyCanvas(true);
73 | return true;
74 | }
75 | }
76 | } catch (error) {
77 | console.error(`加载图像失败: ${error}`);
78 | }
79 | return false;
80 | }
81 |
82 | // 扩展ContextMenu以支持图片缩略图和删除功能
83 | function extendContextMenuForThumbnails() {
84 | const originalContextMenu = LiteGraph.ContextMenu;
85 |
86 | LiteGraph.ContextMenu = function(values, options) {
87 | const ctx = originalContextMenu.call(this, values, options);
88 |
89 | // 检查是否是LG_LoadImage节点的image widget的下拉菜单
90 | if (options?.className === 'dark' && values?.length > 0) {
91 | // 等待DOM更新后处理
92 | requestAnimationFrame(() => {
93 | const currentNode = LGraphCanvas.active_canvas?.current_node;
94 |
95 | // 检查是否是LG_LoadImage节点的image widget
96 | if (currentNode?.comfyClass === "LG_LoadImage") {
97 | const imageWidget = currentNode.widgets?.find(w => w.name === 'image');
98 |
99 | if (imageWidget && imageWidget.options?.values?.length === values.length) {
100 | // 限制菜单宽度 - 调整为更宽
101 | ctx.root.style.maxWidth = '400px';
102 | ctx.root.style.minWidth = '350px';
103 |
104 | // 为每个菜单项添加缩略图和删除按钮
105 | const menuItems = ctx.root.querySelectorAll('.litemenu-entry');
106 |
107 | menuItems.forEach((item, index) => {
108 | if (index < values.length) {
109 | const filename = values[index];
110 | addThumbnailAndDeleteToMenuItem(item, filename, currentNode, ctx);
111 | }
112 | });
113 | }
114 | }
115 | });
116 | }
117 |
118 | return ctx;
119 | };
120 |
121 | // 保持原型链
122 | LiteGraph.ContextMenu.prototype = originalContextMenu.prototype;
123 | }
124 |
125 | // 为菜单项添加缩略图和删除按钮
126 | function addThumbnailAndDeleteToMenuItem(menuItem, filename, node, contextMenu) {
127 | // 避免重复添加
128 | if (menuItem.querySelector('.thumbnail-container')) {
129 | return;
130 | }
131 |
132 | // 保存原始文本内容
133 | const originalText = menuItem.textContent;
134 |
135 | // 清空菜单项内容
136 | menuItem.innerHTML = '';
137 |
138 | // 设置菜单项样式为flex布局
139 | menuItem.style.cssText += `
140 | display: flex;
141 | align-items: center;
142 | padding: 6px 12px;
143 | min-height: 48px;
144 | position: relative;
145 | `;
146 |
147 | // 创建缩略图容器
148 | const thumbnailContainer = document.createElement('div');
149 | thumbnailContainer.className = 'thumbnail-container';
150 | thumbnailContainer.style.cssText = `
151 | width: 40px;
152 | height: 40px;
153 | margin-right: 10px;
154 | border-radius: 4px;
155 | overflow: hidden;
156 | background: #222;
157 | display: flex;
158 | align-items: center;
159 | justify-content: center;
160 | flex-shrink: 0;
161 | border: 1px solid #444;
162 | `;
163 |
164 | // 创建缩略图
165 | const thumbnail = document.createElement('img');
166 | thumbnail.style.cssText = `
167 | max-width: 100%;
168 | max-height: 100%;
169 | object-fit: cover;
170 | `;
171 |
172 | // 设置图片源
173 | thumbnail.src = `/view?filename=${encodeURIComponent(filename)}&type=input&subfolder=`;
174 | thumbnail.alt = filename;
175 |
176 | // 处理图片加载失败
177 | thumbnail.onerror = function() {
178 | thumbnailContainer.innerHTML = `
179 | 无
预览
185 | `;
186 | };
187 |
188 | thumbnailContainer.appendChild(thumbnail);
189 |
190 | // 创建文件名标签
191 | const textLabel = document.createElement('span');
192 |
193 | // 截断长文件名 - 保留前10位和后10位文件名及扩展名
194 | let displayName = originalText;
195 | if (displayName.length > 35) {
196 | // 保留文件扩展名
197 | const lastDotIndex = displayName.lastIndexOf('.');
198 | if (lastDotIndex > 0) {
199 | const name = displayName.substring(0, lastDotIndex);
200 | const extension = displayName.substring(lastDotIndex);
201 | if (name.length > 20) {
202 | // 保留前10位 + ... + 后10位 + 扩展名
203 | const firstPart = name.substring(0, 10);
204 | const lastPart = name.substring(name.length - 10);
205 | displayName = firstPart + '...' + lastPart + extension;
206 | }
207 | } else {
208 | // 没有扩展名的情况,保留前10位和后10位
209 | if (displayName.length > 20) {
210 | const firstPart = displayName.substring(0, 10);
211 | const lastPart = displayName.substring(displayName.length - 10);
212 | displayName = firstPart + '...' + lastPart;
213 | }
214 | }
215 | }
216 |
217 | textLabel.textContent = displayName;
218 | textLabel.title = originalText; // 悬停时显示完整文件名
219 | textLabel.style.cssText = `
220 | color: inherit;
221 | font-size: inherit;
222 | overflow: hidden;
223 | text-overflow: ellipsis;
224 | white-space: nowrap;
225 | flex: 1;
226 | cursor: pointer;
227 | max-width: 280px;
228 | min-width: 0;
229 | `;
230 |
231 | // 创建删除按钮 - 扩大点击范围
232 | const deleteButton = document.createElement('button');
233 | deleteButton.innerHTML = '✕';
234 | deleteButton.title = `删除 ${filename}`;
235 | deleteButton.style.cssText = `
236 | width: 28px;
237 | height: 28px;
238 | border: none;
239 | background: transparent;
240 | color: #888;
241 | border-radius: 4px;
242 | cursor: pointer;
243 | font-size: 14px;
244 | font-weight: bold;
245 | margin-left: 8px;
246 | flex-shrink: 0;
247 | display: flex;
248 | align-items: center;
249 | justify-content: center;
250 | opacity: 0.7;
251 | transition: all 0.15s;
252 | padding: 0;
253 | `;
254 |
255 | // 删除按钮悬停效果 - 更明显的反馈
256 | deleteButton.addEventListener('mouseenter', () => {
257 | deleteButton.style.opacity = '1';
258 | deleteButton.style.color = '#fff';
259 | deleteButton.style.background = 'rgba(255, 255, 255, 0.15)';
260 | deleteButton.style.transform = 'scale(1.05)';
261 | });
262 |
263 | deleteButton.addEventListener('mouseleave', () => {
264 | deleteButton.style.opacity = '0.7';
265 | deleteButton.style.color = '#888';
266 | deleteButton.style.background = 'transparent';
267 | deleteButton.style.transform = 'scale(1)';
268 | });
269 |
270 | // 删除按钮点击事件 - 快速删除,无动画延迟
271 | deleteButton.addEventListener('click', async (e) => {
272 | e.stopPropagation(); // 阻止触发菜单项选择
273 | e.preventDefault();
274 |
275 | // 立即显示删除中状态
276 | deleteButton.innerHTML = '⋯';
277 | deleteButton.style.pointerEvents = 'none';
278 | deleteButton.style.opacity = '0.5';
279 |
280 | // 直接执行删除操作,无确认弹窗
281 | const deleted = await deleteImageFile(filename);
282 |
283 | if (deleted) {
284 | // 立即移除菜单项,无动画延迟
285 | if (menuItem.parentNode) {
286 | menuItem.parentNode.removeChild(menuItem);
287 | }
288 |
289 | // 更新节点的文件列表
290 | const imageWidget = node.widgets.find(w => w.name === 'image');
291 | if (imageWidget) {
292 | // 重新获取文件列表
293 | const fileList = await getInputFileList();
294 | imageWidget.options.values = fileList;
295 |
296 | // 如果删除的是当前选中的文件,选择第一个可用文件
297 | if (imageWidget.value === filename) {
298 | imageWidget.value = fileList.length > 0 ? fileList[0] : '';
299 |
300 | // 触发回调更新预览
301 | if (typeof imageWidget.callback === "function") {
302 | imageWidget.callback(imageWidget.value);
303 | }
304 | }
305 |
306 | // 更新画布
307 | app.graph.setDirtyCanvas(true);
308 | }
309 |
310 | // 检查是否还有剩余菜单项,如果没有则关闭菜单
311 | const remainingItems = contextMenu.root.querySelectorAll('.litemenu-entry');
312 | if (remainingItems.length === 0) {
313 | contextMenu.close();
314 | }
315 | } else {
316 | // 删除失败,恢复按钮状态
317 | deleteButton.innerHTML = '✕';
318 | deleteButton.style.pointerEvents = 'auto';
319 | deleteButton.style.opacity = '0.7';
320 | }
321 | });
322 |
323 | // 创建可点击区域(除了删除按钮)
324 | const clickableArea = document.createElement('div');
325 | clickableArea.style.cssText = `
326 | display: flex;
327 | align-items: center;
328 | flex: 1;
329 | cursor: pointer;
330 | `;
331 |
332 | clickableArea.appendChild(thumbnailContainer);
333 | clickableArea.appendChild(textLabel);
334 |
335 | // 为可点击区域添加选择事件
336 | clickableArea.addEventListener('click', () => {
337 | // 模拟原始菜单项点击
338 | const imageWidget = node.widgets.find(w => w.name === 'image');
339 | if (imageWidget) {
340 | imageWidget.value = filename;
341 |
342 | // 触发回调
343 | if (typeof imageWidget.callback === "function") {
344 | imageWidget.callback(filename);
345 | }
346 |
347 | // 更新画布
348 | app.graph.setDirtyCanvas(true);
349 | }
350 |
351 | // 关闭菜单
352 | contextMenu.close();
353 | });
354 |
355 | // 组装菜单项
356 | menuItem.appendChild(clickableArea);
357 | menuItem.appendChild(deleteButton);
358 |
359 | // 移除原有的点击事件,因为我们现在有自定义的点击处理
360 | menuItem.onclick = null;
361 | }
362 |
363 | app.registerExtension({
364 | name: "Comfy.LG.LoadImageButtons",
365 |
366 | init() {
367 | // 扩展ContextMenu以支持缩略图和删除功能
368 | extendContextMenuForThumbnails();
369 | },
370 |
371 | async beforeRegisterNodeDef(nodeType, nodeData, app) {
372 | if (nodeData.name !== "LG_LoadImage") return;
373 |
374 | const onNodeCreated = nodeType.prototype.onNodeCreated;
375 |
376 | nodeType.prototype.onNodeCreated = function() {
377 | const result = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;
378 |
379 | // 使用多按钮组件创建刷新按钮
380 | const refreshWidget = this.addCustomWidget(MultiButtonWidget(app, "Refresh From", {
381 | labelWidth: 80,
382 | buttonSpacing: 4
383 | }, [
384 | {
385 | text: "Temp",
386 | callback: () => {
387 | loadLatestImage(this, "temp").then(success => {
388 | if (success) {
389 | app.graph.setDirtyCanvas(true);
390 | }
391 | });
392 | }
393 | },
394 | {
395 | text: "Output",
396 | callback: () => {
397 | loadLatestImage(this, "output").then(success => {
398 | if (success) {
399 | app.graph.setDirtyCanvas(true);
400 | }
401 | });
402 | }
403 | }
404 | ]));
405 | refreshWidget.serialize = false;
406 |
407 | return result;
408 | };
409 | }
410 | });
411 |
412 | app.registerExtension({
413 | name: "Comfy.LG.LoadImage_V2",
414 |
415 | init() {
416 | // 扩展ContextMenu以支持缩略图和删除功能
417 | extendContextMenuForThumbnails();
418 | },
419 |
420 | async beforeRegisterNodeDef(nodeType, nodeData, app) {
421 | if (nodeData.name !== "LG_LoadImage_V2") return;
422 |
423 | // 保持节点的原始行为,不添加额外的前端功能
424 | // auto_refresh的逻辑现在完全在后端处理
425 | }
426 | });
--------------------------------------------------------------------------------
/web/image_cropper.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../scripts/app.js";
2 | import { api } from "../../scripts/api.js";
3 |
4 | // 创建裁剪模态窗口的HTML结构
5 | function createCropperModal() {
6 | const modal = document.createElement("dialog");
7 | modal.id = "image-cropper-modal";
8 | modal.innerHTML = `
9 |