├── .gitignore ├── LICENSE ├── README.md ├── automatic1111-CFGDenoiser-and-script_callbacks-mod-for-SAG.patch ├── resources └── img │ ├── xyz_grid-0001-232592377.png │ └── xyz_grid-0014-232592377.png └── scripts ├── SAG.py └── xyz_grid_support_sag.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implementation of Self Attention Guidance in webui 2 | https://arxiv.org/abs/2210.00939 3 | 4 | ## Additional setup requirements after installation: 5 | 6 | ### For AUTOMATIC1111 webui: 7 | at commit 22bcc7be 8 | 9 | run the following command in root directory stable-diffusion-webui/: 10 | ``` 11 | git apply --ignore-whitespace extensions/sd_webui_SAG/automatic1111-CFGDenoiser-and-script_callbacks-mod-for-SAG.patch 12 | ``` 13 | 14 | 15 | ### ~~For vladmandic webui:~~ 16 | No longer requires patching after commit [cb465b1](https://github.com/vladmandic/automatic/commit/cb465b12ddc5e4b5f6566030021a26630b927ba6) 17 | since required changes have been merged to upstream. 18 | 19 | 20 | Demos: 21 | ![xyz_grid-0014-232592377.png](resources%2Fimg%2Fxyz_grid-0014-232592377.png) 22 | ![xyz_grid-0001-232592377.png](resources%2Fimg%2Fxyz_grid-0001-232592377.png) -------------------------------------------------------------------------------- /automatic1111-CFGDenoiser-and-script_callbacks-mod-for-SAG.patch: -------------------------------------------------------------------------------- 1 | From cb129e420612813d74043aa4a6a49575b53e9c14 Mon Sep 17 00:00:00 2001 2 | From: Ashen 3 | Date: Fri, 21 Apr 2023 09:40:59 -0700 4 | Subject: [PATCH] CFGDenoiser and script_callbacks mod for SAG 5 | 6 | --- 7 | modules/script_callbacks.py | 34 +++++++++++++++++++++++++++++++ 8 | modules/sd_samplers_kdiffusion.py | 7 ++++++- 9 | 2 files changed, 40 insertions(+), 1 deletion(-) 10 | 11 | diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py 12 | index 07911876..d3d3df14 100644 13 | --- a/modules/script_callbacks.py 14 | +++ b/modules/script_callbacks.py 15 | @@ -53,6 +53,21 @@ class CFGDenoiserParams: 16 | 17 | 18 | class CFGDenoisedParams: 19 | + def __init__(self, x, sampling_step, total_sampling_steps, inner_model): 20 | + self.x = x 21 | + """Latent image representation in the process of being denoised""" 22 | + 23 | + self.sampling_step = sampling_step 24 | + """Current Sampling step number""" 25 | + 26 | + self.total_sampling_steps = total_sampling_steps 27 | + """Total number of sampling steps planned""" 28 | + 29 | + self.inner_model = inner_model 30 | + """Inner model reference that is being used for denoising""" 31 | + 32 | + 33 | +class AfterCFGCallbackParams: 34 | def __init__(self, x, sampling_step, total_sampling_steps): 35 | self.x = x 36 | """Latent image representation in the process of being denoised""" 37 | @@ -63,6 +78,8 @@ class CFGDenoisedParams: 38 | self.total_sampling_steps = total_sampling_steps 39 | """Total number of sampling steps planned""" 40 | 41 | + self.output_altered = False 42 | + """A flag for CFGDenoiser that indicates whether the output has been altered by the callback""" 43 | 44 | class UiTrainTabParams: 45 | def __init__(self, txt2img_preview_params): 46 | @@ -87,6 +104,7 @@ callback_map = dict( 47 | callbacks_image_saved=[], 48 | callbacks_cfg_denoiser=[], 49 | callbacks_cfg_denoised=[], 50 | + callbacks_cfg_after_cfg=[], 51 | callbacks_before_component=[], 52 | callbacks_after_component=[], 53 | callbacks_image_grid=[], 54 | @@ -177,6 +195,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams): 55 | report_exception(c, 'cfg_denoised_callback') 56 | 57 | 58 | +def cfg_after_cfg_callback(params: AfterCFGCallbackParams): 59 | + for c in callback_map['callbacks_cfg_after_cfg']: 60 | + try: 61 | + c.callback(params) 62 | + except Exception: 63 | + report_exception(c, 'cfg_after_cfg_callback') 64 | + 65 | + 66 | def before_component_callback(component, **kwargs): 67 | for c in callback_map['callbacks_before_component']: 68 | try: 69 | @@ -318,6 +344,14 @@ def on_cfg_denoised(callback): 70 | add_callback(callback_map['callbacks_cfg_denoised'], callback) 71 | 72 | 73 | +def on_cfg_after_cfg(callback): 74 | + """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations has completed. 75 | + The callback is called with one argument: 76 | + - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details. 77 | + """ 78 | + add_callback(callback_map['callbacks_cfg_after_cfg'], callback) 79 | + 80 | + 81 | def on_before_component(callback): 82 | """register a function to be called before a component is created. 83 | The callback is called with arguments: 84 | diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py 85 | index e9f08518..6ff55ba6 100644 86 | --- a/modules/sd_samplers_kdiffusion.py 87 | +++ b/modules/sd_samplers_kdiffusion.py 88 | @@ -9,6 +9,7 @@ from modules.shared import opts, state 89 | import modules.shared as shared 90 | from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback 91 | from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback 92 | +from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback 93 | 94 | samplers_k_diffusion = [ 95 | ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), 96 | @@ -146,7 +147,7 @@ class CFGDenoiser(torch.nn.Module): 97 | 98 | x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:])) 99 | 100 | - denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps) 101 | + denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) 102 | cfg_denoised_callback(denoised_params) 103 | 104 | devices.test_for_nans(x_out, "unet") 105 | @@ -164,6 +165,10 @@ class CFGDenoiser(torch.nn.Module): 106 | if self.mask is not None: 107 | denoised = self.init_latent * self.mask + self.nmask * denoised 108 | 109 | + after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) 110 | + cfg_after_cfg_callback(after_cfg_callback_params) 111 | + if after_cfg_callback_params.output_altered: 112 | + denoised = after_cfg_callback_params.x 113 | self.step += 1 114 | 115 | return denoised 116 | -- 117 | 2.40.0 118 | 119 | -------------------------------------------------------------------------------- /resources/img/xyz_grid-0001-232592377.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashen-sensored/sd_webui_SAG/83d66a6cd8af4b6c6d087f8f1c979f0da188d40b/resources/img/xyz_grid-0001-232592377.png -------------------------------------------------------------------------------- /resources/img/xyz_grid-0014-232592377.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashen-sensored/sd_webui_SAG/83d66a6cd8af4b6c6d087f8f1c979f0da188d40b/resources/img/xyz_grid-0014-232592377.png -------------------------------------------------------------------------------- /scripts/SAG.py: -------------------------------------------------------------------------------- 1 | 2 | from inspect import isfunction 3 | import torch 4 | from torch import nn, einsum 5 | import torch.nn.functional as F 6 | from einops import rearrange, repeat 7 | 8 | from modules.processing import StableDiffusionProcessing 9 | 10 | import math 11 | 12 | 13 | 14 | import modules.scripts as scripts 15 | from modules import shared 16 | import gradio as gr 17 | 18 | from modules.script_callbacks import on_cfg_denoiser,CFGDenoiserParams, CFGDenoisedParams, on_cfg_denoised, AfterCFGCallbackParams, on_cfg_after_cfg 19 | 20 | import os 21 | 22 | from scripts import xyz_grid_support_sag 23 | 24 | _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") 25 | def exists(val): 26 | return val is not None 27 | def default(val, d): 28 | if exists(val): 29 | return val 30 | return d() if isfunction(d) else d 31 | 32 | class LoggedSelfAttention(nn.Module): 33 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 34 | super().__init__() 35 | inner_dim = dim_head * heads 36 | context_dim = default(context_dim, query_dim) 37 | 38 | self.scale = dim_head ** -0.5 39 | self.heads = heads 40 | 41 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 42 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 43 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 44 | 45 | self.to_out = nn.Sequential( 46 | nn.Linear(inner_dim, query_dim), 47 | nn.Dropout(dropout) 48 | ) 49 | self.attn_probs = None 50 | 51 | def forward(self, x, context=None, mask=None): 52 | h = self.heads 53 | 54 | q = self.to_q(x) 55 | context = default(context, x) 56 | k = self.to_k(context) 57 | v = self.to_v(context) 58 | 59 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 60 | 61 | # force cast to fp32 to avoid overflowing 62 | if _ATTN_PRECISION == "fp32": 63 | with torch.autocast(enabled=False, device_type='cuda'): 64 | q, k = q.float(), k.float() 65 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 66 | else: 67 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 68 | 69 | del q, k 70 | 71 | if exists(mask): 72 | mask = rearrange(mask, 'b ... -> b (...)') 73 | max_neg_value = -torch.finfo(sim.dtype).max 74 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 75 | sim.masked_fill_(~mask, max_neg_value) 76 | 77 | # attention, what we cannot get enough of 78 | sim = sim.softmax(dim=-1) 79 | 80 | self.attn_probs = sim 81 | 82 | out = einsum('b i j, b j d -> b i d', sim, v) 83 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 84 | return self.to_out(out) 85 | 86 | def xattn_forward_log(self, x, context=None, mask=None): 87 | h = self.heads 88 | 89 | q = self.to_q(x) 90 | context = default(context, x) 91 | k = self.to_k(context) 92 | v = self.to_v(context) 93 | 94 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 95 | 96 | # force cast to fp32 to avoid overflowing 97 | if _ATTN_PRECISION == "fp32": 98 | with torch.autocast(enabled=False, device_type='cuda'): 99 | q, k = q.float(), k.float() 100 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 101 | else: 102 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 103 | 104 | del q, k 105 | 106 | if exists(mask): 107 | mask = rearrange(mask, 'b ... -> b (...)') 108 | max_neg_value = -torch.finfo(sim.dtype).max 109 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 110 | sim.masked_fill_(~mask, max_neg_value) 111 | 112 | # attention, what we cannot get enough of 113 | sim = sim.softmax(dim=-1) 114 | 115 | self.attn_probs = sim 116 | global current_selfattn_map 117 | current_selfattn_map = sim 118 | 119 | out = einsum('b i j, b j d -> b i d', sim, v) 120 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 121 | out = self.to_out(out) 122 | global current_outsize 123 | current_outsize = out.shape[-2:] 124 | return out 125 | 126 | saved_original_selfattn_forward = None 127 | current_selfattn_map = None 128 | current_sag_guidance_scale = 1.0 129 | sag_enabled = False 130 | sag_mask_threshold = 1.0 131 | 132 | current_xin = None 133 | current_outsize = (64,64) 134 | current_batch_size = 1 135 | current_degraded_pred= None 136 | current_unet_kwargs = {} 137 | current_uncond_pred = None 138 | current_degraded_pred_compensation = None 139 | 140 | def gaussian_blur_2d(img, kernel_size, sigma): 141 | ksize_half = (kernel_size - 1) * 0.5 142 | 143 | x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) 144 | 145 | pdf = torch.exp(-0.5 * (x / sigma).pow(2)) 146 | 147 | x_kernel = pdf / pdf.sum() 148 | x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) 149 | 150 | kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) 151 | kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) 152 | 153 | padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] 154 | 155 | img = F.pad(img, padding, mode="reflect") 156 | img = F.conv2d(img, kernel2d, groups=img.shape[-3]) 157 | 158 | return img 159 | class Script(scripts.Script): 160 | 161 | def __init__(self): 162 | pass 163 | 164 | def title(self): 165 | return "Self Attention Guidance" 166 | 167 | def show(self, is_img2img): 168 | return scripts.AlwaysVisible 169 | 170 | def denoiser_callback(self, parms: CFGDenoiserParams): 171 | if not sag_enabled: 172 | return 173 | global current_xin, current_batch_size 174 | 175 | # logging current uncond size for cond/uncond output separation 176 | current_batch_size = parms.text_uncond.shape[0] 177 | # logging current input for eps calculation later 178 | current_xin = parms.x[-current_batch_size:] 179 | 180 | # logging necessary information for SAG pred 181 | current_uncond_emb = parms.text_uncond 182 | current_sigma = parms.sigma 183 | current_image_cond_in = parms.image_cond 184 | global current_unet_kwargs 185 | current_unet_kwargs = { 186 | "sigma": current_sigma[-current_batch_size:], 187 | "image_cond": current_image_cond_in[-current_batch_size:], 188 | "text_uncond": current_uncond_emb, 189 | } 190 | 191 | 192 | 193 | def denoised_callback(self, params: CFGDenoisedParams): 194 | if not sag_enabled: 195 | return 196 | # output from DiscreteEpsDDPMDenoiser is already pred_x0 197 | uncond_output = params.x[-current_batch_size:] 198 | original_latents = uncond_output 199 | global current_uncond_pred 200 | current_uncond_pred = uncond_output 201 | 202 | # Produce attention mask 203 | # We're only interested in the last current_batch_size*head_count slices of logged self-attention map 204 | attn_map = current_selfattn_map[-current_batch_size*8:] 205 | bh, hw1, hw2 = attn_map.shape 206 | b, latent_channel, latent_h, latent_w = original_latents.shape 207 | h=8 208 | 209 | middle_layer_latent_size = [math.ceil(latent_h/8), math.ceil(latent_w/8)] 210 | 211 | attn_map = attn_map.reshape(b, h, hw1, hw2) 212 | attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > sag_mask_threshold 213 | attn_mask = ( 214 | attn_mask.reshape(b, middle_layer_latent_size[0], middle_layer_latent_size[1]) 215 | .unsqueeze(1) 216 | .repeat(1, latent_channel, 1, 1) 217 | .type(attn_map.dtype) 218 | ) 219 | attn_mask = F.interpolate(attn_mask, (latent_h, latent_w)) 220 | 221 | # Blur according to the self-attention mask 222 | degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) 223 | degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) 224 | 225 | renoised_degraded_latent = degraded_latents - (uncond_output - current_xin) 226 | # renoised_degraded_latent = degraded_latents 227 | # get predicted x0 in degraded direction 228 | global current_degraded_pred_compensation 229 | current_degraded_pred_compensation = uncond_output - degraded_latents 230 | if shared.sd_model.model.conditioning_key == "crossattn-adm": 231 | make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} 232 | else: 233 | make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} 234 | degraded_pred = params.inner_model(renoised_degraded_latent, current_unet_kwargs['sigma'], cond=make_condition_dict([current_unet_kwargs['text_uncond']], [current_unet_kwargs['image_cond']])) 235 | global current_degraded_pred 236 | current_degraded_pred = degraded_pred 237 | 238 | def cfg_after_cfg_callback(self, params: AfterCFGCallbackParams): 239 | if not sag_enabled: 240 | return 241 | 242 | params.x = params.x + (current_uncond_pred - (current_degraded_pred + current_degraded_pred_compensation)) * float(current_sag_guidance_scale) 243 | params.output_altered = True 244 | 245 | 246 | 247 | def ui(self, is_img2img): 248 | with gr.Accordion('Self Attention Guidance', open=False): 249 | enabled = gr.Checkbox(label="Enabled", default=False) 250 | scale = gr.Slider(label='Scale', minimum=-2.0, maximum=10.0, step=0.01, value=0.75) 251 | mask_threshold = gr.Slider(label='SAG Mask Threshold', minimum=0.0, maximum=2.0, step=0.01, value=1.0) 252 | 253 | return [enabled, scale, mask_threshold] 254 | 255 | 256 | 257 | def process(self, p: StableDiffusionProcessing, *args, **kwargs): 258 | enabled, scale, mask_threshold = args 259 | global sag_enabled, sag_mask_threshold 260 | if enabled: 261 | 262 | sag_enabled = True 263 | sag_mask_threshold = mask_threshold 264 | global current_sag_guidance_scale 265 | current_sag_guidance_scale = scale 266 | global saved_original_selfattn_forward 267 | # replace target self attention module in unet with ours 268 | 269 | org_attn_module = shared.sd_model.model.diffusion_model.middle_block._modules['1'].transformer_blocks._modules['0'].attn1 270 | saved_original_selfattn_forward = org_attn_module.forward 271 | org_attn_module.forward = xattn_forward_log.__get__(org_attn_module,org_attn_module.__class__) 272 | 273 | p.extra_generation_params["SAG Guidance Scale"] = scale 274 | p.extra_generation_params["SAG Mask Threshold"] = mask_threshold 275 | 276 | else: 277 | sag_enabled = False 278 | 279 | 280 | if not hasattr(self, 'callbacks_added'): 281 | on_cfg_denoiser(self.denoiser_callback) 282 | on_cfg_denoised(self.denoised_callback) 283 | on_cfg_after_cfg(self.cfg_after_cfg_callback) 284 | self.callbacks_added = True 285 | 286 | 287 | 288 | 289 | 290 | return 291 | 292 | def postprocess(self, p, processed, *args): 293 | enabled, scale, sag_mask_threshold = args 294 | if enabled: 295 | # restore original self attention module forward function 296 | attn_module = shared.sd_model.model.diffusion_model.middle_block._modules['1'].transformer_blocks._modules[ 297 | '0'].attn1 298 | attn_module.forward = saved_original_selfattn_forward 299 | return 300 | 301 | xyz_grid_support_sag.initialize(Script) -------------------------------------------------------------------------------- /scripts/xyz_grid_support_sag.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import modules.scripts as scripts 4 | 5 | 6 | 7 | 8 | 9 | 10 | xy_grid = None # XY Grid module 11 | script_class = None # additional_networks scripts.Script class 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | def update_script_args(p, value, arg_idx): 20 | global script_class 21 | for s in scripts.scripts_txt2img.alwayson_scripts: 22 | if isinstance(s, script_class): 23 | args = list(p.script_args) 24 | # print(f"Changed arg {arg_idx} from {args[s.args_from + arg_idx - 1]} to {value}") 25 | args[s.args_from + arg_idx] = value 26 | p.script_args = tuple(args) 27 | break 28 | 29 | 30 | 31 | 32 | def apply_module(p, x, xs, i): 33 | update_script_args(p, True, 0) # set Enabled to True 34 | update_script_args(p, x, 2 + 4 * i) # enabled, separate_weights, ({module}, model, weight_unet, weight_tenc), ... 35 | 36 | 37 | 38 | 39 | def apply_weight(p, x, xs, i): 40 | update_script_args(p, True, 0) 41 | update_script_args(p, x, 4 + 4 * i ) # enabled, separate_weights, (module, model, {weight_unet, weight_tenc}), ... 42 | update_script_args(p, x, 5 + 4 * i) 43 | 44 | 45 | def apply_weight_unet(p, x, xs, i): 46 | update_script_args(p, True, 0) 47 | update_script_args(p, x, 4 + 4 * i) # enabled, separate_weights, (module, model, {weight_unet}, weight_tenc), ... 48 | 49 | 50 | def apply_weight_tenc(p, x, xs, i): 51 | update_script_args(p, True, 0) 52 | update_script_args(p, x, 5 + 4 * i) # enabled, separate_weights, (module, model, weight_unet, {weight_tenc}), ... 53 | 54 | 55 | def apply_sag_guidance_scale(p, x, xs): 56 | update_script_args(p, x, 0) 57 | update_script_args(p, x, 1)# sag_guidance_scale 58 | 59 | def apply_sag_mask_threshold(p, x, xs): 60 | update_script_args(p, x, 0) 61 | update_script_args(p, x, 2)# sag_mask_threshold 62 | 63 | 64 | 65 | 66 | def initialize(script): 67 | global xy_grid, script_class 68 | xy_grid = None 69 | script_class = script 70 | for scriptDataTuple in scripts.scripts_data: 71 | if os.path.basename(scriptDataTuple.path) == "xy_grid.py" or os.path.basename(scriptDataTuple.path) == "xyz_grid.py": 72 | xy_grid = scriptDataTuple.module 73 | sag_guidance_scale = xy_grid.AxisOption("SAG Guidance Scale", float, lambda p, x, xs: apply_sag_guidance_scale(p,x,xs), xy_grid.format_value_add_label, None, cost=0.5) 74 | sag_mask_threshold = xy_grid.AxisOption("SAG Mask Threshold", float, lambda p, x, xs: apply_sag_mask_threshold(p,x,xs), xy_grid.format_value_add_label, None, cost=0.5) 75 | xy_grid.axis_options.extend([sag_guidance_scale, sag_mask_threshold]) 76 | 77 | --------------------------------------------------------------------------------