├── exif ├── __init__.py └── exif.py ├── imgio ├── __init__.py └── converter.py ├── webuiapi ├── __init__.py └── out_api.py ├── .gitignore ├── __init__.py ├── README.md ├── pyproject.toml ├── .github └── workflows │ └── publish.yml ├── utils └── tagger.py ├── nodes.py ├── install.py ├── conversion.py ├── external.py ├── autonode.py ├── crypto.py ├── math_nodes.py ├── logic_gates.py ├── auxilary.py ├── randomness.py ├── pystructure.py └── io_node.py /exif/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgio/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /webuiapi/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Logic Utilities for someone who wants 2 | 3 | ~~prime list calculation in comfyui~~ 4 | Proper documentation is being prepared, however there are too many nodes 5 | ![image](https://github.com/user-attachments/assets/8e388417-6912-41d7-98fa-798b50eacfda) 6 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-logicutils" 3 | description = "Logical Utils (compare, string, boolean operations) for ComfyUI" 4 | version = "1.7.2" 5 | license = "MIT" 6 | 7 | [project.urls] 8 | Repository = "https://github.com/aria1th/ComfyUI-LogicUtils" 9 | # Used by Comfy Registry https://comfyregistry.org 10 | 11 | [tool.comfy] 12 | PublisherId = "angelbottomless" 13 | DisplayName = "ComfyUI-LogicUtils" 14 | Icon = "" 15 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | permissions: 11 | issues: write 12 | 13 | jobs: 14 | publish-node: 15 | name: Publish Custom Node to registry 16 | runs-on: ubuntu-latest 17 | if: ${{ github.repository_owner == 'aria1th' }} 18 | steps: 19 | - name: Check out code 20 | uses: actions/checkout@v4 21 | - name: Publish Custom Node 22 | uses: Comfy-Org/publish-node-action@v1 23 | with: 24 | ## Add your own personal access token to your Github Repository secrets and reference it here. 25 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 26 | -------------------------------------------------------------------------------- /utils/tagger.py: -------------------------------------------------------------------------------- 1 | try: 2 | from imgutils.tagging import get_wd14_tags 3 | from imgutils.tagging.wd14 import MODEL_NAMES as tagger_model_names 4 | except ImportError: 5 | def get_wd14_tags(image_path): 6 | raise Exception("Tagger feature not available, please install dghs-imgutils") 7 | tagger_model_names = { 8 | "EVA02_Large": None, 9 | "ViT_Large": None, 10 | "SwinV2": None, 11 | "ConvNext": None, 12 | "ConvNextV2": None, 13 | "ViT": None, 14 | "MOAT": None, 15 | "SwinV2_v3": None, 16 | "ConvNext_v3": None, 17 | "ViT_v3": None, 18 | } 19 | from typing import Union 20 | from PIL import Image 21 | 22 | def get_rating_class(rating): 23 | # argmax 24 | return max(rating, key=rating.get) 25 | 26 | def get_tags_above_threshold(tags, threshold=0.4): 27 | return [tag for tag, score in tags.items() if score > threshold] 28 | 29 | def replace_underscore(tag): 30 | return tag.replace('_', ' ') 31 | 32 | def get_tags(image_path:Union[str, Image.Image], threshold:float = 0.4, replace:bool = False, model_name:str = "SwinV2") -> dict[str, list[str]]: 33 | result = {} 34 | rating, features, chars = get_wd14_tags(image_path, model_name) 35 | result['rating'] = get_rating_class(rating) 36 | result['tags'] = get_tags_above_threshold(features, threshold) 37 | result['chars'] = get_tags_above_threshold(chars, threshold) 38 | if replace: 39 | result['tags'] = [replace_underscore(tag) for tag in result['tags']] 40 | result['chars'] = [replace_underscore(tag) for tag in result['chars']] 41 | return result 42 | try: 43 | tagger_keys = list(tagger_model_names.keys()) 44 | except NameError: 45 | tagger_keys = [] 46 | -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | 2 | from .install import initialization 3 | 4 | initialization() 5 | 6 | from .logic_gates import CLASS_MAPPINGS as LogicMapping, CLASS_NAMES as LogicNames 7 | from .randomness import CLASS_MAPPINGS as RandomMapping, CLASS_NAMES as RandomNames 8 | from .conversion import CLASS_MAPPINGS as ConversionMapping, CLASS_NAMES as ConversionNames 9 | from .math_nodes import CLASS_MAPPINGS as MathMapping, CLASS_NAMES as MathNames 10 | from .io_node import CLASS_MAPPINGS as IOMapping, CLASS_NAMES as IONames 11 | from .auxilary import CLASS_MAPPINGS as AuxilaryMapping, CLASS_NAMES as AuxilaryNames 12 | from .external import CLASS_MAPPINGS as ExternalMapping, CLASS_NAMES as ExternalNames 13 | 14 | 15 | 16 | NODE_CLASS_MAPPINGS = { 17 | } 18 | NODE_CLASS_MAPPINGS.update(IOMapping) 19 | NODE_CLASS_MAPPINGS.update(LogicMapping) 20 | NODE_CLASS_MAPPINGS.update(RandomMapping) 21 | NODE_CLASS_MAPPINGS.update(ConversionMapping) 22 | NODE_CLASS_MAPPINGS.update(MathMapping) 23 | NODE_CLASS_MAPPINGS.update(ExternalMapping) 24 | NODE_CLASS_MAPPINGS.update(AuxilaryMapping) 25 | 26 | 27 | 28 | NODE_DISPLAY_NAME_MAPPINGS = { 29 | 30 | } 31 | NODE_DISPLAY_NAME_MAPPINGS.update(IONames) 32 | NODE_DISPLAY_NAME_MAPPINGS.update(LogicNames) 33 | NODE_DISPLAY_NAME_MAPPINGS.update(RandomNames) 34 | NODE_DISPLAY_NAME_MAPPINGS.update(ConversionNames) 35 | NODE_DISPLAY_NAME_MAPPINGS.update(MathNames) 36 | NODE_DISPLAY_NAME_MAPPINGS.update(ExternalNames) 37 | NODE_DISPLAY_NAME_MAPPINGS.update(AuxilaryNames) 38 | 39 | 40 | try: 41 | from .pystructure import CLASS_MAPPINGS as PyStructureMapping, CLASS_NAMES as PyStructureNames 42 | NODE_CLASS_MAPPINGS.update(PyStructureMapping) 43 | NODE_DISPLAY_NAME_MAPPINGS.update(PyStructureNames) 44 | except ImportError: 45 | pass 46 | try: 47 | from .crypto import CLASS_MAPPINGS as SecureMapping, CLASS_NAMES as SecureNames 48 | NODE_CLASS_MAPPINGS.update(SecureMapping) 49 | NODE_DISPLAY_NAME_MAPPINGS.update(SecureNames) 50 | except ImportError: 51 | pass -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | #https://github.com/ltdrdata/ComfyUI-Impact-Pack/blob/Main/install.py 2 | import sys 3 | import subprocess 4 | import threading 5 | import locale 6 | 7 | def handle_stream(stream, is_stdout): 8 | stream.reconfigure(encoding=locale.getpreferredencoding(), errors='replace') 9 | 10 | for msg in stream: 11 | if is_stdout: 12 | print(msg, end="", file=sys.stdout) 13 | else: 14 | print(msg, end="", file=sys.stderr) 15 | 16 | def process_wrap(cmd_str, cwd=None, handler=None): 17 | process = subprocess.Popen(cmd_str, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1) 18 | 19 | if handler is None: 20 | handler = handle_stream 21 | 22 | stdout_thread = threading.Thread(target=handler, args=(process.stdout, True)) 23 | stderr_thread = threading.Thread(target=handler, args=(process.stderr, False)) 24 | 25 | stdout_thread.start() 26 | stderr_thread.start() 27 | 28 | stdout_thread.join() 29 | stderr_thread.join() 30 | 31 | return process.wait() 32 | 33 | if "python_embeded" in sys.executable or "python_embedded" in sys.executable: #standalone python version 34 | pip_install = [sys.executable, '-s', '-m', 'pip', 'install', "-U"] 35 | else: 36 | pip_install = [sys.executable, '-m', 'pip', 'install', "-U"] 37 | 38 | def initialization(): 39 | try: 40 | import piexif 41 | except ImportError: 42 | run_installation("piexif") 43 | try: 44 | import chardet 45 | except ImportError: 46 | run_installation("chardet") 47 | try: 48 | from imgutils.tagging import get_wd14_tags 49 | except ImportError: 50 | run_installation("dghs-imgutils[gpu]") 51 | try: 52 | from Crypto.PublicKey import RSA 53 | except ImportError: 54 | run_installation("pycryptodome") 55 | 56 | 57 | def run_installation(pkg_name: str): 58 | print(f"Installing {pkg_name}...") 59 | if process_wrap(pip_install + [pkg_name]) == 0: 60 | print(f"Successfully installed {pkg_name}") 61 | else: 62 | print(f"Failed to install {pkg_name}") 63 | 64 | if __name__ == "__main__": 65 | initialization() 66 | print("Installation completed.") 67 | -------------------------------------------------------------------------------- /conversion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Converts int / float / boolean 3 | Note that it depends on the order of the conversion 4 | """ 5 | from .autonode import validate, node_wrapper, get_node_names_mappings, anytype 6 | classes = [] 7 | node = node_wrapper(classes) 8 | 9 | conversion_operators = { 10 | "Int" : int, 11 | "Float" : float, 12 | "Boolean" : bool, 13 | "String" : str 14 | } 15 | def create_class(type_to): 16 | class_name = "ConvertAny2{}".format(type_to) 17 | class CustomClass: 18 | FUNCTION = "convert" 19 | RETURN_TYPES = (type_to.upper(),) 20 | CATEGORY = "Conversion" 21 | custom_name = "Convert to {}".format(type_to) 22 | @staticmethod 23 | def convert(input1): 24 | return (conversion_operators[type_to](input1),) 25 | @classmethod 26 | def INPUT_TYPES(cls): 27 | return { 28 | "required": { 29 | "input1": (anytype, {"default": 0.0}), 30 | } 31 | } 32 | CustomClass.__name__ = class_name 33 | node(CustomClass) 34 | return CustomClass 35 | 36 | 37 | @node 38 | class StringListToCombo: 39 | """ 40 | Converts raw string, separates with separator, then picks the first element or the element at index 41 | """ 42 | RETURN_TYPES = (anytype,) 43 | @classmethod 44 | def INPUT_TYPES(s): 45 | return { 46 | "required": { 47 | "string": ("STRING", {"default": ""}), 48 | "separator": ("STRING", {"default": "$"}), 49 | }, 50 | "optional": { 51 | "index": ("INT", {"default": 0}), 52 | } 53 | } 54 | FUNCTION = "stringListToCombo" 55 | CATEGORY = "Logic Gates" 56 | custom_name = "String List to Combo" 57 | def stringListToCombo(self, string, separator, index = 0): 58 | if isinstance(string, (float, int, bool)): 59 | return (string,) 60 | if separator == "" or separator == None or separator not in string: 61 | return (string,) 62 | # check length 63 | splitted = string.split(separator) 64 | if index >= len(splitted): 65 | return (splitted[-1],) 66 | return (splitted[index],) 67 | 68 | @node 69 | class ConvertComboToString: 70 | """ 71 | Converts raw list to string, separated with separator 72 | """ 73 | RETURN_TYPES = ("STRING",) 74 | @classmethod 75 | def INPUT_TYPES(s): 76 | return { 77 | "required": { 78 | "combo": (anytype, {"default": []}), 79 | "separator": ("STRING", {"default": "$"}), 80 | } 81 | } 82 | FUNCTION = "convertComboToString" 83 | CATEGORY = "Logic Gates" 84 | custom_name = "Convert Combo to String" 85 | def convertComboToString(self, combo, separator): 86 | if isinstance(combo, (str, float, int, bool)): 87 | return (combo,) 88 | return (separator.join(combo),) 89 | 90 | for type_to in conversion_operators: 91 | create_class(type_to) 92 | 93 | CLASS_MAPPINGS, CLASS_NAMES = get_node_names_mappings(classes) 94 | validate(classes) 95 | -------------------------------------------------------------------------------- /external.py: -------------------------------------------------------------------------------- 1 | from .autonode import node_wrapper, get_node_names_mappings, validate, anytype 2 | from .imgio.converter import PILHandlingHodes 3 | from .webuiapi.out_api import get_image_from_prompt, get_image_from_prompt_fallback 4 | 5 | from PIL import Image 6 | 7 | 8 | external_classes = [] 9 | external_nodes = node_wrapper(external_classes) 10 | 11 | 12 | @external_nodes 13 | 14 | class SDWebuiAPINode: 15 | FUNCTION = "get_image_from_prompt" 16 | RETURN_TYPES = ("IMAGE",) 17 | @classmethod 18 | def INPUT_TYPES(cls): 19 | return { 20 | "required": { 21 | "prompt": ("STRING", {"default": ""}), 22 | "api_endpoint": ("STRING", {"default": ""}), 23 | }, 24 | "optional": { 25 | "auth": ("STRING", {"default": ""}), 26 | "seed": ("INT", {"default": -1}), 27 | "negative_prompt": ("STRING", {"default": ""}), 28 | "steps": ("INT", {"default": 28}), 29 | "width": ("INT", {"default": 1024}), 30 | "height": ("INT", {"default": 1024}), 31 | "hr_scale": ("FLOAT", {"default": 1.5}), 32 | "hr_upscale": ("STRING", {"default": "Latent"}), 33 | "enable_hr": ("BOOLEAN", {"default": False}), 34 | "cfg_scale": ("INT", {"default": 7}), 35 | } 36 | } 37 | CATEGORY = "WebUI API" 38 | custom_name = "Get Image From Prompt" 39 | @PILHandlingHodes.output_wrapper 40 | def get_image_from_prompt(self, prompt, api_endpoint, auth="", seed=-1, negative_prompt="", steps=28, width=1024, height=1024, hr_scale=1.5, hr_upscale="Latent", enable_hr=False, cfg_scale=7): 41 | return (get_image_from_prompt(prompt, api_endpoint, auth, seed, negative_prompt, steps, width, height, hr_scale, hr_upscale, enable_hr, cfg_scale)[0],) 42 | @external_nodes 43 | class SDWebuiAPIFallbackNode: 44 | FUNCTION = "get_image_from_prompt_fallback" 45 | RETURN_TYPES = ("IMAGE",) 46 | @classmethod 47 | def INPUT_TYPES(cls): 48 | return { 49 | "required": { 50 | "prompt": ("STRING", {"default": ""}), 51 | "api_endpoint": ("STRING", {"default": ""}), 52 | }, 53 | "optional": { 54 | "auth": ("STRING", {"default": ""}), 55 | "seed": ("INT", {"default": -1}), 56 | "negative_prompt": ("STRING", {"default": ""}), 57 | "steps": ("INT", {"default": 28}), 58 | "width": ("INT", {"default": 1024}), 59 | "height": ("INT", {"default": 1024}), 60 | "hr_scale": ("FLOAT", {"default": 1.5}), 61 | "hr_upscale": ("STRING", {"default": "Latent"}), 62 | "enable_hr": ("BOOLEAN", {"default": False}), 63 | "cfg_scale": ("INT", {"default": 7}), 64 | } 65 | } 66 | CATEGORY = "WebUI API" 67 | custom_name = "Get Image From Prompt (Fallback)" 68 | @PILHandlingHodes.output_wrapper 69 | def get_image_from_prompt_fallback(self, prompt, api_endpoint, auth="", seed=-1, negative_prompt="", steps=28, width=1024, height=1024, hr_scale=1.5, hr_upscale="Latent", enable_hr=False, cfg_scale=7): 70 | return (get_image_from_prompt_fallback(prompt, api_endpoint, auth, seed, negative_prompt, steps, width, height, hr_scale, hr_upscale, enable_hr, cfg_scale)[0],) 71 | 72 | CLASS_MAPPINGS, CLASS_NAMES = get_node_names_mappings(external_classes) 73 | validate(external_classes) -------------------------------------------------------------------------------- /webuiapi/out_api.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | from PIL import Image 4 | import io 5 | import base64 6 | from typing import Optional 7 | 8 | 9 | def send_request(api_endpoint:str, auth:Optional[str], arguments:dict) -> list[Image.Image]: 10 | session = requests.Session() 11 | if auth: 12 | session.auth = (auth.split(":")[0], auth.split(":")[1]) 13 | api_endpoint = api_endpoint.rstrip("/") + "/sdapi/v1/txt2img" 14 | response = session.post(api_endpoint, json=arguments) 15 | response.raise_for_status() 16 | response_json = response.json() 17 | if "images" in response_json.keys(): 18 | images = [Image.open(io.BytesIO(base64.b64decode(i))) for i in response_json["images"]] 19 | elif "image" in response_json.keys(): 20 | images = [Image.open(io.BytesIO(base64.b64decode(response_json["image"])))] 21 | else: 22 | raise ValueError("No image data in response") 23 | return images 24 | 25 | def construct_args( 26 | prompt:str, 27 | seed:int=-1, 28 | negative_prompt:Optional[str] = None, 29 | steps:int = 28, 30 | width:int = 1024, 31 | height:int = 1024, 32 | hr_scale:float = 1.5, 33 | hr_upscale:str = "Latent", 34 | enable_hr:bool = False, 35 | cfg_scale:int = 7, 36 | ): 37 | arguments = { 38 | "prompt": prompt, 39 | "seed": seed, 40 | "steps": steps, 41 | "width": width, 42 | "height": height, 43 | "hr_scale": hr_scale, 44 | "hr_upscale": hr_upscale, 45 | "enable_hr": enable_hr, 46 | "cfg_scale": cfg_scale, 47 | } 48 | if negative_prompt: 49 | arguments["negative_prompt"] = negative_prompt 50 | else: 51 | arguments["negative_prompt"] = "" 52 | return arguments 53 | 54 | def get_image_from_prompt( 55 | prompt:str, 56 | api_endpoint:str, 57 | auth:Optional[str]=None, 58 | seed:int=-1, 59 | negative_prompt:Optional[str] = None, 60 | steps:int = 28, 61 | width:int = 1024, 62 | height:int = 1024, 63 | hr_scale:float = 1.5, 64 | hr_upscale:str = "Latent", 65 | enable_hr:bool = False, 66 | cfg_scale:int = 7, 67 | ): 68 | arguments = construct_args( 69 | prompt=prompt, 70 | seed=seed, 71 | negative_prompt=negative_prompt, 72 | steps=steps, 73 | width=width, 74 | height=height, 75 | hr_scale=hr_scale, 76 | hr_upscale=hr_upscale, 77 | enable_hr=enable_hr, 78 | cfg_scale=cfg_scale, 79 | ) 80 | return send_request(api_endpoint, auth, arguments) 81 | 82 | def get_image_from_prompt_fallback( 83 | prompt:str, 84 | api_endpoint:str, 85 | auth:Optional[str]=None, 86 | seed:int=-1, 87 | negative_prompt:Optional[str] = None, 88 | steps:int = 28, 89 | width:int = 1024, 90 | height:int = 1024, 91 | hr_scale:float = 1.5, 92 | hr_upscale:str = "Latent", 93 | enable_hr:bool = False, 94 | cfg_scale:int = 7, 95 | ): 96 | arguments = construct_args( 97 | prompt=prompt, 98 | seed=seed, 99 | negative_prompt=negative_prompt, 100 | steps=steps, 101 | width=width, 102 | height=height, 103 | hr_scale=hr_scale, 104 | hr_upscale=hr_upscale, 105 | enable_hr=enable_hr, 106 | cfg_scale=cfg_scale, 107 | ) 108 | try: 109 | return send_request(api_endpoint, auth, arguments) 110 | except Exception as e: 111 | # create blank image 112 | return [Image.new("RGB", (width, height), (255, 255, 255))] 113 | -------------------------------------------------------------------------------- /autonode.py: -------------------------------------------------------------------------------- 1 | """ 2 | AutoNode setup - AngelBottomless@github 3 | # By following the example, you can prepare "decorator" that will automatically collect required information for node registration. 4 | fundamental_classes = [] 5 | fundamental_node = node_wrapper(fundamental_classes) 6 | 7 | # Then, you can define the classes that will be used in the node. "FUNCTION", "INPUT_TYPES", "RETURN_TYPES", "CATEGORY" attributes are used for node registration. 8 | # You can set "custom_name" attribute to set the name of the node that will be displayed in the UI. 9 | # Pleare run validate(fundamental_classes) to check if all required attributes are set. 10 | # You can also use anytype to represent any type of input. 11 | @fundamental_node 12 | class SleepNodeAny: 13 | FUNCTION = "sleep" 14 | RETURN_TYPES = (anytype,) 15 | CATEGORY = "Misc" 16 | custom_name = "SleepNode" 17 | @staticmethod 18 | def sleep(interval, inputs): 19 | time.sleep(interval) 20 | return (inputs,) 21 | @classmethod 22 | def INPUT_TYPES(cls): 23 | return { 24 | "required": { 25 | "interval": ("FLOAT", {"default": 0.0}), 26 | }, 27 | "optional": { 28 | "inputs": (anytype, {"default": 0.0}), 29 | } 30 | 31 | # Then, at the end of each node registeration class, run the following to set up static variables. 32 | CLASS_MAPPINGS, CLASS_NAMES = get_node_names_mappings(fundamental_classes) 33 | 34 | # Then, at the other script - here, nodes.py, you can import the CLASS_MAPPINGS and CLASS_NAMES to register the nodes. 35 | from .io_node import CLASS_MAPPINGS as IOMapping, CLASS_NAMES as IONames 36 | 37 | # it collects NODE_CLASS_MAPPINGS and NODE_DISPLAY_NAME_MAPPINGS, and updates them with the new mappings. Note that same keys will be overwritten. 38 | 39 | # Finally, at the __init__.py, you can import the NODE_CLASS_MAPPINGS and NODE_DISPLAY_NAME_MAPPINGS to register the nodes. 40 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 41 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 42 | 43 | # Then, you can use the registered nodes in the UI! 44 | """ 45 | from inspect import signature 46 | 47 | def get_node_names_mappings(classes): 48 | node_names = {} 49 | node_classes = {} 50 | for cls in classes: 51 | # check if "custom_name" attribute is set 52 | if hasattr(cls, "custom_name"): 53 | node_names[cls.__name__] = cls.custom_name 54 | node_classes[cls.__name__] = cls 55 | return node_classes, node_names 56 | 57 | def node_wrapper(container): 58 | def wrap_class(cls): 59 | container.append(cls) 60 | return cls 61 | return wrap_class 62 | 63 | def validate(container): 64 | # check if "custom_name", "FUNCTION", "INPUT_TYPES", "RETURN_TYPES", "CATEGORY" attributes are set 65 | for cls in container: 66 | for attr in ["FUNCTION", "INPUT_TYPES", "RETURN_TYPES", "CATEGORY"]: 67 | if not hasattr(cls, attr): 68 | raise Exception("Class {} doesn't have attribute {}".format(cls.__name__, attr)) 69 | return_type = cls.RETURN_TYPES 70 | if not isinstance(return_type, tuple): 71 | raise Exception(f"RETURN_TYPES must be a tuple, got {type(return_type)} in {cls.__name__}") 72 | if not all(isinstance(x, str) for x in return_type): 73 | raise Exception(f"RETURN_TYPES must be a tuple of strings, got {return_type} in {cls.__name__}") 74 | input_keys = ["self"] 75 | for key in cls.INPUT_TYPES()["required"]: 76 | input_keys.append(key) 77 | for key in cls.INPUT_TYPES().get("optional", {}): 78 | input_keys.append(key) 79 | function_kwargs = signature(cls.__dict__[cls.FUNCTION]).parameters.keys() 80 | function_kwargs = list(function_kwargs) + ["self"] 81 | # if args/kwargs are in function kwargs, warn and skip 82 | if "args" in function_kwargs: 83 | #print(f"Warning: args in function arguments in {cls.__name__}, skipping argument validation") 84 | continue 85 | if "kwargs" in function_kwargs: 86 | #print(f"Warning: kwargs in function arguments in {cls.__name__}, skipping argument validation") 87 | continue 88 | # input kwargs are subset of function kwargs 89 | if not set(input_keys).issubset(function_kwargs): 90 | raise Exception(f"INPUT_TYPES and function arguments must match in {cls.__name__}, input_types: {input_keys}, function arguments: {function_kwargs}") 91 | # if not exact match, print warning 92 | if len(set(input_keys)) != len(set(function_kwargs)): 93 | #print(f"Warning: INPUT_TYPES and function arguments don't match in {cls.__name__}, input_types: {input_keys}, function arguments: {function_kwargs}") 94 | pass 95 | # AllTrue class hijacks the isinstance, issubclass, bool, str, jsonserializable, eq, ne methods to always return True 96 | class AllTrue(str): 97 | def __init__(self, representation=None) -> None: 98 | self.repr = representation 99 | pass 100 | def __ne__(self, __value: object) -> bool: 101 | return False 102 | # isinstance, jsonserializable hijack 103 | def __instancecheck__(self, instance): 104 | return True 105 | def __subclasscheck__(self, subclass): 106 | return True 107 | def __bool__(self): 108 | return True 109 | def __str__(self): 110 | return self.repr 111 | # jsonserializable hijack 112 | def __jsonencode__(self): 113 | return self.repr 114 | def __repr__(self) -> str: 115 | return self.repr 116 | def __eq__(self, __value: object) -> bool: 117 | return True 118 | anytype = AllTrue("*") # when a != b is called, it will always return False 119 | PILImage = object() # dummy object to represent PIL.Image -------------------------------------------------------------------------------- /crypto.py: -------------------------------------------------------------------------------- 1 | import os, io 2 | import numpy as np 3 | from .imgio.converter import PILHandlingHodes 4 | from .autonode import node_wrapper, get_node_names_mappings, validate 5 | 6 | from PIL import Image 7 | try: 8 | from Crypto.PublicKey import RSA 9 | from Crypto.Cipher import AES, PKCS1_OAEP 10 | from Crypto.Random import get_random_bytes 11 | except ImportError: 12 | print("Crypto library not found. Please install pycryptodome.") 13 | raise 14 | 15 | import torch 16 | from base64 import b64encode, b64decode 17 | 18 | # List of classes to register 19 | secure_classes = [] 20 | secure_node = node_wrapper(secure_classes) 21 | 22 | @secure_node 23 | class SecureBase64Encrypt: 24 | """ 25 | Encrypt an image as a base64 string using RSA public key + AES. 26 | - images: Only the first image is used. 27 | - public_key_pem: RSA public key (PEM string). 28 | Outputs: 'encrypted_base64' string that SecureWebPDecrypt can decrypt. 29 | """ 30 | @classmethod 31 | def INPUT_TYPES(cls): 32 | return { 33 | "required": { 34 | "images": ("IMAGE",), 35 | "public_key_pem": ("STRING", {"multiline": True, "default": ""}), 36 | } 37 | } 38 | 39 | RETURN_TYPES = ("STRING",) 40 | RETURN_NAMES = ("encrypted_base64",) 41 | FUNCTION = "encrypted_base64" 42 | CATEGORY = "image" 43 | custom_name = "Secure Base64 Encrypt" 44 | OUTPUT_NODE = True 45 | RESULT_NODE = True 46 | 47 | def encrypted_base64(self, images, public_key_pem): 48 | # Check input 49 | if images is None or not len(images) or images[0] is None: 50 | raise ValueError("No image provided.") 51 | # Load RSA public key 52 | rsa_key = RSA.import_key(public_key_pem) 53 | cipher_rsa = PKCS1_OAEP.new(rsa_key) 54 | 55 | # Take first image (if there's a batch dimension, pick index [0]) 56 | img_tensor = images[0].clone().detach() 57 | if img_tensor.ndim == 4: 58 | img_tensor = img_tensor[0] 59 | 60 | # Convert [0..1] float => [0..255] uint8 61 | img_array = (255.0 * img_tensor.clamp(0, 1).cpu().numpy()).astype("uint8") 62 | 63 | # If shape is (C,H,W), transpose to (H,W,C). If shape is (H,W,C), leave as is. 64 | if img_array.ndim == 3: 65 | # If the first dimension is small (1,3,4), interpret as channels-first 66 | if img_array.shape[0] in [1,3,4] and img_array.shape[-1] not in [1,3,4]: 67 | img_array = np.transpose(img_array, (1,2,0)) 68 | 69 | # The critical fix: ensure array is contiguous 70 | img_array = np.ascontiguousarray(img_array) 71 | 72 | # Create a PIL Image from array 73 | pil_img = Image.fromarray(img_array, mode="RGB") 74 | 75 | # Save in-memory as lossless WebP 76 | buffer = io.BytesIO() 77 | pil_img.save(buffer, format="WEBP", lossless=True) 78 | image_bytes = buffer.getvalue() 79 | 80 | # Generate a random AES session key, encrypt it with RSA 81 | session_key = get_random_bytes(16) 82 | enc_session_key = cipher_rsa.encrypt(session_key) 83 | 84 | # Encrypt the image bytes with AES-EAX 85 | cipher_aes = AES.new(session_key, AES.MODE_EAX) 86 | ciphertext, tag = cipher_aes.encrypt_and_digest(image_bytes) 87 | 88 | # Build custom envelope 89 | encrypted_blob = ( 90 | b"ENCWEBP" + 91 | len(enc_session_key).to_bytes(2, "big") + 92 | enc_session_key + 93 | bytes([len(cipher_aes.nonce)]) + cipher_aes.nonce + 94 | bytes([len(tag)]) + tag + 95 | ciphertext 96 | ) 97 | 98 | # Base64-encode the blob 99 | encrypted_base64 = b64encode(encrypted_blob).decode("utf-8") 100 | 101 | return (encrypted_base64,) 102 | 103 | 104 | @secure_node 105 | class SecureWebPDecrypt: 106 | """ 107 | Decrypt an encrypted WebP image (or list of them) produced by SecureBase64Encrypt. 108 | Returns a single IMAGE (first one). 109 | """ 110 | @classmethod 111 | def INPUT_TYPES(cls): 112 | return { 113 | "required": { 114 | "encrypted_base64": ("STRING", {"multiline": True, "default": ""}), 115 | "private_key_pem": ("STRING", {"multiline": True, "default": ""}), 116 | } 117 | } 118 | 119 | RETURN_TYPES = ("IMAGE",) 120 | RETURN_NAMES = ("Decrypted_Image",) 121 | FUNCTION = "decrypt_image" 122 | CATEGORY = "image" 123 | custom_name = "Secure WebP Decrypt" 124 | @PILHandlingHodes.output_wrapper 125 | def decrypt_image(self, encrypted_base64, private_key_pem): 126 | # Convert to list if single string 127 | if encrypted_base64 is None: 128 | encrypted_base64 = [] 129 | if isinstance(encrypted_base64, str): 130 | encrypted_base64 = [encrypted_base64] 131 | elif not isinstance(encrypted_base64, (list, tuple)): 132 | raise ValueError("encrypted_base64 must be string or list/tuple.") 133 | 134 | # Import RSA private key 135 | if isinstance(private_key_pem, str): 136 | private_key_pem = private_key_pem.encode("utf-8") 137 | rsa_key = RSA.import_key(private_key_pem) 138 | cipher_rsa = PKCS1_OAEP.new(rsa_key) 139 | 140 | for b64_item in encrypted_base64: 141 | data = b64decode(b64_item) 142 | if data[:7] != b"ENCWEBP": 143 | raise ValueError("Invalid encrypted WebP data (missing header).") 144 | 145 | idx = 7 146 | enc_key_len = int.from_bytes(data[idx : idx + 2], "big") 147 | idx += 2 148 | 149 | enc_session_key = data[idx : idx + enc_key_len] 150 | idx += enc_key_len 151 | 152 | nonce_len = data[idx] 153 | idx += 1 154 | nonce = data[idx : idx + nonce_len] 155 | idx += nonce_len 156 | 157 | tag_len = data[idx] 158 | idx += 1 159 | tag = data[idx : idx + tag_len] 160 | idx += tag_len 161 | 162 | ciphertext = data[idx:] 163 | 164 | # RSA-decrypt the AES session key 165 | session_key = cipher_rsa.decrypt(enc_session_key) 166 | 167 | # AES-EAX decrypt 168 | cipher_aes = AES.new(session_key, AES.MODE_EAX, nonce=nonce) 169 | plaintext = cipher_aes.decrypt(ciphertext) 170 | try: 171 | cipher_aes.verify(tag) 172 | except ValueError: 173 | raise ValueError("Decryption failed: data tampered or wrong key.") 174 | 175 | # plaintext -> PIL 176 | pil_img = Image.open(io.BytesIO(plaintext)).convert("RGB") 177 | return (pil_img, ) 178 | 179 | 180 | # Register node classes with ComfyUI 181 | CLASS_MAPPINGS, CLASS_NAMES = get_node_names_mappings(secure_classes) 182 | validate(secure_classes) 183 | -------------------------------------------------------------------------------- /exif/exif.py: -------------------------------------------------------------------------------- 1 | import json 2 | import gzip 3 | try: 4 | import piexif 5 | except ImportError: 6 | piexif = None 7 | from PIL import Image 8 | 9 | # modules/images.py from Stable Diffusion WebUI 10 | def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]: 11 | if piexif is None: 12 | return None, {} 13 | items = (image.info or {}).copy() 14 | 15 | geninfo = items.pop('parameters', None) 16 | 17 | if "exif" in items: 18 | exif = piexif.load(items["exif"]) 19 | exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'') 20 | try: 21 | exif_comment = piexif.helper.UserComment.load(exif_comment) 22 | except ValueError: 23 | exif_comment = exif_comment.decode('utf8', errors="ignore") 24 | 25 | if exif_comment: 26 | items['exif comment'] = exif_comment 27 | geninfo = exif_comment 28 | 29 | if items.get("Software", None) == "NovelAI": 30 | try: 31 | json_info = json.loads(items["Comment"]) 32 | geninfo = f"""{items["Description"]} 33 | Negative prompt: {json_info["uc"]} 34 | Steps: {json_info["steps"]}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337""" 35 | except Exception: 36 | print("Failed to parse NovelAI info") 37 | 38 | return geninfo, items 39 | 40 | # https://github.com/ashen-sensored/sd_webui_stealth_pnginfo/blob/main/scripts/stealth_pnginfo.py 41 | def read_info_from_image_stealth(image): 42 | # if tensor, convert to PIL image 43 | if hasattr(image, 'cpu'): 44 | image = image.cpu().numpy() #((1, 1, 1280, 3), ' 120 and never_confirmed: 79 | return '' 80 | if index_a == len('stealth_pnginfo') * 8: 81 | decoded_sig = bytearray(int(buffer_a[i:i + 8], 2) for i in 82 | range(0, len(buffer_a), 8)).decode('utf-8', errors='ignore') 83 | if decoded_sig in {'stealth_pnginfo', 'stealth_pngcomp'}: 84 | #print(f"Found signature at {x}, {y}") 85 | confirming_signature = False 86 | sig_confirmed = True 87 | reading_param_len = True 88 | mode = 'alpha' 89 | if decoded_sig == 'stealth_pngcomp': 90 | compressed = True 91 | buffer_a = '' 92 | index_a = 0 93 | never_confirmed = False 94 | else: 95 | read_end = True 96 | break 97 | elif index_rgb == len('stealth_pnginfo') * 8: 98 | decoded_sig = bytearray(int(buffer_rgb[i:i + 8], 2) for i in 99 | range(0, len(buffer_rgb), 8)).decode('utf-8', errors='ignore') 100 | if decoded_sig in {'stealth_rgbinfo', 'stealth_rgbcomp'}: 101 | #print(f"Found signature at {x}, {y}") 102 | confirming_signature = False 103 | sig_confirmed = True 104 | reading_param_len = True 105 | mode = 'rgb' 106 | if decoded_sig == 'stealth_rgbcomp': 107 | compressed = True 108 | buffer_rgb = '' 109 | index_rgb = 0 110 | never_confirmed = False 111 | elif reading_param_len: 112 | if mode == 'alpha': 113 | if index_a == 32: 114 | param_len = int(buffer_a, 2) 115 | reading_param_len = False 116 | reading_param = True 117 | buffer_a = '' 118 | index_a = 0 119 | else: 120 | if index_rgb == 33: 121 | pop = buffer_rgb[-1] 122 | buffer_rgb = buffer_rgb[:-1] 123 | param_len = int(buffer_rgb, 2) 124 | reading_param_len = False 125 | reading_param = True 126 | buffer_rgb = pop 127 | index_rgb = 1 128 | elif reading_param: 129 | if mode == 'alpha': 130 | if index_a == param_len: 131 | binary_data = buffer_a 132 | read_end = True 133 | break 134 | else: 135 | if index_rgb >= param_len: 136 | diff = param_len - index_rgb 137 | if diff < 0: 138 | buffer_rgb = buffer_rgb[:diff] 139 | binary_data = buffer_rgb 140 | read_end = True 141 | break 142 | else: 143 | # impossible 144 | read_end = True 145 | break 146 | if read_end: 147 | break 148 | geninfo = '' 149 | if sig_confirmed and binary_data != '': 150 | # Convert binary string to UTF-8 encoded text 151 | byte_data = bytearray(int(binary_data[i:i + 8], 2) for i in range(0, len(binary_data), 8)) 152 | try: 153 | if compressed: 154 | decoded_data = gzip.decompress(bytes(byte_data)).decode('utf-8') 155 | else: 156 | decoded_data = byte_data.decode('utf-8', errors='ignore') 157 | geninfo = decoded_data 158 | except: 159 | pass 160 | return str(geninfo) 161 | 162 | -------------------------------------------------------------------------------- /math_nodes.py: -------------------------------------------------------------------------------- 1 | import random 2 | from .autonode import node_wrapper, get_node_names_mappings, validate, anytype 3 | classes = [] 4 | node = node_wrapper(classes) 5 | import math 6 | 7 | @node 8 | class MinNode: 9 | """ 10 | Returns the minimum of two values 11 | """ 12 | RETURN_TYPES = (anytype,) 13 | @classmethod 14 | def INPUT_TYPES(s): 15 | return { 16 | "required": { 17 | "input1": (anytype,), 18 | "input2": (anytype,), 19 | } 20 | } 21 | FUNCTION = "min" 22 | CATEGORY = "Math" 23 | custom_name = "Min" 24 | def min(self, input1, input2): 25 | return (min(input1, input2),) 26 | 27 | @node 28 | class MaxNode: 29 | """ 30 | Returns the maximum of two values 31 | """ 32 | RETURN_TYPES = (anytype,) 33 | @classmethod 34 | def INPUT_TYPES(s): 35 | return { 36 | "required": { 37 | "input1": (anytype,), 38 | "input2": (anytype,), 39 | } 40 | } 41 | FUNCTION = "max" 42 | CATEGORY = "Math" 43 | custom_name = "Max" 44 | def max(self, input1, input2): 45 | return (max(input1, input2),) 46 | 47 | @node 48 | class RoundNode: 49 | """ 50 | Rounds a value to the nearest integer 51 | """ 52 | RETURN_TYPES = ("INT",) 53 | @classmethod 54 | def INPUT_TYPES(s): 55 | return { 56 | "required": { 57 | "input1": (anytype,), 58 | } 59 | } 60 | FUNCTION = "round" 61 | CATEGORY = "Math" 62 | custom_name = "Round" 63 | def round(self, input1): 64 | return (round(input1),) 65 | 66 | @node 67 | class AbsNode: 68 | """ 69 | Returns the absolute value of a number 70 | """ 71 | RETURN_TYPES = (anytype,) 72 | @classmethod 73 | def INPUT_TYPES(s): 74 | return { 75 | "required": { 76 | "input1": (anytype,), 77 | } 78 | } 79 | FUNCTION = "abs" 80 | CATEGORY = "Math" 81 | custom_name = "Abs" 82 | def abs(self, input1): 83 | return (abs(input1),) 84 | 85 | @node 86 | class FloorNode: 87 | """ 88 | Returns the floor of a number 89 | """ 90 | RETURN_TYPES = ("INT",) 91 | @classmethod 92 | def INPUT_TYPES(s): 93 | return { 94 | "required": { 95 | "input1": (anytype,), 96 | } 97 | } 98 | FUNCTION = "floor" 99 | CATEGORY = "Math" 100 | custom_name = "Floor" 101 | def floor(self, input1): 102 | return (math.floor(input1),) 103 | 104 | @node 105 | class CeilNode: 106 | """ 107 | Returns the ceiling of a number 108 | """ 109 | RETURN_TYPES = ("INT",) 110 | @classmethod 111 | def INPUT_TYPES(s): 112 | return { 113 | "required": { 114 | "input1": (anytype,), 115 | } 116 | } 117 | FUNCTION = "ceil" 118 | CATEGORY = "Math" 119 | custom_name = "Ceil" 120 | def ceil(self, input1): 121 | return (math.ceil(input1),) 122 | 123 | @node 124 | class PowerNode: 125 | """ 126 | Returns the power of a number 127 | """ 128 | RETURN_TYPES = (anytype,) 129 | @classmethod 130 | def INPUT_TYPES(s): 131 | return { 132 | "required": { 133 | "input1": (anytype,), 134 | "power": (anytype,), 135 | } 136 | } 137 | FUNCTION = "power" 138 | CATEGORY = "Math" 139 | custom_name = "Power" 140 | def power(self, input1, power): 141 | # validate power with log scale, prevent overflow 142 | log_val = math.log(abs(input1), 10) 143 | if log_val * power > 100 or log_val == 0: 144 | raise OverflowError("Power is too large, exceeds 100 digits") 145 | return (math.pow(input1, power),) 146 | 147 | @node 148 | class SigmoidNode: 149 | """ 150 | Returns the sigmoid of a number 151 | """ 152 | RETURN_TYPES = ("FLOAT",) 153 | @classmethod 154 | def INPUT_TYPES(s): 155 | return { 156 | "required": { 157 | "input1": ("FLOAT",), 158 | } 159 | } 160 | FUNCTION = "sigmoid" 161 | CATEGORY = "Math" 162 | custom_name = "Sigmoid" 163 | def sigmoid(self, input1): 164 | return (1 / (1 + math.exp(-input1)),) 165 | def is_prime_small(n: int) -> bool: 166 | """ 167 | Deterministic check for primality for smaller n. 168 | Skips multiples of 2 and 3, then checks i, i+2, i+4 up to sqrt(n). 169 | """ 170 | if n < 2: 171 | return False 172 | if n in (2, 3): 173 | return True 174 | if n % 2 == 0 or n % 3 == 0: 175 | return n == 2 or n == 3 176 | 177 | # 6k ± 1 optimization 178 | limit = int(math.isqrt(n)) # integer sqrt 179 | i = 5 180 | while i <= limit: 181 | if n % i == 0 or n % (i + 2) == 0: 182 | return False 183 | i += 6 184 | return True 185 | 186 | def miller_rabin_test(d: int, n: int) -> bool: 187 | """ One round of the Miller-Rabin test with a random base 'a'. """ 188 | a = random.randrange(2, n - 1) 189 | x = pow(a, d, n) # a^d % n 190 | if x == 1 or x == n - 1: 191 | return True 192 | 193 | # Keep squaring x while d does not reach n-1 194 | while d != n - 1: 195 | x = (x * x) % n 196 | d <<= 1 # d *= 2 197 | if x == 1: 198 | return False 199 | if x == n - 1: 200 | return True 201 | return False 202 | 203 | def is_prime_miller_rabin(n: int, k: int = 5) -> bool: 204 | """ 205 | Miller-Rabin primality test with k rounds (probabilistic). 206 | Good enough for big integers in practice. 207 | """ 208 | # Handle small or trivial cases 209 | if n < 2: 210 | return False 211 | # check small primes quickly 212 | for small_prime in [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]: 213 | if n == small_prime: 214 | return True 215 | if n % small_prime == 0 and n != small_prime: 216 | return False 217 | 218 | # Write n - 1 as d * 2^r 219 | d = n - 1 220 | while d % 2 == 0: 221 | d //= 2 222 | 223 | # Witness loop 224 | for _ in range(k): 225 | if not miller_rabin_test(d, n): 226 | return False 227 | return True 228 | 229 | @node 230 | class IsPrimeNode: 231 | """ 232 | Checks if an integer is prime. 233 | 234 | - If the integer |value| < threshold, uses a deterministic small-check (trial division). 235 | - Otherwise, uses a Miller-Rabin pseudoprime test for a faster check (probabilistic). 236 | 237 | Returns a BOOLEAN (True if prime, False if composite). 238 | """ 239 | FUNCTION = "is_prime" 240 | RETURN_TYPES = ("BOOLEAN",) 241 | CATEGORY = "Math" 242 | custom_name = "Is Prime?" 243 | 244 | @staticmethod 245 | def is_prime(value: int, threshold: int = 10_000_000, miller_rabin_rounds: int = 5): 246 | # handle negative or zero 247 | if value < 2: 248 | return (False,) 249 | 250 | if value < threshold: 251 | # use small prime check 252 | return (is_prime_small(value),) 253 | else: 254 | # use Miller-Rabin 255 | return (is_prime_miller_rabin(value, k=miller_rabin_rounds),) 256 | 257 | @classmethod 258 | def INPUT_TYPES(cls): 259 | return { 260 | "required": { 261 | "value": ("INT", {"default": 1, "min": -9999999999, "max": 9999999999, "step": 1}), 262 | }, 263 | "optional": { 264 | "threshold": ("INT", {"default": 10_000_000, "min": 1, "max": 9999999999, "step": 1}), 265 | "miller_rabin_rounds": ("INT", {"default": 5, "min": 1, "max": 50, "step": 1}), 266 | } 267 | } 268 | @node 269 | class RAMPNode: 270 | """ 271 | Returns the ramp of a number 272 | """ 273 | RETURN_TYPES = ("FLOAT",) 274 | @classmethod 275 | def INPUT_TYPES(s): 276 | return { 277 | "required": { 278 | "input1": ("FLOAT",), 279 | } 280 | } 281 | FUNCTION = "ramp" 282 | CATEGORY = "Math" 283 | custom_name = "RAMP" 284 | def ramp(self, input1): 285 | return (max(0, input1),) 286 | 287 | class ModuloNode: 288 | """ 289 | Returns the modulo of a number 290 | """ 291 | RETURN_TYPES = ("INT",) 292 | @classmethod 293 | def INPUT_TYPES(s): 294 | return { 295 | "required": { 296 | "input1": ("INT",), 297 | "modulo": ("INT",), 298 | } 299 | } 300 | FUNCTION = "modulo" 301 | CATEGORY = "Math" 302 | custom_name = "Modulo" 303 | def modulo(self, input1, modulo): 304 | return (input1 % modulo,) 305 | 306 | @node 307 | class LogNode: 308 | """ 309 | Returns the log of a number 310 | """ 311 | RETURN_TYPES = ("FLOAT",) 312 | @classmethod 313 | def INPUT_TYPES(s): 314 | return { 315 | "required": { 316 | "input1": ("FLOAT",), 317 | "base": ("FLOAT",), 318 | } 319 | } 320 | FUNCTION = "log" 321 | CATEGORY = "Math" 322 | custom_name = "Log" 323 | def log(self, input1, base): 324 | return (math.log(input1, base),) 325 | 326 | @node 327 | class MultiplyNode: 328 | """ 329 | Returns the product of two numbers 330 | """ 331 | RETURN_TYPES = (anytype,) 332 | @classmethod 333 | def INPUT_TYPES(s): 334 | return { 335 | "required": { 336 | "input1": (anytype,), 337 | "input2": (anytype,), 338 | } 339 | } 340 | FUNCTION = "multiply" 341 | CATEGORY = "Math" 342 | custom_name = "Multiply" 343 | def multiply(self, input1, input2): 344 | return (input1 * input2,) 345 | 346 | @node 347 | class DivideNode: 348 | """ 349 | Returns the quotient of two numbers 350 | """ 351 | RETURN_TYPES = ("FLOAT",) 352 | @classmethod 353 | def INPUT_TYPES(s): 354 | return { 355 | "required": { 356 | "input1": (anytype,), 357 | "input2": (anytype,), 358 | } 359 | } 360 | FUNCTION = "divide" 361 | CATEGORY = "Math" 362 | custom_name = "Divide" 363 | def divide(self, input1, input2): 364 | if input2 == 0: 365 | raise ZeroDivisionError("Cannot divide by zero") 366 | return (input1 / input2,) 367 | 368 | validate(classes) 369 | CLASS_MAPPINGS, CLASS_NAMES = get_node_names_mappings(classes) 370 | -------------------------------------------------------------------------------- /logic_gates.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements logic gate nodes 3 | """ 4 | import re 5 | from .autonode import node_wrapper, get_node_names_mappings, validate, anytype 6 | 7 | 8 | classes = [] 9 | node = node_wrapper(classes) 10 | 11 | @node 12 | class LogicGateCompare: 13 | """ 14 | Returns 1 if input1 > input2, 0 otherwise 15 | """ 16 | RETURN_TYPES = ("BOOLEAN",) 17 | @classmethod 18 | def INPUT_TYPES(s): 19 | return { 20 | "required": { 21 | "input1": (anytype, {"default": 0.0}), 22 | "input2": (anytype, {"default": 0.0}), 23 | } 24 | } 25 | FUNCTION = "compareFloat" 26 | CATEGORY = "Logic Gates" 27 | custom_name = "ABiggerThanB" 28 | def compareFloat(self, input1, input2): 29 | return (True if input1 > input2 else False,) 30 | @node 31 | class LogicGateInvertBasic: 32 | """ 33 | Inverts 1 to 0 and 0 to 1 34 | """ 35 | RETURN_TYPES = (anytype,) 36 | @classmethod 37 | def INPUT_TYPES(s): 38 | return { 39 | "required": { 40 | "input1": (anytype, {"default": 0}), 41 | } 42 | } 43 | FUNCTION = "invert" 44 | CATEGORY = "Logic Gates" 45 | custom_name = "Invert Basic" 46 | def invert(self, input1): 47 | return (True if not input1 else False,) 48 | @node 49 | class LogicGateNegateValue: 50 | """ 51 | Inverts x -> -x 52 | """ 53 | RETURN_TYPES = (anytype,) 54 | @classmethod 55 | def INPUT_TYPES(s): 56 | return { 57 | "required": { 58 | "input1": (anytype, {"default": 0}), 59 | } 60 | } 61 | FUNCTION = "invertValue" 62 | CATEGORY = "Logic Gates" 63 | custom_name = "Negate Value" 64 | def invertValue(self, input1): 65 | return (-input1,) 66 | @node 67 | class LogicGateBitwiseShift: 68 | """ 69 | Shifts input1 by input2 bits 70 | Only works on integers 71 | Negative input2 shifts right, positive input2 shifts left 72 | """ 73 | RETURN_TYPES = ("INT",) 74 | @classmethod 75 | def INPUT_TYPES(s): 76 | return { 77 | "required": { 78 | "input1": ("INT", {"default": 0}), 79 | "input2": ("INT", {"default": 0}), 80 | } 81 | } 82 | FUNCTION = "bitwiseShift" 83 | CATEGORY = "Logic Gates" 84 | custom_name = "Bitwise Shift" 85 | def bitwiseShift(self, input1, input2): 86 | # validate input2 87 | if abs(input2) > 32: 88 | raise ValueError("input2 must be between -32 and 32") 89 | return (input1 << input2,) 90 | @node 91 | class LogicGateBitwiseAnd: 92 | """ 93 | Bitwise AND of input1 and input2 94 | Only works on integers 95 | """ 96 | RETURN_TYPES = ("INT",) 97 | @classmethod 98 | def INPUT_TYPES(s): 99 | return { 100 | "required": { 101 | "input1": ("INT", {"default": 0}), 102 | "input2": ("INT", {"default": 0}), 103 | } 104 | } 105 | FUNCTION = "bitwiseAnd" 106 | CATEGORY = "Logic Gates" 107 | custom_name = "Bitwise And" 108 | def bitwiseAnd(self, input1, input2): 109 | return (input1 & input2,) 110 | @node 111 | class LogicGateBitwiseOr: 112 | """ 113 | Bitwise OR of input1 and input2 114 | Only works on integers 115 | """ 116 | RETURN_TYPES = ("INT",) 117 | @classmethod 118 | def INPUT_TYPES(s): 119 | return { 120 | "required": { 121 | "input1": ("INT", {"default": 0}), 122 | "input2": ("INT", {"default": 0}), 123 | } 124 | } 125 | FUNCTION = "bitwiseOr" 126 | CATEGORY = "Logic Gates" 127 | custom_name = "Bitwise Or" 128 | def bitwiseOr(self, input1, input2): 129 | return (input1 | input2,) 130 | @node 131 | class LogicGateBitwiseXor: 132 | """ 133 | Bitwise XOR of input1 and input2 134 | Only works on integers 135 | """ 136 | RETURN_TYPES = ("INT",) 137 | @classmethod 138 | def INPUT_TYPES(s): 139 | return { 140 | "required": { 141 | "input1": ("INT", {"default": 0}), 142 | "input2": ("INT", {"default": 0}), 143 | } 144 | } 145 | FUNCTION = "bitwiseXor" 146 | CATEGORY = "Logic Gates" 147 | custom_name = "Bitwise Xor" 148 | def bitwiseXor(self, input1, input2): 149 | return (input1 ^ input2,) 150 | @node 151 | class LogicGateBitwiseNot: 152 | """ 153 | Bitwise NOT of input1 154 | Only works on integers 155 | """ 156 | RETURN_TYPES = ("INT",) 157 | @classmethod 158 | def INPUT_TYPES(s): 159 | return { 160 | "required": { 161 | "input1": ("INT", {"default": 0}), 162 | } 163 | } 164 | FUNCTION = "bitwiseNot" 165 | CATEGORY = "Logic Gates" 166 | custom_name = "Bitwise Not" 167 | def bitwiseNot(self, input1): 168 | return (~input1,) 169 | @node 170 | class LogicGateCompare: 171 | """ 172 | Returns 1 if input1 > input2, 0 otherwise 173 | """ 174 | RETURN_TYPES = ("BOOLEAN",) 175 | @classmethod 176 | def INPUT_TYPES(s): 177 | return { 178 | "required": { 179 | "input1": (anytype, {"default": 0}), 180 | "input2": (anytype, {"default": 0}), 181 | } 182 | } 183 | FUNCTION = "compareInt" 184 | CATEGORY = "Logic Gates" 185 | custom_name = "ABiggerThanB" 186 | def compareInt(self, input1, input2): 187 | return (True if input1 > input2 else False,) 188 | @node 189 | class LogicGateCompareString: 190 | """ 191 | Returns if given regex (1) is found in given string (2) 192 | """ 193 | RETURN_TYPES = ("BOOLEAN",) 194 | @classmethod 195 | def INPUT_TYPES(s): 196 | return { 197 | "required": { 198 | "regex": ("STRING", {"default": ""}), 199 | "input2": ("STRING", {"default": ""}), 200 | } 201 | } 202 | FUNCTION = "compareString" 203 | CATEGORY = "Logic Gates" 204 | custom_name = "AContainsB(String)" 205 | def compareString(self, regex, input2): 206 | return (True if re.search(regex, input2) else False,) 207 | 208 | @node 209 | class GetLengthString: 210 | """ 211 | Returns the length of the input string 212 | """ 213 | RETURN_TYPES = ("INT",) 214 | @classmethod 215 | def INPUT_TYPES(s): 216 | return { 217 | "required": { 218 | "string": ("STRING", {"default": ""}), 219 | } 220 | } 221 | FUNCTION = "lengthString" 222 | CATEGORY = "Logic Gates" 223 | custom_name = "Length of String" 224 | def lengthString(self, string): 225 | return (len(string),) 226 | 227 | @node 228 | class StaticNumberInt: 229 | """ 230 | Returns a static number 231 | """ 232 | RETURN_TYPES = ("INT",) 233 | @classmethod 234 | def INPUT_TYPES(s): 235 | return { 236 | "required": { 237 | "number": ("INT", {"default": 0}), 238 | } 239 | } 240 | FUNCTION = "staticNumber" 241 | CATEGORY = "Logic Gates" 242 | custom_name = "Static Number Int" 243 | def staticNumber(self, number): 244 | return (number,) 245 | @node 246 | class StaticNumberFloat: 247 | """ 248 | Returns a static number 249 | """ 250 | RETURN_TYPES = ("FLOAT",) 251 | @classmethod 252 | def INPUT_TYPES(s): 253 | return { 254 | "required": { 255 | "number": ("FLOAT", {"default": 0.0}), 256 | } 257 | } 258 | FUNCTION = "staticNumber" 259 | CATEGORY = "Logic Gates" 260 | custom_name = "Static Number Float" 261 | def staticNumber(self, number): 262 | return (number,) 263 | @node 264 | class StaticString: 265 | """ 266 | Returns a static string 267 | """ 268 | RETURN_TYPES = ("STRING",) 269 | @classmethod 270 | def INPUT_TYPES(s): 271 | return { 272 | "required": { 273 | "string": ("STRING", {"default": ""}), 274 | } 275 | } 276 | FUNCTION = "staticString" 277 | CATEGORY = "Logic Gates" 278 | custom_name = "Static String" 279 | def staticString(self, string): 280 | return (string,) 281 | @node 282 | class LogicGateAnd: 283 | """ 284 | Returns 1 if all inputs are True, 0 otherwise 285 | """ 286 | RETURN_TYPES = ("BOOLEAN",) 287 | @classmethod 288 | def INPUT_TYPES(s): 289 | return { 290 | "required": { 291 | "input1": (anytype, {"default": 0.0}), 292 | "input2": (anytype, {"default": 0.0}), 293 | } 294 | } 295 | FUNCTION = "and_" 296 | CATEGORY = "Logic Gates" 297 | custom_name = "AAndBGate" 298 | def and_(self, input1, input2): 299 | return (True if input1 and input2 else False,) 300 | @node 301 | class LogicGateOr: 302 | """ 303 | Returns 1 if any input is True, 0 otherwise 304 | """ 305 | RETURN_TYPES = ("BOOLEAN",) 306 | @classmethod 307 | def INPUT_TYPES(s): 308 | return { 309 | "required": { 310 | "input1": (anytype, {"default": 0}), 311 | "input2": (anytype, {"default": 0}), 312 | } 313 | } 314 | FUNCTION = "or_" 315 | CATEGORY = "Logic Gates" 316 | custom_name = "AOrBGate" 317 | def or_(self, input1, input2): 318 | return (True if input1 or input2 else False,) 319 | @node 320 | class LogicGateEither: 321 | """ 322 | Returns input1 if condition is true, input2 otherwise 323 | """ 324 | RETURN_TYPES = (anytype,) 325 | @classmethod 326 | def INPUT_TYPES(s): 327 | return { 328 | "required": { 329 | "condition": (anytype, {"default": 0}), 330 | "input1": (anytype, {"default": ""}), 331 | "input2": (anytype, {"default": ""}), 332 | } 333 | } 334 | FUNCTION = "either" 335 | CATEGORY = "Logic Gates" 336 | custom_name = "ReturnAorBValue" 337 | def either(self, condition, input1, input2): 338 | return (input1 if condition else input2,) 339 | @node 340 | class AddNode: 341 | """ 342 | Returns the sum of the inputs 343 | """ 344 | RETURN_TYPES = (anytype,) 345 | @classmethod 346 | def INPUT_TYPES(s): 347 | return { 348 | "required": { 349 | "input1": (anytype, {"default": 0}), 350 | "input2": (anytype, {"default": 0}), 351 | } 352 | } 353 | FUNCTION = "add" 354 | CATEGORY = "Logic Gates" 355 | custom_name = "Add Values" 356 | def add(self, input1, input2): 357 | return (input1 + input2,) 358 | @node 359 | class MergeString: 360 | """ 361 | Returns the concatenation of the inputs 362 | """ 363 | RETURN_TYPES = ("STRING",) 364 | @classmethod 365 | def INPUT_TYPES(s): 366 | return { 367 | "required": { 368 | "input1": (anytype, {"default": ""}), 369 | "input2": (anytype, {"default": ""}), 370 | } 371 | } 372 | FUNCTION = "merge" 373 | CATEGORY = "Logic Gates" 374 | custom_name = "Merge String" 375 | def merge(self, input1, input2): 376 | return (str(input1) + str(input2),) 377 | 378 | @node 379 | class ReplaceString: 380 | """ 381 | Returns the concatenation of the inputs 382 | """ 383 | RETURN_TYPES = ("STRING",) 384 | @classmethod 385 | def INPUT_TYPES(s): 386 | return { 387 | "required": { 388 | "String": ("STRING", {"default": ""}), # input string 389 | "Regex": ("STRING", {"default": ""}), # regex to search for 390 | "ReplaceWith": ("STRING", {"default": ""}), # string to replace with 391 | } 392 | } 393 | FUNCTION = "replace" 394 | CATEGORY = "Logic Gates" 395 | custom_name = "Replace String" 396 | def replace(self, String, Regex, ReplaceWith): 397 | # using regex 398 | return (re.sub(Regex, ReplaceWith, String),) 399 | 400 | @node 401 | class MemoryNode: 402 | """ 403 | Stores a value in memory. 404 | Flip-flop behaviour. 405 | """ 406 | def __init__(self): 407 | self.memory_value = None 408 | RETURN_TYPES = (anytype,) 409 | @classmethod 410 | def INPUT_TYPES(s): 411 | return { 412 | "required": { 413 | "input1": (anytype, {"default": ""}), 414 | "flag": (anytype, {"default": 0}), 415 | } 416 | } 417 | FUNCTION = "memory" 418 | CATEGORY = "Logic Gates" 419 | custom_name = "Memory String" 420 | def memory(self, input1, flag): 421 | if self.memory_value is None or flag: 422 | self.memory_value = input1 423 | return (self.memory_value,) 424 | 425 | CLASS_MAPPINGS, CLASS_NAMES = get_node_names_mappings(classes) 426 | validate(classes) 427 | -------------------------------------------------------------------------------- /auxilary.py: -------------------------------------------------------------------------------- 1 | from .imgio.converter import PILHandlingHodes 2 | from .autonode import node_wrapper, get_node_names_mappings, validate, anytype, PILImage 3 | from .utils.tagger import get_tags, tagger_keys 4 | from PIL import Image, ImageFilter 5 | 6 | auxilary_classes = [] 7 | auxilary_node = node_wrapper(auxilary_classes) 8 | 9 | @auxilary_node 10 | class GetRatingNode: 11 | FUNCTION = "get_rating_class" 12 | RETURN_TYPES = ("STRING",) 13 | CATEGORY = "tagger" 14 | custom_name = "Get Rating Class" 15 | @staticmethod 16 | def get_rating_class(image, model_name): 17 | image = PILHandlingHodes.handle_input(image) 18 | result_dict = get_tags(image, model_name=model_name) 19 | return (result_dict['rating'], ) 20 | @classmethod 21 | def INPUT_TYPES(cls): 22 | return { 23 | "required": { 24 | "image": ("IMAGE",), 25 | }, 26 | "optional": { 27 | "model_name": (tagger_keys, {"default": tagger_keys[0]}), 28 | } 29 | } 30 | 31 | @auxilary_node 32 | class GetRatingFromTextNode: 33 | FUNCTION = "get_rating_class" 34 | RETURN_TYPES = ("STRING",) 35 | CATEGORY = "tagger" 36 | custom_name = "Get Rating Class From Text" 37 | @staticmethod 38 | def get_rating_class(image, model_name): 39 | image = PILHandlingHodes.handle_input(image) 40 | result_dict = get_tags(image, model_name=model_name) 41 | return (result_dict['rating'], ) 42 | @classmethod 43 | def INPUT_TYPES(cls): 44 | return { 45 | "required": { 46 | "image": ("STRING", {"default": "/path/to/image.jpg"}), 47 | }, 48 | "optional": { 49 | "model_name": (tagger_keys, {"default": tagger_keys[0]}), 50 | } 51 | } 52 | 53 | def pixelate(image, pixelation_factor=0.1): 54 | # Downscale the image 55 | small = image.resize( 56 | (int(image.width * pixelation_factor), int(image.height * pixelation_factor)), 57 | resample=Image.NEAREST 58 | ) 59 | # Upscale back to original size 60 | return small.resize(image.size, Image.NEAREST) 61 | 62 | 63 | def pixelate_target_tiles(image, max_tiles=100): 64 | aspect_ratio = image.width / image.height 65 | # solve h * a * h = max_tiles 66 | h = int((max_tiles / aspect_ratio) ** 0.5) 67 | w = int(aspect_ratio * h) 68 | h, w = max(4, h), max(4, w) 69 | # Downscale the image 70 | small = image.resize( 71 | (w, h), 72 | resample=Image.NEAREST, 73 | ) 74 | # Upscale back to original size 75 | return small.resize(image.size, Image.NEAREST) 76 | 77 | 78 | @auxilary_node 79 | class CensorImageByRating: 80 | FUNCTION = "censor_image" 81 | RETURN_TYPES = ("IMAGE",) 82 | CATEGORY = "image" 83 | custom_name = "Censor Image by Rating" 84 | 85 | @staticmethod 86 | @PILHandlingHodes.output_wrapper 87 | def censor_image(image, rating_threshold, censor_method, model_name=None): 88 | # Convert input to a PIL image 89 | image = PILHandlingHodes.handle_input(image) 90 | result_dict = get_tags(image, model_name=model_name) 91 | rating = result_dict['rating'] 92 | # If rating is general, no censorship required 93 | if rating.lower() == "general": 94 | return (image,) 95 | censor_image = False 96 | if rating_threshold == "general": 97 | # censor if not general 98 | if rating.lower() != "general": 99 | censor_image = True 100 | elif rating_threshold == "sensitive": 101 | # censor if not general or sensitive 102 | if rating.lower() not in ["general", "sensitive"]: 103 | censor_image = True 104 | elif rating_threshold == "questionable": 105 | # censor if not general, sensitive or questionable 106 | if rating.lower() not in ["general", "sensitive", "questionable"]: 107 | censor_image = True 108 | elif rating_threshold == "explicit": 109 | return (image,) # why are you using this? 110 | if censor_image: 111 | if censor_method.lower() == "white": 112 | # Return a white image of the same size 113 | censored_image = Image.new("RGB", image.size, (255, 255, 255)) 114 | return (censored_image,) 115 | 116 | elif censor_method.lower() == "blur": 117 | # Apply a strong blur (you can adjust radius as needed) 118 | censored_image = image.filter(ImageFilter.GaussianBlur(radius=40)) 119 | return (censored_image,) 120 | elif censor_method.lower() == "pixelate": 121 | # first, blur 122 | censored_image = image.filter(ImageFilter.GaussianBlur(radius=20)) 123 | censored_image = pixelate_target_tiles(censored_image, max_tiles=100) 124 | return (censored_image,) 125 | 126 | # If unknown method is provided, just return the original image 127 | return (image,) 128 | 129 | @classmethod 130 | def INPUT_TYPES(cls): 131 | return { 132 | "required": { 133 | "image": ("IMAGE",), 134 | "rating_threshold": (["general", "sensitive", "questionable", "explicit"],), 135 | "censor_method": (["blur", "white","pixelate"],), 136 | }, 137 | "optional": { 138 | "model_name": (tagger_keys, {"default": tagger_keys[0]}), 139 | } 140 | } 141 | 142 | @auxilary_node 143 | class FilterTagsNode: 144 | """ 145 | Filters tags, given a list of tags splitted by ",". 146 | We assume the input text is splittable by "," (or separator). Then, if any tags contain the filter, we remove matching tags. 147 | """ 148 | FUNCTION = "filter_tags" 149 | RETURN_TYPES = ("STRING",) 150 | CATEGORY = "safety" 151 | custom_name = "Filter Tags" 152 | @staticmethod 153 | def filter_tags(tags, filter_tags, separator): 154 | filter_tags = filter_tags.split(",") 155 | filter_tags = [tag.strip() for tag in filter_tags] 156 | tags = tags.split(separator) 157 | tags = [tag.strip() for tag in tags] 158 | filtered = [] 159 | for tag in tags: 160 | if all(filter_tag not in tag for filter_tag in filter_tags): 161 | filtered.append(tag) 162 | return (separator.join(filtered), ) 163 | @classmethod 164 | def INPUT_TYPES(cls): 165 | return { 166 | "required": { 167 | "tags": ("STRING",), 168 | "filter_tags": ("STRING",), 169 | }, 170 | # optional separator 171 | "optional": { 172 | "separator": ("STRING", {"default": ","}), 173 | } 174 | } 175 | 176 | @auxilary_node 177 | class GetTagsAboveThresholdNode: 178 | FUNCTION = "get_tags_above_threshold" 179 | RETURN_TYPES = ("STRING",) 180 | CATEGORY = "tagger" 181 | custom_name = "Get Tags Above Threshold" 182 | @staticmethod 183 | def get_tags_above_threshold(image, threshold, replace, model_name): 184 | image = PILHandlingHodes.handle_input(image) 185 | result_dict = get_tags(image, threshold=threshold, replace=replace, model_name=model_name) 186 | return (", ".join(result_dict['tags']), ) 187 | @classmethod 188 | def INPUT_TYPES(cls): 189 | return { 190 | "required": { 191 | "image": ("IMAGE",), 192 | }, 193 | "optional": { 194 | "threshold": ("FLOAT", {"default": 0.4}), 195 | "replace": ("BOOLEAN", {"default": False}), 196 | "model_name": (tagger_keys, {"default": tagger_keys[0]}), 197 | } 198 | } 199 | 200 | @auxilary_node 201 | class GetTagsAboveThresholdFromTextNode: 202 | FUNCTION = "get_tags_above_threshold" 203 | RETURN_TYPES = ("STRING",) 204 | CATEGORY = "tagger" 205 | custom_name = "Get Tags Above Threshold From Text" 206 | @staticmethod 207 | def get_tags_above_threshold(image, threshold, replace, model_name): 208 | image = PILHandlingHodes.handle_input(image) 209 | result_dict = get_tags(image, threshold=threshold, replace=replace, model_name=model_name) 210 | return (", ".join(result_dict['tags']), ) 211 | @classmethod 212 | def INPUT_TYPES(cls): 213 | return { 214 | "required": { 215 | "image": ("IMAGE",), 216 | }, 217 | "optional": { 218 | "threshold": ("FLOAT", {"default": 0.4}), 219 | "replace": ("BOOLEAN", {"default": False}), 220 | "model_name": (tagger_keys, {"default": tagger_keys[0]}), 221 | } 222 | } 223 | 224 | @auxilary_node 225 | class GetCharactersAboveThresholdNode: 226 | FUNCTION = "get_tags_above_threshold" 227 | RETURN_TYPES = ("STRING",) 228 | CATEGORY = "tagger" 229 | custom_name = "Get Chars Above Threshold" 230 | @staticmethod 231 | def get_tags_above_threshold(image, threshold, replace, model_name): 232 | image = PILHandlingHodes.handle_input(image) 233 | result_dict = get_tags(image, threshold=threshold, replace=replace, model_name=model_name) 234 | return (", ".join(result_dict['chars']), ) 235 | @classmethod 236 | def INPUT_TYPES(cls): 237 | return { 238 | "required": { 239 | "image": ("IMAGE",), 240 | }, 241 | "optional": { 242 | "threshold": ("FLOAT", {"default": 0.4}), 243 | "replace": ("BOOLEAN", {"default": False}), 244 | "model_name": (tagger_keys, {"default": tagger_keys[0]}), 245 | } 246 | } 247 | 248 | @auxilary_node 249 | class GetCharactersAboveThresholdFromTextNode: 250 | FUNCTION = "get_tags_above_threshold" 251 | RETURN_TYPES = ("STRING",) 252 | CATEGORY = "tagger" 253 | custom_name = "Get Chars Above Threshold From Text" 254 | @staticmethod 255 | def get_tags_above_threshold(image, threshold, replace, model_name): 256 | image = PILHandlingHodes.handle_input(image) 257 | result_dict = get_tags(image, threshold=threshold, replace=replace, model_name=model_name) 258 | return (", ".join(result_dict['chars']), ) 259 | @classmethod 260 | def INPUT_TYPES(cls): 261 | return { 262 | "required": { 263 | "image": ("IMAGE",), 264 | }, 265 | "optional": { 266 | "threshold": ("FLOAT", {"default": 0.4}), 267 | "replace": ("BOOLEAN", {"default": False}), 268 | "model_name": (tagger_keys, {"default": tagger_keys[0]}), 269 | } 270 | } 271 | 272 | @auxilary_node 273 | class GetAllTagsAboveThresholdNode: 274 | FUNCTION = "get_tags" 275 | RETURN_TYPES = ("STRING",) 276 | CATEGORY = "tagger" 277 | custom_name = "Get All Tags Above Threshold" 278 | @staticmethod 279 | def get_tags(image, threshold, replace, model_name): 280 | image = PILHandlingHodes.handle_input(image) 281 | result = get_tags(image, threshold=threshold, replace=replace, model_name=model_name) 282 | result_list = [] 283 | result_list.append(result['rating']) 284 | result_list.extend(result['tags']) 285 | result_list.extend(result['chars']) 286 | return (", ".join(result_list), ) 287 | @classmethod 288 | def INPUT_TYPES(cls): 289 | return { 290 | "required": { 291 | "image": ("IMAGE",), 292 | }, 293 | "optional": { 294 | "threshold": ("FLOAT", {"default": 0.4}), 295 | "replace": ("BOOLEAN", {"default": False}), 296 | "model_name": (tagger_keys, {"default": tagger_keys[0]}), 297 | } 298 | } 299 | @auxilary_node 300 | class GetAllTagsExceptCharacterAboveThresholdNode: 301 | FUNCTION = "get_tags" 302 | RETURN_TYPES = ("STRING",) 303 | CATEGORY = "tagger" 304 | custom_name = "Get All Tags Above Threshold Except Characters" 305 | @staticmethod 306 | def get_tags(image, threshold, replace, model_name): 307 | image = PILHandlingHodes.handle_input(image) 308 | result = get_tags(image, threshold=threshold, replace=replace, model_name=model_name) 309 | result_list = [] 310 | result_list.append(result['rating']) 311 | result_list.extend(result['tags']) 312 | return (", ".join(result_list), ) 313 | @classmethod 314 | def INPUT_TYPES(cls): 315 | return { 316 | "required": { 317 | "image": ("IMAGE",), 318 | }, 319 | "optional": { 320 | "threshold": ("FLOAT", {"default": 0.4}), 321 | "replace": ("BOOLEAN", {"default": False}), 322 | "model_name": (tagger_keys, {"default": tagger_keys[0]}), 323 | } 324 | } 325 | CLASS_MAPPINGS, CLASS_NAMES = get_node_names_mappings(auxilary_classes) 326 | validate(auxilary_classes) 327 | -------------------------------------------------------------------------------- /imgio/converter.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | from PIL import Image 3 | import numpy as np 4 | import base64 5 | import torch 6 | import requests 7 | import os 8 | from io import BytesIO 9 | import gzip 10 | import re 11 | from urllib.parse import urlparse 12 | 13 | 14 | def handle_rgba_composite( 15 | image: Image.Image, background_color=(255, 255, 255), as_rgba=False 16 | ) -> Image.Image: 17 | """ 18 | Convert RGBA image to RGB image using alpha_composite. 19 | """ 20 | mode = image.mode 21 | if as_rgba: 22 | return image.convert("RGBA") # universal format 23 | if mode == "RGB": 24 | return image 25 | if mode == "RGBA": 26 | # Create a white RGBA background 27 | background = Image.new("RGBA", image.size, (255, 255, 255, 255)) 28 | # Composite the original image over the white background 29 | composed = Image.alpha_composite(background, image) 30 | # Convert back to RGB (now that background is flattened) 31 | return composed.convert("RGB") 32 | elif mode == "LA": 33 | # "LA" is 8-bit grayscale + alpha. 34 | rgba_image = image.convert("RGBA") 35 | background = Image.new("RGBA", rgba_image.size, (*background_color, 255)) 36 | composed = Image.alpha_composite(background, rgba_image) 37 | return composed.convert("RGB") 38 | 39 | # 3. "L" or "1" = Grayscale or Black/White, "P" = Palette 40 | elif mode in ["L", "1", "P"]: 41 | # Simply converting to "RGB" is usually enough. 42 | return image.convert("RGB") 43 | 44 | # 4. "CMYK", "YCbCr", "HSV", etc. 45 | elif mode in ["CMYK", "YCbCr", "HSV"]: 46 | # Typically, a .convert("RGB") is enough if you just need an RGB version. 47 | return image.convert("RGB") 48 | print(f"Warning: Unhandled image mode: {mode}. Converting to RGB.") 49 | return image.convert("RGB") 50 | 51 | def fetch_image_securely(image_url: str, 52 | allowed_schemes=('http', 'https'), 53 | max_file_size=5_000_000, 54 | request_timeout=30): 55 | """ 56 | Fetches an image from the given URL securely. 57 | 58 | This function: 59 | 1. Validates the URL scheme (only http/https). 60 | 2. Blocks private IP/loopback addresses to prevent SSRF attacks. 61 | 3. Streams data to avoid excessive memory usage. 62 | 4. Checks MIME type, size limits, and optionally handles form-encoded image data. 63 | 64 | :param image_url: URL of the image to retrieve (e.g., an S3-signed URL). 65 | :param allowed_schemes: A tuple of allowed URL schemes (default: ('http', 'https')). 66 | :param max_file_size: Max size (in bytes) of the file to download. 67 | :param request_timeout: Timeout (in seconds) for the request. 68 | :return: PIL Image object if successful, else raises an exception. 69 | """ 70 | 71 | # -- 1. Validate scheme to avoid unexpected protocols -- 72 | parsed = urlparse(image_url) 73 | if parsed.scheme not in allowed_schemes: 74 | raise ValueError(f"Invalid or disallowed URL scheme: {parsed.scheme}") 75 | 76 | # -- 2. Prevent local network (SSRF) attacks by blocking private or loopback addresses -- 77 | # This is a simplified check. Consider using a library for robust IP parsing if needed. 78 | ip_like_pattern = r'^(\d{1,3}\.){3}\d{1,3}$' 79 | hostname = parsed.hostname 80 | if ( 81 | hostname is None 82 | or hostname.lower() in ("localhost", "127.0.0.1", "::1") 83 | or (re.match(ip_like_pattern, hostname) and hostname.startswith("10.")) 84 | or hostname.startswith("192.168.") 85 | or hostname.startswith("172.16.") 86 | or hostname.startswith("172.17.") 87 | or hostname.startswith("172.18.") 88 | or hostname.startswith("172.19.") 89 | or hostname.startswith("172.2") # covers 172.20 - 172.31 90 | or hostname.startswith("172.3") 91 | ): 92 | raise ValueError("URL resolves to a private or loopback address, which is disallowed.") 93 | 94 | # -- 3. Retrieve the response with a timeout and stream -- 95 | # This handles the S3 URL just like any other public HTTPS link. 96 | with requests.get(image_url, timeout=request_timeout, stream=True) as response: 97 | response.raise_for_status() 98 | 99 | # -- 4. Check Content-Type in headers -- 100 | content_type = response.headers.get('Content-Type', '').lower() 101 | 102 | # If it's a direct image... 103 | if content_type.startswith("image/"): 104 | # -- 5. Check Content-Length against max_file_size -- 105 | content_length = response.headers.get('Content-Length') 106 | if content_length and int(content_length) > max_file_size: 107 | raise ValueError( 108 | f"File is too large: {int(content_length)} bytes. " 109 | f"Max allowed is {max_file_size} bytes." 110 | ) 111 | 112 | data = BytesIO() 113 | downloaded = 0 114 | chunk_size = 8192 115 | for chunk in response.iter_content(chunk_size=chunk_size): 116 | downloaded += len(chunk) 117 | if downloaded > max_file_size: 118 | raise ValueError( 119 | f"File exceeded the maximum allowed size of {max_file_size} bytes." 120 | ) 121 | data.write(chunk) 122 | 123 | # Reset the buffer and open with PIL 124 | data.seek(0) 125 | return Image.open(data) 126 | 127 | # If the server reports x-www-form-urlencoded, parse for embedded image data 128 | elif content_type == "application/x-www-form-urlencoded": 129 | # let PIL handle the parsing 130 | try: 131 | return Image.open(BytesIO(response.content)) 132 | except Exception as e: 133 | raise ValueError( 134 | f"Failed to parse x-www-form-urlencoded data as image: {e}" 135 | ) 136 | 137 | else: 138 | # Some other content type we don't handle 139 | raise ValueError( 140 | f"Unsupported Content-Type or not an image: {content_type}" 141 | ) 142 | 143 | class IOConverter: 144 | """ 145 | Classify the input data type. 146 | 147 | Assumes the inputs to be following: 148 | 149 | - PIL Image 150 | - numpy array 151 | - torch tensor 152 | - string (path to image) 153 | - base64 string (which can be decoded to bytes and then to image) 154 | - gzip-compressed base64 string 155 | - URL (which can be downloaded to image) 156 | 157 | Do NOT pass unsafe URLs / base64 strings, as it may cause security issues. 158 | """ 159 | 160 | class InputType: 161 | PIL = "PIL" 162 | NUMPY = "NUMPY" 163 | TORCH = "TORCH" 164 | STRING = "STRING" 165 | BASE64 = "BASE64" 166 | GZIP_BASE64 = "GZIP_BASE64" 167 | URL = "URL" 168 | 169 | def __init__(self): 170 | raise Exception("This class should not be instantiated.") 171 | 172 | @staticmethod 173 | def classify(input_data): 174 | if isinstance(input_data, Image.Image): 175 | return IOConverter.InputType.PIL 176 | elif isinstance(input_data, np.ndarray): 177 | return IOConverter.InputType.NUMPY 178 | elif isinstance(input_data, torch.Tensor): 179 | return IOConverter.InputType.TORCH 180 | elif isinstance(input_data, str): 181 | if os.path.isfile(input_data): 182 | return IOConverter.InputType.STRING 183 | elif input_data.startswith("data:image/"): 184 | return IOConverter.InputType.BASE64 185 | elif input_data.startswith("http://") or input_data.startswith("https://"): 186 | return IOConverter.InputType.URL 187 | else: 188 | # Attempt to detect base64-encoded data 189 | try: 190 | decoded_data = base64.b64decode(input_data, validate=True) 191 | # Check for gzip magic number 192 | if decoded_data[:2] == b'\x1f\x8b': 193 | return IOConverter.InputType.GZIP_BASE64 194 | else: 195 | return IOConverter.InputType.BASE64 196 | except Exception: 197 | raise Exception(f"Invalid string input, cannot be decoded as base64.") 198 | else: 199 | raise Exception(f"Invalid input type, {type(input_data)}") 200 | @staticmethod 201 | def match_dtype(array_or_tensor, is_tensor=False): 202 | # if all value is between 0 and 1, multiply by 255 and convert to uint8 203 | # however already uint8, skip 204 | # check dtype first 205 | if array_or_tensor.dtype == np.uint8 or array_or_tensor.dtype == torch.uint8: 206 | return array_or_tensor 207 | 208 | if array_or_tensor.min() >= 0 and array_or_tensor.max() <= 1: 209 | multiplied = array_or_tensor * 255 210 | if not is_tensor: 211 | return multiplied.astype(np.uint8) 212 | else: 213 | return multiplied.to(torch.uint8) 214 | return array_or_tensor 215 | 216 | @staticmethod 217 | def convert_to_pil(input_data): 218 | input_type = IOConverter.classify(input_data) 219 | if input_type == IOConverter.InputType.PIL: 220 | return handle_rgba_composite(input_data) 221 | elif input_type == IOConverter.InputType.NUMPY: 222 | # [1, 1216, 832, 3], ' [1216, 832, 3], 'uint8' 223 | # if not first element is 1, then it is a batch of images so warning 224 | if input_data.shape[0] != 1: 225 | result = [] 226 | for i in range(input_data.shape[0]): 227 | np_array = IOConverter.match_dtype(input_data[i]) 228 | result.append(handle_rgba_composite(Image.fromarray(np_array))) 229 | return result # return list of PIL images 230 | input_data = IOConverter.match_dtype(input_data[0]) 231 | return handle_rgba_composite(Image.fromarray(input_data)) 232 | elif input_type == IOConverter.InputType.TORCH: 233 | # same as above 234 | if input_data.shape[0] != 1: 235 | result = [] 236 | for i in range(input_data.shape[0]): 237 | np_array = ( 238 | IOConverter.match_dtype(input_data[i], is_tensor=True) 239 | .cpu() 240 | .numpy() 241 | ) 242 | result.append(handle_rgba_composite(Image.fromarray(np_array))) 243 | return result 244 | input_data = IOConverter.match_dtype(input_data[0], is_tensor=True) 245 | np_array = input_data.cpu().numpy() 246 | return handle_rgba_composite(Image.fromarray(np_array)) 247 | elif input_type == IOConverter.InputType.STRING: 248 | return Image.open(input_data) 249 | elif input_type == IOConverter.InputType.GZIP_BASE64: 250 | decoded_data = IOConverter.read_base64(input_data) 251 | decompressed_data = gzip.decompress(decoded_data) 252 | partial_result = Image.open(BytesIO(decompressed_data)) 253 | result = handle_rgba_composite(partial_result) 254 | return result 255 | elif input_type == IOConverter.InputType.BASE64: 256 | decoded_data = IOConverter.read_base64(input_data) 257 | partial_result = Image.open(BytesIO(decoded_data)) 258 | result = handle_rgba_composite(partial_result) 259 | return result 260 | elif input_type == IOConverter.InputType.URL: 261 | partial_result = fetch_image_securely(input_data) 262 | result = handle_rgba_composite(partial_result) 263 | return result 264 | else: 265 | raise Exception(f"Invalid input type, {input_type}") 266 | 267 | @staticmethod 268 | def to_rgb_tensor(pil_image): 269 | if pil_image.mode == "I": 270 | pil_image = pil_image.point(lambda i: i * (1/255)) # convert to float 271 | pil_image = handle_rgba_composite(pil_image) 272 | np_array = np.array(pil_image).astype(np.float32) / 255.0 273 | tensor = torch.from_numpy(np_array) 274 | tensor = tensor.unsqueeze(0) # Add batch dimension 275 | # assert 4-dimensional tensor, B,C,H,W 276 | if len(tensor.shape) != 4: 277 | raise Exception(f"Invalid tensor shape, expected 4-dimensional tensor, got {tensor.shape}") 278 | return tensor 279 | 280 | @staticmethod 281 | def to_rgba_tensor(pil_image): 282 | if pil_image.mode == "I": 283 | pil_image = pil_image.point(lambda i: i * (1/255)) # convert to float 284 | pil_image = handle_rgba_composite(pil_image, as_rgba=True) 285 | np_array = np.array(pil_image).astype(np.float32) / 255.0 286 | tensor = torch.from_numpy(np_array) 287 | tensor = tensor.unsqueeze(0) # Add batch dimension 288 | # assert 4-dimensional tensor, B,C,H,W 289 | if len(tensor.shape) != 4: 290 | raise Exception(f"Invalid tensor shape, expected 4-dimensional tensor, got {tensor.shape}") 291 | return tensor 292 | 293 | @staticmethod 294 | def read_base64(base64_string: str) -> bytes: 295 | return base64.b64decode(base64_string) 296 | 297 | @staticmethod 298 | def read_maybe_gzip_base64(base64_string: str) -> bytes: 299 | decoded_data = base64.b64decode(base64_string) 300 | if decoded_data[:2] == b'\x1f\x8b': 301 | result = gzip.decompress(decoded_data) 302 | else: 303 | result = decoded_data 304 | # to string 305 | return result.decode('utf-8') 306 | 307 | @staticmethod 308 | def convert_to_rgb_tensor(input_data, rgba=False): 309 | if not rgba: 310 | output_func = IOConverter.to_rgb_tensor 311 | else: 312 | output_func = IOConverter.to_rgba_tensor 313 | input_type = IOConverter.classify(input_data) 314 | if input_type == IOConverter.InputType.PIL: 315 | return output_func(input_data) 316 | elif input_type == IOConverter.InputType.NUMPY: 317 | # if all values are 0~1, skip 318 | if input_data.min() >= 0 and input_data.max() <= 1: 319 | np_array = input_data.astype(np.float32) 320 | else: 321 | np_array = input_data.astype(np.float32) / 255.0 322 | tensor = torch.from_numpy(np_array) 323 | tensor = tensor.unsqueeze(0) # Add batch dimension 324 | return tensor 325 | elif input_type == IOConverter.InputType.TORCH: 326 | return input_data 327 | elif input_type == IOConverter.InputType.STRING: 328 | image = Image.open(input_data) 329 | return output_func(image) 330 | elif input_type == IOConverter.InputType.GZIP_BASE64: 331 | image = IOConverter.convert_to_pil(input_data) 332 | return output_func(image) 333 | elif input_type == IOConverter.InputType.BASE64: 334 | image = IOConverter.convert_to_pil(input_data) 335 | return output_func(image) 336 | elif input_type == IOConverter.InputType.URL: 337 | image = fetch_image_securely(input_data) 338 | return output_func(image) 339 | else: 340 | raise Exception(f"Invalid input type, {input_type}") 341 | 342 | @staticmethod 343 | def convert_to_base64(input_data, format="PNG", quality=100, gzip_compress=False): 344 | pil_image = IOConverter.convert_to_pil(input_data) 345 | buffered = BytesIO() 346 | save_params = {'format': format} 347 | if format.upper() in ['JPEG', 'JPG']: 348 | save_params['quality'] = quality 349 | pil_image.save(buffered, **save_params) 350 | buffered.seek(0) 351 | if gzip_compress: 352 | compressed_buffer = BytesIO() 353 | with gzip.GzipFile(fileobj=compressed_buffer, mode='wb') as f: 354 | f.write(buffered.getvalue()) 355 | compressed_buffer.seek(0) 356 | base64_data = base64.b64encode(compressed_buffer.getvalue()).decode('utf-8') 357 | else: 358 | base64_data = base64.b64encode(buffered.getvalue()).decode('utf-8') 359 | return base64_data 360 | 361 | @staticmethod 362 | def string_to_base64(input_string, gzip_compress=False): 363 | if gzip_compress: 364 | compressed_buffer = BytesIO() 365 | with gzip.GzipFile(fileobj=compressed_buffer, mode='wb') as f: 366 | f.write(input_string.encode()) 367 | compressed_buffer.seek(0) 368 | base64_data = base64.b64encode(compressed_buffer.getvalue()).decode('utf-8') 369 | else: 370 | base64_data = base64.b64encode(input_string.encode()).decode('utf-8') 371 | return base64_data 372 | 373 | class PILHandlingHodes: 374 | @staticmethod 375 | def handle_input(tensor_or_image) -> Union[Image.Image, List[Image.Image]]: 376 | pil_image = IOConverter.convert_to_pil(tensor_or_image) 377 | return pil_image 378 | 379 | @staticmethod 380 | def handle_output_as_pil(pil_image: Image.Image) -> Image.Image: 381 | return pil_image 382 | 383 | @staticmethod 384 | def handle_output_as_tensor(pil_image: Image.Image, rgba=False) -> torch.Tensor: 385 | return IOConverter.convert_to_rgb_tensor(pil_image, rgba=rgba) 386 | 387 | @staticmethod 388 | def handle_output_as_rgba_tensor(pil_image: Image.Image) -> torch.Tensor: 389 | return IOConverter.convert_to_rgb_tensor(pil_image, rgba=True) 390 | 391 | @staticmethod 392 | def output_wrapper(func): 393 | def wrapped(*args, **kwargs): 394 | outputs = func(*args, **kwargs) 395 | tuples_collect = [] 396 | for output in outputs: 397 | if isinstance(output, (Image.Image, torch.Tensor)): 398 | tuples_collect.append(PILHandlingHodes.handle_output_as_tensor(output)) 399 | else: 400 | tuples_collect.append(output) 401 | return tuple(tuples_collect) 402 | return wrapped 403 | 404 | @staticmethod 405 | def rgba_output_wrapper(func): 406 | def wrapped(*args, **kwargs): 407 | outputs = func(*args, **kwargs) 408 | tuples_collect = [] 409 | for output in outputs: 410 | if isinstance(output, (Image.Image, torch.Tensor)): 411 | tuples_collect.append(PILHandlingHodes.handle_output_as_rgba_tensor(output)) 412 | else: 413 | tuples_collect.append(output) 414 | return tuple(tuples_collect) 415 | return wrapped 416 | 417 | @staticmethod 418 | def to_base64(anything, quality=100, format="PNG", gzip_compress=False): 419 | base64_data = IOConverter.convert_to_base64(anything, format=format, quality=quality, gzip_compress=gzip_compress) 420 | return base64_data 421 | 422 | @staticmethod 423 | def string_to_base64(input_string, gzip_compress=False): 424 | base64_data = IOConverter.string_to_base64(input_string, gzip_compress=gzip_compress) 425 | return base64_data 426 | 427 | @staticmethod 428 | def maybe_gzip_base64_to_string(base64_string): 429 | return IOConverter.read_maybe_gzip_base64(base64_string) 430 | -------------------------------------------------------------------------------- /randomness.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import uuid 4 | import time 5 | from .autonode import node_wrapper, get_node_names_mappings, validate 6 | 7 | 8 | classes = [] 9 | node = node_wrapper(classes) 10 | 11 | class RandomGuaranteedClass: 12 | OUTPUT_NODE = True 13 | RESULT_NODE = True 14 | @classmethod 15 | def IS_CHANGED(s, *args, **kwargs): 16 | return float("NaN") 17 | 18 | @node 19 | class SystemRandomFloat(RandomGuaranteedClass): 20 | """ 21 | Random number generator using system randomness 22 | """ 23 | def __init__(self): 24 | pass 25 | @staticmethod 26 | def generate(min_val=0.0, max_val=1.0, precision=0): 27 | instance = random.SystemRandom(time.time()) 28 | value = instance.uniform(min_val, max_val) 29 | if precision > 0: 30 | value = round(value, precision) 31 | return (value,) 32 | RETURN_TYPES = ("FLOAT",) 33 | @classmethod 34 | def INPUT_TYPES(s): 35 | return { 36 | "required": { 37 | "min_val": ("FLOAT", { "default": 0.0, "min": -999999999, "max": 999999999.0, "step": 0.01, "display": "number" }), 38 | "max_val": ("FLOAT", { "default": 1.0, "min": -999999999, "max": 999999999.0, "step": 0.01, "display": "number" }), 39 | "precision": ("INT", { "default": 0, "min": 0, "max": 10, "step": 1, "display": "number" }), 40 | }, 41 | } 42 | FUNCTION = "generate" 43 | CATEGORY = "Logic Gates" 44 | custom_name = "System Random Float" 45 | 46 | @node 47 | class DimensionSelectorWithSeedNode: 48 | """ 49 | Finds (width, height) such that width*height is near (resolution^2), 50 | ratio = width/height is in [min_ratio, max_ratio], 51 | both are multiples of 'multiples', and uses 'seed' for random tie-break. 52 | """ 53 | RETURN_TYPES = ("INT", "INT") 54 | 55 | @classmethod 56 | def INPUT_TYPES(s): 57 | return { 58 | "required": { 59 | "resolution": ("INT", {"default": 1024}), 60 | "min_ratio": ("FLOAT", {"default": 0.6}), 61 | "max_ratio": ("FLOAT", {"default": 1.6}), 62 | "multiples": ("INT", {"default": 32}), 63 | "seed": ("INT", {"default": 0}), 64 | } 65 | } 66 | 67 | FUNCTION = "select_dimensions" 68 | CATEGORY = "Logic Gates" 69 | custom_name = "Random Width/Height with Resolution" 70 | 71 | def select_dimensions(self, resolution, min_ratio, max_ratio, multiples, seed): 72 | # For reproducible randomness 73 | random.seed(seed) 74 | 75 | desired_area = resolution * resolution 76 | ratio = random.uniform(min_ratio, max_ratio) 77 | # width * height = resolution^2, width/height = ratio 78 | # thus h**2 = resolution^2 / ratio 79 | height = int(math.sqrt(desired_area / ratio)) 80 | width = int(desired_area / height) 81 | # round to nearest multiple 82 | div_h = height / multiples 83 | div_w = width / multiples 84 | height = round(div_h) * multiples 85 | width = round(div_w) * multiples 86 | # if width * height > resolution^2, reduce width or height 87 | if width * height > desired_area: 88 | if random.choice([True, False]): 89 | width = width - multiples 90 | else: 91 | height = height - multiples 92 | return (width, height) 93 | 94 | 95 | @node 96 | class SystemRandomInt(RandomGuaranteedClass): 97 | """ 98 | Random number generator using system randomness 99 | Generates an integer value between 0 and 2^32-1 100 | """ 101 | 102 | def __init__(self): 103 | pass 104 | @staticmethod 105 | def generate(min_val=0, max_val=2**63-1): 106 | instance = random.SystemRandom(time.time()) 107 | value = instance.randint(min_val, max_val) 108 | return (value,) 109 | @classmethod 110 | def INPUT_TYPES(s): 111 | return { 112 | "required": { 113 | "min_val": ( 114 | "INT", 115 | { 116 | "default": 0, 117 | "min": -(2**63-1), 118 | "max": 2**63-1, 119 | "step": 1, 120 | "display": "number", 121 | }, 122 | ), 123 | "max_val": ( 124 | "INT", 125 | { 126 | "default": 2**63-1, 127 | "min": -(2**63-1), 128 | "max": 2**63-1, 129 | "step": 1, 130 | "display": "number", 131 | }, 132 | ), 133 | }, 134 | } 135 | RETURN_TYPES = ("INT",) 136 | FUNCTION = "generate" 137 | CATEGORY = "Logic Gates" 138 | custom_name = "System Random Int" 139 | 140 | 141 | @node 142 | class SystemUUIDGenerator(RandomGuaranteedClass): 143 | """ 144 | Generates a random UUID 145 | """ 146 | def __init__(self): 147 | pass 148 | @staticmethod 149 | def generate(length=36): 150 | value = uuid.uuid4() 151 | value = str(value) 152 | if length < 36: 153 | value = value[:length] 154 | return (value,) 155 | @classmethod 156 | def INPUT_TYPES(s): 157 | return { 158 | "required": { 159 | "length": ("INT", { "default": 36, "min": 1, "max": 36, "step": 1, "display": "number" }), 160 | }, 161 | } 162 | RETURN_TYPES = ("STRING",) 163 | FUNCTION = "generate" 164 | CATEGORY = "Logic Gates" 165 | custom_name = "UUID Generator" 166 | 167 | 168 | @node 169 | class UniformRandomFloat(RandomGuaranteedClass): 170 | """ 171 | Selects a random float from min to max 172 | Fallbacks to default if min is greater than max 173 | """ 174 | def __init__(self): 175 | pass 176 | def generate(self, min_val, max_val, decimal_places, seed=0): 177 | if min_val > max_val: 178 | return min_val 179 | instance = random.Random(seed) 180 | value = instance.uniform(min_val, max_val) 181 | # prune to decimal places - 0 = int, 1 = 1 decimal place,... 182 | value = round(value, decimal_places) 183 | #print(f"Selected {value} from {min_val} to {max_val}") 184 | return (value,) 185 | @classmethod 186 | def INPUT_TYPES(s): 187 | return { 188 | "required": { 189 | "min_val": ("FLOAT", { "default": 0.0, "min": -999999999, "max": 999999999.0, "step": 0.02, "display": "number" }), 190 | "max_val": ("FLOAT", { "default": 1.0, "min": -999999999, "max": 999999999.0, "step": 0.02, "display": "number" }), 191 | "decimal_places": ("INT", { "default": 1, "min": 0, "max": 10, "step": 1, "display": "number" }), 192 | "seed" : ("INT", { "default": 0, "min": 0, "max": (2**63-1), "step": 1, "display": "number" }), 193 | }, 194 | } 195 | RETURN_TYPES = ("FLOAT",) 196 | FUNCTION = "generate" 197 | CATEGORY = "Logic Gates" 198 | custom_name = "Uniform Random Float" 199 | 200 | @node 201 | class TriangularRandomFloat(RandomGuaranteedClass): 202 | """ 203 | Selects a random float from min to max 204 | Fallbacks to default if min is greater than max 205 | """ 206 | def __init__(self): 207 | pass 208 | def generate(self, low, high, mode, seed=0): 209 | if low > high: 210 | return low 211 | instance = random.Random(seed) 212 | value = instance.triangular(low, high, mode) 213 | return (value,) 214 | @classmethod 215 | def INPUT_TYPES(s): 216 | return { 217 | "required": { 218 | "low": ("FLOAT", { "default": 0.0, "min": -999999999, "max": 999999999.0, "step": 0.02, "display": "number" }), 219 | "high": ("FLOAT", { "default": 1.0, "min": -999999999, "max": 999999999.0, "step": 0.02, "display": "number" }), 220 | "mode": ("FLOAT", { "default": 0.5, "min": -999999999, "max": 999999999.0, "step": 0.02, "display": "number" }), 221 | "seed" : ("INT", { "default": 0, "min": 0, "max": (2**63-1), "step": 1, "display": "number" }), 222 | }, 223 | } 224 | RETURN_TYPES = ("FLOAT",) 225 | FUNCTION = "generate" 226 | CATEGORY = "Logic Gates" 227 | custom_name = "Triangular Random Float" 228 | 229 | @node 230 | class WeightedRandomChoice(RandomGuaranteedClass): 231 | """ 232 | Randomly choose one item from a list with weights. 233 | The input string is parsed as "value|weight$value2|weight2..." 234 | Example: "apple|10$banana|1$orange|3" 235 | """ 236 | def __init__(self): 237 | pass 238 | 239 | def generate(self, input_string, separator, seed=0): 240 | # Example input: "apple|10$banana|1$orange|3" 241 | # Split by '$' -> ["apple|10", "banana|1", "orange|3"] 242 | items = input_string.split(separator) 243 | choices = [] 244 | weights = [] 245 | for item in items: 246 | if '|' in item: 247 | val, wt = item.split('|', 1) 248 | choices.append(val) 249 | try: 250 | weights.append(float(wt)) 251 | except ValueError: 252 | weights.append(1.0) 253 | else: 254 | # fallback if no weight specified 255 | choices.append(item) 256 | weights.append(1.0) 257 | 258 | instance = random.Random(seed) 259 | chosen = instance.choices(population=choices, weights=weights, k=1)[0] 260 | return (chosen,) 261 | 262 | RETURN_TYPES = ("STRING",) 263 | 264 | @classmethod 265 | def INPUT_TYPES(cls): 266 | return { 267 | "required": { 268 | "input_string": ("STRING", {"default": "apple|10$banana|1$orange|3", "display": "text"}), 269 | "separator": ("STRING", {"default": "$", "display": "text"}), 270 | "seed": ("INT", {"default": 0, "min": 0, "max": 2**63-1, "step": 1, "display": "number"}), 271 | } 272 | } 273 | 274 | FUNCTION = "generate" 275 | CATEGORY = "Logic Gates" 276 | custom_name = "Weighted Random Choice" 277 | 278 | @node 279 | class RandomGaussianFloat(RandomGuaranteedClass): 280 | """ 281 | Generates a random float from a normal (Gaussian) distribution 282 | with specified mean and std_dev. 283 | """ 284 | def __init__(self): 285 | pass 286 | 287 | def generate(self, mean, std_dev, decimal_places, seed=0): 288 | instance = random.Random(seed) 289 | value = instance.gauss(mean, std_dev) 290 | value = round(value, decimal_places) 291 | return (value,) 292 | 293 | RETURN_TYPES = ("FLOAT",) 294 | 295 | @classmethod 296 | def INPUT_TYPES(cls): 297 | return { 298 | "required": { 299 | "mean": ("FLOAT", {"default": 0.0, "min": -999999999, "max": 999999999.0, "step": 0.01}), 300 | "std_dev": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 999999999.0, "step": 0.01}), 301 | "decimal_places": ("INT", {"default": 2, "min": 0, "max": 10, "step": 1}), 302 | "seed": ("INT", {"default": 0, "min": 0, "max": 2**63-1, "step": 1}), 303 | }, 304 | } 305 | 306 | FUNCTION = "generate" 307 | CATEGORY = "Logic Gates" 308 | custom_name = "Random Gaussian Float" 309 | 310 | @node 311 | class SystemRandomGaussianFloat(RandomGuaranteedClass): 312 | """ 313 | Generates a random float from a normal (Gaussian) distribution 314 | with specified mean and std_dev. 315 | """ 316 | def __init__(self): 317 | pass 318 | 319 | def generate(self, mean, std_dev, decimal_places): 320 | instance = random.SystemRandom(time.time()) 321 | value = instance.gauss(mean, std_dev) 322 | value = round(value, decimal_places) 323 | return (value,) 324 | 325 | RETURN_TYPES = ("FLOAT",) 326 | 327 | @classmethod 328 | def INPUT_TYPES(cls): 329 | return { 330 | "required": { 331 | "mean": ("FLOAT", {"default": 0.0, "min": -999999999, "max": 999999999.0, "step": 0.01}), 332 | "std_dev": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 999999999.0, "step": 0.01}), 333 | "decimal_places": ("INT", {"default": 2, "min": 0, "max": 10, "step": 1}), 334 | }, 335 | } 336 | 337 | FUNCTION = "generate" 338 | CATEGORY = "Logic Gates" 339 | custom_name = "System Random Gaussian Float" 340 | 341 | @node 342 | class ProbabilityGate(RandomGuaranteedClass): 343 | """ 344 | Returns TRUE with probability p, FALSE otherwise. 345 | """ 346 | def __init__(self): 347 | pass 348 | 349 | def generate(self, probability, seed=0): 350 | instance = random.Random(seed) 351 | value = instance.random() # uniform in [0,1) 352 | return (value < probability,) 353 | 354 | RETURN_TYPES = ("BOOLEAN",) 355 | 356 | @classmethod 357 | def INPUT_TYPES(cls): 358 | return { 359 | "required": { 360 | "probability": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), 361 | "seed": ("INT", {"default": 0, "min": 0, "max": 2**63-1, "step": 1}), 362 | }, 363 | } 364 | 365 | FUNCTION = "generate" 366 | CATEGORY = "Logic Gates" 367 | custom_name = "Probability Gate" 368 | 369 | 370 | @node 371 | class UniformRandomInt(RandomGuaranteedClass): 372 | """ 373 | Selects a random int from min to max 374 | Fallbacks to default if min is greater than max 375 | """ 376 | def __init__(self): 377 | pass 378 | def generate(self, min_val, max_val, seed=0): 379 | if min_val > max_val: 380 | return min_val 381 | instance = random.Random(seed) 382 | value = instance.randint(min_val, max_val) 383 | #print(f"Selected {value} from {min_val} to {max_val}") 384 | return (value,) 385 | @classmethod 386 | def INPUT_TYPES(s): 387 | return { 388 | "required": { 389 | "min_val": ("INT", { "default": 0, "min": -999999999, "max": 999999999, "step": 1, "display": "number" }), 390 | "max_val": ("INT", { "default": 1, "min": -999999999, "max": 999999999, "step": 1, "display": "number" }), 391 | "seed" : ("INT", { "default": 0, "min": 0, "max": (2**63-1), "step": 1, "display": "number" }), 392 | }, 393 | } 394 | RETURN_TYPES = ("INT",) 395 | FUNCTION = "generate" 396 | CATEGORY = "Logic Gates" 397 | custom_name = "Uniform Random Int" 398 | @node 399 | class UniformRandomChoice(RandomGuaranteedClass): 400 | """ 401 | Parses input string with separator '$' and returns a random choice 402 | separator can be changed in the input 403 | """ 404 | def __init__(self): 405 | pass 406 | def generate(self, input_string, separator, seed=0): 407 | instance = random.Random(seed) 408 | choices = input_string.split(separator) 409 | value = instance.choice(choices) 410 | return (value,) 411 | @classmethod 412 | def INPUT_TYPES(s): 413 | return { 414 | "required": { 415 | "input_string": ("STRING", { "default": "a$b$c", "display": "text" }), 416 | "separator": ("STRING", { "default": "$", "display": "text" }), 417 | "seed" : ("INT", { "default": 0, "min": 0, "max": (2**63-1), "step": 1, "display": "number" }), 418 | }, 419 | } 420 | RETURN_TYPES = ("STRING",) 421 | FUNCTION = "generate" 422 | CATEGORY = "Logic Gates" 423 | custom_name = "Uniform Random Choice" 424 | @node 425 | class ManualChoiceString: 426 | """ 427 | Parses input string with separator '$' and returns a random choice 428 | separator can be changed in the input 429 | Accepts index of choice as input 430 | """ 431 | def __init__(self): 432 | pass 433 | def generate(self, input_string, separator, index): 434 | choices = input_string.split(separator) 435 | value = choices[index] 436 | return (value,) 437 | @classmethod 438 | def INPUT_TYPES(s): 439 | return { 440 | "required": { 441 | "input_string": ("STRING", { "default": "a$b$c", "display": "text" }), 442 | "separator": ("STRING", { "default": "$", "display": "text" }), 443 | "index": ("INT", { "default": 0, "min": 0, "max": (2**63-1), "step": 1, "display": "number" }), 444 | }, 445 | } 446 | RETURN_TYPES = ("STRING",) 447 | FUNCTION = "generate" 448 | CATEGORY = "Logic Gates" 449 | custom_name = "Manual Choice String" 450 | @node 451 | class ManualChoiceInt: 452 | """ 453 | Parses input string with separator '$' and returns a random choice 454 | Returns as int 455 | separator can be changed in the input 456 | Accepts index of choice as input 457 | """ 458 | def __init__(self): 459 | pass 460 | def generate(self, input_string, separator, index): 461 | choices = input_string.split(separator) 462 | value = int(choices[index]) 463 | return (value,) 464 | @classmethod 465 | def INPUT_TYPES(s): 466 | return { 467 | "required": { 468 | "input_string": ("STRING", { "default": "1$2$3", "display": "text" }), 469 | "separator": ("STRING", { "default": "$", "display": "text" }), 470 | "index": ("INT", { "default": 0, "min": 0, "max": (2**63-1), "step": 1, "display": "number" }), 471 | }, 472 | } 473 | RETURN_TYPES = ("INT",) 474 | FUNCTION = "generate" 475 | CATEGORY = "Logic Gates" 476 | custom_name = "Manual Choice Int" 477 | @node 478 | class ManualChoiceFloat: 479 | """ 480 | Parses input string with separator '$' and returns a random choice 481 | Returns as float 482 | separator can be changed in the input 483 | Accepts index of choice as input 484 | """ 485 | def __init__(self): 486 | pass 487 | def generate(self, input_string, separator, index): 488 | choices = input_string.split(separator) 489 | value = float(choices[index]) 490 | return (value,) 491 | @classmethod 492 | def INPUT_TYPES(s): 493 | return { 494 | "required": { 495 | "input_string": ("STRING", { "default": "1.0$2.0$3.0", "display": "text" }), 496 | "separator": ("STRING", { "default": "$", "display": "text" }), 497 | "index": ("INT", { "default": 0, "min": 0, "max": (2**63-1), "step": 1, "display": "number" }), 498 | }, 499 | } 500 | RETURN_TYPES = ("FLOAT",) 501 | FUNCTION = "generate" 502 | CATEGORY = "Logic Gates" 503 | custom_name = "Manual Choice Float" 504 | 505 | @node 506 | class RandomShuffleInt(RandomGuaranteedClass): 507 | """ 508 | Get the shuffled list of integers from start to end 509 | Input types and output types are lists of ints 510 | """ 511 | def __init__(self): 512 | pass 513 | def generate(self, input_string, separator, seed=0): 514 | instance = random.Random(seed) 515 | choices = input_string.split(separator) 516 | # shuffle the list 517 | instance.shuffle(choices) 518 | return (choices,) 519 | @classmethod 520 | def INPUT_TYPES(s): 521 | return { 522 | "required": { 523 | "input_string": ("STRING", { "default": "1$2$3", "display": "text" }), 524 | "separator": ("STRING", { "default": "$", "display": "text" }), 525 | "seed" : ("INT", { "default": 0, "min": 0, "max": (2**63-1), "step": 1, "display": "number" }), 526 | }, 527 | } 528 | RETURN_TYPES = ("STRING",) 529 | FUNCTION = "generate" 530 | CATEGORY = "Logic Gates" 531 | custom_name = "Random Shuffle Int" 532 | @node 533 | class RandomShuffleFloat(RandomGuaranteedClass): 534 | """ 535 | Get the shuffled list of floats from start to end 536 | Input types and output types are lists of floats 537 | """ 538 | def __init__(self): 539 | pass 540 | def generate(self, input_string, separator, seed=0): 541 | instance = random.Random(seed) 542 | choices = input_string.split(separator) 543 | # shuffle the list 544 | instance.shuffle(choices) 545 | return (choices,) 546 | @classmethod 547 | def INPUT_TYPES(s): 548 | return { 549 | "required": { 550 | "input_string": ("STRING", { "default": "1.0$2.0$3.0", "display": "text" }), 551 | "separator": ("STRING", { "default": "$", "display": "text" }), 552 | "seed" : ("INT", { "default": 0, "min": 0, "max": (2**63-1), "step": 1, "display": "number" }), 553 | }, 554 | } 555 | RETURN_TYPES = ("STRING",) 556 | FUNCTION = "generate" 557 | CATEGORY = "Logic Gates" 558 | custom_name = "Random Shuffle Float" 559 | @node 560 | class RandomShuffleString(RandomGuaranteedClass): 561 | """ 562 | Get the shuffled list of strings from start to end 563 | Input types and output types are lists of strings 564 | """ 565 | def __init__(self): 566 | pass 567 | def generate(self, input_string, separator, seed=0): 568 | instance = random.Random(seed) 569 | choices = input_string.split(separator) 570 | # shuffle the list 571 | instance.shuffle(choices) 572 | return (choices,) 573 | @classmethod 574 | def INPUT_TYPES(s): 575 | return { 576 | "required": { 577 | "input_string": ("STRING", { "default": "a$b$c", "display": "text" }), 578 | "separator": ("STRING", { "default": "$", "display": "text" }), 579 | "seed" : ("INT", { "default": 0, "min": 0, "max": (2**63-1), "step": 1, "display": "number" }), 580 | }, 581 | } 582 | RETURN_TYPES = ("STRING",) 583 | FUNCTION = "generate" 584 | CATEGORY = "Logic Gates" 585 | custom_name = "Random Shuffle String" 586 | 587 | @node 588 | class CounterInteger(RandomGuaranteedClass): 589 | """ 590 | Generates a counter that increments by 1 591 | """ 592 | def __init__(self): 593 | self.counter = None 594 | def generate(self, reset, start): 595 | if self.counter is None: 596 | self.counter = start 597 | if reset: 598 | self.counter = 0 599 | self.counter += 1 600 | return (int(self.counter),) 601 | @classmethod 602 | def INPUT_TYPES(s): 603 | return { 604 | "required": { 605 | "start": ("FLOAT", { "default": 0.0, "min": -(2**63-1), "max": (2**63-1), "step": 1.0, "display": "number" }), 606 | }, 607 | "optional": { 608 | "reset": ("BOOLEAN", { "default": False }), 609 | }, 610 | } 611 | 612 | RETURN_TYPES = ("INT",) 613 | FUNCTION = "generate" 614 | CATEGORY = "Logic Gates" 615 | custom_name = "Counter Integer" 616 | 617 | @node 618 | class CounterFloat(RandomGuaranteedClass): 619 | """ 620 | Generates a counter that increments by 1 621 | """ 622 | def __init__(self): 623 | self.counter = None 624 | def generate(self, reset, start, step): 625 | if self.counter is None: 626 | self.counter = start 627 | if reset: 628 | self.counter = start 629 | self.counter += step 630 | return (self.counter,) 631 | @classmethod 632 | def INPUT_TYPES(s): 633 | return { 634 | "required": { 635 | "start": ("FLOAT", { "default": 0.0, "min": -(2**63-1), "max": (2**63-1), "step": 1.0, "display": "number" }), 636 | }, 637 | "optional": { 638 | "reset": ("BOOLEAN"), 639 | "step": ("FLOAT", { "default": 1.0, "min": -(2**63-1), "max": (2**63-1), "step": 1.0, "display": "number" }), 640 | }, 641 | } 642 | RETURN_TYPES = ("FLOAT",) 643 | FUNCTION = "generate" 644 | CATEGORY = "Logic Gates" 645 | custom_name = "Counter Float" 646 | 647 | @node 648 | class YieldableIteratorString(RandomGuaranteedClass): 649 | """ 650 | Yields sequentially from the input list (with separator) 651 | If reset is True, then it starts from the beginning 652 | """ 653 | def __init__(self): 654 | self.index = 0 655 | def generate(self, input_string, separator, reset): 656 | choices = input_string.split(separator) 657 | if reset: 658 | self.index = 0 659 | else: 660 | self.index += 1 661 | if self.index >= len(choices): 662 | self.index = 0 663 | value = choices[self.index] 664 | return (value,) 665 | @classmethod 666 | def INPUT_TYPES(s): 667 | return { 668 | "required": { 669 | "input_string": ("STRING", { "default": "a$b$c", "display": "text" }), 670 | "separator": ("STRING", { "default": "$", "display": "text" }), 671 | "reset": ("BOOLEAN"), 672 | }, 673 | } 674 | RETURN_TYPES = ("STRING",) 675 | FUNCTION = "generate" 676 | CATEGORY = "Logic Gates" 677 | custom_name = "Yieldable Iterator String" 678 | 679 | @node 680 | class YieldableIteratorInt(RandomGuaranteedClass): 681 | """ 682 | Yields sequentially with start, end, step 683 | Resets if reset is True 684 | """ 685 | RETURN_TYPES = ("INT",) 686 | FUNCTION = "generate" 687 | CATEGORY = "Logic Gates" 688 | custom_name = "Yieldable (Sequential) Iterator Int" 689 | def __init__(self): 690 | self.iterator = None 691 | def generate(self, start, end, step, reset): 692 | if reset: 693 | self.iterator = None 694 | if self.iterator is None: 695 | self.iterator = range(start, end, step) 696 | try: 697 | value = next(self.iterator) 698 | except StopIteration: 699 | self.iterator = range(start, end, step) 700 | value = next(self.iterator) 701 | return (value,) 702 | @classmethod 703 | def INPUT_TYPES(s): 704 | return { 705 | "required": { 706 | "start": ("INT", { "default": 0, "min": -(2**63-1), "max": (2**63-1), "step": 1, "display": "number" }), 707 | "end": ("INT", { "default": 10, "min": -(2**63-1), "max": (2**63-1), "step": 1, "display": "number" }), 708 | "step": ("INT", { "default": 1, "min": -(2**63-1), "max": (2**63-1), "step": 1, "display": "number" }), 709 | "reset": ("BOOLEAN"), 710 | }, 711 | } 712 | 713 | CLASS_MAPPINGS, CLASS_NAMES = get_node_names_mappings(classes) 714 | validate(classes) 715 | -------------------------------------------------------------------------------- /pystructure.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from .autonode import node_wrapper, get_node_names_mappings, validate, anytype 5 | 6 | ############################################################################## 7 | # "NewPointer" BASE CLASS 8 | ############################################################################## 9 | 10 | class NewPointer: 11 | """A base class that forces ComfyUI to skip caching by returning NaN in IS_CHANGED.""" 12 | RESULT_NODE = True # Typically means the node can appear as a "result" in the graph 13 | OUTPUT_NODE = True # Typically means the node can appear as an "output" in the graph 14 | @classmethod 15 | def IS_CHANGED(cls, *args, **kwargs): 16 | return float("NaN") # Forces ComfyUI to consider it always changed 17 | 18 | ############################################################################## 19 | # HELPER: Throw if path tries to escape 20 | ############################################################################## 21 | 22 | def throw_if_parent_or_root_access(path): 23 | if ".." in path or path.startswith("/") or path.startswith("\\"): 24 | raise RuntimeError("Tried to access parent or root directory") 25 | if path.startswith("~"): 26 | raise RuntimeError("Tried to access home directory") 27 | if os.path.isabs(path): 28 | raise RuntimeError("Path cannot be absolute") 29 | 30 | ############################################################################## 31 | # REGISTER CLASSES BELOW 32 | ############################################################################## 33 | 34 | fundamental_classes = [] 35 | fundamental_node = node_wrapper(fundamental_classes) 36 | 37 | ############################ 38 | # GLOBALS & JSON NODES 39 | ############################ 40 | 41 | GLOBAL_STORAGE = {} # For global variable set/get 42 | 43 | @fundamental_node 44 | class JsonParseNode(NewPointer): 45 | """ 46 | Convert JSON string into a Python object (dict, list, etc.) stored in 'anytype'. 47 | """ 48 | FUNCTION = "parse_json" 49 | RETURN_TYPES = (anytype,) 50 | CATEGORY = "Data" 51 | custom_name="Pyobjects/JSON -> PyObject" 52 | 53 | @staticmethod 54 | def parse_json(json_string): 55 | try: 56 | obj = json.loads(json_string) 57 | except json.JSONDecodeError as e: 58 | raise ValueError(f"Invalid JSON: {str(e)}") 59 | return (obj,) 60 | 61 | @classmethod 62 | def INPUT_TYPES(cls): 63 | return { 64 | "required": { 65 | "json_string": ("STRING", {"default": '{"key": "value"}'}), 66 | } 67 | } 68 | 69 | @fundamental_node 70 | class JsonDumpNode(NewPointer): 71 | """ 72 | Convert a Python object (dict, list, etc.) into a JSON string. 73 | """ 74 | FUNCTION = "dump_json" 75 | RETURN_TYPES = ("STRING",) 76 | CATEGORY = "Data" 77 | custom_name="Pyobjects/PyObject -> JSON" 78 | 79 | @staticmethod 80 | def dump_json(py_obj, indent=0): 81 | json_str = json.dumps(py_obj, indent=indent, ensure_ascii=False) 82 | return (json_str,) 83 | 84 | @classmethod 85 | def INPUT_TYPES(cls): 86 | return { 87 | "required": { 88 | "py_obj": (anytype, ), 89 | }, 90 | "optional": { 91 | "indent": ("INT", {"default": 0}), 92 | } 93 | } 94 | 95 | @fundamental_node 96 | class JsonDumpAnyStructureNode(NewPointer): 97 | """ 98 | Dump either DICT or LIST or SET (any Python structure) into JSON string. 99 | """ 100 | FUNCTION = "dump_any_struct" 101 | RETURN_TYPES = ("STRING",) 102 | CATEGORY = "Data" 103 | custom_name="Pyobjects/PyStructure -> JSON" 104 | 105 | @staticmethod 106 | def dump_any_struct(py_obj, indent=0): 107 | return (json.dumps(py_obj, indent=indent, ensure_ascii=False),) 108 | 109 | @classmethod 110 | def INPUT_TYPES(cls): 111 | return { 112 | "required": { 113 | "py_obj": (anytype,), 114 | }, 115 | "optional": { 116 | "indent": ("INT", {"default": 0}), 117 | } 118 | } 119 | 120 | ############################ 121 | # DICT NODES 122 | ############################ 123 | 124 | @fundamental_node 125 | class DictCreateNode(NewPointer): 126 | """ 127 | Creates a new empty dictionary (type DICT). 128 | """ 129 | FUNCTION = "create_dict" 130 | RETURN_TYPES = ("DICT",) 131 | CATEGORY = "Data" 132 | custom_name="Pyobjects/Create Dict" 133 | 134 | @staticmethod 135 | def create_dict(): 136 | return ({},) 137 | 138 | @classmethod 139 | def INPUT_TYPES(cls): 140 | return {"required": {}} 141 | 142 | @fundamental_node 143 | class DictSetNode(NewPointer): 144 | """ 145 | dict[key] = value. Returns the updated dict. 146 | """ 147 | FUNCTION = "dict_set" 148 | RETURN_TYPES = ("DICT",) 149 | CATEGORY = "Data" 150 | custom_name="Pyobjects/Dict Set" 151 | 152 | @staticmethod 153 | def dict_set(py_dict, key, value): 154 | if not isinstance(py_dict, dict): 155 | raise ValueError("Input must be a Python dict") 156 | new_dict = py_dict.copy() 157 | new_dict[key] = value 158 | return (new_dict,) 159 | 160 | @classmethod 161 | def INPUT_TYPES(cls): 162 | return { 163 | "required": { 164 | "py_dict": ("DICT", ), 165 | "key": ("STRING", {"default": "some_key"}), 166 | "value": (anytype, {"default": "some_value"}) 167 | } 168 | } 169 | 170 | @fundamental_node 171 | class DictGetNode(NewPointer): 172 | """ 173 | Returns dict[key]. If key not found, returns None. 174 | """ 175 | FUNCTION = "dict_get" 176 | RETURN_TYPES = (anytype,) # The retrieved value can be anything 177 | CATEGORY = "Data" 178 | custom_name="Pyobjects/Dict Get" 179 | 180 | @staticmethod 181 | def dict_get(py_dict, key): 182 | if not isinstance(py_dict, dict): 183 | raise ValueError("Input must be a Python dict") 184 | return (py_dict.get(key, None),) 185 | 186 | @classmethod 187 | def INPUT_TYPES(cls): 188 | return { 189 | "required": { 190 | "py_dict": ("DICT", ), 191 | "key": ("STRING", {"default": "some_key"}), 192 | } 193 | } 194 | 195 | @fundamental_node 196 | class DictRemoveKeyNode(NewPointer): 197 | """ 198 | Removes a key from the dictionary (if present). Returns the updated dict. 199 | """ 200 | FUNCTION = "dict_remove_key" 201 | RETURN_TYPES = ("DICT",) 202 | CATEGORY = "Data" 203 | custom_name="Pyobjects/Dict Remove Key" 204 | 205 | @staticmethod 206 | def dict_remove_key(py_dict, key): 207 | if not isinstance(py_dict, dict): 208 | raise ValueError("Input must be a Python dict") 209 | new_dict = py_dict.copy() 210 | new_dict.pop(key, None) 211 | return (new_dict,) 212 | 213 | @classmethod 214 | def INPUT_TYPES(cls): 215 | return { 216 | "required": { 217 | "py_dict": ("DICT",), 218 | "key": ("STRING", {"default": "some_key"}), 219 | } 220 | } 221 | 222 | @fundamental_node 223 | class DictMergeNode(NewPointer): 224 | """ 225 | Merges two dictionaries. 226 | If there are duplicate keys, the second dict's values overwrite the first. 227 | Returns a new dictionary (or modifies the first in place, see below). 228 | """ 229 | FUNCTION = "dict_merge" 230 | RETURN_TYPES = ("DICT",) 231 | CATEGORY = "Data" 232 | custom_name="Pyobjects/Dict Merge" 233 | 234 | @staticmethod 235 | def dict_merge(dict_a, dict_b, in_place=False): 236 | if not isinstance(dict_a, dict) or not isinstance(dict_b, dict): 237 | raise ValueError("Both inputs must be Python dicts") 238 | if in_place: 239 | new_dict = dict_a.copy() 240 | new_dict.update(dict_b) 241 | return (new_dict,) 242 | else: 243 | merged = {**dict_a, **dict_b} 244 | return (merged,) 245 | 246 | @classmethod 247 | def INPUT_TYPES(cls): 248 | return { 249 | "required": { 250 | "dict_a": ("DICT",), 251 | "dict_b": ("DICT",), 252 | }, 253 | "optional": { 254 | "in_place": ("BOOLEAN", {"default": False}), 255 | } 256 | } 257 | 258 | @fundamental_node 259 | class DictKeysNode(NewPointer): 260 | """ 261 | Returns the list of keys in a dictionary as type LIST. 262 | """ 263 | FUNCTION = "dict_keys" 264 | RETURN_TYPES = ("LIST",) 265 | CATEGORY = "Data" 266 | custom_name="Pyobjects/Dict Keys" 267 | 268 | @staticmethod 269 | def dict_keys(py_dict): 270 | if not isinstance(py_dict, dict): 271 | raise ValueError("Input must be a Python dict") 272 | keys_list = list(py_dict.keys()) 273 | return (keys_list,) 274 | 275 | @classmethod 276 | def INPUT_TYPES(cls): 277 | return { 278 | "required": { 279 | "py_dict": ("DICT",), 280 | } 281 | } 282 | 283 | @fundamental_node 284 | class DictValuesNode(NewPointer): 285 | """ 286 | Returns the list of values in a dictionary as type LIST. 287 | """ 288 | FUNCTION = "dict_values" 289 | RETURN_TYPES = ("LIST",) 290 | CATEGORY = "Data" 291 | custom_name="Pyobjects/Dict Values" 292 | 293 | @staticmethod 294 | def dict_values(py_dict): 295 | if not isinstance(py_dict, dict): 296 | raise ValueError("Input must be a Python dict") 297 | values_list = list(py_dict.values()) 298 | return (values_list,) 299 | 300 | @classmethod 301 | def INPUT_TYPES(cls): 302 | return { 303 | "required": { 304 | "py_dict": ("DICT",), 305 | } 306 | } 307 | 308 | @fundamental_node 309 | class DictItemsNode(NewPointer): 310 | """ 311 | Returns the list of (key, value) pairs in a dictionary as type LIST. 312 | Each item in the list is a 2-element tuple [key, value]. 313 | """ 314 | FUNCTION = "dict_items" 315 | RETURN_TYPES = ("LIST",) 316 | CATEGORY = "Data" 317 | custom_name="Pyobjects/Dict Items" 318 | 319 | @staticmethod 320 | def dict_items(py_dict): 321 | if not isinstance(py_dict, dict): 322 | raise ValueError("Input must be a Python dict") 323 | items_list = list(py_dict.items()) 324 | return (items_list,) 325 | 326 | @classmethod 327 | def INPUT_TYPES(cls): 328 | return { 329 | "required": { 330 | "py_dict": ("DICT",), 331 | } 332 | } 333 | 334 | @fundamental_node 335 | class DictPointer(NewPointer): 336 | """ 337 | Example of a stateful node: holds onto the incoming dict until reset. 338 | """ 339 | FUNCTION = "dict_pointer" 340 | RETURN_TYPES = ("DICT",) 341 | CATEGORY = "Data" 342 | custom_name="Pyobjects/Dict Pointer" 343 | 344 | def __init__(self): 345 | self.pointer = None 346 | 347 | def dict_pointer(self, py_dict, reset=False): 348 | if not isinstance(py_dict, dict): 349 | raise ValueError("Input must be a Python dict") 350 | if self.pointer is None or reset: 351 | self.pointer = py_dict 352 | if reset: 353 | value = self.pointer 354 | self.pointer = None 355 | return (value,) 356 | return (self.pointer,) 357 | 358 | @classmethod 359 | def INPUT_TYPES(cls): 360 | return { 361 | "required": { 362 | "py_dict": ("DICT", ), 363 | }, 364 | "optional": { 365 | "reset": ("BOOLEAN",), 366 | }, 367 | } 368 | 369 | ############################ 370 | # GLOBAL VAR NODES 371 | ############################ 372 | 373 | @fundamental_node 374 | class GlobalVarSetNode(NewPointer): 375 | """ 376 | Store a Python object in a global dictionary under the given key. 377 | Returns the same value that was stored, as anytype. 378 | """ 379 | FUNCTION = "global_var_set" 380 | RETURN_TYPES = (anytype,) 381 | CATEGORY = "Data" 382 | custom_name="Pyobjects/Global Var Set" 383 | 384 | def global_var_set(self, key, value): 385 | GLOBAL_STORAGE[key] = value 386 | print("GlobalVarSetNode:", GLOBAL_STORAGE) 387 | return (value,) 388 | 389 | @classmethod 390 | def INPUT_TYPES(cls): 391 | return { 392 | "required": { 393 | "key": ("STRING", {"default": "my_key"}), 394 | "value": (anytype, {"default": "my_value"}), 395 | } 396 | } 397 | 398 | @fundamental_node 399 | class GlobalVarSetIfNotExistsNode(NewPointer): 400 | """ 401 | Store a Python object in a global dictionary under the given key, only if the key doesn't already exist. 402 | """ 403 | FUNCTION = "global_var_set_if_not_exists" 404 | RETURN_TYPES = (anytype,) 405 | CATEGORY = "Data" 406 | custom_name="Pyobjects/Global Var Set If Not Exists" 407 | 408 | def global_var_set_if_not_exists(self, key, value): 409 | if key not in GLOBAL_STORAGE: 410 | GLOBAL_STORAGE[key] = value 411 | print("GlobalVarSetIfNotExistsNode:", GLOBAL_STORAGE) 412 | return (GLOBAL_STORAGE[key],) # Return the final stored value 413 | 414 | @classmethod 415 | def INPUT_TYPES(cls): 416 | return { 417 | "required": { 418 | "key": ("STRING", {"default": "my_key"}), 419 | "value": (anytype, {"default": "my_value"}), 420 | } 421 | } 422 | 423 | @fundamental_node 424 | class GlobalVarGetNode(NewPointer): 425 | """ 426 | Retrieve a Python object from the global dictionary by key. 427 | If not found, returns None by default. 428 | """ 429 | FUNCTION = "global_var_get" 430 | RETURN_TYPES = (anytype,) 431 | CATEGORY = "Data" 432 | custom_name="Pyobjects/Global Var Get" 433 | 434 | def global_var_get(self, key): 435 | print("GlobalVarGetNode:", GLOBAL_STORAGE) 436 | return (GLOBAL_STORAGE.get(key, None),) 437 | 438 | @classmethod 439 | def INPUT_TYPES(cls): 440 | return { 441 | "required": { 442 | "key": ("STRING", {"default": "my_key"}), 443 | } 444 | } 445 | 446 | @fundamental_node 447 | class GlobalVarRemoveNode(NewPointer): 448 | """ 449 | Remove a key from the global dictionary (if present). 450 | Returns the value that was removed, or None if key didn't exist. 451 | """ 452 | FUNCTION = "global_var_remove" 453 | RETURN_TYPES = (anytype,) 454 | CATEGORY = "Data" 455 | custom_name="Pyobjects/Global Var Remove" 456 | 457 | def global_var_remove(self, key): 458 | removed_value = GLOBAL_STORAGE.pop(key, None) 459 | return (removed_value,) 460 | 461 | @classmethod 462 | def INPUT_TYPES(cls): 463 | return { 464 | "required": { 465 | "key": ("STRING", {"default": "my_key"}), 466 | } 467 | } 468 | 469 | @fundamental_node 470 | class GlobalVarSaveNode(NewPointer): 471 | """ 472 | Saves the value of GLOBAL_STORAGE[key] to a JSON file on disk. 473 | If the key isn't in storage, returns an error or saves None if allow_missing=True. 474 | """ 475 | FUNCTION = "global_var_save" 476 | RETURN_TYPES = ("STRING",) # Return the filepath that was saved 477 | CATEGORY = "Data" 478 | custom_name="Pyobjects/Global Var Save" 479 | 480 | def global_var_save(self, key, filepath, allow_missing=False): 481 | throw_if_parent_or_root_access(filepath) 482 | 483 | value = GLOBAL_STORAGE.get(key, None) 484 | if value is None and not allow_missing and key not in GLOBAL_STORAGE: 485 | raise KeyError(f"Global key '{key}' not found in storage. Set allow_missing=True to save None.") 486 | 487 | dir_ = os.path.dirname(filepath) 488 | if dir_ and not os.path.exists(dir_): 489 | os.makedirs(dir_, exist_ok=True) 490 | 491 | with open(filepath, "w", encoding="utf-8") as f: 492 | json.dump(value, f, ensure_ascii=False, indent=4) 493 | return (filepath,) 494 | 495 | @classmethod 496 | def INPUT_TYPES(cls): 497 | return { 498 | "required": { 499 | "key": ("STRING", {"default": "my_key"}), 500 | "filepath": ("STRING", {"default": "my_global_var.json"}), 501 | }, 502 | "optional": { 503 | "allow_missing": ("BOOLEAN", {"default": False}) 504 | } 505 | } 506 | 507 | @fundamental_node 508 | class GlobalVarLoadNode(NewPointer): 509 | """ 510 | Loads JSON from a file and stores it in GLOBAL_STORAGE under the given key. 511 | Returns the loaded value. 512 | """ 513 | FUNCTION = "global_var_load" 514 | RETURN_TYPES = (anytype,) 515 | CATEGORY = "Data" 516 | custom_name="Pyobjects/Global Var Load" 517 | 518 | def global_var_load(self, key, filepath, allow_missing=False): 519 | throw_if_parent_or_root_access(filepath) 520 | 521 | if not os.path.exists(filepath): 522 | if allow_missing: 523 | GLOBAL_STORAGE[key] = None 524 | return (None,) 525 | else: 526 | raise FileNotFoundError(f"File '{filepath}' does not exist.") 527 | 528 | with open(filepath, "r", encoding="utf-8") as f: 529 | value = json.load(f) 530 | GLOBAL_STORAGE[key] = value 531 | return (value,) 532 | 533 | @classmethod 534 | def INPUT_TYPES(cls): 535 | return { 536 | "required": { 537 | "key": ("STRING", {"default": "my_key"}), 538 | "filepath": ("STRING", {"default": "my_global_var.json"}), 539 | }, 540 | "optional": { 541 | "allow_missing": ("BOOLEAN", {"default": False}), 542 | } 543 | } 544 | 545 | ############################ 546 | # LIST NODES 547 | ############################ 548 | 549 | @fundamental_node 550 | class ListCreateNode(NewPointer): 551 | """ 552 | Creates a new empty list (type LIST). 553 | """ 554 | FUNCTION = "create_list" 555 | RETURN_TYPES = ("LIST",) 556 | CATEGORY = "Data" 557 | custom_name="Pyobjects/Create List" 558 | 559 | @staticmethod 560 | def create_list(): 561 | return ([],) 562 | 563 | @classmethod 564 | def INPUT_TYPES(cls): 565 | return {"required": {}} 566 | 567 | @fundamental_node 568 | class ListAppendNode(NewPointer): 569 | """ 570 | Append an item to a Python list. Returns the updated list. 571 | """ 572 | FUNCTION = "list_append" 573 | RETURN_TYPES = ("LIST",) 574 | CATEGORY = "Data" 575 | custom_name="Pyobjects/List Append" 576 | 577 | @staticmethod 578 | def list_append(py_list, item): 579 | if not isinstance(py_list, list): 580 | raise ValueError("Input must be a Python list") 581 | new_list = py_list.copy() 582 | new_list.append(item) 583 | return (new_list,) 584 | 585 | @classmethod 586 | def INPUT_TYPES(cls): 587 | return { 588 | "required": { 589 | "py_list": ("LIST",), 590 | "item": (anytype,), 591 | } 592 | } 593 | 594 | @fundamental_node 595 | class ListGetNode(NewPointer): 596 | """ 597 | Return an element from a list by index as anytype. 598 | """ 599 | FUNCTION = "list_get" 600 | RETURN_TYPES = (anytype,) 601 | CATEGORY = "Data" 602 | custom_name="Pyobjects/List Get" 603 | 604 | @staticmethod 605 | def list_get(py_list, index): 606 | if not isinstance(py_list, list): 607 | raise ValueError("Input must be a Python list") 608 | if index < 0 or index >= len(py_list): 609 | raise IndexError("Index out of range") 610 | return (py_list[index],) 611 | 612 | @classmethod 613 | def INPUT_TYPES(cls): 614 | return { 615 | "required": { 616 | "py_list": ("LIST",), 617 | "index": ("INT", {"default": 0}), 618 | } 619 | } 620 | 621 | @fundamental_node 622 | class ListRemoveNode(NewPointer): 623 | """ 624 | Removes the first occurrence of 'item' from the list (if present). 625 | Returns the updated list. 626 | """ 627 | FUNCTION = "list_remove" 628 | RETURN_TYPES = ("LIST",) 629 | CATEGORY = "Data" 630 | custom_name="Pyobjects/List Remove" 631 | 632 | @staticmethod 633 | def list_remove(py_list, item): 634 | if not isinstance(py_list, list): 635 | raise ValueError("Input must be a Python list") 636 | new_list = py_list.copy() 637 | if item in new_list: 638 | new_list.remove(item) 639 | return (new_list,) 640 | 641 | @classmethod 642 | def INPUT_TYPES(cls): 643 | return { 644 | "required": { 645 | "py_list": ("LIST",), 646 | "item": (anytype,), 647 | } 648 | } 649 | 650 | @fundamental_node 651 | class ListPopNode(NewPointer): 652 | """ 653 | Pop an item from the list by index. 654 | Returns (popped_item, updated_list). 655 | """ 656 | FUNCTION = "list_pop" 657 | RETURN_TYPES = (anytype, "LIST") 658 | CATEGORY = "Data" 659 | custom_name="Pyobjects/List Pop" 660 | 661 | @staticmethod 662 | def list_pop(py_list, index=-1): 663 | if not isinstance(py_list, list): 664 | raise ValueError("Input must be a Python list") 665 | if len(py_list) == 0: 666 | raise IndexError("Cannot pop from an empty list") 667 | new_list = py_list.copy() 668 | popped = new_list.pop(index) 669 | return (popped, new_list) 670 | 671 | @classmethod 672 | def INPUT_TYPES(cls): 673 | return { 674 | "required": { 675 | "py_list": ("LIST",), 676 | }, 677 | "optional": { 678 | "index": ("INT", {"default": -1}), 679 | } 680 | } 681 | 682 | @fundamental_node 683 | class ListInsertNode(NewPointer): 684 | """ 685 | Insert an item into the list at a given index. 686 | Returns the updated list. 687 | """ 688 | FUNCTION = "list_insert" 689 | RETURN_TYPES = ("LIST",) 690 | CATEGORY = "Data" 691 | custom_name="Pyobjects/List Insert" 692 | 693 | @staticmethod 694 | def list_insert(py_list, index, item): 695 | if not isinstance(py_list, list): 696 | raise ValueError("Input must be a Python list") 697 | new_list = py_list.copy() 698 | new_list.insert(index, item) 699 | return (new_list,) 700 | 701 | @classmethod 702 | def INPUT_TYPES(cls): 703 | return { 704 | "required": { 705 | "py_list": ("LIST",), 706 | "index": ("INT", {"default": 0}), 707 | "item": (anytype,), 708 | } 709 | } 710 | 711 | @fundamental_node 712 | class ListExtendNode(NewPointer): 713 | """ 714 | Extends list A by appending elements from list B. Returns the updated list A. 715 | """ 716 | FUNCTION = "list_extend" 717 | RETURN_TYPES = ("LIST",) 718 | CATEGORY = "Data" 719 | custom_name="Pyobjects/List Extend" 720 | 721 | @staticmethod 722 | def list_extend(list_a, list_b): 723 | if not isinstance(list_a, list) or not isinstance(list_b, list): 724 | raise ValueError("Both inputs must be Python lists") 725 | new_list = list_a.copy() 726 | new_list.extend(list_b) 727 | return (new_list,) 728 | 729 | @classmethod 730 | def INPUT_TYPES(cls): 731 | return { 732 | "required": { 733 | "list_a": ("LIST",), 734 | "list_b": ("LIST",), 735 | } 736 | } 737 | 738 | @fundamental_node 739 | class ToListTypeNode(NewPointer): 740 | """ 741 | Takes any Python object that is actually iterable, returns it as type LIST. 742 | (Casts dicts/sets/tuples to list by calling list(obj).) 743 | """ 744 | FUNCTION = "to_list_type" 745 | RETURN_TYPES = ("LIST",) 746 | CATEGORY = "Data" 747 | custom_name="Pyobjects/Cast to LIST" 748 | 749 | @staticmethod 750 | def to_list_type(py_obj): 751 | if isinstance(py_obj, dict): 752 | # e.g. we can choose to return keys or something else 753 | return (list(py_obj.keys()),) 754 | try: 755 | return (list(py_obj),) 756 | except TypeError: 757 | raise ValueError("Object is not iterable, cannot cast to list") 758 | 759 | @classmethod 760 | def INPUT_TYPES(cls): 761 | return { 762 | "required": { 763 | "py_obj": (anytype,), 764 | } 765 | } 766 | 767 | ############################ 768 | # SET NODES 769 | ############################ 770 | 771 | @fundamental_node 772 | class ToSetTypeNode(NewPointer): 773 | """ 774 | Takes any Python object that is iterable, returns it as type SET. 775 | (Casts dict/list/tuple to set(obj). For dict, uses dict.keys().) 776 | """ 777 | FUNCTION = "to_set_type" 778 | RETURN_TYPES = ("SET",) 779 | CATEGORY = "Data" 780 | custom_name="Pyobjects/Cast to SET" 781 | 782 | @staticmethod 783 | def to_set_type(py_obj): 784 | if isinstance(py_obj, dict): 785 | return (set(py_obj.keys()),) 786 | try: 787 | return (set(py_obj),) 788 | except TypeError: 789 | raise ValueError("Object is not iterable, cannot cast to set") 790 | 791 | @classmethod 792 | def INPUT_TYPES(cls): 793 | return { 794 | "required": { 795 | "py_obj": (anytype,), 796 | } 797 | } 798 | 799 | @fundamental_node 800 | class SetCreateNode(NewPointer): 801 | """ 802 | Creates a new empty set (type SET). 803 | """ 804 | FUNCTION = "create_set" 805 | RETURN_TYPES = ("SET",) 806 | CATEGORY = "Data" 807 | custom_name="Pyobjects/Create Set" 808 | 809 | @staticmethod 810 | def create_set(): 811 | return (set(),) 812 | 813 | @classmethod 814 | def INPUT_TYPES(cls): 815 | return {"required": {}} 816 | 817 | @fundamental_node 818 | class SetAddNode(NewPointer): 819 | """ 820 | Adds an item to a set. 821 | """ 822 | FUNCTION = "set_add" 823 | RETURN_TYPES = ("SET",) 824 | CATEGORY = "Data" 825 | custom_name="Pyobjects/Set Add" 826 | 827 | @staticmethod 828 | def set_add(py_set, item): 829 | if not isinstance(py_set, set): 830 | raise ValueError("Input must be a Python set") 831 | new_set = py_set.copy() 832 | new_set.add(item) 833 | return (new_set,) 834 | 835 | @classmethod 836 | def INPUT_TYPES(cls): 837 | return { 838 | "required": { 839 | "py_set": ("SET",), 840 | "item": (anytype,), 841 | } 842 | } 843 | 844 | @fundamental_node 845 | class SetRemoveNode(NewPointer): 846 | """ 847 | Removes an item from the set (if present). 848 | """ 849 | FUNCTION = "set_remove" 850 | RETURN_TYPES = ("SET",) 851 | CATEGORY = "Data" 852 | custom_name="Pyobjects/Set Remove" 853 | 854 | @staticmethod 855 | def set_remove(py_set, item): 856 | if not isinstance(py_set, set): 857 | raise ValueError("Input must be a Python set") 858 | new_set = py_set.copy() 859 | new_set.discard(item) 860 | return (new_set,) 861 | 862 | @classmethod 863 | def INPUT_TYPES(cls): 864 | return { 865 | "required": { 866 | "py_set": ("SET",), 867 | "item": (anytype,), 868 | } 869 | } 870 | 871 | @fundamental_node 872 | class SetUnionNode(NewPointer): 873 | """ 874 | Returns the union of two sets. 875 | """ 876 | FUNCTION = "set_union" 877 | RETURN_TYPES = ("SET",) 878 | CATEGORY = "Data" 879 | custom_name="Pyobjects/Set Union" 880 | 881 | @staticmethod 882 | def set_union(py_set_a, py_set_b): 883 | if not isinstance(py_set_a, set) or not isinstance(py_set_b, set): 884 | raise ValueError("Both inputs must be Python sets") 885 | return (py_set_a.union(py_set_b),) 886 | 887 | @classmethod 888 | def INPUT_TYPES(cls): 889 | return { 890 | "required": { 891 | "py_set_a": ("SET",), 892 | "py_set_b": ("SET",), 893 | } 894 | } 895 | 896 | @fundamental_node 897 | class SetIntersectionNode(NewPointer): 898 | """ 899 | Returns the intersection of two sets. 900 | """ 901 | FUNCTION = "set_intersection" 902 | RETURN_TYPES = ("SET",) 903 | CATEGORY = "Data" 904 | custom_name="Pyobjects/Set Intersection" 905 | 906 | @staticmethod 907 | def set_intersection(py_set_a, py_set_b): 908 | if not isinstance(py_set_a, set) or not isinstance(py_set_b, set): 909 | raise ValueError("Both inputs must be Python sets") 910 | return (py_set_a.intersection(py_set_b),) 911 | 912 | @classmethod 913 | def INPUT_TYPES(cls): 914 | return { 915 | "required": { 916 | "py_set_a": ("SET",), 917 | "py_set_b": ("SET",), 918 | } 919 | } 920 | 921 | @fundamental_node 922 | class SetDifferenceNode(NewPointer): 923 | """ 924 | Returns the difference of two sets: A - B. 925 | """ 926 | FUNCTION = "set_difference" 927 | RETURN_TYPES = ("SET",) 928 | CATEGORY = "Data" 929 | custom_name="Pyobjects/Set Difference" 930 | 931 | @staticmethod 932 | def set_difference(py_set_a, py_set_b): 933 | if not isinstance(py_set_a, set) or not isinstance(py_set_b, set): 934 | raise ValueError("Both inputs must be Python sets") 935 | return (py_set_a.difference(py_set_b),) 936 | 937 | @classmethod 938 | def INPUT_TYPES(cls): 939 | return { 940 | "required": { 941 | "py_set_a": ("SET",), 942 | "py_set_b": ("SET",), 943 | } 944 | } 945 | 946 | @fundamental_node 947 | class SetSymDifferenceNode(NewPointer): 948 | """ 949 | Returns the symmetric difference (elements in A or B but not both). 950 | """ 951 | FUNCTION = "set_sym_difference" 952 | RETURN_TYPES = ("SET",) 953 | CATEGORY = "Data" 954 | custom_name="Pyobjects/Set Symmetric Difference" 955 | 956 | @staticmethod 957 | def set_sym_difference(py_set_a, py_set_b): 958 | if not isinstance(py_set_a, set) or not isinstance(py_set_b, set): 959 | raise ValueError("Both inputs must be Python sets") 960 | return (py_set_a.symmetric_difference(py_set_b),) 961 | 962 | @classmethod 963 | def INPUT_TYPES(cls): 964 | return { 965 | "required": { 966 | "py_set_a": ("SET",), 967 | "py_set_b": ("SET",), 968 | } 969 | } 970 | 971 | @fundamental_node 972 | class SetClearNode(NewPointer): 973 | """ 974 | Clears all elements from the set. Returns the now-empty set. 975 | """ 976 | FUNCTION = "set_clear" 977 | RETURN_TYPES = ("SET",) 978 | CATEGORY = "Data" 979 | custom_name="Pyobjects/Set Clear" 980 | 981 | @staticmethod 982 | def set_clear(py_set): 983 | if not isinstance(py_set, set): 984 | raise ValueError("Input must be a Python set") 985 | new_set = py_set.copy() 986 | new_set.clear() 987 | return (new_set,) 988 | 989 | @classmethod 990 | def INPUT_TYPES(cls): 991 | return { 992 | "required": { 993 | "py_set": ("SET",), 994 | } 995 | } 996 | 997 | @fundamental_node 998 | class SetToListNode(NewPointer): 999 | """ 1000 | Converts a set to a list. Returns the list as type LIST. 1001 | """ 1002 | FUNCTION = "set_to_list" 1003 | RETURN_TYPES = ("LIST",) 1004 | CATEGORY = "Data" 1005 | custom_name="Pyobjects/Set to List" 1006 | 1007 | @staticmethod 1008 | def set_to_list(py_set): 1009 | if not isinstance(py_set, set): 1010 | raise ValueError("Input must be a Python set") 1011 | return (list(py_set),) 1012 | 1013 | @classmethod 1014 | def INPUT_TYPES(cls): 1015 | return { 1016 | "required": { 1017 | "py_set": ("SET",), 1018 | } 1019 | } 1020 | 1021 | ############################################################################## 1022 | # REGISTER THE NODES 1023 | ############################################################################## 1024 | 1025 | CLASS_MAPPINGS, CLASS_NAMES = get_node_names_mappings(fundamental_classes) 1026 | validate(fundamental_classes) -------------------------------------------------------------------------------- /io_node.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import math 4 | import numpy as np 5 | import torch 6 | 7 | try: 8 | import piexif.helper 9 | import piexif 10 | from .exif.exif import read_info_from_image_stealth 11 | 12 | piexif_loaded = True 13 | except ImportError: 14 | piexif_loaded = False 15 | 16 | from .imgio.converter import PILHandlingHodes 17 | from .autonode import node_wrapper, get_node_names_mappings, validate, anytype, PILImage 18 | import time 19 | import os 20 | import shutil 21 | from PIL import Image 22 | from PIL import ImageOps 23 | from PIL import ImageEnhance 24 | from PIL.PngImagePlugin import PngInfo 25 | import folder_paths 26 | from comfy.cli_args import args 27 | import filelock 28 | import tempfile 29 | 30 | fundamental_classes = [] 31 | fundamental_node = node_wrapper(fundamental_classes) 32 | 33 | 34 | @fundamental_node 35 | class SleepNodeAny: 36 | FUNCTION = "sleep" 37 | RETURN_TYPES = (anytype,) 38 | CATEGORY = "Misc" 39 | custom_name = "SleepNode" 40 | 41 | @staticmethod 42 | def sleep(interval, inputs): 43 | time.sleep(interval) 44 | return (inputs,) 45 | 46 | @classmethod 47 | def INPUT_TYPES(cls): 48 | return { 49 | "required": { 50 | "interval": ("FLOAT", {"default": 0.0}), 51 | }, 52 | "optional": { 53 | "inputs": (anytype, {"default": 0.0}), 54 | }, 55 | } 56 | 57 | 58 | @fundamental_node 59 | class SleepNodeImage: 60 | FUNCTION = "sleep" 61 | RETURN_TYPES = (anytype,) 62 | CATEGORY = "Misc" 63 | custom_name = "Sleep (Image tunnel)" 64 | 65 | @staticmethod 66 | def sleep(interval, image): 67 | time.sleep(interval) 68 | return (image,) 69 | 70 | @classmethod 71 | def INPUT_TYPES(cls): 72 | return { 73 | "required": { 74 | "interval": ("FLOAT", {"default": 0.0}), 75 | "image": (anytype,), 76 | } 77 | } 78 | 79 | 80 | @fundamental_node 81 | class ErrorNode: 82 | FUNCTION = "raise_error" 83 | RETURN_TYPES = ("STRING",) 84 | CATEGORY = "Misc" 85 | custom_name = "ErrorNode" 86 | 87 | @staticmethod 88 | def raise_error(error_msg="Error"): 89 | raise Exception("Error: {}".format(error_msg)) 90 | 91 | @classmethod 92 | def INPUT_TYPES(cls): 93 | return { 94 | "required": { 95 | "error_msg": ("STRING", {"default": "Error"}), 96 | } 97 | } 98 | 99 | 100 | @fundamental_node 101 | class CurrentTimestamp: 102 | """ 103 | Returns the current Unix timestamp or a formatted time string. 104 | """ 105 | 106 | def __init__(self): 107 | pass 108 | 109 | def generate(self, format_string): 110 | if format_string.strip() == "": 111 | # return Unix timestamp 112 | return (int(time.time()),) 113 | else: 114 | # return formatted date/time 115 | return (time.strftime(format_string, time.localtime()),) 116 | 117 | @classmethod 118 | def INPUT_TYPES(cls): 119 | return { 120 | "required": { 121 | "format_string": ( 122 | "STRING", 123 | { 124 | "default": "", 125 | "display": "text", 126 | "comment": "Leave blank for raw timestamp, or use format directives like '%Y-%m-%d %H:%M:%S'", 127 | }, 128 | ), 129 | } 130 | } 131 | 132 | RETURN_TYPES = ("STRING",) # or ("INT",) if returning raw int timestamp 133 | FUNCTION = "generate" 134 | CATEGORY = "Logic Gates" 135 | custom_name = "Current Timestamp" 136 | 137 | 138 | @fundamental_node 139 | class DebugComboInputNode: 140 | FUNCTION = "debug_combo_input" 141 | RETURN_TYPES = ("STRING",) 142 | CATEGORY = "Misc" 143 | custom_name = "Debug Combo Input" 144 | 145 | @staticmethod 146 | def debug_combo_input(input1): 147 | print(input1) 148 | return (input1,) 149 | 150 | @classmethod 151 | def INPUT_TYPES(cls): 152 | return { 153 | "required": { 154 | "input1": (["0", "1", "2"], {"default": "0"}), 155 | } 156 | } 157 | 158 | 159 | # https://github.com/comfyanonymous/ComfyUI/blob/340177e6e85d076ab9e222e4f3c6a22f1fb4031f/custom_nodes/example_node.py.example#L18 160 | @fundamental_node 161 | class TextPreviewNode: 162 | """ 163 | Can't display text but it makes always changed state 164 | """ 165 | 166 | FUNCTION = "text_preview" 167 | RETURN_TYPES = () 168 | CATEGORY = "Misc" 169 | custom_name = "Text Preview" 170 | RESULT_NODE = True 171 | OUTPUT_NODE = True 172 | 173 | def text_preview(self, text): 174 | print(text) 175 | # below does not work, why? 176 | return {"ui": {"text": str(text)}} 177 | 178 | @classmethod 179 | def INPUT_TYPES(cls): 180 | return { 181 | "required": { 182 | "text": (anytype, {"default": "text", "type": "output"}), 183 | } 184 | } 185 | 186 | @classmethod 187 | def IS_CHANGED(s, *args, **kwargs): 188 | return float("nan") 189 | 190 | 191 | @fundamental_node 192 | class ParseExifNode: 193 | """ 194 | Parses exif data from image 195 | """ 196 | 197 | FUNCTION = "parse_exif" 198 | RETURN_TYPES = ("STRING",) 199 | CATEGORY = "Misc" 200 | custom_name = "Parse Exif" 201 | 202 | @staticmethod 203 | def parse_exif(image): 204 | return (read_info_from_image_stealth(image),) 205 | 206 | @classmethod 207 | def INPUT_TYPES(cls): 208 | return { 209 | "required": { 210 | "image": ("IMAGE",), 211 | } 212 | } 213 | 214 | 215 | def throw_if_parent_or_root_access(path): 216 | if ".." in path or path.startswith("/") or path.startswith("\\"): 217 | raise RuntimeError("Tried to access parent or root directory") 218 | if path.startswith("~"): 219 | raise RuntimeError("Tried to access home directory") 220 | if os.path.isabs(path): 221 | raise RuntimeError("Path cannot be absolute") 222 | 223 | 224 | @fundamental_node 225 | class SaveImageCustomNode: 226 | def __init__(self): 227 | self.output_dir = folder_paths.get_output_directory() 228 | self.type = "output" 229 | self.prefix_append = "" 230 | self.compress_level = 4 231 | 232 | @classmethod 233 | def INPUT_TYPES(s): 234 | return { 235 | "required": { 236 | "images": ("IMAGE",), 237 | "filename_prefix": ("STRING", {"default": "ComfyUI"}), 238 | "subfolder_dir": ("STRING", {"default": ""}), 239 | }, 240 | "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, 241 | } 242 | 243 | RETURN_TYPES = ("STRING",) # Filename 244 | FUNCTION = "save_images" 245 | 246 | OUTPUT_NODE = True 247 | RESULT_NODE = True 248 | CATEGORY = "image" 249 | custom_name = "Save Image Custom Node" 250 | 251 | def save_images( 252 | self, 253 | images, 254 | filename_prefix="ComfyUI", 255 | subfolder_dir="", 256 | prompt=None, 257 | extra_pnginfo=None, 258 | ): 259 | if images is None: # sometimes images is empty 260 | images = [] 261 | filename_prefix += self.prefix_append 262 | throw_if_parent_or_root_access(filename_prefix) 263 | throw_if_parent_or_root_access(subfolder_dir) 264 | output_dir = os.path.join(self.output_dir, subfolder_dir) 265 | full_output_folder, filename, counter, subfolder, filename_prefix = ( 266 | folder_paths.get_save_image_path( 267 | filename_prefix, output_dir, images[0].shape[1], images[0].shape[0] 268 | ) 269 | ) 270 | results = list() 271 | for image in images: 272 | i = 255.0 * image.cpu().numpy() 273 | img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) 274 | metadata = None 275 | if not args.disable_metadata: 276 | metadata = PngInfo() 277 | if prompt is not None: 278 | metadata.add_text("prompt", json.dumps(prompt)) 279 | if extra_pnginfo is not None: 280 | for x in extra_pnginfo: 281 | metadata.add_text(x, json.dumps(extra_pnginfo[x])) 282 | 283 | file = f"{filename}_{counter:05}_.png" 284 | img.save( 285 | os.path.join(full_output_folder, file), 286 | pnginfo=metadata, 287 | compress_level=self.compress_level, 288 | ) 289 | results.append( 290 | {"filename": file, "subfolder": subfolder, "type": self.type} 291 | ) 292 | counter += 1 293 | 294 | return {"ui": {"images": results}, "outputs": {"images": file.rstrip(".png")}} 295 | 296 | 297 | @fundamental_node 298 | class SaveTextCustomNode: 299 | def __init__(self): 300 | self.output_dir = folder_paths.get_output_directory() 301 | self.type = "output" 302 | self.prefix_append = "" 303 | self.compress_level = 4 304 | 305 | @classmethod 306 | def INPUT_TYPES(s): 307 | return { 308 | "required": { 309 | "text": (anytype,), 310 | "filename_prefix": ("STRING", {"default": "ComfyUI"}), 311 | "subfolder_dir": ("STRING", {"default": ""}), 312 | "filename": ("STRING", {"default": ""}), 313 | }, 314 | } 315 | 316 | RETURN_TYPES = ("STRING",) # Filename 317 | FUNCTION = "save_text" 318 | custom_name = "Save Text Custom Node" 319 | CATEGORY = "text" 320 | RESULT_NODE = True 321 | OUTPUT_NODE = True 322 | 323 | def save_text(self, text, filename_prefix="ComfyUI", subfolder_dir="", filename=""): 324 | text = str(text) 325 | throw_if_parent_or_root_access(filename_prefix) 326 | throw_if_parent_or_root_access(subfolder_dir) 327 | assert ( 328 | len(text) > 0 and len(filename) > 0 329 | ), "Text and filename must be non-empty" 330 | filename_prefix += self.prefix_append 331 | output_dir = os.path.join(self.output_dir, subfolder_dir) 332 | filename_merged = filename_prefix + filename + ".txt" 333 | full_output_folder, subfolder, actual_filename = output_dir, "", filename_merged 334 | results = list() 335 | file = actual_filename 336 | with open(os.path.join(full_output_folder, file), "w") as f: 337 | f.write(text) 338 | results.append({"filename": file, "subfolder": subfolder, "type": self.type}) 339 | 340 | return {"ui": {"texts": results}, "outputs": {"images": file.rstrip(".txt")}} 341 | 342 | 343 | @fundamental_node 344 | class DumpTextJsonlNode: 345 | """ 346 | Appends text to a JSONL file (one JSON object per line). 347 | Each line will have the structure: { "": "" } 348 | 349 | For concurrency safety, this node uses filelock to block 350 | concurrent writes to the same file. 351 | """ 352 | 353 | FUNCTION = "dump_text_jsonl" 354 | RETURN_TYPES = ("STRING",) # We return the filename for convenience 355 | CATEGORY = "text" 356 | custom_name = "Dump Text JSONL Node" 357 | RESULT_NODE = True 358 | OUTPUT_NODE = True 359 | 360 | def __init__(self): 361 | self.output_dir = folder_paths.get_output_directory() 362 | self.type = "output" # for consistent UI listing 363 | self.prefix_append = "" 364 | 365 | @classmethod 366 | def INPUT_TYPES(cls): 367 | """ 368 | text can be a single string or a list of strings. 369 | If it's a list, each item is appended as a separate line. 370 | """ 371 | return { 372 | "required": { 373 | "text": (anytype,), # Single string or list of strings 374 | "filename_prefix": ("STRING", {"default": "ComfyUI"}), 375 | "subfolder_dir": ("STRING", {"default": ""}), 376 | "filename": ("STRING", {"default": "dump.jsonl"}), 377 | "keyname": ("STRING", {"default": "text"}), 378 | }, 379 | } 380 | 381 | def dump_text_jsonl( 382 | self, 383 | text, 384 | filename_prefix="ComfyUI", 385 | subfolder_dir="", 386 | filename="dump.jsonl", 387 | keyname="text", 388 | ): 389 | # Security checks to avoid writing outside of the ComfyUI output folder 390 | throw_if_parent_or_root_access(filename_prefix) 391 | throw_if_parent_or_root_access(subfolder_dir) 392 | 393 | # Build the actual output path 394 | filename_prefix += self.prefix_append # If you want to append something 395 | output_dir = os.path.join(self.output_dir, subfolder_dir) 396 | os.makedirs(output_dir, exist_ok=True) 397 | 398 | final_filename = filename_prefix + "_" + filename 399 | full_path = os.path.join(output_dir, final_filename) 400 | lock_path = full_path + ".lock" 401 | 402 | # Ensure we can safely write concurrently 403 | with filelock.FileLock(lock_path, timeout=10): 404 | with open(full_path, "a", encoding="utf-8") as f: 405 | # If `text` is a list, write each element as its own JSON line 406 | if isinstance(text, list): 407 | for item in text: 408 | # Convert each item to string, just to be safe 409 | line = {keyname: str(item)} 410 | f.write(json.dumps(line, ensure_ascii=False) + "\n") 411 | else: 412 | # Single string input 413 | line = {keyname: str(text)} 414 | f.write(json.dumps(line, ensure_ascii=False) + "\n") 415 | 416 | # Return data for UI usage 417 | results = [ 418 | {"filename": final_filename, "subfolder": subfolder_dir, "type": self.type} 419 | ] 420 | return { 421 | "ui": {"texts": results}, 422 | "outputs": {"filename": final_filename}, 423 | } 424 | 425 | 426 | @fundamental_node 427 | class ConcatGridNode: 428 | """ 429 | Concatenate multiple images in a row, a column, or a square-like grid 430 | using either resizing or padding to match dimensions. 431 | 432 | direction: 433 | - "horizontal": line up side by side 434 | - "vertical": stack top to bottom 435 | - "square-like": arrange images in an NxN grid (where N = ceil(sqrt(#images))) 436 | 437 | match_method: 438 | - "resize": scale images so their matching dimension is the same 439 | (height for horizontal, width for vertical, or cell-size for square-like) 440 | - "pad": keep original size but add transparent padding so the matching dimension is the same 441 | """ 442 | 443 | FUNCTION = "concat_grid" 444 | RETURN_TYPES = ("IMAGE",) 445 | CATEGORY = "image" 446 | custom_name = "Concat Grid (Batch to single grid)" 447 | 448 | @classmethod 449 | def INPUT_TYPES(cls): 450 | return { 451 | "required": { 452 | "images": ("IMAGE",), 453 | "direction": ( 454 | ["horizontal", "vertical", "square-like"], 455 | {"default": "horizontal"}, 456 | ), 457 | "match_method": (["resize", "pad"], {"default": "resize"}), 458 | } 459 | } 460 | 461 | @staticmethod 462 | @PILHandlingHodes.output_wrapper 463 | def concat_grid(images, direction="horizontal", match_method="resize"): 464 | # 1) Convert images input to a list of PIL RGBA images 465 | # - If it's a torch.Tensor with shape (B, C, H, W) or a single image, unify into list. 466 | if not ( 467 | isinstance(images, torch.Tensor) and len(images.shape) == 4 468 | ) and not isinstance(images, (list, tuple)): 469 | images = [images] 470 | 471 | converted = PILHandlingHodes.handle_input(images) # returns PIL or list of PIL 472 | if isinstance(converted, list): 473 | pil_images = [img.convert("RGBA") for img in converted] 474 | else: 475 | pil_images = [converted.convert("RGBA")] 476 | 477 | if len(pil_images) == 0: 478 | raise RuntimeError("No images provided to Concat Grid") 479 | 480 | # 2) Handle the three layout directions 481 | if direction == "horizontal": 482 | # --- Horizontal layout --- 483 | max_height = max(img.height for img in pil_images) 484 | processed = [] 485 | for img in pil_images: 486 | if match_method == "resize": 487 | # Scale the image so that height == max_height 488 | if img.height == 0: 489 | raise RuntimeError("Encountered an image of zero height.") 490 | ratio = max_height / float(img.height) 491 | new_w = int(img.width * ratio) 492 | new_h = max_height 493 | new_img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) 494 | else: # "pad" 495 | # Create a new image with the same width but max_height 496 | new_img = Image.new("RGBA", (img.width, max_height), (0, 0, 0, 0)) 497 | new_img.paste(img, (0, 0)) 498 | processed.append(new_img) 499 | 500 | total_width = sum(im.width for im in processed) 501 | out = Image.new("RGBA", (total_width, max_height), (0, 0, 0, 0)) 502 | x_offset = 0 503 | for im in processed: 504 | out.paste(im, (x_offset, 0)) 505 | x_offset += im.width 506 | 507 | elif direction == "vertical": 508 | # --- Vertical layout --- 509 | max_width = max(img.width for img in pil_images) 510 | processed = [] 511 | for img in pil_images: 512 | if match_method == "resize": 513 | if img.width == 0: 514 | raise RuntimeError("Encountered an image of zero width.") 515 | ratio = max_width / float(img.width) 516 | new_w = max_width 517 | new_h = int(img.height * ratio) 518 | new_img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) 519 | else: # "pad" 520 | new_img = Image.new("RGBA", (max_width, img.height), (0, 0, 0, 0)) 521 | new_img.paste(img, (0, 0)) 522 | processed.append(new_img) 523 | 524 | total_height = sum(im.height for im in processed) 525 | out = Image.new("RGBA", (max_width, total_height), (0, 0, 0, 0)) 526 | y_offset = 0 527 | for im in processed: 528 | out.paste(im, (0, y_offset)) 529 | y_offset += im.height 530 | 531 | else: # direction == "square-like" 532 | # --- Square-like NxN grid --- 533 | count = len(pil_images) 534 | # Determine grid size 535 | num_cols = int(math.ceil(math.sqrt(count))) 536 | num_rows = int(math.ceil(count / num_cols)) 537 | 538 | # Find maximum width/height among images 539 | max_width = max(img.width for img in pil_images) 540 | max_height = max(img.height for img in pil_images) 541 | 542 | processed = [] 543 | for img in pil_images: 544 | if match_method == "resize": 545 | # Here we forcibly resize each image to (max_width, max_height) 546 | # (which may distort if aspect ratios differ). 547 | new_img = img.resize( 548 | (max_width, max_height), Image.Resampling.LANCZOS 549 | ) 550 | else: # "pad" 551 | # Keep original size but create a new RGBA canvas so each cell is (max_w, max_h) 552 | new_img = Image.new("RGBA", (max_width, max_height), (0, 0, 0, 0)) 553 | new_img.paste(img, (0, 0)) 554 | processed.append(new_img) 555 | 556 | # Create the final output canvas 557 | grid_width = num_cols * max_width 558 | grid_height = num_rows * max_height 559 | out = Image.new("RGBA", (grid_width, grid_height), (0, 0, 0, 0)) 560 | 561 | # Paste images in row-major order 562 | idx = 0 563 | for row in range(num_rows): 564 | for col in range(num_cols): 565 | if idx >= count: 566 | break # no more images 567 | x_offset = col * max_width 568 | y_offset = row * max_height 569 | out.paste(processed[idx], (x_offset, y_offset)) 570 | idx += 1 571 | 572 | return (out,) 573 | 574 | 575 | @fundamental_node 576 | class ConcatTwoImagesNode: 577 | """ 578 | Concatenate exactly two images (imageA, imageB). 579 | 580 | direction: 581 | - "horizontal": line them up side by side 582 | - "vertical": place them top to bottom 583 | 584 | match_method: 585 | - "resize": scale images so their matching dimension is the same 586 | (height for horizontal, width for vertical) 587 | - "pad": keep original size but pad them so the matching dimension is the same 588 | """ 589 | 590 | FUNCTION = "concat_two_images" 591 | RETURN_TYPES = ("IMAGE",) 592 | CATEGORY = "image" 593 | custom_name = "Concat 2 Images to Grid" 594 | 595 | @classmethod 596 | def INPUT_TYPES(cls): 597 | return { 598 | "required": { 599 | "imageA": ("IMAGE",), 600 | "imageB": ("IMAGE",), 601 | "direction": (["horizontal", "vertical"], {"default": "horizontal"}), 602 | "match_method": (["resize", "pad"], {"default": "resize"}), 603 | } 604 | } 605 | 606 | @staticmethod 607 | @PILHandlingHodes.output_wrapper 608 | def concat_two_images( 609 | imageA, imageB, direction="horizontal", match_method="resize" 610 | ): 611 | # Convert input to PIL images (RGBA to preserve alpha if needed) 612 | pilA = PILHandlingHodes.handle_input(imageA) 613 | if isinstance(pilA, list): 614 | raise RuntimeError( 615 | "Expected a single image for imageA, grid only supports two images" 616 | ) 617 | pilB = PILHandlingHodes.handle_input(imageB) 618 | if isinstance(pilB, list): 619 | raise RuntimeError( 620 | "Expected a single image for imageB, grid only supports two images" 621 | ) 622 | if direction == "horizontal": 623 | # We want to unify heights 624 | max_h = max(pilA.height, pilB.height) 625 | 626 | if match_method == "resize": 627 | # Scale each image so their heights match 628 | def scale_height(img, target_h): 629 | if img.height == 0: 630 | raise RuntimeError("Encountered an image with zero height.") 631 | ratio = target_h / float(img.height) 632 | new_w = int(img.width * ratio) 633 | new_h = target_h 634 | return img.resize((new_w, new_h), Image.Resampling.LANCZOS) 635 | 636 | pilA = scale_height(pilA, max_h) 637 | pilB = scale_height(pilB, max_h) 638 | 639 | else: # match_method == "pad" 640 | # Pad images with transparent background so they share the same height 641 | def pad_height(img, target_h): 642 | new_img = Image.new("RGBA", (img.width, target_h), (0, 0, 0, 0)) 643 | new_img.paste(img, (0, 0)) 644 | return new_img 645 | 646 | pilA = pad_height(pilA, max_h) 647 | pilB = pad_height(pilB, max_h) 648 | 649 | total_width = pilA.width + pilB.width 650 | out = Image.new("RGBA", (total_width, max_h), (0, 0, 0, 0)) 651 | # Paste images side by side 652 | out.paste(pilA, (0, 0)) 653 | out.paste(pilB, (pilA.width, 0)) 654 | 655 | else: 656 | # direction == "vertical" 657 | # We want to unify widths 658 | max_w = max(pilA.width, pilB.width) 659 | 660 | if match_method == "resize": 661 | # Scale each image so their widths match 662 | def scale_width(img, target_w): 663 | if img.width == 0: 664 | raise RuntimeError("Encountered an image with zero width.") 665 | ratio = target_w / float(img.width) 666 | new_w = target_w 667 | new_h = int(img.height * ratio) 668 | return img.resize((new_w, new_h), Image.Resampling.LANCZOS) 669 | 670 | pilA = scale_width(pilA, max_w) 671 | pilB = scale_width(pilB, max_w) 672 | 673 | else: # match_method == "pad" 674 | # Pad images with transparent background so they share the same width 675 | def pad_width(img, target_w): 676 | new_img = Image.new("RGBA", (target_w, img.height), (0, 0, 0, 0)) 677 | new_img.paste(img, (0, 0)) 678 | return new_img 679 | 680 | pilA = pad_width(pilA, max_w) 681 | pilB = pad_width(pilB, max_w) 682 | 683 | total_height = pilA.height + pilB.height 684 | out = Image.new("RGBA", (max_w, total_height), (0, 0, 0, 0)) 685 | # Paste images top to bottom 686 | out.paste(pilA, (0, 0)) 687 | out.paste(pilB, (0, pilA.height)) 688 | 689 | return (out,) 690 | 691 | 692 | @fundamental_node 693 | class SaveCustomJPGNode: 694 | def __init__(self): 695 | self.output_dir = folder_paths.get_output_directory() 696 | self.type = "output" 697 | self.prefix_append = "" 698 | 699 | @classmethod 700 | def INPUT_TYPES(s): 701 | return { 702 | "required": { 703 | "images": ("IMAGE",), 704 | "filename_prefix": ("STRING", {"default": "ComfyUI"}), 705 | "subfolder_dir": ("STRING", {"default": ""}), 706 | }, 707 | "optional": { 708 | "quality": ("INT", {"default": 95}), 709 | "optimize": ("BOOLEAN", {"default": True}), 710 | "metadata_string": ("STRING", {"default": ""}), 711 | }, 712 | "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, 713 | } 714 | 715 | RETURN_TYPES = ("STRING",) # Filename 716 | FUNCTION = "save_images" 717 | 718 | OUTPUT_NODE = True 719 | RESULT_NODE = True 720 | 721 | CATEGORY = "image" 722 | custom_name = "Save Custom JPG Node" 723 | 724 | def save_images( 725 | self, 726 | images, 727 | filename_prefix="ComfyUI", 728 | subfolder_dir="", 729 | prompt=None, 730 | extra_pnginfo=None, 731 | quality=95, 732 | optimize=True, 733 | metadata_string="", 734 | ): 735 | if images is None: 736 | images = [] 737 | if not isinstance(images, (list, tuple, torch.Tensor)): 738 | images = [images] 739 | 740 | throw_if_parent_or_root_access(filename_prefix) 741 | throw_if_parent_or_root_access(subfolder_dir) 742 | 743 | filename_prefix += self.prefix_append 744 | output_dir = os.path.join(self.output_dir, subfolder_dir) 745 | filelock_path = os.path.join(output_dir, filename_prefix + ".lock") 746 | 747 | results = [] 748 | for image in images: 749 | if isinstance(image, torch.Tensor): 750 | if image.device.type != "cpu": 751 | image = image.cpu() 752 | image = 255.0 * image.numpy() 753 | clipped = np.clip(image, 0, 255).astype(np.uint8) 754 | if clipped.shape[0] <= 3: 755 | clipped = np.transpose(clipped, (1, 2, 0)) 756 | img = Image.fromarray(clipped) 757 | else: 758 | img = PILHandlingHodes.handle_input(image) 759 | 760 | metadata = {} 761 | if not args.disable_metadata: 762 | if prompt is not None: 763 | metadata["prompt"] = json.dumps(prompt) 764 | if extra_pnginfo is not None: 765 | for x in extra_pnginfo: 766 | metadata[x] = json.dumps(extra_pnginfo[x]) 767 | 768 | if metadata_string: 769 | metadata = {"metadata": metadata_string} 770 | 771 | exif_bytes = None 772 | if piexif_loaded: 773 | exif_bytes = piexif.dump( 774 | { 775 | "Exif": { 776 | piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump( 777 | json.dumps(metadata), encoding="unicode" 778 | ) 779 | }, 780 | } 781 | ) 782 | 783 | with filelock.FileLock(filelock_path, timeout=10): 784 | full_output_folder, filename, counter, subfolder, filename_prefix = ( 785 | folder_paths.get_save_image_path( 786 | filename_prefix, output_dir, img.size[1], img.size[0] 787 | ) 788 | ) 789 | counter_len = len(str(len(images))) 790 | file = f"{filename}_{str(counter).zfill(max(5, counter_len))}_.jpg" 791 | 792 | with tempfile.NamedTemporaryFile( 793 | suffix=".jpg", delete=False 794 | ) as tmpfile: 795 | tmp_path = tmpfile.name 796 | img.save(tmp_path, "JPEG", quality=quality, optimize=optimize) 797 | 798 | if piexif_loaded and exif_bytes: 799 | piexif.insert(exif_bytes, tmp_path) 800 | 801 | final_path = os.path.join(full_output_folder, file) 802 | shutil.copy2(tmp_path, final_path) 803 | os.remove(tmp_path) 804 | 805 | results.append( 806 | { 807 | "filename": os.path.join(full_output_folder, file), 808 | "subfolder": subfolder_dir, 809 | "type": self.type, 810 | } 811 | ) 812 | 813 | return { 814 | "ui": {"images": results}, 815 | "outputs": { 816 | "images": os.path.join(full_output_folder, file).rstrip(".jpg") 817 | }, 818 | } 819 | 820 | 821 | @fundamental_node 822 | class SaveImageWebpCustomNode: 823 | def __init__(self): 824 | self.output_dir = folder_paths.get_output_directory() 825 | self.type = "output" 826 | self.prefix_append = "" 827 | 828 | @classmethod 829 | def INPUT_TYPES(s): 830 | return { 831 | "required": { 832 | "images": ("IMAGE",), 833 | "filename_prefix": ("STRING", {"default": "ComfyUI"}), 834 | "subfolder_dir": ("STRING", {"default": ""}), 835 | }, 836 | "optional": { 837 | "quality": ("INT", {"default": 100}), 838 | "lossless": ("BOOLEAN", {"default": False}), 839 | "compression": ("INT", {"default": 4}), 840 | "optimize": ("BOOLEAN", {"default": False}), 841 | "metadata_string": ("STRING", {"default": ""}), 842 | "optional_additional_metadata": ("STRING", {"default": ""}), 843 | }, 844 | "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, 845 | } 846 | 847 | RETURN_TYPES = ("STRING",) # Filename 848 | FUNCTION = "save_images" 849 | 850 | OUTPUT_NODE = True 851 | RESULT_NODE = True 852 | 853 | CATEGORY = "image" 854 | custom_name = "Save Image Webp Node" 855 | 856 | def save_images( 857 | self, 858 | images, 859 | filename_prefix="ComfyUI", 860 | subfolder_dir="", 861 | prompt=None, 862 | extra_pnginfo=None, 863 | quality=100, 864 | lossless=False, 865 | compression=4, 866 | optimize=False, 867 | metadata_string="", 868 | optional_additional_metadata="", 869 | ): 870 | if images is None: # sometimes images is empty 871 | images = [] 872 | if not isinstance(images, (list, tuple, torch.Tensor)): 873 | images = [images] 874 | throw_if_parent_or_root_access(filename_prefix) 875 | throw_if_parent_or_root_access(subfolder_dir) 876 | filename_prefix += self.prefix_append 877 | output_dir = os.path.join(self.output_dir, subfolder_dir) 878 | filelock_path = os.path.join(output_dir, filename_prefix + ".lock") 879 | 880 | results = list() 881 | for image in images: 882 | if isinstance(image, torch.Tensor): 883 | if image.device.type != "cpu": 884 | image = image.cpu() 885 | image = 255.0 * image.numpy() 886 | clipped = np.clip(image, 0, 255).astype(np.uint8) 887 | if clipped.shape[0] == 3: 888 | clipped = np.transpose(clipped, (1, 2, 0)) # [1216, 832, 3] 889 | # if len(shape) is 4 and first dimension is 1, remove it (batch size) 890 | if clipped.shape[0] == 1 and len(clipped.shape) == 4: 891 | clipped = clipped[0] 892 | # print(clipped.shape) 893 | img = Image.fromarray(clipped) 894 | else: 895 | img = PILHandlingHodes.handle_input(image) 896 | metadata = None 897 | if not args.disable_metadata: 898 | metadata = {} 899 | if prompt is not None: 900 | metadata["prompt"] = json.dumps(prompt) 901 | if extra_pnginfo is not None: 902 | for x in extra_pnginfo: 903 | metadata[x] = json.dumps(extra_pnginfo[x]) 904 | if metadata_string: # override metadata 905 | metadata = {} 906 | metadata["metadata"] = metadata_string 907 | if optional_additional_metadata: 908 | metadata["optional_additional_metadata"] = optional_additional_metadata 909 | if piexif_loaded: 910 | exif_bytes = piexif.dump( 911 | { 912 | "Exif": { 913 | piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump( 914 | json.dumps(metadata) or "", encoding="unicode" 915 | ) 916 | }, 917 | } 918 | ) 919 | 920 | with filelock.FileLock( 921 | filelock_path, timeout=10 922 | ): # timeout 10 seconds should be enough for most cases 923 | full_output_folder, filename, counter, subfolder, filename_prefix = ( 924 | folder_paths.get_save_image_path( 925 | filename_prefix, output_dir, img.size[1], img.size[0] 926 | ) 927 | ) 928 | counter_len = len(str(len(images))) # for padding 929 | # file = f"{filename}_{counter:05}_.webp" 930 | file = f"{filename}_{str(counter).zfill(max(5, counter_len))}_.webp" 931 | with tempfile.NamedTemporaryFile( 932 | suffix=".webp", delete=False 933 | ) as tmpfile: 934 | tmp_path = tmpfile.name 935 | img.save( 936 | tmp_path, 937 | "WEBP", 938 | pnginfo=metadata, 939 | compress_level=compression, 940 | quality=quality, 941 | lossless=lossless, 942 | optimize=optimize, 943 | ) 944 | if piexif_loaded: 945 | piexif.insert(exif_bytes, tmp_path) 946 | final_path = os.path.join(full_output_folder, file) 947 | shutil.copy2(tmp_path, final_path) 948 | os.remove(tmp_path) 949 | 950 | results.append( 951 | { 952 | "filename": os.path.join(full_output_folder, file), 953 | "subfolder": subfolder_dir, 954 | "type": self.type, 955 | } 956 | ) 957 | 958 | return { 959 | "ui": {"images": results}, 960 | "outputs": { 961 | "images": os.path.join(full_output_folder, file).rstrip(".webp") 962 | }, 963 | } 964 | 965 | 966 | @fundamental_node 967 | class ComposeRGBAImageFromMask: 968 | @classmethod 969 | def INPUT_TYPES(s): 970 | return { 971 | "required": { 972 | "image": ("IMAGE",), 973 | "mask": ("MASK",), 974 | "invert": ("BOOLEAN", {"default": False}), 975 | } 976 | } 977 | 978 | RETURN_TYPES = ("IMAGE",) 979 | FUNCTION = "compose" 980 | CATEGORY = "image" 981 | custom_name = "Compose RGBA Image From Mask" 982 | 983 | @staticmethod 984 | def compose(image, mask, invert): 985 | if invert: 986 | mask = 1.0 - mask 987 | 988 | # Ensure mask has shape (batch_size, height, width, 1) 989 | mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1], 1)) 990 | # check devices, move to cpu 991 | if hasattr(image, "device"): 992 | image = image.cpu() 993 | if hasattr(mask, "device"): 994 | mask = mask.cpu() 995 | # Resize mask to match image dimensions if necessary 996 | if ( 997 | image.shape[0] != mask.shape[0] 998 | or image.shape[1] != mask.shape[1] 999 | or image.shape[2] != mask.shape[2] 1000 | ): 1001 | # Resize mask to match image dimensions 1002 | mask = torch.nn.functional.interpolate( 1003 | mask.permute(0, 3, 1, 2), 1004 | size=(image.shape[1], image.shape[2]), 1005 | mode="bilinear", 1006 | align_corners=False, 1007 | ).permute(0, 2, 3, 1) 1008 | 1009 | num_channels = image.shape[-1] 1010 | if num_channels == 3: 1011 | rgba_image = torch.cat((image, mask), dim=-1) 1012 | elif num_channels == 4: 1013 | rgba_image = image.clone() 1014 | rgba_image[:, :, :, 3:] = mask 1015 | else: 1016 | raise ValueError("Image must have 3 (RGB) or 4 (RGBA) channels") 1017 | 1018 | return (rgba_image,) 1019 | 1020 | 1021 | @fundamental_node 1022 | class ResizeImageNode: 1023 | FUNCTION = "resize_image" 1024 | RETURN_TYPES = ("IMAGE",) 1025 | CATEGORY = "image" 1026 | custom_name = "Resize Image" 1027 | 1028 | constants = { 1029 | "NEAREST": Image.Resampling.NEAREST, 1030 | "LANCZOS": Image.Resampling.LANCZOS, 1031 | "BICUBIC": Image.Resampling.BICUBIC, 1032 | } 1033 | 1034 | @staticmethod 1035 | @PILHandlingHodes.output_wrapper 1036 | def resize_image(image, width, height, method): 1037 | image = PILHandlingHodes.handle_input(image) 1038 | return (image.resize((width, height), ResizeImageNode.constants[method]),) 1039 | 1040 | @classmethod 1041 | def INPUT_TYPES(cls): 1042 | return { 1043 | "required": { 1044 | "image": ("IMAGE",), 1045 | "width": ("INT", {"default": 512}), 1046 | "height": ("INT", {"default": 512}), 1047 | "method": (["NEAREST", "LANCZOS", "BICUBIC"],), 1048 | }, 1049 | } 1050 | 1051 | 1052 | @fundamental_node 1053 | class ResizeImageResolution: 1054 | FUNCTION = "resize_image_resolution" 1055 | RETURN_TYPES = ("IMAGE",) 1056 | CATEGORY = "image" 1057 | custom_name = "Resize Image With Resolution" 1058 | 1059 | constants = { 1060 | "NEAREST": Image.Resampling.NEAREST, 1061 | "LANCZOS": Image.Resampling.LANCZOS, 1062 | "BICUBIC": Image.Resampling.BICUBIC, 1063 | } 1064 | 1065 | @staticmethod 1066 | @PILHandlingHodes.output_wrapper 1067 | def resize_image_resolution(image, resolution, method): 1068 | image = PILHandlingHodes.handle_input(image) 1069 | image_width, image_height = image.size 1070 | total_pixels = image_width * image_height 1071 | if total_pixels == 0: 1072 | raise RuntimeError("Image has no pixels") 1073 | if resolution < 256: 1074 | raise RuntimeError("Resolution must be positive and at least 256") 1075 | # get ratio 1076 | target_pixels = resolution**2 1077 | ratio = (target_pixels / total_pixels) ** 0.5 1078 | target_width = int(image_width * ratio) 1079 | target_height = int(image_height * ratio) 1080 | return ( 1081 | image.resize( 1082 | (target_width, target_height), ResizeImageResolution.constants[method] 1083 | ), 1084 | ) 1085 | 1086 | @classmethod 1087 | def INPUT_TYPES(cls): 1088 | return { 1089 | "required": { 1090 | "image": ("IMAGE",), 1091 | "resolution": ("INT", {"default": 512}), 1092 | "method": (["NEAREST", "LANCZOS", "BICUBIC"],), 1093 | }, 1094 | } 1095 | 1096 | 1097 | @fundamental_node 1098 | class ResizeImageEnsuringMultiple: 1099 | FUNCTION = "resize_image_ensuring_multiple" 1100 | RETURN_TYPES = ("IMAGE",) 1101 | CATEGORY = "image" 1102 | custom_name = "Resize Image Ensuring W/H Multiple" 1103 | 1104 | constants = { 1105 | "NEAREST": Image.Resampling.NEAREST, 1106 | "LANCZOS": Image.Resampling.LANCZOS, 1107 | "BICUBIC": Image.Resampling.BICUBIC, 1108 | } 1109 | 1110 | @staticmethod 1111 | @PILHandlingHodes.output_wrapper 1112 | def resize_image_ensuring_multiple(image, multiple, method): 1113 | image = PILHandlingHodes.handle_input(image) 1114 | image_width, image_height = image.size 1115 | total_pixels = image_width * image_height 1116 | if total_pixels == 0: 1117 | raise RuntimeError("Image has no pixels") 1118 | target_width = (image_width // multiple) * multiple 1119 | target_height = (image_height // multiple) * multiple 1120 | return ( 1121 | image.resize( 1122 | (target_width, target_height), ResizeImageResolution.constants[method] 1123 | ), 1124 | ) 1125 | 1126 | @classmethod 1127 | def INPUT_TYPES(cls): 1128 | return { 1129 | "required": { 1130 | "image": ("IMAGE",), 1131 | "multiple": ("INT", {"default": 32}), 1132 | "method": (["NEAREST", "LANCZOS", "BICUBIC"],), 1133 | }, 1134 | } 1135 | 1136 | 1137 | @fundamental_node 1138 | class ResizeImageResolutionIfBigger: 1139 | FUNCTION = "resize_image_resolution_if_bigger" 1140 | RETURN_TYPES = ("IMAGE",) 1141 | CATEGORY = "image" 1142 | custom_name = "Resize Image With Resolution If Bigger" 1143 | 1144 | constants = { 1145 | "NEAREST": Image.Resampling.NEAREST, 1146 | "LANCZOS": Image.Resampling.LANCZOS, 1147 | "BICUBIC": Image.Resampling.BICUBIC, 1148 | } 1149 | 1150 | @staticmethod 1151 | @PILHandlingHodes.output_wrapper 1152 | def resize_image_resolution_if_bigger(image, resolution, method): 1153 | image = PILHandlingHodes.handle_input(image) 1154 | image_width, image_height = image.size 1155 | total_pixels = image_width * image_height 1156 | if total_pixels == 0: 1157 | raise RuntimeError("Image has no pixels") 1158 | if total_pixels <= resolution**2: 1159 | return (image,) 1160 | # get ratio 1161 | target_pixels = resolution**2 1162 | ratio = target_pixels / total_pixels 1163 | target_width = int(image_width * ratio) 1164 | target_height = int(image_height * ratio) 1165 | return ( 1166 | image.resize( 1167 | (target_width, target_height), 1168 | ResizeImageResolutionIfBigger.constants[method], 1169 | ), 1170 | ) 1171 | 1172 | @classmethod 1173 | def INPUT_TYPES(cls): 1174 | return { 1175 | "required": { 1176 | "image": ("IMAGE",), 1177 | "resolution": ("INT", {"default": 512}), 1178 | "method": (["NEAREST", "LANCZOS", "BICUBIC"],), 1179 | }, 1180 | } 1181 | 1182 | 1183 | @fundamental_node 1184 | class ResizeImageResolutionIfSmaller: 1185 | FUNCTION = "resize_image_resolution_if_smaller" 1186 | RETURN_TYPES = ("IMAGE",) 1187 | CATEGORY = "image" 1188 | custom_name = "Resize Image With Resolution If Smaller" 1189 | 1190 | constants = { 1191 | "NEAREST": Image.Resampling.NEAREST, 1192 | "LANCZOS": Image.Resampling.LANCZOS, 1193 | "BICUBIC": Image.Resampling.BICUBIC, 1194 | } 1195 | 1196 | @staticmethod 1197 | @PILHandlingHodes.output_wrapper 1198 | def resize_image_resolution_if_smaller(image, resolution, method): 1199 | image = PILHandlingHodes.handle_input(image) 1200 | image_width, image_height = image.size 1201 | total_pixels = image_width * image_height 1202 | if total_pixels == 0: 1203 | raise RuntimeError("Image has no pixels") 1204 | if total_pixels >= resolution**2: 1205 | return (image,) 1206 | # get ratio 1207 | target_pixels = resolution**2 1208 | ratio = target_pixels / total_pixels 1209 | target_width = int(image_width * ratio) 1210 | target_height = int(image_height * ratio) 1211 | return ( 1212 | image.resize( 1213 | (target_width, target_height), 1214 | ResizeImageResolutionIfSmaller.constants[method], 1215 | ), 1216 | ) 1217 | 1218 | @classmethod 1219 | def INPUT_TYPES(cls): 1220 | return { 1221 | "required": { 1222 | "image": ("IMAGE",), 1223 | "resolution": ("INT", {"default": 512}), 1224 | "method": (["NEAREST", "LANCZOS", "BICUBIC"],), 1225 | }, 1226 | } 1227 | 1228 | 1229 | @fundamental_node 1230 | class Base64DecodeNode: 1231 | FUNCTION = "base64_decode" 1232 | RETURN_TYPES = ("IMAGE",) 1233 | CATEGORY = "image" 1234 | custom_name = "Base64 Decode to Image" 1235 | 1236 | @staticmethod 1237 | @PILHandlingHodes.output_wrapper 1238 | def base64_decode(base64_string): 1239 | image = PILHandlingHodes.handle_input( 1240 | base64_string 1241 | ) # automatically converts to PIL image 1242 | return (image,) 1243 | 1244 | @classmethod 1245 | def INPUT_TYPES(cls): 1246 | return { 1247 | "required": { 1248 | "base64_string": ("STRING",), 1249 | } 1250 | } 1251 | 1252 | 1253 | @fundamental_node 1254 | class ImageFromURLNode: 1255 | FUNCTION = "url_download" 1256 | RETURN_TYPES = ("IMAGE",) 1257 | CATEGORY = "image" 1258 | custom_name = "Download Image from URL" 1259 | 1260 | @staticmethod 1261 | @PILHandlingHodes.output_wrapper 1262 | def url_download(url): 1263 | if not url.startswith("http"): # for security reasons 1264 | raise RuntimeError( 1265 | "Strict URL check is required, however the URL does not start with http" 1266 | ) 1267 | image = PILHandlingHodes.handle_input(url) # automatically downloads image 1268 | return (image,) 1269 | 1270 | @classmethod 1271 | def INPUT_TYPES(cls): 1272 | return { 1273 | "required": { 1274 | "url": ("STRING",), 1275 | } 1276 | } 1277 | 1278 | 1279 | @fundamental_node 1280 | class Base64EncodeNode: 1281 | FUNCTION = "base64_encode" 1282 | RETURN_TYPES = ("STRING",) 1283 | CATEGORY = "image" 1284 | custom_name = "Image to Base64 Encode" 1285 | 1286 | @staticmethod 1287 | def base64_encode(image, quality, format, gzip_compress): 1288 | image = PILHandlingHodes.to_base64(image, quality, format, gzip_compress) 1289 | return (image,) 1290 | 1291 | @classmethod 1292 | def INPUT_TYPES(cls): 1293 | return { 1294 | "required": { 1295 | "image": ("IMAGE",), 1296 | }, 1297 | "optional": { 1298 | "quality": ("INT", {"default": 100}), 1299 | "format": (["PNG", "WEBP", "JPG"], {"default": "PNG"}), 1300 | "gzip_compress": ("BOOLEAN", {"default": False}), 1301 | }, 1302 | } 1303 | 1304 | 1305 | @fundamental_node 1306 | class StringToBase64Node: 1307 | FUNCTION = "string_to_base64" 1308 | RETURN_TYPES = ("STRING",) 1309 | CATEGORY = "image" 1310 | custom_name = "String to Base64 Encode" 1311 | 1312 | @staticmethod 1313 | def string_to_base64(string, gzip_compress): 1314 | return (PILHandlingHodes.string_to_base64(string, gzip_compress),) 1315 | 1316 | @classmethod 1317 | def INPUT_TYPES(cls): 1318 | return { 1319 | "required": { 1320 | "string": ("STRING",), 1321 | }, 1322 | "optional": { 1323 | "gzip_compress": ("BOOLEAN", {"default": False}), 1324 | }, 1325 | } 1326 | 1327 | 1328 | @fundamental_node 1329 | class Base64ToStringNode: 1330 | FUNCTION = "base64_to_string" 1331 | RETURN_TYPES = ("STRING",) 1332 | CATEGORY = "image" 1333 | custom_name = "Base64 to String Decode" 1334 | 1335 | @staticmethod 1336 | def base64_to_string(base64_string): 1337 | return (PILHandlingHodes.maybe_gzip_base64_to_string(base64_string),) 1338 | 1339 | @classmethod 1340 | def INPUT_TYPES(cls): 1341 | return { 1342 | "required": { 1343 | "base64_string": ("STRING",), 1344 | } 1345 | } 1346 | 1347 | 1348 | @fundamental_node 1349 | class InvertImageNode: 1350 | FUNCTION = "invert_image" 1351 | RETURN_TYPES = ("IMAGE",) 1352 | CATEGORY = "image" 1353 | custom_name = "Invert Image" 1354 | 1355 | @staticmethod 1356 | @PILHandlingHodes.output_wrapper 1357 | def invert_image(image): 1358 | image = PILHandlingHodes.handle_input(image) 1359 | return (ImageOps.invert(image),) 1360 | 1361 | @classmethod 1362 | def INPUT_TYPES(cls): 1363 | return { 1364 | "required": { 1365 | "image": ("IMAGE",), 1366 | } 1367 | } 1368 | 1369 | 1370 | @fundamental_node 1371 | class ResizeScaleImageNode: 1372 | FUNCTION = "resize_scale_image" 1373 | RETURN_TYPES = ("IMAGE",) 1374 | CATEGORY = "image" 1375 | custom_name = "Resize Scale Image" 1376 | 1377 | constants = { 1378 | "NEAREST": Image.Resampling.NEAREST, 1379 | "LANCZOS": Image.Resampling.LANCZOS, 1380 | "BICUBIC": Image.Resampling.BICUBIC, 1381 | } 1382 | 1383 | @staticmethod 1384 | @PILHandlingHodes.output_wrapper 1385 | def resize_scale_image(image, scale, method): 1386 | image = PILHandlingHodes.handle_input(image) 1387 | if scale < 0: 1388 | raise RuntimeError("Scale must be positive") 1389 | return ( 1390 | image.resize( 1391 | (int(image.width * scale), int(image.height * scale)), 1392 | ResizeScaleImageNode.constants[method], 1393 | ), 1394 | ) 1395 | 1396 | @classmethod 1397 | def INPUT_TYPES(cls): 1398 | return { 1399 | "required": { 1400 | "image": ("IMAGE",), 1401 | "scale": ("INT", {"default": 2}), 1402 | "method": (["NEAREST", "LANCZOS", "BICUBIC"],), 1403 | }, 1404 | } 1405 | 1406 | 1407 | @fundamental_node 1408 | class ResizeShortestToNode: 1409 | FUNCTION = "resize_shortest_to" 1410 | RETURN_TYPES = ("IMAGE",) 1411 | CATEGORY = "image" 1412 | custom_name = "Resize Shortest To" 1413 | 1414 | constants = { 1415 | "NEAREST": Image.Resampling.NEAREST, 1416 | "LANCZOS": Image.Resampling.LANCZOS, 1417 | "BICUBIC": Image.Resampling.BICUBIC, 1418 | } 1419 | 1420 | @staticmethod 1421 | @PILHandlingHodes.output_wrapper 1422 | def resize_shortest_to(image, size, method): 1423 | image = PILHandlingHodes.handle_input(image) 1424 | if size < 0: 1425 | raise RuntimeError("Size must be positive") 1426 | if image.width < image.height: 1427 | return ( 1428 | image.resize( 1429 | (size, int(image.height * size / image.width)), 1430 | ResizeShortestToNode.constants[method], 1431 | ), 1432 | ) 1433 | else: 1434 | return ( 1435 | image.resize( 1436 | (int(image.width * size / image.height), size), 1437 | ResizeShortestToNode.constants[method], 1438 | ), 1439 | ) 1440 | 1441 | @classmethod 1442 | def INPUT_TYPES(cls): 1443 | return { 1444 | "required": { 1445 | "image": ("IMAGE",), 1446 | "size": ("INT", {"default": 512}), 1447 | "method": (["NEAREST", "LANCZOS", "BICUBIC"],), 1448 | }, 1449 | } 1450 | 1451 | 1452 | @fundamental_node 1453 | class ResizeLongestToNode: 1454 | FUNCTION = "resize_longest_to" 1455 | RETURN_TYPES = ("IMAGE",) 1456 | CATEGORY = "image" 1457 | custom_name = "Resize Longest To" 1458 | 1459 | constants = { 1460 | "NEAREST": Image.Resampling.NEAREST, 1461 | "LANCZOS": Image.Resampling.LANCZOS, 1462 | "BICUBIC": Image.Resampling.BICUBIC, 1463 | } 1464 | 1465 | @staticmethod 1466 | @PILHandlingHodes.output_wrapper 1467 | def resize_longest_to(image, size, method): 1468 | image = PILHandlingHodes.handle_input(image) 1469 | if size < 0: 1470 | raise RuntimeError("Size must be positive") 1471 | if image.width > image.height: 1472 | return ( 1473 | image.resize( 1474 | (size, int(image.height * size / image.width)), 1475 | ResizeLongestToNode.constants[method], 1476 | ), 1477 | ) 1478 | else: 1479 | return ( 1480 | image.resize( 1481 | (int(image.width * size / image.height), size), 1482 | ResizeLongestToNode.constants[method], 1483 | ), 1484 | ) 1485 | 1486 | @classmethod 1487 | def INPUT_TYPES(cls): 1488 | return { 1489 | "required": { 1490 | "image": ("IMAGE",), 1491 | "size": ("INT", {"default": 512}), 1492 | "method": (["NEAREST", "LANCZOS", "BICUBIC"],), 1493 | }, 1494 | } 1495 | 1496 | 1497 | @fundamental_node 1498 | class ConvertGreyscaleNode: 1499 | FUNCTION = "convert_greyscale" 1500 | RETURN_TYPES = ("IMAGE",) 1501 | CATEGORY = "image" 1502 | custom_name = "Convert Greyscale" 1503 | 1504 | @staticmethod 1505 | @PILHandlingHodes.output_wrapper 1506 | def convert_greyscale(image): 1507 | image = PILHandlingHodes.handle_input(image) 1508 | greyscale_image = image.convert("L") 1509 | # 3 channel greyscale image 1510 | return (greyscale_image.convert("RGB"),) 1511 | 1512 | @classmethod 1513 | def INPUT_TYPES(cls): 1514 | return { 1515 | "required": { 1516 | "image": ("IMAGE",), 1517 | } 1518 | } 1519 | 1520 | 1521 | @fundamental_node 1522 | class RotateImageNode: 1523 | FUNCTION = "rotate_image" 1524 | RETURN_TYPES = ("IMAGE",) 1525 | CATEGORY = "image" 1526 | custom_name = "Rotate Image" 1527 | 1528 | @staticmethod 1529 | @PILHandlingHodes.output_wrapper 1530 | def rotate_image(image, angle): 1531 | image = PILHandlingHodes.handle_input(image) 1532 | return (image.rotate(angle),) 1533 | 1534 | @classmethod 1535 | def INPUT_TYPES(cls): 1536 | return { 1537 | "required": { 1538 | "image": ("IMAGE",), 1539 | "angle": ("INT", {"default": 0}), 1540 | } 1541 | } 1542 | 1543 | 1544 | @fundamental_node 1545 | class BrightnessNode: 1546 | FUNCTION = "brightness" 1547 | RETURN_TYPES = ("IMAGE",) 1548 | CATEGORY = "image" 1549 | custom_name = "Brightness" 1550 | 1551 | @staticmethod 1552 | @PILHandlingHodes.output_wrapper 1553 | def brightness(image, factor): 1554 | image = PILHandlingHodes.handle_input(image) 1555 | enhancer = ImageEnhance.Brightness(image) 1556 | return (enhancer.enhance(factor),) 1557 | 1558 | @classmethod 1559 | def INPUT_TYPES(cls): 1560 | return { 1561 | "required": { 1562 | "image": ("IMAGE",), 1563 | "factor": ("FLOAT", {"default": 1.0}), 1564 | } 1565 | } 1566 | 1567 | 1568 | @fundamental_node 1569 | class ContrastNode: 1570 | FUNCTION = "contrast" 1571 | RETURN_TYPES = ("IMAGE",) 1572 | CATEGORY = "image" 1573 | custom_name = "Contrast" 1574 | 1575 | @staticmethod 1576 | @PILHandlingHodes.output_wrapper 1577 | def contrast(image, factor): 1578 | image = PILHandlingHodes.handle_input(image) 1579 | enhancer = ImageEnhance.Contrast(image) 1580 | return (enhancer.enhance(factor),) 1581 | 1582 | @classmethod 1583 | def INPUT_TYPES(cls): 1584 | return { 1585 | "required": { 1586 | "image": ("IMAGE",), 1587 | "factor": ("FLOAT", {"default": 1.0}), 1588 | } 1589 | } 1590 | 1591 | 1592 | @fundamental_node 1593 | class SharpnessNode: 1594 | FUNCTION = "sharpness" 1595 | RETURN_TYPES = ("IMAGE",) 1596 | CATEGORY = "image" 1597 | custom_name = "Sharpness" 1598 | 1599 | @staticmethod 1600 | @PILHandlingHodes.output_wrapper 1601 | def sharpness(image, factor): 1602 | image = PILHandlingHodes.handle_input(image) 1603 | enhancer = ImageEnhance.Sharpness(image) 1604 | return (enhancer.enhance(factor),) 1605 | 1606 | @classmethod 1607 | def INPUT_TYPES(cls): 1608 | return { 1609 | "required": { 1610 | "image": ("IMAGE",), 1611 | "factor": ("FLOAT", {"default": 1.0}), 1612 | } 1613 | } 1614 | 1615 | 1616 | @fundamental_node 1617 | class ColorNode: 1618 | FUNCTION = "color" 1619 | RETURN_TYPES = ("IMAGE",) 1620 | CATEGORY = "image" 1621 | custom_name = "Color" 1622 | 1623 | @staticmethod 1624 | @PILHandlingHodes.output_wrapper 1625 | def color(image, factor): 1626 | image = PILHandlingHodes.handle_input(image) 1627 | enhancer = ImageEnhance.Color(image) 1628 | return (enhancer.enhance(factor),) 1629 | 1630 | @classmethod 1631 | def INPUT_TYPES(cls): 1632 | return { 1633 | "required": { 1634 | "image": ("IMAGE",), 1635 | "factor": ("FLOAT", {"default": 1.0}), 1636 | } 1637 | } 1638 | 1639 | 1640 | @fundamental_node 1641 | class ConvertRGBNode: 1642 | FUNCTION = "convert_rgb" 1643 | RETURN_TYPES = ("IMAGE",) 1644 | CATEGORY = "image" 1645 | custom_name = "Convert RGB" 1646 | 1647 | @staticmethod 1648 | @PILHandlingHodes.output_wrapper 1649 | def convert_rgb(image): 1650 | image = PILHandlingHodes.handle_input(image) 1651 | return (image.convert("RGB"),) 1652 | 1653 | @classmethod 1654 | def INPUT_TYPES(cls): 1655 | return { 1656 | "required": { 1657 | "image": ("IMAGE",), 1658 | } 1659 | } 1660 | 1661 | 1662 | @fundamental_node 1663 | class FFTNode: 1664 | FUNCTION = "fft_image" 1665 | RETURN_TYPES = ("IMAGE",) 1666 | CATEGORY = "image" 1667 | custom_name = "FFT Image" 1668 | 1669 | @staticmethod 1670 | def fft_image(image: Image.Image, mask_radius: int) -> Image.Image: 1671 | """ 1672 | Applies an FFT-based low-pass filter to an input PIL image. 1673 | 1674 | Args: 1675 | image (Image.Image): Input PIL image to filter. 1676 | mask_radius (int): Radius of the low-pass circular mask. 1677 | 1678 | Returns: 1679 | Image.Image: The filtered image as a PIL Image. 1680 | """ 1681 | # Convert image to numpy array 1682 | images = PILHandlingHodes.handle_input(image) 1683 | results = [] 1684 | if isinstance(images, list): 1685 | for image in images: 1686 | image_np = np.array(image.convert("RGB")) 1687 | 1688 | # Compute FFT for each channel and shift to center 1689 | fft_channels = [ 1690 | np.fft.fftshift(np.fft.fft2(image_np[:, :, channel])) 1691 | for channel in range(3) 1692 | ] 1693 | 1694 | # Create low-pass filter mask 1695 | rows, cols = image_np.shape[:2] 1696 | crow, ccol = rows // 2, cols // 2 1697 | mask = np.zeros((rows, cols), dtype=np.uint8) 1698 | y, x = np.ogrid[-crow : rows - crow, -ccol : cols - ccol] 1699 | mask_area = x**2 + y**2 <= mask_radius**2 1700 | mask[mask_area] = 1 1701 | 1702 | # Apply mask and perform inverse FFT 1703 | filtered_channels = [ 1704 | np.abs(np.fft.ifft2(np.fft.ifftshift(channel * mask))) 1705 | for channel in fft_channels 1706 | ] 1707 | 1708 | # Combine channels and convert back to image format 1709 | filtered_image_np = np.stack(filtered_channels, axis=-1) 1710 | filtered_image_np = np.clip(filtered_image_np, 0, 255).astype(np.uint8) 1711 | results.append( 1712 | PILHandlingHodes.handle_output_as_tensor( 1713 | Image.fromarray(filtered_image_np) 1714 | ) 1715 | ) 1716 | return (results,) 1717 | else: 1718 | image_np = np.array(image.convert("RGB")) 1719 | 1720 | # Compute FFT for each channel and shift to center 1721 | fft_channels = [ 1722 | np.fft.fftshift(np.fft.fft2(image_np[:, :, channel])) 1723 | for channel in range(3) 1724 | ] 1725 | 1726 | # Create low-pass filter mask 1727 | rows, cols = image_np.shape[:2] 1728 | crow, ccol = rows // 2, cols // 2 1729 | mask = np.zeros((rows, cols), dtype=np.uint8) 1730 | y, x = np.ogrid[-crow : rows - crow, -ccol : cols - ccol] 1731 | mask_area = x**2 + y**2 <= mask_radius**2 1732 | mask[mask_area] = 1 1733 | 1734 | # Apply mask and perform inverse FFT 1735 | filtered_channels = [ 1736 | np.abs(np.fft.ifft2(np.fft.ifftshift(channel * mask))) 1737 | for channel in fft_channels 1738 | ] 1739 | 1740 | # Combine channels and convert back to image format 1741 | filtered_image_np = np.stack(filtered_channels, axis=-1) 1742 | filtered_image_np = np.clip(filtered_image_np, 0, 255).astype(np.uint8) 1743 | return ( 1744 | PILHandlingHodes.handle_output_as_tensor( 1745 | Image.fromarray(filtered_image_np) 1746 | ), 1747 | ) 1748 | 1749 | @classmethod 1750 | def INPUT_TYPES(cls): 1751 | return { 1752 | "required": { 1753 | "image": ("IMAGE",), 1754 | "mask_radius": ("INT", {"default": 50}), 1755 | } 1756 | } 1757 | 1758 | 1759 | @fundamental_node 1760 | class GetImageInfoNode: 1761 | FUNCTION = "get_image_info" 1762 | RETURN_TYPES = ("WIDTH", "HEIGHT", "TOTAL_PIXELS") 1763 | CATEGORY = "image" 1764 | custom_name = "Get Image Info" 1765 | 1766 | @staticmethod 1767 | def get_image_info(image): 1768 | image = PILHandlingHodes.handle_input(image) 1769 | width, height = image.size 1770 | return (width, height, width * height) 1771 | 1772 | @classmethod 1773 | def INPUT_TYPES(cls): 1774 | return { 1775 | "required": { 1776 | "image": ("IMAGE",), 1777 | } 1778 | } 1779 | 1780 | 1781 | @fundamental_node 1782 | class ThresholdNode: 1783 | FUNCTION = "threshold" 1784 | RETURN_TYPES = ("IMAGE",) 1785 | CATEGORY = "image" 1786 | custom_name = "Threshold image with value" 1787 | 1788 | @staticmethod 1789 | @PILHandlingHodes.output_wrapper 1790 | def threshold(image, threshold): 1791 | image = PILHandlingHodes.handle_input(image) 1792 | return (image.point(lambda p: p > threshold and 255),) 1793 | 1794 | @classmethod 1795 | def INPUT_TYPES(cls): 1796 | return { 1797 | "required": { 1798 | "image": ("IMAGE",), 1799 | "threshold": ("INT", {"default": 128}), 1800 | } 1801 | } 1802 | 1803 | 1804 | CLASS_MAPPINGS, CLASS_NAMES = get_node_names_mappings(fundamental_classes) 1805 | validate(fundamental_classes) 1806 | --------------------------------------------------------------------------------