├── .gitignore
├── LICENSE
├── README.md
├── images
├── cover.jpg
├── idea.png
├── sample-1.png
├── sample-2.png
├── sample-3.png
├── sample-4.png
├── sample-4_small.png
├── sample-5.png
└── sample-5_small.png
└── scripts
├── cutoff.py
└── cutofflib
├── embedding.py
├── sdhook.py
├── utils.py
└── xyz.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | /.vs
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Cutoff - Cutting Off Prompt Effect
2 |
3 | 
4 |
5 |
6 | Update Info
7 |
8 | Upper is newer.
9 |
10 |
11 | - 026ff95a492a533a4a6e5fb2959c2324258c232c
12 | - SDXL support.
13 | - 527ed922b2c4f8d2620376589dfce0f9f4b622ad
14 | - Add support for the newer version of WebUI.
15 | - 20e87ce264338b824296b7559679ed1bb0bdacd7
16 | - Skip empty targets.
17 | - 03bfe60162ba418e18dbaf8f1b9711fd62195ef3
18 | - Add
Disable for Negative prompt
option. Default is True
.
19 | - f0990088fed0f5013a659cacedb194313a398860
20 | - Accept an empty prompt.
21 |
22 |
23 |
24 | ## What is this?
25 |
26 | This is an extension for [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) which limits the tokens' influence scope.
27 |
28 | SDv1, SDv2 and SDXL are supported.
29 |
30 | ## Usage
31 |
32 | 1. Select `Enabled` checkbox.
33 | 2. Input words which you want to limit scope in `Target tokens`.
34 | 3. Generate images.
35 |
36 | ## Note
37 |
38 | If the generated image was corrupted or something like that, try to change the `Weight` value or change the interpolation method to `SLerp`. Interpolation method can be found in `Details`.
39 |
40 | ### `Details` section
41 |
42 |
43 | - Disable for Negative prompt.
44 | - If enabled, Cutoff will not work for the negative prompt. Default is
true
.
45 | - Cutoff strongly.
46 | - See description below. Default is
false
.
47 | - Interpolation method
48 | - How "padded" and "original" vectors will be interpolated. Default is
Lerp
.
49 | - Padding token
50 | - What token will be padded instead of
Target tokens
. Default is _
(underbar).
51 |
52 |
53 | ## Examples
54 |
55 | ### SDv1
56 |
57 | ```
58 | 7th_anime_v3_A-fp16 / kl-f8-anime2 / DPM++ 2M Karras / 15 steps / 512x768
59 | Prompt: a cute girl, white shirt with green tie, red shoes, blue hair, yellow eyes, pink skirt
60 | Negative Prompt: (low quality, worst quality:1.4), nsfw
61 | Target tokens: white, green, red, blue, yellow, pink
62 | ```
63 |
64 | Sample 1.
65 |
66 | 
67 |
68 | Sample 2. (use `SLerp` for interpolation)
69 |
70 | 
71 |
72 | Sample 3.
73 |
74 | 
75 |
76 | ### SDXL
77 |
78 | It seems that the Stability AI's base model of SDXL is much improved on token separation. So the effect of `cutoff` is limited.
79 |
80 | ```
81 | (some models) / sdxl_vae / DPM++ 3M SDE / 50 steps / 768x1344
82 | Prompt: full body shot of a cute girl, wearing white shirt with green tie, red shoes, blue hair, yellow eyes, pink skirt
83 | Negative Prompt: (low quality, worst quality:1.4), nsfw, close up
84 | Target tokens: white, green, red, blue, yellow, pink
85 | Cutoff weight: 1.0
86 | ```
87 |
88 | Sample 4. (Model = `sd_xl_base_1.0`)
89 |
90 | 
91 |
92 | Sample 5. (Model = `hassakuXLSfwNsfw_alphaV07`)
93 |
94 | 
95 |
96 | ## How it works
97 |
98 | - [Japanese](#japanese)
99 | - [English](#english)
100 |
101 | or see [#5](https://github.com/hnmr293/sd-webui-cutoff/issues/5).
102 |
103 | 
104 |
105 | ### Japanese
106 |
107 | プロンプトをCLIPに通して得られる (77, 768) 次元の埋め込み表現(?正式な用語は分かりません)について、
108 | ごく単純には、77個の行ベクトルはプロンプト中の75個のトークン(+開始トークン+終了トークン)に対応していると考えられる。
109 |
110 | ※上図は作図上、この説明とは行と列を入れ替えて描いている。
111 |
112 | このベクトルには単語単体の意味だけではなく、文章全体の、例えば係り結びなどの情報を集約したものが入っているはずである。
113 |
114 | ここで `a cute girl, pink hair, red shoes` というプロンプトを考える。
115 | 普通、こういったプロンプトの意図は
116 |
117 | 1. `pink` は `hair` だけに係っており `shoes` には係っていない。
118 | 2. 同様に `red` も `hair` には係っていない。
119 | 3. `a cute girl` は全体に係っていて欲しい。`hair` や `shoes` は女の子に合うものが出て欲しい。
120 |
121 | ……というもののはずである。
122 |
123 | しかしながら、[EvViz2](https://github.com/hnmr293/sd-webui-evviz2) などでトークン間の関係を見ると、そううまくはいっていないことが多い。
124 | つまり、`shoes` の位置のベクトルに `pink` の影響が出てしまっていたりする。
125 |
126 | 一方で上述の通り `a cute girl` の影響は乗っていて欲しいわけで、どうにかして、特定のトークンの影響を取り除けるようにしたい。
127 |
128 | この拡張では、指定されたトークンを *padding token* に書き換えることでそれを実現している。
129 |
130 | たとえば `red shoes` の部分に対応して `a cute girl, _ hair, red shoes` というプロンプトを生成する。`red` と `shoes` に対応する位置のベクトルをここから生成したもので上書きしてやることで、`pink` の影響を除外している。
131 |
132 | これを `pink` の側から見ると、自分の影響が `pink hair` の範囲内に制限されているように見える。What is this? の "limits the tokens' influence scope" はそういう意味。
133 |
134 | ところで `a cute girl` の方は、`pink hair, red shoes` の影響を受けていてもいいし受けなくてもいいような気がする。
135 | そこでこの拡張では、こういうどちらでもいいプロンプトに対して
136 |
137 | 1. `a cute girl, pink hair, red shoes`
138 | 2. `a cute girl, _ hair, _ shoes`
139 |
140 | のどちらを適用するか選べるようにしている。`Details` の `Cutoff strongly` がそれで、オフのとき1.を、オンのとき2.を、それぞれ選ぶようになっている。
141 | 元絵に近いのが出るのはオフのとき。デフォルトもこちらにしてある。
142 |
143 | ### English
144 |
145 | NB. The following text is a translation of the Japanese text above by [DeepL](https://www.deepl.com/translator).
146 |
147 | For the (77, 768) dimensional embedded representation (I don't know the formal terminology), one could simply assume that the 77 row vectors correspond to the 75 tokens (+ start token and end token) in the prompt.
148 |
149 | Note: The above figure is drawn with the rows and columns interchanged from this explanation.
150 |
151 | This vector should contain not only the meanings of individual words, but also the aggregate information of the whole sentence, for example, the connection between words.
152 |
153 | Consider the prompt `a cute girl, pink hair, red shoes`. Normally, the intent of such a prompt would be
154 |
155 | - `pink` is only for `hair`, not `shoes`.
156 | - Similarly, `red` does not refer to `hair`.
157 | - We want `a cute girl` to be about the whole thing, and we want the `hair` and `shoes` to match the girl.
158 |
159 | However, when we look at the relationship between tokens in [EvViz2](https://github.com/hnmr293/sd-webui-evviz2) and other tools, we see that it is not always that way. In other words, the position vector of the `shoes` may be affected by `pink`.
160 |
161 | On the other hand, as mentioned above, we want the influence of `a cute girl` to be present, so we want to be able to somehow remove the influence of a specific token.
162 |
163 | This extension achieves this by rewriting the specified tokens as a *padding token*.
164 |
165 | For example, for the `red shoes` part, we generate the prompt `a cute girl, _ hair, red shoes`, and by overwriting the position vectors corresponding to `red` and `shoes` with those generated from this prompt, we remove the influence of `pink`.
166 |
167 | From `pink`'s point of view, it appears that its influence is limited to the `pink hair`'s scope.
168 |
169 | By the way, `a cute girl` may or may not be influenced by `pink hair` and `red shoes`. So, in this extension, for such a prompt that can be either
170 |
171 | 1. `a cute girl, pink hair, red shoes`
172 | 2. `a cute girl, _ hair, _ shoes`
173 |
174 | The `Cutoff strongly` in the `Details` section allows you to select 1 when it is off and 2 when it is on. The one that comes out closer to the original image is "off". The default is also set this way.
175 |
--------------------------------------------------------------------------------
/images/cover.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hnmr293/sd-webui-cutoff/33f102c009da7c8934adeabab19ecfc4aa757f52/images/cover.jpg
--------------------------------------------------------------------------------
/images/idea.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hnmr293/sd-webui-cutoff/33f102c009da7c8934adeabab19ecfc4aa757f52/images/idea.png
--------------------------------------------------------------------------------
/images/sample-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hnmr293/sd-webui-cutoff/33f102c009da7c8934adeabab19ecfc4aa757f52/images/sample-1.png
--------------------------------------------------------------------------------
/images/sample-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hnmr293/sd-webui-cutoff/33f102c009da7c8934adeabab19ecfc4aa757f52/images/sample-2.png
--------------------------------------------------------------------------------
/images/sample-3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hnmr293/sd-webui-cutoff/33f102c009da7c8934adeabab19ecfc4aa757f52/images/sample-3.png
--------------------------------------------------------------------------------
/images/sample-4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hnmr293/sd-webui-cutoff/33f102c009da7c8934adeabab19ecfc4aa757f52/images/sample-4.png
--------------------------------------------------------------------------------
/images/sample-4_small.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hnmr293/sd-webui-cutoff/33f102c009da7c8934adeabab19ecfc4aa757f52/images/sample-4_small.png
--------------------------------------------------------------------------------
/images/sample-5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hnmr293/sd-webui-cutoff/33f102c009da7c8934adeabab19ecfc4aa757f52/images/sample-5.png
--------------------------------------------------------------------------------
/images/sample-5_small.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hnmr293/sd-webui-cutoff/33f102c009da7c8934adeabab19ecfc4aa757f52/images/sample-5_small.png
--------------------------------------------------------------------------------
/scripts/cutoff.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from typing import Union, List, Tuple, Dict
3 |
4 | import numpy as np
5 | import torch
6 | from torch import Tensor, nn
7 | import gradio as gr
8 |
9 | from modules.processing import StableDiffusionProcessing
10 | from modules import scripts
11 |
12 | from scripts.cutofflib.sdhook import SDHook
13 | from scripts.cutofflib.embedding import CLIP, CLIP_SDXL, generate_prompts, token_to_block
14 | from scripts.cutofflib.utils import log, set_debug
15 | from scripts.cutofflib.xyz import init_xyz
16 |
17 | NAME = 'Cutoff'
18 | PAD = '_'
19 |
20 | def check_neg(s: str, negative_prompt: str, all_negative_prompts: Union[List[str],None]):
21 | if s == negative_prompt:
22 | return True
23 |
24 | if all_negative_prompts is not None:
25 | return s in all_negative_prompts
26 |
27 | return False
28 |
29 | def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
30 | # cf. https://memo.sugyan.com/entry/2022/09/09/230645
31 |
32 | inputs_are_torch = False
33 | input_device = v0.device
34 | if not isinstance(v0, np.ndarray):
35 | inputs_are_torch = True
36 | v0 = v0.cpu().numpy()
37 | v1 = v1.cpu().numpy()
38 |
39 | dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
40 | if np.abs(dot) > DOT_THRESHOLD:
41 | v2 = (1 - t) * v0 + t * v1
42 | else:
43 | theta_0 = np.arccos(dot)
44 | sin_theta_0 = np.sin(theta_0)
45 | theta_t = theta_0 * t
46 | sin_theta_t = np.sin(theta_t)
47 | s0 = np.sin(theta_0 - theta_t) / sin_theta_0
48 | s1 = sin_theta_t / sin_theta_0
49 | v2 = s0 * v0 + s1 * v1
50 |
51 | if inputs_are_torch:
52 | v2 = torch.from_numpy(v2).to(input_device)
53 |
54 | return v2
55 |
56 |
57 | class Hook(SDHook):
58 |
59 | def __init__(
60 | self,
61 | enabled: bool,
62 | targets: List[str],
63 | padding: Union[str,int],
64 | weight: float,
65 | disable_neg: bool,
66 | strong: bool,
67 | interpolate: str,
68 | ):
69 | super().__init__(enabled)
70 | self.targets = targets
71 | self.padding = padding
72 | self.weight = float(weight)
73 | self.disable_neg = disable_neg
74 | self.strong = strong
75 | self.intp = interpolate
76 |
77 | def interpolate(self, t1: Tensor, t2: Tensor, w):
78 | if self.intp == 'lerp':
79 | return torch.lerp(t1, t2, w)
80 | else:
81 | return slerp(w, t1, t2)
82 |
83 | def hook_clip(self, p: StableDiffusionProcessing, clip: nn.Module):
84 |
85 | skip = False
86 |
87 | def hook(mod: nn.Module, inputs: Tuple[Union[List[str],Dict[str,Tensor]]], output: Union[Tensor,Dict[str,Tensor]]):
88 | nonlocal skip
89 |
90 | if skip:
91 | # called from below
92 | return
93 |
94 | if not hasattr(p.sd_model, 'is_sdxl') or not p.sd_model.is_sdxl:
95 | # SD
96 | assert isinstance(mod, CLIP)
97 | prompts = inputs[0]
98 | output = output.clone()
99 | output_vector = output
100 | def process(prompts):
101 | return mod(prompts)
102 | else:
103 | # SDXL
104 | assert isinstance(mod, CLIP_SDXL)
105 | prompts = inputs[0]['txt']
106 | output = { k : v.clone() for k, v in output.items() }
107 | assert 'crossattn' in output, f'output keys: {", ".join(output.keys())}'
108 | output_vector = output['crossattn']
109 | def process(prompts):
110 | new_inputs = dict()
111 | for k, v in inputs[0].items():
112 | if isinstance(v, Tensor):
113 | new_inputs[k] = torch.cat([v]*len(prompts), 0)
114 | else:
115 | new_inputs[k] = [v]*len(prompts)
116 | new_inputs['txt'] = prompts
117 | vs = mod(new_inputs)
118 | return vs['crossattn']
119 |
120 | assert len(prompts) == output_vector.shape[0], f"number of prompts different than expected: {len(prompts)} != {output_vector.shape[0]}"
121 |
122 | # Check wether we are processing Negative prompt or not.
123 | # I firmly believe there is no one who uses a negative prompt
124 | # exactly identical to a prompt.
125 | if self.disable_neg:
126 | if all(check_neg(x, p.negative_prompt, p.all_negative_prompts) for x in prompts):
127 | # Now we are processing Negative prompt and skip it.
128 | return
129 |
130 | for pidx, prompt in enumerate(prompts):
131 | tt = token_to_block(mod, prompt)
132 |
133 | cutoff = generate_prompts(mod, prompt, self.targets, self.padding)
134 | switch_base = np.full_like(cutoff.sw, self.strong)
135 | switch = np.full_like(cutoff.sw, True)
136 | active = cutoff.active_blocks()
137 |
138 | prompt_to_tokens = defaultdict(lambda: [])
139 | for tidx, (token, block_index) in enumerate(tt):
140 | if block_index in active:
141 | sw = switch.copy()
142 | sw[block_index] = False
143 | prompt = cutoff.text(sw)
144 | else:
145 | prompt = cutoff.text(switch_base)
146 | prompt_to_tokens[prompt].append((tidx, token))
147 |
148 | #log(prompt_to_tokens)
149 |
150 | ks = list(prompt_to_tokens.keys())
151 | if len(ks) == 0:
152 | # without any (negative) prompts
153 | ks.append('')
154 |
155 | try:
156 | #
157 | skip = True
158 | vs = process(ks)
159 | finally:
160 | skip = False
161 |
162 | tensor = output_vector[pidx, :, :] # e.g. (77, 768)
163 | for k, t in zip(ks, vs):
164 | assert tensor.shape == t.shape, f"tensor and t must have same shape\ntensor: {tensor.shape}\n t:{t.shape}"
165 | for tidx, token in prompt_to_tokens[k]:
166 | log(f'{tidx:03} {token.token:<16} {k}')
167 | tensor[tidx, :] = self.interpolate(tensor[tidx,:], t[tidx,:], self.weight)
168 |
169 | return output
170 |
171 | self.hook_layer(clip, hook)
172 |
173 |
174 | class Script(scripts.Script):
175 |
176 | def __init__(self):
177 | super().__init__()
178 | self.last_hooker: Union[SDHook,None] = None
179 |
180 | def title(self):
181 | return NAME
182 |
183 | def show(self, is_img2img):
184 | return scripts.AlwaysVisible
185 |
186 | def ui(self, is_img2img):
187 | with gr.Accordion(NAME, open=False):
188 | enabled = gr.Checkbox(label='Enabled', value=False)
189 | targets = gr.Textbox(label='Target tokens (comma separated)', placeholder='red, blue')
190 | weight = gr.Slider(minimum=-1.0, maximum=2.0, step=0.01, value=0.5, label='Weight')
191 | with gr.Accordion('Details', open=False):
192 | disable_neg = gr.Checkbox(value=True, label='Disable for Negative prompt.')
193 | strong = gr.Checkbox(value=False, label='Cutoff strongly.')
194 | padding = gr.Textbox(label='Padding token (ID or single token)')
195 | lerp = gr.Radio(choices=['Lerp', 'SLerp'], value='Lerp', label='Interpolation method')
196 |
197 | debug = gr.Checkbox(value=False, label='Debug log')
198 | debug.change(fn=set_debug, inputs=[debug], outputs=[])
199 |
200 | return [
201 | enabled,
202 | targets,
203 | weight,
204 | disable_neg,
205 | strong,
206 | padding,
207 | lerp,
208 | debug,
209 | ]
210 |
211 | def process(
212 | self,
213 | p: StableDiffusionProcessing,
214 | enabled: bool,
215 | targets_: str,
216 | weight: Union[float,int],
217 | disable_neg: bool,
218 | strong: bool,
219 | padding: Union[str,int],
220 | intp: str,
221 | debug: bool,
222 | ):
223 | set_debug(debug)
224 |
225 | if self.last_hooker is not None:
226 | self.last_hooker.__exit__(None, None, None)
227 | self.last_hooker = None
228 |
229 | if not enabled:
230 | return
231 |
232 | if targets_ is None or len(targets_) == 0:
233 | return
234 |
235 | targets = [x.strip() for x in targets_.split(',')]
236 | targets = [x for x in targets if len(x) != 0]
237 |
238 | if len(targets) == 0:
239 | return
240 |
241 | if padding is None:
242 | padding = PAD
243 | elif isinstance(padding, str):
244 | if len(padding) == 0:
245 | padding = PAD
246 | else:
247 | try:
248 | padding = int(padding)
249 | except:
250 | if not padding.endswith(''):
251 | padding += ''
252 |
253 | weight = float(weight)
254 | intp = intp.lower()
255 |
256 | self.last_hooker = Hook(
257 | enabled=True,
258 | targets=targets,
259 | padding=padding,
260 | weight=weight,
261 | disable_neg=disable_neg,
262 | strong=strong,
263 | interpolate=intp,
264 | )
265 |
266 | self.last_hooker.setup(p)
267 | self.last_hooker.__enter__()
268 |
269 | p.extra_generation_params.update({
270 | f'{NAME} enabled': enabled,
271 | f'{NAME} targets': targets,
272 | f'{NAME} padding': padding,
273 | f'{NAME} weight': weight,
274 | f'{NAME} disable_for_neg': disable_neg,
275 | f'{NAME} strong': strong,
276 | f'{NAME} interpolation': intp,
277 | })
278 |
279 | if hasattr(p, 'cached_c'):
280 | p.cached_c = [None, None]
281 | if hasattr(p, 'cached_uc'):
282 | p.cached_uc = [None, None]
283 |
284 | init_xyz(Script, NAME)
285 |
--------------------------------------------------------------------------------
/scripts/cutofflib/embedding.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from itertools import product
3 | import re
4 | from typing import Union, List, Tuple
5 | import numpy as np
6 | import open_clip
7 | from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWordsBase as CLIP
8 | try:
9 | from sgm.modules import GeneralConditioner as CLIP_SDXL
10 | except:
11 | print("[Cutoff] failed to load `sgm.modules.GeneralConditioner`")
12 | CLIP_SDXL = int
13 | from modules import prompt_parser, shared
14 | from scripts.cutofflib.utils import log
15 |
16 | class ClipWrapper:
17 | def __init__(self, te: Union[CLIP,CLIP_SDXL]):
18 | self.te = te
19 | self.v1 = hasattr(te.wrapped, 'tokenizer')
20 | self.sdxl = hasattr(te, 'embedders')
21 | self.t = (
22 | te.wrapped.tokenizer if self.v1
23 | else open_clip.tokenizer._tokenizer
24 | )
25 |
26 | def token_to_id(self, token: str) -> int:
27 | if self.v1:
28 | return self.t._convert_token_to_id(token) # type: ignore
29 | else:
30 | return self.t.encoder[token]
31 |
32 | def id_to_token(self, id: int) -> str:
33 | if self.v1:
34 | return self.t.convert_ids_to_tokens(id) # type: ignore
35 | else:
36 | return self.t.decoder[id]
37 |
38 | def ids_to_tokens(self, ids: List[int]) -> List[str]:
39 | if self.v1:
40 | return self.t.convert_ids_to_tokens(ids) # type: ignore
41 | else:
42 | return [self.t.decoder[id] for id in ids]
43 |
44 | def token(self, token: Union[int,str]):
45 | if isinstance(token, int):
46 | return Token(token, self.id_to_token(token))
47 | else:
48 | return Token(self.token_to_id(token), token)
49 |
50 | @property
51 | def id_start(self):
52 | if self.sdxl:
53 | return self.te.embedders[0].id_start
54 | else:
55 | return self.te.id_start
56 |
57 | @property
58 | def id_end(self):
59 | if self.sdxl:
60 | return self.te.embedders[0].id_end
61 | else:
62 | return self.te.id_end
63 |
64 | @property
65 | def hijack(self):
66 | if self.sdxl:
67 | return self.te.embedders[0].hijack
68 | else:
69 | return self.te.hijack
70 |
71 |
72 | @dataclass
73 | class Token:
74 | id: int
75 | token: str
76 |
77 | class CutoffPrompt:
78 |
79 | @staticmethod
80 | def _cutoff(prompt: str, clip: CLIP, tokens: List[str], padding: str):
81 | def token_count(text: str):
82 | te = ClipWrapper(clip)
83 | tt = token_to_block(clip, text)
84 | # tt[0] == te.id_start (<|startoftext|>)
85 | for index, (t, _) in enumerate(tt):
86 | if t.id == te.id_end: # <|endoftext|>
87 | return index - 1
88 | return 0 # must not happen...
89 |
90 | re_targets = [ re.compile(r'\b' + re.escape(x) + r'\b') for x in tokens ]
91 | replacer = [ ' ' + ' '.join([padding] * token_count(x)) + ' ' for x in tokens ]
92 |
93 | rows: List[Tuple[str,str]] = []
94 | for block in prompt.split(','):
95 | b0 = block
96 | for r, p in zip(re_targets, replacer):
97 | block = r.sub(p, block)
98 | b1 = block
99 | rows.append((b0, b1))
100 |
101 | return rows
102 |
103 | def __init__(self, prompt: str, clip: CLIP, tokens: List[str], padding: str):
104 | self.prompt = prompt
105 | rows = CutoffPrompt._cutoff(prompt, clip, tokens, padding)
106 | self.base = np.array([x[0] for x in rows])
107 | self.cut = np.array([x[1] for x in rows])
108 | self.sw = np.array([False] * len(rows))
109 |
110 | @property
111 | def block_count(self):
112 | return self.base.shape[0]
113 |
114 | def switch(self, block_index: int, to: Union[bool,None] = None):
115 | if to is None:
116 | to = not self.sw[block_index]
117 | self.sw[block_index] = to
118 | return to
119 |
120 | def text(self, sw=None):
121 | if sw is None:
122 | sw = self.sw
123 | blocks = np.where(sw, self.cut, self.base)
124 | return ','.join(blocks)
125 |
126 | def active_blocks(self) -> np.ndarray:
127 | indices, = (self.base != self.cut).nonzero()
128 | return indices
129 |
130 | def generate(self):
131 | indices = self.active_blocks()
132 | for diff_sw in product([False, True], repeat=indices.shape[0]):
133 | sw = np.full_like(self.sw, False)
134 | sw[indices] = diff_sw
135 | yield diff_sw, self.text(sw)
136 |
137 |
138 | def generate_prompts(
139 | clip: CLIP,
140 | prompt: str,
141 | targets: List[str],
142 | padding: Union[str,int,Token],
143 | ) -> CutoffPrompt:
144 |
145 | te = ClipWrapper(clip)
146 |
147 | if not isinstance(padding, Token):
148 | o_pad = padding
149 | padding = te.token(padding)
150 | if padding.id == te.id_end:
151 | raise ValueError(f'`{o_pad}` is not a valid token.')
152 |
153 | result = CutoffPrompt(prompt, clip, targets, padding.token.replace('', ''))
154 |
155 | log(f'[Cutoff] replace: {", ".join(targets)}')
156 | log(f'[Cutoff] to: {padding.token} ({padding.id})')
157 | log(f'[Cutoff] original: {prompt}')
158 | for i, (_, pp) in enumerate(result.generate()):
159 | log(f'[Cutoff] #{i}: {pp}')
160 |
161 | return result
162 |
163 |
164 | def token_to_block(clip: CLIP, prompt: str):
165 | te = ClipWrapper(clip)
166 |
167 | # cf. sd_hijack_clip.py
168 |
169 | parsed = prompt_parser.parse_prompt_attention(prompt)
170 | tokenized: List[List[int]] = clip.tokenize([text for text, _ in parsed])
171 |
172 | CHUNK_LENGTH = 75
173 | id_start = te.token(te.id_start) # type: ignore
174 | id_end = te.token(te.id_end) # type: ignore
175 | comma = te.token(',')
176 |
177 | last_comma = -1
178 | current_block = 0
179 | current_tokens: List[Tuple[Token,int]] = []
180 | result: List[Tuple[Token,int]] = []
181 |
182 | def next_chunk():
183 | nonlocal current_tokens, last_comma
184 |
185 | to_add = CHUNK_LENGTH - len(current_tokens)
186 | if 0 < to_add:
187 | current_tokens += [(id_end, -1)] * to_add
188 |
189 | current_tokens = [(id_start, -1)] + current_tokens + [(id_end, -1)]
190 |
191 | last_comma = -1
192 | result.extend(current_tokens)
193 | current_tokens = []
194 |
195 | for tokens, (text, weight) in zip(tokenized, parsed):
196 | if text == 'BREAK' and weight == -1:
197 | next_chunk()
198 | continue
199 |
200 | p = 0
201 | while p < len(tokens):
202 | token = tokens[p]
203 |
204 | if token == comma.id:
205 | last_comma = len(current_tokens)
206 | current_block += 1
207 |
208 | elif (
209 | shared.opts.comma_padding_backtrack != 0
210 | and len(current_tokens) == CHUNK_LENGTH
211 | and last_comma != -1
212 | and len(current_tokens) - last_comma <= shared.opts.comma_padding_backtrack
213 | ):
214 | break_location = last_comma + 1
215 | reloc_tokens = current_tokens[break_location:]
216 | current_tokens = current_tokens[:break_location]
217 | next_chunk()
218 | current_tokens = reloc_tokens
219 |
220 | if len(current_tokens) == CHUNK_LENGTH:
221 | next_chunk()
222 |
223 | embedding, embedding_length_in_tokens = te.hijack.embedding_db.find_embedding_at_position(tokens, p)
224 | if embedding is None:
225 | if token == comma.id:
226 | current_tokens.append((te.token(token), -1))
227 | else:
228 | current_tokens.append((te.token(token), current_block))
229 | p += 1
230 | continue
231 |
232 | emb_len = int(embedding.vec.shape[0])
233 | if len(current_tokens) + emb_len > CHUNK_LENGTH:
234 | next_chunk()
235 |
236 | current_tokens += [(te.token(0), current_block)] * emb_len
237 | p += embedding_length_in_tokens
238 |
239 | if len(current_tokens) > 0:
240 | next_chunk()
241 |
242 | return result
243 |
--------------------------------------------------------------------------------
/scripts/cutofflib/sdhook.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from typing import Any, Callable, Union
3 |
4 | from torch import nn
5 | from torch.utils.hooks import RemovableHandle
6 |
7 | from ldm.modules.diffusionmodules.openaimodel import (
8 | TimestepEmbedSequential,
9 | )
10 | from ldm.modules.attention import (
11 | SpatialTransformer,
12 | BasicTransformerBlock,
13 | CrossAttention,
14 | MemoryEfficientCrossAttention,
15 | )
16 | from ldm.modules.diffusionmodules.openaimodel import (
17 | ResBlock,
18 | )
19 | from modules.processing import StableDiffusionProcessing
20 | from modules import shared
21 |
22 | class ForwardHook:
23 |
24 | def __init__(self, module: nn.Module, fn: Callable[[nn.Module, Callable[..., Any], Any], Any]):
25 | self.o = module.forward
26 | self.fn = fn
27 | self.module = module
28 | self.module.forward = self.forward
29 |
30 | def remove(self):
31 | if self.module is not None and self.o is not None:
32 | self.module.forward = self.o
33 | self.module = None
34 | self.o = None
35 | self.fn = None
36 |
37 | def forward(self, *args, **kwargs):
38 | if self.module is not None and self.o is not None:
39 | if self.fn is not None:
40 | return self.fn(self.module, self.o, *args, **kwargs)
41 | return None
42 |
43 |
44 | class SDHook:
45 |
46 | def __init__(self, enabled: bool):
47 | self._enabled = enabled
48 | self._handles: list[Union[RemovableHandle,ForwardHook]] = []
49 |
50 | @property
51 | def enabled(self):
52 | return self._enabled
53 |
54 | @property
55 | def batch_num(self):
56 | return shared.state.job_no
57 |
58 | @property
59 | def step_num(self):
60 | return shared.state.current_image_sampling_step
61 |
62 | def __enter__(self):
63 | if self.enabled:
64 | pass
65 |
66 | def __exit__(self, exc_type, exc_value, traceback):
67 | if self.enabled:
68 | for handle in self._handles:
69 | handle.remove()
70 | self._handles.clear()
71 | self.dispose()
72 |
73 | def dispose(self):
74 | pass
75 |
76 | def setup(
77 | self,
78 | p: StableDiffusionProcessing
79 | ):
80 | if not self.enabled:
81 | return
82 |
83 | wrapper = getattr(p.sd_model, "model", None)
84 |
85 | unet: Union[nn.Module,None] = getattr(wrapper, "diffusion_model", None) if wrapper is not None else None
86 | vae: Union[nn.Module,None] = getattr(p.sd_model, "first_stage_model", None)
87 | clip: Union[nn.Module,None] = getattr(p.sd_model, "cond_stage_model", None)
88 |
89 | assert unet is not None, "p.sd_model.diffusion_model is not found. broken model???"
90 | self._do_hook(p, p.sd_model, unet=unet, vae=vae, clip=clip) # type: ignore
91 | self.on_setup()
92 |
93 | def on_setup(self):
94 | pass
95 |
96 | def _do_hook(
97 | self,
98 | p: StableDiffusionProcessing,
99 | model: Any,
100 | unet: Union[nn.Module,None],
101 | vae: Union[nn.Module,None],
102 | clip: Union[nn.Module,None]
103 | ):
104 | assert model is not None, "empty model???"
105 |
106 | if clip is not None:
107 | self.hook_clip(p, clip)
108 |
109 | if unet is not None:
110 | self.hook_unet(p, unet)
111 |
112 | if vae is not None:
113 | self.hook_vae(p, vae)
114 |
115 | def hook_vae(
116 | self,
117 | p: StableDiffusionProcessing,
118 | vae: nn.Module
119 | ):
120 | pass
121 |
122 | def hook_unet(
123 | self,
124 | p: StableDiffusionProcessing,
125 | unet: nn.Module
126 | ):
127 | pass
128 |
129 | def hook_clip(
130 | self,
131 | p: StableDiffusionProcessing,
132 | clip: nn.Module
133 | ):
134 | pass
135 |
136 | def hook_layer(
137 | self,
138 | module: Union[nn.Module,Any],
139 | fn: Callable[[nn.Module, tuple, Any], Any]
140 | ):
141 | if not self.enabled:
142 | return
143 |
144 | assert module is not None
145 | assert isinstance(module, nn.Module)
146 | self._handles.append(module.register_forward_hook(fn))
147 |
148 | def hook_layer_pre(
149 | self,
150 | module: Union[nn.Module,Any],
151 | fn: Callable[[nn.Module, tuple], Any]
152 | ):
153 | if not self.enabled:
154 | return
155 |
156 | assert module is not None
157 | assert isinstance(module, nn.Module)
158 | self._handles.append(module.register_forward_pre_hook(fn))
159 |
160 | def hook_forward(
161 | self,
162 | module: Union[nn.Module,Any],
163 | fn: Callable[[nn.Module, Callable[..., Any], Any], Any]
164 | ):
165 | assert module is not None
166 | assert isinstance(module, nn.Module)
167 | self._handles.append(ForwardHook(module, fn))
168 |
169 | def log(self, msg: str):
170 | print(msg, file=sys.stderr)
171 |
172 |
173 | # enumerate SpatialTransformer in TimestepEmbedSequential
174 | def each_transformer(unet_block: TimestepEmbedSequential):
175 | for block in unet_block.children():
176 | if isinstance(block, SpatialTransformer):
177 | yield block
178 |
179 | # enumerate BasicTransformerBlock in SpatialTransformer
180 | def each_basic_block(trans: SpatialTransformer):
181 | for block in trans.transformer_blocks.children():
182 | if isinstance(block, BasicTransformerBlock):
183 | yield block
184 |
185 | # enumerate Attention Layers in TimestepEmbedSequential
186 | # each_transformer + each_basic_block
187 | def each_attns(unet_block: TimestepEmbedSequential):
188 | for n, trans in enumerate(each_transformer(unet_block)):
189 | for depth, basic_block in enumerate(each_basic_block(trans)):
190 | # attn1: Union[CrossAttention,MemoryEfficientCrossAttention]
191 | # attn2: Union[CrossAttention,MemoryEfficientCrossAttention]
192 |
193 | attn1, attn2 = basic_block.attn1, basic_block.attn2
194 | assert isinstance(attn1, CrossAttention) or isinstance(attn1, MemoryEfficientCrossAttention)
195 | assert isinstance(attn2, CrossAttention) or isinstance(attn2, MemoryEfficientCrossAttention)
196 |
197 | yield n, depth, attn1, attn2
198 |
199 | def each_unet_attn_layers(unet: nn.Module):
200 | def get_attns(layer_index: int, block: TimestepEmbedSequential, format: str):
201 | for n, d, attn1, attn2 in each_attns(block):
202 | kwargs = {
203 | 'layer_index': layer_index,
204 | 'trans_index': n,
205 | 'block_index': d
206 | }
207 | yield format.format(attn_name='sattn', **kwargs), attn1
208 | yield format.format(attn_name='xattn', **kwargs), attn2
209 |
210 | def enumerate_all(blocks: nn.ModuleList, format: str):
211 | for idx, block in enumerate(blocks.children()):
212 | if isinstance(block, TimestepEmbedSequential):
213 | yield from get_attns(idx, block, format)
214 |
215 | inputs: nn.ModuleList = unet.input_blocks # type: ignore
216 | middle: TimestepEmbedSequential = unet.middle_block # type: ignore
217 | outputs: nn.ModuleList = unet.output_blocks # type: ignore
218 |
219 | yield from enumerate_all(inputs, 'IN{layer_index:02}_{trans_index:02}_{block_index:02}_{attn_name}')
220 | yield from get_attns(0, middle, 'M{layer_index:02}_{trans_index:02}_{block_index:02}_{attn_name}')
221 | yield from enumerate_all(outputs, 'OUT{layer_index:02}_{trans_index:02}_{block_index:02}_{attn_name}')
222 |
223 |
224 | def each_unet_transformers(unet: nn.Module):
225 | def get_trans(layer_index: int, block: TimestepEmbedSequential, format: str):
226 | for n, trans in enumerate(each_transformer(block)):
227 | kwargs = {
228 | 'layer_index': layer_index,
229 | 'block_index': n,
230 | 'block_name': 'trans',
231 | }
232 | yield format.format(**kwargs), trans
233 |
234 | def enumerate_all(blocks: nn.ModuleList, format: str):
235 | for idx, block in enumerate(blocks.children()):
236 | if isinstance(block, TimestepEmbedSequential):
237 | yield from get_trans(idx, block, format)
238 |
239 | inputs: nn.ModuleList = unet.input_blocks # type: ignore
240 | middle: TimestepEmbedSequential = unet.middle_block # type: ignore
241 | outputs: nn.ModuleList = unet.output_blocks # type: ignore
242 |
243 | yield from enumerate_all(inputs, 'IN{layer_index:02}_{block_index:02}_{block_name}')
244 | yield from get_trans(0, middle, 'M{layer_index:02}_{block_index:02}_{block_name}')
245 | yield from enumerate_all(outputs, 'OUT{layer_index:02}_{block_index:02}_{block_name}')
246 |
247 |
248 | def each_resblock(unet_block: TimestepEmbedSequential):
249 | for block in unet_block.children():
250 | if isinstance(block, ResBlock):
251 | yield block
252 |
253 | def each_unet_resblock(unet: nn.Module):
254 | def get_resblock(layer_index: int, block: TimestepEmbedSequential, format: str):
255 | for n, res in enumerate(each_resblock(block)):
256 | kwargs = {
257 | 'layer_index': layer_index,
258 | 'block_index': n,
259 | 'block_name': 'resblock',
260 | }
261 | yield format.format(**kwargs), res
262 |
263 | def enumerate_all(blocks: nn.ModuleList, format: str):
264 | for idx, block in enumerate(blocks.children()):
265 | if isinstance(block, TimestepEmbedSequential):
266 | yield from get_resblock(idx, block, format)
267 |
268 | inputs: nn.ModuleList = unet.input_blocks # type: ignore
269 | middle: TimestepEmbedSequential = unet.middle_block # type: ignore
270 | outputs: nn.ModuleList = unet.output_blocks # type: ignore
271 |
272 | yield from enumerate_all(inputs, 'IN{layer_index:02}_{block_index:02}_{block_name}')
273 | yield from get_resblock(0, middle, 'M{layer_index:02}_{block_index:02}_{block_name}')
274 | yield from enumerate_all(outputs, 'OUT{layer_index:02}_{block_index:02}_{block_name}')
275 |
276 |
--------------------------------------------------------------------------------
/scripts/cutofflib/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | _debug = False
4 |
5 | def set_debug(is_debug: bool):
6 | global _debug
7 | _debug = is_debug
8 |
9 | def log(s: str):
10 | if _debug:
11 | print(s, file=sys.stderr)
12 |
--------------------------------------------------------------------------------
/scripts/cutofflib/xyz.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Union, List, Callable
3 |
4 | from modules import scripts
5 | from modules.processing import StableDiffusionProcessing, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
6 |
7 |
8 | def __set_value(p: StableDiffusionProcessing, script: type, index: int, value):
9 | args = list(p.script_args)
10 |
11 | if isinstance(p, StableDiffusionProcessingTxt2Img):
12 | all_scripts = scripts.scripts_txt2img.scripts
13 | else:
14 | all_scripts = scripts.scripts_img2img.scripts
15 |
16 | froms = [x.args_from for x in all_scripts if isinstance(x, script)]
17 | for idx in froms:
18 | assert idx is not None
19 | args[idx + index] = value
20 |
21 | p.script_args = type(p.script_args)(args)
22 |
23 |
24 | def to_bool(v: str):
25 | if len(v) == 0: return False
26 | v = v.lower()
27 | if 'true' in v: return True
28 | if 'false' in v: return False
29 |
30 | try:
31 | w = int(v)
32 | return bool(w)
33 | except:
34 | acceptable = ['True', 'False', '1', '0']
35 | s = ', '.join([f'`{v}`' for v in acceptable])
36 | raise ValueError(f'value must be one of {s}.')
37 |
38 |
39 | class AxisOptions:
40 |
41 | def __init__(self, AxisOption: type, axis_options: list):
42 | self.AxisOption = AxisOption
43 | self.target = axis_options
44 | self.options = []
45 |
46 | def __enter__(self):
47 | self.options.clear()
48 | return self
49 |
50 | def __exit__(self, ex_type, ex_value, trace):
51 | if ex_type is not None:
52 | return
53 |
54 | for opt in self.options:
55 | self.target.append(opt)
56 |
57 | self.options.clear()
58 |
59 | def create(self, name: str, type_fn: Callable, action: Callable, choices: Union[List[str],None]):
60 | if choices is None or len(choices) == 0:
61 | opt = self.AxisOption(name, type_fn, action)
62 | else:
63 | opt = self.AxisOption(name, type_fn, action, choices=lambda: choices)
64 | return opt
65 |
66 | def add(self, axis_option):
67 | self.target.append(axis_option)
68 |
69 |
70 | __init = False
71 |
72 | def init_xyz(script: type, ext_name: str):
73 | global __init
74 |
75 | if __init:
76 | return
77 |
78 | for data in scripts.scripts_data:
79 | name = os.path.basename(data.path)
80 | if name != 'xy_grid.py' and name != 'xyz_grid.py':
81 | continue
82 |
83 | if not hasattr(data.module, 'AxisOption'):
84 | continue
85 |
86 | if not hasattr(data.module, 'axis_options'):
87 | continue
88 |
89 | AxisOption = data.module.AxisOption
90 | axis_options = data.module.axis_options
91 |
92 | if not isinstance(AxisOption, type):
93 | continue
94 |
95 | if not isinstance(axis_options, list):
96 | continue
97 |
98 | try:
99 | create_options(ext_name, script, AxisOption, axis_options)
100 | except:
101 | pass
102 |
103 | __init = True
104 |
105 |
106 | def create_options(ext_name: str, script: type, AxisOptionClass: type, axis_options: list):
107 | with AxisOptions(AxisOptionClass, axis_options) as opts:
108 | def define(param: str, index: int, type_fn: Callable, choices: List[str] = []):
109 | def fn(p, x, xs):
110 | __set_value(p, script, index, x)
111 |
112 | name = f'[{ext_name}] {param}'
113 | return opts.create(name, type_fn, fn, choices)
114 |
115 | options = [
116 | define('Enabled', 0, to_bool, choices=['false', 'true']),
117 | define('Targets', 1, str),
118 | define('Weight', 2, float),
119 | define('Disable for Negative Prompt', 3, to_bool, choices=['false', 'true']),
120 | define('Strong', 4, to_bool, choices=['false', 'true']),
121 | define('Padding', 5, str),
122 | define('Interpolation', 6, str, choices=['Lerp', 'SLerp']),
123 | ]
124 |
125 | for opt in options:
126 | opts.add(opt)
127 |
--------------------------------------------------------------------------------