├── requirements.txt ├── imports ├── AdvancedControlNet │ ├── control_lllite.py │ ├── reference_nodes.py │ ├── logger.py │ ├── nodes_sparsectrl.py │ ├── control_sparsectrl.py │ ├── deprecated_nodes.py │ ├── nodes.py │ ├── weight_nodes.py │ ├── latent_keyframe_nodes.py │ └── control.py └── IPAdapterPlus.py ├── demo ├── input.png ├── model.png ├── output.gif └── IPAnimate-demo.json ├── __init__.py ├── README.md ├── .gitignore ├── main.py └── LICENSE /requirements.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imports/AdvancedControlNet/control_lllite.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /demo/input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chan-0312/ComfyUI-IPAnimate/HEAD/demo/input.png -------------------------------------------------------------------------------- /demo/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chan-0312/ComfyUI-IPAnimate/HEAD/demo/model.png -------------------------------------------------------------------------------- /demo/output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chan-0312/ComfyUI-IPAnimate/HEAD/demo/output.gif -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .main import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 4 | -------------------------------------------------------------------------------- /imports/AdvancedControlNet/reference_nodes.py: -------------------------------------------------------------------------------- 1 | class AnimateDiffLoaderWithContext: 2 | @classmethod 3 | def INPUT_TYPES(s): 4 | return { 5 | "required": { 6 | "model": ("MODEL",), 7 | "image": ("IMAGE",), 8 | }, 9 | } 10 | 11 | RETURN_TYPES = ("MODEL",) 12 | CATEGORY = "" -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-IPAnimate 2 | - 这是一个基于IPAdapter+ControlNet逐帧生成视频的工作,与[Steerable-motion](https://github.com/banodoco/Steerable-Motion)不同,我们不依赖于AnimateDiff,这主要是考虑到当前AnimateDiff生成的视频较为模糊所决定的。通过IPAdapter+ControlNet的逐帧控制,使得我们可以生成更高清和可控的视频。 3 | ![输入图片](./demo/input.png) 4 | ![生成视频](./demo/output.gif) 5 | 6 | ## 使用介绍 7 | - 与[Steerable-motion](https://github.com/banodoco/Steerable-Motion)结构类似,我们提供了线性和动态两种控制方式,并且提供了外置ControlNet图片的输入,以提供更多的可玩性。 8 | 9 | ![模型](./demo/model.png) 10 | - 主要参数: 11 | - 变化帧长度 12 | - 影响强度范围 13 | - 相对IPA和CN的影响力度 14 | - 具体工作流请参考:[demo](./demo/IPAnimate-demo.json) 15 | 16 | ## 参考来源 17 | - [Steerable-motion](https://github.com/banodoco/Steerable-Motion) 18 | - [Kosinkadink's ComfyUI-Advanced-ControlNet](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet) 19 | - [IPAdapter_plus](https://github.com/cubiq/ComfyUI_IPAdapter_plus) 20 | -------------------------------------------------------------------------------- /imports/AdvancedControlNet/logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import logging 4 | 5 | 6 | class ColoredFormatter(logging.Formatter): 7 | COLORS = { 8 | "DEBUG": "\033[0;36m", # CYAN 9 | "INFO": "\033[0;32m", # GREEN 10 | "WARNING": "\033[0;33m", # YELLOW 11 | "ERROR": "\033[0;31m", # RED 12 | "CRITICAL": "\033[0;37;41m", # WHITE ON RED 13 | "RESET": "\033[0m", # RESET COLOR 14 | } 15 | 16 | def format(self, record): 17 | colored_record = copy.copy(record) 18 | levelname = colored_record.levelname 19 | seq = self.COLORS.get(levelname, self.COLORS["RESET"]) 20 | colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" 21 | return super().format(colored_record) 22 | 23 | 24 | # Create a new logger 25 | logger = logging.getLogger("Advanced-ControlNet") 26 | logger.propagate = False 27 | 28 | # Add handler if we don't have one. 29 | if not logger.handlers: 30 | handler = logging.StreamHandler(sys.stdout) 31 | handler.setFormatter(ColoredFormatter("[%(name)s] - %(levelname)s - %(message)s")) 32 | logger.addHandler(handler) 33 | 34 | # Configure logger 35 | loglevel = logging.INFO 36 | logger.setLevel(loglevel) 37 | -------------------------------------------------------------------------------- /imports/AdvancedControlNet/nodes_sparsectrl.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | import folder_paths 4 | from nodes import VAEEncode 5 | import comfy.utils 6 | 7 | # from .utils import TimestepKeyframeGroup 8 | from .control_sparsectrl import SparseIndexMethodImport 9 | # from .control import load_sparsectrl, load_controlnet, ControlNetAdvanced, SparseCtrlAdvanced 10 | 11 | 12 | 13 | class SparseIndexMethodNodeImport: 14 | @classmethod 15 | def INPUT_TYPES(s): 16 | return { 17 | "required": { 18 | "indexes": ("STRING", {"default": "0"}), 19 | } 20 | } 21 | 22 | RETURN_TYPES = ("SPARSE_METHOD",) 23 | FUNCTION = "get_method" 24 | 25 | CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/SparseCtrl" 26 | 27 | def get_method(self, indexes: str): 28 | idxs = [] 29 | unique_idxs = set() 30 | # get indeces from string 31 | str_idxs = [x.strip() for x in indexes.strip().split(",")] 32 | for str_idx in str_idxs: 33 | try: 34 | idx = int(str_idx) 35 | if idx in unique_idxs: 36 | raise ValueError(f"'{idx}' is duplicated; indexes must be unique.") 37 | idxs.append(idx) 38 | unique_idxs.add(idx) 39 | except ValueError: 40 | raise ValueError(f"'{str_idx}' is not a valid integer index.") 41 | if len(idxs) == 0: 42 | raise ValueError(f"No indexes were listed in Sparse Index Method.") 43 | return (SparseIndexMethodImport(idxs),) 44 | 45 | -------------------------------------------------------------------------------- /imports/AdvancedControlNet/control_sparsectrl.py: -------------------------------------------------------------------------------- 1 | #taken from: https://github.com/lllyasviel/ControlNet 2 | #and modified 3 | #and then taken from comfy/cldm/cldm.py and modified again 4 | 5 | from abc import ABC, abstractmethod 6 | import math 7 | import numpy as np 8 | from typing import Iterable, Union 9 | import torch 10 | import torch as th 11 | import torch.nn as nn 12 | from torch import Tensor 13 | from einops import rearrange, repeat 14 | 15 | from comfy.ldm.modules.diffusionmodules.util import ( 16 | zero_module, 17 | timestep_embedding, 18 | ) 19 | 20 | from comfy.cldm.cldm import ControlNet as ControlNetCLDM 21 | from comfy.ldm.modules.attention import SpatialTransformer 22 | from comfy.ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential, ResBlock, Downsample 23 | from comfy.ldm.util import exists 24 | from comfy.ldm.modules.attention import default, optimized_attention 25 | from comfy.ldm.modules.attention import FeedForward, SpatialTransformer 26 | from comfy.controlnet import broadcast_image_to 27 | from comfy.utils import repeat_to_batch_size 28 | import comfy.ops 29 | 30 | # from .utils import TimestepKeyframeGroup, disable_weight_init_clean_groupnorm, prepare_mask_batch 31 | 32 | 33 | 34 | 35 | 36 | class SparseMethodImport(ABC): 37 | SPREAD = "spread" 38 | INDEX = "index" 39 | def __init__(self, method: str): 40 | self.method = method 41 | 42 | @abstractmethod 43 | def get_indexes(self, hint_length: int, full_length: int) -> list[int]: 44 | pass 45 | 46 | 47 | 48 | class SparseIndexMethodImport(SparseMethodImport): 49 | def __init__(self, idxs: list[int]): 50 | super().__init__(self.INDEX) 51 | self.idxs = idxs 52 | 53 | def get_indexes(self, hint_length: int, full_length: int) -> list[int]: 54 | orig_hint_length = hint_length 55 | if hint_length > full_length: 56 | hint_length = full_length 57 | # if idxs is less than hint_length, throw error 58 | if len(self.idxs) < hint_length: 59 | err_msg = f"There are not enough indexes ({len(self.idxs)}) provided to fit the usable {hint_length} input images." 60 | if orig_hint_length != hint_length: 61 | err_msg = f"{err_msg} (original input images: {orig_hint_length})" 62 | raise ValueError(err_msg) 63 | # cap idxs to hint_length 64 | idxs = self.idxs[:hint_length] 65 | new_idxs = [] 66 | real_idxs = set() 67 | for idx in idxs: 68 | if idx < 0: 69 | real_idx = full_length+idx 70 | if real_idx in real_idxs: 71 | raise ValueError(f"Index '{idx}' maps to '{real_idx}' and is duplicate - indexes in Sparse Index Method must be unique.") 72 | else: 73 | real_idx = idx 74 | if real_idx in real_idxs: 75 | raise ValueError(f"Index '{idx}' is duplicate (or a negative index is equivalent) - indexes in Sparse Index Method must be unique.") 76 | real_idxs.add(real_idx) 77 | new_idxs.append(real_idx) 78 | return new_idxs 79 | 80 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /imports/AdvancedControlNet/deprecated_nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | import numpy as np 6 | from PIL import Image, ImageOps 7 | from .control import ControlWeights, LatentKeyframeGroup, TimestepKeyframeGroup, TimestepKeyframe 8 | from .logger import logger 9 | 10 | 11 | class LoadImagesFromDirectory: 12 | @classmethod 13 | def INPUT_TYPES(s): 14 | return { 15 | "required": { 16 | "directory": ("STRING", {"default": ""}), 17 | }, 18 | "optional": { 19 | "image_load_cap": ("INT", {"default": 0, "min": 0, "step": 1}), 20 | "start_index": ("INT", {"default": 0, "min": 0, "step": 1}), 21 | } 22 | } 23 | 24 | RETURN_TYPES = ("IMAGE", "MASK", "INT") 25 | FUNCTION = "load_images" 26 | 27 | CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/deprecated" 28 | 29 | def load_images(self, directory: str, image_load_cap: int = 0, start_index: int = 0): 30 | if not os.path.isdir(directory): 31 | raise FileNotFoundError(f"Directory '{directory} cannot be found.'") 32 | dir_files = os.listdir(directory) 33 | if len(dir_files) == 0: 34 | raise FileNotFoundError(f"No files in directory '{directory}'.") 35 | 36 | dir_files = sorted(dir_files) 37 | dir_files = [os.path.join(directory, x) for x in dir_files] 38 | # start at start_index 39 | dir_files = dir_files[start_index:] 40 | 41 | images = [] 42 | masks = [] 43 | 44 | limit_images = False 45 | if image_load_cap > 0: 46 | limit_images = True 47 | image_count = 0 48 | 49 | for image_path in dir_files: 50 | if os.path.isdir(image_path): 51 | continue 52 | if limit_images and image_count >= image_load_cap: 53 | break 54 | i = Image.open(image_path) 55 | i = ImageOps.exif_transpose(i) 56 | image = i.convert("RGB") 57 | image = np.array(image).astype(np.float32) / 255.0 58 | image = torch.from_numpy(image)[None,] 59 | if 'A' in i.getbands(): 60 | mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 61 | mask = 1. - torch.from_numpy(mask) 62 | else: 63 | mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") 64 | images.append(image) 65 | masks.append(mask) 66 | image_count += 1 67 | 68 | if len(images) == 0: 69 | raise FileNotFoundError(f"No images could be loaded from directory '{directory}'.") 70 | 71 | return (torch.cat(images, dim=0), torch.stack(masks, dim=0), image_count) 72 | 73 | 74 | class TimestepKeyframeNodeDeprecated: 75 | @classmethod 76 | def INPUT_TYPES(s): 77 | return { 78 | "required": { 79 | "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}, ), 80 | }, 81 | "optional": { 82 | "control_net_weights": ("CONTROL_NET_WEIGHTS", ), 83 | "t2i_adapter_weights": ("T2I_ADAPTER_WEIGHTS", ), 84 | "latent_keyframe": ("LATENT_KEYFRAME", ), 85 | "prev_timestep_keyframe": ("TIMESTEP_KEYFRAME", ), 86 | } 87 | } 88 | 89 | RETURN_TYPES = ("TIMESTEP_KEYFRAME", ) 90 | FUNCTION = "load_keyframe" 91 | 92 | CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes" 93 | 94 | def load_keyframe(self, 95 | start_percent: float, 96 | control_net_weights: ControlWeights=None, 97 | latent_keyframe: LatentKeyframeGroup=None, 98 | prev_timestep_keyframe: TimestepKeyframeGroup=None): 99 | if not prev_timestep_keyframe: 100 | prev_timestep_keyframe = TimestepKeyframeGroup() 101 | keyframe = TimestepKeyframe(start_percent, control_net_weights, latent_keyframe) 102 | prev_timestep_keyframe.add(keyframe) 103 | return (prev_timestep_keyframe,) 104 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from .imports.IPAdapterPlus import IPAdapterApplyImport, prep_image, IPAdapterEncoderImport 4 | from .imports.AdvancedControlNet.nodes import AdvancedControlNetApplyImport 5 | 6 | 7 | class IPAdapterAnimateNode: 8 | @classmethod 9 | def IS_CHANGED(cls, **kwargs): 10 | return float("NaN") 11 | 12 | @classmethod 13 | def INPUT_TYPES(s): 14 | return { 15 | "required": { 16 | "images": ("IMAGE", ), 17 | "model": ("MODEL", ), 18 | "positive": ("CONDITIONING", ), 19 | "negative": ("CONDITIONING", ), 20 | "ipadapter": ("IPADAPTER", ), 21 | "clip_vision": ("CLIP_VISION",), 22 | "index": ("INT", {"default": 0, "min": 0, "max": 99999, "step": 1}), 23 | "type_of_frame_distribution": (["linear", "dynamic"],), 24 | "linear_frame_distribution_value": ("INT", {"default": 8, "min": 4, "max": 64, "step": 1}), 25 | "dynamic_frame_distribution_values": ("STRING", {"multiline": True, "default": "16,8,8,16"}), 26 | "type_of_strength_distribution": (["linear", "dynamic"],), 27 | "linear_strength_value": ("STRING", {"multiline": False, "default": "(0.1,0.9)"}), 28 | "dynamic_strength_values": ("STRING", {"multiline": True, "default": "(0.0,1.0),(0.0,1.0),(0.0,1.0),(0.0,1.0)"}), 29 | "relative_ipadapter_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), 30 | "relative_cn_strength": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 10.0, "step": 0.01}), 31 | }, 32 | "optional": { 33 | "control_net": ("CONTROL_NET", {"default": None}), 34 | "control_images": ("IMAGE", {"default": None}), 35 | } 36 | } 37 | 38 | RETURN_TYPES = ("MODEL","CONDITIONING", "CONDITIONING", "INT", "STRING") 39 | RETURN_NAMES = ("MODEL", "POSITIVE", "NEGATIVE", "NET_INDEX", "LOGS") 40 | FUNCTION = "apply" 41 | 42 | CATEGORY = "IPAnimate" 43 | 44 | def apply(self, images, model, positive, negative, ipadapter, clip_vision, index, 45 | type_of_frame_distribution, linear_frame_distribution_value, dynamic_frame_distribution_values, 46 | type_of_strength_distribution, linear_strength_value, dynamic_strength_values, 47 | relative_ipadapter_strength, relative_cn_strength, control_net=None, control_images=None 48 | ): 49 | 50 | def get_flow_image_index(images, type_of_frame_distribution, dynamic_frame_distribution_values, linear_frame_distribution_value): 51 | "获取处理流程图片列表索引" 52 | if type_of_frame_distribution == "dynamic": 53 | if isinstance(dynamic_frame_distribution_values, str): 54 | flow_swap_nums = [int(kf.strip()) for kf in dynamic_frame_distribution_values.split(',')] 55 | elif isinstance(dynamic_frame_distribution_values, list): 56 | flow_swap_nums = [int(kf) for kf in dynamic_frame_distribution_values] 57 | 58 | else: 59 | flow_swap_nums = [linear_frame_distribution_value for i in range(len(images)-1)] 60 | 61 | flow_image_indexs = [] 62 | for i, n in enumerate(flow_swap_nums): 63 | flow_image_indexs.extend([(i, i+1)]*n) 64 | return flow_image_indexs, flow_swap_nums 65 | 66 | def get_flow_strengths(flow_swap_nums, type_of_strength_distribution, dynamic_strength_values, linear_strength_value): 67 | "获取处理流程权重列表" 68 | if type_of_strength_distribution == "dynamic": 69 | if isinstance(dynamic_strength_values[0], str) and dynamic_strength_values[0] == "(": 70 | string_representation = ''.join(dynamic_strength_values) 71 | values = eval(f'[{string_representation}]') 72 | else: 73 | values = dynamic_strength_values if isinstance(dynamic_strength_values, list) else [dynamic_strength_values] 74 | else: 75 | values = [eval(linear_strength_value) for _ in flow_swap_nums] 76 | 77 | flow_strengths = [] 78 | for i, v in enumerate(values): 79 | v_min = min(v[0], v[1]) 80 | v_max = max(v[0], v[1]) 81 | v_min = v_min if v_min >= 0 else 0 82 | v_max = v_max if v_max <=1.0 else 1.0 83 | swap_num = flow_swap_nums[i] 84 | x = np.pi*(np.linspace(0, swap_num, swap_num)) / swap_num / 2 85 | strengths_2 = v_min + (v_max-v_min)*np.sin(x) 86 | strengths_1 = v_min + (v_max-v_min)*np.cos(x) 87 | 88 | flow_strengths.extend(list(zip(strengths_1, strengths_2))) 89 | 90 | return flow_strengths 91 | 92 | 93 | # 是否使用controlnet 94 | use_cn = True 95 | if control_net is None or control_images is None: 96 | use_cn = False 97 | relative_cn_strength = 0 98 | 99 | assert len(images) > 1 100 | if use_cn: 101 | assert len(images) == len(control_images) 102 | 103 | # 获取处理流程图片列表索引 104 | flow_image_indexs, flow_swap_nums = get_flow_image_index(images, type_of_frame_distribution, dynamic_frame_distribution_values, linear_frame_distribution_value) 105 | # 获取处理流程权重列表 106 | flow_strengths = get_flow_strengths(flow_swap_nums, type_of_strength_distribution, dynamic_strength_values, linear_strength_value) 107 | 108 | assert len(flow_image_indexs) == len(flow_strengths) 109 | assert index < len(flow_image_indexs) 110 | 111 | if use_cn: 112 | apply_advanced_control_net = AdvancedControlNetApplyImport() 113 | ipadapter_application = IPAdapterApplyImport() 114 | ipadapter_encoder = IPAdapterEncoderImport() 115 | 116 | logs = { 117 | "index": index, 118 | "use_controlnet": use_cn, 119 | "image_index": flow_image_indexs[index], 120 | "ipadapter_strength": [i*relative_ipadapter_strength for i in flow_strengths[index]], 121 | "contronet_strength": [i*relative_cn_strength for i in flow_strengths[index]] 122 | } 123 | 124 | for i in range(2): 125 | if relative_ipadapter_strength > 0: 126 | # IP处理 127 | image = images[flow_image_indexs[index][i]] 128 | # 裁剪中间区域 129 | prepped_image = prep_image(image=image.unsqueeze(0), interpolation="LANCZOS", crop_position="pad", sharpening=0.0)[0] 130 | # 应用IPadapter 131 | embed, = ipadapter_encoder.preprocess(clip_vision, prepped_image, True, 0.0, 1.0) 132 | model, = ipadapter_application.apply_ipadapter(ipadapter=ipadapter, model=model, weight=flow_strengths[index][i]*relative_ipadapter_strength, image=None, weight_type="linear", 133 | noise=0.0, embeds=embed, attn_mask=None, start_at=0.0, end_at=1.0, unfold_batch=False) 134 | 135 | if relative_cn_strength > 0 and use_cn: 136 | # Contronet处理 137 | control_image = control_images[flow_image_indexs[index][i]] 138 | positive, negative = apply_advanced_control_net.apply_controlnet(positive, negative, control_net, control_image.unsqueeze(0), flow_strengths[index][i]*relative_cn_strength, 0.0, 1.0) 139 | 140 | 141 | return model, positive, negative, index+1, str(logs) 142 | 143 | 144 | # NODE MAPPING 145 | NODE_CLASS_MAPPINGS = { 146 | "IPAdapterAnimate": IPAdapterAnimateNode, 147 | } 148 | 149 | NODE_DISPLAY_NAME_MAPPINGS = { 150 | "IPAdapterAnimate": "IPAdapterAnimate by Chan" 151 | } 152 | -------------------------------------------------------------------------------- /imports/AdvancedControlNet/nodes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import Tensor 3 | 4 | import folder_paths 5 | 6 | from .control import load_controlnet, convert_to_advanced, ControlWeightsImport, ControlWeightTypeImport,\ 7 | LatentKeyframeGroupImport, TimestepKeyframeImport, TimestepKeyframeGroupImport, is_advanced_controlnet 8 | from .control import StrengthInterpolationImport as SI 9 | from .weight_nodes import DefaultWeightsImport, ScaledSoftMaskedUniversalWeightsImport, ScaledSoftUniversalWeightsImport, SoftControlNetWeightsImport, CustomControlNetWeightsImport, \ 10 | SoftT2IAdapterWeightsImport, CustomT2IAdapterWeightsImport 11 | from .latent_keyframe_nodes import LatentKeyframeGroupNodeImport, LatentKeyframeInterpolationNodeImport, LatentKeyframeBatchedGroupNodeImport, LatentKeyframeNodeImport 12 | from .logger import logger 13 | 14 | 15 | class TimestepKeyframeNodeImport: 16 | @classmethod 17 | def INPUT_TYPES(s): 18 | return { 19 | "required": { 20 | "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}, ), 21 | }, 22 | "optional": { 23 | "prev_timestep_kf": ("TIMESTEP_KEYFRAME", ), 24 | "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 25 | "cn_weights": ("CONTROL_NET_WEIGHTS", ), 26 | "latent_keyframe": ("LATENT_KEYFRAME", ), 27 | "null_latent_kf_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 28 | "inherit_missing": ("BOOLEAN", {"default": True}, ), 29 | "guarantee_usage": ("BOOLEAN", {"default": True}, ), 30 | "mask_optional": ("MASK", ), 31 | #"interpolation": ([SI.LINEAR, SI.EASE_IN, SI.EASE_OUT, SI.EASE_IN_OUT, SI.NONE], {"default": SI.NONE}, ), 32 | } 33 | } 34 | 35 | RETURN_NAMES = ("TIMESTEP_KF", ) 36 | RETURN_TYPES = ("TIMESTEP_KEYFRAME", ) 37 | FUNCTION = "load_keyframe" 38 | 39 | CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes" 40 | 41 | def load_keyframe(self, 42 | start_percent: float, 43 | strength: float=1.0, 44 | cn_weights: ControlWeightsImport=None, control_net_weights: ControlWeightsImport=None, # old name 45 | latent_keyframe: LatentKeyframeGroupImport=None, 46 | prev_timestep_kf: TimestepKeyframeGroupImport=None, prev_timestep_keyframe: TimestepKeyframeGroupImport=None, # old name 47 | null_latent_kf_strength: float=0.0, 48 | inherit_missing=True, 49 | guarantee_usage=True, 50 | mask_optional=None, 51 | interpolation: str=SI.NONE,): 52 | control_net_weights = control_net_weights if control_net_weights else cn_weights 53 | prev_timestep_keyframe = prev_timestep_keyframe if prev_timestep_keyframe else prev_timestep_kf 54 | if not prev_timestep_keyframe: 55 | prev_timestep_keyframe = TimestepKeyframeGroupImport() 56 | else: 57 | prev_timestep_keyframe = prev_timestep_keyframe.clone() 58 | keyframe = TimestepKeyframeImport(start_percent=start_percent, strength=strength, interpolation=interpolation, null_latent_kf_strength=null_latent_kf_strength, 59 | control_weights=control_net_weights, latent_keyframes=latent_keyframe, inherit_missing=inherit_missing, guarantee_usage=guarantee_usage, 60 | mask_hint_orig=mask_optional) 61 | prev_timestep_keyframe.add(keyframe) 62 | return (prev_timestep_keyframe,) 63 | 64 | 65 | class ControlNetLoaderAdvancedImport: 66 | @classmethod 67 | def INPUT_TYPES(s): 68 | return { 69 | "required": { 70 | "control_net_name": (folder_paths.get_filename_list("controlnet"), ), 71 | }, 72 | "optional": { 73 | "timestep_keyframe": ("TIMESTEP_KEYFRAME", ), 74 | } 75 | } 76 | 77 | RETURN_TYPES = ("CONTROL_NET", ) 78 | FUNCTION = "load_controlnet" 79 | 80 | CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝" 81 | 82 | def load_controlnet(self, control_net_name, 83 | timestep_keyframe: TimestepKeyframeGroupImport=None 84 | ): 85 | controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) 86 | controlnet = load_controlnet(controlnet_path, timestep_keyframe) 87 | return (controlnet,) 88 | 89 | 90 | class DiffControlNetLoaderAdvancedImport: 91 | @classmethod 92 | def INPUT_TYPES(s): 93 | return { 94 | "required": { 95 | "model": ("MODEL",), 96 | "control_net_name": (folder_paths.get_filename_list("controlnet"), ) 97 | }, 98 | "optional": { 99 | "timestep_keyframe": ("TIMESTEP_KEYFRAME", ), 100 | } 101 | } 102 | 103 | RETURN_TYPES = ("CONTROL_NET", ) 104 | FUNCTION = "load_controlnet" 105 | 106 | CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝" 107 | 108 | def load_controlnet(self, control_net_name, model, 109 | timestep_keyframe: TimestepKeyframeGroupImport=None 110 | ): 111 | controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) 112 | controlnet = load_controlnet(controlnet_path, timestep_keyframe, model) 113 | if is_advanced_controlnet(controlnet): 114 | controlnet.verify_all_weights() 115 | return (controlnet,) 116 | 117 | 118 | class AdvancedControlNetApplyImport: 119 | @classmethod 120 | def INPUT_TYPES(s): 121 | return { 122 | "required": { 123 | "positive": ("CONDITIONING", ), 124 | "negative": ("CONDITIONING", ), 125 | "control_net": ("CONTROL_NET", ), 126 | "image": ("IMAGE", ), 127 | "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), 128 | "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), 129 | "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) 130 | }, 131 | "optional": { 132 | "mask_optional": ("MASK", ), 133 | "timestep_kf": ("TIMESTEP_KEYFRAME", ), 134 | "latent_kf_override": ("LATENT_KEYFRAME", ), 135 | "weights_override": ("CONTROL_NET_WEIGHTS", ), 136 | } 137 | } 138 | 139 | RETURN_TYPES = ("CONDITIONING","CONDITIONING") 140 | RETURN_NAMES = ("positive", "negative") 141 | FUNCTION = "apply_controlnet" 142 | 143 | CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝" 144 | 145 | def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, 146 | mask_optional: Tensor=None, 147 | timestep_kf: TimestepKeyframeGroupImport=None, latent_kf_override: LatentKeyframeGroupImport=None, 148 | weights_override: ControlWeightsImport=None): 149 | if strength == 0: 150 | return (positive, negative) 151 | 152 | control_hint = image.movedim(-1,1) 153 | cnets = {} 154 | 155 | out = [] 156 | for conditioning in [positive, negative]: 157 | c = [] 158 | for t in conditioning: 159 | d = t[1].copy() 160 | 161 | prev_cnet = d.get('control', None) 162 | if prev_cnet in cnets: 163 | c_net = cnets[prev_cnet] 164 | else: 165 | # copy, convert to advanced if needed, and set cond 166 | c_net = convert_to_advanced(control_net.copy()).set_cond_hint(control_hint, strength, (start_percent, end_percent)) 167 | if is_advanced_controlnet(c_net): 168 | # apply optional parameters and overrides, if provided 169 | if timestep_kf is not None: 170 | c_net.set_timestep_keyframes(timestep_kf) 171 | if latent_kf_override is not None: 172 | c_net.latent_keyframe_override = latent_kf_override 173 | if weights_override is not None: 174 | c_net.weights_override = weights_override 175 | # verify weights are compatible 176 | c_net.verify_all_weights() 177 | # set cond hint mask 178 | if mask_optional is not None: 179 | mask_optional = mask_optional.clone() 180 | # if not in the form of a batch, make it so 181 | if len(mask_optional.shape) < 3: 182 | mask_optional = mask_optional.unsqueeze(0) 183 | c_net.set_cond_hint_mask(mask_optional) 184 | c_net.set_previous_controlnet(prev_cnet) 185 | cnets[prev_cnet] = c_net 186 | 187 | d['control'] = c_net 188 | d['control_apply_to_uncond'] = False 189 | n = [t[0], d] 190 | c.append(n) 191 | out.append(c) 192 | return (out[0], out[1]) 193 | 194 | 195 | -------------------------------------------------------------------------------- /imports/AdvancedControlNet/weight_nodes.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | import torch 3 | from .control import TimestepKeyframeImport, TimestepKeyframeGroupImport, ControlWeightsImport, get_properly_arranged_t2i_weights, linear_conversion 4 | from .logger import logger 5 | 6 | 7 | WEIGHTS_RETURN_NAMES = ("CN_WEIGHTS", "TK_SHORTCUT") 8 | 9 | 10 | class DefaultWeightsImport: 11 | @classmethod 12 | def INPUT_TYPES(s): 13 | return { 14 | } 15 | 16 | RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",) 17 | RETURN_NAMES = WEIGHTS_RETURN_NAMES 18 | FUNCTION = "load_weights" 19 | 20 | CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights" 21 | 22 | def load_weights(self): 23 | weights = ControlWeightsImport.default() 24 | return (weights, TimestepKeyframeGroupImport.default(TimestepKeyframeImport(control_weights=weights))) 25 | 26 | 27 | class ScaledSoftMaskedUniversalWeightsImport: 28 | @classmethod 29 | def INPUT_TYPES(s): 30 | return { 31 | "required": { 32 | "mask": ("MASK", ), 33 | "min_base_multiplier": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}, ), 34 | "max_base_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}, ), 35 | #"lock_min": ("BOOLEAN", {"default": False}, ), 36 | #"lock_max": ("BOOLEAN", {"default": False}, ), 37 | }, 38 | } 39 | 40 | RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",) 41 | RETURN_NAMES = WEIGHTS_RETURN_NAMES 42 | FUNCTION = "load_weights" 43 | 44 | CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights" 45 | 46 | def load_weights(self, mask: Tensor, min_base_multiplier: float, max_base_multiplier: float, lock_min=False, lock_max=False): 47 | # normalize mask 48 | mask = mask.clone() 49 | x_min = 0.0 if lock_min else mask.min() 50 | x_max = 1.0 if lock_max else mask.max() 51 | if x_min == x_max: 52 | mask = torch.ones_like(mask) * max_base_multiplier 53 | else: 54 | mask = linear_conversion(mask, x_min, x_max, min_base_multiplier, max_base_multiplier) 55 | weights = ControlWeightsImport.universal_mask(weight_mask=mask) 56 | return (weights, TimestepKeyframeGroupImport.default(TimestepKeyframeImport(control_weights=weights))) 57 | 58 | 59 | class ScaledSoftUniversalWeightsImport: 60 | @classmethod 61 | def INPUT_TYPES(s): 62 | return { 63 | "required": { 64 | "base_multiplier": ("FLOAT", {"default": 0.825, "min": 0.0, "max": 1.0, "step": 0.001}, ), 65 | "flip_weights": ("BOOLEAN", {"default": False}), 66 | }, 67 | } 68 | 69 | RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",) 70 | RETURN_NAMES = WEIGHTS_RETURN_NAMES 71 | FUNCTION = "load_weights" 72 | 73 | CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights" 74 | 75 | def load_weights(self, base_multiplier, flip_weights): 76 | weights = ControlWeightsImport.universal(base_multiplier=base_multiplier, flip_weights=flip_weights) 77 | return (weights, TimestepKeyframeGroupImport.default(TimestepKeyframeImport(control_weights=weights))) 78 | 79 | 80 | class SoftControlNetWeightsImport: 81 | @classmethod 82 | def INPUT_TYPES(s): 83 | return { 84 | "required": { 85 | "weight_00": ("FLOAT", {"default": 0.09941396206337118, "min": 0.0, "max": 10.0, "step": 0.001}, ), 86 | "weight_01": ("FLOAT", {"default": 0.12050177219802567, "min": 0.0, "max": 10.0, "step": 0.001}, ), 87 | "weight_02": ("FLOAT", {"default": 0.14606275417942507, "min": 0.0, "max": 10.0, "step": 0.001}, ), 88 | "weight_03": ("FLOAT", {"default": 0.17704576264172736, "min": 0.0, "max": 10.0, "step": 0.001}, ), 89 | "weight_04": ("FLOAT", {"default": 0.214600924414215, "min": 0.0, "max": 10.0, "step": 0.001}, ), 90 | "weight_05": ("FLOAT", {"default": 0.26012233262329093, "min": 0.0, "max": 10.0, "step": 0.001}, ), 91 | "weight_06": ("FLOAT", {"default": 0.3152997971191405, "min": 0.0, "max": 10.0, "step": 0.001}, ), 92 | "weight_07": ("FLOAT", {"default": 0.3821815722656249, "min": 0.0, "max": 10.0, "step": 0.001}, ), 93 | "weight_08": ("FLOAT", {"default": 0.4632503906249999, "min": 0.0, "max": 10.0, "step": 0.001}, ), 94 | "weight_09": ("FLOAT", {"default": 0.561515625, "min": 0.0, "max": 10.0, "step": 0.001}, ), 95 | "weight_10": ("FLOAT", {"default": 0.6806249999999999, "min": 0.0, "max": 10.0, "step": 0.001}, ), 96 | "weight_11": ("FLOAT", {"default": 0.825, "min": 0.0, "max": 10.0, "step": 0.001}, ), 97 | "weight_12": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 98 | "flip_weights": ("BOOLEAN", {"default": False}), 99 | }, 100 | } 101 | 102 | RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",) 103 | RETURN_NAMES = WEIGHTS_RETURN_NAMES 104 | FUNCTION = "load_weights" 105 | 106 | CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights/ControlNet" 107 | 108 | def load_weights(self, weight_00, weight_01, weight_02, weight_03, weight_04, weight_05, weight_06, 109 | weight_07, weight_08, weight_09, weight_10, weight_11, weight_12, flip_weights): 110 | weights = [weight_00, weight_01, weight_02, weight_03, weight_04, weight_05, weight_06, 111 | weight_07, weight_08, weight_09, weight_10, weight_11, weight_12] 112 | weights = ControlWeightsImport.controlnet(weights, flip_weights=flip_weights) 113 | return (weights, TimestepKeyframeGroupImport.default(TimestepKeyframeImport(control_weights=weights))) 114 | 115 | 116 | class CustomControlNetWeightsImport: 117 | @classmethod 118 | def INPUT_TYPES(s): 119 | return { 120 | "required": { 121 | "weight_00": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 122 | "weight_01": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 123 | "weight_02": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 124 | "weight_03": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 125 | "weight_04": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 126 | "weight_05": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 127 | "weight_06": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 128 | "weight_07": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 129 | "weight_08": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 130 | "weight_09": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 131 | "weight_10": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 132 | "weight_11": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 133 | "weight_12": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 134 | "flip_weights": ("BOOLEAN", {"default": False}), 135 | } 136 | } 137 | 138 | RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",) 139 | RETURN_NAMES = WEIGHTS_RETURN_NAMES 140 | FUNCTION = "load_weights" 141 | 142 | CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights/ControlNet" 143 | 144 | def load_weights(self, weight_00, weight_01, weight_02, weight_03, weight_04, weight_05, weight_06, 145 | weight_07, weight_08, weight_09, weight_10, weight_11, weight_12, flip_weights): 146 | weights = [weight_00, weight_01, weight_02, weight_03, weight_04, weight_05, weight_06, 147 | weight_07, weight_08, weight_09, weight_10, weight_11, weight_12] 148 | weights = ControlWeightsImport.controlnet(weights, flip_weights=flip_weights) 149 | return (weights, TimestepKeyframeGroupImport.default(TimestepKeyframeImport(control_weights=weights))) 150 | 151 | 152 | class SoftT2IAdapterWeightsImport: 153 | @classmethod 154 | def INPUT_TYPES(s): 155 | return { 156 | "required": { 157 | "weight_00": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 10.0, "step": 0.001}, ), 158 | "weight_01": ("FLOAT", {"default": 0.62, "min": 0.0, "max": 10.0, "step": 0.001}, ), 159 | "weight_02": ("FLOAT", {"default": 0.825, "min": 0.0, "max": 10.0, "step": 0.001}, ), 160 | "weight_03": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 161 | "flip_weights": ("BOOLEAN", {"default": False}), 162 | }, 163 | } 164 | 165 | RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",) 166 | RETURN_NAMES = WEIGHTS_RETURN_NAMES 167 | FUNCTION = "load_weights" 168 | 169 | CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights/T2IAdapter" 170 | 171 | def load_weights(self, weight_00, weight_01, weight_02, weight_03, flip_weights): 172 | weights = [weight_00, weight_01, weight_02, weight_03] 173 | weights = get_properly_arranged_t2i_weights(weights) 174 | weights = ControlWeightsImport.t2iadapter(weights, flip_weights=flip_weights) 175 | return (weights, TimestepKeyframeGroupImport.default(TimestepKeyframeImport(control_weights=weights))) 176 | 177 | 178 | class CustomT2IAdapterWeightsImport: 179 | @classmethod 180 | def INPUT_TYPES(s): 181 | return { 182 | "required": { 183 | "weight_00": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 184 | "weight_01": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 185 | "weight_02": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 186 | "weight_03": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 187 | "flip_weights": ("BOOLEAN", {"default": False}), 188 | }, 189 | } 190 | 191 | RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",) 192 | RETURN_NAMES = WEIGHTS_RETURN_NAMES 193 | FUNCTION = "load_weights" 194 | 195 | CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights/T2IAdapter" 196 | 197 | def load_weights(self, weight_00, weight_01, weight_02, weight_03, flip_weights): 198 | weights = [weight_00, weight_01, weight_02, weight_03] 199 | weights = get_properly_arranged_t2i_weights(weights) 200 | weights = ControlWeightsImport.t2iadapter(weights, flip_weights=flip_weights) 201 | return (weights, TimestepKeyframeGroupImport.default(TimestepKeyframeImport(control_weights=weights))) 202 | -------------------------------------------------------------------------------- /demo/IPAnimate-demo.json: -------------------------------------------------------------------------------- 1 | {"last_node_id":67,"last_link_id":148,"nodes":[{"id":46,"type":"ShowText|pysssss","pos":[1925,959],"size":{"0":319.0539245605469,"1":116.03609466552734},"flags":{},"order":22,"mode":0,"inputs":[{"name":"text","type":"STRING","link":123,"widget":{"name":"text"},"label":"文本"}],"outputs":[{"name":"STRING","type":"STRING","links":null,"shape":6,"label":"字符串"}],"properties":{"Node name for S&R":"ShowText|pysssss"},"widgets_values":["","{'index': 7, 'use_controlnet': True, 'image_index': (0, 1), 'ipadapter_strength': [0.3736161146605351, 0.8517540966287267], 'contronet_strength': [0.11208483439816054, 0.255526228988618]}"]},{"id":37,"type":"EmptyLatentImage","pos":[1931,1123],"size":{"0":315,"1":106},"flags":{},"order":0,"mode":0,"outputs":[{"name":"LATENT","type":"LATENT","links":[133],"shape":3,"label":"Latent","slot_index":0}],"properties":{"Node name for S&R":"EmptyLatentImage"},"widgets_values":[512,512,1]},{"id":38,"type":"VAEDecode","pos":[2686,729],"size":{"0":210,"1":46},"flags":{},"order":23,"mode":0,"inputs":[{"name":"samples","type":"LATENT","link":134,"label":"Latent"},{"name":"vae","type":"VAE","link":128,"label":"VAE","slot_index":1}],"outputs":[{"name":"IMAGE","type":"IMAGE","links":[129],"shape":3,"label":"图像","slot_index":0}],"properties":{"Node name for S&R":"VAEDecode"}},{"id":51,"type":"SaveImage","pos":[2950,738],"size":{"0":800,"1":640},"flags":{},"order":24,"mode":0,"inputs":[{"name":"images","type":"IMAGE","link":129,"label":"图像"},{"name":"filename_prefix","type":"STRING","link":127,"widget":{"name":"filename_prefix"},"slot_index":1,"label":"文件名前缀"}],"properties":{},"widgets_values":["IPA/11/0"]},{"id":24,"type":"CLIPVisionLoader","pos":[1022,1113],"size":{"0":315,"1":58},"flags":{},"order":1,"mode":0,"outputs":[{"name":"CLIP_VISION","type":"CLIP_VISION","links":[113],"shape":3,"label":"CLIP视觉","slot_index":0}],"properties":{"Node name for S&R":"CLIPVisionLoader"},"widgets_values":["clip_model.safetensors"]},{"id":23,"type":"IPAdapterModelLoader","pos":[1017,1004],"size":{"0":315,"1":58},"flags":{},"order":2,"mode":0,"outputs":[{"name":"IPADAPTER","type":"IPADAPTER","links":[114],"shape":3,"label":"IP适配","slot_index":0}],"properties":{"Node name for S&R":"IPAdapterModelLoader"},"widgets_values":["ip-adapter_sd15_plus.pth"]},{"id":56,"type":"CheckpointLoaderSimple","pos":[465,736],"size":{"0":400,"1":120},"flags":{},"order":3,"mode":0,"outputs":[{"name":"MODEL","type":"MODEL","links":[140],"shape":3,"slot_index":0,"label":"模型"},{"name":"CLIP","type":"CLIP","links":[141],"shape":3,"slot_index":1,"label":"CLIP"},{"name":"VAE","type":"VAE","links":[],"shape":3,"slot_index":2,"label":"VAE"}],"properties":{"Node name for S&R":"CheckpointLoaderSimple"},"widgets_values":["v1-5-pruned.safetensors"],"color":"#223","bgcolor":"#335"},{"id":55,"type":"CLIPSetLastLayer","pos":[1016,762],"size":{"0":315,"1":58},"flags":{},"order":14,"mode":0,"inputs":[{"name":"clip","type":"CLIP","link":141,"label":"CLIP"}],"outputs":[{"name":"CLIP","type":"CLIP","links":[139],"shape":3,"label":"CLIP","slot_index":0}],"properties":{"Node name for S&R":"CLIPSetLastLayer"},"widgets_values":[-2]},{"id":60,"type":"TilePreprocessor","pos":[1024,1335],"size":{"0":315,"1":82},"flags":{},"order":17,"mode":0,"inputs":[{"name":"image","type":"IMAGE","link":147,"label":"图像"}],"outputs":[{"name":"IMAGE","type":"IMAGE","links":[146],"shape":3,"label":"图像","slot_index":0}],"properties":{"Node name for S&R":"TilePreprocessor"},"widgets_values":[3,512]},{"id":61,"type":"Note","pos":[1378,1469],"size":{"0":210,"1":58},"flags":{},"order":4,"mode":0,"properties":{"text":""},"widgets_values":["注意:每次启动需要设置为value=0"],"color":"#432","bgcolor":"#653"},{"id":62,"type":"Note","pos":[2687,979],"size":{"0":210,"1":58},"flags":{},"order":5,"mode":0,"properties":{"text":""},"widgets_values":["注意:请固定随机种子"],"color":"#432","bgcolor":"#653"},{"id":63,"type":"VHS_VideoCombine","pos":[2338.4211396703395,1385.7629311208566],"size":[210,454],"flags":{},"order":15,"mode":4,"inputs":[{"name":"images","type":"IMAGE","link":148,"label":"图像"},{"name":"audio","type":"VHS_AUDIO","link":null,"label":"audio"}],"outputs":[{"name":"GIF","type":"GIF","links":null,"shape":3,"label":"GIF"}],"properties":{"Node name for S&R":"VHS_VideoCombine"},"widgets_values":{"frame_rate":10,"loop_count":0,"filename_prefix":"ipa","format":"video/h264-mp4","pix_fmt":"yuv420p","crf":19,"save_metadata":true,"pingpong":false,"save_output":true,"videopreview":{"hidden":false,"paused":false,"params":{"filename":"inpippop_00004.mp4","subfolder":"","type":"output","format":"video/h264-mp4"}}},"color":"#233","bgcolor":"#355"},{"id":64,"type":"VHS_LoadImagesPath","pos":[1950,1385],"size":[315,194],"flags":{},"order":6,"mode":4,"outputs":[{"name":"IMAGE","type":"IMAGE","links":[148],"shape":3,"slot_index":0,"label":"图像"},{"name":"MASK","type":"MASK","links":null,"shape":3,"label":"遮罩"},{"name":"INT","type":"INT","links":null,"shape":3,"label":"整数"}],"properties":{"Node name for S&R":"VHS_LoadImagesPath"},"widgets_values":{"directory":"/root/autodl-tmp/outputs/IPA/01/","image_load_cap":0,"skip_first_images":0,"select_every_nth":1,"choose folder to upload":"image","videopreview":{"hidden":false,"paused":false,"params":{"frame_load_cap":0,"skip_first_images":0,"filename":"/root/autodl-tmp/outputs/IPA/01/","type":"path","format":"folder","select_every_nth":1}}},"color":"#233","bgcolor":"#355"},{"id":65,"type":"Note","pos":[2692,1088],"size":{"0":210,"1":58},"flags":{},"order":7,"mode":0,"properties":{"text":""},"widgets_values":["注意:请设置auto Queue"],"color":"#432","bgcolor":"#653"},{"id":66,"type":"Note","pos":[1956,1652],"size":{"0":210,"1":58},"flags":{},"order":8,"mode":4,"properties":{"text":""},"widgets_values":["注意:生成完视频后,可基于该流合并成视频"],"color":"#432","bgcolor":"#653"},{"id":54,"type":"KSampler //Inspire","pos":[2285,729],"size":{"0":311.197509765625,"1":362.55767822265625},"flags":{},"order":20,"mode":0,"inputs":[{"name":"model","type":"MODEL","link":130,"label":"模型"},{"name":"positive","type":"CONDITIONING","link":131,"label":"正面条件"},{"name":"negative","type":"CONDITIONING","link":132,"label":"负面条件"},{"name":"latent_image","type":"LATENT","link":133,"label":"Latent"}],"outputs":[{"name":"LATENT","type":"LATENT","links":[134],"shape":3,"slot_index":0,"label":"Latent"}],"properties":{"Node name for S&R":"KSampler //Inspire"},"widgets_values":[123456,"fixed",18,6,"dpmpp_2m","karras",1,"GPU(=A1111)","incremental",0,0]},{"id":49,"type":"IPAdapterAnimate","pos":[1463,726],"size":{"0":400,"1":434},"flags":{},"order":19,"mode":0,"inputs":[{"name":"images","type":"IMAGE","link":142,"label":"images","slot_index":0},{"name":"model","type":"MODEL","link":140,"label":"model"},{"name":"positive","type":"CONDITIONING","link":115,"label":"positive"},{"name":"negative","type":"CONDITIONING","link":116,"label":"negative"},{"name":"ipadapter","type":"IPADAPTER","link":114,"label":"ipadapter"},{"name":"clip_vision","type":"CLIP_VISION","link":113,"label":"clip_vision"},{"name":"control_net","type":"CONTROL_NET","link":126,"label":"control_net","slot_index":6},{"name":"control_images","type":"IMAGE","link":146,"label":"control_images"},{"name":"index","type":"INT","link":121,"widget":{"name":"index"},"label":"index"}],"outputs":[{"name":"MODEL","type":"MODEL","links":[130],"shape":3,"label":"MODEL","slot_index":0},{"name":"POSITIVE","type":"CONDITIONING","links":[131],"shape":3,"label":"POSITIVE","slot_index":1},{"name":"NEGATIVE","type":"CONDITIONING","links":[132],"shape":3,"label":"NEGATIVE","slot_index":2},{"name":"NET_INDEX","type":"INT","links":[122],"shape":3,"label":"NET_INDEX","slot_index":3},{"name":"LOGS","type":"STRING","links":[123],"shape":3,"label":"LOGS","slot_index":4}],"properties":{"Node name for S&R":"IPAdapterAnimate"},"widgets_values":[0,"linear",10,"8,4","linear","(0.1,0.9)","(0.0,0.5),(0.0,1.0),",1,0.3]},{"id":53,"type":"VAELoader","pos":[2275,1162],"size":{"0":315,"1":58},"flags":{},"order":9,"mode":0,"outputs":[{"name":"VAE","type":"VAE","links":[128],"shape":3,"slot_index":0}],"properties":{"Node name for S&R":"VAELoader"},"widgets_values":["vae-ft-mse-840000-ema-pruned.safetensors"]},{"id":44,"type":"ControlNetLoader","pos":[1027,1224],"size":{"0":315,"1":58},"flags":{},"order":10,"mode":0,"outputs":[{"name":"CONTROL_NET","type":"CONTROL_NET","links":[126],"shape":3,"label":"ControlNet","slot_index":0}],"properties":{"Node name for S&R":"ControlNetLoader"},"widgets_values":["control_v11f1e_sd15_tile_fp16.safetensors"]},{"id":57,"type":"VHS_LoadImagesPath","pos":[495,465],"size":[315,194],"flags":{},"order":11,"mode":0,"outputs":[{"name":"IMAGE","type":"IMAGE","links":[142,143,147],"shape":3,"slot_index":0,"label":"图像"},{"name":"MASK","type":"MASK","links":null,"shape":3,"label":"遮罩"},{"name":"INT","type":"INT","links":[],"shape":3,"slot_index":2,"label":"整数"}],"properties":{"Node name for S&R":"VHS_LoadImagesPath"},"widgets_values":{"directory":"/root/autodl-tmp/ipa/2/","image_load_cap":0,"skip_first_images":0,"select_every_nth":1,"choose folder to upload":"image","videopreview":{"hidden":false,"paused":false,"params":{"frame_load_cap":0,"skip_first_images":0,"filename":"/root/autodl-tmp/ipa/2/","type":"path","format":"folder","select_every_nth":1}}},"color":"#332922","bgcolor":"#593930"},{"id":58,"type":"PreviewImage","pos":[1046,329],"size":{"0":1125.7149658203125,"1":297.4338073730469},"flags":{},"order":16,"mode":0,"inputs":[{"name":"images","type":"IMAGE","link":143,"label":"图像"}],"properties":{"Node name for S&R":"PreviewImage"}},{"id":52,"type":"PrimitiveNode","pos":[2693,856],"size":{"0":210,"1":58},"flags":{},"order":12,"mode":0,"outputs":[{"name":"STRING","type":"STRING","links":[127],"widget":{"name":"filename_prefix"},"label":"STRING"}],"properties":{"Run widget replace on values":false},"widgets_values":["IPA/11/0"],"color":"#323","bgcolor":"#535"},{"id":20,"type":"CLIPTextEncode","pos":[1023,870],"size":{"0":299.464599609375,"1":85.85469055175781},"flags":{},"order":18,"mode":0,"inputs":[{"name":"clip","type":"CLIP","link":139,"label":"CLIP","slot_index":0}],"outputs":[{"name":"CONDITIONING","type":"CONDITIONING","links":[115,116],"shape":3,"label":"条件","slot_index":0}],"properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":[""]},{"id":12,"type":"ImpactValueSender","pos":[1929,828],"size":{"0":315,"1":78},"flags":{},"order":21,"mode":0,"inputs":[{"name":"value","type":"*","link":122,"label":"值","slot_index":0},{"name":"signal_opt","type":"*","link":null,"label":"signal_opt"}],"outputs":[{"name":"signal","type":"*","links":null,"shape":3,"label":"signal"}],"properties":{"Node name for S&R":"ImpactValueSender"},"widgets_values":[1]},{"id":10,"type":"ImpactValueReceiver","pos":[1019,1463],"size":{"0":315,"1":106},"flags":{},"order":13,"mode":0,"outputs":[{"name":"*","type":"*","links":[121],"shape":3,"slot_index":0,"label":"输出"}],"properties":{"Node name for S&R":"ImpactValueReceiver"},"widgets_values":["INT",8,1],"color":"#432","bgcolor":"#653"}],"links":[[113,24,0,49,5,"CLIP_VISION"],[114,23,0,49,4,"IPADAPTER"],[115,20,0,49,2,"CONDITIONING"],[116,20,0,49,3,"CONDITIONING"],[121,10,0,49,8,"INT"],[122,49,3,12,0,"*"],[123,49,4,46,0,"STRING"],[126,44,0,49,6,"CONTROL_NET"],[127,52,0,51,1,"STRING"],[128,53,0,38,1,"VAE"],[129,38,0,51,0,"IMAGE"],[130,49,0,54,0,"MODEL"],[131,49,1,54,1,"CONDITIONING"],[132,49,2,54,2,"CONDITIONING"],[133,37,0,54,3,"LATENT"],[134,54,0,38,0,"LATENT"],[139,55,0,20,0,"CLIP"],[140,56,0,49,1,"MODEL"],[141,56,1,55,0,"CLIP"],[142,57,0,49,0,"IMAGE"],[143,57,0,58,0,"IMAGE"],[146,60,0,49,7,"IMAGE"],[147,57,0,60,0,"IMAGE"],[148,64,0,63,0,"IMAGE"]],"groups":[{"title":"合成视频","bounding":[1931,1288,646,568],"color":"#3f789e","font_size":24,"locked":false}],"config":{},"extra":{"workspace_info":{"id":"168361da-0903-423f-8546-a316d74035f6","name":"IPAnimate-demo"}},"version":0.4} 2 | -------------------------------------------------------------------------------- /imports/AdvancedControlNet/latent_keyframe_nodes.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from collections.abc import Iterable 4 | 5 | from .control import LatentKeyframeImport, LatentKeyframeGroupImport 6 | from .control import StrengthInterpolationImport as SI 7 | from .logger import logger 8 | 9 | 10 | class LatentKeyframeNodeImport: 11 | @classmethod 12 | def INPUT_TYPES(s): 13 | return { 14 | "required": { 15 | "batch_index": ("INT", {"default": 0, "min": -1000, "max": 1000, "step": 1}), 16 | "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), 17 | }, 18 | "optional": { 19 | "prev_latent_kf": ("LATENT_KEYFRAME", ), 20 | } 21 | } 22 | 23 | RETURN_NAMES = ("LATENT_KF", ) 24 | RETURN_TYPES = ("LATENT_KEYFRAME", ) 25 | FUNCTION = "load_keyframe" 26 | 27 | CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes" 28 | 29 | def load_keyframe(self, 30 | batch_index: int, 31 | strength: float, 32 | prev_latent_kf: LatentKeyframeGroupImport=None, 33 | prev_latent_keyframe: LatentKeyframeGroupImport=None, # old name 34 | ): 35 | prev_latent_keyframe = prev_latent_keyframe if prev_latent_keyframe else prev_latent_kf 36 | if not prev_latent_keyframe: 37 | prev_latent_keyframe = LatentKeyframeGroupImport() 38 | else: 39 | prev_latent_keyframe = prev_latent_keyframe.clone() 40 | keyframe = LatentKeyframeImport(batch_index, strength) 41 | prev_latent_keyframe.add(keyframe) 42 | return (prev_latent_keyframe,) 43 | 44 | 45 | class LatentKeyframeGroupNodeImport: 46 | @classmethod 47 | def INPUT_TYPES(s): 48 | return { 49 | "required": { 50 | "index_strengths": ("STRING", {"multiline": True, "default": ""}), 51 | }, 52 | "optional": { 53 | "prev_latent_kf": ("LATENT_KEYFRAME", ), 54 | "latent_optional": ("LATENT", ), 55 | "print_keyframes": ("BOOLEAN", {"default": False}) 56 | } 57 | } 58 | 59 | RETURN_NAMES = ("LATENT_KF", ) 60 | RETURN_TYPES = ("LATENT_KEYFRAME", ) 61 | FUNCTION = "load_keyframes" 62 | 63 | CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes" 64 | 65 | def validate_index(self, index: int, latent_count: int = 0, is_range: bool = False, allow_negative = False) -> int: 66 | # if part of range, do nothing 67 | if is_range: 68 | return index 69 | # otherwise, validate index 70 | # validate not out of range - only when latent_count is passed in 71 | if latent_count > 0 and index > latent_count-1: 72 | raise IndexError(f"Index '{index}' out of range for the total {latent_count} latents.") 73 | # if negative, validate not out of range 74 | if index < 0: 75 | if not allow_negative: 76 | raise IndexError(f"Negative indeces not allowed, but was {index}.") 77 | conv_index = latent_count+index 78 | if conv_index < 0: 79 | raise IndexError(f"Index '{index}', converted to '{conv_index}' out of range for the total {latent_count} latents.") 80 | index = conv_index 81 | return index 82 | 83 | def convert_to_index_int(self, raw_index: str, latent_count: int = 0, is_range: bool = False, allow_negative = False) -> int: 84 | try: 85 | return self.validate_index(int(raw_index), latent_count=latent_count, is_range=is_range, allow_negative=allow_negative) 86 | except ValueError as e: 87 | raise ValueError(f"index '{raw_index}' must be an integer.", e) 88 | 89 | def convert_to_latent_keyframes(self, latent_indeces: str, latent_count: int) -> set[LatentKeyframeImport]: 90 | if not latent_indeces: 91 | return set() 92 | int_latent_indeces = [i for i in range(0, latent_count)] 93 | allow_negative = latent_count > 0 94 | chosen_indeces = set() 95 | # parse string - allow positive ints, negative ints, and ranges separated by ':' 96 | groups = latent_indeces.split(",") 97 | groups = [g.strip() for g in groups] 98 | for g in groups: 99 | # parse strengths - default to 1.0 if no strength given 100 | strength = 1.0 101 | if '=' in g: 102 | g, strength_str = g.split("=", 1) 103 | g = g.strip() 104 | try: 105 | strength = float(strength_str.strip()) 106 | except ValueError as e: 107 | raise ValueError(f"strength '{strength_str}' must be a float.", e) 108 | if strength < 0: 109 | raise ValueError(f"Strength '{strength}' cannot be negative.") 110 | # parse range of indeces (e.g. 2:16) 111 | if ':' in g: 112 | index_range = g.split(":", 1) 113 | index_range = [r.strip() for r in index_range] 114 | start_index = self.convert_to_index_int(index_range[0], latent_count=latent_count, is_range=True, allow_negative=allow_negative) 115 | end_index = self.convert_to_index_int(index_range[1], latent_count=latent_count, is_range=True, allow_negative=allow_negative) 116 | # if latents were passed in, base indeces on known latent count 117 | if len(int_latent_indeces) > 0: 118 | for i in int_latent_indeces[start_index:end_index]: 119 | chosen_indeces.add(LatentKeyframeImport(i, strength)) 120 | # otherwise, assume indeces are valid 121 | else: 122 | for i in range(start_index, end_index): 123 | chosen_indeces.add(LatentKeyframeImport(i, strength)) 124 | # parse individual indeces 125 | else: 126 | chosen_indeces.add(LatentKeyframeImport(self.convert_to_index_int(g, latent_count=latent_count, allow_negative=allow_negative), strength)) 127 | return chosen_indeces 128 | 129 | def load_keyframes(self, 130 | index_strengths: str, 131 | prev_latent_kf: LatentKeyframeGroupImport=None, 132 | prev_latent_keyframe: LatentKeyframeGroupImport=None, # old name 133 | latent_image_opt=None, 134 | print_keyframes=False): 135 | prev_latent_keyframe = prev_latent_keyframe if prev_latent_keyframe else prev_latent_kf 136 | if not prev_latent_keyframe: 137 | prev_latent_keyframe = LatentKeyframeGroupImport() 138 | else: 139 | prev_latent_keyframe = prev_latent_keyframe.clone() 140 | curr_latent_keyframe = LatentKeyframeGroupImport() 141 | 142 | latent_count = -1 143 | if latent_image_opt: 144 | latent_count = latent_image_opt['samples'].size()[0] 145 | latent_keyframes = self.convert_to_latent_keyframes(index_strengths, latent_count=latent_count) 146 | 147 | for latent_keyframe in latent_keyframes: 148 | curr_latent_keyframe.add(latent_keyframe) 149 | 150 | if print_keyframes: 151 | for keyframe in curr_latent_keyframe.keyframes: 152 | logger.info(f"keyframe {keyframe.batch_index}:{keyframe.strength}") 153 | 154 | # replace values with prev_latent_keyframes 155 | for latent_keyframe in prev_latent_keyframe.keyframes: 156 | curr_latent_keyframe.add(latent_keyframe) 157 | 158 | return (curr_latent_keyframe,) 159 | 160 | 161 | class LatentKeyframeInterpolationNodeImport: 162 | 163 | @classmethod 164 | def INPUT_TYPES(s): 165 | return { 166 | "required": { 167 | "batch_index_from": ("INT", {"default": 0, "min": -10000, "max": 10000, "step": 1}), 168 | "batch_index_to_excl": ("INT", {"default": 0, "min": -10000, "max": 10000, "step": 1}), 169 | "strength_from": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.0001}, ), 170 | "strength_to": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.0001}, ), 171 | "interpolation": (["linear", "ease-in", "ease-out", "ease-in-out"], ), 172 | "revert_direction_at_midpoint": ("BOOLEAN", {"default": False}), 173 | }, 174 | "optional": { 175 | "prev_latent_keyframe": ("LATENT_KEYFRAME", ), 176 | } 177 | } 178 | 179 | RETURN_TYPES = ("LATENT_KEYFRAME", ) 180 | FUNCTION = "load_keyframe" 181 | CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes" 182 | 183 | def load_keyframe(self, 184 | weights: int, 185 | frame_numbers: float): 186 | 187 | 188 | curr_latent_keyframe = LatentKeyframeGroupImport() 189 | 190 | for i, frame_number in enumerate(frame_numbers): 191 | keyframe = LatentKeyframeImport(frame_number, float(weights[i])) 192 | curr_latent_keyframe.add(keyframe) 193 | 194 | return (curr_latent_keyframe,) 195 | 196 | class LatentKeyframeBatchedGroupNodeImport: 197 | @classmethod 198 | def INPUT_TYPES(s): 199 | return { 200 | "required": { 201 | "float_strengths": ("FLOAT", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}), 202 | }, 203 | "optional": { 204 | "prev_latent_kf": ("LATENT_KEYFRAME", ), 205 | "print_keyframes": ("BOOLEAN", {"default": False}) 206 | } 207 | } 208 | 209 | RETURN_NAMES = ("LATENT_KF", ) 210 | RETURN_TYPES = ("LATENT_KEYFRAME", ) 211 | FUNCTION = "load_keyframe" 212 | CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes" 213 | 214 | def load_keyframe(self, float_strengths: Union[float, list[float]], 215 | prev_latent_kf: LatentKeyframeGroupImport=None, 216 | prev_latent_keyframe: LatentKeyframeGroupImport=None, # old name 217 | print_keyframes=False): 218 | prev_latent_keyframe = prev_latent_keyframe if prev_latent_keyframe else prev_latent_kf 219 | if not prev_latent_keyframe: 220 | prev_latent_keyframe = LatentKeyframeGroupImport() 221 | else: 222 | prev_latent_keyframe = prev_latent_keyframe.clone() 223 | curr_latent_keyframe = LatentKeyframeGroupImport() 224 | 225 | # if received a normal float input, do nothing 226 | if type(float_strengths) in (float, int): 227 | logger.info("No batched float_strengths passed into Latent Keyframe Batch Group node; will not create any new keyframes.") 228 | # if iterable, attempt to create LatentKeyframes with chosen strengths 229 | elif isinstance(float_strengths, Iterable): 230 | for idx, strength in enumerate(float_strengths): 231 | keyframe = LatentKeyframeImport(idx, strength) 232 | curr_latent_keyframe.add(keyframe) 233 | else: 234 | raise ValueError(f"Expected strengths to be an iterable input, but was {type(float_strengths).__repr__}.") 235 | 236 | if print_keyframes: 237 | for keyframe in curr_latent_keyframe.keyframes: 238 | logger.info(f"keyframe {keyframe.batch_index}:{keyframe.strength}") 239 | 240 | # replace values with prev_latent_keyframes 241 | for latent_keyframe in prev_latent_keyframe.keyframes: 242 | curr_latent_keyframe.add(latent_keyframe) 243 | 244 | return (curr_latent_keyframe,) 245 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /imports/IPAdapterPlus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import contextlib 3 | import os 4 | import math 5 | 6 | import comfy.utils 7 | import comfy.model_management 8 | from comfy.clip_vision import clip_preprocess 9 | from comfy.ldm.modules.attention import optimized_attention 10 | import folder_paths 11 | 12 | from torch import nn 13 | from PIL import Image 14 | import torch.nn.functional as F 15 | import torchvision.transforms as TT 16 | 17 | # set the models directory backward compatible 18 | GLOBAL_MODELS_DIR = os.path.join(folder_paths.models_dir, "ipadapter") 19 | MODELS_DIR = GLOBAL_MODELS_DIR if os.path.isdir(GLOBAL_MODELS_DIR) else os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") 20 | if "ipadapter" not in folder_paths.folder_names_and_paths: 21 | folder_paths.folder_names_and_paths["ipadapter"] = ([MODELS_DIR], folder_paths.supported_pt_extensions) 22 | else: 23 | folder_paths.folder_names_and_paths["ipadapter"][1].update(folder_paths.supported_pt_extensions) 24 | 25 | class MLPProjModelImport(torch.nn.Module): 26 | """SD model with image prompt""" 27 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): 28 | super().__init__() 29 | 30 | self.proj = torch.nn.Sequential( 31 | torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), 32 | torch.nn.GELU(), 33 | torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), 34 | torch.nn.LayerNorm(cross_attention_dim) 35 | ) 36 | 37 | def forward(self, image_embeds): 38 | clip_extra_context_tokens = self.proj(image_embeds) 39 | return clip_extra_context_tokens 40 | 41 | class ImageProjModelImport(nn.Module): 42 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 43 | super().__init__() 44 | 45 | self.cross_attention_dim = cross_attention_dim 46 | self.clip_extra_context_tokens = clip_extra_context_tokens 47 | self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 48 | self.norm = nn.LayerNorm(cross_attention_dim) 49 | 50 | def forward(self, image_embeds): 51 | embeds = image_embeds 52 | clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) 53 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 54 | return clip_extra_context_tokens 55 | 56 | class To_KVImport(nn.Module): 57 | def __init__(self, state_dict): 58 | super().__init__() 59 | 60 | self.to_kvs = nn.ModuleDict() 61 | for key, value in state_dict.items(): 62 | self.to_kvs[key.replace(".weight", "").replace(".", "_")] = nn.Linear(value.shape[1], value.shape[0], bias=False) 63 | self.to_kvs[key.replace(".weight", "").replace(".", "_")].weight.data = value 64 | 65 | def FeedForward(dim, mult=4): 66 | inner_dim = int(dim * mult) 67 | return nn.Sequential( 68 | nn.LayerNorm(dim), 69 | nn.Linear(dim, inner_dim, bias=False), 70 | nn.GELU(), 71 | nn.Linear(inner_dim, dim, bias=False), 72 | ) 73 | 74 | 75 | class PerceiverAttention(nn.Module): 76 | def __init__(self, *, dim, dim_head=64, heads=8): 77 | super().__init__() 78 | self.scale = dim_head**-0.5 79 | self.dim_head = dim_head 80 | self.heads = heads 81 | inner_dim = dim_head * heads 82 | 83 | self.norm1 = nn.LayerNorm(dim) 84 | self.norm2 = nn.LayerNorm(dim) 85 | 86 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 87 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 88 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 89 | 90 | 91 | def forward(self, x, latents): 92 | """ 93 | Args: 94 | x (torch.Tensor): image features 95 | shape (b, n1, D) 96 | latent (torch.Tensor): latent features 97 | shape (b, n2, D) 98 | """ 99 | x = self.norm1(x) 100 | latents = self.norm2(latents) 101 | 102 | b, l, _ = latents.shape 103 | 104 | q = self.to_q(latents) 105 | kv_input = torch.cat((x, latents), dim=-2) 106 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 107 | 108 | q = reshape_tensor(q, self.heads) 109 | k = reshape_tensor(k, self.heads) 110 | v = reshape_tensor(v, self.heads) 111 | 112 | # attention 113 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 114 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 115 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 116 | out = weight @ v 117 | 118 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 119 | 120 | return self.to_out(out) 121 | 122 | def reshape_tensor(x, heads): 123 | bs, length, width = x.shape 124 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head) 125 | x = x.view(bs, length, heads, -1) 126 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 127 | x = x.transpose(1, 2) 128 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 129 | x = x.reshape(bs, heads, length, -1) 130 | return x 131 | 132 | def set_model_patch_replace(model, patch_kwargs, key): 133 | to = model.model_options["transformer_options"] 134 | if "patches_replace" not in to: 135 | to["patches_replace"] = {} 136 | if "attn2" not in to["patches_replace"]: 137 | to["patches_replace"]["attn2"] = {} 138 | if key not in to["patches_replace"]["attn2"]: 139 | patch = CrossAttentionPatchImport(**patch_kwargs) 140 | to["patches_replace"]["attn2"][key] = patch 141 | else: 142 | to["patches_replace"]["attn2"][key].set_new_condition(**patch_kwargs) 143 | 144 | def image_add_noise(image, noise): 145 | image = image.permute([0,3,1,2]) 146 | torch.manual_seed(0) # use a fixed random for reproducible results 147 | transforms = TT.Compose([ 148 | TT.CenterCrop(min(image.shape[2], image.shape[3])), 149 | TT.Resize((224, 224), interpolation=TT.InterpolationMode.BICUBIC, antialias=True), 150 | TT.ElasticTransform(alpha=75.0, sigma=noise*3.5), # shuffle the image 151 | TT.RandomVerticalFlip(p=1.0), # flip the image to change the geometry even more 152 | TT.RandomHorizontalFlip(p=1.0), 153 | ]) 154 | image = transforms(image.cpu()) 155 | image = image.permute([0,2,3,1]) 156 | image = image + ((0.25*(1-noise)+0.05) * torch.randn_like(image) ) # add further random noise 157 | return image 158 | 159 | def zeroed_hidden_states(clip_vision, batch_size): 160 | image = torch.zeros([batch_size, 224, 224, 3]) 161 | comfy.model_management.load_model_gpu(clip_vision.patcher) 162 | pixel_values = clip_preprocess(image.to(clip_vision.load_device)) 163 | 164 | if clip_vision.dtype != torch.float32: 165 | precision_scope = torch.autocast 166 | else: 167 | precision_scope = lambda a, b: contextlib.nullcontext(a) 168 | 169 | with precision_scope(comfy.model_management.get_autocast_device(clip_vision.load_device), torch.float32): 170 | outputs = clip_vision.model(pixel_values, intermediate_output=-2) 171 | 172 | # we only need the penultimate hidden states 173 | outputs = outputs[1].to(comfy.model_management.intermediate_device()) 174 | 175 | return outputs 176 | 177 | def min_(tensor_list): 178 | # return the element-wise min of the tensor list. 179 | x = torch.stack(tensor_list) 180 | mn = x.min(axis=0)[0] 181 | return torch.clamp(mn, min=0) 182 | 183 | def max_(tensor_list): 184 | # return the element-wise max of the tensor list. 185 | x = torch.stack(tensor_list) 186 | mx = x.max(axis=0)[0] 187 | return torch.clamp(mx, max=1) 188 | 189 | # From https://github.com/Jamy-L/Pytorch-Contrast-Adaptive-Sharpening/ 190 | def contrast_adaptive_sharpening(image, amount): 191 | img = F.pad(image, pad=(1, 1, 1, 1)).cpu() 192 | 193 | a = img[..., :-2, :-2] 194 | b = img[..., :-2, 1:-1] 195 | c = img[..., :-2, 2:] 196 | d = img[..., 1:-1, :-2] 197 | e = img[..., 1:-1, 1:-1] 198 | f = img[..., 1:-1, 2:] 199 | g = img[..., 2:, :-2] 200 | h = img[..., 2:, 1:-1] 201 | i = img[..., 2:, 2:] 202 | 203 | # Computing contrast 204 | cross = (b, d, e, f, h) 205 | mn = min_(cross) 206 | mx = max_(cross) 207 | 208 | diag = (a, c, g, i) 209 | mn2 = min_(diag) 210 | mx2 = max_(diag) 211 | mx = mx + mx2 212 | mn = mn + mn2 213 | 214 | # Computing local weight 215 | inv_mx = torch.reciprocal(mx) 216 | amp = inv_mx * torch.minimum(mn, (2 - mx)) 217 | 218 | # scaling 219 | amp = torch.sqrt(amp) 220 | w = - amp * (amount * (1/5 - 1/8) + 1/8) 221 | div = torch.reciprocal(1 + 4*w) 222 | 223 | output = ((b + d + f + h)*w + e) * div 224 | output = output.clamp(0, 1) 225 | output = torch.nan_to_num(output) 226 | 227 | return (output) 228 | 229 | class IPAdapterImport(nn.Module): 230 | def __init__(self, ipadapter_model, cross_attention_dim=1024, output_cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4, is_sdxl=False, is_plus=False, is_full=False): 231 | super().__init__() 232 | 233 | self.clip_embeddings_dim = clip_embeddings_dim 234 | self.cross_attention_dim = cross_attention_dim 235 | self.output_cross_attention_dim = output_cross_attention_dim 236 | self.clip_extra_context_tokens = clip_extra_context_tokens 237 | self.is_sdxl = is_sdxl 238 | self.is_full = is_full 239 | 240 | self.image_proj_model = self.init_proj() if not is_plus else self.init_proj_plus() 241 | self.image_proj_model.load_state_dict(ipadapter_model["image_proj"]) 242 | self.ip_layers = To_KVImport(ipadapter_model["ip_adapter"]) 243 | 244 | def init_proj(self): 245 | image_proj_model = ImageProjModelImport( 246 | cross_attention_dim=self.cross_attention_dim, 247 | clip_embeddings_dim=self.clip_embeddings_dim, 248 | clip_extra_context_tokens=self.clip_extra_context_tokens 249 | ) 250 | return image_proj_model 251 | 252 | def init_proj_plus(self): 253 | if self.is_full: 254 | image_proj_model = MLPProjModelImport( 255 | cross_attention_dim=self.cross_attention_dim, 256 | clip_embeddings_dim=self.clip_embeddings_dim 257 | ) 258 | else: 259 | image_proj_model = ResamplerImport( 260 | dim=self.cross_attention_dim, 261 | depth=4, 262 | dim_head=64, 263 | heads=20 if self.is_sdxl else 12, 264 | num_queries=self.clip_extra_context_tokens, 265 | embedding_dim=self.clip_embeddings_dim, 266 | output_dim=self.output_cross_attention_dim, 267 | ff_mult=4 268 | ) 269 | return image_proj_model 270 | 271 | @torch.inference_mode() 272 | def get_image_embeds(self, clip_embed, clip_embed_zeroed): 273 | image_prompt_embeds = self.image_proj_model(clip_embed) 274 | uncond_image_prompt_embeds = self.image_proj_model(clip_embed_zeroed) 275 | return image_prompt_embeds, uncond_image_prompt_embeds 276 | 277 | class CrossAttentionPatchImport: 278 | # forward for patching 279 | def __init__(self, weight, ipadapter, device, dtype, number, cond, uncond, weight_type, mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False): 280 | self.weights = [weight] 281 | self.ipadapters = [ipadapter] 282 | self.conds = [cond] 283 | self.unconds = [uncond] 284 | self.device = 'cuda' if 'cuda' in device.type else 'cpu' 285 | self.dtype = dtype if 'cuda' in self.device else torch.bfloat16 286 | self.number = number 287 | self.weight_type = [weight_type] 288 | self.masks = [mask] 289 | self.sigma_start = [sigma_start] 290 | self.sigma_end = [sigma_end] 291 | self.unfold_batch = [unfold_batch] 292 | 293 | self.k_key = str(self.number*2+1) + "_to_k_ip" 294 | self.v_key = str(self.number*2+1) + "_to_v_ip" 295 | 296 | def set_new_condition(self, weight, ipadapter, device, dtype, number, cond, uncond, weight_type, mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False): 297 | self.weights.append(weight) 298 | self.ipadapters.append(ipadapter) 299 | self.conds.append(cond) 300 | self.unconds.append(uncond) 301 | self.masks.append(mask) 302 | self.device = 'cuda' if 'cuda' in device.type else 'cpu' 303 | self.dtype = dtype if 'cuda' in self.device else torch.bfloat16 304 | self.weight_type.append(weight_type) 305 | self.sigma_start.append(sigma_start) 306 | self.sigma_end.append(sigma_end) 307 | self.unfold_batch.append(unfold_batch) 308 | 309 | def __call__(self, n, context_attn2, value_attn2, extra_options): 310 | org_dtype = n.dtype 311 | cond_or_uncond = extra_options["cond_or_uncond"] 312 | sigma = extra_options["sigmas"][0].item() if 'sigmas' in extra_options else 999999999.9 313 | 314 | # extra options for AnimateDiff 315 | ad_params = extra_options['ad_params'] if "ad_params" in extra_options else None 316 | 317 | with torch.autocast(device_type=self.device, dtype=self.dtype): 318 | q = n 319 | k = context_attn2 320 | v = value_attn2 321 | b = q.shape[0] 322 | qs = q.shape[1] 323 | batch_prompt = b // len(cond_or_uncond) 324 | out = optimized_attention(q, k, v, extra_options["n_heads"]) 325 | _, _, lh, lw = extra_options["original_shape"] 326 | 327 | for weight, cond, uncond, ipadapter, mask, weight_type, sigma_start, sigma_end, unfold_batch in zip(self.weights, self.conds, self.unconds, self.ipadapters, self.masks, self.weight_type, self.sigma_start, self.sigma_end, self.unfold_batch): 328 | if sigma > sigma_start or sigma < sigma_end: 329 | continue 330 | 331 | if unfold_batch and cond.shape[0] > 1: 332 | # Check AnimateDiff context window 333 | if ad_params is not None and ad_params["sub_idxs"] is not None: 334 | # if images length matches or exceeds full_length get sub_idx images 335 | if cond.shape[0] >= ad_params["full_length"]: 336 | cond = torch.Tensor(cond[ad_params["sub_idxs"]]) 337 | uncond = torch.Tensor(uncond[ad_params["sub_idxs"]]) 338 | # otherwise, need to do more to get proper sub_idxs masks 339 | else: 340 | # check if images length matches full_length - if not, make it match 341 | if cond.shape[0] < ad_params["full_length"]: 342 | cond = torch.cat((cond, cond[-1:].repeat((ad_params["full_length"]-cond.shape[0], 1, 1))), dim=0) 343 | uncond = torch.cat((uncond, uncond[-1:].repeat((ad_params["full_length"]-uncond.shape[0], 1, 1))), dim=0) 344 | # if we have too many remove the excess (should not happen, but just in case) 345 | if cond.shape[0] > ad_params["full_length"]: 346 | cond = cond[:ad_params["full_length"]] 347 | uncond = uncond[:ad_params["full_length"]] 348 | cond = cond[ad_params["sub_idxs"]] 349 | uncond = uncond[ad_params["sub_idxs"]] 350 | 351 | # if we don't have enough reference images repeat the last one until we reach the right size 352 | if cond.shape[0] < batch_prompt: 353 | cond = torch.cat((cond, cond[-1:].repeat((batch_prompt-cond.shape[0], 1, 1))), dim=0) 354 | uncond = torch.cat((uncond, uncond[-1:].repeat((batch_prompt-uncond.shape[0], 1, 1))), dim=0) 355 | # if we have too many remove the exceeding 356 | elif cond.shape[0] > batch_prompt: 357 | cond = cond[:batch_prompt] 358 | uncond = uncond[:batch_prompt] 359 | 360 | k_cond = ipadapter.ip_layers.to_kvs[self.k_key](cond) 361 | k_uncond = ipadapter.ip_layers.to_kvs[self.k_key](uncond) 362 | v_cond = ipadapter.ip_layers.to_kvs[self.v_key](cond) 363 | v_uncond = ipadapter.ip_layers.to_kvs[self.v_key](uncond) 364 | else: 365 | k_cond = ipadapter.ip_layers.to_kvs[self.k_key](cond).repeat(batch_prompt, 1, 1) 366 | k_uncond = ipadapter.ip_layers.to_kvs[self.k_key](uncond).repeat(batch_prompt, 1, 1) 367 | v_cond = ipadapter.ip_layers.to_kvs[self.v_key](cond).repeat(batch_prompt, 1, 1) 368 | v_uncond = ipadapter.ip_layers.to_kvs[self.v_key](uncond).repeat(batch_prompt, 1, 1) 369 | 370 | if weight_type.startswith("linear"): 371 | ip_k = torch.cat([(k_cond, k_uncond)[i] for i in cond_or_uncond], dim=0) * weight 372 | ip_v = torch.cat([(v_cond, v_uncond)[i] for i in cond_or_uncond], dim=0) * weight 373 | else: 374 | ip_k = torch.cat([(k_cond, k_uncond)[i] for i in cond_or_uncond], dim=0) 375 | ip_v = torch.cat([(v_cond, v_uncond)[i] for i in cond_or_uncond], dim=0) 376 | 377 | if weight_type.startswith("channel"): 378 | # code by Lvmin Zhang at Stanford University as also seen on Fooocus IPAdapter implementation 379 | # please read licensing notes https://github.com/lllyasviel/Fooocus/blob/main/fooocus_extras/ip_adapter.py#L225 380 | ip_v_mean = torch.mean(ip_v, dim=1, keepdim=True) 381 | ip_v_offset = ip_v - ip_v_mean 382 | _, _, C = ip_k.shape 383 | channel_penalty = float(C) / 1280.0 384 | W = weight * channel_penalty 385 | ip_k = ip_k * W 386 | ip_v = ip_v_offset + ip_v_mean * W 387 | 388 | out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"]) 389 | if weight_type.startswith("original"): 390 | out_ip = out_ip * weight 391 | 392 | if mask is not None: 393 | # TODO: needs checking 394 | mask_h = max(1, round(lh / math.sqrt(lh * lw / qs))) 395 | mask_w = qs // mask_h 396 | 397 | # check if using AnimateDiff and sliding context window 398 | if (mask.shape[0] > 1 and ad_params is not None and ad_params["sub_idxs"] is not None): 399 | # if mask length matches or exceeds full_length, just get sub_idx masks, resize, and continue 400 | if mask.shape[0] >= ad_params["full_length"]: 401 | mask_downsample = torch.Tensor(mask[ad_params["sub_idxs"]]) 402 | mask_downsample = F.interpolate(mask_downsample.unsqueeze(1), size=(mask_h, mask_w), mode="bicubic").squeeze(1) 403 | # otherwise, need to do more to get proper sub_idxs masks 404 | else: 405 | # resize to needed attention size (to save on memory) 406 | mask_downsample = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bicubic").squeeze(1) 407 | # check if mask length matches full_length - if not, make it match 408 | if mask_downsample.shape[0] < ad_params["full_length"]: 409 | mask_downsample = torch.cat((mask_downsample, mask_downsample[-1:].repeat((ad_params["full_length"]-mask_downsample.shape[0], 1, 1))), dim=0) 410 | # if we have too many remove the excess (should not happen, but just in case) 411 | if mask_downsample.shape[0] > ad_params["full_length"]: 412 | mask_downsample = mask_downsample[:ad_params["full_length"]] 413 | # now, select sub_idxs masks 414 | mask_downsample = mask_downsample[ad_params["sub_idxs"]] 415 | # otherwise, perform usual mask interpolation 416 | else: 417 | mask_downsample = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bicubic").squeeze(1) 418 | 419 | # if we don't have enough masks repeat the last one until we reach the right size 420 | if mask_downsample.shape[0] < batch_prompt: 421 | mask_downsample = torch.cat((mask_downsample, mask_downsample[-1:, :, :].repeat((batch_prompt-mask_downsample.shape[0], 1, 1))), dim=0) 422 | # if we have too many remove the exceeding 423 | elif mask_downsample.shape[0] > batch_prompt: 424 | mask_downsample = mask_downsample[:batch_prompt, :, :] 425 | 426 | # repeat the masks 427 | mask_downsample = mask_downsample.repeat(len(cond_or_uncond), 1, 1) 428 | mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1, 1).repeat(1, 1, out.shape[2]) 429 | 430 | out_ip = out_ip * mask_downsample 431 | 432 | out = out + out_ip 433 | 434 | return out.to(dtype=org_dtype) 435 | 436 | 437 | 438 | class IPAdapterApplyImport: 439 | @classmethod 440 | def INPUT_TYPES(s): 441 | return { 442 | "required": { 443 | "ipadapter": ("IPADAPTER", ), 444 | "clip_vision": ("CLIP_VISION",), 445 | "image": ("IMAGE",), 446 | "model": ("MODEL", ), 447 | "weight": ("FLOAT", { "default": 1.0, "min": -1, "max": 3, "step": 0.05 }), 448 | "noise": ("FLOAT", { "default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01 }), 449 | "weight_type": (["original", "linear", "channel penalty"], ), 450 | "start_at": ("FLOAT", { "default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001 }), 451 | "end_at": ("FLOAT", { "default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001 }), 452 | "unfold_batch": ("BOOLEAN", { "default": False }), 453 | }, 454 | "optional": { 455 | "attn_mask": ("MASK",), 456 | } 457 | } 458 | 459 | RETURN_TYPES = ("MODEL",) 460 | FUNCTION = "apply_ipadapter" 461 | CATEGORY = "ipadapter" 462 | 463 | def apply_ipadapter(self, ipadapter, model, weight, clip_vision=None, image=None, weight_type="original", noise=None, embeds=None, attn_mask=None, start_at=0.0, end_at=1.0, unfold_batch=False): 464 | self.dtype = model.model.diffusion_model.dtype 465 | self.device = comfy.model_management.get_torch_device() 466 | self.weight = weight 467 | self.is_full = "proj.0.weight" in ipadapter["image_proj"] 468 | self.is_plus = self.is_full or "latents" in ipadapter["image_proj"] 469 | 470 | output_cross_attention_dim = ipadapter["ip_adapter"]["1.to_k_ip.weight"].shape[1] 471 | self.is_sdxl = output_cross_attention_dim == 2048 472 | cross_attention_dim = 1280 if self.is_plus and self.is_sdxl else output_cross_attention_dim 473 | clip_extra_context_tokens = 16 if self.is_plus else 4 474 | 475 | if embeds is not None: 476 | embeds = torch.unbind(embeds) 477 | clip_embed = embeds[0].cpu() 478 | clip_embed_zeroed = embeds[1].cpu() 479 | else: 480 | if image.shape[1] != image.shape[2]: 481 | print("\033[33mINFO: the IPAdapter reference image is not a square, CLIPImageProcessor will resize and crop it at the center. If the main focus of the picture is not in the middle the result might not be what you are expecting.\033[0m") 482 | 483 | clip_embed = clip_vision.encode_image(image) 484 | neg_image = image_add_noise(image, noise) if noise > 0 else None 485 | 486 | if self.is_plus: 487 | clip_embed = clip_embed.penultimate_hidden_states 488 | if noise > 0: 489 | clip_embed_zeroed = clip_vision.encode_image(neg_image).penultimate_hidden_states 490 | else: 491 | clip_embed_zeroed = zeroed_hidden_states(clip_vision, image.shape[0]) 492 | else: 493 | clip_embed = clip_embed.image_embeds 494 | if noise > 0: 495 | clip_embed_zeroed = clip_vision.encode_image(neg_image).image_embeds 496 | else: 497 | clip_embed_zeroed = torch.zeros_like(clip_embed) 498 | 499 | clip_embeddings_dim = clip_embed.shape[-1] 500 | 501 | self.ipadapter = IPAdapterImport( 502 | ipadapter, 503 | cross_attention_dim=cross_attention_dim, 504 | output_cross_attention_dim=output_cross_attention_dim, 505 | clip_embeddings_dim=clip_embeddings_dim, 506 | clip_extra_context_tokens=clip_extra_context_tokens, 507 | is_sdxl=self.is_sdxl, 508 | is_plus=self.is_plus, 509 | is_full=self.is_full, 510 | ) 511 | 512 | self.ipadapter.to(self.device, dtype=self.dtype) 513 | 514 | image_prompt_embeds, uncond_image_prompt_embeds = self.ipadapter.get_image_embeds(clip_embed.to(self.device, self.dtype), clip_embed_zeroed.to(self.device, self.dtype)) 515 | image_prompt_embeds = image_prompt_embeds.to(self.device, dtype=self.dtype) 516 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.to(self.device, dtype=self.dtype) 517 | 518 | work_model = model.clone() 519 | 520 | if attn_mask is not None: 521 | attn_mask = attn_mask.to(self.device) 522 | 523 | sigma_start = model.model.model_sampling.percent_to_sigma(start_at) 524 | sigma_end = model.model.model_sampling.percent_to_sigma(end_at) 525 | 526 | patch_kwargs = { 527 | "number": 0, 528 | "weight": self.weight, 529 | "ipadapter": self.ipadapter, 530 | "device": self.device, 531 | "dtype": self.dtype, 532 | "cond": image_prompt_embeds, 533 | "uncond": uncond_image_prompt_embeds, 534 | "weight_type": weight_type, 535 | "mask": attn_mask, 536 | "sigma_start": sigma_start, 537 | "sigma_end": sigma_end, 538 | "unfold_batch": unfold_batch, 539 | } 540 | 541 | if not self.is_sdxl: 542 | for id in [1,2,4,5,7,8]: # id of input_blocks that have cross attention 543 | set_model_patch_replace(work_model, patch_kwargs, ("input", id)) 544 | patch_kwargs["number"] += 1 545 | for id in [3,4,5,6,7,8,9,10,11]: # id of output_blocks that have cross attention 546 | set_model_patch_replace(work_model, patch_kwargs, ("output", id)) 547 | patch_kwargs["number"] += 1 548 | set_model_patch_replace(work_model, patch_kwargs, ("middle", 0)) 549 | else: 550 | for id in [4,5,7,8]: # id of input_blocks that have cross attention 551 | block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth 552 | for index in block_indices: 553 | set_model_patch_replace(work_model, patch_kwargs, ("input", id, index)) 554 | patch_kwargs["number"] += 1 555 | for id in range(6): # id of output_blocks that have cross attention 556 | block_indices = range(2) if id in [3, 4, 5] else range(10) # transformer_depth 557 | for index in block_indices: 558 | set_model_patch_replace(work_model, patch_kwargs, ("output", id, index)) 559 | patch_kwargs["number"] += 1 560 | for index in range(10): 561 | set_model_patch_replace(work_model, patch_kwargs, ("middle", 0, index)) 562 | patch_kwargs["number"] += 1 563 | 564 | return (work_model, ) 565 | 566 | def prep_image(image, interpolation="LANCZOS", crop_position="center", sharpening=0.0): 567 | _, oh, ow, _ = image.shape 568 | output = image.permute([0,3,1,2]) 569 | 570 | if "pad" in crop_position: 571 | target_length = max(oh, ow) 572 | pad_l = (target_length - ow) // 2 573 | pad_r = (target_length - ow) - pad_l 574 | pad_t = (target_length - oh) // 2 575 | pad_b = (target_length - oh) - pad_t 576 | output = F.pad(output, (pad_l, pad_r, pad_t, pad_b), value=0, mode="constant") 577 | else: 578 | crop_size = min(oh, ow) 579 | x = (ow-crop_size) // 2 580 | y = (oh-crop_size) // 2 581 | if "top" in crop_position: 582 | y = 0 583 | elif "bottom" in crop_position: 584 | y = oh-crop_size 585 | elif "left" in crop_position: 586 | x = 0 587 | elif "right" in crop_position: 588 | x = ow-crop_size 589 | 590 | x2 = x+crop_size 591 | y2 = y+crop_size 592 | 593 | # crop 594 | output = output[:, :, y:y2, x:x2] 595 | 596 | # resize (apparently PIL resize is better than tourchvision interpolate) 597 | imgs = [] 598 | for i in range(output.shape[0]): 599 | img = TT.ToPILImage()(output[i]) 600 | img = img.resize((224,224), resample=Image.Resampling[interpolation]) 601 | imgs.append(TT.ToTensor()(img)) 602 | output = torch.stack(imgs, dim=0) 603 | 604 | if sharpening > 0: 605 | output = contrast_adaptive_sharpening(output, sharpening) 606 | 607 | output = output.permute([0,2,3,1]) 608 | 609 | return (output,) 610 | 611 | class ResamplerImport(nn.Module): 612 | def __init__( 613 | self, 614 | dim=1024, 615 | depth=8, 616 | dim_head=64, 617 | heads=16, 618 | num_queries=8, 619 | embedding_dim=768, 620 | output_dim=1024, 621 | ff_mult=4, 622 | ): 623 | super().__init__() 624 | 625 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 626 | 627 | self.proj_in = nn.Linear(embedding_dim, dim) 628 | 629 | self.proj_out = nn.Linear(dim, output_dim) 630 | self.norm_out = nn.LayerNorm(output_dim) 631 | 632 | self.layers = nn.ModuleList([]) 633 | for _ in range(depth): 634 | self.layers.append( 635 | nn.ModuleList( 636 | [ 637 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 638 | FeedForward(dim=dim, mult=ff_mult), 639 | ] 640 | ) 641 | ) 642 | 643 | def forward(self, x): 644 | 645 | latents = self.latents.repeat(x.size(0), 1, 1) 646 | 647 | x = self.proj_in(x) 648 | 649 | for attn, ff in self.layers: 650 | latents = attn(x, latents) + latents 651 | latents = ff(latents) + latents 652 | 653 | latents = self.proj_out(latents) 654 | return self.norm_out(latents) 655 | 656 | 657 | class IPAdapterEncoderImport: 658 | @classmethod 659 | def INPUT_TYPES(s): 660 | return {"required": { 661 | "clip_vision": ("CLIP_VISION",), 662 | "image_1": ("IMAGE",), 663 | "ipadapter_plus": ("BOOLEAN", { "default": False }), 664 | "noise": ("FLOAT", { "default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01 }), 665 | "weight_1": ("FLOAT", { "default": 1.0, "min": 0, "max": 1.0, "step": 0.01 }), 666 | }, 667 | "optional": { 668 | "image_2": ("IMAGE",), 669 | "image_3": ("IMAGE",), 670 | "image_4": ("IMAGE",), 671 | "weight_2": ("FLOAT", { "default": 1.0, "min": 0, "max": 1.0, "step": 0.01 }), 672 | "weight_3": ("FLOAT", { "default": 1.0, "min": 0, "max": 1.0, "step": 0.01 }), 673 | "weight_4": ("FLOAT", { "default": 1.0, "min": 0, "max": 1.0, "step": 0.01 }), 674 | } 675 | } 676 | 677 | RETURN_TYPES = ("EMBEDS",) 678 | FUNCTION = "preprocess" 679 | CATEGORY = "ipadapter" 680 | 681 | def preprocess(self, clip_vision, image_1, ipadapter_plus, noise, weight_1, image_2=None, image_3=None, image_4=None, weight_2=1.0, weight_3=1.0, weight_4=1.0): 682 | weight_1 *= (0.1 + (weight_1 - 0.1)) 683 | weight_1 = 1.19e-05 if weight_1 <= 1.19e-05 else weight_1 684 | weight_2 *= (0.1 + (weight_2 - 0.1)) 685 | weight_2 = 1.19e-05 if weight_2 <= 1.19e-05 else weight_2 686 | weight_3 *= (0.1 + (weight_3 - 0.1)) 687 | weight_3 = 1.19e-05 if weight_3 <= 1.19e-05 else weight_3 688 | weight_4 *= (0.1 + (weight_4 - 0.1)) 689 | weight_5 = 1.19e-05 if weight_4 <= 1.19e-05 else weight_4 690 | 691 | image = image_1 692 | weight = [weight_1]*image_1.shape[0] 693 | 694 | if image_2 is not None: 695 | if image_1.shape[1:] != image_2.shape[1:]: 696 | image_2 = comfy.utils.common_upscale(image_2.movedim(-1,1), image.shape[2], image.shape[1], "bilinear", "center").movedim(1,-1) 697 | image = torch.cat((image, image_2), dim=0) 698 | weight += [weight_2]*image_2.shape[0] 699 | if image_3 is not None: 700 | if image.shape[1:] != image_3.shape[1:]: 701 | image_3 = comfy.utils.common_upscale(image_3.movedim(-1,1), image.shape[2], image.shape[1], "bilinear", "center").movedim(1,-1) 702 | image = torch.cat((image, image_3), dim=0) 703 | weight += [weight_3]*image_3.shape[0] 704 | if image_4 is not None: 705 | if image.shape[1:] != image_4.shape[1:]: 706 | image_4 = comfy.utils.common_upscale(image_4.movedim(-1,1), image.shape[2], image.shape[1], "bilinear", "center").movedim(1,-1) 707 | image = torch.cat((image, image_4), dim=0) 708 | weight += [weight_4]*image_4.shape[0] 709 | 710 | clip_embed = clip_vision.encode_image(image) 711 | neg_image = image_add_noise(image, noise) if noise > 0 else None 712 | 713 | if ipadapter_plus: 714 | clip_embed = clip_embed.penultimate_hidden_states 715 | if noise > 0: 716 | clip_embed_zeroed = clip_vision.encode_image(neg_image).penultimate_hidden_states 717 | else: 718 | clip_embed_zeroed = zeroed_hidden_states(clip_vision, image.shape[0]) 719 | else: 720 | clip_embed = clip_embed.image_embeds 721 | if noise > 0: 722 | clip_embed_zeroed = clip_vision.encode_image(neg_image).image_embeds 723 | else: 724 | clip_embed_zeroed = torch.zeros_like(clip_embed) 725 | 726 | if any(e != 1.0 for e in weight): 727 | weight = torch.tensor(weight).unsqueeze(-1) if not ipadapter_plus else torch.tensor(weight).unsqueeze(-1).unsqueeze(-1) 728 | clip_embed = clip_embed * weight 729 | 730 | output = torch.stack((clip_embed, clip_embed_zeroed)) 731 | 732 | return( output, ) 733 | 734 | 735 | 736 | 737 | class IPAdapterBatchEmbedsImport: 738 | @classmethod 739 | def INPUT_TYPES(s): 740 | return {"required": { 741 | "embed1": ("EMBEDS",), 742 | "embed2": ("EMBEDS",), 743 | }} 744 | 745 | RETURN_TYPES = ("EMBEDS",) 746 | FUNCTION = "batch" 747 | CATEGORY = "ipadapter" 748 | 749 | def batch(self, embed1, embed2): 750 | output = torch.cat((embed1, embed2), dim=1) 751 | return (output, ) -------------------------------------------------------------------------------- /imports/AdvancedControlNet/control.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from torch import Tensor 3 | import torch 4 | 5 | import comfy.utils 6 | import comfy.controlnet as comfy_cn 7 | from comfy.controlnet import ControlBase, ControlNet, ControlLora, T2IAdapter, broadcast_image_to 8 | 9 | 10 | def get_properly_arranged_t2i_weights(initial_weights: list[float]): 11 | new_weights = [] 12 | new_weights.extend([initial_weights[0]]*3) 13 | new_weights.extend([initial_weights[1]]*3) 14 | new_weights.extend([initial_weights[2]]*3) 15 | new_weights.extend([initial_weights[3]]*3) 16 | return new_weights 17 | 18 | 19 | class ControlWeightTypeImport: 20 | DEFAULT = "default" 21 | UNIVERSAL = "universal" 22 | T2IADAPTER = "t2iadapter" 23 | CONTROLNET = "controlnet" 24 | CONTROLLORA = "controllora" 25 | CONTROLLLLITE = "controllllite" 26 | 27 | 28 | class ControlWeightsImport: 29 | def __init__(self, weight_type: str, base_multiplier: float=1.0, flip_weights: bool=False, weights: list[float]=None, weight_mask: Tensor=None): 30 | self.weight_type = weight_type 31 | self.base_multiplier = base_multiplier 32 | self.flip_weights = flip_weights 33 | self.weights = weights 34 | if self.weights is not None and self.flip_weights: 35 | self.weights.reverse() 36 | self.weight_mask = weight_mask 37 | 38 | def get(self, idx: int) -> Union[float, Tensor]: 39 | # if weights is not none, return index 40 | if self.weights is not None: 41 | return self.weights[idx] 42 | return 1.0 43 | 44 | @classmethod 45 | def default(cls): 46 | return cls(ControlWeightTypeImport.DEFAULT) 47 | 48 | @classmethod 49 | def universal(cls, base_multiplier: float, flip_weights: bool=False): 50 | return cls(ControlWeightTypeImport.UNIVERSAL, base_multiplier=base_multiplier, flip_weights=flip_weights) 51 | 52 | @classmethod 53 | def universal_mask(cls, weight_mask: Tensor): 54 | return cls(ControlWeightTypeImport.UNIVERSAL, weight_mask=weight_mask) 55 | 56 | @classmethod 57 | def t2iadapter(cls, weights: list[float]=None, flip_weights: bool=False): 58 | if weights is None: 59 | weights = [1.0]*12 60 | return cls(ControlWeightTypeImport.T2IADAPTER, weights=weights,flip_weights=flip_weights) 61 | 62 | @classmethod 63 | def controlnet(cls, weights: list[float]=None, flip_weights: bool=False): 64 | if weights is None: 65 | weights = [1.0]*13 66 | return cls(ControlWeightTypeImport.CONTROLNET, weights=weights, flip_weights=flip_weights) 67 | 68 | @classmethod 69 | def controllora(cls, weights: list[float]=None, flip_weights: bool=False): 70 | if weights is None: 71 | weights = [1.0]*10 72 | return cls(ControlWeightTypeImport.CONTROLLORA, weights=weights, flip_weights=flip_weights) 73 | 74 | @classmethod 75 | def controllllite(cls, weights: list[float]=None, flip_weights: bool=False): 76 | if weights is None: 77 | # TODO: make this have a real value 78 | weights = [1.0]*200 79 | return cls(ControlWeightTypeImport.CONTROLLLLITE, weights=weights, flip_weights=flip_weights) 80 | 81 | 82 | class StrengthInterpolationImport: 83 | LINEAR = "linear" 84 | EASE_IN = "ease-in" 85 | EASE_OUT = "ease-out" 86 | EASE_IN_OUT = "ease-in-out" 87 | NONE = "none" 88 | 89 | 90 | class LatentKeyframeImport: 91 | def __init__(self, batch_index: int, strength: float) -> None: 92 | self.batch_index = batch_index 93 | self.strength = strength 94 | 95 | 96 | # always maintain sorted state (by batch_index of LatentKeyframe) 97 | class LatentKeyframeGroupImport: 98 | def __init__(self) -> None: 99 | self.keyframes: list[LatentKeyframeImport] = [] 100 | 101 | def add(self, keyframe: LatentKeyframeImport) -> None: 102 | added = False 103 | # replace existing keyframe if same batch_index 104 | for i in range(len(self.keyframes)): 105 | if self.keyframes[i].batch_index == keyframe.batch_index: 106 | self.keyframes[i] = keyframe 107 | added = True 108 | break 109 | if not added: 110 | self.keyframes.append(keyframe) 111 | self.keyframes.sort(key=lambda k: k.batch_index) 112 | 113 | def get_index(self, index: int) -> Union[LatentKeyframeImport, None]: 114 | try: 115 | return self.keyframes[index] 116 | except IndexError: 117 | return None 118 | 119 | def __getitem__(self, index) -> LatentKeyframeImport: 120 | return self.keyframes[index] 121 | 122 | def is_empty(self) -> bool: 123 | return len(self.keyframes) == 0 124 | 125 | def clone(self) -> 'LatentKeyframeGroupImport': 126 | cloned = LatentKeyframeGroupImport() 127 | for tk in self.keyframes: 128 | cloned.add(tk) 129 | return cloned 130 | 131 | 132 | class TimestepKeyframeImport: 133 | def __init__(self, 134 | start_percent: float = 0.0, 135 | strength: float = 1.0, 136 | interpolation: str = StrengthInterpolationImport.NONE, 137 | control_weights: ControlWeightsImport = None, 138 | latent_keyframes: LatentKeyframeGroupImport = None, 139 | null_latent_kf_strength: float = 0.0, 140 | inherit_missing: bool = True, 141 | guarantee_usage: bool = True, 142 | mask_hint_orig: Tensor = None) -> None: 143 | self.start_percent = start_percent 144 | self.start_t = 999999999.9 145 | self.strength = strength 146 | self.interpolation = interpolation 147 | self.control_weights = control_weights 148 | self.latent_keyframes = latent_keyframes 149 | self.null_latent_kf_strength = null_latent_kf_strength 150 | self.inherit_missing = inherit_missing 151 | self.guarantee_usage = guarantee_usage 152 | self.mask_hint_orig = mask_hint_orig 153 | 154 | def has_control_weights(self): 155 | return self.control_weights is not None 156 | 157 | def has_latent_keyframes(self): 158 | return self.latent_keyframes is not None 159 | 160 | def has_mask_hint(self): 161 | return self.mask_hint_orig is not None 162 | 163 | 164 | @classmethod 165 | def default(cls) -> 'TimestepKeyframeImport': 166 | return cls(0.0) 167 | 168 | 169 | # always maintain sorted state (by start_percent of TimestepKeyFrame) 170 | class TimestepKeyframeGroupImport: 171 | def __init__(self) -> None: 172 | self.keyframes: list[TimestepKeyframeImport] = [] 173 | self.keyframes.append(TimestepKeyframeImport.default()) 174 | 175 | def add(self, keyframe: TimestepKeyframeImport) -> None: 176 | added = False 177 | # replace existing keyframe if same start_percent 178 | for i in range(len(self.keyframes)): 179 | if self.keyframes[i].start_percent == keyframe.start_percent: 180 | self.keyframes[i] = keyframe 181 | added = True 182 | break 183 | if not added: 184 | self.keyframes.append(keyframe) 185 | self.keyframes.sort(key=lambda k: k.start_percent) 186 | 187 | def get_index(self, index: int) -> Union[TimestepKeyframeImport, None]: 188 | try: 189 | return self.keyframes[index] 190 | except IndexError: 191 | return None 192 | 193 | def has_index(self, index: int) -> int: 194 | return index >=0 and index < len(self.keyframes) 195 | 196 | def __getitem__(self, index) -> TimestepKeyframeImport: 197 | return self.keyframes[index] 198 | 199 | def __len__(self) -> int: 200 | return len(self.keyframes) 201 | 202 | def is_empty(self) -> bool: 203 | return len(self.keyframes) == 0 204 | 205 | def clone(self) -> 'TimestepKeyframeGroupImport': 206 | cloned = TimestepKeyframeGroupImport() 207 | for tk in self.keyframes: 208 | cloned.add(tk) 209 | return cloned 210 | 211 | @classmethod 212 | def default(cls, keyframe: TimestepKeyframeImport) -> 'TimestepKeyframeGroupImport': 213 | group = cls() 214 | group.keyframes[0] = keyframe 215 | return group 216 | 217 | 218 | # used to inject ControlNetAdvancedImport and T2IAdapterAdvancedImport control_merge function 219 | 220 | 221 | class AdvancedControlBaseImport: 222 | def __init__(self, base: ControlBase, timestep_keyframes: TimestepKeyframeGroupImport, weights_default: ControlWeightsImport): 223 | self.base = base 224 | self.compatible_weights = [ControlWeightTypeImport.UNIVERSAL] 225 | self.add_compatible_weight(weights_default.weight_type) 226 | # mask for which parts of controlnet output to keep 227 | self.mask_cond_hint_original = None 228 | self.mask_cond_hint = None 229 | self.tk_mask_cond_hint_original = None 230 | self.tk_mask_cond_hint = None 231 | self.weight_mask_cond_hint = None 232 | # actual index values 233 | self.sub_idxs = None 234 | self.full_latent_length = 0 235 | self.context_length = 0 236 | # timesteps 237 | self.t: Tensor = None 238 | self.batched_number: int = None 239 | # weights + override 240 | self.weights: ControlWeightsImport = None 241 | self.weights_default: ControlWeightsImport = weights_default 242 | self.weights_override: ControlWeightsImport = None 243 | # latent keyframe + override 244 | self.latent_keyframes: LatentKeyframeGroupImport = None 245 | self.latent_keyframe_override: LatentKeyframeGroupImport = None 246 | # initialize timestep_keyframes 247 | self.set_timestep_keyframes(timestep_keyframes) 248 | # override some functions 249 | self.get_control = self.get_control_inject 250 | self.control_merge = self.control_merge_inject#.__get__(self, type(self)) 251 | self.pre_run = self.pre_run_inject 252 | self.cleanup = self.cleanup_inject 253 | 254 | def add_compatible_weight(self, control_weight_type: str): 255 | self.compatible_weights.append(control_weight_type) 256 | 257 | def verify_all_weights(self, throw_error=True): 258 | # first, check if override exists - if so, only need to check the override 259 | if self.weights_override is not None: 260 | if self.weights_override.weight_type not in self.compatible_weights: 261 | msg = f"Weight override is type {self.weights_override.weight_type}, but loaded {type(self).__name__}" + \ 262 | f"only supports {self.compatible_weights} weights." 263 | raise WeightTypeExceptionImport(msg) 264 | # otherwise, check all timestep keyframe weights 265 | else: 266 | for tk in self.timestep_keyframes.keyframes: 267 | if tk.has_control_weights() and tk.control_weights.weight_type not in self.compatible_weights: 268 | msg = f"Weight on Timestep Keyframe with start_percent={tk.start_percent} is type" + \ 269 | f"{tk.control_weights.weight_type}, but loaded {type(self).__name__} only supports {self.compatible_weights} weights." 270 | raise WeightTypeExceptionImport(msg) 271 | 272 | def set_timestep_keyframes(self, timestep_keyframes: TimestepKeyframeGroupImport): 273 | self.timestep_keyframes = timestep_keyframes if timestep_keyframes else TimestepKeyframeGroupImport() 274 | # prepare first timestep_keyframe related stuff 275 | self.current_timestep_keyframe = None 276 | self.current_timestep_index = -1 277 | self.next_timestep_keyframe = None 278 | self.weights = None 279 | self.latent_keyframes = None 280 | 281 | def prepare_current_timestep(self, t: Tensor, batched_number: int): 282 | self.t = t 283 | self.batched_number = batched_number 284 | # get current step percent 285 | curr_t: float = t[0] 286 | prev_index = self.current_timestep_index 287 | # if has next index, loop through and see if need to switch 288 | if self.timestep_keyframes.has_index(self.current_timestep_index+1): 289 | for i in range(self.current_timestep_index+1, len(self.timestep_keyframes)): 290 | eval_tk = self.timestep_keyframes[i] 291 | # check if start percent is less or equal to curr_t 292 | if eval_tk.start_t >= curr_t: 293 | self.current_timestep_index = i 294 | self.current_timestep_keyframe = eval_tk 295 | # keep track of control weights, latent keyframes, and masks, 296 | # accounting for inherit_missing 297 | if self.current_timestep_keyframe.has_control_weights(): 298 | self.weights = self.current_timestep_keyframe.control_weights 299 | elif not self.current_timestep_keyframe.inherit_missing: 300 | self.weights = self.weights_default 301 | if self.current_timestep_keyframe.has_latent_keyframes(): 302 | self.latent_keyframes = self.current_timestep_keyframe.latent_keyframes 303 | elif not self.current_timestep_keyframe.inherit_missing: 304 | self.latent_keyframes = None 305 | if self.current_timestep_keyframe.has_mask_hint(): 306 | self.tk_mask_cond_hint_original = self.current_timestep_keyframe.mask_hint_orig 307 | elif not self.current_timestep_keyframe.inherit_missing: 308 | del self.tk_mask_cond_hint_original 309 | self.tk_mask_cond_hint_original = None 310 | # if guarantee_usage, stop searching for other TKs 311 | if self.current_timestep_keyframe.guarantee_usage: 312 | break 313 | # if eval_tk is outside of percent range, stop looking further 314 | else: 315 | break 316 | 317 | # if index changed, apply overrides 318 | if prev_index != self.current_timestep_index: 319 | if self.weights_override is not None: 320 | self.weights = self.weights_override 321 | if self.latent_keyframe_override is not None: 322 | self.latent_keyframes = self.latent_keyframe_override 323 | 324 | # make sure weights and latent_keyframes are in a workable state 325 | # Note: each AdvancedControlBaseImport should create their own get_universal_weights class 326 | self.prepare_weights() 327 | 328 | def prepare_weights(self): 329 | if self.weights is None or self.weights.weight_type == ControlWeightTypeImport.DEFAULT: 330 | self.weights = self.weights_default 331 | elif self.weights.weight_type == ControlWeightTypeImport.UNIVERSAL: 332 | # if universal and weight_mask present, no need to convert 333 | if self.weights.weight_mask is not None: 334 | return 335 | self.weights = self.get_universal_weights() 336 | 337 | def get_universal_weights(self) -> ControlWeightsImport: 338 | return self.weights 339 | 340 | def set_cond_hint_mask(self, mask_hint): 341 | self.mask_cond_hint_original = mask_hint 342 | return self 343 | 344 | def pre_run_inject(self, model, percent_to_timestep_function): 345 | self.base.pre_run(model, percent_to_timestep_function) 346 | self.pre_run_advanced(model, percent_to_timestep_function) 347 | 348 | def pre_run_advanced(self, model, percent_to_timestep_function): 349 | # for each timestep keyframe, calculate the start_t 350 | for tk in self.timestep_keyframes.keyframes: 351 | tk.start_t = percent_to_timestep_function(tk.start_percent) 352 | # clear variables 353 | self.cleanup_advanced() 354 | 355 | def get_control_inject(self, x_noisy, t, cond, batched_number): 356 | # prepare timestep and everything related 357 | self.prepare_current_timestep(t=t, batched_number=batched_number) 358 | # if should not perform any actions for the controlnet, exit without doing any work 359 | if self.strength == 0.0 or self.current_timestep_keyframe.strength == 0.0: 360 | control_prev = None 361 | if self.previous_controlnet is not None: 362 | control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) 363 | if control_prev is not None: 364 | return control_prev 365 | else: 366 | return None 367 | # otherwise, perform normal function 368 | return self.get_control_advanced(x_noisy, t, cond, batched_number) 369 | 370 | def get_control_advanced(self, x_noisy, t, cond, batched_number): 371 | pass 372 | 373 | def calc_weight(self, idx: int, x: Tensor, layers: int) -> Union[float, Tensor]: 374 | if self.weights.weight_mask is not None: 375 | # prepare weight mask 376 | self.prepare_weight_mask_cond_hint(x, self.batched_number) 377 | # adjust mask for current layer and return 378 | return torch.pow(self.weight_mask_cond_hint, self.get_calc_pow(idx=idx, layers=layers)) 379 | return self.weights.get(idx=idx) 380 | 381 | def get_calc_pow(self, idx: int, layers: int) -> int: 382 | return (layers-1)-idx 383 | 384 | def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int): 385 | # apply strengths, and get batch indeces to null out 386 | # AKA latents that should not be influenced by ControlNet 387 | if self.latent_keyframes is not None: 388 | latent_count = x.size(0)//batched_number 389 | indeces_to_null = set(range(latent_count)) 390 | mapped_indeces = None 391 | # if expecting subdivision, will need to translate between subset and actual idx values 392 | if self.sub_idxs: 393 | mapped_indeces = {} 394 | for i, actual in enumerate(self.sub_idxs): 395 | mapped_indeces[actual] = i 396 | for keyframe in self.latent_keyframes: 397 | real_index = keyframe.batch_index 398 | # if negative, count from end 399 | if real_index < 0: 400 | real_index += latent_count if self.sub_idxs is None else self.full_latent_length 401 | 402 | # if not mapping indeces, what you see is what you get 403 | if mapped_indeces is None: 404 | if real_index in indeces_to_null: 405 | indeces_to_null.remove(real_index) 406 | # otherwise, see if batch_index is even included in this set of latents 407 | else: 408 | real_index = mapped_indeces.get(real_index, None) 409 | if real_index is None: 410 | continue 411 | indeces_to_null.remove(real_index) 412 | 413 | # if real_index is outside the bounds of latents, don't apply 414 | if real_index >= latent_count or real_index < 0: 415 | continue 416 | 417 | # apply strength for each batched cond/uncond 418 | for b in range(batched_number): 419 | x[(latent_count*b)+real_index] = x[(latent_count*b)+real_index] * keyframe.strength 420 | 421 | # null them out by multiplying by null_latent_kf_strength 422 | for batch_index in indeces_to_null: 423 | # apply null for each batched cond/uncond 424 | for b in range(batched_number): 425 | x[(latent_count*b)+batch_index] = x[(latent_count*b)+batch_index] * self.current_timestep_keyframe.null_latent_kf_strength 426 | # apply masks, resizing mask to required dims 427 | if self.mask_cond_hint is not None: 428 | masks = prepare_mask_batch(self.mask_cond_hint, x.shape) 429 | x[:] = x[:] * masks 430 | if self.tk_mask_cond_hint is not None: 431 | masks = prepare_mask_batch(self.tk_mask_cond_hint, x.shape) 432 | x[:] = x[:] * masks 433 | # apply timestep keyframe strengths 434 | if self.current_timestep_keyframe.strength != 1.0: 435 | x[:] *= self.current_timestep_keyframe.strength 436 | 437 | def control_merge_inject(self: 'AdvancedControlBaseImport', control_input, control_output, control_prev, output_dtype): 438 | out = {'input':[], 'middle':[], 'output': []} 439 | 440 | if control_input is not None: 441 | for i in range(len(control_input)): 442 | key = 'input' 443 | x = control_input[i] 444 | if x is not None: 445 | self.apply_advanced_strengths_and_masks(x, self.batched_number) 446 | 447 | x *= self.strength * self.calc_weight(i, x, len(control_input)) 448 | if x.dtype != output_dtype: 449 | x = x.to(output_dtype) 450 | out[key].insert(0, x) 451 | 452 | if control_output is not None: 453 | for i in range(len(control_output)): 454 | if i == (len(control_output) - 1): 455 | key = 'middle' 456 | index = 0 457 | else: 458 | key = 'output' 459 | index = i 460 | x = control_output[i] 461 | if x is not None: 462 | self.apply_advanced_strengths_and_masks(x, self.batched_number) 463 | 464 | if self.global_average_pooling: 465 | x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) 466 | 467 | x *= self.strength * self.calc_weight(i, x, len(control_output)) 468 | if x.dtype != output_dtype: 469 | x = x.to(output_dtype) 470 | 471 | out[key].append(x) 472 | if control_prev is not None: 473 | for x in ['input', 'middle', 'output']: 474 | o = out[x] 475 | for i in range(len(control_prev[x])): 476 | prev_val = control_prev[x][i] 477 | if i >= len(o): 478 | o.append(prev_val) 479 | elif prev_val is not None: 480 | if o[i] is None: 481 | o[i] = prev_val 482 | else: 483 | o[i] += prev_val 484 | return out 485 | 486 | def prepare_mask_cond_hint(self, x_noisy: Tensor, t, cond, batched_number, dtype=None): 487 | self._prepare_mask("mask_cond_hint", self.mask_cond_hint_original, x_noisy, t, cond, batched_number, dtype) 488 | self.prepare_tk_mask_cond_hint(x_noisy, t, cond, batched_number, dtype) 489 | 490 | def prepare_tk_mask_cond_hint(self, x_noisy: Tensor, t, cond, batched_number, dtype=None): 491 | return self._prepare_mask("tk_mask_cond_hint", self.current_timestep_keyframe.mask_hint_orig, x_noisy, t, cond, batched_number, dtype) 492 | 493 | def prepare_weight_mask_cond_hint(self, x_noisy: Tensor, batched_number, dtype=None): 494 | return self._prepare_mask("weight_mask_cond_hint", self.weights.weight_mask, x_noisy, t=None, cond=None, batched_number=batched_number, dtype=dtype, direct_attn=True) 495 | 496 | def _prepare_mask(self, attr_name, orig_mask: Tensor, x_noisy: Tensor, t, cond, batched_number, dtype=None, direct_attn=False): 497 | # make mask appropriate dimensions, if present 498 | if orig_mask is not None: 499 | out_mask = getattr(self, attr_name) 500 | if self.sub_idxs is not None or out_mask is None or x_noisy.shape[2] * 8 != out_mask.shape[1] or x_noisy.shape[3] * 8 != out_mask.shape[2]: 501 | self._reset_attr(attr_name) 502 | del out_mask 503 | # TODO: perform upscale on only the sub_idxs masks at a time instead of all to conserve RAM 504 | # resize mask and match batch count 505 | multiplier = 1 if direct_attn else 8 506 | out_mask = prepare_mask_batch(orig_mask, x_noisy.shape, multiplier=multiplier) 507 | actual_latent_length = x_noisy.shape[0] // batched_number 508 | out_mask = comfy.utils.repeat_to_batch_size(out_mask, actual_latent_length if self.sub_idxs is None else self.full_latent_length) 509 | if self.sub_idxs is not None: 510 | out_mask = out_mask[self.sub_idxs] 511 | # make cond_hint_mask length match x_noise 512 | if x_noisy.shape[0] != out_mask.shape[0]: 513 | out_mask = broadcast_image_to(out_mask, x_noisy.shape[0], batched_number) 514 | # default dtype to be same as x_noisy 515 | if dtype is None: 516 | dtype = x_noisy.dtype 517 | setattr(self, attr_name, out_mask.to(dtype=dtype).to(self.device)) 518 | del out_mask 519 | 520 | def _reset_attr(self, attr_name, new_value=None): 521 | if hasattr(self, attr_name): 522 | delattr(self, attr_name) 523 | setattr(self, attr_name, new_value) 524 | 525 | def cleanup_inject(self): 526 | self.base.cleanup() 527 | self.cleanup_advanced() 528 | 529 | def cleanup_advanced(self): 530 | self.sub_idxs = None 531 | self.full_latent_length = 0 532 | self.context_length = 0 533 | self.t = None 534 | self.batched_number = None 535 | self.weights = None 536 | self.latent_keyframes = None 537 | # timestep stuff 538 | self.current_timestep_keyframe = None 539 | self.next_timestep_keyframe = None 540 | self.current_timestep_index = -1 541 | # clear mask hints 542 | if self.mask_cond_hint is not None: 543 | del self.mask_cond_hint 544 | self.mask_cond_hint = None 545 | if self.tk_mask_cond_hint_original is not None: 546 | del self.tk_mask_cond_hint_original 547 | self.tk_mask_cond_hint_original = None 548 | if self.tk_mask_cond_hint is not None: 549 | del self.tk_mask_cond_hint 550 | self.tk_mask_cond_hint = None 551 | if self.weight_mask_cond_hint is not None: 552 | del self.weight_mask_cond_hint 553 | self.weight_mask_cond_hint = None 554 | 555 | def copy_to_advanced(self, copied: 'AdvancedControlBaseImport'): 556 | copied.mask_cond_hint_original = self.mask_cond_hint_original 557 | copied.weights_override = self.weights_override 558 | copied.latent_keyframe_override = self.latent_keyframe_override 559 | 560 | 561 | class ControlNetAdvancedImport(ControlNet, AdvancedControlBaseImport): 562 | def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroupImport, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): 563 | super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, device=device, load_device=load_device, manual_cast_dtype=manual_cast_dtype) 564 | AdvancedControlBaseImport.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeightsImport.controlnet()) 565 | 566 | def get_universal_weights(self) -> ControlWeightsImport: 567 | raw_weights = [(self.weights.base_multiplier ** float(12 - i)) for i in range(13)] 568 | return ControlWeightsImport.controlnet(raw_weights, self.weights.flip_weights) 569 | 570 | def get_control_advanced(self, x_noisy, t, cond, batched_number): 571 | # perform special version of get_control that supports sliding context and masks 572 | return self.sliding_get_control(x_noisy, t, cond, batched_number) 573 | 574 | def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number): 575 | control_prev = None 576 | if self.previous_controlnet is not None: 577 | control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) 578 | 579 | if self.timestep_range is not None: 580 | if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: 581 | if control_prev is not None: 582 | return control_prev 583 | else: 584 | return None 585 | 586 | dtype = self.control_model.dtype 587 | if self.manual_cast_dtype is not None: 588 | dtype = self.manual_cast_dtype 589 | 590 | output_dtype = x_noisy.dtype 591 | # make cond_hint appropriate dimensions 592 | # TODO: change this to not require cond_hint upscaling every step when self.sub_idxs are present 593 | if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: 594 | if self.cond_hint is not None: 595 | del self.cond_hint 596 | self.cond_hint = None 597 | # if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling 598 | if self.sub_idxs is not None and self.cond_hint_original.size(0) >= self.full_latent_length: 599 | self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device) 600 | else: 601 | self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device) 602 | if x_noisy.shape[0] != self.cond_hint.shape[0]: 603 | self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) 604 | 605 | # prepare mask_cond_hint 606 | self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype) 607 | 608 | context = cond['c_crossattn'] 609 | # uses 'y' in new ComfyUI update 610 | y = cond.get('y', None) 611 | if y is None: # TODO: remove this in the future since no longer used by newest ComfyUI 612 | y = cond.get('c_adm', None) 613 | if y is not None: 614 | y = y.to(dtype) 615 | timestep = self.model_sampling_current.timestep(t) 616 | x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) 617 | 618 | control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) 619 | return self.control_merge(None, control, control_prev, output_dtype) 620 | 621 | def copy(self): 622 | c = ControlNetAdvancedImport(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) 623 | self.copy_to(c) 624 | self.copy_to_advanced(c) 625 | return c 626 | 627 | @staticmethod 628 | def from_vanilla(v: ControlNet, timestep_keyframe: TimestepKeyframeGroupImport=None) -> 'ControlNetAdvancedImport': 629 | return ControlNetAdvancedImport(control_model=v.control_model, timestep_keyframes=timestep_keyframe, 630 | global_average_pooling=v.global_average_pooling, device=v.device, load_device=v.load_device, manual_cast_dtype=v.manual_cast_dtype) 631 | 632 | 633 | class T2IAdapterAdvancedImport(T2IAdapter, AdvancedControlBaseImport): 634 | def __init__(self, t2i_model, timestep_keyframes: TimestepKeyframeGroupImport, channels_in, device=None): 635 | super().__init__(t2i_model=t2i_model, channels_in=channels_in, device=device) 636 | AdvancedControlBaseImport.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeightsImport.t2iadapter()) 637 | 638 | def get_universal_weights(self) -> ControlWeightsImport: 639 | raw_weights = [(self.weights.base_multiplier ** float(7 - i)) for i in range(8)] 640 | raw_weights = [raw_weights[-8], raw_weights[-3], raw_weights[-2], raw_weights[-1]] 641 | raw_weights = get_properly_arranged_t2i_weights(raw_weights) 642 | return ControlWeightsImport.t2iadapter(raw_weights, self.weights.flip_weights) 643 | 644 | def get_calc_pow(self, idx: int, layers: int) -> int: 645 | # match how T2IAdapterAdvancedImport deals with universal weights 646 | indeces = [7 - i for i in range(8)] 647 | indeces = [indeces[-8], indeces[-3], indeces[-2], indeces[-1]] 648 | indeces = get_properly_arranged_t2i_weights(indeces) 649 | return indeces[idx] 650 | 651 | def get_control_advanced(self, x_noisy, t, cond, batched_number): 652 | # prepare timestep and everything related 653 | self.prepare_current_timestep(t=t, batched_number=batched_number) 654 | try: 655 | # if sub indexes present, replace original hint with subsection 656 | if self.sub_idxs is not None: 657 | # cond hints 658 | full_cond_hint_original = self.cond_hint_original 659 | del self.cond_hint 660 | self.cond_hint = None 661 | self.cond_hint_original = full_cond_hint_original[self.sub_idxs] 662 | # mask hints 663 | self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number) 664 | return super().get_control(x_noisy, t, cond, batched_number) 665 | finally: 666 | if self.sub_idxs is not None: 667 | # replace original cond hint 668 | self.cond_hint_original = full_cond_hint_original 669 | del full_cond_hint_original 670 | 671 | def copy(self): 672 | c = T2IAdapterAdvancedImport(self.t2i_model, self.timestep_keyframes, self.channels_in) 673 | self.copy_to(c) 674 | self.copy_to_advanced(c) 675 | return c 676 | 677 | def cleanup(self): 678 | super().cleanup() 679 | self.cleanup_advanced() 680 | 681 | @staticmethod 682 | def from_vanilla(v: T2IAdapter, timestep_keyframe: TimestepKeyframeGroupImport=None) -> 'T2IAdapterAdvancedImport': 683 | return T2IAdapterAdvancedImport(t2i_model=v.t2i_model, timestep_keyframes=timestep_keyframe, channels_in=v.channels_in, device=v.device) 684 | 685 | 686 | class ControlLoraAdvancedImport(ControlLora, AdvancedControlBaseImport): 687 | def __init__(self, control_weights, timestep_keyframes: TimestepKeyframeGroupImport, global_average_pooling=False, device=None): 688 | super().__init__(control_weights=control_weights, global_average_pooling=global_average_pooling, device=device) 689 | AdvancedControlBaseImport.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeightsImport.controllora()) 690 | # use some functions from ControlNetAdvancedImport 691 | self.get_control_advanced = ControlNetAdvancedImport.get_control_advanced.__get__(self, type(self)) 692 | self.sliding_get_control = ControlNetAdvancedImport.sliding_get_control.__get__(self, type(self)) 693 | 694 | def get_universal_weights(self) -> ControlWeightsImport: 695 | raw_weights = [(self.weights.base_multiplier ** float(9 - i)) for i in range(10)] 696 | return ControlWeightsImport.controllora(raw_weights, self.weights.flip_weights) 697 | 698 | def copy(self): 699 | c = ControlLoraAdvancedImport(self.control_weights, self.timestep_keyframes, global_average_pooling=self.global_average_pooling) 700 | self.copy_to(c) 701 | self.copy_to_advanced(c) 702 | return c 703 | 704 | def cleanup(self): 705 | super().cleanup() 706 | self.cleanup_advanced() 707 | 708 | @staticmethod 709 | def from_vanilla(v: ControlLora, timestep_keyframe: TimestepKeyframeGroupImport=None) -> 'ControlLoraAdvancedImport': 710 | return ControlLoraAdvancedImport(control_weights=v.control_weights, timestep_keyframes=timestep_keyframe, 711 | global_average_pooling=v.global_average_pooling, device=v.device) 712 | 713 | 714 | class ControlLLLiteAdvancedImport(ControlNet, AdvancedControlBaseImport): 715 | def __init__(self, control_weights, timestep_keyframes: TimestepKeyframeGroupImport, device=None): 716 | AdvancedControlBaseImport.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeightsImport.controllllite()) 717 | 718 | 719 | def load_controlnet(ckpt_path, timestep_keyframe: TimestepKeyframeGroupImport=None, model=None): 720 | control = comfy_cn.load_controlnet(ckpt_path, model=model) 721 | # TODO: support controlnet-lllite 722 | # if is None, see if is a non-vanilla ControlNet 723 | # if control is None: 724 | # controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) 725 | # # check if lllite 726 | # if "lllite_unet" in controlnet_data: 727 | # pass 728 | return convert_to_advanced(control, timestep_keyframe=timestep_keyframe) 729 | 730 | 731 | def convert_to_advanced(control, timestep_keyframe: TimestepKeyframeGroupImport=None): 732 | # if already advanced, leave it be 733 | if is_advanced_controlnet(control): 734 | return control 735 | # if exactly ControlNet returned, transform it into ControlNetAdvancedImport 736 | if type(control) == ControlNet: 737 | return ControlNetAdvancedImport.from_vanilla(v=control, timestep_keyframe=timestep_keyframe) 738 | # if exactly ControlLora returned, transform it into ControlLoraAdvancedImport 739 | elif type(control) == ControlLora: 740 | return ControlLoraAdvancedImport.from_vanilla(v=control, timestep_keyframe=timestep_keyframe) 741 | # if T2IAdapter returned, transform it into T2IAdapterAdvancedImport 742 | elif isinstance(control, T2IAdapter): 743 | return T2IAdapterAdvancedImport.from_vanilla(v=control, timestep_keyframe=timestep_keyframe) 744 | # otherwise, leave it be - might be something I am not supporting yet 745 | return control 746 | 747 | 748 | def is_advanced_controlnet(input_object): 749 | return hasattr(input_object, "sub_idxs") 750 | 751 | 752 | # adapted from comfy/sample.py 753 | def prepare_mask_batch(mask: Tensor, shape: Tensor, multiplier: int=1, match_dim1=False): 754 | mask = mask.clone() 755 | mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[2]*multiplier, shape[3]*multiplier), mode="bilinear") 756 | if match_dim1: 757 | mask = torch.cat([mask] * shape[1], dim=1) 758 | return mask 759 | 760 | 761 | # applies min-max normalization, from: 762 | # https://stackoverflow.com/questions/68791508/min-max-normalization-of-a-tensor-in-pytorch 763 | def normalize_min_max(x: Tensor, new_min = 0.0, new_max = 1.0): 764 | x_min, x_max = x.min(), x.max() 765 | return (((x - x_min)/(x_max - x_min)) * (new_max - new_min)) + new_min 766 | 767 | def linear_conversion(x, x_min=0.0, x_max=1.0, new_min=0.0, new_max=1.0): 768 | return (((x - x_min)/(x_max - x_min)) * (new_max - new_min)) + new_min 769 | 770 | 771 | class WeightTypeExceptionImport(TypeError): 772 | "Raised when weight not compatible with AdvancedControlBaseImport object" 773 | pass 774 | --------------------------------------------------------------------------------