├── .gitignore ├── LICENSE ├── README.md ├── lib_neutral_prompt ├── cfg_denoiser_hijack.py ├── external_code │ ├── __init__.py │ └── api.py ├── global_state.py ├── hijacker.py ├── neutral_prompt_parser.py ├── prompt_parser_hijack.py ├── ui.py └── xyz_grid.py ├── scripts └── neutral_prompt.py └── test └── perp_parser ├── __init__.py ├── basic_test.py └── malicious_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ljleb 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neutral Prompt 2 | 3 | Neutral prompt is an a1111 webui extension that adds alternative composable diffusion keywords to the prompt language. It enhances the original implementation using more recent research. 4 | 5 | ## Features 6 | 7 | - [Perp-Neg](https://perp-neg.github.io/) orthogonal prompts, invoked using the `AND_PERP` keyword 8 | - saliency-aware noise blending, invoked using the `AND_SALT` keyword (credits to [Magic Fusion](https://magicfusion.github.io/) for the algorithm used to determine SNB maps from epsilons) 9 | - semantic guidance top-k filtering, invoked using the `AND_TOPK` keyword (reference: https://arxiv.org/abs/2301.12247) 10 | - standard deviation based CFG rescaling (Reference: https://arxiv.org/abs/2305.08891, section 3.4) 11 | 12 | ## Usage 13 | 14 | *Disclaimer: some sections of the readme have been generated by GPT-4. If anything is unclear, feel free to ask for clarifications in the [discussions](https://github.com/ljleb/sd-webui-neutral-prompt/discussions).* 15 | 16 | ### Keyword `AND_PERP` 17 | 18 | The `AND_PERP` keyword, standing for "PERPendicular `AND`", integrates the orthogonalization process described in the Perp-Neg paper. Essentially, `AND_PERP` allows for prompting concepts that highly overlap with regular prompts, by negating contradicting concepts. 19 | 20 | You could visualize it as such: if `AND` prompts are "greedy" (taking as much space as possible in the output), `AND_PERP` prompts are opposite, relinquishing control as soon as there is a disagreement in the generated output. 21 | 22 | ### Keyword `AND_SALT` 23 | 24 | Saliency-aware blending is made possible using the `AND_SALT` keyword, shorthand for "SALienT `AND`". In essence, `AND_SALT` keeps the highest activation pixels at each denoising step. 25 | 26 | Think of it as a territorial dispute: the image generated by the `AND` prompts is one country, and the images generated by `AND_SALT` prompts represent neighbouring nations. They're all vying for the same land - whoever strikes the strongest at a given time (denoising step) and location (latent pixel) claims it. 27 | 28 | ### Keyword `AND_TOPK` 29 | 30 | The `AND_TOPK` keyword refers to "TOP-K filtering". It keeps only the "k" highest activation latent pixels in the noise map and discards the rest. It works similarly to `AND_SALT`, except that the high-activation regions are simply added instead of replacing previous content. 31 | 32 | Currently, k is constantly 5% of all latent pixels, meaning 95% of the weakest latent pixel values at each step are discarded. 33 | 34 | Top-k filtering is useful when you want to have a more targeted effect on the generated image. It should work best with smaller objects and details. 35 | 36 | ## Examples 37 | 38 | ### Using the `AND_PERP` Keyword 39 | 40 | Here is an example to illustrate one use case of the `AND_PREP` keyword. Prompt: 41 | 42 | `beautiful castle landscape AND monster house castle :-1` 43 | 44 | This is an XY grid with prompt S/R `AND, AND_PERP`: 45 | 46 | ![image](https://github.com/ljleb/sd-webui-neutral-prompt/assets/32277961/29f3cf34-2ed4-45d2-b73a-b6fadec21d61) 47 | 48 | Key observations: 49 | 50 | - The `AND_PERP` images exhibit a higher dynamic range compared to the `AND` images. 51 | - Since the prompts have a lot of overlap, the `AND` images sometimes struggle to depict a castle. This isn't a problem for the `AND_PERP` images. 52 | - The `AND` images tend to lean towards a purple color, because this was the path of least resistance between the two opposing prompts during generation. In contrast, the `AND_PERP` images, free from this tug-of-war, present a clearer representation. 53 | 54 | ### Using the `AND_SALT` Keyword 55 | 56 | The `AND_SALT` keyword can be used to invoke saliency-aware blending. It spotlights and accentuates areas of high-activation in the output. 57 | 58 | Consider this example prompt utilizing `AND_SALT`: 59 | 60 | ``` 61 | a vibrant rainforest with lush green foliage 62 | AND_SALT the glimmering rays of a golden sunset piercing through the trees 63 | ``` 64 | 65 | In this case, the extension identifies and isolates the most salient regions in the sunset prompt. Then, the extension applies this marsked image to the rainforest prompt. Only the portions of the rainforest prompt that coincide with the salient areas of the sunset prompt are affected. These areas are replaced by pixels from the sunset prompt. 66 | 67 | This is an XY grid with prompt S/R `AND_SALT, AND, AND_PERP`: 68 | 69 | ![xyz_grid-0008-1564977627-a vibrant rainforest with lush green foliage_AND_SALT the glimmering rays of a golden sunset piercing through the trees](https://github.com/ljleb/sd-webui-neutral-prompt/assets/32277961/2404f20b-47f6-457f-b4c5-76b9fd919345) 70 | 71 | Key observations: 72 | 73 | - `AND_SALT` behaves more diplomatically, enhancing areas where its impact makes the most sense and aligning with high activity regions in the output 74 | - `AND` gives equal weight to both prompts, creating a blended result 75 | - `AND_PERP` will find its way through anything not blocked by the regular prompt 76 | 77 | ## Advanced Features 78 | 79 | ### Nesting prompts 80 | 81 | The extension supports nesting of all prompt keywords including `AND`, allowing greater flexibility and control over the final output. Here's an example of how these keywords can be combined: 82 | 83 | ``` 84 | magical tree forests, eternal city 85 | AND_PERP [ 86 | electrical pole voyage 87 | AND_SALT small nocturne companion 88 | ] 89 | AND_SALT [ 90 | electrical tornado 91 | AND_SALT electric arcs, bzzz, sparks 92 | ] 93 | ``` 94 | 95 | To generate the final image from the diffusion model: 96 | 97 | 1. The extension first processes the root `AND` prompts. In this case, it's just `magical tree forests, eternal city` 98 | 2. It then processes the `AND_SALT` prompt `small nocturne companion` in the context of `electrical pole voyage`. This enhances salient features in the `electrical pole voyage` image 99 | 3. This new image is orthogonalized with the image from `magical tree forests, eternal city`, blending the details of the 'electrical pole voyage' into the main scene without creating conflicts 100 | 4. The extension then turns to the second `AND_SALT` group. It processes `electric arcs, bzzz, sparks` in the context of `electrical tornado`, amplifying salient features in the electrical tornado image 101 | 5. The image from this `AND_SALT` group is then combined with the `magical tree forests, eternal city` image. The final output retains the strongest features from both the `electrical tornado` (enhanced by 'electric arcs, bzzz, sparks') and the earlier 'magical tree forests, eternal city' scene influenced by the 'electrical pole voyage' 102 | 103 | Each keyword can define a distinct denoising space within its square brackets `[...]`. Prompts inside it merge into a single image before further processing down the prompt tree. 104 | 105 | While there's no strict limit on the depth of nesting, experimental evidence suggests that going beyond a depth of 2 is generally unnecessary. We're still exploring the added precision from deeper nesting. If you discover innovative ways of controlling the generations using nested prompts, please share in the discussions! 106 | 107 | ![image](https://github.com/ljleb/sd-webui-neutral-prompt/assets/32277961/f16587fe-2244-4832-a253-98f819a9e2e0) 108 | 109 | ## Special Mentions 110 | 111 | Special thanks to these people for helping make this extension possible: 112 | 113 | - [Ai-Casanova](https://github.com/AI-Casanova) : for sharing mathematical knowledge, time, and conducting proof-testing to enhance the robustness of this extension 114 | -------------------------------------------------------------------------------- /lib_neutral_prompt/cfg_denoiser_hijack.py: -------------------------------------------------------------------------------- 1 | from lib_neutral_prompt import hijacker, global_state, neutral_prompt_parser 2 | from modules import script_callbacks, sd_samplers, shared 3 | from typing import Tuple, List 4 | import dataclasses 5 | import functools 6 | import torch 7 | import sys 8 | import textwrap 9 | 10 | 11 | def combine_denoised_hijack( 12 | x_out: torch.Tensor, 13 | batch_cond_indices: List[List[Tuple[int, float]]], 14 | text_uncond: torch.Tensor, 15 | cond_scale: float, 16 | original_function, 17 | ) -> torch.Tensor: 18 | if not global_state.is_enabled: 19 | return original_function(x_out, batch_cond_indices, text_uncond, cond_scale) 20 | 21 | denoised = get_webui_denoised(x_out, batch_cond_indices, text_uncond, cond_scale, original_function) 22 | uncond = x_out[-text_uncond.shape[0]:] 23 | 24 | for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)): 25 | args = CombineDenoiseArgs(x_out, uncond[batch_i], cond_indices) 26 | cond_delta = prompt.accept(CondDeltaVisitor(), args, 0) 27 | aux_cond_delta = prompt.accept(AuxCondDeltaVisitor(), args, cond_delta, 0) 28 | cfg_cond = denoised[batch_i] + aux_cond_delta * cond_scale 29 | denoised[batch_i] = cfg_rescale(cfg_cond, uncond[batch_i] + cond_delta + aux_cond_delta) 30 | 31 | return denoised 32 | 33 | 34 | def get_webui_denoised( 35 | x_out: torch.Tensor, 36 | batch_cond_indices: List[List[Tuple[int, float]]], 37 | text_uncond: torch.Tensor, 38 | cond_scale: float, 39 | original_function, 40 | ): 41 | uncond = x_out[-text_uncond.shape[0]:] 42 | sliced_batch_x_out = [] 43 | sliced_batch_cond_indices = [] 44 | 45 | for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)): 46 | args = CombineDenoiseArgs(x_out, uncond[batch_i], cond_indices) 47 | sliced_x_out, sliced_cond_indices = gather_webui_conds(prompt, args, 0, len(sliced_batch_x_out)) 48 | if sliced_cond_indices: 49 | sliced_batch_cond_indices.append(sliced_cond_indices) 50 | sliced_batch_x_out.extend(sliced_x_out) 51 | 52 | sliced_batch_x_out += list(uncond) 53 | sliced_batch_x_out = torch.stack(sliced_batch_x_out, dim=0) 54 | return original_function(sliced_batch_x_out, sliced_batch_cond_indices, text_uncond, cond_scale) 55 | 56 | 57 | def cfg_rescale(cfg_cond, cond): 58 | if global_state.cfg_rescale == 0: 59 | return cfg_cond 60 | 61 | global_state.apply_and_clear_cfg_rescale_override() 62 | cfg_cond_mean = cfg_cond.mean() 63 | cfg_rescale_mean = (1 - global_state.cfg_rescale) * cfg_cond_mean + global_state.cfg_rescale * cond.mean() 64 | cfg_rescale_factor = global_state.cfg_rescale * (cond.std() / cfg_cond.std() - 1) + 1 65 | return cfg_rescale_mean + (cfg_cond - cfg_cond_mean) * cfg_rescale_factor 66 | 67 | 68 | @dataclasses.dataclass 69 | class CombineDenoiseArgs: 70 | x_out: torch.Tensor 71 | uncond: torch.Tensor 72 | cond_indices: List[Tuple[int, float]] 73 | 74 | 75 | def gather_webui_conds( 76 | prompt: neutral_prompt_parser.CompositePrompt, 77 | args: CombineDenoiseArgs, 78 | index_in: int, 79 | index_out: int, 80 | ) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]: 81 | sliced_x_out = [] 82 | sliced_cond_indices = [] 83 | 84 | for child in prompt.children: 85 | if child.conciliation is None: 86 | if isinstance(child, neutral_prompt_parser.LeafPrompt): 87 | child_x_out = args.x_out[args.cond_indices[index_in][0]] 88 | else: 89 | child_x_out = child.accept(CondDeltaVisitor(), args, index_in) 90 | child_x_out += child.accept(AuxCondDeltaVisitor(), args, child_x_out, index_in) 91 | child_x_out += args.uncond 92 | index_offset = index_out + len(sliced_x_out) 93 | sliced_x_out.append(child_x_out) 94 | sliced_cond_indices.append((index_offset, child.weight)) 95 | 96 | index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor()) 97 | 98 | return sliced_x_out, sliced_cond_indices 99 | 100 | 101 | class CondDeltaVisitor: 102 | def visit_leaf_prompt( 103 | self, 104 | that: neutral_prompt_parser.LeafPrompt, 105 | args: CombineDenoiseArgs, 106 | index: int, 107 | ) -> torch.Tensor: 108 | cond_info = args.cond_indices[index] 109 | if that.weight != cond_info[1]: 110 | console_warn(f''' 111 | An unexpected noise weight was encountered at prompt #{index} 112 | Expected :{that.weight}, but got :{cond_info[1]} 113 | This is likely due to another extension also monkey patching the webui `combine_denoised` function 114 | Please open a bug report here so that the conflict can be resolved: 115 | https://github.com/ljleb/sd-webui-neutral-prompt/issues 116 | ''') 117 | 118 | return args.x_out[cond_info[0]] - args.uncond 119 | 120 | def visit_composite_prompt( 121 | self, 122 | that: neutral_prompt_parser.CompositePrompt, 123 | args: CombineDenoiseArgs, 124 | index: int, 125 | ) -> torch.Tensor: 126 | cond_delta = torch.zeros_like(args.x_out[0]) 127 | 128 | for child in that.children: 129 | if child.conciliation is None: 130 | child_cond_delta = child.accept(CondDeltaVisitor(), args, index) 131 | child_cond_delta += child.accept(AuxCondDeltaVisitor(), args, child_cond_delta, index) 132 | cond_delta += child.weight * child_cond_delta 133 | 134 | index += child.accept(neutral_prompt_parser.FlatSizeVisitor()) 135 | 136 | return cond_delta 137 | 138 | 139 | class AuxCondDeltaVisitor: 140 | def visit_leaf_prompt( 141 | self, 142 | that: neutral_prompt_parser.LeafPrompt, 143 | args: CombineDenoiseArgs, 144 | cond_delta: torch.Tensor, 145 | index: int, 146 | ) -> torch.Tensor: 147 | return torch.zeros_like(args.x_out[0]) 148 | 149 | def visit_composite_prompt( 150 | self, 151 | that: neutral_prompt_parser.CompositePrompt, 152 | args: CombineDenoiseArgs, 153 | cond_delta: torch.Tensor, 154 | index: int, 155 | ) -> torch.Tensor: 156 | aux_cond_delta = torch.zeros_like(args.x_out[0]) 157 | salient_cond_deltas = [] 158 | 159 | for child in that.children: 160 | if child.conciliation is not None: 161 | child_cond_delta = child.accept(CondDeltaVisitor(), args, index) 162 | child_cond_delta += child.accept(AuxCondDeltaVisitor(), args, child_cond_delta, index) 163 | 164 | if child.conciliation == neutral_prompt_parser.ConciliationStrategy.PERPENDICULAR: 165 | aux_cond_delta += child.weight * get_perpendicular_component(cond_delta, child_cond_delta) 166 | elif child.conciliation == neutral_prompt_parser.ConciliationStrategy.SALIENCE_MASK: 167 | salient_cond_deltas.append((child_cond_delta, child.weight)) 168 | elif child.conciliation == neutral_prompt_parser.ConciliationStrategy.SEMANTIC_GUIDANCE: 169 | aux_cond_delta += child.weight * filter_abs_top_k(child_cond_delta, 0.05) 170 | 171 | index += child.accept(neutral_prompt_parser.FlatSizeVisitor()) 172 | 173 | aux_cond_delta += salient_blend(cond_delta, salient_cond_deltas) 174 | return aux_cond_delta 175 | 176 | 177 | def get_perpendicular_component(normal: torch.Tensor, vector: torch.Tensor) -> torch.Tensor: 178 | if (normal == 0).all(): 179 | if shared.state.sampling_step <= 0: 180 | warn_projection_not_found() 181 | 182 | return vector 183 | 184 | return vector - normal * torch.sum(normal * vector) / torch.norm(normal) ** 2 185 | 186 | 187 | def salient_blend(normal: torch.Tensor, vectors: List[Tuple[torch.Tensor, float]]) -> torch.Tensor: 188 | """ 189 | Blends the `normal` tensor with `vectors` in salient regions, weighting contributions by their weights. 190 | Salience maps are calculated to identify regions of interest. 191 | The blended result combines `normal` and vector information in salient regions. 192 | """ 193 | 194 | salience_maps = [get_salience(normal)] + [get_salience(vector) for vector, _ in vectors] 195 | mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0) 196 | 197 | result = torch.zeros_like(normal) 198 | for mask_i, (vector, weight) in enumerate(vectors, start=1): 199 | vector_mask = (mask == mask_i).float() 200 | result += weight * vector_mask * (vector - normal) 201 | 202 | return result 203 | 204 | 205 | def get_salience(vector: torch.Tensor) -> torch.Tensor: 206 | return torch.softmax(torch.abs(vector).flatten(), dim=0).reshape_as(vector) 207 | 208 | 209 | def filter_abs_top_k(vector: torch.Tensor, k_ratio: float) -> torch.Tensor: 210 | k = int(torch.numel(vector) * (1 - k_ratio)) 211 | top_k, _ = torch.kthvalue(torch.abs(torch.flatten(vector)), k) 212 | return vector * (torch.abs(vector) >= top_k).to(vector.dtype) 213 | 214 | 215 | sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get( 216 | module=sd_samplers, 217 | hijacker_attribute='__neutral_prompt_hijacker', 218 | on_uninstall=script_callbacks.on_script_unloaded, 219 | ) 220 | 221 | 222 | @sd_samplers_hijacker.hijack('create_sampler') 223 | def create_sampler_hijack(name: str, model, original_function): 224 | sampler = original_function(name, model) 225 | if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'): 226 | if global_state.is_enabled: 227 | warn_unsupported_sampler() 228 | 229 | return sampler 230 | 231 | sampler.model_wrap_cfg.combine_denoised = functools.partial( 232 | combine_denoised_hijack, 233 | original_function=sampler.model_wrap_cfg.combine_denoised 234 | ) 235 | return sampler 236 | 237 | 238 | def warn_unsupported_sampler(): 239 | console_warn(''' 240 | Neutral prompt relies on composition via AND, which the webui does not support when using any of the DDIM, PLMS and UniPC samplers 241 | The sampler will NOT be patched 242 | Falling back on original sampler implementation... 243 | ''') 244 | 245 | 246 | def warn_projection_not_found(): 247 | console_warn(''' 248 | Could not find a projection for one or more AND_PERP prompts 249 | These prompts will NOT be made perpendicular 250 | ''') 251 | 252 | 253 | def console_warn(message): 254 | if not global_state.verbose: 255 | return 256 | 257 | print(f'\n[sd-webui-neutral-prompt extension]{textwrap.dedent(message)}', file=sys.stderr) 258 | -------------------------------------------------------------------------------- /lib_neutral_prompt/external_code/__init__.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | 4 | @contextlib.contextmanager 5 | def fix_path(): 6 | import sys 7 | from pathlib import Path 8 | 9 | extension_path = str(Path(__file__).parent.parent.parent) 10 | added = False 11 | if extension_path not in sys.path: 12 | sys.path.insert(0, extension_path) 13 | added = True 14 | 15 | yield 16 | 17 | if added: 18 | sys.path.remove(extension_path) 19 | 20 | 21 | with fix_path(): 22 | del fix_path, contextlib 23 | from .api import * 24 | -------------------------------------------------------------------------------- /lib_neutral_prompt/external_code/api.py: -------------------------------------------------------------------------------- 1 | from lib_neutral_prompt import global_state 2 | 3 | 4 | def override_cfg_rescale(cfg_rescale: float): 5 | global_state.cfg_rescale_override = cfg_rescale 6 | -------------------------------------------------------------------------------- /lib_neutral_prompt/global_state.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from lib_neutral_prompt import neutral_prompt_parser 3 | 4 | 5 | is_enabled: bool = False 6 | prompt_exprs: List[neutral_prompt_parser.PromptExpr] = [] 7 | cfg_rescale: float = 0.0 8 | verbose: bool = True 9 | cfg_rescale_override: Optional[float] = None 10 | 11 | 12 | def apply_and_clear_cfg_rescale_override(): 13 | global cfg_rescale, cfg_rescale_override 14 | if cfg_rescale_override is not None: 15 | cfg_rescale = cfg_rescale_override 16 | cfg_rescale_override = None 17 | -------------------------------------------------------------------------------- /lib_neutral_prompt/hijacker.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | 4 | class ModuleHijacker: 5 | def __init__(self, module): 6 | self.__module = module 7 | self.__original_functions = dict() 8 | 9 | def hijack(self, attribute): 10 | if attribute not in self.__original_functions: 11 | self.__original_functions[attribute] = getattr(self.__module, attribute) 12 | 13 | def decorator(function): 14 | setattr(self.__module, attribute, functools.partial(function, original_function=self.__original_functions[attribute])) 15 | return function 16 | 17 | return decorator 18 | 19 | def reset_module(self): 20 | for attribute, original_function in self.__original_functions.items(): 21 | setattr(self.__module, attribute, original_function) 22 | 23 | self.__original_functions.clear() 24 | 25 | @staticmethod 26 | def install_or_get(module, hijacker_attribute, on_uninstall=lambda _callback: None): 27 | if not hasattr(module, hijacker_attribute): 28 | module_hijacker = ModuleHijacker(module) 29 | setattr(module, hijacker_attribute, module_hijacker) 30 | on_uninstall(lambda: delattr(module, hijacker_attribute)) 31 | on_uninstall(module_hijacker.reset_module) 32 | return module_hijacker 33 | else: 34 | return getattr(module, hijacker_attribute) 35 | -------------------------------------------------------------------------------- /lib_neutral_prompt/neutral_prompt_parser.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import dataclasses 3 | import re 4 | from enum import Enum 5 | from typing import List, Tuple, Any, Optional 6 | 7 | 8 | class PromptKeyword(Enum): 9 | AND = 'AND' 10 | AND_PERP = 'AND_PERP' 11 | AND_SALT = 'AND_SALT' 12 | AND_TOPK = 'AND_TOPK' 13 | 14 | 15 | prompt_keywords = [e.value for e in PromptKeyword] 16 | 17 | 18 | class ConciliationStrategy(Enum): 19 | PERPENDICULAR = PromptKeyword.AND_PERP.value 20 | SALIENCE_MASK = PromptKeyword.AND_SALT.value 21 | SEMANTIC_GUIDANCE = PromptKeyword.AND_TOPK.value 22 | 23 | 24 | conciliation_strategies = [e.value for e in ConciliationStrategy] 25 | 26 | 27 | @dataclasses.dataclass 28 | class PromptExpr(abc.ABC): 29 | weight: float 30 | conciliation: Optional[ConciliationStrategy] 31 | 32 | @abc.abstractmethod 33 | def accept(self, visitor, *args, **kwargs) -> Any: 34 | pass 35 | 36 | 37 | @dataclasses.dataclass 38 | class LeafPrompt(PromptExpr): 39 | prompt: str 40 | 41 | def accept(self, visitor, *args, **kwargs): 42 | return visitor.visit_leaf_prompt(self, *args, **kwargs) 43 | 44 | 45 | @dataclasses.dataclass 46 | class CompositePrompt(PromptExpr): 47 | children: List[PromptExpr] 48 | 49 | def accept(self, visitor, *args, **kwargs): 50 | return visitor.visit_composite_prompt(self, *args, **kwargs) 51 | 52 | 53 | class FlatSizeVisitor: 54 | def visit_leaf_prompt(self, that: LeafPrompt) -> int: 55 | return 1 56 | 57 | def visit_composite_prompt(self, that: CompositePrompt) -> int: 58 | return sum(child.accept(self) for child in that.children) if that.children else 0 59 | 60 | 61 | def parse_root(string: str) -> CompositePrompt: 62 | tokens = tokenize(string) 63 | prompts = parse_prompts(tokens) 64 | return CompositePrompt(1., None, prompts) 65 | 66 | 67 | def parse_prompts(tokens: List[str], *, nested: bool = False) -> List[PromptExpr]: 68 | prompts = [parse_prompt(tokens, first=True, nested=nested)] 69 | while tokens: 70 | if nested and tokens[0] in [']']: 71 | break 72 | 73 | prompts.append(parse_prompt(tokens, first=False, nested=nested)) 74 | 75 | return prompts 76 | 77 | 78 | def parse_prompt(tokens: List[str], *, first: bool, nested: bool = False) -> PromptExpr: 79 | if not first and tokens[0] in prompt_keywords: 80 | prompt_type = tokens.pop(0) 81 | else: 82 | prompt_type = PromptKeyword.AND.value 83 | 84 | tokens_copy = tokens.copy() 85 | if tokens_copy and tokens_copy[0] == '[': 86 | tokens_copy.pop(0) 87 | prompts = parse_prompts(tokens_copy, nested=True) 88 | if tokens_copy: 89 | assert tokens_copy.pop(0) == ']' 90 | if len(prompts) > 1: 91 | tokens[:] = tokens_copy 92 | weight = parse_weight(tokens) 93 | conciliation = ConciliationStrategy(prompt_type) if prompt_type in conciliation_strategies else None 94 | return CompositePrompt(weight, conciliation, prompts) 95 | 96 | prompt_text, weight = parse_prompt_text(tokens, nested=nested) 97 | return LeafPrompt(weight, ConciliationStrategy(prompt_type) if prompt_type in conciliation_strategies else None, prompt_text) 98 | 99 | 100 | def parse_prompt_text(tokens: List[str], *, nested: bool = False) -> Tuple[str, float]: 101 | text = '' 102 | depth = 0 103 | weight = 1. 104 | while tokens: 105 | if tokens[0] == ']': 106 | if depth == 0: 107 | if nested: 108 | break 109 | else: 110 | depth -= 1 111 | elif tokens[0] == '[': 112 | depth += 1 113 | elif tokens[0] == ':': 114 | if len(tokens) >= 2 and is_float(tokens[1].strip()): 115 | if len(tokens) < 3 or tokens[2] in prompt_keywords or tokens[2] == ']' and depth == 0: 116 | tokens.pop(0) 117 | weight = float(tokens.pop(0).strip()) 118 | break 119 | elif tokens[0] in prompt_keywords: 120 | break 121 | 122 | text += tokens.pop(0) 123 | 124 | return text, weight 125 | 126 | 127 | def parse_weight(tokens: List[str]) -> float: 128 | weight = 1. 129 | if len(tokens) >= 2 and tokens[0] == ':' and is_float(tokens[1]): 130 | tokens.pop(0) 131 | weight = float(tokens.pop(0)) 132 | return weight 133 | 134 | 135 | def tokenize(s: str): 136 | prompt_keywords_regex = '|'.join(rf'\b{keyword}\b' for keyword in prompt_keywords) 137 | return [s for s in re.split(rf'(\[|\]|:|{prompt_keywords_regex})', s) if s.strip()] 138 | 139 | 140 | def is_float(string: str) -> bool: 141 | try: 142 | float(string) 143 | return True 144 | except ValueError: 145 | return False 146 | 147 | 148 | if __name__ == '__main__': 149 | res = parse_root(''' 150 | hello 151 | AND_PERP [ 152 | arst 153 | AND defg : 2 154 | AND_SALT [ 155 | very nested huh? what do you say :.0 156 | ] 157 | ] 158 | ''') 159 | pass 160 | -------------------------------------------------------------------------------- /lib_neutral_prompt/prompt_parser_hijack.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from lib_neutral_prompt import hijacker, global_state, neutral_prompt_parser 4 | from modules import script_callbacks, prompt_parser 5 | 6 | 7 | prompt_parser_hijacker = hijacker.ModuleHijacker.install_or_get( 8 | module=prompt_parser, 9 | hijacker_attribute='__neutral_prompt_hijacker', 10 | on_uninstall=script_callbacks.on_script_unloaded, 11 | ) 12 | 13 | 14 | @prompt_parser_hijacker.hijack('get_multicond_prompt_list') 15 | def get_multicond_prompt_list_hijack(prompts, original_function): 16 | if not global_state.is_enabled: 17 | return original_function(prompts) 18 | 19 | global_state.prompt_exprs = parse_prompts(prompts) 20 | webui_prompts = transpile_exprs(global_state.prompt_exprs) 21 | if isinstance(prompts, getattr(prompt_parser, 'SdConditioning', type(None))): 22 | webui_prompts = prompt_parser.SdConditioning(webui_prompts, copy_from=prompts) 23 | 24 | return original_function(webui_prompts) 25 | 26 | 27 | def parse_prompts(prompts: List[str]) -> List[neutral_prompt_parser.PromptExpr]: 28 | exprs = [] 29 | for prompt in prompts: 30 | expr = neutral_prompt_parser.parse_root(prompt) 31 | exprs.append(expr) 32 | 33 | return exprs 34 | 35 | 36 | def transpile_exprs(exprs: neutral_prompt_parser.PromptExpr): 37 | webui_prompts = [] 38 | for expr in exprs: 39 | webui_prompts.append(expr.accept(WebuiPromptVisitor())) 40 | 41 | return webui_prompts 42 | 43 | 44 | class WebuiPromptVisitor: 45 | def visit_leaf_prompt(self, that: neutral_prompt_parser.LeafPrompt) -> str: 46 | return f'{that.prompt} :{that.weight}' 47 | 48 | def visit_composite_prompt(self, that: neutral_prompt_parser.CompositePrompt) -> str: 49 | return ' AND '.join(child.accept(self) for child in that.children) 50 | -------------------------------------------------------------------------------- /lib_neutral_prompt/ui.py: -------------------------------------------------------------------------------- 1 | from lib_neutral_prompt import global_state, neutral_prompt_parser 2 | from modules import script_callbacks, shared 3 | from typing import Dict, Tuple, List, Callable 4 | import gradio as gr 5 | import dataclasses 6 | 7 | 8 | txt2img_prompt_textbox = None 9 | img2img_prompt_textbox = None 10 | 11 | 12 | prompt_types = { 13 | 'Perpendicular': neutral_prompt_parser.PromptKeyword.AND_PERP.value, 14 | 'Saliency-aware': neutral_prompt_parser.PromptKeyword.AND_SALT.value, 15 | 'Semantic guidance top-k': neutral_prompt_parser.PromptKeyword.AND_TOPK.value, 16 | } 17 | prompt_types_tooltip = '\n'.join([ 18 | 'AND - add all prompt features equally (webui builtin)', 19 | 'Perpendicular - reduce the impact of contradicting prompt features', 20 | 'Saliency-aware - strongest prompt features win', 21 | 'Semantic guidance top-k - small targeted changes', 22 | ]) 23 | 24 | 25 | @dataclasses.dataclass 26 | class AccordionInterface: 27 | get_elem_id: Callable 28 | 29 | def __post_init__(self): 30 | self.is_rendered = False 31 | 32 | self.cfg_rescale = gr.Slider(label='CFG rescale', minimum=0, maximum=1, value=0) 33 | self.neutral_prompt = gr.Textbox(label='Neutral prompt', show_label=False, lines=3, placeholder='Neutral prompt (click on apply below to append this to the positive prompt textbox)') 34 | self.neutral_cond_scale = gr.Slider(label='Prompt weight', minimum=-3, maximum=3, value=1) 35 | self.aux_prompt_type = gr.Dropdown(label='Prompt type', choices=list(prompt_types.keys()), value=next(iter(prompt_types.keys())), tooltip=prompt_types_tooltip, elem_id=self.get_elem_id('formatter_prompt_type')) 36 | self.append_to_prompt_button = gr.Button(value='Apply to prompt') 37 | 38 | def arrange_components(self, is_img2img: bool): 39 | if self.is_rendered: 40 | return 41 | 42 | with gr.Accordion(label='Neutral Prompt', open=False): 43 | self.cfg_rescale.render() 44 | with gr.Accordion(label='Prompt formatter', open=False): 45 | self.neutral_prompt.render() 46 | self.neutral_cond_scale.render() 47 | self.aux_prompt_type.render() 48 | self.append_to_prompt_button.render() 49 | 50 | def connect_events(self, is_img2img: bool): 51 | if self.is_rendered: 52 | return 53 | 54 | prompt_textbox = img2img_prompt_textbox if is_img2img else txt2img_prompt_textbox 55 | self.append_to_prompt_button.click( 56 | fn=lambda init_prompt, prompt, scale, prompt_type: (f'{init_prompt}\n{prompt_types[prompt_type]} {prompt} :{scale}', ''), 57 | inputs=[prompt_textbox, self.neutral_prompt, self.neutral_cond_scale, self.aux_prompt_type], 58 | outputs=[prompt_textbox, self.neutral_prompt] 59 | ) 60 | 61 | def set_rendered(self, value: bool = True): 62 | self.is_rendered = value 63 | 64 | def get_components(self) -> Tuple[gr.components.Component]: 65 | return ( 66 | self.cfg_rescale, 67 | ) 68 | 69 | def get_infotext_fields(self) -> Tuple[Tuple[gr.components.Component, str]]: 70 | return tuple(zip(self.get_components(), ( 71 | 'CFG Rescale phi', 72 | ))) 73 | 74 | def get_paste_field_names(self) -> List[str]: 75 | return [ 76 | 'CFG Rescale phi', 77 | ] 78 | 79 | def get_extra_generation_params(self, args: Dict) -> Dict: 80 | return { 81 | 'CFG Rescale phi': args['cfg_rescale'], 82 | } 83 | 84 | def unpack_processing_args( 85 | self, 86 | cfg_rescale: float, 87 | ) -> Dict: 88 | return { 89 | 'cfg_rescale': cfg_rescale, 90 | } 91 | 92 | 93 | def on_ui_settings(): 94 | section = ('neutral_prompt', 'Neutral Prompt') 95 | 96 | shared.opts.add_option('neutral_prompt_enabled', shared.OptionInfo(True, 'Enable neutral-prompt extension', section=section)) 97 | global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True) 98 | 99 | shared.opts.add_option('neutral_prompt_verbose', shared.OptionInfo(False, 'Enable verbose debugging for neutral-prompt', section=section)) 100 | shared.opts.onchange('neutral_prompt_verbose', update_verbose) 101 | 102 | 103 | script_callbacks.on_ui_settings(on_ui_settings) 104 | 105 | 106 | def update_verbose(): 107 | global_state.verbose = shared.opts.data.get('neutral_prompt_verbose', False) 108 | 109 | 110 | def on_after_component(component, **_kwargs): 111 | if getattr(component, 'elem_id', None) == 'txt2img_prompt': 112 | global txt2img_prompt_textbox 113 | txt2img_prompt_textbox = component 114 | 115 | if getattr(component, 'elem_id', None) == 'img2img_prompt': 116 | global img2img_prompt_textbox 117 | img2img_prompt_textbox = component 118 | 119 | 120 | script_callbacks.on_after_component(on_after_component) 121 | -------------------------------------------------------------------------------- /lib_neutral_prompt/xyz_grid.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from types import ModuleType 3 | from typing import Optional 4 | from modules import scripts 5 | from lib_neutral_prompt import global_state 6 | 7 | 8 | def patch(): 9 | xyz_module = find_xyz_module() 10 | if xyz_module is None: 11 | print("[sd-webui-neutral-prompt]", "xyz_grid.py not found.", file=sys.stderr) 12 | return 13 | 14 | xyz_module.axis_options.extend([ 15 | xyz_module.AxisOption("[Neutral Prompt] CFG Rescale", int_or_float, apply_cfg_rescale()), 16 | ]) 17 | 18 | 19 | class XyzFloat(float): 20 | is_xyz: bool = True 21 | 22 | 23 | def apply_cfg_rescale(): 24 | def callback(_p, v, _vs): 25 | global_state.cfg_rescale = XyzFloat(v) 26 | 27 | return callback 28 | 29 | 30 | def int_or_float(string): 31 | try: 32 | return int(string) 33 | except ValueError: 34 | return float(string) 35 | 36 | 37 | def find_xyz_module() -> Optional[ModuleType]: 38 | for data in scripts.scripts_data: 39 | if data.script_class.__module__ in {"xyz_grid.py", "xy_grid.py"} and hasattr(data, "module"): 40 | return data.module 41 | 42 | return None 43 | -------------------------------------------------------------------------------- /scripts/neutral_prompt.py: -------------------------------------------------------------------------------- 1 | from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser, prompt_parser_hijack, cfg_denoiser_hijack, ui, xyz_grid 2 | from modules import scripts, processing, shared 3 | from typing import Dict 4 | import functools 5 | 6 | 7 | class NeutralPromptScript(scripts.Script): 8 | def __init__(self): 9 | self.accordion_interface = None 10 | self._is_img2img = False 11 | 12 | @property 13 | def is_img2img(self): 14 | return self._is_img2img 15 | 16 | @is_img2img.setter 17 | def is_img2img(self, is_img2img): 18 | self._is_img2img = is_img2img 19 | if self.accordion_interface is None: 20 | self.accordion_interface = ui.AccordionInterface(self.elem_id) 21 | 22 | def title(self) -> str: 23 | return "Neutral Prompt" 24 | 25 | def show(self, is_img2img: bool): 26 | return scripts.AlwaysVisible 27 | 28 | def ui(self, is_img2img: bool): 29 | self.hijack_composable_lora(is_img2img) 30 | 31 | self.accordion_interface.arrange_components(is_img2img) 32 | self.accordion_interface.connect_events(is_img2img) 33 | self.infotext_fields = self.accordion_interface.get_infotext_fields() 34 | self.paste_field_names = self.accordion_interface.get_paste_field_names() 35 | self.accordion_interface.set_rendered() 36 | return self.accordion_interface.get_components() 37 | 38 | def process(self, p: processing.StableDiffusionProcessing, *args): 39 | args = self.accordion_interface.unpack_processing_args(*args) 40 | 41 | self.update_global_state(args) 42 | if global_state.is_enabled: 43 | p.extra_generation_params.update(self.accordion_interface.get_extra_generation_params(args)) 44 | 45 | def update_global_state(self, args: Dict): 46 | if shared.state.job_no == 0: 47 | global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True) 48 | 49 | for k, v in args.items(): 50 | try: 51 | getattr(global_state, k) 52 | except AttributeError: 53 | continue 54 | 55 | if getattr(getattr(global_state, k), 'is_xyz', False): 56 | xyz_attr = getattr(global_state, k) 57 | xyz_attr.is_xyz = False 58 | args[k] = xyz_attr 59 | continue 60 | 61 | if shared.state.job_no > 0: 62 | continue 63 | 64 | setattr(global_state, k, v) 65 | 66 | def hijack_composable_lora(self, is_img2img): 67 | if self.accordion_interface.is_rendered: 68 | return 69 | 70 | lora_script = None 71 | script_runner = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img 72 | 73 | for script in script_runner.alwayson_scripts: 74 | if script.title().lower() == "composable lora": 75 | lora_script = script 76 | break 77 | 78 | if lora_script is not None: 79 | lora_script.process = functools.partial(composable_lora_process_hijack, original_function=lora_script.process) 80 | 81 | 82 | def composable_lora_process_hijack(p: processing.StableDiffusionProcessing, *args, original_function, **kwargs): 83 | if not global_state.is_enabled: 84 | return original_function(p, *args, **kwargs) 85 | 86 | exprs = prompt_parser_hijack.parse_prompts(p.all_prompts) 87 | all_prompts, p.all_prompts = p.all_prompts, prompt_parser_hijack.transpile_exprs(exprs) 88 | res = original_function(p, *args, **kwargs) 89 | # restore original prompts 90 | p.all_prompts = all_prompts 91 | return res 92 | 93 | 94 | xyz_grid.patch() 95 | -------------------------------------------------------------------------------- /test/perp_parser/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljleb/sd-webui-neutral-prompt/64ee2bef6318d2c97b87962e4babadc0d804d93c/test/perp_parser/__init__.py -------------------------------------------------------------------------------- /test/perp_parser/basic_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pathlib 3 | import sys 4 | sys.path.append(str(pathlib.Path(__file__).parent.parent.parent)) 5 | from lib_neutral_prompt import neutral_prompt_parser 6 | 7 | 8 | class TestPromptParser(unittest.TestCase): 9 | def setUp(self): 10 | self.simple_prompt = neutral_prompt_parser.parse_root("hello :1.0") 11 | self.and_prompt = neutral_prompt_parser.parse_root("hello AND goodbye :2.0") 12 | self.and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP goodbye :2.0") 13 | self.and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT goodbye :2.0") 14 | self.nested_and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0 AND_PERP welcome :3.0]") 15 | self.nested_and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0 AND_SALT welcome :3.0]") 16 | self.invalid_weight = neutral_prompt_parser.parse_root("hello :not_a_float") 17 | 18 | def test_simple_prompt_child_count(self): 19 | self.assertEqual(len(self.simple_prompt.children), 1) 20 | 21 | def test_simple_prompt_child_weight(self): 22 | self.assertEqual(self.simple_prompt.children[0].weight, 1.0) 23 | 24 | def test_simple_prompt_child_prompt(self): 25 | self.assertEqual(self.simple_prompt.children[0].prompt, "hello ") 26 | 27 | def test_square_weight_prompt(self): 28 | prompt = "a [b c d e : f g h :1.5]" 29 | parsed = neutral_prompt_parser.parse_root(prompt) 30 | self.assertEqual(parsed.children[0].prompt, prompt) 31 | 32 | composed_prompt = f"{prompt} AND_PERP other prompt" 33 | parsed = neutral_prompt_parser.parse_root(composed_prompt) 34 | self.assertEqual(parsed.children[0].prompt, prompt) 35 | 36 | def test_and_prompt_child_count(self): 37 | self.assertEqual(len(self.and_prompt.children), 2) 38 | 39 | def test_and_prompt_child_weights_and_prompts(self): 40 | self.assertEqual(self.and_prompt.children[0].weight, 1.0) 41 | self.assertEqual(self.and_prompt.children[0].prompt, "hello ") 42 | self.assertEqual(self.and_prompt.children[1].weight, 2.0) 43 | self.assertEqual(self.and_prompt.children[1].prompt, " goodbye ") 44 | 45 | def test_and_perp_prompt_child_count(self): 46 | self.assertEqual(len(self.and_perp_prompt.children), 2) 47 | 48 | def test_and_perp_prompt_child_types(self): 49 | self.assertIsInstance(self.and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt) 50 | self.assertIsInstance(self.and_perp_prompt.children[1], neutral_prompt_parser.LeafPrompt) 51 | 52 | def test_and_perp_prompt_nested_child(self): 53 | nested_child = self.and_perp_prompt.children[1] 54 | self.assertEqual(nested_child.weight, 2.0) 55 | self.assertEqual(nested_child.prompt.strip(), "goodbye") 56 | 57 | def test_nested_and_perp_prompt_child_count(self): 58 | self.assertEqual(len(self.nested_and_perp_prompt.children), 2) 59 | 60 | def test_nested_and_perp_prompt_child_types(self): 61 | self.assertIsInstance(self.nested_and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt) 62 | self.assertIsInstance(self.nested_and_perp_prompt.children[1], neutral_prompt_parser.CompositePrompt) 63 | 64 | def test_nested_and_perp_prompt_nested_child_types(self): 65 | nested_child = self.nested_and_perp_prompt.children[1].children[0] 66 | self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt) 67 | nested_child = self.nested_and_perp_prompt.children[1].children[1] 68 | self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt) 69 | 70 | def test_nested_and_perp_prompt_nested_child(self): 71 | nested_child = self.nested_and_perp_prompt.children[1].children[1] 72 | self.assertEqual(nested_child.weight, 3.0) 73 | self.assertEqual(nested_child.prompt.strip(), "welcome") 74 | 75 | def test_invalid_weight_child_count(self): 76 | self.assertEqual(len(self.invalid_weight.children), 1) 77 | 78 | def test_invalid_weight_child_weight(self): 79 | self.assertEqual(self.invalid_weight.children[0].weight, 1.0) 80 | 81 | def test_invalid_weight_child_prompt(self): 82 | self.assertEqual(self.invalid_weight.children[0].prompt, "hello :not_a_float") 83 | 84 | def test_and_salt_prompt_child_count(self): 85 | self.assertEqual(len(self.and_salt_prompt.children), 2) 86 | 87 | def test_and_salt_prompt_child_types(self): 88 | self.assertIsInstance(self.and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt) 89 | self.assertIsInstance(self.and_salt_prompt.children[1], neutral_prompt_parser.LeafPrompt) 90 | 91 | def test_and_salt_prompt_nested_child(self): 92 | nested_child = self.and_salt_prompt.children[1] 93 | self.assertEqual(nested_child.weight, 2.0) 94 | self.assertEqual(nested_child.prompt.strip(), "goodbye") 95 | 96 | def test_nested_and_salt_prompt_child_count(self): 97 | self.assertEqual(len(self.nested_and_salt_prompt.children), 2) 98 | 99 | def test_nested_and_salt_prompt_child_types(self): 100 | self.assertIsInstance(self.nested_and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt) 101 | self.assertIsInstance(self.nested_and_salt_prompt.children[1], neutral_prompt_parser.CompositePrompt) 102 | 103 | def test_nested_and_salt_prompt_nested_child_types(self): 104 | nested_child = self.nested_and_salt_prompt.children[1].children[0] 105 | self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt) 106 | nested_child = self.nested_and_salt_prompt.children[1].children[1] 107 | self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt) 108 | 109 | def test_nested_and_salt_prompt_nested_child(self): 110 | nested_child = self.nested_and_salt_prompt.children[1].children[1] 111 | self.assertEqual(nested_child.weight, 3.0) 112 | self.assertEqual(nested_child.prompt.strip(), "welcome") 113 | 114 | def test_start_with_prompt_editing(self): 115 | prompt = "[(long shot:1.2):0.1] detail.." 116 | res = neutral_prompt_parser.parse_root(prompt) 117 | self.assertEqual(res.children[0].weight, 1.0) 118 | self.assertEqual(res.children[0].prompt, prompt) 119 | 120 | 121 | if __name__ == '__main__': 122 | unittest.main() 123 | -------------------------------------------------------------------------------- /test/perp_parser/malicious_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pathlib 3 | import sys 4 | sys.path.append(str(pathlib.Path(__file__).parent.parent.parent)) 5 | from lib_neutral_prompt import neutral_prompt_parser 6 | 7 | 8 | class TestMaliciousPromptParser(unittest.TestCase): 9 | def setUp(self): 10 | self.parser = neutral_prompt_parser 11 | 12 | def test_empty(self): 13 | result = self.parser.parse_root("") 14 | self.assertEqual(result.children[0].prompt, "") 15 | self.assertEqual(result.children[0].weight, 1.0) 16 | 17 | def test_zero_weight(self): 18 | result = self.parser.parse_root("hello :0.0") 19 | self.assertEqual(result.children[0].weight, 0.0) 20 | 21 | def test_mixed_positive_and_negative_weights(self): 22 | result = self.parser.parse_root("hello :1.0 AND goodbye :-2.0") 23 | self.assertEqual(result.children[0].weight, 1.0) 24 | self.assertEqual(result.children[1].weight, -2.0) 25 | 26 | def test_debalanced_square_brackets(self): 27 | prompt = "a [ b " * 100 28 | result = self.parser.parse_root(prompt) 29 | self.assertEqual(result.children[0].prompt, prompt) 30 | 31 | prompt = "a ] b " * 100 32 | result = self.parser.parse_root(prompt) 33 | self.assertEqual(result.children[0].prompt, prompt) 34 | 35 | repeats = 10 36 | prompt = "a [ [ b AND c ] " * repeats 37 | result = self.parser.parse_root(prompt) 38 | self.assertEqual([x.prompt for x in result.children], ["a [[ b ", *[" c ] a [[ b "] * (repeats - 1), " c ]"]) 39 | 40 | repeats = 10 41 | prompt = "a [ b AND c ] ] " * repeats 42 | result = self.parser.parse_root(prompt) 43 | self.assertEqual([x.prompt for x in result.children], ["a [ b ", *[" c ]] a [ b "] * (repeats - 1), " c ]]"]) 44 | 45 | def test_erroneous_syntax(self): 46 | result = self.parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0") 47 | self.assertEqual(result.children[0].weight, 1.0) 48 | self.assertEqual(result.children[1].prompt, "[goodbye ") 49 | self.assertEqual(result.children[1].weight, 2.0) 50 | 51 | result = self.parser.parse_root("hello :1.0 AND_PERP goodbye :2.0]") 52 | self.assertEqual(result.children[0].weight, 1.0) 53 | self.assertEqual(result.children[1].prompt, " goodbye ") 54 | 55 | result = self.parser.parse_root("hello :1.0 AND_PERP goodbye] :2.0") 56 | self.assertEqual(result.children[1].prompt, " goodbye]") 57 | self.assertEqual(result.children[1].weight, 2.0) 58 | 59 | result = self.parser.parse_root("hello :1.0 AND_PERP a [ goodbye :2.0") 60 | self.assertEqual(result.children[1].weight, 2.0) 61 | self.assertEqual(result.children[1].prompt, " a [ goodbye ") 62 | 63 | result = self.parser.parse_root("hello :1.0 AND_PERP AND goodbye :2.0") 64 | self.assertEqual(result.children[0].weight, 1.0) 65 | self.assertEqual(result.children[2].prompt, " goodbye ") 66 | 67 | def test_huge_number_of_prompt_parts(self): 68 | result = self.parser.parse_root(" AND ".join(f"hello{i} :{i}" for i in range(10**4))) 69 | self.assertEqual(len(result.children), 10**4) 70 | 71 | def test_prompt_ending_with_weight(self): 72 | result = self.parser.parse_root("hello :1.0 AND :2.0") 73 | self.assertEqual(result.children[0].weight, 1.0) 74 | self.assertEqual(result.children[1].prompt, "") 75 | self.assertEqual(result.children[1].weight, 2.0) 76 | 77 | def test_huge_input_string(self): 78 | big_string = "hello :1.0 AND " * 10**4 79 | result = self.parser.parse_root(big_string) 80 | self.assertEqual(len(result.children), 10**4 + 1) 81 | 82 | def test_deeply_nested_prompt(self): 83 | deeply_nested_prompt = "hello :1.0" + " AND_PERP [goodbye :2.0" * 100 + "]" * 100 84 | result = self.parser.parse_root(deeply_nested_prompt) 85 | self.assertIsInstance(result.children[1], neutral_prompt_parser.CompositePrompt) 86 | 87 | def test_complex_nested_prompts(self): 88 | complex_prompt = "hello :1.0 AND goodbye :2.0 AND_PERP [welcome :3.0 AND farewell :4.0 AND_PERP greetings:5.0]" 89 | result = self.parser.parse_root(complex_prompt) 90 | self.assertEqual(result.children[0].weight, 1.0) 91 | self.assertEqual(result.children[1].weight, 2.0) 92 | self.assertEqual(result.children[2].children[0].weight, 3.0) 93 | self.assertEqual(result.children[2].children[1].weight, 4.0) 94 | self.assertEqual(result.children[2].children[2].weight, 5.0) 95 | 96 | def test_string_with_random_characters(self): 97 | random_chars = "ASDFGHJKL:@#$/.,|}{><~`12[3]456AND_PERP7890" 98 | try: 99 | self.parser.parse_root(random_chars) 100 | except Exception: 101 | self.fail("parse_root couldn't handle a string with random characters.") 102 | 103 | def test_string_with_unexpected_symbols(self): 104 | unexpected_symbols = "hello :1.0 AND $%^&*()goodbye :2.0" 105 | try: 106 | self.parser.parse_root(unexpected_symbols) 107 | except Exception: 108 | self.fail("parse_root couldn't handle a string with unexpected symbols.") 109 | 110 | def test_string_with_unconventional_structure(self): 111 | unconventional_structure = "hello :1.0 AND_PERP :2.0 AND [goodbye]" 112 | try: 113 | self.parser.parse_root(unconventional_structure) 114 | except Exception: 115 | self.fail("parse_root couldn't handle a string with unconventional structure.") 116 | 117 | def test_string_with_mixed_alphabets_and_numbers(self): 118 | mixed_alphabets_and_numbers = "123hello :1.0 AND goodbye456 :2.0" 119 | try: 120 | self.parser.parse_root(mixed_alphabets_and_numbers) 121 | except Exception: 122 | self.fail("parse_root couldn't handle a string with mixed alphabets and numbers.") 123 | 124 | def test_string_with_nested_brackets(self): 125 | nested_brackets = "hello :1.0 AND [goodbye :2.0 AND [[welcome :3.0]]]" 126 | try: 127 | self.parser.parse_root(nested_brackets) 128 | except Exception: 129 | self.fail("parse_root couldn't handle a string with nested brackets.") 130 | 131 | def test_unmatched_opening_braces(self): 132 | unmatched_opening_braces = "hello [[[[[[[[[ :1.0 AND_PERP goodbye :2.0" 133 | try: 134 | self.parser.parse_root(unmatched_opening_braces) 135 | except Exception: 136 | self.fail("parse_root couldn't handle a string with unmatched opening braces.") 137 | 138 | def test_unmatched_closing_braces(self): 139 | unmatched_closing_braces = "hello :1.0 AND_PERP goodbye ]]]]]]]]] :2.0" 140 | try: 141 | self.parser.parse_root(unmatched_closing_braces) 142 | except Exception: 143 | self.fail("parse_root couldn't handle a string with unmatched closing braces.") 144 | 145 | def test_repeating_colons(self): 146 | repeating_colons = "hello ::::::: :1.0 AND_PERP goodbye :::: :2.0" 147 | try: 148 | self.parser.parse_root(repeating_colons) 149 | except Exception: 150 | self.fail("parse_root couldn't handle a string with repeating colons.") 151 | 152 | def test_excessive_whitespace(self): 153 | excessive_whitespace = "hello :1.0 AND_PERP goodbye :2.0" 154 | try: 155 | self.parser.parse_root(excessive_whitespace) 156 | except Exception: 157 | self.fail("parse_root couldn't handle a string with excessive whitespace.") 158 | 159 | def test_repeating_AND_keyword(self): 160 | repeating_AND_keyword = "hello :1.0 AND AND AND AND AND goodbye :2.0" 161 | try: 162 | self.parser.parse_root(repeating_AND_keyword) 163 | except Exception: 164 | self.fail("parse_root couldn't handle a string with repeating AND keyword.") 165 | 166 | def test_repeating_AND_PERP_keyword(self): 167 | repeating_AND_PERP_keyword = "hello :1.0 AND_PERP AND_PERP AND_PERP AND_PERP goodbye :2.0" 168 | try: 169 | self.parser.parse_root(repeating_AND_PERP_keyword) 170 | except Exception: 171 | self.fail("parse_root couldn't handle a string with repeating AND_PERP keyword.") 172 | 173 | def test_square_weight_prompt(self): 174 | prompt = "AND_PERP [weighted] you thought it was the end" 175 | try: 176 | self.parser.parse_root(prompt) 177 | except Exception: 178 | self.fail("parse_root couldn't handle a string starting with a square-weighted sub-prompt.") 179 | 180 | 181 | if __name__ == '__main__': 182 | unittest.main() 183 | --------------------------------------------------------------------------------