├── README.md └── scripts └── blockcache.py /README.md: -------------------------------------------------------------------------------- 1 | ## First Block Cache and TeaCache, in Forge webUI ## 2 | ### accelerate inference at some, perhaps minimal, quality cost ### 3 | 4 | derived, with lots of reworking, from: 5 | * https://github.com/likelovewant/sd-forge-teacache (flux only, teacache only) 6 | 7 | more info: 8 | * https://github.com/ali-vilab/TeaCache/tree/main/TeaCache4FLUX 9 | * https://github.com/chengzeyi/Comfy-WaveSpeed 10 | 11 | install: 12 | **Extensions** tab, **Install from URL**, use URL for this repo 13 | 14 | >[!NOTE] 15 | >This handles SelfAttentionGuidance and PerturbedAttentionGuidance (and anything else that calculates a cond), and applies the caching to them too, independently. 16 | > 17 | >Previous implementation moved to `old` branch. 18 | > 19 | >(30/05/2025) pre-SD3/Chroma version moved to `less-old` branch 20 | 21 | usage: 22 | 1. Enable the extension 23 | 2. select caching threshold: higher threshold = more caching = faster + lower quality 24 | 3. low step models (Hyper) will need higher threshold to do anything 25 | 4. Generate 26 | 5. You'll need to experiment to find settings that work with your favoured models, step counts, samplers. 27 | 28 | >[!NOTE] 29 | >Both methods work with SD1.5, SD2, SDXL (including separated cond processing), and Flux. 30 | > 31 | >(30/05/2025) added versions for SD3(.5) and Chroma. Caching SD3 does not seem to work especially well, tends to reduce detail too much, but may be more useful with higher steps. 32 | > 33 | >The use of cached residuals applies to the whole batch, so results will not be identical between different batch sizes. This is absolutely 100% *will not fix*. 34 | 35 | Now works with batch_size > 1, but results will not be consistent with same seed at batch_size == 1. 36 | 37 | Added option for maximum consecutive cached steps (0: no limit); and made not using cache for final step an option (previously always processed the final step). 38 | 39 | Some samplers (DPM++ 2M, UniPC, likely others) need very low threshold and/or delayed start + limit to consecutive cached steps. 40 | --- 41 | --- 42 | original README: 43 | 44 | ## Sd-Forge-TeaCache: Speed up Your Diffusion Models 45 | 46 | **Introduction** 47 | 48 | Timestep Embedding Aware Cache (TeaCache) is a revolutionary training-free caching approach that leverages the 49 | fluctuating differences between model outputs across timesteps. This acceleration technique significantly boosts 50 | inference speed for various diffusion models, including Image, Video, and Audio. 51 | 52 | TeaCache's integration into SD Forge WebUI for Flux only. Installation is as 53 | straightforward as any other extension: 54 | 55 | * **Clone:** `git clone https://github.com/likelovewant/sd-forge-teacache.git` 56 | 57 | into extensions directory ,relauch the system . 58 | 59 | 60 | **Speed Up Your Diffusion Generation** 61 | 62 | TeaCache can accelerate FLUX inference by up to 2x with minimal visual quality degradation, all without requiring any training. 63 | 64 | Within the Forge WebUI, you can easily adjust the following settings: 65 | 66 | * **Relative L1 Threshold:** Controls the sensitivity of TeaCache's caching mechanism. 67 | * **Steps:** Matches the number of sampling steps used in TeaCache. 68 | 69 | **Performance Tuning** 70 | 71 | Based on [TeaCache4FLUX](https://github.com/ali-vilab/TeaCache/tree/main/TeaCache4FLUX), you can achieve different 72 | speedups: 73 | 74 | * 0.25 threshold for 1.5x speedup 75 | * 0.4 threshold for 1.8x speedup 76 | * 0.6 threshold for 2.0x speedup 77 | * 0.8 threshold for 2.25x speedup 78 | 79 | **Important Notes:** 80 | 81 | * **Maintain Consistency:** Keep the sampling steps in TeaCache aligned with the steps used in your Flux Sampling steps .Discrepancies can lead to lower quality outputs. 82 | * **LoRA Considerations:** When utilizing LoRAs, adjust the steps or scales based on your GPU's capabilities. A recommended starting point is 28 steps or more. 83 | 84 | To ensure smooth operation, remember to: 85 | 86 | 1. **Clear Residual Cache (optional):** When changing image sizes or disabling the TeaCache extension, always click "Clear Residual Cache" within the Forge WebUI. This prevents potential conflicts and maintains optimal performance. 87 | 2. **Disable TeaCache Properly:** Ensure disable the TeaCache extension if you don't need it in your Forge WebUI. If not proper `Clear Residual Cache`, you may encounter unexpected behavior and require a full relaunch. 88 | 89 | 90 | Several AI assistants has assisting with code generation and refinement for this extension based on the below resources. 91 | 92 | **Credits and Resources** 93 | 94 | This adaptation leverages [TeaCache4FLUX](https://github.com/ali-vilab/TeaCache/tree/main/TeaCache4FLUX) 95 | From ali-vilab TeaCache repository:[TeaCache](https://github.com/ali-vilab/TeaCache). 96 | 97 | For additional information and other integrations, explore: 98 | 99 | * [ComfyUI-TeaCache](https://github.com/welltop-cn/ComfyUI-TeaCache) 100 | 101 | -------------------------------------------------------------------------------- /scripts/blockcache.py: -------------------------------------------------------------------------------- 1 | ## First Block Cache / TeaCache for Forge2 webui 2 | ## with option to skip cache for early steps 3 | ## options to always process last step 4 | ## option for maximum consecutive steps to apply caching (0: no limit) 5 | ## handles highresfix 6 | ## handles PAG and SAG (with unet models, not Flux) by accelerating them too, independently 7 | ## opposite time/quality trade offs ... but some way of handling them is necessary to avoid potential errors 8 | 9 | ## derived from https://github.com/likelovewant/sd-forge-teacache (flux only, teacache only) 10 | 11 | # fbc for flux 12 | # fbc and tc for sd1, sdxl 13 | # fbc and tc for sd3 14 | # fbc and tc for chroma - untested 15 | 16 | # actually, I'm skeptical about these coefficients 17 | 18 | 19 | import torch 20 | import numpy as np 21 | from torch import Tensor 22 | import gradio as gr 23 | from modules import scripts 24 | from modules.ui_components import InputAccordion 25 | from backend.nn.flux import IntegratedFluxTransformer2DModel 26 | from backend.nn.flux import timestep_embedding as timestep_embedding_flux 27 | from backend.nn.unet import IntegratedUNet2DConditionModel, apply_control 28 | from backend.nn.unet import timestep_embedding as timestep_embedding_unet 29 | 30 | try: 31 | from backend.nn.mmditx import MMDiTX 32 | except: 33 | MMDiTX = None 34 | 35 | try: 36 | from backend.nn.chroma import IntegratedChromaTransformer2DModel 37 | from backend.nn.chroma import timestep_embedding as timestep_embedding_chroma 38 | except: 39 | IntegratedChromaTransformer2DModel = None 40 | 41 | 42 | class BlockCache(scripts.Script): 43 | original_inner_forward = None 44 | 45 | def __init__(self): 46 | if BlockCache.original_inner_forward is None: 47 | if IntegratedChromaTransformer2DModel is not None: 48 | BlockCache.chroma_inner_forward = IntegratedChromaTransformer2DModel.inner_forward 49 | BlockCache.original_inner_forward = IntegratedFluxTransformer2DModel.inner_forward 50 | BlockCache.original_forward_unet = IntegratedUNet2DConditionModel.forward 51 | if MMDiTX is not None: 52 | BlockCache.original_forward_mmditx = MMDiTX.forward 53 | 54 | def title(self): 55 | return "First Block Cache / TeaCache" 56 | 57 | def show(self, is_img2img): 58 | return scripts.AlwaysVisible 59 | 60 | def ui(self, is_img2img): 61 | with InputAccordion(False, label=self.title()) as enabled: 62 | method = gr.Radio(label="Method", choices=["First Block Cache", "TeaCache"], type="value", value="First Block Cache") 63 | with gr.Row(): 64 | nocache_steps = gr.Number(label="Uncached starting steps", scale=0, 65 | minimum=1, maximum=12, value=1, step=1, 66 | ) 67 | threshold = gr.Slider(label="caching threshold, higher values cache more aggressively.", 68 | minimum=0.0, maximum=1.0, value=0.1, step=0.001, 69 | ) 70 | with gr.Row(): 71 | max_cached = gr.Number(label="Max. consecutive cached", scale=0, 72 | minimum=0, maximum=99, value=0, step=1, 73 | ) 74 | always_last = gr.Checkbox(label="Do not use cache on last step", value=False) 75 | 76 | enabled.do_not_save_to_config = True 77 | method.do_not_save_to_config = True 78 | nocache_steps.do_not_save_to_config = True 79 | threshold.do_not_save_to_config = True 80 | max_cached.do_not_save_to_config = True 81 | always_last.do_not_save_to_config = True 82 | 83 | self.infotext_fields = [ 84 | (enabled, lambda d: d.get("bc_enabled", False)), 85 | (method, "bc_method"), 86 | (threshold, "bc_threshold"), 87 | (nocache_steps, "bc_nocache_steps"), 88 | (max_cached, "bc_skip_limit"), 89 | (always_last, "bc_always_last"), 90 | ] 91 | 92 | return [enabled, method, threshold, nocache_steps, max_cached, always_last] 93 | 94 | 95 | def process(self, p, *args): 96 | enabled, method, threshold, nocache_steps, max_cached, always_last = args 97 | 98 | if enabled: 99 | if method == "First Block Cache": 100 | if (p.sd_model.is_sd1 == True) or (p.sd_model.is_sd2 == True) or (p.sd_model.is_sdxl == True): 101 | IntegratedUNet2DConditionModel.forward = patched_forward_unet_fbc 102 | elif p.sd_model.is_sd3 == True: 103 | MMDiTX.forward = patched_forward_mmditx_fbc 104 | else: 105 | IntegratedFluxTransformer2DModel.inner_forward = patched_inner_forward_flux_fbc 106 | if IntegratedChromaTransformer2DModel is not None: 107 | IntegratedChromaTransformer2DModel.inner_forward = patched_inner_forward_chroma_fbc 108 | else: 109 | if (p.sd_model.is_sd1 == True) or (p.sd_model.is_sd2 == True) or (p.sd_model.is_sdxl == True): 110 | IntegratedUNet2DConditionModel.forward = patched_forward_unet_tc 111 | elif p.sd_model.is_sd3 == True: 112 | MMDiTX.forward = patched_forward_mmditx_tc 113 | else: 114 | # identify flux / chroma to avoid patching both 115 | IntegratedFluxTransformer2DModel.inner_forward = patched_inner_forward_flux_tc 116 | if IntegratedChromaTransformer2DModel is not None: 117 | IntegratedChromaTransformer2DModel.inner_forward = patched_inner_forward_chroma_tc 118 | 119 | p.extra_generation_params.update({ 120 | "bc_enabled" : enabled, 121 | "bc_method" : method, 122 | "bc_threshold" : threshold, 123 | "bc_nocache_steps" : nocache_steps, 124 | "bc_skip_limit" : max_cached, 125 | "bc_always_last" : always_last, 126 | }) 127 | 128 | setattr(BlockCache, "threshold", threshold) 129 | setattr(BlockCache, "nocache_steps", nocache_steps) 130 | setattr(BlockCache, "skip_limit", max_cached) 131 | setattr(BlockCache, "always_last", always_last) 132 | 133 | 134 | def process_before_every_sampling(self, p, *args, **kwargs): 135 | enabled = args[0] 136 | 137 | # possibly many passes through the forward method on each step: cond, uncond, PAG, SAG, TRAsce, SLG, 138 | 139 | if enabled: 140 | setattr(BlockCache, "index", 0) 141 | setattr(BlockCache, "distance", [0]) 142 | setattr(BlockCache, "this_step", 0) 143 | setattr(BlockCache, "last_step", p.hr_second_pass_steps if p.is_hr_pass else p.steps) 144 | setattr(BlockCache, "residual", [None]) 145 | setattr(BlockCache, "previous", [None]) 146 | setattr(BlockCache, "previousSigma", None) 147 | setattr(BlockCache, "skipped", [0]) 148 | 149 | 150 | def post_sample (self, params, ps, *args): 151 | # def postprocess(self, params, processed, *args): 152 | # always clean up after processing 153 | enabled = args[0] 154 | 155 | if enabled: 156 | # restore the original inner_forward method 157 | if IntegratedChromaTransformer2DModel is not None: 158 | IntegratedChromaTransformer2DModel.inner_forward = BlockCache.chroma_inner_forward 159 | IntegratedFluxTransformer2DModel.inner_forward = BlockCache.original_inner_forward 160 | IntegratedUNet2DConditionModel.forward = BlockCache.original_forward_unet 161 | if MMDiTX is not None: 162 | MMDiTX.forward = BlockCache.original_forward_mmditx 163 | 164 | delattr(BlockCache, "index") 165 | delattr(BlockCache, "threshold") 166 | delattr(BlockCache, "nocache_steps") 167 | delattr(BlockCache, "skip_limit") 168 | delattr(BlockCache, "always_last") 169 | delattr(BlockCache, "distance") 170 | delattr(BlockCache, "this_step") 171 | delattr(BlockCache, "last_step") 172 | delattr(BlockCache, "residual") 173 | delattr(BlockCache, "previous") 174 | delattr(BlockCache, "previousSigma") 175 | delattr(BlockCache, "skipped") 176 | 177 | 178 | # patches forward, with inline forward_with_concat 179 | def patched_forward_mmditx_fbc( 180 | self, 181 | x: torch.Tensor, 182 | t: torch.Tensor, 183 | y = None, 184 | context = None, 185 | control=None, transformer_options={}, **kwargs) -> torch.Tensor: 186 | 187 | thisSigma = t[0].item() 188 | 189 | if BlockCache.previousSigma == thisSigma: 190 | BlockCache.index += 1 191 | if BlockCache.index == len(BlockCache.distance): 192 | BlockCache.distance.append(0) 193 | BlockCache.residual.append(None) 194 | BlockCache.previous.append(None) 195 | BlockCache.skipped.append(0) 196 | else: 197 | BlockCache.previousSigma = thisSigma 198 | BlockCache.index = 0 199 | BlockCache.this_step += 1 200 | 201 | index = BlockCache.index 202 | 203 | skip_layers = transformer_options.get("skip_layers", []) 204 | 205 | hw = x.shape[-2:] 206 | 207 | x = self.x_embedder(x) + self.cropped_pos_embed(hw).to(x.device, x.dtype) 208 | c = self.t_embedder(t, dtype=x.dtype) # (N, D) 209 | if y is not None: 210 | y = self.y_embedder(y) # (N, D) 211 | c = c + y # (N, D) 212 | 213 | context = self.context_embedder(context) 214 | 215 | if self.register_length > 0: 216 | context = torch.cat( 217 | ( 218 | repeat(self.register, "1 ... -> b ...", b=x.shape[0]), 219 | context if context is not None else torch.Tensor([]).type_as(x), 220 | ), 221 | 1, 222 | ) 223 | 224 | original_x = x.clone() 225 | 226 | epsilon = 1e-6 227 | 228 | first_block = True 229 | for i, block in enumerate(self.joint_blocks): 230 | if i in skip_layers: 231 | continue 232 | 233 | context, x = block(context, x, c=c) 234 | if control is not None: 235 | controlnet_block_interval = len(self.joint_blocks) // len( 236 | control 237 | ) 238 | x = x + control[i // controlnet_block_interval] 239 | 240 | if first_block: 241 | first_block = False 242 | if BlockCache.this_step <= BlockCache.nocache_steps: 243 | skip_check = False 244 | elif BlockCache.always_last and BlockCache.this_step >= BlockCache.last_step: 245 | skip_check = False 246 | else: 247 | skip_check = True 248 | if BlockCache.previous[index] is None or BlockCache.residual[index] is None: 249 | skip_check = False 250 | if BlockCache.skip_limit > 0 and BlockCache.skipped[index] >= BlockCache.skip_limit: 251 | skip_check = False 252 | 253 | if skip_check: 254 | ## accumulate (then average?) distance per channel 255 | thisDistance = torch.zeros_like(x) 256 | for i in range(len(x)): 257 | thisDistance += (x[i] - BlockCache.previous[index][i]).abs() / (epsilon + BlockCache.previous[index][i].abs()) 258 | 259 | avgDistance = thisDistance.mean().cpu().item() 260 | 261 | # fullDistance = (x - BlockCache.previous[index]).abs().mean() / (epsilon + BlockCache.previous[index].abs().mean()).cpu().item() 262 | # print (avgDistance, fullDistance) 263 | 264 | BlockCache.distance[index] += avgDistance 265 | 266 | BlockCache.previous[index] = x.clone() 267 | if BlockCache.distance[index] < BlockCache.threshold: 268 | BlockCache.skipped[index] += 1 269 | # print (x.mean(), x.std(), BlockCache.residual[index].mean(), BlockCache.residual[index].std()) 270 | # for i in range(len(x)): 271 | # x[i] += BlockCache.residual[index][i] * (x[i].mean().abs() / BlockCache.residual[index][i].mean().abs()) * x[i].std() 272 | 273 | x += BlockCache.residual[index] * (x.mean().abs() / BlockCache.residual[index].mean().abs())# * x.std() 274 | 275 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) 276 | x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W) 277 | return x ## early exit 278 | else: 279 | BlockCache.previous[index] = x.clone() 280 | 281 | BlockCache.residual[index] = x - original_x 282 | BlockCache.distance[index] = 0 283 | BlockCache.skipped[index] = 0 284 | 285 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) 286 | 287 | x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W) 288 | 289 | return x 290 | 291 | 292 | def patched_inner_forward_chroma_fbc(self, img, img_ids, txt, txt_ids, timesteps, guidance=None): 293 | # BlockCache version 294 | 295 | thisSigma = timesteps[0].item() 296 | if BlockCache.previousSigma == thisSigma: 297 | BlockCache.index += 1 298 | if BlockCache.index == len(BlockCache.distance): 299 | BlockCache.distance.append(0) 300 | BlockCache.residual.append(None) 301 | BlockCache.previous.append(None) 302 | BlockCache.skipped.append(0) 303 | else: 304 | BlockCache.previousSigma = thisSigma 305 | BlockCache.index = 0 306 | BlockCache.this_step += 1 307 | 308 | index = BlockCache.index 309 | 310 | if img.ndim != 3 or txt.ndim != 3: 311 | raise ValueError("Input img and txt tensors must have 3 dimensions.") 312 | 313 | img = self.img_in(img) 314 | device = img.device 315 | dtype = img.dtype 316 | nb_double_block = len(self.double_blocks) 317 | nb_single_block = len(self.single_blocks) 318 | 319 | mod_index_length = nb_double_block*12 + nb_single_block*3 + 2 320 | distill_timestep = timestep_embedding_chroma(timesteps.detach().clone(), 16).to(device=device, dtype=dtype) 321 | distil_guidance = timestep_embedding_chroma(guidance.detach().clone(), 16).to(device=device, dtype=dtype) 322 | modulation_index = timestep_embedding_chroma(torch.arange(mod_index_length), 32).to(device=device, dtype=dtype) 323 | modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1) 324 | timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1) 325 | input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1) 326 | mod_vectors = self.distilled_guidance_layer(input_vec) 327 | mod_vectors_dict = self.distribute_modulations(mod_vectors, nb_single_block, nb_double_block) 328 | 329 | txt = self.txt_in(txt) 330 | del guidance 331 | ids = torch.cat((txt_ids, img_ids), dim=1) 332 | del txt_ids, img_ids 333 | pe = self.pe_embedder(ids) 334 | del ids 335 | 336 | original_img = img.clone() 337 | 338 | first_block = True 339 | for i, block in enumerate(self.double_blocks): 340 | img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"] 341 | txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"] 342 | double_mod = [img_mod, txt_mod] 343 | img, txt = block(img=img, txt=txt, mod=double_mod, pe=pe) 344 | if first_block: 345 | first_block = False 346 | if BlockCache.this_step <= BlockCache.nocache_steps: 347 | skip_check = False 348 | elif BlockCache.always_last and BlockCache.this_step >= BlockCache.last_step: 349 | skip_check = False 350 | else: 351 | skip_check = True 352 | if BlockCache.previous[index] is None or BlockCache.residual[index] is None: 353 | skip_check = False 354 | if BlockCache.skip_limit > 0 and BlockCache.skipped[index] >= BlockCache.skip_limit: 355 | skip_check = False 356 | 357 | if skip_check: 358 | BlockCache.distance[index] += ((img - BlockCache.previous[index]).abs().mean() / BlockCache.previous[index].abs().mean()).cpu().item() 359 | BlockCache.previous[index] = img.clone() 360 | if BlockCache.distance[index] < BlockCache.threshold: 361 | BlockCache.skipped[index] += 1 362 | img = original_img + BlockCache.residual[index] 363 | final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"] 364 | img = self.final_layer(img, final_mod) 365 | return img ## early exit 366 | else: 367 | BlockCache.previous[index] = img 368 | 369 | img = torch.cat((txt, img), 1) 370 | for i, block in enumerate(self.single_blocks): 371 | single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] 372 | img = block(img, mod=single_mod, pe=pe) 373 | del pe 374 | img = img[:, txt.shape[1]:, ...] 375 | 376 | BlockCache.residual[index] = img - original_img 377 | BlockCache.distance[index] = 0 378 | BlockCache.skipped[index] = 0 379 | 380 | final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"] 381 | img = self.final_layer(img, final_mod) 382 | return img 383 | 384 | 385 | def patched_inner_forward_flux_fbc(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None): 386 | # BlockCache version 387 | 388 | thisSigma = timesteps[0].item() 389 | if BlockCache.previousSigma == thisSigma: 390 | BlockCache.index += 1 391 | if BlockCache.index == len(BlockCache.distance): 392 | BlockCache.distance.append(0) 393 | BlockCache.residual.append(None) 394 | BlockCache.previous.append(None) 395 | BlockCache.skipped.append(0) 396 | else: 397 | BlockCache.previousSigma = thisSigma 398 | BlockCache.index = 0 399 | BlockCache.this_step += 1 400 | 401 | index = BlockCache.index 402 | 403 | if img.ndim != 3 or txt.ndim != 3: 404 | raise ValueError("Input img and txt tensors must have 3 dimensions.") 405 | 406 | # Image and text embedding 407 | img = self.img_in(img) 408 | vec = self.time_in(timestep_embedding_flux(timesteps, 256).to(img.dtype)) 409 | 410 | # If guidance_embed is enabled, add guidance information 411 | if self.guidance_embed: 412 | if guidance is None: 413 | raise ValueError("Didn't get guidance strength for guidance distilled model.") 414 | vec = vec + self.guidance_in(timestep_embedding_flux(guidance, 256).to(img.dtype)) 415 | 416 | vec = vec + self.vector_in(y) 417 | txt = self.txt_in(txt) 418 | 419 | # Merge image and text IDs 420 | ids = torch.cat((txt_ids, img_ids), dim=1) 421 | pe = self.pe_embedder(ids) 422 | 423 | original_img = img.clone() 424 | 425 | first_block = True 426 | for block in self.double_blocks: 427 | img, txt = block(img=img, txt=txt, vec=vec, pe=pe) 428 | if first_block: 429 | first_block = False 430 | if BlockCache.this_step <= BlockCache.nocache_steps: 431 | skip_check = False 432 | elif BlockCache.always_last and BlockCache.this_step >= BlockCache.last_step: 433 | skip_check = False 434 | else: 435 | skip_check = True 436 | if BlockCache.previous[index] is None or BlockCache.residual[index] is None: 437 | skip_check = False 438 | if BlockCache.skip_limit > 0 and BlockCache.skipped[index] >= BlockCache.skip_limit: 439 | skip_check = False 440 | 441 | if skip_check: 442 | BlockCache.distance[index] += ((img - BlockCache.previous[index]).abs().mean() / BlockCache.previous[index].abs().mean()).cpu().item() 443 | BlockCache.previous[index] = img.clone() 444 | if BlockCache.distance[index] < BlockCache.threshold: 445 | BlockCache.skipped[index] += 1 446 | img = original_img + BlockCache.residual[index] 447 | img = self.final_layer(img, vec) 448 | return img ## early exit 449 | else: 450 | BlockCache.previous[index] = img 451 | 452 | img = torch.cat((txt, img), 1) 453 | for block in self.single_blocks: 454 | img = block(img, vec=vec, pe=pe) 455 | img = img[:, txt.shape[1]:, ...] 456 | 457 | BlockCache.residual[index] = img - original_img 458 | BlockCache.distance[index] = 0 459 | BlockCache.skipped[index] = 0 460 | 461 | img = self.final_layer(img, vec) 462 | return img 463 | 464 | 465 | def patched_forward_unet_fbc(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): 466 | # BlockCache version 467 | 468 | thisSigma = transformer_options["sigmas"][0].item() 469 | if BlockCache.previousSigma == thisSigma: 470 | BlockCache.index += 1 471 | if BlockCache.index == len(BlockCache.distance): 472 | BlockCache.distance.append(0) 473 | BlockCache.residual.append(None) 474 | BlockCache.previous.append(None) 475 | BlockCache.skipped.append(0) 476 | else: 477 | BlockCache.previousSigma = thisSigma 478 | BlockCache.index = 0 479 | BlockCache.this_step += 1 480 | 481 | index = BlockCache.index 482 | residual = BlockCache.residual[index] 483 | previous = BlockCache.previous[index] 484 | distance = BlockCache.distance[index] 485 | skipped = BlockCache.skipped[index] 486 | 487 | transformer_options["original_shape"] = list(x.shape) 488 | transformer_options["transformer_index"] = 0 489 | transformer_patches = transformer_options.get("patches", {}) 490 | block_modifiers = transformer_options.get("block_modifiers", []) 491 | assert (y is not None) == (self.num_classes is not None) 492 | hs = [] 493 | t_emb = timestep_embedding_unet(timesteps, self.model_channels, repeat_only=False).to(x.dtype) 494 | emb = self.time_embed(t_emb) 495 | if self.num_classes is not None: 496 | assert y.shape[0] == x.shape[0] 497 | emb = emb + self.label_emb(y) 498 | h = x 499 | 500 | original_h = h.clone() 501 | 502 | skip = False 503 | first_block = True 504 | for id, module in enumerate(self.input_blocks): 505 | transformer_options["block"] = ("input", id) 506 | for block_modifier in block_modifiers: 507 | h = block_modifier(h, 'before', transformer_options) 508 | h = module(h, emb, context, transformer_options) 509 | h = apply_control(h, control, 'input') 510 | for block_modifier in block_modifiers: 511 | h = block_modifier(h, 'after', transformer_options) 512 | if "input_block_patch" in transformer_patches: 513 | patch = transformer_patches["input_block_patch"] 514 | for p in patch: 515 | h = p(h, transformer_options) 516 | hs.append(h) 517 | if "input_block_patch_after_skip" in transformer_patches: 518 | patch = transformer_patches["input_block_patch_after_skip"] 519 | for p in patch: 520 | h = p(h, transformer_options) 521 | 522 | if first_block: 523 | first_block = False 524 | if BlockCache.this_step <= BlockCache.nocache_steps: 525 | skip_check = False 526 | elif BlockCache.always_last and BlockCache.this_step >= BlockCache.last_step: 527 | skip_check = False 528 | else: 529 | skip_check = True 530 | if previous is None or residual is None: 531 | skip_check = False 532 | if BlockCache.skip_limit > 0 and skipped >= BlockCache.skip_limit: 533 | skip_check = False 534 | 535 | if skip_check: 536 | distance += ((h - previous).abs().mean() / previous.abs().mean()).cpu().item() 537 | previous = h.clone() 538 | if distance < BlockCache.threshold: 539 | h = original_h + residual 540 | skip = True 541 | skipped += 1 542 | break 543 | else: 544 | previous = h.clone() 545 | 546 | if not skip: 547 | transformer_options["block"] = ("middle", 0) 548 | for block_modifier in block_modifiers: 549 | h = block_modifier(h, 'before', transformer_options) 550 | h = self.middle_block(h, emb, context, transformer_options) 551 | h = apply_control(h, control, 'middle') 552 | for block_modifier in block_modifiers: 553 | h = block_modifier(h, 'after', transformer_options) 554 | for id, module in enumerate(self.output_blocks): 555 | transformer_options["block"] = ("output", id) 556 | hsp = hs.pop() 557 | hsp = apply_control(hsp, control, 'output') 558 | if "output_block_patch" in transformer_patches: 559 | patch = transformer_patches["output_block_patch"] 560 | for p in patch: 561 | h, hsp = p(h, hsp, transformer_options) 562 | h = torch.cat([h, hsp], dim=1) 563 | del hsp 564 | if len(hs) > 0: 565 | output_shape = hs[-1].shape 566 | else: 567 | output_shape = None 568 | for block_modifier in block_modifiers: 569 | h = block_modifier(h, 'before', transformer_options) 570 | h = module(h, emb, context, transformer_options, output_shape) 571 | for block_modifier in block_modifiers: 572 | h = block_modifier(h, 'after', transformer_options) 573 | transformer_options["block"] = ("last", 0) 574 | for block_modifier in block_modifiers: 575 | h = block_modifier(h, 'before', transformer_options) 576 | if "group_norm_wrapper" in transformer_options: 577 | out_norm, out_rest = self.out[0], self.out[1:] 578 | h = transformer_options["group_norm_wrapper"](out_norm, h, transformer_options) 579 | h = out_rest(h) 580 | else: 581 | h = self.out(h) 582 | for block_modifier in block_modifiers: 583 | h = block_modifier(h, 'after', transformer_options) 584 | 585 | residual = h - original_h 586 | distance = 0 587 | skipped = 0 588 | 589 | BlockCache.residual[index] = residual 590 | BlockCache.previous[index] = previous 591 | BlockCache.distance[index] = distance 592 | BlockCache.skipped[index] = skipped 593 | 594 | return h.type(x.dtype) 595 | 596 | 597 | def patched_forward_mmditx_tc( 598 | self, 599 | x: torch.Tensor, 600 | t: torch.Tensor, 601 | y = None, 602 | context = None, 603 | control=None, transformer_options={}, **kwargs) -> torch.Tensor: 604 | """ 605 | Forward pass of DiT. 606 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 607 | t: (N,) tensor of diffusion timesteps 608 | y: (N,) tensor of class labels 609 | """ 610 | 611 | thisSigma = t[0].item() 612 | if BlockCache.previousSigma == thisSigma: 613 | BlockCache.index += 1 614 | if BlockCache.index == len(BlockCache.distance): 615 | BlockCache.distance.append(0) 616 | BlockCache.residual.append(None) 617 | BlockCache.previous.append(None) 618 | BlockCache.skipped.append(0) 619 | else: 620 | BlockCache.previousSigma = thisSigma 621 | BlockCache.index = 0 622 | BlockCache.this_step += 1 623 | 624 | index = BlockCache.index 625 | residual = BlockCache.residual[index] 626 | previous = BlockCache.previous[index] 627 | distance = BlockCache.distance[index] 628 | skipped = BlockCache.skipped[index] 629 | 630 | if BlockCache.this_step <= BlockCache.nocache_steps: 631 | skip_check = False 632 | elif BlockCache.always_last and BlockCache.this_step == BlockCache.last_step: 633 | skip_check = False 634 | else: 635 | skip_check = True 636 | if previous is None or previous.shape != x.shape: 637 | skip_check = False 638 | if residual is None: 639 | skip_check = False 640 | if BlockCache.skip_limit > 0 and skipped >= BlockCache.skip_limit: 641 | skip_check = False 642 | 643 | epsilon = 1e-6 644 | 645 | skip = False 646 | if skip_check: 647 | # distance += ((x - previous).abs().mean() / previous.abs().mean()).cpu().item() 648 | 649 | # coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01] 650 | # rescale_func = np.poly1d(coefficients) 651 | # distance += rescale_func( 652 | # ((x - previous).abs().mean() / previous.abs().mean()).cpu().item() 653 | # ) 654 | # print ("SD3 tc distance:", distance); 655 | 656 | thisDistance = torch.zeros_like(x) 657 | for i in range(len(x)): 658 | thisDistance += (x[i] - BlockCache.previous[index][i]).abs() / (epsilon + BlockCache.previous[index][i].abs()) 659 | 660 | avgDistance = thisDistance.mean().cpu().item() 661 | 662 | # fullDistance = ((x - previous).abs().mean() / previous.abs().mean()).cpu().item() 663 | # print (avgDistance, fullDistance) 664 | distance += avgDistance 665 | 666 | if distance < BlockCache.threshold: 667 | skip = True 668 | 669 | 670 | previous = x.clone() 671 | 672 | if skip: 673 | x += residual 674 | skipped += 1 675 | else: 676 | hw = x.shape[-2:] 677 | x = self.x_embedder(x) + self.cropped_pos_embed(hw).to(x.device, x.dtype) 678 | skip_layers = transformer_options.get("skip_layers", []) 679 | c = self.t_embedder(t, dtype=x.dtype) # (N, D) 680 | if y is not None: 681 | y = self.y_embedder(y) # (N, D) 682 | c = c + y # (N, D) 683 | 684 | context = self.context_embedder(context) 685 | 686 | x = self.forward_core_with_concat(x, c, context, skip_layers, control) 687 | x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W) 688 | 689 | residual = x - previous 690 | distance = 0 691 | skipped = 0 692 | 693 | 694 | BlockCache.residual[index] = residual 695 | BlockCache.previous[index] = previous 696 | BlockCache.distance[index] = distance 697 | BlockCache.skipped[index] = skipped 698 | 699 | return x 700 | 701 | 702 | def patched_inner_forward_chroma_tc(self, img, img_ids, txt, txt_ids, timesteps, guidance=None): 703 | # TeaCache version 704 | 705 | thisSigma = timesteps[0].item() 706 | if BlockCache.previousSigma == thisSigma: 707 | BlockCache.index += 1 708 | if BlockCache.index == len(BlockCache.distance): 709 | BlockCache.distance.append(0) 710 | BlockCache.residual.append(None) 711 | BlockCache.previous.append(None) 712 | BlockCache.skipped.append(0) 713 | else: 714 | BlockCache.previousSigma = thisSigma 715 | BlockCache.index = 0 716 | BlockCache.this_step += 1 717 | 718 | index = BlockCache.index 719 | 720 | if img.ndim != 3 or txt.ndim != 3: 721 | raise ValueError("Input img and txt tensors must have 3 dimensions.") 722 | 723 | img = self.img_in(img) 724 | device = img.device 725 | dtype = img.dtype 726 | nb_double_block = len(self.double_blocks) 727 | nb_single_block = len(self.single_blocks) 728 | 729 | mod_index_length = nb_double_block*12 + nb_single_block*3 + 2 730 | distill_timestep = timestep_embedding_chroma(timesteps.detach().clone(), 16).to(device=device, dtype=dtype) 731 | distil_guidance = timestep_embedding_chroma(guidance.detach().clone(), 16).to(device=device, dtype=dtype) 732 | modulation_index = timestep_embedding_chroma(torch.arange(mod_index_length), 32).to(device=device, dtype=dtype) 733 | modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1) 734 | timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1) 735 | input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1) 736 | mod_vectors = self.distilled_guidance_layer(input_vec) 737 | mod_vectors_dict = self.distribute_modulations(mod_vectors, nb_single_block, nb_double_block) 738 | 739 | txt = self.txt_in(txt) 740 | del guidance 741 | ids = torch.cat((txt_ids, img_ids), dim=1) 742 | del txt_ids, img_ids 743 | pe = self.pe_embedder(ids) 744 | del ids 745 | 746 | original_img = img.clone() 747 | 748 | if BlockCache.this_step <= BlockCache.nocache_steps: 749 | skip_check = False 750 | elif BlockCache.always_last and BlockCache.this_step == BlockCache.last_step: 751 | skip_check = False 752 | else: 753 | skip_check = True 754 | if BlockCache.previous[index] is None or BlockCache.previous[index].shape != original_img.shape: 755 | skip_check = False 756 | if BlockCache.residual[index] is None: 757 | skip_check = False 758 | if BlockCache.skip_limit > 0 and BlockCache.skipped[index] >= BlockCache.skip_limit: 759 | skip_check = False 760 | 761 | skip = False 762 | if skip_check: 763 | coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01] 764 | rescale_func = np.poly1d(coefficients) 765 | BlockCache.distance[index] += rescale_func( 766 | ((original_img - BlockCache.previous[index]).abs().mean() / BlockCache.previous[index].abs().mean()).cpu().item() 767 | ) 768 | 769 | if BlockCache.distance[index] < BlockCache.threshold: 770 | skip = True 771 | 772 | BlockCache.previous[index] = original_img 773 | 774 | if skip: 775 | img += BlockCache.residual[index] 776 | BlockCache.skipped[index] += 1 777 | else: 778 | for i, block in enumerate(self.double_blocks): 779 | img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"] 780 | txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"] 781 | double_mod = [img_mod, txt_mod] 782 | img, txt = block(img=img, txt=txt, mod=double_mod, pe=pe) 783 | img = torch.cat((txt, img), 1) 784 | for i, block in enumerate(self.single_blocks): 785 | single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] 786 | img = block(img, mod=single_mod, pe=pe) 787 | del pe 788 | img = img[:, txt.shape[1]:, ...] 789 | # final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"] 790 | # img = self.final_layer(img, final_mod) 791 | 792 | BlockCache.residual[index] = img - original_img 793 | BlockCache.distance[index] = 0 794 | BlockCache.skipped[index] = 0 795 | 796 | final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"] 797 | img = self.final_layer(img, final_mod) 798 | return img 799 | 800 | 801 | def patched_inner_forward_flux_tc(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None): 802 | # TeaCache version 803 | 804 | thisSigma = timesteps[0].item() 805 | if BlockCache.previousSigma == thisSigma: 806 | BlockCache.index += 1 807 | if BlockCache.index == len(BlockCache.distance): 808 | BlockCache.distance.append(0) 809 | BlockCache.residual.append(None) 810 | BlockCache.previous.append(None) 811 | BlockCache.skipped.append(0) 812 | else: 813 | BlockCache.previousSigma = thisSigma 814 | BlockCache.index = 0 815 | BlockCache.this_step += 1 816 | 817 | index = BlockCache.index 818 | 819 | if img.ndim != 3 or txt.ndim != 3: 820 | raise ValueError("Input img and txt tensors must have 3 dimensions.") 821 | 822 | # Image and text embedding 823 | img = self.img_in(img) 824 | vec = self.time_in(timestep_embedding_flux(timesteps, 256).to(img.dtype)) 825 | 826 | # If guidance_embed is enabled, add guidance information 827 | if self.guidance_embed: 828 | if guidance is None: 829 | raise ValueError("Didn't get guidance strength for guidance distilled model.") 830 | vec = vec + self.guidance_in(timestep_embedding_flux(guidance, 256).to(img.dtype)) 831 | 832 | vec = vec + self.vector_in(y) 833 | txt = self.txt_in(txt) 834 | 835 | # Merge image and text IDs 836 | ids = torch.cat((txt_ids, img_ids), dim=1) 837 | pe = self.pe_embedder(ids) 838 | 839 | original_img = img.clone() 840 | 841 | if BlockCache.this_step <= BlockCache.nocache_steps: 842 | skip_check = False 843 | elif BlockCache.always_last and BlockCache.this_step == BlockCache.last_step: 844 | skip_check = False 845 | else: 846 | skip_check = True 847 | if BlockCache.previous[index] is None or BlockCache.previous[index].shape != original_img.shape: 848 | skip_check = False 849 | if BlockCache.residual[index] is None: 850 | skip_check = False 851 | if BlockCache.skip_limit > 0 and BlockCache.skipped[index] >= BlockCache.skip_limit: 852 | skip_check = False 853 | 854 | skip = False 855 | if skip_check: 856 | coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01] 857 | rescale_func = np.poly1d(coefficients) 858 | BlockCache.distance[index] += rescale_func( 859 | ((original_img - BlockCache.previous[index]).abs().mean() / BlockCache.previous[index].abs().mean()).cpu().item() 860 | ) 861 | 862 | if BlockCache.distance[index] < BlockCache.threshold: 863 | skip = True 864 | 865 | BlockCache.previous[index] = original_img 866 | 867 | if skip: 868 | img += BlockCache.residual[index] 869 | BlockCache.skipped[index] += 1 870 | else: 871 | for block in self.double_blocks: 872 | img, txt = block(img=img, txt=txt, vec=vec, pe=pe) 873 | img = torch.cat((txt, img), 1) 874 | for block in self.single_blocks: 875 | img = block(img, vec=vec, pe=pe) 876 | img = img[:, txt.shape[1]:, ...] 877 | BlockCache.residual[index] = img - original_img 878 | BlockCache.distance[index] = 0 879 | BlockCache.skipped[index] = 0 880 | 881 | img = self.final_layer(img, vec) 882 | return img 883 | 884 | 885 | def patched_forward_unet_tc(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): 886 | # TeaCache version 887 | 888 | thisSigma = transformer_options["sigmas"][0].item() 889 | if BlockCache.previousSigma == thisSigma: 890 | BlockCache.index += 1 891 | if BlockCache.index == len(BlockCache.distance): 892 | BlockCache.distance.append(0) 893 | BlockCache.residual.append(None) 894 | BlockCache.previous.append(None) 895 | BlockCache.skipped.append(0) 896 | else: 897 | BlockCache.previousSigma = thisSigma 898 | BlockCache.index = 0 899 | BlockCache.this_step += 1 900 | 901 | index = BlockCache.index 902 | residual = BlockCache.residual[index] 903 | previous = BlockCache.previous[index] 904 | distance = BlockCache.distance[index] 905 | skipped = BlockCache.skipped[index] 906 | 907 | # print (BlockCache.this_step, index, thisSigma, distance) 908 | 909 | transformer_options["original_shape"] = list(x.shape) 910 | transformer_options["transformer_index"] = 0 911 | transformer_patches = transformer_options.get("patches", {}) 912 | block_modifiers = transformer_options.get("block_modifiers", []) 913 | assert (y is not None) == (self.num_classes is not None) 914 | hs = [] 915 | t_emb = timestep_embedding_unet(timesteps, self.model_channels, repeat_only=False).to(x.dtype) 916 | emb = self.time_embed(t_emb) 917 | if self.num_classes is not None: 918 | assert y.shape[0] == x.shape[0] 919 | emb = emb + self.label_emb(y) 920 | h = x 921 | 922 | original_h = h.clone() 923 | 924 | if BlockCache.this_step <= BlockCache.nocache_steps: 925 | skip_check = False 926 | elif BlockCache.always_last and BlockCache.this_step == BlockCache.last_step: 927 | skip_check = False 928 | else: 929 | skip_check = True 930 | if previous is None or previous.shape != original_h.shape: 931 | skip_check = False 932 | if residual is None: 933 | skip_check = False 934 | if BlockCache.skip_limit > 0 and skipped >= BlockCache.skip_limit: 935 | skip_check = False 936 | 937 | skip = False 938 | if skip_check: 939 | distance += ((original_h - BlockCache.previous[index]).abs().mean() / BlockCache.previous[index].abs().mean()).cpu().item() 940 | 941 | if distance < BlockCache.threshold: 942 | skip = True 943 | 944 | if skip: 945 | h += residual 946 | skipped += 1 947 | else: 948 | for id, module in enumerate(self.input_blocks): 949 | transformer_options["block"] = ("input", id) 950 | for block_modifier in block_modifiers: 951 | h = block_modifier(h, 'before', transformer_options) 952 | h = module(h, emb, context, transformer_options) 953 | h = apply_control(h, control, 'input') 954 | for block_modifier in block_modifiers: 955 | h = block_modifier(h, 'after', transformer_options) 956 | if "input_block_patch" in transformer_patches: 957 | patch = transformer_patches["input_block_patch"] 958 | for p in patch: 959 | h = p(h, transformer_options) 960 | hs.append(h) 961 | if "input_block_patch_after_skip" in transformer_patches: 962 | patch = transformer_patches["input_block_patch_after_skip"] 963 | for p in patch: 964 | h = p(h, transformer_options) 965 | 966 | transformer_options["block"] = ("middle", 0) 967 | for block_modifier in block_modifiers: 968 | h = block_modifier(h, 'before', transformer_options) 969 | h = self.middle_block(h, emb, context, transformer_options) 970 | h = apply_control(h, control, 'middle') 971 | for block_modifier in block_modifiers: 972 | h = block_modifier(h, 'after', transformer_options) 973 | for id, module in enumerate(self.output_blocks): 974 | transformer_options["block"] = ("output", id) 975 | hsp = hs.pop() 976 | hsp = apply_control(hsp, control, 'output') 977 | if "output_block_patch" in transformer_patches: 978 | patch = transformer_patches["output_block_patch"] 979 | for p in patch: 980 | h, hsp = p(h, hsp, transformer_options) 981 | h = torch.cat([h, hsp], dim=1) 982 | del hsp 983 | if len(hs) > 0: 984 | output_shape = hs[-1].shape 985 | else: 986 | output_shape = None 987 | for block_modifier in block_modifiers: 988 | h = block_modifier(h, 'before', transformer_options) 989 | h = module(h, emb, context, transformer_options, output_shape) 990 | for block_modifier in block_modifiers: 991 | h = block_modifier(h, 'after', transformer_options) 992 | transformer_options["block"] = ("last", 0) 993 | for block_modifier in block_modifiers: 994 | h = block_modifier(h, 'before', transformer_options) 995 | if "group_norm_wrapper" in transformer_options: 996 | out_norm, out_rest = self.out[0], self.out[1:] 997 | h = transformer_options["group_norm_wrapper"](out_norm, h, transformer_options) 998 | h = out_rest(h) 999 | else: 1000 | h = self.out(h) 1001 | for block_modifier in block_modifiers: 1002 | h = block_modifier(h, 'after', transformer_options) 1003 | 1004 | residual = h - original_h 1005 | distance = 0 1006 | skipped = 0 1007 | 1008 | BlockCache.residual[index] = residual 1009 | BlockCache.previous[index] = original_h 1010 | BlockCache.distance[index] = distance 1011 | BlockCache.skipped[index] = skipped 1012 | 1013 | return h.type(x.dtype) 1014 | 1015 | --------------------------------------------------------------------------------