├── src ├── __init__.py ├── architectures │ ├── __init__.py │ ├── general_architecture.py │ └── dit_lora.py ├── experimental │ ├── __init__,py.py │ ├── lora_attention_logger.py │ ├── lora_analyzer.py │ └── checkpoint_merge.py ├── device │ ├── __init__.py │ └── manager.py ├── utils │ ├── __init__.py │ ├── progress.py │ └── config.py ├── comfy_util.py ├── mergekit_utils.py ├── validation │ └── __init__.py ├── decomposition │ ├── __init__.py │ └── svd.py ├── lora_apply.py ├── merge │ ├── __init__.py │ ├── dispatcher.py │ ├── base_node.py │ └── utils.py ├── lora_selector.py ├── lora_save.py ├── lora_dir_stacker.py ├── nodes_lora_modifier.py ├── lora_stack_sampler.py ├── lora_decompose.py ├── lora_block_sampler.py ├── lora_power_stacker.py ├── lora_resize.py └── lora_mergekit_merge.py ├── requirements.txt ├── .gitmodules ├── assets ├── pm-lora_apply.png ├── pm-lora_merger.png ├── pm-save_lora.png ├── pm-block_sampler.png ├── pm-lora_modifier.png ├── pm-lora_stacker.png ├── pm-merge_methods.png ├── pm-basic_workflow.png ├── pm-lora_decomposer.png ├── pm-stack_from_dir.png ├── pm-lora_stack_sampler.png ├── pm-workflow_lora_resize.png └── pm-paramter-sweep-sampler.png ├── fonts └── ShareTechMono-Regular.ttf ├── tests ├── __init__.py ├── lora_keys │ ├── wan_2_2_lora.txt │ ├── zImage_lora.txt │ └── qwen_image_edit_lora.txt ├── conftest.py ├── test_algorithms.py ├── test_types.py └── test_decomposition.py ├── preset.txt ├── pyproject.toml ├── requirements-dev.txt ├── .github └── workflows │ └── publish.yml ├── run_pytest.py ├── pytest.ini ├── LICENSE ├── debug_keys.py ├── __init__.py └── js ├── dit.svg └── lora-merge.js /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/arcee-ai/mergekit.git#egg=mergekit 2 | lxml 3 | -------------------------------------------------------------------------------- /src/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | from .general_architecture import LORA_STACK, LORA_KEY_DICT, LORA_WEIGHTS -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "mergekit"] 2 | path = mergekit 3 | url = https://github.com/arcee-ai/mergekit.git 4 | -------------------------------------------------------------------------------- /assets/pm-lora_apply.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larsupb/LoRA-Merger-ComfyUI/HEAD/assets/pm-lora_apply.png -------------------------------------------------------------------------------- /assets/pm-lora_merger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larsupb/LoRA-Merger-ComfyUI/HEAD/assets/pm-lora_merger.png -------------------------------------------------------------------------------- /assets/pm-save_lora.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larsupb/LoRA-Merger-ComfyUI/HEAD/assets/pm-save_lora.png -------------------------------------------------------------------------------- /assets/pm-block_sampler.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larsupb/LoRA-Merger-ComfyUI/HEAD/assets/pm-block_sampler.png -------------------------------------------------------------------------------- /assets/pm-lora_modifier.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larsupb/LoRA-Merger-ComfyUI/HEAD/assets/pm-lora_modifier.png -------------------------------------------------------------------------------- /assets/pm-lora_stacker.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larsupb/LoRA-Merger-ComfyUI/HEAD/assets/pm-lora_stacker.png -------------------------------------------------------------------------------- /assets/pm-merge_methods.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larsupb/LoRA-Merger-ComfyUI/HEAD/assets/pm-merge_methods.png -------------------------------------------------------------------------------- /assets/pm-basic_workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larsupb/LoRA-Merger-ComfyUI/HEAD/assets/pm-basic_workflow.png -------------------------------------------------------------------------------- /assets/pm-lora_decomposer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larsupb/LoRA-Merger-ComfyUI/HEAD/assets/pm-lora_decomposer.png -------------------------------------------------------------------------------- /assets/pm-stack_from_dir.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larsupb/LoRA-Merger-ComfyUI/HEAD/assets/pm-stack_from_dir.png -------------------------------------------------------------------------------- /assets/pm-lora_stack_sampler.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larsupb/LoRA-Merger-ComfyUI/HEAD/assets/pm-lora_stack_sampler.png -------------------------------------------------------------------------------- /fonts/ShareTechMono-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larsupb/LoRA-Merger-ComfyUI/HEAD/fonts/ShareTechMono-Regular.ttf -------------------------------------------------------------------------------- /assets/pm-workflow_lora_resize.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larsupb/LoRA-Merger-ComfyUI/HEAD/assets/pm-workflow_lora_resize.png -------------------------------------------------------------------------------- /assets/pm-paramter-sweep-sampler.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larsupb/LoRA-Merger-ComfyUI/HEAD/assets/pm-paramter-sweep-sampler.png -------------------------------------------------------------------------------- /src/experimental/__init__,py.py: -------------------------------------------------------------------------------- 1 | from .checkpoint_merge import checkpoint_process, MergeMethod 2 | from .lora_attention_logger import LoRAAttentionLogger 3 | from .lora_analyzer import LoRAAnalyzer -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test suite for LoRA Power-Merger. 3 | 4 | Unit tests for core functionality. Integration tests are not included 5 | as they require a running ComfyUI instance. 6 | """ 7 | -------------------------------------------------------------------------------- /src/device/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Device management module for LoRA Power-Merger. 3 | 4 | Provides unified device and dtype handling across the codebase. 5 | """ 6 | 7 | from .manager import DeviceManager 8 | 9 | __all__ = ['DeviceManager'] 10 | -------------------------------------------------------------------------------- /src/architectures/general_architecture.py: -------------------------------------------------------------------------------- 1 | # Re-export types from centralized types module 2 | # This module maintains backward compatibility for existing imports 3 | from ..types import ( 4 | LORA_KEY_DICT, 5 | LORA_STACK, 6 | LORA_WEIGHTS, 7 | ) 8 | 9 | __all__ = ['LORA_KEY_DICT', 'LORA_STACK', 'LORA_WEIGHTS'] -------------------------------------------------------------------------------- /preset.txt: -------------------------------------------------------------------------------- 1 | INALL:1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0 2 | OUTALL:1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1 3 | MIDD:1,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,1,1,1 4 | INALLXL:1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0 5 | OUTALLXL:1,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1 6 | MIDDXL:1,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0 -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities package for LoRA Power-Merger. 3 | 4 | Provides helper utilities including: 5 | - Layer filtering 6 | - Architecture detection 7 | - Progress tracking 8 | - Configuration constants 9 | """ 10 | 11 | from .layer_filter import LayerFilter, detect_lora_architecture 12 | from .progress import ThreadSafeProgressBar 13 | from .config import * 14 | 15 | __all__ = [ 16 | 'LayerFilter', 17 | 'detect_lora_architecture', 18 | 'ThreadSafeProgressBar', 19 | ] 20 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "lora-merger-comfyui" 3 | description = "An extension for merging LoRAs. Offers a wide range of LoRA merge techniques (including dare) and XY plots. XY plots require efficiency nodes." 4 | version = "2.1.0" 5 | license = "MIT" 6 | 7 | [project.urls] 8 | Repository = "https://github.com/larsupb/LoRA-Merger-ComfyUI" 9 | # Used by Comfy Registry https://comfyregistry.org 10 | 11 | [tool.comfy] 12 | PublisherId = "larsupb" 13 | DisplayName = "LoRA-Merger-ComfyUI" 14 | Icon = "" 15 | -------------------------------------------------------------------------------- /src/comfy_util.py: -------------------------------------------------------------------------------- 1 | import comfy 2 | import comfy.model_management 3 | 4 | def load_as_comfy_lora(lora: dict, model): 5 | if 'lora_raw' not in lora or lora['lora_raw'] is None: 6 | raise ValueError("LoRA data is missing. Please provide a valid LoRA dictionary with 'lora_raw' key.") 7 | key_map = {} 8 | if model is not None: 9 | key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) 10 | patch_dict = comfy.lora.load_lora(lora['lora_raw'], key_map) 11 | return patch_dict 12 | -------------------------------------------------------------------------------- /src/mergekit_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Literal 2 | 3 | import torch 4 | from mergekit.common import ModelReference 5 | 6 | MERGEKIT_GTA_MODES = Literal[ 7 | "della", "breadcrumbs", "dare", "ties", "task_arithmetic", "linear"] 8 | 9 | 10 | def load_on_device(tensors: Dict[ModelReference, torch.Tensor], 11 | tensor_weights: Dict[ModelReference, torch.Tensor], device, dtype): 12 | for k, v in tensors.items(): 13 | tensors[k] = v.to(device=device, dtype=dtype) 14 | for k, v in tensor_weights.items(): 15 | tensor_weights[k] = v.to(device=device, dtype=dtype) 16 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # Development dependencies for testing LoRA Power-Merger 2 | 3 | # Main dependencies (from requirements.txt) 4 | git+https://github.com/arcee-ai/mergekit.git#egg=mergekit 5 | lxml 6 | 7 | # Testing framework 8 | pytest>=7.0.0 9 | pytest-cov>=4.0.0 # Coverage reporting 10 | pytest-mock>=3.10.0 # Advanced mocking utilities 11 | pytest-xdist>=3.0.0 # Parallel test execution 12 | 13 | # Code quality 14 | black>=23.0.0 # Code formatting 15 | flake8>=6.0.0 # Linting 16 | mypy>=1.0.0 # Type checking 17 | isort>=5.12.0 # Import sorting 18 | 19 | # Utilities for testing 20 | numpy>=1.21.0 # For numerical comparisons -------------------------------------------------------------------------------- /src/validation/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Validation module for LoRA Power-Merger. 3 | 4 | Provides comprehensive input validation for merge operations including: 5 | - LoRA stack validation 6 | - Tensor shape compatibility checks 7 | - Method parameter validation 8 | - Weight/strength validation 9 | """ 10 | 11 | from .validators import ( 12 | LoRAStackValidator, 13 | TensorShapeValidator, 14 | MergeParameterValidator, 15 | validate_lora_stack_for_merge, 16 | validate_tensor_shapes_compatible, 17 | ) 18 | 19 | __all__ = [ 20 | 'LoRAStackValidator', 21 | 'TensorShapeValidator', 22 | 'MergeParameterValidator', 23 | 'validate_lora_stack_for_merge', 24 | 'validate_tensor_shapes_compatible', 25 | ] 26 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | permissions: 12 | issues: write 13 | 14 | jobs: 15 | publish-node: 16 | name: Publish Custom Node to registry 17 | runs-on: ubuntu-latest 18 | if: ${{ github.repository_owner == 'larsupb' }} 19 | steps: 20 | - name: Check out code 21 | uses: actions/checkout@v4 22 | - name: Publish Custom Node 23 | uses: Comfy-Org/publish-node-action@v1 24 | with: 25 | ## Add your own personal access token to your Github Repository secrets and reference it here. 26 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 27 | -------------------------------------------------------------------------------- /src/decomposition/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Decomposition module for LoRA Power-Merger. 3 | 4 | Handles LoRA decomposition operations including: 5 | - Decomposition of LoRAs into (up, down, alpha) tuples 6 | - Caching layer for expensive decomposition operations 7 | - Base classes for decomposition strategies (SVD, QR, etc.) 8 | """ 9 | 10 | from .base import ( 11 | TensorDecomposer, 12 | DecompositionMethod, 13 | SingularValueDistribution, 14 | ) 15 | from .svd import ( 16 | SVDDecomposer, 17 | RandomizedSVDDecomposer, 18 | EnergyBasedRandomizedSVDDecomposer, 19 | QRDecomposer, 20 | ) 21 | 22 | __all__ = [ 23 | # Base classes 24 | 'TensorDecomposer', 25 | 'DecompositionMethod', 26 | 'SingularValueDistribution', 27 | # Decomposers 28 | 'SVDDecomposer', 29 | 'RandomizedSVDDecomposer', 30 | 'EnergyBasedRandomizedSVDDecomposer', 31 | 'QRDecomposer', 32 | ] 33 | -------------------------------------------------------------------------------- /run_pytest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Test runner that sets up mocking before pytest runs 4 | """ 5 | 6 | import sys 7 | from unittest.mock import MagicMock 8 | from typing import Tuple 9 | import torch 10 | 11 | # Mock ComfyUI modules BEFORE any imports 12 | sys.modules['comfy'] = MagicMock() 13 | sys.modules['comfy.utils'] = MagicMock() 14 | sys.modules['comfy.model_management'] = MagicMock() 15 | sys.modules['comfy.lora'] = MagicMock() 16 | 17 | # Mock architectures module 18 | architectures_mock = MagicMock() 19 | architectures_mock.sd_lora.UP_DOWN_ALPHA_TUPLE = Tuple[torch.Tensor, torch.Tensor, float] 20 | sys.modules['architectures'] = architectures_mock 21 | sys.modules['architectures.sd_lora'] = architectures_mock.sd_lora 22 | sys.modules['architectures.general_architecture'] = MagicMock() 23 | sys.modules['architectures.wan_lora'] = MagicMock() 24 | 25 | # Now run pytest 26 | import pytest 27 | 28 | if __name__ == "__main__": 29 | sys.exit(pytest.main(sys.argv[1:] or ["test_utility.py", "-v"])) -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | # Pytest configuration for LoRA Power-Merger tests 3 | 4 | # Test discovery patterns 5 | python_files = test_*.py 6 | python_classes = Test* 7 | python_functions = test_* 8 | 9 | # Output options 10 | addopts = 11 | -v 12 | --tb=short 13 | --strict-markers 14 | --disable-warnings 15 | --color=yes 16 | --import-mode=importlib 17 | 18 | # Markers for organizing tests 19 | markers = 20 | unit: Unit tests for individual functions 21 | integration: Integration tests combining multiple components 22 | slow: Tests that take longer to run 23 | edge_case: Edge case and boundary condition tests 24 | 25 | # Test paths 26 | testpaths = tests 27 | 28 | # Minimum Python version 29 | minversion = 3.8 30 | 31 | # Coverage options (if using pytest-cov) 32 | # Uncomment these if you install pytest-cov 33 | # addopts = --cov=. --cov-report=html --cov-report=term 34 | 35 | # Ignore patterns 36 | norecursedirs = 37 | .git 38 | .idea 39 | __pycache__ 40 | deprecated 41 | experimental 42 | js 43 | fonts 44 | src -------------------------------------------------------------------------------- /src/lora_apply.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .comfy_util import load_as_comfy_lora 4 | 5 | 6 | class LoraApply: 7 | def __init__(self): 8 | self.loaded_lora = None 9 | 10 | @classmethod 11 | def INPUT_TYPES(s): 12 | return {"required": {"model": ("MODEL",), 13 | "lora": ("LoRABundle",), 14 | }} 15 | 16 | RETURN_TYPES = ("MODEL",) 17 | FUNCTION = "apply_merged_lora" 18 | CATEGORY = "LoRA PowerMerge" 19 | 20 | def apply_merged_lora(self, model, lora): 21 | strength_model = lora["strength_model"] 22 | 23 | if strength_model == 0: 24 | return (model,) 25 | 26 | if 'lora' not in lora or lora['lora'] is None: 27 | lora['lora'] = load_as_comfy_lora(lora, model) 28 | 29 | new_model_patcher = model.clone() 30 | k = new_model_patcher.add_patches(lora['lora'], strength_model) 31 | 32 | k = set(k) 33 | for x in lora["lora"]: 34 | if x not in k: 35 | logging.warning("PM LoraApply: NOT LOADED {}".format(x)) 36 | 37 | return (new_model_patcher,) 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 LoRA Power-Merger Contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /src/merge/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Merge module for LoRA Power-Merger. 3 | 4 | This module contains all merge-related functionality: 5 | - Core merger logic (LoraMergerMergekit node) 6 | - Merge algorithm implementations 7 | - Helper utilities for merging operations 8 | - Algorithm dispatcher 9 | - Base classes for merge method nodes 10 | """ 11 | 12 | from .utils import ( 13 | create_map, 14 | create_tensor_param, 15 | parse_layer_filter, 16 | apply_layer_filter, 17 | apply_weights_to_tensors, 18 | ) 19 | from .dispatcher import get_merge_method, prepare_method_args 20 | from .algorithms import MERGE_ALGORITHMS, get_merge_algorithm 21 | from .base_node import BaseMergeMethodNode, BaseTaskArithmeticNode 22 | 23 | __all__ = [ 24 | # Utils 25 | 'create_map', 26 | 'create_tensor_param', 27 | 'parse_layer_filter', 28 | 'apply_layer_filter', 29 | 'apply_weights_to_tensors', 30 | # Dispatcher 31 | 'get_merge_method', 32 | 'prepare_method_args', 33 | # Algorithms 34 | 'MERGE_ALGORITHMS', 35 | 'get_merge_algorithm', 36 | # Base classes 37 | 'BaseMergeMethodNode', 38 | 'BaseTaskArithmeticNode', 39 | ] 40 | -------------------------------------------------------------------------------- /src/lora_selector.py: -------------------------------------------------------------------------------- 1 | from .types import LORA_STACK 2 | 3 | 4 | class LoRASelect: 5 | """ 6 | Select one LoRA out of a LoRAStack by its index. 7 | Optionally accepts raw LoRA dict to preserve CLIP weights. 8 | """ 9 | @classmethod 10 | def INPUT_TYPES(s): 11 | return { 12 | "required": { 13 | "key_dicts": ("LoRAStack",), 14 | "index": ("INT", {"default": 0, "min": 0, "max": 1000, "tooltip": "Index of the LoRA to select."}), 15 | }, 16 | "optional": { 17 | "lora_raw_dict": ("LoRARawDict", {"tooltip": "Optional raw LoRA dict to preserve CLIP weights"}), 18 | }, 19 | } 20 | 21 | RETURN_TYPES = ("LoRABundle",) 22 | FUNCTION = "select_lora" 23 | CATEGORY = "LoRA PowerMerge" 24 | DESCRIPTION = "Select one LoRA from stack by index. Preserves CLIP weights if raw dict is provided." 25 | 26 | def select_lora(self, key_dicts: LORA_STACK, index: int, lora_raw_dict: dict = None) -> (dict,): 27 | keys = list(key_dicts.keys()) 28 | if index < 0 or index >= len(keys): 29 | raise IndexError(f"Index {index} out of range for LoRAStack with {len(keys)} items.") 30 | selected_key = keys[index] 31 | 32 | bundle = { 33 | "lora": key_dicts[selected_key], 34 | "strength_model": 1.0, 35 | "name": selected_key 36 | } 37 | 38 | # Add raw LoRA data if available (for preserving CLIP weights) 39 | if lora_raw_dict is not None and selected_key in lora_raw_dict: 40 | bundle["lora_raw"] = lora_raw_dict[selected_key] 41 | 42 | return (bundle,) 43 | -------------------------------------------------------------------------------- /src/lora_save.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import comfy 4 | import folder_paths 5 | 6 | 7 | from .architectures.sd_lora import convert_to_regular_lora 8 | 9 | class LoraSave: 10 | def __init__(self): 11 | self.loaded_lora = None 12 | 13 | @classmethod 14 | def INPUT_TYPES(s): 15 | return {"required": { 16 | "model": ("MODEL",), 17 | "lora": ("LoRABundle",), 18 | "file_name": ("STRING", {"multiline": False, "default": "merged"}), "extension": (["safetensors"], ), 19 | }} 20 | RETURN_TYPES = () 21 | FUNCTION = "lora_save" 22 | CATEGORY = "LoRA PowerMerge" 23 | 24 | OUTPUT_NODE = True 25 | 26 | def lora_save(self, model, lora, file_name, extension): 27 | save_path = os.path.join(folder_paths.folder_names_and_paths["loras"][0][0], file_name + "." + extension) 28 | 29 | # Convert model weights from ComfyUI format to regular LoRA format 30 | state_dict = lora['lora'] 31 | new_state_dict = convert_to_regular_lora(model, state_dict) 32 | 33 | # If lora_raw exists, extract and merge CLIP weights 34 | if 'lora_raw' in lora and lora['lora_raw'] is not None: 35 | lora_raw = lora['lora_raw'] 36 | 37 | # Extract CLIP weights from original LoRA 38 | # CLIP weights have keys starting with "lora_te" or "lora_te1_text_model" etc. 39 | for key in lora_raw.keys(): 40 | # Check if this is a CLIP/text encoder key 41 | if any(clip_prefix in key for clip_prefix in [ 42 | 'lora_te', 'text_encoder', 'lora_te1_text_model', 'lora_te2_text_model', 43 | 'text_model', 'transformer.text_model' 44 | ]): 45 | # Copy CLIP weight from original to new state dict 46 | # This preserves the unmodified CLIP weights 47 | new_state_dict[key] = lora_raw[key] 48 | 49 | print(f"Saving LoRA to {save_path}") 50 | comfy.utils.save_torch_file(new_state_dict, save_path) 51 | 52 | return {} 53 | -------------------------------------------------------------------------------- /src/merge/dispatcher.py: -------------------------------------------------------------------------------- 1 | """ 2 | Merge method dispatcher for LoRA Power-Merger. 3 | 4 | Handles routing merge requests to the appropriate algorithm implementation 5 | based on method name. Replaces the large if-elif chain with a registry pattern. 6 | """ 7 | 8 | from typing import get_args, Callable 9 | 10 | from ..mergekit_utils import MERGEKIT_GTA_MODES 11 | from .algorithms import ( 12 | MERGE_ALGORITHMS, 13 | generalized_task_arithmetic_merge, 14 | ) 15 | 16 | 17 | # Extended registry including GTA methods 18 | def get_merge_method(method_name: str) -> Callable: 19 | """ 20 | Get merge method function by name. 21 | 22 | Supports both direct algorithm names and GTA mode variants. 23 | 24 | Args: 25 | method_name: Name of the merge method 26 | 27 | Returns: 28 | Merge method function 29 | 30 | Raises: 31 | ValueError: If method name is unknown/unsupported 32 | 33 | Examples: 34 | >>> get_merge_method("linear") 35 | 36 | >>> get_merge_method("dare") # GTA mode 37 | 38 | """ 39 | # Check direct algorithm registry first 40 | if method_name in MERGE_ALGORITHMS: 41 | return MERGE_ALGORITHMS[method_name] 42 | 43 | # Check if it's a GTA mode 44 | if method_name in get_args(MERGEKIT_GTA_MODES): 45 | return generalized_task_arithmetic_merge 46 | 47 | # Unknown method 48 | available_methods = list(MERGE_ALGORITHMS.keys()) + list(get_args(MERGEKIT_GTA_MODES)) 49 | raise ValueError( 50 | f"Invalid / unsupported merge method: {method_name}. " 51 | f"Available methods: {', '.join(sorted(available_methods))}" 52 | ) 53 | 54 | 55 | def prepare_method_args(method_name: str, method_settings: dict) -> dict: 56 | """ 57 | Prepare method arguments dictionary for merge execution. 58 | 59 | Combines method name, default settings, and user-provided settings. 60 | 61 | Args: 62 | method_name: Name of the merge method 63 | method_settings: User-provided method settings 64 | 65 | Returns: 66 | Complete method arguments dictionary 67 | 68 | Example: 69 | >>> prepare_method_args("dare", {"density": 0.8}) 70 | { 71 | "mode": "dare", 72 | "int8_mask": False, 73 | "lambda_": 1.0, 74 | "density": 0.8 75 | } 76 | """ 77 | # Base method arguments 78 | method_args = { 79 | "mode": method_name, 80 | "int8_mask": False, 81 | "lambda_": 1.0, # Internal GTA processing (applied separately at the end) 82 | } 83 | 84 | # Merge with user settings 85 | method_args.update(method_settings) 86 | 87 | return method_args 88 | -------------------------------------------------------------------------------- /src/utils/progress.py: -------------------------------------------------------------------------------- 1 | """ 2 | Thread-safe progress tracking for LoRA Power-Merger. 3 | 4 | Provides ThreadSafeProgressBar wrapper for ComfyUI progress bars. 5 | """ 6 | 7 | import threading 8 | from typing import Optional 9 | import comfy.utils 10 | 11 | 12 | class ThreadSafeProgressBar: 13 | """ 14 | Thread-safe wrapper for ComfyUI progress bars. 15 | 16 | Ensures progress updates from multiple threads don't cause race conditions. 17 | """ 18 | 19 | def __init__(self, total: int, desc: str = "Processing"): 20 | """ 21 | Initialize thread-safe progress bar. 22 | 23 | Args: 24 | total: Total number of steps 25 | desc: Description for progress bar 26 | """ 27 | self.total = total 28 | self.desc = desc 29 | self._lock = threading.Lock() 30 | self._current = 0 31 | self._pbar = comfy.utils.ProgressBar(total) 32 | 33 | def update(self, n: int = 1): 34 | """ 35 | Update progress by n steps (thread-safe). 36 | 37 | Args: 38 | n: Number of steps to increment 39 | """ 40 | with self._lock: 41 | self._current += n 42 | self._pbar.update(n) 43 | 44 | def set_description(self, desc: str): 45 | """ 46 | Update progress bar description (thread-safe). 47 | 48 | Args: 49 | desc: New description 50 | """ 51 | with self._lock: 52 | self.desc = desc 53 | # ComfyUI progress bars don't support dynamic descriptions 54 | # but we store it for potential logging 55 | 56 | def reset(self): 57 | """Reset progress to zero (thread-safe).""" 58 | with self._lock: 59 | self._current = 0 60 | self._pbar = comfy.utils.ProgressBar(self.total) 61 | 62 | def close(self): 63 | """Close the progress bar (thread-safe).""" 64 | with self._lock: 65 | # Ensure we're at 100% 66 | if self._current < self.total: 67 | remaining = self.total - self._current 68 | self._pbar.update(remaining) 69 | 70 | @property 71 | def current(self) -> int: 72 | """Get current progress value (thread-safe).""" 73 | with self._lock: 74 | return self._current 75 | 76 | @property 77 | def percentage(self) -> float: 78 | """Get current progress percentage (thread-safe).""" 79 | with self._lock: 80 | if self.total == 0: 81 | return 0.0 82 | return (self._current / self.total) * 100.0 83 | 84 | def __enter__(self): 85 | """Context manager entry.""" 86 | return self 87 | 88 | def __exit__(self, exc_type, exc_val, exc_tb): 89 | """Context manager exit.""" 90 | self.close() 91 | return False 92 | -------------------------------------------------------------------------------- /tests/lora_keys/wan_2_2_lora.txt: -------------------------------------------------------------------------------- 1 | Analyzing LoRA: LoRA/WAN_2_2-LoRA.safetensors 2 | 3 | Total keys: 800 4 | 5 | Sample keys (first 20): 6 | diffusion_model.blocks.0.cross_attn.k.lora_A.weight 7 | diffusion_model.blocks.0.cross_attn.k.lora_B.weight 8 | diffusion_model.blocks.0.cross_attn.o.lora_A.weight 9 | diffusion_model.blocks.0.cross_attn.o.lora_B.weight 10 | diffusion_model.blocks.0.cross_attn.q.lora_A.weight 11 | diffusion_model.blocks.0.cross_attn.q.lora_B.weight 12 | diffusion_model.blocks.0.cross_attn.v.lora_A.weight 13 | diffusion_model.blocks.0.cross_attn.v.lora_B.weight 14 | diffusion_model.blocks.0.ffn.0.lora_A.weight 15 | diffusion_model.blocks.0.ffn.0.lora_B.weight 16 | diffusion_model.blocks.0.ffn.2.lora_A.weight 17 | diffusion_model.blocks.0.ffn.2.lora_B.weight 18 | diffusion_model.blocks.0.self_attn.k.lora_A.weight 19 | diffusion_model.blocks.0.self_attn.k.lora_B.weight 20 | diffusion_model.blocks.0.self_attn.o.lora_A.weight 21 | diffusion_model.blocks.0.self_attn.o.lora_B.weight 22 | diffusion_model.blocks.0.self_attn.q.lora_A.weight 23 | diffusion_model.blocks.0.self_attn.q.lora_B.weight 24 | diffusion_model.blocks.0.self_attn.v.lora_A.weight 25 | diffusion_model.blocks.0.self_attn.v.lora_B.weight 26 | 27 | ================================================================================ 28 | 29 | Unique layer components found: 30 | 0 31 | 1 32 | 10 33 | 11 34 | 12 35 | 13 36 | 14 37 | 15 38 | 16 39 | 17 40 | 18 41 | 19 42 | 2 43 | 20 44 | 21 45 | 22 46 | 23 47 | 24 48 | 25 49 | 26 50 | 27 51 | 28 52 | 29 53 | 3 54 | 30 55 | 31 56 | 32 57 | 33 58 | 34 59 | 35 60 | 36 61 | 37 62 | 38 63 | 39 64 | 4 65 | 5 66 | 6 67 | 7 68 | 8 69 | 9 70 | blocks 71 | cross_attn 72 | diffusion_model 73 | ffn 74 | k 75 | lora_A 76 | lora_B 77 | o 78 | q 79 | self_attn 80 | v 81 | 82 | ================================================================================ 83 | 84 | Pattern matching analysis: 85 | 86 | ATTENTION (640 keys): 87 | diffusion_model.blocks.0.cross_attn.k.lora_A.weight 88 | diffusion_model.blocks.0.cross_attn.k.lora_B.weight 89 | diffusion_model.blocks.0.cross_attn.o.lora_A.weight 90 | diffusion_model.blocks.0.cross_attn.o.lora_B.weight 91 | diffusion_model.blocks.0.cross_attn.q.lora_A.weight 92 | ... and 635 more 93 | 94 | MLP/FEEDFORWARD (800 keys): 95 | diffusion_model.blocks.0.cross_attn.k.lora_A.weight 96 | diffusion_model.blocks.0.cross_attn.k.lora_B.weight 97 | diffusion_model.blocks.0.cross_attn.o.lora_A.weight 98 | diffusion_model.blocks.0.cross_attn.o.lora_B.weight 99 | diffusion_model.blocks.0.cross_attn.q.lora_A.weight 100 | ... and 795 more 101 | 102 | ================================================================================ 103 | 104 | Architecture detection: 105 | -> Unknown architecture -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pytest configuration and fixtures for LoRA Power-Merger tests 3 | 4 | This file sets up the test environment by mocking ComfyUI dependencies 5 | that may not be available in the test environment. 6 | """ 7 | 8 | import sys 9 | import os 10 | from unittest.mock import MagicMock 11 | from typing import Tuple 12 | import torch 13 | 14 | # Add the src directory to the path to allow imports 15 | project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 16 | src_path = os.path.join(project_root, 'src') 17 | if src_path not in sys.path: 18 | sys.path.insert(0, src_path) 19 | 20 | # Prevent pytest from treating the project root as a package 21 | # by ensuring tests are collected from tests directory only 22 | import pytest 23 | 24 | def pytest_ignore_collect(collection_path, config): 25 | """Ignore collection of __init__.py in project root""" 26 | # Convert to string for comparison 27 | path_str = str(collection_path) 28 | # Ignore the root __init__.py 29 | if path_str.endswith('__init__.py') and 'tests' not in path_str: 30 | return True 31 | # Ignore src directory 32 | if '/src/' in path_str or path_str.endswith('/src'): 33 | return True 34 | return False 35 | 36 | # Mock ComfyUI modules before any test imports 37 | def pytest_configure(config): 38 | """Configure pytest and mock unavailable modules""" 39 | 40 | # Mock comfy module 41 | comfy_mock = MagicMock() 42 | sys.modules['comfy'] = comfy_mock 43 | sys.modules['comfy.utils'] = MagicMock() 44 | sys.modules['comfy.model_management'] = MagicMock() 45 | sys.modules['comfy.lora'] = MagicMock() 46 | sys.modules['comfy.weight_adapter'] = MagicMock() 47 | sys.modules['comfy.model_patcher'] = MagicMock() 48 | sys.modules['comfy.sd'] = MagicMock() 49 | 50 | # Create a LoRAAdapter mock class 51 | class LoRAAdapterMock: 52 | def __init__(self, *args, **kwargs): 53 | self.state_dict = {} 54 | 55 | comfy_mock.weight_adapter.LoRAAdapter = LoRAAdapterMock 56 | 57 | # Mock comfy_util module 58 | comfy_util_mock = MagicMock() 59 | sys.modules['comfy_util'] = comfy_util_mock 60 | 61 | # Mock nodes module (ComfyUI nodes) 62 | nodes_mock = MagicMock() 63 | sys.modules['nodes'] = nodes_mock 64 | 65 | # Mock folder_paths module (ComfyUI utility) 66 | folder_paths_mock = MagicMock() 67 | folder_paths_mock.get_folder_paths.return_value = [] 68 | folder_paths_mock.folder_names_and_paths = {} 69 | sys.modules['folder_paths'] = folder_paths_mock 70 | 71 | # Mock comfy_extras module (ComfyUI extras) 72 | comfy_extras_mock = MagicMock() 73 | sys.modules['comfy_extras'] = comfy_extras_mock 74 | sys.modules['comfy_extras.nodes_custom_sampler'] = MagicMock() 75 | 76 | # Mock architectures module to avoid relative import issues 77 | architectures_mock = MagicMock() 78 | # Define the UP_DOWN_ALPHA_TUPLE type alias 79 | architectures_mock.sd_lora.UP_DOWN_ALPHA_TUPLE = Tuple[torch.Tensor, torch.Tensor, float] 80 | sys.modules['architectures'] = architectures_mock 81 | sys.modules['architectures.sd_lora'] = architectures_mock.sd_lora 82 | sys.modules['architectures.general_architecture'] = MagicMock() 83 | sys.modules['architectures.wan_lora'] = MagicMock() -------------------------------------------------------------------------------- /tests/lora_keys/zImage_lora.txt: -------------------------------------------------------------------------------- 1 | Analyzing LoRA: loras/zImage-LoRA.safetensors 2 | 3 | Total keys: 480 4 | 5 | Sample keys (first 20): 6 | diffusion_model.layers.0.adaLN_modulation.0.lora_A.weight 7 | diffusion_model.layers.0.adaLN_modulation.0.lora_B.weight 8 | diffusion_model.layers.0.attention.to_k.lora_A.weight 9 | diffusion_model.layers.0.attention.to_k.lora_B.weight 10 | diffusion_model.layers.0.attention.to_out.0.lora_A.weight 11 | diffusion_model.layers.0.attention.to_out.0.lora_B.weight 12 | diffusion_model.layers.0.attention.to_q.lora_A.weight 13 | diffusion_model.layers.0.attention.to_q.lora_B.weight 14 | diffusion_model.layers.0.attention.to_v.lora_A.weight 15 | diffusion_model.layers.0.attention.to_v.lora_B.weight 16 | diffusion_model.layers.0.feed_forward.w1.lora_A.weight 17 | diffusion_model.layers.0.feed_forward.w1.lora_B.weight 18 | diffusion_model.layers.0.feed_forward.w2.lora_A.weight 19 | diffusion_model.layers.0.feed_forward.w2.lora_B.weight 20 | diffusion_model.layers.0.feed_forward.w3.lora_A.weight 21 | diffusion_model.layers.0.feed_forward.w3.lora_B.weight 22 | diffusion_model.layers.1.adaLN_modulation.0.lora_A.weight 23 | diffusion_model.layers.1.adaLN_modulation.0.lora_B.weight 24 | diffusion_model.layers.1.attention.to_k.lora_A.weight 25 | diffusion_model.layers.1.attention.to_k.lora_B.weight 26 | 27 | ================================================================================ 28 | 29 | Unique layer components found: 30 | 0 31 | 1 32 | 10 33 | 11 34 | 12 35 | 13 36 | 14 37 | 15 38 | 16 39 | 17 40 | 18 41 | 19 42 | 2 43 | 20 44 | 21 45 | 22 46 | 23 47 | 24 48 | 25 49 | 26 50 | 27 51 | 28 52 | 29 53 | 3 54 | 4 55 | 5 56 | 6 57 | 7 58 | 8 59 | 9 60 | adaLN_modulation 61 | attention 62 | diffusion_model 63 | feed_forward 64 | layers 65 | lora_A 66 | lora_B 67 | to_k 68 | to_out 69 | to_q 70 | to_v 71 | w1 72 | w2 73 | w3 74 | 75 | ================================================================================ 76 | 77 | Pattern matching analysis: 78 | 79 | ATTENTION (240 keys): 80 | diffusion_model.layers.0.attention.to_k.lora_A.weight 81 | diffusion_model.layers.0.attention.to_k.lora_B.weight 82 | diffusion_model.layers.0.attention.to_out.0.lora_A.weight 83 | diffusion_model.layers.0.attention.to_out.0.lora_B.weight 84 | diffusion_model.layers.0.attention.to_q.lora_A.weight 85 | ... and 235 more 86 | 87 | MLP/FEEDFORWARD (480 keys): 88 | diffusion_model.layers.0.adaLN_modulation.0.lora_A.weight 89 | diffusion_model.layers.0.adaLN_modulation.0.lora_B.weight 90 | diffusion_model.layers.0.attention.to_k.lora_A.weight 91 | diffusion_model.layers.0.attention.to_k.lora_B.weight 92 | diffusion_model.layers.0.attention.to_out.0.lora_A.weight 93 | ... and 475 more 94 | 95 | NORM (60 keys): 96 | diffusion_model.layers.0.adaLN_modulation.0.lora_A.weight 97 | diffusion_model.layers.0.adaLN_modulation.0.lora_B.weight 98 | diffusion_model.layers.1.adaLN_modulation.0.lora_A.weight 99 | diffusion_model.layers.1.adaLN_modulation.0.lora_B.weight 100 | diffusion_model.layers.10.adaLN_modulation.0.lora_A.weight 101 | ... and 55 more 102 | 103 | ================================================================================ 104 | 105 | Architecture detection: 106 | -> Unknown architecture -------------------------------------------------------------------------------- /tests/lora_keys/qwen_image_edit_lora.txt: -------------------------------------------------------------------------------- 1 | Analyzing LoRA: loras/Qwen-LoRA.safetensors 2 | 3 | Total keys: 2160 4 | 5 | Sample keys (first 20): 6 | transformer_blocks.0.attn.add_k_proj.alpha 7 | transformer_blocks.0.attn.add_k_proj.lora_down.weight 8 | transformer_blocks.0.attn.add_k_proj.lora_up.weight 9 | transformer_blocks.0.attn.add_q_proj.alpha 10 | transformer_blocks.0.attn.add_q_proj.lora_down.weight 11 | transformer_blocks.0.attn.add_q_proj.lora_up.weight 12 | transformer_blocks.0.attn.add_v_proj.alpha 13 | transformer_blocks.0.attn.add_v_proj.lora_down.weight 14 | transformer_blocks.0.attn.add_v_proj.lora_up.weight 15 | transformer_blocks.0.attn.to_add_out.alpha 16 | transformer_blocks.0.attn.to_add_out.lora_down.weight 17 | transformer_blocks.0.attn.to_add_out.lora_up.weight 18 | transformer_blocks.0.attn.to_k.alpha 19 | transformer_blocks.0.attn.to_k.lora_down.weight 20 | transformer_blocks.0.attn.to_k.lora_up.weight 21 | transformer_blocks.0.attn.to_out.0.alpha 22 | transformer_blocks.0.attn.to_out.0.lora_down.weight 23 | transformer_blocks.0.attn.to_out.0.lora_up.weight 24 | transformer_blocks.0.attn.to_q.alpha 25 | transformer_blocks.0.attn.to_q.lora_down.weight 26 | 27 | ================================================================================ 28 | 29 | Unique layer components found: 30 | 0 31 | 1 32 | 10 33 | 11 34 | 12 35 | 13 36 | 14 37 | 15 38 | 16 39 | 17 40 | 18 41 | 19 42 | 2 43 | 20 44 | 21 45 | 22 46 | 23 47 | 24 48 | 25 49 | 26 50 | 27 51 | 28 52 | 29 53 | 3 54 | 30 55 | 31 56 | 32 57 | 33 58 | 34 59 | 35 60 | 36 61 | 37 62 | 38 63 | 39 64 | 4 65 | 40 66 | 41 67 | 42 68 | 43 69 | 44 70 | 45 71 | 46 72 | 47 73 | 48 74 | 49 75 | 5 76 | 50 77 | 51 78 | 52 79 | 53 80 | 54 81 | 55 82 | 56 83 | 57 84 | 58 85 | 59 86 | 6 87 | 7 88 | 8 89 | 9 90 | add_k_proj 91 | add_q_proj 92 | add_v_proj 93 | attn 94 | img_mlp 95 | net 96 | proj 97 | to_add_out 98 | to_k 99 | to_out 100 | to_q 101 | to_v 102 | transformer_blocks 103 | txt_mlp 104 | 105 | ================================================================================ 106 | 107 | Pattern matching analysis: 108 | 109 | ATTENTION (1440 keys): 110 | transformer_blocks.0.attn.add_k_proj.alpha 111 | transformer_blocks.0.attn.add_k_proj.lora_down.weight 112 | transformer_blocks.0.attn.add_k_proj.lora_up.weight 113 | transformer_blocks.0.attn.add_q_proj.alpha 114 | transformer_blocks.0.attn.add_q_proj.lora_down.weight 115 | ... and 1435 more 116 | 117 | MLP/FEEDFORWARD (720 keys): 118 | transformer_blocks.0.img_mlp.net.0.proj.alpha 119 | transformer_blocks.0.img_mlp.net.0.proj.lora_down.weight 120 | transformer_blocks.0.img_mlp.net.0.proj.lora_up.weight 121 | transformer_blocks.0.img_mlp.net.2.alpha 122 | transformer_blocks.0.img_mlp.net.2.lora_down.weight 123 | ... and 715 more 124 | 125 | PROJECTION (900 keys): 126 | transformer_blocks.0.attn.add_k_proj.alpha 127 | transformer_blocks.0.attn.add_k_proj.lora_down.weight 128 | transformer_blocks.0.attn.add_k_proj.lora_up.weight 129 | transformer_blocks.0.attn.add_q_proj.alpha 130 | transformer_blocks.0.attn.add_q_proj.lora_down.weight 131 | ... and 895 more 132 | 133 | ================================================================================ 134 | 135 | Architecture detection: 136 | -> Unknown architecture -------------------------------------------------------------------------------- /debug_keys.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Debug script to analyze LoRA layer keys from a Wan 2.2 model. 4 | Run this to see what layer keys are actually present in your LoRA. 5 | """ 6 | 7 | import sys 8 | import safetensors.torch 9 | 10 | def analyze_lora_keys(lora_path): 11 | """Load and analyze the layer keys in a LoRA file.""" 12 | print(f"\nAnalyzing LoRA: {lora_path}\n") 13 | 14 | try: 15 | # Load the LoRA 16 | if lora_path.endswith('.safetensors'): 17 | lora_dict = safetensors.torch.load_file(lora_path) 18 | else: 19 | import torch 20 | lora_dict = torch.load(lora_path, map_location='cpu') 21 | 22 | keys = list(lora_dict.keys()) 23 | print(f"Total keys: {len(keys)}\n") 24 | 25 | # Analyze key patterns 26 | print("Sample keys (first 20):") 27 | for key in keys[:20]: 28 | print(f" {key}") 29 | 30 | print("\n" + "="*80) 31 | 32 | # Extract unique layer components 33 | components = set() 34 | for key in keys: 35 | parts = key.split('.') 36 | for part in parts: 37 | if part not in ['lora_up', 'lora_down', 'alpha', 'weight', 'bias']: 38 | components.add(part) 39 | 40 | print("\nUnique layer components found:") 41 | for comp in sorted(components): 42 | print(f" {comp}") 43 | 44 | print("\n" + "="*80) 45 | 46 | # Check for common patterns 47 | patterns = { 48 | 'attention': ['attn', 'attention', 'self_attn', 'cross_attn'], 49 | 'mlp/feedforward': ['mlp', 'ff', 'feed_forward', 'ffn', 'fc'], 50 | 'projection': ['proj', 'projection'], 51 | 'norm': ['norm', 'ln', 'layer_norm'], 52 | 'embedding': ['emb', 'embedding', 'token'], 53 | } 54 | 55 | print("\nPattern matching analysis:") 56 | for category, pattern_list in patterns.items(): 57 | matching_keys = [] 58 | for key in keys: 59 | if any(pattern in key.lower() for pattern in pattern_list): 60 | matching_keys.append(key) 61 | 62 | if matching_keys: 63 | print(f"\n{category.upper()} ({len(matching_keys)} keys):") 64 | for key in matching_keys[:5]: 65 | print(f" {key}") 66 | if len(matching_keys) > 5: 67 | print(f" ... and {len(matching_keys) - 5} more") 68 | 69 | print("\n" + "="*80) 70 | 71 | # Try to identify architecture 72 | print("\nArchitecture detection:") 73 | key_str = ' '.join(keys).lower() 74 | 75 | if 'double_blocks' in key_str or 'single_blocks' in key_str: 76 | print(" -> Likely a Flux/DiT architecture (double_blocks/single_blocks)") 77 | elif 'joint_blocks' in key_str: 78 | print(" -> Likely a SD3/MMDiT architecture (joint_blocks)") 79 | elif 'input_blocks' in key_str or 'output_blocks' in key_str: 80 | print(" -> Likely a SD1.5/SDXL architecture (UNet)") 81 | else: 82 | print(" -> Unknown architecture") 83 | 84 | except Exception as e: 85 | print(f"Error loading LoRA: {e}") 86 | import traceback 87 | traceback.print_exc() 88 | 89 | if __name__ == "__main__": 90 | if len(sys.argv) < 2: 91 | print("Usage: python debug_keys.py ") 92 | print("\nExample:") 93 | print(" python debug_keys.py /path/to/your/wan2.2_lora.safetensors") 94 | sys.exit(1) 95 | 96 | analyze_lora_keys(sys.argv[1]) 97 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .src.lora_apply import LoraApply 2 | from .src.lora_block_sampler import LoRABlockSampler 3 | from .src.lora_mergekit_merge import LoraMergerMergekit 4 | from .src.lora_selector import LoRASelect 5 | from .src.lora_dir_stacker import LoraStackFromDir 6 | from .src.lora_decompose import LoraDecompose 7 | from .src.lora_parameter_sweep_sampler import LoRAParameterSweepSampler 8 | from .src.lora_power_stacker import LoraPowerStacker 9 | from .src.lora_resize import LoraResizer 10 | from .src.lora_save import LoraSave 11 | from .src.lora_stack_sampler import LoRAStackSampler 12 | from .src.nodes_lora_modifier import LoRAModifier 13 | from .src.nodes_merge_methods import TaskArithmeticMergeMethod, NearSwapMergeMethod, SCEMergeMethod, BreadcrumbsMergeMethod, \ 14 | TIESMergeMethod, DAREMergeMethod, DELLAMergeMethod, SLERPMergeMethod, LinearMergeMethod, NuSlerpMergeMethod, \ 15 | ArceeFusionMergeMethod, KArcherMergeMethod 16 | 17 | version_code = [2, 1, 0] 18 | version_str = f"V{version_code[0]}.{version_code[1]}" + (f'.{version_code[2]}' if len(version_code) > 2 else '') 19 | print(f"### Loading: ComfyUI LoRA-PowerMerge ({version_str})") 20 | 21 | NODE_CLASS_MAPPINGS = { 22 | "PM LoRA Merger (Mergekit)": LoraMergerMergekit, 23 | 24 | "PM LoRA Power Stacker": LoraPowerStacker, 25 | "PM LoRA Stacker (from Directory)": LoraStackFromDir, 26 | "PM LoRA Select": LoRASelect, 27 | "PM LoRA Stack Decompose": LoraDecompose, 28 | 29 | "PM LoRA Block Sampler": LoRABlockSampler, 30 | "PM LoRA Stack Sampler": LoRAStackSampler, 31 | "PM LoRA Parameter Sweep Sampler": LoRAParameterSweepSampler, 32 | 33 | "PM Slerp (Mergekit)": SLERPMergeMethod, 34 | "PM NuSlerp (Mergekit)": NuSlerpMergeMethod, 35 | "PM NearSwap (Mergekit)": NearSwapMergeMethod, 36 | "PM Arcee Fusion (Mergekit)": ArceeFusionMergeMethod, 37 | 38 | "PM Linear (Mergekit)": LinearMergeMethod, 39 | "PM SCE (Mergekit)": SCEMergeMethod, 40 | "PM KArcher (Mergekit)": KArcherMergeMethod, 41 | 42 | "PM Task Arithmetic (Mergekit)": TaskArithmeticMergeMethod, 43 | "PM Ties (Mergekit)": TIESMergeMethod, 44 | "PM Breadcrumbs (Mergekit)": BreadcrumbsMergeMethod, 45 | "PM Dare (Mergekit)": DAREMergeMethod, 46 | "PM Della (Mergekit)": DELLAMergeMethod, 47 | 48 | "PM LoRA Modifier": LoRAModifier, 49 | 50 | "PM LoRA Resizer": LoraResizer, 51 | "PM LoRA Apply": LoraApply, 52 | "PM LoRA Save": LoraSave, 53 | } 54 | 55 | NODE_DISPLAY_NAME_MAPPINGS = { 56 | "PM LoRA Power Stacker": "PM LoRA Power Stacker", 57 | "PM LoRA Stacker (from Directory)": "PM LoRA Stacker (from Directory)", 58 | "PM LoRA Select": "PM LoRA Select", 59 | "PM LoRA Stack Decompose": "PM LoRA Stack Decompose", 60 | 61 | "PM LoRA Merger (Mergekit)": "PM LoRA Merger (Mergekit)", 62 | 63 | "PM LoRA Block Sampler": "PM LoRA Block Sampler", 64 | "PM LoRA Stack Sampler": "PM LoRA Stack Sampler", 65 | "PM LoRA Parameter Sweep Sampler": "PM LoRA Parameter Sweep Sampler", 66 | 67 | "PM Slerp (Mergekit)": "PM Slerp (Mergekit)", 68 | "PM NuSlerp (Mergekit)": "PM NuSlerp (Mergekit)", 69 | "PM NearSwap (Mergekit)": "PM NearSwap (Mergekit)", 70 | "PM Arcee Fusion (Mergekit)": "PM Arcee Fusion (Mergekit)", 71 | 72 | "PM Linear (Mergekit)": "PM Linear (Mergekit)", 73 | "PM SCE (Mergekit)": "PM SCE (Mergekit)", 74 | "PM KArcher (Mergekit)": "PM KArcher (Mergekit)", 75 | 76 | "PM Task Arithmetic (Mergekit)": "PM Task Arithmetic (Mergekit)", 77 | "PM TIES (Mergekit)": "PM TIES (Mergekit)", 78 | "PM DARE (Mergekit)": "PM DARE (Mergekit)", 79 | "PM Breadcrumbs (Mergekit)": "PM Breadcrumbs (Mergekit)", 80 | "PM Della (Mergekit)": "PM Della (Mergekit)", 81 | 82 | "PM LoRA Modifier": "PM LoRA Modifier", 83 | 84 | "PM LoRA Resizer": "PM Resize LoRA", 85 | "PM LoRA Apply": "PM Apply LoRA", 86 | "PM LoRA Save": "PM Save LoRA", 87 | } 88 | 89 | 90 | WEB_DIRECTORY = "./js" 91 | 92 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"] 93 | -------------------------------------------------------------------------------- /src/lora_dir_stacker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import comfy.lora 3 | import comfy.utils 4 | 5 | from .merge import parse_layer_filter, apply_layer_filter 6 | from .types import LORA_STACK, LORA_WEIGHTS 7 | from .utils import LayerFilter 8 | 9 | 10 | class LoraStackFromDir: 11 | """ 12 | Node for loading LoRA weights 13 | """ 14 | 15 | @classmethod 16 | def INPUT_TYPES(s): 17 | return { 18 | "required": { 19 | "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), 20 | "directory": ("STRING",), 21 | "strength_model": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "General model strength applied to all LoRAs."}), 22 | "strength_clip": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "General CLIP strength applied to all LoRAs."}), 23 | "layer_filter": ( 24 | list(LayerFilter.PRESETS.keys()), {"default": "full", "tooltip": "Filter for specific layers."}), 25 | "sort_by": (["name", "name descending", "date", "date descending"], 26 | {"default": "name", "tooltip": "Sort LoRAs by name or size."}), 27 | "limit": ("INT", {"default": -1, "min": -1, "max": 1000, "tooltip": "Limit the number of LoRAs to load."}), 28 | }, 29 | } 30 | 31 | RETURN_TYPES = ("LoRAStack", "LoRAWeights", "LoRARawDict",) 32 | FUNCTION = "stack_loras" 33 | CATEGORY = "LoRA PowerMerge" 34 | DESCRIPTION = "Stacks LoRA weights from the given directory and applies them to the model." 35 | 36 | def stack_loras(self, model, directory, strength_model: float = 1.0, strength_clip: float = 1.0, 37 | layer_filter=None, sort_by: str = None, limit: int = 0) -> \ 38 | (LORA_STACK, LORA_WEIGHTS, dict): 39 | # check if directory exists 40 | if not os.path.isdir(directory): 41 | raise FileNotFoundError(f"Directory {directory} does not exist.") 42 | 43 | key_map = {} 44 | if model is not None: 45 | key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) 46 | 47 | layer_filter = parse_layer_filter(layer_filter) 48 | 49 | # Load LoRAs and patch key names 50 | lora_patch_dicts = {} 51 | lora_strengths = {} 52 | lora_raw_dicts = {} # Store raw LoRA state dicts for CLIP weights 53 | 54 | # Load LoRAs from the directory 55 | # walk over files in the directory 56 | for root, _, files in os.walk(directory): 57 | # Sort files based on the specified criteria 58 | if sort_by == "name": 59 | files = sorted(files) 60 | elif sort_by == "name descending": 61 | files = sorted(files, reverse=True) 62 | elif sort_by == "date": 63 | files = sorted(files, key=lambda f: os.path.getmtime(os.path.join(root, f))) 64 | elif sort_by == "date descending": 65 | files = sorted(files, key=lambda f: os.path.getmtime(os.path.join(root, f)), reverse=True) 66 | # Limit the number of LoRAs to load 67 | if limit > 0: 68 | files = files[:limit] 69 | 70 | for file in files: 71 | if file.endswith(".safetensors") or file.endswith(".ckpt"): 72 | lora_path = os.path.join(root, file) 73 | lora_name = os.path.splitext(file)[0] 74 | lora_raw = comfy.utils.load_torch_file(lora_path, safe_load=True) 75 | patch_dict = comfy.lora.load_lora(lora_raw, key_map) 76 | patch_dict = apply_layer_filter(patch_dict, layer_filter) 77 | lora_patch_dicts[lora_name] = patch_dict 78 | lora_strengths[lora_name] = { 79 | 'strength_model': strength_model, 80 | 'strength_clip': strength_clip, 81 | } 82 | lora_raw_dicts[lora_name] = lora_raw # Store raw state dict 83 | 84 | return lora_patch_dicts, lora_strengths, lora_raw_dicts, 85 | -------------------------------------------------------------------------------- /src/nodes_lora_modifier.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from typing import Dict, Any 4 | 5 | import torch 6 | 7 | # init logging to log with a prefix 8 | logging.basicConfig(level=logging.INFO, format='[LoRAModifier] %(message)s') 9 | 10 | from comfy.weight_adapter import LoRAAdapter 11 | from .architectures import sd_lora, dit_lora 12 | from .architectures.general_architecture import LORA_STACK, LORA_KEY_DICT 13 | 14 | class LoRAModifier: 15 | @classmethod 16 | def INPUT_TYPES(cls): 17 | return { 18 | "required": { 19 | "key_dicts": ("LoRAStack", {"tooltip": "The dictionary containing LoRA names and key weights."}), 20 | "blocks_store": ("STRING", {"multiline": False}), 21 | }, 22 | } 23 | 24 | RETURN_TYPES = ("LoRAStack",) 25 | 26 | FUNCTION = "run" 27 | CATEGORY = "LoRA PowerMerge" 28 | 29 | def run(self, key_dicts: LORA_STACK, blocks_store: str): 30 | widget_data: dict = self.parse(blocks_store) 31 | arch: str = widget_data.get("mode", "sdxl_unet") 32 | block_scale_dict = widget_data.get("blockScales", {}) 33 | 34 | # Workaround for middle_block expected but middle_blocks provided 35 | if "middle_blocks.1" in block_scale_dict: 36 | block_scale_dict["middle_block.1"] = block_scale_dict.pop("middle_blocks.1") 37 | print("Block scale dict:", block_scale_dict) 38 | 39 | new_key_dicts = {} 40 | for lora_name, patch_dict in key_dicts.items(): 41 | patch_dict_modified = self.apply(patch_dict, block_scale_dict, architecture=arch) 42 | new_key_dicts[lora_name] = patch_dict_modified 43 | 44 | return (new_key_dicts,) 45 | 46 | def apply(self, patch_dict : LORA_KEY_DICT, block_scale_dict: dict, architecture: str): 47 | # Iterate over keys in the LoRA adapter 48 | # Sum up the total weight of tensors for debugging 49 | total_weight = 0.0 50 | total_weight_after = 0.0 51 | patch_dict_filtered = {} 52 | for layer_key, adapter in patch_dict.items(): 53 | total_weight += torch.sum(adapter.weights[0]).item() + torch.sum(adapter.weights[1]).item() 54 | 55 | # copy the weights to avoid modifying the original adapter 56 | new_weights = [] 57 | for weight in adapter.weights: 58 | # copy if tensor 59 | if isinstance(weight, torch.Tensor): 60 | new_weights.append(weight.clone()) 61 | else: 62 | new_weights.append(weight) 63 | 64 | # Select the appropriate detect function based on architecture 65 | if "dit" in architecture: 66 | block_names = dit_lora.detect_block_names(layer_key) 67 | else: # sd/sdxl 68 | block_names = sd_lora.detect_block_names(layer_key) 69 | if (block_names is None or "main_block" not in block_names 70 | or block_names["main_block"] not in block_scale_dict): 71 | # Skip scaling for this layer 72 | logging.info(f"Skipping layer {layer_key} as it was not mentioned by the block scale dict.") 73 | else: 74 | scale_factor = float(block_scale_dict[block_names["main_block"]]) 75 | # Apply the scale factor to the weights 76 | new_weights[0] *= scale_factor 77 | new_weights[1] *= scale_factor 78 | # Sum up the total weight of tensors for debugging 79 | total_weight_after += torch.sum(new_weights[0]).item() + torch.sum(new_weights[1]).item() 80 | 81 | # Convert list to tuple to match LoRAAdapter expectations 82 | # LoRAAdapter expects (loaded_keys, weights) where weights is a tuple 83 | patch_dict_filtered[layer_key] = LoRAAdapter(loaded_keys=adapter.loaded_keys, weights=tuple(new_weights)) 84 | logging.info(f"Modified LoRA: {len(patch_dict_filtered)} layers after scaling.") 85 | logging.info(f"Total weight before scaling: {total_weight}, after scaling: {total_weight_after}") 86 | return patch_dict_filtered 87 | 88 | def parse(self, stringified: str) -> Dict[str, Any]: 89 | try: 90 | return json.loads(stringified) # This will now be a proper JSON string 91 | except: 92 | print(f"Failed to parse JSON string: {stringified}.\n Returning empty dictionary.") 93 | return {} 94 | 95 | 96 | -------------------------------------------------------------------------------- /src/architectures/dit_lora.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Dict, Optional 3 | 4 | from ..types import DiTBlockNameInfo 5 | 6 | 7 | def detect_block_names(layer_key: str, layers_per_group: int = 5) -> Optional[DiTBlockNameInfo]: 8 | """ 9 | Detect block names for DiT (Diffusion Transformer) architecture. 10 | 11 | DiT models have a flat structure with sequential transformer layers (e.g., layers.0 through layers.39). 12 | This function groups them into "main blocks" for easier block-wise manipulation. 13 | 14 | Args: 15 | layer_key: The layer key (string or tuple). If tuple, first element is used as the key string. 16 | (e.g., "diffusion_model.layers.13.attention.qkv.weight" or 17 | ("diffusion_model.layers.13.attention.qkv.weight", 0)) 18 | layers_per_group: Number of layers to group into one main block (default: 5) 19 | For a 40-layer model, this creates 8 main blocks 20 | 21 | Returns: 22 | Dictionary with block information, or None if the key doesn't match DiT pattern 23 | 24 | Examples: 25 | >>> detect_block_names("diffusion_model.layers.13.attention.qkv.weight", layers_per_group=5) 26 | { 27 | "layer_idx": "13", 28 | "component": "attention", 29 | "main_block": "layers_group.2", # layers 10-14 = group 2 30 | "sub_block": "layers.13" 31 | } 32 | """ 33 | # Convert tuple keys to strings (ComfyUI uses tuple keys) 34 | if isinstance(layer_key, tuple): 35 | layer_key = layer_key[0] 36 | 37 | # DiT pattern: diffusion_model.layers.{idx}.{component} 38 | # Components typically include: attention, norm1, norm2, mlp, feed_forward, etc. 39 | exp_dit = re.compile(r""" 40 | (?:diffusion_model\.)? # optional prefix 41 | layers\. 42 | (?P\d+) # layer index (0-39 for 40-layer model) 43 | \. 44 | (?P[a-zA-Z_][a-zA-Z0-9_]*) # component type (any valid identifier) 45 | (?:\..+)? # allow nested submodules (e.g. .weight, .to_q.weight) 46 | """, re.VERBOSE) 47 | 48 | match = exp_dit.search(layer_key) 49 | if match: 50 | layer_idx = int(match.group("layer_idx")) 51 | component = match.group("component") 52 | 53 | # Calculate which group this layer belongs to 54 | # For layers_per_group=5: layers 0-4 -> group 0, layers 5-9 -> group 1, etc. 55 | group_idx = layer_idx // layers_per_group 56 | 57 | out = { 58 | "layer_idx": match.group("layer_idx"), 59 | "component": component, 60 | "main_block": f"layers_group.{group_idx}", 61 | "sub_block": f"layers.{layer_idx}", 62 | "group_idx": group_idx, 63 | "group_start": group_idx * layers_per_group, 64 | "group_end": (group_idx + 1) * layers_per_group - 1, 65 | } 66 | return out 67 | 68 | return None 69 | 70 | 71 | def get_group_count(total_layers: int, layers_per_group: int = 5) -> int: 72 | """ 73 | Calculate the number of groups for a given number of layers. 74 | 75 | Args: 76 | total_layers: Total number of transformer layers in the model 77 | layers_per_group: Number of layers per group 78 | 79 | Returns: 80 | Number of groups 81 | 82 | Examples: 83 | >>> get_group_count(40, 5) 84 | 8 85 | >>> get_group_count(28, 5) 86 | 6 87 | """ 88 | return (total_layers + layers_per_group - 1) // layers_per_group 89 | 90 | 91 | def detect_architecture(patch_dict: Dict) -> Optional[str]: 92 | """ 93 | Auto-detect if a LoRA uses DiT architecture by examining its keys. 94 | 95 | Args: 96 | patch_dict: Dictionary of LoRA patches 97 | 98 | Returns: 99 | "dit" if DiT architecture is detected, None otherwise 100 | """ 101 | # Sample a few keys to check 102 | sample_size = min(10, len(patch_dict)) 103 | sample_keys = list(patch_dict.keys())[:sample_size] 104 | 105 | dit_pattern = re.compile(r"(?:diffusion_model\.)?layers\.\d+\.") 106 | 107 | dit_matches = sum(1 for key in sample_keys if dit_pattern.search(str(key))) 108 | 109 | # If more than 50% of sampled keys match DiT pattern, consider it DiT 110 | if dit_matches / sample_size > 0.5: 111 | return "dit" 112 | 113 | return None 114 | -------------------------------------------------------------------------------- /src/decomposition/svd.py: -------------------------------------------------------------------------------- 1 | """ 2 | SVD-based tensor decomposition implementations. 3 | 4 | Provides SVD, randomized SVD, and energy-based randomized SVD decomposers. 5 | """ 6 | 7 | import logging 8 | from typing import Tuple 9 | 10 | import torch 11 | 12 | from .base import TensorDecomposer, SingularValueDistribution 13 | 14 | 15 | class SVDDecomposer(TensorDecomposer): 16 | """ 17 | Standard SVD decomposition. 18 | 19 | Uses torch.linalg.svd for full singular value decomposition. 20 | Most accurate but slower for large matrices. 21 | """ 22 | 23 | def _decompose_2d( 24 | self, 25 | weight_2d: torch.Tensor 26 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 27 | """ 28 | Perform full SVD decomposition. 29 | 30 | Args: 31 | weight_2d: 2D weight tensor 32 | 33 | Returns: 34 | Tuple of (U, S, Vh) 35 | 36 | Raises: 37 | RuntimeError: If SVD computation fails 38 | """ 39 | try: 40 | U, S, Vh = torch.linalg.svd(weight_2d, full_matrices=False) 41 | return U, S, Vh 42 | except RuntimeError as e: 43 | if "singular value decomposition" in str(e).lower(): 44 | raise RuntimeError( 45 | f"SVD failed for tensor shape {weight_2d.shape}. " 46 | "Matrix may be singular or ill-conditioned." 47 | ) from e 48 | raise 49 | 50 | 51 | class RandomizedSVDDecomposer(TensorDecomposer): 52 | """ 53 | Randomized SVD decomposition. 54 | 55 | Faster approximation of SVD for large matrices. 56 | Uses randomized algorithm with power iterations. 57 | """ 58 | 59 | def __init__( 60 | self, 61 | n_oversamples: int = 10, 62 | n_iter: int = 2, 63 | **kwargs 64 | ): 65 | """ 66 | Initialize randomized SVD decomposer. 67 | 68 | Args: 69 | n_oversamples: Number of additional samples for randomization 70 | n_iter: Number of power iterations for accuracy 71 | **kwargs: Passed to parent TensorDecomposer 72 | """ 73 | super().__init__(**kwargs) 74 | self.n_oversamples = n_oversamples 75 | self.n_iter = n_iter 76 | 77 | def _decompose_2d( 78 | self, 79 | weight_2d: torch.Tensor 80 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 81 | """ 82 | Perform randomized SVD. 83 | 84 | Args: 85 | weight_2d: 2D weight tensor 86 | 87 | Returns: 88 | Tuple of (U, S, Vh) 89 | """ 90 | # For small matrices, use standard SVD 91 | if min(weight_2d.shape) < 100: 92 | return torch.linalg.svd(weight_2d, full_matrices=False) 93 | 94 | # Randomized SVD implementation 95 | m, n = weight_2d.shape 96 | rank = min(m, n) 97 | 98 | # Determine sketch size 99 | sketch_size = min(rank, rank + self.n_oversamples) 100 | 101 | # Random projection 102 | Omega = torch.randn( 103 | n, sketch_size, 104 | dtype=weight_2d.dtype, 105 | device=weight_2d.device 106 | ) 107 | 108 | # Compute sketch 109 | Y = weight_2d @ Omega 110 | 111 | # Power iterations for better approximation 112 | for _ in range(self.n_iter): 113 | Y = weight_2d @ (weight_2d.T @ Y) 114 | 115 | # Orthogonalize 116 | Q, _ = torch.linalg.qr(Y) 117 | 118 | # Project and decompose 119 | B = Q.T @ weight_2d 120 | U_b, S, Vh = torch.linalg.svd(B, full_matrices=False) 121 | 122 | # Recover U 123 | U = Q @ U_b 124 | 125 | return U, S, Vh 126 | 127 | 128 | class EnergyBasedRandomizedSVDDecomposer(RandomizedSVDDecomposer): 129 | """ 130 | Energy-based randomized SVD. 131 | 132 | Adaptive randomized SVD that adjusts sketch size based on 133 | spectral energy distribution. 134 | """ 135 | 136 | def __init__( 137 | self, 138 | energy_threshold: float = 0.99, 139 | **kwargs 140 | ): 141 | """ 142 | Initialize energy-based randomized SVD. 143 | 144 | Args: 145 | energy_threshold: Target energy retention (0.0 to 1.0) 146 | **kwargs: Passed to parent RandomizedSVDDecomposer 147 | """ 148 | super().__init__(**kwargs) 149 | self.energy_threshold = energy_threshold 150 | 151 | def _decompose_2d( 152 | self, 153 | weight_2d: torch.Tensor 154 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 155 | """ 156 | Perform energy-based randomized SVD. 157 | 158 | First estimates spectral energy, then adapts sketch size. 159 | 160 | Args: 161 | weight_2d: 2D weight tensor 162 | 163 | Returns: 164 | Tuple of (U, S, Vh) 165 | """ 166 | # Perform initial randomized SVD with moderate sketch size 167 | U, S, Vh = super()._decompose_2d(weight_2d) 168 | 169 | # Calculate energy retention 170 | S_squared = S.pow(2) 171 | cumulative_energy = torch.cumsum(S_squared, dim=0) / torch.sum(S_squared) 172 | 173 | # Find rank that meets energy threshold 174 | energy_rank = torch.searchsorted( 175 | cumulative_energy, 176 | self.energy_threshold 177 | ).item() + 1 178 | 179 | logging.debug( 180 | f"Energy-based SVD: {energy_rank}/{len(S)} components " 181 | f"retain {self.energy_threshold*100}% energy" 182 | ) 183 | 184 | # Return with energy-based rank suggestion (truncation happens in base class) 185 | return U, S, Vh -------------------------------------------------------------------------------- /src/merge/base_node.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for merge method nodes. 3 | 4 | Provides common functionality for all merge method nodes to eliminate boilerplate code. 5 | Each merge method node only needs to define INPUT_TYPES and get_settings(). 6 | """ 7 | 8 | from abc import ABC, abstractmethod 9 | from typing import Dict, Any, ClassVar 10 | 11 | 12 | class BaseMergeMethodNode(ABC): 13 | """ 14 | Abstract base class for merge method nodes. 15 | 16 | Subclasses should: 17 | 1. Define INPUT_TYPES() classmethod with method-specific parameters 18 | 2. Implement get_settings(**kwargs) to return settings dictionary 19 | 3. Set METHOD_NAME class variable 20 | 21 | The base class provides: 22 | - Common RETURN_TYPES, FUNCTION, CATEGORY 23 | - get_method() implementation 24 | - Consistent structure across all merge method nodes 25 | 26 | Example: 27 | class LinearMergeMethod(BaseMergeMethodNode): 28 | METHOD_NAME = "linear" 29 | CATEGORY = "LoRA PowerMerge" 30 | 31 | @classmethod 32 | def INPUT_TYPES(cls): 33 | return { 34 | "required": { 35 | "normalize": ("BOOLEAN", {"default": True}), 36 | } 37 | } 38 | 39 | def get_settings(self, normalize: bool): 40 | return {"normalize": normalize} 41 | """ 42 | 43 | # Must be set by subclass 44 | METHOD_NAME: ClassVar[str] = "" 45 | 46 | # ComfyUI node configuration 47 | RETURN_TYPES = ("MergeMethod",) 48 | FUNCTION = "get_method" 49 | CATEGORY = "LoRA PowerMerge" # Can be overridden by subclass 50 | 51 | @classmethod 52 | @abstractmethod 53 | def INPUT_TYPES(cls) -> Dict[str, Any]: 54 | """ 55 | Define input types for this merge method. 56 | 57 | Should return a dictionary with "required" and optionally "optional" keys. 58 | 59 | Returns: 60 | Dictionary defining input parameters 61 | """ 62 | pass 63 | 64 | @abstractmethod 65 | def get_settings(self, **kwargs) -> Dict[str, Any]: 66 | """ 67 | Convert input parameters to settings dictionary. 68 | 69 | Args: 70 | **kwargs: Input parameters from ComfyUI 71 | 72 | Returns: 73 | Settings dictionary for this merge method 74 | """ 75 | pass 76 | 77 | def get_method(self, **kwargs) -> tuple: 78 | """ 79 | Create merge method definition dictionary. 80 | 81 | This method is called by ComfyUI. It combines the method name 82 | with the settings returned by get_settings(). 83 | 84 | Args: 85 | **kwargs: Input parameters from ComfyUI 86 | 87 | Returns: 88 | Tuple containing method definition dictionary 89 | """ 90 | if not self.METHOD_NAME: 91 | raise NotImplementedError( 92 | f"{self.__class__.__name__} must set METHOD_NAME class variable" 93 | ) 94 | 95 | method_def = { 96 | "name": self.METHOD_NAME, 97 | "settings": self.get_settings(**kwargs) 98 | } 99 | 100 | return (method_def,) 101 | 102 | 103 | class BaseTaskArithmeticNode(BaseMergeMethodNode): 104 | """ 105 | Base class for task arithmetic merge method nodes. 106 | 107 | Provides common parameters shared by GTA methods (TIES, DARE, DELLA, etc.). 108 | 109 | Subclasses only need to set METHOD_NAME and optionally override get_extra_inputs() 110 | for method-specific parameters. 111 | """ 112 | 113 | CATEGORY = "LoRA PowerMerge/Task Arithmetic" 114 | 115 | @classmethod 116 | def get_extra_inputs(cls) -> Dict[str, Any]: 117 | """ 118 | Define extra method-specific inputs. 119 | 120 | Override this in subclasses to add method-specific parameters. 121 | 122 | Returns: 123 | Dictionary of extra input parameters 124 | """ 125 | return {} 126 | 127 | @classmethod 128 | def INPUT_TYPES(cls) -> Dict[str, Any]: 129 | """ 130 | Define common task arithmetic inputs. 131 | 132 | Includes rescale_norm and normalize, plus any extra inputs 133 | from get_extra_inputs(). 134 | """ 135 | base_inputs = { 136 | "required": { 137 | "rescale_norm": ( 138 | ["default", "l1", "none"], 139 | { 140 | "default": "default", 141 | "tooltip": "Norm rescaling strategy for task vectors" 142 | } 143 | ), 144 | "normalize": ( 145 | "BOOLEAN", 146 | { 147 | "default": True, 148 | "tooltip": "Normalize weights to sum to 1" 149 | } 150 | ), 151 | } 152 | } 153 | 154 | # Merge with extra inputs 155 | extra_inputs = cls.get_extra_inputs() 156 | if extra_inputs: 157 | base_inputs["required"].update(extra_inputs) 158 | 159 | return base_inputs 160 | 161 | def get_settings( 162 | self, 163 | rescale_norm: str = "default", 164 | normalize: bool = True, 165 | **kwargs 166 | ) -> Dict[str, Any]: 167 | """ 168 | Convert inputs to settings dictionary. 169 | 170 | Args: 171 | rescale_norm: Norm rescaling strategy 172 | normalize: Whether to normalize weights 173 | **kwargs: Extra method-specific parameters 174 | 175 | Returns: 176 | Settings dictionary 177 | """ 178 | settings = { 179 | "rescale_norm": rescale_norm, 180 | "normalize": normalize, 181 | } 182 | 183 | # Add any extra settings 184 | settings.update(kwargs) 185 | 186 | return settings 187 | -------------------------------------------------------------------------------- /js/dit.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | DiT Architecture 7 | 8 | 9 | 10 | 11 | 12 | 13 | Group 7 (Layers 35-39) 14 | 15 | 16 | SF: 1.00 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | Group 6 (Layers 30-34) 25 | 26 | 27 | SF: 1.00 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | Group 5 (Layers 25-29) 36 | 37 | 38 | SF: 1.00 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | Group 4 (Layers 20-24) 47 | 48 | 49 | SF: 1.00 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | Group 3 (Layers 15-19) 58 | 59 | 60 | SF: 1.00 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | Group 2 (Layers 10-14) 69 | 70 | 71 | SF: 1.00 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | Group 1 (Layers 5-9) 80 | 81 | 82 | SF: 1.00 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | Group 0 (Layers 0-4) 91 | 92 | 93 | SF: 1.00 94 | 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /js/lora-merge.js: -------------------------------------------------------------------------------- 1 | import {app} from "../../../scripts/app.js"; 2 | 3 | app.registerExtension({ 4 | name: "Comfy.LoRAMerger", 5 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 6 | 7 | if (nodeData.name === 'PM LoRA Merger' || nodeData.name === 'PM LoRA SVD Merger') { 8 | nodeType.prototype.onConnectionsChange = function (type, index, connected) { 9 | if (type !== 1) return; 10 | 11 | this.inputs.forEach((input, i) => input.name = `lora${i + 1}`); 12 | 13 | if (connected && this.inputs[this.inputs.length - 1].link !== null) { 14 | this.addInput(`lora${this.inputs.length + 1}`, this.inputs[0].type); 15 | } else { 16 | if (this.inputs.length > 1 && this.inputs[this.inputs.length - 2].link == null) 17 | this.removeInput(this.inputs.length - 1); 18 | } 19 | } 20 | } 21 | 22 | if (nodeData.name === 'PM LoRA Stacker') { 23 | // Add initialization to ensure proper state on load 24 | nodeType.prototype.onNodeCreated = function() { 25 | this.ensureExactlyOneFreeLoRASlot = () => { 26 | const first_lora_idx = 1; // model, lora1, ..., layer_filter 27 | const last_input_idx = this.inputs.length - 1; 28 | const layer_filter_idx = last_input_idx; 29 | 30 | // Find all LoRA inputs (between first_lora_idx and layer_filter) 31 | let lora_inputs = []; 32 | for (let i = first_lora_idx; i < layer_filter_idx; i++) { 33 | lora_inputs.push(i); 34 | } 35 | 36 | // Count connected and disconnected LoRA slots 37 | let connected_count = 0; 38 | let disconnected_count = 0; 39 | 40 | for (let idx of lora_inputs) { 41 | if (this.inputs[idx].link !== null && this.inputs[idx].link !== undefined) { 42 | connected_count++; 43 | } else { 44 | disconnected_count++; 45 | } 46 | } 47 | 48 | // Ensure exactly one free slot 49 | if (disconnected_count === 0) { 50 | // Need to add a slot - insert before layer_filter 51 | const new_lora_num = lora_inputs.length + 1; 52 | const lora_type = this.inputs[first_lora_idx].type; 53 | this.addInput(`lora${new_lora_num}`, lora_type); 54 | 55 | // Move layer_filter to the end 56 | let temp = this.inputs[this.inputs.length - 1]; 57 | this.inputs[this.inputs.length - 1] = this.inputs[this.inputs.length - 2]; 58 | this.inputs[this.inputs.length - 2] = temp; 59 | } else if (disconnected_count > 1) { 60 | // Too many free slots - remove extras from the end (but keep layer_filter) 61 | let to_remove = disconnected_count - 1; 62 | for (let i = layer_filter_idx - 1; i >= first_lora_idx && to_remove > 0; i--) { 63 | if (this.inputs[i].link === null || this.inputs[i].link === undefined) { 64 | this.removeInput(i); 65 | to_remove--; 66 | } 67 | } 68 | } 69 | 70 | // Renumber all LoRA inputs sequentially 71 | let lora_num = 1; 72 | for (let i = first_lora_idx; i < this.inputs.length - 1; i++) { 73 | this.inputs[i].name = `lora${lora_num}`; 74 | lora_num++; 75 | } 76 | 77 | this.computeSize(); 78 | }; 79 | }; 80 | 81 | nodeType.prototype.onConnectionsChange = function (type, index, connected) { 82 | // Check if the event type is 1 (input) 83 | if (type !== 1) return; 84 | 85 | const first_lora_idx = 1; 86 | const last_input_idx = this.inputs.length - 1; 87 | 88 | // Ignore changes to model, clip, or layer_filter 89 | if (index < first_lora_idx || index === last_input_idx) return; 90 | 91 | // Ensure the layer_filter is always at the end 92 | if (this.inputs[index].name === "layer_filter" && index !== last_input_idx) { 93 | // Move the layer_filter input to the end 94 | let temp_type = this.inputs[index].type; 95 | let temp_widget = this.inputs[index].widget; 96 | this.removeInput(index); 97 | this.addInput(`layer_filter`, temp_type); 98 | this.inputs[last_input_idx].widget = temp_widget; 99 | this.computeSize(); 100 | return; 101 | } 102 | 103 | // Use the helper function to maintain exactly one free slot 104 | if (this.ensureExactlyOneFreeLoRASlot) { 105 | this.ensureExactlyOneFreeLoRASlot(); 106 | } 107 | }; 108 | 109 | // Call initialization on graph load to fix state after reload 110 | const origOnGraphConfigured = nodeType.prototype.onGraphConfigured; 111 | nodeType.prototype.onGraphConfigured = function() { 112 | if (origOnGraphConfigured) { 113 | origOnGraphConfigured.apply(this, arguments); 114 | } 115 | // Ensure proper state after loading from workflow 116 | if (this.ensureExactlyOneFreeLoRASlot) { 117 | setTimeout(() => { 118 | this.ensureExactlyOneFreeLoRASlot(); 119 | }, 100); 120 | } 121 | }; 122 | } 123 | }, 124 | }); -------------------------------------------------------------------------------- /src/utils/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration constants for LoRA Power-Merger. 3 | 4 | Centralizes all magic numbers and configuration values used throughout the codebase. 5 | """ 6 | 7 | # ============================================================================ 8 | # SVD/Decomposition Constants 9 | # ============================================================================ 10 | 11 | # Minimum singular value threshold for SVD operations 12 | MIN_SINGULAR_VALUE = 1e-6 13 | 14 | # Default SVD parameters 15 | DEFAULT_SVD_RANK = 16 16 | DEFAULT_SVD_DISTRIBUTION = "symmetric" # or "asymmetric" 17 | 18 | # Dynamic rank selection defaults 19 | DEFAULT_SV_RATIO = 100.0 20 | DEFAULT_SV_CUMULATIVE = 0.95 21 | DEFAULT_SV_FRO = 0.99 22 | 23 | 24 | # ============================================================================ 25 | # Merge Operation Constants 26 | # ============================================================================ 27 | 28 | # Maximum number of worker threads for parallel processing 29 | MAX_MERGE_WORKERS = 8 30 | 31 | # Default lambda scaling factor 32 | DEFAULT_LAMBDA = 1.0 33 | 34 | # Default normalization setting 35 | DEFAULT_NORMALIZE = True 36 | 37 | 38 | # ============================================================================ 39 | # Validation Constants 40 | # ============================================================================ 41 | 42 | # Minimum number of LoRAs required for merge 43 | MIN_LORAS_FOR_MERGE = 2 44 | 45 | # Minimum key overlap ratio to avoid warnings 46 | MIN_KEY_OVERLAP_RATIO = 0.5 47 | 48 | # Typical strength value range (for warnings) 49 | TYPICAL_STRENGTH_MIN = 0.0 50 | TYPICAL_STRENGTH_MAX = 1.0 51 | 52 | 53 | # ============================================================================ 54 | # Device and Memory Constants 55 | # ============================================================================ 56 | 57 | # Supported device types 58 | SUPPORTED_DEVICES = ["cpu", "cuda", "mps", "auto"] 59 | 60 | # Supported dtype strings 61 | SUPPORTED_DTYPES = [ 62 | "float16", "float32", "float64", 63 | "bfloat16", 64 | "int8", "int16", "int32", "int64", 65 | ] 66 | 67 | # Default device for computation 68 | DEFAULT_DEVICE = "cpu" 69 | 70 | # Default dtype for computation 71 | DEFAULT_DTYPE = "float32" 72 | 73 | 74 | # ============================================================================ 75 | # Progress Bar Constants 76 | # ============================================================================ 77 | 78 | # Update frequency for progress bars (in seconds) 79 | PROGRESS_UPDATE_INTERVAL = 0.1 80 | 81 | # Default progress bar format 82 | PROGRESS_BAR_FORMAT = "{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt}" 83 | 84 | 85 | # ============================================================================ 86 | # Layer Filter Constants 87 | # ============================================================================ 88 | 89 | # Architecture-agnostic layer filter sets (works for SD, DiT, Flux, and Wan) 90 | # Note: 'attn' is added as a general pattern to catch Flux keys like 'img_attn_proj' 91 | ATTENTION_LAYERS = {"attn", "attn1", "attn2", "attention", "self_attn", "cross_attn"} 92 | MLP_LAYERS = {"ff", "mlp", "feed_forward", "ffn"} 93 | ATTENTION_MLP_LAYERS = {"attn", "attn1", "attn2", "attention", "self_attn", "cross_attn", "ff", "mlp", "feed_forward", "ffn"} 94 | 95 | # Legacy architecture-specific constants (deprecated, kept for backward compatibility) 96 | SD_ATTENTION_LAYERS = {"attn1", "attn2"} 97 | SD_MLP_LAYERS = {"ff"} 98 | SD_ATTENTION_MLP_LAYERS = {"attn1", "attn2", "ff"} 99 | SD_PROJECTION_LAYERS = {"proj_in", "proj_out"} 100 | DIT_ATTENTION_LAYERS = {"attention"} 101 | DIT_MLP_LAYERS = {"mlp", "feed_forward"} 102 | WAN_ATTENTION_LAYERS = {"self_attn", "cross_attn"} 103 | WAN_MLP_LAYERS = {"ffn"} 104 | 105 | 106 | # ============================================================================ 107 | # File I/O Constants 108 | # ============================================================================ 109 | 110 | # Supported LoRA file extensions 111 | LORA_FILE_EXTENSIONS = [".safetensors", ".pt", ".pth", ".ckpt"] 112 | 113 | # Default LoRA save format 114 | DEFAULT_LORA_SAVE_FORMAT = "safetensors" 115 | 116 | 117 | # ============================================================================ 118 | # Caching Constants 119 | # ============================================================================ 120 | 121 | # Maximum cache size for decomposition results (number of entries) 122 | MAX_DECOMPOSITION_CACHE_SIZE = 100 123 | 124 | # Cache TTL in seconds (time-to-live) 125 | CACHE_TTL_SECONDS = 3600 # 1 hour 126 | 127 | 128 | # ============================================================================ 129 | # Logging Constants 130 | # ============================================================================ 131 | 132 | # Default logging level 133 | DEFAULT_LOG_LEVEL = "INFO" 134 | 135 | # Log format 136 | LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 137 | 138 | 139 | # ============================================================================ 140 | # Export All Constants 141 | # ============================================================================ 142 | 143 | __all__ = [ 144 | # SVD/Decomposition 145 | 'MIN_SINGULAR_VALUE', 146 | 'DEFAULT_SVD_RANK', 147 | 'DEFAULT_SVD_DISTRIBUTION', 148 | 'DEFAULT_SV_RATIO', 149 | 'DEFAULT_SV_CUMULATIVE', 150 | 'DEFAULT_SV_FRO', 151 | 152 | # Merge Operations 153 | 'MAX_MERGE_WORKERS', 154 | 'DEFAULT_LAMBDA', 155 | 'DEFAULT_NORMALIZE', 156 | 157 | # Validation 158 | 'MIN_LORAS_FOR_MERGE', 159 | 'MIN_KEY_OVERLAP_RATIO', 160 | 'TYPICAL_STRENGTH_MIN', 161 | 'TYPICAL_STRENGTH_MAX', 162 | 163 | # Device and Memory 164 | 'SUPPORTED_DEVICES', 165 | 'SUPPORTED_DTYPES', 166 | 'DEFAULT_DEVICE', 167 | 'DEFAULT_DTYPE', 168 | 169 | # Progress 170 | 'PROGRESS_UPDATE_INTERVAL', 171 | 'PROGRESS_BAR_FORMAT', 172 | 173 | # Layer Filters 174 | 'ATTENTION_LAYERS', 175 | 'MLP_LAYERS', 176 | 'ATTENTION_MLP_LAYERS', 177 | # Legacy (deprecated) 178 | 'SD_ATTENTION_LAYERS', 179 | 'SD_MLP_LAYERS', 180 | 'SD_ATTENTION_MLP_LAYERS', 181 | 'SD_PROJECTION_LAYERS', 182 | 'DIT_ATTENTION_LAYERS', 183 | 'DIT_MLP_LAYERS', 184 | 'WAN_ATTENTION_LAYERS', 185 | 'WAN_MLP_LAYERS', 186 | 187 | # File I/O 188 | 'LORA_FILE_EXTENSIONS', 189 | 'DEFAULT_LORA_SAVE_FORMAT', 190 | 191 | # Caching 192 | 'MAX_DECOMPOSITION_CACHE_SIZE', 193 | 'CACHE_TTL_SECONDS', 194 | 195 | # Logging 196 | 'DEFAULT_LOG_LEVEL', 197 | 'LOG_FORMAT', 198 | ] 199 | -------------------------------------------------------------------------------- /src/device/manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unified device and dtype management. 3 | 4 | Consolidates all device/dtype conversion logic into a single class. 5 | Replaces scattered `map_device()` and `str_to_dtype()` functions. 6 | """ 7 | 8 | import torch 9 | from typing import Tuple, Optional, Union 10 | 11 | from ..types import DeviceType, DtypeType 12 | 13 | 14 | class DeviceManager: 15 | """ 16 | Unified device and dtype management for tensor operations. 17 | 18 | Provides methods for: 19 | - Converting string representations to torch objects 20 | - Moving tensors between devices 21 | - Checking device availability 22 | - Selecting appropriate devices for operations 23 | """ 24 | 25 | # Supported device strings 26 | SUPPORTED_DEVICES = ["cpu", "cuda", "mps", "auto"] 27 | 28 | # Supported dtype strings 29 | SUPPORTED_DTYPES = [ 30 | "float16", "float32", "float64", 31 | "bfloat16", 32 | "int8", "int16", "int32", "int64", 33 | ] 34 | 35 | @staticmethod 36 | def parse( 37 | device: DeviceType, 38 | dtype: DtypeType 39 | ) -> Tuple[torch.device, torch.dtype]: 40 | """ 41 | Convert device and dtype from string representation to torch objects. 42 | 43 | Backward compatible with the old `map_device()` function. 44 | 45 | Args: 46 | device: Device specification as string or torch.device 47 | Supports: "cpu", "cuda", "cuda:0", "auto", or torch.device object 48 | dtype: Data type specification as string or torch.dtype 49 | Supports: "float16", "float32", "bfloat16", etc., or torch.dtype object 50 | 51 | Returns: 52 | Tuple of (torch.device, torch.dtype) 53 | 54 | Raises: 55 | ValueError: If device or dtype string is invalid 56 | 57 | Examples: 58 | >>> DeviceManager.parse("cuda", "float16") 59 | (device(type='cuda'), torch.float16) 60 | 61 | >>> DeviceManager.parse(torch.device("cpu"), torch.float32) 62 | (device(type='cpu'), torch.float32) 63 | """ 64 | # Convert device 65 | if isinstance(device, str): 66 | if device == "auto": 67 | device = DeviceManager.get_default_device() 68 | else: 69 | try: 70 | device = torch.device(device) 71 | except RuntimeError as e: 72 | raise ValueError(f"Invalid device string: {device}") from e 73 | elif not isinstance(device, torch.device): 74 | raise TypeError(f"Device must be str or torch.device, got {type(device)}") 75 | 76 | # Convert dtype 77 | if isinstance(dtype, str): 78 | if not hasattr(torch, dtype): 79 | raise ValueError( 80 | f"Invalid dtype string: {dtype}. " 81 | f"Supported: {', '.join(DeviceManager.SUPPORTED_DTYPES)}" 82 | ) 83 | dtype = getattr(torch, dtype) 84 | elif not isinstance(dtype, torch.dtype): 85 | raise TypeError(f"Dtype must be str or torch.dtype, got {type(dtype)}") 86 | 87 | return device, dtype 88 | 89 | @staticmethod 90 | def get_default_device() -> torch.device: 91 | """ 92 | Get the default device for operations. 93 | 94 | Checks availability in order: CUDA > MPS > CPU 95 | 96 | Returns: 97 | torch.device for the best available device 98 | """ 99 | if torch.cuda.is_available(): 100 | return torch.device("cuda") 101 | elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): 102 | return torch.device("mps") 103 | else: 104 | return torch.device("cpu") 105 | 106 | @staticmethod 107 | def is_device_available(device: Union[str, torch.device]) -> bool: 108 | """ 109 | Check if a device is available. 110 | 111 | Args: 112 | device: Device to check (string or torch.device) 113 | 114 | Returns: 115 | True if device is available, False otherwise 116 | """ 117 | if isinstance(device, str): 118 | device = torch.device(device) 119 | 120 | if device.type == "cpu": 121 | return True 122 | elif device.type == "cuda": 123 | return torch.cuda.is_available() 124 | elif device.type == "mps": 125 | return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() 126 | else: 127 | return False 128 | 129 | @staticmethod 130 | def to_device( 131 | tensor: torch.Tensor, 132 | device: DeviceType, 133 | dtype: Optional[DtypeType] = None, 134 | non_blocking: bool = False 135 | ) -> torch.Tensor: 136 | """ 137 | Move tensor to device with optional dtype conversion. 138 | 139 | Args: 140 | tensor: Tensor to move 141 | device: Target device 142 | dtype: Target dtype (optional, keeps current dtype if None) 143 | non_blocking: Whether to use non-blocking transfer 144 | 145 | Returns: 146 | Tensor on target device 147 | 148 | Example: 149 | >>> tensor = torch.randn(10, 10) 150 | >>> DeviceManager.to_device(tensor, "cuda", "float16") 151 | """ 152 | device_obj, _ = DeviceManager.parse(device, dtype or tensor.dtype) 153 | 154 | if dtype is not None: 155 | _, dtype_obj = DeviceManager.parse(device, dtype) 156 | return tensor.to(device=device_obj, dtype=dtype_obj, non_blocking=non_blocking) 157 | else: 158 | return tensor.to(device=device_obj, non_blocking=non_blocking) 159 | 160 | @staticmethod 161 | def get_device_memory_info(device: Union[str, torch.device]) -> Optional[dict]: 162 | """ 163 | Get memory information for a device. 164 | 165 | Args: 166 | device: Device to query 167 | 168 | Returns: 169 | Dictionary with memory info, or None if not available 170 | Keys: "allocated", "reserved", "total" (all in bytes) 171 | """ 172 | if isinstance(device, str): 173 | device = torch.device(device) 174 | 175 | if device.type == "cuda" and torch.cuda.is_available(): 176 | return { 177 | "allocated": torch.cuda.memory_allocated(device), 178 | "reserved": torch.cuda.memory_reserved(device), 179 | "total": torch.cuda.get_device_properties(device).total_memory, 180 | } 181 | else: 182 | return None 183 | 184 | @staticmethod 185 | def empty_cache(device: Optional[Union[str, torch.device]] = None): 186 | """ 187 | Empty the cache for a device. 188 | 189 | Args: 190 | device: Device to clear cache for (defaults to CUDA if available) 191 | """ 192 | if device is None: 193 | if torch.cuda.is_available(): 194 | torch.cuda.empty_cache() 195 | else: 196 | if isinstance(device, str): 197 | device = torch.device(device) 198 | 199 | if device.type == "cuda" and torch.cuda.is_available(): 200 | torch.cuda.empty_cache() 201 | -------------------------------------------------------------------------------- /src/lora_stack_sampler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image, ImageDraw 6 | from torch import Tensor 7 | 8 | from comfy.sd import VAE 9 | from comfy_extras.nodes_custom_sampler import SamplerCustom 10 | from .architectures import LORA_STACK, LORA_WEIGHTS 11 | from .utility import load_font 12 | 13 | 14 | class LoRAStackSampler: 15 | @classmethod 16 | def INPUT_TYPES(cls): 17 | return { 18 | "required": { 19 | "model": ("MODEL",), 20 | "vae": ("VAE",), 21 | "add_noise": ("BOOLEAN", {"default": True}), 22 | "noise_seed": ( 23 | "INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True} 24 | ), 25 | "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}), 26 | "positive": ("CONDITIONING",), 27 | "negative": ("CONDITIONING",), 28 | "sampler": ("SAMPLER",), 29 | "sigmas": ("SIGMAS",), 30 | "latent_image": ("LATENT",), 31 | "lora_key_dicts": ("LoRAStack", {"tooltip": "The dictionary containing LoRA names and key weights."}), 32 | "lora_strengths": ("LoRAWeights", {"tooltip": "The LoRA weighting to apply."}), 33 | } 34 | } 35 | RETURN_TYPES = ("LATENT", "IMAGE", "IMAGE") 36 | RETURN_NAMES = ("latents", "images", "image_grid") 37 | FUNCTION = "sample" 38 | CATEGORY = "LoRA PowerMerge/sampling" 39 | DESCRIPTION = "Samples images by iterating over the given LoRA key dictionary and applying the LoRA weights." 40 | 41 | def sample(self, model, vae: VAE, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, 42 | latent_image, lora_key_dicts: LORA_STACK = None, lora_strengths: LORA_WEIGHTS = None): 43 | if lora_key_dicts is None or lora_strengths is None: 44 | raise ValueError("key_dicts and lora_weighting must be provided.") 45 | 46 | latents_out = self.do_sample(add_noise, cfg, lora_key_dicts, lora_strengths, latent_image, model, noise_seed, 47 | positive, negative, sampler, sigmas) 48 | 49 | # Create a grid of images with LoRA names and strengths 50 | names = list(lora_key_dicts.keys()) 51 | weights = list(lora_strengths.values()) 52 | grid_single_images, image_grid = self.image_grid(names, weights, [s['samples'] for s in latents_out], vae) 53 | 54 | # Repack the output 55 | latents_out = { 56 | "samples": torch.cat([s['samples'] for s in latents_out], dim=0) 57 | } 58 | return latents_out, grid_single_images, image_grid 59 | 60 | @staticmethod 61 | def image_grid(names, strengths, batch_latents, vae) -> tuple[Tensor, Tensor]: 62 | """ 63 | Create an image grid with batches on Y-axis and LoRA names on X-axis. 64 | Args: 65 | names: List of LoRA names. 66 | strengths: List of strengths corresponding to each LoRA name. 67 | batch_latents: List of latents, where each tensor is a batch of images. 68 | vae: The VAE model used for decoding the latents. 69 | Returns: 70 | Tuple of (all images tensor, grid image tensor) 71 | """ 72 | grid_images = [] 73 | 74 | for n, w, l in zip(names, strengths, batch_latents): 75 | n = n.split('/')[-1].split('.')[0] 76 | 77 | images = vae.decode(l) 78 | if len(images.shape) == 5: 79 | images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) 80 | 81 | for img in images.squeeze(dim=0): 82 | out_image = LoRAStackSampler.annotate_image(n, w, img) 83 | grid_images.append(out_image) 84 | 85 | num_loras = len(names) 86 | num_batches = len(grid_images) // num_loras 87 | 88 | if not grid_images: 89 | return torch.tensor([]), torch.tensor([]) 90 | 91 | first_image = grid_images[0] 92 | 93 | img_height = first_image.shape[0] 94 | img_width = first_image.shape[1] 95 | num_img_chans = first_image.shape[2] 96 | 97 | img_grid = torch.zeros(1, num_batches * img_height, num_loras * img_width, num_img_chans) 98 | if len(grid_images) > 0: 99 | for i in range(num_loras): 100 | for j in range(num_batches): 101 | idx = i * num_batches + j 102 | img_grid[:, j * img_height:(j + 1) * img_height, 103 | i * img_width:(i + 1) * img_width] = grid_images[idx] 104 | return torch.stack(grid_images), img_grid 105 | 106 | @staticmethod 107 | def annotate_image(name, weighting, img_tensor): 108 | title_font = load_font() 109 | 110 | i = 255. * img_tensor.cpu().numpy() 111 | img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) 112 | title = f"{name}\nStrength: {weighting['strength_model']:.2f}" 113 | max_text_width = img.width - 100 # small margin 114 | # Wrap the text 115 | lines = [] 116 | for word in title.split(): 117 | if not lines: 118 | lines.append(word) 119 | else: 120 | test_line = lines[-1] + ' ' + word 121 | test_width = title_font.getbbox(test_line)[2] 122 | if test_width <= max_text_width: 123 | lines[-1] = test_line 124 | else: 125 | lines.append(word) 126 | title_padding = 6 127 | line_height = title_font.getbbox("A")[3] + title_padding 128 | title_text_height = line_height * len(lines) + title_padding 129 | title_text_image = Image.new('RGB', (img.width, title_text_height), color=(0, 0, 0)) 130 | draw = ImageDraw.Draw(title_text_image) 131 | for i, line in enumerate(lines): 132 | line_width = title_font.getbbox(line)[2] 133 | draw.text( 134 | ((img.width - line_width) // 2, i * line_height + title_padding // 2), 135 | line, 136 | font=title_font, 137 | fill=(255, 255, 255) 138 | ) 139 | title_text_image_tensor = torch.tensor(np.array(title_text_image).astype(np.float32) / 255.0) 140 | out_image = torch.cat([title_text_image_tensor, img_tensor], 0) 141 | return out_image 142 | 143 | @staticmethod 144 | def do_sample(add_noise, cfg, key_dicts, lora_strengths, latent_image, model, noise_seed, positive, 145 | negative, sampler, sigmas): 146 | latents_out = [] 147 | 148 | kSampler = SamplerCustom() 149 | for lora_name, patch_dict in key_dicts.items(): 150 | strengths = lora_strengths[lora_name] 151 | logging.info(f"PM LoRAStackSampler: Applying LoRA {lora_name} with weights {strengths}") 152 | 153 | new_model_patcher = model.clone() 154 | new_model_patcher.add_patches(patch_dict, strengths['strength_model']) 155 | 156 | denoised, _ = kSampler.sample( 157 | model=new_model_patcher, 158 | add_noise=add_noise, 159 | noise_seed=noise_seed, 160 | cfg=cfg, 161 | positive=positive, 162 | negative=negative, 163 | sampler=sampler, 164 | sigmas=sigmas, 165 | latent_image=latent_image, 166 | ) 167 | latents_out.append(denoised) 168 | return latents_out 169 | -------------------------------------------------------------------------------- /src/merge/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for merge operations. 3 | 4 | Contains helper functions used throughout the merging pipeline: 5 | - WeightInfo mapping creation 6 | - Tensor parameter creation 7 | - Layer filtering logic 8 | """ 9 | 10 | import logging 11 | from typing import Dict, Optional, Any 12 | 13 | import torch 14 | from mergekit.architecture import WeightInfo 15 | from mergekit.common import ModelReference, ImmutableMap 16 | 17 | from ..types import ( 18 | LORA_KEY_DICT, 19 | LayerFilterType, 20 | LayerComponentSet, 21 | ) 22 | 23 | 24 | def create_map( 25 | key: str, 26 | tensors: Dict[ModelReference, torch.Tensor], 27 | dtype: torch.dtype 28 | ) -> ImmutableMap[ModelReference, WeightInfo]: 29 | """ 30 | Create an ImmutableMap of WeightInfo objects for mergekit operations. 31 | 32 | Args: 33 | key: Layer key identifier 34 | tensors: Dictionary mapping model references to tensors 35 | dtype: Data type for the weight info 36 | 37 | Returns: 38 | ImmutableMap of WeightInfo objects for each model reference 39 | """ 40 | return ImmutableMap({ 41 | r: WeightInfo(name=f'model{i}.{key}', dtype=dtype) 42 | for i, r in enumerate(tensors.keys()) 43 | }) 44 | 45 | 46 | def create_tensor_param(tensor_weight: float, method_args: Dict[str, Any]) -> Dict[str, Any]: 47 | """ 48 | Create tensor parameter dictionary for merge operations. 49 | 50 | Combines weight value with method-specific arguments into a single 51 | parameter dictionary. 52 | 53 | Args: 54 | tensor_weight: Weight/strength value for the tensor 55 | method_args: Additional method-specific arguments (density, epsilon, etc.) 56 | 57 | Returns: 58 | Dictionary with weight and method args merged 59 | 60 | Example: 61 | >>> create_tensor_param(0.5, {"density": 0.8}) 62 | {"weight": 0.5, "density": 0.8} 63 | """ 64 | out = {"weight": tensor_weight} 65 | out.update(method_args) 66 | return out 67 | 68 | 69 | def parse_layer_filter(layer_filter: LayerFilterType) -> Optional[LayerComponentSet]: 70 | """ 71 | Parse layer filter string into set of component names. 72 | 73 | Converts high-level filter specification into specific component name sets. 74 | Architecture-agnostic: works for both Stable Diffusion and DiT LoRAs. 75 | 76 | Args: 77 | layer_filter: Filter type specification 78 | - "full": No filtering (returns None) 79 | - "attn-only": Only attention layers (SD: attn1/attn2, DiT: attention) 80 | - "mlp-only": Only MLP/feedforward layers (SD: ff, DiT: mlp/feed_forward) 81 | - "attn-mlp": Both attention and MLP layers 82 | 83 | Returns: 84 | Set of layer component names to keep, or None for no filtering 85 | 86 | Example: 87 | >>> parse_layer_filter("attn-mlp") 88 | {"attn1", "attn2", "attention", "ff", "mlp", "feed_forward"} 89 | 90 | Note: 91 | This function delegates to LayerFilter.PRESETS for consistency. 92 | Direct use of the LayerFilter class is recommended for new code. 93 | """ 94 | # Import here to avoid circular imports 95 | from ..utils.layer_filter import LayerFilter 96 | 97 | # Delegate to LayerFilter.PRESETS for single source of truth 98 | return LayerFilter.PRESETS.get(layer_filter, None) 99 | 100 | 101 | def apply_layer_filter( 102 | patch_dict: LORA_KEY_DICT, 103 | layer_filter: Optional[LayerComponentSet], 104 | detect_architecture: bool = True 105 | ) -> LORA_KEY_DICT: 106 | """ 107 | Apply layer component filter to patch dictionary. 108 | 109 | Filters LoRA patches to include only specified component types. 110 | Used for selective merging (e.g., merge only attention layers). 111 | 112 | Args: 113 | patch_dict: Dictionary of layer key -> LoRAAdapter 114 | layer_filter: Set of component names to keep, or None for no filtering 115 | detect_architecture: Whether to detect and log architecture (default: True) 116 | 117 | Returns: 118 | Filtered patch dictionary containing only matching layers 119 | 120 | Note: 121 | Uses component-based matching (split by '.') to avoid false positives 122 | from substring matches (e.g., 'ff' in 'diffusion_model'). 123 | 124 | Example: 125 | >>> patches = {"model.attn1.weight": adapter1, "model.ff.weight": adapter2} 126 | >>> filtered = apply_layer_filter(patches, {"attn1", "attn2"}) 127 | >>> # Returns only {"model.attn1.weight": adapter1} 128 | """ 129 | from ..utils.layer_filter import detect_lora_architecture 130 | 131 | num_keys = len(patch_dict.keys()) 132 | 133 | # Detect architecture before filtering 134 | if detect_architecture and patch_dict: 135 | arch_name, arch_meta = detect_lora_architecture(patch_dict) 136 | total_keys = arch_meta.get("total_keys", num_keys) 137 | 138 | if arch_name != "Unknown": 139 | logging.info(f"Detected {arch_name} architecture ({total_keys} keys)") 140 | else: 141 | logging.debug(f"Processing LoRA ({total_keys} keys, architecture: {arch_name})") 142 | 143 | if layer_filter: 144 | import re 145 | 146 | def matches_filter(key) -> bool: 147 | """ 148 | Check if key matches any filter component. 149 | 150 | Uses word-boundary aware matching to avoid false positives. 151 | For example, 'ff' will match 'ff_net' or '.ff.' but not 'diffusion'. 152 | """ 153 | # Handle tuple keys from ComfyUI 154 | key_str = str(key[0]) if isinstance(key, tuple) else str(key) 155 | key_lower = key_str.lower() 156 | 157 | for filter_pattern in layer_filter: 158 | pattern_lower = filter_pattern.lower() 159 | 160 | # Create regex pattern with word boundaries 161 | # Match pattern when surrounded by dots, underscores, or at start/end 162 | regex_pattern = r'(?:^|[._])' + re.escape(pattern_lower) + r'(?:[._]|$)' 163 | 164 | if re.search(regex_pattern, key_lower): 165 | return True 166 | 167 | return False 168 | 169 | patch_dict = { 170 | k0: v0 171 | for k0, v0 in patch_dict.items() 172 | if matches_filter(k0) 173 | } 174 | 175 | logging.info( 176 | f"Stacking {len(patch_dict)} keys with {num_keys - len(patch_dict)} " 177 | f"filtered out by filter method {layer_filter}." 178 | ) 179 | 180 | return patch_dict 181 | 182 | 183 | def apply_weights_to_tensors( 184 | tensors: Dict[str, torch.Tensor], 185 | tensor_parameters: Dict[str, Dict[str, Any]] 186 | ) -> Dict[str, torch.Tensor]: 187 | """ 188 | Apply strength weights to tensors. 189 | 190 | This is a common pattern across multiple merge algorithms (slerp, karcher, 191 | nearswap, arcee_fusion). Extracted to eliminate code duplication. 192 | 193 | Args: 194 | tensors: Dictionary mapping LoRA names to their tensors 195 | tensor_parameters: Dictionary mapping LoRA names to their parameters 196 | (must contain "weight" key) 197 | 198 | Returns: 199 | Dictionary mapping LoRA names to weighted tensors 200 | 201 | Example: 202 | >>> tensors = {"lora1": torch.ones(10, 10), "lora2": torch.ones(10, 10)} 203 | >>> params = {"lora1": {"weight": 0.5}, "lora2": {"weight": 0.8}} 204 | >>> weighted = apply_weights_to_tensors(tensors, params) 205 | >>> # lora1 scaled by 0.5, lora2 scaled by 0.8 206 | """ 207 | return { 208 | ref: tensor_parameters[ref]["weight"] * tensors[ref] 209 | for ref in tensors.keys() 210 | } 211 | -------------------------------------------------------------------------------- /tests/test_algorithms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for merge algorithms. 3 | 4 | Tests individual merge algorithm functions from src/merge/algorithms.py. 5 | Mock mergekit dependencies to test algorithm logic in isolation. 6 | """ 7 | 8 | import pytest 9 | import torch 10 | import sys 11 | from pathlib import Path 12 | from unittest.mock import Mock, MagicMock, patch 13 | 14 | # Add src to path for imports 15 | sys.path.insert(0, str(Path(__file__).parent.parent / "src")) 16 | 17 | from merge.algorithms import ( 18 | linear_merge, 19 | get_merge_algorithm, 20 | MERGE_ALGORITHMS, 21 | ) 22 | from merge.utils import apply_weights_to_tensors 23 | 24 | 25 | class TestApplyWeightsToTensors: 26 | """Tests for the apply_weights_to_tensors utility function.""" 27 | 28 | def test_basic_weighting(self): 29 | """Test that weights are applied correctly to tensors.""" 30 | tensors = { 31 | "lora1": torch.ones(10, 10), 32 | "lora2": torch.ones(10, 10) * 2, 33 | } 34 | tensor_parameters = { 35 | "lora1": {"weight": 0.5}, 36 | "lora2": {"weight": 0.8}, 37 | } 38 | 39 | result = apply_weights_to_tensors(tensors, tensor_parameters) 40 | 41 | assert "lora1" in result 42 | assert "lora2" in result 43 | assert torch.allclose(result["lora1"], torch.ones(10, 10) * 0.5) 44 | assert torch.allclose(result["lora2"], torch.ones(10, 10) * 1.6) 45 | 46 | def test_zero_weight(self): 47 | """Test that zero weight produces zero tensor.""" 48 | tensors = {"lora1": torch.ones(5, 5)} 49 | tensor_parameters = {"lora1": {"weight": 0.0}} 50 | 51 | result = apply_weights_to_tensors(tensors, tensor_parameters) 52 | 53 | assert torch.allclose(result["lora1"], torch.zeros(5, 5)) 54 | 55 | def test_preserves_tensor_shape(self): 56 | """Test that tensor shapes are preserved.""" 57 | tensors = { 58 | "lora1": torch.randn(10, 20), 59 | "lora2": torch.randn(5, 15, 3), 60 | } 61 | tensor_parameters = { 62 | "lora1": {"weight": 0.7}, 63 | "lora2": {"weight": 0.3}, 64 | } 65 | 66 | result = apply_weights_to_tensors(tensors, tensor_parameters) 67 | 68 | assert result["lora1"].shape == (10, 20) 69 | assert result["lora2"].shape == (5, 15, 3) 70 | 71 | 72 | class TestAlgorithmRegistry: 73 | """Tests for the algorithm registry and dispatcher.""" 74 | 75 | def test_all_algorithms_registered(self): 76 | """Test that all expected algorithms are in the registry.""" 77 | expected_algorithms = [ 78 | "linear", 79 | "generalized_task_arithmetic", 80 | "sce", 81 | "karcher", 82 | "slerp", 83 | "nuslerp", 84 | "nearswap", 85 | "arcee_fusion", 86 | ] 87 | 88 | for alg in expected_algorithms: 89 | assert alg in MERGE_ALGORITHMS, f"{alg} not in registry" 90 | 91 | def test_get_merge_algorithm_valid(self): 92 | """Test getting valid algorithm from registry.""" 93 | alg = get_merge_algorithm("linear") 94 | assert callable(alg) 95 | assert alg == MERGE_ALGORITHMS["linear"] 96 | 97 | def test_get_merge_algorithm_invalid(self): 98 | """Test that invalid algorithm name raises error.""" 99 | with pytest.raises(ValueError, match="Unknown merge algorithm"): 100 | get_merge_algorithm("nonexistent_algorithm") 101 | 102 | def test_algorithm_signature(self): 103 | """Test that all algorithms have the expected signature.""" 104 | # All merge algorithms should accept these parameters 105 | expected_params = ["tensors", "gather_tensors", "weight_info", "tensor_parameters", "method_args"] 106 | 107 | for name, func in MERGE_ALGORITHMS.items(): 108 | # Check function has the right parameter names 109 | import inspect 110 | sig = inspect.signature(func) 111 | param_names = list(sig.parameters.keys()) 112 | 113 | for expected in expected_params: 114 | assert expected in param_names, f"{name} missing parameter {expected}" 115 | 116 | 117 | class TestLinearMerge: 118 | """Tests for linear merge algorithm.""" 119 | 120 | @patch('merge.algorithms.LinearMergeTask') 121 | def test_linear_merge_calls_task(self, mock_task_class): 122 | """Test that linear merge creates and executes LinearMergeTask.""" 123 | # Setup mocks 124 | mock_task = Mock() 125 | mock_task.execute.return_value = torch.ones(10, 10) 126 | mock_task_class.return_value = mock_task 127 | 128 | mock_tensors = {"lora1": torch.randn(10, 10)} 129 | mock_gather = Mock() 130 | mock_weight_info = Mock() 131 | mock_params = Mock() 132 | method_args = {"normalize": True} 133 | 134 | # Execute 135 | result = linear_merge( 136 | tensors=mock_tensors, 137 | gather_tensors=mock_gather, 138 | weight_info=mock_weight_info, 139 | tensor_parameters=mock_params, 140 | method_args=method_args 141 | ) 142 | 143 | # Verify task was created with correct args 144 | mock_task_class.assert_called_once_with( 145 | gather_tensors=mock_gather, 146 | tensor_parameters=mock_params, 147 | normalize=True, 148 | weight_info=mock_weight_info, 149 | ) 150 | 151 | # Verify task was executed 152 | mock_task.execute.assert_called_once_with(tensors=mock_tensors) 153 | 154 | # Verify result 155 | assert torch.allclose(result, torch.ones(10, 10)) 156 | 157 | @patch('merge.algorithms.LinearMergeTask') 158 | def test_linear_merge_default_normalize(self, mock_task_class): 159 | """Test that normalize defaults to False if not in method_args.""" 160 | mock_task = Mock() 161 | mock_task.execute.return_value = torch.zeros(5, 5) 162 | mock_task_class.return_value = mock_task 163 | 164 | # No normalize in method_args 165 | linear_merge( 166 | tensors={}, 167 | gather_tensors=Mock(), 168 | weight_info=Mock(), 169 | tensor_parameters=Mock(), 170 | method_args={} 171 | ) 172 | 173 | # Should use default False 174 | call_kwargs = mock_task_class.call_args.kwargs 175 | assert call_kwargs["normalize"] == False 176 | 177 | 178 | # Fixtures for common test data 179 | 180 | @pytest.fixture 181 | def sample_tensors(): 182 | """Fixture providing sample tensors for testing.""" 183 | return { 184 | "lora_1": torch.randn(10, 5), 185 | "lora_2": torch.randn(10, 5), 186 | } 187 | 188 | 189 | @pytest.fixture 190 | def sample_tensor_parameters(): 191 | """Fixture providing sample tensor parameters.""" 192 | return { 193 | "lora_1": {"weight": 0.6}, 194 | "lora_2": {"weight": 0.4}, 195 | } 196 | 197 | 198 | @pytest.fixture 199 | def mock_mergekit_objects(): 200 | """Fixture providing mocked mergekit objects.""" 201 | return { 202 | "gather_tensors": Mock(), 203 | "weight_info": Mock(name="test.layer"), 204 | "method_args": {}, 205 | } 206 | 207 | 208 | class TestIntegrationWithFixtures: 209 | """Integration tests using fixtures.""" 210 | 211 | def test_apply_weights_with_sample_data(self, sample_tensors, sample_tensor_parameters): 212 | """Test apply_weights with realistic sample data.""" 213 | result = apply_weights_to_tensors(sample_tensors, sample_tensor_parameters) 214 | 215 | # Check all tensors are weighted 216 | assert len(result) == 2 217 | assert result["lora_1"].shape == (10, 5) 218 | assert result["lora_2"].shape == (10, 5) 219 | 220 | # Check weights were applied (result should be scaled versions) 221 | assert not torch.allclose(result["lora_1"], sample_tensors["lora_1"]) 222 | assert not torch.allclose(result["lora_2"], sample_tensors["lora_2"]) 223 | 224 | 225 | if __name__ == "__main__": 226 | pytest.main([__file__, "-v"]) 227 | -------------------------------------------------------------------------------- /src/lora_decompose.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import logging 3 | import time 4 | from typing import Optional, Set 5 | 6 | import torch 7 | import comfy 8 | 9 | from .architectures.sd_lora import analyse_keys, calc_up_down_alphas 10 | from .types import LORA_TENSORS_BY_LAYER, LORA_STACK, LORA_TENSOR_DICT 11 | from .utility import map_device, adjust_tensor_dims 12 | 13 | 14 | class LoraDecompose: 15 | """ 16 | Node for decomposing LoRA models into their components 17 | """ 18 | 19 | def __init__(self): 20 | self.last_lora_names_hash: Optional[list] = None 21 | self.last_tensor_sum: float = 0.0 22 | self.last_svd_rank: int = -1 23 | self.last_decomposition_method: str = "" 24 | self.last_layer_filter: Optional[Set[str]] = None 25 | self.last_result: LORA_TENSORS_BY_LAYER = {} 26 | 27 | @classmethod 28 | def INPUT_TYPES(s): 29 | return { 30 | "required": { 31 | "key_dicts": ("LoRAStack",), 32 | "decomposition_method": ( 33 | ["none", "rSVD", "energy_rSVD", "SVD"], 34 | { 35 | "default": "rSVD", 36 | "tooltip": ( 37 | "Method used to reconcile LoRA ranks when they differ. " 38 | "'none' will raise an error if ranks do not match. " 39 | "'SVD' uses full singular value decomposition (slow but optimal). " 40 | "'rSVD' uses randomized SVD (much faster, near-optimal). " 41 | "'energy_rSVD' first prunes low-energy LoRA components and then " 42 | "applies randomized SVD for fast, stable rank reduction " 43 | "(recommended for DiT and large LoRAs)." 44 | ), 45 | } 46 | ), 47 | "svd_rank": ( 48 | "INT", 49 | { 50 | "default": -1, 51 | "min": -1, 52 | "max": 128, 53 | "tooltip": ( 54 | "Target LoRA rank after decomposition. " 55 | "-1 keeps the rank of the first LoRA. " 56 | "Lower values reduce model size and strength." 57 | ), 58 | } 59 | ), 60 | "device": (["cuda", "cpu"], 61 | {"tooltip": "Decomposition device. Note: All decomposition uses float32 internally for numerical stability, then converts back to the original dtype."} 62 | ), 63 | }, 64 | } 65 | 66 | RETURN_TYPES = ("LoRATensors",) 67 | FUNCTION = "lora_decompose" 68 | CATEGORY = "LoRA PowerMerge" 69 | DESCRIPTION = """Decomposes LoRA stack into tensor components for merging. 70 | 71 | Extracts (up, down, alpha) tuples from each LoRA layer and handles rank mismatches using SVD-based decomposition methods. 72 | 73 | Decomposition Methods: 74 | - none: Requires all LoRAs to have matching ranks (fastest, fails if ranks differ) 75 | - rSVD: Randomized SVD for rank reconciliation (fast, recommended for most cases) 76 | - energy_rSVD: Energy-based randomized SVD (best for DiT/large LoRAs) 77 | - SVD: Full SVD decomposition (slow but optimal) 78 | 79 | Features hash-based caching to skip recomputation when inputs haven't changed.""" 80 | 81 | def lora_decompose(self, key_dicts: LORA_STACK = None, 82 | decomposition_method="rSVD", svd_rank=-1, device=None): 83 | device, _ = map_device(device, "float32") 84 | 85 | logging.info(f"Decomposing LoRAs with method: {decomposition_method}, SVD rank: {svd_rank}") 86 | 87 | # check if key_dicts differs from the previous one 88 | lora_names_hash_new = self.compute_hash(list(key_dicts.keys())) 89 | if (self.last_lora_names_hash == lora_names_hash_new 90 | and self.last_svd_rank == svd_rank 91 | and self.last_decomposition_method == decomposition_method 92 | and self.last_tensor_sum == self.compute_sum(key_dicts)): 93 | logging.info("Key dicts have not changed, returning last result.") 94 | if self.last_result is not None: 95 | return (self.last_result,) 96 | else: 97 | logging.warning("No last result available, recomputing.") 98 | else: 99 | logging.info("Key dicts have changed, recomputing.") 100 | 101 | self.last_lora_names_hash = lora_names_hash_new 102 | self.last_tensor_sum = self.compute_sum(key_dicts) 103 | self.last_svd_rank = svd_rank 104 | self.last_decomposition_method = decomposition_method 105 | 106 | self.last_result = self.decompose(key_dicts=key_dicts, device=device, 107 | decomposition_method=decomposition_method, 108 | svd_rank=svd_rank) 109 | return (self.last_result,) 110 | 111 | @staticmethod 112 | def compute_hash(value): 113 | """Computes a hash of the value for change detection.""" 114 | return hashlib.md5(str(value).encode()).hexdigest() 115 | 116 | @staticmethod 117 | def compute_sum(lora_key_dicts: LORA_STACK): 118 | """Computes the sum of all up, down, and alpha tensors in the LoRA key dicts.""" 119 | sum_ = 0 120 | for lora_name, lora_key_dict in lora_key_dicts.items(): 121 | for key in lora_key_dict.keys(): 122 | lora_adapter = lora_key_dict[key] 123 | up, down, _, _, _, _ = lora_adapter.weights 124 | sum_ += up.sum().item() + down.sum().item() 125 | return sum_ 126 | 127 | def decompose(self, key_dicts, device, decomposition_method, svd_rank) -> LORA_TENSORS_BY_LAYER: 128 | """ 129 | Decomposes LoRA models into their components. 130 | Args: 131 | key_dicts: Dictionary of LoRA names and their respective keys. 132 | device: Device to load tensors on. 133 | decomposition_method: Method to use for dimension alignment ("none", "svd", "rSVD", or "energy_rSVD"). 134 | svd_rank: Target rank for decomposition. 135 | Returns: 136 | Dictionary of LoRA components. 137 | lora_key -> lora_name -> (up, down, alpha) 138 | """ 139 | keys = list(analyse_keys(key_dicts)) # [:10] # Limit to 100 keys for testing 140 | 141 | pbar = comfy.utils.ProgressBar(len(keys)) 142 | start = time.time() 143 | 144 | def process_key(key, device_=device) -> LORA_TENSOR_DICT: 145 | uda = calc_up_down_alphas(key_dicts, key, load_device=device_, scale_to_alpha_0=True) 146 | 147 | # Determine if SVD should be applied 148 | if decomposition_method == "none": 149 | # Check if all LoRAs have the same rank 150 | ranks = [up.shape[1] for up, _, _ in uda.values()] 151 | if len(set(ranks)) > 1: 152 | rank_info = {lora_name: up.shape[1] for lora_name, (up, _, _) in uda.items()} 153 | raise ValueError( 154 | f"LoRAs have different ranks for key '{key}': {rank_info}. " 155 | f"Please select a decomposition method (SVD, rSVD, or energy_rSVD) to align dimensions." 156 | ) 157 | # No adjustment needed 158 | return uda 159 | else: 160 | # Apply the selected decomposition method 161 | uda_adjusted = adjust_tensor_dims( 162 | uda, 163 | apply_svd=True, 164 | svd_rank=svd_rank, 165 | method=decomposition_method 166 | ) 167 | return uda_adjusted 168 | 169 | out = {} 170 | for i, key in enumerate(keys): 171 | out[key] = process_key(key) 172 | pbar.update(1) 173 | 174 | logging.info(f"Processed {len(keys)} keys in {time.time() - start:.2f} seconds") 175 | 176 | torch.cuda.empty_cache() 177 | 178 | return out 179 | -------------------------------------------------------------------------------- /tests/test_types.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for type system and validators. 3 | 4 | Tests type guards, validators, and type definitions from src/types.py. 5 | """ 6 | 7 | import pytest 8 | import torch 9 | import sys 10 | from pathlib import Path 11 | 12 | # Add src to path for imports 13 | sys.path.insert(0, str(Path(__file__).parent.parent / "src")) 14 | 15 | from types import ( 16 | is_lora_tensors, 17 | is_lora_stack, 18 | validate_lora_tensors, 19 | validate_lora_stack, 20 | LORA_TENSORS, 21 | LORA_STACK, 22 | MIN_SINGULAR_VALUE, 23 | ) 24 | 25 | 26 | class TestLoRATensorsTypeGuard: 27 | """Tests for is_lora_tensors type guard.""" 28 | 29 | def test_valid_lora_tensors(self): 30 | """Test that valid LORA_TENSORS tuple is recognized.""" 31 | up = torch.randn(10, 5) 32 | down = torch.randn(5, 10) 33 | alpha = torch.tensor(1.0) 34 | tensors = (up, down, alpha) 35 | 36 | assert is_lora_tensors(tensors) 37 | 38 | def test_valid_lora_tensors_with_float_alpha(self): 39 | """Test that LORA_TENSORS with float alpha is recognized.""" 40 | up = torch.randn(10, 5) 41 | down = torch.randn(5, 10) 42 | alpha = 1.0 # Float instead of tensor 43 | tensors = (up, down, alpha) 44 | 45 | assert is_lora_tensors(tensors) 46 | 47 | def test_invalid_not_tuple(self): 48 | """Test that non-tuple is rejected.""" 49 | assert not is_lora_tensors([torch.randn(10, 5), torch.randn(5, 10), 1.0]) 50 | 51 | def test_invalid_wrong_length(self): 52 | """Test that tuple with wrong length is rejected.""" 53 | assert not is_lora_tensors((torch.randn(10, 5), torch.randn(5, 10))) 54 | 55 | def test_invalid_non_tensor_elements(self): 56 | """Test that tuple with non-tensor up/down is rejected.""" 57 | assert not is_lora_tensors(("not a tensor", torch.randn(5, 10), 1.0)) 58 | assert not is_lora_tensors((torch.randn(10, 5), "not a tensor", 1.0)) 59 | 60 | def test_invalid_alpha_type(self): 61 | """Test that invalid alpha type is rejected.""" 62 | up = torch.randn(10, 5) 63 | down = torch.randn(5, 10) 64 | assert not is_lora_tensors((up, down, "invalid")) 65 | 66 | 67 | class TestLoRAStackTypeGuard: 68 | """Tests for is_lora_stack type guard.""" 69 | 70 | def test_valid_lora_stack(self): 71 | """Test that valid LORA_STACK is recognized.""" 72 | from comfy.weight_adapter import LoRAAdapter 73 | 74 | # Create mock LoRAAdapter objects 75 | adapter1 = LoRAAdapter("lora", (torch.randn(10, 5), torch.randn(5, 10), 1.0, None, None, None)) 76 | adapter2 = LoRAAdapter("lora", (torch.randn(10, 5), torch.randn(5, 10), 1.0, None, None, None)) 77 | 78 | stack = { 79 | "lora1": {"layer1": adapter1}, 80 | "lora2": {"layer2": adapter2}, 81 | } 82 | 83 | assert is_lora_stack(stack) 84 | 85 | def test_invalid_not_dict(self): 86 | """Test that non-dict is rejected.""" 87 | assert not is_lora_stack([]) 88 | assert not is_lora_stack("not a dict") 89 | 90 | def test_invalid_non_string_keys(self): 91 | """Test that dict with non-string keys is rejected.""" 92 | assert not is_lora_stack({1: {}}) 93 | 94 | def test_invalid_non_dict_values(self): 95 | """Test that dict with non-dict values is rejected.""" 96 | assert not is_lora_stack({"lora1": "not a dict"}) 97 | 98 | 99 | class TestValidateLoRATensors: 100 | """Tests for validate_lora_tensors validator.""" 101 | 102 | def test_valid_2d_tensors(self): 103 | """Test validation of valid 2D LoRA tensors.""" 104 | up = torch.randn(10, 5) 105 | down = torch.randn(5, 10) 106 | alpha = torch.tensor(1.0) 107 | tensors = (up, down, alpha) 108 | 109 | # Should not raise 110 | validate_lora_tensors(tensors) 111 | 112 | def test_valid_4d_conv_tensors(self): 113 | """Test validation of valid 4D convolutional LoRA tensors.""" 114 | up = torch.randn(10, 5, 1, 1) 115 | down = torch.randn(5, 10, 1, 1) 116 | alpha = torch.tensor(1.0) 117 | tensors = (up, down, alpha) 118 | 119 | # Should not raise 120 | validate_lora_tensors(tensors) 121 | 122 | def test_invalid_structure(self): 123 | """Test that invalid structure raises ValueError.""" 124 | with pytest.raises(ValueError, match="Invalid LORA_TENSORS structure"): 125 | validate_lora_tensors("not a tuple") 126 | 127 | def test_invalid_up_dimensions(self): 128 | """Test that invalid up tensor dimensions raise ValueError.""" 129 | up = torch.randn(10) # 1D tensor 130 | down = torch.randn(5, 10) 131 | alpha = torch.tensor(1.0) 132 | tensors = (up, down, alpha) 133 | 134 | with pytest.raises(ValueError, match="Invalid up tensor dimensions"): 135 | validate_lora_tensors(tensors) 136 | 137 | def test_invalid_down_dimensions(self): 138 | """Test that invalid down tensor dimensions raise ValueError.""" 139 | up = torch.randn(10, 5) 140 | down = torch.randn(5) # 1D tensor 141 | alpha = torch.tensor(1.0) 142 | tensors = (up, down, alpha) 143 | 144 | with pytest.raises(ValueError, match="Invalid down tensor dimensions"): 145 | validate_lora_tensors(tensors) 146 | 147 | 148 | class TestValidateLoRAStack: 149 | """Tests for validate_lora_stack validator.""" 150 | 151 | def test_valid_stack(self): 152 | """Test validation of valid LoRA stack.""" 153 | from comfy.weight_adapter import LoRAAdapter 154 | 155 | adapter = LoRAAdapter("lora", (torch.randn(10, 5), torch.randn(5, 10), 1.0, None, None, None)) 156 | stack = {"lora1": {"layer1": adapter}} 157 | 158 | # Should not raise 159 | validate_lora_stack(stack) 160 | 161 | def test_invalid_structure(self): 162 | """Test that invalid structure raises ValueError.""" 163 | with pytest.raises(ValueError, match="Invalid LORA_STACK structure"): 164 | validate_lora_stack("not a dict") 165 | 166 | def test_empty_stack(self): 167 | """Test that empty stack raises ValueError.""" 168 | with pytest.raises(ValueError, match="LORA_STACK cannot be empty"): 169 | validate_lora_stack({}) 170 | 171 | 172 | class TestConstants: 173 | """Tests for constants defined in types module.""" 174 | 175 | def test_min_singular_value(self): 176 | """Test that MIN_SINGULAR_VALUE is defined correctly.""" 177 | assert MIN_SINGULAR_VALUE == 1e-6 178 | assert isinstance(MIN_SINGULAR_VALUE, float) 179 | 180 | 181 | # Fixtures for reusable test data 182 | 183 | @pytest.fixture 184 | def sample_lora_tensors(): 185 | """Fixture providing sample LORA_TENSORS.""" 186 | up = torch.randn(10, 5) 187 | down = torch.randn(5, 10) 188 | alpha = torch.tensor(1.0) 189 | return (up, down, alpha) 190 | 191 | 192 | @pytest.fixture 193 | def sample_lora_stack(): 194 | """Fixture providing sample LORA_STACK.""" 195 | from comfy.weight_adapter import LoRAAdapter 196 | 197 | adapter1 = LoRAAdapter("lora", (torch.randn(10, 5), torch.randn(5, 10), 1.0, None, None, None)) 198 | adapter2 = LoRAAdapter("lora", (torch.randn(8, 4), torch.randn(4, 8), 1.0, None, None, None)) 199 | 200 | return { 201 | "lora_1": { 202 | "layer.0.attn1": adapter1, 203 | "layer.0.attn2": adapter1, 204 | }, 205 | "lora_2": { 206 | "layer.0.attn1": adapter2, 207 | "layer.1.ff": adapter2, 208 | }, 209 | } 210 | 211 | 212 | class TestIntegrationWithFixtures: 213 | """Integration tests using fixtures.""" 214 | 215 | def test_sample_tensors_are_valid(self, sample_lora_tensors): 216 | """Test that sample tensors pass validation.""" 217 | assert is_lora_tensors(sample_lora_tensors) 218 | validate_lora_tensors(sample_lora_tensors) 219 | 220 | def test_sample_stack_is_valid(self, sample_lora_stack): 221 | """Test that sample stack passes validation.""" 222 | assert is_lora_stack(sample_lora_stack) 223 | validate_lora_stack(sample_lora_stack) 224 | 225 | 226 | if __name__ == "__main__": 227 | pytest.main([__file__, "-v"]) 228 | -------------------------------------------------------------------------------- /src/experimental/lora_attention_logger.py: -------------------------------------------------------------------------------- 1 | import io 2 | from collections import defaultdict 3 | from typing import List 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from PIL import Image 9 | from matplotlib import pyplot as plt 10 | 11 | import comfy 12 | from ..architectures.sd_lora import detect_block_names 13 | 14 | # Create a global variable to store the norm values for each layer 15 | # Format: { layer_name: [norm1, norm2, ...] } 16 | layer_step_log = dict() 17 | 18 | ''' 19 | Patch the KSampler to inject our step callback 20 | This allows us to log the step index during sampling. 21 | ''' 22 | original_sample = comfy.samplers.KSampler.sample 23 | 24 | 25 | def patched_sample(self, *args, **kwargs): 26 | # extra = kwargs.get("extra_options", {}) 27 | # if "callback" not in extra: 28 | # extra["callback"] = step_callback 29 | # kwargs["extra_options"] = extra 30 | return original_sample(self, *args, **kwargs) 31 | 32 | 33 | comfy.samplers.KSampler.sample = patched_sample 34 | 35 | 36 | def make_logging_hook(layer_name): 37 | def hook_fn(module, input, output): 38 | with torch.no_grad(): 39 | # step_index = get_step_index_fn() 40 | norm = output.norm(dim=-1).mean().item() 41 | if layer_name not in layer_step_log: 42 | layer_step_log[layer_name] = [] 43 | layer_step_log[layer_name].append(norm) 44 | 45 | return hook_fn 46 | 47 | 48 | def register_attention_hooks(model): 49 | for name, module in model.named_modules(): 50 | # Customize this filter as needed for LoRA models 51 | if "attn" in name.lower() or "cross_attn" in name.lower(): 52 | try: 53 | module.register_forward_hook(make_logging_hook(name)) 54 | except Exception as e: 55 | print(f"Failed to register hook on {name}: {e}") 56 | return model 57 | 58 | 59 | class LoRAAttentionLogger: 60 | """ 61 | Wraps a model (UNet or any LoRA-modified model) and installs hooks 62 | that log attention activity during the forward pass. 63 | """ 64 | 65 | @classmethod 66 | def INPUT_TYPES(cls): 67 | return { 68 | "required": { 69 | "model": ("MODEL",), 70 | }, 71 | } 72 | 73 | RETURN_TYPES = ("MODEL",) 74 | FUNCTION = "apply_hooks" 75 | CATEGORY = "LoRA PowerMerge/Analytics" 76 | 77 | def apply_hooks(self, model): 78 | # The UNet is inside model.model.diffusion_model if coming from base nodes 79 | target_model = getattr(model, "model", model) 80 | 81 | print("[LoRAAttentionLogger] Registering attention hooks...") 82 | register_attention_hooks(target_model) 83 | 84 | global layer_step_log 85 | layer_step_log = defaultdict(lambda: defaultdict(list)) 86 | 87 | return (model,) 88 | 89 | 90 | class LoRAAttentionPlot: 91 | @classmethod 92 | def INPUT_TYPES(cls): 93 | return { 94 | "required": { 95 | "latent": ("LATENT",), 96 | } 97 | } 98 | 99 | RETURN_TYPES = ("IMAGE",) 100 | RETURN_NAMES = ("activity_plot",) 101 | FUNCTION = "plot_activity" 102 | CATEGORY = "custom" 103 | 104 | def plot_activity(self, latent): 105 | if not layer_step_log: 106 | raise ValueError("No attention activity data found.") 107 | 108 | # Group the layers by their block names 109 | layer_data = [] 110 | for layer_name, norms in layer_step_log.items(): 111 | data0 = [layer_name, norms] + list(detect_block_names(layer_name).values()) 112 | layer_data.append(data0) 113 | df = pd.DataFrame(layer_data, columns=["layer_name", "norms", 114 | "block_type", "block_idx", "inner_idx", "component", 115 | "main_block", "sub_block", "transformer_idx"]) 116 | 117 | # Convert to image 118 | pil_img = self.generate_plot_image(df) 119 | pil_detail_images: List[Image] = self.generate_block_detail_images(df) 120 | 121 | image_tensors = self.pil_to_tensor_batch([pil_img] + pil_detail_images) 122 | 123 | # Clear data after processing 124 | layer_step_log.clear() # Clear the log after processing 125 | 126 | return (image_tensors,) 127 | 128 | def generate_plot_image(self, df: pd.DataFrame): 129 | fig, axes = plt.subplots(2, 2, figsize=(12, 8)) 130 | subplot_map = { 131 | "input_blocks": axes[0, 0], 132 | "middle_block": axes[0, 1], 133 | "output_blocks": axes[1, 0], 134 | } 135 | 136 | subplot_titles = { 137 | "input_blocks": "Input Blocks", 138 | "middle_block": "Middle Blocks", 139 | "output_blocks": "Output Blocks", 140 | } 141 | 142 | # Track which subplot had content 143 | filled_axes = set() 144 | 145 | # Determine number of steps from the first row 146 | step_count = len(df["norms"].iloc[0]) 147 | x_values = np.arange(step_count) 148 | 149 | # Group rows by block_type 150 | for block_type, ax in subplot_map.items(): 151 | block_df = df[df["block_type"] == block_type] 152 | 153 | if block_df.empty: 154 | ax.axis('off') 155 | continue 156 | 157 | filled_axes.add(block_type) 158 | 159 | # Group by main_block (e.g., input1, middle2...) for plotting 160 | for block_name, group in block_df.groupby("main_block"): 161 | norms_array = np.stack(group["norms"].values) # shape: (num_layers, step_count) 162 | summed_norms = norms_array.sum(axis=0) # shape: (step_count,) 163 | 164 | ax.plot(x_values, summed_norms, label=block_name, linewidth=1) 165 | 166 | ax.set_title(subplot_titles[block_type]) 167 | ax.set_xlabel("Denoising Step") 168 | ax.set_ylabel("Output Norm") 169 | ax.legend(fontsize='x-small', loc='upper right') 170 | 171 | # Hide unused subplot 172 | axes[1, 1].axis('off') 173 | 174 | plt.tight_layout() 175 | 176 | # Save to image 177 | buf = io.BytesIO() 178 | plt.savefig(buf, format="png") 179 | buf.seek(0) 180 | image = Image.open(buf).convert("RGB") 181 | buf.close() 182 | plt.close() 183 | 184 | return image 185 | 186 | def generate_block_detail_images(self, df: pd.DataFrame) -> List[Image.Image]: 187 | images = [] 188 | components = ["attn1", "attn2", "ff"] 189 | component_titles = { 190 | "attn1": "Attention", 191 | "attn2": "Cross-Attention", 192 | "ff": "Feedforward", 193 | } 194 | 195 | step_count = len(df["norms"].iloc[0]) 196 | x_values = np.arange(step_count) 197 | 198 | for main_block, group in df.groupby("main_block"): 199 | fig, axes = plt.subplots(2, 2, figsize=(12, 8)) 200 | flat_axes = axes.flatten() 201 | 202 | for i, component in enumerate(components): 203 | ax = flat_axes[i] 204 | comp_df = group[group["component"] == component] 205 | 206 | if comp_df.empty: 207 | ax.axis('off') 208 | continue 209 | 210 | # Plot series grouped by transformer_idx (or "others") 211 | for transformer_idx, sub_df in comp_df.groupby( 212 | comp_df["transformer_idx"].apply(lambda x: x if pd.notna(x) else "others") 213 | ): 214 | norms_array = np.stack(sub_df["norms"].values) # (num_layers, step_count) 215 | summed_norms = norms_array.sum(axis=0) 216 | label = f"{component_titles[component]} - {transformer_idx}" 217 | ax.plot(x_values, summed_norms, label=label, linewidth=1) 218 | 219 | ax.set_title(component_titles[component]) 220 | ax.set_xlabel("Denoising Step") 221 | ax.set_ylabel("Output Norm") 222 | ax.legend(fontsize="x-small", loc="upper right") 223 | 224 | # Leave last subplot empty 225 | flat_axes[3].axis('off') 226 | 227 | plt.suptitle(f"Block: {main_block}", fontsize=14) 228 | plt.tight_layout() 229 | 230 | buf = io.BytesIO() 231 | plt.savefig(buf, format="png") 232 | buf.seek(0) 233 | image = Image.open(buf).convert("RGB") 234 | buf.close() 235 | plt.close() 236 | 237 | images.append(image) 238 | 239 | return images 240 | 241 | def pil_to_tensor(self, image: Image.Image): 242 | image_np = np.array(image).astype(np.float32) / 255.0 243 | image_np = np.expand_dims(image_np, axis=0) # Shape: (1, H, W, 3) 244 | return torch.from_numpy(image_np) 245 | 246 | def pil_to_tensor_batch(self, images: List[Image.Image]) -> torch.Tensor: 247 | tensors = [] 248 | for img in images: 249 | tensor = self.pil_to_tensor(img) 250 | tensor = tensor.squeeze(0) # (H, W, 3) 251 | tensors.append(tensor) # Convert each image to tensor 252 | stack = torch.stack(tensors) 253 | print('Shape of stacked tensors:', stack.shape) # Debugging output 254 | return stack 255 | -------------------------------------------------------------------------------- /src/lora_block_sampler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image, ImageFont, ImageDraw, ImageChops 7 | 8 | from nodes import VAEDecode 9 | from comfy_extras.nodes_custom_sampler import SamplerCustom 10 | 11 | from .comfy_util import load_as_comfy_lora 12 | from .architectures import sd_lora, dit_lora 13 | from .utility import FONTS_DIR 14 | 15 | 16 | class LoRABlockSampler: 17 | @classmethod 18 | def INPUT_TYPES(cls): 19 | return { 20 | "required": { 21 | "model": ("MODEL",), 22 | "add_noise": ("BOOLEAN", {"default": True}), 23 | "noise_seed": ( 24 | "INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True} 25 | ), 26 | "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}), 27 | "positive": ("CONDITIONING",), 28 | "negative": ("CONDITIONING",), 29 | "sampler": ("SAMPLER",), 30 | "sigmas": ("SIGMAS",), 31 | "latent_image": ("LATENT",), 32 | "lora": ("LoRABundle",), 33 | "vae": ("VAE",), 34 | "bock_sampling_mode": (["round_robin_exclude", "round_robin_include"],), 35 | "image_display": (["image", "image_diff"],) 36 | } 37 | } 38 | RETURN_TYPES = ("LATENT", "IMAGE") 39 | RETURN_NAMES = ("latents", "image_grid") 40 | FUNCTION = "sample" 41 | CATEGORY = "LoRA PowerMerge/sampling" 42 | 43 | def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, 44 | latent_image, lora, vae, bock_sampling_mode, image_display): 45 | if 'lora' not in lora or lora['lora'] is None: 46 | lora['lora'] = load_as_comfy_lora(lora, model) 47 | 48 | action = "include" if bock_sampling_mode == "round_robin_include" else "exclude" 49 | 50 | patch_dict = lora['lora'] 51 | 52 | # Debug: Log first few keys to understand structure 53 | sample_keys = list(patch_dict.keys())[:3] 54 | logging.info(f"PM LoRABlockSampler: Sample keys from patch_dict: {sample_keys}") 55 | 56 | # Helper function to extract string key from various formats 57 | def get_string_key(key): 58 | """Extract string key from tuple or return as-is if already string""" 59 | if isinstance(key, tuple) and len(key) > 0 and isinstance(key[0], str): 60 | return key[0] 61 | elif isinstance(key, str): 62 | return key 63 | return None 64 | 65 | # Build a mapping of string keys for detection only 66 | # Keep original keys intact for patch application 67 | string_keys_for_detection = [] 68 | for key in patch_dict.keys(): 69 | str_key = get_string_key(key) 70 | if str_key: 71 | string_keys_for_detection.append(str_key) 72 | 73 | # Auto-detect architecture using string keys 74 | detection_dict = {get_string_key(k): v for k, v in patch_dict.items() if get_string_key(k)} 75 | arch = dit_lora.detect_architecture(detection_dict) 76 | if arch == "dit": 77 | logging.info("PM LoRABlockSampler: Detected DiT architecture") 78 | detect_fn = dit_lora.detect_block_names 79 | else: 80 | logging.info("PM LoRABlockSampler: Using SD/SDXL architecture") 81 | detect_fn = sd_lora.detect_block_names 82 | 83 | # Detect main blocks using string keys 84 | main_blocks = set() 85 | for key in patch_dict.keys(): 86 | str_key = get_string_key(key) 87 | if not str_key: 88 | logging.warning(f"PM LoRABlockSampler: Skipping unsupported key type: {type(key)} = {key}") 89 | continue 90 | block_names = detect_fn(str_key) 91 | if block_names is None: 92 | continue 93 | main_blocks.add(block_names["main_block"]) 94 | 95 | out = [] 96 | kSampler = SamplerCustom() 97 | 98 | logging.info(f"PM LoRABlockSampler: Detected main blocks: {main_blocks}") 99 | main_blocks = ["NONE", "ALL"] + sorted(list(main_blocks)) 100 | for block in main_blocks: 101 | patch_dict_filtered = {} 102 | for orig_key, value in patch_dict.items(): 103 | # Get string representation for filtering logic 104 | str_key = get_string_key(orig_key) 105 | if not str_key: 106 | continue 107 | if block == "NONE": 108 | continue 109 | if block == "ALL": 110 | # Use original key to preserve metadata 111 | patch_dict_filtered[orig_key] = value 112 | else: 113 | # Detect which main_block this key belongs to 114 | block_names = detect_fn(str_key) 115 | key_main_block = block_names["main_block"] if block_names and "main_block" in block_names else None 116 | 117 | # Filter based on detected main_block 118 | if key_main_block: 119 | if (action == "include" and key_main_block == block) or \ 120 | (action == "exclude" and key_main_block != block): 121 | patch_dict_filtered[orig_key] = value 122 | 123 | if block == "NONE": 124 | logging.info("PM LoRABlockSampler: Do not apply any of the patches.") 125 | elif block == "ALL": 126 | logging.info(f"PM LoRABlockSampler: Apply all patches. Total patches: {len(patch_dict.keys())}") 127 | else: 128 | logging.info(f"PM LoRABlockSampler: {action} block {block} from sampling, " 129 | f"remaining patches: {len(patch_dict_filtered)}") 130 | 131 | new_model_patcher = model.clone() 132 | new_model_patcher.add_patches(patch_dict_filtered, lora['strength_model']) 133 | 134 | denoised, _ = kSampler.sample( 135 | model=new_model_patcher, 136 | add_noise=add_noise, 137 | noise_seed=noise_seed, 138 | cfg=cfg, 139 | positive=positive, 140 | negative=negative, 141 | sampler=sampler, 142 | sigmas=sigmas, 143 | latent_image=latent_image, 144 | ) 145 | out.append(denoised) 146 | 147 | # Repack the output 148 | out = { 149 | "samples": torch.stack([s['samples'].squeeze(0) for s in out]) 150 | } 151 | 152 | grid_images = [] 153 | if vae is not None: 154 | vae_decoder = VAEDecode() 155 | images = list(vae_decoder.decode(vae, out)[0]) 156 | 157 | # Load a font 158 | font_path = f"{FONTS_DIR}/ShareTechMono-Regular.ttf" 159 | try: 160 | title_font = ImageFont.truetype(font_path, size=48) 161 | except OSError: 162 | logging.warning(f"PM LoRABlockSampler: Font not found at {font_path}, using default font.") 163 | title_font = ImageFont.load_default() 164 | 165 | img_diff_target = None 166 | for img_tensor, block_name in zip(images, main_blocks): 167 | i = 255. * img_tensor.cpu().numpy() 168 | img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) 169 | 170 | # Write a white text on the image indicating the block name 171 | if block_name == "NONE": 172 | title = "No LoRA blocks applied" 173 | if action == "include": 174 | img_diff_target = img 175 | elif block_name == "ALL": 176 | title = "All LoRA blocks applied" 177 | if action == "exclude": 178 | img_diff_target = img 179 | else: 180 | title = f"{action.capitalize()} block: {block_name}" 181 | title_width = title_font.getbbox(title)[2] 182 | title_padding = 6 183 | title_line_height = (title_font.getmask(title).getbbox()[3] + title_font.getmetrics()[1] + 184 | title_padding * 2) 185 | title_text_height = title_line_height 186 | title_text_image = Image.new('RGB', (img.width, title_text_height), color=(0, 0, 0, 0)) 187 | 188 | draw = ImageDraw.Draw(title_text_image) 189 | draw.text((img.width // 2 - title_width // 2, title_padding), title, font=title_font, 190 | fill=(255, 255, 255)) 191 | # Convert the title text image to a tensor 192 | title_text_image_tensor = torch.tensor(np.array(title_text_image).astype(np.float32) / 255.0) 193 | 194 | if image_display == "image_diff": 195 | # Calculate the difference 196 | if block_name not in ("NONE", "ALL") and img_diff_target is not None: 197 | img_tensor = ImageChops.difference(img, img_diff_target) 198 | # Convert the image difference to a tensor 199 | img_tensor = torch.tensor(np.array(img_tensor).astype(np.float32) / 255.0) 200 | 201 | out_image = torch.cat([title_text_image_tensor, img_tensor], 0) 202 | grid_images.append(out_image) 203 | 204 | grid_images = torch.stack(grid_images) 205 | return out, grid_images, 206 | -------------------------------------------------------------------------------- /src/experimental/lora_analyzer.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import io 3 | import logging 4 | from typing import Dict, Any, Tuple 5 | 6 | import cairosvg 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | from lxml import etree 11 | from mergekit.common import ModelReference 12 | 13 | from ..architectures import sd_lora 14 | from comfy.utils import ProgressBar # Assuming this is thread-safe or replaced 15 | from comfy_util import load_as_comfy_lora 16 | 17 | 18 | class LoRAAnalyzer: 19 | @classmethod 20 | def INPUT_TYPES(cls): 21 | return { 22 | "required": { 23 | "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), 24 | "clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}), 25 | "lora": ("LoRABundle",), 26 | "indicator": (["frobenius_norm", "sparsity"],), 27 | }, 28 | } 29 | 30 | RETURN_TYPES = ("IMAGE", "STRING",) 31 | RETURN_NAMES = ("image", "blocks_store") 32 | 33 | FUNCTION = "run" 34 | CATEGORY = "LoRA PowerMerge/Analytics" 35 | 36 | def run(self, model: ModelReference, clip: ModelReference, lora: Dict[str, Any], indicator): 37 | if 'lora' not in lora or lora['lora'] is None: 38 | lora['lora'] = load_as_comfy_lora(lora, model, clip) 39 | 40 | # Calculate indicators for each tensor 41 | layer_measures = self.calculate_all_measures(lora['lora']) 42 | 43 | # Generate SVG from block_dict 44 | svg, color_settings = self.generate_svg(layer_measures, indicator=indicator, size=512) 45 | # Generate image from SVG 46 | image = self.generate_image(svg) 47 | 48 | return image, color_settings 49 | 50 | def process_layer(self, layer_key, lora_adapter): 51 | try: 52 | block_names = sd_lora.detect_block_names(layer_key) 53 | if not block_names or "main_block" not in block_names: 54 | logging.info(f"Skipping layer {layer_key} as it does not match the expected pattern.") 55 | return layer_key, None 56 | 57 | alpha = lora_adapter.weights[2] 58 | up_weights = lora_adapter.weights[0] 59 | down_weights = lora_adapter.weights[1] 60 | if up_weights.dim() > 2 or down_weights.dim() > 2: 61 | delta_W = alpha * up_weights.squeeze((2, 3)) @ down_weights.squeeze((2, 3)) 62 | else: 63 | delta_W = alpha * up_weights @ down_weights 64 | 65 | frobenius_norm, mean_frobenius = self.frobenius_norm(delta_W) 66 | sparsity = self.sparsity(delta_W, epsilon=1e-3) 67 | 68 | return layer_key, { 69 | "main_block": block_names["main_block"], 70 | "sub_block": block_names["sub_block"] if "sub_block" in block_names else None, 71 | "frobenius_norm": mean_frobenius, 72 | "sparsity": sparsity, 73 | } 74 | except Exception as e: 75 | logging.error(f"Error processing {layer_key}: {e}") 76 | return layer_key, None 77 | 78 | def calculate_all_measures(self, patch_dict): 79 | layer_measures = {} 80 | pbar = ProgressBar(len(patch_dict)) 81 | 82 | # Threaded version 83 | with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: 84 | futures = { 85 | executor.submit(self.process_layer, layer_key, lora_adapter): layer_key 86 | for layer_key, lora_adapter in patch_dict.items() 87 | } 88 | 89 | for future in concurrent.futures.as_completed(futures): 90 | layer_key, result = future.result() 91 | if result is not None: 92 | layer_measures[layer_key] = result 93 | pbar.update(1) 94 | 95 | return layer_measures 96 | 97 | def frobenius_norm(self, tensor: torch.Tensor) -> tuple[float, float]: 98 | """Calculate the Frobenius norm of a tensor.""" 99 | frobenius_norm = torch.norm(tensor, p='fro').item() 100 | frobenius_norm_mean = frobenius_norm / tensor.numel() 101 | return frobenius_norm, frobenius_norm_mean 102 | 103 | def sparsity(self, tensor: torch.Tensor, epsilon: float) -> float: 104 | """Calculate the sparsity of a tensor.""" 105 | if tensor.numel() == 0: 106 | return 0.0 107 | return (torch.sum(torch.abs(tensor) < epsilon).item() / tensor.numel()) * 100.0 108 | 109 | def generate_svg(self, layer_measures: Dict[str, Dict[str, Any]], indicator="frobenius_norm", size=512) -> \ 110 | Tuple[str, Dict[str, str]]: 111 | # Group the layer_measures by their main_block and calculate an average indicator value for each block. 112 | block_info = self.group_measures(indicator, layer_measures) 113 | 114 | # Modify the SVG shapes with the block_info 115 | color_settings = self.calculate_color_settings(block_info, indicator) 116 | 117 | # Load the SVG template and apply the color schema 118 | svg = self.apply_color_schema(block_info, color_settings) 119 | 120 | return svg, color_settings 121 | 122 | def apply_color_schema(self, block_info, color_settings): 123 | # read svg template from file js/sdxl_unet.svg 124 | with open("custom_nodes/LoRA-Merger-ComfyUI/js/sdxl_unet.svg", "r") as f: 125 | svg_template = f.read() 126 | 127 | root = etree.fromstring(svg_template) 128 | # SVGs use namespaces 129 | ns = {'svg': 'http://www.w3.org/2000/svg'} 130 | for key, value in block_info.items(): 131 | # Loop through each block and replace the placeholders in the SVG template 132 | # Find the svg element where id matches "{key}.rect" 133 | element_id = f"{key}.rect" 134 | # Find element by id 135 | target = root.xpath(f'//*[@id="{element_id}"]', namespaces=ns) 136 | if target: 137 | # Replace the background (search for fill) with a color based on the value 138 | color = color_settings[key] 139 | # Change attribute 140 | elem = target[0] 141 | elem.set("fill", color) 142 | else: 143 | logging.warning(f"Block {key} not found in SVG template. Skipping.") 144 | # Print modified SVG to string 145 | svg_template = etree.tostring(root, pretty_print=True, xml_declaration=True, encoding='UTF-8').decode('utf-8') 146 | return svg_template 147 | 148 | def calculate_color_settings(self, block_info, indicator) -> Dict[str, str]: 149 | """Creates a dictionary with color settings for each block based on the indicator.""" 150 | color_settings = {} 151 | indicator_values = [v[indicator] for v in block_info.values() if indicator in v] 152 | min_value = min(indicator_values) if indicator_values else 0 153 | max_value = max(indicator_values) if indicator_values else 1 154 | 155 | def interpolate_color(v, min_color=(255, 255, 255), max_color=(255, 0, 0)): 156 | """Interpolate color based on the value.""" 157 | if max_value == min_value: 158 | return min_color 159 | ratio = (v - min_value) / (max_value - min_value) 160 | r = int(min_color[0] + ratio * (max_color[0] - min_color[0])) 161 | g = int(min_color[1] + ratio * (max_color[1] - min_color[1])) 162 | b = int(min_color[2] + ratio * (max_color[2] - min_color[2])) 163 | return f'rgb({r}, {g}, {b})' 164 | 165 | for key, value in block_info.items(): 166 | max_color = (0, 0, 255) if value['block_type'] == "sub_block" else (255, 0, 0) 167 | color_settings[key] = interpolate_color(value[indicator], max_color=max_color) 168 | return color_settings 169 | 170 | def group_measures(self, indicator, layer_measures): 171 | # group the values of layer_measures by their main_block. Calculate an average indicator value for each block.# 172 | block_info = {} 173 | for layer_key, measures in layer_measures.items(): 174 | main_block = measures["main_block"] 175 | sub_block = measures["sub_block"] 176 | value = measures[indicator] 177 | 178 | # Aggregate values for each block 179 | if main_block not in block_info: 180 | block_info[main_block] = {"block_type": "main_block", indicator: 0.} 181 | block_info[main_block][indicator] += value 182 | 183 | # If sub_block is present, also aggregate it 184 | if sub_block: 185 | if sub_block not in block_info: 186 | block_info[sub_block] = {"block_type": "sub_block", indicator: 0.} 187 | block_info[sub_block][indicator] += value 188 | 189 | # Normalize the values by the number of occurrences 190 | for block, values in block_info.items(): 191 | count = sum(1 for v in layer_measures.values() if v["main_block"] == block or v["sub_block"] == block) 192 | if count > 0: 193 | for key in values: 194 | # if values[key] is a number, normalize it 195 | if isinstance(values[key], (int, float)): 196 | values[key] /= count 197 | return block_info 198 | 199 | def generate_image(self, svg): 200 | # Convert SVG string to PNG 201 | png_bytes = cairosvg.svg2png(bytestring=svg.encode("utf-8")) 202 | 203 | # Load into PIL Image 204 | image = Image.open(io.BytesIO(png_bytes)).convert("RGB") 205 | 206 | # Convert to numpy and normalize 207 | image_np = np.array(image).astype(np.float32) / 255.0 208 | 209 | # Convert to torch tensor and add batch dimension (B, H, W, C) 210 | image_tensor = torch.from_numpy(image_np).unsqueeze(0) # shape: [1, H, W, C] 211 | 212 | # Return shape (B, H, W, C) 213 | return image_tensor -------------------------------------------------------------------------------- /src/lora_power_stacker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import folder_paths 3 | import comfy 4 | import comfy.lora 5 | import comfy.utils 6 | from nodes import LoraLoader 7 | 8 | from .merge.utils import parse_layer_filter, apply_layer_filter 9 | 10 | 11 | class AnyType(str): 12 | """A special class that is always equal in not equal comparisons. Credit to pythongosssss""" 13 | 14 | def __ne__(self, __value: object) -> bool: 15 | return False 16 | 17 | 18 | class FlexibleOptionalInputType(dict): 19 | """A special class to make flexible nodes that pass data to our python handlers. 20 | 21 | Enables both flexible/dynamic input types (like for Any Switch) or a dynamic number of inputs 22 | (like for Any Switch, Context Switch, Context Merge, Power Lora Loader, etc). 23 | 24 | Note, for ComfyUI, all that's needed is the `__contains__` override below, which tells ComfyUI 25 | that our node will handle the input, regardless of what it is. 26 | 27 | However, with https://github.com/comfyanonymous/ComfyUI/pull/2666 a large change would occur 28 | requiring more details on the input itself. There, we need to return a list/tuple where the first 29 | item is the type. This can be a real type, or use the AnyType for additional flexibility. 30 | 31 | This should be forwards compatible unless more changes occur in the PR. 32 | """ 33 | 34 | def __init__(self, type): 35 | self.type = type 36 | 37 | def __getitem__(self, key): 38 | return (self.type,) 39 | 40 | def __contains__(self, key): 41 | return True 42 | 43 | 44 | any_type = AnyType("*") 45 | 46 | 47 | def get_lora_by_filename(file_path, log_node="PM LoRA Power Stacker"): 48 | """Returns a lora by filename, looking for exact paths and then fuzzier matching. 49 | 50 | Adapted from rgthree's power_prompt_utils.py 51 | """ 52 | lora_paths = folder_paths.get_filename_list('loras') 53 | 54 | if file_path in lora_paths: 55 | return file_path 56 | 57 | lora_paths_no_ext = [os.path.splitext(x)[0] for x in lora_paths] 58 | 59 | # See if we've entered the exact path, but without the extension 60 | if file_path in lora_paths_no_ext: 61 | found = lora_paths[lora_paths_no_ext.index(file_path)] 62 | return found 63 | 64 | # Same check, but ensure file_path is without extension. 65 | file_path_force_no_ext = os.path.splitext(file_path)[0] 66 | if file_path_force_no_ext in lora_paths_no_ext: 67 | found = lora_paths[lora_paths_no_ext.index(file_path_force_no_ext)] 68 | return found 69 | 70 | # See if we passed just the name, without paths. 71 | lora_filenames_only = [os.path.basename(x) for x in lora_paths] 72 | if file_path in lora_filenames_only: 73 | found = lora_paths[lora_filenames_only.index(file_path)] 74 | print(f'[{log_node}] Matched Lora input "{file_path}" to "{found}".') 75 | return found 76 | 77 | # Same, but force the input to be without paths 78 | file_path_force_filename = os.path.basename(file_path) 79 | lora_filenames_only = [os.path.basename(x) for x in lora_paths] 80 | if file_path_force_filename in lora_filenames_only: 81 | found = lora_paths[lora_filenames_only.index(file_path_force_filename)] 82 | print(f'[{log_node}] Matched Lora input "{file_path}" to "{found}".') 83 | return found 84 | 85 | # Check the filenames and without extension. 86 | lora_filenames_and_no_ext = [os.path.splitext(os.path.basename(x))[0] for x in lora_paths] 87 | if file_path in lora_filenames_and_no_ext: 88 | found = lora_paths[lora_filenames_and_no_ext.index(file_path)] 89 | print(f'[{log_node}] Matched Lora input "{file_path}" to "{found}".') 90 | return found 91 | 92 | # And, one last forcing the input to be the same 93 | file_path_force_filename_and_no_ext = os.path.splitext(os.path.basename(file_path))[0] 94 | if file_path_force_filename_and_no_ext in lora_filenames_and_no_ext: 95 | found = lora_paths[lora_filenames_and_no_ext.index(file_path_force_filename_and_no_ext)] 96 | print(f'[{log_node}] Matched Lora input "{file_path}" to "{found}".') 97 | return found 98 | 99 | # Finally, super fuzzy, we'll just check if the input exists in the path at all. 100 | for index, lora_path in enumerate(lora_paths): 101 | if file_path in lora_path: 102 | found = lora_paths[index] 103 | print(f'[{log_node}] Fuzzy-matched Lora input "{file_path}" to "{found}".') 104 | return found 105 | 106 | print(f'[{log_node}] WARNING: Lora "{file_path}" not found, skipping.') 107 | return None 108 | 109 | 110 | class LoraPowerStacker: 111 | """The Power LoRA Stacker is a flexible widget-based node to stack multiple LoRAs. 112 | 113 | Similar to rgthree's Power Lora Loader but outputs LoRAKeyDict and LoRAStrengths 114 | for use with PM LoRA PowerMerge workflow. 115 | """ 116 | 117 | NAME = "PM LoRA Power Stacker" 118 | CATEGORY = "LoRA PowerMerge" 119 | DESCRIPTION = """Widget-based LoRA stacker for PowerMerge workflow. 120 | 121 | Outputs: 122 | - LoRAStack: Processed model weights (filtered by layer_filter) 123 | - LoRAWeights: Strength metadata for each LoRA 124 | - LoRARawDict: Original raw state dicts (preserves CLIP weights) 125 | - CLIP: Modified CLIP model with all LoRAs applied 126 | 127 | Use the widget to add/remove LoRAs dynamically. Connect LoRARawDict to LoRA Select to preserve CLIP weights when saving merged LoRAs.""" 128 | 129 | @classmethod 130 | def INPUT_TYPES(cls): 131 | return { 132 | "required": { 133 | "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), 134 | "clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}), 135 | }, 136 | "optional": FlexibleOptionalInputType(any_type), 137 | "hidden": {}, 138 | } 139 | 140 | RETURN_TYPES = ("LoRAStack", "LoRAWeights", "LoRARawDict", "CLIP") 141 | RETURN_NAMES = ("LoRAStack", "LoRAWeights", "LoRARawDict", "CLIP") 142 | FUNCTION = "stack_loras_widget" 143 | 144 | def stack_loras_widget(self, model, clip, **kwargs): 145 | """Loops over the provided loras in kwargs and builds stack outputs.""" 146 | 147 | # Extract layer_filter if provided (comes from widget, not LoRA data) 148 | layer_filter = kwargs.pop("layer_filter", "full") 149 | layer_filter = parse_layer_filter(layer_filter) 150 | 151 | # Build key_map for LoRA loading 152 | key_map = {} 153 | if model is not None: 154 | key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) 155 | 156 | # Initialize outputs 157 | lora_patch_dicts = {} 158 | lora_strengths = {} 159 | lora_raw_dicts = {} # Store raw LoRA state dicts for CLIP weights 160 | 161 | # Track how many LoRAs were loaded 162 | loaded_count = 0 163 | 164 | # Loop through kwargs looking for LoRA widgets 165 | for key, value in kwargs.items(): 166 | key_upper = key.upper() 167 | 168 | # Check if this is a LoRA widget (must have required fields) 169 | if (key_upper.startswith('LORA_') and 170 | isinstance(value, dict) and 171 | 'on' in value and 172 | 'lora' in value and 173 | 'strength' in value): 174 | 175 | # Extract values 176 | is_on = value.get('on', False) 177 | lora_name = value.get('lora') 178 | strength_model = value.get('strength', 1.0) 179 | 180 | # Handle separate model/clip strengths 181 | # If strengthTwo exists and is not None, use it for clip 182 | # Otherwise use strength for both model and clip 183 | strength_clip = value.get('strengthTwo') 184 | if strength_clip is None: 185 | strength_clip = strength_model 186 | 187 | # Skip if disabled or strength is zero 188 | if not is_on or (strength_model == 0 and strength_clip == 0): 189 | continue 190 | 191 | # Skip if no LoRA specified 192 | if not lora_name or lora_name == "None": 193 | continue 194 | 195 | # Find the LoRA file using fuzzy matching 196 | lora_file = get_lora_by_filename(lora_name, log_node=self.NAME) 197 | if lora_file is None: 198 | continue 199 | 200 | # Load the LoRA 201 | try: 202 | lora_path = folder_paths.get_full_path("loras", lora_file) 203 | lora_raw = comfy.utils.load_torch_file(lora_path, safe_load=True) 204 | 205 | # Get pretty name (without extension) 206 | lora_name_pretty = os.path.splitext(os.path.basename(lora_file))[0] 207 | 208 | # Load LoRA into patch dict 209 | patch_dict = comfy.lora.load_lora(lora_raw, key_map) 210 | 211 | # Apply layer filter 212 | patch_dict = apply_layer_filter(patch_dict, layer_filter) 213 | 214 | # Store in outputs 215 | lora_patch_dicts[lora_name_pretty] = patch_dict 216 | lora_strengths[lora_name_pretty] = { 217 | 'strength_model': strength_model, 218 | } 219 | lora_raw_dicts[lora_name_pretty] = lora_raw # Store raw state dict 220 | 221 | # Apply to CLIP 222 | # Note: We need a dummy model for LoraLoader, but we only care about CLIP output 223 | # So we'll use the standard LoraLoader node's load_lora method 224 | _, clip = LoraLoader().load_lora(model, clip, lora_file, strength_model, strength_clip) 225 | 226 | loaded_count += 1 227 | 228 | except Exception as e: 229 | print(f"[{self.NAME}] Error loading LoRA '{lora_name}': {e}") 230 | continue 231 | 232 | print(f"[{self.NAME}] Loaded {loaded_count} LoRAs") 233 | 234 | return (lora_patch_dicts, lora_strengths, lora_raw_dicts, clip) 235 | -------------------------------------------------------------------------------- /src/lora_resize.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import comfy.utils 5 | import torch 6 | from comfy.weight_adapter import LoRAAdapter 7 | 8 | from .types import LoRABundleDict 9 | from .utility import map_device, adjust_tensor_dims 10 | 11 | 12 | class LoraResizer: 13 | def __init__(self): 14 | self.loaded_lora = None 15 | @classmethod 16 | def INPUT_TYPES(s): 17 | return { 18 | "required": { 19 | "lora": ("LoRABundle",), 20 | "decomposition_method": ( 21 | ["rSVD", "energy_rSVD", "SVD"], 22 | { 23 | "default": "rSVD", 24 | "tooltip": ( 25 | "Method used to reconcile LoRA ranks when they differ. " 26 | "'SVD' uses full singular value decomposition (slow but optimal). " 27 | "'rSVD' uses randomized SVD (much faster, near-optimal). " 28 | "'energy_rSVD' first prunes low-energy LoRA components and then " 29 | "applies randomized SVD for fast, stable rank reduction " 30 | "(recommended for DiT and large LoRAs)." 31 | ), 32 | } 33 | ), 34 | "new_rank": ( 35 | "INT", 36 | { 37 | "default": 16, 38 | "min": 1, 39 | "max": 128, 40 | "tooltip": ( 41 | "Target LoRA rank after decomposition. " 42 | "Lower values reduce model size and strength." 43 | ), 44 | } 45 | ), 46 | "device": (["cuda", "cpu"],), 47 | "dtype": (["float32", "float16", "bfloat16"],), 48 | }, 49 | } 50 | 51 | RETURN_TYPES = ("LoRABundle",) 52 | FUNCTION = "lora_resize" 53 | CATEGORY = "LoRA PowerMerge" 54 | DESCRIPTION = """Resizes a LoRA to a different rank using tensor decomposition. 55 | 56 | This node reduces or increases the rank of all layers in a LoRA model using SVD-based methods. 57 | Lower ranks reduce memory usage and may reduce strength, while maintaining semantic meaning. 58 | 59 | Decomposition Methods: 60 | - SVD: Full singular value decomposition (slow but optimal) 61 | - rSVD: Randomized SVD (fast, recommended for most cases) 62 | - energy_rSVD: Energy-based randomized SVD (best for DiT/large LoRAs) 63 | 64 | The resizing uses asymmetric singular value distribution (all S values in up matrix) 65 | which differs from the symmetric distribution used in lora_decompose.""" 66 | 67 | def lora_resize( 68 | self, 69 | lora: LoRABundleDict, 70 | decomposition_method: str = "rSVD", 71 | new_rank: int = 16, 72 | device: str = "cuda", 73 | dtype: str = "float32" 74 | ) -> tuple: 75 | """ 76 | Resize a LoRA to a new rank using tensor decomposition. 77 | 78 | Args: 79 | lora: LoRA bundle containing lora_raw and lora (LoRAAdapter dict) 80 | decomposition_method: Method to use ('SVD', 'rSVD', 'energy_rSVD') 81 | new_rank: Target rank for all layers 82 | device: Device for computation 83 | dtype: Data type for computation 84 | 85 | Returns: 86 | Tuple containing resized LoRA bundle 87 | """ 88 | device, dtype = map_device(device, dtype) 89 | 90 | logging.info(f"Resizing LoRA '{lora.get('name', 'unknown')}' to rank {new_rank} using {decomposition_method}") 91 | 92 | # Extract the LoRA adapter dictionary (layer_key -> LoRAAdapter) 93 | lora_adapters = lora["lora"] 94 | lora_raw = lora.get("lora_raw", {}) 95 | 96 | # Get all keys from the LoRA 97 | keys = list(lora_adapters.keys()) 98 | 99 | logging.info(f"Processing {len(keys)} layers") 100 | 101 | pbar = comfy.utils.ProgressBar(len(keys)) 102 | start = time.time() 103 | 104 | # Process each layer 105 | resized_adapters = {} 106 | 107 | for key in keys: 108 | adapter = lora_adapters[key] 109 | up, down, alpha, mid, dora_scale, reshape = adapter.weights 110 | 111 | # Skip if mid exists (LoHA/LoCon - not supported for now) 112 | if mid is not None: 113 | logging.warning(f"Skipping layer {key}: LoHA/LoCon format with mid tensor not supported") 114 | resized_adapters[key] = adapter 115 | pbar.update(1) 116 | continue 117 | 118 | # Skip if DoRA scale exists (DoRA not fully supported) 119 | if dora_scale is not None: 120 | logging.warning(f"Skipping layer {key}: DoRA format not supported") 121 | resized_adapters[key] = adapter 122 | pbar.update(1) 123 | continue 124 | 125 | # Get current rank from down tensor 126 | current_rank = down.shape[0] 127 | 128 | # Handle alpha=None: in standard LoRA, None means alpha equals rank 129 | if alpha is None: 130 | alpha = current_rank 131 | 132 | # Check if resizing is needed 133 | if current_rank == new_rank: 134 | # No resizing needed 135 | resized_adapters[key] = adapter 136 | pbar.update(1) 137 | continue 138 | 139 | # Create a single-item dictionary for adjust_tensor_dims 140 | # adjust_tensor_dims expects Dict[str, LORA_TENSORS] 141 | temp_dict = {"temp": (up, down, alpha)} 142 | 143 | # Apply resizing using adjust_tensor_dims 144 | if decomposition_method == "none": 145 | # If method is 'none' and ranks differ, this will raise an error 146 | if current_rank != new_rank: 147 | raise ValueError( 148 | f"Layer '{key}' has rank {current_rank} but target is {new_rank}. " 149 | f"Cannot resize with decomposition_method='none'. " 150 | f"Please select a decomposition method (SVD, rSVD, or energy_rSVD)." 151 | ) 152 | resized_adapters[key] = adapter 153 | else: 154 | # Resize using the selected method 155 | resized_dict = adjust_tensor_dims( 156 | temp_dict, 157 | apply_svd=True, 158 | svd_rank=new_rank, 159 | method=decomposition_method 160 | ) 161 | 162 | up_new, down_new, alpha_new = resized_dict["temp"] 163 | 164 | # Scale alpha proportionally to maintain the same effective strength 165 | # Original strength: (alpha / current_rank) * (up @ down) 166 | # New strength: (alpha_scaled / new_rank) * (up_new @ down_new) 167 | # To maintain same strength: alpha_scaled = alpha * (new_rank / current_rank) 168 | # However, since we want to preserve the original alpha value semantics, 169 | # we keep alpha unchanged and the SVD already captured the strength in the tensors 170 | # But if alpha was None (now set to current_rank), scale it to new_rank 171 | if alpha_new == current_rank: 172 | alpha_new = new_rank 173 | 174 | # Log shapes and alpha for debugging 175 | # logging.info(f"Layer {key}: Original shapes up={up.shape}, down={down.shape}, rank={current_rank}, alpha={alpha}") 176 | # logging.info(f"Layer {key}: Resized shapes up_new={up_new.shape}, down_new={down_new.shape}, rank={new_rank}, alpha_new={alpha_new}") 177 | 178 | # Ensure tensors are contiguous (required for safetensors saving) 179 | # SVD operations can produce non-contiguous tensors 180 | up_new = up_new.contiguous() 181 | down_new = down_new.contiguous() 182 | 183 | # Generate the loaded_keys that will be used during save 184 | # These keys need to match the format that convert_to_regular_lora expects 185 | key_str = key[0] if isinstance(key, tuple) else key 186 | 187 | # Strip .weight suffix if present (matches convert_to_regular_lora logic) 188 | key_base = key_str.replace("diffusion_model.", "") 189 | if key_base.endswith(".weight"): 190 | key_base = key_base[:-len(".weight")] 191 | 192 | # Convert dots to underscores for LoRA naming convention 193 | key_suffix = key_base.replace(".", "_") 194 | 195 | # Create the loaded_keys set for this adapter 196 | # Format matches what convert_to_regular_lora will generate 197 | new_loaded_keys = { 198 | f"lora_unet_{key_suffix}.lora_up.weight", 199 | f"lora_unet_{key_suffix}.lora_down.weight", 200 | f"lora_unet_{key_suffix}.alpha" 201 | } 202 | 203 | # Create new LoRAAdapter with resized tensors 204 | # Keep mid, dora_scale as None since we skipped those cases 205 | # Set reshape to None - the tensors are already in the correct shape after SVD 206 | # and keeping the old reshape metadata would cause mismatches 207 | resized_adapters[key] = LoRAAdapter( 208 | weights=(up_new, down_new, alpha_new, None, None, None), 209 | loaded_keys=new_loaded_keys 210 | ) 211 | 212 | pbar.update(1) 213 | 214 | logging.info(f"Resized {len(keys)} layers in {time.time() - start:.2f} seconds") 215 | 216 | # Create output bundle 217 | # Preserve the original raw state dict and metadata 218 | lora_out = { 219 | "lora_raw": lora_raw, # Keep original raw dict (includes CLIP weights) 220 | "lora": resized_adapters, # Resized adapters 221 | "strength_model": lora.get("strength_model", 1.0), 222 | "strength_clip": lora.get("strength_clip", 1.0), 223 | "name": f"{lora.get('name', 'LoRA')}_r{new_rank}" 224 | } 225 | 226 | torch.cuda.empty_cache() 227 | 228 | return (lora_out,) 229 | -------------------------------------------------------------------------------- /src/experimental/checkpoint_merge.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, Protocol, Optional 2 | 3 | import torch 4 | from mergekit.architecture import WeightInfo 5 | from mergekit.common import ModelReference, ModelPath, ImmutableMap 6 | from mergekit.io.tasks import GatherTensors 7 | from mergekit.merge_methods import REGISTERED_MERGE_METHODS 8 | from mergekit.merge_methods.generalized_task_arithmetic import GTATask 9 | from mergekit.merge_methods.karcher import KarcherTask 10 | from mergekit.merge_methods.linear import LinearMergeTask 11 | from mergekit.merge_methods.model_stock import ModelStockMergeTask 12 | from mergekit.merge_methods.nearswap import nearswap_merge 13 | from mergekit.merge_methods.nuslerp import NuSlerpTask 14 | 15 | import comfy 16 | from ..utility import map_device 17 | 18 | 19 | class MergeMethod(Protocol): 20 | def __call__( 21 | *, 22 | tensors: Dict[ModelReference, torch.Tensor], 23 | gather_tensors: GatherTensors, 24 | base_model: ModelReference, 25 | weight_info: WeightInfo, 26 | tensor_parameters: Optional[ImmutableMap[ModelReference, Any]] = ..., 27 | method_args: Optional[Dict] = ..., 28 | ) -> torch.Tensor: ... 29 | 30 | 31 | class CheckpointMergerMergekit: 32 | """ 33 | Node for merging LoRA models with Mergekit and algorithms that require SVD 34 | """ 35 | 36 | @classmethod 37 | def INPUT_TYPES(s): 38 | return { 39 | "required": { 40 | "method": ("MergeMethod",), 41 | "base_model": ("MODEL",), 42 | "model1": ("MODEL",), 43 | "lambda_": ("FLOAT", { 44 | "default": 1, 45 | "min": 0, 46 | "max": 1, 47 | "step": 0.01, 48 | "tooltip": "Scaling factor between 0 and 1 applied after weighted sum of task vectors.", 49 | }), 50 | "device": (["cuda", "cpu"],), 51 | "dtype": (["float16", "bfloat16", "float32"],), 52 | }, 53 | } 54 | 55 | RETURN_TYPES = ("MODEL",) 56 | FUNCTION = "checkpoint_mergekit" 57 | CATEGORY = "LoRA PowerMerge" 58 | 59 | @torch.no_grad() 60 | def checkpoint_mergekit(self, method: Dict = None, base_model=None, model1=None, lambda_=None, 61 | device=None, dtype=None, **kwargs): 62 | models = [model1] 63 | for k, v in kwargs.items(): 64 | models.append(v) 65 | 66 | if method['name'] == "linear": 67 | merge_method = linear_merge 68 | elif method['name'] == "model_stock": 69 | merge_method = model_stock_merge 70 | elif method['name'] == "nuslerp": 71 | merge_method = nuslerp_merge 72 | elif method['name'] == "nearswap": 73 | merge_method = nearswap_merge_ 74 | elif method['name'] == "task_arithmetic": 75 | merge_method = task_arithmetic 76 | else: 77 | raise Exception(f"Invalid / unsupported method {method['name']}") 78 | 79 | method_args = { 80 | "normalize": False, # LinearMerge & Task Arithmetic 81 | #"max_iter": 10, # kArcher only 82 | #"tol": 0.1, # kArcher only 83 | "row_wise": False, # NuSlerp only 84 | "flatten": True, # NuSlerp only 85 | "similarity_threshold": 0.5, # Nearswap only 86 | "int8_mask": False 87 | } 88 | # update method_args with dictionary method['settings'] 89 | method_args.update(method['settings']) 90 | 91 | return checkpoint_process(merge_method=merge_method, method_args=method_args, base_model=base_model, 92 | models=models, lambda_=lambda_, device=device, dtype=dtype) 93 | 94 | 95 | def checkpoint_process(merge_method: MergeMethod, method_args: Dict, base_model, models, device, dtype, lambda_=1.0): 96 | device, dtype = map_device(device, dtype) 97 | 98 | keys_base_model = base_model.get_key_patches("diffusion_model.") 99 | 100 | pbar = comfy.utils.ProgressBar(len(keys_base_model)) 101 | for key, v in keys_base_model.items(): 102 | # We need to get the model up and down tensors 103 | base_model_tensor = v[0][0].to(device=device, dtype=dtype) 104 | base_model_ref = ModelReference(model=ModelPath(path=key)) 105 | 106 | tensors = {base_model_ref: base_model_tensor} 107 | tensor_weights = {base_model_ref: 0} 108 | for i, model in enumerate(models): 109 | ref = ModelReference(model=ModelPath(path=str(i))) 110 | key_weight, _, _ = comfy.model_patcher.get_key_weight(model.model, key) 111 | tensors[ref] = key_weight.to(device=device, dtype=dtype) 112 | tensor_weights[ref] = 1 113 | 114 | # Wrap into mergekit data structure 115 | immutable_map = ImmutableMap({ 116 | r: WeightInfo(name=f'model{i}.{key}', dtype=dtype) for i, r in enumerate(tensors.keys()) 117 | }) 118 | gather_tensors = GatherTensors(weight_info=immutable_map) 119 | weight_info = WeightInfo(name='base.' + key, dtype=dtype) 120 | 121 | tensor_parameters = ImmutableMap({r: ImmutableMap({"weight": tensor_weights[r]}) for r in tensors.keys()}) 122 | 123 | merge = merge_method(tensors=tensors, base_model=base_model_ref, weight_info=weight_info, 124 | gather_tensors=gather_tensors, tensor_parameters=tensor_parameters, method_args=method_args) 125 | 126 | # Apply lambda_ to the merge 127 | merge *= lambda_ 128 | 129 | # pass the merged tensor to the new model 130 | base_model.add_patches({key: (merge.to(device="cpu"),)}, 1, 0) 131 | 132 | for tv in tensors.values(): 133 | tv.to(device="cpu") 134 | 135 | pbar.update(1) 136 | 137 | return (base_model,) 138 | 139 | 140 | def linear_merge( 141 | *, 142 | tensors: Dict[ModelReference, torch.Tensor], 143 | base_model: ModelReference, 144 | weight_info: WeightInfo, 145 | gather_tensors: GatherTensors, 146 | tensor_parameters: ImmutableMap[ModelReference, Any], 147 | method_args: Optional[Dict] = None 148 | ) -> torch.Tensor: 149 | """ 150 | Merges tensors using a linear merge strategy, excluding the base model from the merge. 151 | """ 152 | 153 | method_args = method_args or {} 154 | 155 | # Exclude base_model from tensors 156 | tensors = {k: v for k, v in tensors.items() if k != base_model} 157 | 158 | task = LinearMergeTask( 159 | base_model=base_model, 160 | weight_info=weight_info, 161 | gather_tensors=gather_tensors, 162 | tensor_parameters=tensor_parameters, 163 | normalize=method_args.get("normalize", False) 164 | ) 165 | 166 | return task.execute(tensors=tensors) 167 | 168 | 169 | def karcher_merge( 170 | *, 171 | tensors: Dict[ModelReference, torch.Tensor], 172 | base_model: ModelReference, 173 | weight_info: WeightInfo, 174 | gather_tensors: GatherTensors, 175 | tensor_parameters: Optional[ImmutableMap[ModelReference, Any]] = None, 176 | method_args: Optional[Dict] = None, 177 | ) -> torch.Tensor: 178 | method_args = method_args or {} 179 | 180 | # Extract base tensor and remove it from the dictionary 181 | tensors.pop(base_model) 182 | 183 | task = KarcherTask( 184 | base_model=base_model, 185 | weight_info=weight_info, 186 | gather_tensors=gather_tensors, 187 | max_iter=method_args.get('max_iter', 10), 188 | tol=method_args.get('tol', 1e-4), 189 | ) 190 | return task.execute(tensors=tensors) 191 | 192 | 193 | def task_arithmetic( 194 | *, 195 | tensors: Dict[ModelReference, torch.Tensor], 196 | base_model: ModelReference, 197 | weight_info: WeightInfo, 198 | gather_tensors: GatherTensors, 199 | tensor_parameters: ImmutableMap[ModelReference, Any], 200 | method_args: Optional[Dict] = None, 201 | ) -> torch.Tensor: 202 | method_args = method_args or {} 203 | 204 | method = REGISTERED_MERGE_METHODS.get("task_arithmetic") 205 | 206 | task = GTATask( 207 | method=method, 208 | base_model=base_model, 209 | weight_info=weight_info, 210 | tensors=gather_tensors, 211 | tensor_parameters=tensor_parameters, 212 | int8_mask=method_args.get('int8_mask', False), 213 | normalize=method_args.get('normalize', True), 214 | lambda_=1.0, 215 | rescale_norm=None 216 | ) 217 | return task.execute(tensors=tensors) 218 | 219 | 220 | def nearswap_merge_( 221 | *, 222 | tensors: Dict[ModelReference, torch.Tensor], 223 | base_model: ModelReference, 224 | weight_info: WeightInfo, 225 | gather_tensors: GatherTensors, 226 | tensor_parameters: ImmutableMap[ModelReference, Any], 227 | method_args: Optional[Dict] = None, 228 | ) -> torch.Tensor: 229 | method_args = method_args or {} 230 | 231 | # Extract base tensor and remove it from the dictionary 232 | base_tensor = tensors.pop(base_model) 233 | other_tensors = list(tensors.values()) # Must be length 1 234 | 235 | return nearswap_merge( 236 | base_tensor=base_tensor, 237 | tensors=other_tensors, 238 | t=method_args.get('similarity_threshold', 0.9) 239 | ) 240 | 241 | 242 | def model_stock_merge( 243 | *, 244 | tensors: Dict[ModelReference, torch.Tensor], 245 | base_model: ModelReference, 246 | weight_info: WeightInfo, 247 | gather_tensors: GatherTensors, 248 | tensor_parameters: Optional[ImmutableMap[ModelReference, Any]] = None, 249 | method_args: Optional[Dict] = None, 250 | ) -> torch.Tensor: 251 | task = ModelStockMergeTask( 252 | base_model=base_model, 253 | weight_info=weight_info, 254 | gather_tensors=gather_tensors 255 | ) 256 | return task.execute(tensors=tensors) 257 | 258 | 259 | def nuslerp_merge( 260 | *, 261 | tensors: Dict[ModelReference, torch.Tensor], 262 | base_model: ModelReference, 263 | weight_info: WeightInfo, 264 | gather_tensors: GatherTensors, 265 | tensor_parameters: Optional[ImmutableMap[ModelReference, Any]] = None, 266 | method_args: Optional[Dict] = None, 267 | ) -> torch.Tensor: 268 | method_args = method_args or {} 269 | 270 | task = NuSlerpTask( 271 | base_model=base_model, 272 | weight_info=weight_info, 273 | gather_tensors=gather_tensors, 274 | tensor_parameters=tensor_parameters, 275 | row_wise=method_args.get('row_wise', False), 276 | flatten=method_args.get('flatten', False) 277 | ) 278 | return task.execute(tensors=tensors) 279 | 280 | -------------------------------------------------------------------------------- /tests/test_decomposition.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for decomposition module. 3 | 4 | Tests tensor decomposition functionality including SVD, QR, and error handling. 5 | """ 6 | 7 | import pytest 8 | import torch 9 | import sys 10 | from pathlib import Path 11 | 12 | # Add src to path for imports 13 | sys.path.insert(0, str(Path(__file__).parent.parent / "src")) 14 | 15 | from decomposition import ( 16 | SVDDecomposer, 17 | RandomizedSVDDecomposer, 18 | EnergyBasedRandomizedSVDDecomposer, 19 | QRDecomposer, 20 | SingularValueDistribution, 21 | ) 22 | 23 | 24 | class TestSVDDecomposer: 25 | """Tests for standard SVD decomposer.""" 26 | 27 | def test_basic_2d_decomposition(self): 28 | """Test basic 2D tensor decomposition.""" 29 | decomposer = SVDDecomposer() 30 | weight = torch.randn(100, 50) 31 | target_rank = 10 32 | 33 | up, down, alpha, stats = decomposer.decompose( 34 | weight, target_rank, return_statistics=True 35 | ) 36 | 37 | # Check shapes 38 | assert up.shape == (100, 10) 39 | assert down.shape == (10, 50) 40 | assert isinstance(alpha, float) 41 | 42 | # Check reconstruction quality 43 | reconstructed = up @ down 44 | reconstruction_error = torch.norm(weight - reconstructed) / torch.norm(weight) 45 | assert reconstruction_error < 0.5 # Should be reasonable approximation 46 | 47 | def test_4d_conv_decomposition(self): 48 | """Test 4D convolutional tensor decomposition.""" 49 | decomposer = SVDDecomposer() 50 | weight = torch.randn(64, 32, 3, 3) # Conv layer 51 | target_rank = 16 52 | 53 | up, down, alpha, _ = decomposer.decompose(weight, target_rank) 54 | 55 | # Check shapes 56 | assert up.shape == (64, 16, 1, 1) 57 | assert down.shape == (16, 32, 3, 3) 58 | 59 | def test_symmetric_vs_asymmetric_distribution(self): 60 | """Test different singular value distributions.""" 61 | weight = torch.randn(50, 30) 62 | target_rank = 10 63 | 64 | # Symmetric distribution 65 | decomposer_sym = SVDDecomposer( 66 | distribution=SingularValueDistribution.SYMMETRIC 67 | ) 68 | up_sym, down_sym, _, _ = decomposer_sym.decompose(weight, target_rank) 69 | 70 | # Asymmetric distribution 71 | decomposer_asym = SVDDecomposer( 72 | distribution=SingularValueDistribution.ASYMMETRIC 73 | ) 74 | up_asym, down_asym, _, _ = decomposer_asym.decompose(weight, target_rank) 75 | 76 | # Both should reconstruct similarly but with different scaling 77 | recon_sym = up_sym @ down_sym 78 | recon_asym = up_asym @ down_asym 79 | 80 | assert torch.allclose(recon_sym, recon_asym, rtol=1e-4) 81 | 82 | def test_statistics_calculation(self): 83 | """Test that statistics are calculated correctly.""" 84 | decomposer = SVDDecomposer(return_statistics=True) 85 | weight = torch.randn(80, 40) 86 | target_rank = 20 87 | 88 | _, _, _, stats = decomposer.decompose(weight, target_rank) 89 | 90 | assert stats is not None 91 | assert 'new_rank' in stats 92 | assert 'new_alpha' in stats 93 | assert 'sum_retained' in stats 94 | assert 'fro_retained' in stats 95 | assert 'max_ratio' in stats 96 | 97 | # Check statistics are reasonable 98 | assert stats['new_rank'] == 20 99 | assert 0.0 <= stats['sum_retained'] <= 1.0 100 | assert 0.0 <= stats['fro_retained'] <= 1.0 101 | 102 | def test_dynamic_rank_selection_ratio(self): 103 | """Test dynamic rank selection by singular value ratio.""" 104 | decomposer = SVDDecomposer(return_statistics=True) 105 | weight = torch.randn(100, 50) 106 | 107 | _, _, _, stats = decomposer.decompose( 108 | weight, 109 | target_rank=50, 110 | dynamic_method="sv_ratio", 111 | dynamic_param=100.0 # Ratio threshold 112 | ) 113 | 114 | # Rank should be selected based on ratio 115 | assert stats['new_rank'] <= 50 116 | assert stats['new_rank'] >= 1 117 | 118 | def test_dynamic_rank_selection_cumulative(self): 119 | """Test dynamic rank selection by cumulative singular values.""" 120 | decomposer = SVDDecomposer(return_statistics=True) 121 | weight = torch.randn(100, 50) 122 | 123 | _, _, _, stats = decomposer.decompose( 124 | weight, 125 | target_rank=50, 126 | dynamic_method="sv_cumulative", 127 | dynamic_param=0.95 # 95% of cumulative sum 128 | ) 129 | 130 | # Rank should be selected to capture 95% of singular values 131 | assert stats['sum_retained'] >= 0.90 # Allow some tolerance 132 | 133 | def test_dynamic_rank_selection_frobenius(self): 134 | """Test dynamic rank selection by Frobenius norm.""" 135 | decomposer = SVDDecomposer(return_statistics=True) 136 | weight = torch.randn(100, 50) 137 | 138 | _, _, _, stats = decomposer.decompose( 139 | weight, 140 | target_rank=50, 141 | dynamic_method="sv_fro", 142 | dynamic_param=0.99 # 99% of Frobenius norm 143 | ) 144 | 145 | # Rank should be selected to retain 99% of Frobenius norm 146 | assert stats['fro_retained'] >= 0.95 # Allow some tolerance 147 | 148 | 149 | class TestRandomizedSVDDecomposer: 150 | """Tests for randomized SVD decomposer.""" 151 | 152 | def test_randomized_svd_approximation(self): 153 | """Test that randomized SVD produces good approximation.""" 154 | weight = torch.randn(200, 100) 155 | target_rank = 20 156 | 157 | # Standard SVD 158 | decomposer_std = SVDDecomposer() 159 | up_std, down_std, _, _ = decomposer_std.decompose(weight, target_rank) 160 | recon_std = up_std @ down_std 161 | 162 | # Randomized SVD 163 | decomposer_rand = RandomizedSVDDecomposer(n_oversamples=10, n_iter=2) 164 | up_rand, down_rand, _, _ = decomposer_rand.decompose(weight, target_rank) 165 | recon_rand = up_rand @ down_rand 166 | 167 | # Reconstructions should be similar 168 | error_std = torch.norm(weight - recon_std) 169 | error_rand = torch.norm(weight - recon_rand) 170 | 171 | # Randomized should be close to standard (within 50% relative error) 172 | assert abs(error_rand - error_std) / error_std < 0.5 173 | 174 | def test_randomized_svd_small_matrix(self): 175 | """Test that small matrices fall back to standard SVD.""" 176 | decomposer = RandomizedSVDDecomposer() 177 | weight = torch.randn(50, 30) # Small matrix 178 | target_rank = 10 179 | 180 | # Should not raise error 181 | up, down, alpha, _ = decomposer.decompose(weight, target_rank) 182 | 183 | assert up.shape == (50, 10) 184 | assert down.shape == (10, 30) 185 | 186 | 187 | class TestEnergyBasedRandomizedSVDDecomposer: 188 | """Tests for energy-based randomized SVD.""" 189 | 190 | def test_energy_based_rank_selection(self): 191 | """Test that energy threshold affects rank selection.""" 192 | weight = torch.randn(100, 50) 193 | 194 | # Low energy threshold (fewer components) 195 | decomposer_low = EnergyBasedRandomizedSVDDecomposer( 196 | energy_threshold=0.8, 197 | return_statistics=True 198 | ) 199 | _, _, _, stats_low = decomposer_low.decompose(weight, target_rank=50) 200 | 201 | # High energy threshold (more components) 202 | decomposer_high = EnergyBasedRandomizedSVDDecomposer( 203 | energy_threshold=0.99, 204 | return_statistics=True 205 | ) 206 | _, _, _, stats_high = decomposer_high.decompose(weight, target_rank=50) 207 | 208 | # Higher threshold should generally use more components 209 | # (though not guaranteed due to randomness) 210 | assert stats_low is not None 211 | assert stats_high is not None 212 | 213 | 214 | class TestQRDecomposer: 215 | """Tests for QR decomposer.""" 216 | 217 | def test_qr_decomposition(self): 218 | """Test basic QR decomposition.""" 219 | decomposer = QRDecomposer() 220 | weight = torch.randn(100, 50) 221 | target_rank = 20 222 | 223 | up, down, alpha, _ = decomposer.decompose(weight, target_rank) 224 | 225 | # Check shapes 226 | assert up.shape == (100, 20) 227 | assert down.shape == (20, 50) 228 | 229 | # QR decomposition should still provide reasonable approximation 230 | reconstructed = up @ down 231 | reconstruction_error = torch.norm(weight - reconstructed) / torch.norm(weight) 232 | assert reconstruction_error < 1.0 # Looser bound than SVD 233 | 234 | 235 | class TestErrorHandling: 236 | """Tests for error handling in decomposition.""" 237 | 238 | def test_invalid_tensor_dimensions(self): 239 | """Test that invalid tensor dimensions raise errors.""" 240 | decomposer = SVDDecomposer() 241 | weight = torch.randn(10) # 1D tensor (invalid) 242 | 243 | with pytest.raises(ValueError, match="must be 2D, 3D, or 4D"): 244 | decomposer.decompose(weight, target_rank=5) 245 | 246 | def test_zero_matrix_handling(self): 247 | """Test handling of numerically zero matrices.""" 248 | decomposer = SVDDecomposer(return_statistics=True) 249 | weight = torch.zeros(50, 30) 250 | 251 | # Should not crash, should handle gracefully 252 | up, down, alpha, stats = decomposer.decompose(weight, target_rank=10) 253 | 254 | # Rank should be minimal for zero matrix 255 | assert stats['new_rank'] == 1 256 | 257 | def test_invalid_dynamic_method(self): 258 | """Test that invalid dynamic method raises error.""" 259 | decomposer = SVDDecomposer() 260 | weight = torch.randn(50, 30) 261 | 262 | with pytest.raises(ValueError, match="Unknown dynamic method"): 263 | decomposer.decompose( 264 | weight, 265 | target_rank=10, 266 | dynamic_method="invalid_method" 267 | ) 268 | 269 | 270 | # Fixtures 271 | 272 | @pytest.fixture 273 | def sample_2d_weight(): 274 | """Fixture providing sample 2D weight tensor.""" 275 | return torch.randn(100, 50) 276 | 277 | 278 | @pytest.fixture 279 | def sample_4d_weight(): 280 | """Fixture providing sample 4D convolutional weight.""" 281 | return torch.randn(64, 32, 3, 3) 282 | 283 | 284 | class TestIntegrationWithFixtures: 285 | """Integration tests using fixtures.""" 286 | 287 | def test_all_decomposers_with_2d(self, sample_2d_weight): 288 | """Test all decomposers work with 2D tensors.""" 289 | decomposers = [ 290 | SVDDecomposer(), 291 | RandomizedSVDDecomposer(), 292 | EnergyBasedRandomizedSVDDecomposer(), 293 | QRDecomposer(), 294 | ] 295 | 296 | for decomposer in decomposers: 297 | up, down, alpha, _ = decomposer.decompose( 298 | sample_2d_weight, 299 | target_rank=20 300 | ) 301 | 302 | assert up.shape[0] == 100 303 | assert up.shape[1] == 20 304 | assert down.shape[0] == 20 305 | assert down.shape[1] == 50 306 | 307 | def test_all_decomposers_with_4d(self, sample_4d_weight): 308 | """Test all decomposers work with 4D conv tensors.""" 309 | decomposers = [ 310 | SVDDecomposer(), 311 | RandomizedSVDDecomposer(), 312 | QRDecomposer(), 313 | ] 314 | 315 | for decomposer in decomposers: 316 | up, down, alpha, _ = decomposer.decompose( 317 | sample_4d_weight, 318 | target_rank=16 319 | ) 320 | 321 | assert up.shape == (64, 16, 1, 1) 322 | assert down.shape == (16, 32, 3, 3) 323 | 324 | 325 | if __name__ == "__main__": 326 | pytest.main([__file__, "-v"]) 327 | -------------------------------------------------------------------------------- /src/lora_mergekit_merge.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from concurrent.futures import ThreadPoolExecutor, as_completed 4 | from typing import Dict 5 | 6 | import torch 7 | import comfy 8 | import comfy.model_management 9 | from comfy.weight_adapter import LoRAAdapter 10 | 11 | from mergekit.architecture import WeightInfo 12 | from mergekit.common import ModelReference, ImmutableMap, ModelPath 13 | from mergekit.io.tasks import GatherTensors 14 | 15 | from .architectures.sd_lora import weights_as_tuple 16 | 17 | # Import merge module components 18 | from .merge import ( 19 | create_map, 20 | create_tensor_param, 21 | get_merge_method, 22 | prepare_method_args, 23 | ) 24 | from .mergekit_utils import load_on_device 25 | # Import centralized types 26 | from .types import ( 27 | LORA_WEIGHTS, 28 | LORA_TENSOR_DICT, 29 | LORA_TENSORS_BY_LAYER, 30 | MergeMethod, 31 | ) 32 | from .utility import map_device 33 | # Import validation components 34 | from .validation import validate_tensor_shapes_compatible 35 | 36 | 37 | # Helper functions moved to src/merge/utils.py 38 | # Imported above for backward compatibility 39 | 40 | 41 | class LoraMergerMergekit: 42 | """ 43 | Node for merging LoRA models with Mergekit 44 | """ 45 | 46 | def __init__(self): 47 | self.components: LORA_TENSORS_BY_LAYER = {} 48 | self.strengths: LORA_WEIGHTS = {} 49 | 50 | @classmethod 51 | def INPUT_TYPES(s): 52 | return { 53 | "required": { 54 | "method": ("MergeMethod",), 55 | "components": ("LoRATensors", {"tooltip": "The decomposed components of the LoRAs to be merged."}), 56 | "strengths": ("LoRAWeights", {"tooltip": "The weights of the LoRAs to be merged."}), 57 | "lambda_": ("FLOAT", { 58 | "default": 1, 59 | "min": 0, 60 | "max": 1, 61 | "step": 0.01, 62 | "tooltip": "Lambda value for scaling the merged model.", 63 | }), 64 | "device": (["cuda", "cpu"],), 65 | "dtype": (["float16", "bfloat16", "float32"],), 66 | }, 67 | } 68 | 69 | RETURN_TYPES = ("LoRABundle", "MergeContext") 70 | RETURN_NAMES = ("lora", "merge_context") 71 | FUNCTION = "lora_mergekit" 72 | CATEGORY = "LoRA PowerMerge" 73 | DESCRIPTION = """Core LoRA merger using Mergekit algorithms. 74 | 75 | Merges decomposed LoRA components using the selected merge method (TIES, DARE, SLERP, etc.). Processes layers in parallel using ThreadPoolExecutor for performance. 76 | 77 | Inputs: 78 | - method: Merge algorithm configuration from method nodes 79 | - components: Decomposed LoRA tensors from LoRA Decompose node 80 | - strengths: Per-LoRA weight multipliers (strength_model for UNet, strength_clip for CLIP) 81 | - lambda_: Global scaling factor applied to final merged result (0-1) 82 | 83 | Outputs: 84 | - LoRABundle: Merged LoRA ready for application or saving 85 | - MergeContext: Reusable merge configuration for batch operations""" 86 | 87 | @torch.no_grad() 88 | def lora_mergekit(self, 89 | method: Dict = None, 90 | components: LORA_TENSORS_BY_LAYER = None, 91 | strengths: LORA_WEIGHTS = None, 92 | lambda_: float = 1.0, 93 | device=None, dtype=None): 94 | 95 | if components is None: 96 | raise Exception("No components provided for merging.") 97 | 98 | self.components = components 99 | self.strengths = strengths 100 | 101 | device, dtype = map_device(device, dtype) 102 | 103 | # Use dispatcher to get merge method 104 | merge_method = get_merge_method(method['name']) 105 | 106 | # Prepare method arguments 107 | method_args = prepare_method_args(method['name'], method['settings']) 108 | 109 | self.validate_input() 110 | 111 | # Adjust components to match the method requirements 112 | merge = self.merge(method=merge_method, method_args=method_args, lambda_=lambda_, device=device, dtype=dtype) 113 | # Clean up VRAM 114 | torch.cuda.empty_cache() 115 | 116 | # Create merge context for downstream nodes 117 | merge_context = { 118 | "method": method, 119 | "components": components, 120 | "strengths": strengths, 121 | "lambda_": lambda_, 122 | "device": device, 123 | "dtype": dtype 124 | } 125 | 126 | return merge + (merge_context,) 127 | 128 | def merge(self, method: MergeMethod, method_args, lambda_, device, dtype): 129 | pbar = comfy.utils.ProgressBar(len(self.components.keys())) 130 | start = time.time() 131 | 132 | def process_key(key): 133 | lora_key_tuples: LORA_TENSOR_DICT = self.components[key] 134 | 135 | weights = [self.strengths[lora_name]["strength_model"] for lora_name in lora_key_tuples.keys()] 136 | weights = torch.tensor(weights, dtype=dtype).to(device=device) 137 | 138 | def calculate(tensors_): 139 | tensor_map = {} 140 | tensor_weight_map = {} 141 | weight_info = WeightInfo(name=f'{key}.merge', dtype=dtype, is_embed=False) 142 | for i, t in enumerate(tensors_): 143 | ref = ModelReference(model=ModelPath(path=f'{key}.{i}')) 144 | tensor_map[ref] = t 145 | tensor_weight_map[ref] = weights[i] 146 | 147 | gather_tensors = GatherTensors(weight_info=create_map(key, tensor_map, dtype)) 148 | tensor_parameters = ImmutableMap( 149 | {r: ImmutableMap(create_tensor_param(tensor_weight_map[r], method_args)) for r in 150 | tensor_map.keys()}) 151 | 152 | # Load to the device 153 | load_on_device(tensor_map, tensor_weight_map, device, dtype) 154 | 155 | # Call the merge method 156 | out = method(tensor_map, gather_tensors, weight_info, tensor_parameters, method_args) 157 | 158 | # Apply lambda scaling 159 | if lambda_ < 1.0: 160 | out = out * lambda_ 161 | 162 | # Offload the result to CPU 163 | load_on_device(tensor_map, tensor_weight_map, "cpu", dtype) 164 | out = out.to(device='cpu', dtype=torch.float32) 165 | return out 166 | 167 | # Extract up and down tensors 168 | up_tensors = [u for u, _, _ in lora_key_tuples.values()] 169 | down_tensors = [d for _, d, _ in lora_key_tuples.values()] 170 | 171 | # Debug logging 172 | if len(up_tensors) > 0: 173 | logging.debug(f"Key {key}: up_tensors[0] shape = {up_tensors[0].shape}") 174 | if len(down_tensors) > 0: 175 | logging.debug(f"Key {key}: down_tensors[0] shape = {down_tensors[0].shape}") 176 | 177 | up = calculate(up_tensors) 178 | down = calculate(down_tensors) 179 | alpha_0 = next(iter(lora_key_tuples.values()))[2] 180 | 181 | # Debug logging for results 182 | logging.debug(f"Key {key}: merged up shape = {up.shape}, merged down shape = {down.shape}") 183 | 184 | # Sanity check: up should be (out_features, rank) and down should be (rank, in_features) 185 | # For LoRA to work, up.shape[1] should equal down.shape[0] (the rank dimension) 186 | # If they're swapped, we'll detect it here 187 | if len(up.shape) == 2 and len(down.shape) == 2: 188 | # Check if up and down appear to be swapped 189 | # up should have more elements in dim 0 than dim 1 (tall matrix) 190 | # down should have more elements in dim 1 than dim 0 (wide matrix) 191 | up_is_tall = up.shape[0] > up.shape[1] 192 | down_is_wide = down.shape[1] > down.shape[0] 193 | 194 | # If up is NOT tall or down is NOT wide, they might be swapped 195 | if not up_is_tall and down_is_wide: 196 | # up appears to be a wide matrix, might need transpose 197 | logging.warning(f"Key {key}: up tensor appears to be transposed (shape {up.shape}). This may cause issues.") 198 | if up_is_tall and not down_is_wide: 199 | # down appears to be a tall matrix, might need transpose 200 | logging.warning(f"Key {key}: down tensor appears to be transposed (shape {down.shape}). This may cause issues.") 201 | 202 | return key, (up, down, alpha_0) 203 | 204 | adapter_state_dict = {} 205 | with ThreadPoolExecutor(max_workers=8) as executor: 206 | keys = self.components.keys() 207 | 208 | # distribute the work across available devices 209 | futures = {executor.submit(process_key, key) for key in keys} 210 | 211 | for future in as_completed(futures): 212 | key, result = future.result() 213 | if result: 214 | up, down, alpha_0 = result 215 | adapter_state_dict[key] = LoRAAdapter(weights=weights_as_tuple(up, down, alpha_0), 216 | loaded_keys=set(keys)) 217 | pbar.update(1) 218 | 219 | logging.info(f"Processed {len(keys)} keys in {time.time() - start:.2f} seconds") 220 | 221 | lora_out = {"lora": adapter_state_dict, "strength_model": 1, "name": "Merge"} 222 | return (lora_out,) 223 | 224 | def validate_input(self): 225 | """ 226 | Validate input parameters for merge operation. 227 | 228 | Performs comprehensive validation of: 229 | - Component tensors (shapes, compatibility) 230 | - Strength values (presence, reasonable ranges) 231 | 232 | Logs warnings for potential issues and raises exceptions for 233 | critical errors that would cause merge to fail. 234 | 235 | Raises: 236 | ValueError: If critical validation errors are found 237 | """ 238 | errors = [] 239 | warnings = [] 240 | 241 | # Validate components exist and are not empty 242 | if not self.components: 243 | raise ValueError("No components provided for merging") 244 | 245 | if len(self.components) == 0: 246 | raise ValueError("Components dictionary is empty") 247 | 248 | # Validate strengths exist 249 | if not self.strengths: 250 | raise ValueError("No strengths provided for merging") 251 | 252 | # Validate tensor shapes are compatible 253 | validation_result = validate_tensor_shapes_compatible(self.components) 254 | 255 | # Log validation warnings 256 | for warning in validation_result["warnings"]: 257 | logging.warning(f"Validation warning: {warning}") 258 | warnings.append(warning) 259 | 260 | # Collect validation errors 261 | for error in validation_result["errors"]: 262 | error_msg = f"{error['code']}: {error['message']}" 263 | if error.get('location'): 264 | error_msg += f" (at {error['location']})" 265 | errors.append(error_msg) 266 | logging.error(f"Validation error: {error_msg}") 267 | 268 | # Validate all LoRAs in components have corresponding strengths 269 | lora_names = set() 270 | for layer_tensors in self.components.values(): 271 | lora_names.update(layer_tensors.keys()) 272 | 273 | for lora_name in lora_names: 274 | if lora_name not in self.strengths: 275 | error_msg = f"Missing strength for LoRA '{lora_name}'" 276 | errors.append(error_msg) 277 | logging.error(f"Validation error: {error_msg}") 278 | 279 | # Raise exception if critical errors found 280 | if errors: 281 | error_summary = "\n".join(f" - {e}" for e in errors) 282 | raise ValueError( 283 | f"Validation failed with {len(errors)} error(s):\n{error_summary}" 284 | ) 285 | 286 | 287 | # ============================================================================ 288 | # Algorithm implementations moved to src/merge/algorithms.py 289 | # The functions below have been extracted to the merge module for better organization 290 | # ============================================================================ 291 | --------------------------------------------------------------------------------