├── .github └── workflows │ └── publish.yml ├── .gitignore ├── .prettierrc ├── README.md ├── __init__.py ├── config.py ├── modules ├── animatediff │ └── __init__.py ├── controlnet │ ├── __init__.py │ ├── advanced.py │ └── preprocessor.py ├── fooocus │ ├── __init__.py │ ├── anisotropic.py │ ├── efficient.py │ └── patch.py ├── image_utils.py ├── impact │ ├── __init__.py │ └── facedetailer.py ├── inpaint │ ├── __init__.py │ ├── lama │ │ └── __init__.py │ ├── nodes.py │ └── sam │ │ └── nodes.py ├── interrogate │ ├── .gitignore │ ├── __init__.py │ ├── blip_node.py │ ├── configs │ │ └── med_config.json │ ├── danbooru.py │ ├── models │ │ ├── __init__.py │ │ ├── blip.py │ │ ├── blip_itm.py │ │ ├── blip_nlvr.py │ │ ├── blip_pretrain.py │ │ ├── blip_retrieval.py │ │ ├── blip_vqa.py │ │ ├── deepbooru_model.py │ │ ├── med.py │ │ ├── nlvr_encoder.py │ │ └── vit.py │ └── transform │ │ └── randaugment.py ├── ip_adapter_nodes.py ├── isnet │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ ├── isnet.py │ │ └── isnet_dis.py │ └── segmenter.py ├── llm │ ├── __init__.py │ └── chat.py ├── logger.py ├── masking.py ├── model_utils.py ├── nodes.py ├── postprocessing │ ├── __init__.py │ ├── color_blend.py │ └── color_correct.py ├── sdxl_prompt_styler │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── sdxl_prompt_styler.py │ ├── sdxl_styles_base.json │ ├── sdxl_styles_sai.json │ └── sdxl_styles_twri.json ├── utility_nodes.py ├── utils.py └── video │ └── __init__.py ├── pyproject.toml ├── requirements.txt └── web ├── text-switch-case.js ├── upload.js └── utils.js /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | jobs: 11 | publish-node: 12 | name: Publish Custom Node to registry 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Check out code 16 | uses: actions/checkout@v4 17 | - name: Publish Custom Node 18 | uses: Comfy-Org/publish-node-action@main 19 | with: 20 | ## Add your own personal access token to your Github Repository secrets and reference it here. 21 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.ckpt 3 | *.safetensors 4 | *.pth 5 | /ESRGAN/* 6 | /SwinIR/* 7 | /repositories 8 | /venv 9 | /tmp 10 | /model.ckpt 11 | /models/**/* 12 | /GFPGANv1.3.pth 13 | /gfpgan/weights/*.pth 14 | /ui-config.json 15 | /outputs 16 | /config.json 17 | /log 18 | /webui.settings.bat 19 | /embeddings 20 | /styles.csv 21 | /params.txt 22 | /styles.csv.bak 23 | /webui-user.bat 24 | /webui-user.sh 25 | /interrogate 26 | /user.css 27 | /.idea 28 | notification.mp3 29 | /SwinIR 30 | /textual_inversion 31 | .vscode 32 | /extensions 33 | /test/stdout.txt 34 | /test/stderr.txt 35 | /cache.json 36 | *.sql 37 | *.db 38 | *.sqlite 39 | *.sqlite3 40 | tailwind.* 41 | workflow.json 42 | workflow-*.json 43 | -------------------------------------------------------------------------------- /.prettierrc: -------------------------------------------------------------------------------- 1 | { 2 | "singleQuote": true, 3 | "jsxSingleQuote": false, 4 | "arrowParens": "always", 5 | "trailingComma": "all", 6 | "semi": true, 7 | "tabWidth": 2, 8 | "printWidth": 100 9 | } 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ArtVenture Custom Nodes 2 | 3 | A comprehensive set of custom nodes for ComfyUI, focusing on utilities for image processing, JSON manipulation, model operations and working with object via URLs 4 | 5 | ### Image Nodes 6 | 7 | #### LoadImageFromUrl 8 | 9 | Loads images from URLs. 10 | 11 | **Inputs:** 12 | 13 | - `image`: List of URLs or base64 image data, separated by new lines 14 | - `keep_alpha_channel`: Preserve alpha channel 15 | - `output_mode`: List or batch output. Use `List` if you have different resolutions. 16 | 17 | ![load image from url](https://github.com/user-attachments/assets/9da4840c-925e-4e0c-984a-5412282aee79) 18 | 19 | ### JSON Nodes 20 | 21 | #### LoadJsonFromUrl 22 | 23 | Loads JSON data from URLs. 24 | 25 | **Inputs:** 26 | 27 | - `url`: JSON URL 28 | - `print_to_console`: Print JSON to console 29 | 30 | #### LoadJsonFromText 31 | 32 | Loads JSON data from text. 33 | 34 | **Inputs:** 35 | 36 | - `data`: JSON text 37 | - `print_to_console`: Print JSON to console 38 | 39 | #### Get<\*>FromJson 40 | 41 | Includes `GetObjectFromJson`, `GetTextFromJson`, `GetFloatFromJson`, `GetIntFromJson`, `GetBoolFromJson`. 42 | 43 | Use key format `key.[index].subkey.[sub_index]` to access nested objects. 44 | 45 | ![get data from json](https://github.com/user-attachments/assets/a71793d6-9661-441c-a15c-66b2dcaa7972) 46 | 47 | ### Utility Nodes 48 | 49 | #### StringToNumber 50 | 51 | Converts strings to numbers. 52 | 53 | **Inputs:** 54 | 55 | - `string`: Input string 56 | - `rounding`: Rounding method 57 | 58 | #### TextRandomMultiline 59 | 60 | Randomizes the order of lines in a multiline string. 61 | 62 | **Inputs:** 63 | 64 | - `text`: Input text 65 | - `amount`: Number of lines to randomize 66 | - `seed`: Random seed 67 | 68 | ![text random multiline](https://github.com/user-attachments/assets/86f811e3-579e-4ccc-81a3-e216cd851d3c) 69 | 70 | #### TextSwitchCase 71 | 72 | Switch between multiple cases based on a condition. 73 | 74 | **Inputs:** 75 | 76 | - `switch_cases`: Switch cases, separated by new lines 77 | - `condition`: Condition to switch on 78 | - `default_value`: Default value when no condition matches 79 | - `delimiter`: Delimiter between case and value, default is `:` 80 | 81 | The `switch_cases` format is `casevalue`, where `case` is the condition to match and `value` is the value to return when the condition matches. You can have new lines in the value to return multiple lines. 82 | 83 | ![text switch case](https://github.com/user-attachments/assets/4c5450a8-6a3a-4d3c-8c2a-c6e3a33cb95f) 84 | 85 | ### Inpainting Nodes 86 | 87 | #### PrepareImageAndMaskForInpaint 88 | 89 | Prepares images and masks for inpainting operations. It's to mimic the behavior of the inpainting in A1111. 90 | 91 | **Inputs:** 92 | 93 | - `image`: Input image tensor 94 | - `mask`: Input mask tensor 95 | - `mask_blur`: Blur amount for mask (0-64) 96 | - `inpaint_masked`: Whether to inpaint only the masked regions, otherwise it will inpaint the whole image. 97 | - `mask_padding`: Padding around mask (0-256) 98 | - `width`: Manually set inpaint area width. Leave 0 default to the masked area plus padding. (0-2048) 99 | - `height`: Manually set inpaint area height. (0-2048) 100 | 101 | **Outputs:** 102 | 103 | - `inpaint_image`: Processed image for inpainting 104 | - `inpaint_mask`: Processed mask 105 | - `overlay_image`: Preview overlay 106 | - `crop_region`: Crop coordinates (input of OverlayInpaintedImage) 107 | 108 | ![inpaiting prepare](https://github.com/user-attachments/assets/38e87c04-7a64-4a62-a462-054396b3de14) 109 | 110 | #### OverlayInpaintedImage 111 | 112 | Overlays inpainted images with original images. 113 | 114 | **Inputs:** 115 | 116 | - `inpainted`: Inpainted image 117 | - `overlay_image`: Original image 118 | - `crop_region`: Crop region coordinates 119 | 120 | **Outputs:** 121 | 122 | - `IMAGE`: Final composited image 123 | 124 | #### LaMaInpaint 125 | 126 | Remove objects from images using LaMa model. 127 | 128 | ![lama remove object](https://github.com/user-attachments/assets/c28bbd8b-d55f-4fa5-bbc9-ace267382bd0) 129 | 130 | ### LLM Nodes 131 | 132 | #### LLMApiConfig 133 | 134 | Configures generic LLM API parameters. 135 | 136 | **Inputs:** 137 | 138 | - `model`: Model name (GPT-3.5, GPT-4, etc) 139 | - `max_token`: Maximum tokens 140 | - `temperature`: Temperature parameter 141 | 142 | #### OpenAIApi 143 | 144 | Configures OpenAI API access. 145 | 146 | **Inputs:** 147 | 148 | - `openai_api_key`: OpenAI API key 149 | - `endpoint`: API endpoint URL 150 | 151 | ### Claude API Nodes 152 | 153 | #### ClaudeApi 154 | 155 | Configures Anthropic Claude API access. 156 | 157 | **Inputs:** 158 | 159 | - `claude_api_key`: Claude API key 160 | - `endpoint`: API endpoint 161 | - `version`: API version 162 | 163 | #### AwsBedrockClaudeApi 164 | 165 | Configures AWS Bedrock Claude API access. 166 | 167 | **Inputs:** 168 | 169 | - `aws_access_key_id`: AWS access key 170 | - `aws_secret_access_key`: AWS secret key 171 | - `region`: AWS region 172 | - `version`: API version 173 | 174 | #### AwsBedrockMistralApi 175 | 176 | Configures AWS Bedrock Mistral API access. 177 | 178 | **Inputs:** 179 | 180 | - `aws_access_key_id`: AWS access key 181 | - `aws_secret_access_key`: AWS secret key 182 | - `region`: AWS region 183 | 184 | #### LLMMessage 185 | 186 | Creates a message for LLM interaction. 187 | 188 | **Inputs:** 189 | 190 | - `role`: Message role (system/user/assistant) 191 | - `text`: Message content 192 | - `image`: Optional image input 193 | - `messages`: Previous message history 194 | 195 | #### LLMChat 196 | 197 | Handles chat interactions with LLMs. 198 | 199 | **Inputs:** 200 | 201 | - `messages`: Message history 202 | - `api`: LLM API configuration 203 | - `config`: Model configuration 204 | - `seed`: Random seed 205 | 206 | #### LLMCompletion 207 | 208 | Handles completion requests to LLMs. 209 | 210 | **Inputs:** 211 | 212 | - `prompt`: Input prompt 213 | - `api`: LLM API configuration 214 | - `config`: Model configuration 215 | - `seed`: Random seed 216 | 217 | ![Screenshot 2024-10-30 at 11 20 12](https://github.com/user-attachments/assets/45b8d4fd-57cd-4bd9-8274-d3e6ac4ef938) 218 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import folder_paths 3 | 4 | if not 'saved_prompts' in folder_paths.folder_names_and_paths: 5 | folder_paths.folder_names_and_paths['saved_prompts'] = ([], set(['.txt'])) 6 | 7 | custom_nodes = folder_paths.get_folder_paths("custom_nodes") 8 | for dir in custom_nodes: 9 | if dir not in sys.path: 10 | print("Adding", dir, "to sys.path") 11 | sys.path.append(dir) 12 | 13 | from .modules.nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 14 | 15 | WEB_DIRECTORY = "./web" 16 | 17 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"] 18 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import inspect 4 | from typing import Dict 5 | 6 | from server import PromptServer 7 | 8 | from .modules.logger import logger 9 | 10 | comfy_dir = os.path.dirname(inspect.getfile(PromptServer)) 11 | ext_dir = os.path.dirname(os.path.realpath(__file__)) 12 | config_path = os.path.join(ext_dir, "config.json") 13 | 14 | 15 | def __get_dir(root: str, subpath=None, mkdir=False): 16 | dir = root 17 | if subpath is not None: 18 | dir = os.path.join(dir, subpath) 19 | 20 | dir = os.path.abspath(dir) 21 | 22 | if mkdir and not os.path.exists(dir): 23 | os.makedirs(dir) 24 | return dir 25 | 26 | 27 | def get_ext_dir(subpath=None, mkdir=False): 28 | return __get_dir(ext_dir, subpath, mkdir) 29 | 30 | 31 | def get_comfy_dir(subpath=None, mkdir=False): 32 | return __get_dir(comfy_dir, subpath, mkdir) 33 | 34 | 35 | def write_config(config): 36 | with open(config_path, "w") as f: 37 | json.dump(config, f, indent=4) 38 | 39 | 40 | def load_config() -> Dict: 41 | default_config = { 42 | "av_endpoint": "https://api.artventure.ai", 43 | "av_token": "", 44 | "runner_enabled": False, 45 | "remove_runner_images_after_upload": False, 46 | } 47 | 48 | if not os.path.isfile(config_path): 49 | logger.info("Config file not found, creating...") 50 | write_config(default_config) 51 | 52 | with open(config_path, "r") as f: 53 | config = json.load(f) 54 | 55 | need_update = False 56 | for key, value in default_config.items(): 57 | if key not in config: 58 | config[key] = value 59 | need_update = True 60 | 61 | if need_update: 62 | write_config(config) 63 | 64 | logger.debug(f"Loaded config {config}") 65 | return config 66 | 67 | 68 | config = load_config() 69 | -------------------------------------------------------------------------------- /modules/animatediff/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable 3 | 4 | import folder_paths 5 | 6 | from ..utils import load_module 7 | 8 | custom_nodes = folder_paths.get_folder_paths("custom_nodes") 9 | animatediff_dir_names = ["AnimateDiff", "comfyui-animatediff"] 10 | 11 | NODE_CLASS_MAPPINGS = {} 12 | NODE_DISPLAY_NAME_MAPPINGS = {} 13 | 14 | try: 15 | module_path = None 16 | 17 | for custom_node in custom_nodes: 18 | custom_node = custom_node if not os.path.islink(custom_node) else os.readlink(custom_node) 19 | for module_dir in animatediff_dir_names: 20 | if module_dir in os.listdir(custom_node): 21 | module_path = os.path.abspath(os.path.join(custom_node, module_dir)) 22 | break 23 | 24 | if module_path is None: 25 | raise Exception("Could not find AnimateDiff nodes") 26 | 27 | module_path = os.path.join(module_path, "animatediff/sliding_schedule.py") 28 | module = load_module(module_path) 29 | print("Loaded AnimateDiff from", module_path) 30 | 31 | get_context_scheduler: Callable = module.get_context_scheduler 32 | 33 | class AnimateDiffFrameCalculator: 34 | @classmethod 35 | def INPUT_TYPES(s): 36 | return { 37 | "required": { 38 | "frame_rate": ("INT", {"default": 8, "min": 1, "max": 24, "step": 1}), 39 | "duration": ("INT", {"default": 2, "min": 1, "max": 10000, "step": 1}), 40 | "sliding_window": ("SLIDING_WINDOW_OPTS",), 41 | } 42 | } 43 | 44 | RETURN_TYPES = ("INT", "INT", "INT", "INT") 45 | RETURN_NAMES = ("frame_number", "_1/2-1_index", "_1/2_index", "end_index") 46 | FUNCTION = "calculate" 47 | CATEGORY = "Animate Diff" 48 | 49 | def get_batch_count(self, frame_number, context_scheduler, ctx): 50 | batches = list( 51 | context_scheduler( 52 | 0, 53 | 0, 54 | frame_number, 55 | ctx.context_length, 56 | ctx.context_stride, 57 | ctx.context_overlap, 58 | ctx.closed_loop, 59 | ) 60 | ) 61 | batch_count = len(batches) 62 | if len(batches[-1]) == 0: 63 | batch_count -= 1 64 | 65 | return batch_count 66 | 67 | def calculate(self, frame_rate: int, duration: int, sliding_window): 68 | frame_number = frame_rate * duration 69 | 70 | ctx = sliding_window 71 | context_scheduler = get_context_scheduler(ctx.context_schedule) 72 | batch_count = self.get_batch_count(frame_number, context_scheduler, ctx) 73 | 74 | while True: 75 | next_batch_count = self.get_batch_count(frame_number + 1, context_scheduler, ctx) 76 | if next_batch_count > batch_count: 77 | break 78 | 79 | frame_number += 1 80 | 81 | snd_half_start = frame_number // 2 + frame_number % 2 82 | fst_half_end = snd_half_start - 1 83 | return (frame_number, fst_half_end, snd_half_start, frame_number - 1) 84 | 85 | NODE_CLASS_MAPPINGS["AnimateDiffFrameCalculator"] = AnimateDiffFrameCalculator 86 | NODE_DISPLAY_NAME_MAPPINGS["AnimateDiffFrameCalculator"] = "Animate Diff Frame Calculator" 87 | 88 | except Exception as e: 89 | print(e) 90 | -------------------------------------------------------------------------------- /modules/controlnet/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import folder_paths 4 | from nodes import ControlNetLoader, ControlNetApply, ControlNetApplyAdvanced 5 | 6 | from .preprocessor import preprocessors, apply_preprocessor 7 | from .advanced import comfy_load_controlnet 8 | 9 | 10 | def load_controlnet(control_net_name, control_net_override="None", timestep_keyframe=None): 11 | if control_net_override != "None": 12 | if control_net_override not in folder_paths.get_filename_list("controlnet"): 13 | print(f"Warning: Not found ControlNet model {control_net_override}. Use {control_net_name} instead.") 14 | else: 15 | control_net_name = control_net_override 16 | 17 | if control_net_name == "None": 18 | return None 19 | 20 | return comfy_load_controlnet(control_net_name, timestep_keyframe=timestep_keyframe) 21 | 22 | 23 | def detect_controlnet(preprocessor: str, sd_version: str): 24 | controlnets = folder_paths.get_filename_list("controlnet") 25 | controlnets = filter(lambda x: sd_version in x, controlnets) 26 | if sd_version == "sdxl": 27 | controlnets = filter(lambda x: "t2i" not in x, controlnets) 28 | controlnets = filter(lambda x: "lllite" not in x, controlnets) 29 | 30 | control_net_name = "None" 31 | if preprocessor in {"canny", "scribble", "mlsd"}: 32 | control_net_name = next((c for c in controlnets if preprocessor in c), "None") 33 | if preprocessor in {"scribble", "scribble_hed"}: 34 | control_net_name = next((c for c in controlnets if "scribble" in c), "None") 35 | if preprocessor in {"lineart", "lineart_coarse"}: 36 | control_net_name = next((c for c in controlnets if "lineart." in c), "None") 37 | if preprocessor in {"lineart_anime", "lineart_manga"}: 38 | control_net_name = next((c for c in controlnets if "lineart_anime" in c), "None") 39 | if preprocessor in {"hed", "hed_safe", "pidi", "pidi_safe"}: 40 | control_net_name = next((c for c in controlnets if "softedge" in c), "None") 41 | if preprocessor in {"pose", "openpose", "dwpose"}: 42 | control_net_name = next((c for c in controlnets if "openpose" in c), "None") 43 | if preprocessor in {"normalmap_bae", "normalmap_midas"}: 44 | control_net_name = next((c for c in controlnets if "normalbae" in c), "None") 45 | if preprocessor in {"depth", "depth_midas", "depth_zoe"}: 46 | control_net_name = next((c for c in controlnets if "depth" in c), "None") 47 | if preprocessor in {"seg_ofcoco", "seg_ofade20k", "seg_ufade20k"}: 48 | control_net_name = next((c for c in controlnets if "seg" in c), "None") 49 | 50 | if preprocessor in {"tile"}: 51 | control_net_name = next((c for c in controlnets if "tile" in c), "None") 52 | 53 | return control_net_name 54 | 55 | 56 | class AVControlNetLoader(ControlNetLoader): 57 | @classmethod 58 | def INPUT_TYPES(s): 59 | return { 60 | "required": {"control_net_name": (folder_paths.get_filename_list("controlnet"),)}, 61 | "optional": { 62 | "control_net_override": ("STRING", {"default": "None"}), 63 | "timestep_keyframe": ("TIMESTEP_KEYFRAME",), 64 | }, 65 | } 66 | 67 | RETURN_TYPES = ("CONTROL_NET",) 68 | FUNCTION = "load_controlnet" 69 | CATEGORY = "Art Venture/Loaders" 70 | 71 | def load_controlnet(self, control_net_name, control_net_override="None", timestep_keyframe=None): 72 | return load_controlnet(control_net_name, control_net_override, timestep_keyframe=timestep_keyframe) 73 | 74 | 75 | class AV_ControlNetPreprocessor: 76 | @classmethod 77 | def INPUT_TYPES(s): 78 | return { 79 | "required": { 80 | "image": ("IMAGE",), 81 | "preprocessor": (["None"] + preprocessors,), 82 | "sd_version": (["sd15", "sdxl"],), 83 | }, 84 | "optional": { 85 | "resolution": ("INT", {"default": 512, "min": 64, "max": 2048, "step": 64}), 86 | "preprocessor_override": ("STRING", {"default": "None"}), 87 | }, 88 | } 89 | 90 | RETURN_TYPES = ("IMAGE", "STRING") 91 | RETURN_NAMES = ("IMAGE", "CNET_NAME") 92 | FUNCTION = "detect_controlnet" 93 | CATEGORY = "Art Venture/Loaders" 94 | 95 | def detect_controlnet(self, image, preprocessor, sd_version, resolution=512, preprocessor_override="None"): 96 | if preprocessor_override != "None": 97 | if preprocessor_override not in preprocessors: 98 | print( 99 | f"Warning: Not found ControlNet preprocessor {preprocessor_override}. Use {preprocessor} instead." 100 | ) 101 | else: 102 | preprocessor = preprocessor_override 103 | 104 | image = apply_preprocessor(image, preprocessor, resolution=resolution) 105 | control_net_name = detect_controlnet(preprocessor, sd_version) 106 | 107 | return (image, control_net_name) 108 | 109 | 110 | class AVControlNetEfficientStacker: 111 | controlnets = folder_paths.get_filename_list("controlnet") 112 | 113 | @classmethod 114 | def INPUT_TYPES(s): 115 | return { 116 | "required": { 117 | "control_net_name": (["None", "Auto: sd15", "Auto: sdxl", "Auto: sdxl_t2i"] + s.controlnets,), 118 | "image": ("IMAGE",), 119 | "strength": ( 120 | "FLOAT", 121 | {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}, 122 | ), 123 | "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), 124 | "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), 125 | "preprocessor": (["None"] + preprocessors,), 126 | }, 127 | "optional": { 128 | "cnet_stack": ("CONTROL_NET_STACK",), 129 | "control_net_override": ("STRING", {"default": "None"}), 130 | "timestep_keyframe": ("TIMESTEP_KEYFRAME",), 131 | "resolution": ("INT", {"default": 512, "min": 64, "max": 2048, "step": 64}), 132 | "enabled": ("BOOLEAN", {"default": True}), 133 | }, 134 | } 135 | 136 | RETURN_TYPES = ("CONTROL_NET_STACK",) 137 | RETURN_NAMES = ("CNET_STACK",) 138 | FUNCTION = "control_net_stacker" 139 | CATEGORY = "Art Venture/Loaders" 140 | 141 | def control_net_stacker( 142 | self, 143 | control_net_name: str, 144 | image, 145 | strength, 146 | start_percent, 147 | end_percent, 148 | preprocessor, 149 | cnet_stack=None, 150 | control_net_override="None", 151 | timestep_keyframe=None, 152 | resolution=512, 153 | enabled=True, 154 | ): 155 | if not enabled: 156 | return (cnet_stack,) 157 | 158 | # If control_net_stack is None, initialize as an empty list 159 | if cnet_stack is None: 160 | cnet_stack = [] 161 | 162 | if control_net_name.startswith("Auto: "): 163 | assert preprocessor != "None", "preprocessor must be set when using Auto mode" 164 | 165 | sd_version = control_net_name[len("Auto: ") :] 166 | control_net_name = detect_controlnet(preprocessor, sd_version) 167 | 168 | control_net = load_controlnet(control_net_name, control_net_override, timestep_keyframe=timestep_keyframe) 169 | 170 | # Extend the control_net_stack with the new tuple 171 | if control_net is not None: 172 | image = apply_preprocessor(image, preprocessor, resolution=resolution) 173 | cnet_stack.extend([(control_net, image, strength, start_percent, end_percent)]) 174 | 175 | return (cnet_stack,) 176 | 177 | 178 | class AVControlNetEfficientStackerSimple(AVControlNetEfficientStacker): 179 | @classmethod 180 | def INPUT_TYPES(s): 181 | return { 182 | "required": { 183 | "control_net_name": (["None", "Auto: sd15", "Auto: sdxl", "Auto: sdxl_t2i"] + s.controlnets,), 184 | "image": ("IMAGE",), 185 | "strength": ( 186 | "FLOAT", 187 | {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}, 188 | ), 189 | "preprocessor": (["None"] + preprocessors,), 190 | }, 191 | "optional": { 192 | "cnet_stack": ("CONTROL_NET_STACK",), 193 | "control_net_override": ("STRING", {"default": "None"}), 194 | "timestep_keyframe": ("TIMESTEP_KEYFRAME",), 195 | "resolution": ("INT", {"default": 512, "min": 64, "max": 2048, "step": 64}), 196 | "enabled": ("BOOLEAN", {"default": True}), 197 | }, 198 | } 199 | 200 | FUNCTION = "control_net_stacker_simple" 201 | 202 | def control_net_stacker_simple( 203 | self, 204 | *args, 205 | **kwargs, 206 | ): 207 | return self.control_net_stacker(*args, start_percent=0.0, end_percent=1.0, **kwargs) 208 | 209 | 210 | class AVControlNetEfficientLoader(ControlNetApply): 211 | controlnets = folder_paths.get_filename_list("controlnet") 212 | 213 | @classmethod 214 | def INPUT_TYPES(s): 215 | return { 216 | "required": { 217 | "control_net_name": (["None"] + s.controlnets,), 218 | "conditioning": ("CONDITIONING",), 219 | "image": ("IMAGE",), 220 | "strength": ( 221 | "FLOAT", 222 | {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}, 223 | ), 224 | "preprocessor": (["None"] + preprocessors,), 225 | }, 226 | "optional": { 227 | "control_net_override": ("STRING", {"default": "None"}), 228 | "timestep_keyframe": ("TIMESTEP_KEYFRAME",), 229 | "resolution": ("INT", {"default": 512, "min": 64, "max": 2048, "step": 64}), 230 | "enabled": ("BOOLEAN", {"default": True}), 231 | }, 232 | } 233 | 234 | RETURN_TYPES = ("CONDITIONING",) 235 | FUNCTION = "load_controlnet" 236 | CATEGORY = "Art Venture/Loaders" 237 | 238 | def load_controlnet( 239 | self, 240 | control_net_name, 241 | conditioning, 242 | image, 243 | strength, 244 | preprocessor, 245 | control_net_override="None", 246 | timestep_keyframe=None, 247 | resolution=512, 248 | enabled=True, 249 | ): 250 | if not enabled: 251 | return (conditioning,) 252 | 253 | control_net = load_controlnet(control_net_name, control_net_override, timestep_keyframe=timestep_keyframe) 254 | if control_net is None: 255 | return (conditioning,) 256 | 257 | image = apply_preprocessor(image, preprocessor, resolution=resolution) 258 | 259 | return super().apply_controlnet(conditioning, control_net, image, strength) 260 | 261 | 262 | class AVControlNetEfficientLoaderAdvanced(ControlNetApplyAdvanced): 263 | controlnets = folder_paths.get_filename_list("controlnet") 264 | 265 | @classmethod 266 | def INPUT_TYPES(s): 267 | return { 268 | "required": { 269 | "control_net_name": (["None"] + s.controlnets,), 270 | "positive": ("CONDITIONING",), 271 | "negative": ("CONDITIONING",), 272 | "image": ("IMAGE",), 273 | "strength": ( 274 | "FLOAT", 275 | {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}, 276 | ), 277 | "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), 278 | "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), 279 | "preprocessor": (["None"] + preprocessors,), 280 | }, 281 | "optional": { 282 | "control_net_override": ("STRING", {"default": "None"}), 283 | "timestep_keyframe": ("TIMESTEP_KEYFRAME",), 284 | "resolution": ("INT", {"default": 512, "min": 64, "max": 2048, "step": 64}), 285 | "enabled": ("BOOLEAN", {"default": True}), 286 | }, 287 | } 288 | 289 | RETURN_TYPES = ("CONDITIONING", "CONDITIONING") 290 | RETURN_NAMES = ("positive", "negative") 291 | FUNCTION = "load_controlnet" 292 | CATEGORY = "Art Venture/Loaders" 293 | 294 | def load_controlnet( 295 | self, 296 | control_net_name, 297 | positive, 298 | negative, 299 | image, 300 | strength, 301 | start_percent, 302 | end_percent, 303 | preprocessor, 304 | control_net_override="None", 305 | timestep_keyframe=None, 306 | resolution=512, 307 | enabled=True, 308 | ): 309 | if not enabled: 310 | return (positive, negative) 311 | 312 | control_net = load_controlnet(control_net_name, control_net_override, timestep_keyframe=timestep_keyframe) 313 | if control_net is None: 314 | return (positive, negative) 315 | 316 | image = apply_preprocessor(image, preprocessor, resolution=resolution) 317 | 318 | return super().apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent) 319 | 320 | 321 | NODE_CLASS_MAPPINGS = { 322 | "AV_ControlNetLoader": AVControlNetLoader, 323 | "AV_ControlNetEfficientLoader": AVControlNetEfficientLoader, 324 | "AV_ControlNetEfficientLoaderAdvanced": AVControlNetEfficientLoaderAdvanced, 325 | "AV_ControlNetEfficientStacker": AVControlNetEfficientStacker, 326 | "AV_ControlNetEfficientStackerSimple": AVControlNetEfficientStackerSimple, 327 | "AV_ControlNetPreprocessor": AV_ControlNetPreprocessor, 328 | } 329 | 330 | NODE_DISPLAY_NAME_MAPPINGS = { 331 | "AV_ControlNetLoader": "ControlNet Loader", 332 | "AV_ControlNetEfficientLoader": "ControlNet Loader", 333 | "AV_ControlNetEfficientLoaderAdvanced": "ControlNet Loader Adv.", 334 | "AV_ControlNetEfficientStacker": "ControlNet Stacker Adv.", 335 | "AV_ControlNetEfficientStackerSimple": "ControlNet Stacker", 336 | "AV_ControlNetPreprocessor": "ControlNet Preprocessor", 337 | } 338 | -------------------------------------------------------------------------------- /modules/controlnet/advanced.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict 3 | 4 | import folder_paths 5 | import comfy.sd 6 | import comfy.controlnet 7 | 8 | from ..utils import load_module 9 | 10 | custom_nodes = folder_paths.get_folder_paths("custom_nodes") 11 | advanced_cnet_dir_names = ["AdvancedControlNet", "ComfyUI-Advanced-ControlNet"] 12 | 13 | 14 | def comfy_load_controlnet(control_net_name: str, **_): 15 | controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) 16 | return comfy.controlnet.load_controlnet(controlnet_path) 17 | 18 | 19 | try: 20 | module_path = None 21 | 22 | for custom_node in custom_nodes: 23 | custom_node = custom_node if not os.path.islink(custom_node) else os.readlink(custom_node) 24 | for module_dir in advanced_cnet_dir_names: 25 | if module_dir in os.listdir(custom_node): 26 | module_path = os.path.abspath(os.path.join(custom_node, module_dir)) 27 | break 28 | 29 | if module_path is None: 30 | raise Exception("Could not find AdvancedControlNet nodes") 31 | 32 | module = load_module(module_path) 33 | print("Loaded AdvancedControlNet nodes from", module_path) 34 | 35 | nodes: Dict = getattr(module, "NODE_CLASS_MAPPINGS") 36 | ControlNetLoaderAdvanced = nodes["ControlNetLoaderAdvanced"] 37 | 38 | loader = ControlNetLoaderAdvanced() 39 | 40 | def comfy_load_controlnet(control_net_name: str, timestep_keyframe=None): 41 | return loader.load_controlnet(control_net_name, timestep_keyframe=timestep_keyframe)[0] 42 | 43 | except Exception as e: 44 | print(e) 45 | -------------------------------------------------------------------------------- /modules/controlnet/preprocessor.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict 3 | 4 | import folder_paths 5 | 6 | from ..utils import load_module 7 | 8 | custom_nodes = folder_paths.get_folder_paths("custom_nodes") 9 | preprocessors_dir_names = ["ControlNetPreprocessors", "comfyui_controlnet_aux"] 10 | 11 | preprocessors: list[str] = [] 12 | _preprocessors_map = { 13 | "canny": "CannyEdgePreprocessor", 14 | "canny_pyra": "PyraCannyPreprocessor", 15 | "lineart": "LineArtPreprocessor", 16 | "lineart_anime": "AnimeLineArtPreprocessor", 17 | "lineart_manga": "Manga2Anime_LineArt_Preprocessor", 18 | "lineart_any": "AnyLineArtPreprocessor_aux", 19 | "scribble": "ScribblePreprocessor", 20 | "scribble_xdog": "Scribble_XDoG_Preprocessor", 21 | "scribble_pidi": "Scribble_PiDiNet_Preprocessor", 22 | "scribble_hed": "FakeScribblePreprocessor", 23 | "hed": "HEDPreprocessor", 24 | "pidi": "PiDiNetPreprocessor", 25 | "mlsd": "M-LSDPreprocessor", 26 | "pose": "DWPreprocessor", 27 | "openpose": "OpenposePreprocessor", 28 | "dwpose": "DWPreprocessor", 29 | "pose_dense": "DensePosePreprocessor", 30 | "pose_animal": "AnimalPosePreprocessor", 31 | "normalmap_bae": "BAE-NormalMapPreprocessor", 32 | "normalmap_dsine": "DSINE-NormalMapPreprocessor", 33 | "normalmap_midas": "MiDaS-NormalMapPreprocessor", 34 | "depth": "DepthAnythingV2Preprocessor", 35 | "depth_anything": "DepthAnythingPreprocessor", 36 | "depth_anything_v2": "DepthAnythingV2Preprocessor", 37 | "depth_anything_zoe": "Zoe_DepthAnythingPreprocessor", 38 | "depth_zoe": "Zoe-DepthMapPreprocessor", 39 | "depth_midas": "MiDaS-DepthMapPreprocessor", 40 | "depth_leres": "LeReS-DepthMapPreprocessor", 41 | "depth_metric3d": "Metric3D-DepthMapPreprocessor", 42 | "depth_meshgraphormer": "MeshGraphormer-DepthMapPreprocessor", 43 | "seg_ofcoco": "OneFormer-COCO-SemSegPreprocessor", 44 | "seg_ofade20k": "OneFormer-ADE20K-SemSegPreprocessor", 45 | "seg_ufade20k": "UniFormer-SemSegPreprocessor", 46 | "seg_animeface": "AnimeFace_SemSegPreprocessor", 47 | "shuffle": "ShufflePreprocessor", 48 | "teed": "TEEDPreprocessor", 49 | "color": "ColorPreprocessor", 50 | "sam": "SAMPreprocessor", 51 | "tile": "TilePreprocessor" 52 | } 53 | 54 | 55 | def apply_preprocessor(image, preprocessor, resolution=512): 56 | raise NotImplementedError("apply_preprocessor is not implemented") 57 | 58 | 59 | try: 60 | module_path = None 61 | 62 | for custom_node in custom_nodes: 63 | custom_node = custom_node if not os.path.islink(custom_node) else os.readlink(custom_node) 64 | for module_dir in preprocessors_dir_names: 65 | if module_dir in os.listdir(custom_node): 66 | module_path = os.path.abspath(os.path.join(custom_node, module_dir)) 67 | break 68 | 69 | if module_path is None: 70 | raise Exception("Could not find ControlNetPreprocessors nodes") 71 | 72 | module = load_module(module_path) 73 | print("Loaded ControlNetPreprocessors nodes from", module_path) 74 | 75 | nodes: Dict = getattr(module, "NODE_CLASS_MAPPINGS") 76 | available_preprocessors: list[str] = getattr(module, "PREPROCESSOR_OPTIONS") 77 | 78 | AIO_Preprocessor = nodes.get("AIO_Preprocessor", None) 79 | if AIO_Preprocessor is None: 80 | raise Exception("Could not find AIO_Preprocessor node") 81 | 82 | for name, preprocessor in _preprocessors_map.items(): 83 | if preprocessor in available_preprocessors: 84 | preprocessors.append(name) 85 | 86 | aio_preprocessor = AIO_Preprocessor() 87 | 88 | def apply_preprocessor(image, preprocessor, resolution=512): 89 | if preprocessor == "None": 90 | return image 91 | 92 | if preprocessor not in preprocessors: 93 | raise Exception(f"Preprocessor {preprocessor} is not implemented") 94 | 95 | preprocessor_cls = _preprocessors_map[preprocessor] 96 | args = {"preprocessor": preprocessor_cls, "image": image, "resolution": resolution} 97 | 98 | function_name = AIO_Preprocessor.FUNCTION 99 | res = getattr(aio_preprocessor, function_name)(**args) 100 | if isinstance(res, dict): 101 | res = res["result"] 102 | 103 | return res[0] 104 | 105 | except Exception as e: 106 | print(e) 107 | -------------------------------------------------------------------------------- /modules/fooocus/__init__.py: -------------------------------------------------------------------------------- 1 | from nodes import KSampler, KSamplerAdvanced 2 | 3 | from .patch import patch_all, unpatch_all 4 | from .efficient import ( 5 | NODE_CLASS_MAPPINGS as EFFICIENCY_NODE_CLASS_MAPPINGS, 6 | NODE_DISPLAY_NAME_MAPPINGS as EFFICIENCY_NODE_DISPLAY_NAME_MAPPINGS, 7 | ) 8 | 9 | 10 | class KSamplerWithSharpness(KSampler): 11 | @classmethod 12 | def INPUT_TYPES(cls): 13 | inputs = KSampler.INPUT_TYPES() 14 | inputs["optional"] = { 15 | "sharpness": ( 16 | "FLOAT", 17 | {"default": 2.0, "min": 0.0, "max": 100.0, "step": 0.01}, 18 | ) 19 | } 20 | 21 | return inputs 22 | 23 | CATEGORY = "Art Venture/Sampling" 24 | 25 | def sample(self, *args, sharpness=2.0, **kwargs): 26 | patch.sharpness = sharpness 27 | patch_all() 28 | results = super().sample(*args, **kwargs) 29 | unpatch_all() 30 | return results 31 | 32 | 33 | class KSamplerAdvancedWithSharpness(KSamplerAdvanced): 34 | @classmethod 35 | def INPUT_TYPES(cls): 36 | inputs = KSamplerAdvanced.INPUT_TYPES() 37 | inputs["optional"] = { 38 | "sharpness": ( 39 | "FLOAT", 40 | {"default": 2.0, "min": 0.0, "max": 100.0, "step": 0.01}, 41 | ) 42 | } 43 | 44 | return inputs 45 | 46 | CATEGORY = "Art Venture/Sampling" 47 | 48 | def sample(self, *args, sharpness=2.0, **kwargs): 49 | patch.sharpness = sharpness 50 | patch_all() 51 | results = super().sample(*args, **kwargs) 52 | unpatch_all() 53 | return results 54 | 55 | 56 | NODE_CLASS_MAPPINGS = { 57 | "Fooocus_KSampler": KSamplerWithSharpness, 58 | "Fooocus_KSamplerAdvanced": KSamplerAdvancedWithSharpness, 59 | } 60 | 61 | NODE_DISPLAY_NAME_MAPPINGS = { 62 | "Fooocus_KSampler": "KSampler Fooocus", 63 | "Fooocus_KSamplerAdvanced": "KSampler Adv. Fooocus", 64 | } 65 | 66 | NODE_CLASS_MAPPINGS.update(EFFICIENCY_NODE_CLASS_MAPPINGS) 67 | NODE_DISPLAY_NAME_MAPPINGS.update(EFFICIENCY_NODE_DISPLAY_NAME_MAPPINGS) 68 | -------------------------------------------------------------------------------- /modules/fooocus/anisotropic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | Tensor = torch.Tensor 5 | Device = torch.DeviceObjType 6 | Dtype = torch.Type 7 | pad = torch.nn.functional.pad 8 | 9 | 10 | def _compute_zero_padding(kernel_size: tuple[int, int] | int) -> tuple[int, int]: 11 | ky, kx = _unpack_2d_ks(kernel_size) 12 | return (ky - 1) // 2, (kx - 1) // 2 13 | 14 | 15 | def _unpack_2d_ks(kernel_size: tuple[int, int] | int) -> tuple[int, int]: 16 | if isinstance(kernel_size, int): 17 | ky = kx = kernel_size 18 | else: 19 | assert len(kernel_size) == 2, "2D Kernel size should have a length of 2." 20 | ky, kx = kernel_size 21 | 22 | ky = int(ky) 23 | kx = int(kx) 24 | return ky, kx 25 | 26 | 27 | def gaussian( 28 | window_size: int, 29 | sigma: Tensor | float, 30 | *, 31 | device: Device | None = None, 32 | dtype: Dtype | None = None, 33 | ) -> Tensor: 34 | batch_size = sigma.shape[0] 35 | 36 | x = ( 37 | torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) 38 | - window_size // 2 39 | ).expand(batch_size, -1) 40 | 41 | if window_size % 2 == 0: 42 | x = x + 0.5 43 | 44 | gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) 45 | 46 | return gauss / gauss.sum(-1, keepdim=True) 47 | 48 | 49 | def get_gaussian_kernel1d( 50 | kernel_size: int, 51 | sigma: float | Tensor, 52 | force_even: bool = False, 53 | *, 54 | device: Device | None = None, 55 | dtype: Dtype | None = None, 56 | ) -> Tensor: 57 | return gaussian(kernel_size, sigma, device=device, dtype=dtype) 58 | 59 | 60 | def get_gaussian_kernel2d( 61 | kernel_size: tuple[int, int] | int, 62 | sigma: tuple[float, float] | Tensor, 63 | force_even: bool = False, 64 | *, 65 | device: Device | None = None, 66 | dtype: Dtype | None = None, 67 | ) -> Tensor: 68 | sigma = torch.Tensor([[sigma, sigma]]).to(device=device, dtype=dtype) 69 | 70 | ksize_y, ksize_x = _unpack_2d_ks(kernel_size) 71 | sigma_y, sigma_x = sigma[:, 0, None], sigma[:, 1, None] 72 | 73 | kernel_y = get_gaussian_kernel1d( 74 | ksize_y, sigma_y, force_even, device=device, dtype=dtype 75 | )[..., None] 76 | kernel_x = get_gaussian_kernel1d( 77 | ksize_x, sigma_x, force_even, device=device, dtype=dtype 78 | )[..., None] 79 | 80 | return kernel_y * kernel_x.view(-1, 1, ksize_x) 81 | 82 | 83 | def _bilateral_blur( 84 | input: Tensor, 85 | guidance: Tensor | None, 86 | kernel_size: tuple[int, int] | int, 87 | sigma_color: float | Tensor, 88 | sigma_space: tuple[float, float] | Tensor, 89 | border_type: str = "reflect", 90 | color_distance_type: str = "l1", 91 | ) -> Tensor: 92 | if isinstance(sigma_color, Tensor): 93 | sigma_color = sigma_color.to(device=input.device, dtype=input.dtype).view( 94 | -1, 1, 1, 1, 1 95 | ) 96 | 97 | ky, kx = _unpack_2d_ks(kernel_size) 98 | pad_y, pad_x = _compute_zero_padding(kernel_size) 99 | 100 | padded_input = pad(input, (pad_x, pad_x, pad_y, pad_y), mode=border_type) 101 | unfolded_input = ( 102 | padded_input.unfold(2, ky, 1).unfold(3, kx, 1).flatten(-2) 103 | ) # (B, C, H, W, Ky x Kx) 104 | 105 | if guidance is None: 106 | guidance = input 107 | unfolded_guidance = unfolded_input 108 | else: 109 | padded_guidance = pad(guidance, (pad_x, pad_x, pad_y, pad_y), mode=border_type) 110 | unfolded_guidance = ( 111 | padded_guidance.unfold(2, ky, 1).unfold(3, kx, 1).flatten(-2) 112 | ) # (B, C, H, W, Ky x Kx) 113 | 114 | diff = unfolded_guidance - guidance.unsqueeze(-1) 115 | if color_distance_type == "l1": 116 | color_distance_sq = diff.abs().sum(1, keepdim=True).square() 117 | elif color_distance_type == "l2": 118 | color_distance_sq = diff.square().sum(1, keepdim=True) 119 | else: 120 | raise ValueError("color_distance_type only acceps l1 or l2") 121 | color_kernel = ( 122 | -0.5 / sigma_color**2 * color_distance_sq 123 | ).exp() # (B, 1, H, W, Ky x Kx) 124 | 125 | space_kernel = get_gaussian_kernel2d( 126 | kernel_size, sigma_space, device=input.device, dtype=input.dtype 127 | ) 128 | space_kernel = space_kernel.view(-1, 1, 1, 1, kx * ky) 129 | 130 | kernel = space_kernel * color_kernel 131 | out = (unfolded_input * kernel).sum(-1) / kernel.sum(-1) 132 | return out 133 | 134 | 135 | def bilateral_blur( 136 | input: Tensor, 137 | kernel_size: tuple[int, int] | int = (13, 13), 138 | sigma_color: float | Tensor = 3.0, 139 | sigma_space: tuple[float, float] | Tensor = 3.0, 140 | border_type: str = "reflect", 141 | color_distance_type: str = "l1", 142 | ) -> Tensor: 143 | return _bilateral_blur( 144 | input, 145 | None, 146 | kernel_size, 147 | sigma_color, 148 | sigma_space, 149 | border_type, 150 | color_distance_type, 151 | ) 152 | 153 | 154 | def joint_bilateral_blur( 155 | input: Tensor, 156 | guidance: Tensor, 157 | kernel_size: tuple[int, int] | int, 158 | sigma_color: float | Tensor, 159 | sigma_space: tuple[float, float] | Tensor, 160 | border_type: str = "reflect", 161 | color_distance_type: str = "l1", 162 | ) -> Tensor: 163 | return _bilateral_blur( 164 | input, 165 | guidance, 166 | kernel_size, 167 | sigma_color, 168 | sigma_space, 169 | border_type, 170 | color_distance_type, 171 | ) 172 | 173 | 174 | class _BilateralBlur(torch.nn.Module): 175 | def __init__( 176 | self, 177 | kernel_size: tuple[int, int] | int, 178 | sigma_color: float | Tensor, 179 | sigma_space: tuple[float, float] | Tensor, 180 | border_type: str = "reflect", 181 | color_distance_type: str = "l1", 182 | ) -> None: 183 | super().__init__() 184 | self.kernel_size = kernel_size 185 | self.sigma_color = sigma_color 186 | self.sigma_space = sigma_space 187 | self.border_type = border_type 188 | self.color_distance_type = color_distance_type 189 | 190 | def __repr__(self) -> str: 191 | return ( 192 | f"{self.__class__.__name__}" 193 | f"(kernel_size={self.kernel_size}, " 194 | f"sigma_color={self.sigma_color}, " 195 | f"sigma_space={self.sigma_space}, " 196 | f"border_type={self.border_type}, " 197 | f"color_distance_type={self.color_distance_type})" 198 | ) 199 | 200 | 201 | class BilateralBlur(_BilateralBlur): 202 | def forward(self, input: Tensor) -> Tensor: 203 | return bilateral_blur( 204 | input, 205 | self.kernel_size, 206 | self.sigma_color, 207 | self.sigma_space, 208 | self.border_type, 209 | self.color_distance_type, 210 | ) 211 | 212 | 213 | class JointBilateralBlur(_BilateralBlur): 214 | def forward(self, input: Tensor, guidance: Tensor) -> Tensor: 215 | return joint_bilateral_blur( 216 | input, 217 | guidance, 218 | self.kernel_size, 219 | self.sigma_color, 220 | self.sigma_space, 221 | self.border_type, 222 | self.color_distance_type, 223 | ) 224 | -------------------------------------------------------------------------------- /modules/fooocus/efficient.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict 3 | 4 | import folder_paths 5 | 6 | from . import patch 7 | from ..utils import load_module 8 | 9 | custom_nodes = folder_paths.get_folder_paths("custom_nodes") 10 | efficieny_dir_names = ["Efficiency", "efficiency-nodes-comfyui"] 11 | 12 | NODE_CLASS_MAPPINGS = {} 13 | NODE_DISPLAY_NAME_MAPPINGS = {} 14 | 15 | try: 16 | module_path = None 17 | 18 | for custom_node in custom_nodes: 19 | custom_node = ( 20 | custom_node if not os.path.islink(custom_node) else os.readlink(custom_node) 21 | ) 22 | for module_dir in efficieny_dir_names: 23 | if module_dir in os.listdir(custom_node): 24 | module_path = os.path.abspath( 25 | os.path.join(custom_node, module_dir) 26 | ) 27 | break 28 | 29 | if module_path is None: 30 | raise Exception("Could not find efficiency nodes") 31 | 32 | module = load_module(module_path) 33 | print("Loaded Efficiency nodes from", module_path) 34 | 35 | nodes: Dict = getattr(module, "NODE_CLASS_MAPPINGS") 36 | 37 | TSC_KSampler = nodes["KSampler (Efficient)"] 38 | TSC_KSamplerAdvanced = nodes["KSampler Adv. (Efficient)"] 39 | TSC_EfficientLoader = nodes["Efficient Loader"] 40 | 41 | class KSamplerEfficientWithSharpness(TSC_KSampler): 42 | @classmethod 43 | def INPUT_TYPES(cls): 44 | inputs = TSC_KSampler.INPUT_TYPES() 45 | inputs["optional"]["sharpness"] = ( 46 | "FLOAT", 47 | {"default": 2.0, "min": 0.0, "max": 100.0, "step": 0.01}, 48 | ) 49 | 50 | return inputs 51 | 52 | CATEGORY = "Art Venture/Sampling" 53 | 54 | def sample(self, *args, sharpness=2.0, **kwargs): 55 | patch.sharpness = sharpness 56 | patch.patch_all() 57 | results = super().sample(*args, **kwargs) 58 | patch.unpatch_all() 59 | return results 60 | 61 | class KSamplerEfficientAdvancedWithSharpness(TSC_KSamplerAdvanced): 62 | @classmethod 63 | def INPUT_TYPES(cls): 64 | inputs = TSC_KSampler.INPUT_TYPES() 65 | inputs["optional"]["sharpness"] = ( 66 | "FLOAT", 67 | {"default": 2.0, "min": 0.0, "max": 100.0, "step": 0.01}, 68 | ) 69 | 70 | return inputs 71 | 72 | CATEGORY = "Art Venture/Sampling" 73 | 74 | def sampleadv(self, *args, sharpness=2.0, **kwargs): 75 | patch.sharpness = sharpness 76 | patch.patch_all() 77 | results = super().sampleadv(*args, **kwargs) 78 | patch.unpatch_all() 79 | return results 80 | 81 | class AVCheckpointLoader(TSC_EfficientLoader): 82 | @classmethod 83 | def INPUT_TYPES(cls): 84 | inputs = TSC_EfficientLoader.INPUT_TYPES() 85 | inputs["optional"]["ckpt_override"] = ("STRING", {"default": "None"}) 86 | inputs["optional"]["vae_override"] = ("STRING", {"default": "None"}) 87 | inputs["optional"]["lora_override"] = ("STRING", {"default": "None"}) 88 | return inputs 89 | 90 | CATEGORY = "Art Venture/Loaders" 91 | 92 | def efficientloader( 93 | self, 94 | ckpt_name, 95 | vae_name, 96 | clip_skip, 97 | lora_name, 98 | *args, 99 | ckpt_override="None", 100 | vae_override="None", 101 | lora_override="None", 102 | **kwargs 103 | ): 104 | if ckpt_override != "None": 105 | ckpt_name = ckpt_override 106 | if vae_override != "None": 107 | vae_name = vae_override 108 | if lora_override != "None": 109 | lora_name = lora_override 110 | 111 | return super().efficientloader( 112 | ckpt_name, vae_name, clip_skip, lora_name, *args, **kwargs 113 | ) 114 | 115 | NODE_CLASS_MAPPINGS.update( 116 | { 117 | "Fooocus_KSamplerEfficient": KSamplerEfficientWithSharpness, 118 | "Fooocus_KSamplerEfficientAdvanced": KSamplerEfficientAdvancedWithSharpness, 119 | "AV_CheckpointLoader": AVCheckpointLoader, 120 | } 121 | ) 122 | NODE_DISPLAY_NAME_MAPPINGS.update( 123 | { 124 | "Fooocus_KSamplerEfficient": "KSampler Efficient Fooocus", 125 | "Fooocus_KSamplerEfficientAdvanced": "KSampler Adv. Efficient Fooocus", 126 | "AV_CheckpointLoader": "Checkpoint Loader", 127 | } 128 | ) 129 | 130 | 131 | except Exception as e: 132 | print(e) 133 | -------------------------------------------------------------------------------- /modules/fooocus/patch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import comfy.model_base 3 | import comfy.ldm.modules.diffusionmodules.openaimodel 4 | import comfy.samplers 5 | 6 | 7 | from .anisotropic import bilateral_blur 8 | 9 | sharpness = 2.0 10 | 11 | original_unet_forward = comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward 12 | original_sdxl_encode_adm = comfy.model_base.SDXL.encode_adm 13 | 14 | 15 | def unet_forward_patched( 16 | self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs 17 | ): 18 | x0 = original_unet_forward( 19 | self, 20 | x, 21 | timesteps=timesteps, 22 | context=context, 23 | y=y, 24 | control=control, 25 | transformer_options=transformer_options, 26 | **kwargs 27 | ) 28 | uc_mask = torch.Tensor(transformer_options["cond_or_uncond"]).to(x0).float()[:, None, None, None] 29 | 30 | alpha = 1.0 - (timesteps / 999.0)[:, None, None, None].clone() 31 | alpha *= 0.001 * sharpness 32 | degraded_x0 = bilateral_blur(x0) * alpha + x0 * (1.0 - alpha) 33 | 34 | # FIX: uc_mask is not always the same size as x0 35 | if uc_mask.shape[0] < x0.shape[0]: 36 | uc_mask = uc_mask.repeat(int(x0.shape[0] / uc_mask.shape[0]), 1, 1, 1) 37 | 38 | x0 = x0 * uc_mask + degraded_x0 * (1.0 - uc_mask) 39 | 40 | return x0 41 | 42 | 43 | def sdxl_encode_adm_patched(self, **kwargs): 44 | clip_pooled = kwargs["pooled_output"] 45 | width = kwargs.get("width", 768) 46 | height = kwargs.get("height", 768) 47 | crop_w = kwargs.get("crop_w", 0) 48 | crop_h = kwargs.get("crop_h", 0) 49 | target_width = kwargs.get("target_width", width) 50 | target_height = kwargs.get("target_height", height) 51 | 52 | if kwargs.get("prompt_type", "") == "negative": 53 | width *= 0.8 54 | height *= 0.8 55 | elif kwargs.get("prompt_type", "") == "positive": 56 | width *= 1.5 57 | height *= 1.5 58 | 59 | out = [] 60 | out.append(self.embedder(torch.Tensor([height]))) 61 | out.append(self.embedder(torch.Tensor([width]))) 62 | out.append(self.embedder(torch.Tensor([crop_h]))) 63 | out.append(self.embedder(torch.Tensor([crop_w]))) 64 | out.append(self.embedder(torch.Tensor([target_height]))) 65 | out.append(self.embedder(torch.Tensor([target_width]))) 66 | flat = torch.flatten(torch.cat(out))[None,] 67 | return torch.cat((clip_pooled.to(flat.device), flat), dim=1) 68 | 69 | 70 | def patch_all(): 71 | comfy.model_base.SDXL.encode_adm = sdxl_encode_adm_patched 72 | comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = unet_forward_patched 73 | 74 | 75 | def unpatch_all(): 76 | comfy.model_base.SDXL.encode_adm = original_sdxl_encode_adm 77 | comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = original_unet_forward 78 | -------------------------------------------------------------------------------- /modules/image_utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from PIL import Image 3 | 4 | 5 | class ResizeMode(Enum): 6 | RESIZE = 0 # just resize 7 | RESIZE_TO_FILL = 1 # crop and resize 8 | RESIZE_TO_FIT = 2 # resize and fill 9 | 10 | 11 | def resize_image(im: Image.Image, width: int, height: int, resize_mode=ResizeMode.RESIZE_TO_FIT): 12 | """ 13 | Resizes an image with the specified resize_mode, width, and height. 14 | 15 | Args: 16 | resize_mode: The mode to use when resizing the image. 17 | 0: Resize the image to the specified width and height. 18 | 1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess. 19 | 2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image. 20 | im: The image to resize. 21 | width: The width to resize the image to. 22 | height: The height to resize the image to. 23 | upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img. 24 | """ 25 | 26 | def resize(im: Image.Image, w, h): 27 | return im.resize((w, h), resample=Image.LANCZOS) 28 | 29 | if resize_mode == ResizeMode.RESIZE: 30 | res = resize(im, width, height) 31 | 32 | elif resize_mode == ResizeMode.RESIZE_TO_FILL: 33 | ratio = width / height 34 | src_ratio = im.width / im.height 35 | 36 | src_w = width if ratio > src_ratio else im.width * height // im.height 37 | src_h = height if ratio <= src_ratio else im.height * width // im.width 38 | 39 | resized = resize(im, src_w, src_h) 40 | res = Image.new("RGB", (width, height)) 41 | res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) 42 | 43 | else: 44 | ratio = width / height 45 | src_ratio = im.width / im.height 46 | 47 | src_w = width if ratio < src_ratio else im.width * height // im.height 48 | src_h = height if ratio >= src_ratio else im.height * width // im.width 49 | 50 | resized = resize(im, src_w, src_h) 51 | res = Image.new("RGB", (width, height)) 52 | res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) 53 | 54 | if ratio < src_ratio: 55 | fill_height = height // 2 - src_h // 2 56 | if fill_height > 0: 57 | res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) 58 | res.paste( 59 | resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), 60 | box=(0, fill_height + src_h), 61 | ) 62 | elif ratio > src_ratio: 63 | fill_width = width // 2 - src_w // 2 64 | if fill_width > 0: 65 | res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) 66 | res.paste( 67 | resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), 68 | box=(fill_width + src_w, 0), 69 | ) 70 | 71 | return res 72 | 73 | 74 | def flatten_image(im: Image.Image, bgcolor="#ffffff"): 75 | """replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency""" 76 | 77 | if im.mode == "RGBA": 78 | background = Image.new("RGBA", im.size, bgcolor) 79 | background.paste(im, mask=im) 80 | im = background 81 | 82 | return im.convert("RGB") 83 | -------------------------------------------------------------------------------- /modules/impact/__init__.py: -------------------------------------------------------------------------------- 1 | from .facedetailer import ( 2 | NODE_CLASS_MAPPINGS as FACE_DETAILER_NODE_CLASS_MAPPINGS, 3 | NODE_DISPLAY_NAME_MAPPINGS as FACE_DETAILER_NODE_DISPLAY_NAME_MAPPINGS, 4 | ) 5 | 6 | NODE_CLASS_MAPPINGS = {} 7 | NODE_DISPLAY_NAME_MAPPINGS = {} 8 | 9 | NODE_CLASS_MAPPINGS.update(FACE_DETAILER_NODE_CLASS_MAPPINGS) 10 | NODE_DISPLAY_NAME_MAPPINGS.update(FACE_DETAILER_NODE_DISPLAY_NAME_MAPPINGS) 11 | 12 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] 13 | -------------------------------------------------------------------------------- /modules/impact/facedetailer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import inspect 3 | from typing import Dict 4 | 5 | import folder_paths 6 | 7 | from ..utils import load_module 8 | 9 | custom_nodes = folder_paths.get_folder_paths("custom_nodes") 10 | efficieny_dir_names = ["ImpactPack", "ComfyUI-Impact-Pack"] 11 | 12 | NODE_CLASS_MAPPINGS = {} 13 | NODE_DISPLAY_NAME_MAPPINGS = {} 14 | 15 | try: 16 | module_path = None 17 | 18 | for custom_node in custom_nodes: 19 | custom_node = custom_node if not os.path.islink(custom_node) else os.readlink(custom_node) 20 | for module_dir in efficieny_dir_names: 21 | if module_dir in os.listdir(custom_node): 22 | module_path = os.path.abspath(os.path.join(custom_node, module_dir)) 23 | break 24 | 25 | if module_path is None: 26 | raise Exception("Could not find ImpactPack nodes") 27 | 28 | module = load_module(module_path) 29 | print("Loaded ImpactPack nodes from", module_path) 30 | 31 | nodes: Dict = getattr(module, "NODE_CLASS_MAPPINGS") 32 | FaceDetailer = nodes["FaceDetailer"] 33 | FaceDetailerPipe = nodes["FaceDetailerPipe"] 34 | 35 | class AV_FaceDetailer(FaceDetailer): 36 | @classmethod 37 | def INPUT_TYPES(s): 38 | inputs = FaceDetailer.INPUT_TYPES() 39 | inputs["optional"]["enabled"] = ( 40 | "BOOLEAN", 41 | {"default": True, "label_on": "enabled", "label_off": "disabled"}, 42 | ) 43 | return inputs 44 | 45 | CATEGORY = "ArtVenture/Detailer" 46 | 47 | def args_to_pipe(self, args: dict): 48 | hook_args = [ 49 | "model", 50 | "clip", 51 | "vae", 52 | "positive", 53 | "negative", 54 | "wildcard", 55 | "bbox_detector", 56 | "segm_detector_opt", 57 | "sam_model_opt", 58 | "detailer_hook", 59 | ] 60 | 61 | pipe_args = [] 62 | for arg in hook_args: 63 | pipe_args.append(args.get(arg, None)) 64 | 65 | return tuple(pipe_args + [None, None, None, None]) 66 | 67 | def doit(self, image, *args, enabled=True, **kwargs): 68 | if enabled: 69 | return super().doit(image, *args, **kwargs) 70 | else: 71 | pipe = self.args_to_pipe(kwargs) 72 | return (image, [], [], None, pipe, []) 73 | 74 | class AV_FaceDetailerPipe(FaceDetailerPipe): 75 | @classmethod 76 | def INPUT_TYPES(s): 77 | inputs = FaceDetailerPipe.INPUT_TYPES() 78 | inputs["optional"]["enabled"] = ( 79 | "BOOLEAN", 80 | {"default": True, "label_on": "enabled", "label_off": "disabled"}, 81 | ) 82 | return inputs 83 | 84 | CATEGORY = "ArtVenture/Detailer" 85 | 86 | def doit(self, image, detailer_pipe, *args, enabled=True, **kwargs): 87 | if enabled: 88 | return super().doit(image, detailer_pipe, *args, **kwargs) 89 | else: 90 | return (image, [], [], None, detailer_pipe, []) 91 | 92 | NODE_CLASS_MAPPINGS.update( 93 | { 94 | "AV_FaceDetailer": AV_FaceDetailer, 95 | "AV_FaceDetailerPipe": AV_FaceDetailerPipe, 96 | } 97 | ) 98 | NODE_DISPLAY_NAME_MAPPINGS.update( 99 | {"AV_FaceDetailer": "FaceDetailer (AV)", "AV_FaceDetailerPipe": "FaceDetailerPipe (AV)"} 100 | ) 101 | 102 | except Exception as e: 103 | print("Could not load ImpactPack nodes", e) 104 | -------------------------------------------------------------------------------- /modules/inpaint/__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS -------------------------------------------------------------------------------- /modules/inpaint/lama/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/advimman/lama 2 | import os 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import folder_paths 7 | import comfy.model_management as model_management 8 | 9 | from ...model_utils import download_file 10 | 11 | 12 | lama = None 13 | gpu = model_management.get_torch_device() 14 | cpu = torch.device("cpu") 15 | model_dir = os.path.join(folder_paths.models_dir, "lama") 16 | model_url = "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt" 17 | model_sha = "344c77bbcb158f17dd143070d1e789f38a66c04202311ae3a258ef66667a9ea9" 18 | 19 | 20 | def ceil_modulo(x, mod): 21 | if x % mod == 0: 22 | return x 23 | return (x // mod + 1) * mod 24 | 25 | 26 | def pad_tensor_to_modulo(img, mod): 27 | height, width = img.shape[-2:] 28 | out_height = ceil_modulo(height, mod) 29 | out_width = ceil_modulo(width, mod) 30 | return F.pad(img, pad=(0, out_width - width, 0, out_height - height), mode="reflect") 31 | 32 | 33 | def load_model(): 34 | global lama 35 | if lama is None: 36 | model_path = os.path.join(model_dir, "big-lama.pt") 37 | download_file(model_url, model_path, model_sha) 38 | 39 | lama = torch.jit.load(model_path, map_location="cpu") 40 | lama.eval() 41 | 42 | return lama 43 | 44 | 45 | class LaMaInpaint: 46 | @classmethod 47 | def INPUT_TYPES(s): 48 | return { 49 | "required": { 50 | "image": ("IMAGE",), 51 | "mask": ("MASK",), 52 | }, 53 | "optional": {"device_mode": (["AUTO", "Prefer GPU", "CPU"],)}, 54 | } 55 | 56 | RETURN_TYPES = ("IMAGE",) 57 | CATEGORY = "Art Venture/Inpainting" 58 | FUNCTION = "lama_inpaint" 59 | 60 | def lama_inpaint( 61 | self, 62 | image: torch.Tensor, 63 | mask: torch.Tensor, 64 | device_mode="AUTO", 65 | ): 66 | if image.shape[0] != mask.shape[0]: 67 | raise Exception("Image and mask must have the same batch size") 68 | 69 | device = gpu if device_mode != "CPU" else cpu 70 | 71 | model = load_model() 72 | model.to(device) 73 | 74 | try: 75 | inpainted = [] 76 | orig_w = image.shape[2] 77 | orig_h = image.shape[1] 78 | 79 | for i, img in enumerate(image): 80 | img = img.permute(2, 0, 1).unsqueeze(0) 81 | msk = mask[i].detach().cpu() 82 | msk = (msk > 0) * 1.0 83 | msk = msk.unsqueeze(0).unsqueeze(0) 84 | 85 | src_image = pad_tensor_to_modulo(img, 8).to(device) 86 | src_mask = pad_tensor_to_modulo(msk, 8).to(device) 87 | 88 | res = model(src_image, src_mask) 89 | res = res[0].permute(1, 2, 0).detach().cpu() 90 | res = res[:orig_h, :orig_w] 91 | 92 | inpainted.append(res) 93 | 94 | return (torch.stack(inpainted),) 95 | finally: 96 | if device_mode == "AUTO": 97 | model.to(cpu) 98 | -------------------------------------------------------------------------------- /modules/inpaint/nodes.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from PIL import Image, ImageOps 5 | from typing import Dict 6 | 7 | from .sam.nodes import SAMLoader, GetSAMEmbedding, SAMEmbeddingToImage 8 | from .lama import LaMaInpaint 9 | 10 | from ..masking import get_crop_region, expand_crop_region 11 | from ..image_utils import ResizeMode, resize_image, flatten_image 12 | from ..utils import numpy2pil, tensor2pil, pil2tensor 13 | 14 | 15 | class PrepareImageAndMaskForInpaint: 16 | @classmethod 17 | def INPUT_TYPES(s): 18 | return { 19 | "required": { 20 | "image": ("IMAGE",), 21 | "mask": ("MASK",), 22 | "mask_blur": ("INT", {"default": 4, "min": 0, "max": 64}), 23 | "inpaint_masked": ("BOOLEAN", {"default": False}), 24 | "mask_padding": ("INT", {"default": 32, "min": 0, "max": 256}), 25 | "width": ("INT", {"default": 0, "min": 0, "max": 2048}), 26 | "height": ("INT", {"default": 0, "min": 0, "max": 2048}), 27 | } 28 | } 29 | 30 | RETURN_TYPES = ("IMAGE", "MASK", "IMAGE", "CROP_REGION") 31 | RETURN_NAMES = ("inpaint_image", "inpaint_mask", "overlay_image", "crop_region") 32 | CATEGORY = "Art Venture/Inpainting" 33 | FUNCTION = "prepare" 34 | 35 | def prepare( 36 | self, 37 | image: torch.Tensor, 38 | mask: torch.Tensor, 39 | # resize_mode: str, 40 | mask_blur: int, 41 | inpaint_masked: bool, 42 | mask_padding: int, 43 | width: int, 44 | height: int, 45 | ): 46 | if image.shape[0] != mask.shape[0]: 47 | raise ValueError("image and mask must have same batch size") 48 | 49 | if image.shape[1] != mask.shape[1] or image.shape[2] != mask.shape[2]: 50 | raise ValueError("image and mask must have same dimensions") 51 | 52 | if width == 0 and height == 0: 53 | height, width = image.shape[1:3] 54 | 55 | sourceheight, sourcewidth = image.shape[1:3] 56 | 57 | masks = [] 58 | images = [] 59 | overlay_masks = [] 60 | overlay_images = [] 61 | crop_regions = [] 62 | 63 | for img, msk in zip(image, mask): 64 | np_mask: np.ndarray = msk.cpu().numpy() 65 | 66 | if mask_blur > 0: 67 | kernel_size = 2 * int(2.5 * mask_blur + 0.5) + 1 68 | np_mask = cv2.GaussianBlur(np_mask, (kernel_size, kernel_size), mask_blur) 69 | 70 | pil_mask = numpy2pil(np_mask, "L") 71 | crop_region = None 72 | 73 | if inpaint_masked: 74 | crop_region = get_crop_region(np_mask, mask_padding) 75 | crop_region = expand_crop_region(crop_region, width, height, sourcewidth, sourceheight) 76 | # crop mask 77 | overlay_mask = pil_mask 78 | pil_mask = resize_image(pil_mask.crop(crop_region), width, height, ResizeMode.RESIZE_TO_FIT) 79 | pil_mask = pil_mask.convert("L") 80 | else: 81 | np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) 82 | overlay_mask = numpy2pil(np_mask, "L") 83 | 84 | pil_img = tensor2pil(img) 85 | pil_img = flatten_image(pil_img) 86 | 87 | image_masked = Image.new("RGBa", (pil_img.width, pil_img.height)) 88 | image_masked.paste(pil_img.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(overlay_mask)) 89 | overlay_images.append(pil2tensor(image_masked.convert("RGBA"))) 90 | overlay_masks.append(pil2tensor(overlay_mask)) 91 | 92 | if crop_region is not None: 93 | pil_img = resize_image(pil_img.crop(crop_region), width, height, ResizeMode.RESIZE_TO_FIT) 94 | else: 95 | crop_region = (0, 0, 0, 0) 96 | 97 | images.append(pil2tensor(pil_img)) 98 | masks.append(pil2tensor(pil_mask)) 99 | crop_regions.append(torch.tensor(crop_region, dtype=torch.int64)) 100 | 101 | return ( 102 | torch.cat(images, dim=0), 103 | torch.cat(masks, dim=0), 104 | torch.cat(overlay_images, dim=0), 105 | torch.stack(crop_regions), 106 | ) 107 | 108 | 109 | class OverlayInpaintedLatent: 110 | @classmethod 111 | def INPUT_TYPES(s): 112 | return { 113 | "required": { 114 | "original": ("LATENT",), 115 | "inpainted": ("LATENT",), 116 | "mask": ("MASK",), 117 | } 118 | } 119 | 120 | RETURN_TYPES = ("LATENT",) 121 | CATEGORY = "Art Venture/Inpainting" 122 | FUNCTION = "overlay" 123 | 124 | def overlay(self, original: Dict, inpainted: Dict, mask: torch.Tensor): 125 | s_original: torch.Tensor = original["samples"] 126 | s_inpainted: torch.Tensor = inpainted["samples"] 127 | 128 | if s_original.shape[0] != s_inpainted.shape[0]: 129 | raise ValueError("original and inpainted must have same batch size") 130 | 131 | if s_original.shape[0] != mask.shape[0]: 132 | raise ValueError("original and mask must have same batch size") 133 | 134 | overlays = [] 135 | 136 | for org, inp, msk in zip(s_original, s_inpainted, mask): 137 | latmask = tensor2pil(msk.unsqueeze(0), "L").convert("RGB").resize((org.shape[2], org.shape[1])) 138 | latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 139 | latmask = latmask[0] 140 | latmask = np.around(latmask) 141 | latmask = np.tile(latmask[None], (4, 1, 1)) 142 | 143 | msk = torch.asarray(1.0 - latmask) 144 | nmask = torch.asarray(latmask) 145 | 146 | overlayed = inp * nmask + org * msk 147 | overlays.append(overlayed) 148 | 149 | samples = torch.stack(overlays) 150 | return ({"samples": samples},) 151 | 152 | 153 | class OverlayInpaintedImage: 154 | @classmethod 155 | def INPUT_TYPES(s): 156 | return { 157 | "required": { 158 | "inpainted": ("IMAGE",), 159 | "overlay_image": ("IMAGE",), 160 | "crop_region": ("CROP_REGION",), 161 | } 162 | } 163 | 164 | RETURN_TYPES = ("IMAGE",) 165 | CATEGORY = "Art Venture/Inpainting" 166 | FUNCTION = "overlay" 167 | 168 | def overlay(self, inpainted: torch.Tensor, overlay_image: torch.Tensor, crop_region: torch.Tensor): 169 | if inpainted.shape[0] != overlay_image.shape[0]: 170 | raise ValueError("inpainted and overlay_image must have same batch size") 171 | if inpainted.shape[0] != crop_region.shape[0]: 172 | raise ValueError("inpainted and crop_region must have same batch size") 173 | 174 | images = [] 175 | for image, overlay, region in zip(inpainted, overlay_image, crop_region): 176 | image = tensor2pil(image.unsqueeze(0)) 177 | overlay = tensor2pil(overlay.unsqueeze(0), mode="RGBA") 178 | 179 | x1, y1, x2, y2 = region.tolist() 180 | if (x1, y1, x2, y2) == (0, 0, 0, 0): 181 | pass 182 | else: 183 | base_image = Image.new("RGBA", (overlay.width, overlay.height)) 184 | image = resize_image(image, x2 - x1, y2 - y1, ResizeMode.RESIZE_TO_FILL) 185 | base_image.paste(image, (x1, y1)) 186 | image = base_image 187 | 188 | image = image.convert("RGBA") 189 | image.alpha_composite(overlay) 190 | image = image.convert("RGB") 191 | 192 | images.append(pil2tensor(image)) 193 | 194 | return (torch.cat(images, dim=0),) 195 | 196 | 197 | NODE_CLASS_MAPPINGS = { 198 | "AV_SAMLoader": SAMLoader, 199 | "GetSAMEmbedding": GetSAMEmbedding, 200 | "SAMEmbeddingToImage": SAMEmbeddingToImage, 201 | "LaMaInpaint": LaMaInpaint, 202 | "PrepareImageAndMaskForInpaint": PrepareImageAndMaskForInpaint, 203 | "OverlayInpaintedLatent": OverlayInpaintedLatent, 204 | "OverlayInpaintedImage": OverlayInpaintedImage, 205 | } 206 | 207 | NODE_DISPLAY_NAME_MAPPINGS = { 208 | "AV_SAMLoader": "SAM Loader", 209 | "GetSAMEmbedding": "Get SAM Embedding", 210 | "SAMEmbeddingToImage": "SAM Embedding to Image", 211 | "LaMaInpaint": "LaMa Remove Object", 212 | "PrepareImageAndMaskForInpaint": "Prepare Image & Mask for Inpaint", 213 | "OverlayInpaintedLatent": "Overlay Inpainted Latent", 214 | "OverlayInpaintedImage": "Overlay Inpainted Image", 215 | } 216 | -------------------------------------------------------------------------------- /modules/inpaint/sam/nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import folder_paths 7 | import comfy.model_management as model_management 8 | import comfy.utils 9 | 10 | from ...utils import ensure_package, tensor2pil, pil2tensor 11 | 12 | folder_paths.folder_names_and_paths["sams"] = ( 13 | [ 14 | os.path.join(folder_paths.models_dir, "sams"), 15 | ], 16 | folder_paths.supported_pt_extensions, 17 | ) 18 | 19 | gpu = model_management.get_torch_device() 20 | cpu = torch.device("cpu") 21 | 22 | 23 | class SAMLoader: 24 | @classmethod 25 | def INPUT_TYPES(cls): 26 | return { 27 | "required": { 28 | "model_name": (folder_paths.get_filename_list("sams"),), 29 | } 30 | } 31 | 32 | RETURN_TYPES = ("AV_SAM_MODEL",) 33 | RETURN_NAMES = ("sam_model",) 34 | FUNCTION = "load_model" 35 | CATEGORY = "Art Venture/Segmentation" 36 | 37 | def load_model(self, model_name): 38 | modelname = folder_paths.get_full_path("sams", model_name) 39 | 40 | state_dict = comfy.utils.load_torch_file(modelname) 41 | encoder_size = state_dict["image_encoder.patch_embed.proj.bias"].shape[0] 42 | 43 | if encoder_size == 1280: 44 | model_kind = "vit_h" 45 | elif encoder_size == 1024: 46 | model_kind = "vit_l" 47 | else: 48 | model_kind = "vit_b" 49 | 50 | ensure_package("segment_anything") 51 | from segment_anything import sam_model_registry 52 | 53 | sam = sam_model_registry[model_kind]() 54 | sam.load_state_dict(state_dict) 55 | 56 | return (sam,) 57 | 58 | 59 | class GetSAMEmbedding: 60 | @classmethod 61 | def INPUT_TYPES(s): 62 | return { 63 | "required": { 64 | "sam_model": ("AV_SAM_MODEL",), 65 | "image": ("IMAGE",), 66 | }, 67 | "optional": {"device_mode": (["AUTO", "Prefer GPU", "CPU"],)}, 68 | } 69 | 70 | RETURN_TYPES = ("SAM_EMBEDDING",) 71 | CATEGORY = "Art Venture/Segmentation" 72 | FUNCTION = "get_sam_embedding" 73 | 74 | def get_sam_embedding(self, image, sam_model, device_mode="AUTO"): 75 | device = gpu if device_mode != "CPU" else cpu 76 | sam_model.to(device) 77 | 78 | ensure_package("segment_anything") 79 | from segment_anything import SamPredictor 80 | 81 | try: 82 | predictor = SamPredictor(sam_model) 83 | image = tensor2pil(image) 84 | image = image.convert("RGB") 85 | image = np.array(image) 86 | predictor.set_image(image, "RGB") 87 | embedding = predictor.get_image_embedding().cpu().numpy() 88 | 89 | return (embedding,) 90 | finally: 91 | if device_mode == "AUTO": 92 | sam_model.to(cpu) 93 | 94 | 95 | class SAMEmbeddingToImage: 96 | @classmethod 97 | def INPUT_TYPES(s): 98 | return { 99 | "required": { 100 | "embedding": ("SAM_EMBEDDING",), 101 | }, 102 | } 103 | 104 | RETURN_TYPES = ("IMAGE",) 105 | CATEGORY = "Art Venture/Segmentation" 106 | FUNCTION = "sam_embedding_to_noise_image" 107 | 108 | def sam_embedding_to_noise_image(self, embedding: np.ndarray): 109 | # Flatten the array to a 1D array 110 | flat_arr = embedding.flatten() 111 | # Convert the 1D array to bytes 112 | bytes_arr = flat_arr.astype(np.float32).tobytes() 113 | # Convert bytes to RGBA PIL Image 114 | size = (embedding.shape[1] * 4, int(embedding.shape[2] * embedding.shape[3] / 4)) 115 | 116 | img = Image.frombytes("RGBA", size, bytes_arr) 117 | 118 | return (pil2tensor(img),) 119 | -------------------------------------------------------------------------------- /modules/interrogate/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | models/__pycache__/ 3 | *.pyc 4 | *.pth -------------------------------------------------------------------------------- /modules/interrogate/__init__.py: -------------------------------------------------------------------------------- 1 | from .blip_node import BlipLoader, BlipCaption, DownloadAndLoadBlip 2 | from .danbooru import DeepDanbooruCaption 3 | 4 | NODE_CLASS_MAPPINGS = { 5 | "BLIPLoader": BlipLoader, 6 | "BLIPCaption": BlipCaption, 7 | "DownloadAndLoadBlip": DownloadAndLoadBlip, 8 | "DeepDanbooruCaption": DeepDanbooruCaption, 9 | } 10 | NODE_DISPLAY_NAME_MAPPINGS = { 11 | "BLIPLoader": "BLIP Loader", 12 | "BLIPCaption": "BLIP Caption", 13 | "DownloadAndLoadBlip": "Download and Load BLIP Model", 14 | "DeepDanbooruCaption": "Deep Danbooru Caption", 15 | } 16 | 17 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] 18 | -------------------------------------------------------------------------------- /modules/interrogate/blip_node.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torchvision import transforms 5 | from torchvision.transforms.functional import InterpolationMode 6 | 7 | import folder_paths 8 | from comfy.model_management import text_encoder_device, text_encoder_offload_device, soft_empty_cache 9 | 10 | from ..model_utils import download_file 11 | from ..utils import tensor2pil 12 | 13 | blips = {} 14 | blip_size = 384 15 | gpu = text_encoder_device() 16 | cpu = text_encoder_offload_device() 17 | model_dir = os.path.join(folder_paths.models_dir, "blip") 18 | models = { 19 | "model_base_caption_capfilt_large.pth": { 20 | "url": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth", 21 | "sha": "96ac8749bd0a568c274ebe302b3a3748ab9be614c737f3d8c529697139174086", 22 | }, 23 | "model_base_capfilt_large.pth": { 24 | "url": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth", 25 | "sha": "8f5187458d4d47bb87876faf3038d5947eff17475edf52cf47b62e84da0b235f", 26 | }, 27 | } 28 | 29 | 30 | folder_paths.folder_names_and_paths["blip"] = ( 31 | [model_dir], 32 | folder_paths.supported_pt_extensions, 33 | ) 34 | 35 | 36 | def packages(versions=False): 37 | import subprocess 38 | import sys 39 | 40 | return [ 41 | (r.decode().split("==")[0] if not versions else r.decode()) 42 | for r in subprocess.check_output([sys.executable, "-m", "pip", "freeze"]).split() 43 | ] 44 | 45 | 46 | def transformImage(input_image): 47 | raw_image = input_image.convert("RGB") 48 | raw_image = raw_image.resize((blip_size, blip_size)) 49 | transform = transforms.Compose( 50 | [ 51 | transforms.Resize(raw_image.size, interpolation=InterpolationMode.BICUBIC), 52 | transforms.ToTensor(), 53 | transforms.Normalize( 54 | (0.48145466, 0.4578275, 0.40821073), 55 | (0.26862954, 0.26130258, 0.27577711), 56 | ), 57 | ] 58 | ) 59 | image = transform(raw_image).unsqueeze(0).to(gpu) 60 | return image.view(1, -1, blip_size, blip_size) # Change the shape of the output tensor 61 | 62 | 63 | def load_blip(model_name): 64 | if model_name not in blips: 65 | blip_path = folder_paths.get_full_path("blip", model_name) 66 | 67 | from .models.blip import blip_decoder 68 | 69 | current_dir = os.path.dirname(os.path.realpath(__file__)) 70 | med_config = os.path.join(current_dir, "configs", "med_config.json") 71 | blip = blip_decoder( 72 | pretrained=blip_path, 73 | image_size=blip_size, 74 | vit="base", 75 | med_config=med_config, 76 | ) 77 | blip.eval() 78 | blips[model_name] = blip 79 | 80 | return blips[model_name] 81 | 82 | 83 | def unload_blip(): 84 | global blips 85 | if blips is not None and blips.is_auto_mode: 86 | blips = blips.to(cpu) 87 | 88 | soft_empty_cache() 89 | 90 | 91 | def join_caption(caption, prefix, suffix): 92 | if prefix: 93 | caption = prefix + ", " + caption 94 | if suffix: 95 | caption = caption + ", " + suffix 96 | return caption 97 | 98 | 99 | def blip_caption(model, image, min_length, max_length): 100 | image = tensor2pil(image) 101 | tensor = transformImage(image) 102 | 103 | with torch.no_grad(): 104 | caption = model.generate( 105 | tensor, 106 | sample=False, 107 | num_beams=1, 108 | min_length=min_length, 109 | max_length=max_length, 110 | ) 111 | return caption[0] 112 | 113 | 114 | class BlipLoader: 115 | @classmethod 116 | def INPUT_TYPES(s): 117 | return { 118 | "required": { 119 | "model_name": (folder_paths.get_filename_list("blip"),), 120 | }, 121 | } 122 | 123 | RETURN_TYPES = ("BLIP_MODEL",) 124 | FUNCTION = "load_blip" 125 | CATEGORY = "Art Venture/Captioning" 126 | 127 | def load_blip(self, model_name): 128 | return (load_blip(model_name),) 129 | 130 | 131 | class DownloadAndLoadBlip: 132 | @classmethod 133 | def INPUT_TYPES(s): 134 | return { 135 | "required": { 136 | "model_name": (list(models.keys()),), 137 | }, 138 | } 139 | 140 | RETURN_TYPES = ("BLIP_MODEL",) 141 | FUNCTION = "download_and_load_blip" 142 | CATEGORY = "Art Venture/Captioning" 143 | 144 | def download_and_load_blip(self, model_name): 145 | if model_name not in folder_paths.get_filename_list("blip"): 146 | model_info = models[model_name] 147 | download_file( 148 | model_info["url"], 149 | os.path.join(model_dir, model_name), 150 | model_info["sha"], 151 | ) 152 | 153 | return (load_blip(model_name),) 154 | 155 | 156 | class BlipCaption: 157 | @classmethod 158 | def INPUT_TYPES(s): 159 | return { 160 | "required": { 161 | "image": ("IMAGE",), 162 | "min_length": ( 163 | "INT", 164 | { 165 | "default": 24, 166 | "min": 0, # minimum value 167 | "max": 200, # maximum value 168 | "step": 1, # slider's step 169 | }, 170 | ), 171 | "max_length": ( 172 | "INT", 173 | { 174 | "default": 48, 175 | "min": 0, # minimum value 176 | "max": 200, # maximum value 177 | "step": 1, # slider's step 178 | }, 179 | ), 180 | }, 181 | "optional": { 182 | "device_mode": (["AUTO", "Prefer GPU", "CPU"],), 183 | "prefix": ("STRING", {"default": ""}), 184 | "suffix": ("STRING", {"default": ""}), 185 | "enabled": ("BOOLEAN", {"default": True}), 186 | "blip_model": ("BLIP_MODEL",), 187 | }, 188 | } 189 | 190 | RETURN_TYPES = ("STRING",) 191 | RETURN_NAMES = ("caption",) 192 | OUTPUT_IS_LIST = (True,) 193 | FUNCTION = "blip_caption" 194 | CATEGORY = "Art Venture/Captioning" 195 | 196 | def blip_caption( 197 | self, image, min_length, max_length, device_mode="AUTO", prefix="", suffix="", enabled=True, blip_model=None 198 | ): 199 | if not enabled: 200 | return ([join_caption("", prefix, suffix)],) 201 | 202 | if blip_model is None: 203 | downloader = DownloadAndLoadBlip() 204 | blip_model = downloader.download_and_load_blip("model_base_caption_capfilt_large.pth")[0] 205 | 206 | device = gpu if device_mode != "CPU" else cpu 207 | blip_model = blip_model.to(device) 208 | 209 | try: 210 | captions = [] 211 | 212 | with torch.no_grad(): 213 | for img in image: 214 | img = tensor2pil(img) 215 | tensor = transformImage(img) 216 | caption = blip_model.generate( 217 | tensor, 218 | sample=False, 219 | num_beams=1, 220 | min_length=min_length, 221 | max_length=max_length, 222 | ) 223 | 224 | caption = join_caption(caption[0], prefix, suffix) 225 | captions.append(caption) 226 | 227 | return (captions,) 228 | except: 229 | raise 230 | finally: 231 | if device_mode == "AUTO": 232 | blip_model = blip_model.to(cpu) 233 | -------------------------------------------------------------------------------- /modules/interrogate/configs/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /modules/interrogate/danbooru.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import torch 4 | import numpy as np 5 | 6 | import folder_paths 7 | from comfy.model_management import text_encoder_device, text_encoder_offload_device, soft_empty_cache 8 | 9 | from ..image_utils import resize_image 10 | from ..model_utils import download_file 11 | from ..utils import is_junction, tensor2pil 12 | from .blip_node import join_caption 13 | 14 | danbooru = None 15 | blip_size = 384 16 | gpu = text_encoder_device() 17 | cpu = text_encoder_offload_device() 18 | model_dir = os.path.join(folder_paths.models_dir, "blip") 19 | model_url = "https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt" 20 | model_sha = "3841542cda4dd037da12a565e854b3347bb2eec8fbcd95ea3941b2c68990a355" 21 | re_special = re.compile(r"([\\()])") 22 | 23 | 24 | def load_danbooru(device_mode): 25 | global danbooru 26 | if danbooru is None: 27 | if not os.path.exists(model_dir) and not is_junction(model_dir): 28 | os.makedirs(model_dir, exist_ok=True) 29 | 30 | model_path = os.path.join(model_dir, "model-resnet_custom_v3.pt") 31 | download_file(model_url, model_path, model_sha) 32 | 33 | from .models.deepbooru_model import DeepDanbooruModel 34 | 35 | danbooru = DeepDanbooruModel() 36 | danbooru.load_state_dict(torch.load(model_path, map_location="cpu")) 37 | danbooru.eval() 38 | 39 | if device_mode != "CPU": 40 | danbooru = danbooru.to(gpu) 41 | 42 | danbooru.is_auto_mode = device_mode == "AUTO" 43 | 44 | return danbooru 45 | 46 | 47 | def unload_danbooru(): 48 | global danbooru 49 | if danbooru is not None and danbooru.is_auto_mode: 50 | danbooru = danbooru.to(cpu) 51 | 52 | soft_empty_cache() 53 | 54 | 55 | class DeepDanbooruCaption: 56 | def __init__(self): 57 | pass 58 | 59 | @classmethod 60 | def INPUT_TYPES(s): 61 | return { 62 | "required": { 63 | "image": ("IMAGE",), 64 | "threshold": ("FLOAT", {"default": 0.5, "min": 0, "max": 1, "step": 0.01}), 65 | "sort_alpha": ("BOOLEAN", {"default": True}), 66 | "use_spaces": ("BOOLEAN", {"default": True}), 67 | "escape": ("BOOLEAN", {"default": True}), 68 | "filter_tags": ("STRING", {"default": "blacklist", "multiline": True}), 69 | }, 70 | "optional": { 71 | "device_mode": (["AUTO", "Prefer GPU", "CPU"],), 72 | "prefix": ("STRING", {"default": ""}), 73 | "suffix": ("STRING", {"default": ""}), 74 | "enabled": ("BOOLEAN", {"default": True}), 75 | }, 76 | } 77 | 78 | RETURN_TYPES = ("STRING",) 79 | RETURN_NAMES = ("caption",) 80 | OUTPUT_IS_LIST = (True,) 81 | FUNCTION = "caption" 82 | CATEGORY = "Art Venture/Utils" 83 | 84 | def caption( 85 | self, 86 | image, 87 | threshold, 88 | sort_alpha, 89 | use_spaces, 90 | escape, 91 | filter_tags, 92 | device_mode="AUTO", 93 | prefix="", 94 | suffix="", 95 | enabled=True, 96 | ): 97 | if not enabled: 98 | return ([join_caption("", prefix, suffix)],) 99 | 100 | model = load_danbooru(device_mode) 101 | 102 | try: 103 | captions = [] 104 | 105 | for img in image: 106 | img = tensor2pil(img) 107 | img = resize_image(img.convert("RGB"), 512, 512, resize_mode=2) 108 | arr = np.expand_dims(np.array(img, dtype=np.float32), 0) / 255 109 | 110 | with torch.no_grad(): 111 | x = torch.from_numpy(arr).to(gpu) 112 | y = model(x)[0].detach().cpu().numpy() 113 | 114 | probability_dict = {} 115 | 116 | for tag, probability in zip(model.tags, y): 117 | if probability < threshold: 118 | continue 119 | 120 | if tag.startswith("rating:"): 121 | continue 122 | 123 | probability_dict[tag] = probability 124 | 125 | if sort_alpha: 126 | tags = sorted(probability_dict) 127 | else: 128 | tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])] 129 | 130 | res = [] 131 | filtertags = {x.strip().replace(" ", "_") for x in filter_tags.split(",")} 132 | 133 | for tag in [x for x in tags if x not in filtertags]: 134 | probability = probability_dict[tag] 135 | tag_outformat = tag 136 | if use_spaces: 137 | tag_outformat = tag_outformat.replace("_", " ") 138 | if escape: 139 | tag_outformat = re.sub(re_special, r"\\\1", tag_outformat) 140 | 141 | res.append(tag_outformat) 142 | 143 | caption = ", ".join(res) 144 | caption = join_caption(caption, prefix, suffix) 145 | captions.append(caption) 146 | 147 | return (captions,) 148 | except: 149 | raise 150 | finally: 151 | unload_danbooru() 152 | -------------------------------------------------------------------------------- /modules/interrogate/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sipherxyz/comfyui-art-venture/d78b709e3164c5cc410b55fe822df3904216f546/modules/interrogate/models/__init__.py -------------------------------------------------------------------------------- /modules/interrogate/models/blip.py: -------------------------------------------------------------------------------- 1 | """ 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | """ 8 | import warnings 9 | 10 | warnings.filterwarnings("ignore") 11 | 12 | import os 13 | from urllib.parse import urlparse 14 | 15 | import torch 16 | from torch import nn 17 | from timm.models.hub import download_cached_file 18 | from transformers import BertTokenizer 19 | 20 | from .med import BertConfig, BertLMHeadModel, BertModel 21 | from .vit import VisionTransformer, interpolate_pos_embed 22 | 23 | 24 | class BLIP_Base(nn.Module): 25 | def __init__( 26 | self, 27 | med_config="configs/med_config.json", 28 | image_size=224, 29 | vit="base", 30 | vit_grad_ckpt=False, 31 | vit_ckpt_layer=0, 32 | ): 33 | """ 34 | Args: 35 | med_config (str): path for the mixture of encoder-decoder model's configuration file 36 | image_size (int): input image size 37 | vit (str): model size of vision transformer 38 | """ 39 | super().__init__() 40 | 41 | self.visual_encoder, vision_width = create_vit( 42 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer 43 | ) 44 | self.tokenizer = init_tokenizer() 45 | med_config = BertConfig.from_json_file(med_config) 46 | med_config.encoder_width = vision_width 47 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 48 | 49 | def forward(self, image, caption, mode): 50 | assert mode in [ 51 | "image", 52 | "text", 53 | "multimodal", 54 | ], "mode parameter must be image, text, or multimodal" 55 | text = self.tokenizer(caption, return_tensors="pt").to(image.device) 56 | 57 | if mode == "image": 58 | # return image features 59 | image_embeds = self.visual_encoder(image) 60 | return image_embeds 61 | 62 | elif mode == "text": 63 | # return text features 64 | text_output = self.text_encoder( 65 | text.input_ids, 66 | attention_mask=text.attention_mask, 67 | return_dict=True, 68 | mode="text", 69 | ) 70 | return text_output.last_hidden_state 71 | 72 | elif mode == "multimodal": 73 | # return multimodel features 74 | image_embeds = self.visual_encoder(image) 75 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( 76 | image.device 77 | ) 78 | 79 | text.input_ids[:, 0] = self.tokenizer.enc_token_id 80 | output = self.text_encoder( 81 | text.input_ids, 82 | attention_mask=text.attention_mask, 83 | encoder_hidden_states=image_embeds, 84 | encoder_attention_mask=image_atts, 85 | return_dict=True, 86 | ) 87 | return output.last_hidden_state 88 | 89 | 90 | class BLIP_Decoder(nn.Module): 91 | def __init__( 92 | self, 93 | med_config="configs/med_config.json", 94 | image_size=384, 95 | vit="base", 96 | vit_grad_ckpt=False, 97 | vit_ckpt_layer=0, 98 | prompt="a picture of ", 99 | ): 100 | """ 101 | Args: 102 | med_config (str): path for the mixture of encoder-decoder model's configuration file 103 | image_size (int): input image size 104 | vit (str): model size of vision transformer 105 | """ 106 | super().__init__() 107 | 108 | self.visual_encoder, vision_width = create_vit( 109 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer 110 | ) 111 | self.tokenizer = init_tokenizer() 112 | med_config = BertConfig.from_json_file(med_config) 113 | med_config.encoder_width = vision_width 114 | self.text_decoder = BertLMHeadModel(config=med_config) 115 | 116 | self.prompt = prompt 117 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 118 | 119 | def forward(self, image, caption): 120 | image_embeds = self.visual_encoder(image) 121 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( 122 | image.device 123 | ) 124 | 125 | text = self.tokenizer( 126 | caption, 127 | padding="longest", 128 | truncation=True, 129 | max_length=40, 130 | return_tensors="pt", 131 | ).to(image.device) 132 | 133 | text.input_ids[:, 0] = self.tokenizer.bos_token_id 134 | 135 | decoder_targets = text.input_ids.masked_fill( 136 | text.input_ids == self.tokenizer.pad_token_id, -100 137 | ) 138 | decoder_targets[:, : self.prompt_length] = -100 139 | 140 | decoder_output = self.text_decoder( 141 | text.input_ids, 142 | attention_mask=text.attention_mask, 143 | encoder_hidden_states=image_embeds, 144 | encoder_attention_mask=image_atts, 145 | labels=decoder_targets, 146 | return_dict=True, 147 | ) 148 | loss_lm = decoder_output.loss 149 | 150 | return loss_lm 151 | 152 | def generate( 153 | self, 154 | image, 155 | sample=False, 156 | num_beams=3, 157 | max_length=30, 158 | min_length=10, 159 | top_p=0.9, 160 | repetition_penalty=1.0, 161 | ): 162 | image_embeds = self.visual_encoder(image) 163 | 164 | if not sample: 165 | image_embeds = image_embeds.repeat_interleave(num_beams, dim=0) 166 | 167 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( 168 | image.device 169 | ) 170 | model_kwargs = { 171 | "encoder_hidden_states": image_embeds, 172 | "encoder_attention_mask": image_atts, 173 | } 174 | 175 | prompt = [self.prompt] * image.size(0) 176 | input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to( 177 | image.device 178 | ) 179 | input_ids[:, 0] = self.tokenizer.bos_token_id 180 | input_ids = input_ids[:, :-1] 181 | 182 | if sample: 183 | # nucleus sampling 184 | outputs = self.text_decoder.generate( 185 | input_ids=input_ids, 186 | max_length=max_length, 187 | min_length=min_length, 188 | do_sample=True, 189 | top_p=top_p, 190 | num_return_sequences=1, 191 | eos_token_id=self.tokenizer.sep_token_id, 192 | pad_token_id=self.tokenizer.pad_token_id, 193 | repetition_penalty=1.1, 194 | **model_kwargs 195 | ) 196 | else: 197 | # beam search 198 | outputs = self.text_decoder.generate( 199 | input_ids=input_ids, 200 | max_length=max_length, 201 | min_length=min_length, 202 | num_beams=num_beams, 203 | eos_token_id=self.tokenizer.sep_token_id, 204 | pad_token_id=self.tokenizer.pad_token_id, 205 | repetition_penalty=repetition_penalty, 206 | **model_kwargs 207 | ) 208 | 209 | captions = [] 210 | for output in outputs: 211 | caption = self.tokenizer.decode(output, skip_special_tokens=True) 212 | captions.append(caption[len(self.prompt) :]) 213 | return captions 214 | 215 | 216 | def blip_decoder(pretrained="", **kwargs): 217 | model = BLIP_Decoder(**kwargs) 218 | if pretrained: 219 | model, msg = load_checkpoint(model, pretrained) 220 | assert len(msg.missing_keys) == 0 221 | return model 222 | 223 | 224 | def blip_feature_extractor(pretrained="", **kwargs): 225 | model = BLIP_Base(**kwargs) 226 | if pretrained: 227 | model, msg = load_checkpoint(model, pretrained) 228 | assert len(msg.missing_keys) == 0 229 | return model 230 | 231 | 232 | def init_tokenizer(): 233 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 234 | tokenizer.add_special_tokens({"bos_token": "[DEC]"}) 235 | tokenizer.add_special_tokens({"additional_special_tokens": ["[ENC]"]}) 236 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] 237 | return tokenizer 238 | 239 | 240 | def create_vit( 241 | vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0 242 | ): 243 | assert vit in ["base", "large"], "vit parameter must be base or large" 244 | if vit == "base": 245 | vision_width = 768 246 | visual_encoder = VisionTransformer( 247 | img_size=image_size, 248 | patch_size=16, 249 | embed_dim=vision_width, 250 | depth=12, 251 | num_heads=12, 252 | use_grad_checkpointing=use_grad_checkpointing, 253 | ckpt_layer=ckpt_layer, 254 | drop_path_rate=0 or drop_path_rate, 255 | ) 256 | elif vit == "large": 257 | vision_width = 1024 258 | visual_encoder = VisionTransformer( 259 | img_size=image_size, 260 | patch_size=16, 261 | embed_dim=vision_width, 262 | depth=24, 263 | num_heads=16, 264 | use_grad_checkpointing=use_grad_checkpointing, 265 | ckpt_layer=ckpt_layer, 266 | drop_path_rate=0.1 or drop_path_rate, 267 | ) 268 | return visual_encoder, vision_width 269 | 270 | 271 | def is_url(url_or_filename): 272 | parsed = urlparse(url_or_filename) 273 | return parsed.scheme in ("http", "https") 274 | 275 | 276 | def load_checkpoint(model, url_or_filename): 277 | if is_url(url_or_filename): 278 | cached_file = download_cached_file( 279 | url_or_filename, check_hash=False, progress=True 280 | ) 281 | checkpoint = torch.load(cached_file, map_location="cpu") 282 | elif os.path.isfile(url_or_filename): 283 | checkpoint = torch.load(url_or_filename, map_location="cpu") 284 | else: 285 | raise RuntimeError("checkpoint url or path is invalid") 286 | 287 | state_dict = checkpoint["model"] 288 | 289 | state_dict["visual_encoder.pos_embed"] = interpolate_pos_embed( 290 | state_dict["visual_encoder.pos_embed"], model.visual_encoder 291 | ) 292 | if "visual_encoder_m.pos_embed" in model.state_dict().keys(): 293 | state_dict["visual_encoder_m.pos_embed"] = interpolate_pos_embed( 294 | state_dict["visual_encoder_m.pos_embed"], model.visual_encoder_m 295 | ) 296 | for key in model.state_dict().keys(): 297 | if key in state_dict.keys(): 298 | if state_dict[key].shape != model.state_dict()[key].shape: 299 | del state_dict[key] 300 | 301 | msg = model.load_state_dict(state_dict, strict=False) 302 | print("load checkpoint from %s" % url_or_filename) 303 | return model, msg 304 | -------------------------------------------------------------------------------- /modules/interrogate/models/blip_itm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | from .blip import create_vit, init_tokenizer, load_checkpoint 7 | from .med import BertConfig, BertModel 8 | 9 | 10 | class BLIP_ITM(nn.Module): 11 | def __init__( 12 | self, 13 | med_config="configs/med_config.json", 14 | image_size=384, 15 | vit="base", 16 | vit_grad_ckpt=False, 17 | vit_ckpt_layer=0, 18 | embed_dim=256, 19 | ): 20 | """ 21 | Args: 22 | med_config (str): path for the mixture of encoder-decoder model's configuration file 23 | image_size (int): input image size 24 | vit (str): model size of vision transformer 25 | """ 26 | super().__init__() 27 | 28 | self.visual_encoder, vision_width = create_vit( 29 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer 30 | ) 31 | self.tokenizer = init_tokenizer() 32 | med_config = BertConfig.from_json_file(med_config) 33 | med_config.encoder_width = vision_width 34 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 35 | 36 | text_width = self.text_encoder.config.hidden_size 37 | 38 | self.vision_proj = nn.Linear(vision_width, embed_dim) 39 | self.text_proj = nn.Linear(text_width, embed_dim) 40 | 41 | self.itm_head = nn.Linear(text_width, 2) 42 | 43 | def forward(self, image, caption, match_head="itm"): 44 | image_embeds = self.visual_encoder(image) 45 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( 46 | image.device 47 | ) 48 | 49 | text = self.tokenizer( 50 | caption, 51 | padding="max_length", 52 | truncation=True, 53 | max_length=35, 54 | return_tensors="pt", 55 | ).to(image.device) 56 | 57 | if match_head == "itm": 58 | output = self.text_encoder( 59 | text.input_ids, 60 | attention_mask=text.attention_mask, 61 | encoder_hidden_states=image_embeds, 62 | encoder_attention_mask=image_atts, 63 | return_dict=True, 64 | ) 65 | itm_output = self.itm_head(output.last_hidden_state[:, 0, :]) 66 | return itm_output 67 | 68 | elif match_head == "itc": 69 | text_output = self.text_encoder( 70 | text.input_ids, 71 | attention_mask=text.attention_mask, 72 | return_dict=True, 73 | mode="text", 74 | ) 75 | image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1) 76 | text_feat = F.normalize( 77 | self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1 78 | ) 79 | 80 | sim = image_feat @ text_feat.t() 81 | return sim 82 | 83 | 84 | def blip_itm(pretrained="", **kwargs): 85 | model = BLIP_ITM(**kwargs) 86 | if pretrained: 87 | model, msg = load_checkpoint(model, pretrained) 88 | assert len(msg.missing_keys) == 0 89 | return model 90 | -------------------------------------------------------------------------------- /modules/interrogate/models/blip_nlvr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from timm.models.hub import download_cached_file 5 | 6 | 7 | from .med import BertConfig 8 | from .nlvr_encoder import BertModel 9 | from .vit import interpolate_pos_embed 10 | from .blip import create_vit, init_tokenizer, is_url 11 | 12 | 13 | class BLIP_NLVR(nn.Module): 14 | def __init__( 15 | self, 16 | med_config="configs/med_config.json", 17 | image_size=480, 18 | vit="base", 19 | vit_grad_ckpt=False, 20 | vit_ckpt_layer=0, 21 | ): 22 | """ 23 | Args: 24 | med_config (str): path for the mixture of encoder-decoder model's configuration file 25 | image_size (int): input image size 26 | vit (str): model size of vision transformer 27 | """ 28 | super().__init__() 29 | 30 | self.visual_encoder, vision_width = create_vit( 31 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1 32 | ) 33 | self.tokenizer = init_tokenizer() 34 | med_config = BertConfig.from_json_file(med_config) 35 | med_config.encoder_width = vision_width 36 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 37 | 38 | self.cls_head = nn.Sequential( 39 | nn.Linear( 40 | self.text_encoder.config.hidden_size, 41 | self.text_encoder.config.hidden_size, 42 | ), 43 | nn.ReLU(), 44 | nn.Linear(self.text_encoder.config.hidden_size, 2), 45 | ) 46 | 47 | def forward(self, image, text, targets, train=True): 48 | image_embeds = self.visual_encoder(image) 49 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( 50 | image.device 51 | ) 52 | image0_embeds, image1_embeds = torch.split(image_embeds, targets.size(0)) 53 | 54 | text = self.tokenizer(text, padding="longest", return_tensors="pt").to( 55 | image.device 56 | ) 57 | text.input_ids[:, 0] = self.tokenizer.enc_token_id 58 | 59 | output = self.text_encoder( 60 | text.input_ids, 61 | attention_mask=text.attention_mask, 62 | encoder_hidden_states=[image0_embeds, image1_embeds], 63 | encoder_attention_mask=[ 64 | image_atts[: image0_embeds.size(0)], 65 | image_atts[image0_embeds.size(0) :], 66 | ], 67 | return_dict=True, 68 | ) 69 | hidden_state = output.last_hidden_state[:, 0, :] 70 | prediction = self.cls_head(hidden_state) 71 | 72 | if train: 73 | loss = F.cross_entropy(prediction, targets) 74 | return loss 75 | else: 76 | return prediction 77 | 78 | 79 | def blip_nlvr(pretrained="", **kwargs): 80 | model = BLIP_NLVR(**kwargs) 81 | if pretrained: 82 | model, msg = load_checkpoint(model, pretrained) 83 | print("missing keys:") 84 | print(msg.missing_keys) 85 | return model 86 | 87 | 88 | def load_checkpoint(model, url_or_filename): 89 | if is_url(url_or_filename): 90 | cached_file = download_cached_file( 91 | url_or_filename, check_hash=False, progress=True 92 | ) 93 | checkpoint = torch.load(cached_file, map_location="cpu") 94 | elif os.path.isfile(url_or_filename): 95 | checkpoint = torch.load(url_or_filename, map_location="cpu") 96 | else: 97 | raise RuntimeError("checkpoint url or path is invalid") 98 | state_dict = checkpoint["model"] 99 | 100 | state_dict["visual_encoder.pos_embed"] = interpolate_pos_embed( 101 | state_dict["visual_encoder.pos_embed"], model.visual_encoder 102 | ) 103 | 104 | for key in list(state_dict.keys()): 105 | if "crossattention.self." in key: 106 | new_key0 = key.replace("self", "self0") 107 | new_key1 = key.replace("self", "self1") 108 | state_dict[new_key0] = state_dict[key] 109 | state_dict[new_key1] = state_dict[key] 110 | elif "crossattention.output.dense." in key: 111 | new_key0 = key.replace("dense", "dense0") 112 | new_key1 = key.replace("dense", "dense1") 113 | state_dict[new_key0] = state_dict[key] 114 | state_dict[new_key1] = state_dict[key] 115 | 116 | msg = model.load_state_dict(state_dict, strict=False) 117 | print("load checkpoint from %s" % url_or_filename) 118 | return model, msg 119 | -------------------------------------------------------------------------------- /modules/interrogate/models/blip_vqa.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | 7 | from .med import BertConfig, BertModel, BertLMHeadModel 8 | from .blip import create_vit, init_tokenizer, load_checkpoint 9 | 10 | 11 | class BLIP_VQA(nn.Module): 12 | def __init__( 13 | self, 14 | med_config="configs/med_config.json", 15 | image_size=480, 16 | vit="base", 17 | vit_grad_ckpt=False, 18 | vit_ckpt_layer=0, 19 | ): 20 | """ 21 | Args: 22 | med_config (str): path for the mixture of encoder-decoder model's configuration file 23 | image_size (int): input image size 24 | vit (str): model size of vision transformer 25 | """ 26 | super().__init__() 27 | 28 | self.visual_encoder, vision_width = create_vit( 29 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1 30 | ) 31 | self.tokenizer = init_tokenizer() 32 | 33 | encoder_config = BertConfig.from_json_file(med_config) 34 | encoder_config.encoder_width = vision_width 35 | self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False) 36 | 37 | decoder_config = BertConfig.from_json_file(med_config) 38 | self.text_decoder = BertLMHeadModel(config=decoder_config) 39 | 40 | def forward( 41 | self, 42 | image, 43 | question, 44 | answer=None, 45 | n=None, 46 | weights=None, 47 | train=True, 48 | inference="rank", 49 | k_test=128, 50 | ): 51 | image_embeds = self.visual_encoder(image) 52 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( 53 | image.device 54 | ) 55 | 56 | question = self.tokenizer( 57 | question, 58 | padding="longest", 59 | truncation=True, 60 | max_length=35, 61 | return_tensors="pt", 62 | ).to(image.device) 63 | question.input_ids[:, 0] = self.tokenizer.enc_token_id 64 | 65 | if train: 66 | """ 67 | n: number of answers for each question 68 | weights: weight for each answer 69 | """ 70 | answer = self.tokenizer(answer, padding="longest", return_tensors="pt").to( 71 | image.device 72 | ) 73 | answer.input_ids[:, 0] = self.tokenizer.bos_token_id 74 | answer_targets = answer.input_ids.masked_fill( 75 | answer.input_ids == self.tokenizer.pad_token_id, -100 76 | ) 77 | 78 | question_output = self.text_encoder( 79 | question.input_ids, 80 | attention_mask=question.attention_mask, 81 | encoder_hidden_states=image_embeds, 82 | encoder_attention_mask=image_atts, 83 | return_dict=True, 84 | ) 85 | 86 | question_states = [] 87 | question_atts = [] 88 | for b, n in enumerate(n): 89 | question_states += [question_output.last_hidden_state[b]] * n 90 | question_atts += [question.attention_mask[b]] * n 91 | question_states = torch.stack(question_states, 0) 92 | question_atts = torch.stack(question_atts, 0) 93 | 94 | answer_output = self.text_decoder( 95 | answer.input_ids, 96 | attention_mask=answer.attention_mask, 97 | encoder_hidden_states=question_states, 98 | encoder_attention_mask=question_atts, 99 | labels=answer_targets, 100 | return_dict=True, 101 | reduction="none", 102 | ) 103 | 104 | loss = weights * answer_output.loss 105 | loss = loss.sum() / image.size(0) 106 | 107 | return loss 108 | 109 | else: 110 | question_output = self.text_encoder( 111 | question.input_ids, 112 | attention_mask=question.attention_mask, 113 | encoder_hidden_states=image_embeds, 114 | encoder_attention_mask=image_atts, 115 | return_dict=True, 116 | ) 117 | 118 | if inference == "generate": 119 | num_beams = 3 120 | question_states = question_output.last_hidden_state.repeat_interleave( 121 | num_beams, dim=0 122 | ) 123 | question_atts = torch.ones( 124 | question_states.size()[:-1], dtype=torch.long 125 | ).to(question_states.device) 126 | model_kwargs = { 127 | "encoder_hidden_states": question_states, 128 | "encoder_attention_mask": question_atts, 129 | } 130 | 131 | bos_ids = torch.full( 132 | (image.size(0), 1), 133 | fill_value=self.tokenizer.bos_token_id, 134 | device=image.device, 135 | ) 136 | 137 | outputs = self.text_decoder.generate( 138 | input_ids=bos_ids, 139 | max_length=10, 140 | min_length=1, 141 | num_beams=num_beams, 142 | eos_token_id=self.tokenizer.sep_token_id, 143 | pad_token_id=self.tokenizer.pad_token_id, 144 | **model_kwargs 145 | ) 146 | 147 | answers = [] 148 | for output in outputs: 149 | answer = self.tokenizer.decode(output, skip_special_tokens=True) 150 | answers.append(answer) 151 | return answers 152 | 153 | elif inference == "rank": 154 | max_ids = self.rank_answer( 155 | question_output.last_hidden_state, 156 | question.attention_mask, 157 | answer.input_ids, 158 | answer.attention_mask, 159 | k_test, 160 | ) 161 | return max_ids 162 | 163 | def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k): 164 | num_ques = question_states.size(0) 165 | start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token 166 | 167 | start_output = self.text_decoder( 168 | start_ids, 169 | encoder_hidden_states=question_states, 170 | encoder_attention_mask=question_atts, 171 | return_dict=True, 172 | reduction="none", 173 | ) 174 | logits = start_output.logits[:, 0, :] # first token's logit 175 | 176 | # topk_probs: top-k probability 177 | # topk_ids: [num_question, k] 178 | answer_first_token = answer_ids[:, 1] 179 | prob_first_token = F.softmax(logits, dim=1).index_select( 180 | dim=1, index=answer_first_token 181 | ) 182 | topk_probs, topk_ids = prob_first_token.topk(k, dim=1) 183 | 184 | # answer input: [num_question*k, answer_len] 185 | input_ids = [] 186 | input_atts = [] 187 | for b, topk_id in enumerate(topk_ids): 188 | input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) 189 | input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) 190 | input_ids = torch.cat(input_ids, dim=0) 191 | input_atts = torch.cat(input_atts, dim=0) 192 | 193 | targets_ids = input_ids.masked_fill( 194 | input_ids == self.tokenizer.pad_token_id, -100 195 | ) 196 | 197 | # repeat encoder's output for top-k answers 198 | question_states = tile(question_states, 0, k) 199 | question_atts = tile(question_atts, 0, k) 200 | 201 | output = self.text_decoder( 202 | input_ids, 203 | attention_mask=input_atts, 204 | encoder_hidden_states=question_states, 205 | encoder_attention_mask=question_atts, 206 | labels=targets_ids, 207 | return_dict=True, 208 | reduction="none", 209 | ) 210 | 211 | log_probs_sum = -output.loss 212 | log_probs_sum = log_probs_sum.view(num_ques, k) 213 | 214 | max_topk_ids = log_probs_sum.argmax(dim=1) 215 | max_ids = topk_ids[max_topk_ids >= 0, max_topk_ids] 216 | 217 | return max_ids 218 | 219 | 220 | def blip_vqa(pretrained="", **kwargs): 221 | model = BLIP_VQA(**kwargs) 222 | if pretrained: 223 | model, msg = load_checkpoint(model, pretrained) 224 | # assert(len(msg.missing_keys)==0) 225 | return model 226 | 227 | 228 | def tile(x, dim, n_tile): 229 | init_dim = x.size(dim) 230 | repeat_idx = [1] * x.dim() 231 | repeat_idx[dim] = n_tile 232 | x = x.repeat(*(repeat_idx)) 233 | order_index = torch.LongTensor( 234 | np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) 235 | ) 236 | return torch.index_select(x, dim, order_index.to(x.device)) 237 | -------------------------------------------------------------------------------- /modules/interrogate/transform/randaugment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | ## aug functions 6 | def identity_func(img): 7 | return img 8 | 9 | 10 | def autocontrast_func(img, cutoff=0): 11 | ''' 12 | same output as PIL.ImageOps.autocontrast 13 | ''' 14 | n_bins = 256 15 | 16 | def tune_channel(ch): 17 | n = ch.size 18 | cut = cutoff * n // 100 19 | if cut == 0: 20 | high, low = ch.max(), ch.min() 21 | else: 22 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 23 | low = np.argwhere(np.cumsum(hist) > cut) 24 | low = 0 if low.shape[0] == 0 else low[0] 25 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 26 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 27 | if high <= low: 28 | table = np.arange(n_bins) 29 | else: 30 | scale = (n_bins - 1) / (high - low) 31 | offset = -low * scale 32 | table = np.arange(n_bins) * scale + offset 33 | table[table < 0] = 0 34 | table[table > n_bins - 1] = n_bins - 1 35 | table = table.clip(0, 255).astype(np.uint8) 36 | return table[ch] 37 | 38 | channels = [tune_channel(ch) for ch in cv2.split(img)] 39 | out = cv2.merge(channels) 40 | return out 41 | 42 | 43 | def equalize_func(img): 44 | ''' 45 | same output as PIL.ImageOps.equalize 46 | PIL's implementation is different from cv2.equalize 47 | ''' 48 | n_bins = 256 49 | 50 | def tune_channel(ch): 51 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 52 | non_zero_hist = hist[hist != 0].reshape(-1) 53 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 54 | if step == 0: return ch 55 | n = np.empty_like(hist) 56 | n[0] = step // 2 57 | n[1:] = hist[:-1] 58 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 59 | return table[ch] 60 | 61 | channels = [tune_channel(ch) for ch in cv2.split(img)] 62 | out = cv2.merge(channels) 63 | return out 64 | 65 | 66 | def rotate_func(img, degree, fill=(0, 0, 0)): 67 | ''' 68 | like PIL, rotate by degree, not radians 69 | ''' 70 | H, W = img.shape[0], img.shape[1] 71 | center = W / 2, H / 2 72 | M = cv2.getRotationMatrix2D(center, degree, 1) 73 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 74 | return out 75 | 76 | 77 | def solarize_func(img, thresh=128): 78 | ''' 79 | same output as PIL.ImageOps.posterize 80 | ''' 81 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 82 | table = table.clip(0, 255).astype(np.uint8) 83 | out = table[img] 84 | return out 85 | 86 | 87 | def color_func(img, factor): 88 | ''' 89 | same output as PIL.ImageEnhance.Color 90 | ''' 91 | ## implementation according to PIL definition, quite slow 92 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 93 | # out = blend(degenerate, img, factor) 94 | # M = ( 95 | # np.eye(3) * factor 96 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 97 | # )[np.newaxis, np.newaxis, :] 98 | M = ( 99 | np.float32([ 100 | [0.886, -0.114, -0.114], 101 | [-0.587, 0.413, -0.587], 102 | [-0.299, -0.299, 0.701]]) * factor 103 | + np.float32([[0.114], [0.587], [0.299]]) 104 | ) 105 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 106 | return out 107 | 108 | 109 | def contrast_func(img, factor): 110 | """ 111 | same output as PIL.ImageEnhance.Contrast 112 | """ 113 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 114 | table = np.array([( 115 | el - mean) * factor + mean 116 | for el in range(256) 117 | ]).clip(0, 255).astype(np.uint8) 118 | out = table[img] 119 | return out 120 | 121 | 122 | def brightness_func(img, factor): 123 | ''' 124 | same output as PIL.ImageEnhance.Contrast 125 | ''' 126 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 127 | out = table[img] 128 | return out 129 | 130 | 131 | def sharpness_func(img, factor): 132 | ''' 133 | The differences the this result and PIL are all on the 4 boundaries, the center 134 | areas are same 135 | ''' 136 | kernel = np.ones((3, 3), dtype=np.float32) 137 | kernel[1][1] = 5 138 | kernel /= 13 139 | degenerate = cv2.filter2D(img, -1, kernel) 140 | if factor == 0.0: 141 | out = degenerate 142 | elif factor == 1.0: 143 | out = img 144 | else: 145 | out = img.astype(np.float32) 146 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 147 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 148 | out = out.astype(np.uint8) 149 | return out 150 | 151 | 152 | def shear_x_func(img, factor, fill=(0, 0, 0)): 153 | H, W = img.shape[0], img.shape[1] 154 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 155 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 156 | return out 157 | 158 | 159 | def translate_x_func(img, offset, fill=(0, 0, 0)): 160 | ''' 161 | same output as PIL.Image.transform 162 | ''' 163 | H, W = img.shape[0], img.shape[1] 164 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 165 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 166 | return out 167 | 168 | 169 | def translate_y_func(img, offset, fill=(0, 0, 0)): 170 | ''' 171 | same output as PIL.Image.transform 172 | ''' 173 | H, W = img.shape[0], img.shape[1] 174 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 175 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 176 | return out 177 | 178 | 179 | def posterize_func(img, bits): 180 | ''' 181 | same output as PIL.ImageOps.posterize 182 | ''' 183 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 184 | return out 185 | 186 | 187 | def shear_y_func(img, factor, fill=(0, 0, 0)): 188 | H, W = img.shape[0], img.shape[1] 189 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 190 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 191 | return out 192 | 193 | 194 | def cutout_func(img, pad_size, replace=(0, 0, 0)): 195 | replace = np.array(replace, dtype=np.uint8) 196 | H, W = img.shape[0], img.shape[1] 197 | rh, rw = np.random.random(2) 198 | pad_size = pad_size // 2 199 | ch, cw = int(rh * H), int(rw * W) 200 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 201 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 202 | out = img.copy() 203 | out[x1:x2, y1:y2, :] = replace 204 | return out 205 | 206 | 207 | ### level to args 208 | def enhance_level_to_args(MAX_LEVEL): 209 | def level_to_args(level): 210 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 211 | return level_to_args 212 | 213 | 214 | def shear_level_to_args(MAX_LEVEL, replace_value): 215 | def level_to_args(level): 216 | level = (level / MAX_LEVEL) * 0.3 217 | if np.random.random() > 0.5: level = -level 218 | return (level, replace_value) 219 | 220 | return level_to_args 221 | 222 | 223 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 224 | def level_to_args(level): 225 | level = (level / MAX_LEVEL) * float(translate_const) 226 | if np.random.random() > 0.5: level = -level 227 | return (level, replace_value) 228 | 229 | return level_to_args 230 | 231 | 232 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 233 | def level_to_args(level): 234 | level = int((level / MAX_LEVEL) * cutout_const) 235 | return (level, replace_value) 236 | 237 | return level_to_args 238 | 239 | 240 | def solarize_level_to_args(MAX_LEVEL): 241 | def level_to_args(level): 242 | level = int((level / MAX_LEVEL) * 256) 243 | return (level, ) 244 | return level_to_args 245 | 246 | 247 | def none_level_to_args(level): 248 | return () 249 | 250 | 251 | def posterize_level_to_args(MAX_LEVEL): 252 | def level_to_args(level): 253 | level = int((level / MAX_LEVEL) * 4) 254 | return (level, ) 255 | return level_to_args 256 | 257 | 258 | def rotate_level_to_args(MAX_LEVEL, replace_value): 259 | def level_to_args(level): 260 | level = (level / MAX_LEVEL) * 30 261 | if np.random.random() < 0.5: 262 | level = -level 263 | return (level, replace_value) 264 | 265 | return level_to_args 266 | 267 | 268 | func_dict = { 269 | 'Identity': identity_func, 270 | 'AutoContrast': autocontrast_func, 271 | 'Equalize': equalize_func, 272 | 'Rotate': rotate_func, 273 | 'Solarize': solarize_func, 274 | 'Color': color_func, 275 | 'Contrast': contrast_func, 276 | 'Brightness': brightness_func, 277 | 'Sharpness': sharpness_func, 278 | 'ShearX': shear_x_func, 279 | 'TranslateX': translate_x_func, 280 | 'TranslateY': translate_y_func, 281 | 'Posterize': posterize_func, 282 | 'ShearY': shear_y_func, 283 | } 284 | 285 | translate_const = 10 286 | MAX_LEVEL = 10 287 | replace_value = (128, 128, 128) 288 | arg_dict = { 289 | 'Identity': none_level_to_args, 290 | 'AutoContrast': none_level_to_args, 291 | 'Equalize': none_level_to_args, 292 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value), 293 | 'Solarize': solarize_level_to_args(MAX_LEVEL), 294 | 'Color': enhance_level_to_args(MAX_LEVEL), 295 | 'Contrast': enhance_level_to_args(MAX_LEVEL), 296 | 'Brightness': enhance_level_to_args(MAX_LEVEL), 297 | 'Sharpness': enhance_level_to_args(MAX_LEVEL), 298 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value), 299 | 'TranslateX': translate_level_to_args( 300 | translate_const, MAX_LEVEL, replace_value 301 | ), 302 | 'TranslateY': translate_level_to_args( 303 | translate_const, MAX_LEVEL, replace_value 304 | ), 305 | 'Posterize': posterize_level_to_args(MAX_LEVEL), 306 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value), 307 | } 308 | 309 | 310 | class RandomAugment(object): 311 | 312 | def __init__(self, N=2, M=10, isPIL=False, augs=[]): 313 | self.N = N 314 | self.M = M 315 | self.isPIL = isPIL 316 | if augs: 317 | self.augs = augs 318 | else: 319 | self.augs = list(arg_dict.keys()) 320 | 321 | def get_random_ops(self): 322 | sampled_ops = np.random.choice(self.augs, self.N) 323 | return [(op, 0.5, self.M) for op in sampled_ops] 324 | 325 | def __call__(self, img): 326 | if self.isPIL: 327 | img = np.array(img) 328 | ops = self.get_random_ops() 329 | for name, prob, level in ops: 330 | if np.random.random() > prob: 331 | continue 332 | args = arg_dict[name](level) 333 | img = func_dict[name](img, *args) 334 | return img 335 | 336 | 337 | if __name__ == '__main__': 338 | a = RandomAugment() 339 | img = np.random.randn(32, 32, 3) 340 | a(img) -------------------------------------------------------------------------------- /modules/ip_adapter_nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Tuple 3 | 4 | import folder_paths 5 | import comfy.clip_vision 6 | import comfy.controlnet 7 | import comfy.utils 8 | import comfy.model_management 9 | 10 | from .utils import load_module 11 | 12 | custom_nodes = folder_paths.get_folder_paths("custom_nodes") 13 | ip_adapter_dir_names = ["IPAdapter", "ComfyUI_IPAdapter_plus"] 14 | 15 | NODE_CLASS_MAPPINGS = {} 16 | NODE_DISPLAY_NAME_MAPPINGS = {} 17 | 18 | try: 19 | module_path = None 20 | 21 | for custom_node in custom_nodes: 22 | custom_node = custom_node if not os.path.islink(custom_node) else os.readlink(custom_node) 23 | for module_dir in ip_adapter_dir_names: 24 | if module_dir in os.listdir(custom_node): 25 | module_path = os.path.abspath(os.path.join(custom_node, module_dir)) 26 | break 27 | 28 | if module_path is None: 29 | raise Exception("Could not find IPAdapter nodes") 30 | 31 | module = load_module(module_path) 32 | print("Loaded IPAdapter nodes from", module_path) 33 | 34 | nodes: Dict = getattr(module, "NODE_CLASS_MAPPINGS") 35 | IPAdapterModelLoader = nodes.get("IPAdapterModelLoader") 36 | IPAdapterSimple = nodes.get("IPAdapter") 37 | 38 | loader = IPAdapterModelLoader() 39 | apply = IPAdapterSimple() 40 | 41 | class AV_IPAdapterPipe: 42 | @classmethod 43 | def INPUT_TYPES(cls): 44 | return { 45 | "required": { 46 | "ip_adapter_name": (folder_paths.get_filename_list("ipadapter"),), 47 | "clip_name": (folder_paths.get_filename_list("clip_vision"),), 48 | } 49 | } 50 | 51 | RETURN_TYPES = ("IPADAPTER",) 52 | RETURN_NAMES = ("pipeline",) 53 | CATEGORY = "Art Venture/IP Adapter" 54 | FUNCTION = "load_ip_adapter" 55 | 56 | def load_ip_adapter(self, ip_adapter_name, clip_name): 57 | ip_adapter = loader.load_ipadapter_model(ip_adapter_name)[0] 58 | 59 | clip_path = folder_paths.get_full_path("clip_vision", clip_name) 60 | clip_vision = comfy.clip_vision.load(clip_path) 61 | 62 | pipeline = {"ipadapter": {"model": ip_adapter}, "clipvision": {"model": clip_vision}} 63 | return (pipeline,) 64 | 65 | class AV_IPAdapter: 66 | @classmethod 67 | def INPUT_TYPES(cls): 68 | inputs = IPAdapterSimple.INPUT_TYPES() 69 | 70 | return { 71 | "required": { 72 | "ip_adapter_name": (["None"] + folder_paths.get_filename_list("ipadapter"),), 73 | "clip_name": (["None"] + folder_paths.get_filename_list("clip_vision"),), 74 | "model": ("MODEL",), 75 | "image": ("IMAGE",), 76 | "weight": ("FLOAT", {"default": 1.0, "min": -1, "max": 3, "step": 0.05}), 77 | }, 78 | "optional": { 79 | "ip_adapter_opt": ("IPADAPTER",), 80 | "clip_vision_opt": ("CLIP_VISION",), 81 | "attn_mask": ("MASK",), 82 | "start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), 83 | "end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), 84 | "weight_type": inputs["required"]["weight_type"], 85 | "enabled": ("BOOLEAN", {"default": True}), 86 | }, 87 | } 88 | 89 | RETURN_TYPES = ("MODEL", "IPADAPTER", "CLIP_VISION") 90 | RETURN_NAMES = ("model", "pipeline", "clip_vision") 91 | CATEGORY = "Art Venture/IP Adapter" 92 | FUNCTION = "apply_ip_adapter" 93 | 94 | def apply_ip_adapter( 95 | self, 96 | ip_adapter_name, 97 | clip_name, 98 | model, 99 | image, 100 | weight, 101 | ip_adapter_opt=None, 102 | clip_vision_opt=None, 103 | enabled=True, 104 | **kwargs, 105 | ): 106 | if not enabled: 107 | return (model, None, None) 108 | 109 | if ip_adapter_opt: 110 | if "ipadapter" in ip_adapter_opt: 111 | ip_adapter = ip_adapter_opt["ipadapter"]["model"] 112 | else: 113 | ip_adapter = ip_adapter_opt 114 | else: 115 | assert ip_adapter_name != "None", "IP Adapter name must be specified" 116 | ip_adapter = loader.load_ipadapter_model(ip_adapter_name)[0] 117 | 118 | if clip_vision_opt: 119 | clip_vision = clip_vision_opt 120 | elif ip_adapter_opt and "clipvision" in ip_adapter_opt: 121 | clip_vision = ip_adapter_opt["clipvision"]["model"] 122 | else: 123 | assert clip_name != "None", "Clip vision name must be specified" 124 | clip_path = folder_paths.get_full_path("clip_vision", clip_name) 125 | clip_vision = comfy.clip_vision.load(clip_path) 126 | 127 | pipeline = {"ipadapter": {"model": ip_adapter}, "clipvision": {"model": clip_vision}} 128 | 129 | res: Tuple = apply.apply_ipadapter(model, pipeline, image=image, weight=weight, **kwargs) 130 | res += (pipeline, clip_vision) 131 | 132 | return res 133 | 134 | NODE_CLASS_MAPPINGS.update( 135 | { 136 | "AV_IPAdapter": AV_IPAdapter, 137 | "AV_IPAdapterPipe": AV_IPAdapterPipe, 138 | } 139 | ) 140 | NODE_DISPLAY_NAME_MAPPINGS.update( 141 | { 142 | "AV_IPAdapter": "IP Adapter Apply", 143 | "AV_IPAdapterPipe": "IP Adapter Pipe", 144 | } 145 | ) 146 | 147 | except Exception as e: 148 | print(e) 149 | -------------------------------------------------------------------------------- /modules/isnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .segmenter import ISNetLoader, ISNetSegment, DownloadISNetModel 2 | 3 | NODE_CLASS_MAPPINGS = { 4 | "ISNetLoader": ISNetLoader, 5 | "ISNetSegment": ISNetSegment, 6 | "DownloadISNetModel": DownloadISNetModel, 7 | } 8 | NODE_DISPLAY_NAME_MAPPINGS = { 9 | "ISNetLoader": "ISNet Loader", 10 | "ISNetSegment": "ISNet Segment", 11 | "DownloadISNetModel": "Download and Load ISNet Model", 12 | } 13 | 14 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] 15 | -------------------------------------------------------------------------------- /modules/isnet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .isnet import ISNetBase 2 | from .isnet_dis import ISNetDIS 3 | -------------------------------------------------------------------------------- /modules/isnet/models/isnet_dis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .isnet import RSU4, RSU4F, RSU5, RSU6, RSU7, _upsample_like, ISNetBase 6 | 7 | 8 | bce_loss = nn.BCELoss(size_average=True) 9 | fea_loss = nn.MSELoss(size_average=True) 10 | kl_loss = nn.KLDivLoss(size_average=True) 11 | l1_loss = nn.L1Loss(size_average=True) 12 | smooth_l1_loss = nn.SmoothL1Loss(size_average=True) 13 | 14 | 15 | def muti_loss_fusion(preds, target): 16 | loss0 = 0.0 17 | loss = 0.0 18 | 19 | for i in range(0, len(preds)): 20 | # print("i: ", i, preds[i].shape) 21 | if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]: 22 | # tmp_target = _upsample_like(target,preds[i]) 23 | tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode="bilinear", align_corners=True) 24 | loss = loss + bce_loss(preds[i], tmp_target) 25 | else: 26 | loss = loss + bce_loss(preds[i], target) 27 | if i == 0: 28 | loss0 = loss 29 | return loss0, loss 30 | 31 | 32 | def muti_loss_fusion_kl(preds, target, dfs, fs, mode="MSE"): 33 | loss0 = 0.0 34 | loss = 0.0 35 | 36 | for i in range(0, len(preds)): 37 | # print("i: ", i, preds[i].shape) 38 | if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]: 39 | # tmp_target = _upsample_like(target,preds[i]) 40 | tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode="bilinear", align_corners=True) 41 | loss = loss + bce_loss(preds[i], tmp_target) 42 | else: 43 | loss = loss + bce_loss(preds[i], target) 44 | if i == 0: 45 | loss0 = loss 46 | 47 | for i in range(0, len(dfs)): 48 | if mode == "MSE": 49 | loss = loss + fea_loss(dfs[i], fs[i]) ### add the mse loss of features as additional constraints 50 | # print("fea_loss: ", fea_loss(dfs[i],fs[i]).item()) 51 | elif mode == "KL": 52 | loss = loss + kl_loss(F.log_softmax(dfs[i], dim=1), F.softmax(fs[i], dim=1)) 53 | # print("kl_loss: ", kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)).item()) 54 | elif mode == "MAE": 55 | loss = loss + l1_loss(dfs[i], fs[i]) 56 | # print("ls_loss: ", l1_loss(dfs[i],fs[i])) 57 | elif mode == "SmoothL1": 58 | loss = loss + smooth_l1_loss(dfs[i], fs[i]) 59 | # print("SmoothL1: ", smooth_l1_loss(dfs[i],fs[i]).item()) 60 | 61 | return loss0, loss 62 | 63 | 64 | class ISNetDIS(ISNetBase): 65 | def __init__(self, in_ch=3, out_ch=1): 66 | super(ISNetDIS, self).__init__(in_ch=in_ch, out_ch=out_ch) 67 | 68 | self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) 69 | self.side3 = nn.Conv2d(128, out_ch, 3, padding=1) 70 | self.side4 = nn.Conv2d(256, out_ch, 3, padding=1) 71 | self.side5 = nn.Conv2d(512, out_ch, 3, padding=1) 72 | self.side6 = nn.Conv2d(512, out_ch, 3, padding=1) 73 | 74 | def compute_loss_kl(self, preds, targets, dfs, fs, mode="MSE"): 75 | # return muti_loss_fusion(preds,targets) 76 | return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode) 77 | 78 | def compute_loss(self, preds, targets): 79 | # return muti_loss_fusion(preds,targets) 80 | return muti_loss_fusion(preds, targets) 81 | 82 | def forward(self, x): 83 | hx = x 84 | 85 | hxin = self.conv_in(hx) 86 | # hx = self.pool_in(hxin) 87 | 88 | # stage 1 89 | hx1 = self.stage1(hxin) 90 | hx = self.pool12(hx1) 91 | 92 | # stage 2 93 | hx2 = self.stage2(hx) 94 | hx = self.pool23(hx2) 95 | 96 | # stage 3 97 | hx3 = self.stage3(hx) 98 | hx = self.pool34(hx3) 99 | 100 | # stage 4 101 | hx4 = self.stage4(hx) 102 | hx = self.pool45(hx4) 103 | 104 | # stage 5 105 | hx5 = self.stage5(hx) 106 | hx = self.pool56(hx5) 107 | 108 | # stage 6 109 | hx6 = self.stage6(hx) 110 | hx6up = _upsample_like(hx6, hx5) 111 | 112 | # -------------------- decoder -------------------- 113 | hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) 114 | hx5dup = _upsample_like(hx5d, hx4) 115 | 116 | hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) 117 | hx4dup = _upsample_like(hx4d, hx3) 118 | 119 | hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) 120 | hx3dup = _upsample_like(hx3d, hx2) 121 | 122 | hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) 123 | hx2dup = _upsample_like(hx2d, hx1) 124 | 125 | hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) 126 | 127 | # side output 128 | d1 = self.side1(hx1d) 129 | d1 = _upsample_like(d1, x) 130 | 131 | d2 = self.side2(hx2d) 132 | d2 = _upsample_like(d2, x) 133 | 134 | d3 = self.side3(hx3d) 135 | d3 = _upsample_like(d3, x) 136 | 137 | d4 = self.side4(hx4d) 138 | d4 = _upsample_like(d4, x) 139 | 140 | d5 = self.side5(hx5d) 141 | d5 = _upsample_like(d5, x) 142 | 143 | d6 = self.side6(hx6) 144 | d6 = _upsample_like(d6, x) 145 | 146 | return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [ 147 | hx1d, 148 | hx2d, 149 | hx3d, 150 | hx4d, 151 | hx5d, 152 | hx6, 153 | ] 154 | -------------------------------------------------------------------------------- /modules/isnet/segmenter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from PIL import Image 6 | from torch import nn 7 | from torch.autograd import Variable 8 | from torchvision import transforms 9 | from torchvision.transforms.functional import normalize 10 | 11 | import folder_paths 12 | import comfy.model_management as model_management 13 | import comfy.utils 14 | 15 | from ..model_utils import download_file 16 | from ..utils import pil2tensor, tensor2pil 17 | from ..logger import logger 18 | 19 | 20 | isnets = {} 21 | cache_size = [1024, 1024] 22 | gpu = model_management.get_torch_device() 23 | cpu = torch.device("cpu") 24 | model_dir = os.path.join(folder_paths.models_dir, "isnet") 25 | models = { 26 | "isnet-general-use.pth": { 27 | "url": "https://huggingface.co/NimaBoscarino/IS-Net_DIS-general-use/resolve/main/isnet-general-use.pth", 28 | "sha": "9e1aafea58f0b55d0c35077e0ceade6ba1ba2bce372fd4f8f77215391f3fac13", 29 | }, 30 | "isnetis.pth": { 31 | "url": "https://github.com/Sanster/models/releases/download/isnetis/isnetis.pth", 32 | "sha": "90a970badbd99ca7839b4e0beb09a36565d24edba7e4a876de23c761981e79e0", 33 | }, 34 | "RMBG-1.4.bin": { 35 | "url": "https://huggingface.co/briaai/RMBG-1.4/resolve/main/pytorch_model.bin", 36 | "sha": "59569acdb281ac9fc9f78f9d33b6f9f17f68e25086b74f9025c35bb5f2848967", 37 | }, 38 | } 39 | 40 | folder_paths.folder_names_and_paths["isnet"] = ( 41 | [model_dir], 42 | folder_paths.supported_pt_extensions, 43 | ) 44 | 45 | 46 | class GOSNormalize(object): 47 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 48 | self.mean = mean 49 | self.std = std 50 | 51 | def __call__(self, image): 52 | image = normalize(image, self.mean, self.std) 53 | return image 54 | 55 | 56 | transform = transforms.Compose([GOSNormalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])]) 57 | 58 | 59 | def load_isnet_model(model_name): 60 | if model_name not in isnets: 61 | isnet_path = folder_paths.get_full_path("isnet", model_name) 62 | state_dict = comfy.utils.load_torch_file(isnet_path) 63 | 64 | from .models import ISNetBase, ISNetDIS 65 | 66 | if "side2.weight" in state_dict: 67 | isnet = ISNetDIS() 68 | else: 69 | isnet = ISNetBase() 70 | 71 | # convert to half precision 72 | isnet.is_fp16 = model_management.should_use_fp16() 73 | if isnet.is_fp16: 74 | isnet.half() 75 | for layer in isnet.modules(): 76 | if isinstance(layer, nn.BatchNorm2d): 77 | layer.float() 78 | 79 | isnet.load_state_dict(state_dict) 80 | isnet.eval() 81 | isnets[model_name] = isnet 82 | 83 | return isnets[model_name] 84 | 85 | 86 | def im_preprocess(im: torch.Tensor, size): 87 | im = im.clone() 88 | 89 | # Ensure the image has three channels 90 | if len(im.shape) < 3: 91 | im = im.unsqueeze(2) 92 | if im.shape[2] == 1: 93 | im = im.repeat(1, 1, 3) 94 | 95 | # Permute dimensions to match the model input format (C, H, W) 96 | im_tensor = im.permute(2, 0, 1) 97 | 98 | # Resize the image 99 | im_tensor = F.interpolate(im_tensor.unsqueeze(0), size=size, mode="bilinear").squeeze(0) 100 | 101 | # Normalize the image 102 | im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) 103 | 104 | # Return the processed image tensor with a batch dimension and original size 105 | return im_tensor.unsqueeze(0), im.shape[0:2] 106 | 107 | 108 | def predict(model, im: torch.Tensor, device): 109 | image, orig_size = im_preprocess(im, cache_size) 110 | 111 | if model.is_fp16: 112 | image = image.type(torch.HalfTensor) 113 | else: 114 | image = image.type(torch.FloatTensor) 115 | 116 | image_v = Variable(image, requires_grad=False).to(device) 117 | ds_val = model(image_v) # list of 6 results 118 | 119 | if isinstance(ds_val, tuple): 120 | ds_val = ds_val[0] 121 | 122 | if isinstance(ds_val, list): 123 | ds_val = ds_val[0] 124 | 125 | if len(ds_val.shape) < 4: 126 | ds_val = torch.unsqueeze(ds_val, 0) 127 | 128 | # B x 1 x H x W # we want the first one which is the most accurate prediction 129 | pred_val = ds_val[0, :, :, :] 130 | 131 | # recover the prediction spatial size to the orignal image size 132 | # pred_val = torch.squeeze(F.interpolate(pred_val, size=orig_size, mode='bilinear'), 0) 133 | # pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val, 0), (orig_size[0], orig_size[1]), mode="bilinear")) 134 | pred_val = F.interpolate(pred_val.unsqueeze(0), size=orig_size, mode="bilinear").squeeze(0) 135 | 136 | ma = torch.max(pred_val) 137 | mi = torch.min(pred_val) 138 | pred_val = (pred_val - mi) / (ma - mi) # max = 1 139 | 140 | # it is the mask we need 141 | return pred_val.detach().cpu() 142 | 143 | 144 | class ISNetLoader: 145 | @classmethod 146 | def INPUT_TYPES(s): 147 | return { 148 | "required": { 149 | "model_name": (folder_paths.get_filename_list("isnet"),), 150 | }, 151 | } 152 | 153 | RETURN_TYPES = ("ISNET_MODEL",) 154 | FUNCTION = "load_isnet" 155 | CATEGORY = "Art Venture/Segmentation" 156 | 157 | def load_isnet(self, model_name): 158 | return (load_isnet_model(model_name),) 159 | 160 | 161 | class DownloadISNetModel: 162 | @classmethod 163 | def INPUT_TYPES(s): 164 | return { 165 | "required": { 166 | "model_name": (list(models.keys()),), 167 | }, 168 | } 169 | 170 | RETURN_TYPES = ("ISNET_MODEL",) 171 | FUNCTION = "download_isnet" 172 | CATEGORY = "Art Venture/Segmentation" 173 | 174 | def download_isnet(self, model_name): 175 | if model_name not in folder_paths.get_filename_list("isnet"): 176 | model_info = models[model_name] 177 | download_file( 178 | model_info["url"], 179 | os.path.join(model_dir, model_name), 180 | model_info["sha"], 181 | ) 182 | 183 | return (load_isnet_model(model_name),) 184 | 185 | 186 | class ISNetSegment: 187 | @classmethod 188 | def INPUT_TYPES(s): 189 | return { 190 | "required": { 191 | "images": ("IMAGE",), 192 | "threshold": ("FLOAT", {"default": 0.5, "min": 0, "max": 1, "step": 0.001}), 193 | }, 194 | "optional": { 195 | "device_mode": (["AUTO", "Prefer GPU", "CPU"],), 196 | "enabled": ("BOOLEAN", {"default": True}), 197 | "isnet_model": ("ISNET_MODEL",), 198 | }, 199 | } 200 | 201 | RETURN_TYPES = ("IMAGE", "MASK") 202 | RETURN_NAMES = ("segmented", "mask") 203 | CATEGORY = "Art Venture/Segmentation" 204 | FUNCTION = "segment_isnet" 205 | 206 | def segment_isnet(self, images: torch.Tensor, threshold, device_mode="AUTO", enabled=True, isnet_model=None): 207 | if not enabled: 208 | masks = torch.zeros((len(images), 64, 64), dtype=torch.float32) 209 | return (images, masks) 210 | 211 | if isnet_model is None: 212 | downloader = DownloadISNetModel() 213 | isnet_model = downloader.download_isnet("isnet-general-use.pth")[0] 214 | 215 | device = gpu if device_mode != "CPU" else cpu 216 | isnet_model = isnet_model.to(device) 217 | 218 | try: 219 | segments = [] 220 | masks = [] 221 | for image in images: 222 | mask = predict(isnet_model, image, device) 223 | mask_im = tensor2pil(mask.permute(1, 2, 0)) 224 | cropped = Image.new("RGBA", mask_im.size, (0, 0, 0, 0)) 225 | cropped.paste(tensor2pil(image), mask=mask_im) 226 | 227 | masks.append(mask) 228 | segments.append(pil2tensor(cropped)) 229 | 230 | return (torch.cat(segments, dim=0), torch.stack(masks)) 231 | finally: 232 | if device_mode == "AUTO": 233 | isnet_model = isnet_model.to(cpu) 234 | -------------------------------------------------------------------------------- /modules/llm/__init__.py: -------------------------------------------------------------------------------- 1 | from .chat import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] 4 | -------------------------------------------------------------------------------- /modules/logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import logging 4 | 5 | 6 | class ColoredFormatter(logging.Formatter): 7 | COLORS = { 8 | "DEBUG": "\033[0;36m", # CYAN 9 | "INFO": "\033[0;32m", # GREEN 10 | "WARNING": "\033[0;33m", # YELLOW 11 | "ERROR": "\033[0;31m", # RED 12 | "CRITICAL": "\033[0;37;41m", # WHITE ON RED 13 | "RESET": "\033[0m", # RESET COLOR 14 | } 15 | 16 | def format(self, record): 17 | colored_record = copy.copy(record) 18 | levelname = colored_record.levelname 19 | seq = self.COLORS.get(levelname, self.COLORS["RESET"]) 20 | colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" 21 | return super().format(colored_record) 22 | 23 | 24 | # Create a new logger 25 | logger = logging.getLogger("ArtVenture") 26 | logger.propagate = False 27 | 28 | # Add handler if we don't have one. 29 | if not logger.handlers: 30 | handler = logging.StreamHandler(sys.stdout) 31 | handler.setFormatter(ColoredFormatter("[%(name)s] - %(levelname)s - %(message)s")) 32 | logger.addHandler(handler) 33 | 34 | # Configure logger 35 | loglevel = logging.INFO 36 | logger.setLevel(loglevel) 37 | -------------------------------------------------------------------------------- /modules/masking.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image, ImageFilter, ImageOps 3 | 4 | 5 | def get_crop_region(mask: np.ndarray, pad=0): 6 | """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle. 7 | For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)""" 8 | 9 | h, w = mask.shape 10 | 11 | crop_left = 0 12 | for i in range(w): 13 | if not (mask[:, i] == 0).all(): 14 | break 15 | crop_left += 1 16 | 17 | crop_right = 0 18 | for i in reversed(range(w)): 19 | if not (mask[:, i] == 0).all(): 20 | break 21 | crop_right += 1 22 | 23 | crop_top = 0 24 | for i in range(h): 25 | if not (mask[i] == 0).all(): 26 | break 27 | crop_top += 1 28 | 29 | crop_bottom = 0 30 | for i in reversed(range(h)): 31 | if not (mask[i] == 0).all(): 32 | break 33 | crop_bottom += 1 34 | 35 | return ( 36 | int(max(crop_left - pad, 0)), 37 | int(max(crop_top - pad, 0)), 38 | int(min(w - crop_right + pad, w)), 39 | int(min(h - crop_bottom + pad, h)), 40 | ) 41 | 42 | 43 | def expand_crop_region(crop_region: np.ndarray, processing_width, processing_height, image_width, image_height): 44 | """expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region 45 | for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128. 46 | """ 47 | 48 | x1, y1, x2, y2 = crop_region 49 | 50 | ratio_crop_region = (x2 - x1) / (y2 - y1) 51 | ratio_processing = processing_width / processing_height 52 | 53 | if ratio_crop_region > ratio_processing: 54 | desired_height = (x2 - x1) / ratio_processing 55 | desired_height_diff = int(desired_height - (y2 - y1)) 56 | y1 -= desired_height_diff // 2 57 | y2 += desired_height_diff - desired_height_diff // 2 58 | if y2 >= image_height: 59 | diff = y2 - image_height 60 | y2 -= diff 61 | y1 -= diff 62 | if y1 < 0: 63 | y2 -= y1 64 | y1 -= y1 65 | if y2 >= image_height: 66 | y2 = image_height 67 | else: 68 | desired_width = (y2 - y1) * ratio_processing 69 | desired_width_diff = int(desired_width - (x2 - x1)) 70 | x1 -= desired_width_diff // 2 71 | x2 += desired_width_diff - desired_width_diff // 2 72 | if x2 >= image_width: 73 | diff = x2 - image_width 74 | x2 -= diff 75 | x1 -= diff 76 | if x1 < 0: 77 | x2 -= x1 78 | x1 -= x1 79 | if x2 >= image_width: 80 | x2 = image_width 81 | 82 | return x1, y1, x2, y2 83 | 84 | 85 | def fill(image, mask): 86 | """fills masked regions with colors from image using blur. Not extremely effective.""" 87 | 88 | image_mod = Image.new("RGBA", (image.width, image.height)) 89 | 90 | image_masked = Image.new("RGBa", (image.width, image.height)) 91 | image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L"))) 92 | 93 | image_masked = image_masked.convert("RGBa") 94 | 95 | for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]: 96 | blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert("RGBA") 97 | for _ in range(repeats): 98 | image_mod.alpha_composite(blurred) 99 | 100 | return image_mod.convert("RGB") 101 | -------------------------------------------------------------------------------- /modules/model_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import torch 4 | import hashlib 5 | import urllib.request 6 | import urllib.error 7 | from tqdm import tqdm 8 | from urllib.parse import urlparse 9 | from typing import Dict, Optional 10 | 11 | 12 | def natural_sort_key(s, regex=re.compile("([0-9]+)")): 13 | return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)] 14 | 15 | 16 | def walk_files(path, allowed_extensions=None): 17 | if not os.path.exists(path): 18 | return 19 | 20 | if allowed_extensions is not None: 21 | allowed_extensions = set(allowed_extensions) 22 | 23 | items = list(os.walk(path, followlinks=True)) 24 | items = sorted(items, key=lambda x: natural_sort_key(x[0])) 25 | 26 | for root, _, files in items: 27 | for filename in sorted(files, key=natural_sort_key): 28 | if allowed_extensions is not None: 29 | _, ext = os.path.splitext(filename) 30 | if ext not in allowed_extensions: 31 | continue 32 | 33 | # Skip hidden files 34 | if "/." in root or "\\." in root: 35 | continue 36 | 37 | yield os.path.join(root, filename) 38 | 39 | 40 | def load_file_from_url( 41 | url: str, 42 | *, 43 | model_dir: str, 44 | progress: bool = True, 45 | file_name: str | None = None, 46 | ) -> str: 47 | """Download a file from `url` into `model_dir`, using the file present if possible. 48 | 49 | Returns the path to the downloaded file. 50 | """ 51 | os.makedirs(model_dir, exist_ok=True) 52 | if not file_name: 53 | parts = urlparse(url) 54 | file_name = os.path.basename(parts.path) 55 | cached_file = os.path.abspath(os.path.join(model_dir, file_name)) 56 | if not os.path.exists(cached_file): 57 | print(f'Downloading: "{url}" to {cached_file}\n') 58 | from torch.hub import download_url_to_file 59 | 60 | download_url_to_file(url, cached_file, progress=progress) 61 | return cached_file 62 | 63 | 64 | def calculate_sha(file: str, force=False) -> Optional[str]: 65 | sha_file = f"{file}.sha" 66 | 67 | # Check if the .sha file exists 68 | if not force and os.path.exists(sha_file): 69 | try: 70 | with open(sha_file, "r") as f: 71 | stored_hash = f.read().strip() 72 | if stored_hash: 73 | return stored_hash 74 | except IOError as e: 75 | print(f"Failed to read hash: {e}") 76 | 77 | # Calculate the hash if the .sha file doesn't exist or is empty 78 | try: 79 | with open(file, "rb") as fp: 80 | file_hash = hashlib.sha256() 81 | while chunk := fp.read(8192): 82 | file_hash.update(chunk) 83 | calculated_hash = file_hash.hexdigest() 84 | 85 | # Write the calculated hash to the .sha file 86 | try: 87 | with open(sha_file, "w") as f: 88 | f.write(calculated_hash) 89 | except IOError as e: 90 | print(f"Failed to write hash to {sha_file}: {e}") 91 | 92 | return calculated_hash 93 | except IOError as e: 94 | print(f"Failed to read file {file}: {e}") 95 | return None 96 | 97 | 98 | def download_file(url: str, dst: str, sha256sum: Optional[str] = None) -> Dict[str, Optional[str]]: 99 | """ 100 | Downloads a file from a URL to a destination path, optionally verifying its SHA-256 checksum. 101 | 102 | :param url: URL of the file to download 103 | :param dst: Destination path to save the downloaded file 104 | :param sha256sum: Optional SHA-256 checksum to verify the downloaded file 105 | :return: Dictionary with file path, download status, calculated checksum, and checksum match status 106 | """ 107 | # Ensure the directory exists 108 | os.makedirs(os.path.dirname(dst), exist_ok=True) 109 | 110 | file_exists = os.path.isfile(dst) 111 | file_checksum = None 112 | checksum_match = None 113 | downloaded = False 114 | 115 | try: 116 | if file_exists: 117 | file_checksum = calculate_sha(dst) 118 | if sha256sum: 119 | checksum_match = file_checksum == sha256sum 120 | if not checksum_match: 121 | os.remove(dst) 122 | 123 | if not file_exists or checksum_match == False: 124 | with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=dst.split("/")[-1]) as t: 125 | 126 | def reporthook(blocknum, blocksize, totalsize): 127 | if t.total is None and totalsize > 0: 128 | t.total = totalsize 129 | read_so_far = blocknum * blocksize 130 | t.update(max(0, read_so_far - t.n)) 131 | 132 | urllib.request.urlretrieve(url, dst, reporthook=reporthook) 133 | downloaded = True 134 | 135 | file_checksum = calculate_sha(dst, force=True) 136 | if sha256sum: 137 | checksum_match = file_checksum == sha256sum 138 | 139 | except urllib.error.URLError as ex: 140 | print("Download failed:", ex) 141 | if os.path.isfile(dst): 142 | os.remove(dst) 143 | except Exception as ex: 144 | print("An error occurred:", ex) 145 | finally: 146 | return {"file": dst, "downloaded": downloaded, "sha": file_checksum, "match": checksum_match} 147 | 148 | 149 | def load_jit_torch_file(model_path: str): 150 | model = torch.jit.load(model_path, map_location="cpu") 151 | model.eval() 152 | return model 153 | -------------------------------------------------------------------------------- /modules/postprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .color_blend import ColorBlend 2 | from .color_correct import ColorCorrect 3 | 4 | NODE_CLASS_MAPPINGS = { 5 | "ColorBlend": ColorBlend, 6 | "ColorCorrect": ColorCorrect, 7 | } 8 | 9 | NODE_DISPLAY_NAME_MAPPINGS = { 10 | "ColorBlend": "Color Blend", 11 | "ColorCorrect": "Color Correct", 12 | } 13 | -------------------------------------------------------------------------------- /modules/postprocessing/color_blend.py: -------------------------------------------------------------------------------- 1 | # Adapt from https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/color_blend.py 2 | 3 | # Color blend node by Yam Levi 4 | # Property of Stability AI 5 | import cv2 6 | import torch 7 | import numpy as np 8 | 9 | import comfy.utils 10 | 11 | 12 | def color_blend(bw_layer, color_layer): 13 | # Convert the color layer to LAB color space 14 | color_lab = cv2.cvtColor(color_layer, cv2.COLOR_BGR2Lab) 15 | # Convert the black and white layer to grayscale 16 | bw_layer_gray = cv2.cvtColor(bw_layer, cv2.COLOR_BGR2GRAY) 17 | # Replace the luminosity (L) channel in the color image with the black and white luminosity 18 | _, color_a, color_b = cv2.split(color_lab) 19 | blended_lab = cv2.merge((bw_layer_gray, color_a, color_b)) 20 | # Convert the blended LAB image back to BGR color space 21 | blended_result = cv2.cvtColor(blended_lab, cv2.COLOR_Lab2BGR) 22 | return blended_result 23 | 24 | 25 | class ColorBlend: 26 | def __init__(self): 27 | pass 28 | 29 | @classmethod 30 | def INPUT_TYPES(s): 31 | return { 32 | "required": { 33 | "bw_layer": ("IMAGE",), 34 | "color_layer": ("IMAGE",), 35 | }, 36 | } 37 | 38 | RETURN_TYPES = ("IMAGE",) 39 | FUNCTION = "color_blending_mode" 40 | CATEGORY = "Art Venture/Post Processing" 41 | 42 | def color_blending_mode(self, bw_layer, color_layer): 43 | if bw_layer.shape[0] < color_layer.shape[0]: 44 | bw_layer = bw_layer.repeat(color_layer.shape[0], 1, 1, 1)[ 45 | : color_layer.shape[0] 46 | ] 47 | if bw_layer.shape[0] > color_layer.shape[0]: 48 | color_layer = color_layer.repeat(bw_layer.shape[0], 1, 1, 1)[ 49 | : bw_layer.shape[0] 50 | ] 51 | 52 | batch_size, *_ = bw_layer.shape 53 | tensor_output = torch.empty_like(bw_layer) 54 | 55 | image1 = bw_layer.cpu() 56 | image2 = color_layer.cpu() 57 | if image1.shape != image2.shape: 58 | image2 = image2.permute(0, 3, 1, 2) 59 | image2 = comfy.utils.common_upscale( 60 | image2, 61 | image1.shape[2], 62 | image1.shape[1], 63 | upscale_method="bicubic", 64 | crop="center", 65 | ) 66 | image2 = image2.permute(0, 2, 3, 1) 67 | image1 = (image1 * 255).to(torch.uint8).numpy() 68 | image2 = (image2 * 255).to(torch.uint8).numpy() 69 | 70 | for i in range(batch_size): 71 | blend = color_blend(image1[i], image2[i]) 72 | blend = np.stack([blend]) 73 | tensor_output[i : i + 1] = ( 74 | torch.from_numpy(blend.transpose(0, 3, 1, 2)) / 255.0 75 | ).permute(0, 2, 3, 1) 76 | 77 | return (tensor_output,) 78 | -------------------------------------------------------------------------------- /modules/postprocessing/color_correct.py: -------------------------------------------------------------------------------- 1 | # Adapt from https://github.com/EllangoK/ComfyUI-post-processing-nodes/blob/master/post_processing/color_correct.py 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from PIL import Image, ImageEnhance 7 | 8 | 9 | class ColorCorrect: 10 | @classmethod 11 | def INPUT_TYPES(s): 12 | return { 13 | "required": { 14 | "image": ("IMAGE",), 15 | "temperature": ( 16 | "FLOAT", 17 | {"default": 0, "min": -100, "max": 100, "step": 5}, 18 | ), 19 | "hue": ("FLOAT", {"default": 0, "min": -90, "max": 90, "step": 5}), 20 | "brightness": ( 21 | "FLOAT", 22 | {"default": 0, "min": -100, "max": 100, "step": 5}, 23 | ), 24 | "contrast": ( 25 | "FLOAT", 26 | {"default": 0, "min": -100, "max": 100, "step": 5}, 27 | ), 28 | "saturation": ( 29 | "FLOAT", 30 | {"default": 0, "min": -100, "max": 100, "step": 5}, 31 | ), 32 | "gamma": ("FLOAT", {"default": 1, "min": 0.2, "max": 2.2, "step": 0.1}), 33 | }, 34 | } 35 | 36 | RETURN_TYPES = ("IMAGE",) 37 | FUNCTION = "color_correct" 38 | 39 | CATEGORY = "Art Venture/Post Processing" 40 | 41 | def color_correct( 42 | self, 43 | image: torch.Tensor, 44 | temperature: float, 45 | hue: float, 46 | brightness: float, 47 | contrast: float, 48 | saturation: float, 49 | gamma: float, 50 | ): 51 | batch_size, height, width, _ = image.shape 52 | result = torch.zeros_like(image) 53 | 54 | brightness /= 100 55 | contrast /= 100 56 | saturation /= 100 57 | temperature /= 100 58 | 59 | brightness = 1 + brightness 60 | contrast = 1 + contrast 61 | saturation = 1 + saturation 62 | 63 | for b in range(batch_size): 64 | tensor_image = image[b].numpy() 65 | 66 | modified_image = Image.fromarray((tensor_image * 255).astype(np.uint8)) 67 | 68 | # brightness 69 | modified_image = ImageEnhance.Brightness(modified_image).enhance(brightness) 70 | 71 | # contrast 72 | modified_image = ImageEnhance.Contrast(modified_image).enhance(contrast) 73 | modified_image = np.array(modified_image).astype(np.float32) 74 | 75 | # temperature 76 | if temperature > 0: 77 | modified_image[:, :, 0] *= 1 + temperature 78 | modified_image[:, :, 1] *= 1 + temperature * 0.4 79 | elif temperature < 0: 80 | modified_image[:, :, 2] *= 1 - temperature 81 | modified_image = np.clip(modified_image, 0, 255) / 255 82 | 83 | # gamma 84 | modified_image = np.clip(np.power(modified_image, gamma), 0, 1) 85 | 86 | # saturation 87 | hls_img = cv2.cvtColor(modified_image, cv2.COLOR_RGB2HLS) 88 | hls_img[:, :, 2] = np.clip(saturation * hls_img[:, :, 2], 0, 1) 89 | modified_image = cv2.cvtColor(hls_img, cv2.COLOR_HLS2RGB) * 255 90 | 91 | # hue 92 | hsv_img = cv2.cvtColor(modified_image, cv2.COLOR_RGB2HSV) 93 | hsv_img[:, :, 0] = (hsv_img[:, :, 0] + hue) % 360 94 | modified_image = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2RGB) 95 | 96 | modified_image = modified_image.astype(np.uint8) 97 | modified_image = modified_image / 255 98 | modified_image = torch.from_numpy(modified_image).unsqueeze(0) 99 | result[b] = modified_image 100 | 101 | return (result,) 102 | -------------------------------------------------------------------------------- /modules/sdxl_prompt_styler/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ -------------------------------------------------------------------------------- /modules/sdxl_prompt_styler/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 twri 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /modules/sdxl_prompt_styler/README.md: -------------------------------------------------------------------------------- 1 | SDXL Prompt Styler 2 | ======= 3 | Custom node for ComfyUI 4 | ----------- 5 | ![SDXL Prompt Styler Screenshot](examples/sdxl_prompt_styler.png) 6 | 7 | SDXL Prompt Styler is a node that enables you to style prompts based on predefined templates stored in multiple JSON files. The node specifically replaces a {prompt} placeholder in the 'prompt' field of each template with provided positive text. 8 | 9 | The node also effectively manages negative prompts. If negative text is provided, the node combines this with the 'negative_prompt' field from the template. If no negative text is supplied, the system defaults to using the 'negative_prompt' from the JSON template. This flexibility enables the creation of a diverse and specific range of negative prompts. 10 | 11 | ## Important Update: 12 | With the latest changes, the file structure and naming convention for style JSONs have been modified. If you've added or made changes to the `sdxl_styles.json` file in the past, follow these steps to ensure your styles remain intact: 13 | 14 | 1. **Backup**: Before pulling the latest changes, back up your `sdxl_styles.json` to a safe location. 15 | 2. **Migration**: After updating the repository, create a new JSON file in the styles directory. Move your custom styles from the backup of `sdxl_styles.json` into this new file. 16 | 3. **Unique Style Names**: While the system now detects duplicates and appends a suffix to ensure uniqueness, it's a best practice to ensure your style names are originally unique to prevent any potential confusion. 17 | 4. **Managing Included JSON Files**: If you prefer not to load specific included JSON files, consider renaming or moving them to a different location outside of the styles directory. The system will load all JSON files present in the specified directory. 18 | 19 | ## New Features: 20 | 21 | 1. **Loading from Multiple JSON Files:** The system can now load styles from multiple JSON files present in the specified directory, ensuring the uniqueness of style names by appending a suffix to duplicates. 22 | 2. **Enhanced Error Handling:** Improved error handling for file reading, data validity, and template replacement functions. 23 | 24 | --- 25 | 26 | ### Usage Example with SDXL Prompt Styler 27 | 28 | Template example from a JSON file: 29 | 30 | ```json 31 | [ 32 | { 33 | "name": "base", 34 | "prompt": "{prompt}", 35 | "negative_prompt": "" 36 | }, 37 | { 38 | "name": "enhance", 39 | "prompt": "breathtaking {prompt} . award-winning, professional, highly detailed", 40 | "negative_prompt": "ugly, deformed, noisy, blurry, distorted, grainy" 41 | } 42 | ] 43 | ``` 44 | 45 | ```python 46 | style = "enhance" 47 | positive_prompt = "a futuristic pop up tent in a forest" 48 | negative_prompt = "dark" 49 | ``` 50 | 51 | This will generate the following styled prompts as outputs: 52 | 53 | ``` 54 | breathtaking a futuristic pop up tent in a forest . award-winning, professional, highly detailed 55 | ugly, deformed, noisy, blurry, distorted, grainy, dark 56 | ``` 57 | 58 | ### Installation 59 | 60 | To install and use the SDXL Prompt Styler nodes, follow these steps: 61 | 62 | 1. Open a terminal or command line interface. 63 | 2. Navigate to the `ComfyUI/custom_nodes/` directory. 64 | 3. Run the following command: 65 | ```git clone https://github.com/twri/sdxl_prompt_styler.git``` 66 | 4. Restart ComfyUI. 67 | 68 | This command clones the SDXL Prompt Styler repository into your `ComfyUI/custom_nodes/` directory. You should now be able to access and use the nodes from this repository. 69 | 70 | ### Inputs 71 | 72 | * **text_positive** - text for the positive base prompt G 73 | * **text_negative** - text for the negative base prompt G 74 | * **log_prompt** - print inputs and outputs to stdout 75 | 76 | ### Outputs 77 | 78 | * **positive_prompt_text_g** - combined prompt with style for positive promt G 79 | * **negative_prompt_text_g** - combined prompt with style for negative promt G -------------------------------------------------------------------------------- /modules/sdxl_prompt_styler/__init__.py: -------------------------------------------------------------------------------- 1 | from .sdxl_prompt_styler import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 4 | -------------------------------------------------------------------------------- /modules/sdxl_prompt_styler/sdxl_prompt_styler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | 5 | def read_json_file(file_path): 6 | """ 7 | Reads the content of a JSON file and returns it as a Python data structure. 8 | """ 9 | if not os.access(file_path, os.R_OK): 10 | print(f"Warning: No read permissions for file {file_path}") 11 | return None 12 | 13 | try: 14 | with open(file_path, "r", encoding="utf-8") as file: 15 | content = json.load(file) 16 | # Check if the content matches the expected format. 17 | if not all( 18 | [ 19 | "name" in item and "prompt" in item and "negative_prompt" in item 20 | for item in content 21 | ] 22 | ): 23 | print(f"Warning: Invalid content in file {file_path}") 24 | return None 25 | return content 26 | except Exception as e: 27 | print(f"An error occurred while reading {file_path}: {str(e)}") 28 | return None 29 | 30 | 31 | def read_sdxl_styles(json_data): 32 | """ 33 | Extracts style names from the provided data. 34 | """ 35 | if not isinstance(json_data, list): 36 | print("Error: input data must be a list") 37 | return [] 38 | 39 | return [ 40 | item["name"] for item in json_data if isinstance(item, dict) and "name" in item 41 | ] 42 | 43 | 44 | def get_all_json_files(directory): 45 | """ 46 | Retrieves all JSON files present in the specified directory. 47 | """ 48 | return [ 49 | os.path.join(directory, file) 50 | for file in os.listdir(directory) 51 | if file.endswith(".json") and os.path.isfile(os.path.join(directory, file)) 52 | ] 53 | 54 | 55 | def load_styles_from_directory(directory): 56 | """ 57 | Loads style names and combined data from all JSON files in the directory. 58 | Ensures style names are unique by appending a suffix to duplicates. 59 | """ 60 | json_files = get_all_json_files(directory) 61 | combined_data = [] 62 | seen = set() 63 | 64 | for json_file in json_files: 65 | json_data = read_json_file(json_file) 66 | if json_data: 67 | for item in json_data: 68 | original_style = item["name"] 69 | style = original_style 70 | suffix = 1 71 | while style in seen: 72 | style = f"{original_style}_{suffix}" 73 | suffix += 1 74 | item["name"] = style 75 | seen.add(style) 76 | combined_data.append(item) 77 | 78 | unique_style_names = [ 79 | item["name"] 80 | for item in combined_data 81 | if isinstance(item, dict) and "name" in item 82 | ] 83 | 84 | return combined_data, unique_style_names 85 | 86 | 87 | def read_sdxl_templates_replace_and_combine( 88 | json_data, template_name, positive_prompt, negative_prompt 89 | ): 90 | try: 91 | # Check if json_data is a list 92 | if not isinstance(json_data, list): 93 | raise ValueError("Invalid JSON data. Expected a list of templates.") 94 | 95 | for template in json_data: 96 | # Check if template contains 'name' and 'prompt' fields 97 | if "name" not in template or "prompt" not in template: 98 | raise ValueError("Invalid template. Missing 'name' or 'prompt' field.") 99 | 100 | # Replace {prompt} in the matching template 101 | if template["name"] == template_name: 102 | positive_prompt = template["prompt"].replace( 103 | "{prompt}", positive_prompt 104 | ) 105 | 106 | json_negative_prompt = template.get("negative_prompt", "") 107 | if negative_prompt: 108 | negative_prompt = ( 109 | f"{json_negative_prompt}, {negative_prompt}" 110 | if json_negative_prompt 111 | else negative_prompt 112 | ) 113 | else: 114 | negative_prompt = json_negative_prompt 115 | 116 | return positive_prompt, negative_prompt 117 | 118 | # If function hasn't returned yet, no matching template was found 119 | raise ValueError(f"No template found with name '{template_name}'.") 120 | 121 | except Exception as e: 122 | print(f"An error occurred: {str(e)}") 123 | 124 | 125 | class SDXLPromptStyler: 126 | def __init__(self): 127 | pass 128 | 129 | @classmethod 130 | def INPUT_TYPES(self): 131 | current_directory = os.path.dirname(os.path.realpath(__file__)) 132 | self.json_data, self.styles = load_styles_from_directory(current_directory) 133 | 134 | return { 135 | "required": { 136 | "text_positive": ("STRING", {"default": "", "multiline": True}), 137 | "text_negative": ("STRING", {"default": "", "multiline": True}), 138 | "style": ((self.styles),), 139 | "log_prompt": (["No", "Yes"], {"default": "No"}), 140 | }, 141 | "optional": { 142 | "style_name": ("STRING", {"multiline": False}), 143 | }, 144 | } 145 | 146 | RETURN_TYPES = ( 147 | "STRING", 148 | "STRING", 149 | ) 150 | RETURN_NAMES = ( 151 | "positive_prompt_text_g", 152 | "negative_prompt_text_g", 153 | ) 154 | FUNCTION = "prompt_styler" 155 | CATEGORY = "utils" 156 | 157 | def prompt_styler( 158 | self, text_positive, text_negative, style, log_prompt, style_name=None 159 | ): 160 | if style_name and style_name not in self.styles: 161 | print(f"Warning: Style '{style_name}' not found. Using '{style}' instead.") 162 | style_name = None 163 | 164 | if style_name: 165 | style = style_name 166 | 167 | # Process and combine prompts in templates 168 | # The function replaces the positive prompt placeholder in the template, 169 | # and combines the negative prompt with the template's negative prompt, if they exist. 170 | positive_prompt, negative_prompt = read_sdxl_templates_replace_and_combine( 171 | self.json_data, style, text_positive, text_negative 172 | ) 173 | 174 | # If logging is enabled (log_prompt is set to "Yes"), 175 | # print the style, positive and negative text, and positive and negative prompts to the console 176 | if log_prompt == "Yes": 177 | print(f"style: {style}") 178 | print(f"text_positive: {text_positive}") 179 | print(f"text_negative: {text_negative}") 180 | print(f"positive_prompt: {positive_prompt}") 181 | print(f"negative_prompt: {negative_prompt}") 182 | 183 | return positive_prompt, negative_prompt 184 | 185 | 186 | NODE_CLASS_MAPPINGS = { 187 | "SDXLPromptStyler": SDXLPromptStyler, 188 | } 189 | 190 | NODE_DISPLAY_NAME_MAPPINGS = { 191 | "SDXLPromptStyler": "SDXL Prompt Styler", 192 | } 193 | -------------------------------------------------------------------------------- /modules/sdxl_prompt_styler/sdxl_styles_base.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "name": "base", 4 | "prompt": "{prompt}", 5 | "negative_prompt": "" 6 | } 7 | ] -------------------------------------------------------------------------------- /modules/sdxl_prompt_styler/sdxl_styles_sai.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "name": "3d-model", 4 | "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting", 5 | "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting" 6 | }, 7 | { 8 | "name": "analog film", 9 | "prompt": "analog film photo {prompt} . faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage", 10 | "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured" 11 | }, 12 | { 13 | "name": "anime", 14 | "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed", 15 | "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast" 16 | }, 17 | { 18 | "name": "cinematic", 19 | "prompt": "cinematic film still {prompt} . shallow depth of field, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", 20 | "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured" 21 | }, 22 | { 23 | "name": "comic book", 24 | "prompt": "comic {prompt} . graphic illustration, comic art, graphic novel art, vibrant, highly detailed", 25 | "negative_prompt": "photograph, deformed, glitch, noisy, realistic, stock photo" 26 | }, 27 | { 28 | "name": "craft clay", 29 | "prompt": "play-doh style {prompt} . sculpture, clay art, centered composition, Claymation", 30 | "negative_prompt": "sloppy, messy, grainy, highly detailed, ultra textured, photo" 31 | }, 32 | { 33 | "name": "digital art", 34 | "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed", 35 | "negative_prompt": "photo, photorealistic, realism, ugly" 36 | }, 37 | { 38 | "name": "enhance", 39 | "prompt": "breathtaking {prompt} . award-winning, professional, highly detailed", 40 | "negative_prompt": "ugly, deformed, noisy, blurry, distorted, grainy" 41 | }, 42 | { 43 | "name": "fantasy art", 44 | "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy", 45 | "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white" 46 | }, 47 | { 48 | "name": "isometric", 49 | "prompt": "isometric style {prompt} . vibrant, beautiful, crisp, detailed, ultra detailed, intricate", 50 | "negative_prompt": "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy, realistic, photographic" 51 | }, 52 | { 53 | "name": "line art", 54 | "prompt": "line art drawing {prompt} . professional, sleek, modern, minimalist, graphic, line art, vector graphics", 55 | "negative_prompt": "anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic" 56 | }, 57 | { 58 | "name": "lowpoly", 59 | "prompt": "low-poly style {prompt} . low-poly game art, polygon mesh, jagged, blocky, wireframe edges, centered composition", 60 | "negative_prompt": "noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo" 61 | }, 62 | { 63 | "name": "neonpunk", 64 | "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional", 65 | "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured" 66 | }, 67 | { 68 | "name": "origami", 69 | "prompt": "origami style {prompt} . paper art, pleated paper, folded, origami art, pleats, cut and fold, centered composition", 70 | "negative_prompt": "noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo" 71 | }, 72 | { 73 | "name": "photographic", 74 | "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed", 75 | "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly" 76 | }, 77 | { 78 | "name": "pixel art", 79 | "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics", 80 | "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic" 81 | }, 82 | { 83 | "name": "texture", 84 | "prompt": "texture {prompt} top down close-up", 85 | "negative_prompt": "ugly, deformed, noisy, blurry" 86 | } 87 | ] -------------------------------------------------------------------------------- /modules/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import sys 4 | import torch 5 | import base64 6 | import numpy as np 7 | import importlib 8 | import subprocess 9 | import pkg_resources 10 | from pkg_resources import parse_version 11 | from PIL import Image 12 | 13 | from .logger import logger 14 | 15 | 16 | class AnyType(str): 17 | def __ne__(self, __value: object) -> bool: 18 | return False 19 | 20 | 21 | any_type = AnyType("*") 22 | 23 | 24 | def ensure_package(package, version=None, install_package_name=None): 25 | # Try to import the package 26 | try: 27 | module = importlib.import_module(package) 28 | except ImportError: 29 | logger.info(f"Package {package} is not installed. Installing now...") 30 | install_command = _construct_pip_command(install_package_name or package, version) 31 | subprocess.check_call(install_command) 32 | else: 33 | # If a specific version is required, check the version 34 | if version: 35 | installed_version = pkg_resources.get_distribution(package).version 36 | if parse_version(installed_version) < parse_version(version): 37 | logger.info( 38 | f"Package {package} is outdated (installed: {installed_version}, required: {version}). Upgrading now..." 39 | ) 40 | install_command = _construct_pip_command(install_package_name or package, version) 41 | subprocess.check_call(install_command) 42 | 43 | 44 | def _construct_pip_command(package_name, version=None): 45 | if "python_embeded" in sys.executable or "python_embedded" in sys.executable: 46 | pip_install = [sys.executable, "-s", "-m", "pip", "install"] 47 | else: 48 | pip_install = [sys.executable, "-m", "pip", "install"] 49 | 50 | # Include the version in the package name if specified 51 | if version: 52 | package_name = f"{package_name}=={version}" 53 | 54 | return pip_install + [package_name] 55 | 56 | 57 | def get_dict_attribute(dict_inst: dict, name_string: str, default=None): 58 | nested_keys = name_string.split(".") 59 | value = dict_inst 60 | 61 | for key in nested_keys: 62 | # Handle array indexing 63 | if key.startswith("[") and key.endswith("]"): 64 | try: 65 | index = int(key[1:-1]) 66 | if not isinstance(value, (list, tuple)) or index >= len(value): 67 | return default 68 | value = value[index] 69 | except (ValueError, TypeError): 70 | return default 71 | else: 72 | if not isinstance(value, dict): 73 | return default 74 | value = value.get(key, None) 75 | 76 | if value is None: 77 | return default 78 | 79 | return value 80 | 81 | 82 | def set_dict_attribute(dict_inst: dict, name_string: str, value): 83 | """ 84 | Set an attribute to a dictionary using dot notation. 85 | If the attribute does not already exist, it will create a nested dictionary. 86 | 87 | Parameters: 88 | - dict_inst: the dictionary instance to set the attribute 89 | - name_string: the attribute name in dot notation (ex: 'attributes[1].name') 90 | - value: the value to set for the attribute 91 | 92 | Returns: 93 | None 94 | """ 95 | # Split the attribute names by dot 96 | name_list = name_string.split(".") 97 | 98 | # Traverse the dictionary and create a nested dictionary if necessary 99 | current_dict = dict_inst 100 | for name in name_list[:-1]: 101 | is_array = name.endswith("]") 102 | if is_array: 103 | open_bracket_index = name.index("[") 104 | idx = int(name[open_bracket_index + 1 : -1]) 105 | name = name[:open_bracket_index] 106 | 107 | if name not in current_dict: 108 | current_dict[name] = [] if is_array else {} 109 | 110 | current_dict = current_dict[name] 111 | if is_array: 112 | while len(current_dict) <= idx: 113 | current_dict.append({}) 114 | current_dict = current_dict[idx] 115 | 116 | # Set the final attribute to its value 117 | name = name_list[-1] 118 | if name.endswith("]"): 119 | open_bracket_index = name.index("[") 120 | idx = int(name[open_bracket_index + 1 : -1]) 121 | name = name[:open_bracket_index] 122 | 123 | if name not in current_dict: 124 | current_dict[name] = [] 125 | 126 | while len(current_dict[name]) <= idx: 127 | current_dict[name].append(None) 128 | 129 | current_dict[name][idx] = value 130 | else: 131 | current_dict[name] = value 132 | 133 | 134 | def is_junction(src: str) -> bool: 135 | import subprocess 136 | 137 | child = subprocess.Popen('fsutil reparsepoint query "{}"'.format(src), stdout=subprocess.PIPE) 138 | streamdata = child.communicate()[0] 139 | rc = child.returncode 140 | return rc == 0 141 | 142 | 143 | def load_module(module_path, module_name=None): 144 | import importlib.util 145 | 146 | if module_name is None: 147 | module_name = os.path.basename(module_path) 148 | if os.path.isdir(module_path): 149 | module_path = os.path.join(module_path, "__init__.py") 150 | 151 | module_spec = importlib.util.spec_from_file_location(module_name, module_path) 152 | 153 | module = importlib.util.module_from_spec(module_spec) 154 | module_spec.loader.exec_module(module) 155 | 156 | return module 157 | 158 | 159 | def pil2numpy(image: Image.Image): 160 | return np.array(image).astype(np.float32) / 255.0 161 | 162 | 163 | def numpy2pil(image: np.ndarray, mode=None): 164 | return Image.fromarray(np.clip(255.0 * image, 0, 255).astype(np.uint8), mode) 165 | 166 | 167 | def pil2tensor(image: Image.Image): 168 | return torch.from_numpy(pil2numpy(image)).unsqueeze(0) 169 | 170 | 171 | def tensor2pil(image: torch.Tensor, mode=None): 172 | return numpy2pil(image.cpu().numpy().squeeze(), mode=mode) 173 | 174 | 175 | def tensor2bytes(image: torch.Tensor) -> bytes: 176 | return tensor2pil(image).tobytes() 177 | 178 | 179 | def pil2base64(image: Image.Image): 180 | buffered = io.BytesIO() 181 | image.save(buffered, format="PNG") 182 | img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") 183 | return img_str 184 | -------------------------------------------------------------------------------- /modules/video/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import folder_paths 5 | from comfy.utils import common_upscale 6 | 7 | from ..utils import load_module, pil2tensor 8 | 9 | custom_nodes = folder_paths.get_folder_paths("custom_nodes") 10 | video_dir_names = ["VideoHelperSuite", "ComfyUI-VideoHelperSuite"] 11 | 12 | output_dir = folder_paths.get_output_directory() 13 | input_dir = folder_paths.get_input_directory() 14 | temp_dir = folder_paths.get_temp_directory() 15 | 16 | NODE_CLASS_MAPPINGS = {} 17 | NODE_DISPLAY_NAME_MAPPINGS = {} 18 | 19 | try: 20 | module_path = None 21 | 22 | for custom_node in custom_nodes: 23 | custom_node = custom_node if not os.path.islink(custom_node) else os.readlink(custom_node) 24 | for module_dir in video_dir_names: 25 | if module_dir in os.listdir(custom_node): 26 | module_path = os.path.abspath(os.path.join(custom_node, module_dir)) 27 | break 28 | 29 | if module_path is None: 30 | raise Exception("Could not find VideoHelperSuite nodes") 31 | 32 | module_path = os.path.join(module_path) 33 | module = load_module(module_path) 34 | print("Loaded VideoHelperSuite from", module_path) 35 | 36 | LoadVideoPath = module.NODE_CLASS_MAPPINGS["VHS_LoadVideoPath"] 37 | 38 | def target_size(width, height, force_size) -> tuple[int, int]: 39 | force_size = force_size.split("x") 40 | if force_size[0] == "?": 41 | width = (width * int(force_size[1])) // height 42 | width = int(width) + 4 & ~7 43 | height = int(force_size[1]) 44 | elif force_size[1] == "?": 45 | height = (height * int(force_size[0])) // width 46 | height = int(height) + 4 & ~7 47 | width = int(force_size[0]) 48 | 49 | return (width, height) 50 | 51 | class UtilLoadVideoFromUrl(LoadVideoPath): 52 | @classmethod 53 | def INPUT_TYPES(s): 54 | inputs = LoadVideoPath.INPUT_TYPES() 55 | inputs["required"]["video"] = ("STRING", {"default": ""}) 56 | return inputs 57 | 58 | CATEGORY = "Art Venture/Loaders" 59 | FUNCTION = "load" 60 | RETURN_TYPES = ("IMAGE", "INT", "BOOLEAN") 61 | RETURN_NAMES = ("frames", "frame_count", "has_video") 62 | OUTPUT_IS_LIST = (True, True, False) 63 | 64 | def load_gif( 65 | self, 66 | gif_path: str, 67 | force_rate: int, 68 | force_size: str, 69 | skip_first_frames: int, 70 | frame_load_cap: int, 71 | select_every_nth: int, 72 | ): 73 | from PIL import Image, ImageSequence 74 | 75 | image = Image.open(gif_path) 76 | frames = [] 77 | total_frames_evaluated = -1 78 | 79 | if force_rate != 0: 80 | print(f"Force rate is not supported for gifs/webps") 81 | if frame_load_cap == 0: 82 | frame_load_cap = 999999999 83 | 84 | for i, frame in enumerate(ImageSequence.Iterator(image)): 85 | if i < skip_first_frames: 86 | continue 87 | elif i >= skip_first_frames + frame_load_cap: 88 | break 89 | else: 90 | total_frames_evaluated += 1 91 | if total_frames_evaluated % select_every_nth == 0: 92 | frames.append(pil2tensor(frame.copy().convert("RGB"))) 93 | 94 | images = torch.cat(frames, dim=0) 95 | 96 | if force_size != "Disabled": 97 | height = images.shape[1] 98 | width = images.shape[2] 99 | new_size = target_size(width, height, force_size) 100 | if new_size[0] != width or new_size[1] != height: 101 | s = images.movedim(-1, 1) 102 | s = common_upscale(s, new_size[0], new_size[1], "lanczos", "disabled") 103 | images = s.movedim(1, -1) 104 | 105 | return (images, len(frames)) 106 | 107 | def load_url(self, video: str, **kwargs): 108 | url = video.strip('"') 109 | 110 | if url == "": 111 | return (None, 0) 112 | 113 | if os.path.isfile(url): 114 | pass 115 | elif url.startswith("file://"): 116 | url = url[7:] 117 | url = os.path.abspath(url) 118 | 119 | if not os.path.isfile(url): 120 | raise Exception(f"File {url} does not exist") 121 | 122 | if url.startswith(input_dir): 123 | video = url[len(input_dir) + 1 :] + " [input]" 124 | elif url.startswith(output_dir): 125 | video = url[len(output_dir) + 1 :] + " [output]" 126 | elif url.startswith(temp_dir): 127 | video = url[len(temp_dir) + 1 :] + " [temp]" 128 | else: 129 | # move file to temp_dir 130 | import shutil 131 | 132 | tempdir = os.path.join(temp_dir, "video") 133 | if not os.path.exists(tempdir): 134 | os.makedirs(tempfile, exist_ok=True) 135 | 136 | filename = os.path.basename(url) 137 | filepath = os.path.join(tempdir, filename) 138 | 139 | i = 1 140 | split = os.path.splitext(filename) 141 | while os.path.exists(filepath): 142 | filename = f"{split[0]} ({i}){split[1]}" 143 | filepath = os.path.join(tempdir, filename) 144 | i += 1 145 | 146 | shutil.copy(url, filepath) 147 | video = "video/" + filename + " [temp]" 148 | elif url.startswith("http://") or url.startswith("https://"): 149 | from torch.hub import download_url_to_file 150 | from urllib.parse import urlparse 151 | 152 | parts = urlparse(url) 153 | filename = os.path.basename(parts.path) 154 | tempfile = os.path.join(temp_dir, "video") 155 | if not os.path.exists(tempfile): 156 | os.makedirs(tempfile, exist_ok=True) 157 | tempfile = os.path.join(tempfile, filename) 158 | 159 | print(f'Downloading: "{url}" to {tempfile}\n') 160 | download_url_to_file(url, tempfile, progress=True) 161 | 162 | video = "video/" + filename + " [temp]" 163 | elif url.startswith(("/view?", "/api/view?")): 164 | from urllib.parse import parse_qs 165 | 166 | qs_idx = url.find("?") 167 | qs = parse_qs(url[qs_idx + 1:]) 168 | filename = qs.get("name", qs.get("filename", None)) 169 | if filename is None: 170 | raise Exception(f"Invalid url: {url}") 171 | 172 | filename = filename[0] 173 | subfolder = qs.get("subfolder", None) 174 | if subfolder is not None: 175 | filename = os.path.join(subfolder[0], filename) 176 | 177 | dirtype = qs.get("type", ["input"]) 178 | video = f"{filename} [{dirtype[0]}]" 179 | else: 180 | raise Exception(f"Invalid url: {url}") 181 | 182 | if ".gif [" in video.lower() or ".webp [" in video.lower(): 183 | gif_path = folder_paths.get_annotated_filepath(video.strip('"')) 184 | res = self.load_gif(gif_path, **kwargs) 185 | else: 186 | res = self.load_video(video=video, **kwargs) 187 | 188 | return res 189 | 190 | def load(self, video: str, **kwargs): 191 | urls = video.strip().split("\n") 192 | 193 | videos = [] 194 | frame_counts = [] 195 | 196 | for url in urls: 197 | images, frame_count = self.load_url(url, **kwargs) 198 | if images is not None and frame_count > 0: 199 | videos.append(images) 200 | frame_counts.append(frame_count) 201 | 202 | has_video = len(videos) > 0 203 | if not has_video: 204 | image = torch.zeros((1, 64, 64, 3), dtype=torch.float32, device="cpu") 205 | videos.append(image) 206 | frame_counts.append(1) 207 | 208 | return (videos, frame_counts, has_video) 209 | 210 | @classmethod 211 | def IS_CHANGED(s, video: str, **kwargs): 212 | return video 213 | 214 | @classmethod 215 | def VALIDATE_INPUTS(s, **kwargs): 216 | return True 217 | 218 | NODE_CLASS_MAPPINGS["LoadVideoFromUrl"] = UtilLoadVideoFromUrl 219 | NODE_DISPLAY_NAME_MAPPINGS["LoadVideoFromUrl"] = "Load Video From Url" 220 | 221 | 222 | except Exception as e: 223 | print(e) 224 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-art-venture" 3 | description = "A comprehensive set of custom nodes for ComfyUI, focusing on utilities for image processing, JSON manipulation, model operations and working with object via URLs" 4 | version = "1.0.7" 5 | license = "LICENSE" 6 | dependencies = ["timm==0.6.13", "transformers", "fairscale", "pycocoevalcap", "opencv-python", "qrcode[pil]", "pytorch_lightning", "kornia", "pydantic", "segment_anything", "boto3>=1.34.101"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/sipherxyz/comfyui-art-venture" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "protogaia" 14 | DisplayName = "ComfyUI ArtVenture" 15 | Icon = "https://cdn.protogaia.com/assets/gaia.png" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm 2 | transformers 3 | fairscale 4 | pycocoevalcap 5 | opencv-python 6 | qrcode[pil] 7 | pytorch_lightning 8 | kornia 9 | pydantic 10 | segment_anything 11 | omegaconf 12 | boto3>=1.34.101 13 | -------------------------------------------------------------------------------- /web/text-switch-case.js: -------------------------------------------------------------------------------- 1 | import { app } from '../../../scripts/app.js'; 2 | import { ComfyWidgets } from '../../../scripts/widgets.js'; 3 | 4 | import { 5 | addKVState, 6 | chainCallback, 7 | hideWidgetForGood, 8 | addWidgetChangeCallback, 9 | } from './utils.js'; 10 | 11 | function addTextSwitchCaseWidget(nodeType) { 12 | chainCallback(nodeType.prototype, 'onNodeCreated', function () { 13 | const dataWidget = this.widgets.find((w) => w.name === 'switch_cases'); 14 | const delimiterWidget = this.widgets.find((w) => w.name === 'delimiter'); 15 | this.widgets = this.widgets.filter((w) => w.name !== 'condition'); 16 | 17 | let conditionCombo = null; 18 | 19 | const updateConditionCombo = () => { 20 | if (!delimiterWidget.value) return; 21 | 22 | const cases = (dataWidget.value ?? '') 23 | .split('\n') 24 | .filter((line) => line.includes(delimiterWidget.value)) 25 | .map((line) => line.split(delimiterWidget.value)[0]); 26 | 27 | if (!conditionCombo) { 28 | conditionCombo = ComfyWidgets['COMBO'](this, 'condition', [ 29 | ['__default__', ...(cases ?? [])], 30 | ]).widget; 31 | } else { 32 | conditionCombo.options.values = ['__default__', ...cases]; 33 | } 34 | }; 35 | 36 | updateConditionCombo(); 37 | dataWidget.inputEl.addEventListener('input', updateConditionCombo); 38 | addWidgetChangeCallback(delimiterWidget, updateConditionCombo); 39 | }); 40 | } 41 | 42 | app.registerExtension({ 43 | name: 'ArtVenture.TextSwitchCase', 44 | async beforeRegisterNodeDef(nodeType, nodeData) { 45 | if (!nodeData) return; 46 | if (nodeData.name !== 'TextSwitchCase') { 47 | return; 48 | } 49 | 50 | addKVState(nodeType); 51 | addTextSwitchCaseWidget(nodeType); 52 | }, 53 | }); 54 | -------------------------------------------------------------------------------- /web/upload.js: -------------------------------------------------------------------------------- 1 | import { app } from '../../../scripts/app.js'; 2 | import { api } from '../../../scripts/api.js'; 3 | import { $el } from '../../../scripts/ui.js'; 4 | import { addWidget, DOMWidgetImpl } from '../../../scripts/domWidget.js'; 5 | import { ComfyWidgets } from '../../../scripts/widgets.js' 6 | 7 | import { chainCallback, addKVState } from './utils.js'; 8 | 9 | const style = ` 10 | .comfy-img-preview video { 11 | object-fit: contain; 12 | width: var(--comfy-img-preview-width); 13 | height: var(--comfy-img-preview-height); 14 | } 15 | `; 16 | 17 | const supportedNodes = ['LoadImageFromUrl', 'LoadImageAsMaskFromUrl']; 18 | 19 | const formatUrl = (url) => { 20 | if (url.startsWith("http://") || url.startsWith("https://") || url.startsWith("blob:")) return url 21 | 22 | let type = "output" 23 | if (url.endsWith(']')) { 24 | const openBracketIndex = url.lastIndexOf('[') 25 | type = url.slice(openBracketIndex + 1, url.length - 1).trim() 26 | url = url.slice(0, openBracketIndex).trim() 27 | } 28 | 29 | const parts = url.split('/') 30 | const filename = parts.pop() 31 | const subfolder = parts.join('/') 32 | 33 | const params = [ 34 | 'filename=' + encodeURIComponent(filename), 35 | 'type=' + type, 36 | 'subfolder=' + subfolder, 37 | app.getRandParam().substring(1) 38 | ].join('&') 39 | 40 | return api.apiURL(`/view?${params}`) 41 | } 42 | 43 | // copied from ComfyUI_frontend/src/composables/widgets/useStringWidget.ts 44 | // remove the Object.defineProperty(widget, 'value') part 45 | function addUrlWidget(node, name, options) { 46 | const inputEl = document.createElement('textarea') 47 | inputEl.className = 'comfy-multiline-input' 48 | inputEl.value = options.default 49 | inputEl.placeholder = options.placeholder || name 50 | inputEl.spellcheck = false 51 | 52 | const widget = new DOMWidgetImpl({ 53 | node, 54 | name, 55 | type: 'customtext', 56 | element: inputEl, 57 | options: { 58 | hideOnZoom: true, 59 | getValue() { 60 | return inputEl.value 61 | }, 62 | setValue(v) { 63 | inputEl.value = v 64 | } 65 | } 66 | }) 67 | addWidget(node, widget) 68 | 69 | widget.inputEl = inputEl 70 | widget.options.minNodeSize = [400, 200] 71 | 72 | inputEl.addEventListener('input', () => { 73 | widget.value = inputEl.value 74 | widget.callback?.(inputEl.value, true) 75 | }) 76 | 77 | // Allow middle mouse button panning 78 | inputEl.addEventListener('pointerdown', (event) => { 79 | if (event.button === 1) { 80 | app.canvas.processMouseDown(event) 81 | } 82 | }) 83 | 84 | inputEl.addEventListener('pointermove', (event) => { 85 | if ((event.buttons & 4) === 4) { 86 | app.canvas.processMouseMove(event) 87 | } 88 | }) 89 | 90 | inputEl.addEventListener('pointerup', (event) => { 91 | if (event.button === 1) { 92 | app.canvas.processMouseUp(event) 93 | } 94 | }) 95 | 96 | /** Timer reference. `null` when the timer completes. */ 97 | let ignoreEventsTimer = null 98 | /** Total number of events ignored since the timer started. */ 99 | let ignoredEvents = 0 100 | 101 | // Pass wheel events to the canvas when appropriate 102 | inputEl.addEventListener('wheel', (event) => { 103 | if (!Object.is(event.deltaX, -0)) return 104 | 105 | // If the textarea has focus, require more effort to activate pass-through 106 | const multiplier = document.activeElement === inputEl ? 2 : 1 107 | const maxScrollHeight = inputEl.scrollHeight - inputEl.clientHeight 108 | 109 | if ( 110 | (event.deltaY < 0 && inputEl.scrollTop === 0) || 111 | (event.deltaY > 0 && inputEl.scrollTop === maxScrollHeight) 112 | ) { 113 | // Attempting to scroll past the end of the textarea 114 | if (!ignoreEventsTimer || ignoredEvents > 25 * multiplier) { 115 | app.canvas.processMouseWheel(event) 116 | } else { 117 | ignoredEvents++ 118 | } 119 | } else if (event.deltaY !== 0) { 120 | // Start timer whenever a successful scroll occurs 121 | ignoredEvents = 0 122 | if (ignoreEventsTimer) clearTimeout(ignoreEventsTimer) 123 | 124 | ignoreEventsTimer = setTimeout(() => { 125 | ignoreEventsTimer = null 126 | }, 800 * multiplier) 127 | } 128 | }) 129 | 130 | return widget 131 | } 132 | 133 | function addImageUploadWidget(nodeType, nodeData, imageInputName) { 134 | const { input } = nodeData ?? {} 135 | const { required } = input ?? {} 136 | if (!required) return 137 | 138 | const imageOptions = required.image 139 | delete required.image 140 | 141 | chainCallback(nodeType.prototype, "onNodeCreated", function () { 142 | const urlWidget = addUrlWidget(this, imageInputName, imageOptions[1]) 143 | ComfyWidgets.IMAGEUPLOAD( 144 | this, 145 | 'upload', 146 | ["IMAGEUPLOAD", { "image_upload": true, imageInputName }], 147 | ) 148 | 149 | const safeLoadImageFromUrl = (url) => { 150 | return new Promise((resolve, reject) => { 151 | const img = new Image(); 152 | img.onload = () => resolve(img); 153 | img.onerror = () => reject(null); 154 | img.src = url; 155 | }); 156 | } 157 | 158 | const setImagesFromUrl = (value = "") => { 159 | this.imageIndex = null; 160 | 161 | const urls = value.split("\n").filter(Boolean).map(formatUrl); 162 | if (!urls.length) { 163 | this.imgs = undefined; 164 | this.widgets = this.widgets.filter((w) => w.name !== "$$canvas-image-preview"); 165 | return 166 | } 167 | 168 | return Promise.all( 169 | urls.map(safeLoadImageFromUrl) 170 | ).then((imgs) => { 171 | const initialImgs = imgs.filter(Boolean); 172 | this.imgs = initialImgs.length > 0 ? initialImgs : undefined; 173 | app.graph.setDirtyCanvas(true); 174 | return initialImgs; 175 | }) 176 | } 177 | 178 | const originalUrlCallback = urlWidget.callback 179 | urlWidget.callback = (value, isProgrammatic = false) => { 180 | if (!isProgrammatic) { 181 | originalUrlCallback?.(value) 182 | urlWidget.options.setValue(value) 183 | } else { 184 | setImagesFromUrl(value) 185 | } 186 | } 187 | }) 188 | } 189 | 190 | app.registerExtension({ 191 | name: 'ArtVenture.Upload', 192 | init() { 193 | $el('style', { 194 | textContent: style, 195 | parent: document.head, 196 | }); 197 | }, 198 | async beforeRegisterNodeDef(nodeType, nodeData) { 199 | if (!nodeData) return; 200 | if (!supportedNodes.includes(nodeData?.name)) { 201 | return; 202 | } 203 | 204 | if (nodeData.name === 'LoadImageFromUrl' || nodeData.name === 'LoadImageAsMaskFromUrl') { 205 | addImageUploadWidget(nodeType, nodeData, 'image'); 206 | } 207 | 208 | addKVState(nodeType); 209 | }, 210 | }); 211 | -------------------------------------------------------------------------------- /web/utils.js: -------------------------------------------------------------------------------- 1 | export const CONVERTED_TYPE = "converted-widget"; 2 | 3 | export function hideWidgetForGood(node, widget, suffix = "") { 4 | widget.origType = widget.type; 5 | widget.origComputeSize = widget.computeSize; 6 | widget.computeSize = () => [0, -4]; // -4 is due to the gap litegraph adds between widgets automatically 7 | widget.type = CONVERTED_TYPE + suffix; 8 | 9 | // Hide any linked widgets, e.g. seed+seedControl 10 | if (widget.linkedWidgets) { 11 | for (const w of widget.linkedWidgets) { 12 | hideWidgetForGood(node, w, ":" + widget.name); 13 | } 14 | } 15 | } 16 | 17 | const doesInputWithNameExist = (node, name) => { 18 | return node.inputs ? node.inputs.some((input) => input.name === name) : false; 19 | }; 20 | 21 | const HIDDEN_TAG = "tschide"; 22 | const origProps = {}; 23 | 24 | // Toggle Widget + change size 25 | export function toggleWidget(node, widget, show = false, suffix = "", updateSize = true) { 26 | if (!widget || doesInputWithNameExist(node, widget.name)) return; 27 | 28 | // Store the original properties of the widget if not already stored 29 | if (!origProps[widget.name]) { 30 | origProps[widget.name] = { 31 | origType: widget.type, 32 | origComputeSize: widget.computeSize, 33 | }; 34 | } 35 | 36 | const origSize = node.size; 37 | 38 | // Set the widget type and computeSize based on the show flag 39 | widget.type = show ? origProps[widget.name].origType : HIDDEN_TAG + suffix; 40 | widget.computeSize = show 41 | ? origProps[widget.name].origComputeSize 42 | : () => [0, -4]; 43 | 44 | // Recursively handle linked widgets if they exist 45 | widget.linkedWidgets?.forEach((w) => 46 | toggleWidget(node, w, ":" + widget.name, show) 47 | ); 48 | 49 | // Calculate the new height for the node based on its computeSize method 50 | if (updateSize) { 51 | const newHeight = node.computeSize()[1]; 52 | node.setSize([node.size[0], newHeight]); 53 | } 54 | } 55 | 56 | export function addWidgetChangeCallback(widget, callback) { 57 | let widgetValue = widget.value; 58 | let originalDescriptor = Object.getOwnPropertyDescriptor(widget, "value"); 59 | Object.defineProperty(widget, "value", { 60 | get() { 61 | return originalDescriptor && originalDescriptor.get 62 | ? originalDescriptor.get.call(widget) 63 | : widgetValue; 64 | }, 65 | set(newVal) { 66 | if (originalDescriptor && originalDescriptor.set) { 67 | originalDescriptor.set.call(widget, newVal); 68 | } else { 69 | widgetValue = newVal; 70 | } 71 | 72 | callback(newVal); 73 | }, 74 | }); 75 | } 76 | 77 | export function chainCallback(object, property, callback) { 78 | if (object == undefined) { 79 | //This should not happen. 80 | console.error("Tried to add callback to non-existant object"); 81 | return; 82 | } 83 | if (property in object) { 84 | const callback_orig = object[property]; 85 | object[property] = function () { 86 | const r = callback_orig.apply(this, arguments); 87 | callback.apply(this, arguments); 88 | return r; 89 | }; 90 | } else { 91 | object[property] = callback; 92 | } 93 | } 94 | 95 | export function addKVState(nodeType) { 96 | chainCallback(nodeType.prototype, 'onNodeCreated', function () { 97 | chainCallback(this, 'onConfigure', function (info) { 98 | if (!this.widgets) { 99 | //Node has no widgets, there is nothing to restore 100 | return; 101 | } 102 | if (typeof info.widgets_values != 'object') { 103 | //widgets_values is in some unknown inactionable format 104 | return; 105 | } 106 | let widgetDict = info.widgets_values; 107 | if (widgetDict.length == undefined) { 108 | for (let w of this.widgets) { 109 | if (w.name in widgetDict) { 110 | w.value = widgetDict[w.name]; 111 | w.callback?.(w.value) 112 | } else { 113 | //attempt to restore default value 114 | let inputs = LiteGraph.getNodeType(this.type).nodeData.input; 115 | let initialValue = null; 116 | if (inputs?.required?.hasOwnProperty(w.name)) { 117 | if (inputs.required[w.name][1]?.hasOwnProperty('default')) { 118 | initialValue = inputs.required[w.name][1].default; 119 | } else if (inputs.required[w.name][0].length) { 120 | initialValue = inputs.required[w.name][0][0]; 121 | } 122 | } else if (inputs?.optional?.hasOwnProperty(w.name)) { 123 | if (inputs.optional[w.name][1]?.hasOwnProperty('default')) { 124 | initialValue = inputs.optional[w.name][1].default; 125 | } else if (inputs.optional[w.name][0].length) { 126 | initialValue = inputs.optional[w.name][0][0]; 127 | } 128 | } 129 | if (initialValue) { 130 | w.value = initialValue; 131 | w.callback?.(w.value) 132 | } 133 | } 134 | } 135 | } 136 | }); 137 | chainCallback(this, 'onSerialize', function (info) { 138 | info.widgets_values = {}; 139 | if (!this.widgets) { 140 | //object has no widgets, there is nothing to store 141 | return; 142 | } 143 | for (let w of this.widgets) { 144 | info.widgets_values[w.name] = w.value; 145 | } 146 | }); 147 | }); 148 | } --------------------------------------------------------------------------------