├── demo.jpg
├── method.jpg
├── constraints.txt
├── requirements.txt
├── lora_algebra
├── __init__.py
├── lora_cache.py
├── core.py
├── utils.py
├── analysis.py
└── operations.py
├── fixed_files
├── README.md
└── flux_merge_lora.py
├── LICENSE
├── .gitignore
├── windows_install.bat
├── lora_algebra_gui.py
├── setup.py
├── update.bat
├── update_launchers.py
├── README.md
└── install.py
/demo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shootthesound/lora-the-explorer/HEAD/demo.jpg
--------------------------------------------------------------------------------
/method.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shootthesound/lora-the-explorer/HEAD/method.jpg
--------------------------------------------------------------------------------
/constraints.txt:
--------------------------------------------------------------------------------
1 | # Version constraints for sd-scripts dependencies
2 | # These override versions specified in sd-scripts/requirements.txt
3 | accelerate==1.8.1
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Only dependencies NOT in sd-scripts requirements.txt
2 | gradio>=5.35.0
3 | gradio_client>=1.10.4
4 | torch
5 | torchvision
6 |
7 | # GUI-specific dependencies
8 | pyyaml
9 | tqdm
10 |
11 | # Critical dependencies that might not install from sd-scripts
12 | opencv-python
13 | toml==0.10.2
14 | imagesize==1.4.1
15 |
16 | # LoRA-specific dependencies not in sd-scripts
17 | lycoris-lora==1.8.3
18 | omegaconf
19 | scipy
20 | k-diffusion==0.0.16
21 | peft==0.16.0
22 | # Pin accelerate to prevent dependency resolution conflicts
23 | accelerate==1.8.1
--------------------------------------------------------------------------------
/lora_algebra/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | LoRA Algebra: Advanced LoRA Manipulation Toolkit
3 |
4 | Revolutionary LoRA manipulation toolkit enabling mathematical operations
5 | on Low-Rank Adaptation models for unprecedented compatibility and control.
6 | """
7 |
8 | __version__ = "1.0.0"
9 | __author__ = "LoRA Algebra Team"
10 | __license__ = "MIT"
11 |
12 | from .core import LoRAProcessor
13 | from .operations import subtract_loras, merge_loras, analyze_lora
14 | from .analysis import extract_metadata, predict_compatibility
15 |
16 | __all__ = [
17 | "LoRAProcessor",
18 | "subtract_loras",
19 | "merge_loras",
20 | "analyze_lora",
21 | "extract_metadata",
22 | "predict_compatibility"
23 | ]
--------------------------------------------------------------------------------
/fixed_files/README.md:
--------------------------------------------------------------------------------
1 | # Fixed Files Directory
2 |
3 | This directory contains corrected versions of sd-scripts files that have bugs affecting LoRA the Explorer.
4 |
5 | ## flux_merge_lora.py
6 |
7 | **Issue:** The original file sets `ss_network_module` to `"networks.lora"` instead of `"networks.lora_flux"` for FLUX LoRAs.
8 |
9 | **Impact:** This metadata mismatch prevents FLUX LoRAs created by LoRA the Explorer from merging with LoRAs created by other tools (like FluxGym) even when they have the same rank.
10 |
11 | **Fix:** Line 550 changed from:
12 | ```python
13 | metadata = train_util.build_minimum_network_metadata(str(False), base_model, "networks.lora", dims, alphas, None)
14 | ```
15 | to:
16 | ```python
17 | metadata = train_util.build_minimum_network_metadata(str(False), base_model, "networks.lora_flux", dims, alphas, None)
18 | ```
19 |
20 | **Applied:** During installation, this corrected file is copied over the sd-scripts version.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 LoRA Algebra Team
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.so
6 | .Python
7 | build/
8 | develop-eggs/
9 | dist/
10 | downloads/
11 | eggs/
12 | .eggs/
13 | lib/
14 | lib64/
15 | parts/
16 | sdist/
17 | var/
18 | wheels/
19 | pip-wheel-metadata/
20 | share/python-wheels/
21 | *.egg-info/
22 | .installed.cfg
23 | *.egg
24 | MANIFEST
25 |
26 | # PyInstaller
27 | *.manifest
28 | *.spec
29 |
30 | # Installer logs
31 | pip-log.txt
32 | pip-delete-this-directory.txt
33 |
34 | # Unit test / coverage reports
35 | htmlcov/
36 | .tox/
37 | .nox/
38 | .coverage
39 | .coverage.*
40 | .cache
41 | nosetests.xml
42 | coverage.xml
43 | *.cover
44 | *.py,cover
45 | .hypothesis/
46 | .pytest_cache/
47 |
48 | # Virtual environments
49 | .env
50 | .venv
51 | env/
52 | venv/
53 | ENV/
54 | env.bak/
55 | venv.bak/
56 |
57 | # IDEs
58 | .vscode/
59 | .idea/
60 | *.swp
61 | *.swo
62 | *~
63 |
64 | # OS
65 | .DS_Store
66 | .DS_Store?
67 | ._*
68 | .Spotlight-V100
69 | .Trashes
70 | ehthumbs.db
71 | Thumbs.db
72 |
73 | # LoRA files and outputs
74 | *.safetensors
75 | output/
76 | outputs/
77 | temp/
78 | backups/
79 |
80 | # Logs
81 | *.log
82 | logs/
83 |
84 | # Model files
85 | models/
86 | checkpoints/
87 | weights/
88 |
89 | # Downloaded dependencies
90 | sd-scripts/
91 | *.zip
92 |
93 | # Data
94 | data/
95 | datasets/
96 | training_data/
97 |
98 | # Jupyter
99 | .ipynb_checkpoints
100 |
101 | # Local config
102 | config.local.yaml
103 | .env.local
--------------------------------------------------------------------------------
/windows_install.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | title LoRA the Explorer - Installation
3 |
4 | echo.
5 | echo ===============================================
6 | echo LoRA the Explorer - Installation
7 | echo ===============================================
8 | echo.
9 |
10 | REM Check if Python is available
11 | python --version >nul 2>&1
12 | if errorlevel 1 (
13 | echo [ERROR] PYTHON NOT FOUND
14 | echo.
15 | echo LoRA the Explorer requires Python 3.8 or higher to run.
16 | echo.
17 | echo NEXT STEPS:
18 | echo 1. Download Python from: https://www.python.org/downloads/
19 | echo 2. During installation, CHECK "Add Python to PATH"
20 | echo 3. Restart your computer after installation
21 | echo 4. Run this installer again
22 | echo.
23 | echo IMPORTANT: You MUST check "Add Python to PATH" during installation
24 | echo or this installer will not work.
25 | echo.
26 | pause
27 | exit /b 1
28 | )
29 |
30 | echo [OK] Python found
31 | python --version
32 | echo.
33 |
34 | echo Starting installation...
35 | echo This will create a virtual environment and install all dependencies
36 | echo.
37 |
38 | REM Run the installer
39 | python install.py
40 |
41 | echo.
42 | if errorlevel 1 (
43 | echo [ERROR] Installation failed!
44 | echo.
45 | echo Please check the error messages above and try again.
46 | echo If problems persist, please report the issue.
47 | echo.
48 | ) else (
49 | echo [SUCCESS] Installation completed successfully!
50 | echo.
51 | echo You can now launch LoRA the Explorer by:
52 | echo - Double-clicking start_gui.bat
53 | echo - Or running: python lora_algebra_gui.py
54 | echo.
55 | )
56 |
57 | echo Press any key to close this window...
58 | pause >nul
--------------------------------------------------------------------------------
/lora_algebra_gui.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | LoRA Algebra GUI Launcher
4 |
5 | Simple launcher script for the LoRA Algebra GUI
6 | """
7 |
8 | import os
9 | import sys
10 | from pathlib import Path
11 |
12 | # Add current directory to Python path
13 | sys.path.insert(0, str(Path(__file__).parent))
14 |
15 | def main():
16 | """Launch the LoRA Algebra GUI"""
17 | try:
18 | from lora_algebra.gui import launch_gui
19 |
20 | # Look for sd-scripts (should be installed locally)
21 | sd_scripts_path = os.path.abspath("sd-scripts")
22 |
23 | if not os.path.exists(sd_scripts_path) or not os.path.exists(os.path.join(sd_scripts_path, "networks")):
24 | print("⚠️ Warning: sd-scripts not found.")
25 | print(" Please run 'python install.py' first to set up dependencies.")
26 | print(" Some features may not work without sd-scripts.")
27 | print()
28 | sd_scripts_path = None
29 |
30 | print("Swiper no swiping...")
31 | print(" This may take a moment to start...")
32 | print()
33 |
34 | # Launch GUI
35 | launch_gui(
36 | sd_scripts_path=sd_scripts_path,
37 | share=False,
38 | inbrowser=True
39 | )
40 |
41 | except ImportError as e:
42 | print("❌ Error: Missing dependencies")
43 | print(f" {e}")
44 | print()
45 | print("Please install requirements:")
46 | print(" pip install -r requirements.txt")
47 | sys.exit(1)
48 |
49 | except Exception as e:
50 | print(f"❌ Error launching GUI: {e}")
51 | sys.exit(1)
52 |
53 | if __name__ == "__main__":
54 | main()
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """
2 | Setup script for LoRA the Explorer
3 | """
4 |
5 | from setuptools import setup, find_packages
6 |
7 | with open("README.md", "r", encoding="utf-8") as fh:
8 | long_description = fh.read()
9 |
10 | with open("requirements.txt", "r", encoding="utf-8") as fh:
11 | requirements = [line.strip() for line in fh if line.strip() and not line.startswith("#")]
12 |
13 | setup(
14 | name="lora-the-explorer",
15 | version="1.0.0",
16 | author="LoRA the Explorer",
17 | author_email="pneill@gmail.com",
18 | description="Advanced FLUX LoRA manipulation toolkit with GUI interface for layer targeting, difference operations, and compatibility fixes",
19 | long_description=long_description,
20 | long_description_content_type="text/markdown",
21 | url="https://github.com/shootthesound/lora-the-explorer",
22 | packages=find_packages(),
23 | classifiers=[
24 | "Development Status :: 5 - Production/Stable",
25 | "Intended Audience :: Developers",
26 | "Intended Audience :: Science/Research",
27 | "License :: OSI Approved :: MIT License",
28 | "Operating System :: OS Independent",
29 | "Programming Language :: Python :: 3",
30 | "Programming Language :: Python :: 3.8",
31 | "Programming Language :: Python :: 3.9",
32 | "Programming Language :: Python :: 3.10",
33 | "Programming Language :: Python :: 3.11",
34 | "Programming Language :: Python :: 3.12",
35 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
36 | "Topic :: Software Development :: Libraries :: Python Modules",
37 | "Topic :: Multimedia :: Graphics",
38 | "Environment :: X11 Applications",
39 | "Environment :: Win32 (MS Windows)",
40 | "Environment :: MacOS X",
41 | ],
42 | python_requires=">=3.8",
43 | install_requires=requirements,
44 | extras_require={
45 | "dev": [
46 | "pytest>=8.0.0",
47 | "pytest-cov>=5.0.0",
48 | "black>=24.0.0",
49 | "ruff>=0.1.0",
50 | "mypy>=1.8.0",
51 | ],
52 | "docs": [
53 | "sphinx>=7.0.0",
54 | "sphinx-rtd-theme>=2.0.0",
55 | ],
56 | },
57 | include_package_data=True,
58 | zip_safe=False,
59 | keywords="lora, flux, stable-diffusion, ai, gui, layer-targeting",
60 | project_urls={
61 | "Bug Reports": "https://github.com/shootthesound/lora-the-explorer/issues",
62 | "Documentation": "https://github.com/shootthesound/lora-the-explorer/blob/main/README.md",
63 | "Source": "https://github.com/shootthesound/lora-the-explorer",
64 | "Support": "https://buymeacoffee.com/loratheexplorer",
65 | },
66 | )
--------------------------------------------------------------------------------
/update.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | setlocal EnableDelayedExpansion
3 | title LoRA the Explorer - Update
4 |
5 | echo.
6 | echo ===============================================
7 | echo LoRA the Explorer - Update
8 | echo ===============================================
9 | echo.
10 |
11 | REM Change to the project directory
12 | cd /d "%~dp0"
13 |
14 | REM Check if git is available
15 | git --version >nul 2>&1
16 | if errorlevel 1 (
17 | echo [ERROR] GIT NOT FOUND
18 | echo.
19 | echo Git is required to update LoRA the Explorer.
20 | echo.
21 | echo NEXT STEPS:
22 | echo 1. Download Git from: https://git-scm.com/download/win
23 | echo 2. Install Git with default settings
24 | echo 3. Restart your computer after installation
25 | echo 4. Run this updater again
26 | echo.
27 | pause
28 | exit /b 1
29 | )
30 |
31 | echo [OK] Git found
32 | git --version
33 | echo.
34 |
35 | REM Check if we're in a git repository
36 | git status >nul 2>&1
37 | if errorlevel 1 (
38 | echo [ERROR] NOT A GIT REPOSITORY
39 | echo.
40 | echo This directory is not a git repository.
41 | echo Cannot perform update.
42 | echo.
43 | pause
44 | exit /b 1
45 | )
46 |
47 | echo [OK] Git repository detected
48 | echo.
49 |
50 | echo Checking for updates...
51 | echo.
52 |
53 | REM Fetch latest changes
54 | git fetch
55 | if errorlevel 1 (
56 | echo [ERROR] Failed to fetch updates from remote repository
57 | echo.
58 | echo Please check your internet connection and try again.
59 | echo.
60 | pause
61 | exit /b 1
62 | )
63 |
64 | REM Check if there are updates available
65 | set UPDATE_AVAILABLE=0
66 |
67 | REM Check if we have an upstream branch configured
68 | git rev-parse --abbrev-ref @{u} >nul 2>&1
69 | if not errorlevel 1 (
70 | REM Count commits behind using rev-list
71 | for /f %%i in ('git rev-list --count HEAD..@{u} 2^>nul') do set BEHIND_COUNT=%%i
72 | if not "!BEHIND_COUNT!"=="0" if not "!BEHIND_COUNT!"=="" (
73 | set UPDATE_AVAILABLE=1
74 | )
75 | )
76 |
77 | if !UPDATE_AVAILABLE!==0 (
78 | echo [INFO] Already up to date!
79 | echo No updates available.
80 | echo.
81 | ) else (
82 | echo [INFO] Updates available. Pulling changes...
83 | echo.
84 |
85 | REM Pull the latest changes
86 | git pull
87 | if errorlevel 1 (
88 | echo [ERROR] Failed to pull updates
89 | echo.
90 | echo There may be local changes conflicting with the update.
91 | echo Please resolve any conflicts manually or contact support.
92 | echo.
93 | ) else (
94 | echo [SUCCESS] Update completed successfully!
95 | echo.
96 | echo Latest changes have been applied to your LoRA the Explorer installation.
97 | echo.
98 |
99 | REM Update launcher scripts to get latest improvements
100 | echo [INFO] Updating launcher scripts...
101 | python update_launchers.py
102 | if errorlevel 1 (
103 | echo [WARNING] Failed to update launcher scripts
104 | echo Your start_gui.bat may not have the latest improvements.
105 | echo You can manually run: python update_launchers.py
106 | echo.
107 | ) else (
108 | echo [OK] Launcher scripts updated
109 | echo.
110 | )
111 | )
112 | )
113 |
114 | echo Press any key to close this window...
115 | pause >nul
--------------------------------------------------------------------------------
/lora_algebra/lora_cache.py:
--------------------------------------------------------------------------------
1 | """
2 | Simple global LoRA cache for autocomplete functionality
3 | """
4 |
5 | import os
6 | import json
7 | from typing import List, Optional, Tuple
8 | from .utils import find_lora_files
9 |
10 | class LoRACache:
11 | """Simple cache to store discovered LoRA file paths"""
12 |
13 | def __init__(self):
14 | self.lora_paths: List[str] = []
15 | self.scan_directory: Optional[str] = None
16 | self.default_output_path: Optional[str] = None
17 | self.settings_file = os.path.join(os.path.expanduser("~"), ".lora_algebra_settings.json")
18 | self.load_settings()
19 |
20 | def scan_directory_for_loras(self, directory: str) -> Tuple[bool, str]:
21 | """Scan a directory for LoRA files and cache the results
22 |
23 | Args:
24 | directory: Directory path to scan
25 |
26 | Returns:
27 | Tuple of (success: bool, message: str)
28 | """
29 | try:
30 | if not os.path.exists(directory):
31 | return False, f"Directory does not exist: {directory}"
32 |
33 | if not os.path.isdir(directory):
34 | return False, f"Path is not a directory: {directory}"
35 |
36 | # Use existing utility function to find LoRA files recursively
37 | found_files = find_lora_files(directory, recursive=True)
38 |
39 | # Clear only the old paths, not the directory
40 | self.lora_paths = []
41 |
42 | # Update cache with new data
43 | self.lora_paths = found_files
44 | self.scan_directory = directory
45 |
46 | # Save settings
47 | self.save_settings()
48 |
49 | return True, f"Found {len(found_files)} LoRA files in {directory}"
50 |
51 | except Exception as e:
52 | return False, f"Error scanning directory: {str(e)}"
53 |
54 | def get_matching_loras(self, query: str) -> List[str]:
55 | """Get LoRA paths that match the query string
56 |
57 | Args:
58 | query: Search string to match against
59 |
60 | Returns:
61 | List of matching LoRA file paths
62 | """
63 | if not query or len(query) < 2:
64 | return []
65 |
66 | query_lower = query.lower()
67 | matches = []
68 |
69 | for lora_path in self.lora_paths:
70 | # Extract filename for matching
71 | filename = os.path.basename(lora_path).lower()
72 |
73 | # Simple contains matching
74 | if query_lower in filename:
75 | matches.append(lora_path)
76 |
77 | # Sort matches - exact filename matches first, then contains matches
78 | def sort_key(path):
79 | filename = os.path.basename(path).lower()
80 | if filename.startswith(query_lower):
81 | return (0, filename) # Starts with query - highest priority
82 | else:
83 | return (1, filename) # Contains query - lower priority
84 |
85 | matches.sort(key=sort_key)
86 |
87 | # Limit results to avoid UI overload
88 | return matches[:50]
89 |
90 | def get_cache_info(self) -> dict:
91 | """Get information about the current cache state"""
92 | return {
93 | "total_loras": len(self.lora_paths),
94 | "scan_directory": self.scan_directory,
95 | "has_data": len(self.lora_paths) > 0
96 | }
97 |
98 | def clear_cache(self):
99 | """Clear the cache"""
100 | self.lora_paths = []
101 | self.scan_directory = None
102 |
103 | def save_settings(self):
104 | """Save current settings to file"""
105 | try:
106 | settings = {
107 | "scan_directory": self.scan_directory,
108 | "default_output_path": self.default_output_path
109 | }
110 | with open(self.settings_file, 'w') as f:
111 | json.dump(settings, f, indent=2)
112 | except Exception as e:
113 | print(f"Error saving settings: {e}")
114 |
115 | def load_settings(self):
116 | """Load settings from file"""
117 | try:
118 | if os.path.exists(self.settings_file):
119 | with open(self.settings_file, 'r') as f:
120 | settings = json.load(f)
121 | self.scan_directory = settings.get("scan_directory")
122 | self.default_output_path = settings.get("default_output_path")
123 | except Exception as e:
124 | print(f"Error loading settings: {e}")
125 |
126 | def auto_scan_on_startup(self) -> Tuple[bool, str]:
127 | """Automatically scan the saved directory on startup"""
128 | if self.scan_directory and os.path.exists(self.scan_directory):
129 | return self.scan_directory_for_loras(self.scan_directory)
130 | return False, "No saved directory to scan"
131 |
132 | def set_default_output_path(self, path: str):
133 | """Set and save the default output path"""
134 | self.default_output_path = path
135 | self.save_settings()
136 |
137 | def get_default_output_path(self) -> str:
138 | """Get the default output path or fallback"""
139 | return self.default_output_path or "output"
140 |
141 | # Global cache instance
142 | lora_cache = LoRACache()
--------------------------------------------------------------------------------
/update_launchers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Update Launcher Scripts for LoRA the Explorer
4 |
5 | This script regenerates the launcher scripts (start_gui.bat/start_gui.sh)
6 | after updates to ensure users get the latest launcher improvements.
7 | """
8 |
9 | import os
10 | import sys
11 | import platform
12 | from pathlib import Path
13 |
14 | def get_python_executable():
15 | """Get the path to Python executable in virtual environment"""
16 | if platform.system() == "Windows":
17 | return Path("env") / "Scripts" / "python.exe"
18 | else:
19 | return Path("env") / "bin" / "python"
20 |
21 | def create_launcher_scripts():
22 | """Create launcher scripts for easy access"""
23 | python_exe = get_python_executable()
24 |
25 | # Check if virtual environment exists
26 | if not python_exe.exists():
27 | print(f"[ERROR] Virtual environment not found at: {python_exe}")
28 | print("Please run install.py first to set up the environment.")
29 | return False
30 |
31 | if platform.system() == "Windows":
32 | # Windows batch file
33 | launcher_content = f"""@echo off
34 | setlocal EnableDelayedExpansion
35 | echo Launching LoRA the Explorer GUI...
36 | echo.
37 |
38 | REM Check for updates if git is available (non-blocking)
39 | git --version >nul 2>&1
40 | if not errorlevel 1 (
41 | git status >nul 2>&1
42 | if not errorlevel 1 (
43 | echo [INFO] Checking for updates...
44 | git fetch >nul 2>&1
45 | if not errorlevel 1 (
46 | REM Check if we have an upstream branch configured
47 | git rev-parse --abbrev-ref @{{u}} >nul 2>&1
48 | if not errorlevel 1 (
49 | REM Count commits behind using rev-list
50 | for /f %%i in ('git rev-list --count HEAD..@{{u}} 2^>nul') do set BEHIND_COUNT=%%i
51 | if not "!BEHIND_COUNT!"=="0" if not "!BEHIND_COUNT!"=="" (
52 | echo.
53 | echo ===============================================
54 | echo UPDATE AVAILABLE!
55 | echo ===============================================
56 | echo.
57 | echo A newer version of LoRA the Explorer is available.
58 | echo Run update.bat to get the latest features and fixes.
59 | echo.
60 | echo Press any key to continue launching the GUI...
61 | pause >nul
62 | echo.
63 | ) else (
64 | echo [OK] You are running the latest version
65 | echo.
66 | )
67 | ) else (
68 | echo [INFO] No upstream branch configured, skipping update check
69 | echo.
70 | )
71 | )
72 | )
73 | )
74 |
75 | echo Starting GUI...
76 | "{python_exe.absolute()}" lora_algebra_gui.py
77 | pause
78 | """
79 | with open("start_gui.bat", "w") as f:
80 | f.write(launcher_content)
81 |
82 | print("[OK] Updated Windows launcher script: start_gui.bat")
83 | return True
84 |
85 | else:
86 | # Unix shell script
87 | launcher_content = f"""#!/bin/bash
88 | echo " Launching LoRA the Explorer GUI..."
89 | echo
90 |
91 | # Check for updates if git is available (non-blocking)
92 | if command -v git >/dev/null 2>&1; then
93 | if git status >/dev/null 2>&1; then
94 | echo "[INFO] Checking for updates..."
95 | if git fetch >/dev/null 2>&1; then
96 | # Check if we have an upstream branch configured
97 | if git rev-parse --abbrev-ref @{{u}} >/dev/null 2>&1; then
98 | # Count commits behind using rev-list
99 | BEHIND_COUNT=$(git rev-list --count HEAD..@{{u}} 2>/dev/null)
100 | if [ "$BEHIND_COUNT" -gt 0 ] 2>/dev/null; then
101 | echo
102 | echo "==============================================="
103 | echo " UPDATE AVAILABLE!"
104 | echo "==============================================="
105 | echo
106 | echo "A newer version of LoRA the Explorer is available."
107 | echo "Run 'git pull' to get the latest features and fixes."
108 | echo
109 | echo "Press any key to continue launching the GUI..."
110 | read -n 1 -s
111 | echo
112 | else
113 | echo "[OK] You are running the latest version"
114 | echo
115 | fi
116 | else
117 | echo "[INFO] No upstream branch configured, skipping update check"
118 | echo
119 | fi
120 | fi
121 | fi
122 | fi
123 |
124 | echo "Starting GUI..."
125 | "{python_exe.absolute()}" lora_algebra_gui.py
126 | """
127 | with open("start_gui.sh", "w") as f:
128 | f.write(launcher_content)
129 | os.chmod("start_gui.sh", 0o755)
130 |
131 | print("[OK] Updated Unix launcher script: start_gui.sh")
132 | return True
133 |
134 | def main():
135 | """Main function to update launcher scripts"""
136 | print("Updating LoRA the Explorer launcher scripts...")
137 |
138 | try:
139 | success = create_launcher_scripts()
140 | if success:
141 | print("[SUCCESS] Launcher scripts updated successfully!")
142 | else:
143 | print("[ERROR] Failed to update launcher scripts")
144 | sys.exit(1)
145 | except Exception as e:
146 | print(f"[ERROR] Failed to update launcher scripts: {e}")
147 | sys.exit(1)
148 |
149 | if __name__ == "__main__":
150 | main()
--------------------------------------------------------------------------------
/lora_algebra/core.py:
--------------------------------------------------------------------------------
1 | """
2 | Core LoRA manipulation functionality
3 | """
4 |
5 | import os
6 | import sys
7 | import subprocess
8 | from pathlib import Path
9 | from typing import Dict, Any, Optional, Tuple
10 | from safetensors import safe_open
11 |
12 | class LoRAProcessor:
13 | """Main class for LoRA manipulation operations"""
14 |
15 | def __init__(self, sd_scripts_path: Optional[str] = None):
16 | """Initialize LoRA processor
17 |
18 | Args:
19 | sd_scripts_path: Path to sd-scripts directory. If None, looks in parent directory.
20 | """
21 | self.sd_scripts_path = self._find_sd_scripts(sd_scripts_path)
22 |
23 | def _find_sd_scripts(self, custom_path: Optional[str] = None) -> str:
24 | """Find sd-scripts directory"""
25 | if custom_path and os.path.exists(custom_path):
26 | return custom_path
27 |
28 | # Look in common locations
29 | possible_paths = [
30 | "sd-scripts",
31 | "../sd-scripts",
32 | "../../sd-scripts",
33 | os.path.join(os.path.dirname(__file__), "..", "..", "sd-scripts")
34 | ]
35 |
36 | for path in possible_paths:
37 | abs_path = os.path.abspath(path)
38 | if os.path.exists(abs_path) and os.path.exists(os.path.join(abs_path, "networks")):
39 | return abs_path
40 |
41 | raise FileNotFoundError("sd-scripts directory not found. Please specify the path manually.")
42 |
43 | def extract_metadata(self, lora_path: str) -> Optional[Dict[str, Any]]:
44 | """Extract metadata from LoRA file
45 |
46 | Args:
47 | lora_path: Path to LoRA .safetensors file
48 |
49 | Returns:
50 | Dictionary containing LoRA metadata or None if extraction fails
51 | """
52 | try:
53 | if not os.path.exists(lora_path):
54 | return None
55 |
56 | metadata = {}
57 |
58 | with safe_open(lora_path, framework="pt", device="cpu") as f:
59 | file_metadata = f.metadata()
60 |
61 | # Extract network dimension (rank)
62 | network_dim = None
63 | network_alpha = None
64 |
65 | # Try to get from metadata first
66 | if file_metadata:
67 | network_dim = file_metadata.get('ss_network_dim')
68 | network_alpha = file_metadata.get('ss_network_alpha')
69 |
70 | # Get other useful metadata
71 | metadata['learning_rate'] = file_metadata.get('ss_learning_rate', '1e-4')
72 | metadata['base_model'] = file_metadata.get('ss_base_model_version', '')
73 | metadata['training_comment'] = file_metadata.get('ss_training_comment', '')
74 |
75 | # If not in metadata, inspect tensor shapes
76 | if network_dim is None:
77 | for key in f.keys():
78 | if 'lora_down.weight' in key:
79 | tensor = f.get_tensor(key)
80 | if len(tensor.shape) == 2: # Linear layer
81 | network_dim = tensor.shape[0]
82 | break
83 | elif len(tensor.shape) == 4: # Conv layer
84 | network_dim = tensor.shape[0]
85 | break
86 |
87 | # If still not found, try alpha tensors
88 | if network_alpha is None:
89 | for key in f.keys():
90 | if key.endswith('.alpha'):
91 | alpha_tensor = f.get_tensor(key)
92 | network_alpha = float(alpha_tensor.item())
93 | break
94 |
95 | # Convert to proper types
96 | if network_dim:
97 | try:
98 | network_dim = int(network_dim)
99 | except:
100 | network_dim = 32 # fallback
101 | else:
102 | network_dim = 32 # fallback
103 |
104 | if network_alpha:
105 | try:
106 | network_alpha = float(network_alpha)
107 | except:
108 | network_alpha = 32.0 # fallback
109 | else:
110 | network_alpha = 32.0 # fallback
111 |
112 | metadata['network_dim'] = network_dim
113 | metadata['network_alpha'] = network_alpha
114 |
115 | return metadata
116 |
117 | except Exception as e:
118 | print(f"Error extracting LoRA metadata from {lora_path}: {e}")
119 | return None
120 |
121 | def _run_sd_script(self, script_name: str, args: list) -> Tuple[bool, str]:
122 | """Run an sd-scripts command
123 |
124 | Args:
125 | script_name: Name of the script (e.g., 'flux_merge_lora.py')
126 | args: List of arguments to pass to the script
127 |
128 | Returns:
129 | Tuple of (success: bool, output: str)
130 | """
131 | script_path = os.path.join(self.sd_scripts_path, "networks", script_name)
132 |
133 | if not os.path.exists(script_path):
134 | return False, f"Script not found: {script_path}"
135 |
136 | command = [sys.executable, script_path] + args
137 |
138 | print(f"DEBUG: Running command: {' '.join(command)}")
139 | print(f"DEBUG: Working directory: {self.sd_scripts_path}")
140 | print(f"DEBUG: Script exists: {os.path.exists(script_path)}")
141 |
142 | try:
143 | result = subprocess.run(
144 | command,
145 | capture_output=True,
146 | text=True,
147 | cwd=self.sd_scripts_path
148 | )
149 |
150 | success = result.returncode == 0
151 |
152 | if success:
153 | output = result.stdout
154 | else:
155 | # Combine both stdout and stderr for better error reporting
156 | output = f"Return code: {result.returncode}\n"
157 | if result.stdout:
158 | output += f"STDOUT:\n{result.stdout}\n"
159 | if result.stderr:
160 | output += f"STDERR:\n{result.stderr}\n"
161 | if not result.stdout and not result.stderr:
162 | output += "No output from command"
163 |
164 | return success, output
165 |
166 | except Exception as e:
167 | return False, f"Exception running command: {str(e)}\nCommand: {' '.join(command)}"
168 |
169 | def _resolve_path(self, path: str) -> str:
170 | """Convert relative path to absolute with quotes"""
171 | abs_path = os.path.abspath(path)
172 | return f'"{abs_path}"'
173 |
174 | def _resolve_path_without_quotes(self, path: str) -> str:
175 | """Convert relative path to absolute without quotes"""
176 | return os.path.abspath(path)
--------------------------------------------------------------------------------
/lora_algebra/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions for LoRA operations
3 | """
4 |
5 | import os
6 | import re
7 | from typing import List, Tuple, Optional
8 | from pathlib import Path
9 |
10 | def find_lora_files(directory: str, recursive: bool = True) -> List[str]:
11 | """Find all LoRA files in a directory
12 |
13 | Args:
14 | directory: Directory to search
15 | recursive: Whether to search subdirectories
16 |
17 | Returns:
18 | List of LoRA file paths
19 | """
20 | lora_files = []
21 | search_pattern = "**/*.safetensors" if recursive else "*.safetensors"
22 |
23 | try:
24 | directory_path = Path(directory)
25 | if directory_path.exists():
26 | for file_path in directory_path.glob(search_pattern):
27 | if file_path.is_file():
28 | lora_files.append(str(file_path))
29 | except Exception as e:
30 | print(f"Error searching directory {directory}: {e}")
31 |
32 | return sorted(lora_files)
33 |
34 | def validate_lora_path(path: str) -> Tuple[bool, str]:
35 | """Validate that a path points to a valid LoRA file
36 |
37 | Args:
38 | path: Path to validate
39 |
40 | Returns:
41 | Tuple of (is_valid: bool, message: str)
42 | """
43 | if not path:
44 | return False, "Path is empty"
45 |
46 | if not os.path.exists(path):
47 | return False, f"File does not exist: {path}"
48 |
49 | if not path.lower().endswith('.safetensors'):
50 | return False, "File must be a .safetensors file"
51 |
52 | if not os.path.isfile(path):
53 | return False, "Path is not a file"
54 |
55 | # Check file size (LoRAs should be at least 1KB, typically much larger)
56 | try:
57 | size = os.path.getsize(path)
58 | if size < 1024: # Less than 1KB
59 | return False, "File is too small to be a valid LoRA"
60 | except Exception as e:
61 | return False, f"Could not check file size: {e}"
62 |
63 | return True, "Valid LoRA file"
64 |
65 | def sanitize_filename(filename: str) -> str:
66 | """Sanitize a filename for safe use across different operating systems
67 |
68 | Args:
69 | filename: Original filename
70 |
71 | Returns:
72 | Sanitized filename
73 | """
74 | # Remove or replace invalid characters
75 | filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
76 |
77 | # Remove leading/trailing spaces and dots
78 | filename = filename.strip(' .')
79 |
80 | # Ensure it's not empty
81 | if not filename:
82 | filename = "unnamed"
83 |
84 | # Truncate if too long (255 chars is common limit)
85 | if len(filename) > 200: # Leave room for extension
86 | filename = filename[:200]
87 |
88 | return filename
89 |
90 | def generate_output_path(base_dir: str, name: str, operation: str = "processed") -> str:
91 | """Generate a safe output path for processed LoRAs
92 |
93 | Args:
94 | base_dir: Base directory for output
95 | name: Base name for the file
96 | operation: Type of operation (for naming)
97 |
98 | Returns:
99 | Full output path
100 | """
101 | # Sanitize the name
102 | safe_name = sanitize_filename(name)
103 |
104 | # Add operation suffix
105 | if operation != "processed":
106 | safe_name = f"{safe_name}_{operation}"
107 |
108 | # Ensure .safetensors extension
109 | if not safe_name.lower().endswith('.safetensors'):
110 | safe_name += '.safetensors'
111 |
112 | # Create full path
113 | output_path = os.path.join(base_dir, safe_name)
114 |
115 | # Handle name conflicts by adding numbers
116 | counter = 1
117 | original_path = output_path
118 | while os.path.exists(output_path):
119 | base, ext = os.path.splitext(original_path)
120 | output_path = f"{base}_{counter}{ext}"
121 | counter += 1
122 |
123 | return output_path
124 |
125 | def format_file_size(size_bytes: int) -> str:
126 | """Format file size in human-readable format
127 |
128 | Args:
129 | size_bytes: Size in bytes
130 |
131 | Returns:
132 | Formatted size string
133 | """
134 | for unit in ['B', 'KB', 'MB', 'GB']:
135 | if size_bytes < 1024:
136 | return f"{size_bytes:.1f} {unit}"
137 | size_bytes /= 1024
138 | return f"{size_bytes:.1f} TB"
139 |
140 | def get_recent_loras(directory: str, limit: int = 10) -> List[Tuple[str, str, float]]:
141 | """Get recently modified LoRA files
142 |
143 | Args:
144 | directory: Directory to search
145 | limit: Maximum number of files to return
146 |
147 | Returns:
148 | List of tuples (path, name, modification_time)
149 | """
150 | lora_files = find_lora_files(directory, recursive=True)
151 |
152 | # Get modification times and sort
153 | file_info = []
154 | for file_path in lora_files:
155 | try:
156 | mod_time = os.path.getmtime(file_path)
157 | name = os.path.basename(file_path)
158 | file_info.append((file_path, name, mod_time))
159 | except Exception:
160 | continue # Skip files we can't access
161 |
162 | # Sort by modification time (newest first) and limit
163 | file_info.sort(key=lambda x: x[2], reverse=True)
164 | return file_info[:limit]
165 |
166 | def estimate_processing_time(lora_paths: List[str], operation: str = "subtract") -> float:
167 | """Estimate processing time for LoRA operations
168 |
169 | Args:
170 | lora_paths: List of LoRA file paths
171 | operation: Type of operation
172 |
173 | Returns:
174 | Estimated time in seconds
175 | """
176 | if not lora_paths:
177 | return 0.0
178 |
179 | # Base processing time per operation
180 | base_times = {
181 | "subtract": 3.0,
182 | "merge": 2.5,
183 | "analyze": 1.0
184 | }
185 |
186 | base_time = base_times.get(operation, 2.0)
187 |
188 | # Adjust based on file sizes
189 | total_size = 0
190 | for path in lora_paths:
191 | if os.path.exists(path):
192 | total_size += os.path.getsize(path)
193 |
194 | # Add time based on file size (rough estimate)
195 | size_factor = (total_size / (1024 * 1024)) * 0.1 # 0.1 seconds per MB
196 |
197 | return base_time + size_factor
198 |
199 | def create_backup(file_path: str, backup_dir: Optional[str] = None) -> str:
200 | """Create a backup of a file
201 |
202 | Args:
203 | file_path: Path to file to backup
204 | backup_dir: Directory for backup (default: same directory)
205 |
206 | Returns:
207 | Path to backup file
208 | """
209 | if not os.path.exists(file_path):
210 | raise FileNotFoundError(f"File not found: {file_path}")
211 |
212 | # Determine backup directory
213 | if backup_dir is None:
214 | backup_dir = os.path.dirname(file_path)
215 | else:
216 | os.makedirs(backup_dir, exist_ok=True)
217 |
218 | # Generate backup filename
219 | base_name = os.path.basename(file_path)
220 | name, ext = os.path.splitext(base_name)
221 |
222 | import time
223 | timestamp = int(time.time())
224 | backup_name = f"{name}_backup_{timestamp}{ext}"
225 | backup_path = os.path.join(backup_dir, backup_name)
226 |
227 | # Copy file
228 | import shutil
229 | shutil.copy2(file_path, backup_path)
230 |
231 | return backup_path
--------------------------------------------------------------------------------
/lora_algebra/analysis.py:
--------------------------------------------------------------------------------
1 | """
2 | LoRA analysis and compatibility prediction
3 | """
4 |
5 | import os
6 | from typing import List, Dict, Any, Tuple
7 | from .core import LoRAProcessor
8 |
9 | def extract_metadata(lora_path: str) -> Dict[str, Any]:
10 | """Extract metadata from a LoRA file
11 |
12 | Args:
13 | lora_path: Path to LoRA file
14 |
15 | Returns:
16 | Dictionary containing metadata
17 | """
18 | processor = LoRAProcessor()
19 | return processor.extract_metadata(lora_path) or {}
20 |
21 | def predict_compatibility(lora_a_path: str, lora_b_path: str) -> Dict[str, Any]:
22 | """Predict compatibility between two LoRAs
23 |
24 | Args:
25 | lora_a_path: Path to first LoRA
26 | lora_b_path: Path to second LoRA
27 |
28 | Returns:
29 | Dictionary containing compatibility analysis
30 | """
31 | processor = LoRAProcessor()
32 |
33 | metadata_a = processor.extract_metadata(lora_a_path)
34 | metadata_b = processor.extract_metadata(lora_b_path)
35 |
36 | if not metadata_a or not metadata_b:
37 | return {
38 | "compatible": False,
39 | "confidence": 0.0,
40 | "issues": ["Could not extract metadata from one or both LoRAs"],
41 | "recommendations": ["Check that both files are valid LoRA .safetensors files"]
42 | }
43 |
44 | issues = []
45 | recommendations = []
46 | compatibility_score = 1.0
47 |
48 | # Check rank compatibility
49 | rank_a = metadata_a.get('network_dim', 32)
50 | rank_b = metadata_b.get('network_dim', 32)
51 |
52 | if abs(rank_a - rank_b) > 64:
53 | issues.append(f"Large rank difference: {rank_a} vs {rank_b}")
54 | recommendations.append("Consider using concat mode when merging")
55 | compatibility_score -= 0.2
56 |
57 | # Check alpha compatibility
58 | alpha_a = metadata_a.get('network_alpha', 32.0)
59 | alpha_b = metadata_b.get('network_alpha', 32.0)
60 |
61 | alpha_ratio = max(alpha_a, alpha_b) / min(alpha_a, alpha_b) if min(alpha_a, alpha_b) > 0 else 1.0
62 |
63 | if alpha_ratio > 2.0:
64 | issues.append(f"Very different alpha values: {alpha_a} vs {alpha_b}")
65 | recommendations.append("May need strength adjustment when combining")
66 | compatibility_score -= 0.1
67 |
68 | # Check base model compatibility
69 | base_a = metadata_a.get('base_model', '').lower()
70 | base_b = metadata_b.get('base_model', '').lower()
71 |
72 | if base_a and base_b and base_a != base_b:
73 | if 'flux' in base_a and 'flux' in base_b:
74 | # Both are Flux models, probably compatible
75 | pass
76 | elif 'sd' in base_a and 'sd' in base_b:
77 | # Both are SD models, check version compatibility
78 | if '1.5' in base_a and 'xl' in base_b or 'xl' in base_a and '1.5' in base_b:
79 | issues.append(f"Different SD versions: {base_a} vs {base_b}")
80 | recommendations.append("These LoRAs may not be compatible - consider training on same base model")
81 | compatibility_score -= 0.5
82 | else:
83 | issues.append(f"Different model families: {base_a} vs {base_b}")
84 | recommendations.append("Cross-model LoRAs may not work well together")
85 | compatibility_score -= 0.3
86 |
87 | # Rank efficiency analysis
88 | efficiency_a = alpha_a / rank_a if rank_a > 0 else 1.0
89 | efficiency_b = alpha_b / rank_b if rank_b > 0 else 1.0
90 |
91 | if abs(efficiency_a - efficiency_b) > 1.0:
92 | issues.append("Different training efficiencies detected")
93 | recommendations.append("Consider adjusting relative strengths when combining")
94 | compatibility_score -= 0.1
95 |
96 | # Overall assessment
97 | if compatibility_score > 0.8:
98 | status = "Highly Compatible"
99 | confidence = compatibility_score
100 | elif compatibility_score > 0.6:
101 | status = "Compatible with Adjustments"
102 | confidence = compatibility_score
103 | elif compatibility_score > 0.4:
104 | status = "Limited Compatibility"
105 | confidence = compatibility_score
106 | recommendations.append("Consider using LoRA difference to resolve conflicts")
107 | else:
108 | status = "Incompatible"
109 | confidence = compatibility_score
110 | recommendations.append("Strong recommendation to use LoRA difference for conflict resolution")
111 |
112 | if not issues:
113 | issues.append("No major compatibility issues detected")
114 |
115 | if not recommendations:
116 | recommendations.append("LoRAs should work well together with standard strengths")
117 |
118 | return {
119 | "compatible": compatibility_score > 0.4,
120 | "status": status,
121 | "confidence": round(confidence, 2),
122 | "compatibility_score": round(compatibility_score, 2),
123 | "issues": issues,
124 | "recommendations": recommendations,
125 | "metadata_a": metadata_a,
126 | "metadata_b": metadata_b,
127 | "suggested_strengths": {
128 | "lora_a": min(1.0, 1.0 * compatibility_score + 0.2),
129 | "lora_b": min(1.0, 1.0 * compatibility_score + 0.2)
130 | }
131 | }
132 |
133 | def analyze_multiple_loras(lora_paths: List[str]) -> Dict[str, Any]:
134 | """Analyze compatibility between multiple LoRAs
135 |
136 | Args:
137 | lora_paths: List of paths to LoRA files
138 |
139 | Returns:
140 | Dictionary containing multi-LoRA analysis
141 | """
142 | if len(lora_paths) < 2:
143 | return {"error": "Need at least 2 LoRAs for compatibility analysis"}
144 |
145 | # Extract metadata for all LoRAs
146 | processor = LoRAProcessor()
147 | lora_metadata = []
148 |
149 | for path in lora_paths:
150 | metadata = processor.extract_metadata(path)
151 | if metadata:
152 | metadata['path'] = path
153 | lora_metadata.append(metadata)
154 |
155 | if len(lora_metadata) < 2:
156 | return {"error": "Could not extract metadata from enough LoRAs"}
157 |
158 | # Build compatibility matrix
159 | compatibility_matrix = {}
160 | overall_issues = []
161 |
162 | for i, lora_a in enumerate(lora_metadata):
163 | for j, lora_b in enumerate(lora_metadata):
164 | if i < j: # Only check each pair once
165 | key = f"{os.path.basename(lora_a['path'])} × {os.path.basename(lora_b['path'])}"
166 | compat = predict_compatibility(lora_a['path'], lora_b['path'])
167 | compatibility_matrix[key] = compat
168 |
169 | if not compat['compatible']:
170 | overall_issues.extend(compat['issues'])
171 |
172 | # Calculate overall compatibility
173 | scores = [result['compatibility_score'] for result in compatibility_matrix.values()]
174 | avg_score = sum(scores) / len(scores) if scores else 0.0
175 |
176 | # Recommendations for multi-LoRA usage
177 | recommendations = []
178 |
179 | if avg_score > 0.7:
180 | recommendations.append("All LoRAs show good compatibility - can be used together")
181 | elif avg_score > 0.5:
182 | recommendations.append("Some LoRAs may conflict - consider adjusting strengths")
183 | recommendations.append("Use LoRA difference to resolve specific conflicts")
184 | else:
185 | recommendations.append("Multiple compatibility issues detected")
186 | recommendations.append("Strong recommendation to use LoRA difference workflow")
187 | recommendations.append("Consider processing LoRAs in pairs before combining all")
188 |
189 | return {
190 | "lora_count": len(lora_metadata),
191 | "overall_compatibility": avg_score > 0.5,
192 | "average_score": round(avg_score, 2),
193 | "compatibility_matrix": compatibility_matrix,
194 | "overall_issues": list(set(overall_issues)), # Remove duplicates
195 | "recommendations": recommendations,
196 | "metadata": lora_metadata
197 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LoRA the Explorer - advanced FLUX LoRA manipulation
2 |
3 | 
4 | 
5 | 
6 |
7 | This tool provides various FLUX LoRA manipulation techniques including difference operations, merging, targeted layer merging between LoRAs, layer targeting for zeroing and analysis. It's designed to help you create compatible LoRAs and experiment with different combination approaches.
8 |
9 | ## What You Can Do:
10 |
11 |
12 |
13 | - **LoRA Difference**: Remove conflicts between LoRAs
14 | - **Traditional Merging**: Combine LoRAs with custom weights (auto-detects incompatibilities)
15 | - **Layer-Based Merging**: Surgically combine specific layers from different LoRAs
16 | - **Layer Targeting**: Selectively mute facial or style layers
17 | - **LoRA MetaEditor**: Direct metadata editing for any LoRA file
18 | - **LoRA MetaViewer**: Examine LoRA characteristics, metadata, and compatibility
19 | - **LoRA MetaEditor**: Direct metadata editing for fixing and customizing LoRAs
20 | - **Universal Compatibility**: Works with LoRAs from any training tool (AI-Toolkit, FluxGym, sd-scripts)
21 | - **Automatic Fixes**: Auto-detects and resolves common compatibility issues
22 |
23 | *Demo image created using layer-based LoRA merging in the app.
See Loras merged in Credits further down the page.*
24 |
25 | ## 🛠️ Installation
26 |
27 | **Latest news:** Several updates as of 7/7/25 - 7:25 BST to improve installing on earlier python versions and some dependancy hellscapes.
28 |
29 | ### Automatic Setup (Recommended)
30 |
31 | ```bash
32 | # Clone the repository
33 | git clone https://github.com/shootthesound/lora-the-explorer.git
34 | cd lora-the-explorer
35 |
36 | # Run the installer
37 | python install.py
38 |
39 | ```
40 |
41 | This will:
42 | - ✅ Create a Python virtual environment
43 | - ✅ Download and set up sd-scripts (sd3 branch with Flux support, pinned to stable commit)
44 | - ✅ Apply compatibility fixes for FLUX LoRA metadata
45 | - ✅ Install all dependencies
46 | - ✅ Create launcher scripts
47 |
48 | ### Quick Launch
49 | ```bash
50 | # Windows: Double-click start_gui.bat
51 | # Unix/macOS: ./start_gui.sh
52 | # Or manually:
53 | python lora_algebra_gui.py
54 | ```
55 |
56 | ### Keeping Updated
57 | ```bash
58 | # Windows: Double-click update.bat
59 | # Or manually:
60 | git pull
61 | ```
62 |
63 | The `update.bat` script automatically:
64 | - ✅ Checks for Git availability
65 | - ✅ Validates you're in a git repository
66 | - ✅ Fetches latest changes from the repository
67 | - ✅ Applies updates with clear status messages
68 | - ✅ Handles errors gracefully with helpful guidance
69 |
70 | ## 🎮 Getting Started
71 |
72 | 1. **Set up paths**: Use the "LoRA Paths" tab to scan your LoRA directory for autocomplete, and choose your default save directory
73 | 2. **Experiment**: Try different operations - each tab has usage guides and presets
74 | 3. **Start simple**: Begin with basic merging or difference operations or trying advanced layer techniques. My tip is to find a great style lora and merge a character into it with face layers.
75 |
76 | ## 🎯 Key Features
77 |
78 | ### LoRA Difference
79 | Remove unwanted influences from LoRAs:
80 | ```
81 | Style_LoRA - Character_LoRA = Clean_Style_LoRA
82 | ```
83 | Perfect for removing face changes from style LoRAs and creating character-neutral styles.
84 |
85 | ### Layer Targeting (FLUX)
86 | Selectively mute specific layers in FLUX LoRAs:
87 | - **Facial Layers (7,12,16,20)**: Remove face details while keeping style/costume
88 | - **Aggressive Mode**: Maximum facial identity removal
89 | - **Custom Selection**: Choose any combination of available layers
90 |
91 | Perfect for extracting character costumes without faces (like Gandalf costume without Ian McKellen's face).
92 |
93 | ### Layer-Based Merging
94 | Surgically combine layers from different LoRAs:
95 | - **Face A + Style B**: Facial layers from character LoRA, style from artistic LoRA
96 | - **Facial Priority**: All potential facial layers from one LoRA
97 | - **Complement Split**: Early layers from one LoRA, late layers from another
98 | - **Fix Overtrained LoRAs**: Replace problematic layers with clean ones
99 |
100 | ### LoRA MetaViewer & Analysis
101 | Technical validation and metadata viewing:
102 | - Check rank, alpha, base model compatibility
103 | - Predict potential conflicts and suggest optimal strengths
104 | - Browse and analyze existing LoRAs
105 | - View complete LoRA metadata in readable format
106 | - Double-click file selection from analysis results
107 |
108 | *Note: This analysis is a guide to catch common technical issues only. It checks: rank differences, alpha mismatches, base model compatibility, and training efficiency.*
109 |
110 | ### Path Management
111 | Streamlined workflow:
112 | - Recursive LoRA directory scanning with autocomplete for all path inputs
113 | - Default output directory settings that persist across sessions
114 | - Double-click file selection in analysis results
115 |
116 | ### Universal Compatibility
117 | Automatic fixes for cross-tool compatibility:
118 | - **Auto-detects incompatible LoRAs** and enables concat mode when needed
119 | - **Fixes FLUX metadata issues** (networks.lora → networks.lora_flux) automatically
120 | - **Handles mixed-dimension LoRAs** from layer merging operations
121 | - **Works with any training tool**: AI-Toolkit, FluxGym, sd-scripts, kohya-ss
122 | - **Enhanced error messages** with helpful suggestions when edge cases occur
123 |
124 | ### LoRA MetaEditor
125 | Direct metadata editing with full control:
126 | - **Edit any metadata field** in raw JSON format
127 | - **Fix common issues**: Wrong network modules, incorrect base models, missing fields
128 | - **In-place editing**: Modifies original file (ensure you have backups!)
129 | - **Full user responsibility**: No safety nets - edit at your own risk
130 | - **Universal metadata repair** for LoRAs from any source
131 | - **Cross-tool compatibility fixes**: Repair LoRAs from different training tools
132 | - **Custom metadata management**: Add training details, tags, or remove sensitive info
133 |
134 | ## 💡 Use Cases
135 |
136 | ### Style LoRA Cleaning
137 | Use the LoRA Difference tab to remove face changes from style LoRAs:
138 | 1. Load your style LoRA as "LoRA A"
139 | 2. Load a character/face LoRA as "LoRA B"
140 | 3. Set strength B to ~0.7 (85% of normal usage)
141 | 4. Output a clean style LoRA that won't change faces
142 |
143 | ### Character Costume Extraction
144 | - **Gandalf costume, no Ian McKellen face**: Mute facial layers (7,12,16,20)
145 |
146 | ### Cross-Tool LoRA Compatibility
147 | Fix LoRAs that won't merge due to different training tools:
148 | 1. **Auto-detection**: Tool automatically detects and fixes metadata issues
149 | 2. **Manual fixes**: Use LoRA MetaEditor to fix network modules manually
150 | 3. **Mixed dimensions**: Layer-merged LoRAs automatically trigger concat mode
151 | 4. **Universal merging**: Combine LoRAs from AI-Toolkit, FluxGym, sd-scripts seamlessly
152 |
153 | ### Metadata Repair & Customization
154 | Use the LoRA MetaEditor for advanced metadata management:
155 | 1. **Fix training tool bugs**: Correct wrong network modules or base model info
156 | 2. **Add missing information**: Training details, base model versions, custom tags
157 | 3. **Remove sensitive data**: Clean out unwanted metadata fields
158 | 4. **Standardize collections**: Ensure consistent metadata across your LoRA library
159 | ⚠️ **MetaEditor Safety**: Always backup your LoRAs before editing metadata. Incorrect changes can break your LoRAs permanently.
160 |
161 | ### Advanced Combinations
162 | - Mix facial features from one LoRA with style from another using layer-based merging
163 | - Create hybrid concepts using selective layers
164 | - Rescue partially corrupted or overtrained LoRAs
165 |
166 | ## 📊 FLUX Layer Architecture
167 |
168 | LoRA the Explorer works with FLUX's layer architecture:
169 | - **Text Encoder (0-11)**: 12 layers
170 | - **Double Blocks (0-19)**: 20 layers
171 | - **Single Blocks (0-37)**: 38 layers
172 |
173 | **Known layer functions:**
174 | - **Layers 7 & 20**: Primary facial structure and details
175 | - **Layers 12 & 16**: Secondary facial features
176 | - **Other layers**: Style, composition, lighting (experimental)
177 |
178 | ## 💡 Tips & Best Practices
179 |
180 | ### Difference Sweet Spot
181 | ```
182 | Optimal_Difference_Strength = Normal_Usage_Strength × 0.9 to 1.0
183 | ```
184 |
185 | ### Layer Targeting Strategy
186 | - Start with preset combinations (facial, style, aggressive)
187 | - Experiment with individual layers for fine control
188 | - Use preview to see selected layers before applying
189 |
190 | ## 🌟 Community Project
191 |
192 | I'm building this for the community and welcome your feedback, suggestions, and bug reports. The goal is to make LoRA manipulation more accessible and experimental for the many who are less comfortable with very CLI based tools.
193 |
194 | ### Free to Use
195 | This tool is completely free. If you find it useful and want to support development, you can do so at:
196 |
197 | **☕ [buymeacoffee.com/loratheexplorer](https://buymeacoffee.com/loratheexplorer)**
198 |
199 | Re-occurring supporters get early access to test new features as milestones arise in development.
200 |
201 | ### Feedback Welcome
202 | Found a bug? Have a feature request? Want to share results? All feedback helps improve the tool for everyone. Emails to pneill@gmail.com.
203 |
204 | ## 📄 License
205 |
206 | MIT License - see [LICENSE](LICENSE) for details.
207 |
208 | This tool is free to use, modify, and distribute. The goal is to make LoRA manipulation more accessible to everyone.
209 |
210 | ## 🙏 Credits & Dependencies
211 |
212 | ### sd-scripts Integration
213 | LoRA the Explorer relies on [sd-scripts by kohya-ss](https://github.com/kohya-ss/sd-scripts) for the core LoRA processing functionality. Our installer automatically downloads the sd3 branch which includes FLUX support.
214 |
215 | **sd-scripts provides:**
216 | - FLUX LoRA manipulation capabilities
217 | - SafeTensors file handling
218 | - Core mathematical operations for LoRA algebra
219 |
220 | Special thanks to kohya-ss and the sd-scripts community for creating and maintaining this essential toolkit.
221 |
222 | **Demo Image LoRAs:**
223 | - [Eurasian Golden Oriole](https://civitai.green/models/1668493/eurasian-golden-oriole?modelVersionId=1888520) by hloveex30w126 on CivitAI
224 | - [Fantasy LoRA](https://civitai.green/models/789313?modelVersionId=1287297) by ArsMachina on CivitAI
225 |
226 | ## 🔬 Technical Notes
227 |
228 | ### Mathematical Foundation
229 | LoRAs modify model weights as: `W_new = W_original + α(BA)`
230 |
231 | Difference operation: `LoRA_A - LoRA_B = A₁B₁ - A₂B₂`
232 |
233 | This removes overlapping parameter modifications while preserving unique characteristics.
234 |
235 | ### Performance
236 | - **Processing Speed**: ~2-5 seconds per operation
237 | - **Memory Usage**: ~2GB RAM for large LoRAs
238 | - **Compatibility**: Works with safetensors, ckpt, and pt formats
239 | - **Quality**: No degradation when using optimal parameters
240 |
241 | ---
242 |
243 | **⭐ Star the repository at [github.com/shootthesound/lora-the-explorer](https://github.com/shootthesound/lora-the-explorer) if LoRA the Explorer helps with your LoRA workflow!**
--------------------------------------------------------------------------------
/install.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | LoRA the Explorer Installation Script
4 |
5 | Creates a Python virtual environment and installs dependencies, similar to FluxGym.
6 | """
7 |
8 | import os
9 | import sys
10 | import subprocess
11 | import platform
12 | from pathlib import Path
13 |
14 | def run_command(command, description, check=True):
15 | """Run a command with nice output"""
16 | print(f" {description}...")
17 | try:
18 | if isinstance(command, str):
19 | result = subprocess.run(command, shell=True, check=check, capture_output=True, text=True)
20 | else:
21 | result = subprocess.run(command, check=check, capture_output=True, text=True)
22 |
23 | if result.stdout:
24 | print(f" {result.stdout.strip()}")
25 | return result
26 | except subprocess.CalledProcessError as e:
27 | print(f" Error: {e}")
28 | if e.stderr:
29 | print(f" Error details: {e.stderr.strip()}")
30 | if check:
31 | raise
32 | return e
33 |
34 | def check_python_version():
35 | """Check if Python version is compatible"""
36 | version = sys.version_info
37 | if version.major < 3 or (version.major == 3 and version.minor < 8):
38 | print(" Python 3.8 or higher is required")
39 | print(f" Current version: {version.major}.{version.minor}.{version.micro}")
40 | sys.exit(1)
41 |
42 | print(f" Python {version.major}.{version.minor}.{version.micro} - Compatible")
43 | return True
44 |
45 | def create_virtual_environment():
46 | """Create Python virtual environment"""
47 | env_path = Path("env")
48 |
49 | if env_path.exists():
50 | print(" Virtual environment already exists")
51 | return env_path
52 |
53 | print(" Creating Python virtual environment...")
54 |
55 | # Create virtual environment
56 | result = run_command([sys.executable, "-m", "venv", "env"], "Creating virtual environment")
57 |
58 | if result.returncode == 0:
59 | print(" Virtual environment created successfully")
60 | return env_path
61 | else:
62 | print(" Failed to create virtual environment")
63 | sys.exit(1)
64 |
65 | def get_python_executable():
66 | """Get the path to Python executable in virtual environment"""
67 | if platform.system() == "Windows":
68 | return Path("env") / "Scripts" / "python.exe"
69 | else:
70 | return Path("env") / "bin" / "python"
71 |
72 | def get_pip_executable():
73 | """Get the path to pip executable in virtual environment"""
74 | if platform.system() == "Windows":
75 | return Path("env") / "Scripts" / "pip.exe"
76 | else:
77 | return Path("env") / "bin" / "pip"
78 |
79 | def install_dependencies():
80 | """Install required dependencies"""
81 | pip_exe = get_pip_executable()
82 | python_exe = get_python_executable()
83 |
84 | if not pip_exe.exists():
85 | print(" pip not found in virtual environment")
86 | sys.exit(1)
87 |
88 | print(" Installing dependencies...")
89 |
90 | # Upgrade pip first using python -m pip (more reliable on Windows)
91 | run_command([str(python_exe), "-m", "pip", "install", "--upgrade", "pip"], "Upgrading pip", check=False)
92 |
93 | # Install wheel for better package building
94 | run_command([str(pip_exe), "install", "wheel"], "Installing wheel")
95 |
96 | # Install requirements
97 | if Path("requirements.txt").exists():
98 | run_command([str(pip_exe), "install", "-r", "requirements.txt"], "Installing requirements")
99 | else:
100 | # Fallback to manual installation of core dependencies
101 | dependencies = [
102 | "torch==2.7.1",
103 | "torchvision==0.22.1",
104 | "accelerate==1.8.1",
105 | "transformers==4.44.0",
106 | "diffusers[torch]==0.25.0",
107 | "safetensors==0.4.4",
108 | "sentencepiece==0.2.0",
109 | "gradio>=4.0.0",
110 | "einops==0.7.0",
111 | "huggingface-hub==0.24.5",
112 | "rich==13.7.0",
113 | "numpy>=1.24.0",
114 | "pyyaml>=6.0.0",
115 | "pillow>=10.0.0",
116 | "tqdm>=4.66.0"
117 | ]
118 |
119 | for dep in dependencies:
120 | run_command([str(pip_exe), "install", dep], f"Installing {dep}")
121 |
122 | # Install current package in development mode
123 | run_command([str(pip_exe), "install", "-e", "."], "Installing LoRA the Explorer in development mode")
124 |
125 | def create_launcher_scripts():
126 | """Create launcher scripts for easy access"""
127 | python_exe = get_python_executable()
128 |
129 | if platform.system() == "Windows":
130 | # Windows batch file
131 | launcher_content = f"""@echo off
132 | setlocal EnableDelayedExpansion
133 | echo Launching LoRA the Explorer GUI...
134 | echo.
135 |
136 | REM Check for updates if git is available (non-blocking)
137 | git --version >nul 2>&1
138 | if not errorlevel 1 (
139 | git status >nul 2>&1
140 | if not errorlevel 1 (
141 | echo [INFO] Checking for updates...
142 | git fetch >nul 2>&1
143 | if not errorlevel 1 (
144 | REM Check if we have an upstream branch configured
145 | git rev-parse --abbrev-ref @{{u}} >nul 2>&1
146 | if not errorlevel 1 (
147 | REM Count commits behind using rev-list
148 | for /f %%i in ('git rev-list --count HEAD..@{{u}} 2^>nul') do set BEHIND_COUNT=%%i
149 | if not "!BEHIND_COUNT!"=="0" if not "!BEHIND_COUNT!"=="" (
150 | echo.
151 | echo ===============================================
152 | echo UPDATE AVAILABLE!
153 | echo ===============================================
154 | echo.
155 | echo A newer version of LoRA the Explorer is available.
156 | echo Run update.bat to get the latest features and fixes.
157 | echo.
158 | echo Press any key to continue launching the GUI...
159 | pause >nul
160 | echo.
161 | ) else (
162 | echo [OK] You are running the latest version
163 | echo.
164 | )
165 | ) else (
166 | echo [INFO] No upstream branch configured, skipping update check
167 | echo.
168 | )
169 | )
170 | )
171 | )
172 |
173 | echo Starting GUI...
174 | "{python_exe.absolute()}" lora_algebra_gui.py
175 | pause
176 | """
177 | with open("start_gui.bat", "w") as f:
178 | f.write(launcher_content)
179 |
180 | print(" Created Windows launcher script:")
181 | print(" start_gui.bat - Launch GUI")
182 |
183 | else:
184 | # Unix shell script
185 | launcher_content = f"""#!/bin/bash
186 | echo " Launching LoRA the Explorer GUI..."
187 | echo
188 |
189 | # Check for updates if git is available (non-blocking)
190 | if command -v git >/dev/null 2>&1; then
191 | if git status >/dev/null 2>&1; then
192 | echo "[INFO] Checking for updates..."
193 | if git fetch >/dev/null 2>&1; then
194 | # Check if we have an upstream branch configured
195 | if git rev-parse --abbrev-ref @{{u}} >/dev/null 2>&1; then
196 | # Count commits behind using rev-list
197 | BEHIND_COUNT=$(git rev-list --count HEAD..@{{u}} 2>/dev/null)
198 | if [ "$BEHIND_COUNT" -gt 0 ] 2>/dev/null; then
199 | echo
200 | echo "==============================================="
201 | echo " UPDATE AVAILABLE!"
202 | echo "==============================================="
203 | echo
204 | echo "A newer version of LoRA the Explorer is available."
205 | echo "Run 'git pull' to get the latest features and fixes."
206 | echo
207 | echo "Press any key to continue launching the GUI..."
208 | read -n 1 -s
209 | echo
210 | else
211 | echo "[OK] You are running the latest version"
212 | echo
213 | fi
214 | else
215 | echo "[INFO] No upstream branch configured, skipping update check"
216 | echo
217 | fi
218 | fi
219 | fi
220 | fi
221 |
222 | echo "Starting GUI..."
223 | "{python_exe.absolute()}" lora_algebra_gui.py
224 | """
225 | with open("start_gui.sh", "w") as f:
226 | f.write(launcher_content)
227 | os.chmod("start_gui.sh", 0o755)
228 |
229 | print(" Created Unix launcher script:")
230 | print(" start_gui.sh - Launch GUI")
231 |
232 | def download_sd_scripts():
233 | """Download and set up sd-scripts"""
234 | sd_scripts_path = Path("sd-scripts")
235 |
236 | if sd_scripts_path.exists() and (sd_scripts_path / "networks").exists():
237 | print(" sd-scripts already installed")
238 | return sd_scripts_path
239 |
240 | print(" Downloading sd-scripts...")
241 |
242 | # Clone sd-scripts repository (sd3 branch with Flux support)
243 | clone_result = run_command([
244 | "git", "clone",
245 | "-b", "sd3",
246 | "https://github.com/kohya-ss/sd-scripts.git",
247 | "sd-scripts"
248 | ], "Cloning sd-scripts sd3 branch (Flux support)", check=False)
249 |
250 | if clone_result.returncode == 0:
251 | # Pin to specific commit for version stability
252 | print(" Pinning sd-scripts to tested commit...")
253 |
254 | # Change to sd-scripts directory to run git checkout
255 | original_dir = os.getcwd()
256 | try:
257 | os.chdir("sd-scripts")
258 | run_command([
259 | "git", "checkout", "3e6935a07edcb944407840ef74fcaf6fcad352f7"
260 | ], "Pinning to stable commit", check=False)
261 | finally:
262 | os.chdir(original_dir)
263 |
264 | if clone_result.returncode != 0:
265 | print(" Failed to clone sd-scripts. Trying alternative method...")
266 |
267 | # Alternative: download as zip
268 | try:
269 | import urllib.request
270 | import zipfile
271 |
272 | print(" Downloading sd-scripts sd3 branch as ZIP...")
273 | url = "https://github.com/kohya-ss/sd-scripts/archive/refs/heads/sd3.zip"
274 | zip_path = "sd-scripts-sd3.zip"
275 |
276 | urllib.request.urlretrieve(url, zip_path)
277 |
278 | with zipfile.ZipFile(zip_path, 'r') as zip_ref:
279 | zip_ref.extractall(".")
280 |
281 | # Rename extracted folder
282 | if Path("sd-scripts-sd3").exists():
283 | Path("sd-scripts-sd3").rename("sd-scripts")
284 |
285 | # Clean up
286 | Path(zip_path).unlink()
287 | print(" sd-scripts downloaded successfully")
288 |
289 | except Exception as e:
290 | print(f" Failed to download sd-scripts: {e}")
291 | print(" Please manually download from: https://github.com/kohya-ss/sd-scripts")
292 | print(" Extract to: ./sd-scripts/")
293 | return None
294 |
295 | # Install sd-scripts requirements and package
296 | if sd_scripts_path.exists():
297 | print(" Installing sd-scripts dependencies...")
298 | pip_exe = get_pip_executable()
299 | python_exe = get_python_executable()
300 |
301 | # Install sd-scripts requirements
302 | sd_requirements = sd_scripts_path / "requirements.txt"
303 | if sd_requirements.exists():
304 | run_command([
305 | str(pip_exe), "install", "-r", str(sd_requirements), "-c", "constraints.txt"
306 | ], "Installing sd-scripts requirements", check=False)
307 | else:
308 | print(" Warning: sd-scripts requirements.txt not found")
309 |
310 | # Install sd-scripts as editable package
311 | print(" Installing sd-scripts library...")
312 | install_result = run_command([
313 | str(pip_exe), "install", "-e", str(sd_scripts_path)
314 | ], "Installing sd-scripts library", check=False)
315 |
316 | # Verify installation by checking if library module can be imported
317 | print(" Verifying sd-scripts installation...")
318 | verify_result = run_command([
319 | str(python_exe), "-c", f"import sys; sys.path.insert(0, '{sd_scripts_path}'); import library.utils; print(' sd-scripts library verified')"
320 | ], "Verifying library module", check=False)
321 |
322 | if verify_result.returncode == 0:
323 | # Apply fix for FLUX LoRA metadata by copying corrected file
324 | print("🔧 Applying FLUX metadata fix...")
325 | try:
326 | import shutil
327 | fixed_file = Path("fixed_files") / "flux_merge_lora.py"
328 | target_file = sd_scripts_path / "networks" / "flux_merge_lora.py"
329 |
330 | if fixed_file.exists() and target_file.exists():
331 | shutil.copy2(str(fixed_file), str(target_file))
332 | print(" FLUX metadata fix applied (networks.lora_flux)")
333 | elif not fixed_file.exists():
334 | print(" Fixed file not found in fixed_files directory")
335 | elif not target_file.exists():
336 | print(" Target file not found in sd-scripts")
337 | except Exception as e:
338 | print(f"️ Could not apply FLUX metadata fix: {e}")
339 | print(" LoRAs may have incorrect network module metadata")
340 |
341 | print(" sd-scripts setup complete")
342 | else:
343 | print(" sd-scripts installed but library verification failed")
344 | print(" This may cause issues with LoRA operations")
345 |
346 | return sd_scripts_path
347 |
348 | return None
349 |
350 | def main():
351 | """Main installation process"""
352 | print("LoRA the Explorer Installation")
353 | print("=" * 50)
354 | print()
355 |
356 | # Check Python version
357 | check_python_version()
358 | print()
359 |
360 | # Create virtual environment
361 | env_path = create_virtual_environment()
362 | print()
363 |
364 | # Install dependencies
365 | install_dependencies()
366 | print()
367 |
368 | # Create launcher scripts
369 | create_launcher_scripts()
370 | print()
371 |
372 | # Download and set up sd-scripts
373 | sd_scripts_path = download_sd_scripts()
374 | print()
375 |
376 | # Success message
377 | print(" Installation Complete!")
378 | print("=" * 50)
379 | print()
380 | print(" Quick Start:")
381 |
382 | if platform.system() == "Windows":
383 | print(" GUI: Double-click start_gui.bat")
384 | else:
385 | print(" GUI: ./start_gui.sh")
386 |
387 | print()
388 | print(" Manual command:")
389 | python_exe = get_python_executable()
390 | print(f" GUI: {python_exe} lora_algebra_gui.py")
391 | print()
392 | print("🔗 Project: https://github.com/shootthesound/lora-the-explorer")
393 |
394 | if sd_scripts_path:
395 | print(f" sd-scripts: {sd_scripts_path.absolute()}")
396 |
397 | print()
398 |
399 | if __name__ == "__main__":
400 | try:
401 | main()
402 | except KeyboardInterrupt:
403 | print("\n Installation cancelled by user")
404 | sys.exit(1)
405 | except Exception as e:
406 | print(f"\n Installation failed: {e}")
407 | sys.exit(1)
--------------------------------------------------------------------------------
/fixed_files/flux_merge_lora.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | import os
4 | import time
5 | from typing import Any, Dict, Union
6 |
7 | import torch
8 | from safetensors import safe_open
9 | from safetensors.torch import load_file, save_file
10 | from tqdm import tqdm
11 |
12 | from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file
13 |
14 | setup_logging()
15 | import logging
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 | import lora_flux as lora_flux
20 | from library import sai_model_spec, train_util
21 |
22 |
23 | def load_state_dict(file_name, dtype):
24 | if os.path.splitext(file_name)[1] == ".safetensors":
25 | sd = load_file(file_name)
26 | metadata = train_util.load_metadata_from_safetensors(file_name)
27 | else:
28 | sd = torch.load(file_name, map_location="cpu")
29 | metadata = {}
30 |
31 | for key in list(sd.keys()):
32 | if type(sd[key]) == torch.Tensor:
33 | sd[key] = sd[key].to(dtype)
34 |
35 | return sd, metadata
36 |
37 |
38 | def save_to_file(file_name, state_dict: Dict[str, Union[Any, torch.Tensor]], dtype, metadata, mem_eff_save=False):
39 | if dtype is not None:
40 | logger.info(f"converting to {dtype}...")
41 | for key in tqdm(list(state_dict.keys())):
42 | if type(state_dict[key]) == torch.Tensor and state_dict[key].dtype.is_floating_point:
43 | state_dict[key] = state_dict[key].to(dtype)
44 |
45 | logger.info(f"saving to: {file_name}")
46 | if mem_eff_save:
47 | mem_eff_save_file(state_dict, file_name, metadata=metadata)
48 | else:
49 | save_file(state_dict, file_name, metadata=metadata)
50 |
51 |
52 | def merge_to_flux_model(
53 | loading_device,
54 | working_device,
55 | flux_path: str,
56 | clip_l_path: str,
57 | t5xxl_path: str,
58 | models,
59 | ratios,
60 | merge_dtype,
61 | save_dtype,
62 | mem_eff_load_save=False,
63 | ):
64 | # create module map without loading state_dict
65 | lora_name_to_module_key = {}
66 | if flux_path is not None:
67 | logger.info(f"loading keys from FLUX.1 model: {flux_path}")
68 | with safe_open(flux_path, framework="pt", device=loading_device) as flux_file:
69 | keys = list(flux_file.keys())
70 | for key in keys:
71 | if key.endswith(".weight"):
72 | module_name = ".".join(key.split(".")[:-1])
73 | lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_")
74 | lora_name_to_module_key[lora_name] = key
75 |
76 | lora_name_to_clip_l_key = {}
77 | if clip_l_path is not None:
78 | logger.info(f"loading keys from clip_l model: {clip_l_path}")
79 | with safe_open(clip_l_path, framework="pt", device=loading_device) as clip_l_file:
80 | keys = list(clip_l_file.keys())
81 | for key in keys:
82 | if key.endswith(".weight"):
83 | module_name = ".".join(key.split(".")[:-1])
84 | lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP + "_" + module_name.replace(".", "_")
85 | lora_name_to_clip_l_key[lora_name] = key
86 |
87 | lora_name_to_t5xxl_key = {}
88 | if t5xxl_path is not None:
89 | logger.info(f"loading keys from t5xxl model: {t5xxl_path}")
90 | with safe_open(t5xxl_path, framework="pt", device=loading_device) as t5xxl_file:
91 | keys = list(t5xxl_file.keys())
92 | for key in keys:
93 | if key.endswith(".weight"):
94 | module_name = ".".join(key.split(".")[:-1])
95 | lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5 + "_" + module_name.replace(".", "_")
96 | lora_name_to_t5xxl_key[lora_name] = key
97 |
98 | flux_state_dict = {}
99 | clip_l_state_dict = {}
100 | t5xxl_state_dict = {}
101 | if mem_eff_load_save:
102 | if flux_path is not None:
103 | with MemoryEfficientSafeOpen(flux_path) as flux_file:
104 | for key in tqdm(flux_file.keys()):
105 | flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed
106 |
107 | if clip_l_path is not None:
108 | with MemoryEfficientSafeOpen(clip_l_path) as clip_l_file:
109 | for key in tqdm(clip_l_file.keys()):
110 | clip_l_state_dict[key] = clip_l_file.get_tensor(key).to(loading_device)
111 |
112 | if t5xxl_path is not None:
113 | with MemoryEfficientSafeOpen(t5xxl_path) as t5xxl_file:
114 | for key in tqdm(t5xxl_file.keys()):
115 | t5xxl_state_dict[key] = t5xxl_file.get_tensor(key).to(loading_device)
116 | else:
117 | if flux_path is not None:
118 | flux_state_dict = load_file(flux_path, device=loading_device)
119 | if clip_l_path is not None:
120 | clip_l_state_dict = load_file(clip_l_path, device=loading_device)
121 | if t5xxl_path is not None:
122 | t5xxl_state_dict = load_file(t5xxl_path, device=loading_device)
123 |
124 | for model, ratio in zip(models, ratios):
125 | logger.info(f"loading: {model}")
126 | lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU
127 |
128 | logger.info(f"merging...")
129 | for key in tqdm(list(lora_sd.keys())):
130 | if "lora_down" in key:
131 | lora_name = key[: key.rfind(".lora_down")]
132 | up_key = key.replace("lora_down", "lora_up")
133 | alpha_key = key[: key.index("lora_down")] + "alpha"
134 |
135 | if lora_name in lora_name_to_module_key:
136 | module_weight_key = lora_name_to_module_key[lora_name]
137 | state_dict = flux_state_dict
138 | elif lora_name in lora_name_to_clip_l_key:
139 | module_weight_key = lora_name_to_clip_l_key[lora_name]
140 | state_dict = clip_l_state_dict
141 | elif lora_name in lora_name_to_t5xxl_key:
142 | module_weight_key = lora_name_to_t5xxl_key[lora_name]
143 | state_dict = t5xxl_state_dict
144 | else:
145 | logger.warning(
146 | f"no module found for LoRA weight: {key}. Skipping..."
147 | f"LoRAの重みに対応するモジュールが見つかりませんでした。スキップします。"
148 | )
149 | continue
150 |
151 | down_weight = lora_sd.pop(key)
152 | up_weight = lora_sd.pop(up_key)
153 |
154 | dim = down_weight.size()[0]
155 | alpha = lora_sd.pop(alpha_key, dim)
156 | scale = alpha / dim
157 |
158 | # W <- W + U * D
159 | weight = state_dict[module_weight_key]
160 |
161 | weight = weight.to(working_device, merge_dtype)
162 | up_weight = up_weight.to(working_device, merge_dtype)
163 | down_weight = down_weight.to(working_device, merge_dtype)
164 |
165 | # logger.info(module_name, down_weight.size(), up_weight.size())
166 | if len(weight.size()) == 2:
167 | # linear
168 | weight = weight + ratio * (up_weight @ down_weight) * scale
169 | elif down_weight.size()[2:4] == (1, 1):
170 | # conv2d 1x1
171 | weight = (
172 | weight
173 | + ratio
174 | * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
175 | * scale
176 | )
177 | else:
178 | # conv2d 3x3
179 | conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
180 | # logger.info(conved.size(), weight.size(), module.stride, module.padding)
181 | weight = weight + ratio * conved * scale
182 |
183 | state_dict[module_weight_key] = weight.to(loading_device, save_dtype)
184 | del up_weight
185 | del down_weight
186 | del weight
187 |
188 | if len(lora_sd) > 0:
189 | logger.warning(f"Unused keys in LoRA model: {list(lora_sd.keys())}")
190 |
191 | return flux_state_dict, clip_l_state_dict, t5xxl_state_dict
192 |
193 |
194 | def merge_to_flux_model_diffusers(
195 | loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False
196 | ):
197 | logger.info(f"loading keys from FLUX.1 model: {flux_model}")
198 | if mem_eff_load_save:
199 | flux_state_dict = {}
200 | with MemoryEfficientSafeOpen(flux_model) as flux_file:
201 | for key in tqdm(flux_file.keys()):
202 | flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed
203 | else:
204 | flux_state_dict = load_file(flux_model, device=loading_device)
205 |
206 | def create_key_map(n_double_layers, n_single_layers):
207 | key_map = {}
208 | for index in range(n_double_layers):
209 | prefix_from = f"transformer_blocks.{index}"
210 | prefix_to = f"double_blocks.{index}"
211 |
212 | for end in ("weight", "bias"):
213 | k = f"{prefix_from}.attn."
214 | qkv_img = f"{prefix_to}.img_attn.qkv.{end}"
215 | qkv_txt = f"{prefix_to}.txt_attn.qkv.{end}"
216 |
217 | key_map[f"{k}to_q.{end}"] = qkv_img
218 | key_map[f"{k}to_k.{end}"] = qkv_img
219 | key_map[f"{k}to_v.{end}"] = qkv_img
220 | key_map[f"{k}add_q_proj.{end}"] = qkv_txt
221 | key_map[f"{k}add_k_proj.{end}"] = qkv_txt
222 | key_map[f"{k}add_v_proj.{end}"] = qkv_txt
223 |
224 | block_map = {
225 | "attn.to_out.0.weight": "img_attn.proj.weight",
226 | "attn.to_out.0.bias": "img_attn.proj.bias",
227 | "norm1.linear.weight": "img_mod.lin.weight",
228 | "norm1.linear.bias": "img_mod.lin.bias",
229 | "norm1_context.linear.weight": "txt_mod.lin.weight",
230 | "norm1_context.linear.bias": "txt_mod.lin.bias",
231 | "attn.to_add_out.weight": "txt_attn.proj.weight",
232 | "attn.to_add_out.bias": "txt_attn.proj.bias",
233 | "ff.net.0.proj.weight": "img_mlp.0.weight",
234 | "ff.net.0.proj.bias": "img_mlp.0.bias",
235 | "ff.net.2.weight": "img_mlp.2.weight",
236 | "ff.net.2.bias": "img_mlp.2.bias",
237 | "ff_context.net.0.proj.weight": "txt_mlp.0.weight",
238 | "ff_context.net.0.proj.bias": "txt_mlp.0.bias",
239 | "ff_context.net.2.weight": "txt_mlp.2.weight",
240 | "ff_context.net.2.bias": "txt_mlp.2.bias",
241 | "attn.norm_q.weight": "img_attn.norm.query_norm.scale",
242 | "attn.norm_k.weight": "img_attn.norm.key_norm.scale",
243 | "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
244 | "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
245 | }
246 |
247 | for k, v in block_map.items():
248 | key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}"
249 |
250 | for index in range(n_single_layers):
251 | prefix_from = f"single_transformer_blocks.{index}"
252 | prefix_to = f"single_blocks.{index}"
253 |
254 | for end in ("weight", "bias"):
255 | k = f"{prefix_from}.attn."
256 | qkv = f"{prefix_to}.linear1.{end}"
257 | key_map[f"{k}to_q.{end}"] = qkv
258 | key_map[f"{k}to_k.{end}"] = qkv
259 | key_map[f"{k}to_v.{end}"] = qkv
260 | key_map[f"{prefix_from}.proj_mlp.{end}"] = qkv
261 |
262 | block_map = {
263 | "norm.linear.weight": "modulation.lin.weight",
264 | "norm.linear.bias": "modulation.lin.bias",
265 | "proj_out.weight": "linear2.weight",
266 | "proj_out.bias": "linear2.bias",
267 | "attn.norm_q.weight": "norm.query_norm.scale",
268 | "attn.norm_k.weight": "norm.key_norm.scale",
269 | }
270 |
271 | for k, v in block_map.items():
272 | key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}"
273 |
274 | # add as-is keys
275 | values = list([(v if isinstance(v, str) else v[0]) for v in set(key_map.values())])
276 | values.sort()
277 | key_map.update({v: v for v in values})
278 |
279 | return key_map
280 |
281 | key_map = create_key_map(18, 38) # 18 double layers, 38 single layers
282 |
283 | def find_matching_key(flux_dict, lora_key):
284 | lora_key = lora_key.replace("diffusion_model.", "")
285 | lora_key = lora_key.replace("transformer.", "")
286 | lora_key = lora_key.replace("lora_A", "lora_down").replace("lora_B", "lora_up")
287 | lora_key = lora_key.replace("single_transformer_blocks", "single_blocks")
288 | lora_key = lora_key.replace("transformer_blocks", "double_blocks")
289 |
290 | double_block_map = {
291 | "attn.to_out.0": "img_attn.proj",
292 | "norm1.linear": "img_mod.lin",
293 | "norm1_context.linear": "txt_mod.lin",
294 | "attn.to_add_out": "txt_attn.proj",
295 | "ff.net.0.proj": "img_mlp.0",
296 | "ff.net.2": "img_mlp.2",
297 | "ff_context.net.0.proj": "txt_mlp.0",
298 | "ff_context.net.2": "txt_mlp.2",
299 | "attn.norm_q": "img_attn.norm.query_norm",
300 | "attn.norm_k": "img_attn.norm.key_norm",
301 | "attn.norm_added_q": "txt_attn.norm.query_norm",
302 | "attn.norm_added_k": "txt_attn.norm.key_norm",
303 | "attn.to_q": "img_attn.qkv",
304 | "attn.to_k": "img_attn.qkv",
305 | "attn.to_v": "img_attn.qkv",
306 | "attn.add_q_proj": "txt_attn.qkv",
307 | "attn.add_k_proj": "txt_attn.qkv",
308 | "attn.add_v_proj": "txt_attn.qkv",
309 | }
310 | single_block_map = {
311 | "norm.linear": "modulation.lin",
312 | "proj_out": "linear2",
313 | "attn.norm_q": "norm.query_norm",
314 | "attn.norm_k": "norm.key_norm",
315 | "attn.to_q": "linear1",
316 | "attn.to_k": "linear1",
317 | "attn.to_v": "linear1",
318 | "proj_mlp": "linear1",
319 | }
320 |
321 | # same key exists in both single_block_map and double_block_map, so we must care about single/double
322 | # print("lora_key before double_block_map", lora_key)
323 | for old, new in double_block_map.items():
324 | if "double" in lora_key:
325 | lora_key = lora_key.replace(old, new)
326 | # print("lora_key before single_block_map", lora_key)
327 | for old, new in single_block_map.items():
328 | if "single" in lora_key:
329 | lora_key = lora_key.replace(old, new)
330 | # print("lora_key after mapping", lora_key)
331 |
332 | if lora_key in key_map:
333 | flux_key = key_map[lora_key]
334 | logger.info(f"Found matching key: {flux_key}")
335 | return flux_key
336 |
337 | # If not found in key_map, try partial matching
338 | potential_key = lora_key + ".weight"
339 | logger.info(f"Searching for key: {potential_key}")
340 | matches = [k for k in flux_dict.keys() if potential_key in k]
341 | if matches:
342 | logger.info(f"Found matching key: {matches[0]}")
343 | return matches[0]
344 | return None
345 |
346 | merged_keys = set()
347 | for model, ratio in zip(models, ratios):
348 | logger.info(f"loading: {model}")
349 | lora_sd, _ = load_state_dict(model, merge_dtype)
350 |
351 | logger.info("merging...")
352 | for key in lora_sd.keys():
353 | if "lora_down" in key or "lora_A" in key:
354 | lora_name = key[: key.rfind(".lora_down" if "lora_down" in key else ".lora_A")]
355 | up_key = key.replace("lora_down", "lora_up").replace("lora_A", "lora_B")
356 | alpha_key = key[: key.index("lora_down" if "lora_down" in key else "lora_A")] + "alpha"
357 |
358 | logger.info(f"Processing LoRA key: {lora_name}")
359 | flux_key = find_matching_key(flux_state_dict, lora_name)
360 |
361 | if flux_key is None:
362 | logger.warning(f"no module found for LoRA weight: {key}")
363 | continue
364 |
365 | logger.info(f"Merging LoRA key {lora_name} into Flux key {flux_key}")
366 |
367 | down_weight = lora_sd[key]
368 | up_weight = lora_sd[up_key]
369 |
370 | dim = down_weight.size()[0]
371 | alpha = lora_sd.get(alpha_key, dim)
372 | scale = alpha / dim
373 |
374 | weight = flux_state_dict[flux_key]
375 |
376 | weight = weight.to(working_device, merge_dtype)
377 | up_weight = up_weight.to(working_device, merge_dtype)
378 | down_weight = down_weight.to(working_device, merge_dtype)
379 |
380 | # print(up_weight.size(), down_weight.size(), weight.size())
381 |
382 | if lora_name.startswith("transformer."):
383 | if "qkv" in flux_key or "linear1" in flux_key: # combined qkv or qkv+mlp
384 | update = ratio * (up_weight @ down_weight) * scale
385 | # print(update.shape)
386 |
387 | if "img_attn" in flux_key or "txt_attn" in flux_key:
388 | q, k, v = torch.chunk(weight, 3, dim=0)
389 | if "to_q" in lora_name or "add_q_proj" in lora_name:
390 | q += update.reshape(q.shape)
391 | elif "to_k" in lora_name or "add_k_proj" in lora_name:
392 | k += update.reshape(k.shape)
393 | elif "to_v" in lora_name or "add_v_proj" in lora_name:
394 | v += update.reshape(v.shape)
395 | weight = torch.cat([q, k, v], dim=0)
396 | elif "linear1" in flux_key:
397 | q, k, v = torch.chunk(weight[: int(update.shape[-1] * 3)], 3, dim=0)
398 | mlp = weight[int(update.shape[-1] * 3) :]
399 | # print(q.shape, k.shape, v.shape, mlp.shape)
400 | if "to_q" in lora_name:
401 | q += update.reshape(q.shape)
402 | elif "to_k" in lora_name:
403 | k += update.reshape(k.shape)
404 | elif "to_v" in lora_name:
405 | v += update.reshape(v.shape)
406 | elif "proj_mlp" in lora_name:
407 | mlp += update.reshape(mlp.shape)
408 | weight = torch.cat([q, k, v, mlp], dim=0)
409 | else:
410 | if len(weight.size()) == 2:
411 | weight = weight + ratio * (up_weight @ down_weight) * scale
412 | elif down_weight.size()[2:4] == (1, 1):
413 | weight = (
414 | weight
415 | + ratio
416 | * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
417 | * scale
418 | )
419 | else:
420 | conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
421 | weight = weight + ratio * conved * scale
422 | else:
423 | if len(weight.size()) == 2:
424 | weight = weight + ratio * (up_weight @ down_weight) * scale
425 | elif down_weight.size()[2:4] == (1, 1):
426 | weight = (
427 | weight
428 | + ratio
429 | * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
430 | * scale
431 | )
432 | else:
433 | conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
434 | weight = weight + ratio * conved * scale
435 |
436 | flux_state_dict[flux_key] = weight.to(loading_device, save_dtype)
437 | merged_keys.add(flux_key)
438 | del up_weight
439 | del down_weight
440 | del weight
441 |
442 | logger.info(f"Merged keys: {sorted(list(merged_keys))}")
443 | return flux_state_dict
444 |
445 |
446 | def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
447 | base_alphas = {} # alpha for merged model
448 | base_dims = {}
449 |
450 | merged_sd = {}
451 | base_model = None
452 | for model, ratio in zip(models, ratios):
453 | logger.info(f"loading: {model}")
454 | lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
455 |
456 | if lora_metadata is not None:
457 | if base_model is None:
458 | base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
459 |
460 | # get alpha and dim
461 | alphas = {} # alpha for current model
462 | dims = {} # dims for current model
463 | for key in lora_sd.keys():
464 | if "alpha" in key:
465 | lora_module_name = key[: key.rfind(".alpha")]
466 | alpha = float(lora_sd[key].detach().numpy())
467 | alphas[lora_module_name] = alpha
468 | if lora_module_name not in base_alphas:
469 | base_alphas[lora_module_name] = alpha
470 | elif "lora_down" in key:
471 | lora_module_name = key[: key.rfind(".lora_down")]
472 | dim = lora_sd[key].size()[0]
473 | dims[lora_module_name] = dim
474 | if lora_module_name not in base_dims:
475 | base_dims[lora_module_name] = dim
476 |
477 | for lora_module_name in dims.keys():
478 | if lora_module_name not in alphas:
479 | alpha = dims[lora_module_name]
480 | alphas[lora_module_name] = alpha
481 | if lora_module_name not in base_alphas:
482 | base_alphas[lora_module_name] = alpha
483 |
484 | logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
485 |
486 | # merge
487 | logger.info("merging...")
488 | for key in tqdm(lora_sd.keys()):
489 | if "alpha" in key:
490 | continue
491 |
492 | if "lora_up" in key and concat:
493 | concat_dim = 1
494 | elif "lora_down" in key and concat:
495 | concat_dim = 0
496 | else:
497 | concat_dim = None
498 |
499 | lora_module_name = key[: key.rfind(".lora_")]
500 |
501 | base_alpha = base_alphas[lora_module_name]
502 | alpha = alphas[lora_module_name]
503 |
504 | scale = math.sqrt(alpha / base_alpha) * ratio
505 | scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
506 |
507 | if key in merged_sd:
508 | assert (
509 | merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
510 | ), "weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。"
511 | if concat_dim is not None:
512 | merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
513 | else:
514 | merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
515 | else:
516 | merged_sd[key] = lora_sd[key] * scale
517 |
518 | # set alpha to sd
519 | for lora_module_name, alpha in base_alphas.items():
520 | key = lora_module_name + ".alpha"
521 | merged_sd[key] = torch.tensor(alpha)
522 | if shuffle:
523 | key_down = lora_module_name + ".lora_down.weight"
524 | key_up = lora_module_name + ".lora_up.weight"
525 | dim = merged_sd[key_down].shape[0]
526 | perm = torch.randperm(dim)
527 | merged_sd[key_down] = merged_sd[key_down][perm]
528 | merged_sd[key_up] = merged_sd[key_up][:, perm]
529 |
530 | logger.info("merged model")
531 | logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
532 |
533 | # check all dims are same
534 | dims_list = list(set(base_dims.values()))
535 | alphas_list = list(set(base_alphas.values()))
536 | all_same_dims = True
537 | all_same_alphas = True
538 | for dims in dims_list:
539 | if dims != dims_list[0]:
540 | all_same_dims = False
541 | break
542 | for alphas in alphas_list:
543 | if alphas != alphas_list[0]:
544 | all_same_alphas = False
545 | break
546 |
547 | # build minimum metadata
548 | dims = f"{dims_list[0]}" if all_same_dims else "Dynamic"
549 | alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic"
550 | metadata = train_util.build_minimum_network_metadata(str(False), base_model, "networks.lora_flux", dims, alphas, None)
551 |
552 | return merged_sd, metadata
553 |
554 |
555 | def merge(args):
556 | if args.models is None:
557 | args.models = []
558 | if args.ratios is None:
559 | args.ratios = []
560 |
561 | assert len(args.models) == len(
562 | args.ratios
563 | ), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
564 |
565 | merge_dtype = str_to_dtype(args.precision)
566 | save_dtype = str_to_dtype(args.save_precision)
567 | if save_dtype is None:
568 | save_dtype = merge_dtype
569 |
570 | assert (
571 | args.save_to or args.clip_l_save_to or args.t5xxl_save_to
572 | ), "save_to or clip_l_save_to or t5xxl_save_to must be specified / save_toまたはclip_l_save_toまたはt5xxl_save_toを指定してください"
573 | dest_dir = os.path.dirname(args.save_to or args.clip_l_save_to or args.t5xxl_save_to)
574 | if not os.path.exists(dest_dir):
575 | logger.info(f"creating directory: {dest_dir}")
576 | os.makedirs(dest_dir)
577 |
578 | if args.flux_model is not None or args.clip_l is not None or args.t5xxl is not None:
579 | if not args.diffusers:
580 | assert (args.clip_l is None and args.clip_l_save_to is None) or (
581 | args.clip_l is not None and args.clip_l_save_to is not None
582 | ), "clip_l_save_to must be specified if clip_l is specified / clip_lが指定されている場合はclip_l_save_toも指定してください"
583 | assert (args.t5xxl is None and args.t5xxl_save_to is None) or (
584 | args.t5xxl is not None and args.t5xxl_save_to is not None
585 | ), "t5xxl_save_to must be specified if t5xxl is specified / t5xxlが指定されている場合はt5xxl_save_toも指定してください"
586 | flux_state_dict, clip_l_state_dict, t5xxl_state_dict = merge_to_flux_model(
587 | args.loading_device,
588 | args.working_device,
589 | args.flux_model,
590 | args.clip_l,
591 | args.t5xxl,
592 | args.models,
593 | args.ratios,
594 | merge_dtype,
595 | save_dtype,
596 | args.mem_eff_load_save,
597 | )
598 | else:
599 | assert (
600 | args.clip_l is None and args.t5xxl is None
601 | ), "clip_l and t5xxl are not supported with --diffusers / clip_l、t5xxlはDiffusersではサポートされていません"
602 | flux_state_dict = merge_to_flux_model_diffusers(
603 | args.loading_device,
604 | args.working_device,
605 | args.flux_model,
606 | args.models,
607 | args.ratios,
608 | merge_dtype,
609 | save_dtype,
610 | args.mem_eff_load_save,
611 | )
612 | clip_l_state_dict = None
613 | t5xxl_state_dict = None
614 |
615 | if args.no_metadata or (flux_state_dict is None or len(flux_state_dict) == 0):
616 | sai_metadata = None
617 | else:
618 | merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models)
619 | title = os.path.splitext(os.path.basename(args.save_to))[0]
620 | sai_metadata = sai_model_spec.build_metadata(
621 | None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev"
622 | )
623 |
624 | if flux_state_dict is not None and len(flux_state_dict) > 0:
625 | logger.info(f"saving FLUX model to: {args.save_to}")
626 | save_to_file(args.save_to, flux_state_dict, save_dtype, sai_metadata, args.mem_eff_load_save)
627 |
628 | if clip_l_state_dict is not None and len(clip_l_state_dict) > 0:
629 | logger.info(f"saving clip_l model to: {args.clip_l_save_to}")
630 | save_to_file(args.clip_l_save_to, clip_l_state_dict, save_dtype, None, args.mem_eff_load_save)
631 |
632 | if t5xxl_state_dict is not None and len(t5xxl_state_dict) > 0:
633 | logger.info(f"saving t5xxl model to: {args.t5xxl_save_to}")
634 | save_to_file(args.t5xxl_save_to, t5xxl_state_dict, save_dtype, None, args.mem_eff_load_save)
635 |
636 | else:
637 | flux_state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
638 |
639 | logger.info("calculating hashes and creating metadata...")
640 |
641 | model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(flux_state_dict, metadata)
642 | metadata["sshs_model_hash"] = model_hash
643 | metadata["sshs_legacy_hash"] = legacy_hash
644 |
645 | if not args.no_metadata:
646 | merged_from = sai_model_spec.build_merged_from(args.models)
647 | title = os.path.splitext(os.path.basename(args.save_to))[0]
648 | sai_metadata = sai_model_spec.build_metadata(
649 | flux_state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev"
650 | )
651 | metadata.update(sai_metadata)
652 |
653 | logger.info(f"saving model to: {args.save_to}")
654 | save_to_file(args.save_to, flux_state_dict, save_dtype, metadata)
655 |
656 |
657 | def setup_parser() -> argparse.ArgumentParser:
658 | parser = argparse.ArgumentParser()
659 | parser.add_argument(
660 | "--save_precision",
661 | type=str,
662 | default=None,
663 | help="precision in saving, same to merging if omitted. supported types: "
664 | "float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz"
665 | " / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
666 | )
667 | parser.add_argument(
668 | "--precision",
669 | type=str,
670 | default="float",
671 | help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
672 | )
673 | parser.add_argument(
674 | "--flux_model",
675 | type=str,
676 | default=None,
677 | help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする",
678 | )
679 | parser.add_argument(
680 | "--clip_l",
681 | type=str,
682 | default=None,
683 | help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)",
684 | )
685 | parser.add_argument(
686 | "--t5xxl",
687 | type=str,
688 | default=None,
689 | help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)",
690 | )
691 | parser.add_argument(
692 | "--mem_eff_load_save",
693 | action="store_true",
694 | help="use custom memory efficient load and save functions for FLUX.1 model"
695 | " / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する",
696 | )
697 | parser.add_argument(
698 | "--loading_device",
699 | type=str,
700 | default="cpu",
701 | help="device to load FLUX.1 model. LoRA models are loaded on CPU / FLUX.1モデルを読み込むデバイス。LoRAモデルはCPUで読み込まれます",
702 | )
703 | parser.add_argument(
704 | "--working_device",
705 | type=str,
706 | default="cpu",
707 | help="device to work (merge). Merging LoRA models are done on CPU."
708 | + " / 作業(マージ)するデバイス。LoRAモデルのマージはCPUで行われます。",
709 | )
710 | parser.add_argument(
711 | "--save_to",
712 | type=str,
713 | default=None,
714 | help="destination file name: safetensors file / 保存先のファイル名、safetensorsファイル",
715 | )
716 | parser.add_argument(
717 | "--clip_l_save_to",
718 | type=str,
719 | default=None,
720 | help="destination file name for clip_l: safetensors file / clip_lの保存先のファイル名、safetensorsファイル",
721 | )
722 | parser.add_argument(
723 | "--t5xxl_save_to",
724 | type=str,
725 | default=None,
726 | help="destination file name for t5xxl: safetensors file / t5xxlの保存先のファイル名、safetensorsファイル",
727 | )
728 | parser.add_argument(
729 | "--models",
730 | type=str,
731 | nargs="*",
732 | help="LoRA models to merge: safetensors file / マージするLoRAモデル、safetensorsファイル",
733 | )
734 | parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
735 | parser.add_argument(
736 | "--no_metadata",
737 | action="store_true",
738 | help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
739 | + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
740 | )
741 | parser.add_argument(
742 | "--concat",
743 | action="store_true",
744 | help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
745 | + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)",
746 | )
747 | parser.add_argument(
748 | "--shuffle",
749 | action="store_true",
750 | help="shuffle lora weight./ " + "LoRAの重みをシャッフルする",
751 | )
752 | parser.add_argument(
753 | "--diffusers",
754 | action="store_true",
755 | help="merge Diffusers (?) LoRA models / Diffusers (?) LoRAモデルをマージする",
756 | )
757 |
758 | return parser
759 |
760 |
761 | if __name__ == "__main__":
762 | parser = setup_parser()
763 |
764 | args = parser.parse_args()
765 | merge(args)
766 |
--------------------------------------------------------------------------------
/lora_algebra/operations.py:
--------------------------------------------------------------------------------
1 | """
2 | LoRA mathematical operations (difference, merge, etc.)
3 | """
4 |
5 | import os
6 | import sys
7 | import subprocess
8 | from typing import Optional, Tuple, List
9 | from safetensors import safe_open
10 | from safetensors.torch import save_file
11 | import torch
12 |
13 |
14 | class LoRAProcessor:
15 | """Helper class for LoRA operations using sd-scripts"""
16 |
17 | def __init__(self, sd_scripts_path: Optional[str] = None):
18 | self.sd_scripts_path = sd_scripts_path or resolve_path_without_quotes("../sd-scripts")
19 |
20 | def _run_sd_script(self, script_name: str, args: List[str]) -> Tuple[bool, str]:
21 | """Run an sd-scripts script with given arguments"""
22 | script_path = os.path.join(self.sd_scripts_path, "networks", script_name)
23 |
24 | if not os.path.exists(script_path):
25 | return False, f"Script not found: {script_path}"
26 |
27 | command = [sys.executable, script_path] + args
28 |
29 | try:
30 | project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
31 |
32 | # Set up environment with sd-scripts in Python path
33 | env = os.environ.copy()
34 | if 'PYTHONPATH' in env:
35 | env['PYTHONPATH'] = f"{self.sd_scripts_path}{os.pathsep}{env['PYTHONPATH']}"
36 | else:
37 | env['PYTHONPATH'] = self.sd_scripts_path
38 |
39 | result = subprocess.run(
40 | command,
41 | capture_output=True,
42 | text=True,
43 | cwd=project_root,
44 | env=env
45 | )
46 |
47 | if result.returncode == 0:
48 | return True, result.stdout
49 | else:
50 | error_msg = result.stderr or result.stdout
51 | if "ModuleNotFoundError: No module named 'library'" in error_msg:
52 | error_msg += f"\n\nTroubleshooting: sd-scripts library not found. Try:\n"
53 | error_msg += f"1. Ensure sd-scripts is properly installed\n"
54 | error_msg += f"2. Check that sd-scripts path is correct: {self.sd_scripts_path}\n"
55 | error_msg += f"3. Reinstall by running the installer again"
56 | return False, error_msg
57 |
58 | except Exception as e:
59 | return False, f"Error running script: {str(e)}"
60 |
61 | def extract_metadata(self, lora_path: str) -> dict:
62 | """Extract metadata from LoRA file"""
63 | try:
64 | with safe_open(lora_path, framework="pt", device="cpu") as f:
65 | metadata = {}
66 | if f.metadata():
67 | metadata.update(f.metadata())
68 | return metadata
69 | except Exception as e:
70 | return {"error": f"Could not read metadata: {str(e)}"}
71 |
72 | def resolve_path_without_quotes(p):
73 | """Copy of custom.py's path resolution"""
74 | current_dir = os.path.dirname(os.path.abspath(__file__))
75 | norm_path = os.path.normpath(os.path.join(current_dir, p))
76 | return norm_path
77 |
78 | def subtract_loras(
79 | lora_a_path: str,
80 | lora_b_path: str,
81 | output_path: str,
82 | strength_a: float = 1.0,
83 | strength_b: float = 1.0,
84 | use_concat: bool = True,
85 | sd_scripts_path: Optional[str] = None
86 | ) -> Tuple[bool, str]:
87 | """Extract difference between two LoRAs (A - B) using negative weights - copied from custom.py"""
88 |
89 | # Validate inputs
90 | if not lora_a_path or not lora_b_path:
91 | return False, "Error: Please provide paths for both LoRAs"
92 |
93 | if not os.path.exists(lora_a_path):
94 | return False, f"Error: LoRA A file not found: {lora_a_path}"
95 |
96 | if not os.path.exists(lora_b_path):
97 | return False, f"Error: LoRA B file not found: {lora_b_path}"
98 |
99 | if not output_path:
100 | return False, "Error: Please provide an output path"
101 |
102 | # Create output directory
103 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
104 |
105 | # Build command for LoRA difference extraction using negative weight - EXACT copy from custom.py
106 | script_path = resolve_path_without_quotes("../sd-scripts/networks/flux_merge_lora.py")
107 |
108 | # Use positive weight for A and negative weight for B to get A - B
109 | command = [
110 | sys.executable,
111 | script_path,
112 | "--save_to", output_path,
113 | "--models", lora_a_path, lora_b_path,
114 | "--ratios", str(strength_a), str(-strength_b), # Negative for difference
115 | "--save_precision", "fp16"
116 | ]
117 |
118 | # Add concat flag if selected
119 | if use_concat:
120 | command.append("--concat")
121 |
122 | try:
123 | # Run the difference extraction command - but use project root as working directory
124 | project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
125 |
126 | # Set up environment with sd-scripts in Python path
127 | sd_scripts_dir = sd_scripts_path or resolve_path_without_quotes("../sd-scripts")
128 | env = os.environ.copy()
129 | if 'PYTHONPATH' in env:
130 | env['PYTHONPATH'] = f"{sd_scripts_dir}{os.pathsep}{env['PYTHONPATH']}"
131 | else:
132 | env['PYTHONPATH'] = sd_scripts_dir
133 |
134 | result = subprocess.run(
135 | command,
136 | capture_output=True,
137 | text=True,
138 | cwd=project_root,
139 | env=env
140 | )
141 |
142 | if result.returncode == 0:
143 | return True, f"Success! Difference LoRA saved to: {output_path}\n\nThis LoRA represents: A (strength {strength_a}) - B (strength {strength_b})\n\nOutput:\n{result.stdout}"
144 | else:
145 | error_msg = result.stderr or result.stdout
146 | if "ModuleNotFoundError: No module named 'library'" in error_msg:
147 | error_msg += f"\n\nTroubleshooting: sd-scripts library not found. Try:\n"
148 | error_msg += f"1. Ensure sd-scripts is properly installed\n"
149 | error_msg += f"2. Check that sd-scripts path is correct: {sd_scripts_dir}\n"
150 | error_msg += f"3. Reinstall by running the installer again"
151 | return False, f"Error during difference extraction:\n{error_msg}"
152 |
153 | except Exception as e:
154 | return False, f"Error running difference extraction: {str(e)}"
155 |
156 | def merge_loras(
157 | lora_a_path: str,
158 | lora_b_path: str,
159 | output_path: str,
160 | strength_a: float = 1.0,
161 | strength_b: float = 1.0,
162 | use_concat: bool = True,
163 | sd_scripts_path: Optional[str] = None
164 | ) -> Tuple[bool, str]:
165 | """Merge two LoRAs with positive weights (A + B)
166 |
167 | Args:
168 | lora_a_path: Path to LoRA A
169 | lora_b_path: Path to LoRA B
170 | output_path: Path for output LoRA
171 | strength_a: Strength multiplier for LoRA A
172 | strength_b: Strength multiplier for LoRA B
173 | use_concat: Whether to use concat mode for different ranks
174 | sd_scripts_path: Custom path to sd-scripts directory
175 |
176 | Returns:
177 | Tuple of (success: bool, message: str)
178 | """
179 | processor = LoRAProcessor(sd_scripts_path)
180 |
181 | # Validate inputs
182 | if not os.path.exists(lora_a_path):
183 | return False, f"LoRA A file not found: {lora_a_path}"
184 |
185 | if not os.path.exists(lora_b_path):
186 | return False, f"LoRA B file not found: {lora_b_path}"
187 |
188 | # Create output directory
189 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
190 |
191 | # Build merge command
192 | args = [
193 | "--save_to", output_path,
194 | "--models", lora_a_path, lora_b_path,
195 | "--ratios", str(strength_a), str(strength_b),
196 | "--save_precision", "fp16"
197 | ]
198 |
199 | # Add concat flag if selected
200 | if use_concat:
201 | args.append("--concat")
202 |
203 | # Run the merge
204 | success, output = processor._run_sd_script("flux_merge_lora.py", args)
205 |
206 | if success and os.path.exists(output_path):
207 | return True, f"Success! Merged LoRA saved to: {output_path}\\n\\nCombined: A (strength {strength_a}) + B (strength {strength_b})"
208 | else:
209 | return False, f"Error during merge: {output}"
210 |
211 | def analyze_lora(lora_path: str, sd_scripts_path: Optional[str] = None) -> dict:
212 | """Analyze a LoRA file and return detailed information
213 |
214 | Args:
215 | lora_path: Path to LoRA file
216 | sd_scripts_path: Custom path to sd-scripts directory
217 |
218 | Returns:
219 | Dictionary containing analysis results
220 | """
221 | processor = LoRAProcessor(sd_scripts_path)
222 |
223 | if not os.path.exists(lora_path):
224 | return {"error": f"File not found: {lora_path}"}
225 |
226 | # Extract metadata
227 | metadata = processor.extract_metadata(lora_path)
228 |
229 | if not metadata:
230 | return {"error": "Could not extract metadata from LoRA file"}
231 |
232 | # File size analysis
233 | file_size = os.path.getsize(lora_path)
234 | file_size_mb = file_size / (1024 * 1024)
235 |
236 | # Rank analysis
237 | rank = metadata.get('network_dim', 32)
238 | alpha = metadata.get('network_alpha', 32.0)
239 |
240 | # Calculate approximate parameter count
241 | # This is a rough estimate for Flux LoRAs
242 | approx_params = rank * rank * 50 # Rough estimate
243 |
244 | analysis = {
245 | "file_path": lora_path,
246 | "file_size_mb": round(file_size_mb, 2),
247 | "rank": rank,
248 | "alpha": alpha,
249 | "learning_rate": metadata.get('learning_rate', 'Unknown'),
250 | "base_model": metadata.get('base_model', 'Unknown'),
251 | "training_comment": metadata.get('training_comment', ''),
252 | "estimated_parameters": approx_params,
253 | "rank_efficiency": f"{alpha/rank:.2f}" if rank > 0 else "N/A",
254 | "metadata": metadata
255 | }
256 |
257 | return analysis
258 |
259 | def target_lora_layers(
260 | lora_path: str,
261 | output_path: str,
262 | mute_layers: List[int],
263 | sd_scripts_path: Optional[str] = None
264 | ) -> Tuple[bool, str]:
265 | """Mute specific FLUX layers in a LoRA by setting their weights to zero
266 |
267 | Args:
268 | lora_path: Path to input LoRA file
269 | output_path: Path for output LoRA file
270 | mute_layers: List of layer numbers to mute (e.g., [7, 20])
271 | sd_scripts_path: Not used but kept for consistency
272 |
273 | Returns:
274 | Tuple of (success: bool, message: str)
275 | """
276 | try:
277 | # Validate input
278 | if not os.path.exists(lora_path):
279 | return False, f"LoRA file not found: {lora_path}"
280 |
281 | if not mute_layers:
282 | return False, "No layers selected to mute"
283 |
284 | # Create output directory
285 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
286 |
287 | print(f"🔍 Loading LoRA from: {lora_path}")
288 | print(f"🎯 Targeting layers: {mute_layers}")
289 |
290 | # Load the LoRA file and analyze structure
291 | tensors = {}
292 | metadata = {}
293 | total_tensors = 0
294 | muted_tensors = 0
295 | layer_analysis = {}
296 |
297 | with safe_open(lora_path, framework="pt", device="cpu") as f:
298 | # Copy metadata
299 | if f.metadata():
300 | metadata.update(f.metadata())
301 |
302 | # First pass: analyze tensor structure
303 | all_keys = list(f.keys())
304 | total_tensors = len(all_keys)
305 | print(f"📊 Total tensors in LoRA: {total_tensors}")
306 |
307 | # Debug: Print ALL tensor names to understand the naming pattern
308 | print(f"🔍 DEBUG: ALL {total_tensors} tensor names:")
309 | for i, key in enumerate(all_keys):
310 | print(f" {i+1:4d}. {key}")
311 |
312 | # Also analyze unique patterns
313 | unique_prefixes = set()
314 | for key in all_keys:
315 | parts = key.split('.')
316 | if len(parts) >= 2:
317 | unique_prefixes.add('.'.join(parts[:2]))
318 |
319 | print(f"🔍 DEBUG: Unique tensor prefixes found:")
320 | for prefix in sorted(unique_prefixes):
321 | count = len([k for k in all_keys if k.startswith(prefix)])
322 | print(f" {prefix}: {count} tensors")
323 |
324 | # Analyze layer distribution with FLUX-specific patterns
325 | for key in all_keys:
326 | for layer_num in range(50): # Check layers 0-49
327 | # FLUX LoRA naming patterns
328 | patterns = [
329 | f"_layers_{layer_num}_", # lora_te1_text_model_encoder_layers_7_
330 | f"_blocks_{layer_num}_", # lora_unet_double_blocks_7_, lora_unet_single_blocks_7_
331 | f"single_blocks_{layer_num}_", # lora_unet_single_blocks_7_
332 | f"double_blocks_{layer_num}_", # lora_unet_double_blocks_7_
333 | # Legacy patterns for compatibility
334 | f"single_transformer_blocks.{layer_num}.",
335 | f"transformer.single_transformer_blocks.{layer_num}.",
336 | f"transformer_blocks.{layer_num}.",
337 | f"blocks.{layer_num}.",
338 | f"layer.{layer_num}.",
339 | f"layers.{layer_num}."
340 | ]
341 |
342 | found_match = False
343 | for pattern in patterns:
344 | if pattern in key:
345 | if layer_num not in layer_analysis:
346 | layer_analysis[layer_num] = []
347 | layer_analysis[layer_num].append(key)
348 | found_match = True
349 | break
350 |
351 | if found_match:
352 | break
353 |
354 | print(f"📈 Layer distribution found:")
355 | for layer_num in sorted(layer_analysis.keys()):
356 | tensor_count = len(layer_analysis[layer_num])
357 | is_target = layer_num in mute_layers
358 | status = "🎯 TARGET" if is_target else "✅ keep"
359 | print(f" Layer {layer_num}: {tensor_count} tensors {status}")
360 |
361 | # Second pass: process each tensor
362 | for key in all_keys:
363 | tensor = f.get_tensor(key)
364 |
365 | # Check if this tensor belongs to a layer we want to mute
366 | should_mute = False
367 | matched_layer = None
368 |
369 | for layer_num in mute_layers:
370 | # Use FLUX LoRA naming patterns (same as analysis)
371 | layer_patterns = [
372 | f"_layers_{layer_num}_", # lora_te1_text_model_encoder_layers_7_
373 | f"_blocks_{layer_num}_", # lora_unet_double_blocks_7_, lora_unet_single_blocks_7_
374 | f"single_blocks_{layer_num}_", # lora_unet_single_blocks_7_
375 | f"double_blocks_{layer_num}_", # lora_unet_double_blocks_7_
376 | # Legacy patterns for compatibility
377 | f"single_transformer_blocks.{layer_num}.",
378 | f"transformer.single_transformer_blocks.{layer_num}.",
379 | f"transformer_blocks.{layer_num}.",
380 | f"blocks.{layer_num}.",
381 | f"layer.{layer_num}.",
382 | f"layers.{layer_num}."
383 | ]
384 |
385 | if any(pattern in key for pattern in layer_patterns):
386 | should_mute = True
387 | matched_layer = layer_num
388 | muted_tensors += 1
389 | print(f"🔇 Muting L{layer_num}: {key}")
390 | break
391 |
392 | if should_mute:
393 | # Set tensor to zeros (mute the layer)
394 | tensors[key] = torch.zeros_like(tensor)
395 | else:
396 | # Keep original tensor
397 | tensors[key] = tensor.clone()
398 |
399 | print(f"✅ Processing complete: {muted_tensors}/{total_tensors} tensors muted")
400 |
401 | # Add targeting info to metadata
402 | metadata["lora_algebra_targeted_layers"] = ",".join(map(str, mute_layers))
403 | metadata["lora_algebra_operation"] = "layer_targeting"
404 |
405 | # Save the modified LoRA
406 | print(f"💾 Saving modified LoRA to: {output_path}")
407 | save_file(tensors, output_path, metadata=metadata)
408 |
409 | # Create detailed report
410 | layer_report = ""
411 | for layer_num in sorted(layer_analysis.keys()):
412 | tensor_count = len(layer_analysis[layer_num])
413 | is_muted = layer_num in mute_layers
414 | status = "🔇 MUTED" if is_muted else "✅ preserved"
415 | layer_report += f"Layer {layer_num}: {tensor_count} tensors - {status}\n"
416 |
417 | report = f"""✅ Layer targeting completed successfully!
418 |
419 | 📊 ANALYSIS REPORT:
420 | Total tensors processed: {total_tensors}
421 | Tensors muted: {muted_tensors}
422 | Layers targeted: {mute_layers}
423 |
424 | 📈 LAYER BREAKDOWN:
425 | {layer_report}
426 | 💾 Output saved to: {output_path}
427 |
428 | 🎯 OPERATION SUMMARY:
429 | The selected facial layers have been zeroed out while preserving all other characteristics. This LoRA should now work without the original facial features while maintaining style, poses, and costume elements."""
430 |
431 | return True, report
432 |
433 | except Exception as e:
434 | return False, f"Error during layer targeting: {str(e)}"
435 |
436 | def selective_layer_merge(
437 | lora_a_path: str,
438 | lora_b_path: str,
439 | output_path: str,
440 | layers_from_a: List[int],
441 | layers_from_b: List[int],
442 | strength_a: float = 1.0,
443 | strength_b: float = 1.0,
444 | sd_scripts_path: Optional[str] = None
445 | ) -> Tuple[bool, str]:
446 | """Selectively merge specific layers from two LoRAs with strength control
447 |
448 | Args:
449 | lora_a_path: Path to LoRA A
450 | lora_b_path: Path to LoRA B
451 | output_path: Path for output LoRA
452 | layers_from_a: List of layer numbers to take from LoRA A
453 | layers_from_b: List of layer numbers to take from LoRA B
454 | strength_a: Strength multiplier for all layers from LoRA A (default 1.0)
455 | strength_b: Strength multiplier for all layers from LoRA B (default 1.0)
456 | sd_scripts_path: Not used but kept for consistency
457 |
458 | Returns:
459 | Tuple of (success: bool, message: str)
460 | """
461 | try:
462 | # Validate inputs
463 | if not os.path.exists(lora_a_path):
464 | return False, f"LoRA A file not found: {lora_a_path}"
465 |
466 | if not os.path.exists(lora_b_path):
467 | return False, f"LoRA B file not found: {lora_b_path}"
468 |
469 | if not layers_from_a and not layers_from_b:
470 | return False, "No layers selected for merging"
471 |
472 | # Create output directory
473 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
474 |
475 | print(f"🔀 Loading LoRAs for selective merge")
476 | print(f"📁 LoRA A: {lora_a_path} (strength: {strength_a})")
477 | print(f"📁 LoRA B: {lora_b_path} (strength: {strength_b})")
478 | print(f"🔵 Layers from A: {layers_from_a}")
479 | print(f"🔴 Layers from B: {layers_from_b}")
480 |
481 | # Load both LoRAs
482 | merged_tensors = {}
483 | merged_metadata = {}
484 |
485 | # Statistics
486 | tensors_from_a = 0
487 | tensors_from_b = 0
488 | total_tensors = 0
489 |
490 | # First, load LoRA A and copy selected layers
491 | with safe_open(lora_a_path, framework="pt", device="cpu") as f_a:
492 | # Copy metadata from A as base
493 | if f_a.metadata():
494 | merged_metadata.update(f_a.metadata())
495 |
496 | all_keys_a = list(f_a.keys())
497 | print(f"📊 LoRA A has {len(all_keys_a)} tensors")
498 |
499 | for key in all_keys_a:
500 | tensor = f_a.get_tensor(key)
501 | total_tensors += 1
502 |
503 | # Check if this tensor belongs to a layer we want from A
504 | should_include = False
505 | for layer_num in layers_from_a:
506 | # Use same layer patterns as targeting
507 | layer_patterns = [
508 | f"_layers_{layer_num}_", # lora_te1_text_model_encoder_layers_7_
509 | f"_blocks_{layer_num}_", # lora_unet_double_blocks_7_, lora_unet_single_blocks_7_
510 | f"single_blocks_{layer_num}_", # lora_unet_single_blocks_7_
511 | f"double_blocks_{layer_num}_", # lora_unet_double_blocks_7_
512 | ]
513 |
514 | if any(pattern in key for pattern in layer_patterns):
515 | should_include = True
516 | tensors_from_a += 1
517 | print(f"🔵 From A L{layer_num}: {key}")
518 | break
519 |
520 | if should_include:
521 | # Apply strength multiplier for LoRA A
522 | modified_tensor = tensor.clone() * strength_a
523 | merged_tensors[key] = modified_tensor
524 |
525 | # Then, load LoRA B and copy selected layers
526 | with safe_open(lora_b_path, framework="pt", device="cpu") as f_b:
527 | all_keys_b = list(f_b.keys())
528 | print(f"📊 LoRA B has {len(all_keys_b)} tensors")
529 |
530 | for key in all_keys_b:
531 | tensor = f_b.get_tensor(key)
532 |
533 | # Check if this tensor belongs to a layer we want from B
534 | should_include = False
535 | for layer_num in layers_from_b:
536 | # Use same layer patterns as targeting
537 | layer_patterns = [
538 | f"_layers_{layer_num}_", # lora_te1_text_model_encoder_layers_7_
539 | f"_blocks_{layer_num}_", # lora_unet_double_blocks_7_, lora_unet_single_blocks_7_
540 | f"single_blocks_{layer_num}_", # lora_unet_single_blocks_7_
541 | f"double_blocks_{layer_num}_", # lora_unet_double_blocks_7_
542 | ]
543 |
544 | if any(pattern in key for pattern in layer_patterns):
545 | should_include = True
546 | tensors_from_b += 1
547 | print(f"🔴 From B L{layer_num}: {key}")
548 | break
549 |
550 | if should_include:
551 | # Check for conflicts (same tensor name from both LoRAs)
552 | if key in merged_tensors:
553 | print(f"⚠️ Conflict detected: {key} exists in both LoRAs, using version from B")
554 | # Apply strength multiplier for LoRA B
555 | modified_tensor = tensor.clone() * strength_b
556 | merged_tensors[key] = modified_tensor
557 |
558 | print(f"✅ Merge statistics:")
559 | print(f" Tensors from A: {tensors_from_a}")
560 | print(f" Tensors from B: {tensors_from_b}")
561 | print(f" Total merged: {len(merged_tensors)}")
562 |
563 | if len(merged_tensors) == 0:
564 | return False, "No tensors were selected for merging. Check layer selection."
565 |
566 | # Add merge info to metadata
567 | merged_metadata["lora_algebra_merge_type"] = "selective_layer_merge"
568 | merged_metadata["lora_algebra_source_a"] = os.path.basename(lora_a_path)
569 | merged_metadata["lora_algebra_source_b"] = os.path.basename(lora_b_path)
570 | merged_metadata["lora_algebra_layers_from_a"] = ",".join(map(str, layers_from_a))
571 | merged_metadata["lora_algebra_layers_from_b"] = ",".join(map(str, layers_from_b))
572 | merged_metadata["lora_algebra_strength_a"] = str(strength_a)
573 | merged_metadata["lora_algebra_strength_b"] = str(strength_b)
574 |
575 | # Save the merged LoRA
576 | print(f"💾 Saving merged LoRA to: {output_path}")
577 | save_file(merged_tensors, output_path, metadata=merged_metadata)
578 |
579 | # Create detailed report
580 | report = f"""✅ Selective layer merge completed successfully!
581 |
582 | 📊 MERGE STATISTICS:
583 | Source A: {os.path.basename(lora_a_path)} @ {strength_a}x strength
584 | Source B: {os.path.basename(lora_b_path)} @ {strength_b}x strength
585 | Tensors from A: {tensors_from_a}
586 | Tensors from B: {tensors_from_b}
587 | Total merged tensors: {len(merged_tensors)}
588 |
589 | 🔵 Layers from A: {layers_from_a}
590 | 🔴 Layers from B: {layers_from_b}
591 | ⚡ Strength multipliers: A={strength_a}x, B={strength_b}x
592 |
593 | 💾 Output saved to: {output_path}
594 |
595 | 🎯 HYBRID LoRA CREATED:
596 | This LoRA combines the best aspects of both source LoRAs with surgical precision and custom strength control. Each layer was carefully selected and scaled to create the perfect hybrid for your specific use case."""
597 |
598 | return True, report
599 |
600 | except Exception as e:
601 | return False, f"Error during selective layer merge: {str(e)}"
602 |
603 | def deep_layer_analysis(
604 | lora_path: str,
605 | user_goal: str = "Keep maximum flexibility (works with any style)",
606 | sd_scripts_path: Optional[str] = None
607 | ) -> Tuple[bool, dict]:
608 | """Perform deep analysis of LoRA layer patterns and generate recommendations
609 |
610 | Args:
611 | lora_path: Path to LoRA file
612 | user_goal: User's intended use case for tailored recommendations
613 | sd_scripts_path: Not used but kept for consistency
614 |
615 | Returns:
616 | Tuple of (success: bool, analysis_dict: dict)
617 | """
618 | try:
619 | import numpy as np
620 |
621 | if not os.path.exists(lora_path):
622 | return False, {"error": f"LoRA file not found: {lora_path}"}
623 |
624 | print(f"🧠 Starting deep layer analysis: {lora_path}")
625 | print(f"🎯 User goal: {user_goal}")
626 |
627 | # Initialize analysis structures
628 | layer_stats = {}
629 | te_layers = {} # Text Encoder layers 0-11
630 | double_layers = {} # Double Block layers 0-19
631 | single_layers = {} # Single Block layers 0-37
632 |
633 | all_tensors = []
634 | total_tensors = 0
635 |
636 | # Load and analyze all tensors
637 | with safe_open(lora_path, framework="pt", device="cpu") as f:
638 | all_keys = list(f.keys())
639 | total_tensors = len(all_keys)
640 |
641 | print(f"📊 Analyzing {total_tensors} tensors...")
642 |
643 | for key in all_keys:
644 | tensor = f.get_tensor(key)
645 | all_tensors.append(tensor)
646 |
647 | # Calculate tensor statistics
648 | tensor_stats = {
649 | 'mean_abs': float(torch.mean(torch.abs(tensor)).item()),
650 | 'std': float(torch.std(tensor).item()),
651 | 'max_abs': float(torch.max(torch.abs(tensor)).item()),
652 | 'frobenius_norm': float(torch.norm(tensor, 'fro').item()),
653 | 'sparsity': float((tensor == 0).sum().item() / tensor.numel()),
654 | 'tensor_size': tensor.numel(),
655 | 'shape': list(tensor.shape)
656 | }
657 |
658 | # Classify by layer type and number
659 | layer_num = None
660 | layer_type = None
661 |
662 | # Text Encoder layers
663 | if 'te1_text_model_encoder_layers_' in key:
664 | import re
665 | match = re.search(r'layers_(\d+)_', key)
666 | if match:
667 | layer_num = int(match.group(1))
668 | layer_type = 'te'
669 | if layer_num not in te_layers:
670 | te_layers[layer_num] = {'tensors': [], 'stats': []}
671 | te_layers[layer_num]['tensors'].append(key)
672 | te_layers[layer_num]['stats'].append(tensor_stats)
673 |
674 | # UNet Double Block layers
675 | elif 'unet_double_blocks_' in key:
676 | import re
677 | match = re.search(r'blocks_(\d+)_', key)
678 | if match:
679 | layer_num = int(match.group(1))
680 | layer_type = 'double'
681 | if layer_num not in double_layers:
682 | double_layers[layer_num] = {'tensors': [], 'stats': []}
683 | double_layers[layer_num]['tensors'].append(key)
684 | double_layers[layer_num]['stats'].append(tensor_stats)
685 |
686 | # UNet Single Block layers
687 | elif 'unet_single_blocks_' in key:
688 | import re
689 | match = re.search(r'blocks_(\d+)_', key)
690 | if match:
691 | layer_num = int(match.group(1))
692 | layer_type = 'single'
693 | if layer_num not in single_layers:
694 | single_layers[layer_num] = {'tensors': [], 'stats': []}
695 | single_layers[layer_num]['tensors'].append(key)
696 | single_layers[layer_num]['stats'].append(tensor_stats)
697 |
698 | # Aggregate layer statistics
699 | def aggregate_layer_stats(layer_dict):
700 | aggregated = {}
701 | for layer_num, data in layer_dict.items():
702 | stats_list = data['stats']
703 | if stats_list:
704 | aggregated[layer_num] = {
705 | 'tensor_count': len(stats_list),
706 | 'total_magnitude': sum(s['frobenius_norm'] for s in stats_list),
707 | 'avg_magnitude': sum(s['frobenius_norm'] for s in stats_list) / len(stats_list),
708 | 'max_magnitude': max(s['frobenius_norm'] for s in stats_list),
709 | 'total_parameters': sum(s['tensor_size'] for s in stats_list),
710 | 'avg_sparsity': sum(s['sparsity'] for s in stats_list) / len(stats_list)
711 | }
712 | return aggregated
713 |
714 | te_aggregated = aggregate_layer_stats(te_layers)
715 | double_aggregated = aggregate_layer_stats(double_layers)
716 | single_aggregated = aggregate_layer_stats(single_layers)
717 |
718 | # Detect patterns and anomalies
719 | known_facial_layers = [7, 12, 16, 20]
720 |
721 | # Find layers with unusually high magnitude (potential overtraining)
722 | suspicious_layers = []
723 |
724 | # Check for facial data in non-facial layers
725 | if single_aggregated:
726 | # Calculate baseline from known facial layers
727 | facial_magnitudes = [single_aggregated.get(layer, {}).get('avg_magnitude', 0)
728 | for layer in known_facial_layers if layer in single_aggregated]
729 |
730 | if facial_magnitudes:
731 | facial_baseline = np.mean(facial_magnitudes)
732 | facial_std = np.std(facial_magnitudes) if len(facial_magnitudes) > 1 else facial_baseline * 0.3
733 |
734 | # Check all layers for suspicious activity
735 | for layer_num, stats in single_aggregated.items():
736 | if layer_num not in known_facial_layers:
737 | if stats['avg_magnitude'] > facial_baseline * 0.7: # 70% of facial layer magnitude
738 | confidence = min(100, (stats['avg_magnitude'] / facial_baseline) * 100)
739 | suspicious_layers.append({
740 | 'layer': layer_num,
741 | 'type': 'single',
742 | 'magnitude': stats['avg_magnitude'],
743 | 'confidence': confidence,
744 | 'reason': 'High magnitude suggesting facial data'
745 | })
746 |
747 | # Generate analysis report
748 | analysis = {
749 | 'file_info': {
750 | 'path': lora_path,
751 | 'filename': os.path.basename(lora_path),
752 | 'file_size_mb': round(os.path.getsize(lora_path) / (1024 * 1024), 2),
753 | 'total_tensors': total_tensors
754 | },
755 | 'layer_distribution': {
756 | 'text_encoder': len(te_aggregated),
757 | 'double_blocks': len(double_aggregated),
758 | 'single_blocks': len(single_aggregated)
759 | },
760 | 'layer_stats': {
761 | 'text_encoder': te_aggregated,
762 | 'double_blocks': double_aggregated,
763 | 'single_blocks': single_aggregated
764 | },
765 | 'pattern_analysis': {
766 | 'known_facial_layers': known_facial_layers,
767 | 'suspicious_layers': suspicious_layers,
768 | 'overtraining_detected': len(suspicious_layers) > 0
769 | },
770 | 'user_goal': user_goal,
771 | 'analysis_timestamp': str(torch.rand(1).item()) # Simple timestamp
772 | }
773 |
774 | print(f"✅ Analysis complete: {len(suspicious_layers)} suspicious layers detected")
775 |
776 | return True, analysis
777 |
778 | except Exception as e:
779 | print(f"❌ Error during deep analysis: {str(e)}")
780 | return False, {"error": f"Analysis failed: {str(e)}"}
781 |
782 | def generate_recommendations(analysis_result: dict) -> str:
783 | """Generate contextual recommendations based on analysis and user goals"""
784 |
785 | if "error" in analysis_result:
786 | return f"**❌ Analysis Error**\n\n{analysis_result['error']}"
787 |
788 | user_goal = analysis_result.get('user_goal', 'Keep maximum flexibility (works with any style)')
789 | suspicious_layers = analysis_result.get('pattern_analysis', {}).get('suspicious_layers', [])
790 | overtraining = analysis_result.get('pattern_analysis', {}).get('overtraining_detected', False)
791 |
792 | recommendations = []
793 |
794 | # Header based on analysis
795 | if overtraining:
796 | recommendations.append("## ⚠️ **Overtraining Detected**\n")
797 | recommendations.append(f"Found facial data in {len(suspicious_layers)} non-standard layers.\n")
798 | else:
799 | recommendations.append("## ✅ **Clean LoRA Structure**\n")
800 | recommendations.append("No obvious overtraining patterns detected.\n")
801 |
802 | # Goal-specific recommendations
803 | recommendations.append(f"### 🎯 **Recommendations for: {user_goal}**\n")
804 |
805 | if user_goal == "Keep maximum flexibility (works with any style)":
806 | if overtraining:
807 | recommendations.append("**For Maximum Flexibility:**")
808 | recommendations.append("- Use **🎨 Style Layers** preset to remove facial data from non-facial layers")
809 | recommendations.append("- This will prevent interference when combining with style LoRAs")
810 | recommendations.append(f"- Specifically target layers: {', '.join(str(l['layer']) for l in suspicious_layers)}")
811 | recommendations.append("- **Why:** Facial data in style layers causes conflicts with artistic LoRAs\n")
812 | else:
813 | recommendations.append("**Your LoRA looks clean!**")
814 | recommendations.append("- No changes needed for maximum flexibility")
815 | recommendations.append("- Should work well with most style LoRAs")
816 | recommendations.append("- Consider light targeting (7,12,16,20) only if you experience conflicts\n")
817 |
818 | elif user_goal == "Preserve strong character identity":
819 | if overtraining:
820 | recommendations.append("**For Strong Character Preservation:**")
821 | recommendations.append("- Use **👤🔥 Facial Priority** preset to keep ALL facial data")
822 | recommendations.append("- Include standard layers (7,12,16,20) PLUS detected layers")
823 | recommendations.append(f"- Keep layers: 7,12,16,20,{','.join(str(l['layer']) for l in suspicious_layers)}")
824 | recommendations.append("- **Trade-off:** Strong identity but may conflict with style LoRAs")
825 | recommendations.append("- **Why:** Preserves all facial information for maximum character fidelity\n")
826 | else:
827 | recommendations.append("**For Character Preservation:**")
828 | recommendations.append("- Use standard **👤🎨 Face A + Style B** when merging")
829 | recommendations.append("- Your LoRA has clean facial separation")
830 | recommendations.append("- Should preserve identity well without modifications\n")
831 |
832 | elif user_goal == "Fix overtraining issues":
833 | if overtraining:
834 | recommendations.append("**Overtraining Fix Strategy:**")
835 | recommendations.append("- Use **🔥 Aggressive** preset in Layer Targeting")
836 | recommendations.append("- Target ALL detected problematic layers")
837 | recommendations.append(f"- Mute layers: {', '.join(str(l['layer']) for l in suspicious_layers)}")
838 | for layer_info in suspicious_layers:
839 | recommendations.append(f" - Layer {layer_info['layer']}: {layer_info['reason']} ({layer_info['confidence']:.0f}% confidence)")
840 | recommendations.append("- **Result:** Cleaner LoRA with proper layer separation\n")
841 | else:
842 | recommendations.append("**No Overtraining Detected:**")
843 | recommendations.append("- Your LoRA appears well-trained")
844 | recommendations.append("- No fixing needed")
845 | recommendations.append("- Layer separation looks appropriate\n")
846 |
847 | elif user_goal == "Understand layer distribution":
848 | layer_stats = analysis_result.get('layer_stats', {})
849 | recommendations.append("**Layer Distribution Analysis:**")
850 |
851 | if layer_stats.get('text_encoder'):
852 | recommendations.append(f"- **Text Encoder:** {len(layer_stats['text_encoder'])} active layers")
853 | if layer_stats.get('double_blocks'):
854 | recommendations.append(f"- **Double Blocks:** {len(layer_stats['double_blocks'])} active layers")
855 | if layer_stats.get('single_blocks'):
856 | recommendations.append(f"- **Single Blocks:** {len(layer_stats['single_blocks'])} active layers")
857 |
858 | if suspicious_layers:
859 | recommendations.append("\n**Unusual Activity:**")
860 | for layer_info in suspicious_layers:
861 | recommendations.append(f"- Layer {layer_info['layer']}: {layer_info['reason']}")
862 |
863 | recommendations.append("\n**Layer Function Reference:**")
864 | recommendations.append("- Layers 7,12,16,20: Standard facial features")
865 | recommendations.append("- Layers 0-6: Early structure/composition")
866 | recommendations.append("- Layers 21-37: Fine details/textures")
867 |
868 | # Add specific action items
869 | recommendations.append("\n### 🔧 **Recommended Actions:**\n")
870 |
871 | if overtraining and user_goal != "Preserve strong character identity":
872 | recommendations.append("1. **Immediate:** Use Layer Targeting to clean detected layers")
873 | recommendations.append("2. **Test:** Check if targeted LoRA works better with style combinations")
874 | recommendations.append("3. **Compare:** A/B test original vs cleaned version")
875 | else:
876 | recommendations.append("1. **Current LoRA:** Appears suitable for your goal")
877 | recommendations.append("2. **Optional:** Light targeting if you experience conflicts")
878 | recommendations.append("3. **Monitor:** Watch for style conflicts in actual use")
879 |
880 | return "\n".join(recommendations)
881 |
882 | def create_layer_heatmap(analysis_result: dict):
883 | """Create a visual heatmap of layer magnitudes"""
884 | try:
885 | import matplotlib.pyplot as plt
886 | import numpy as np
887 |
888 | if "error" in analysis_result:
889 | return None
890 |
891 | layer_stats = analysis_result.get('layer_stats', {})
892 |
893 | # Prepare data for heatmap
894 | single_blocks = layer_stats.get('single_blocks', {})
895 | double_blocks = layer_stats.get('double_blocks', {})
896 | te_blocks = layer_stats.get('text_encoder', {})
897 |
898 | # Create magnitude arrays
899 | single_mags = [single_blocks.get(i, {}).get('avg_magnitude', 0) for i in range(38)]
900 | double_mags = [double_blocks.get(i, {}).get('avg_magnitude', 0) for i in range(20)]
901 | te_mags = [te_blocks.get(i, {}).get('avg_magnitude', 0) for i in range(12)]
902 |
903 | # Create the plot
904 | fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 8))
905 |
906 | # Single blocks heatmap
907 | ax1.bar(range(38), single_mags, color=['red' if i in [7,12,16,20] else 'lightblue' for i in range(38)])
908 | ax1.set_title('Single Block Layers (0-37)')
909 | ax1.set_ylabel('Magnitude')
910 | ax1.set_xticks(range(0, 38, 5))
911 |
912 | # Double blocks heatmap
913 | ax2.bar(range(20), double_mags, color=['red' if i in [7,12,16] else 'lightgreen' for i in range(20)])
914 | ax2.set_title('Double Block Layers (0-19)')
915 | ax2.set_ylabel('Magnitude')
916 | ax2.set_xticks(range(0, 20, 2))
917 |
918 | # Text encoder heatmap
919 | ax3.bar(range(12), te_mags, color=['red' if i == 7 else 'lightyellow' for i in range(12)])
920 | ax3.set_title('Text Encoder Layers (0-11)')
921 | ax3.set_ylabel('Magnitude')
922 | ax3.set_xlabel('Layer Number')
923 | ax3.set_xticks(range(12))
924 |
925 | plt.tight_layout()
926 | plt.suptitle(f"Layer Magnitude Analysis - {analysis_result['file_info']['filename']}", y=0.98)
927 |
928 | return fig
929 |
930 | except Exception as e:
931 | print(f"Error creating heatmap: {e}")
932 | return None
--------------------------------------------------------------------------------