├── .gitignore ├── LICENSE ├── README.md ├── lib_free_u ├── global_state.py ├── unet.py └── xyz_grid.py └── scripts └── freeu.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | presets.json 3 | *.csv 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ljleb 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sd-webui-freeu 2 | implementation of [FreeU](https://github.com/ChenyangSi/FreeU) as an [a1111 sd webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) extension 3 | 4 | At each of the 3 stages of the UNet decoder: 5 | - Apply a scalar on a window over the features of the backbone 6 | - Tame the frequencies of the skip connection 7 | 8 | ## Settings 9 | 10 | - Start At Step: do not apply FreeU until this sampling step is reached 11 | - Stop At Step: apply FreeU until this sampling step is reached 12 | - Transition Smoothness: see $k_{smooth}$ in [this desmos graph](https://www.desmos.com/calculator/ngcqo5ictm) 13 | - Backbone n Scale: scalar applied to the backbone window during UNet stage n 14 | - Backbone n Offset: offset of the window, 1 is the same as 0 as the window wraps around the downsampled latent features 15 | - Backbone n Width: width of the window applied to the backbone 16 | - Skip n Scale: scalar applied to the low frequencies (low end) of the skip connection during UNet stage n 17 | - Skip n High End Scale: scalar applied to the high frequencies (high end) of the skip connection 18 | - Skip n Cutoff: ratio that separates low from high frequencies, 0 means to control the single lowest frequency with "Skip n Scale" and 1 means scale all frequencies with "Skip n Scale" 19 | 20 | ## API 21 | 22 | You can pass a single dict as the alwayson script args when making API calls: 23 | 24 | ```json 25 | { 26 | "alwayson_scripts": { 27 | "freeu": { 28 | "args": [{ 29 | "enable": true, 30 | "start_ratio": 0.1, 31 | "stop_ratio": 0.9, 32 | "transition_smoothness": 0.1, 33 | "stage_infos": [ 34 | { 35 | "backbone_factor": 1.2, 36 | "backbone_offset": 0.5, 37 | "backbone_width": 0.75, 38 | "skip_factor": 0.9, 39 | "skip_high_end_factor": 1.1, 40 | "skip_cutoff": 0.3 41 | }, 42 | { 43 | "backbone_factor": 1.4, 44 | "backbone_offset": 0.5, 45 | "backbone_width": 0.75, 46 | "skip_factor": 0.2, 47 | "skip_high_end_factor": 1.1, 48 | "skip_cutoff": 0.3 49 | }, 50 | { 51 | "backbone_factor": 1.1, 52 | "backbone_offset": 0.5, 53 | "backbone_width": 0.75, 54 | "skip_factor": 0.9, 55 | "skip_high_end_factor": 1.1, 56 | "skip_cutoff": 0.3 57 | } 58 | ] 59 | }] 60 | } 61 | } 62 | } 63 | ``` 64 | 65 | It is possible to omit any of the entries. For example: 66 | 67 | ```json 68 | { 69 | "alwayson_scripts": { 70 | "freeu": { 71 | "args": [{ 72 | "start_ratio": 0.1, 73 | "stage_infos": [ 74 | { 75 | "backbone_factor": 0.8, 76 | "backbone_offset": 0.5, 77 | "skip_high_end_factor": 0.9 78 | } 79 | ] 80 | }] 81 | } 82 | } 83 | } 84 | ``` 85 | 86 | Here, since there is a single dict in the `stage_infos` array, freeu will only have an effect during the first stage of the unet. 87 | If you want to modify only the second stage, prepend the `"stage_infos"` array with 1 empty dict `{}`. 88 | If you want to modify only the third stage, prepend the `"stage_infos"` array with 2 empty dicts. 89 | 90 | If `"stop_ratio"` or `"start_ratio"` is an integer, then it is a step number. 91 | Otherwise, it is expected to be a float between `0.0` and `1.0` and it represents a ratio of the total sampling steps. 92 | -------------------------------------------------------------------------------- /lib_free_u/global_state.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import inspect 3 | import json 4 | import pathlib 5 | import re 6 | import sys 7 | from typing import Union, List, Any 8 | 9 | 10 | @dataclasses.dataclass 11 | class StageInfo: 12 | backbone_factor: float = 1.0 13 | skip_factor: float = 1.0 14 | backbone_offset: float = 0.0 15 | backbone_width: float = 0.5 16 | skip_cutoff: float = 0.0 17 | skip_high_end_factor: float = 1.0 18 | # <- add new fields at the end here for png info backwards compatibility 19 | 20 | def to_dict(self, include_default=False): 21 | default_stage_info = StageInfo() 22 | res = vars(self).copy() 23 | for k, v in res.copy().items(): 24 | if not include_default and v == getattr(default_stage_info, k): 25 | del res[k] 26 | 27 | return res 28 | 29 | def copy(self): 30 | return StageInfo(**vars(self)) 31 | 32 | 33 | STAGE_INFO_ARGS_LEN = len(inspect.getfullargspec(StageInfo.__init__)[0]) - 1 # off by one because of self 34 | STAGES_COUNT = 3 35 | shorthand_re = re.compile(r"^([a-z]{1,2})([0-9]+)$") 36 | all_versions = { 37 | f"Version {version+1}": str(version+1) 38 | for version in range(2) 39 | } 40 | reversed_all_versions = { 41 | v: k 42 | for k, v in all_versions.items() 43 | } 44 | 45 | xyz_attrs: dict = {} 46 | current_sampling_step: float = 0 47 | 48 | 49 | @dataclasses.dataclass 50 | class State: 51 | enable: bool = True 52 | start_ratio: Union[float, int] = 0.0 53 | stop_ratio: Union[float, int] = 1.0 54 | transition_smoothness: float = 0.0 55 | version: str = "1" 56 | stage_infos: List[Union[StageInfo, dict, Any]] = dataclasses.field(default_factory=lambda: [StageInfo() for _ in range(STAGES_COUNT)]) 57 | 58 | def __post_init__(self): 59 | self.stage_infos = self.group_stage_infos() 60 | self.version = self.format_version() 61 | 62 | def group_stage_infos(self): 63 | res = [] 64 | i = 0 65 | while i < len(self.stage_infos) and len(res) < STAGES_COUNT: 66 | if isinstance(self.stage_infos[i], StageInfo): 67 | res.append(self.stage_infos[i]) 68 | i += 1 69 | elif isinstance(self.stage_infos[i], dict): 70 | res.append(StageInfo(**self.stage_infos[i])) 71 | i += 1 72 | else: 73 | next_i = i + STAGE_INFO_ARGS_LEN 74 | res.append(StageInfo(*self.stage_infos[i:next_i])) 75 | i = next_i 76 | 77 | for _ in range(STAGES_COUNT - len(res)): 78 | res.append(StageInfo()) 79 | 80 | return res 81 | 82 | def format_version(self): 83 | if self.version not in reversed_all_versions: 84 | return all_versions.get(self.version, "1") 85 | 86 | return str(self.version) 87 | 88 | def to_dict(self): 89 | result = vars(self).copy() 90 | result["stage_infos"] = [stage_info.to_dict() for stage_info in result["stage_infos"]] 91 | del result["enable"] 92 | return result 93 | 94 | def copy(self): 95 | self_vars = vars(self) 96 | old_stage_infos = self_vars["stage_infos"] 97 | self_vars["stage_infos"] = old_stage_infos.copy() 98 | for i, stage_info in enumerate(old_stage_infos): 99 | self_vars["stage_infos"][i] = stage_info.copy() 100 | 101 | return State(**self_vars) 102 | 103 | def update_attr(self, key, value): 104 | if match := shorthand_re.match(key): 105 | char, index = match.group(1, 2) 106 | stage_info = self.stage_infos[int(index)] 107 | if char == "b": 108 | stage_info.backbone_factor = value 109 | elif char == "s": 110 | stage_info.skip_factor = value 111 | elif char == "o": 112 | stage_info.backbone_offset = value 113 | elif char == "w": 114 | stage_info.backbone_width = value 115 | elif char == "t": 116 | stage_info.skip_cutoff = value 117 | elif char == "h": 118 | stage_info.skip_high_end_factor = value 119 | else: 120 | self.__dict__[key] = value 121 | 122 | 123 | def apply_xyz(): 124 | global instance 125 | 126 | if preset_key := xyz_attrs.get("preset"): 127 | if preset := all_presets.get(preset_key): 128 | instance = preset.copy() 129 | elif preset_key != "UI Settings": 130 | print("[sd-webui-freeu]", f"XYZ Preset '{preset_key}' does not exist", file=sys.stderr) 131 | 132 | for k, v in xyz_attrs.items(): 133 | if k == "preset": 134 | continue 135 | 136 | instance.update_attr(k, v) 137 | 138 | 139 | STATE_ARGS_LEN = len(inspect.getfullargspec(State.__init__)[0]) - 1 # off by one because of self 140 | PRESETS_PATH = pathlib.Path(__file__).parent.parent / "presets.json" 141 | 142 | instance = State() 143 | default_presets = { 144 | "SD1.4 Recommendations": State( 145 | stage_infos=[ 146 | StageInfo(1.2, 0.9), 147 | StageInfo(1.4, 0.2), 148 | StageInfo(1, 1), 149 | ], 150 | ), 151 | "SD2.1 Recommendations": State( 152 | stage_infos=[ 153 | StageInfo(1.1, 0.9), 154 | StageInfo(1.2, 0.2), 155 | StageInfo(1, 1), 156 | ], 157 | ), 158 | "SDXL Recommendations": State( 159 | stage_infos=[ 160 | StageInfo(1.1, 0.6), 161 | StageInfo(1.2, 0.4), 162 | StageInfo(1, 1), 163 | ], 164 | ), 165 | } 166 | all_presets = {} 167 | 168 | 169 | def reload_presets(): 170 | all_presets.clear() 171 | all_presets.update(default_presets) 172 | all_presets.update(load_presets()) 173 | 174 | 175 | def load_presets(): 176 | if not PRESETS_PATH.exists(): 177 | return [] 178 | 179 | with open(PRESETS_PATH, "r") as f: 180 | return { 181 | k: State(**v) 182 | for k, v in json.load(f).items() 183 | } 184 | 185 | 186 | def save_presets(presets=None): 187 | if presets is None: 188 | presets = get_user_presets() 189 | 190 | presets = {k: v.to_dict() for k, v in presets.items()} 191 | 192 | with open(PRESETS_PATH, "w") as f: 193 | json.dump(presets, f, indent=4) 194 | 195 | 196 | def get_user_presets(): 197 | return { 198 | k: v 199 | for k, v in all_presets.items() 200 | if k not in default_presets 201 | } 202 | -------------------------------------------------------------------------------- /lib_free_u/unet.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import math 3 | import pathlib 4 | import sys 5 | from typing import Tuple, Union, Optional 6 | from lib_free_u import global_state 7 | from modules import scripts, shared 8 | from modules.sd_hijack_unet import th 9 | import torch 10 | 11 | 12 | def patch(): 13 | th.cat = functools.partial(free_u_cat_hijack, original_function=th.cat) 14 | 15 | cn_script_paths = [ 16 | str(pathlib.Path(scripts.basedir()).parent.parent / "extensions-builtin" / "sd-webui-controlnet"), 17 | str(pathlib.Path(scripts.basedir()).parent / "sd-webui-controlnet"), 18 | ] 19 | sys.path[0:0] = cn_script_paths 20 | cn_status = "enabled" 21 | try: 22 | import scripts.hook as controlnet_hook 23 | except ImportError: 24 | cn_status = "disabled" 25 | else: 26 | controlnet_hook.th.cat = functools.partial(free_u_cat_hijack, original_function=controlnet_hook.th.cat) 27 | finally: 28 | for p in cn_script_paths: 29 | sys.path.remove(p) 30 | 31 | print("[sd-webui-freeu]", f"Controlnet support: *{cn_status}*") 32 | 33 | 34 | def free_u_cat_hijack(hs, *args, original_function, **kwargs): 35 | if not global_state.instance.enable: 36 | return original_function(hs, *args, **kwargs) 37 | 38 | schedule_ratio = get_schedule_ratio() 39 | if schedule_ratio == 0: 40 | return original_function(hs, *args, **kwargs) 41 | 42 | try: 43 | h, h_skip = hs 44 | if list(kwargs.keys()) != ["dim"] or kwargs.get("dim", -1) != 1: 45 | return original_function(hs, *args, **kwargs) 46 | except ValueError: 47 | return original_function(hs, *args, **kwargs) 48 | 49 | dims = h.shape[1] 50 | try: 51 | index = [1280, 640, 320].index(dims) 52 | stage_info = global_state.instance.stage_infos[index] 53 | except ValueError: 54 | stage_info = None 55 | 56 | if stage_info is not None: 57 | region_begin, region_end, region_inverted = ratio_to_region(stage_info.backbone_width, stage_info.backbone_offset, dims) 58 | mask = torch.arange(dims, device=h.device) 59 | mask = (region_begin <= mask) & (mask <= region_end) 60 | if region_inverted: 61 | mask = ~mask 62 | mask = mask.reshape(1, -1, 1, 1).to(h.dtype) 63 | 64 | scale = get_backbone_scale( 65 | h, 66 | backbone_factor=lerp(1, stage_info.backbone_factor, schedule_ratio), 67 | ) 68 | h *= mask * scale + (1 - mask) 69 | 70 | h_skip = filter_skip( 71 | h_skip, 72 | threshold=stage_info.skip_cutoff, 73 | scale=lerp(1, stage_info.skip_factor, schedule_ratio), 74 | scale_high=lerp(1, stage_info.skip_high_end_factor, schedule_ratio), 75 | ) 76 | 77 | return original_function([h, h_skip], *args, **kwargs) 78 | 79 | 80 | def get_backbone_scale(h, backbone_factor): 81 | if global_state.instance.version == "1": 82 | return backbone_factor 83 | 84 | #if global_state.instance.version == "2": 85 | features_mean = h.mean(1, keepdim=True) 86 | batch_dims = h.shape[0] 87 | features_max, _ = torch.max(features_mean.view(batch_dims, -1), dim=-1, keepdim=True) 88 | features_min, _ = torch.min(features_mean.view(batch_dims, -1), dim=-1, keepdim=True) 89 | hidden_mean = (features_mean - features_min.unsqueeze(2).unsqueeze(3)) / (features_max - features_min).unsqueeze(2).unsqueeze(3) 90 | return 1 + (backbone_factor - 1) * hidden_mean 91 | 92 | 93 | def filter_skip(x, threshold, scale, scale_high): 94 | if scale == 1 and scale_high == 1: 95 | return x 96 | 97 | fft_device = x.device 98 | if not is_gpu_complex_supported(x): 99 | fft_device = "cpu" 100 | 101 | # FFT 102 | x_freq = torch.fft.fftn(x.to(fft_device, dtype=torch.float32), dim=(-2, -1)) 103 | x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1)) 104 | 105 | B, C, H, W = x_freq.shape 106 | mask = torch.full((B, C, H, W), float(scale_high), device=fft_device) 107 | 108 | crow, ccol = H // 2, W // 2 109 | threshold_row = max(1, math.floor(crow * threshold)) 110 | threshold_col = max(1, math.floor(ccol * threshold)) 111 | mask[..., crow - threshold_row:crow + threshold_row, ccol - threshold_col:ccol + threshold_col] = scale 112 | x_freq *= mask 113 | 114 | # IFFT 115 | x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1)) 116 | x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real.to(device=x.device, dtype=x.dtype) 117 | 118 | return x_filtered 119 | 120 | 121 | def ratio_to_region(width: float, offset: float, n: int) -> Tuple[int, int, bool]: 122 | if width < 0: 123 | offset += width 124 | width = -width 125 | width = min(width, 1) 126 | 127 | if offset < 0: 128 | offset = 1 + offset - int(offset) 129 | offset = math.fmod(offset, 1.0) 130 | 131 | if width + offset <= 1: 132 | inverted = False 133 | start = offset * n 134 | end = (width + offset) * n 135 | else: 136 | inverted = True 137 | start = (width + offset - 1) * n 138 | end = offset * n 139 | 140 | return round(start), round(end), inverted 141 | 142 | 143 | def get_schedule_ratio(): 144 | start_step = to_denoising_step(global_state.instance.start_ratio) 145 | stop_step = to_denoising_step(global_state.instance.stop_ratio) 146 | 147 | if start_step == stop_step: 148 | smooth_schedule_ratio = 0.0 149 | elif global_state.current_sampling_step < start_step: 150 | smooth_schedule_ratio = min(1.0, max(0.0, global_state.current_sampling_step / start_step)) 151 | else: 152 | smooth_schedule_ratio = min(1.0, max(0.0, 1 + (global_state.current_sampling_step - start_step) / (start_step - stop_step))) 153 | 154 | flat_schedule_ratio = 1.0 if start_step <= global_state.current_sampling_step < stop_step else 0.0 155 | 156 | return lerp(flat_schedule_ratio, smooth_schedule_ratio, global_state.instance.transition_smoothness) 157 | 158 | 159 | def to_denoising_step(number: Union[float, int], steps=None) -> int: 160 | if steps is None: 161 | steps = shared.state.sampling_steps 162 | 163 | if isinstance(number, float): 164 | return int(number * steps) 165 | 166 | return number 167 | 168 | 169 | def lerp(a, b, r): 170 | return (1-r)*a + r*b 171 | 172 | 173 | gpu_complex_support: Optional[bool] = None 174 | def is_gpu_complex_supported(x): 175 | global gpu_complex_support 176 | 177 | if x.is_cpu: 178 | return True 179 | 180 | if gpu_complex_support is not None: 181 | return gpu_complex_support 182 | 183 | # catch known cases in advance 184 | mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() 185 | try: 186 | import torch_directml 187 | except ImportError: 188 | dml_available = False 189 | else: 190 | dml_available = torch_directml.is_available() 191 | 192 | gpu_complex_support = not (mps_available or dml_available) 193 | if gpu_complex_support: 194 | # try filter_skip fft to make sure it is viable on the gpu 195 | try: 196 | torch.fft.fftn(x.float(), dim=(-2, -1)) 197 | except RuntimeError: 198 | gpu_complex_support = False 199 | 200 | return gpu_complex_support 201 | -------------------------------------------------------------------------------- /lib_free_u/xyz_grid.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from types import ModuleType 3 | from typing import Optional 4 | from modules import scripts 5 | from lib_free_u import global_state 6 | 7 | 8 | def patch(): 9 | xyz_module = find_xyz_module() 10 | if xyz_module is None: 11 | print("[sd-webui-freeu]", "xyz_grid.py not found.", file=sys.stderr) 12 | return 13 | xyz_module.axis_options.extend([ 14 | xyz_module.AxisOption("[FreeU] Enabled", str_to_bool, apply_global_state("enable"), choices=choices_bool), 15 | xyz_module.AxisOption("[FreeU] Version", str, apply_global_state("version", key_map=global_state.all_versions), choices=choices_version), 16 | xyz_module.AxisOption("[FreeU] Preset", str, apply_global_state("preset"), choices=choices_preset), 17 | xyz_module.AxisOption("[FreeU] Start At Step", int_or_float, apply_global_state("start_ratio")), 18 | xyz_module.AxisOption("[FreeU] Stop At Step", int_or_float, apply_global_state("stop_ratio")), 19 | xyz_module.AxisOption("[FreeU] Transition Smoothness", int_or_float, apply_global_state("transition_smoothness")), 20 | *[ 21 | opt 22 | for index in range(global_state.STAGES_COUNT) 23 | for opt in [ 24 | xyz_module.AxisOption(f"[FreeU] Stage {index+1} Backbone Scale", float, apply_global_state(f"b{index}")), 25 | xyz_module.AxisOption(f"[FreeU] Stage {index+1} Backbone Offset", float, apply_global_state(f"o{index}")), 26 | xyz_module.AxisOption(f"[FreeU] Stage {index+1} Backbone Width", float, apply_global_state(f"w{index}")), 27 | xyz_module.AxisOption(f"[FreeU] Stage {index+1} Skip Scale", float, apply_global_state(f"s{index}")), 28 | xyz_module.AxisOption(f"[FreeU] Stage {index+1} Skip Cutoff", float, apply_global_state(f"t{index}")), 29 | xyz_module.AxisOption(f"[FreeU] Stage {index+1} Skip High End Scale", float, apply_global_state(f"h{index}")), 30 | ] 31 | ] 32 | ]) 33 | 34 | 35 | def apply_global_state(k, key_map=None): 36 | def callback(_p, v, _vs): 37 | if key_map is not None: 38 | v = key_map[v] 39 | global_state.xyz_attrs[k] = v 40 | 41 | return callback 42 | 43 | 44 | def str_to_bool(string): 45 | string = str(string) 46 | if string in ["None", ""]: 47 | return None 48 | elif string.lower() in ["true", "1"]: 49 | return True 50 | elif string.lower() in ["false", "0"]: 51 | return False 52 | else: 53 | raise ValueError(f"Could not convert string to boolean: {string}") 54 | 55 | 56 | def int_or_float(string): 57 | try: 58 | return int(string) 59 | except ValueError: 60 | return float(string) 61 | 62 | 63 | def choices_bool(): 64 | return ["False", "True"] 65 | 66 | 67 | def choices_version(): 68 | return list(global_state.all_versions.keys()) 69 | 70 | 71 | def choices_preset(): 72 | presets = list(global_state.all_presets.keys()) 73 | presets.insert(0, "UI Settings") 74 | return presets 75 | 76 | 77 | def find_xyz_module() -> Optional[ModuleType]: 78 | for data in scripts.scripts_data: 79 | if data.script_class.__module__ in {"xyz_grid.py", "xy_grid.py"} and hasattr(data, "module"): 80 | return data.module 81 | 82 | return None 83 | -------------------------------------------------------------------------------- /scripts/freeu.py: -------------------------------------------------------------------------------- 1 | import json 2 | import gradio as gr 3 | from modules import scripts, script_callbacks, processing, shared 4 | from lib_free_u import global_state, unet, xyz_grid 5 | 6 | 7 | txt2img_steps_component = None 8 | img2img_steps_component = None 9 | txt2img_steps_callbacks = [] 10 | img2img_steps_callbacks = [] 11 | 12 | 13 | class FreeUScript(scripts.Script): 14 | def title(self): 15 | return "FreeU" 16 | 17 | def show(self, is_img2img): 18 | return scripts.AlwaysVisible 19 | 20 | def ui(self, is_img2img): 21 | global_state.reload_presets() 22 | default_stage_infos = next(iter(global_state.all_presets.values())).stage_infos 23 | 24 | with gr.Accordion(open=False, label=self.title()): 25 | with gr.Row(): 26 | with gr.Row(): 27 | enabled = gr.Checkbox( 28 | label="Enable", 29 | value=False, 30 | ) 31 | 32 | version = gr.Dropdown( 33 | show_label=False, 34 | elem_id=self.elem_id("version"), 35 | choices=list(global_state.all_versions.keys()), 36 | value=next(iter(reversed(global_state.all_versions.keys()))), 37 | ) 38 | 39 | preset_name = gr.Dropdown( 40 | show_label=False, 41 | choices=list(global_state.all_presets.keys()), 42 | value=next(iter(global_state.all_presets.keys())), 43 | type="value", 44 | elem_id=self.elem_id("preset_name"), 45 | allow_custom_value=True, 46 | tooltip="Apply button loads settings\nWrite custom name to enable save\nDelete automatically will save to file", 47 | size="sm", 48 | ) 49 | 50 | is_custom_preset = preset_name.value not in global_state.default_presets 51 | preset_exists = preset_name.value in global_state.all_presets 52 | 53 | apply_preset = gr.Button( 54 | value="✅", 55 | size="lg", 56 | elem_classes="tool", 57 | interactive=preset_exists, 58 | ) 59 | save_preset = gr.Button( 60 | value="💾", 61 | size="lg", 62 | elem_classes="tool", 63 | interactive=is_custom_preset, 64 | ) 65 | refresh_presets = gr.Button( 66 | value="🔄", 67 | size="lg", 68 | elem_classes="tool" 69 | ) 70 | delete_preset = gr.Button( 71 | value="🗑️", 72 | size="lg", 73 | elem_classes="tool", 74 | interactive=is_custom_preset and preset_exists, 75 | ) 76 | 77 | with gr.Row(): 78 | start_ratio = gr.Slider( 79 | label="Start At Step", 80 | elem_id=self.elem_id("start_at_step"), 81 | minimum=0, 82 | maximum=1, 83 | value=0, 84 | ) 85 | 86 | stop_ratio = gr.Slider( 87 | label="Stop At Step", 88 | elem_id=self.elem_id("stop_at_step"), 89 | minimum=0, 90 | maximum=1, 91 | value=1, 92 | ) 93 | 94 | transition_smoothness = gr.Slider( 95 | label="Transition Smoothness", 96 | elem_id=self.elem_id("transition_smoothness"), 97 | minimum=0, 98 | maximum=1, 99 | value=0, 100 | ) 101 | 102 | flat_stage_infos = [] 103 | 104 | for index in range(global_state.STAGES_COUNT): 105 | stage_n = index + 1 106 | default_stage_info = default_stage_infos[index] 107 | 108 | with gr.Accordion(open=index < 2, label=f"Stage {stage_n}"): 109 | with gr.Row(): 110 | backbone_scale = gr.Slider( 111 | label=f"Backbone {stage_n} Scale", 112 | elem_id=self.elem_id(f"backbone_scale_{stage_n}"), 113 | minimum=-1, 114 | maximum=3, 115 | value=default_stage_info.backbone_factor, 116 | ) 117 | 118 | backbone_offset = gr.Slider( 119 | label=f"Backbone {stage_n} Offset", 120 | elem_id=self.elem_id(f"backbone_offset_{stage_n}"), 121 | minimum=0, 122 | maximum=1, 123 | value=default_stage_info.backbone_offset, 124 | ) 125 | 126 | backbone_width = gr.Slider( 127 | label=f"Backbone {stage_n} Width", 128 | elem_id=self.elem_id(f"backbone_width_{stage_n}"), 129 | minimum=0, 130 | maximum=1, 131 | value=default_stage_info.backbone_width, 132 | ) 133 | 134 | with gr.Row(): 135 | skip_scale = gr.Slider( 136 | label=f"Skip {stage_n} Scale", 137 | elem_id=self.elem_id(f"skip_scale_{stage_n}"), 138 | minimum=-1, 139 | maximum=3, 140 | value=default_stage_info.skip_factor, 141 | ) 142 | 143 | skip_high_end_scale = gr.Slider( 144 | label=f"Skip {stage_n} High End Scale", 145 | elem_id=self.elem_id(f"skip_high_end_scale_{stage_n}"), 146 | minimum=-1, 147 | maximum=3, 148 | value=default_stage_info.skip_high_end_factor, 149 | ) 150 | 151 | skip_cutoff = gr.Slider( 152 | label=f"Skip {stage_n} Cutoff", 153 | elem_id=self.elem_id(f"skip_cutoff_{stage_n}"), 154 | minimum=0.0, 155 | maximum=1.0, 156 | value=default_stage_info.skip_cutoff, 157 | ) 158 | 159 | flat_stage_infos.extend([ 160 | backbone_scale, 161 | skip_scale, 162 | backbone_offset, 163 | backbone_width, 164 | skip_cutoff, 165 | skip_high_end_scale, 166 | ]) 167 | 168 | def on_preset_name_change(preset_name): 169 | is_custom_preset = preset_name not in global_state.default_presets 170 | preset_exists = preset_name in global_state.all_presets 171 | return ( 172 | gr.Button.update(interactive=preset_exists), 173 | gr.Button.update(interactive=is_custom_preset), 174 | gr.Button.update(interactive=is_custom_preset and preset_exists), 175 | ) 176 | 177 | preset_name.change( 178 | fn=on_preset_name_change, 179 | inputs=[preset_name], 180 | outputs=[apply_preset, save_preset, delete_preset], 181 | ) 182 | 183 | def on_apply_click(user_settings_name): 184 | preset = global_state.all_presets[user_settings_name] 185 | return ( 186 | gr.Slider.update(value=preset.start_ratio), 187 | gr.Slider.update(value=preset.stop_ratio), 188 | gr.Slider.update(value=preset.transition_smoothness), 189 | *[ 190 | gr.update(value=v) 191 | for stage_info in preset.stage_infos 192 | for v in stage_info.to_dict(include_default=True).values() 193 | ], 194 | ) 195 | 196 | apply_preset.click( 197 | fn=on_apply_click, 198 | inputs=[preset_name], 199 | outputs=[start_ratio, stop_ratio, transition_smoothness, *flat_stage_infos], 200 | ) 201 | 202 | def on_save_click(preset_name, start_ratio, stop_ratio, transition_smoothness, *flat_stage_infos): 203 | global_state.all_presets[preset_name] = global_state.State( 204 | stage_infos=flat_stage_infos, 205 | start_ratio=start_ratio, 206 | stop_ratio=stop_ratio, 207 | transition_smoothness=transition_smoothness, 208 | ) 209 | global_state.save_presets() 210 | 211 | return ( 212 | gr.Dropdown.update(choices=list(global_state.all_presets.keys())), 213 | gr.Button.update(interactive=True), 214 | gr.Button.update(interactive=True), 215 | ) 216 | 217 | save_preset.click( 218 | fn=on_save_click, 219 | inputs=[preset_name, start_ratio, stop_ratio, transition_smoothness, *flat_stage_infos], 220 | outputs=[preset_name, apply_preset, delete_preset], 221 | ) 222 | 223 | def on_refresh_click(preset_name): 224 | global_state.reload_presets() 225 | is_custom_preset = preset_name not in global_state.default_presets 226 | preset_exists = preset_name in global_state.all_presets 227 | 228 | return ( 229 | gr.Dropdown.update(value=preset_name, choices=list(global_state.all_presets.keys())), 230 | gr.Button.update(interactive=preset_exists), 231 | gr.Button.update(interactive=is_custom_preset), 232 | gr.Button.update(interactive=is_custom_preset and preset_exists), 233 | ) 234 | 235 | refresh_presets.click( 236 | fn=on_refresh_click, 237 | inputs=[preset_name], 238 | outputs=[preset_name, apply_preset, save_preset, delete_preset], 239 | ) 240 | 241 | def on_delete_click(preset_name): 242 | preset_name_index = list(global_state.all_presets.keys()).index(preset_name) 243 | del global_state.all_presets[preset_name] 244 | global_state.save_presets() 245 | 246 | preset_name_index = min(len(global_state.all_presets) - 1, preset_name_index) 247 | preset_names = list(global_state.all_presets.keys()) 248 | preset_name = preset_names[preset_name_index] 249 | 250 | is_custom_preset = preset_name not in global_state.default_presets 251 | preset_exists = preset_name in global_state.all_presets 252 | return ( 253 | gr.Dropdown.update(value=preset_name, choices=preset_names), 254 | gr.Button.update(interactive=preset_exists), 255 | gr.Button.update(interactive=is_custom_preset), 256 | gr.Button.update(interactive=is_custom_preset and preset_exists), 257 | ) 258 | 259 | delete_preset.click( 260 | fn=on_delete_click, 261 | inputs=[preset_name], 262 | outputs=[preset_name, apply_preset, save_preset, delete_preset], 263 | ) 264 | 265 | schedule_infotext = gr.HTML(visible=False, interactive=False) 266 | stages_infotext = gr.HTML(visible=False, interactive=False) 267 | version_infotext = gr.HTML(visible=False, interactive=False) 268 | 269 | def register_schedule_infotext_change(steps_component): 270 | schedule_infotext.change( 271 | fn=self.on_schedule_infotext_update, 272 | inputs=[schedule_infotext, steps_component], 273 | outputs=[schedule_infotext, start_ratio, stop_ratio, transition_smoothness], 274 | ) 275 | 276 | steps_component, steps_callbacks = ( 277 | (img2img_steps_component, img2img_steps_callbacks) 278 | if is_img2img else 279 | (txt2img_steps_component, txt2img_steps_callbacks) 280 | ) 281 | 282 | if steps_component is None: 283 | steps_callbacks.append(register_schedule_infotext_change) 284 | else: 285 | register_schedule_infotext_change(steps_component) 286 | 287 | stages_infotext.change( 288 | fn=self.on_stages_infotext_update, 289 | inputs=[stages_infotext], 290 | outputs=[stages_infotext, enabled, *flat_stage_infos], 291 | ) 292 | 293 | version_infotext.change( 294 | fn=self.on_version_infotext_update, 295 | inputs=[version_infotext], 296 | outputs=[version_infotext, version] 297 | ) 298 | 299 | self.infotext_fields = [ 300 | (schedule_infotext, "FreeU Schedule"), 301 | (stages_infotext, "FreeU Stages"), 302 | (version_infotext, "FreeU Version"), 303 | ] 304 | self.paste_field_names = [f for _, f in self.infotext_fields] 305 | 306 | return enabled, start_ratio, stop_ratio, transition_smoothness, version, *flat_stage_infos 307 | 308 | def on_schedule_infotext_update(self, infotext, steps): 309 | if not infotext: 310 | return (gr.skip(),) * 4 311 | 312 | start_ratio, stop_ratio, transition_smoothness, *_ = infotext.split(", ") 313 | 314 | return ( 315 | gr.update(value=""), 316 | gr.update(value=unet.to_denoising_step(xyz_grid.int_or_float(start_ratio), steps) / steps), 317 | gr.update(value=unet.to_denoising_step(xyz_grid.int_or_float(stop_ratio), steps) / steps), 318 | gr.update(value=float(transition_smoothness)), 319 | ) 320 | 321 | def on_stages_infotext_update(self, infotext): 322 | if not infotext: 323 | return (gr.skip(),) * (2 + global_state.STAGES_COUNT * global_state.STAGE_INFO_ARGS_LEN) 324 | 325 | stage_infos = json.loads(infotext) 326 | stage_infos = [ 327 | global_state.StageInfo(**stage_info) 328 | for stage_info in stage_infos 329 | ] 330 | stage_infos.extend([ 331 | global_state.StageInfo() 332 | for _ in range(global_state.STAGES_COUNT - len(stage_infos)) 333 | ]) 334 | 335 | return ( 336 | gr.update(value=""), 337 | gr.update(value=shared.opts.data.get("freeu_png_info_auto_enable", True)), 338 | *( 339 | gr.update(value=v) 340 | for stage_info in stage_infos 341 | for v in stage_info.to_dict(include_default=True).values() 342 | ) 343 | ) 344 | 345 | def on_version_infotext_update(self, infotext): 346 | if not infotext: 347 | return (gr.skip(),) * 2 348 | 349 | return ( 350 | gr.update(value=""), 351 | gr.update(value=global_state.reversed_all_versions.get(infotext, infotext)), 352 | ) 353 | 354 | def process( 355 | self, 356 | p: processing.StableDiffusionProcessing, 357 | *args 358 | ): 359 | if isinstance(args[0], dict): 360 | global_state.instance = global_state.State(**args[0]) 361 | elif isinstance(args[0], bool): 362 | stage_infos_begin = global_state.STATE_ARGS_LEN - 1 363 | global_state.instance = global_state.State( 364 | args[0], 365 | *[float(n) for n in args[1:stage_infos_begin-1]], 366 | args[stage_infos_begin-1], 367 | args[stage_infos_begin:], 368 | ) 369 | else: 370 | raise TypeError(f"Unrecognized args sequence starting with type {type(args[0])}") 371 | 372 | global_state.apply_xyz() 373 | global_state.xyz_attrs.clear() 374 | if not global_state.instance.enable: 375 | return 376 | 377 | last_d = False 378 | p.extra_generation_params["FreeU Stages"] = json.dumps(list(reversed([ 379 | stage_info.to_dict() 380 | for stage_info in reversed(global_state.instance.stage_infos) 381 | # strip all empty dicts 382 | if last_d or stage_info.to_dict() and (last_d := True) 383 | ]))) 384 | p.extra_generation_params["FreeU Schedule"] = ", ".join([ 385 | str(global_state.instance.start_ratio), 386 | str(global_state.instance.stop_ratio), 387 | str(global_state.instance.transition_smoothness), 388 | ]) 389 | p.extra_generation_params["FreeU Version"] = global_state.instance.version 390 | 391 | def process_batch(self, p, *args, **kwargs): 392 | global_state.current_sampling_step = 0 393 | 394 | 395 | def increment_sampling_step(*_args, **_kwargs): 396 | global_state.current_sampling_step += 1 397 | 398 | 399 | try: 400 | script_callbacks.on_cfg_after_cfg(increment_sampling_step) 401 | except AttributeError: 402 | # webui < 1.6.0 403 | # normally we should increment the current sampling step after cfg 404 | # but as long as we don't need to run code during cfg it should be fine to increment early 405 | script_callbacks.on_cfg_denoised(increment_sampling_step) 406 | 407 | 408 | def on_after_component(component, **kwargs): 409 | global txt2img_steps_component, img2img_steps_component 410 | 411 | if kwargs.get("elem_id", None) == "img2img_steps": 412 | img2img_steps_component = component 413 | for callback in img2img_steps_callbacks: 414 | callback(component) 415 | 416 | if kwargs.get("elem_id", None) == "txt2img_steps": 417 | txt2img_steps_component = component 418 | for callback in txt2img_steps_callbacks: 419 | callback(component) 420 | 421 | 422 | script_callbacks.on_after_component(on_after_component) 423 | 424 | 425 | def on_ui_settings(): 426 | section = ("freeu", "FreeU") 427 | shared.opts.add_option( 428 | "freeu_png_info_auto_enable", 429 | shared.OptionInfo( 430 | default=True, 431 | label="Auto enable when loading the PNG Info of a generation that used FreeU", 432 | section=section, 433 | ) 434 | ) 435 | 436 | 437 | script_callbacks.on_ui_settings(on_ui_settings) 438 | 439 | 440 | unet.patch() 441 | xyz_grid.patch() 442 | --------------------------------------------------------------------------------