├── pyproject.toml ├── .github └── workflows │ └── publish.yml ├── README.md ├── __init__.py ├── LICENSE └── ras.py /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-regional-adaptive-sampling" 3 | description = "ComfyUI implementation of Regional Adaptive Sampling, (original implementation at https://github.com/microsoft/RAS)." 4 | version = "1.1.0" 5 | license = {file = "LICENSE"} 6 | dependencies = ["torch", "einops", "diffusers"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/Slickytail/ComfyUI-RegionalAdaptiveSampling" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "slickytail" 14 | DisplayName = "ComfyUI-RegionalAdaptiveSampling" 15 | Icon = "" 16 | 17 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | permissions: 11 | issues: write 12 | 13 | jobs: 14 | publish-node: 15 | name: Publish Custom Node to registry 16 | runs-on: ubuntu-latest 17 | if: ${{ github.repository_owner == 'Slickytail' }} 18 | steps: 19 | - name: Check out code 20 | uses: actions/checkout@v4 21 | - name: Publish Custom Node 22 | uses: Comfy-Org/publish-node-action@v1 23 | with: 24 | ## Add your own personal access token to your Github Repository secrets and reference it here. 25 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Regional Adaptive Sampling 2 | 3 | *** Version 1.1.0***: Support for Wan/VACE. 4 | It's currently experimental status. I tested with Wan2.1 Vace 1.3B T2I and Wan2.2 T2I 5B. 5 | Please feel free to open an issue if your wan workflow doesn't work (there are a lot of Wan variants, and I might not have handled all cases correctly). 6 | 7 | [Regional Adaptive Sampling](https://github.com/microsoft/RAS) is a new technique for accelerating the inference of diffusion transformers. 8 | It essentially works as a KV Cache inside the model, picking regions that are likely to be updated by each diffusion step and passing in only those tokens. 9 | 10 | This implementation is simple to use, and compatible with Flux (dev & schnell), HunYuanVideo, and some Wan variants. 11 | 12 | ## Usage 13 | Apply the `Regional Adaptive Sampling` node to the desired model. It has the following parameters: 14 | - **sample_ratio**: The percent of tokens to keep in the model on a RAS pass. Anything below 0.3 is usually very bad quality. 15 | - **warmup_steps**: The number of steps to do without RAS at the beginning. Setting higher will decrease the speedup, and setting it lower will degrade the composition. 16 | - **hydrate_every**: Every `hydrate_every` steps, we do a full run through the model with all tokens, to refresh the stale cache. Set to 0 to disable and do full RAS. 17 | - **starvation_scale**: Controls how the model decides which part of the image to focus on. Increasing it will probably shift quality from the main subject to the background. The default of 0.1 is what's used in the paper, and I haven't tried anything else. 18 | 19 | ## Todos: 20 | support batch size > 1 or cfg (makes the token caching logic more complicated) 21 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .ras import RASConfig, RASManager 2 | from comfy.model_patcher import ModelPatcher 3 | 4 | 5 | class RegionalAdaptiveSampling: 6 | @classmethod 7 | def INPUT_TYPES(s): 8 | return { 9 | "required": { 10 | "model": ("MODEL",), 11 | "sample_ratio": ( 12 | "FLOAT", 13 | {"default": 0.5, "min": 0.05, "max": 1.0, "step": 0.05}, 14 | ), 15 | "warmup_steps": ( 16 | "INT", 17 | {"default": 4, "min": 0, "max": 100}, 18 | ), 19 | "hydrate_every": ( 20 | "INT", 21 | {"default": 4, "min": 0, "max": 100}, 22 | ), 23 | "starvation_scale": ( 24 | "FLOAT", 25 | {"default": 0.1, "min": 0.01, "max": 1.0, "step": 0.01}, 26 | ), 27 | } 28 | } 29 | 30 | RETURN_TYPES = ("MODEL",) 31 | FUNCTION = "apply_ras" 32 | CATEGORY = "ras" 33 | 34 | def apply_ras( 35 | self, 36 | model: ModelPatcher, 37 | sample_ratio: float, 38 | warmup_steps: int, 39 | hydrate_every: int, 40 | starvation_scale: float, 41 | ): 42 | model = model.clone() 43 | # unpatch the model 44 | # this makes sure that we're wrapping the model "in a pure state" 45 | # the model will repatch itself later 46 | model.unpatch_model() 47 | config = RASConfig( 48 | sample_ratio=sample_ratio, 49 | warmup_steps=warmup_steps, 50 | hydrate_every=hydrate_every, 51 | starvation_scale=starvation_scale, 52 | ) 53 | manager = RASManager(config) 54 | manager.wrap_model(model) 55 | return (model,) 56 | 57 | 58 | NODE_CLASS_MAPPINGS = {"RegionalAdaptiveSampling": RegionalAdaptiveSampling} 59 | NODE_DISPLAY_NAME_MAPPING = {"RegionalAdaptiveSampling": "Regional Adaptive Sampling"} 60 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPING"] 61 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /ras.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from types import MethodType 3 | 4 | import torch 5 | from torch import nn 6 | from torch import Tensor 7 | from einops import rearrange 8 | 9 | from comfy.ldm.flux.model import Flux 10 | from comfy.ldm.hunyuan_video.model import HunyuanVideo 11 | from comfy.ldm.flux.layers import DoubleStreamBlock, SingleStreamBlock, LastLayer 12 | from comfy.ldm.wan.model import ( 13 | WanModel, 14 | VaceWanModel, 15 | CameraWanModel, 16 | WanModel_S2V, 17 | HumoWanModel, 18 | WanAttentionBlock, 19 | VaceWanAttentionBlock, 20 | Head, 21 | ) 22 | from comfy.ldm.flux.math import apply_rope1 23 | from comfy.ldm.modules.attention import optimized_attention 24 | from comfy.model_patcher import ModelPatcher 25 | import comfy.model_management 26 | 27 | 28 | def apply_pe(x: Tensor, pe: Tensor) -> Tensor: 29 | """ 30 | The PE application from flux.math.attention, removed, so that we can cache the keys post-PE 31 | """ 32 | shape = x.shape 33 | dtype = x.dtype 34 | x = x.float().reshape(*x.shape[:-1], -1, 1, 2) 35 | x = (pe[..., 0] * x[..., 0] + pe[..., 1] * x[..., 1]).reshape(*shape).to(dtype) 36 | 37 | return x 38 | 39 | 40 | def take_attributes_from(source, target, keys): 41 | for x in keys: 42 | setattr(target, x, getattr(source, x)) 43 | 44 | 45 | @dataclass 46 | class RASConfig: 47 | warmup_steps: int = 4 48 | hydrate_every: int = 5 49 | sample_ratio: float = 0.5 50 | starvation_scale: float = 0.1 51 | high_ratio: float = 1.0 52 | 53 | 54 | class RASManager: 55 | """ 56 | Coordinates the live indices, metrics, and model wrapping. 57 | """ 58 | 59 | def __init__(self, config: RASConfig): 60 | self.flipped_img_txt = False 61 | self.timestep: int = 0 62 | self.n_txt: int = 0 63 | self.n_img: int = 0 64 | self.cached_output: Tensor | None = None 65 | self.live_txt_indices: Tensor | None = None 66 | self.live_img_indices: Tensor | None = None 67 | self.drop_count: torch.Tensor | None = None 68 | self.config = config 69 | self.patch_size: list[int] 70 | self.model: Flux | HunyuanVideo 71 | assert ( 72 | self.config.high_ratio >= 0 and self.config.high_ratio <= 1 73 | ), "High ratio should be in the range of [0, 1]" 74 | 75 | def wrap_layer(self, layer, first: bool = False, last: bool = False): 76 | if isinstance( 77 | layer, 78 | ( 79 | DoubleStreamBlockWrapper, 80 | SingleStreamBlockWrapper, 81 | LastLayerWrapper, 82 | WanAttentionBlockWrapper, 83 | VaceWanAttentionBlockWrapper, 84 | ), 85 | ): 86 | raise TypeError("Old wrapping wasn't removed!") 87 | 88 | if isinstance(layer, DoubleStreamBlock): 89 | wrapped = DoubleStreamBlockWrapper(layer, self, first) 90 | elif isinstance(layer, SingleStreamBlock): 91 | wrapped = SingleStreamBlockWrapper(layer, self, last) 92 | elif isinstance(layer, LastLayer): 93 | wrapped = LastLayerWrapper(layer, self) 94 | # note: vacewaneattentionblock is a subclass of wanattentionblock 95 | # so we have to check for the vace block first 96 | elif isinstance(layer, VaceWanAttentionBlock): 97 | wrapped = VaceWanAttentionBlockWrapper(layer, self, first, last) 98 | elif isinstance(layer, WanAttentionBlock): 99 | wrapped = WanAttentionBlockWrapper(layer, self, first, last) 100 | elif isinstance(layer, Head): 101 | wrapped = HeadWrapper(layer, self) 102 | else: 103 | raise TypeError(f"Can't wrap layer of type {layer.__class__.__name__}") 104 | return wrapped 105 | 106 | def wrap_model(self, patcher: ModelPatcher): 107 | model = patcher.model.diffusion_model 108 | self.model = model 109 | if isinstance(model, Flux): 110 | self.patch_size = [model.patch_size, model.patch_size] 111 | elif isinstance(model, HunyuanVideo): 112 | self.patch_size = model.patch_size 113 | elif isinstance( 114 | model, (WanModel, VaceWanModel, CameraWanModel, WanModel_S2V, HumoWanModel) 115 | ): 116 | self.patch_size = model.patch_size 117 | else: 118 | raise TypeError(f"Can't wrap model of type {model.__class__.__name__}") 119 | # Handle different model architectures 120 | if isinstance(model, (Flux, HunyuanVideo)): 121 | # wrap the single and double blocks to have caching 122 | for i, v in enumerate(model.double_blocks): 123 | # first block has the special responsibility of removing tokens 124 | if i == 0: 125 | self.flipped_img_txt = v.flipped_img_txt 126 | layer = self.wrap_layer(v, first=True) 127 | else: 128 | layer = self.wrap_layer(v) 129 | patcher.add_object_patch(f"diffusion_model.double_blocks.{i}", layer) 130 | for i, v in enumerate(model.single_blocks): 131 | # last block will put them back 132 | patcher.add_object_patch( 133 | f"diffusion_model.single_blocks.{i}", 134 | self.wrap_layer(v, last=(i == (len(model.single_blocks) - 1))), 135 | ) 136 | elif isinstance( 137 | model, (WanModel, VaceWanModel, CameraWanModel, WanModel_S2V, HumoWanModel) 138 | ): 139 | # Wan models have a different structure with just 'blocks' 140 | self.flipped_img_txt = False # Wan doesn't use the flipped pattern 141 | for i, v in enumerate(model.blocks): 142 | # first block has the special responsibility of removing tokens 143 | # last block will put them back 144 | layer = self.wrap_layer( 145 | v, first=(i == 0), last=(i == (len(model.blocks) - 1)) 146 | ) 147 | patcher.add_object_patch(f"diffusion_model.blocks.{i}", layer) 148 | 149 | # todo, add the vace blocks as well here 150 | if hasattr(model, "vace_blocks"): 151 | for i, v in enumerate(model.vace_blocks): 152 | layer = self.wrap_layer( 153 | v, 154 | first=(i == 0), 155 | # we DONT put back the conditioning tokens 156 | # because the last vace block is well before the last real block 157 | last=False, 158 | # last=(i == (len(model.vace_blocks) - 1)) 159 | ) 160 | patcher.add_object_patch(f"diffusion_model.vace_blocks.{i}", layer) 161 | 162 | # wrap the forward_orig method, to be able to get the timestep 163 | forward_orig = model.forward_orig 164 | 165 | def new_forward(_self, *args, **kwargs): 166 | # Get transformer_options from kwargs for Wan models, or from args for Flux/Hunyuan 167 | if "transformer_options" in kwargs: 168 | transformer_options = kwargs["transformer_options"] 169 | else: 170 | transformer_options = args[-1] 171 | 172 | self.timestep = self.timestep_from_sigmas( 173 | transformer_options["sigmas"], transformer_options["sample_sigmas"] 174 | ) 175 | if self.timestep == 0: 176 | # reset as much as possible 177 | self.live_img_indices = None 178 | self.live_txt_indices = None 179 | self.drop_count = None 180 | return forward_orig(*args, **kwargs) 181 | 182 | patcher.add_object_patch( 183 | "diffusion_model.forward_orig", MethodType(new_forward, model) 184 | ) 185 | # wrap the last_layer, to be able to read the output and calculate the metric 186 | if isinstance(model, (Flux, HunyuanVideo)): 187 | patcher.add_object_patch( 188 | "diffusion_model.final_layer", self.wrap_layer(model.final_layer) 189 | ) 190 | elif isinstance( 191 | model, (WanModel, VaceWanModel, CameraWanModel, WanModel_S2V, HumoWanModel) 192 | ): 193 | # Wan models use 'head' instead of 'final_layer' 194 | patcher.add_object_patch( 195 | "diffusion_model.head", self.wrap_layer(model.head) 196 | ) 197 | 198 | @staticmethod 199 | def timestep_from_sigmas(sigmas: Tensor, sample_sigmas: Tensor): 200 | # we assume that one element of sample_sigmas is exactly equal to sigmas 201 | # but we'll still check explicitly, using an argmin, in case of some loss of precision 202 | s = sigmas.item() 203 | i = torch.argmin(torch.abs(sample_sigmas - s).flatten()) 204 | return int(i.item()) 205 | 206 | def skip_ratio(self, timestep: int) -> float: 207 | if timestep < self.config.warmup_steps: 208 | return 0 209 | 210 | if self.config.hydrate_every: 211 | if ( 212 | 1 + timestep - self.config.warmup_steps 213 | ) % self.config.hydrate_every == 0: 214 | return 0 215 | result = 1.0 - self.config.sample_ratio 216 | return result 217 | 218 | def select_indices(self, diff: Tensor, timestep: int): 219 | if isinstance(self.model, Flux): 220 | # b (h w) (c ph pw) = model_out.shape 221 | metric = rearrange( 222 | diff, 223 | "b s (c ph pw) -> b s ph pw c", 224 | ph=self.patch_size[0], 225 | pw=self.patch_size[1], 226 | ) 227 | metric = torch.std(metric, dim=-1).mean((-1, -2)) 228 | elif isinstance( 229 | self.model, 230 | ( 231 | HunyuanVideo, 232 | WanModel, 233 | VaceWanModel, 234 | CameraWanModel, 235 | WanModel_S2V, 236 | HumoWanModel, 237 | ), 238 | ): 239 | # b (h w) (c ph pw) = model_out.shape 240 | # Both HunyuanVideo and Wan models are video models with 3D patches 241 | metric = rearrange( 242 | diff, 243 | "b s (c pt ph pw ) -> b s pt ph pw c", 244 | pt=self.patch_size[0], 245 | ph=self.patch_size[1], 246 | pw=self.patch_size[2], 247 | ) 248 | metric = torch.std(metric, dim=-1).mean((-1, -2, -3)) 249 | else: 250 | raise TypeError("Unknown latent type!") 251 | # for batch size > 1, we pick separate indices per batch 252 | # for now, JUST FOR TESTING, we'll merge all the batches and use the indices that are the most relevant for all batches 253 | metric = metric.mean(dim=0) 254 | metric = metric.flatten() 255 | if self.drop_count is None: 256 | self.drop_count = torch.zeros( 257 | metric.shape, dtype=torch.int, device=diff.device 258 | ) 259 | # hmm, what if we do a gaussian blur or some sort of spatial lowpass, to improve the spatial continuity of the patches? 260 | metric *= torch.exp(self.config.starvation_scale * self.drop_count) 261 | indices = torch.sort(metric, dim=-1, descending=False).indices 262 | skip_ratio = self.skip_ratio(timestep) 263 | if skip_ratio <= 0.01: 264 | # we're not dropping anything -- remove the live_indices 265 | # we use the value None to indicate a full hydrate 266 | self.live_img_indices = None 267 | else: 268 | low_bar = int(skip_ratio * len(metric) * (1 - self.config.high_ratio)) 269 | high_bar = int(skip_ratio * len(metric) * self.config.high_ratio) 270 | cache_indices = torch.cat([indices[:low_bar], indices[-high_bar:]]) 271 | self.live_img_indices = indices[low_bar:-high_bar] 272 | self.drop_count[cache_indices] += 1 273 | 274 | # TODO: for now we keep all txt tokens 275 | # in the future, we can probably do something like randomly keep a fraction of them 276 | if self.n_txt > 0: 277 | self.live_txt_indices = torch.arange( 278 | 0, self.n_txt, dtype=torch.int, device=diff.device 279 | ) 280 | 281 | def live_indices(self): 282 | if self.live_img_indices is None or self.live_txt_indices is None: 283 | return self.live_img_indices 284 | if self.flipped_img_txt: 285 | result = torch.cat( 286 | (self.live_img_indices, self.live_txt_indices + self.n_img) 287 | ) 288 | else: 289 | result = torch.cat( 290 | (self.live_txt_indices, self.live_img_indices + self.n_txt) 291 | ) 292 | return result 293 | 294 | 295 | class DoubleStreamBlockWrapper(DoubleStreamBlock): 296 | """ 297 | Same as the DoubleStreamBlock, but uses a RASManager and RASCache to do KV caching. 298 | """ 299 | 300 | def __init__(self, original: DoubleStreamBlock, manager: RASManager, first=False): 301 | nn.Module.__init__(self) 302 | take_attributes_from( 303 | original, 304 | self, 305 | [ 306 | "num_heads", 307 | "hidden_size", 308 | "img_mod", 309 | "img_norm1", 310 | "img_attn", 311 | "img_norm2", 312 | "img_mlp", 313 | "txt_mod", 314 | "txt_norm1", 315 | "txt_attn", 316 | "txt_norm2", 317 | "txt_mlp", 318 | "flipped_img_txt", 319 | ], 320 | ) 321 | self.manager = manager 322 | self.k_cache: torch.Tensor 323 | self.v_cache: torch.Tensor 324 | self.first = first 325 | 326 | def forward(self, img, txt, vec, pe, attn_mask=None): 327 | # RAS: if this is the first doublestreamblock, then we should drop some of the img and txt tokens 328 | idx = self.manager.live_indices() 329 | if self.first: 330 | self.manager.n_txt = txt.shape[1] 331 | self.manager.n_img = img.shape[1] 332 | 333 | img_idx = self.manager.live_img_indices 334 | txt_idx = self.manager.live_txt_indices 335 | if idx is not None: 336 | img = img[..., img_idx, :] 337 | txt = txt[..., txt_idx, :] 338 | 339 | img_mod1, img_mod2 = self.img_mod(vec) 340 | txt_mod1, txt_mod2 = self.txt_mod(vec) 341 | 342 | # prepare image for attention 343 | img_modulated = self.img_norm1(img) 344 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 345 | img_qkv = self.img_attn.qkv(img_modulated) 346 | img_q, img_k, img_v = img_qkv.view( 347 | img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1 348 | ).permute(2, 0, 3, 1, 4) 349 | img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) 350 | 351 | # prepare txt for attention 352 | txt_modulated = self.txt_norm1(txt) 353 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 354 | txt_qkv = self.txt_attn.qkv(txt_modulated) 355 | txt_q, txt_k, txt_v = txt_qkv.view( 356 | txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1 357 | ).permute(2, 0, 3, 1, 4) 358 | txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) 359 | 360 | # RAS: KV Cache and Attention Call 361 | # select part of the PE 362 | if idx is not None: 363 | pe = pe[:, :, idx] 364 | 365 | # create queries, keys, and values 366 | if self.flipped_img_txt: 367 | queries = apply_pe(torch.cat((img_q, txt_q), dim=2), pe) 368 | keys = apply_pe(torch.cat((img_k, txt_k), dim=2), pe) 369 | values = torch.cat((img_v, txt_v), dim=2) 370 | else: 371 | queries = apply_pe(torch.cat((txt_q, img_q), dim=2), pe) 372 | keys = apply_pe(torch.cat((txt_k, img_k), dim=2), pe) 373 | values = torch.cat((txt_v, img_v), dim=2) 374 | 375 | # fill in the KV cache 376 | if idx is None: 377 | self.k_cache = keys 378 | self.v_cache = values 379 | else: 380 | self.k_cache[..., idx, :] = keys 381 | self.v_cache[..., idx, :] = values 382 | # actual attention call 383 | attn = optimized_attention( 384 | queries, 385 | self.k_cache, 386 | self.v_cache, 387 | img_q.shape[1], 388 | skip_reshape=True, 389 | mask=attn_mask, 390 | ) 391 | if self.flipped_img_txt: 392 | img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] 393 | else: 394 | txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] 395 | # End of RAS code 396 | 397 | # calculate the img blocks 398 | img = img + img_mod1.gate * self.img_attn.proj(img_attn) 399 | img = img + img_mod2.gate * self.img_mlp( 400 | (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift 401 | ) 402 | 403 | # calculate the txt blocks 404 | txt += txt_mod1.gate * self.txt_attn.proj(txt_attn) 405 | txt += txt_mod2.gate * self.txt_mlp( 406 | (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift 407 | ) 408 | 409 | if txt.dtype == torch.float16: 410 | txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) 411 | 412 | return img, txt 413 | 414 | 415 | class SingleStreamBlockWrapper(SingleStreamBlock): 416 | """ 417 | Same as the SingleStreamBlock, but uses a RASManager and RASCache to do KV caching. 418 | """ 419 | 420 | def __init__(self, original: SingleStreamBlock, manager: RASManager, last=False): 421 | nn.Module.__init__(self) 422 | # steal all the attributes from the SingleStreamBlock 423 | take_attributes_from( 424 | original, 425 | self, 426 | [ 427 | "hidden_dim", 428 | "num_heads", 429 | "scale", 430 | "mlp_hidden_dim", 431 | "linear1", 432 | "linear2", 433 | "norm", 434 | "hidden_size", 435 | "pre_norm", 436 | "mlp_act", 437 | "modulation", 438 | ], 439 | ) 440 | self.manager = manager 441 | self.k_cache: torch.Tensor 442 | self.v_cache: torch.Tensor 443 | self.last = last 444 | 445 | def forward(self, x, vec, pe, attn_mask=None): 446 | idx = self.manager.live_indices() 447 | mod, _ = self.modulation(vec) 448 | qkv, mlp = torch.split( 449 | self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), 450 | [3 * self.hidden_size, self.mlp_hidden_dim], 451 | dim=-1, 452 | ) 453 | 454 | q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute( 455 | 2, 0, 3, 1, 4 456 | ) 457 | q, k = self.norm(q, k, v) 458 | 459 | # RAS: KV Cache 460 | if idx is not None: 461 | pe = pe[:, :, idx] 462 | q = apply_pe(q, pe) 463 | k = apply_pe(k, pe) 464 | # full hydrate 465 | if idx is None: 466 | self.k_cache = k 467 | self.v_cache = v 468 | # partial update 469 | else: 470 | self.k_cache[..., idx, :] = k 471 | self.v_cache[..., idx, :] = v 472 | attn = optimized_attention( 473 | q, 474 | self.k_cache, 475 | self.v_cache, 476 | q.shape[1], 477 | skip_reshape=True, 478 | mask=attn_mask, 479 | ) 480 | # End of RAS code 481 | # compute activation in mlp stream, cat again and run second linear layer 482 | output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) 483 | x += mod.gate * output 484 | if x.dtype == torch.float16: 485 | x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) 486 | 487 | if self.last: 488 | # put the relevant tokens back into the cached output 489 | if idx is None or self.manager.cached_output is None: 490 | self.manager.cached_output = x.clone() 491 | else: 492 | self.manager.cached_output[..., idx, :] = x 493 | return self.manager.cached_output 494 | return x 495 | 496 | 497 | class LastLayerWrapper(LastLayer): 498 | """ 499 | Same as the LastLayer, but reports its output to a manager. 500 | """ 501 | 502 | def __init__(self, original: LastLayer, manager: RASManager): 503 | nn.Module.__init__(self) 504 | take_attributes_from( 505 | original, self, ["norm_final", "linear", "adaLN_modulation"] 506 | ) 507 | self.manager = manager 508 | 509 | def forward(self, x, vec) -> Tensor: 510 | output = super().forward(x, vec) 511 | self.manager.select_indices(output, self.manager.timestep) 512 | return output 513 | 514 | 515 | class WanAttentionBlockWrapper(WanAttentionBlock): 516 | """ 517 | Same as the WanAttentionBlock, but uses a RASManager and RASCache to do KV caching. 518 | """ 519 | 520 | def __init__(self, original: WanAttentionBlock, manager: RASManager, first, last): 521 | nn.Module.__init__(self) 522 | take_attributes_from( 523 | original, 524 | self, 525 | [ 526 | "dim", 527 | "ffn_dim", 528 | "num_heads", 529 | "window_size", 530 | "qk_norm", 531 | "cross_attn_norm", 532 | "eps", 533 | "norm1", 534 | "self_attn", 535 | "norm3", 536 | "cross_attn", 537 | "norm2", 538 | "ffn", 539 | "modulation", 540 | ], 541 | ) 542 | self.manager = manager 543 | self.k_cache: torch.Tensor 544 | self.v_cache: torch.Tensor 545 | self.first = first 546 | self.last = last 547 | 548 | def forward( 549 | self, x, e, freqs, context, context_img_len=257, transformer_options={} 550 | ): 551 | # RAS: if this is the first block, then we should drop some tokens 552 | idx = self.manager.live_indices() 553 | 554 | if self.first: 555 | # we just have img tokens 556 | self.manager.n_txt = 0 557 | self.manager.n_img = x.shape[1] 558 | if idx is not None: 559 | x = x[..., idx, :] 560 | 561 | # Modulation handling (copied from original) 562 | if e.ndim < 4: 563 | e = ( 564 | comfy.model_management.cast_to( 565 | self.modulation, dtype=x.dtype, device=x.device 566 | ) 567 | + e 568 | ).chunk(6, dim=1) 569 | else: 570 | e = ( 571 | comfy.model_management.cast_to( 572 | self.modulation, dtype=x.dtype, device=x.device 573 | ).unsqueeze(0) 574 | + e 575 | ).unbind(2) 576 | 577 | # Self-attention with RAS caching 578 | y = self.self_attn_with_cache( 579 | torch.addcmul( 580 | self.repeat_e(e[0], x), self.norm1(x), 1 + self.repeat_e(e[1], x) 581 | ), 582 | freqs, 583 | idx, 584 | transformer_options=transformer_options, 585 | ) 586 | 587 | x = torch.addcmul(x, y, self.repeat_e(e[2], x)) 588 | 589 | # Cross-attention & ffn (unchanged from original) 590 | x = x + self.cross_attn( 591 | self.norm3(x), 592 | context, 593 | context_img_len=context_img_len, 594 | transformer_options=transformer_options, 595 | ) 596 | y = self.ffn( 597 | torch.addcmul( 598 | self.repeat_e(e[3], x), self.norm2(x), 1 + self.repeat_e(e[4], x) 599 | ) 600 | ) 601 | x = torch.addcmul(x, y, self.repeat_e(e[5], x)) 602 | 603 | if self.last: 604 | # put the relevant tokens back into the cached output 605 | if idx is None or self.manager.cached_output is None: 606 | self.manager.cached_output = x.clone() 607 | else: 608 | self.manager.cached_output[..., idx, :] = x 609 | return self.manager.cached_output 610 | return x 611 | 612 | def repeat_e(self, e, x): 613 | """Helper function for modulation broadcasting""" 614 | repeats = 1 615 | if e.size(1) > 1: 616 | repeats = x.size(1) // e.size(1) 617 | if repeats == 1: 618 | return e 619 | if repeats * e.size(1) == x.size(1): 620 | return torch.repeat_interleave(e, repeats, dim=1) 621 | else: 622 | return torch.repeat_interleave(e, repeats + 1, dim=1)[:, : x.size(1)] 623 | 624 | def self_attn_with_cache(self, x, freqs, idx, transformer_options={}): 625 | """Modified self-attention that uses KV caching""" 626 | b, s, n, d = *x.shape[:2], self.num_heads, self.self_attn.head_dim 627 | 628 | # just pull out part of the frequencies 629 | if idx is not None: 630 | freqs = freqs[:, idx] 631 | 632 | # Compute QKV like original Wan self-attention 633 | q = self.self_attn.norm_q(self.self_attn.q(x)).view(b, s, n, d) 634 | q = apply_rope1(q, freqs).view(b, s, n * d) 635 | k = self.self_attn.norm_k(self.self_attn.k(x)).view(b, s, n, d) 636 | k = apply_rope1(k, freqs).view(b, s, n * d) 637 | v = self.self_attn.v(x).view(b, s, n * d) 638 | 639 | # RAS: KV Cache management 640 | if idx is None: 641 | self.k_cache = k 642 | self.v_cache = v 643 | else: 644 | self.k_cache[:, idx, :] = k 645 | self.v_cache[:, idx, :] = v 646 | x = optimized_attention( 647 | q, 648 | self.k_cache, 649 | self.v_cache, 650 | heads=n, 651 | transformer_options=transformer_options, 652 | ) 653 | 654 | x = self.self_attn.o(x) 655 | return x 656 | 657 | 658 | class VaceWanAttentionBlockWrapper(WanAttentionBlockWrapper): 659 | """ 660 | Same as the VaceWanAttentionBlock, but uses a RASManager and RASCache to do KV caching. 661 | """ 662 | 663 | def __init__( 664 | self, 665 | original: VaceWanAttentionBlock, 666 | manager: RASManager, 667 | first=False, 668 | last=False, 669 | ): 670 | nn.Module.__init__(self) 671 | # Copy attributes, handling the case where some might not exist 672 | attributes_to_copy = [ 673 | "dim", 674 | "ffn_dim", 675 | "num_heads", 676 | "window_size", 677 | "qk_norm", 678 | "cross_attn_norm", 679 | "eps", 680 | "block_id", 681 | "norm1", 682 | "self_attn", 683 | "norm3", 684 | "cross_attn", 685 | "norm2", 686 | "ffn", 687 | "modulation", 688 | ] 689 | # Handle optional attributes that might not exist 690 | take_attributes_from(original, self, attributes_to_copy) 691 | 692 | # Copy optional attributes if they exist 693 | for attr in ["before_proj", "after_proj"]: 694 | if hasattr(original, attr): 695 | setattr(self, attr, getattr(original, attr)) 696 | self.manager = manager 697 | self.k_cache: torch.Tensor 698 | self.v_cache: torch.Tensor 699 | self.first = first 700 | self.last = last 701 | 702 | def forward(self, c, x, **kwargs): 703 | if hasattr(self, "before_proj"): 704 | c = self.before_proj(c) + x 705 | c = super().forward(c, **kwargs) 706 | c_skip = self.after_proj(c) 707 | return c_skip, c 708 | 709 | 710 | class HeadWrapper(Head): 711 | """ 712 | Same as the Head (Wan final layer), but reports its output to a manager. 713 | """ 714 | 715 | def __init__(self, original: Head, manager: RASManager): 716 | nn.Module.__init__(self) 717 | take_attributes_from( 718 | original, 719 | self, 720 | ["dim", "out_dim", "patch_size", "eps", "norm", "head", "modulation"], 721 | ) 722 | self.manager = manager 723 | 724 | def forward(self, x, e) -> Tensor: 725 | output = super().forward(x, e) 726 | self.manager.select_indices(output, self.manager.timestep) 727 | return output 728 | --------------------------------------------------------------------------------