├── .gitignore ├── assets ├── bg_beach.jpg ├── fc_output.jpg ├── fbc_output.jpg ├── i2i_output1.jpg ├── i2i_output2.jpg ├── reinforce_off.jpg ├── reinforce_on.jpg ├── subject_input.jpg └── subject_rembg.jpg ├── style.css ├── install.py ├── lib_iclight ├── settings.py ├── backend.py ├── logging.py ├── __init__.py ├── detail_utils.py ├── model_loader.py ├── backgrounds.py ├── backends │ ├── forge.py │ └── a1111.py ├── rembg_utils.py ├── utils.py ├── ic_light_nodes.py ├── patch_weight.py └── parameters.py ├── README.md ├── LICENSE └── scripts └── ic_light_script.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /assets/bg_beach.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoming02/sd-forge-ic-light/HEAD/assets/bg_beach.jpg -------------------------------------------------------------------------------- /assets/fc_output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoming02/sd-forge-ic-light/HEAD/assets/fc_output.jpg -------------------------------------------------------------------------------- /assets/fbc_output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoming02/sd-forge-ic-light/HEAD/assets/fbc_output.jpg -------------------------------------------------------------------------------- /assets/i2i_output1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoming02/sd-forge-ic-light/HEAD/assets/i2i_output1.jpg -------------------------------------------------------------------------------- /assets/i2i_output2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoming02/sd-forge-ic-light/HEAD/assets/i2i_output2.jpg -------------------------------------------------------------------------------- /assets/reinforce_off.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoming02/sd-forge-ic-light/HEAD/assets/reinforce_off.jpg -------------------------------------------------------------------------------- /assets/reinforce_on.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoming02/sd-forge-ic-light/HEAD/assets/reinforce_on.jpg -------------------------------------------------------------------------------- /assets/subject_input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoming02/sd-forge-ic-light/HEAD/assets/subject_input.jpg -------------------------------------------------------------------------------- /assets/subject_rembg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoming02/sd-forge-ic-light/HEAD/assets/subject_rembg.jpg -------------------------------------------------------------------------------- /style.css: -------------------------------------------------------------------------------- 1 | .ic-light-desc { 2 | padding-left: 1em; 3 | } 4 | 5 | .ic-light-btns button { 6 | border-radius: 0.5em !important; 7 | } 8 | -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | import launch 2 | 3 | if not launch.is_installed("onnxruntime"): 4 | launch.run_pip("install onnxruntime>=1.21.0", "onnxruntime for ic-light") 5 | 6 | if not launch.is_installed("rembg"): 7 | launch.run_pip("install rembg==2.0.65", "rembg for ic-light") 8 | 9 | if not launch.is_installed("cv2"): 10 | launch.run_pip("install opencv-python~=4.8.1", "opencv for ic-light") 11 | -------------------------------------------------------------------------------- /lib_iclight/settings.py: -------------------------------------------------------------------------------- 1 | from modules.shared import OptionInfo, opts 2 | 3 | 4 | def ic_settings(): 5 | args = {"section": ("ic", "IC Light"), "category_id": "sd"} 6 | 7 | opts.add_option( 8 | "ic_sync_dim", 9 | OptionInfo(True, "Show [Sync Resolution] Button", **args).needs_reload_ui(), 10 | ) 11 | 12 | opts.add_option( 13 | "ic_all_rembg", 14 | OptionInfo(False, "List all available Rembg models", **args).needs_reload_ui(), 15 | ) 16 | -------------------------------------------------------------------------------- /lib_iclight/backend.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from enum import Enum 3 | 4 | 5 | class BackendType(Enum): 6 | A1111 = -1 7 | Forge = 0 8 | reForge = 1 9 | Classic = 2 10 | 11 | 12 | def _import(module: str) -> bool: 13 | try: 14 | importlib.import_module(module) 15 | return True 16 | except ImportError: 17 | return False 18 | 19 | 20 | def detect_backend() -> BackendType: 21 | if _import("backend.shared"): 22 | return BackendType.Forge 23 | 24 | if _import("modules_forge.forge_version"): 25 | from modules_forge.forge_version import version 26 | 27 | if "1.10.1" in version: 28 | return BackendType.reForge 29 | else: 30 | return BackendType.Classic 31 | 32 | return BackendType.A1111 33 | -------------------------------------------------------------------------------- /lib_iclight/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | 5 | class ColorCode: 6 | RESET = "\033[0m" 7 | BLACK = "\033[0;90m" 8 | CYAN = "\033[0;36m" 9 | YELLOW = "\033[0;33m" 10 | RED = "\033[0;31m" 11 | 12 | MAP = { 13 | "DEBUG": BLACK, 14 | "INFO": CYAN, 15 | "WARNING": YELLOW, 16 | "ERROR": RED, 17 | } 18 | 19 | 20 | class ColoredFormatter(logging.Formatter): 21 | def format(self, record): 22 | levelname = record.levelname 23 | record.levelname = f"{ColorCode.MAP[levelname]}{levelname}{ColorCode.RESET}" 24 | return super().format(record) 25 | 26 | 27 | logger = logging.getLogger("IC-Light") 28 | logger.setLevel(logging.INFO) 29 | logger.propagate = False 30 | 31 | if not logger.handlers: 32 | handler = logging.StreamHandler(sys.stdout) 33 | handler.setFormatter(ColoredFormatter("[%(name)s] %(levelname)s - %(message)s")) 34 | logger.addHandler(handler) 35 | -------------------------------------------------------------------------------- /lib_iclight/__init__.py: -------------------------------------------------------------------------------- 1 | VERSION = "2.0" 2 | 3 | t2i_fc: str = """ 4 | Relighting with Foreground Condition
5 | Given a foreground image, generate a new background via txt2img, 6 | then blend them together with coherent lighting conditions. 7 | """ 8 | 9 | t2i_fbc: str = """ 10 | Relighting with Foreground and Background Condition
11 | Extract the subject from the foreground image, then blend it onto the background image, 12 | while keeping the lighting conditions coherent.
13 | Sampler and Steps are important; Prompts doesn't matter. 14 | """ 15 | 16 | i2i_fc: str = """ 17 | Relighting with Light-Map Condition
18 | Given an input image, generate a new background using conditioned lighting. 19 | """ 20 | 21 | removal: str = """ 22 | Note: Disable this feature if the image already has no background 23 | """ 24 | 25 | raw: str = """ 26 | Use the input before the background removal as the "Original" to restore details from 27 | """ 28 | -------------------------------------------------------------------------------- /lib_iclight/detail_utils.py: -------------------------------------------------------------------------------- 1 | # =========================================== # 2 | # Reference: # 3 | # https://youtu.be/5EuYKEvugLU?feature=shared # 4 | # =========================================== # 5 | 6 | import cv2 7 | import numpy as np 8 | from PIL import Image 9 | 10 | from modules.images import resize_image 11 | 12 | 13 | def resize_input(img: np.ndarray, h: int, w: int, mode: int = 1) -> np.ndarray: 14 | img = Image.fromarray(img) 15 | resized_img: Image.Image = resize_image(mode, img, w, h) # Crop & Resize 16 | 17 | return np.asarray(resized_img.convert("RGB"), dtype=np.uint8) 18 | 19 | 20 | def restore_detail( 21 | ic_light_image: np.ndarray, 22 | original_image: np.ndarray, 23 | blur_radius: int, 24 | ) -> Image.Image: 25 | h, w, c = ic_light_image.shape 26 | if c == 4: 27 | ic_light_image = cv2.cvtColor(ic_light_image, cv2.COLOR_RGBA2RGB) 28 | 29 | original_image = resize_input(original_image, h, w) 30 | 31 | ic_light_image = ic_light_image.astype(np.float32) / 255.0 32 | original_image = original_image.astype(np.float32) / 255.0 33 | 34 | blurred_ic_light = cv2.GaussianBlur(ic_light_image, (blur_radius, blur_radius), 0) 35 | blurred_original = cv2.GaussianBlur(original_image, (blur_radius, blur_radius), 0) 36 | 37 | DoG = original_image + (blurred_ic_light - blurred_original) 38 | DoG = np.clip(DoG * 255.0, 0, 255).round().astype(np.uint8) 39 | 40 | return Image.fromarray(DoG) 41 | -------------------------------------------------------------------------------- /lib_iclight/model_loader.py: -------------------------------------------------------------------------------- 1 | from .logging import logger 2 | 3 | 4 | class ICModels: 5 | _init: bool = False 6 | 7 | fc: str = "" 8 | fbc: str = "" 9 | 10 | fc_path: str = None 11 | fbc_path: str = None 12 | 13 | @classmethod 14 | def detect_models(cls): 15 | if cls._init: 16 | return 17 | else: 18 | cls._init = True 19 | 20 | import os 21 | 22 | from modules.paths import models_path 23 | 24 | folder = os.path.join(models_path, "ic-light") 25 | os.makedirs(folder, exist_ok=True) 26 | 27 | fc, fbc = None, None 28 | 29 | for obj in os.listdir(folder): 30 | if not obj.endswith(".safetensors"): 31 | continue 32 | if "fc" in obj.lower(): 33 | fc = os.path.join(folder, obj) 34 | if "fbc" in obj.lower(): 35 | fbc = os.path.join(folder, obj) 36 | 37 | if fc is None or fbc is None: 38 | logger.error("Failed to locate IC-Light models! Download from Releases!") 39 | return 40 | 41 | cls.fc: str = os.path.basename(fc).rsplit(".", 1)[0] 42 | cls.fc_path: str = fc 43 | 44 | cls.fbc: str = os.path.basename(fbc).rsplit(".", 1)[0] 45 | cls.fbc_path: str = fbc 46 | 47 | @classmethod 48 | def get_path(cls, model: str) -> str: 49 | match model: 50 | case cls.fc: 51 | return cls.fc_path 52 | case cls.fbc: 53 | return cls.fbc_path 54 | case _: 55 | raise ValueError 56 | -------------------------------------------------------------------------------- /lib_iclight/backgrounds.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import numpy as np 4 | 5 | 6 | class BackgroundFC(Enum): 7 | """Background Source for FC Models""" 8 | 9 | LEFT = "Left Light" 10 | RIGHT = "Right Light" 11 | TOP = "Top Light" 12 | BOTTOM = "Bottom Light" 13 | GREY = "Ambient" 14 | CUSTOM = "Custom" 15 | 16 | def get_bg(self, width: int = 512, height: int = 512) -> np.ndarray: 17 | match self: 18 | case BackgroundFC.LEFT: 19 | gradient = np.linspace(255, 0, width) 20 | image = np.tile(gradient, (height, 1)) 21 | input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) 22 | 23 | case BackgroundFC.RIGHT: 24 | gradient = np.linspace(0, 255, width) 25 | image = np.tile(gradient, (height, 1)) 26 | input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) 27 | 28 | case BackgroundFC.TOP: 29 | gradient = np.linspace(255, 0, height)[:, None] 30 | image = np.tile(gradient, (1, width)) 31 | input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) 32 | 33 | case BackgroundFC.BOTTOM: 34 | gradient = np.linspace(0, 255, height)[:, None] 35 | image = np.tile(gradient, (1, width)) 36 | input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) 37 | 38 | case BackgroundFC.GREY: 39 | input_bg = np.zeros((height, width, 3), dtype=np.uint8) + 127 40 | 41 | case BackgroundFC.CUSTOM: 42 | return None 43 | 44 | return input_bg 45 | -------------------------------------------------------------------------------- /lib_iclight/backends/forge.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | if TYPE_CHECKING: 4 | from modules.processing import StableDiffusionProcessing 5 | 6 | from ..parameters import ICLightArgs 7 | 8 | try: 9 | from ldm_patched.modules.model_patcher import ModelPatcher 10 | from ldm_patched.modules.sd import VAE 11 | from ldm_patched.modules.utils import load_torch_file 12 | 13 | classic = True 14 | 15 | except ImportError: 16 | from backend.patcher.base import ModelPatcher 17 | from backend.patcher.vae import VAE 18 | from backend.utils import load_torch_file 19 | 20 | classic = False 21 | 22 | import torch 23 | 24 | from modules.devices import device, dtype 25 | 26 | from ..ic_light_nodes import ICLight 27 | from ..model_loader import ICModels 28 | from ..utils import forge_numpy2pytorch 29 | 30 | 31 | @torch.inference_mode() 32 | def apply_ic_light(p: "StableDiffusionProcessing", args: "ICLightArgs"): 33 | sd = load_torch_file( 34 | ICModels.get_path(args.model_type), 35 | safe_load=True, 36 | device=device, 37 | ) 38 | 39 | work_model: ModelPatcher = p.sd_model.forge_objects.unet.clone() 40 | vae: VAE = p.sd_model.forge_objects.vae 41 | 42 | pixel_concat = ( 43 | forge_numpy2pytorch(args.get_concat_cond(p)) 44 | .to(device=vae.device, dtype=dtype) 45 | .movedim(1, 3) 46 | ) 47 | 48 | patched_unet: ModelPatcher = ICLight.apply( 49 | model=work_model, 50 | ic_model_state_dict=sd, 51 | c_concat={"samples": vae.encode(pixel_concat)}, 52 | mode=None if classic else args.model_type, 53 | ) 54 | 55 | p.sd_model.forge_objects.unet = patched_unet 56 | -------------------------------------------------------------------------------- /lib_iclight/rembg_utils.py: -------------------------------------------------------------------------------- 1 | # ============================================================= # 2 | # Reference: # 3 | # https://github.com/AUTOMATIC1111/stable-diffusion-webui-rembg # 4 | # ============================================================= # 5 | 6 | import os 7 | 8 | import numpy as np 9 | import rembg 10 | from PIL import Image 11 | 12 | from modules.paths import models_path 13 | from modules.shared import opts 14 | 15 | if "U2NET_HOME" not in os.environ: 16 | os.environ["U2NET_HOME"] = os.path.join(models_path, "u2net") 17 | 18 | 19 | def get_models() -> tuple[str]: 20 | if getattr(opts, "ic_all_rembg", False): 21 | return ( 22 | "u2net", 23 | "u2netp", 24 | "u2net_human_seg", 25 | "u2net_cloth_seg", 26 | "isnet-anime", 27 | "isnet-general-use", 28 | "silueta", 29 | ) 30 | else: 31 | return ( 32 | "u2net_human_seg", 33 | "isnet-anime", 34 | ) 35 | 36 | 37 | def run_rmbg( 38 | np_image: np.ndarray, 39 | model: str, 40 | foreground_threshold: int, 41 | background_threshold: int, 42 | erode_size: int, 43 | ) -> np.ndarray: 44 | image = Image.fromarray(np_image.astype(np.uint8)).convert("RGB") 45 | 46 | processed_image = rembg.remove( 47 | image, 48 | session=rembg.new_session( 49 | model_name=model, 50 | providers=["CPUExecutionProvider"], 51 | ), 52 | alpha_matting=True, 53 | alpha_matting_foreground_threshold=foreground_threshold, 54 | alpha_matting_background_threshold=background_threshold, 55 | alpha_matting_erode_size=erode_size, 56 | post_process_mask=True, 57 | only_mask=False, 58 | bgcolor=(127, 127, 127, 255), 59 | ) 60 | 61 | return np.asarray(processed_image.convert("RGB"), dtype=np.uint8) 62 | -------------------------------------------------------------------------------- /lib_iclight/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | 5 | from modules.images import LANCZOS 6 | 7 | 8 | def numpy2pytorch(imgs: np.ndarray) -> torch.Tensor: 9 | """Automatic1111's VAE accepts -1.0 ~ 1.0 tensors""" 10 | h = torch.from_numpy(np.stack(imgs, axis=0, dtype=np.float32)) / 127 - 1.0 11 | h = h.movedim(-1, 1) 12 | return h 13 | 14 | 15 | def forge_numpy2pytorch(img: np.ndarray) -> torch.Tensor: 16 | """Forge & ComfyUI's VAE accepts 0.0 ~ 1.0 tensors""" 17 | h = torch.from_numpy(img.astype(np.float32) / 255) 18 | h = h.movedim(-1, 1) 19 | return h 20 | 21 | 22 | def resize_and_center_crop(image: np.ndarray, w: int, h: int) -> np.ndarray: 23 | pil_image = Image.fromarray(image) 24 | original_width, original_height = pil_image.size 25 | scale_factor = max(w / original_width, h / original_height) 26 | resized_width = int(round(original_width * scale_factor)) 27 | resized_height = int(round(original_height * scale_factor)) 28 | resized_image = pil_image.resize((resized_width, resized_height), LANCZOS) 29 | left = (resized_width - w) / 2 30 | top = (resized_height - h) / 2 31 | right = (resized_width + w) / 2 32 | bottom = (resized_height + h) / 2 33 | cropped_image = resized_image.crop((left, top, right, bottom)) 34 | return np.asarray(cropped_image, dtype=np.uint8) 35 | 36 | 37 | def align_dim_latent(x: int) -> int: 38 | """ 39 | Align the pixel dimension to latent dimension\n 40 | Stable Diffusion uses 1:8 ratio for latent:pixel\n 41 | i.e. 1 latent unit == 8 pixel unit 42 | """ 43 | return round(x / 8) * 8 44 | 45 | 46 | def make_masked_area_grey(image: np.ndarray, alpha: np.ndarray) -> np.ndarray: 47 | """Make the masked area grey""" 48 | return ( 49 | (image.astype(np.float32) * alpha + (1.0 - alpha) * 127) 50 | .round() 51 | .clip(0, 255) 52 | .astype(np.uint8) 53 | ) 54 | -------------------------------------------------------------------------------- /lib_iclight/backends/a1111.py: -------------------------------------------------------------------------------- 1 | from ..logging import logger 2 | 3 | try: 4 | from lib_modelpatcher.model_patcher import ModulePatch 5 | except ImportError: 6 | logger.error("Please install [sd-webui-model-patcher] first!") 7 | raise 8 | 9 | from typing import TYPE_CHECKING, Callable 10 | 11 | if TYPE_CHECKING: 12 | from modules.processing import StableDiffusionProcessing 13 | 14 | from ..parameters import ICLightArgs 15 | 16 | from functools import wraps 17 | 18 | import safetensors.torch 19 | import torch 20 | 21 | from modules.devices import device, dtype 22 | 23 | from ..model_loader import ICModels 24 | from ..utils import numpy2pytorch 25 | 26 | 27 | def vae_encode(sd_model, image: torch.Tensor) -> torch.Tensor: 28 | """ 29 | image: [B, C, H, W] format tensor, ranging from -1.0 to 1.0 30 | Return: tensor in [B, C, H, W] format 31 | 32 | Note: Input image format differs from Forge/Comfy's VAE input format 33 | """ 34 | return sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image)) 35 | 36 | 37 | @torch.inference_mode() 38 | def apply_ic_light(p: "StableDiffusionProcessing", args: "ICLightArgs"): 39 | sd = safetensors.torch.load_file(ICModels.get_path(args.model_type)) 40 | 41 | concat_conds = vae_encode( 42 | p.sd_model, 43 | numpy2pytorch(args.get_concat_cond(p)).to(dtype=dtype, device=device), 44 | ).to(dtype=dtype) 45 | 46 | concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1) 47 | 48 | def apply_c_concat(unet, old_forward: Callable) -> Callable: 49 | @wraps(old_forward) 50 | def new_forward(x, timesteps=None, context=None, **kwargs): 51 | c_concat = torch.cat( 52 | ([concat_conds.to(x.device)] * (x.shape[0] // concat_conds.shape[0])), 53 | dim=0, 54 | ) 55 | new_x = torch.cat([x, c_concat], dim=1) 56 | return old_forward(new_x, timesteps, context, **kwargs) 57 | 58 | return new_forward 59 | 60 | model_patcher = p.get_model_patcher() 61 | model_patcher.add_module_patch( 62 | "diffusion_model", 63 | ModulePatch(create_new_forward_func=apply_c_concat), 64 | ) 65 | model_patcher.add_patches( 66 | patches={ 67 | "diffusion_model." + key: (value.to(dtype=dtype, device=device),) 68 | for key, value in sd.items() 69 | } 70 | ) 71 | -------------------------------------------------------------------------------- /lib_iclight/ic_light_nodes.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, TypedDict 2 | 3 | import torch 4 | 5 | from modules.devices import device, dtype 6 | 7 | try: 8 | from ldm_patched.modules.model_patcher import ModelPatcher 9 | except ImportError: 10 | from backend.patcher.base import ModelPatcher 11 | 12 | 13 | class UnetParams(TypedDict): 14 | input: torch.Tensor 15 | timestep: torch.Tensor 16 | c: dict 17 | cond_or_uncond: torch.Tensor 18 | 19 | 20 | class ICLight: 21 | """IC-Light Implementation""" 22 | 23 | @staticmethod 24 | def apply( 25 | model: ModelPatcher, 26 | ic_model_state_dict: dict[str, torch.Tensor], 27 | c_concat: dict, 28 | mode: Optional[str] = None, 29 | ) -> ModelPatcher: 30 | work_model = model.clone() 31 | 32 | model_config = ( 33 | work_model.model.model_config 34 | if hasattr(work_model.model, "model_config") 35 | else work_model.model.config 36 | ) 37 | scale_factor: float = model_config.latent_format.scale_factor 38 | 39 | concat_conds: torch.Tensor = c_concat["samples"] * scale_factor 40 | concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1) 41 | 42 | def apply_c_concat(params: UnetParams) -> UnetParams: 43 | """Apply c_concat on Unet call""" 44 | sample = params["input"] 45 | params["c"]["c_concat"] = torch.cat( 46 | ( 47 | [concat_conds.to(sample.device)] 48 | * (sample.shape[0] // concat_conds.shape[0]) 49 | ), 50 | dim=0, 51 | ) 52 | return params 53 | 54 | def unet_dummy_apply(unet_apply: Callable, params: UnetParams) -> Callable: 55 | """A dummy unet apply wrapper serving as the endpoint of wrapper chain""" 56 | return unet_apply(x=params["input"], t=params["timestep"], **params["c"]) 57 | 58 | existing_wrapper = work_model.model_options.get( 59 | "model_function_wrapper", unet_dummy_apply 60 | ) 61 | 62 | def wrapper_func(unet_apply: Callable, params: UnetParams) -> Callable: 63 | return existing_wrapper(unet_apply, params=apply_c_concat(params)) 64 | 65 | work_model.set_model_unet_function_wrapper(wrapper_func) 66 | 67 | args = { 68 | "patches": { 69 | ("diffusion_model." + key): (value.to(dtype=dtype, device=device),) 70 | for key, value in ic_model_state_dict.items() 71 | } 72 | } 73 | 74 | if mode is not None: 75 | args["filename"] = f"ic-light-{mode}" 76 | 77 | work_model.add_patches(**args) 78 | return work_model 79 | -------------------------------------------------------------------------------- /lib_iclight/patch_weight.py: -------------------------------------------------------------------------------- 1 | """ 2 | Credit: huchenlei 3 | https://github.com/huchenlei/ComfyUI-layerdiffuse/blob/v0.1.0/layered_diffusion.py#L35 4 | 5 | Modified by. Haoming02 to work with reForge 6 | """ 7 | 8 | from functools import wraps 9 | from typing import Callable 10 | 11 | import torch 12 | 13 | from ldm_patched.modules import lora 14 | from ldm_patched.modules.model_management import cast_to_device 15 | from ldm_patched.modules.model_patcher import ModelPatcher 16 | 17 | 18 | def adjust_channel(func: Callable): 19 | """Patches weight application to accept multi-channel inputs""" 20 | 21 | @torch.inference_mode() 22 | @wraps(func) 23 | def calculate_weight(*args) -> torch.Tensor: 24 | weight = func(*args) 25 | 26 | if isinstance(args[0], list): 27 | patches, weight, key = args[0:3] 28 | else: 29 | assert isinstance(args[1], list) 30 | patches, weight, key = args[1:4] 31 | 32 | for p in patches: 33 | alpha = p[0] 34 | v = p[1] 35 | 36 | # The recursion call should be handled in the main func call. 37 | if isinstance(v, list): 38 | continue 39 | 40 | if len(v) == 1: 41 | patch_type = "diff" 42 | elif len(v) == 2: 43 | patch_type = v[0] 44 | v = v[1] 45 | 46 | if patch_type == "diff": 47 | w1 = v[0] 48 | if all( 49 | ( 50 | alpha != 0.0, 51 | w1.shape != weight.shape, 52 | w1.ndim == weight.ndim == 4, 53 | ) 54 | ): 55 | new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)] 56 | print( 57 | f"Merged with {key} channel changed from {weight.shape} to {new_shape}" 58 | ) 59 | new_diff = alpha * cast_to_device(w1, weight.device, weight.dtype) 60 | new_weight = torch.zeros(size=new_shape).to(weight) 61 | new_weight[ 62 | : weight.shape[0], 63 | : weight.shape[1], 64 | : weight.shape[2], 65 | : weight.shape[3], 66 | ] = weight 67 | new_weight[ 68 | : new_diff.shape[0], 69 | : new_diff.shape[1], 70 | : new_diff.shape[2], 71 | : new_diff.shape[3], 72 | ] += new_diff 73 | new_weight = new_weight.contiguous().clone() 74 | weight = new_weight 75 | 76 | return weight 77 | 78 | return calculate_weight 79 | 80 | 81 | if hasattr(lora, "calculate_weight"): 82 | lora.calculate_weight = adjust_channel(lora.calculate_weight) 83 | print("\nlora.calculate_weight Patched!\n") 84 | else: 85 | ModelPatcher.calculate_weight = adjust_channel(ModelPatcher.calculate_weight) 86 | print("\nModelPatcher.calculate_weight Patched!\n") 87 | -------------------------------------------------------------------------------- /lib_iclight/parameters.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | from modules.api.api import decode_base64_to_image 5 | from modules.processing import ( 6 | StableDiffusionProcessing, 7 | StableDiffusionProcessingImg2Img, 8 | StableDiffusionProcessingTxt2Img, 9 | ) 10 | 11 | from .logging import logger 12 | from .model_loader import ICModels 13 | from .rembg_utils import run_rmbg 14 | from .utils import ( 15 | align_dim_latent, 16 | make_masked_area_grey, 17 | resize_and_center_crop, 18 | ) 19 | 20 | 21 | class DetailTransfer: 22 | def __init__(self, transfer: bool, radius: int, original: np.ndarray): 23 | self.enable = transfer 24 | self.radius = radius 25 | self.original = original 26 | 27 | 28 | class ICLightArgs: 29 | def __init__( 30 | self, 31 | p: StableDiffusionProcessing, 32 | model_type: str, 33 | input_fg: np.ndarray, 34 | uploaded_bg: np.ndarray, 35 | remove_bg: bool, 36 | rembg_model: str, 37 | foreground_threshold: int, 38 | background_threshold: int, 39 | erode_size: int, 40 | detail_transfer: bool, 41 | detail_transfer_raw: bool, 42 | detail_transfer_blur_radius: int, 43 | reinforce_fg: bool, 44 | ): 45 | self.model_type: str = model_type 46 | 47 | if isinstance(p, StableDiffusionProcessingImg2Img): 48 | self.input_fg: np.ndarray = np.asarray(p.init_images[0], dtype=np.uint8) 49 | p.init_images[0] = Image.fromarray(self.parse_image(input_fg)) 50 | 51 | if p.cfg_scale > 2.5: 52 | logger.warning("Low CFG is recommended!") 53 | if p.denoising_strength < 0.9: 54 | logger.warning("High Denoising Strength is recommended!") 55 | 56 | else: 57 | self.input_fg: np.ndarray = self.parse_image(input_fg) 58 | 59 | self.uploaded_bg: np.ndarray = self.parse_image(uploaded_bg) 60 | 61 | self.input_fg_rgb: np.ndarray = self.process_input_foreground( 62 | self.input_fg, 63 | remove_bg, 64 | rembg_model, 65 | foreground_threshold, 66 | background_threshold, 67 | erode_size, 68 | ) 69 | 70 | self.detail_transfer = DetailTransfer( 71 | detail_transfer, 72 | detail_transfer_blur_radius, 73 | self.input_fg if detail_transfer_raw else self.input_fg_rgb, 74 | ) 75 | 76 | if detail_transfer and reinforce_fg: 77 | assert isinstance(p, StableDiffusionProcessingImg2Img) 78 | assert self.model_type == ICModels.fc 79 | 80 | lightmap = np.asarray(p.init_images[0], dtype=np.uint8) 81 | 82 | mask = np.all(self.input_fg_rgb == np.asarray([127, 127, 127]), axis=-1) 83 | mask = mask[..., None] # [H, W, 1] 84 | lightmap = resize_and_center_crop( 85 | lightmap, 86 | w=self.input_fg_rgb.shape[1], 87 | h=self.input_fg_rgb.shape[0], 88 | ) 89 | lightmap_rgb = lightmap[..., :3] 90 | lightmap_alpha = lightmap[..., 3:4] 91 | lightmap_rgb = self.input_fg_rgb * (1 - mask) + lightmap_rgb * mask 92 | lightmap = np.concatenate([lightmap_rgb, lightmap_alpha], axis=-1) 93 | 94 | p.init_images[0] = Image.fromarray(lightmap.astype(np.uint8)) 95 | 96 | @staticmethod 97 | def process_input_foreground( 98 | image: np.ndarray, 99 | remove_bg: bool, 100 | rembg_model: str, 101 | foreground_threshold: int, 102 | background_threshold: int, 103 | erode_size: int, 104 | ) -> np.ndarray: 105 | """Process input foreground image into [H, W, 3] format""" 106 | 107 | if image is None: 108 | return None 109 | 110 | if remove_bg: 111 | return run_rmbg( 112 | image, 113 | rembg_model, 114 | foreground_threshold, 115 | background_threshold, 116 | erode_size, 117 | ) 118 | 119 | assert len(image.shape) == 3, "Does not support greyscale image..." 120 | 121 | if image.shape[2] == 3: 122 | return image 123 | 124 | return make_masked_area_grey( 125 | image[..., :3], 126 | image[..., 3:].astype(np.float32) / 255.0, 127 | ) 128 | 129 | def get_concat_cond(self, p: StableDiffusionProcessing) -> np.ndarray: 130 | """Returns concat condition in [B, H, W, C] format.""" 131 | 132 | if getattr(p, "is_hr_pass", False): 133 | assert isinstance(p, StableDiffusionProcessingTxt2Img) 134 | if p.hr_resize_x == 0 and p.hr_resize_y == 0: 135 | hr_x = int(p.width * p.hr_scale) 136 | hr_y = int(p.height * p.hr_scale) 137 | else: 138 | hr_y, hr_x = p.hr_resize_y, p.hr_resize_x 139 | image_width = align_dim_latent(hr_x) 140 | image_height = align_dim_latent(hr_y) 141 | 142 | else: 143 | image_width = p.width 144 | image_height = p.height 145 | 146 | fg = resize_and_center_crop(self.input_fg_rgb, image_width, image_height) 147 | 148 | match self.model_type: 149 | case ICModels.fc: 150 | np_concat = [fg] 151 | case ICModels.fbc: 152 | bg = resize_and_center_crop( 153 | self.uploaded_bg, 154 | image_width, 155 | image_height, 156 | ) 157 | np_concat = [fg, bg] 158 | case _: 159 | raise ValueError 160 | 161 | return np.stack(np_concat, axis=0) 162 | 163 | @staticmethod 164 | def decode_base64(base64string: str) -> np.ndarray: 165 | return np.asarray(decode_base64_to_image(base64string), dtype=np.uint8) 166 | 167 | @staticmethod 168 | def parse_image(value) -> np.ndarray: 169 | if isinstance(value, str): 170 | return ICLightArgs.decode_base64(value) 171 | assert isinstance(value, np.ndarray) or (value is None) 172 | return value 173 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SD Forge IC-Light 2 | This is an Extension for the [Forge Webui](https://github.com/lllyasviel/stable-diffusion-webui-forge), which implements [IC-Light](https://github.com/lllyasviel/IC-Light), allowing you to manipulate the illumination of images. 3 | 4 | ### Compatibility Matrix 5 | 6 | > **Last Checked:** 2025 Apr.29 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 |
Automatic1111
dev
Forge
(Gradio 4)
Forge Classic
(Gradio 3)
reForge
main
reForge
dev
reForge
dev2
WorkingWorkingWorkingWorkingWorkingWorking
26 | 27 |
28 | for Automatic1111 Webui 29 | 30 | - You additionally need to install [sd-webui-model-patcher](https://github.com/huchenlei/sd-webui-model-patcher) first 31 | 32 |
33 | 34 | ## Getting Started 35 | 1. Download the two models from [Releases](https://github.com/Haoming02/sd-forge-ic-light/releases) 36 | 2. Create a new folder, `ic-light`, inside your webui `models` folder 37 | 3. Place the two models inside the `ic-light` folder 38 | 4. **(Optional)** You can rename the models, as long as the filename contains either **`fc`** or **`fbc`** 39 | 40 | ## How to Use 41 | 42 | > [!Important] 43 | > IC-Light only supports **SD1** checkpoints 44 | 45 | #### Index 46 | 47 | 1. [txt2img - FC](#txt2img---fc) 48 | 2. [txt2img - FBC](#txt2img---fbc) 49 | 3. [img2img - FC](#img2img---fc) 50 | - [Reinforce Foreground](#reinforce-foreground) 51 | 4. Options 52 | - [Background Removal](#background-removal) 53 | - [Restore Details](#restore-details) 54 | 55 |

56 |
57 | example Foreground image 58 |

59 | 60 | ### txt2img - FC 61 | > Relighting with Foreground Condition 62 | 63 | - In the Extension input, upload an image of your subject, then generate a new background using **txt2img** 64 | - If the generation aspect ratio is different, the `Foreground` image will be `Crop and resize` first 65 | - `Hires. Fix` is supported 66 | 67 |

68 |
69 | example output
70 | a photo of a gentleman in suit, standing under sunset 71 |

72 | 73 | ### txt2img - FBC 74 | > Relighting with Foreground and Background Condition 75 | 76 | - In the Extension inputs, upload an image of your subject, and another image as the background 77 | - Simply write some quality tags as the prompts 78 | - `Hires. Fix` is supported 79 | 80 |

81 |
82 | example Background image 83 |

84 | 85 |

86 |
87 | example output
88 | a photo of a gentleman in suit, standing at a beach, sunny day 89 |

90 | 91 | ### img2img - FC 92 | > Relighting with Light-Map Condition 93 | 94 | - In the **img2img** input, upload an image of your subject as normal 95 | - In the Extension input, you can select between different light directions, or select `Custom LightMap` and upload one yourself 96 | - Describe the scene with the prompts 97 | - Low `CFG` *(`~2.0`)* and high `Denoising strength` *(`~ 1.0`)* is recommended 98 | 99 |

100 | 101 |
102 | example output
103 | Right Light | Left Light
104 | a photo of a gentleman in suit, winter, snowing 105 |

106 | 107 | #### Reinforce Foreground 108 | 109 | When enabled, the subject will be additionally pasted onto the light map to preserve the original color. This may improve the details at the cost of weaker lighting influence. 110 | 111 | > As shown below, the suit gets brightened to a khaki color by the prompt; using `Reinforce Foreground` allows the suit to retain more of its original color 112 | 113 |

114 | 115 |
116 | example output
117 | Off | On
118 | fire, explosion 119 |

120 | 121 |
122 | 123 | ### Options 124 | > These settings are available for all 3 modes 125 | 126 | #### Background Removal 127 | 128 | - Use the **[rembg](https://github.com/danielgatis/rembg)** package to separate the subject from the background. 129 | - If you already have a subject image with alpha, you can simply disable this option. 130 | - If you have an anime subject instead, select `isnet-anime` from the **Background Removal Model** dropdown. 131 | - When this is enabled, it will additionally append the result to the outputs. 132 | - If the separation is not clean enough, edit the **Threshold** parameters to improve the accuracy. 133 | 134 |

135 |
136 | example result 137 |

138 | 139 | #### Restore Details 140 | 141 | Use the *Difference of Gaussian* algorithm to transfer the details from the input to the output. 142 | 143 | By default, this only uses the `DoG` of the subject without background. You can also switch to using the `DoG` of the entire input image instead. Increasing the **Blur Radius** will strengthen the effect. 144 | 145 |
146 | 147 | ## Settings 148 | 149 | > The settings are in the **IC Light** section under the Stable Diffusion category in the **Settings** tab 150 | 151 | - **Sync Resolution Button:** Adds a button in the `txt2img` tab that changes the `Width` and `Height` parameters to the closest ratio of the uploaded `Foreground` image. 152 | - **All Rembg Models:** By default, the Extension only shows `u2net_human_seg` and `isnet-anime` options. If those do not suit your needs *(**eg.** your subject is not a "person")*, you may enable this to list all available models instead. 153 | 154 | ## Roadmap 155 | - [X] Select different `rembg` models 156 | - [X] API Support 157 | - see [wiki](https://github.com/Haoming02/sd-forge-ic-light/wiki/API) 158 | - [ ] Improve `Reinforce Foreground` 159 | - [ ] Improve `Restore Details` 160 | 161 |
162 | 163 |
164 | Copyright 2024 huchenlei
165 | Copyright 2025 Haoming02
166 | 
167 | Licensed under the Apache License, Version 2.0 (the "License");
168 | you may not use this file except in compliance with the License.
169 | You may obtain a copy of the License at
170 | 
171 |     http://www.apache.org/licenses/LICENSE-2.0
172 | 
173 | Unless required by applicable law or agreed to in writing, software
174 | distributed under the License is distributed on an "AS IS" BASIS,
175 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
176 | See the License for the specific language governing permissions and
177 | limitations under the License.
178 | 
179 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | -------------------------------------------------------------------------------- /scripts/ic_light_script.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import numpy as np 3 | from lib_iclight import VERSION, i2i_fc, raw, removal, t2i_fbc, t2i_fc 4 | from lib_iclight.backend import BackendType, detect_backend 5 | from lib_iclight.backgrounds import BackgroundFC 6 | from lib_iclight.detail_utils import restore_detail 7 | from lib_iclight.logging import logger 8 | from lib_iclight.model_loader import ICModels 9 | from lib_iclight.parameters import ICLightArgs 10 | from lib_iclight.rembg_utils import get_models 11 | from lib_iclight.settings import ic_settings 12 | 13 | from modules import scripts 14 | from modules.script_callbacks import on_ui_settings 15 | from modules.shared import opts 16 | from modules.ui_components import InputAccordion 17 | 18 | backend_type = detect_backend() 19 | 20 | if backend_type is BackendType.A1111: 21 | from lib_iclight.backends.a1111 import apply_ic_light 22 | else: 23 | from lib_iclight.backends.forge import apply_ic_light 24 | 25 | if backend_type in (BackendType.Classic, BackendType.reForge): 26 | from lib_iclight import patch_weight # noqa 27 | 28 | 29 | class ICLightScript(scripts.Script): 30 | def __init__(self): 31 | ICModels.detect_models() 32 | self.args: ICLightArgs 33 | self.extra_images: list[np.ndarray] 34 | 35 | def title(self): 36 | return "IC Light" 37 | 38 | def show(self, is_img2img): 39 | return scripts.AlwaysVisible 40 | 41 | def ui(self, is_img2img) -> list[gr.components.Component]: 42 | with InputAccordion(False, label=f"{self.title()} {VERSION}") as enable: 43 | with gr.Row(): 44 | model_type = gr.Dropdown( 45 | label="Mode", 46 | choices=[ICModels.fc, ICModels.fbc], 47 | value=ICModels.fc, 48 | interactive=(not is_img2img), 49 | ) 50 | desc = gr.Markdown( 51 | value=(i2i_fc if is_img2img else t2i_fc), 52 | elem_classes=["ic-light-desc"], 53 | ) 54 | 55 | with gr.Column(variant="panel"): 56 | with gr.Row(): 57 | input_fg = gr.Image( 58 | label=("Lighting Conditioning" if is_img2img else "Foreground"), 59 | source="upload", 60 | type="numpy", 61 | height=480, 62 | visible=True, 63 | image_mode="RGBA", 64 | ) 65 | uploaded_bg = gr.Image( 66 | label="Background", 67 | source="upload", 68 | type="numpy", 69 | height=480, 70 | visible=False, 71 | image_mode="RGB", 72 | ) 73 | 74 | def parse_resolution(img: np.ndarray | None) -> list[int, int]: 75 | if img is None: 76 | return [gr.skip(), gr.skip()] 77 | 78 | h, w, _ = img.shape 79 | while (w > 2048) or (h > 2048): 80 | w /= 2 81 | h /= 2 82 | 83 | return [round(w / 64) * 64, round(h / 64) * 64] 84 | 85 | if not is_img2img: 86 | _sync: bool = getattr(opts, "ic_sync_dim", True) 87 | with gr.Row(variant="compact", elem_classes=["ic-light-btns"]): 88 | sync = gr.Button("Sync Resolution", visible=_sync) 89 | sync.click( 90 | fn=parse_resolution, 91 | inputs=[input_fg], 92 | outputs=[self.txt2img_width, self.txt2img_height], 93 | show_progress="hidden", 94 | ) 95 | 96 | flip_bg = gr.Button("Flip Background", visible=False) 97 | 98 | _sources = [bg.value for bg in BackgroundFC] 99 | background_source = gr.Radio( 100 | label="Background Source", 101 | choices=_sources, 102 | value=_sources[-1], 103 | visible=is_img2img, 104 | type="value", 105 | ) 106 | 107 | with InputAccordion(True, label="Background Removal") as remove_bg: 108 | gr.Markdown(removal) 109 | 110 | _rembg_models = get_models() 111 | rembg_model = gr.Dropdown( 112 | label="Background Removal Model", 113 | choices=_rembg_models, 114 | value=_rembg_models[0], 115 | ) 116 | with gr.Row(): 117 | foreground_threshold = gr.Slider( 118 | label="Foreground Threshold", 119 | value=225, 120 | minimum=0, 121 | maximum=255, 122 | step=1, 123 | ) 124 | background_threshold = gr.Slider( 125 | label="Background Threshold", 126 | value=16, 127 | minimum=0, 128 | maximum=255, 129 | step=1, 130 | ) 131 | erode_size = gr.Slider( 132 | label="Erode Size", 133 | value=16, 134 | minimum=0, 135 | maximum=128, 136 | step=1, 137 | ) 138 | 139 | with InputAccordion(False, label="Restore Details") as detail_transfer: 140 | detail_transfer_raw = gr.Checkbox(False, label=raw) 141 | detail_transfer_blur_radius = gr.Slider( 142 | label="Blur Radius", 143 | info="for Difference of Gaussian; higher = stronger", 144 | value=3, 145 | minimum=1, 146 | maximum=9, 147 | step=2, 148 | ) 149 | 150 | with gr.Row(variant="compact", visible=is_img2img): 151 | reinforce_fg = gr.Checkbox( 152 | value=False, 153 | label="Reinforce Foreground", 154 | info="Paste the Subject onto the Lighting Conditioning", 155 | ) 156 | 157 | if is_img2img: 158 | self._hook_i2i(input_fg, background_source) 159 | else: 160 | self._hook_t2i(model_type, flip_bg, uploaded_bg, desc) 161 | 162 | components: list[gr.components.Component] = [ 163 | enable, 164 | model_type, 165 | input_fg, 166 | uploaded_bg, 167 | remove_bg, 168 | rembg_model, 169 | foreground_threshold, 170 | background_threshold, 171 | erode_size, 172 | detail_transfer, 173 | detail_transfer_raw, 174 | detail_transfer_blur_radius, 175 | reinforce_fg, 176 | ] 177 | 178 | for comp in components: 179 | comp.do_not_save_to_config = True 180 | 181 | return components 182 | 183 | def before_process(self, p, enable: bool, *args, **kwargs): 184 | self.extra_images: list[np.ndarray] = [] 185 | self.args = None 186 | 187 | if not enable: 188 | return 189 | if not getattr(p.sd_model, "is_sd1", True): 190 | logger.error("IC-Light only supports SD1 checkpoint...") 191 | return 192 | if args[1] is None: 193 | logger.error("An input image is required...") 194 | return 195 | 196 | self.args = ICLightArgs(p, *args) 197 | self.extra_images.append(self.args.input_fg_rgb) 198 | 199 | def process_before_every_sampling(self, p, *args, **kwargs): 200 | if self.args is not None: 201 | apply_ic_light(p, self.args) 202 | 203 | def postprocess_image(self, p, pp, *args, **kwargs): 204 | if self.args is None: 205 | return 206 | if not self.args.detail_transfer.enable: 207 | return 208 | 209 | self.extra_images.append( 210 | restore_detail( 211 | np.asarray(pp.image, dtype=np.uint8), 212 | self.args.detail_transfer.original, 213 | self.args.detail_transfer.radius, 214 | ) 215 | ) 216 | 217 | def postprocess(self, p, processed, *args, **kwargs): 218 | if self.args is None: 219 | return 220 | 221 | processed.images.extend(self.extra_images) 222 | 223 | def after_component(self, component: gr.Slider, **kwargs): 224 | if not getattr(opts, "ic_sync_dim", True): 225 | return 226 | 227 | if not (elem_id := kwargs.get("elem_id", None)): 228 | return 229 | 230 | if elem_id == "txt2img_width": 231 | self.txt2img_width = component 232 | if elem_id == "txt2img_height": 233 | self.txt2img_height = component 234 | 235 | @staticmethod 236 | def _hook_t2i(model_type: gr.Dropdown, flip_bg: gr.Button, uploaded_bg, desc): 237 | def on_model_change(model: str): 238 | match model: 239 | case ICModels.fc: 240 | return ( 241 | gr.update(visible=False), 242 | gr.update(visible=False), 243 | gr.update(value=t2i_fc), 244 | ) 245 | case ICModels.fbc: 246 | return ( 247 | gr.update(visible=True), 248 | gr.update(visible=True), 249 | gr.update(value=t2i_fbc), 250 | ) 251 | case _: 252 | raise ValueError 253 | 254 | model_type.change( 255 | fn=on_model_change, 256 | inputs=[model_type], 257 | outputs=[flip_bg, uploaded_bg, desc], 258 | show_progress="hidden", 259 | ) 260 | 261 | def on_flip_image(image: np.ndarray) -> np.ndarray: 262 | if image is None: 263 | return gr.skip() 264 | return gr.update(value=np.fliplr(image)) 265 | 266 | flip_bg.click(fn=on_flip_image, inputs=[uploaded_bg], outputs=[uploaded_bg]) 267 | 268 | @staticmethod 269 | def _hook_i2i(input_fg: gr.Image, background_source: gr.Dropdown): 270 | def update_img2img_input(source: str): 271 | source_fc = BackgroundFC(source) 272 | if source_fc is BackgroundFC.CUSTOM: 273 | return gr.skip() 274 | else: 275 | return gr.update(value=source_fc.get_bg()) 276 | 277 | background_source.input( 278 | fn=update_img2img_input, 279 | inputs=[background_source], 280 | outputs=[input_fg], 281 | show_progress="hidden", 282 | ) 283 | 284 | input_fg.upload( 285 | fn=lambda: gr.update(value=BackgroundFC.CUSTOM.value), 286 | outputs=[background_source], 287 | show_progress="hidden", 288 | ) 289 | 290 | 291 | on_ui_settings(ic_settings) 292 | --------------------------------------------------------------------------------