├── LICENSE ├── README.md ├── examples ├── base.png ├── dance_base.png ├── dance_diff.png ├── dance_merge.png ├── real_base.png ├── real_diff.png ├── real_merge.png ├── studying_10.png ├── studying_50.png ├── studying_75.png ├── studying_base.png ├── studying_diff.png └── studying_merge.png ├── latents ├── k │ └── latents-saved-here.txt └── v │ └── latents-saved-here.txt └── scripts └── refdrop.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Thomas Cantrell 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SD-RefDrop 2 | This is an extension for the [ReForge version](https://github.com/Panchovix/stable-diffusion-webui-reForge) of the Automatic1111 Stable Diffusion web interface. Its purpose is to implement RefDrop image consistency based on [RefDrop: Controllable Consistency in Image or Video Generation via Reference Feature Guidance](https://arxiv.org/abs/2405.17661) (Fan et al. 2024). RefDrop allows for either consistency or diversification of diffusion model outputs based on an intially recorded output. 3 | 4 | For real-world application purposes, you can find a prompt and a seed with an output character or scene that you like and then apply aspects of that character to all future outputs. Alternatively you can also similarly find a seed with features you want to avoid and remove aspects of that image in future outputs. This level of consistency or diversification is controlled by the RFG Coefficent, which is a parameter that ranges from -1 to 1. Positive values force the outputs to be similar to the initial image while negative values ensure differences. It seems to work best in the 0.2 to 0.3 range for consistency or -0.2 to -0.3 range for diversification. 5 | 6 | ## Examples 7 | The "Original" images were generated using a random seed, and then following images were made with a singular different seed. For the different seed a slight change was made to the prompt, denoted in brackets below: 8 | ``` 9 | Positive prompt: 10 | score_9,score_8_up,score_7_up, 11 | 1girl,simple hoodie, jeans, solo, 12 | light smile, [dancing,][library, studying,] 13 | ponytail,looking at viewer,grey_background,wide_shot 14 | 15 | Negative prompt: 16 | score_6, score_5, score_4, 17 | graphic t-shirt, 18 | ``` 19 | All were generated at 512x768 at CFG 7 using `Euler a` with 20 sampling steps. For this first set, all images were made using the WAI-CUTE-v6.0 fine tune of SDXL. 20 | 21 | | Original | Dancing | Consistent | Diversify | 22 | | ------------ | ------------ | ------------ | ------------ | 23 | | ![Base](examples/base.png) | ![Dance Base](examples/dance_base.png) | ![Dance Merge](examples/dance_merge.png) | ![Dance Diff](examples/dance_diff.png) | 24 | 25 | | Original | Studying | Consistent | Diversify | 26 | | ------------ | ------------ | ------------ | ------------ | 27 | | ![Base](examples/base.png) | ![Studying Base](examples/studying_base.png) | ![Studying Merge](examples/studying_merge.png) | ![Studying Diff](examples/studying_diff.png) | 28 | 29 | The following images use the same original saved output, but then are merged or diversified with using a separate, realistic fine tuned SDXL model. In practice, I've seen how this method can apply unrealistic aspects to a model trained on photos via consistency or emphasize the details and realism of an output using a negative RFG coefficent from an initial input from a less detailed model. The authors of the original paper also showed how this output diversification method can be used to help overcome stereotypes the model may have learned. 30 | 31 | | Original | Studying | Consistent | Diversify | 32 | | ------------ | ------------ | ------------ | ------------ | 33 | | ![Base](examples/base.png) | ![Real Base](examples/real_base.png) | ![Real Merge](examples/real_merge.png) | ![Real Diff](examples/real_diff.png) | 34 | 35 | ## Usage Guide 36 | Install by using the `Extensions` tab from within the ReForge web interface. Navigate to the `Install from URL` tab and copy and enter the URL to this repository then click `Install`. When it finishes, go to the `Installed` tab and click `Apply and restart UI`. After reloading the interface, on the `txt2img` tab select `RefDrop` from the list of parameters. 37 | 38 | First, find a specific image you want to use as the base for consistency or diversification. Once you've found the single image output you'll use, save its seed using the recycle symbol next to the `Seed` field. Click `Enabled` from the RefDrop menu and under `Mode` select `Save`. The RFG Coefficient doesn't matter for this first step, because at this point we are only saving the network details about the base image. 39 | 40 | > [!WARNING] 41 | > Using `Disk` will save a large amount of data to the `extensions\refdrop\latents` folder. The small "Original" image above took 5,602 files totalling 7.3GB. More detailed images using hires fix can go much larger. However, this data is only written to disk during the `Save` step, and these files are deleted and replaced every time you run a new `Save`. Alternatively, there is a dedicated button for deleting all RefDrop related latent files and data in memory. 42 | 43 | > [!TIP] 44 | > This extension only saves one base image data at a time. If you have multiple images you care about, it might be easiest to save the details of the prompt and seed and rerun the `Save` step as needed. Alternatively, you can run in `Disk` save mode and backup the contents of the `extensions\refdrop\latents` folder, but this is a lot of data. 45 | 46 | The amount of data stored can be limited by using the `Save Percentage` parameter. This option also has the added benefit of decreasing the overall run time during `Use` mode. However, using it too much can decrease the overall effect of RefDrop. Using the "studying" example, we can see this how the outputs are affected at different percentages below. 47 | 48 | | Studying | 75% | 50% | 10% | 49 | | ------------ | ------------ | ------------ | ------------ | 50 | | ![Studying Base](examples/studying_base.png) | ![Studying 75%](examples/studying_75.png) | ![Studying 50%](examples/studying_50.png) | ![Studying 10%](examples/studying_10.png) | 51 | 52 | After the save step finishes and you have recreated the original image you want to use for the process, you can now switch the mode to `Use` and set the RFG coefficent as described earlier. While `Enabled` is selected all outputs will use RefDrop. 53 | 54 | > [!IMPORTANT] 55 | > When generating new images using RefDrop, the network parameters must be the same as the original saved image. In practical terms, this means only use models from the same lineage (for example, if using SD1.5 for the base image, only use SD1.5 fine tunes for the output). This RefDrop implementation will use the embedding data available from the `Save` step and then continue on as normal. I have not seen any issues with changing the height, width, sampling method, etc. between `Save` and `Use`. 56 | 57 | ## Afterword 58 | A big thank you to the author of this paper for coming up with the idea and taking the time to talk with me about it at NeurIPS 2024. There is a lot left to explore in this area, including applicability to other areas such as video. One important point that needs to be addressed for this implementation is figuring out how to prune what needs to be saved in order to acheive the desired results. The current implementation is saving all K and V values that are created at any point during the image generation, which is probably overkill. -------------------------------------------------------------------------------- /examples/base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tocantrell/sd-refdrop/e880a8190343bc11ba29bc405742b6d31a75fdab/examples/base.png -------------------------------------------------------------------------------- /examples/dance_base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tocantrell/sd-refdrop/e880a8190343bc11ba29bc405742b6d31a75fdab/examples/dance_base.png -------------------------------------------------------------------------------- /examples/dance_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tocantrell/sd-refdrop/e880a8190343bc11ba29bc405742b6d31a75fdab/examples/dance_diff.png -------------------------------------------------------------------------------- /examples/dance_merge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tocantrell/sd-refdrop/e880a8190343bc11ba29bc405742b6d31a75fdab/examples/dance_merge.png -------------------------------------------------------------------------------- /examples/real_base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tocantrell/sd-refdrop/e880a8190343bc11ba29bc405742b6d31a75fdab/examples/real_base.png -------------------------------------------------------------------------------- /examples/real_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tocantrell/sd-refdrop/e880a8190343bc11ba29bc405742b6d31a75fdab/examples/real_diff.png -------------------------------------------------------------------------------- /examples/real_merge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tocantrell/sd-refdrop/e880a8190343bc11ba29bc405742b6d31a75fdab/examples/real_merge.png -------------------------------------------------------------------------------- /examples/studying_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tocantrell/sd-refdrop/e880a8190343bc11ba29bc405742b6d31a75fdab/examples/studying_10.png -------------------------------------------------------------------------------- /examples/studying_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tocantrell/sd-refdrop/e880a8190343bc11ba29bc405742b6d31a75fdab/examples/studying_50.png -------------------------------------------------------------------------------- /examples/studying_75.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tocantrell/sd-refdrop/e880a8190343bc11ba29bc405742b6d31a75fdab/examples/studying_75.png -------------------------------------------------------------------------------- /examples/studying_base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tocantrell/sd-refdrop/e880a8190343bc11ba29bc405742b6d31a75fdab/examples/studying_base.png -------------------------------------------------------------------------------- /examples/studying_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tocantrell/sd-refdrop/e880a8190343bc11ba29bc405742b6d31a75fdab/examples/studying_diff.png -------------------------------------------------------------------------------- /examples/studying_merge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tocantrell/sd-refdrop/e880a8190343bc11ba29bc405742b6d31a75fdab/examples/studying_merge.png -------------------------------------------------------------------------------- /latents/k/latents-saved-here.txt: -------------------------------------------------------------------------------- 1 | When ran in Save mode, the RefDrop code will save a large amount of network data in this folder. 2 | Any new Save run will first delete the previous data. -------------------------------------------------------------------------------- /latents/v/latents-saved-here.txt: -------------------------------------------------------------------------------- 1 | When ran in Save mode, the RefDrop code will save a large amount of network data in this folder. 2 | Any new Save run will first delete the previous data. -------------------------------------------------------------------------------- /scripts/refdrop.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | 6 | import modules.scripts as scripts 7 | from ldm_patched.ldm.modules.attention import CrossAttention, BasicTransformerBlock 8 | import ldm_patched.ldm.modules.attention as attention 9 | import gradio as gr 10 | 11 | import modules.scripts as scripts 12 | from modules.processing import process_images, Processed 13 | 14 | import torch 15 | import glob 16 | 17 | import ldm_patched.modules.ops 18 | ops = ldm_patched.modules.ops.disable_weight_init 19 | 20 | current_extension_directory = scripts.basedir() 21 | 22 | def remove_latent_files(save_loc): 23 | if save_loc == 'Disk': 24 | files = glob.glob(current_extension_directory+'/latents/k/*.pt') 25 | for f in files: 26 | os.remove(f) 27 | files = glob.glob(current_extension_directory+'/latents/v/*.pt') 28 | for f in files: 29 | os.remove(f) 30 | else: 31 | try: 32 | del CrossAttention.k_dict 33 | except: 34 | pass 35 | try: 36 | del CrossAttention.v_dict 37 | except: 38 | pass 39 | CrossAttention.k_dict = {} 40 | CrossAttention.v_dict = {} 41 | 42 | def remove_all_latents(): 43 | remove_latent_files('Disk') 44 | remove_latent_files('RAM') 45 | 46 | class Script(scripts.Script): 47 | 48 | def title(self): 49 | 50 | return "RefDrop" 51 | 52 | def show(self, is_txt2img): 53 | 54 | return scripts.AlwaysVisible 55 | 56 | def ui(self, is_img2img): 57 | with gr.Group(): 58 | with gr.Accordion("RefDrop", open=False): 59 | enabled = gr.Checkbox(label="Enabled", value=False) 60 | 61 | with gr.Row(equal_height=True): 62 | save_or_use = gr.Radio(["Save", "Use"],value="Save",label="Mode", 63 | info="You must first generate a single image to record its embedding information. Caution: Running \"Save\" a second time will overwrite existing data.") 64 | enabled_hr = gr.Checkbox(label="Enabled for hires fix", value=False) 65 | 66 | rfg = gr.Slider(minimum=-1.0, maximum=1.0, step=0.01, value=0., 67 | label="RFG Coefficent", 68 | info="RFG is only used applying to a new image. Positive values increase consistency with the saved data while negative do the opposite.") 69 | 70 | with gr.Row(equal_height=True): 71 | save_loc = gr.Radio(["RAM", "Disk"],value="RAM",label="Latent Store Location",info="Choose 'Disk' if low on memory.") 72 | save_percent = gr.Slider(minimum=0, maximum=100, step=1, value=100, 73 | label="Save Percentage", 74 | info="Reduce run time by limiting the number of embedding files saved. Minimal impact >=50%") 75 | with gr.Row(): 76 | layer_input = gr.Checkbox(label='input',value=True,info='Select which layer group to use. Use mode must use layers that were selected during Save mode, though fewer may be selected during Use mode.') 77 | delete_button = gr.Button('Delete Saved RefDrop Latents',size='sm', scale=0) 78 | layer_middle = gr.Checkbox(label='middle',value=True) 79 | layer_output = gr.Checkbox(label='output',value=True) 80 | 81 | 82 | delete_button.click(remove_all_latents) 83 | 84 | return [enabled, rfg, save_or_use, enabled_hr, save_loc, save_percent, layer_input,layer_middle,layer_output] 85 | 86 | def process_before_every_step( 87 | self, 88 | p, 89 | enabled, 90 | rfg, 91 | save_or_use, 92 | enabled_hr, 93 | save_loc, 94 | save_percent, 95 | layer_input, 96 | layer_middle, 97 | layer_output, 98 | *args, 99 | **kwarg 100 | ): 101 | 102 | if enabled: 103 | CrossAttention.current_step += 1 104 | CrossAttention.layer_index = 0 105 | 106 | 107 | def process_before_every_sampling( 108 | self, 109 | p, 110 | enabled, 111 | rfg, 112 | save_or_use, 113 | enabled_hr, 114 | save_loc, 115 | save_percent, 116 | layer_input, 117 | layer_middle, 118 | layer_output, 119 | *args, 120 | **kwarg 121 | ): 122 | 123 | if enabled: 124 | #Reset the layer name to the start 125 | CrossAttention.layer_name = 'input' 126 | CrossAttention.layer_index = 0 127 | 128 | #Disable after initial run if "Enable for hires" not selected 129 | if p.is_hr_pass: 130 | if not enabled_hr: 131 | print('Not using RefDrop for hires fix') 132 | CrossAttention.refdrop = 'Done' 133 | else: 134 | CrossAttention.refdrop_hires = True 135 | 136 | 137 | 138 | def before_process_batch( 139 | self, 140 | p, 141 | enabled, 142 | rfg, 143 | save_or_use, 144 | enabled_hr, 145 | save_loc, 146 | save_percent, 147 | layer_input, 148 | layer_middle, 149 | layer_output, 150 | *args, 151 | **kwarg 152 | ): 153 | 154 | 155 | layer_list = ['input','middle','output'] 156 | CrossAttention.layer_refdrop = [x for x in layer_list if [layer_input,layer_middle,layer_output][layer_list.index(x)]] 157 | 158 | if enabled: 159 | print('RefDrop Enabled') 160 | 161 | CrossAttention.current_step = 0 162 | if save_percent != 100: 163 | save_percent /= 100. 164 | CrossAttention.max_step = int(p.steps * save_percent) 165 | else: 166 | CrossAttention.max_step = p.steps 167 | 168 | CrossAttention.rfg = rfg 169 | CrossAttention.current_step = 0 170 | CrossAttention.layer_name = 'input' 171 | CrossAttention.layer_index = 0 172 | CrossAttention.refdrop_hires = False 173 | CrossAttention.to_disk = False 174 | if save_loc == 'Disk': 175 | CrossAttention.to_disk = True 176 | 177 | if save_or_use == 'Use': 178 | print('Applying RefDrop data') 179 | CrossAttention.refdrop = 'Use' 180 | 181 | if save_or_use == 'Save': 182 | print('Saving RefDrop data') 183 | CrossAttention.refdrop = 'Save' 184 | #Delete existing latent data 185 | remove_latent_files(save_loc) 186 | 187 | 188 | def _forwardBasicTransformerBlock(self, x, context=None, transformer_options={}): 189 | extra_options = {} 190 | block = transformer_options.get("block", None) 191 | block_index = transformer_options.get("block_index", 0) 192 | transformer_patches = {} 193 | transformer_patches_replace = {} 194 | #print(block) 195 | 196 | if CrossAttention.layer_name != block[0]: 197 | CrossAttention.layer_name = block[0] 198 | CrossAttention.layer_index = 0 199 | 200 | #Define file save or read location 201 | if CrossAttention.refdrop_hires: 202 | hires = '_hires' 203 | else: 204 | hires = '' 205 | latentname = CrossAttention.layer_name +'_step'+ str(CrossAttention.current_step) +'_layer'+ str(CrossAttention.layer_index) + hires 206 | if CrossAttention.to_disk: 207 | k_file = current_extension_directory + '/latents/k/' + latentname + '.pt' 208 | v_file = current_extension_directory + '/latents/v/' + latentname + '.pt' 209 | else: 210 | k_file = latentname 211 | v_file = latentname 212 | refdrop_save = False 213 | refdrop_use = False 214 | 215 | if CrossAttention.max_step <= CrossAttention.current_step: 216 | CrossAttention.refdrop = 'Done' 217 | v_refdrop = None 218 | k_refdrop = None 219 | 220 | if (CrossAttention.refdrop == 'Use')&(CrossAttention.to_disk)&(CrossAttention.layer_name in CrossAttention.layer_refdrop): 221 | try: 222 | v_refdrop = torch.load(v_file, weights_only=True) 223 | k_refdrop = torch.load(k_file, weights_only=True) 224 | refdrop_use = True 225 | except: 226 | #Running without the last few K and V files will not significantly change the results. 227 | #Also allows for variable hires fix and adetailer 228 | print('Saved RefDrop file not found. Continuing without RefDrop.') 229 | CrossAttention.refdrop = 'Done' 230 | v_refdrop = None 231 | k_refdrop = None 232 | elif (CrossAttention.refdrop == 'Use')&(CrossAttention.to_disk!=True)&(CrossAttention.layer_name in CrossAttention.layer_refdrop): 233 | try: 234 | v_refdrop = CrossAttention.v_dict[v_file].to('cuda') 235 | k_refdrop = CrossAttention.k_dict[k_file].to('cuda') 236 | refdrop_use = True 237 | except: 238 | #Running without the last few K and V files will not significantly change the results. 239 | #Also allows for variable hires fix and adetailer 240 | print('Saved RefDrop file not found. Continuing without RefDrop.') 241 | CrossAttention.refdrop = 'Done' 242 | v_refdrop = None 243 | k_refdrop = None 244 | elif CrossAttention.refdrop == 'Save': 245 | v_refdrop = None 246 | k_refdrop = None 247 | refdrop_save = True 248 | else: 249 | v_refdrop = None 250 | k_refdrop = None 251 | refdrop_use = False 252 | refdrop_save = False 253 | 254 | for k in transformer_options: 255 | if k == "patches": 256 | transformer_patches = transformer_options[k] 257 | elif k == "patches_replace": 258 | transformer_patches_replace = transformer_options[k] 259 | else: 260 | extra_options[k] = transformer_options[k] 261 | 262 | extra_options["n_heads"] = self.n_heads 263 | extra_options["dim_head"] = self.d_head 264 | 265 | if self.ff_in: 266 | x_skip = x 267 | x = self.ff_in(self.norm_in(x)) 268 | if self.is_res: 269 | x += x_skip 270 | 271 | n = self.norm1(x) 272 | if self.disable_self_attn: 273 | context_attn1 = context 274 | else: 275 | context_attn1 = None 276 | value_attn1 = None 277 | 278 | if "attn1_patch" in transformer_patches: 279 | patch = transformer_patches["attn1_patch"] 280 | if context_attn1 is None: 281 | context_attn1 = n 282 | value_attn1 = context_attn1 283 | for p in patch: 284 | n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options) 285 | 286 | if block is not None: 287 | transformer_block = (block[0], block[1], block_index) 288 | else: 289 | transformer_block = None 290 | attn1_replace_patch = transformer_patches_replace.get("attn1", {}) 291 | block_attn1 = transformer_block 292 | if block_attn1 not in attn1_replace_patch: 293 | block_attn1 = block 294 | 295 | if block_attn1 in attn1_replace_patch: 296 | if context_attn1 is None: 297 | context_attn1 = n 298 | value_attn1 = n 299 | n = self.attn1.to_q(n) 300 | context_attn1 = self.attn1.to_k(context_attn1) 301 | value_attn1 = self.attn1.to_v(value_attn1) 302 | n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options) 303 | n = self.attn1.to_out(n) 304 | 305 | else: 306 | #Apply RefDrop if the current layer is in the selected list 307 | if CrossAttention.layer_name in CrossAttention.layer_refdrop: 308 | n = self.attn1( 309 | n, 310 | context=context_attn1, 311 | value=value_attn1, 312 | transformer_options=extra_options, 313 | k_refdrop=k_refdrop, 314 | v_refdrop=v_refdrop, 315 | refdrop_save=refdrop_save, 316 | refdrop_use=refdrop_use, 317 | k_file=k_file, 318 | v_file=v_file 319 | ) 320 | else: 321 | n = self.attn1( 322 | n, 323 | context=context_attn1, 324 | value=value_attn1, 325 | transformer_options=extra_options 326 | ) 327 | 328 | if "attn1_output_patch" in transformer_patches: 329 | patch = transformer_patches["attn1_output_patch"] 330 | for p in patch: 331 | n = p(n, extra_options) 332 | 333 | x += n 334 | if "middle_patch" in transformer_patches: 335 | patch = transformer_patches["middle_patch"] 336 | for p in patch: 337 | x = p(x, extra_options) 338 | 339 | if self.attn2 is not None: 340 | n = self.norm2(x) 341 | if self.switch_temporal_ca_to_sa: 342 | context_attn2 = n 343 | else: 344 | context_attn2 = context 345 | value_attn2 = None 346 | if "attn2_patch" in transformer_patches: 347 | patch = transformer_patches["attn2_patch"] 348 | value_attn2 = context_attn2 349 | for p in patch: 350 | n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options) 351 | 352 | attn2_replace_patch = transformer_patches_replace.get("attn2", {}) 353 | block_attn2 = transformer_block 354 | if block_attn2 not in attn2_replace_patch: 355 | block_attn2 = block 356 | 357 | if block_attn2 in attn2_replace_patch: 358 | if value_attn2 is None: 359 | value_attn2 = context_attn2 360 | n = self.attn2.to_q(n) 361 | context_attn2 = self.attn2.to_k(context_attn2) 362 | value_attn2 = self.attn2.to_v(value_attn2) 363 | n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) 364 | n = self.attn2.to_out(n) 365 | else: 366 | n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=extra_options) 367 | 368 | if "attn2_output_patch" in transformer_patches: 369 | patch = transformer_patches["attn2_output_patch"] 370 | for p in patch: 371 | n = p(n, extra_options) 372 | 373 | x += n 374 | if self.is_res: 375 | x_skip = x 376 | x = self.ff(self.norm3(x)) 377 | if self.is_res: 378 | x += x_skip 379 | 380 | CrossAttention.layer_index += 1 381 | 382 | return x 383 | 384 | BasicTransformerBlock._forward = _forwardBasicTransformerBlock 385 | 386 | 387 | def forward_crossattention( 388 | self, 389 | x, 390 | context=None, 391 | value=None, 392 | mask=None, 393 | transformer_options=None, 394 | k_refdrop=None, 395 | v_refdrop=None, 396 | refdrop_save=False, 397 | refdrop_use=False, 398 | k_file=None, 399 | v_file=None 400 | ): 401 | 402 | q = self.to_q(x) 403 | context = attention.default(context, x) 404 | k = self.to_k(context) 405 | if value is not None: 406 | v = self.to_v(value) 407 | del value 408 | else: 409 | v = self.to_v(context) 410 | 411 | if refdrop_save: 412 | if CrossAttention.to_disk: 413 | #Save K and V to files on disk 414 | torch.save(k, k_file) 415 | torch.save(v, v_file) 416 | else: 417 | #Save K and V to files to memory via a dictionary 418 | CrossAttention.k_dict.update({k_file:k.to('cpu')}) 419 | CrossAttention.v_dict.update({v_file:v.to('cpu')}) 420 | 421 | if mask is None: 422 | out = attention.optimized_attention(q, k, v, self.heads) 423 | if refdrop_use: 424 | out_refdrop = attention.optimized_attention(q, k_refdrop, v_refdrop, self.heads) 425 | else: 426 | out = attention.optimized_attention_masked(q, k, v, self.heads, mask) 427 | if refdrop_use: 428 | out_refdrop = attention.optimized_attention(q, k_refdrop, v_refdrop, self.heads) 429 | 430 | if refdrop_use: 431 | out = (out * (1-CrossAttention.rfg)) + (out_refdrop * CrossAttention.rfg) 432 | 433 | return self.to_out(out) 434 | 435 | CrossAttention.forward = forward_crossattention 436 | 437 | else: 438 | 439 | CrossAttention.layer_index = 0 440 | CrossAttention.refdrop = None 441 | CrossAttention.rfg = rfg 442 | CrossAttention.current_step = 0 443 | CrossAttention.layer_name = 'input' 444 | 445 | --------------------------------------------------------------------------------