Uses the reinhard method of tonemapping (from comfyanonymous' ComfyUI Experiments) to clamp the CFG if the difference is too strong. 26 | 27 | Lower `tonemap_multiplier` clamps more noise, and a lower `tonemap_percentile` will increase the calculated standard deviation from the original noise. Play with it!
28 | * Arctan:Clamps the values dynamically using a simple arctan curve. [Link to interactive Desmos visualization](https://www.desmos.com/calculator/e4nrcdpqbl). 29 | 30 | Recommended values for testing: tonemap_multiplier of 5, tonemap_percentile of 90.
31 | * Quantile:Clamps the values using torch.quantile for obtaining the highest magnitudes, and clamping based on the result. 32 | 33 | 34 | `Closer to 100 percentile == stronger clamping`. Recommended values for testing: tonemap_multiplier of 1, tonemap_percentile of 99.
35 | * Gated:Clamps the values using torch.quantile, only if above a specific floor value, which is set by `tonemapping_multiplier`. Clamps the noise prediction latent based on the percentile. 36 | 37 | 38 | `Closer to 100 percentile == stronger clamping, lower tonemapping_multiplier == stronger clamping`. Recommended values for testing: tonemap_multiplier of 0.8-1, tonemap_percentile of 99.995.
39 | * CFG-Mimic:Attempts to mimic a lower or higher CFG based on `tonemapping_multiplier`, and clamps it using `tonemapping_percentile` with torch.quantile. 40 | 41 | 42 | `Closer to 100 percentile == stronger clamping, lower tonemapping_multiplier == stronger clamping`. Recommended values for testing: tonemap_multiplier of 0.33-1.0, tonemap_percentile of 100.
43 | * Spatial-Norm:Clamps the values according to the noise prediction's absolute mean in the spectral domain. `tonemap_multiplier` adjusts the strength of the clamping. 44 | 45 | 46 | `Lower tonemapping_multiplier == stronger clamping`. Recommended value for testing: tonemap_multiplier of 0.5-2.0.
47 | 48 | ### Contrast Explanation: 49 |Scales the pixel values by the standard deviation, achieving a more contrasty look. In practice, this can effectively act as a secondary CFG slider for stylization. It doesn't modify subject poses much, if at all, which can be great for those looking to get more oomf out of their low-cfg setups. 50 | 51 | Using a negative value will apply the inverse of the operation to the latent.
52 | 53 | ### Spectral Modification Explanation: 54 |We boost the low frequencies (low rate of change in the noise), and we lower the high frequencies (high rates of change in the noise). 55 | 56 | Change the low/high frequency range using `spectral_mod_percentile` (default of 5.0, which is the upper and lower 5th percentiles.) 57 | 58 | Increase/Decrease the strength of the adjustment by increasing `spectral_mod_multiplier` 59 | 60 | Beware of percentile values higher than 15 and multiplier values higher than 5, especially for hard clamping. Here be dragons, as large values may cause it to "noise-out", or become full of non-sensical noise, especially earlier in the diffusion process.
61 | 62 | 63 | #### Current Pipeline: 64 | >##### Add extra noise to conditioning -> Sharpen conditioning -> Convert to Noise Prediction -> Tonemap Noise Prediction -> Spectral Modification -> Modify contrast of noise prediction -> Rescale CFG -> Divisive Normalization -> Combat CFG Drift 65 | 66 | #### Why use this over `x` node? 67 | Since the `set_model_sampler_cfg_function` hijack in ComfyUI can only utilize a single function, we bundle many latent modification methods into one large function for processing. This is simpler than taking an existing hijack and modifying it, which may be possible, but my (Clybius') lack of Python/PyTorch knowledge leads to this being the optimal method for simplicity. If you know how to do this, feel free to reach out through any means! 68 | 69 | #### Can you implement `x` function? 70 | Depends. Is there existing code for such a function, with an open license for possible use in this repository? I could likely attempt adding it! Feel free to start an issue or to reach out for ideas you'd want implemented. 71 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from . import sampler_mega_modifier 2 | 3 | NODE_CLASS_MAPPINGS = { 4 | "Latent Diffusion Mega Modifier": sampler_mega_modifier.ModelSamplerLatentMegaModifier, 5 | } -------------------------------------------------------------------------------- /sampler_mega_modifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import random 6 | 7 | # Set manual seeds for noise 8 | # rand(n)_like. but with generator support 9 | def gen_like(f, input, generator=None): 10 | return f(input.size(), generator=generator).to(input) 11 | 12 | ''' 13 | The following snippet is utilized from https://github.com/Jamy-L/Pytorch-Contrast-Adaptive-Sharpening/ 14 | ''' 15 | def min_(tensor_list): 16 | # return the element-wise min of the tensor list. 17 | x = torch.stack(tensor_list) 18 | mn = x.min(axis=0)[0] 19 | return mn#torch.clamp(mn, min=-1) 20 | 21 | def max_(tensor_list): 22 | # return the element-wise max of the tensor list. 23 | x = torch.stack(tensor_list) 24 | mx = x.max(axis=0)[0] 25 | return mx#torch.clamp(mx, max=1) 26 | def contrast_adaptive_sharpening(image, amount): 27 | img = F.pad(image, pad=(1, 1, 1, 1)) 28 | absmean = torch.abs(image.mean()) 29 | 30 | a = img[..., :-2, :-2] 31 | b = img[..., :-2, 1:-1] 32 | c = img[..., :-2, 2:] 33 | d = img[..., 1:-1, :-2] 34 | e = img[..., 1:-1, 1:-1] 35 | f = img[..., 1:-1, 2:] 36 | g = img[..., 2:, :-2] 37 | h = img[..., 2:, 1:-1] 38 | i = img[..., 2:, 2:] 39 | 40 | # Computing contrast 41 | cross = (b, d, e, f, h) 42 | mn = min_(cross) 43 | mx = max_(cross) 44 | 45 | diag = (a, c, g, i) 46 | mn2 = min_(diag) 47 | mx2 = max_(diag) 48 | mx = mx + mx2 49 | mn = mn + mn2 50 | 51 | # Computing local weight 52 | inv_mx = torch.reciprocal(mx) 53 | amp = inv_mx * torch.minimum(mn, (2 - mx)) 54 | 55 | # scaling 56 | amp = torch.copysign(torch.sqrt(torch.abs(amp)), amp) 57 | w = - amp * (amount * (1/5 - 1/8) + 1/8) 58 | div = torch.reciprocal(1 + 4*w).clamp(-10, 10) 59 | 60 | output = ((b + d + f + h)*w + e) * div 61 | output = torch.nan_to_num(output) 62 | 63 | return (output.to(image.device)) 64 | 65 | ''' 66 | The following gaussian functions were utilized from the Fooocus UI, many thanks to github.com/Illyasviel ! 67 | ''' 68 | def gaussian_kernel(kernel_size, sigma): 69 | kernel = np.fromfunction( 70 | lambda x, y: (1 / (2 * np.pi * sigma ** 2)) * 71 | np.exp(-((x - (kernel_size - 1) / 2) ** 2 + (y - (kernel_size - 1) / 2) ** 2) / (2 * sigma ** 2)), 72 | (kernel_size, kernel_size) 73 | ) 74 | return kernel / np.sum(kernel) 75 | 76 | 77 | class GaussianBlur(nn.Module): 78 | def __init__(self, channels, kernel_size, sigma): 79 | super(GaussianBlur, self).__init__() 80 | self.channels = channels 81 | self.kernel_size = kernel_size 82 | self.sigma = sigma 83 | self.padding = kernel_size // 2 # Ensure output size matches input size 84 | self.register_buffer('kernel', torch.tensor(gaussian_kernel(kernel_size, sigma), dtype=torch.float32)) 85 | self.kernel = self.kernel.view(1, 1, kernel_size, kernel_size) 86 | self.kernel = self.kernel.expand(self.channels, -1, -1, -1) # Repeat the kernel for each input channel 87 | 88 | def forward(self, x): 89 | x = F.conv2d(x, self.kernel.to(x), padding=self.padding, groups=self.channels) 90 | return x 91 | 92 | gaussian_filter_2d = GaussianBlur(4, 7, 0.8) 93 | 94 | ''' 95 | As of August 18th (on Fooocus' GitHub), the gaussian functions were replaced by an anisotropic function for better stability. 96 | ''' 97 | Tensor = torch.Tensor 98 | Device = torch.DeviceObjType 99 | Dtype = torch.Type 100 | pad = torch.nn.functional.pad 101 | 102 | 103 | def _compute_zero_padding(kernel_size: tuple[int, int] | int) -> tuple[int, int]: 104 | ky, kx = _unpack_2d_ks(kernel_size) 105 | return (ky - 1) // 2, (kx - 1) // 2 106 | 107 | 108 | def _unpack_2d_ks(kernel_size: tuple[int, int] | int) -> tuple[int, int]: 109 | if isinstance(kernel_size, int): 110 | ky = kx = kernel_size 111 | else: 112 | assert len(kernel_size) == 2, '2D Kernel size should have a length of 2.' 113 | ky, kx = kernel_size 114 | 115 | ky = int(ky) 116 | kx = int(kx) 117 | return ky, kx 118 | 119 | 120 | def gaussian( 121 | window_size: int, sigma: Tensor | float, *, device: Device | None = None, dtype: Dtype | None = None 122 | ) -> Tensor: 123 | 124 | batch_size = sigma.shape[0] 125 | 126 | x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) 127 | 128 | if window_size % 2 == 0: 129 | x = x + 0.5 130 | 131 | gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) 132 | 133 | return gauss / gauss.sum(-1, keepdim=True) 134 | 135 | 136 | def get_gaussian_kernel1d( 137 | kernel_size: int, 138 | sigma: float | Tensor, 139 | force_even: bool = False, 140 | *, 141 | device: Device | None = None, 142 | dtype: Dtype | None = None, 143 | ) -> Tensor: 144 | 145 | return gaussian(kernel_size, sigma, device=device, dtype=dtype) 146 | 147 | 148 | def get_gaussian_kernel2d( 149 | kernel_size: tuple[int, int] | int, 150 | sigma: tuple[float, float] | Tensor, 151 | force_even: bool = False, 152 | *, 153 | device: Device | None = None, 154 | dtype: Dtype | None = None, 155 | ) -> Tensor: 156 | 157 | sigma = torch.Tensor([[sigma, sigma]]).to(device=device, dtype=dtype) 158 | 159 | ksize_y, ksize_x = _unpack_2d_ks(kernel_size) 160 | sigma_y, sigma_x = sigma[:, 0, None], sigma[:, 1, None] 161 | 162 | kernel_y = get_gaussian_kernel1d(ksize_y, sigma_y, force_even, device=device, dtype=dtype)[..., None] 163 | kernel_x = get_gaussian_kernel1d(ksize_x, sigma_x, force_even, device=device, dtype=dtype)[..., None] 164 | 165 | return kernel_y * kernel_x.view(-1, 1, ksize_x) 166 | 167 | 168 | def _bilateral_blur( 169 | input: Tensor, 170 | guidance: Tensor | None, 171 | kernel_size: tuple[int, int] | int, 172 | sigma_color: float | Tensor, 173 | sigma_space: tuple[float, float] | Tensor, 174 | border_type: str = 'reflect', 175 | color_distance_type: str = 'l1', 176 | ) -> Tensor: 177 | 178 | if isinstance(sigma_color, Tensor): 179 | sigma_color = sigma_color.to(device=input.device, dtype=input.dtype).view(-1, 1, 1, 1, 1) 180 | 181 | ky, kx = _unpack_2d_ks(kernel_size) 182 | pad_y, pad_x = _compute_zero_padding(kernel_size) 183 | 184 | padded_input = pad(input, (pad_x, pad_x, pad_y, pad_y), mode=border_type) 185 | unfolded_input = padded_input.unfold(2, ky, 1).unfold(3, kx, 1).flatten(-2) # (B, C, H, W, Ky x Kx) 186 | 187 | if guidance is None: 188 | guidance = input 189 | unfolded_guidance = unfolded_input 190 | else: 191 | padded_guidance = pad(guidance, (pad_x, pad_x, pad_y, pad_y), mode=border_type) 192 | unfolded_guidance = padded_guidance.unfold(2, ky, 1).unfold(3, kx, 1).flatten(-2) # (B, C, H, W, Ky x Kx) 193 | 194 | diff = unfolded_guidance - guidance.unsqueeze(-1) 195 | if color_distance_type == "l1": 196 | color_distance_sq = diff.abs().sum(1, keepdim=True).square() 197 | elif color_distance_type == "l2": 198 | color_distance_sq = diff.square().sum(1, keepdim=True) 199 | else: 200 | raise ValueError("color_distance_type only acceps l1 or l2") 201 | color_kernel = (-0.5 / sigma_color**2 * color_distance_sq).exp() # (B, 1, H, W, Ky x Kx) 202 | 203 | space_kernel = get_gaussian_kernel2d(kernel_size, sigma_space, device=input.device, dtype=input.dtype) 204 | space_kernel = space_kernel.view(-1, 1, 1, 1, kx * ky) 205 | 206 | kernel = space_kernel * color_kernel 207 | out = (unfolded_input * kernel).sum(-1) / kernel.sum(-1) 208 | return out 209 | 210 | 211 | def bilateral_blur( 212 | input: Tensor, 213 | kernel_size: tuple[int, int] | int = (13, 13), 214 | sigma_color: float | Tensor = 3.0, 215 | sigma_space: tuple[float, float] | Tensor = 3.0, 216 | border_type: str = 'reflect', 217 | color_distance_type: str = 'l1', 218 | ) -> Tensor: 219 | return _bilateral_blur(input, None, kernel_size, sigma_color, sigma_space, border_type, color_distance_type) 220 | 221 | 222 | def joint_bilateral_blur( 223 | input: Tensor, 224 | guidance: Tensor, 225 | kernel_size: tuple[int, int] | int, 226 | sigma_color: float | Tensor, 227 | sigma_space: tuple[float, float] | Tensor, 228 | border_type: str = 'reflect', 229 | color_distance_type: str = 'l1', 230 | ) -> Tensor: 231 | return _bilateral_blur(input, guidance, kernel_size, sigma_color, sigma_space, border_type, color_distance_type) 232 | 233 | 234 | class _BilateralBlur(torch.nn.Module): 235 | def __init__( 236 | self, 237 | kernel_size: tuple[int, int] | int, 238 | sigma_color: float | Tensor, 239 | sigma_space: tuple[float, float] | Tensor, 240 | border_type: str = 'reflect', 241 | color_distance_type: str = "l1", 242 | ) -> None: 243 | super().__init__() 244 | self.kernel_size = kernel_size 245 | self.sigma_color = sigma_color 246 | self.sigma_space = sigma_space 247 | self.border_type = border_type 248 | self.color_distance_type = color_distance_type 249 | 250 | def __repr__(self) -> str: 251 | return ( 252 | f"{self.__class__.__name__}" 253 | f"(kernel_size={self.kernel_size}, " 254 | f"sigma_color={self.sigma_color}, " 255 | f"sigma_space={self.sigma_space}, " 256 | f"border_type={self.border_type}, " 257 | f"color_distance_type={self.color_distance_type})" 258 | ) 259 | 260 | 261 | class BilateralBlur(_BilateralBlur): 262 | def forward(self, input: Tensor) -> Tensor: 263 | return bilateral_blur( 264 | input, self.kernel_size, self.sigma_color, self.sigma_space, self.border_type, self.color_distance_type 265 | ) 266 | 267 | 268 | class JointBilateralBlur(_BilateralBlur): 269 | def forward(self, input: Tensor, guidance: Tensor) -> Tensor: 270 | return joint_bilateral_blur( 271 | input, 272 | guidance, 273 | self.kernel_size, 274 | self.sigma_color, 275 | self.sigma_space, 276 | self.border_type, 277 | self.color_distance_type, 278 | ) 279 | 280 | 281 | # Below is perlin noise from https://github.com/tasptz/pytorch-perlin-noise/blob/main/perlin_noise/perlin_noise.py 282 | from torch import Generator, Tensor, lerp 283 | from torch.nn.functional import unfold 284 | from typing import Callable, Tuple 285 | from math import pi 286 | 287 | def get_positions(block_shape: Tuple[int, int]) -> Tensor: 288 | """ 289 | Generate position tensor. 290 | 291 | Arguments: 292 | block_shape -- (height, width) of position tensor 293 | 294 | Returns: 295 | position vector shaped (1, height, width, 1, 1, 2) 296 | """ 297 | bh, bw = block_shape 298 | positions = torch.stack( 299 | torch.meshgrid( 300 | [(torch.arange(b) + 0.5) / b for b in (bw, bh)], 301 | indexing="xy", 302 | ), 303 | -1, 304 | ).view(1, bh, bw, 1, 1, 2) 305 | return positions 306 | 307 | 308 | def unfold_grid(vectors: Tensor) -> Tensor: 309 | """ 310 | Unfold vector grid to batched vectors. 311 | 312 | Arguments: 313 | vectors -- grid vectors 314 | 315 | Returns: 316 | batched grid vectors 317 | """ 318 | batch_size, _, gpy, gpx = vectors.shape 319 | return ( 320 | unfold(vectors, (2, 2)) 321 | .view(batch_size, 2, 4, -1) 322 | .permute(0, 2, 3, 1) 323 | .view(batch_size, 4, gpy - 1, gpx - 1, 2) 324 | ) 325 | 326 | 327 | def smooth_step(t: Tensor) -> Tensor: 328 | """ 329 | Smooth step function [0, 1] -> [0, 1]. 330 | 331 | Arguments: 332 | t -- input values (any shape) 333 | 334 | Returns: 335 | output values (same shape as input values) 336 | """ 337 | return t * t * (3.0 - 2.0 * t) 338 | 339 | 340 | def perlin_noise_tensor( 341 | vectors: Tensor, positions: Tensor, step: Callable = None 342 | ) -> Tensor: 343 | """ 344 | Generate perlin noise from batched vectors and positions. 345 | 346 | Arguments: 347 | vectors -- batched grid vectors shaped (batch_size, 4, grid_height, grid_width, 2) 348 | positions -- batched grid positions shaped (batch_size or 1, block_height, block_width, grid_height or 1, grid_width or 1, 2) 349 | 350 | Keyword Arguments: 351 | step -- smooth step function [0, 1] -> [0, 1] (default: `smooth_step`) 352 | 353 | Raises: 354 | Exception: if position and vector shapes do not match 355 | 356 | Returns: 357 | (batch_size, block_height * grid_height, block_width * grid_width) 358 | """ 359 | if step is None: 360 | step = smooth_step 361 | 362 | batch_size = vectors.shape[0] 363 | # grid height, grid width 364 | gh, gw = vectors.shape[2:4] 365 | # block height, block width 366 | bh, bw = positions.shape[1:3] 367 | 368 | for i in range(2): 369 | if positions.shape[i + 3] not in (1, vectors.shape[i + 2]): 370 | raise Exception( 371 | f"Blocks shapes do not match: vectors ({vectors.shape[1]}, {vectors.shape[2]}), positions {gh}, {gw})" 372 | ) 373 | 374 | if positions.shape[0] not in (1, batch_size): 375 | raise Exception( 376 | f"Batch sizes do not match: vectors ({vectors.shape[0]}), positions ({positions.shape[0]})" 377 | ) 378 | 379 | vectors = vectors.view(batch_size, 4, 1, gh * gw, 2) 380 | positions = positions.view(positions.shape[0], bh * bw, -1, 2) 381 | 382 | step_x = step(positions[..., 0]) 383 | step_y = step(positions[..., 1]) 384 | 385 | row0 = lerp( 386 | (vectors[:, 0] * positions).sum(dim=-1), 387 | (vectors[:, 1] * (positions - positions.new_tensor((1, 0)))).sum(dim=-1), 388 | step_x, 389 | ) 390 | row1 = lerp( 391 | (vectors[:, 2] * (positions - positions.new_tensor((0, 1)))).sum(dim=-1), 392 | (vectors[:, 3] * (positions - positions.new_tensor((1, 1)))).sum(dim=-1), 393 | step_x, 394 | ) 395 | noise = lerp(row0, row1, step_y) 396 | return ( 397 | noise.view( 398 | batch_size, 399 | bh, 400 | bw, 401 | gh, 402 | gw, 403 | ) 404 | .permute(0, 3, 1, 4, 2) 405 | .reshape(batch_size, gh * bh, gw * bw) 406 | ) 407 | 408 | 409 | def perlin_noise( 410 | grid_shape: Tuple[int, int], 411 | out_shape: Tuple[int, int], 412 | batch_size: int = 1, 413 | generator: Generator = None, 414 | *args, 415 | **kwargs, 416 | ) -> Tensor: 417 | """ 418 | Generate perlin noise with given shape. `*args` and `**kwargs` are forwarded to `Tensor` creation. 419 | 420 | Arguments: 421 | grid_shape -- Shape of grid (height, width). 422 | out_shape -- Shape of output noise image (height, width). 423 | 424 | Keyword Arguments: 425 | batch_size -- (default: {1}) 426 | generator -- random generator used for grid vectors (default: {None}) 427 | 428 | Raises: 429 | Exception: if grid and out shapes do not match 430 | 431 | Returns: 432 | Noise image shaped (batch_size, height, width) 433 | """ 434 | # grid height and width 435 | gh, gw = grid_shape 436 | # output height and width 437 | oh, ow = out_shape 438 | # block height and width 439 | bh, bw = oh // gh, ow // gw 440 | 441 | if oh != bh * gh: 442 | raise Exception(f"Output height {oh} must be divisible by grid height {gh}") 443 | if ow != bw * gw != 0: 444 | raise Exception(f"Output width {ow} must be divisible by grid width {gw}") 445 | 446 | angle = torch.empty( 447 | [batch_size] + [s + 1 for s in grid_shape], *args, **kwargs 448 | ).uniform_(to=2.0 * pi, generator=generator) 449 | # random vectors on grid points 450 | vectors = unfold_grid(torch.stack((torch.cos(angle), torch.sin(angle)), dim=1)) 451 | # positions inside grid cells [0, 1) 452 | positions = get_positions((bh, bw)).to(vectors) 453 | return perlin_noise_tensor(vectors, positions).squeeze(0) 454 | 455 | def generate_1f_noise(tensor, alpha, k, generator=None): 456 | """Generate 1/f noise for a given tensor. 457 | 458 | Args: 459 | tensor: The tensor to add noise to. 460 | alpha: The parameter that determines the slope of the spectrum. 461 | k: A constant. 462 | 463 | Returns: 464 | A tensor with the same shape as `tensor` containing 1/f noise. 465 | """ 466 | fft = torch.fft.fft2(tensor) 467 | freq = torch.arange(1, len(fft) + 1, dtype=torch.float) 468 | spectral_density = k / freq**alpha 469 | noise = torch.randn(tensor.shape, generator=generator) * spectral_density 470 | return noise 471 | 472 | def green_noise(width, height, generator=None): 473 | noise = torch.randn(width, height, generator=generator) 474 | scale = 1.0 / (width * height) 475 | fy = torch.fft.fftfreq(width)[:, None] ** 2 476 | fx = torch.fft.fftfreq(height) ** 2 477 | f = fy + fx 478 | power = torch.sqrt(f) 479 | power[0, 0] = 1 480 | noise = torch.fft.ifft2(torch.fft.fft2(noise) / torch.sqrt(power)) 481 | noise *= scale / noise.std() 482 | return torch.real(noise) 483 | 484 | # Algorithm from https://github.com/v0xie/sd-webui-cads/ 485 | def add_cads_noise(y, timestep, cads_schedule_start, cads_schedule_end, cads_noise_scale, cads_rescale_factor, cads_rescale=False): 486 | timestep_as_float = (timestep / 999.0)[:, None, None, None].clone()[0].item() 487 | gamma = 0.0 488 | if timestep_as_float < cads_schedule_start: 489 | gamma = 1.0 490 | elif timestep_as_float > cads_schedule_end: 491 | gamma = 0.0 492 | else: 493 | gamma = (cads_schedule_end - timestep_as_float) / (cads_schedule_end - cads_schedule_start) 494 | 495 | y_mean, y_std = torch.mean(y), torch.std(y) 496 | y = np.sqrt(gamma) * y + cads_noise_scale * np.sqrt(1 - gamma) * torch.randn_like(y) 497 | 498 | if cads_rescale: 499 | y_scaled = (y - torch.mean(y)) / torch.std(y) * y_std + y_mean 500 | if not torch.isnan(y_scaled).any(): 501 | y = cads_rescale_factor * y_scaled + (1 - cads_rescale_factor) * y 502 | else: 503 | print("Encountered NaN in cads rescaling. Skipping rescaling.") 504 | return y 505 | 506 | # Algorithm from https://github.com/v0xie/sd-webui-cads/ 507 | def add_cads_custom_noise(y, noise, timestep, cads_schedule_start, cads_schedule_end, cads_noise_scale, cads_rescale_factor, cads_rescale=False): 508 | timestep_as_float = (timestep / 999.0)[:, None, None, None].clone()[0].item() 509 | gamma = 0.0 510 | if timestep_as_float < cads_schedule_start: 511 | gamma = 1.0 512 | elif timestep_as_float > cads_schedule_end: 513 | gamma = 0.0 514 | else: 515 | gamma = (cads_schedule_end - timestep_as_float) / (cads_schedule_end - cads_schedule_start) 516 | 517 | y_mean, y_std = torch.mean(y), torch.std(y) 518 | y = np.sqrt(gamma) * y + cads_noise_scale * np.sqrt(1 - gamma) * noise#.sub_(noise.mean()).div_(noise.std()) 519 | 520 | if cads_rescale: 521 | y_scaled = (y - torch.mean(y)) / torch.std(y) * y_std + y_mean 522 | if not torch.isnan(y_scaled).any(): 523 | y = cads_rescale_factor * y_scaled + (1 - cads_rescale_factor) * y 524 | else: 525 | print("Encountered NaN in cads rescaling. Skipping rescaling.") 526 | return y 527 | 528 | # Tonemapping functions 529 | 530 | def train_difference(a: Tensor, b: Tensor, c: Tensor) -> Tensor: 531 | diff_AB = a.float() - b.float() 532 | distance_A0 = torch.abs(b.float() - c.float()) 533 | distance_A1 = torch.abs(b.float() - a.float()) 534 | 535 | sum_distances = distance_A0 + distance_A1 536 | 537 | scale = torch.where( 538 | sum_distances != 0, distance_A1 / sum_distances, torch.tensor(0.0).float() 539 | ) 540 | sign_scale = torch.sign(b.float() - c.float()) 541 | scale = sign_scale * torch.abs(scale) 542 | new_diff = scale * torch.abs(diff_AB) 543 | return new_diff 544 | 545 | def gated_thresholding(percentile: float, floor: float, t: Tensor) -> Tensor: 546 | """ 547 | Args: 548 | percentile: float between 0.0 and 1.0. for example 0.995 would subject only the top 0.5%ile to clamping. 549 | t: [b, c, v] tensor in pixel or latent space (where v is the result of flattening w and h) 550 | """ 551 | a = t.abs() # Magnitudes 552 | q = torch.quantile(a, percentile, dim=2) # Get clamp value via top % of magnitudes 553 | q.clamp_(min=floor) 554 | q = q.unsqueeze(2).expand(*t.shape) 555 | t = t.clamp(-q, q) # Clamp latent with magnitude value 556 | t = t / q 557 | return t 558 | 559 | def dyn_thresh_gate(latent: Tensor, centered_magnitudes: Tensor, tonemap_percentile: float, floor: float, ceil: float): 560 | if centered_magnitudes.lt(torch.tensor(ceil, device=centered_magnitudes.device)).all().item(): # If the magnitudes are less than the ceiling 561 | return latent # Return the unmodified centered latent 562 | else: 563 | latent = gated_thresholding(tonemap_percentile, floor, latent) # If the magnitudes are higher than the ceiling 564 | return latent # Gated-dynamic thresholding by Birchlabs 565 | 566 | def spatial_norm_thresholding(x0, value): 567 | # b c h w 568 | pow_x0 = torch.pow(torch.abs(x0), 2) 569 | s = pow_x0.mean(1, keepdim=True).sqrt().clamp(min=value) 570 | return x0 * (value / s) 571 | 572 | def spatial_norm_chw_thresholding(x0, value): 573 | # b c h w 574 | pow_x0 = torch.pow(torch.abs(x0), 2) 575 | s = pow_x0.mean(dim=(1, 2, 3), keepdim=True).sqrt().clamp(min=value) 576 | return x0 * (value / s) 577 | 578 | # Contrast function 579 | 580 | def contrast(x: Tensor): 581 | # Calculate the mean and standard deviation of the pixel values 582 | #mean = x.mean(dim=(1,2,3), keepdim=True) 583 | stddev = x.std(dim=(1,2,3), keepdim=True) 584 | # Scale the pixel values by the standard deviation 585 | scaled_pixels = (x) / stddev 586 | return scaled_pixels 587 | 588 | def contrast_with_mean(x: Tensor): 589 | # Calculate the mean and standard deviation of the pixel values 590 | #mean = x.mean(dim=(2,3), keepdim=True) 591 | stddev = x.std(dim=(1,2,3), keepdim=True) 592 | diff_mean = ((x / stddev) - x).mean(dim=(1,2,3), keepdim=True) 593 | # Scale the pixel values by the standard deviation 594 | scaled_pixels = x / stddev 595 | return scaled_pixels - diff_mean 596 | 597 | def center_latent(tensor): #https://birchlabs.co.uk/machine-learning#combating-mean-drift-in-cfg 598 | """Centers on 0 to combat CFG drift.""" 599 | tensor = tensor - tensor.mean(dim=(-2, -1)).unsqueeze(-1).unsqueeze(-1).expand(tensor.shape) 600 | return tensor 601 | 602 | def center_0channel(tensor): #https://birchlabs.co.uk/machine-learning#combating-mean-drift-in-cfg 603 | """Centers on 0 to combat CFG drift.""" 604 | std_dev_0 = tensor[:, [0]].std() 605 | mean_0 = tensor[:, [0]].mean() 606 | mean_12 = tensor[:, [1,2]].mean() 607 | mean_3 = tensor[:, [3]].mean() 608 | 609 | #tensor[:, [0]] /= std_dev_0 610 | tensor[:, [0]] -= mean_0 611 | tensor[:, [0]] += torch.copysign(torch.pow(torch.abs(mean_0), 1.5), mean_0) 612 | #tensor[:, [1, 2]] -= tensor[:, [1, 2]].mean() 613 | tensor[:, [1, 2]] -= mean_12 * 0.5 614 | tensor[:, [3]] -= mean_3 615 | tensor[:, [3]] += torch.copysign(torch.pow(torch.abs(mean_3), 1.5), mean_3) 616 | return tensor# - tensor.mean(dim=(2,3), keepdim=True) 617 | 618 | def channel_sharpen(tensor): 619 | """Centers on 0 to combat CFG drift.""" 620 | flattened = tensor.flatten(2) 621 | flat_std = flattened.std(dim=(2)).unsqueeze(2).expand(flattened.shape) 622 | flattened *= flat_std 623 | flattened -= flattened.mean(dim=(2)).unsqueeze(2).expand(flattened.shape) 624 | flattened /= flat_std 625 | tensor = flattened.unflatten(2, tensor.shape[2:]) 626 | return tensor 627 | 628 | 629 | def center_012channel(tensor): #https://birchlabs.co.uk/machine-learning#combating-mean-drift-in-cfg 630 | """Centers on 0 to combat CFG drift.""" 631 | curr_tens = tensor[:, [0,1,2]] 632 | tensor[:, [0,1,2]] -= curr_tens.mean() 633 | return tensor 634 | 635 | def center_latent_perchannel(tensor): # Does nothing different than above 636 | """Centers on 0 to combat CFG drift.""" 637 | flattened = tensor.flatten(2) 638 | flattened = flattened - flattened.mean(dim=(2)).unsqueeze(2).expand(flattened.shape) 639 | tensor = flattened.unflatten(2, tensor.shape[2:]) 640 | return tensor 641 | 642 | def center_latent_perchannel_with_magnitudes(tensor): # Does nothing different than above 643 | """Centers on 0 to combat CFG drift.""" 644 | flattened = tensor.flatten(2) 645 | flattened_magnitude = (torch.linalg.vector_norm(flattened, dim=(2), keepdim=True) + 0.0000000001) 646 | flattened /= flattened_magnitude 647 | flattened = flattened - flattened.mean(dim=(2)).unsqueeze(2).expand(flattened.shape) 648 | flattened *= flattened_magnitude 649 | tensor = flattened.unflatten(2, tensor.shape[2:]) 650 | return tensor 651 | 652 | def center_latent_perchannel_with_decorrelate(tensor): # Decorrelates data, slight change, test and play with it. 653 | """Centers on 0 to combat CFG drift, preprocesses the latent with decorrelation""" 654 | tensor = decorrelate_data(tensor) 655 | flattened = tensor.flatten(2) 656 | flattened_magnitude = (torch.linalg.vector_norm(flattened, dim=(2), keepdim=True) + 0.0000000001) 657 | flattened /= flattened_magnitude 658 | flattened = flattened - flattened.mean(dim=(2)).unsqueeze(2).expand(flattened.shape) 659 | flattened *= flattened_magnitude 660 | tensor = flattened.unflatten(2, tensor.shape[2:]) 661 | return tensor 662 | 663 | def center_latent_median(tensor): 664 | flattened = tensor.flatten(2) 665 | median = flattened.median() 666 | scaled_data = (flattened - median) 667 | scaled_data = scaled_data.unflatten(2, tensor.shape[2:]) 668 | return scaled_data 669 | 670 | def divisive_normalization(image_tensor, neighborhood_size, threshold=1e-6): 671 | # Compute the local mean and local variance 672 | local_mean = F.avg_pool2d(image_tensor, neighborhood_size, stride=1, padding=neighborhood_size // 2, count_include_pad=False) 673 | local_mean_squared = local_mean**2 674 | 675 | local_variance = F.avg_pool2d(image_tensor**2, neighborhood_size, stride=1, padding=neighborhood_size // 2, count_include_pad=False) - local_mean_squared 676 | 677 | # Add a small value to prevent division by zero 678 | local_variance = local_variance + threshold 679 | 680 | # Apply divisive normalization 681 | normalized_tensor = image_tensor / torch.sqrt(local_variance) 682 | 683 | return normalized_tensor 684 | 685 | def decorrelate_data(data): 686 | """flattened = tensor.flatten(2).squeeze(0) # this code aint shit, yo 687 | cov_matrix = torch.cov(flattened) 688 | sqrt_inv_cov_matrix = torch.linalg.inv(torch.sqrt(cov_matrix)) 689 | decorrelated_tensor = torch.dot(flattened, sqrt_inv_cov_matrix.T) 690 | decorrelated_tensor = decorrelated_tensor.unflatten(2, tensor.shape[2:]).unsqueeze(0)""" 691 | 692 | # Reshape the 4D tensor to a 2D tensor for covariance calculation 693 | num_samples, num_channels, height, width = data.size() 694 | data_reshaped = data.view(num_samples, num_channels, -1) 695 | data_reshaped = data_reshaped - torch.mean(data_reshaped, dim=2, keepdim=True) 696 | 697 | # Compute covariance matrix 698 | cov_matrix = torch.matmul(data_reshaped, data_reshaped.transpose(1, 2)) / (height * width - 1) 699 | 700 | # Compute the inverse square root of the covariance matrix 701 | u, s, v = torch.svd(cov_matrix) 702 | sqrt_inv_cov_matrix = torch.matmul(u, torch.matmul(torch.diag_embed(1.0 / torch.sqrt(s)), v.transpose(1, 2))) 703 | 704 | # Reshape sqrt_inv_cov_matrix to match the dimensions of data_reshaped 705 | sqrt_inv_cov_matrix = sqrt_inv_cov_matrix.unsqueeze(0).expand(num_samples, -1, -1, -1) 706 | 707 | # Decorrelate the data 708 | decorrelated_data = torch.matmul(data_reshaped.transpose(1, 2), sqrt_inv_cov_matrix.transpose(2, 3)) 709 | decorrelated_data = decorrelated_data.transpose(2, 3) 710 | 711 | # Reshape back to the original shape 712 | decorrelated_data = decorrelated_data.view(num_samples, num_channels, height, width) 713 | 714 | return decorrelated_data.to(data.device) 715 | 716 | def get_low_frequency_noise(image: Tensor, threshold: float): 717 | # Convert image to Fourier domain 718 | fourier = torch.fft.fft2(image, dim=(-2, -1)) # Apply FFT along Height and Width dimensions 719 | 720 | # Compute the power spectrum 721 | power_spectrum = torch.abs(fourier) ** 2 722 | 723 | threshold = threshold ** 2 724 | 725 | # Drop low-frequency components 726 | mask = (power_spectrum < threshold).float() 727 | filtered_fourier = fourier * mask 728 | 729 | # Inverse transform back to spatial domain 730 | inverse_transformed = torch.fft.ifft2(filtered_fourier, dim=(-2, -1)) # Apply IFFT along Height and Width dimensions 731 | 732 | return inverse_transformed.real.to(image.device) 733 | 734 | def spectral_modulation(image: Tensor, modulation_multiplier: float, spectral_mod_percentile: float): # Reference implementation by Clybius, 2023 :tm::c::r: (jk idc who uses it :3) 735 | # Convert image to Fourier domain 736 | fourier = torch.fft.fft2(image, dim=(-2, -1)) # Apply FFT along Height and Width dimensions 737 | 738 | log_amp = torch.log(torch.sqrt(fourier.real ** 2 + fourier.imag ** 2)) 739 | 740 | quantile_low = torch.quantile( 741 | log_amp.abs().flatten(2), 742 | spectral_mod_percentile * 0.01, 743 | dim = 2 744 | ).unsqueeze(-1).unsqueeze(-1).expand(log_amp.shape) 745 | 746 | quantile_high = torch.quantile( 747 | log_amp.abs().flatten(2), 748 | 1 - (spectral_mod_percentile * 0.01), 749 | dim = 2 750 | ).unsqueeze(-1).unsqueeze(-1).expand(log_amp.shape) 751 | 752 | # Increase low-frequency components 753 | mask_low = ((log_amp < quantile_low).float() + 1).clamp_(max=1.5) # If lower than low 5% quantile, set to 1.5, otherwise 1 754 | # Decrease high-frequency components 755 | mask_high = ((log_amp < quantile_high).float()).clamp_(min=0.5) # If lower than high 5% quantile, set to 1, otherwise 0.5 756 | filtered_fourier = fourier * ((mask_low * mask_high) ** modulation_multiplier) # Effectively 757 | 758 | # Inverse transform back to spatial domain 759 | inverse_transformed = torch.fft.ifft2(filtered_fourier, dim=(-2, -1)) # Apply IFFT along Height and Width dimensions 760 | 761 | return inverse_transformed.real.to(image.device) 762 | 763 | def spectral_modulation_soft(image: Tensor, modulation_multiplier: float, spectral_mod_percentile: float): # Modified for soft quantile adjustment using a novel:tm::c::r: method titled linalg. 764 | # Convert image to Fourier domain 765 | fourier = torch.fft.fft2(image, dim=(-2, -1)) # Apply FFT along Height and Width dimensions 766 | 767 | log_amp = torch.log(torch.sqrt(fourier.real ** 2 + fourier.imag ** 2)) 768 | 769 | quantile_low = torch.quantile( 770 | log_amp.abs().flatten(2), 771 | spectral_mod_percentile * 0.01, 772 | dim = 2 773 | ).unsqueeze(-1).unsqueeze(-1).expand(log_amp.shape) 774 | 775 | quantile_high = torch.quantile( 776 | log_amp.abs().flatten(2), 777 | 1 - (spectral_mod_percentile * 0.01), 778 | dim = 2 779 | ).unsqueeze(-1).unsqueeze(-1).expand(log_amp.shape) 780 | 781 | quantile_max = torch.quantile( 782 | log_amp.abs().flatten(2), 783 | 1, 784 | dim = 2 785 | ).unsqueeze(-1).unsqueeze(-1).expand(log_amp.shape) 786 | 787 | # Decrease high-frequency components 788 | mask_high = log_amp > quantile_high # If we're larger than 95th percentile 789 | 790 | additive_mult_high = torch.where( 791 | mask_high, 792 | 1 - ((log_amp - quantile_high) / (quantile_max - quantile_high)).clamp_(max=0.5), # (1) - (0-1), where 0 is 95th %ile and 1 is 100%ile 793 | torch.tensor(1.0) 794 | ) 795 | 796 | 797 | # Increase low-frequency components 798 | mask_low = log_amp < quantile_low 799 | additive_mult_low = torch.where( 800 | mask_low, 801 | 1 + (1 - (log_amp / quantile_low)).clamp_(max=0.5), # (1) + (0-1), where 0 is 5th %ile and 1 is 0%ile 802 | torch.tensor(1.0) 803 | ) 804 | 805 | mask_mult = ((additive_mult_low * additive_mult_high) ** modulation_multiplier).clamp_(min=0.05, max=20) 806 | #print(mask_mult) 807 | filtered_fourier = fourier * mask_mult 808 | 809 | # Inverse transform back to spatial domain 810 | inverse_transformed = torch.fft.ifft2(filtered_fourier, dim=(-2, -1)) # Apply IFFT along Height and Width dimensions 811 | 812 | return inverse_transformed.real.to(image.device) 813 | 814 | def pyramid_noise_like(x, discount=0.9, generator=None, rand_source=random): 815 | b, c, w, h = x.shape # EDIT: w and h get over-written, rename for a different variant! 816 | u = torch.nn.Upsample(size=(w, h), mode='nearest-exact') 817 | noise = gen_like(torch.randn, x, generator=generator) 818 | for i in range(10): 819 | r = rand_source.random()*2+2 # Rather than always going 2x, 820 | w, h = max(1, int(w/(r**i))), max(1, int(h/(r**i))) 821 | noise += u(torch.randn(b, c, w, h, generator=generator).to(x)) * discount**i 822 | if w==1 or h==1: break # Lowest resolution is 1x1 823 | return noise/noise.std() # Scaled back to roughly unit variance 824 | 825 | import math 826 | def dyn_cfg_modifier(conditioning, unconditioning, method, cond_scale, time_mult): 827 | match method: 828 | case "dyncfg-halfcosine": 829 | noise_pred = conditioning - unconditioning 830 | 831 | noise_pred_magnitude = (torch.linalg.vector_norm(noise_pred, dim=(1)) + 0.0000000001)[:,None] 832 | 833 | time = time_mult.item() 834 | time_factor = -(math.cos(0.5 * time * math.pi) / 2) + 1 835 | noise_pred_timescaled_magnitude = (torch.linalg.vector_norm(noise_pred * time_factor, dim=(1)) + 0.0000000001)[:,None] 836 | 837 | noise_pred /= noise_pred_magnitude 838 | noise_pred *= noise_pred_timescaled_magnitude 839 | return noise_pred 840 | case "dyncfg-halfcosine-mimic": 841 | noise_pred = conditioning - unconditioning 842 | 843 | noise_pred_magnitude = (torch.linalg.vector_norm(noise_pred, dim=(1)) + 0.0000000001)[:,None] 844 | 845 | time = time_mult.item() 846 | time_factor = -(math.cos(0.5 * time * math.pi) / 2) + 1 847 | 848 | latent = noise_pred 849 | 850 | mimic_latent = noise_pred * time_factor 851 | mimic_flattened = mimic_latent.flatten(2) 852 | mimic_means = mimic_flattened.mean(dim=2).unsqueeze(2) 853 | mimic_recentered = mimic_flattened - mimic_means 854 | mimic_abs = mimic_recentered.abs() 855 | mimic_max = mimic_abs.max(dim=2).values.unsqueeze(2) 856 | 857 | latent_flattened = latent.flatten(2) 858 | latent_means = latent_flattened.mean(dim=2).unsqueeze(2) 859 | latent_recentered = latent_flattened - latent_means 860 | latent_abs = latent_recentered.abs() 861 | latent_q = torch.quantile(latent_abs, 0.995, dim=2).unsqueeze(2) 862 | s = torch.maximum(latent_q, mimic_max) 863 | pred_clamped = noise_pred.flatten(2).clamp(-s, s) 864 | pred_normalized = pred_clamped / s 865 | pred_renorm = pred_normalized * mimic_max 866 | pred_uncentered = pred_renorm + latent_means 867 | noise_pred_degraded = pred_uncentered.unflatten(2, noise_pred.shape[2:]) 868 | 869 | noise_pred /= noise_pred_magnitude 870 | 871 | noise_pred_timescaled_magnitude = (torch.linalg.vector_norm(noise_pred_degraded, dim=(1)) + 0.0000000001)[:,None] 872 | noise_pred *= noise_pred_timescaled_magnitude 873 | return noise_pred 874 | 875 | 876 | class ModelSamplerLatentMegaModifier: 877 | @classmethod 878 | def INPUT_TYPES(s): 879 | return {"required": { "model": ("MODEL",), 880 | "sharpness_multiplier": ("FLOAT", {"default": 0.0, "min": -100.0, "max": 100.0, "step": 0.1}), 881 | "sharpness_method": (["anisotropic", "joint-anisotropic", "gaussian", "cas"], ), 882 | "tonemap_multiplier": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.01}), 883 | "tonemap_method": (["reinhard", "reinhard_perchannel", "arctan", "quantile", "gated", "cfg-mimic", "spatial-norm"], ), 884 | "tonemap_percentile": ("FLOAT", {"default": 100.0, "min": 0.0, "max": 100.0, "step": 0.005}), 885 | "contrast_multiplier": ("FLOAT", {"default": 0.0, "min": -100.0, "max": 100.0, "step": 0.1}), 886 | "combat_method": (["subtract", "subtract_channels", "subtract_median", "sharpen"], ), 887 | "combat_cfg_drift": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}), 888 | "rescale_cfg_phi": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}), 889 | "extra_noise_type": (["gaussian", "uniform", "perlin", "pink", "green", "pyramid"], ), 890 | "extra_noise_method": (["add", "add_scaled", "speckle", "cads", "cads_rescaled", "cads_speckle", "cads_speckle_rescaled"], ), 891 | "extra_noise_multiplier": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}), 892 | "extra_noise_lowpass": ("INT", {"default": 100, "min": 0, "max": 1000, "step": 1}), 893 | "divisive_norm_size": ("INT", {"default": 127, "min": 1, "max": 255, "step": 1}), 894 | "divisive_norm_multiplier": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), 895 | "spectral_mod_mode": (["hard_clamp", "soft_clamp"], ), 896 | "spectral_mod_percentile": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.01}), 897 | "spectral_mod_multiplier": ("FLOAT", {"default": 0.0, "min": -15.0, "max": 15.0, "step": 0.01}), 898 | "affect_uncond": (["None", "Sharpness"], ), 899 | "dyn_cfg_augmentation": (["None", "dyncfg-halfcosine", "dyncfg-halfcosine-mimic"], ), 900 | }, 901 | "optional": { "seed": ("INT", {"min": 0, "max": 0xffffffffffffffff}) 902 | }} 903 | RETURN_TYPES = ("MODEL",) 904 | FUNCTION = "mega_modify" 905 | 906 | CATEGORY = "clybNodes" 907 | 908 | def mega_modify(self, model, sharpness_multiplier, sharpness_method, tonemap_multiplier, tonemap_method, tonemap_percentile, contrast_multiplier, combat_method, combat_cfg_drift, rescale_cfg_phi, extra_noise_type, extra_noise_method, extra_noise_multiplier, extra_noise_lowpass, divisive_norm_size, divisive_norm_multiplier, spectral_mod_mode, spectral_mod_percentile, spectral_mod_multiplier, affect_uncond, dyn_cfg_augmentation, seed=None): 909 | gen = None 910 | rand = random 911 | if seed is not None: 912 | gen = torch.Generator(device='cpu') 913 | rand = random.Random() 914 | gen.manual_seed(seed) 915 | rand.seed(seed) 916 | 917 | def modify_latent(args): 918 | x_input = args["input"] 919 | cond = args["cond"] 920 | uncond = args["uncond"] 921 | cond_scale = args["cond_scale"] 922 | timestep = model.model.model_sampling.timestep(args["timestep"]) 923 | sigma = args["sigma"] 924 | sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1)) 925 | #print(model.model.model_sampling.timestep(timestep)) 926 | 927 | x = x_input / (sigma * sigma + 1.0) 928 | cond = ((x - (x_input - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma) 929 | uncond = ((x - (x_input - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma) 930 | 931 | noise_pred = (cond - uncond) 932 | 933 | # Extra noise 934 | if extra_noise_multiplier > 0: 935 | match extra_noise_type: 936 | case "gaussian": 937 | extra_noise = gen_like(torch.randn, cond, generator=gen) 938 | case "uniform": 939 | extra_noise = (gen_like(torch.rand, cond, generator=gen) - 0.5) * 2 * 1.73 940 | case "perlin": 941 | cond_size_0 = cond.size(dim=2) 942 | cond_size_1 = cond.size(dim=3) 943 | extra_noise = perlin_noise(grid_shape=(cond_size_0, cond_size_1), out_shape=(cond_size_0, cond_size_1), batch_size=4, generator=gen).to(cond.device).unsqueeze(0) 944 | mean = torch.mean(extra_noise) 945 | std = torch.std(extra_noise) 946 | 947 | extra_noise.sub_(mean).div_(std) 948 | case "pink": 949 | extra_noise = generate_1f_noise(cond, 2, extra_noise_multiplier, generator=gen).to(cond.device) 950 | mean = torch.mean(extra_noise) 951 | std = torch.std(extra_noise) 952 | 953 | extra_noise.sub_(mean).div_(std) 954 | case "green": 955 | cond_size_0 = cond.size(dim=2) 956 | cond_size_1 = cond.size(dim=3) 957 | extra_noise = green_noise(cond_size_0, cond_size_1, generator=gen).to(cond.device) 958 | mean = torch.mean(extra_noise) 959 | std = torch.std(extra_noise) 960 | 961 | extra_noise.sub_(mean).div_(std) 962 | case "pyramid": 963 | extra_noise = pyramid_noise_like(cond) 964 | 965 | if extra_noise_lowpass > 0: 966 | extra_noise = get_low_frequency_noise(extra_noise, extra_noise_lowpass) 967 | 968 | alpha_noise = 1.0 - (timestep / 999.0)[:, None, None, None].clone() # Get alpha multiplier, lower alpha at high sigmas/high noise 969 | alpha_noise *= 0.001 * extra_noise_multiplier # User-input and weaken the strength so we don't annihilate the latent. 970 | match extra_noise_method: 971 | case "add": 972 | cond = cond + extra_noise * alpha_noise 973 | uncond = uncond - extra_noise * alpha_noise 974 | case "add_scaled": 975 | cond = cond + train_difference(cond, extra_noise, cond) * alpha_noise 976 | uncond = uncond - train_difference(uncond, extra_noise, uncond) * alpha_noise 977 | case "speckle": 978 | cond = cond + cond * extra_noise * alpha_noise 979 | uncond = uncond - uncond * extra_noise * alpha_noise 980 | case "cads": 981 | cond = add_cads_custom_noise(cond, extra_noise, timestep, 0.6, 0.9, extra_noise_multiplier / 100., 1, False) 982 | uncond = add_cads_custom_noise(uncond, extra_noise, timestep, 0.6, 0.9, extra_noise_multiplier / 100., 1, False) 983 | case "cads_rescaled": 984 | cond = add_cads_custom_noise(cond, extra_noise, timestep, 0.6, 0.9, extra_noise_multiplier / 100., 1, True) 985 | uncond = add_cads_custom_noise(uncond, extra_noise, timestep, 0.6, 0.9, extra_noise_multiplier / 100., 1, True) 986 | case "cads_speckle": 987 | cond = add_cads_custom_noise(cond, extra_noise * cond, timestep, 0.6, 0.9, extra_noise_multiplier / 100., 1, False) 988 | uncond = add_cads_custom_noise(uncond, extra_noise * uncond, timestep, 0.6, 0.9, extra_noise_multiplier / 100., 1, False) 989 | case "cads_speckle_rescaled": 990 | cond = add_cads_custom_noise(cond, extra_noise * cond, timestep, 0.6, 0.9, extra_noise_multiplier / 100., 1, True) 991 | uncond = add_cads_custom_noise(uncond, extra_noise * uncond, timestep, 0.6, 0.9, extra_noise_multiplier / 100., 1, True) 992 | case _: 993 | print("Haven't heard of a noise method named like that before... (Couldn't find method)") 994 | 995 | if sharpness_multiplier > 0.0 or sharpness_multiplier < 0.0: 996 | match sharpness_method: 997 | case "anisotropic": 998 | degrade_func = bilateral_blur 999 | case "joint-anisotropic": 1000 | degrade_func = lambda img: joint_bilateral_blur(img, (img - torch.mean(img, dim=(1, 2, 3), keepdim=True)) / torch.std(img, dim=(1, 2, 3), keepdim=True), 13, 3.0, 3.0, "reflect", "l1") 1001 | case "gaussian": 1002 | degrade_func = gaussian_filter_2d 1003 | case "cas": 1004 | degrade_func = lambda image: contrast_adaptive_sharpening(image, amount=sigma.clamp(max=1.00).item()) 1005 | case _: 1006 | print("For some reason, the sharpness filter could not be found.") 1007 | # Sharpness 1008 | alpha = 1.0 - (timestep / 999.0)[:, None, None, None].clone() # Get alpha multiplier, lower alpha at high sigmas/high noise 1009 | alpha *= 0.001 * sharpness_multiplier # User-input and weaken the strength so we don't annihilate the latent. 1010 | cond = degrade_func(cond) * alpha + cond * (1.0 - alpha) # Mix the modified latent with the existing latent by the alpha 1011 | if affect_uncond == "Sharpness": 1012 | uncond = degrade_func(uncond) * alpha + uncond * (1.0 - alpha) 1013 | 1014 | time_mult = 1.0 - (timestep / 999.0)[:, None, None, None].clone() 1015 | noise_pred_degraded = (cond - uncond) if dyn_cfg_augmentation == "None" else dyn_cfg_modifier(cond, uncond, dyn_cfg_augmentation, cond_scale, time_mult) # New noise pred 1016 | 1017 | # After this point, we use `noise_pred_degraded` instead of just `cond` for the final set of calculations 1018 | 1019 | # Tonemap noise 1020 | if tonemap_multiplier == 0: 1021 | new_magnitude = 1.0 1022 | else: 1023 | match tonemap_method: 1024 | case "reinhard": 1025 | noise_pred_vector_magnitude = (torch.linalg.vector_norm(noise_pred_degraded, dim=(1)) + 0.0000000001)[:,None] 1026 | noise_pred_degraded /= noise_pred_vector_magnitude 1027 | 1028 | mean = torch.mean(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True) 1029 | std = torch.std(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True) 1030 | 1031 | top = (std * 3 * (100 / tonemap_percentile) + mean) * tonemap_multiplier 1032 | 1033 | noise_pred_vector_magnitude *= (1.0 / top) 1034 | new_magnitude = noise_pred_vector_magnitude / (noise_pred_vector_magnitude + 1.0) 1035 | new_magnitude *= top 1036 | 1037 | noise_pred_degraded *= new_magnitude 1038 | case "reinhard_perchannel": # Testing the flatten strategy 1039 | flattened = noise_pred_degraded.flatten(2) 1040 | noise_pred_vector_magnitude = (torch.linalg.vector_norm(flattened, dim=(2), keepdim=True) + 0.0000000001) 1041 | flattened /= noise_pred_vector_magnitude 1042 | 1043 | mean = torch.mean(noise_pred_vector_magnitude, dim=(2), keepdim=True) 1044 | 1045 | top = (3 * (100 / tonemap_percentile) + mean) * tonemap_multiplier 1046 | 1047 | noise_pred_vector_magnitude *= (1.0 / top) 1048 | 1049 | new_magnitude = noise_pred_vector_magnitude / (noise_pred_vector_magnitude + 1.0) 1050 | new_magnitude *= top 1051 | 1052 | flattened *= new_magnitude 1053 | noise_pred_degraded = flattened.unflatten(2, noise_pred_degraded.shape[2:]) 1054 | case "arctan": 1055 | noise_pred_vector_magnitude = (torch.linalg.vector_norm(noise_pred_degraded, dim=(1)) + 0.0000000001)[:,None] 1056 | noise_pred_degraded /= noise_pred_vector_magnitude 1057 | 1058 | noise_pred_degraded = (torch.arctan(noise_pred_degraded * tonemap_multiplier) * (1 / tonemap_multiplier)) + (noise_pred_degraded * (100 - tonemap_percentile) / 100) 1059 | 1060 | noise_pred_degraded *= noise_pred_vector_magnitude 1061 | case "quantile": 1062 | s: FloatTensor = torch.quantile( 1063 | (uncond + noise_pred_degraded * cond_scale).flatten(start_dim=1).abs(), 1064 | tonemap_percentile / 100, 1065 | dim = -1 1066 | ) * tonemap_multiplier 1067 | s.clamp_(min = 1.) 1068 | s = s.reshape(*s.shape, 1, 1, 1) 1069 | noise_pred_degraded = noise_pred_degraded.clamp(-s, s) / s 1070 | case "gated": # https://birchlabs.co.uk/machine-learning#dynamic-thresholding-latents so based,.,.,...., 1071 | latent_scale = model.model.latent_format.scale_factor 1072 | 1073 | latent = uncond + noise_pred_degraded * cond_scale # Get full latent from CFG formula 1074 | latent /= latent_scale # Divide full CFG by latent scale (~0.13 for sdxl) 1075 | flattened = latent.flatten(2) 1076 | means = flattened.mean(dim=2).unsqueeze(2) 1077 | centered_magnitudes = (flattened - means).abs().max() # Get highest magnitude of full CFG 1078 | 1079 | flattened_pred = (noise_pred_degraded / latent_scale).flatten(2) 1080 | 1081 | floor = 3.0560 1082 | ceil = 42. * tonemap_multiplier # as is the answer to life, unless you modify the multiplier cuz u aint a believer in life 1083 | 1084 | 1085 | thresholded_latent = dyn_thresh_gate(flattened_pred, centered_magnitudes, tonemap_percentile / 100., floor, ceil) # Threshold if passes ceil 1086 | thresholded_latent = thresholded_latent.unflatten(2, noise_pred_degraded.shape[2:]) 1087 | noise_pred_degraded = thresholded_latent * latent_scale # Rescale by latent 1088 | case "cfg-mimic": 1089 | latent = noise_pred_degraded 1090 | 1091 | mimic_latent = noise_pred_degraded * tonemap_multiplier 1092 | mimic_flattened = mimic_latent.flatten(2) 1093 | mimic_means = mimic_flattened.mean(dim=2).unsqueeze(2) 1094 | mimic_recentered = mimic_flattened - mimic_means 1095 | mimic_abs = mimic_recentered.abs() 1096 | mimic_max = mimic_abs.max(dim=2).values.unsqueeze(2) 1097 | 1098 | latent_flattened = latent.flatten(2) 1099 | latent_means = latent_flattened.mean(dim=2).unsqueeze(2) 1100 | latent_recentered = latent_flattened - latent_means 1101 | latent_abs = latent_recentered.abs() 1102 | latent_q = torch.quantile(latent_abs, tonemap_percentile / 100., dim=2).unsqueeze(2) 1103 | s = torch.maximum(latent_q, mimic_max) 1104 | pred_clamped = noise_pred_degraded.flatten(2).clamp(-s, s) 1105 | pred_normalized = pred_clamped / s 1106 | pred_renorm = pred_normalized * mimic_max 1107 | pred_uncentered = pred_renorm + mimic_means # Personal choice to re-mean from the mimic here... should be latent_means. 1108 | noise_pred_degraded = pred_uncentered.unflatten(2, noise_pred_degraded.shape[2:]) 1109 | case "spatial-norm": 1110 | #time = (1.0 - (timestep / 999.0)[:, None, None, None].clone().item()) 1111 | #time = -(math.cos(time * math.pi) / (3)) + (2/3) # 0.33333 to 1.0, half cosine 1112 | noise_pred_degraded = spatial_norm_chw_thresholding(noise_pred_degraded, tonemap_multiplier / 2 / cond_scale) 1113 | case _: 1114 | print("Could not tonemap, for the method was not found.") 1115 | 1116 | # Spectral Modification 1117 | if spectral_mod_multiplier > 0 or spectral_mod_multiplier < 0: 1118 | #alpha = 1. - (timestep / 999.0)[:, None, None, None].clone() # Get alpha multiplier, lower alpha at high sigmas/high noise 1119 | #alpha = spectral_mod_multiplier# User-input and weaken the strength so we don't annihilate the latent. 1120 | match spectral_mod_mode: 1121 | case "hard_clamp": 1122 | modulation_func = spectral_modulation 1123 | case "soft_clamp": 1124 | modulation_func = spectral_modulation_soft 1125 | modulation_diff = modulation_func(noise_pred_degraded, spectral_mod_multiplier, spectral_mod_percentile) - noise_pred_degraded 1126 | noise_pred_degraded += modulation_diff 1127 | 1128 | if contrast_multiplier > 0 or contrast_multiplier < 0: 1129 | contrast_func = contrast 1130 | # Contrast, after tonemapping, to ensure user-set contrast is expected to behave similarly across tonemapping settings 1131 | alpha = 1.0 - (timestep / 999.0)[:, None, None, None].clone() 1132 | alpha *= 0.001 * contrast_multiplier 1133 | noise_pred_degraded = contrast_func(noise_pred_degraded) * alpha + (noise_pred_degraded) * (1.0 - alpha) # Temporary fix for contrast is to add the input? Maybe? It just doesn't work like before... 1134 | 1135 | # Rescale CFG 1136 | if rescale_cfg_phi == 0: 1137 | x_final = uncond + noise_pred_degraded * cond_scale 1138 | else: 1139 | x_cfg = uncond + noise_pred_degraded * cond_scale 1140 | ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True) 1141 | ro_cfg = torch.std(x_cfg, dim=(1,2,3), keepdim=True) 1142 | 1143 | x_rescaled = x_cfg * (ro_pos / ro_cfg) 1144 | x_final = rescale_cfg_phi * x_rescaled + (1.0 - rescale_cfg_phi) * x_cfg 1145 | 1146 | if combat_cfg_drift > 0 or combat_cfg_drift < 0: 1147 | alpha = (1. - (timestep / 999.0)[:, None, None, None].clone()) 1148 | alpha ** 0.025 # Alpha might as well be 1, but we want to protect the first steps (?). 1149 | alpha = alpha.clamp_(max=1) 1150 | match combat_method: 1151 | case "subtract": 1152 | combat_drift_func = center_latent_perchannel 1153 | alpha *= combat_cfg_drift 1154 | case "subtract_channels": 1155 | combat_drift_func = center_0channel 1156 | alpha *= combat_cfg_drift 1157 | case "subtract_median": 1158 | combat_drift_func = center_latent_median 1159 | alpha *= combat_cfg_drift 1160 | case "sharpen": 1161 | combat_drift_func = channel_sharpen 1162 | alpha *= combat_cfg_drift 1163 | x_final = combat_drift_func(x_final) * alpha + x_final * (1.0 - alpha) # Mix the modified latent with the existing latent by the alpha 1164 | 1165 | if divisive_norm_multiplier > 0: 1166 | alpha = 1. - (timestep / 999.0)[:, None, None, None].clone() 1167 | alpha ** 0.025 # Alpha might as well be 1, but we want to protect the beginning steps (?). 1168 | alpha *= divisive_norm_multiplier 1169 | high_noise = divisive_normalization(x_final, (divisive_norm_size * 2) + 1) 1170 | x_final = high_noise * alpha + x_final * (1.0 - alpha) 1171 | 1172 | 1173 | return x_input - (x - x_final * sigma / (sigma * sigma + 1.0) ** 0.5) # General formula for CFG. uncond + (cond - uncond) * cond_scale 1174 | 1175 | m = model.clone() 1176 | m.set_model_sampler_cfg_function(modify_latent) 1177 | return (m, ) --------------------------------------------------------------------------------