├── .gitignore ├── README.md ├── __init__.py ├── comfyui_batch_io.py ├── comfyui_color_ops.py ├── comfyui_datetime.py ├── comfyui_debug.py ├── comfyui_default.py ├── comfyui_group_io.py ├── comfyui_image_channel_ops.py ├── comfyui_image_ops.py ├── comfyui_image_sequence.py ├── comfyui_info_hash.py ├── comfyui_jw.py ├── comfyui_mask_sequence_ops.py ├── comfyui_primitive_ops.py ├── comfyui_raft.py ├── comfyui_rc.py ├── comfyui_sound.py ├── comfyui_string_list.py └── comfyui_uncrop.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | *.log 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Various ComfyUI Nodes by Type 2 | 3 | This repository provides various nodes for use in ComfyUI. 4 | 5 | The nodes are grouped into separate files, you can just download the specific file you want to avoid filling your nodes list with nodes you don't need. 6 | 7 | ## Installation 8 | 9 | ### Method 1 (Recommended): Download each file individually 10 | 11 | Go though each file and see which nodes you want to use. Download the corresponding file and put it in: 12 | 13 | ``` 14 | ComfyUI/custom_nodes 15 | ``` 16 | 17 | If you want to use `RAFTLoadFlowFromEXRChannels` from `comfyui_raft.py`, you must install `OpenEXR` in your ComfyUI Python environment. 18 | 19 | ```sh 20 | # Activate your Python environment first. 21 | pip install OpenEXR 22 | ``` 23 | 24 | ### Method 2: Clone the repo 25 | 26 | This method populates your node list with most nodes in this repository, which may be annoying to some (e.g. me) so I recommend using **Method 1** to keep your nodes list organised. 27 | 28 | If you're happy with installing most nodes in this repository, clone the repository to your `custom_nodes` folder: 29 | 30 | ``` 31 | cd ComfyUI/custom_nodes 32 | git clone https://github.com/jamesWalker55/comfyui-various 33 | ``` 34 | 35 | ## Available Nodes 36 | 37 | Each `comfyui_*.py` file contains a group of nodes of similar purpose. This repo is still in early stages so I can't write documentation for each file yet - have a look at the code for each file to see what they are for. 38 | 39 | ``` 40 | comfyui_image_ops 41 | JWImageLoadRGB: Image Load RGB 42 | JWImageLoadRGBA: Image Load RGBA 43 | JWLoadImagesFromString: Load Images From String 44 | JWImageSaveToPath: Image Save To Path 45 | JWImageExtractFromBatch: Image Extract From Batch 46 | JWImageBatchCount: Get Image Batch Count 47 | JWImageResize: Image Resize 48 | JWImageFlip: Image Flip 49 | JWMaskResize: Mask Resize 50 | JWMaskLikeImageSize: Mask Like Image Size 51 | JWImageResizeToSquare: Image Resize to Square 52 | JWImageResizeByFactor: Image Resize by Factor 53 | JWImageResizeByShorterSide: Image Resize by Shorter Side 54 | JWImageResizeByLongerSide: Image Resize by Longer Side 55 | JWImageResizeToClosestSDXLResolution: Image Resize to Closest SDXL Resolution 56 | JWImageLoadRGBFromClipboard: Image Load RGB From Clipboard 57 | JWImageLoadRGBA From Clipboard: Image Load RGBA From Clipboard 58 | 59 | comfyui_primitive_ops 60 | JWInteger: Integer 61 | JWIntegerToFloat: Integer to Float 62 | JWIntegerToString: Integer to String 63 | JWIntegerAdd: Integer Add 64 | JWIntegerSub: Integer Subtract 65 | JWIntegerMul: Integer Multiply 66 | JWIntegerDiv: Integer Divide 67 | JWIntegerAbsolute: Integer Absolute Value 68 | JWIntegerMin: Integer Minimum 69 | JWIntegerMax: Integer Maximum 70 | JWFloat: Float 71 | JWFloatToInteger: Float to Integer 72 | JWFloatToString: Float to String 73 | JWFloatAdd: Float Add 74 | JWFloatSub: Float Subtract 75 | JWFloatMul: Float Multiply 76 | JWFloatDiv: Float Divide 77 | JWFloatAbsolute: Float Absolute Value 78 | JWFloatMin: Float Minimum 79 | JWFloatMax: Float Maximum 80 | JWString: String 81 | JWStringToInteger: String to Integer 82 | JWStringToFloat: String to Float 83 | JWStringMultiline: String (Multiline) 84 | JWStringConcat: String Concatenate 85 | JWStringReplace: String Replace 86 | JWStringSplit: String Split 87 | JWStringGetLine: String Get Line 88 | JWStringUnescape: String Unescape 89 | 90 | comfyui_raft 91 | RAFTEstimate: RAFT Estimate 92 | RAFTFlowToImage: RAFT Flow to Image 93 | RAFTLoadFlowFromEXRChannels: RAFT Load Flow from EXR Channels 94 | 95 | comfyui_image_channel_ops 96 | JWImageStackChannels: Image Stack Channels 97 | 98 | comfyui_color_ops 99 | JWImageMix: Image Mix 100 | JWImageContrast: Image Contrast 101 | JWImageSaturation: Image Saturation 102 | JWImageLevels: Image Levels 103 | 104 | comfyui_datetime 105 | JWDatetimeString: Datetime String 106 | 107 | comfyui_image_sequence 108 | JWLoadImageSequence: Batch Load Image Sequence 109 | JWLoadImageSequenceWithStopIndex: Batch Load Image Sequence With Stop Index 110 | JWImageSequenceExtractFromBatch: Extract Image Sequence From Batch 111 | JWSaveImageSequence: Batch Save Image Sequence 112 | JWLoopImageSequence: Loop Image Sequence 113 | 114 | comfyui_mask_sequence_ops 115 | JWMaskSequenceFromMask: Mask Sequence From Mask 116 | JWMaskSequenceJoin: Join Mask Sequence 117 | JWMaskSequenceApplyToLatent: Apply Mask Sequence to Latent 118 | ``` 119 | 120 | ### Other nodes 121 | 122 | Some files contain nodes for my own personal use, and are likely completely useless to anyone else. These nodes are hidden by default but can be enabled by setting the environment variable `COMFYUI_JW_ENABLE_EXTRA_NODES` to `true`. These files are hidden by default: 123 | 124 | - `comfyui_batch_io.py` 125 | - `comfyui_group_io.py` 126 | - `comfyui_cn_preprocessors.py` _(Use [Fannovel16's preprocessor nodes](https://github.com/Fannovel16/comfy_controlnet_preprocessors) instead, they're way better)_ 127 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | NODE_CLASS_MAPPINGS = {} 5 | NODE_DISPLAY_NAME_MAPPINGS = {} 6 | 7 | # Main nodes for all users 8 | NODE_MODULES = [ 9 | ".comfyui_image_ops", 10 | ".comfyui_primitive_ops", 11 | ".comfyui_raft", 12 | ".comfyui_image_channel_ops", 13 | ".comfyui_color_ops", 14 | ".comfyui_datetime", 15 | ".comfyui_image_sequence", 16 | ".comfyui_mask_sequence_ops", 17 | ".comfyui_default", 18 | ".comfyui_sound", 19 | ] 20 | 21 | # Extra nodes for my own use 22 | if ( 23 | "COMFYUI_JW_ENABLE_EXTRA_NODES" in os.environ 24 | and os.environ["COMFYUI_JW_ENABLE_EXTRA_NODES"].lower() == "true" 25 | ): 26 | NODE_MODULES.extend( 27 | [ 28 | ".comfyui_batch_io", 29 | ".comfyui_group_io", 30 | ".comfyui_jw", 31 | ".comfyui_info_hash", 32 | ".comfyui_debug", 33 | ".comfyui_string_list", 34 | ".comfyui_uncrop", 35 | ".comfyui_rc", 36 | ] 37 | ) 38 | 39 | 40 | def load_nodes(module_name: str): 41 | global NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 42 | 43 | module = importlib.import_module(module_name, package=__name__) 44 | 45 | NODE_CLASS_MAPPINGS = { 46 | **NODE_CLASS_MAPPINGS, 47 | **module.NODE_CLASS_MAPPINGS, 48 | } 49 | NODE_DISPLAY_NAME_MAPPINGS = { 50 | **NODE_DISPLAY_NAME_MAPPINGS, 51 | **module.NODE_DISPLAY_NAME_MAPPINGS, 52 | } 53 | 54 | 55 | def write_nodes_list(module_names: list[str]): 56 | this_dir = os.path.dirname(os.path.abspath(__file__)) 57 | path = os.path.join(this_dir, "nodes.log") 58 | 59 | lines = [] 60 | 61 | for module_name in module_names: 62 | module = importlib.import_module(module_name, package=__name__) 63 | 64 | lines.append(module_name.strip(".")) 65 | 66 | for identifier, display_name in module.NODE_DISPLAY_NAME_MAPPINGS.items(): 67 | lines.append(f" {identifier}: {display_name}") 68 | 69 | lines.append("") 70 | 71 | lines = "\n".join(lines) 72 | 73 | with open(path, "w", encoding="utf8") as f: 74 | f.write(lines) 75 | 76 | 77 | for module_name in NODE_MODULES: 78 | load_nodes(module_name) 79 | 80 | # write_nodes_list(NODE_MODULES) 81 | -------------------------------------------------------------------------------- /comfyui_batch_io.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import subprocess 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | from PIL.PngImagePlugin import PngInfo 11 | 12 | NODE_CLASS_MAPPINGS = {} 13 | NODE_DISPLAY_NAME_MAPPINGS = {} 14 | 15 | 16 | def register_node(identifier: str, display_name: str): 17 | def decorator(cls): 18 | NODE_CLASS_MAPPINGS[identifier] = cls 19 | NODE_DISPLAY_NAME_MAPPINGS[identifier] = display_name 20 | 21 | return cls 22 | 23 | return decorator 24 | 25 | 26 | def load_image(path): 27 | img = Image.open(path).convert("RGB") 28 | img = np.array(img).astype(np.float32) / 255.0 29 | img = torch.from_numpy(img).unsqueeze(0) 30 | return img 31 | 32 | 33 | @register_node("BatchLoadImage", "[DEPRECATED] Batch Load Image") 34 | class _: 35 | """ 36 | Batch-load images in a given folder. To avoid loading too many images at once, 37 | you can use `paginate_size` and `paginate_page` to load a subset of the images. 38 | 39 | To disable pagination functionality, leave `paginate_size` and `paginate_page` at 0. 40 | """ 41 | 42 | CATEGORY = "jamesWalker55" 43 | INPUT_TYPES = lambda: { 44 | "required": { 45 | "image_dir": ("STRING", {"default": "./images", "multiline": False}), 46 | "glob_pattern": ("STRING", {"default": "*.png", "multiline": False}), 47 | "paginate_size": ("INT", {"default": 0, "min": 0}), 48 | "paginate_page": ("INT", {"default": 0, "min": 0}), 49 | } 50 | } 51 | RETURN_NAMES = ("IMAGE", "FRAME_COUNT", "FILENAMES") 52 | RETURN_TYPES = ("IMAGE", "INT", "STRING") 53 | FUNCTION = "execute" 54 | 55 | def execute( 56 | self, image_dir: str, glob_pattern: str, paginate_size: int, paginate_page: int 57 | ): 58 | assert isinstance(image_dir, str) 59 | assert isinstance(glob_pattern, str) 60 | assert isinstance(paginate_size, int) 61 | assert isinstance(paginate_page, int) 62 | 63 | # get paths relative to root dir 64 | paths = glob.glob(glob_pattern, root_dir=image_dir, recursive=True) 65 | # convert paths to be relative to here 66 | paths = [os.path.join(image_dir, x) for x in paths] 67 | # sort paths alphabetically 68 | paths.sort() 69 | 70 | if len(paths) == 0: 71 | raise FileNotFoundError( 72 | f"No images found in folder matching pattern {glob_pattern!r}" 73 | ) 74 | 75 | if paginate_size > 0: 76 | start_offset = paginate_page * paginate_size 77 | if start_offset > len(paths): 78 | raise StopIteration( 79 | f"No more images in folder at page {paginate_page}!" 80 | ) 81 | paths = paths[start_offset : start_offset + paginate_size] 82 | 83 | filenames = [os.path.splitext(os.path.basename(x))[0] for x in paths] 84 | 85 | imgs = [] 86 | for p in paths: 87 | img = load_image(p) 88 | # img.shape => torch.Size([1, 768, 768, 3]) 89 | imgs.append(img) 90 | 91 | imgs = torch.cat(imgs, dim=0) 92 | 93 | assert len(imgs) == len(filenames) 94 | 95 | return (imgs, len(imgs), "\n".join(filenames)) 96 | 97 | 98 | @register_node("BatchSaveImage", "[DEPRECATED] Batch Save Image") 99 | class _: 100 | CATEGORY = "jamesWalker55" 101 | INPUT_TYPES = lambda: { 102 | "required": { 103 | "images": ("IMAGE",), 104 | "output_dir": ("STRING", {"default": "./", "multiline": False}), 105 | "name_prefix": ("STRING", {"default": ""}), 106 | "name_suffix": ("STRING", {"default": ""}), 107 | "numbering_start": ( 108 | "INT", 109 | {"default": 1, "min": 0, "step": 1}, 110 | ), 111 | "numbering_digits": ("INT", {"default": 4, "min": 1, "step": 1}), 112 | "render_video_fps": ("INT", {"default": 8, "min": 0, "step": 1}), 113 | }, 114 | "optional": { 115 | "filenames": ("STRING", {"multiline": True, "dynamicPrompts": False}), 116 | }, 117 | "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, 118 | } 119 | RETURN_TYPES = () 120 | OUTPUT_NODE = True 121 | FUNCTION = "main" 122 | 123 | def main( 124 | self, 125 | images: torch.Tensor, 126 | output_dir: str, 127 | name_prefix: str, 128 | name_suffix: str, 129 | numbering_start: int, 130 | numbering_digits: int, 131 | render_video_fps: int, 132 | filenames: str | None = None, 133 | prompt=None, 134 | extra_pnginfo=None, 135 | ): 136 | if filenames is not None: 137 | filenames = [x.strip() for x in filenames.splitlines()] 138 | filenames = [x for x in filenames if len(x) > 0] 139 | if len(filenames) != len(images): 140 | raise ValueError( 141 | f"Number of images ({len(images)}) and filenames ({len(filenames)}) must be the same" 142 | ) 143 | filenames = filenames.copy() 144 | filenames.reverse() 145 | 146 | output_dir: Path = Path(output_dir) 147 | output_dir.mkdir(exist_ok=True) 148 | 149 | ui_results = [] 150 | 151 | for i, img in enumerate(images): 152 | num = i + numbering_start 153 | if filenames is not None: 154 | filename = filenames.pop() 155 | filename = f"{filename}.png" 156 | else: 157 | filename = f"{name_prefix}{num:0{numbering_digits}d}{name_suffix}.png" 158 | output_path = output_dir / filename 159 | ui = self.save_image( 160 | img, output_path, prompt=prompt, extra_pnginfo=extra_pnginfo 161 | ) 162 | ui_results.append(ui) 163 | 164 | if render_video_fps > 0: 165 | subprocess.run( 166 | [ 167 | "python", 168 | R"D:\Programming\bin\render-img-sequence.py", 169 | "-i", 170 | str(output_dir), 171 | "-r", 172 | str(render_video_fps), 173 | ] 174 | ) 175 | 176 | return {"ui": {"images": ui_results}} 177 | 178 | @staticmethod 179 | def save_image(img: torch.Tensor, path, prompt=None, extra_pnginfo: dict = None): 180 | path = str(path) 181 | 182 | img = 255.0 * img.cpu().numpy() 183 | img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8)) 184 | 185 | metadata = PngInfo() 186 | 187 | if prompt is not None: 188 | metadata.add_text("prompt", json.dumps(prompt)) 189 | 190 | if extra_pnginfo is not None: 191 | for k, v in extra_pnginfo.items(): 192 | metadata.add_text(k, json.dumps(v)) 193 | 194 | img.save(path, pnginfo=metadata, compress_level=4) 195 | 196 | subfolder, filename = os.path.split(path) 197 | 198 | return {"filename": filename, "subfolder": subfolder, "type": "output"} 199 | -------------------------------------------------------------------------------- /comfyui_color_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | 4 | NODE_CLASS_MAPPINGS = {} 5 | NODE_DISPLAY_NAME_MAPPINGS = {} 6 | 7 | 8 | def register_node(identifier: str, display_name: str): 9 | def decorator(cls): 10 | NODE_CLASS_MAPPINGS[identifier] = cls 11 | NODE_DISPLAY_NAME_MAPPINGS[identifier] = display_name 12 | 13 | return cls 14 | 15 | return decorator 16 | 17 | 18 | @register_node("JWImageMix", "Image Mix") 19 | class _: 20 | CATEGORY = "jamesWalker55" 21 | BLEND_TYPES = ("mix", "multiply") 22 | 23 | @classmethod 24 | def INPUT_TYPES(cls): 25 | return { 26 | "required": { 27 | "blend_type": (cls.BLEND_TYPES, {"default": "mix"}), 28 | "factor": ("FLOAT", {"min": 0, "max": 1, "step": 0.01, "default": 0.5}), 29 | "image_a": ("IMAGE",), 30 | "image_b": ("IMAGE",), 31 | } 32 | } 33 | 34 | RETURN_TYPES = ("IMAGE",) 35 | FUNCTION = "execute" 36 | 37 | def execute( 38 | self, 39 | blend_type: str, 40 | factor: float, 41 | image_a: torch.Tensor, 42 | image_b: torch.Tensor, 43 | ): 44 | assert blend_type in self.BLEND_TYPES 45 | assert isinstance(factor, float) 46 | assert isinstance(image_a, torch.Tensor) 47 | assert isinstance(image_b, torch.Tensor) 48 | 49 | assert image_a.shape == image_b.shape 50 | 51 | if blend_type == "mix": 52 | mixed = image_a * (1 - factor) + image_b * factor 53 | elif blend_type == "multiply": 54 | mixed = image_a * (1 - factor + image_b * factor) 55 | else: 56 | raise NotImplementedError(f"Blend type not yet implemented: {blend_type}") 57 | 58 | return (mixed,) 59 | 60 | 61 | @register_node("JWImageContrast", "Image Contrast") 62 | class _: 63 | CATEGORY = "jamesWalker55" 64 | INPUT_TYPES = lambda: { 65 | "required": { 66 | "image": ("IMAGE",), 67 | "factor": ( 68 | "FLOAT", 69 | {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.01}, 70 | ), 71 | } 72 | } 73 | RETURN_TYPES = ("IMAGE",) 74 | FUNCTION = "execute" 75 | 76 | def execute( 77 | self, 78 | image: torch.Tensor, 79 | factor: float, 80 | ): 81 | assert isinstance(image, torch.Tensor) 82 | assert isinstance(factor, float) 83 | 84 | image = image.permute(0, 3, 1, 2) 85 | image = F.adjust_contrast(image, factor) 86 | image = image.permute(0, 2, 3, 1) 87 | 88 | return (image,) 89 | 90 | 91 | @register_node("JWImageSaturation", "Image Saturation") 92 | class _: 93 | CATEGORY = "jamesWalker55" 94 | INPUT_TYPES = lambda: { 95 | "required": { 96 | "image": ("IMAGE",), 97 | "factor": ( 98 | "FLOAT", 99 | {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.01}, 100 | ), 101 | } 102 | } 103 | RETURN_TYPES = ("IMAGE",) 104 | FUNCTION = "execute" 105 | 106 | def execute( 107 | self, 108 | image: torch.Tensor, 109 | factor: float, 110 | ): 111 | assert isinstance(image, torch.Tensor) 112 | assert isinstance(factor, float) 113 | 114 | image = image.permute(0, 3, 1, 2) 115 | image = F.adjust_saturation(image, factor) 116 | image = image.permute(0, 2, 3, 1) 117 | 118 | return (image,) 119 | 120 | 121 | @register_node("JWImageLevels", "Image Levels") 122 | class _: 123 | CATEGORY = "jamesWalker55" 124 | INPUT_TYPES = lambda: { 125 | "required": { 126 | "image": ("IMAGE",), 127 | "min": ( 128 | "FLOAT", 129 | {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}, 130 | ), 131 | "max": ( 132 | "FLOAT", 133 | {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, 134 | ), 135 | } 136 | } 137 | RETURN_TYPES = ("IMAGE",) 138 | FUNCTION = "execute" 139 | 140 | def execute( 141 | self, 142 | image: torch.Tensor, 143 | min: float, 144 | max: float, 145 | ): 146 | assert isinstance(image, torch.Tensor) 147 | assert isinstance(min, float) 148 | assert isinstance(max, float) 149 | 150 | image = (image - min) / (max - min) 151 | image = torch.clamp(image, 0.0, 1.0) 152 | 153 | return (image,) 154 | -------------------------------------------------------------------------------- /comfyui_datetime.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | NODE_CLASS_MAPPINGS = {} 4 | NODE_DISPLAY_NAME_MAPPINGS = {} 5 | 6 | 7 | def register_node(identifier: str, display_name: str): 8 | def decorator(cls): 9 | NODE_CLASS_MAPPINGS[identifier] = cls 10 | NODE_DISPLAY_NAME_MAPPINGS[identifier] = display_name 11 | 12 | return cls 13 | 14 | return decorator 15 | 16 | 17 | @register_node("JWDatetimeString", "Datetime String") 18 | class _: 19 | CATEGORY = "jamesWalker55" 20 | INPUT_TYPES = lambda: { 21 | "required": { 22 | "format": ("STRING", {"default": "%Y-%m-%dT%H:%M:%S"}), 23 | } 24 | } 25 | RETURN_TYPES = ("STRING",) 26 | FUNCTION = "execute" 27 | 28 | def execute(self, format: str): 29 | now = datetime.now() 30 | return (now.strftime(format),) 31 | 32 | @classmethod 33 | def IS_CHANGED(cls, *args): 34 | # This value will be compared with previous 'IS_CHANGED' outputs 35 | # If inequal, then this node will be considered as modified 36 | # NaN is never equal to itself 37 | return float("NaN") 38 | -------------------------------------------------------------------------------- /comfyui_debug.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | from pprint import pformat, pprint 3 | 4 | import torch 5 | 6 | NODE_CLASS_MAPPINGS = {} 7 | NODE_DISPLAY_NAME_MAPPINGS = {} 8 | 9 | 10 | def register_node(identifier: str, display_name: str): 11 | def decorator(cls): 12 | NODE_CLASS_MAPPINGS[identifier] = cls 13 | NODE_DISPLAY_NAME_MAPPINGS[identifier] = display_name 14 | 15 | return cls 16 | 17 | return decorator 18 | 19 | 20 | @register_node("JWPrintInteger", "Print Integer") 21 | class _: 22 | CATEGORY = "jamesWalker55" 23 | INPUT_TYPES = lambda: { 24 | "required": { 25 | "value": ("INT", {"default": 0, "min": -99999999999, "max": 99999999999}), 26 | "name": ( 27 | "STRING", 28 | {"default": "integer", "multiline": True, "dynamicPrompts": False}, 29 | ), 30 | } 31 | } 32 | RETURN_TYPES = ("INT",) 33 | OUTPUT_NODE = True 34 | FUNCTION = "execute" 35 | 36 | def execute(self, value, name: str): 37 | print(f"{name} = {pformat(value)}") 38 | 39 | return (value,) 40 | 41 | @classmethod 42 | def IS_CHANGED(cls, *args): 43 | # Always recalculate 44 | return float("NaN") 45 | 46 | 47 | @register_node("JWPrintFloat", "Print Float") 48 | class _: 49 | CATEGORY = "jamesWalker55" 50 | INPUT_TYPES = lambda: { 51 | "required": { 52 | "value": ("FLOAT", {"default": 0, "min": -99999999999, "max": 99999999999}), 53 | "name": ( 54 | "STRING", 55 | {"default": "float", "multiline": True, "dynamicPrompts": False}, 56 | ), 57 | } 58 | } 59 | RETURN_TYPES = ("FLOAT",) 60 | OUTPUT_NODE = True 61 | FUNCTION = "execute" 62 | 63 | def execute(self, value, name: str): 64 | print(f"{name} = {pformat(value)}") 65 | 66 | return (value,) 67 | 68 | @classmethod 69 | def IS_CHANGED(cls, *args): 70 | # Always recalculate 71 | return float("NaN") 72 | 73 | 74 | @register_node("JWPrintString", "Print String") 75 | class _: 76 | CATEGORY = "jamesWalker55" 77 | INPUT_TYPES = lambda: { 78 | "required": { 79 | "value": ("STRING", {"default": "text", "multiline": False}), 80 | "name": ( 81 | "STRING", 82 | {"default": "string", "multiline": True, "dynamicPrompts": False}, 83 | ), 84 | } 85 | } 86 | RETURN_TYPES = ("STRING",) 87 | OUTPUT_NODE = True 88 | FUNCTION = "execute" 89 | 90 | def execute(self, value, name: str): 91 | print(f"{name} = {pformat(value)}") 92 | 93 | return (value,) 94 | 95 | @classmethod 96 | def IS_CHANGED(cls, *args): 97 | # Always recalculate 98 | return float("NaN") 99 | 100 | 101 | @register_node("JWPrintImage", "Print Image") 102 | class _: 103 | CATEGORY = "jamesWalker55" 104 | INPUT_TYPES = lambda: { 105 | "required": { 106 | "value": ("IMAGE",), 107 | "name": ( 108 | "STRING", 109 | {"default": "image", "multiline": True, "dynamicPrompts": False}, 110 | ), 111 | } 112 | } 113 | RETURN_TYPES = ("IMAGE",) 114 | OUTPUT_NODE = True 115 | FUNCTION = "execute" 116 | 117 | def execute(self, value: torch.Tensor, name: str): 118 | lines = [ 119 | f"{name} =", 120 | f" {name}.shape = {value.shape}", 121 | f" {name}.min() = {value.min()}", 122 | f" {name}.max() = {value.max()}", 123 | f" {name}.mean() = {value.mean()}", 124 | f" {name}.std() = {value.std()}", 125 | f" {name}.dtype = {value.dtype}", 126 | ] 127 | lines = "\n".join(lines) 128 | print(lines) 129 | 130 | return (value,) 131 | 132 | @classmethod 133 | def IS_CHANGED(cls, *args): 134 | # Always recalculate 135 | return float("NaN") 136 | 137 | 138 | @register_node("JWPrintMask", "Print Mask") 139 | class _: 140 | CATEGORY = "jamesWalker55" 141 | INPUT_TYPES = lambda: { 142 | "required": { 143 | "value": ("MASK",), 144 | "name": ( 145 | "STRING", 146 | {"default": "mask", "multiline": True, "dynamicPrompts": False}, 147 | ), 148 | } 149 | } 150 | RETURN_TYPES = ("MASK",) 151 | OUTPUT_NODE = True 152 | FUNCTION = "execute" 153 | 154 | def execute(self, value: torch.Tensor, name: str): 155 | lines = [ 156 | f"{name} =", 157 | f" {name}.shape = {value.shape}", 158 | f" {name}.min() = {value.min()}", 159 | f" {name}.max() = {value.max()}", 160 | f" {name}.mean() = {value.mean()}", 161 | f" {name}.std() = {value.std()}", 162 | f" {name}.dtype = {value.dtype}", 163 | ] 164 | lines = "\n".join(lines) 165 | print(lines) 166 | 167 | return (value,) 168 | 169 | @classmethod 170 | def IS_CHANGED(cls, *args): 171 | # Always recalculate 172 | return float("NaN") 173 | 174 | 175 | def serialise_obj(obj): 176 | if isinstance(obj, dict): 177 | text = ["{"] 178 | for k, v in obj.items(): 179 | subtext = [ 180 | textwrap.indent(f"{k!r}:", " "), 181 | textwrap.indent(serialise_obj(v), " "), 182 | ] 183 | text.append("\n".join(subtext)) 184 | text.append("}") 185 | text = "\n".join(text) 186 | elif isinstance(obj, list): 187 | text = [] 188 | for x in obj: 189 | subtext = serialise_obj(x) 190 | subtext = textwrap.indent(subtext, " ") 191 | subtext = f"-{subtext[1:]}" 192 | text.append(subtext) 193 | text = "\n".join(text) 194 | elif isinstance(obj, torch.Tensor): 195 | text = "\n".join( 196 | [ 197 | f"Tensor", 198 | f" .shape = {obj.shape}", 199 | f" .min() = {obj.min()}", 200 | f" .max() = {obj.max()}", 201 | f" .mean() = {obj.mean()}", 202 | f" .std() = {obj.std()}", 203 | f" .dtype = {obj.dtype}", 204 | ] 205 | ) 206 | else: 207 | text = pformat(obj) 208 | return text 209 | 210 | 211 | @register_node("JWPrintLatent", "Print Latent") 212 | class _: 213 | CATEGORY = "jamesWalker55" 214 | INPUT_TYPES = lambda: { 215 | "required": { 216 | "value": ("LATENT",), 217 | "name": ( 218 | "STRING", 219 | {"default": "latent", "multiline": True, "dynamicPrompts": False}, 220 | ), 221 | } 222 | } 223 | RETURN_TYPES = ("LATENT",) 224 | OUTPUT_NODE = True 225 | FUNCTION = "execute" 226 | 227 | def execute(self, value: dict, name: str): 228 | print(f"{name} = {serialise_obj(value)}") 229 | 230 | return (value,) 231 | 232 | @classmethod 233 | def IS_CHANGED(cls, *args): 234 | # Always recalculate 235 | return float("NaN") 236 | -------------------------------------------------------------------------------- /comfyui_default.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision.transforms.functional as F 8 | from PIL import Image 9 | from PIL.PngImagePlugin import PngInfo 10 | from torchvision.transforms import InterpolationMode 11 | 12 | NODE_CLASS_MAPPINGS = {} 13 | NODE_DISPLAY_NAME_MAPPINGS = {} 14 | 15 | 16 | def register_node(identifier: str, display_name: str): 17 | def decorator(cls): 18 | NODE_CLASS_MAPPINGS[identifier] = cls 19 | NODE_DISPLAY_NAME_MAPPINGS[identifier] = display_name 20 | 21 | return cls 22 | 23 | return decorator 24 | 25 | 26 | def load_image(path, convert="RGB"): 27 | img = Image.open(path).convert(convert) 28 | img = np.array(img).astype(np.float32) / 255.0 29 | img = torch.from_numpy(img).unsqueeze(0) 30 | return img 31 | 32 | 33 | def save_image(img: torch.Tensor, path, prompt=None, extra_pnginfo: dict = None): 34 | path = str(path) 35 | 36 | if len(img.shape) != 3: 37 | raise ValueError(f"can't take image batch as input, got {img.shape[0]} images") 38 | 39 | img = img.permute(2, 0, 1) 40 | if img.shape[0] != 3: 41 | raise ValueError(f"image must have 3 channels, but got {img.shape[0]} channels") 42 | 43 | img = img.clamp(0, 1) 44 | img = F.to_pil_image(img) 45 | 46 | metadata = PngInfo() 47 | 48 | if prompt is not None: 49 | metadata.add_text("prompt", json.dumps(prompt)) 50 | 51 | if extra_pnginfo is not None: 52 | for k, v in extra_pnginfo.items(): 53 | metadata.add_text(k, json.dumps(v)) 54 | 55 | img.save(path, pnginfo=metadata, compress_level=4) 56 | 57 | subfolder, filename = os.path.split(path) 58 | 59 | return {"filename": filename, "subfolder": subfolder, "type": "output"} 60 | 61 | 62 | @register_node("JWImageLoadRGBIfExists", "Image Load RGB If Exists") 63 | class _: 64 | CATEGORY = "jamesWalker55" 65 | INPUT_TYPES = lambda: { 66 | "required": { 67 | "default": ("IMAGE",), 68 | "path": ("STRING", {"default": "./image.png"}), 69 | } 70 | } 71 | RETURN_TYPES = ("IMAGE",) 72 | FUNCTION = "execute" 73 | 74 | def execute(self, path: str, default: torch.Tensor): 75 | assert isinstance(path, str) 76 | assert isinstance(default, torch.Tensor) 77 | 78 | if not os.path.exists(path): 79 | return (default,) 80 | 81 | img = load_image(path) 82 | return (img,) 83 | 84 | @classmethod 85 | def IS_CHANGED(cls, path: str, default: torch.Tensor): 86 | if os.path.exists(path): 87 | mtime = os.path.getmtime(path) 88 | else: 89 | mtime = None 90 | return (mtime, default.__hash__()) 91 | -------------------------------------------------------------------------------- /comfyui_group_io.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import glob 3 | import os 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | import yaml 9 | from PIL import Image 10 | 11 | NODE_CLASS_MAPPINGS = {} 12 | NODE_DISPLAY_NAME_MAPPINGS = {} 13 | 14 | 15 | def register_node(identifier: str, display_name: str): 16 | def decorator(cls): 17 | NODE_CLASS_MAPPINGS[identifier] = cls 18 | NODE_DISPLAY_NAME_MAPPINGS[identifier] = display_name 19 | 20 | return cls 21 | 22 | return decorator 23 | 24 | 25 | def load_image(path): 26 | img = Image.open(path).convert("RGB") 27 | img = np.array(img).astype(np.float32) / 255.0 28 | img = torch.from_numpy(img).unsqueeze(0) 29 | return img 30 | 31 | 32 | @register_node("JamesLoadImageGroup", "[DEPRECATED] James: Load Image Group") 33 | class _: 34 | """ 35 | An opinionated batch image loader. This is used for loading groups for batch processing. 36 | 37 | Folder structure: 38 | 39 | ```plain 40 | groups/ 41 | baseprompt.txt 42 | g1/ 43 | 0001.png 44 | 0002.png 45 | subprompt.txt 46 | g2/ 47 | 0003.png 48 | 0004.png 49 | subprompt.txt 50 | ... 51 | ``` 52 | """ 53 | 54 | CATEGORY = "jamesWalker55" 55 | INPUT_TYPES = lambda: { 56 | "required": { 57 | "groups_dir": ("STRING", {"default": "./groups", "multiline": False}), 58 | "groups_id": ("INT", {"default": 1, "min": 0, "step": 1, "max": 9999}), 59 | "base_prompt_name": ( 60 | "STRING", 61 | {"default": "baseprompt.txt", "multiline": False}, 62 | ), 63 | "sub_prompt_name": ( 64 | "STRING", 65 | {"default": "subprompt.txt", "multiline": False}, 66 | ), 67 | "negative_prompt_delimiter": ( 68 | "STRING", 69 | {"default": "---", "multiline": False}, 70 | ), 71 | "image_glob": ( 72 | "STRING", 73 | {"default": "*.png", "multiline": False}, 74 | ), 75 | } 76 | } 77 | RETURN_NAMES = ( 78 | "POSITIVE_PROMPT", 79 | "NEGATIVE_PROMPT", 80 | "IMAGES", 81 | "FRAME_COUNT", 82 | "FILENAMES", 83 | ) 84 | RETURN_TYPES = ("STRING", "STRING", "IMAGE", "INT", "STRING") 85 | FUNCTION = "execute" 86 | 87 | def execute( 88 | self, 89 | groups_dir: str, 90 | groups_id: int, 91 | base_prompt_name: str, 92 | sub_prompt_name: str, 93 | negative_prompt_delimiter: str, 94 | image_glob: str, 95 | ): 96 | assert isinstance(groups_dir, str) 97 | assert isinstance(groups_id, int) 98 | assert isinstance(base_prompt_name, str) 99 | assert isinstance(sub_prompt_name, str) 100 | assert isinstance(negative_prompt_delimiter, str) 101 | 102 | pos_prompt, neg_prompt = self.get_group_prompt( 103 | groups_dir, 104 | groups_id, 105 | base_prompt_name, 106 | sub_prompt_name, 107 | negative_prompt_delimiter, 108 | ) 109 | images, filenames = self.load_group_images(groups_dir, groups_id, image_glob) 110 | 111 | print( 112 | f"JamesLoadImageGroup: {(pos_prompt, neg_prompt, len(filenames), filenames)!r}" 113 | ) 114 | return (pos_prompt, neg_prompt, images, len(filenames), "\n".join(filenames)) 115 | 116 | def get_base_prompt( 117 | self, 118 | groups_dir: str, 119 | base_prompt_name: str, 120 | negative_prompt_delimiter: str, 121 | ): 122 | """Get the base prompt of the group""" 123 | path = os.path.join(groups_dir, base_prompt_name) 124 | with open(path, "r", encoding="utf8") as f: 125 | prompt = f.read() 126 | 127 | match prompt.split(negative_prompt_delimiter, 1): 128 | case pos_prompt, neg_prompt: 129 | return pos_prompt, neg_prompt 130 | case [pos_prompt]: 131 | return pos_prompt, "" 132 | case _: 133 | raise ValueError("Invalid base prompt, more than 1 delimiter found") 134 | 135 | def get_group_path(self, groups_dir: str, groups_id: int): 136 | return os.path.join(groups_dir, f"g{groups_id}") 137 | 138 | def get_sub_prompt( 139 | self, 140 | groups_dir: str, 141 | groups_id: int, 142 | sub_prompt_name: str, 143 | negative_prompt_delimiter: str, 144 | ): 145 | """Get the sub prompt of the group""" 146 | group_path = self.get_group_path(groups_dir, groups_id) 147 | path = os.path.join(group_path, sub_prompt_name) 148 | with open(path, "r", encoding="utf8") as f: 149 | prompt = f.read() 150 | 151 | match prompt.split(negative_prompt_delimiter, 1): 152 | case pos_prompt, neg_prompt: 153 | return pos_prompt, neg_prompt 154 | case [pos_prompt]: 155 | return pos_prompt, "" 156 | case _: 157 | raise ValueError("Invalid sub prompt, more than 1 delimiter found") 158 | 159 | def get_group_prompt( 160 | self, 161 | groups_dir: str, 162 | groups_id: int, 163 | base_prompt_name: str, 164 | sub_prompt_name: str, 165 | negative_prompt_delimiter: str, 166 | ): 167 | """Generate the final combined prompt of the group""" 168 | base_pos, base_neg = self.get_base_prompt( 169 | groups_dir, base_prompt_name, negative_prompt_delimiter 170 | ) 171 | sub_pos, sub_neg = self.get_sub_prompt( 172 | groups_dir, groups_id, sub_prompt_name, negative_prompt_delimiter 173 | ) 174 | group_pos = base_pos.format(sub_pos) 175 | group_neg = base_neg.format(sub_neg) 176 | return group_pos, group_neg 177 | 178 | def load_group_images(self, groups_dir: str, groups_id: int, image_glob: str): 179 | """Get all images for the group""" 180 | group_path = self.get_group_path(groups_dir, groups_id) 181 | 182 | # convert paths to be relative to here 183 | paths = glob.glob(image_glob, root_dir=group_path, recursive=True) 184 | # convert paths to be relative to here 185 | paths = [os.path.join(group_path, x) for x in paths] 186 | # sort paths alphabetically 187 | paths.sort() 188 | 189 | # extract filenames without extension 190 | filenames = [os.path.splitext(os.path.basename(x))[0] for x in paths] 191 | 192 | # must have at least 1 image 193 | if len(paths) == 0: 194 | raise FileNotFoundError( 195 | f"No images found in folder matching pattern {image_glob!r}" 196 | ) 197 | 198 | # load images 199 | imgs = [] 200 | for p in paths: 201 | img = load_image(p) 202 | # img.shape => torch.Size([1, 768, 768, 3]) 203 | imgs.append(img) 204 | 205 | imgs = torch.cat(imgs, dim=0) 206 | 207 | # sanity check, image count == filename count 208 | assert len(imgs) == len(filenames) 209 | 210 | return imgs, filenames 211 | 212 | 213 | class GroupedWorkspace: 214 | """ 215 | YAML structure: 216 | 217 | ```yaml 218 | positive: | 219 | {positive}, 220 | simple background, white background, 221 | 222 | negative: | 223 | {negative}, 224 | low quality, 225 | 226 | image_pattern: '{frame_id:04d}.png' 227 | 228 | groups: 229 | - start_id: 1 230 | positive: ... 231 | negative: ... 232 | - start_id: 5 233 | positive: ... 234 | negative: ... 235 | ... 236 | ``` 237 | """ 238 | 239 | _original_definition: dict 240 | _base_path: Path 241 | _base_pos: str 242 | _base_neg: str 243 | _image_pattern: str 244 | _groups: list[dict] 245 | 246 | def __init__(self, base_path: Path, definition: dict): 247 | self._validate_definition(definition) 248 | self._original_definition = definition 249 | self._base_path = base_path 250 | self._parse_groups(definition) 251 | 252 | @classmethod 253 | def open(cls, path, base_path=None): 254 | if base_path is None: 255 | base_path = Path(path).parent 256 | else: 257 | base_path = Path(base_path) 258 | 259 | with open(path, "r", encoding="utf8") as f: 260 | definition = yaml.safe_load(f) 261 | return cls(base_path, definition) 262 | 263 | @staticmethod 264 | def _validate_definition(definition): 265 | assert isinstance(definition, dict), "file must be a dict" 266 | 267 | assert "positive" in definition, "missing key: positive" 268 | assert isinstance(definition["positive"], str), "positive must be a string" 269 | 270 | assert "negative" in definition, "missing key: negative" 271 | assert isinstance(definition["negative"], str), "negative must be a string" 272 | 273 | assert "image_pattern" in definition, "missing key: image_pattern" 274 | assert isinstance(definition["image_pattern"], str), "pattern must be a string" 275 | 276 | assert "groups" in definition, "missing key: groups" 277 | assert isinstance(definition["groups"], list), "groups must be a list" 278 | assert len(definition["groups"]) > 0, "must have at least 1 group" 279 | 280 | assert "start_id" not in definition, "'start_id' not allowed at root" 281 | assert "group_id" not in definition, "'group_id' not allowed in definition" 282 | 283 | prev_start_id = -1 284 | 285 | for gp in definition["groups"]: 286 | assert isinstance(gp, dict), "group must be a dict" 287 | 288 | assert "start_id" in gp, "group missing key: start_id" 289 | assert "group_id" not in gp, "'group_id' not allowed in definition" 290 | 291 | start_id = gp["start_id"] 292 | assert isinstance(start_id, int), "start_id must be a number" 293 | assert start_id >= 0, "start_id cannot be negative" 294 | assert prev_start_id < start_id, "start_id must be in ascending order" 295 | 296 | prev_start_id = start_id 297 | 298 | def _parse_groups(self, definition: dict): 299 | definition = copy.deepcopy(definition) 300 | 301 | self._base_pos = definition.pop("positive") 302 | self._base_neg = definition.pop("negative") 303 | self._image_pattern = definition.pop("image_pattern") 304 | raw_groups = definition.pop("groups") 305 | assert "start_id" not in definition 306 | 307 | self._groups = [] 308 | 309 | for group in raw_groups: 310 | assert "start_id" in group 311 | assert isinstance(group["start_id"], int) 312 | 313 | # add extra keys in definition to group info 314 | group = {**definition, **group} 315 | 316 | self._groups.append(group) 317 | 318 | def _get_group_info(self, group_id: int): 319 | group = self._groups[group_id] 320 | return {**group, "group_id": group_id} 321 | 322 | def get_group_info(self, group_id: int): 323 | return copy.deepcopy(self._get_group_info(group_id)) 324 | 325 | def _get_frame_info(self, group_id: int, frame_id: int): 326 | info = self._get_group_info(group_id) 327 | return {**info, "frame_id": frame_id} 328 | 329 | def get_frame_info(self, frame_id: int): 330 | group_id = self._frame_id_to_group_id(frame_id) 331 | return copy.deepcopy(self._get_frame_info(group_id, frame_id)) 332 | 333 | def _get_positive_prompt(self, group_id: int): 334 | prompt = self._base_pos.format(**self._get_group_info(group_id)) 335 | return prompt 336 | 337 | def _get_negative_prompt(self, group_id: int): 338 | prompt = self._base_neg.format(**self._get_group_info(group_id)) 339 | return prompt 340 | 341 | def _get_image_path(self, group_id: int, frame_id: int): 342 | relpath = self._image_pattern.format(**self._get_frame_info(group_id, frame_id)) 343 | return self._base_path / relpath 344 | 345 | def _get_group_frame_range(self, group_id: int) -> tuple[int, int | None]: 346 | start_frame_id: int = self._groups[group_id]["start_id"] 347 | 348 | if group_id < len(self._groups) - 1: 349 | # Not last group, last frame is the next group's start frame 350 | # Otherwise, must determine end frame ID dynamically 351 | return start_frame_id, self._groups[group_id + 1]["start_id"] 352 | else: 353 | return start_frame_id, None 354 | 355 | def _frame_id_to_group_id(self, frame_id: int): 356 | for i, group in enumerate(self._groups): 357 | if frame_id >= group["start_id"]: 358 | # frame ID is higher than this group 359 | continue 360 | 361 | # frame ID belongs to previous group 362 | if i == 0: 363 | raise ValueError(f"Frame ID {frame_id} is not covered by any group") 364 | 365 | return i - 1 366 | 367 | # return last group 368 | return len(self._groups) - 1 369 | 370 | def get_frame_image(self, frame_id: int): 371 | group_id = self._frame_id_to_group_id(frame_id) 372 | image_path = self._get_image_path(group_id, frame_id) 373 | 374 | img = load_image(image_path) 375 | filename = os.path.splitext(os.path.basename(image_path))[0] 376 | 377 | return img, filename 378 | 379 | def get_frame_prompts(self, frame_id: int): 380 | group_id = self._frame_id_to_group_id(frame_id) 381 | return self._get_positive_prompt(group_id), self._get_negative_prompt(group_id) 382 | 383 | def get_group_prompts(self, group_id: int): 384 | return self._get_positive_prompt(group_id), self._get_negative_prompt(group_id) 385 | 386 | def get_group_images(self, group_id: int): 387 | start_frame, end_frame = self._get_group_frame_range(group_id) 388 | 389 | images = [] 390 | filenames: list[str] = [] 391 | 392 | i = start_frame 393 | while True: 394 | image_path = self._get_image_path(group_id, i) 395 | 396 | # check for end of sequence 397 | if end_frame is not None and i >= end_frame: 398 | # reached end of sequence 399 | break 400 | elif end_frame is None and not os.path.exists(image_path): 401 | # unknown end frame, and this frame is missing 402 | # assume this is the end of sequence 403 | break 404 | 405 | try: 406 | img = load_image(image_path) 407 | 408 | images.append(img) 409 | filenames.append(image_path.stem) 410 | except FileNotFoundError as e: 411 | print(f"WARNING: Image missing from sequence: {image_path}") 412 | 413 | i += 1 414 | 415 | images = torch.cat(images, dim=0) 416 | 417 | # sanity check, image count == filename count 418 | assert len(images) == len(filenames) 419 | 420 | return images, filenames 421 | 422 | 423 | @register_node("GroupLoadBatchImages", "[DEPRECATED] Group Load Batch Images") 424 | class __: 425 | """ 426 | An opinionated batch image loader. This is used for loading groups for batch processing. 427 | 428 | "base_path" controls where the images are loaded relative from. Defaults to the 429 | folder containing the definition file. 430 | """ 431 | 432 | CATEGORY = "jamesWalker55" 433 | INPUT_TYPES = lambda: { 434 | "required": { 435 | "definition_path": ( 436 | "STRING", 437 | {"default": "./groups.yml", "multiline": False}, 438 | ), 439 | "group_id": ("INT", {"default": 1, "min": 0, "step": 1, "max": 9999}), 440 | "base_path": ("STRING", {"default": ""}), 441 | } 442 | } 443 | RETURN_NAMES = ( 444 | "POSITIVE_PROMPT", 445 | "NEGATIVE_PROMPT", 446 | "IMAGES", 447 | "FRAME_COUNT", 448 | "FILENAMES", 449 | "GROUP_INFO", 450 | ) 451 | RETURN_TYPES = ("STRING", "STRING", "IMAGE", "INT", "STRING", "GROUP_INFO") 452 | FUNCTION = "execute" 453 | 454 | def execute(self, definition_path: str, group_id: int, base_path: str): 455 | assert isinstance(definition_path, str) 456 | assert isinstance(group_id, int) 457 | assert isinstance(base_path, str) 458 | 459 | base_path = base_path.strip() 460 | if len(base_path) == 0: 461 | base_path = None 462 | 463 | workspace = GroupedWorkspace.open(definition_path, base_path=base_path) 464 | 465 | images, filenames = workspace.get_group_images(group_id) 466 | pos_prompt, neg_prompt = workspace.get_group_prompts(group_id) 467 | group_info = workspace.get_group_info(group_id) 468 | 469 | return ( 470 | pos_prompt, 471 | neg_prompt, 472 | images, 473 | len(filenames), 474 | "\n".join(filenames), 475 | group_info, 476 | ) 477 | 478 | 479 | @register_node("GroupLoadImage", "[DEPRECATED] Group Load Image") 480 | class _: 481 | """ 482 | An opinionated image loader. This is used for loading groups for batch processing. 483 | 484 | "base_path" controls where the images are loaded relative from. Defaults to the 485 | folder containing the definition file. 486 | """ 487 | 488 | CATEGORY = "jamesWalker55" 489 | INPUT_TYPES = lambda: { 490 | "required": { 491 | "definition_path": ( 492 | "STRING", 493 | {"default": "./groups.yml", "multiline": False}, 494 | ), 495 | "frame_id": ("INT", {"default": 1, "min": 0, "step": 1, "max": 9999}), 496 | "base_path": ("STRING", {"default": ""}), 497 | } 498 | } 499 | RETURN_NAMES = ( 500 | "POSITIVE_PROMPT", 501 | "NEGATIVE_PROMPT", 502 | "IMAGE", 503 | "FILENAME", 504 | "GROUP_INFO", 505 | ) 506 | RETURN_TYPES = ("STRING", "STRING", "IMAGE", "STRING", "GROUP_INFO") 507 | FUNCTION = "execute" 508 | 509 | def execute(self, definition_path: str, frame_id: int, base_path: str): 510 | assert isinstance(definition_path, str) 511 | assert isinstance(frame_id, int) 512 | assert isinstance(base_path, str) 513 | 514 | base_path = base_path.strip() 515 | if len(base_path) == 0: 516 | base_path = None 517 | 518 | workspace = GroupedWorkspace.open(definition_path, base_path=base_path) 519 | 520 | image, filename = workspace.get_frame_image(frame_id) 521 | pos_prompt, neg_prompt = workspace.get_frame_prompts(frame_id) 522 | group_info = workspace.get_frame_info(frame_id) 523 | 524 | return (pos_prompt, neg_prompt, image, filename, group_info) 525 | 526 | 527 | @register_node("GroupInfoExtractInt", "[DEPRECATED] Group Info Extract Integer") 528 | class _: 529 | CATEGORY = "jamesWalker55" 530 | INPUT_TYPES = lambda: { 531 | "required": { 532 | "group_info": ("GROUP_INFO",), 533 | "key": ("STRING", {"default": ""}), 534 | } 535 | } 536 | RETURN_TYPES = ("INT",) 537 | FUNCTION = "execute" 538 | 539 | def execute(self, group_info: dict, key: str): 540 | assert isinstance(group_info, dict) 541 | assert isinstance(key, str) 542 | 543 | val = int(group_info[key]) 544 | 545 | return (val,) 546 | 547 | 548 | @register_node("GroupInfoExtractFloat", "[DEPRECATED] Group Info Extract Float") 549 | class _: 550 | CATEGORY = "jamesWalker55" 551 | INPUT_TYPES = lambda: { 552 | "required": { 553 | "group_info": ("GROUP_INFO",), 554 | "key": ("STRING", {"default": ""}), 555 | } 556 | } 557 | RETURN_TYPES = ("FLOAT",) 558 | FUNCTION = "execute" 559 | 560 | def execute(self, group_info: dict, key: str): 561 | assert isinstance(group_info, dict) 562 | assert isinstance(key, str) 563 | 564 | val = float(group_info[key]) 565 | 566 | return (val,) 567 | -------------------------------------------------------------------------------- /comfyui_image_channel_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | NODE_CLASS_MAPPINGS = {} 4 | NODE_DISPLAY_NAME_MAPPINGS = {} 5 | 6 | 7 | def register_node(identifier: str, display_name: str): 8 | def decorator(cls): 9 | NODE_CLASS_MAPPINGS[identifier] = cls 10 | NODE_DISPLAY_NAME_MAPPINGS[identifier] = display_name 11 | 12 | return cls 13 | 14 | return decorator 15 | 16 | 17 | @register_node("JWImageStackChannels", "Image Stack Channels") 18 | class _: 19 | CATEGORY = "jamesWalker55" 20 | INPUT_TYPES = lambda: { 21 | "required": { 22 | "image_a": ("IMAGE",), 23 | "image_b": ("IMAGE",), 24 | } 25 | } 26 | RETURN_TYPES = ("IMAGE",) 27 | FUNCTION = "execute" 28 | 29 | def execute(self, image_a: torch.Tensor, image_b: torch.Tensor): 30 | assert isinstance(image_a, torch.Tensor) 31 | assert isinstance(image_b, torch.Tensor) 32 | 33 | stacked = torch.cat((image_a, image_b), dim=3) 34 | 35 | return (stacked,) 36 | -------------------------------------------------------------------------------- /comfyui_image_ops.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | from pathlib import Path 5 | from typing import Optional, Union 6 | 7 | import numpy as np 8 | import torch 9 | import torchvision.transforms.functional as F 10 | from PIL import Image, ImageGrab 11 | from PIL.PngImagePlugin import PngInfo 12 | from torchvision.transforms import InterpolationMode 13 | 14 | NODE_CLASS_MAPPINGS = {} 15 | NODE_DISPLAY_NAME_MAPPINGS = {} 16 | 17 | 18 | def register_node(identifier: str, display_name: str): 19 | def decorator(cls): 20 | NODE_CLASS_MAPPINGS[identifier] = cls 21 | NODE_DISPLAY_NAME_MAPPINGS[identifier] = display_name 22 | 23 | return cls 24 | 25 | return decorator 26 | 27 | 28 | def load_image(path, convert="RGB"): 29 | img = Image.open(path).convert(convert) 30 | img = np.array(img).astype(np.float32) / 255.0 31 | img = torch.from_numpy(img).unsqueeze(0) 32 | return img 33 | 34 | 35 | def save_image(img: torch.Tensor, path, prompt=None, extra_pnginfo: dict = None): 36 | path = str(path) 37 | 38 | if len(img.shape) != 3: 39 | raise ValueError(f"can't take image batch as input, got {img.shape[0]} images") 40 | 41 | img = img.permute(2, 0, 1) 42 | if img.shape[0] not in (3, 4): 43 | raise ValueError( 44 | f"image must have 3 or 4 channels, but got {img.shape[0]} channels" 45 | ) 46 | 47 | img = img.clamp(0, 1) 48 | img = F.to_pil_image(img) 49 | 50 | metadata = PngInfo() 51 | 52 | if prompt is not None: 53 | metadata.add_text("prompt", json.dumps(prompt)) 54 | 55 | if extra_pnginfo is not None: 56 | for k, v in extra_pnginfo.items(): 57 | metadata.add_text(k, json.dumps(v)) 58 | 59 | img.save(path, pnginfo=metadata, compress_level=4) 60 | 61 | subfolder, filename = os.path.split(path) 62 | 63 | return {"filename": filename, "subfolder": subfolder, "type": "output"} 64 | 65 | 66 | @register_node("JWImageLoadRGB", "Image Load RGB") 67 | class _: 68 | CATEGORY = "jamesWalker55" 69 | INPUT_TYPES = lambda: { 70 | "required": { 71 | "path": ("STRING", {"default": "./image.png"}), 72 | } 73 | } 74 | RETURN_TYPES = ("IMAGE",) 75 | FUNCTION = "execute" 76 | 77 | def execute(self, path: str): 78 | assert isinstance(path, str) 79 | 80 | img = load_image(path) 81 | return (img,) 82 | 83 | 84 | @register_node("JWImageLoadRGBA", "Image Load RGBA") 85 | class _: 86 | CATEGORY = "jamesWalker55" 87 | INPUT_TYPES = lambda: { 88 | "required": { 89 | "path": ("STRING", {"default": "./image.png"}), 90 | } 91 | } 92 | RETURN_TYPES = ("IMAGE", "MASK") 93 | FUNCTION = "execute" 94 | 95 | def execute(self, path: str): 96 | assert isinstance(path, str) 97 | 98 | img = load_image(path, convert="RGBA") 99 | color = img[:, :, :, 0:3] 100 | mask = img[0, :, :, 3] 101 | mask = 1 - mask # invert mask 102 | 103 | return (color, mask) 104 | 105 | 106 | @register_node("JWLoadImagesFromString", "Load Images From String") 107 | class _: 108 | CATEGORY = "jamesWalker55" 109 | INPUT_TYPES = lambda: { 110 | "required": { 111 | "paths": ( 112 | "STRING", 113 | { 114 | "default": "./frame000001.png\n./frame000002.png\n./frame000003.png", 115 | "multiline": True, 116 | "dynamicPrompts": False, 117 | }, 118 | ), 119 | "ignore_missing_images": (("false", "true"), {"default": "false"}), 120 | } 121 | } 122 | RETURN_TYPES = ("IMAGE",) 123 | FUNCTION = "execute" 124 | 125 | def execute(self, paths, ignore_missing_images: str): 126 | assert isinstance(paths, str) 127 | assert isinstance(ignore_missing_images, str) 128 | 129 | ignore_missing_images: bool = ignore_missing_images == "true" 130 | 131 | paths = [p.strip() for p in paths.splitlines()] 132 | paths = [p for p in paths if len(p) != 0] 133 | 134 | if ignore_missing_images: 135 | # remove missing images 136 | paths = [p for p in paths if os.path.exists(p)] 137 | else: 138 | # early check for missing images 139 | for path in paths: 140 | if not os.path.exists(path): 141 | raise FileNotFoundError(f"Image does not exist: {path}") 142 | 143 | if len(paths) == 0: 144 | raise RuntimeError("Image sequence empty - no images to load") 145 | 146 | imgs = [] 147 | for path in paths: 148 | img = load_image(path) 149 | # img.shape => torch.Size([1, 768, 768, 3]) 150 | imgs.append(img) 151 | 152 | imgs = torch.cat(imgs, dim=0) 153 | 154 | return (imgs,) 155 | 156 | 157 | @register_node("JWImageSaveToPath", "Image Save To Path") 158 | class _: 159 | CATEGORY = "jamesWalker55" 160 | INPUT_TYPES = lambda: { 161 | "required": { 162 | "path": ("STRING", {"default": "./image.png"}), 163 | "image": ("IMAGE",), 164 | "overwrite": (("false", "true"), {"default": "true"}), 165 | }, 166 | "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, 167 | } 168 | RETURN_TYPES = () 169 | OUTPUT_NODE = True 170 | FUNCTION = "execute" 171 | 172 | def execute( 173 | self, 174 | path: str, 175 | image: torch.Tensor, 176 | overwrite: str, 177 | prompt=None, 178 | extra_pnginfo=None, 179 | ): 180 | assert isinstance(path, str) 181 | assert isinstance(image, torch.Tensor) 182 | assert isinstance(overwrite, str) 183 | 184 | overwrite: bool = overwrite == "true" 185 | 186 | path: Path = Path(path) 187 | if not overwrite and path.exists(): 188 | return () 189 | 190 | path.parent.mkdir(exist_ok=True) 191 | 192 | if image.shape[0] == 1: 193 | # batch has 1 image only 194 | save_image( 195 | image[0], 196 | path, 197 | prompt=prompt, 198 | extra_pnginfo=extra_pnginfo, 199 | ) 200 | else: 201 | # batch has multiple images 202 | for i, img in enumerate(image): 203 | subpath = path.with_stem(f"{path.stem}-{i}") 204 | save_image( 205 | img, 206 | subpath, 207 | prompt=prompt, 208 | extra_pnginfo=extra_pnginfo, 209 | ) 210 | 211 | return () 212 | 213 | 214 | @register_node("JWImageExtractFromBatch", "Image Extract From Batch") 215 | class _: 216 | CATEGORY = "jamesWalker55" 217 | INPUT_TYPES = lambda: { 218 | "required": { 219 | "images": ("IMAGE",), 220 | "index": ("INT", {"default": 0, "min": 0, "step": 1}), 221 | } 222 | } 223 | RETURN_TYPES = ("IMAGE",) 224 | FUNCTION = "execute" 225 | 226 | def execute(self, images: torch.Tensor, index: int): 227 | assert isinstance(images, torch.Tensor) 228 | assert isinstance(index, int) 229 | 230 | img = images[index].unsqueeze(0) 231 | 232 | return (img,) 233 | 234 | 235 | @register_node("JWImageBatchCount", "Get Image Batch Count") 236 | class _: 237 | CATEGORY = "jamesWalker55" 238 | INPUT_TYPES = lambda: { 239 | "required": { 240 | "images": ("IMAGE",), 241 | } 242 | } 243 | RETURN_TYPES = ("INT",) 244 | FUNCTION = "execute" 245 | 246 | def execute(self, images: torch.Tensor): 247 | assert isinstance(images, torch.Tensor) 248 | 249 | batch_count = len(images) 250 | 251 | return (batch_count,) 252 | 253 | 254 | @register_node("JWImageResize", "Image Resize") 255 | class _: 256 | CATEGORY = "jamesWalker55" 257 | INPUT_TYPES = lambda: { 258 | "required": { 259 | "image": ("IMAGE",), 260 | "height": ("INT", {"default": 512, "min": 0, "step": 1, "max": 99999}), 261 | "width": ("INT", {"default": 512, "min": 0, "step": 1, "max": 99999}), 262 | "interpolation_mode": ( 263 | ["bicubic", "bilinear", "nearest", "nearest exact"], 264 | ), 265 | } 266 | } 267 | RETURN_TYPES = ("IMAGE",) 268 | FUNCTION = "execute" 269 | 270 | def execute( 271 | self, 272 | image: torch.Tensor, 273 | width: int, 274 | height: int, 275 | interpolation_mode: str, 276 | ): 277 | assert isinstance(image, torch.Tensor) 278 | assert isinstance(height, int) 279 | assert isinstance(width, int) 280 | assert isinstance(interpolation_mode, str) 281 | 282 | interpolation_mode = interpolation_mode.upper().replace(" ", "_") 283 | interpolation_mode = getattr(InterpolationMode, interpolation_mode) 284 | 285 | image = image.permute(0, 3, 1, 2) 286 | image = F.resize( 287 | image, 288 | (height, width), 289 | interpolation=interpolation_mode, 290 | antialias=True, 291 | ) 292 | image = image.permute(0, 2, 3, 1) 293 | 294 | return (image,) 295 | 296 | 297 | @register_node("JWImageFlip", "Image Flip") 298 | class _: 299 | CATEGORY = "jamesWalker55" 300 | INPUT_TYPES = lambda: { 301 | "required": { 302 | "image": ("IMAGE",), 303 | "direction": (("horizontal", "vertical"), {"default": "horizontal"}), 304 | } 305 | } 306 | RETURN_TYPES = ("IMAGE",) 307 | FUNCTION = "execute" 308 | 309 | def execute( 310 | self, 311 | image: torch.Tensor, 312 | direction: str, 313 | ): 314 | assert isinstance(image, torch.Tensor) 315 | assert direction in ("horizontal", "vertical") 316 | 317 | image = image.permute(0, 3, 1, 2) 318 | if direction == "horizontal": 319 | image = F.hflip(image) 320 | else: 321 | image = F.vflip(image) 322 | image = image.permute(0, 2, 3, 1) 323 | 324 | return (image,) 325 | 326 | 327 | @register_node("JWMaskResize", "Mask Resize") 328 | class _: 329 | CATEGORY = "jamesWalker55" 330 | INPUT_TYPES = lambda: { 331 | "required": { 332 | "mask": ("MASK",), 333 | "height": ("INT", {"default": 512, "min": 0, "step": 1, "max": 99999}), 334 | "width": ("INT", {"default": 512, "min": 0, "step": 1, "max": 99999}), 335 | "interpolation_mode": ( 336 | ["bicubic", "bilinear", "nearest", "nearest exact"], 337 | ), 338 | } 339 | } 340 | RETURN_TYPES = ("MASK",) 341 | FUNCTION = "execute" 342 | 343 | def execute( 344 | self, 345 | mask: torch.Tensor, 346 | width: int, 347 | height: int, 348 | interpolation_mode: str, 349 | ): 350 | assert isinstance(mask, torch.Tensor) 351 | assert isinstance(height, int) 352 | assert isinstance(width, int) 353 | assert isinstance(interpolation_mode, str) 354 | 355 | interpolation_mode = interpolation_mode.upper().replace(" ", "_") 356 | interpolation_mode = getattr(InterpolationMode, interpolation_mode) 357 | 358 | mask = mask.unsqueeze(0) 359 | mask = F.resize( 360 | mask, 361 | (height, width), 362 | interpolation=interpolation_mode, 363 | antialias=True, 364 | ) 365 | mask = mask[0] 366 | 367 | return (mask,) 368 | 369 | 370 | @register_node("JWMaskLikeImageSize", "Mask Like Image Size") 371 | class _: 372 | CATEGORY = "jamesWalker55" 373 | INPUT_TYPES = lambda: { 374 | "required": { 375 | "image": ("IMAGE",), 376 | "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), 377 | } 378 | } 379 | RETURN_TYPES = ("MASK",) 380 | FUNCTION = "execute" 381 | 382 | def execute( 383 | self, 384 | image: torch.Tensor, 385 | value: float, 386 | ): 387 | assert isinstance(image, torch.Tensor) 388 | assert isinstance(value, float) 389 | 390 | _, h, w, _ = image.shape 391 | mask_shape = (h, w) 392 | # code copied from: 393 | # comfy_extras\nodes_mask.py 394 | mask = torch.full(mask_shape, value, dtype=torch.float32, device="cpu") 395 | 396 | return (mask,) 397 | 398 | 399 | @register_node("JWImageResizeToSquare", "Image Resize to Square") 400 | class _: 401 | CATEGORY = "jamesWalker55" 402 | INPUT_TYPES = lambda: { 403 | "required": { 404 | "image": ("IMAGE",), 405 | "size": ("INT", {"default": 512, "min": 0, "step": 1, "max": 99999}), 406 | "interpolation_mode": ( 407 | ["bicubic", "bilinear", "nearest", "nearest exact"], 408 | ), 409 | } 410 | } 411 | RETURN_TYPES = ("IMAGE",) 412 | FUNCTION = "execute" 413 | 414 | def execute( 415 | self, 416 | image: torch.Tensor, 417 | size: int, 418 | interpolation_mode: str, 419 | ): 420 | assert isinstance(image, torch.Tensor) 421 | assert isinstance(size, int) 422 | assert isinstance(interpolation_mode, str) 423 | 424 | interpolation_mode = interpolation_mode.upper().replace(" ", "_") 425 | interpolation_mode = getattr(InterpolationMode, interpolation_mode) 426 | 427 | image = image.permute(0, 3, 1, 2) 428 | image = F.resize( 429 | image, 430 | (size, size), 431 | interpolation=interpolation_mode, 432 | antialias=True, 433 | ) 434 | image = image.permute(0, 2, 3, 1) 435 | 436 | return (image,) 437 | 438 | 439 | @register_node("JWImageResizeByFactor", "Image Resize by Factor") 440 | class _: 441 | CATEGORY = "jamesWalker55" 442 | INPUT_TYPES = lambda: { 443 | "required": { 444 | "image": ("IMAGE",), 445 | "factor": ("FLOAT", {"default": 1, "min": 0, "step": 0.01, "max": 99999}), 446 | "interpolation_mode": ( 447 | ["bicubic", "bilinear", "nearest", "nearest exact"], 448 | ), 449 | } 450 | } 451 | RETURN_TYPES = ("IMAGE",) 452 | FUNCTION = "execute" 453 | 454 | def execute( 455 | self, 456 | image: torch.Tensor, 457 | factor: float, 458 | interpolation_mode: str, 459 | ): 460 | assert isinstance(image, torch.Tensor) 461 | assert isinstance(factor, float) 462 | assert isinstance(interpolation_mode, str) 463 | 464 | interpolation_mode = interpolation_mode.upper().replace(" ", "_") 465 | interpolation_mode = getattr(InterpolationMode, interpolation_mode) 466 | 467 | new_height = round(image.shape[1] * factor) 468 | new_width = round(image.shape[2] * factor) 469 | 470 | image = image.permute(0, 3, 1, 2) 471 | image = F.resize( 472 | image, 473 | (new_height, new_width), 474 | interpolation=interpolation_mode, 475 | antialias=True, 476 | ) 477 | image = image.permute(0, 2, 3, 1) 478 | 479 | return (image,) 480 | 481 | 482 | @register_node("JWImageResizeByShorterSide", "Image Resize by Shorter Side") 483 | class _: 484 | CATEGORY = "jamesWalker55" 485 | INPUT_TYPES = lambda: { 486 | "required": { 487 | "image": ("IMAGE",), 488 | "size": ("INT", {"default": 512, "min": 0, "step": 1, "max": 99999}), 489 | "interpolation_mode": ( 490 | ["bicubic", "bilinear", "nearest", "nearest exact"], 491 | ), 492 | } 493 | } 494 | RETURN_TYPES = ("IMAGE",) 495 | FUNCTION = "execute" 496 | 497 | def execute( 498 | self, 499 | image: torch.Tensor, 500 | size: int, 501 | interpolation_mode: str, 502 | ): 503 | assert isinstance(image, torch.Tensor) 504 | assert isinstance(size, int) 505 | assert isinstance(interpolation_mode, str) 506 | 507 | interpolation_mode = interpolation_mode.upper().replace(" ", "_") 508 | interpolation_mode = getattr(InterpolationMode, interpolation_mode) 509 | 510 | image = image.permute(0, 3, 1, 2) 511 | image = F.resize( 512 | image, 513 | size, 514 | interpolation=interpolation_mode, 515 | antialias=True, 516 | ) 517 | image = image.permute(0, 2, 3, 1) 518 | 519 | return (image,) 520 | 521 | 522 | @register_node("JWImageResizeByLongerSide", "Image Resize by Longer Side") 523 | class _: 524 | CATEGORY = "jamesWalker55" 525 | INPUT_TYPES = lambda: { 526 | "required": { 527 | "image": ("IMAGE",), 528 | "size": ("INT", {"default": 512, "min": 0, "step": 1, "max": 99999}), 529 | "interpolation_mode": ( 530 | ["bicubic", "bilinear", "nearest", "nearest exact"], 531 | ), 532 | } 533 | } 534 | RETURN_TYPES = ("IMAGE",) 535 | FUNCTION = "execute" 536 | 537 | def execute( 538 | self, 539 | image: torch.Tensor, 540 | size: int, 541 | interpolation_mode: str, 542 | ): 543 | assert isinstance(image, torch.Tensor) 544 | assert isinstance(size, int) 545 | assert isinstance(interpolation_mode, str) 546 | 547 | interpolation_mode = interpolation_mode.upper().replace(" ", "_") 548 | interpolation_mode = getattr(InterpolationMode, interpolation_mode) 549 | 550 | _, h, w, _ = image.shape 551 | 552 | if h >= w: 553 | new_h = size 554 | new_w = round(w * new_h / h) 555 | else: # h < w 556 | new_w = size 557 | new_h = round(h * new_w / w) 558 | 559 | image = image.permute(0, 3, 1, 2) 560 | image = F.resize( 561 | image, 562 | (new_h, new_w), 563 | interpolation=interpolation_mode, 564 | antialias=True, 565 | ) 566 | image = image.permute(0, 2, 3, 1) 567 | 568 | return (image,) 569 | 570 | 571 | @register_node( 572 | "JWImageResizeToClosestSDXLResolution", "Image Resize to Closest SDXL Resolution" 573 | ) 574 | class _: 575 | CATEGORY = "jamesWalker55" 576 | INPUT_TYPES = lambda: { 577 | "required": { 578 | "image": ("IMAGE",), 579 | "interpolation_mode": ( 580 | ["bicubic", "bilinear", "nearest", "nearest exact"], 581 | ), 582 | } 583 | } 584 | RETURN_TYPES = ("IMAGE", "INT", "INT") 585 | RETURN_NAMES = ("IMAGE", "WIDTH", "HEIGHT") 586 | FUNCTION = "execute" 587 | 588 | # tuples of (height x width) 589 | SDXL_RESOLUTIONS = ( 590 | (1024, 1024), 591 | (1152, 896), 592 | (896, 1152), 593 | (1216, 832), 594 | (832, 1216), 595 | (1344, 768), 596 | (768, 1344), 597 | (1536, 640), 598 | (640, 1536), 599 | ) 600 | 601 | @staticmethod 602 | def compare_fn(img_w: int, img_h: int, resolution: tuple[int, int]): 603 | img_deg = math.atan(img_h / img_w) 604 | xl_deg = math.atan(resolution[0] / resolution[1]) 605 | return abs(img_deg - xl_deg) 606 | 607 | def execute( 608 | self, 609 | image: torch.Tensor, 610 | interpolation_mode: str, 611 | ): 612 | interpolation_mode = interpolation_mode.upper().replace(" ", "_") 613 | interpolation_mode = getattr(InterpolationMode, interpolation_mode) 614 | 615 | _, h, w, _ = image.shape 616 | 617 | closest_resolution = min( 618 | self.SDXL_RESOLUTIONS, key=lambda res: self.compare_fn(w, h, res) 619 | ) 620 | 621 | image = image.permute(0, 3, 1, 2) 622 | image = F.resize( 623 | image, 624 | closest_resolution, # type: ignore 625 | interpolation=interpolation_mode, # type: ignore 626 | antialias=True, 627 | ) 628 | image = image.permute(0, 2, 3, 1) 629 | 630 | return (image, closest_resolution[1], closest_resolution[0]) 631 | 632 | 633 | @register_node( 634 | "JWImageCropToClosestSDXLResolution", "Image Crop to Closest SDXL Resolution" 635 | ) 636 | class _: 637 | CATEGORY = "jamesWalker55" 638 | INPUT_TYPES = lambda: { 639 | "required": { 640 | "image": ("IMAGE",), 641 | "interpolation_mode": ( 642 | ["bicubic", "bilinear", "nearest", "nearest exact"], 643 | ), 644 | } 645 | } 646 | RETURN_TYPES = ("IMAGE", "INT", "INT") 647 | RETURN_NAMES = ("IMAGE", "WIDTH", "HEIGHT") 648 | FUNCTION = "execute" 649 | 650 | # tuples of (height x width) 651 | SDXL_RESOLUTIONS = ( 652 | (1024, 1024), 653 | (1152, 896), 654 | (896, 1152), 655 | (1216, 832), 656 | (832, 1216), 657 | (1344, 768), 658 | (768, 1344), 659 | (1536, 640), 660 | (640, 1536), 661 | ) 662 | 663 | @staticmethod 664 | def angle(w: int, h: int): 665 | return math.atan(h / w) 666 | 667 | @staticmethod 668 | def compare_fn(img_w: int, img_h: int, resolution: tuple[int, int]): 669 | img_deg = math.atan(img_h / img_w) 670 | xl_deg = math.atan(resolution[0] / resolution[1]) 671 | return abs(img_deg - xl_deg) 672 | 673 | def execute( 674 | self, 675 | image: torch.Tensor, 676 | interpolation_mode: str, 677 | ): 678 | interpolation_mode = interpolation_mode.upper().replace(" ", "_") 679 | interpolation_mode = getattr(InterpolationMode, interpolation_mode) 680 | 681 | _, h, w, _ = image.shape 682 | 683 | closest_resolution = min( 684 | self.SDXL_RESOLUTIONS, key=lambda res: self.compare_fn(w, h, res) 685 | ) 686 | 687 | img_deg = self.angle(w, h) 688 | target_deg = self.angle(closest_resolution[1], closest_resolution[0]) 689 | 690 | if img_deg > target_deg: 691 | # image is taller and narrower than target 692 | w_scaled = closest_resolution[1] 693 | h_scaled = max(round(closest_resolution[1] / w * h), 0) 694 | else: 695 | # image is wider and shorter than target 696 | h_scaled = closest_resolution[0] 697 | w_scaled = max(round(closest_resolution[0] / h * w), 0) 698 | 699 | scaled_deg = self.angle(w_scaled, h_scaled) 700 | print(f"{[h, w] = }") 701 | print(f"{closest_resolution = }") 702 | print(f"{[h_scaled, w_scaled] = }") 703 | print(f"{img_deg = }") 704 | print(f"{target_deg = }") 705 | print(f"{scaled_deg = }") 706 | 707 | image = image.permute(0, 3, 1, 2) 708 | image = F.resize( 709 | image, 710 | [h_scaled, w_scaled], 711 | interpolation=interpolation_mode, # type: ignore 712 | antialias=True, 713 | ) 714 | image = F.center_crop( 715 | image, 716 | closest_resolution, # type: ignore 717 | ) 718 | image = image.permute(0, 2, 3, 1) 719 | 720 | return (image, closest_resolution[1], closest_resolution[0]) 721 | 722 | 723 | def get_image_from_clipboard(rgba=False) -> Optional[torch.Tensor]: 724 | rv = ImageGrab.grabclipboard() 725 | if rv is None: 726 | return None 727 | 728 | if isinstance(rv, list): 729 | if len(rv) == 0: 730 | return None 731 | 732 | img = Image.open(rv[0]).convert("RGBA" if rgba else "RGB") 733 | else: 734 | # rv is some kind of image 735 | img = rv.convert("RGBA" if rgba else "RGB") 736 | 737 | img = np.array(img).astype(np.float32) / 255.0 738 | img = torch.from_numpy(img).unsqueeze(0) 739 | 740 | return img 741 | 742 | 743 | @register_node("JWImageLoadRGBFromClipboard", "Image Load RGB From Clipboard") 744 | class _: 745 | CATEGORY = "jamesWalker55" 746 | INPUT_TYPES = lambda: {"required": {}} 747 | RETURN_TYPES = ("IMAGE",) 748 | FUNCTION = "execute" 749 | 750 | def execute(self): 751 | img = get_image_from_clipboard(rgba=False) 752 | if img is None: 753 | raise ValueError(f"failed to get image from clipboard") 754 | return (img,) 755 | 756 | def IS_CHANGED(self, *args): 757 | # This value will be compared with previous 'IS_CHANGED' outputs 758 | # If inequal, then this node will be considered as modified 759 | return get_image_from_clipboard(rgba=False) 760 | 761 | 762 | @register_node("JWImageLoadRGBA From Clipboard", "Image Load RGBA From Clipboard") 763 | class _: 764 | CATEGORY = "jamesWalker55" 765 | INPUT_TYPES = lambda: {"required": {}} 766 | RETURN_TYPES = ("IMAGE", "MASK") 767 | FUNCTION = "execute" 768 | 769 | def execute(self): 770 | img = get_image_from_clipboard(rgba=True) 771 | if img is None: 772 | raise ValueError(f"failed to get image from clipboard") 773 | 774 | color = img[:, :, :, 0:3] 775 | mask = img[0, :, :, 3] 776 | mask = 1 - mask # invert mask 777 | 778 | return (color, mask) 779 | 780 | def IS_CHANGED(self, *args): 781 | # This value will be compared with previous 'IS_CHANGED' outputs 782 | # If inequal, then this node will be considered as modified 783 | return get_image_from_clipboard(rgba=True) 784 | -------------------------------------------------------------------------------- /comfyui_image_sequence.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | from PIL.PngImagePlugin import PngInfo 9 | 10 | NODE_CLASS_MAPPINGS = {} 11 | NODE_DISPLAY_NAME_MAPPINGS = {} 12 | 13 | 14 | def register_node(identifier: str, display_name: str): 15 | def decorator(cls): 16 | NODE_CLASS_MAPPINGS[identifier] = cls 17 | NODE_DISPLAY_NAME_MAPPINGS[identifier] = display_name 18 | 19 | return cls 20 | 21 | return decorator 22 | 23 | 24 | def load_image(path): 25 | img = Image.open(path).convert("RGB") 26 | img = np.array(img).astype(np.float32) / 255.0 27 | img = torch.from_numpy(img).unsqueeze(0) 28 | return img 29 | 30 | 31 | @register_node("JWLoadImageSequence", "Batch Load Image Sequence") 32 | class _: 33 | CATEGORY = "jamesWalker55" 34 | INPUT_TYPES = lambda: { 35 | "required": { 36 | "path_pattern": ( 37 | "STRING", 38 | {"default": "./frame{:06d}.png", "multiline": False}, 39 | ), 40 | "start_index": ("INT", {"default": 0, "min": 0, "step": 1}), 41 | "frame_count": ("INT", {"default": 16, "min": 1, "step": 1}), 42 | "ignore_missing_images": (("false", "true"), {"default": "false"}), 43 | } 44 | } 45 | RETURN_TYPES = ("IMAGE",) 46 | FUNCTION = "execute" 47 | 48 | def execute( 49 | self, 50 | path_pattern: str, 51 | start_index: int, 52 | frame_count: int, 53 | ignore_missing_images: str, 54 | ): 55 | ignore_missing_images: bool = ignore_missing_images == "true" 56 | 57 | # generate image paths to load 58 | image_paths = [] 59 | for i in range(start_index, start_index + frame_count): 60 | try: 61 | image_paths.append(path_pattern.format(i)) 62 | except KeyError: 63 | image_paths.append(path_pattern.format(i=i)) 64 | 65 | if ignore_missing_images: 66 | # remove missing images 67 | image_paths = [p for p in image_paths if os.path.exists(p)] 68 | else: 69 | # early check for missing images 70 | for path in image_paths: 71 | if not os.path.exists(path): 72 | raise FileNotFoundError(f"Image does not exist: {path}") 73 | 74 | if len(image_paths) == 0: 75 | raise RuntimeError("Image sequence empty - no images to load") 76 | 77 | imgs = [] 78 | for path in image_paths: 79 | img = load_image(path) 80 | # img.shape => torch.Size([1, 768, 768, 3]) 81 | imgs.append(img) 82 | 83 | imgs = torch.cat(imgs, dim=0) 84 | 85 | return (imgs,) 86 | 87 | 88 | @register_node( 89 | "JWLoadImageSequenceWithStopIndex", 90 | "Batch Load Image Sequence With Stop Index", 91 | ) 92 | class _: 93 | CATEGORY = "jamesWalker55" 94 | INPUT_TYPES = lambda: { 95 | "required": { 96 | "path_pattern": ( 97 | "STRING", 98 | {"default": "./frame{:06d}.png", "multiline": False}, 99 | ), 100 | "start_index": ("INT", {"default": 0, "min": 0, "step": 1, "max": 999999}), 101 | "stop_index": ("INT", {"default": 16, "min": 1, "step": 1, "max": 999999}), 102 | "inclusive": (("false", "true"), {"default": "false"}), 103 | "ignore_missing_images": (("false", "true"), {"default": "false"}), 104 | } 105 | } 106 | RETURN_TYPES = ("IMAGE",) 107 | FUNCTION = "execute" 108 | 109 | def execute( 110 | self, 111 | path_pattern: str, 112 | start_index: int, 113 | stop_index: int, 114 | inclusive: str, 115 | ignore_missing_images: str, 116 | ): 117 | inclusive: bool = inclusive == "true" 118 | ignore_missing_images: bool = ignore_missing_images == "true" 119 | 120 | # generate image paths to load 121 | image_paths = [] 122 | for i in range(start_index, stop_index + 1 if inclusive else stop_index): 123 | try: 124 | image_paths.append(path_pattern.format(i)) 125 | except KeyError: 126 | image_paths.append(path_pattern.format(i=i)) 127 | 128 | if ignore_missing_images: 129 | # remove missing images 130 | image_paths = [p for p in image_paths if os.path.exists(p)] 131 | else: 132 | # early check for missing images 133 | for path in image_paths: 134 | if not os.path.exists(path): 135 | raise FileNotFoundError(f"Image does not exist: {path}") 136 | 137 | if len(image_paths) == 0: 138 | raise RuntimeError("Image sequence empty - no images to load") 139 | 140 | imgs = [] 141 | for path in image_paths: 142 | img = load_image(path) 143 | # img.shape => torch.Size([1, 768, 768, 3]) 144 | imgs.append(img) 145 | 146 | imgs = torch.cat(imgs, dim=0) 147 | 148 | return (imgs,) 149 | 150 | 151 | def generate_non_conflicting_path(path: Path): 152 | if not path.exists(): 153 | return path 154 | 155 | i = -1 156 | while True: 157 | i += 1 158 | new_path = path.with_stem(f"{path.stem}-{i}") 159 | if new_path.exists(): 160 | continue 161 | 162 | return new_path 163 | 164 | 165 | def save_image(img: torch.Tensor, path, prompt=None, extra_pnginfo: dict = None): 166 | path = str(path) 167 | 168 | img = 255.0 * img.cpu().numpy() 169 | img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8)) 170 | 171 | metadata = PngInfo() 172 | 173 | if prompt is not None: 174 | metadata.add_text("prompt", json.dumps(prompt)) 175 | 176 | if extra_pnginfo is not None: 177 | for k, v in extra_pnginfo.items(): 178 | metadata.add_text(k, json.dumps(v)) 179 | 180 | img.save(path, pnginfo=metadata, compress_level=4) 181 | 182 | 183 | @register_node("JWImageSequenceExtractFromBatch", "Extract Image Sequence From Batch") 184 | class _: 185 | CATEGORY = "jamesWalker55" 186 | INPUT_TYPES = lambda: { 187 | "required": { 188 | "images": ("IMAGE",), 189 | "i_start": ("INT", {"default": 0, "min": 0, "step": 1}), 190 | "i_stop": ("INT", {"default": 0, "min": 0, "step": 1}), 191 | "inclusive": (("false", "true"), {"default": "false"}), 192 | } 193 | } 194 | RETURN_TYPES = ("IMAGE",) 195 | FUNCTION = "execute" 196 | 197 | def execute(self, images: torch.Tensor, i_start: int, i_stop: int, inclusive: str): 198 | assert isinstance(images, torch.Tensor) 199 | assert isinstance(i_start, int) 200 | assert isinstance(i_stop, int) 201 | assert isinstance(inclusive, str) 202 | 203 | inclusive: bool = inclusive == "true" 204 | 205 | img = images[i_start : i_stop + 1 if inclusive else i_stop] 206 | 207 | return (img,) 208 | 209 | 210 | @register_node("JWSaveImageSequence", "Batch Save Image Sequence") 211 | class _: 212 | CATEGORY = "jamesWalker55" 213 | INPUT_TYPES = lambda: { 214 | "required": { 215 | "images": ("IMAGE",), 216 | "path_pattern": ( 217 | "STRING", 218 | {"default": "./frame{:06d}.png", "multiline": False}, 219 | ), 220 | "start_index": ("INT", {"default": 0, "min": 0, "step": 1}), 221 | "overwrite": (("false", "true"), {"default": "true"}), 222 | }, 223 | "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, 224 | } 225 | RETURN_TYPES = () 226 | OUTPUT_NODE = True 227 | FUNCTION = "execute" 228 | 229 | def execute( 230 | self, 231 | images: torch.Tensor, 232 | path_pattern: str, 233 | start_index: int, 234 | overwrite: str, 235 | prompt=None, 236 | extra_pnginfo=None, 237 | ): 238 | overwrite: bool = overwrite == "true" 239 | 240 | image_range = range(start_index, start_index + len(images)) 241 | 242 | for i, img in zip(image_range, images): 243 | try: 244 | path = Path(path_pattern.format(i)) 245 | except KeyError: 246 | path = Path(path_pattern.format(i=i)) 247 | 248 | # Create containing folder for output path 249 | path.parent.mkdir(exist_ok=True) 250 | 251 | if not overwrite and path.exists(): 252 | print("JWSaveImageSequence: [WARNING]") 253 | print(f"JWSaveImageSequence: Image already exists: {path}") 254 | path = generate_non_conflicting_path(path) 255 | print(f"JWSaveImageSequence: Saving to new path instead: {path}") 256 | 257 | save_image( 258 | img, 259 | path, 260 | prompt=prompt, 261 | extra_pnginfo=extra_pnginfo, 262 | ) 263 | 264 | return () 265 | 266 | 267 | @register_node("JWLoopImageSequence", "Loop Image Sequence") 268 | class LoopImageSequence: 269 | CATEGORY = "jamesWalker55" 270 | INPUT_TYPES = lambda: { 271 | "required": { 272 | "images": ("IMAGE",), 273 | "target_frames": ("INT", {"default": 16, "step": 1}), 274 | } 275 | } 276 | RETURN_TYPES = ("IMAGE",) 277 | FUNCTION = "execute" 278 | 279 | def execute(self, images: torch.Tensor, target_frames: int): 280 | if len(images) > target_frames: 281 | images = images[0:target_frames] 282 | elif len(images) < target_frames: 283 | to_cat = [] 284 | 285 | for _ in range(target_frames // len(images)): 286 | to_cat.append(images) 287 | 288 | extra_frames = target_frames % len(images) 289 | if extra_frames > 0: 290 | to_cat.append(images[0:extra_frames]) 291 | 292 | images = torch.cat(to_cat, dim=0) 293 | assert len(images) == target_frames 294 | else: # len(images) == target_frames 295 | pass 296 | 297 | return (images,) 298 | -------------------------------------------------------------------------------- /comfyui_info_hash.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import glob 3 | import os 4 | import textwrap 5 | from pathlib import Path 6 | from typing import Any, TypedDict 7 | 8 | import numpy as np 9 | import torch 10 | import yaml 11 | from PIL import Image 12 | 13 | NODE_CLASS_MAPPINGS = {} 14 | NODE_DISPLAY_NAME_MAPPINGS = {} 15 | 16 | 17 | def register_node(identifier: str, display_name: str): 18 | def decorator(cls): 19 | NODE_CLASS_MAPPINGS[identifier] = cls 20 | NODE_DISPLAY_NAME_MAPPINGS[identifier] = display_name 21 | 22 | return cls 23 | 24 | return decorator 25 | 26 | 27 | def load_image(path, convert="RGB"): 28 | img = Image.open(path).convert(convert) 29 | img = np.array(img).astype(np.float32) / 255.0 30 | img = torch.from_numpy(img).unsqueeze(0) 31 | return img 32 | 33 | 34 | class RangedConfig: 35 | def __init__(self, definition: dict[str, Any], range_key: str = "ranges") -> None: 36 | self.definition = definition 37 | self.range_key = range_key 38 | self._validate() 39 | 40 | def _validate(self): 41 | for k, v in self.definition.items(): 42 | if k == self.range_key: 43 | assert isinstance(v, dict), f"{type(v)!r}" 44 | for kk, vv in v.items(): 45 | # sub prompts ranges 46 | assert isinstance(kk, int), f"{type(kk)!r}" 47 | assert isinstance(vv, dict), f"{type(vv)!r}" 48 | for kkk, vvv in vv.items(): 49 | # actual sub prompts 50 | assert isinstance(kkk, str), f"{type(kkk)!r}" 51 | if vvv is None: 52 | vvv = "" 53 | assert isinstance(vvv, (int, float, str)), f"{type(vvv)!r}" 54 | else: 55 | # base prompts 56 | assert isinstance(k, str), f"{type(k)!r}" 57 | if v is None: 58 | v = "" 59 | assert isinstance(v, (int, float, str)), f"{type(v)!r}" 60 | 61 | def get_ranges(self): 62 | return sorted(self.definition[self.range_key].keys()) 63 | 64 | def _get_range_start(self, i: int) -> int | None: 65 | if len(self.definition[self.range_key]) == 0: 66 | return None 67 | 68 | range_starts = sorted(self.definition[self.range_key].keys()) 69 | 70 | for range_start_idx, range_start in enumerate(range_starts): 71 | if i < range_start: 72 | if range_start_idx == 0: 73 | return None 74 | else: 75 | return range_starts[range_start_idx - 1] 76 | 77 | return range_starts[-1] 78 | 79 | def _get_raw_sub_prompt(self, i: int): 80 | range_start = self._get_range_start(i) 81 | if range_start is None: 82 | # not in range, just use base definition 83 | return {**self.definition, "i": i} 84 | else: 85 | raw_sub_prompt = self.definition[self.range_key][self._get_range_start(i)] 86 | return {**self.definition, **raw_sub_prompt, "i": i} 87 | 88 | def get_sub_prompt(self, i: int): 89 | raw_sub_prompt = self._get_raw_sub_prompt(i) 90 | sub_prompt = {} 91 | for k, v in raw_sub_prompt.items(): 92 | if k == self.range_key: 93 | continue 94 | 95 | if isinstance(v, str): 96 | v = v.format(**raw_sub_prompt) 97 | 98 | sub_prompt[k] = v 99 | return sub_prompt 100 | 101 | 102 | DEFAULT_CONFIG = """\ 103 | p: | 104 | masterpiece, best quality, 105 | {sp}, 106 | 107 | n: | 108 | {sn}, 109 | embedding:EasyNegative, embedding:bad-artist, embedding:bad-hands-5, embedding:bad-image-v2-39000, 110 | lowres, ((bad anatomy)), ((bad hands)), text, missing finger, extra digits, fewer digits, blurry, ((mutated hands and fingers)), (poorly drawn face), ((mutation)), ((deformed face)), (ugly), ((bad proportions)), ((extra limbs)), extra face, (double head), (extra head), ((extra feet)), monster, logo, cropped, worst quality, low quality, normal quality, jpeg, humpbacked, long body, long neck, ((jpeg artifacts)), 111 | 112 | path: "{i:04d}.png" 113 | 114 | example: 0 115 | 116 | ranges: 117 | 1: 118 | sp: positive subprompt for 1-4 119 | sn: negative subprompt for 1-4 120 | 5: 121 | sp: positive subprompt for 5-... 122 | sn: negative subprompt for 5-... 123 | example: 1 124 | """ 125 | 126 | 127 | @register_node("JWInfoHashFromRangedInfo", "Info Hash From Ranged Config") 128 | class _: 129 | CATEGORY = "jamesWalker55" 130 | INPUT_TYPES = lambda: { 131 | "required": { 132 | "config": ( 133 | "STRING", 134 | {"default": DEFAULT_CONFIG, "multiline": True, "dynamicPrompts": False}, 135 | ), 136 | "i": ("INT", {"default": 1, "min": 0, "step": 1, "max": 999999}), 137 | "ranges_key": ("STRING", {"default": "ranges", "multiline": False}), 138 | } 139 | } 140 | RETURN_TYPES = ("INFO_HASH",) 141 | FUNCTION = "execute" 142 | 143 | def execute(self, config: str, i: int, ranges_key: str): 144 | config = yaml.safe_load(config) 145 | 146 | info = RangedConfig(config, range_key=ranges_key) 147 | 148 | return (info.get_sub_prompt(i),) 149 | 150 | 151 | @register_node("JWInfoHashListFromRangedInfo", "Info Hash List From Ranged Config") 152 | class _: 153 | CATEGORY = "jamesWalker55" 154 | INPUT_TYPES = lambda: { 155 | "required": { 156 | "config": ( 157 | "STRING", 158 | {"default": DEFAULT_CONFIG, "multiline": True, "dynamicPrompts": False}, 159 | ), 160 | "i_start": ("INT", {"default": 0, "min": 0, "step": 1, "max": 999999}), 161 | "i_stop": ("INT", {"default": 16, "min": 0, "step": 1, "max": 999999}), 162 | "ranges_key": ("STRING", {"default": "ranges", "multiline": False}), 163 | "inclusive": (("false", "true"), {"default": "false"}), 164 | } 165 | } 166 | RETURN_TYPES = ("INFO_HASH_LIST",) 167 | FUNCTION = "execute" 168 | 169 | def execute( 170 | self, config: str, i_start: int, i_stop: int, ranges_key: str, inclusive: str 171 | ): 172 | inclusive: bool = inclusive == "true" 173 | 174 | config = yaml.safe_load(config) 175 | 176 | info = RangedConfig(config, range_key=ranges_key) 177 | subinfos = [ 178 | info.get_sub_prompt(i) 179 | for i in range(i_start, i_stop + 1 if inclusive else i_stop) 180 | ] 181 | 182 | return (subinfos,) 183 | 184 | 185 | def calculate_batches( 186 | i_start: int, # start of i 187 | i_stop: int, # end of i, excludes end 188 | range_starts: int, # scene cuts, batch will be terminated before this 189 | max_batch_size: int, # maximum length of batch 190 | ): 191 | """ 192 | :param int i_start: start of i 193 | :param int i_stop: end of i, excludes end 194 | :param int range_starts: scene cuts, batch will be terminated before this 195 | :param int max_batch_size: maximum length of batch 196 | :return: a list of 2-tuples, each represents (batch start frame, batch stop frame), where stop frame is exclusive 197 | """ 198 | batch_starts: list[int] = [] # also includes end frame 199 | i = i_start - 1 200 | counter = -1 201 | while True: 202 | i += 1 203 | counter += 1 204 | if i >= i_stop: 205 | batch_starts.append(i) 206 | break 207 | 208 | if i in range_starts: 209 | batch_starts.append(i) 210 | counter = 0 211 | continue 212 | 213 | if counter >= max_batch_size: 214 | batch_starts.append(i) 215 | counter = 0 216 | continue 217 | 218 | if counter == 0: 219 | batch_starts.append(i) 220 | continue 221 | 222 | batches = list(zip(batch_starts[:-1], batch_starts[1:])) 223 | 224 | return batches 225 | 226 | 227 | @register_node("JWRangedInfoCalculateSubBatch", "Calculate Sub Batch for Ranged Info") 228 | class _: 229 | CATEGORY = "jamesWalker55" 230 | INPUT_TYPES = lambda: { 231 | "required": { 232 | "config": ( 233 | "STRING", 234 | {"default": DEFAULT_CONFIG, "multiline": True, "dynamicPrompts": False}, 235 | ), 236 | "ranges_key": ("STRING", {"default": "ranges", "multiline": False}), 237 | "batch_idx": ("INT", {"default": 0, "min": 0, "step": 1, "max": 999999}), 238 | "i_start": ("INT", {"default": 1, "min": 0, "step": 1, "max": 999999}), 239 | "i_stop": ("INT", {"default": 100, "min": 0, "step": 1, "max": 999999}), 240 | "max_batch_size": ( 241 | "INT", 242 | {"default": 16, "min": 1, "step": 1, "max": 999999}, 243 | ), 244 | "inclusive": (("false", "true"), {"default": "false"}), 245 | } 246 | } 247 | RETURN_NAMES = ("BATCH_I_START", "BATCH_I_STOP") 248 | RETURN_TYPES = ("INT", "INT") 249 | FUNCTION = "execute" 250 | 251 | def execute( 252 | self, 253 | config: str, 254 | ranges_key: str, 255 | batch_idx: int, 256 | i_start: int, 257 | i_stop: int, 258 | max_batch_size: int, 259 | inclusive: str, 260 | ): 261 | inclusive: bool = inclusive == "true" 262 | 263 | config = yaml.safe_load(config) 264 | 265 | info = RangedConfig(config, range_key=ranges_key) 266 | 267 | range_starts = set(info.get_ranges()) 268 | 269 | # get images in selected batch 270 | batches = calculate_batches( 271 | i_start, i_stop + 1 if inclusive else i_stop, range_starts, max_batch_size 272 | ) 273 | batch = batches[batch_idx] 274 | 275 | return (batch[0], batch[1]) 276 | 277 | 278 | @register_node( 279 | "JWInfoHashFromRangedInfoAndLoadSubsequences", 280 | "Info Hash From Ranged Config and Load Batch", 281 | ) 282 | class _: 283 | CATEGORY = "jamesWalker55" 284 | INPUT_TYPES = lambda: { 285 | "required": { 286 | "config": ( 287 | "STRING", 288 | {"default": DEFAULT_CONFIG, "multiline": True, "dynamicPrompts": False}, 289 | ), 290 | "ranges_key": ("STRING", {"default": "ranges", "multiline": False}), 291 | "path_key": ("STRING", {"default": "path", "multiline": False}), 292 | "batch_idx": ("INT", {"default": 0, "min": 0, "step": 1, "max": 999999}), 293 | "i_start": ("INT", {"default": 1, "min": 0, "step": 1, "max": 999999}), 294 | "i_stop": ("INT", {"default": 100, "min": 0, "step": 1, "max": 999999}), 295 | "max_batch_size": ( 296 | "INT", 297 | {"default": 16, "min": 1, "step": 1, "max": 999999}, 298 | ), 299 | "inclusive": (("false", "true"), {"default": "false"}), 300 | } 301 | } 302 | RETURN_NAMES = ("INFO_HASH", "IMAGE", "BATCH_I_START", "BATCH_I_STOP") 303 | RETURN_TYPES = ("INFO_HASH", "IMAGE", "INT", "INT") 304 | FUNCTION = "execute" 305 | 306 | def execute( 307 | self, 308 | config: str, 309 | ranges_key: str, 310 | path_key: str, 311 | batch_idx: int, 312 | i_start: int, 313 | i_stop: int, 314 | max_batch_size: int, 315 | inclusive: str, 316 | ): 317 | inclusive: bool = inclusive == "true" 318 | 319 | config = yaml.safe_load(config) 320 | 321 | info = RangedConfig(config, range_key=ranges_key) 322 | 323 | range_starts = set(info.get_ranges()) 324 | 325 | # get images in selected batch 326 | batches = calculate_batches( 327 | i_start, i_stop + 1 if inclusive else i_stop, range_starts, max_batch_size 328 | ) 329 | batch = batches[batch_idx] 330 | 331 | print(f"Getting images in batch: {batch}") 332 | 333 | images = [] 334 | for i in range(batch[0], batch[1]): 335 | subinfo = info.get_sub_prompt(i) 336 | path = subinfo[path_key] 337 | print(f" Loading: {path}") 338 | img = load_image(path) 339 | images.append(img) 340 | images = torch.cat(images, dim=0) 341 | 342 | return (info.get_sub_prompt(batch[0]), images, batch[0], batch[1]) 343 | 344 | 345 | @register_node("JWInfoHashExtractInteger", "Info Hash Extract Integer") 346 | class _: 347 | CATEGORY = "jamesWalker55" 348 | INPUT_TYPES = lambda: { 349 | "required": { 350 | "info_hash": ("INFO_HASH",), 351 | "key": ("STRING", {"default": "i", "multiline": False}), 352 | } 353 | } 354 | RETURN_TYPES = ("INT",) 355 | FUNCTION = "execute" 356 | 357 | def execute(self, info_hash: dict, key: str): 358 | val = int(info_hash[key]) 359 | return (val,) 360 | 361 | 362 | @register_node("JWInfoHashExtractFloat", "Info Hash Extract Float") 363 | class _: 364 | CATEGORY = "jamesWalker55" 365 | INPUT_TYPES = lambda: { 366 | "required": { 367 | "info_hash": ("INFO_HASH",), 368 | "key": ("STRING", {"default": "key", "multiline": False}), 369 | } 370 | } 371 | RETURN_TYPES = ("FLOAT",) 372 | FUNCTION = "execute" 373 | 374 | def execute(self, info_hash: dict, key: str): 375 | val = float(info_hash[key]) 376 | return (val,) 377 | 378 | 379 | @register_node("JWInfoHashExtractString", "Info Hash Extract String") 380 | class _: 381 | CATEGORY = "jamesWalker55" 382 | INPUT_TYPES = lambda: { 383 | "required": { 384 | "info_hash": ("INFO_HASH",), 385 | "key": ("STRING", {"default": "p", "multiline": False}), 386 | } 387 | } 388 | RETURN_TYPES = ("STRING",) 389 | FUNCTION = "execute" 390 | 391 | def execute(self, info_hash: dict, key: str): 392 | val = str(info_hash[key]) 393 | return (val,) 394 | 395 | 396 | @register_node("JWInfoHashListExtractStringList", "Info Hash List Extract String List") 397 | class _: 398 | CATEGORY = "jamesWalker55" 399 | INPUT_TYPES = lambda: { 400 | "required": { 401 | "info_hash_list": ("INFO_HASH_LIST",), 402 | "key": ("STRING", {"default": "p", "multiline": False}), 403 | } 404 | } 405 | RETURN_TYPES = ("STRING_LIST",) 406 | FUNCTION = "execute" 407 | 408 | def execute(self, info_hash_list: list[dict], key: str): 409 | val = [str(info_hash[key]) for info_hash in info_hash_list] 410 | return (val,) 411 | 412 | 413 | @register_node("JWInfoHashFromInfoHashList", "Extract Info Hash From Info Hash List") 414 | class _: 415 | CATEGORY = "jamesWalker55" 416 | INPUT_TYPES = lambda: { 417 | "required": { 418 | "info_hash_list": ("INFO_HASH_LIST",), 419 | "i": ("INT", {"default": 0, "step": 1, "min": -99999999, "max": 99999999}), 420 | } 421 | } 422 | RETURN_TYPES = ("INFO_HASH",) 423 | FUNCTION = "execute" 424 | 425 | def execute(self, info_hash_list: list[dict], i: int): 426 | return (info_hash_list[i],) 427 | 428 | 429 | @register_node("JWInfoHashPrint", "Print Info Hash (Debug)") 430 | class _: 431 | CATEGORY = "jamesWalker55" 432 | INPUT_TYPES = lambda: { 433 | "required": { 434 | "info_hash": ("INFO_HASH",), 435 | } 436 | } 437 | RETURN_TYPES = () 438 | OUTPUT_NODE = True 439 | FUNCTION = "execute" 440 | 441 | def execute(self, info_hash: dict): 442 | from pprint import pformat, pprint 443 | 444 | pprint(info_hash) 445 | raise ValueError(pformat(info_hash)) 446 | -------------------------------------------------------------------------------- /comfyui_jw.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | 5 | NODE_CLASS_MAPPINGS = {} 6 | NODE_DISPLAY_NAME_MAPPINGS = {} 7 | 8 | 9 | def register_node(identifier: str, display_name: str): 10 | def decorator(cls): 11 | NODE_CLASS_MAPPINGS[identifier] = cls 12 | NODE_DISPLAY_NAME_MAPPINGS[identifier] = display_name 13 | 14 | return cls 15 | 16 | return decorator 17 | 18 | 19 | @register_node("JWReferenceOnly", "James: Reference Only") 20 | class ReferenceOnlySimple: 21 | CATEGORY = "jamesWalker55" 22 | INPUT_TYPES = lambda: { 23 | "required": { 24 | "model": ("MODEL",), 25 | "reference": ("LATENT",), 26 | "initial_latent": ("LATENT",), 27 | "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}), 28 | } 29 | } 30 | RETURN_TYPES = ("MODEL", "LATENT") 31 | FUNCTION = "execute" 32 | 33 | def execute(self, model, reference, initial_latent, batch_size): 34 | model_reference = model.clone() 35 | size_latent = list(reference["samples"].shape) 36 | size_latent[0] = batch_size 37 | latent = {} 38 | latent["samples"] = initial_latent["samples"] 39 | 40 | batch = latent["samples"].shape[0] + reference["samples"].shape[0] 41 | 42 | def reference_apply(q, k, v, extra_options): 43 | k = k.clone().repeat(1, 2, 1) 44 | 45 | for o in range(0, q.shape[0], batch): 46 | for x in range(1, batch): 47 | k[x + o, q.shape[1] :] = q[o, :] 48 | 49 | return q, k, k 50 | 51 | model_reference.set_model_attn1_patch(reference_apply) 52 | 53 | out_latent = torch.cat((reference["samples"], latent["samples"])) 54 | if "noise_mask" in latent: 55 | mask = latent["noise_mask"] 56 | else: 57 | mask = torch.ones((64, 64), dtype=torch.float32, device="cpu") 58 | 59 | if len(mask.shape) < 3: 60 | mask = mask.unsqueeze(0) 61 | if mask.shape[0] < latent["samples"].shape[0]: 62 | print(latent["samples"].shape, mask.shape) 63 | mask = mask.repeat(latent["samples"].shape[0], 1, 1) 64 | 65 | out_mask = torch.zeros( 66 | (1, mask.shape[1], mask.shape[2]), dtype=torch.float32, device="cpu" 67 | ) 68 | return ( 69 | model_reference, 70 | {"samples": out_latent, "noise_mask": torch.cat((out_mask, mask))}, 71 | ) 72 | 73 | 74 | @register_node( 75 | "JWSetLastControlNetStrengthForBatch", 76 | "Set Last ControlNet Strength For Batch", 77 | ) 78 | class _: 79 | """ 80 | Set the strength of the previously-added ControlNet, number of values must be 81 | equal to batch size. 82 | """ 83 | 84 | CATEGORY = "jamesWalker55" 85 | INPUT_TYPES = lambda: { 86 | "required": { 87 | "conditioning": ("CONDITIONING",), 88 | "strengths": ( 89 | "STRING", 90 | { 91 | "default": "0.25, 0.5, 0.75, 1.0", 92 | "multiline": True, 93 | "dynamicPrompts": False, 94 | }, 95 | ), 96 | } 97 | } 98 | RETURN_TYPES = ("CONDITIONING",) 99 | FUNCTION = "execute" 100 | 101 | def execute( 102 | self, 103 | conditioning: list[list[torch.Tensor | dict[str, Any]]], 104 | strengths, 105 | ): 106 | strengths = [float(x.strip()) for x in strengths.split(",")] 107 | strengths = torch.tensor(strengths).reshape((-1, 1, 1, 1)) 108 | strengths = torch.cat((strengths, strengths)) 109 | strengths = strengths.to("cuda") 110 | 111 | new_conditioning = [] 112 | for old_cond in conditioning: 113 | cond = old_cond.copy() 114 | cond[1] = cond[1].copy() 115 | 116 | if cond[1].get("control", None): 117 | # new_cond[1]["control"]: comfy.controlnet.ControlNet 118 | cond[1]["control"].strength = strengths 119 | 120 | new_conditioning.append(cond) 121 | 122 | return (new_conditioning,) 123 | -------------------------------------------------------------------------------- /comfyui_mask_sequence_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | NODE_CLASS_MAPPINGS = {} 4 | NODE_DISPLAY_NAME_MAPPINGS = {} 5 | 6 | 7 | def register_node(identifier: str, display_name: str): 8 | def decorator(cls): 9 | NODE_CLASS_MAPPINGS[identifier] = cls 10 | NODE_DISPLAY_NAME_MAPPINGS[identifier] = display_name 11 | 12 | return cls 13 | 14 | return decorator 15 | 16 | 17 | @register_node("JWMaskSequenceFromMask", "Mask Sequence From Mask") 18 | class _: 19 | CATEGORY = "jamesWalker55" 20 | INPUT_TYPES = lambda: { 21 | "required": { 22 | "mask": ("MASK",), 23 | "batch_size": ("INT", {"default": 1, "min": 1, "step": 1}), 24 | } 25 | } 26 | RETURN_TYPES = ("MASK_SEQUENCE",) 27 | FUNCTION = "execute" 28 | 29 | def execute( 30 | self, 31 | mask: torch.Tensor, 32 | batch_size: int, 33 | ): 34 | assert isinstance(mask, torch.Tensor) 35 | assert isinstance(batch_size, int) 36 | assert len(mask.shape) == 2 37 | 38 | mask_seq = mask.reshape((1, 1, *mask.shape)) 39 | mask_seq = mask_seq.repeat(batch_size, 1, 1, 1) 40 | 41 | return (mask_seq,) 42 | 43 | 44 | @register_node("JWMaskSequenceJoin", "Join Mask Sequence") 45 | class _: 46 | CATEGORY = "jamesWalker55" 47 | INPUT_TYPES = lambda: { 48 | "required": { 49 | "mask_sequence_1": ("MASK_SEQUENCE",), 50 | "mask_sequence_2": ("MASK_SEQUENCE",), 51 | } 52 | } 53 | RETURN_TYPES = ("MASK_SEQUENCE",) 54 | FUNCTION = "execute" 55 | 56 | def execute( 57 | self, 58 | mask_sequence_1: torch.Tensor, 59 | mask_sequence_2: torch.Tensor, 60 | ): 61 | assert isinstance(mask_sequence_1, torch.Tensor) 62 | assert isinstance(mask_sequence_2, torch.Tensor) 63 | 64 | mask_seq = torch.cat((mask_sequence_1, mask_sequence_2), dim=0) 65 | 66 | return (mask_seq,) 67 | 68 | 69 | @register_node("JWMaskSequenceApplyToLatent", "Apply Mask Sequence to Latent") 70 | class _: 71 | CATEGORY = "jamesWalker55" 72 | INPUT_TYPES = lambda: { 73 | "required": { 74 | "samples": ("LATENT",), 75 | "mask_sequence": ("MASK_SEQUENCE",), 76 | } 77 | } 78 | RETURN_TYPES = ("LATENT",) 79 | FUNCTION = "execute" 80 | 81 | def execute( 82 | self, 83 | samples: dict, 84 | mask_sequence: torch.Tensor, 85 | ): 86 | assert isinstance(samples, dict) 87 | assert isinstance(mask_sequence, torch.Tensor) 88 | 89 | samples = samples.copy() 90 | 91 | samples["noise_mask"] = mask_sequence 92 | 93 | return (samples,) 94 | -------------------------------------------------------------------------------- /comfyui_primitive_ops.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import math 3 | import typing 4 | from typing import Literal 5 | 6 | NODE_CLASS_MAPPINGS = {} 7 | NODE_DISPLAY_NAME_MAPPINGS = {} 8 | 9 | 10 | def register_node(identifier: str, display_name: str): 11 | def decorator(cls): 12 | NODE_CLASS_MAPPINGS[identifier] = cls 13 | NODE_DISPLAY_NAME_MAPPINGS[identifier] = display_name 14 | 15 | return cls 16 | 17 | return decorator 18 | 19 | 20 | def generate_functional_node( 21 | category: str, 22 | identifier: str, 23 | display_name: str, 24 | *, 25 | multiline_string: bool = False, 26 | output_node: bool = False, 27 | ): 28 | def decorator(func): 29 | signature = inspect.signature(func) 30 | 31 | # generate INPUT_TYPES 32 | required_inputs = {} 33 | 34 | for name, param in signature.parameters.items(): 35 | has_default = param.default is not param.empty 36 | param_type = param.annotation 37 | if param_type is int: 38 | if not has_default: 39 | raise TypeError("INT input must have a default value") 40 | 41 | required_inputs[name] = ( 42 | "INT", 43 | { 44 | "default": param.default, 45 | "min": -0xFFFFFFFFFFFFFFFF, 46 | "max": 0xFFFFFFFFFFFFFFFF, 47 | }, 48 | ) 49 | elif param_type is float: 50 | if not has_default: 51 | raise TypeError("FLOAT input must have a default value") 52 | 53 | required_inputs[name] = ( 54 | "FLOAT", 55 | { 56 | "default": param.default, 57 | "min": -99999999999999999.0, 58 | "max": 99999999999999999.0, 59 | }, 60 | ) 61 | elif param_type is str: 62 | if not has_default: 63 | raise TypeError("STRING input must have a default value") 64 | 65 | required_inputs[name] = ( 66 | "STRING", 67 | { 68 | "default": param.default, 69 | "multiline": multiline_string, 70 | "dynamicPrompts": False, 71 | }, 72 | ) 73 | elif isinstance(param_type, str): 74 | if has_default: 75 | raise TypeError("Custom input types cannot have default values") 76 | 77 | required_inputs[name] = (param_type,) 78 | elif typing.get_origin(param_type) is Literal: 79 | choices = typing.get_args(param_type) 80 | if param.default is not None: 81 | raise TypeError( 82 | "Choice input types must have default value set to None" 83 | ) 84 | 85 | required_inputs[name] = (choices,) 86 | else: 87 | raise NotImplementedError( 88 | f"Unsupported functional node type: {param_type}" 89 | ) 90 | 91 | # generate RETURN_TYPES 92 | if signature.return_annotation is signature.empty: 93 | raise TypeError("Functional node must have annotation for return type") 94 | elif typing.get_origin(signature.return_annotation) is not tuple: 95 | raise TypeError("Functional node must return a tuple") 96 | 97 | return_types = [] 98 | for annot in typing.get_args(signature.return_annotation): 99 | if isinstance(annot, str): 100 | return_types.append(annot) 101 | elif annot is int: 102 | return_types.append("INT") 103 | elif annot is float: 104 | return_types.append("FLOAT") 105 | elif annot is str: 106 | return_types.append("STRING") 107 | else: 108 | raise NotImplementedError(f"Unsupported return type: {annot}") 109 | 110 | @register_node(identifier, display_name) 111 | class _: 112 | CATEGORY = category 113 | INPUT_TYPES = lambda: {"required": required_inputs} 114 | RETURN_TYPES = tuple(return_types) 115 | OUTPUT_NODE = output_node 116 | FUNCTION = "execute" 117 | 118 | def execute(self, *args, **kwargs): 119 | return func(*args, **kwargs) 120 | 121 | return func 122 | 123 | return decorator 124 | 125 | 126 | @generate_functional_node("jamesWalker55", "JWInteger", "Integer") 127 | def _(value: int = 0) -> tuple[int]: 128 | return (value,) 129 | 130 | 131 | @generate_functional_node("jamesWalker55", "JWIntegerToFloat", "Integer to Float") 132 | def _(value: int = 0) -> tuple[float]: 133 | return (float(value),) 134 | 135 | 136 | @generate_functional_node("jamesWalker55", "JWIntegerToString", "Integer to String") 137 | def _(value: int = 0, format_string: str = "{:04d}") -> tuple[str]: 138 | return (format_string.format(value),) 139 | 140 | 141 | @generate_functional_node("jamesWalker55", "JWIntegerAdd", "Integer Add") 142 | def _(a: int = 0, b: int = 0) -> tuple[int]: 143 | return (a + b,) 144 | 145 | 146 | @generate_functional_node("jamesWalker55", "JWIntegerSub", "Integer Subtract") 147 | def _(a: int = 0, b: int = 0) -> tuple[int]: 148 | return (a - b,) 149 | 150 | 151 | @generate_functional_node("jamesWalker55", "JWIntegerMul", "Integer Multiply") 152 | def _(a: int = 0, b: int = 0) -> tuple[int]: 153 | return (a * b,) 154 | 155 | 156 | @generate_functional_node("jamesWalker55", "JWIntegerDiv", "Integer Divide") 157 | def _(a: int = 0, b: int = 0) -> tuple[float]: 158 | return (a / b,) 159 | 160 | 161 | @generate_functional_node( 162 | "jamesWalker55", "JWIntegerAbsolute", "Integer Absolute Value" 163 | ) 164 | def _(value: int = 0) -> tuple[int]: 165 | return (abs(value),) 166 | 167 | 168 | @generate_functional_node("jamesWalker55", "JWIntegerMin", "Integer Minimum") 169 | def _(a: int = 0, b: int = 0) -> tuple[int]: 170 | return (min(a, b),) 171 | 172 | 173 | @generate_functional_node("jamesWalker55", "JWIntegerMax", "Integer Maximum") 174 | def _(a: int = 0, b: int = 0) -> tuple[int]: 175 | return (max(a, b),) 176 | 177 | 178 | @generate_functional_node("jamesWalker55", "JWFloat", "Float") 179 | def _(value: float = 0) -> tuple[float]: 180 | return (value,) 181 | 182 | 183 | @generate_functional_node("jamesWalker55", "JWFloatToInteger", "Float to Integer") 184 | def _( 185 | value: float = 0, mode: Literal["round", "floor", "ceiling"] = None 186 | ) -> tuple[int]: 187 | if mode == "round": 188 | return (round(value),) 189 | elif mode == "floor": 190 | return (math.floor(value),) 191 | elif mode == "ceiling": 192 | return (math.ceil(value),) 193 | else: 194 | raise NotImplementedError(f"Unsupported mode: {mode}") 195 | 196 | 197 | @generate_functional_node("jamesWalker55", "JWFloatToString", "Float to String") 198 | def _(value: float = 0, format_string: str = "{:.6g}") -> tuple[str]: 199 | return (format_string.format(value),) 200 | 201 | 202 | @generate_functional_node("jamesWalker55", "JWFloatAdd", "Float Add") 203 | def _(a: float = 0, b: float = 0) -> tuple[float]: 204 | return (a + b,) 205 | 206 | 207 | @generate_functional_node("jamesWalker55", "JWFloatSub", "Float Subtract") 208 | def _(a: float = 0, b: float = 0) -> tuple[float]: 209 | return (a - b,) 210 | 211 | 212 | @generate_functional_node("jamesWalker55", "JWFloatMul", "Float Multiply") 213 | def _(a: float = 0, b: float = 0) -> tuple[float]: 214 | return (a * b,) 215 | 216 | 217 | @generate_functional_node("jamesWalker55", "JWFloatDiv", "Float Divide") 218 | def _(a: float = 0, b: float = 0) -> tuple[float]: 219 | return (a / b,) 220 | 221 | 222 | @generate_functional_node("jamesWalker55", "JWFloatAbsolute", "Float Absolute Value") 223 | def _(value: float = 0) -> tuple[float]: 224 | return (abs(value),) 225 | 226 | 227 | @generate_functional_node("jamesWalker55", "JWFloatMin", "Float Minimum") 228 | def _(a: float = 0, b: float = 0) -> tuple[float]: 229 | return (min(a, b),) 230 | 231 | 232 | @generate_functional_node("jamesWalker55", "JWFloatMax", "Float Maximum") 233 | def _(a: float = 0, b: float = 0) -> tuple[float]: 234 | return (max(a, b),) 235 | 236 | 237 | @generate_functional_node("jamesWalker55", "JWString", "String") 238 | def _(text: str = "") -> tuple[str]: 239 | return (text,) 240 | 241 | 242 | @generate_functional_node("jamesWalker55", "JWStringToInteger", "String to Integer") 243 | def _(text: str = "0") -> tuple[int]: 244 | return (int(text),) 245 | 246 | 247 | @generate_functional_node("jamesWalker55", "JWStringToFloat", "String to Float") 248 | def _(text: str = "0.0") -> tuple[float]: 249 | return (float(text),) 250 | 251 | 252 | @generate_functional_node( 253 | "jamesWalker55", "JWStringMultiline", "String (Multiline)", multiline_string=True 254 | ) 255 | def _(text: str = "") -> tuple[str]: 256 | return (text,) 257 | 258 | 259 | @generate_functional_node("jamesWalker55", "JWStringConcat", "String Concatenate") 260 | def _(a: str = "", b: str = "") -> tuple[str]: 261 | return (a + b,) 262 | 263 | 264 | @generate_functional_node("jamesWalker55", "JWStringReplace", "String Replace") 265 | def _(source: str = "", to_replace: str = "", replace_with: str = "") -> tuple[str]: 266 | return (source.replace(to_replace, replace_with),) 267 | 268 | 269 | @generate_functional_node("jamesWalker55", "JWStringSplit", "String Split") 270 | def _( 271 | source: str = "a,b", 272 | split_by: str = ",", 273 | from_right: Literal["false", "true"] = None, 274 | ) -> tuple[str, str]: 275 | from_right = from_right == "true" 276 | if from_right: 277 | splits = source.rsplit(split_by, 1) 278 | else: 279 | splits = source.split(split_by, 1) 280 | match splits: 281 | case a, b: 282 | return (a, b) 283 | case a: 284 | return (a, "") 285 | 286 | 287 | @generate_functional_node("jamesWalker55", "JWStringGetLine", "String Get Line") 288 | def _(source: str = "", line_index: int = 0) -> tuple[str]: 289 | return (source.splitlines()[line_index],) 290 | 291 | 292 | @generate_functional_node("jamesWalker55", "JWStringUnescape", "String Unescape") 293 | def _(text: str = "") -> tuple[str]: 294 | """parses '\\n' literals in a string to actual '\n' characters""" 295 | # convert to bytes, while converting unicode to escaped literals 296 | text_bytes = text.encode("ascii", "backslashreplace") 297 | # convert back to string, parsing backslash escapes 298 | text = text_bytes.decode("unicode-escape") 299 | return (text,) 300 | -------------------------------------------------------------------------------- /comfyui_raft.py: -------------------------------------------------------------------------------- 1 | import comfy.model_management as model_management 2 | import numpy as np 3 | import torch 4 | import torchvision.transforms.functional as F 5 | from torchvision.models.optical_flow import Raft_Large_Weights, raft_large 6 | from torchvision.utils import flow_to_image 7 | 8 | NODE_CLASS_MAPPINGS = {} 9 | NODE_DISPLAY_NAME_MAPPINGS = {} 10 | 11 | 12 | def register_node(identifier: str, display_name: str): 13 | def decorator(cls): 14 | NODE_CLASS_MAPPINGS[identifier] = cls 15 | NODE_DISPLAY_NAME_MAPPINGS[identifier] = display_name 16 | 17 | return cls 18 | 19 | return decorator 20 | 21 | 22 | def comfyui_to_native_torch(imgs: torch.Tensor): 23 | """ 24 | Convert images in NHWC format to NCHW format. 25 | 26 | Use this to convert ComfyUI images to torch-native images. 27 | """ 28 | return imgs.permute(0, 3, 1, 2) 29 | 30 | 31 | def native_torch_to_comfyui(imgs: torch.Tensor): 32 | """ 33 | Convert images in NCHW format to NHWC format. 34 | 35 | Use this to convert torch-native images to ComfyUI images. 36 | """ 37 | return imgs.permute(0, 2, 3, 1) 38 | 39 | 40 | _model = None 41 | 42 | 43 | def load_model(): 44 | global _model 45 | 46 | if _model is not None: 47 | return _model 48 | 49 | try: 50 | offload_device = model_management.unet_offload_device() 51 | 52 | _model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).eval() 53 | _model = _model.to(offload_device) 54 | 55 | return _model 56 | except Exception as e: 57 | _model = None 58 | raise e 59 | 60 | 61 | def preprocess_image(img: torch.Tensor): 62 | # Image size must be divisible by 8 63 | _, _, h, w = img.shape 64 | assert h % 8 == 0, "Image height must be divisible by 8" 65 | assert w % 8 == 0, "Image width must be divisible by 8" 66 | 67 | img = F.convert_image_dtype(img, torch.float) 68 | 69 | # map [0, 1] into [-1, 1] 70 | img = F.normalize(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 71 | 72 | img = img.contiguous() 73 | 74 | return img 75 | 76 | 77 | @register_node("RAFTEstimate", "RAFT Estimate") 78 | class _: 79 | """ 80 | https://pytorch.org/vision/main/auto_examples/plot_optical_flow.html 81 | """ 82 | 83 | CATEGORY = "jamesWalker55" 84 | INPUT_TYPES = lambda: { 85 | "required": { 86 | "image_a": ("IMAGE",), 87 | "image_b": ("IMAGE",), 88 | } 89 | } 90 | RETURN_TYPES = ("RAFT_FLOW",) 91 | FUNCTION = "execute" 92 | 93 | def execute(self, image_a: torch.Tensor, image_b: torch.Tensor): 94 | """ 95 | Code derived from: 96 | https://pytorch.org/vision/main/auto_examples/plot_optical_flow.html 97 | """ 98 | 99 | assert isinstance(image_a, torch.Tensor) 100 | assert isinstance(image_b, torch.Tensor) 101 | 102 | torch_device = model_management.get_torch_device() 103 | offload_device = model_management.unet_offload_device() 104 | 105 | image_a = comfyui_to_native_torch(image_a).to(torch_device) 106 | image_b = comfyui_to_native_torch(image_b).to(torch_device) 107 | model = load_model().to(torch_device) 108 | 109 | image_a = preprocess_image(image_a) 110 | image_b = preprocess_image(image_b) 111 | 112 | all_flows = model(image_a, image_b) 113 | best_flow = all_flows[-1] 114 | # best_flow.shape => torch.Size([1, 2, 512, 512]) 115 | 116 | model.to(offload_device) 117 | image_a = image_a.to("cpu") 118 | image_b = image_b.to("cpu") 119 | best_flow = best_flow.to("cpu") 120 | 121 | return (best_flow,) 122 | 123 | 124 | @register_node("RAFTFlowToImage", "RAFT Flow to Image") 125 | class _: 126 | """ 127 | https://pytorch.org/vision/main/auto_examples/plot_optical_flow.html 128 | """ 129 | 130 | CATEGORY = "jamesWalker55" 131 | INPUT_TYPES = lambda: { 132 | "required": { 133 | "raft_flow": ("RAFT_FLOW",), 134 | } 135 | } 136 | RETURN_TYPES = ("IMAGE",) 137 | FUNCTION = "execute" 138 | 139 | def execute(self, raft_flow: torch.Tensor): 140 | assert isinstance(raft_flow, torch.Tensor) 141 | assert raft_flow.shape[1] == 2 142 | 143 | images = flow_to_image(raft_flow) 144 | # pixel range is [0, 255], dtype=torch.uint8 145 | 146 | images = images / 255 147 | 148 | images = native_torch_to_comfyui(images) 149 | 150 | return (images,) 151 | 152 | 153 | def depth_exr_to_numpy(exr_path, typemap={"HALF": np.float16, "FLOAT": np.float32}): 154 | # Code stolen from: 155 | # https://gist.github.com/andres-fr/4ddbb300d418ed65951ce88766236f9c 156 | 157 | import OpenEXR 158 | 159 | # load EXR and extract shape 160 | exr = OpenEXR.InputFile(exr_path) 161 | print(exr.header()) 162 | dw = exr.header()["dataWindow"] 163 | shape = (dw.max.y - dw.min.y + 1, dw.max.x - dw.min.x + 1) 164 | # 165 | arr_maps = {} 166 | for ch_name, ch in exr.header()["channels"].items(): 167 | print("reading channel", ch_name) 168 | # This, and __str__ seem to be the only ways to get typename 169 | exr_typename = ch.type.names[ch.type.v] 170 | np_type = typemap[exr_typename] 171 | # convert channel to np array 172 | bytestring = exr.channel(ch_name, ch.type) 173 | arr = np.frombuffer(bytestring, dtype=np_type).reshape(shape) 174 | arr_maps[ch_name] = arr 175 | 176 | return arr_maps 177 | 178 | 179 | @register_node("RAFTLoadFlowFromEXRChannels", "RAFT Load Flow from EXR Channels") 180 | class _: 181 | """ 182 | This is a utility function for loading motion flows from an EXR image file. 183 | This is intended for use with Blender's vector pass in the Cycles renderer. 184 | 185 | In Blender, enable the vector pass. In the compositor, use "Separate Color" to 186 | extract the "Blue" and "Alpha" channels of the vector pass. Then, combine them 187 | using "Combine Color" to two of the RGB channels. Finally, render to the "OpenEXR" 188 | format. 189 | 190 | https://gist.github.com/andres-fr/4ddbb300d418ed65951ce88766236f9c 191 | """ 192 | 193 | CATEGORY = "jamesWalker55" 194 | INPUT_TYPES = lambda: { 195 | "required": { 196 | "path": ("STRING", {"default": ""}), 197 | "x_channel": (("R", "G", "B", "A"), {"default": "R"}), 198 | "y_channel": (("R", "G", "B", "A"), {"default": "G"}), 199 | "invert_x": (("false", "true"), {"default": "true"}), 200 | "invert_y": (("false", "true"), {"default": "false"}), 201 | } 202 | } 203 | RETURN_TYPES = ("RAFT_FLOW",) 204 | FUNCTION = "execute" 205 | 206 | def execute( 207 | self, path: str, x_channel: str, y_channel: str, invert_x: str, invert_y: str 208 | ): 209 | assert isinstance(path, str) 210 | assert x_channel in ("R", "G", "B", "A") 211 | assert y_channel in ("R", "G", "B", "A") 212 | assert invert_x in ("true", "false") 213 | assert invert_y in ("true", "false") 214 | 215 | invert_x: bool = invert_x == "true" 216 | invert_y: bool = invert_y == "true" 217 | 218 | maps = depth_exr_to_numpy(path) 219 | 220 | x = torch.from_numpy(maps[x_channel]) 221 | y = torch.from_numpy(maps[y_channel]) 222 | 223 | if invert_x: 224 | x = x * -1 225 | 226 | if invert_y: 227 | y = y * -1 228 | 229 | return (torch.stack((x, y)).unsqueeze(0),) 230 | -------------------------------------------------------------------------------- /comfyui_rc.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import lzma 4 | from io import BytesIO 5 | 6 | import torch 7 | 8 | NODE_CLASS_MAPPINGS = {} 9 | NODE_DISPLAY_NAME_MAPPINGS = {} 10 | 11 | 12 | def register_node(identifier: str, display_name: str): 13 | def decorator(cls): 14 | NODE_CLASS_MAPPINGS[identifier] = cls 15 | NODE_DISPLAY_NAME_MAPPINGS[identifier] = display_name 16 | 17 | return cls 18 | 19 | return decorator 20 | 21 | 22 | def compress(x: bytes): 23 | comp = lzma.LZMACompressor() 24 | out = comp.compress(x) 25 | return out + comp.flush() 26 | 27 | 28 | def decompress(x: bytes): 29 | decomp = lzma.LZMADecompressor() 30 | return decomp.decompress(x) 31 | 32 | 33 | def base85_encode(x: bytes): 34 | return base64.b85encode(x) 35 | 36 | 37 | def base85_decode(x: bytes): 38 | return base64.b85decode(x) 39 | 40 | 41 | def torch_save_to_bytes(obj): 42 | with BytesIO() as f: 43 | torch.save(obj, f) 44 | return f.getvalue() 45 | 46 | 47 | def torch_load_from_bytes(text: bytes): 48 | with BytesIO(text) as f: 49 | return torch.load(f) 50 | 51 | 52 | def torch_save_to_blob(obj): 53 | return base85_encode(compress(torch_save_to_bytes(obj))) 54 | 55 | 56 | def torch_load_from_blob(text: bytes): 57 | return torch_load_from_bytes(decompress(base85_decode(text))) 58 | 59 | 60 | @register_node("RCReceiveLatent", "Remote Call: Receive Latent") 61 | class _: 62 | CATEGORY = "jamesWalker55/rc" 63 | INPUT_TYPES = lambda: { 64 | "required": { 65 | "key": ( 66 | "STRING", 67 | {"default": "input_latent", "multiline": False}, 68 | ), 69 | "value": ( 70 | "STRING", 71 | {"default": "Don't touch this field!", "multiline": False}, 72 | ), 73 | } 74 | } 75 | RETURN_TYPES = ("LATENT",) 76 | FUNCTION = "execute" 77 | 78 | def execute(self, key: str, value: str): 79 | latent = torch_load_from_blob(value) 80 | val = {"samples": latent} 81 | # { "samples": } 82 | return (val,) 83 | 84 | 85 | @register_node("RCReceiveInt", "Remote Call: Receive Integer") 86 | class _: 87 | CATEGORY = "jamesWalker55/rc" 88 | INPUT_TYPES = lambda: { 89 | "required": { 90 | "key": ( 91 | "STRING", 92 | {"default": "input_integer", "multiline": False}, 93 | ), 94 | "value": ("INT", {"default": 0, "min": -99999999999, "max": 99999999999}), 95 | } 96 | } 97 | RETURN_TYPES = ("INT",) 98 | FUNCTION = "execute" 99 | 100 | def execute(self, key: str, value): 101 | return (value,) 102 | 103 | 104 | @register_node("RCReceiveFloat", "Remote Call: Receive Float") 105 | class _: 106 | CATEGORY = "jamesWalker55/rc" 107 | INPUT_TYPES = lambda: { 108 | "required": { 109 | "key": ( 110 | "STRING", 111 | {"default": "input_float", "multiline": False}, 112 | ), 113 | "value": ("FLOAT", {"default": 0, "min": -99999999999, "max": 99999999999}), 114 | } 115 | } 116 | RETURN_TYPES = ("FLOAT",) 117 | FUNCTION = "execute" 118 | 119 | def execute(self, key: str, value): 120 | return (value,) 121 | 122 | 123 | @register_node("RCReceiveIntList", "Remote Call: Receive Integer List") 124 | class _: 125 | CATEGORY = "jamesWalker55/rc" 126 | INPUT_TYPES = lambda: { 127 | "required": { 128 | "key": ( 129 | "STRING", 130 | {"default": "input_integer_list", "multiline": False}, 131 | ), 132 | "value": ( 133 | "STRING", 134 | {"default": "[1, 2, 3]", "multiline": False}, 135 | ), 136 | } 137 | } 138 | RETURN_TYPES = ("INT_LIST",) 139 | FUNCTION = "execute" 140 | 141 | def execute(self, key: str, value): 142 | value = json.loads(value) 143 | return (value,) 144 | 145 | 146 | @register_node("RCReceiveFloatList", "Remote Call: Receive Float List") 147 | class _: 148 | CATEGORY = "jamesWalker55/rc" 149 | INPUT_TYPES = lambda: { 150 | "required": { 151 | "key": ( 152 | "STRING", 153 | {"default": "input_float_list", "multiline": False}, 154 | ), 155 | "value": ( 156 | "STRING", 157 | {"default": "[1.0, 2.0, 3.0]", "multiline": False}, 158 | ), 159 | } 160 | } 161 | RETURN_TYPES = ("FLOAT_LIST",) 162 | FUNCTION = "execute" 163 | 164 | def execute(self, key: str, value): 165 | value = json.loads(value) 166 | return (value,) 167 | 168 | 169 | @register_node("RCSendLatent", "Remote Call: Send Latent") 170 | class _: 171 | CATEGORY = "jamesWalker55/rc" 172 | INPUT_TYPES = lambda: { 173 | "required": { 174 | "key": ( 175 | "STRING", 176 | {"default": "input_latent", "multiline": False}, 177 | ), 178 | "latent": ("LATENT",), 179 | } 180 | } 181 | FUNCTION = "execute" 182 | RETURN_TYPES = () 183 | OUTPUT_NODE = True 184 | 185 | def execute(self, key: str, latent: str): 186 | blob = torch_save_to_blob(latent["samples"]) 187 | 188 | return { 189 | "ui": { 190 | "jw_rc": ( 191 | { 192 | "type": "latent", 193 | "value": blob.decode(), 194 | }, 195 | ), 196 | } 197 | } 198 | -------------------------------------------------------------------------------- /comfyui_sound.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | from pathlib import Path 5 | from typing import Literal, TypedDict 6 | 7 | import soundfile as sf 8 | import torch 9 | import torchaudio 10 | 11 | NODE_CLASS_MAPPINGS = {} 12 | NODE_DISPLAY_NAME_MAPPINGS = {} 13 | 14 | 15 | def register_node(identifier: str, display_name: str): 16 | def decorator(cls): 17 | NODE_CLASS_MAPPINGS[identifier] = cls 18 | NODE_DISPLAY_NAME_MAPPINGS[identifier] = display_name 19 | 20 | return cls 21 | 22 | return decorator 23 | 24 | 25 | FLOAT_MAX = 99999999999999999.0 26 | 27 | 28 | class Audio(TypedDict): 29 | sample_rate: int 30 | # .shape => torch.Size([1, 2, 1403182]) 31 | waveform: torch.Tensor 32 | 33 | 34 | def scalar_to_db(scalar: float): 35 | return 20 * math.log10(scalar) 36 | 37 | 38 | def db_to_scalar(db: float): 39 | return 10 ** (db / 20) 40 | 41 | 42 | def load_audio( 43 | path: str | Path, 44 | sr: None | int | float = None, 45 | offset: float = 0.0, 46 | duration: float | None = None, 47 | make_stereo: bool = True, 48 | ) -> Audio: 49 | import librosa 50 | 51 | mix, sr = librosa.load(path, sr=sr, mono=False, offset=offset, duration=duration) 52 | mix = torch.from_numpy(mix) 53 | 54 | # If stereo, shape will be: 55 | # torch.Size([2, 1403182]) 56 | # If mono, shape will be: 57 | # torch.Size([1403182]) 58 | # 59 | # Ensure shape is [channels, data] 60 | if len(mix.shape) == 1: 61 | mix = torch.stack([mix], dim=0) 62 | assert len(mix.shape) == 2 63 | 64 | # Convert mono to stereo if needed 65 | if make_stereo: 66 | if mix.shape[0] == 1: 67 | mix = torch.cat([mix, mix], dim=0) 68 | elif mix.shape[0] == 2: 69 | pass 70 | else: 71 | raise ValueError( 72 | f"Input audio has {mix.shape[0]} channels, cannot convert to stereo (2 channels)" 73 | ) 74 | 75 | # Add extra dimension for batch size 76 | 77 | # shape => torch.Size([2, 1403182]) 78 | mix = torch.unsqueeze(mix, 0) 79 | # shape => torch.Size([1, 2, 1403182]) 80 | 81 | return { 82 | "sample_rate": round(sr), 83 | "waveform": mix, 84 | } 85 | 86 | 87 | def save_audio(path: str | Path, mix: torch.Tensor, sr): 88 | path = str(path) 89 | 90 | # make sure tensor has shape [channels, data] 91 | if len(mix.shape) == 3: 92 | if mix.shape[0] > 1: 93 | raise ValueError("Audio batch size is more than 1") 94 | mix = mix[0] 95 | elif len(mix.shape) == 2: 96 | pass 97 | elif len(mix.shape) == 1: 98 | mix = torch.unsqueeze(mix, 0) 99 | else: 100 | raise ValueError(f"Invalid tensor shape: {mix.shape}") 101 | 102 | subtype = "FLOAT" if path.lower().endswith("wav") else None 103 | sf.write(path, mix.T, sr, subtype=subtype) 104 | 105 | 106 | def write_audio_comment(path: str | Path, comment: str): 107 | try: 108 | from mediafile import MediaFile 109 | except ImportError as e: 110 | print( 111 | "[WARN] Failed to import `mediafile`, saved audio files will not have metadata" 112 | ) 113 | return 114 | 115 | f = MediaFile(path) 116 | f.comments = comment 117 | f.save() 118 | 119 | 120 | @register_node("JWLoadAudio", "Audio Load") 121 | class _: 122 | CATEGORY = "jamesWalker55" 123 | INPUT_TYPES = lambda: { 124 | "required": { 125 | "path": ("STRING", {"default": "./audio.mp3"}), 126 | "gain_db": ("FLOAT", {"default": 0, "min": -100, "max": 100}), 127 | "offset_seconds": ("FLOAT", {"default": 0, "min": 0, "max": FLOAT_MAX}), 128 | "duration_seconds": ("FLOAT", {"default": 0, "min": 0, "max": FLOAT_MAX}), 129 | "resample_to_hz": ("FLOAT", {"default": 0, "min": 0, "max": FLOAT_MAX}), 130 | "make_stereo": ("BOOLEAN", {"default": True}), 131 | } 132 | } 133 | RETURN_TYPES = ("AUDIO",) 134 | FUNCTION = "execute" 135 | 136 | def execute( 137 | self, 138 | path: str, 139 | gain_db: float, 140 | offset_seconds: float, 141 | duration_seconds: float, 142 | resample_to_hz: float, 143 | make_stereo: bool, 144 | ) -> tuple[Audio]: 145 | rv = load_audio( 146 | path, 147 | sr=resample_to_hz if resample_to_hz > 0 else None, 148 | offset=offset_seconds, 149 | duration=duration_seconds if duration_seconds > 0 else None, 150 | make_stereo=make_stereo, 151 | ) 152 | if gain_db != 0.0: 153 | gain_scalar = db_to_scalar(gain_db) 154 | rv["waveform"] = gain_scalar * rv["waveform"] 155 | 156 | return (rv,) 157 | 158 | @classmethod 159 | def IS_CHANGED( 160 | cls, 161 | path: str, 162 | *args, 163 | ): 164 | if os.path.exists(path): 165 | mtime = os.path.getmtime(path) 166 | else: 167 | mtime = None 168 | 169 | return (mtime, path, *args) 170 | 171 | 172 | @register_node("JWAudioBlend", "Audio Blend") 173 | class _: 174 | CATEGORY = "jamesWalker55" 175 | INPUT_TYPES = lambda: { 176 | "required": { 177 | "a": ("AUDIO",), 178 | "b": ("AUDIO",), 179 | "ratio": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0}), 180 | "if_durations_differ": ( 181 | ("use_longest", "use_shortest"), 182 | {"default": "use_longest"}, 183 | ), 184 | "if_samplerates_differ": ( 185 | ("use_highest", "use_lowest"), 186 | {"default": "use_highest"}, 187 | ), 188 | } 189 | } 190 | RETURN_TYPES = ("AUDIO",) 191 | FUNCTION = "execute" 192 | 193 | def execute( 194 | self, 195 | a: Audio, 196 | b: Audio, 197 | ratio: float, 198 | if_durations_differ: Literal["use_longest", "use_shortest"], 199 | if_samplerates_differ: Literal["use_highest", "use_lowest"], 200 | ) -> tuple[Audio]: 201 | import librosa 202 | 203 | # shallow clone audios 204 | a = {**a} 205 | b = {**b} 206 | 207 | # # if they have different batch size, attempt to resolve them 208 | # if a["waveform"].shape[0] != b["waveform"].shape[0]: 209 | # pass 210 | 211 | # if they have different channels, attempt to resolve them 212 | if a["waveform"].shape[1] != b["waveform"].shape[1]: 213 | # if one of them is mono, distribute it 214 | if a["waveform"].shape[1] == 1: 215 | a["waveform"] = a["waveform"].expand(-1, b["waveform"].shape[1]) 216 | elif b["waveform"].shape[1] == 1: 217 | b["waveform"] = b["waveform"].expand(-1, a["waveform"].shape[1]) 218 | 219 | # ensure audio has same sample rate 220 | if a["sample_rate"] != b["sample_rate"]: 221 | # determine which rate to use 222 | if if_samplerates_differ == "use_highest": 223 | sr = max(a["sample_rate"], b["sample_rate"]) 224 | elif if_samplerates_differ == "use_lowest": 225 | sr = min(a["sample_rate"], b["sample_rate"]) 226 | else: 227 | raise NotImplementedError(if_samplerates_differ) 228 | 229 | # do the resampling 230 | if a["sample_rate"] != sr: 231 | a["waveform"] = torchaudio.functional.resample( 232 | a["waveform"], a["sample_rate"], sr 233 | ) 234 | if b["sample_rate"] != sr: 235 | b["waveform"] = torchaudio.functional.resample( 236 | b["waveform"], b["sample_rate"], sr 237 | ) 238 | 239 | # ensure input has same length 240 | if a["waveform"].shape[-1] != b["waveform"].shape[-1]: 241 | # determine which duration to use 242 | if if_durations_differ == "use_longest": 243 | duration = max(a["waveform"].shape[-1], b["waveform"].shape[-1]) 244 | elif if_durations_differ == "use_shortest": 245 | duration = min(a["waveform"].shape[-1], b["waveform"].shape[-1]) 246 | else: 247 | raise NotImplementedError(if_samplerates_differ) 248 | 249 | def waveform_with_duration(wave: torch.Tensor, new_duration: int): 250 | batch, channels, original_duration = wave.shape 251 | if original_duration >= new_duration: 252 | return wave[:, :, new_duration] 253 | else: 254 | rv = torch.zeros(batch, channels, new_duration) 255 | rv[:, :, original_duration] = wave[:, :, original_duration] 256 | return rv 257 | 258 | # do the chopping 259 | if a["waveform"].shape[-1] != duration: 260 | a["waveform"] = waveform_with_duration(a["waveform"], duration) 261 | if b["waveform"].shape[-1] != duration: 262 | b["waveform"] = waveform_with_duration(b["waveform"], duration) 263 | 264 | rv: Audio = { 265 | "sample_rate": sr, 266 | "waveform": a["waveform"] * (1.0 - ratio) + a["waveform"] * ratio, 267 | } 268 | 269 | return (rv,) 270 | 271 | 272 | class ResultItem(TypedDict): 273 | filename: str 274 | subfolder: str 275 | type: Literal["output"] 276 | 277 | 278 | @register_node("JWAudioSaveToPath", "Audio Save to Path") 279 | class _: 280 | CATEGORY = "jamesWalker55" 281 | INPUT_TYPES = lambda: { 282 | "required": { 283 | "audio": ("AUDIO",), 284 | "path": ("STRING", {"default": "./audio.mp3"}), 285 | "overwrite": ("BOOLEAN", {"default": True}), 286 | }, 287 | "hidden": { 288 | "prompt": "PROMPT", 289 | "extra_pnginfo": "EXTRA_PNGINFO", 290 | }, 291 | } 292 | RETURN_TYPES = () 293 | FUNCTION = "execute" 294 | OUTPUT_NODE = True 295 | 296 | def execute( 297 | self, 298 | path: str | Path, 299 | audio: Audio, 300 | overwrite: bool, 301 | prompt=None, 302 | extra_pnginfo=None, 303 | ): 304 | path = Path(path) 305 | 306 | path.parent.mkdir(exist_ok=True) 307 | 308 | metadata = {**(extra_pnginfo or {})} 309 | if prompt is not None: 310 | metadata["prompt"] = prompt 311 | metadata_str = json.dumps(metadata) 312 | 313 | results: list[ResultItem] = [] 314 | 315 | if audio["waveform"].shape[0] == 1: 316 | # batch has 1 audio only 317 | if overwrite or not path.exists(): 318 | save_audio( 319 | path, 320 | audio["waveform"][0], 321 | audio["sample_rate"], 322 | ) 323 | write_audio_comment(path, metadata_str) 324 | results.append( 325 | { 326 | "filename": path.name, 327 | "subfolder": str(path.parent), 328 | "type": "output", 329 | } 330 | ) 331 | else: 332 | # batch has multiple images 333 | for i, subwaveform in enumerate(audio["waveform"]): 334 | subpath = path.with_stem(f"{path.stem}-{i}") 335 | if overwrite or not path.exists(): 336 | save_audio( 337 | subpath, 338 | subwaveform, 339 | audio["sample_rate"], 340 | ) 341 | write_audio_comment(subpath, metadata_str) 342 | results.append( 343 | { 344 | "filename": subpath.name, 345 | "subfolder": str(subpath.parent), 346 | "type": "output", 347 | } 348 | ) 349 | 350 | return {"ui": {"audio": results}} 351 | -------------------------------------------------------------------------------- /comfyui_string_list.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | NODE_CLASS_MAPPINGS = {} 5 | NODE_DISPLAY_NAME_MAPPINGS = {} 6 | 7 | 8 | def register_node(identifier: str, display_name: str): 9 | def decorator(cls): 10 | NODE_CLASS_MAPPINGS[identifier] = cls 11 | NODE_DISPLAY_NAME_MAPPINGS[identifier] = display_name 12 | 13 | return cls 14 | 15 | return decorator 16 | 17 | 18 | @register_node("JWStringListFromString", "String List From String") 19 | class _: 20 | CATEGORY = "jamesWalker55" 21 | INPUT_TYPES = lambda: { 22 | "required": { 23 | "value": ("STRING", {"default": "", "multiline": False}), 24 | } 25 | } 26 | RETURN_TYPES = ("STRING_LIST",) 27 | FUNCTION = "execute" 28 | 29 | def execute(self, value: str): 30 | val = [value] 31 | return (val,) 32 | 33 | 34 | @register_node("JWStringListFromStrings", "String List From Strings") 35 | class _: 36 | CATEGORY = "jamesWalker55" 37 | INPUT_TYPES = lambda: { 38 | "required": { 39 | "a": ("STRING", {"default": "", "multiline": False}), 40 | "b": ("STRING", {"default": "", "multiline": False}), 41 | } 42 | } 43 | RETURN_TYPES = ("STRING_LIST",) 44 | FUNCTION = "execute" 45 | 46 | def execute(self, a: str, b: str): 47 | val = [a, b] 48 | return (val,) 49 | 50 | 51 | @register_node("JWStringListJoin", "Join String List") 52 | class _: 53 | CATEGORY = "jamesWalker55" 54 | INPUT_TYPES = lambda: { 55 | "required": { 56 | "a": ("STRING_LIST",), 57 | "b": ("STRING_LIST",), 58 | } 59 | } 60 | RETURN_TYPES = ("STRING_LIST",) 61 | FUNCTION = "execute" 62 | 63 | def execute(self, a: list[str], b: list[str]): 64 | val = a + b 65 | return (val,) 66 | 67 | 68 | @register_node("JWStringListRepeat", "Repeat String List") 69 | class _: 70 | CATEGORY = "jamesWalker55" 71 | INPUT_TYPES = lambda: { 72 | "required": { 73 | "string_list": ("STRING_LIST",), 74 | "repeats": ("INT", {"default": 1, "min": 0}), 75 | } 76 | } 77 | RETURN_TYPES = ("STRING_LIST",) 78 | FUNCTION = "execute" 79 | 80 | def execute(self, string_list: list[str], repeats: int): 81 | val = string_list * repeats 82 | return (val,) 83 | 84 | 85 | @register_node("JWStringListToString", "String List To String") 86 | class _: 87 | CATEGORY = "jamesWalker55" 88 | INPUT_TYPES = lambda: { 89 | "required": { 90 | "string_list": ("STRING_LIST",), 91 | "join": ( 92 | "STRING", 93 | {"default": "\n", "multiline": True, "dynamicPrompts": False}, 94 | ), 95 | } 96 | } 97 | RETURN_TYPES = ("STRING",) 98 | FUNCTION = "execute" 99 | 100 | def execute(self, string_list: list[str], join: str): 101 | val = join.join(string_list) 102 | return (val,) 103 | 104 | 105 | @register_node("JWStringListToFormatedString", "String List To Formatted String") 106 | class _: 107 | CATEGORY = "jamesWalker55" 108 | INPUT_TYPES = lambda: { 109 | "required": { 110 | "string_list": ("STRING_LIST",), 111 | "template": ( 112 | "STRING", 113 | {"default": "{}, {}, {}", "multiline": True, "dynamicPrompts": False}, 114 | ), 115 | } 116 | } 117 | RETURN_TYPES = ("STRING",) 118 | FUNCTION = "execute" 119 | 120 | def execute(self, string_list: list[str], join: str): 121 | val = join.join(string_list) 122 | return (val,) 123 | 124 | 125 | @register_node("JWStringListCLIPEncode", "String List CLIP Encode") 126 | class _: 127 | CATEGORY = "jamesWalker55" 128 | INPUT_TYPES = lambda: { 129 | "required": { 130 | "string_list": ("STRING_LIST",), 131 | "clip": ("CLIP",), 132 | } 133 | } 134 | RETURN_TYPES = ("CONDITIONING",) 135 | FUNCTION = "execute" 136 | 137 | def execute(self, string_list: list[str], clip): 138 | all_cond = [] 139 | all_pooled = [] 140 | 141 | for text in string_list: 142 | tokens = clip.tokenize(text) 143 | cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) 144 | # cond.shape => torch.Size([1, 77, 768]) 145 | # pooled.shape => torch.Size([1, 768]) 146 | all_cond.append(cond) 147 | all_pooled.append(pooled) 148 | 149 | all_cond = torch.cat(all_cond, dim=0) 150 | all_pooled = torch.cat(all_pooled, dim=0) 151 | return ([[all_cond, {"pooled_output": all_pooled}]],) 152 | -------------------------------------------------------------------------------- /comfyui_uncrop.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | import torch 4 | import torchvision.transforms.functional as F 5 | from torchvision.transforms import InterpolationMode 6 | 7 | NODE_CLASS_MAPPINGS = {} 8 | NODE_DISPLAY_NAME_MAPPINGS = {} 9 | 10 | 11 | def register_node(identifier: str, display_name: str): 12 | def decorator(cls): 13 | NODE_CLASS_MAPPINGS[identifier] = cls 14 | NODE_DISPLAY_NAME_MAPPINGS[identifier] = display_name 15 | 16 | return cls 17 | 18 | return decorator 19 | 20 | 21 | MAX_RESOLUTION = 8192 22 | 23 | 24 | def validate_bounds(img: torch.Tensor, x: int, y: int, w: int, h: int): 25 | _, img_h, img_w, _ = img.shape 26 | 27 | assert x >= 0 28 | assert y >= 0 29 | 30 | assert ( 31 | x + w <= img_w 32 | ), f"crop region out of bounds: crop {(x, y, w, h)} from image {(img_w, img_h)}" 33 | assert ( 34 | y + h <= img_h 35 | ), f"crop region out of bounds: crop {(x, y, w, h)} from image {(img_w, img_h)}" 36 | 37 | 38 | def crop_image(img: torch.Tensor, x: int, y: int, w: int, h: int): 39 | validate_bounds(img, x, y, w, h) 40 | 41 | to_x = x + w 42 | to_y = y + h 43 | return img[:, y:to_y, x:to_x, :] 44 | 45 | 46 | def resize_image(img: torch.Tensor, w: int, h: int): 47 | img = img.permute(0, 3, 1, 2) 48 | img = F.resize( 49 | img, 50 | (h, w), # type: ignore 51 | interpolation=InterpolationMode.BILINEAR, 52 | antialias=True, 53 | ) 54 | img = img.permute(0, 2, 3, 1) 55 | return img 56 | 57 | 58 | class CropRect(NamedTuple): 59 | x: int 60 | y: int 61 | width: int 62 | height: int 63 | 64 | 65 | @register_node("JWUncropNewRect", "Uncrop: New rect") 66 | class _: 67 | CATEGORY = "jamesWalker55" 68 | INPUT_TYPES = lambda: { 69 | "required": { 70 | "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), 71 | "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), 72 | "width": ( 73 | "INT", 74 | {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}, 75 | ), 76 | "height": ( 77 | "INT", 78 | {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}, 79 | ), 80 | } 81 | } 82 | RETURN_TYPES = ("CROP_RECT",) 83 | FUNCTION = "execute" 84 | 85 | def execute( 86 | self, 87 | x: int, 88 | y: int, 89 | width: int, 90 | height: int, 91 | ) -> tuple[CropRect]: 92 | return (CropRect(x, y, width, height),) 93 | 94 | 95 | @register_node("JWUncropCrop", "Uncrop: Crop") 96 | class _: 97 | CATEGORY = "jamesWalker55" 98 | INPUT_TYPES = lambda: { 99 | "required": { 100 | "image": ("IMAGE",), 101 | "resize_length": ("INT", {"default": 512, "min": 8, "step": 8}), 102 | "crop_rect": ("CROP_RECT",), 103 | } 104 | } 105 | RETURN_TYPES = ("IMAGE",) 106 | FUNCTION = "execute" 107 | 108 | def execute( 109 | self, 110 | image: torch.Tensor, 111 | resize_length: int, 112 | crop_rect: CropRect, 113 | ) -> tuple[torch.Tensor]: 114 | x, y, width, height = crop_rect 115 | 116 | # crop the image 117 | image = crop_image(image, x, y, width, height) 118 | 119 | shortest_side = min(width, height) 120 | scale_ratio = resize_length / shortest_side 121 | new_width = round(round(width * scale_ratio / 8) * 8) 122 | new_height = round(round(height * scale_ratio / 8) * 8) 123 | 124 | image = resize_image(image, new_width, new_height) 125 | 126 | return (image,) 127 | 128 | 129 | @register_node("JWUncropUncrop", "Uncrop: Uncrop") 130 | class _: 131 | CATEGORY = "jamesWalker55" 132 | INPUT_TYPES = lambda: { 133 | "required": { 134 | "original_image": ("IMAGE",), 135 | "cropped_image": ("IMAGE",), 136 | "cropped_mask": ("MASK",), 137 | "crop_rect": ("CROP_RECT",), 138 | } 139 | } 140 | RETURN_TYPES = ("IMAGE",) 141 | FUNCTION = "execute" 142 | 143 | def execute( 144 | self, 145 | original_image: torch.Tensor, 146 | cropped_image: torch.Tensor, 147 | cropped_mask: torch.Tensor, 148 | crop_rect: CropRect, 149 | ) -> tuple[torch.Tensor]: 150 | x, y, width, height = crop_rect 151 | 152 | validate_bounds(original_image, x, y, width, height) 153 | 154 | # resize cropped image if needed 155 | _, _h, _w, _ = cropped_image.shape 156 | if _w != width or _h != height: 157 | cropped_image = resize_image(cropped_image, width, height) 158 | 159 | # resize cropped mask if needed 160 | _h, _w = cropped_mask.shape[-2:] 161 | if _w != width or _h != height: 162 | cropped_mask = torch.reshape(cropped_mask, (1, _h, _w, 1)) 163 | cropped_mask = resize_image(cropped_mask, width, height) 164 | cropped_mask = torch.reshape(cropped_mask, (height, width)) 165 | 166 | to_x = x + width 167 | to_y = y + height 168 | 169 | # https://easings.net/#easeOutQuint 170 | weighted_mask = 1 - (1 - cropped_mask) ** 5 171 | 172 | # blend original image with cropped image using mask 173 | cropped_image = original_image[:, y:to_y, x:to_x, :] * ( 174 | 1 - weighted_mask.view(1, *weighted_mask.shape, 1) 175 | ) + cropped_image * weighted_mask.view(1, *weighted_mask.shape, 1) 176 | 177 | # paste cropped image into original image 178 | original_image = original_image.clone() 179 | original_image[:, y:to_y, x:to_x, :] = cropped_image 180 | 181 | return (original_image,) 182 | --------------------------------------------------------------------------------