├── .gitignore ├── LICENSE.md ├── README.md ├── assets ├── 00007-1449410826.png ├── 00008-1449410826.png ├── 00009-1449410826.png ├── 00012-1449410826.png ├── 00013-1449410826.png ├── 00014-1449410826.png ├── 00015-1449410826.png ├── 00016-1449410826.png ├── 00017-1449410826.png ├── tmp6vhmj4ty.png ├── tmpumkrx_oc.png └── tmpzxoq_cn7.png ├── loractl └── lib │ ├── __init__.py │ ├── lora_ctl_network.py │ ├── network_patch.py │ ├── plot.py │ └── utils.py ├── scripts └── loractl.py └── test ├── __init__.py └── tests.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright 2023 Chris Heald 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LoRa Control - Dynamic Weights Controller 2 | 3 | This is an extension for the [Automatic1111 Stable Diffusion web interface](https://github.com/AUTOMATIC1111/stable-diffusion-webui) which replaces the standard built-in LoraNetwork with one that understands additional syntax for mutating lora weights over the course of the generation. Both positional arguments and named arguments are honored, and additional control for normal versus high-res passes are provided. 4 | 5 | The idea was inspired by the step control in [composable lora](https://github.com/a2569875/stable-diffusion-webui-composable-lora), which unfortunately doesn't work with newer versions of the webui. 6 | 7 | Quick features: 8 | 9 | * Easily specify keyframe weights for loras at arbitrary points 10 | * Extends the existing lora syntax; no new systems to learn. 11 | * Provides for separate control of lora weights over initial and hires passes. 12 | 13 | ### Installation 14 | 15 | **⚠️ This extension only works with the Automatic1111 1.5RC or later. It is directly tied to the new network handling architecture. It does not currently work with SD.Next, and will not unless SD.Next adopts the 1.5 extra network architecture.** 16 | 17 | In the Automatic1111 UI, go to Extensions, then "Install from URL" and enter 18 | 19 | ``` 20 | https://github.com/cheald/sd-webui-loractl 21 | ``` 22 | 23 | Then go to the "Installed" tab and click "Apply and Restart UI". You don't have to enable it; the extension is "always on". If you don't use the extended syntax, then the behavior is identical to the stock behavior. 24 | 25 | ### Basic Usage 26 | 27 | The standard built-in Lora syntax is: 28 | 29 | 30 | 31 | or with named parameters: 32 | 33 | 34 | 35 | This extension extends this syntax so that rather than just single weights, you can provide a list of `weight@step`, delimited by commas or semicolons. The weight will then be interpolated per step appropriately. 36 | 37 | The weight may be: 38 | 39 | * A single number, which will be used for all steps (e.g, `1.0`) 40 | * A comma-separated list of weight-at-step pairs, e.g. `0@0,1@0.5,0@1` to start at 0, go to 1.0 strength halfway through generation, then scale back down to 0 when we finish generation. This is smoothly interpolated, so the weight curve looks something like: 41 | 42 | ![](assets/tmpumkrx_oc.png) 43 | 44 | The step value (after the @) may be a float in the 0.0-1.0 domain, in which case it is interpreted as a percentage of the pass steps. If it is greater than 1, then it is interpreted as an absolute step number. 45 | 46 | If only a single argument (or just `te`) is given, then it applies to both the text encoder and the unet. 47 | 48 | The default weight for the network at step 0 is the earliest weight given. That is, given a step weight of `0.25@0.5,1@1`, the weights will will begin at 0.25 weight, stay there until until half the steps are run, then interpolate up to 1.0 for the final step. 49 | 50 | ### Network Aliases 51 | 52 | You can also use `loractl` as the network name; this is functionally identical, but may let you dodge compatibility issues with other network handling, and will cause loractl networks to just not do anything when the extension is not enabled. 53 | 54 | 55 | 56 | ### Separate high-res pass control 57 | 58 | You can use the named arguments `hr`, `hrte`, and `hrunet` to specify weights for the whole lora, or just the te/unet during the high res pass. For example, you could apply a lora at half weight during the first pass, and full weight during the HR pass, with: 59 | 60 | 61 | 62 | ![](assets/tmp6vhmj4ty.png) 63 | 64 | Or, you could grow the lora strength during the first pass, and then decline during the HR pass: 65 | 66 | 67 | 68 | ### Lora mixing 69 | 70 | Sometimes, one lora or another is too powerful. Either it overpowers the base model, or it overpowers other loras in the prompt. For example, I have the [Mechanical Bird](https://civitai.com/models/98218/mechanical-bird) and [Star Wars AT-AT](https://civitai.com/models/97961/star-wars-at-at-1980) loras together in a prompt. I want a sweet birdlike cybernetic AT walker! 71 | 72 | So I first try just throwing them together: 73 | 74 | ``` 75 | mechanical bird, st4rw4ar5at4t 76 | ``` 77 | 78 | ![](assets/00007-1449410826.png) 79 | 80 | The AT lora is clearly too powerful, so I'll try mixing them together more conservatively: 81 | 82 | ``` 83 | mechanical bird, st4rw4ar5at4t 84 | `````` 85 | 86 | ![](assets/00008-1449410826.png) 87 | 88 | The bird hardly comes through; the AT-AT lora is clearly far more heavily overtrained. 89 | 90 | I can try reducing that AT-AT lora weight to let more of the bird come through: 91 | 92 | ``` 93 | mechanical bird, st4rw4ar5at4t 94 | ``` 95 | ![](assets/00012-1449410826.png) 96 | 97 | That AT-AT model is just way too strong, and we can't get enough bird to come through. This is where we can use Loractl! 98 | 99 | ``` 100 | mechanical bird, st4rw4ar5at4t 101 | ``` 102 | 103 | Here, I'm going to set the AT lora's weight to 0 to start, ramping up to full strength by 40%. The mechanical bird will start at full strength, and will ramp down to 50% strength by 50%. 104 | 105 | ![](assets/00009-1449410826.png) 106 | 107 | That's more like it! 108 | 109 | We can see how the lora weights applied over the course of the run: 110 | 111 | ![](assets/tmpzxoq_cn7.png) 112 | 113 | ### Lora warmup 114 | 115 | Sometimes, a lora has useful elements that we want in an image, but it conflicts with the base model. For an example here, I'm using Realistic Vision 4.0 (which is a hyper-realistic model) and a Pixelart lora (which plays well with anime models, but which fights with realistic models). 116 | 117 | I want a picture of an awesome mecha at sunset. First we'll get a baseline with the pixelart lora at 0 strength: 118 | 119 | ``` 120 | pixelart, mecha on the beach, beautiful sunset, epic lighting 121 | ``` 122 | 123 | ![](assets/00013-1449410826.png) 124 | 125 | Sweet! That looks awesome, but let's pixelify it: 126 | 127 | ``` 128 | pixelart, mecha on the beach, beautiful sunset, epic lighting 129 | ``` 130 | 131 | ![](assets/00014-1449410826.png) 132 | 133 | The clash between the animated and realistic models is evident here: they fight for inital control of the image, and the result ends up looking like something that you'd have played in the 80s on a Tandy 1000. 134 | 135 | So, we're going to just stuff the pixelart lora for the first 5 steps, THEN turn it on. This lets the underlying model determine the overall compositional elements of the image before the lora starts exerting its influence. 136 | 137 | ``` 138 | pixelart, mecha on the beach, beautiful sunset, epic lighting 139 | ``` 140 | 141 | ![](assets/00015-1449410826.png) 142 | 143 | Awesome. 144 | 145 | ### Separate text encoder/unet control 146 | 147 | The new 1.5.0RC allows for separate control of the text encoder and unet weights in the lora syntax. loractl allows for variable control of them independently, as well: 148 | 149 | ``` 150 | pixelart, mecha on the beach, beautiful sunset, epic lighting 151 | ``` 152 | 153 | ![](assets/00016-1449410826.png) 154 | 155 | ``` 156 | pixelart, mecha on the beach, beautiful sunset, epic lighting 157 | ``` 158 | 159 | ![](assets/00017-1449410826.png) 160 | 161 | You can play with each of the weights individually to achieve the effects and model mixing best desired. 162 | 163 | ### Running tests 164 | 165 | A basic test suite is included to assert that parsing and setup of weight params is correct. Invoke it with: 166 | 167 | python -m unittest discover test 168 | 169 | Please note that the extension will need to be properly installed in a webui install to be tested, as it does rely on imports from the webui itself. 170 | -------------------------------------------------------------------------------- /assets/00007-1449410826.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheald/sd-webui-loractl/b48cf602b0c414a362726fa079a284f3ae6a51e2/assets/00007-1449410826.png -------------------------------------------------------------------------------- /assets/00008-1449410826.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheald/sd-webui-loractl/b48cf602b0c414a362726fa079a284f3ae6a51e2/assets/00008-1449410826.png -------------------------------------------------------------------------------- /assets/00009-1449410826.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheald/sd-webui-loractl/b48cf602b0c414a362726fa079a284f3ae6a51e2/assets/00009-1449410826.png -------------------------------------------------------------------------------- /assets/00012-1449410826.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheald/sd-webui-loractl/b48cf602b0c414a362726fa079a284f3ae6a51e2/assets/00012-1449410826.png -------------------------------------------------------------------------------- /assets/00013-1449410826.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheald/sd-webui-loractl/b48cf602b0c414a362726fa079a284f3ae6a51e2/assets/00013-1449410826.png -------------------------------------------------------------------------------- /assets/00014-1449410826.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheald/sd-webui-loractl/b48cf602b0c414a362726fa079a284f3ae6a51e2/assets/00014-1449410826.png -------------------------------------------------------------------------------- /assets/00015-1449410826.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheald/sd-webui-loractl/b48cf602b0c414a362726fa079a284f3ae6a51e2/assets/00015-1449410826.png -------------------------------------------------------------------------------- /assets/00016-1449410826.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheald/sd-webui-loractl/b48cf602b0c414a362726fa079a284f3ae6a51e2/assets/00016-1449410826.png -------------------------------------------------------------------------------- /assets/00017-1449410826.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheald/sd-webui-loractl/b48cf602b0c414a362726fa079a284f3ae6a51e2/assets/00017-1449410826.png -------------------------------------------------------------------------------- /assets/tmp6vhmj4ty.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheald/sd-webui-loractl/b48cf602b0c414a362726fa079a284f3ae6a51e2/assets/tmp6vhmj4ty.png -------------------------------------------------------------------------------- /assets/tmpumkrx_oc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheald/sd-webui-loractl/b48cf602b0c414a362726fa079a284f3ae6a51e2/assets/tmpumkrx_oc.png -------------------------------------------------------------------------------- /assets/tmpzxoq_cn7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheald/sd-webui-loractl/b48cf602b0c414a362726fa079a284f3ae6a51e2/assets/tmpzxoq_cn7.png -------------------------------------------------------------------------------- /loractl/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheald/sd-webui-loractl/b48cf602b0c414a362726fa079a284f3ae6a51e2/loractl/lib/__init__.py -------------------------------------------------------------------------------- /loractl/lib/lora_ctl_network.py: -------------------------------------------------------------------------------- 1 | from modules import extra_networks, script_callbacks, shared 2 | from loractl.lib import utils 3 | 4 | import sys, importlib 5 | from pathlib import Path 6 | 7 | # extensions-builtin isn't normally referencable due to the dash; this hacks around that 8 | lora_path = str(Path(__file__).parent.parent.parent.parent.parent / "extensions-builtin" / "Lora") 9 | sys.path.insert(0, lora_path) 10 | import network, networks, network_lora, extra_networks_lora 11 | sys.path.remove(lora_path) 12 | 13 | lora_weights = {} 14 | 15 | 16 | def reset_weights(): 17 | lora_weights.clear() 18 | 19 | 20 | class LoraCtlNetwork(extra_networks_lora.ExtraNetworkLora): 21 | # Hijack the params parser and feed it dummy weights instead so it doesn't choke trying to 22 | # parse our extended syntax 23 | def activate(self, p, params_list): 24 | if not utils.is_active(): 25 | return super().activate(p, params_list) 26 | 27 | for params in params_list: 28 | assert params.items 29 | name = params.positional[0] 30 | if lora_weights.get(name, None) == None: 31 | lora_weights[name] = utils.params_to_weights(params) 32 | # The hardcoded 1 weight is fine here, since our actual patch looks up the weights from 33 | # our lora_weights dict 34 | params.positional = [name, 1] 35 | params.named = {} 36 | return super().activate(p, params_list) 37 | -------------------------------------------------------------------------------- /loractl/lib/network_patch.py: -------------------------------------------------------------------------------- 1 | from modules import shared 2 | import numpy as np 3 | 4 | from loractl.lib.lora_ctl_network import network, lora_weights 5 | from loractl.lib.utils import calculate_weight, is_hires 6 | 7 | # Patch network.Network so it reapplies properly for dynamic weights 8 | # By default, network application is cached, with (name, te, unet, dim) as a key 9 | # By replacing the bare properties with getters, we can ensure that we cause SD 10 | # to reapply the network each time we change its weights, while still taking advantage 11 | # of caching when weights are not updated. 12 | 13 | 14 | def get_weight(m): 15 | return calculate_weight(m, shared.state.sampling_step, shared.state.sampling_steps, step_offset=2) 16 | 17 | 18 | def get_dynamic_te(self): 19 | if self.name in lora_weights: 20 | key = "te" if not is_hires() else "hrte" 21 | w = lora_weights[self.name] 22 | return get_weight(w.get(key, self._te_multiplier)) 23 | 24 | return get_weight(self._te_multiplier) 25 | 26 | 27 | def get_dynamic_unet(self): 28 | if self.name in lora_weights: 29 | key = "unet" if not is_hires() else "hrunet" 30 | w = lora_weights[self.name] 31 | return get_weight(w.get(key, self._unet_multiplier)) 32 | 33 | return get_weight(self._unet_multiplier) 34 | 35 | 36 | def set_dynamic_te(self, value): 37 | self._te_multiplier = value 38 | 39 | 40 | def set_dynamic_unet(self, value): 41 | self._unet_multiplier = value 42 | 43 | 44 | def apply(): 45 | if getattr(network.Network, "te_multiplier", None) == None: 46 | network.Network.te_multiplier = property(get_dynamic_te, set_dynamic_te) 47 | network.Network.unet_multiplier = property( 48 | get_dynamic_unet, set_dynamic_unet) 49 | -------------------------------------------------------------------------------- /loractl/lib/plot.py: -------------------------------------------------------------------------------- 1 | import io 2 | from PIL import Image 3 | from modules import script_callbacks 4 | import matplotlib 5 | import pandas as pd 6 | from loractl.lib.lora_ctl_network import networks 7 | 8 | log_weights = [] 9 | log_names = [] 10 | last_plotted_step = -1 11 | 12 | 13 | # Copied from composable_lora 14 | def plot_lora_weight(lora_weights, lora_names): 15 | data = pd.DataFrame(lora_weights, columns=lora_names) 16 | ax = data.plot() 17 | ax.set_xlabel("Steps") 18 | ax.set_ylabel("LoRA weight") 19 | ax.set_title("LoRA weight in all steps") 20 | ax.legend(loc=0) 21 | result_image = fig2img(ax) 22 | matplotlib.pyplot.close(ax.figure) 23 | del ax 24 | return result_image 25 | 26 | 27 | # Copied from composable_lora 28 | def fig2img(fig): 29 | buf = io.BytesIO() 30 | fig.figure.savefig(buf) 31 | buf.seek(0) 32 | img = Image.open(buf) 33 | return img 34 | 35 | 36 | def reset_plot(): 37 | global last_plotted_step 38 | log_weights.clear() 39 | log_names.clear() 40 | 41 | 42 | def make_plot(): 43 | return plot_lora_weight(log_weights, log_names) 44 | 45 | 46 | # On each step, capture our lora weights for plotting 47 | def on_step(params): 48 | global last_plotted_step 49 | if last_plotted_step == params.sampling_step and len(log_weights) > 0: 50 | log_weights.pop() 51 | last_plotted_step = params.sampling_step 52 | if len(log_names) == 0: 53 | for net in networks.loaded_networks: 54 | log_names.append(net.name + "_te") 55 | log_names.append(net.name + "_unet") 56 | frame = [] 57 | for net in networks.loaded_networks: 58 | frame.append(net.te_multiplier) 59 | frame.append(net.unet_multiplier) 60 | log_weights.append(frame) 61 | 62 | 63 | script_callbacks.on_cfg_after_cfg(on_step) 64 | -------------------------------------------------------------------------------- /loractl/lib/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | 4 | # Given a string like x@y,z@a, returns [[x, z], [y, a]] sorted for consumption by np.interp 5 | 6 | 7 | def sorted_positions(raw_steps): 8 | steps = [[float(s.strip()) for s in re.split("[@~]", x)] 9 | for x in re.split("[,;]", str(raw_steps))] 10 | # If we just got a single number, just return it 11 | if len(steps[0]) == 1: 12 | return steps[0][0] 13 | 14 | # Add implicit 1s to any steps which don't have a weight 15 | steps = [[s[0], s[1] if len(s) == 2 else 1] for s in steps] 16 | 17 | # Sort by index 18 | steps.sort(key=lambda k: k[1]) 19 | 20 | steps = [list(v) for v in zip(*steps)] 21 | return steps 22 | 23 | 24 | def calculate_weight(m, step, max_steps, step_offset=2): 25 | if isinstance(m, list): 26 | if m[1][-1] <= 1.0: 27 | if max_steps > 0: 28 | step = (step) / (max_steps - step_offset) 29 | else: 30 | step = 1.0 31 | else: 32 | step = step 33 | v = np.interp(step, m[1], m[0]) 34 | return v 35 | else: 36 | return m 37 | 38 | 39 | def params_to_weights(params): 40 | weights = {"unet": None, "te": 1.0, "hrunet": None, "hrte": None} 41 | 42 | if len(params.positional) > 1: 43 | weights["te"] = sorted_positions(params.positional[1]) 44 | 45 | if len(params.positional) > 2: 46 | weights["unet"] = sorted_positions(params.positional[2]) 47 | 48 | if params.named.get("te"): 49 | weights["te"] = sorted_positions(params.named.get("te")) 50 | 51 | if params.named.get("unet"): 52 | weights["unet"] = sorted_positions(params.named.get("unet")) 53 | 54 | if params.named.get("hr"): 55 | weights["hrunet"] = sorted_positions(params.named.get("hr")) 56 | weights["hrte"] = sorted_positions(params.named.get("hr")) 57 | 58 | if params.named.get("hrunet"): 59 | weights["hrunet"] = sorted_positions(params.named.get("hrunet")) 60 | 61 | if params.named.get("hrte"): 62 | weights["hrte"] = sorted_positions(params.named.get("hrte")) 63 | 64 | # If unet ended up unset, then use the te value 65 | weights["unet"] = weights["unet"] if weights["unet"] is not None else weights["te"] 66 | # If hrunet ended up unset, use unet value 67 | weights["hrunet"] = weights["hrunet"] if weights["hrunet"] is not None else weights["unet"] 68 | # If hrte ended up unset, use te value 69 | weights["hrte"] = weights["hrte"] if weights["hrte"] is not None else weights["te"] 70 | 71 | return weights 72 | 73 | 74 | hires = False 75 | loractl_active = True 76 | 77 | def is_hires(): 78 | return hires 79 | 80 | 81 | def set_hires(value): 82 | global hires 83 | hires = value 84 | 85 | 86 | def set_active(value): 87 | global loractl_active 88 | loractl_active = value 89 | 90 | def is_active(): 91 | global loractl_active 92 | return loractl_active 93 | -------------------------------------------------------------------------------- /scripts/loractl.py: -------------------------------------------------------------------------------- 1 | import modules.scripts as scripts 2 | from modules import extra_networks 3 | from modules.processing import StableDiffusionProcessing 4 | import gradio as gr 5 | from loractl.lib import utils, plot, lora_ctl_network, network_patch 6 | 7 | 8 | class LoraCtlScript(scripts.Script): 9 | def __init__(self): 10 | self.original_network = None 11 | super().__init__() 12 | 13 | def title(self): 14 | return "Dynamic Lora Weights" 15 | 16 | def show(self, is_img2img): 17 | return scripts.AlwaysVisible 18 | 19 | def ui(self, is_img2img): 20 | with gr.Group(): 21 | with gr.Accordion("Dynamic Lora Weights", open=False): 22 | opt_enable = gr.Checkbox( 23 | value=True, label="Enable Dynamic Lora Weights") 24 | opt_plot_lora_weight = gr.Checkbox( 25 | value=False, label="Plot the LoRA weight in all steps") 26 | return [opt_enable, opt_plot_lora_weight] 27 | 28 | def process(self, p: StableDiffusionProcessing, opt_enable=True, opt_plot_lora_weight=False, **kwargs): 29 | if opt_enable and type(extra_networks.extra_network_registry["lora"]) != lora_ctl_network.LoraCtlNetwork: 30 | self.original_network = extra_networks.extra_network_registry["lora"] 31 | network = lora_ctl_network.LoraCtlNetwork() 32 | extra_networks.register_extra_network(network) 33 | extra_networks.register_extra_network_alias(network, "loractl") 34 | elif not opt_enable and type(extra_networks.extra_network_registry["lora"]) != lora_ctl_network.LoraCtlNetwork.__bases__[0]: 35 | extra_networks.register_extra_network(self.original_network) 36 | self.original_network = None 37 | network_patch.apply() 38 | utils.set_hires(False) 39 | utils.set_active(opt_enable) 40 | lora_ctl_network.reset_weights() 41 | plot.reset_plot() 42 | 43 | def before_hr(self, p, *args): 44 | utils.set_hires(True) 45 | 46 | def postprocess(self, p, processed, opt_enable=True, opt_plot_lora_weight=False, **kwargs): 47 | if opt_plot_lora_weight and opt_enable: 48 | processed.images.extend([plot.make_plot()]) 49 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheald/sd-webui-loractl/b48cf602b0c414a362726fa079a284f3ae6a51e2/test/__init__.py -------------------------------------------------------------------------------- /test/tests.py: -------------------------------------------------------------------------------- 1 | import sys, unittest 2 | from pathlib import Path 3 | from loractl.lib.utils import sorted_positions, calculate_weight, params_to_weights 4 | 5 | path = str(Path(__file__).parent.parent.parent.parent) 6 | sys.path.insert(0, path) 7 | from modules.extra_networks import ExtraNetworkParams 8 | sys.path.remove(path) 9 | 10 | 11 | class LoraCtlTests(unittest.TestCase): 12 | def test_sorted_positions(self): 13 | self.assertEqual(sorted_positions("1"), 1.0) 14 | self.assertEqual(sorted_positions("1@0,0.5@3,1@6"), 15 | [[1.0, 0.5, 1.0], [0.0, 3.0, 6.0]]) 16 | self.assertEqual(sorted_positions("0.5@3,1@6,1@0"), 17 | [[1.0, 0.5, 1.0], [0.0, 3.0, 6.0]]) 18 | self.assertEqual(sorted_positions("0.5@0,0.5@0.5,0@1"), 19 | [[0.5, 0.5, 0.0], [0.0, 0.5, 1.0]]) 20 | 21 | def test_sorted_position_semicolons(self): 22 | self.assertEqual(sorted_positions("1@0;0.5@3;1@6"), 23 | [[1.0, 0.5, 1.0], [0.0, 3.0, 6.0]]) 24 | 25 | def test_weight_interpolation(self): 26 | # Bare weights are never interpolated 27 | steps = sorted_positions("1.0") 28 | self.assertEqual(calculate_weight(steps, 0, 30), 1.0) 29 | self.assertEqual(calculate_weight(steps, 15, 30), 1.0) 30 | self.assertEqual(calculate_weight(steps, 30, 30), 1.0) 31 | 32 | # Weights are interpolated correctly 33 | steps = sorted_positions("0.75@0;0.5@3;1@6") 34 | self.assertEqual(calculate_weight(steps, 0, 30), 0.75) 35 | self.assertEqual(calculate_weight(steps, 3, 30), 0.5) 36 | self.assertEqual(calculate_weight(steps, 6, 30), 1.0) 37 | self.assertEqual(calculate_weight(steps, 9, 30), 1.0) 38 | 39 | # An implicit 0-step is added 40 | steps = sorted_positions("0.5@5,1.0@10") 41 | self.assertEqual(calculate_weight(steps, 0, 30), 0.5) 42 | self.assertEqual(calculate_weight(steps, 5, 30), 0.5) 43 | self.assertEqual(calculate_weight(steps, 8, 30), 0.8) 44 | self.assertEqual(calculate_weight(steps, 15, 30), 1.0) 45 | 46 | 47 | class LoraCtlNetworkTests(unittest.TestCase): 48 | def assert_params(self, str, expected): 49 | params = ExtraNetworkParams(str.split(":")) 50 | self.assertEqual(params_to_weights(params), expected) 51 | 52 | def test_params_to_weights(self): 53 | # TE cascades to all 54 | self.assert_params("loraname:1.0", { 55 | 'hrte': 1.0, 56 | 'hrunet': 1.0, 57 | 'te': 1.0, 58 | 'unet': 1.0 59 | }) 60 | 61 | # HR can be specified separately 62 | self.assert_params("loraname:0.5@0,1@1:hr=0.6", { 63 | 'hrte': 0.6, 64 | 'hrunet': 0.6, 65 | 'te': [[0.5, 1.0], [0.0, 1.0]], 66 | 'unet': [[0.5, 1.0], [0.0, 1.0]] 67 | }) 68 | 69 | # Explicit TE cascades 70 | self.assert_params("loraname:te=0.5@0,1@1", { 71 | 'te': [[0.5, 1.0], [0.0, 1.0]], 72 | 'unet': [[0.5, 1.0], [0.0, 1.0]], 73 | 'hrte': [[0.5, 1.0], [0.0, 1.0]], 74 | 'hrunet': [[0.5, 1.0], [0.0, 1.0]], 75 | }) 76 | 77 | # Implicit TE cascades, explicit unet cascades 78 | self.assert_params("loraname:unet=0.5@0,1@1", { 79 | 'te': 1.0, 80 | 'unet': [[0.5, 1.0], [0.0, 1.0]], 81 | 'hrte': 1.0, 82 | 'hrunet': [[0.5, 1.0], [0.0, 1.0]], 83 | }) 84 | 85 | # Explicit HR TE overrides lowres TE 86 | self.assert_params("loraname:unet=0.5@0,1@1:hrte=0.5", { 87 | 'te': 1.0, 88 | 'unet': [[0.5, 1.0], [0.0, 1.0]], 89 | 'hrte': 0.5, 90 | 'hrunet': [[0.5, 1.0], [0.0, 1.0]], 91 | }) 92 | 93 | # Explicit HR TE overrides HR 94 | self.assert_params("loraname:hr=0.6:hrte=0.5", { 95 | 'te': 1.0, 96 | 'unet': 1.0, 97 | 'hrte': 0.5, 98 | 'hrunet': 0.6, 99 | }) 100 | 101 | self.assert_params("loraname:0.8@0.15,0@0.3:hr=0", { 102 | 'hrte': 0.0, 103 | 'hrunet': 0.0, 104 | 'te': [[0.8, 0.0], [0.15, 0.3]], 105 | 'unet': [[0.8, 0.0], [0.15, 0.3]] 106 | }) 107 | --------------------------------------------------------------------------------