├── samples ├── pww-flow.png ├── sample-3-1.png ├── sample-3-2.png ├── sample-1-compressed.png ├── sample-2-compressed.png ├── sample-1-output-compressed.png └── sample-2-output-compressed.png ├── requirements.txt ├── README.md ├── .gitignore ├── modules ├── flash_attn.py ├── safechecker.py ├── lora.py ├── prompt_parser.py └── model.py └── app.py /samples/pww-flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mikubill/sd-paint-with-words/HEAD/samples/pww-flow.png -------------------------------------------------------------------------------- /samples/sample-3-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mikubill/sd-paint-with-words/HEAD/samples/sample-3-1.png -------------------------------------------------------------------------------- /samples/sample-3-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mikubill/sd-paint-with-words/HEAD/samples/sample-3-2.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | einops 3 | diffusers 4 | transformers 5 | k_diffusion 6 | safetensors 7 | gradio 8 | torch 9 | -------------------------------------------------------------------------------- /samples/sample-1-compressed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mikubill/sd-paint-with-words/HEAD/samples/sample-1-compressed.png -------------------------------------------------------------------------------- /samples/sample-2-compressed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mikubill/sd-paint-with-words/HEAD/samples/sample-2-compressed.png -------------------------------------------------------------------------------- /samples/sample-1-output-compressed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mikubill/sd-paint-with-words/HEAD/samples/sample-1-output-compressed.png -------------------------------------------------------------------------------- /samples/sample-2-output-compressed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mikubill/sd-paint-with-words/HEAD/samples/sample-2-output-compressed.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Paint-With-Words 2 | 3 | [Paper](https://arxiv.org/abs/2211.01324) | [Demo](https://huggingface.co/spaces/nyanko7/sd-diffusers-webui) 4 | 5 | Paint-with-words is a method proposed by researchers from NVIDIA that allows users to control the location of objects by selecting phrases and painting them on the canvas. The user-specified masks increase the value of corresponding entries of the attention matrix in the cross-attention layers. 6 | 7 | Inspired by this method, we created a simple a1111-style sketching UI that allows multi-mask input to address same area on different tokens. Also, textual-inversion and LoRA support are fully functional*, you can add them to the generation process and adjust the strength and area they are applied to. 8 | 9 | **Config and Run** 10 | 11 | 1. Set your model path in https://github.com/Mikubill/sd-paint-with-words/blob/15e800e6c5ec14763567ec47173c9528fedc2649/app.py#L28-L35 12 | 2. `python app.py` 13 | 14 | **Some samples** 15 | 16 | | Sketch | Image | 17 | |:-------------------------:|:-------------------------:| 18 | | | | 19 | | | | 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .pt 131 | *.pt 132 | weights/ 133 | -------------------------------------------------------------------------------- /modules/flash_attn.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch 4 | from einops import rearrange 5 | from torch import einsum 6 | from torch.autograd.function import Function 7 | 8 | 9 | EPSILON = 1e-6 10 | exists = lambda val: val is not None 11 | default = lambda val, d: val if exists(val) else d 12 | 13 | class FlashAttentionFunction(Function): 14 | @staticmethod 15 | @torch.no_grad() 16 | def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): 17 | """Algorithm 2 in the paper""" 18 | 19 | device = q.device 20 | max_neg_value = -torch.finfo(q.dtype).max 21 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) 22 | 23 | o = torch.zeros_like(q) 24 | all_row_sums = torch.zeros((*q.shape[:-1], 1), device=device) 25 | all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device=device) 26 | 27 | scale = q.shape[-1] ** -0.5 28 | 29 | if not exists(mask): 30 | mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) 31 | else: 32 | mask = rearrange(mask, "b n -> b 1 1 n") 33 | mask = mask.split(q_bucket_size, dim=-1) 34 | 35 | row_splits = zip( 36 | q.split(q_bucket_size, dim=-2), 37 | o.split(q_bucket_size, dim=-2), 38 | mask, 39 | all_row_sums.split(q_bucket_size, dim=-2), 40 | all_row_maxes.split(q_bucket_size, dim=-2), 41 | ) 42 | 43 | for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): 44 | q_start_index = ind * q_bucket_size - qk_len_diff 45 | 46 | col_splits = zip( 47 | k.split(k_bucket_size, dim=-2), 48 | v.split(k_bucket_size, dim=-2), 49 | ) 50 | 51 | for k_ind, (kc, vc) in enumerate(col_splits): 52 | k_start_index = k_ind * k_bucket_size 53 | 54 | attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale 55 | 56 | if exists(row_mask): 57 | attn_weights.masked_fill_(~row_mask, max_neg_value) 58 | 59 | if causal and q_start_index < (k_start_index + k_bucket_size - 1): 60 | causal_mask = torch.ones( 61 | (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device 62 | ).triu(q_start_index - k_start_index + 1) 63 | attn_weights.masked_fill_(causal_mask, max_neg_value) 64 | 65 | block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) 66 | attn_weights -= block_row_maxes 67 | exp_weights = torch.exp(attn_weights) 68 | 69 | if exists(row_mask): 70 | exp_weights.masked_fill_(~row_mask, 0.0) 71 | 72 | block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp( 73 | min=EPSILON 74 | ) 75 | 76 | new_row_maxes = torch.maximum(block_row_maxes, row_maxes) 77 | 78 | exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc) 79 | 80 | exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) 81 | exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) 82 | 83 | new_row_sums = ( 84 | exp_row_max_diff * row_sums 85 | + exp_block_row_max_diff * block_row_sums 86 | ) 87 | 88 | oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_( 89 | (exp_block_row_max_diff / new_row_sums) * exp_values 90 | ) 91 | 92 | row_maxes.copy_(new_row_maxes) 93 | row_sums.copy_(new_row_sums) 94 | 95 | lse = all_row_sums.log() + all_row_maxes 96 | 97 | ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) 98 | ctx.save_for_backward(q, k, v, o, lse) 99 | 100 | return o 101 | 102 | @staticmethod 103 | @torch.no_grad() 104 | def backward(ctx, do): 105 | """Algorithm 4 in the paper""" 106 | 107 | causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args 108 | q, k, v, o, lse = ctx.saved_tensors 109 | 110 | device = q.device 111 | 112 | max_neg_value = -torch.finfo(q.dtype).max 113 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) 114 | 115 | dq = torch.zeros_like(q) 116 | dk = torch.zeros_like(k) 117 | dv = torch.zeros_like(v) 118 | 119 | row_splits = zip( 120 | q.split(q_bucket_size, dim=-2), 121 | o.split(q_bucket_size, dim=-2), 122 | do.split(q_bucket_size, dim=-2), 123 | mask, 124 | lse.split(q_bucket_size, dim=-2), 125 | dq.split(q_bucket_size, dim=-2), 126 | ) 127 | 128 | for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits): 129 | q_start_index = ind * q_bucket_size - qk_len_diff 130 | 131 | col_splits = zip( 132 | k.split(k_bucket_size, dim=-2), 133 | v.split(k_bucket_size, dim=-2), 134 | dk.split(k_bucket_size, dim=-2), 135 | dv.split(k_bucket_size, dim=-2), 136 | ) 137 | 138 | for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): 139 | k_start_index = k_ind * k_bucket_size 140 | 141 | attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale 142 | 143 | if causal and q_start_index < (k_start_index + k_bucket_size - 1): 144 | causal_mask = torch.ones( 145 | (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device 146 | ).triu(q_start_index - k_start_index + 1) 147 | attn_weights.masked_fill_(causal_mask, max_neg_value) 148 | 149 | p = torch.exp(attn_weights - lsec) 150 | 151 | if exists(row_mask): 152 | p.masked_fill_(~row_mask, 0.0) 153 | 154 | dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) 155 | dp = einsum("... i d, ... j d -> ... i j", doc, vc) 156 | 157 | D = (doc * oc).sum(dim=-1, keepdims=True) 158 | ds = p * scale * (dp - D) 159 | 160 | dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) 161 | dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) 162 | 163 | dqc.add_(dq_chunk) 164 | dkc.add_(dk_chunk) 165 | dvc.add_(dv_chunk) 166 | 167 | return dq, dk, dv, None, None, None, None 168 | -------------------------------------------------------------------------------- /modules/safechecker.py: -------------------------------------------------------------------------------- 1 | # this code is adapted from the script contributed by anon from /h/ 2 | # modified, from https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/6cff4401824299a983c8e13424018efc347b4a2b/modules/safe.py 3 | 4 | import io 5 | import pickle 6 | import collections 7 | import sys 8 | import traceback 9 | 10 | import torch 11 | import numpy 12 | import _codecs 13 | import zipfile 14 | import re 15 | 16 | 17 | # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage 18 | TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage 19 | 20 | 21 | def encode(*args): 22 | out = _codecs.encode(*args) 23 | return out 24 | 25 | 26 | class RestrictedUnpickler(pickle.Unpickler): 27 | extra_handler = None 28 | 29 | def persistent_load(self, saved_id): 30 | assert saved_id[0] == 'storage' 31 | return TypedStorage() 32 | 33 | def find_class(self, module, name): 34 | if self.extra_handler is not None: 35 | res = self.extra_handler(module, name) 36 | if res is not None: 37 | return res 38 | 39 | if module == 'collections' and name == 'OrderedDict': 40 | return getattr(collections, name) 41 | if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']: 42 | return getattr(torch._utils, name) 43 | if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']: 44 | return getattr(torch, name) 45 | if module == 'torch.nn.modules.container' and name in ['ParameterDict']: 46 | return getattr(torch.nn.modules.container, name) 47 | if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']: 48 | return getattr(numpy.core.multiarray, name) 49 | if module == 'numpy' and name in ['dtype', 'ndarray']: 50 | return getattr(numpy, name) 51 | if module == '_codecs' and name == 'encode': 52 | return encode 53 | if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': 54 | import pytorch_lightning.callbacks 55 | return pytorch_lightning.callbacks.model_checkpoint 56 | if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': 57 | import pytorch_lightning.callbacks.model_checkpoint 58 | return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint 59 | if module == "__builtin__" and name == 'set': 60 | return set 61 | 62 | # Forbid everything else. 63 | raise Exception(f"global '{module}/{name}' is forbidden") 64 | 65 | 66 | # Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/' 67 | allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$") 68 | data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$") 69 | 70 | def check_zip_filenames(filename, names): 71 | for name in names: 72 | if allowed_zip_names_re.match(name): 73 | continue 74 | 75 | raise Exception(f"bad file inside {filename}: {name}") 76 | 77 | 78 | def check_pt(filename, extra_handler): 79 | try: 80 | 81 | # new pytorch format is a zip file 82 | with zipfile.ZipFile(filename) as z: 83 | check_zip_filenames(filename, z.namelist()) 84 | 85 | # find filename of data.pkl in zip file: '/data.pkl' 86 | data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)] 87 | if len(data_pkl_filenames) == 0: 88 | raise Exception(f"data.pkl not found in {filename}") 89 | if len(data_pkl_filenames) > 1: 90 | raise Exception(f"Multiple data.pkl found in {filename}") 91 | with z.open(data_pkl_filenames[0]) as file: 92 | unpickler = RestrictedUnpickler(file) 93 | unpickler.extra_handler = extra_handler 94 | unpickler.load() 95 | 96 | except zipfile.BadZipfile: 97 | 98 | # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle 99 | with open(filename, "rb") as file: 100 | unpickler = RestrictedUnpickler(file) 101 | unpickler.extra_handler = extra_handler 102 | for i in range(5): 103 | unpickler.load() 104 | 105 | 106 | def load(filename, *args, **kwargs): 107 | return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs) 108 | 109 | 110 | def load_with_extra(filename, extra_handler=None, *args, **kwargs): 111 | """ 112 | this function is intended to be used by extensions that want to load models with 113 | some extra classes in them that the usual unpickler would find suspicious. 114 | 115 | Use the extra_handler argument to specify a function that takes module and field name as text, 116 | and returns that field's value: 117 | 118 | ```python 119 | def extra(module, name): 120 | if module == 'collections' and name == 'OrderedDict': 121 | return collections.OrderedDict 122 | 123 | return None 124 | 125 | safe.load_with_extra('model.pt', extra_handler=extra) 126 | ``` 127 | 128 | The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is 129 | definitely unsafe. 130 | """ 131 | 132 | try: 133 | check_pt(filename, extra_handler) 134 | 135 | except pickle.UnpicklingError: 136 | print(f"Error verifying pickled file from {filename}:", file=sys.stderr) 137 | print(traceback.format_exc(), file=sys.stderr) 138 | print("The file is most likely corrupted.", file=sys.stderr) 139 | return None 140 | 141 | except Exception: 142 | print(f"Error verifying pickled file from {filename}:", file=sys.stderr) 143 | print(traceback.format_exc(), file=sys.stderr) 144 | print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) 145 | print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr) 146 | return None 147 | 148 | return unsafe_torch_load(filename, *args, **kwargs) 149 | 150 | 151 | class Extra: 152 | """ 153 | A class for temporarily setting the global handler for when you can't explicitly call load_with_extra 154 | (because it's not your code making the torch.load call). The intended use is like this: 155 | 156 | ``` 157 | import torch 158 | from modules import safe 159 | 160 | def handler(module, name): 161 | if module == 'torch' and name in ['float64', 'float16']: 162 | return getattr(torch, name) 163 | 164 | return None 165 | 166 | with safe.Extra(handler): 167 | x = torch.load('model.pt') 168 | ``` 169 | """ 170 | 171 | def __init__(self, handler): 172 | self.handler = handler 173 | 174 | def __enter__(self): 175 | global global_extra_handler 176 | 177 | assert global_extra_handler is None, 'already inside an Extra() block' 178 | global_extra_handler = self.handler 179 | 180 | def __exit__(self, exc_type, exc_val, exc_tb): 181 | global global_extra_handler 182 | 183 | global_extra_handler = None 184 | 185 | 186 | unsafe_torch_load = torch.load 187 | torch.load = load 188 | global_extra_handler = None 189 | -------------------------------------------------------------------------------- /modules/lora.py: -------------------------------------------------------------------------------- 1 | # LoRA network module 2 | # reference: 3 | # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py 4 | # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py 5 | # https://github.com/bmaltais/kohya_ss/blob/master/networks/lora.py#L48 6 | 7 | import math 8 | import os 9 | import torch 10 | import modules.safechecker as _ 11 | from safetensors.torch import load_file 12 | 13 | 14 | class LoRAModule(torch.nn.Module): 15 | """ 16 | replaces forward method of the original Linear, instead of replacing the original Linear module. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | lora_name, 22 | org_module: torch.nn.Module, 23 | multiplier=1.0, 24 | lora_dim=4, 25 | alpha=1, 26 | ): 27 | """if alpha == 0 or None, alpha is rank (no scaling).""" 28 | super().__init__() 29 | self.lora_name = lora_name 30 | self.lora_dim = lora_dim 31 | 32 | if org_module.__class__.__name__ == "Conv2d": 33 | in_dim = org_module.in_channels 34 | out_dim = org_module.out_channels 35 | self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False) 36 | self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False) 37 | else: 38 | in_dim = org_module.in_features 39 | out_dim = org_module.out_features 40 | self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False) 41 | self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False) 42 | 43 | if type(alpha) == torch.Tensor: 44 | alpha = alpha.detach().float().numpy() # without casting, bf16 causes error 45 | 46 | alpha = lora_dim if alpha is None or alpha == 0 else alpha 47 | self.scale = alpha / self.lora_dim 48 | self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える 49 | 50 | # same as microsoft's 51 | torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) 52 | torch.nn.init.zeros_(self.lora_up.weight) 53 | 54 | self.multiplier = multiplier 55 | self.org_module = org_module # remove in applying 56 | self.enable = False 57 | 58 | def resize(self, rank, alpha, multiplier): 59 | self.alpha = torch.tensor(alpha) 60 | self.multiplier = multiplier 61 | self.scale = alpha / rank 62 | if self.lora_down.__class__.__name__ == "Conv2d": 63 | in_dim = self.lora_down.in_channels 64 | out_dim = self.lora_up.out_channels 65 | self.lora_down = torch.nn.Conv2d(in_dim, rank, (1, 1), bias=False) 66 | self.lora_up = torch.nn.Conv2d(rank, out_dim, (1, 1), bias=False) 67 | else: 68 | in_dim = self.lora_down.in_features 69 | out_dim = self.lora_up.out_features 70 | self.lora_down = torch.nn.Linear(in_dim, rank, bias=False) 71 | self.lora_up = torch.nn.Linear(rank, out_dim, bias=False) 72 | 73 | def apply(self): 74 | if hasattr(self, "org_module"): 75 | self.org_forward = self.org_module.forward 76 | self.org_module.forward = self.forward 77 | del self.org_module 78 | 79 | def forward(self, x): 80 | if self.enable: 81 | return ( 82 | self.org_forward(x) 83 | + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale 84 | ) 85 | return self.org_forward(x) 86 | 87 | 88 | class LoRANetwork(torch.nn.Module): 89 | UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] 90 | TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] 91 | LORA_PREFIX_UNET = "lora_unet" 92 | LORA_PREFIX_TEXT_ENCODER = "lora_te" 93 | 94 | def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None: 95 | super().__init__() 96 | self.multiplier = multiplier 97 | self.lora_dim = lora_dim 98 | self.alpha = alpha 99 | 100 | # create module instances 101 | def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules): 102 | loras = [] 103 | for name, module in root_module.named_modules(): 104 | if module.__class__.__name__ in target_replace_modules: 105 | for child_name, child_module in module.named_modules(): 106 | if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): 107 | lora_name = prefix + "." + name + "." + child_name 108 | lora_name = lora_name.replace(".", "_") 109 | lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha,) 110 | loras.append(lora) 111 | return loras 112 | 113 | if isinstance(text_encoder, list): 114 | self.text_encoder_loras = text_encoder 115 | else: 116 | self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) 117 | print(f"Create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") 118 | 119 | self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE) 120 | print(f"Create LoRA for U-Net: {len(self.unet_loras)} modules.") 121 | 122 | self.weights_sd = None 123 | 124 | # assertion 125 | names = set() 126 | for lora in self.text_encoder_loras + self.unet_loras: 127 | assert (lora.lora_name not in names), f"duplicated lora name: {lora.lora_name}" 128 | names.add(lora.lora_name) 129 | 130 | lora.apply() 131 | self.add_module(lora.lora_name, lora) 132 | 133 | def reset(self): 134 | for lora in self.text_encoder_loras + self.unet_loras: 135 | lora.enable = False 136 | 137 | def load(self, file, scale): 138 | 139 | weights = None 140 | if os.path.splitext(file)[1] == ".safetensors": 141 | weights = load_file(file) 142 | else: 143 | weights = torch.load(file, map_location="cpu") 144 | 145 | if not weights: 146 | return 147 | 148 | network_alpha = None 149 | network_dim = None 150 | for key, value in weights.items(): 151 | if network_alpha is None and "alpha" in key: 152 | network_alpha = value 153 | if network_dim is None and "lora_down" in key and len(value.size()) == 2: 154 | network_dim = value.size()[0] 155 | 156 | if network_alpha is None: 157 | network_alpha = network_dim 158 | 159 | weights_has_text_encoder = weights_has_unet = False 160 | weights_to_modify = [] 161 | 162 | for key in weights.keys(): 163 | if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): 164 | weights_has_text_encoder = True 165 | 166 | if key.startswith(LoRANetwork.LORA_PREFIX_UNET): 167 | weights_has_unet = True 168 | 169 | if weights_has_text_encoder: 170 | weights_to_modify += self.text_encoder_loras 171 | 172 | if weights_has_unet: 173 | weights_to_modify += self.unet_loras 174 | 175 | for lora in self.text_encoder_loras + self.unet_loras: 176 | lora.resize(network_dim, network_alpha, scale) 177 | if lora in weights_to_modify: 178 | lora.enable = True 179 | 180 | info = self.load_state_dict(weights, False) 181 | if len(info.unexpected_keys) > 0: 182 | print(f"Weights are loaded. Unexpected keys={info.unexpected_keys}") 183 | -------------------------------------------------------------------------------- /modules/prompt_parser.py: -------------------------------------------------------------------------------- 1 | import re 2 | import math 3 | import numpy as np 4 | import torch 5 | 6 | # Code from https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/8e2aeee4a127b295bfc880800e4a312e0f049b85, modified. 7 | 8 | 9 | class PromptChunk: 10 | """ 11 | This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt. 12 | If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary. 13 | Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token, 14 | so just 75 tokens from prompt. 15 | """ 16 | 17 | def __init__(self): 18 | self.tokens = [] 19 | self.multipliers = [] 20 | self.fixes = [] 21 | 22 | 23 | class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): 24 | """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to 25 | have unlimited prompt length and assign weights to tokens in prompt. 26 | """ 27 | 28 | def __init__(self, text_encoder, enable_emphasis=True): 29 | super().__init__() 30 | 31 | self.device = lambda: text_encoder.device 32 | self.enable_emphasis = enable_emphasis 33 | """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation, 34 | depending on model.""" 35 | 36 | self.chunk_length = 75 37 | 38 | def empty_chunk(self): 39 | """creates an empty PromptChunk and returns it""" 40 | 41 | chunk = PromptChunk() 42 | chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1) 43 | chunk.multipliers = [1.0] * (self.chunk_length + 2) 44 | return chunk 45 | 46 | def get_target_prompt_token_count(self, token_count): 47 | """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented""" 48 | 49 | return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length 50 | 51 | def tokenize_line(self, line): 52 | """ 53 | this transforms a single prompt into a list of PromptChunk objects - as many as needed to 54 | represent the prompt. 55 | Returns the list and the total number of tokens in the prompt. 56 | """ 57 | 58 | if self.enable_emphasis: 59 | parsed = parse_prompt_attention(line) 60 | else: 61 | parsed = [[line, 1.0]] 62 | 63 | tokenized = self.tokenize([text for text, _ in parsed]) 64 | 65 | chunks = [] 66 | chunk = PromptChunk() 67 | token_count = 0 68 | last_comma = -1 69 | 70 | def next_chunk(is_last=False): 71 | """puts current chunk into the list of results and produces the next one - empty; 72 | if is_last is true, tokens tokens at the end won't add to token_count""" 73 | nonlocal token_count 74 | nonlocal last_comma 75 | nonlocal chunk 76 | 77 | if is_last: 78 | token_count += len(chunk.tokens) 79 | else: 80 | token_count += self.chunk_length 81 | 82 | to_add = self.chunk_length - len(chunk.tokens) 83 | if to_add > 0: 84 | chunk.tokens += [self.id_end] * to_add 85 | chunk.multipliers += [1.0] * to_add 86 | 87 | chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end] 88 | chunk.multipliers = [1.0] + chunk.multipliers + [1.0] 89 | 90 | last_comma = -1 91 | chunks.append(chunk) 92 | chunk = PromptChunk() 93 | 94 | comma_padding_backtrack = 20 # default value in https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/6cff4401824299a983c8e13424018efc347b4a2b/modules/shared.py#L410 95 | for tokens, (text, weight) in zip(tokenized, parsed): 96 | if text == "BREAK" and weight == -1: 97 | next_chunk() 98 | continue 99 | 100 | position = 0 101 | while position < len(tokens): 102 | token = tokens[position] 103 | 104 | if token == self.comma_token: 105 | last_comma = len(chunk.tokens) 106 | 107 | # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack 108 | # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next. 109 | elif ( 110 | comma_padding_backtrack != 0 111 | and len(chunk.tokens) == self.chunk_length 112 | and last_comma != -1 113 | and len(chunk.tokens) - last_comma <= comma_padding_backtrack 114 | ): 115 | break_location = last_comma + 1 116 | 117 | reloc_tokens = chunk.tokens[break_location:] 118 | reloc_mults = chunk.multipliers[break_location:] 119 | 120 | chunk.tokens = chunk.tokens[:break_location] 121 | chunk.multipliers = chunk.multipliers[:break_location] 122 | 123 | next_chunk() 124 | chunk.tokens = reloc_tokens 125 | chunk.multipliers = reloc_mults 126 | 127 | if len(chunk.tokens) == self.chunk_length: 128 | next_chunk() 129 | 130 | chunk.tokens.append(token) 131 | chunk.multipliers.append(weight) 132 | position += 1 133 | 134 | if len(chunk.tokens) > 0 or len(chunks) == 0: 135 | next_chunk(is_last=True) 136 | 137 | return chunks, token_count 138 | 139 | def process_texts(self, texts): 140 | """ 141 | Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum 142 | length, in tokens, of all texts. 143 | """ 144 | 145 | token_count = 0 146 | 147 | cache = {} 148 | batch_chunks = [] 149 | for line in texts: 150 | if line in cache: 151 | chunks = cache[line] 152 | else: 153 | chunks, current_token_count = self.tokenize_line(line) 154 | token_count = max(current_token_count, token_count) 155 | 156 | cache[line] = chunks 157 | 158 | batch_chunks.append(chunks) 159 | 160 | return batch_chunks, token_count 161 | 162 | def forward(self, texts): 163 | """ 164 | Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts. 165 | Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will 166 | be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024. 167 | An example shape returned by this function can be: (2, 77, 768). 168 | Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet 169 | is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" 170 | """ 171 | 172 | batch_chunks, token_count = self.process_texts(texts) 173 | chunk_count = max([len(x) for x in batch_chunks]) 174 | 175 | zs = [] 176 | ts = [] 177 | for i in range(chunk_count): 178 | batch_chunk = [ 179 | chunks[i] if i < len(chunks) else self.empty_chunk() 180 | for chunks in batch_chunks 181 | ] 182 | 183 | tokens = [x.tokens for x in batch_chunk] 184 | multipliers = [x.multipliers for x in batch_chunk] 185 | # self.embeddings.fixes = [x.fixes for x in batch_chunk] 186 | 187 | # for fixes in self.embeddings.fixes: 188 | # for position, embedding in fixes: 189 | # used_embeddings[embedding.name] = embedding 190 | 191 | z = self.process_tokens(tokens, multipliers) 192 | zs.append(z) 193 | ts.append(tokens) 194 | 195 | return np.hstack(ts), torch.hstack(zs) 196 | 197 | def process_tokens(self, remade_batch_tokens, batch_multipliers): 198 | """ 199 | sends one single prompt chunk to be encoded by transformers neural network. 200 | remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually 201 | there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens. 202 | Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier 203 | corresponds to one token. 204 | """ 205 | tokens = torch.asarray(remade_batch_tokens).to(self.device()) 206 | 207 | # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones. 208 | if self.id_end != self.id_pad: 209 | for batch_pos in range(len(remade_batch_tokens)): 210 | index = remade_batch_tokens[batch_pos].index(self.id_end) 211 | tokens[batch_pos, index + 1 : tokens.shape[1]] = self.id_pad 212 | 213 | z = self.encode_with_transformers(tokens) 214 | 215 | # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise 216 | batch_multipliers = torch.asarray(batch_multipliers).to(self.device()) 217 | original_mean = z.mean() 218 | z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand( 219 | z.shape 220 | ) 221 | new_mean = z.mean() 222 | z = z * (original_mean / new_mean) 223 | 224 | return z 225 | 226 | 227 | class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): 228 | CLIP_stop_at_last_layers = 1 229 | 230 | def __init__(self, tokenizer, text_encoder): 231 | super().__init__(text_encoder) 232 | self.tokenizer = tokenizer 233 | self.text_encoder = text_encoder 234 | 235 | vocab = self.tokenizer.get_vocab() 236 | 237 | self.comma_token = vocab.get(",", None) 238 | 239 | self.token_mults = {} 240 | tokens_with_parens = [ 241 | (k, v) 242 | for k, v in vocab.items() 243 | if "(" in k or ")" in k or "[" in k or "]" in k 244 | ] 245 | for text, ident in tokens_with_parens: 246 | mult = 1.0 247 | for c in text: 248 | if c == "[": 249 | mult /= 1.1 250 | if c == "]": 251 | mult *= 1.1 252 | if c == "(": 253 | mult *= 1.1 254 | if c == ")": 255 | mult /= 1.1 256 | 257 | if mult != 1.0: 258 | self.token_mults[ident] = mult 259 | 260 | self.id_start = self.tokenizer.bos_token_id 261 | self.id_end = self.tokenizer.eos_token_id 262 | self.id_pad = self.id_end 263 | 264 | def tokenize(self, texts): 265 | tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)[ 266 | "input_ids" 267 | ] 268 | 269 | return tokenized 270 | 271 | def encode_with_transformers(self, tokens): 272 | tokens = tokens.to(self.text_encoder.device) 273 | outputs = self.text_encoder(tokens, output_hidden_states=True) 274 | 275 | if self.CLIP_stop_at_last_layers > 1: 276 | z = outputs.hidden_states[-self.CLIP_stop_at_last_layers] 277 | z = self.text_encoder.text_model.final_layer_norm(z) 278 | else: 279 | z = outputs.last_hidden_state 280 | 281 | return z 282 | 283 | 284 | re_attention = re.compile( 285 | r""" 286 | \\\(| 287 | \\\)| 288 | \\\[| 289 | \\]| 290 | \\\\| 291 | \\| 292 | \(| 293 | \[| 294 | :([+-]?[.\d]+)\)| 295 | \)| 296 | ]| 297 | [^\\()\[\]:]+| 298 | : 299 | """, 300 | re.X, 301 | ) 302 | 303 | re_break = re.compile(r"\s*\bBREAK\b\s*", re.S) 304 | 305 | 306 | def parse_prompt_attention(text): 307 | """ 308 | Parses a string with attention tokens and returns a list of pairs: text and its associated weight. 309 | Accepted tokens are: 310 | (abc) - increases attention to abc by a multiplier of 1.1 311 | (abc:3.12) - increases attention to abc by a multiplier of 3.12 312 | [abc] - decreases attention to abc by a multiplier of 1.1 313 | \( - literal character '(' 314 | \[ - literal character '[' 315 | \) - literal character ')' 316 | \] - literal character ']' 317 | \\ - literal character '\' 318 | anything else - just text 319 | 320 | >>> parse_prompt_attention('normal text') 321 | [['normal text', 1.0]] 322 | >>> parse_prompt_attention('an (important) word') 323 | [['an ', 1.0], ['important', 1.1], [' word', 1.0]] 324 | >>> parse_prompt_attention('(unbalanced') 325 | [['unbalanced', 1.1]] 326 | >>> parse_prompt_attention('\(literal\]') 327 | [['(literal]', 1.0]] 328 | >>> parse_prompt_attention('(unnecessary)(parens)') 329 | [['unnecessaryparens', 1.1]] 330 | >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') 331 | [['a ', 1.0], 332 | ['house', 1.5730000000000004], 333 | [' ', 1.1], 334 | ['on', 1.0], 335 | [' a ', 1.1], 336 | ['hill', 0.55], 337 | [', sun, ', 1.1], 338 | ['sky', 1.4641000000000006], 339 | ['.', 1.1]] 340 | """ 341 | 342 | res = [] 343 | round_brackets = [] 344 | square_brackets = [] 345 | 346 | round_bracket_multiplier = 1.1 347 | square_bracket_multiplier = 1 / 1.1 348 | 349 | def multiply_range(start_position, multiplier): 350 | for p in range(start_position, len(res)): 351 | res[p][1] *= multiplier 352 | 353 | for m in re_attention.finditer(text): 354 | text = m.group(0) 355 | weight = m.group(1) 356 | 357 | if text.startswith("\\"): 358 | res.append([text[1:], 1.0]) 359 | elif text == "(": 360 | round_brackets.append(len(res)) 361 | elif text == "[": 362 | square_brackets.append(len(res)) 363 | elif weight is not None and len(round_brackets) > 0: 364 | multiply_range(round_brackets.pop(), float(weight)) 365 | elif text == ")" and len(round_brackets) > 0: 366 | multiply_range(round_brackets.pop(), round_bracket_multiplier) 367 | elif text == "]" and len(square_brackets) > 0: 368 | multiply_range(square_brackets.pop(), square_bracket_multiplier) 369 | else: 370 | parts = re.split(re_break, text) 371 | for i, part in enumerate(parts): 372 | if i > 0: 373 | res.append(["BREAK", -1]) 374 | res.append([part, 1.0]) 375 | 376 | for pos in round_brackets: 377 | multiply_range(pos, round_bracket_multiplier) 378 | 379 | for pos in square_brackets: 380 | multiply_range(pos, square_bracket_multiplier) 381 | 382 | if len(res) == 0: 383 | res = [["", 1.0]] 384 | 385 | # merge runs of identical weights 386 | i = 0 387 | while i + 1 < len(res): 388 | if res[i][1] == res[i + 1][1]: 389 | res[i][0] += res[i + 1][0] 390 | res.pop(i + 1) 391 | else: 392 | i += 1 393 | 394 | return res 395 | -------------------------------------------------------------------------------- /modules/model.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import inspect 3 | import math 4 | from pathlib import Path 5 | import re 6 | from collections import defaultdict 7 | from typing import List, Optional, Union 8 | 9 | import k_diffusion 10 | import numpy as np 11 | import PIL 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from einops import rearrange 16 | from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser 17 | from modules.prompt_parser import FrozenCLIPEmbedderWithCustomWords 18 | from torch import einsum 19 | from torch.autograd.function import Function 20 | 21 | from diffusers import DiffusionPipeline 22 | from diffusers.utils import PIL_INTERPOLATION, is_accelerate_available 23 | from diffusers.utils import logging, randn_tensor 24 | from modules.flash_attn import FlashAttentionFunction 25 | 26 | import modules.safechecker as _ 27 | from safetensors.torch import load_file 28 | 29 | xformers_available = False 30 | try: 31 | import xformers 32 | 33 | xformers_available = True 34 | except ImportError: 35 | pass 36 | 37 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 38 | 39 | 40 | def get_attention_scores(attn, query, key, attention_mask=None): 41 | 42 | if attn.upcast_attention: 43 | query = query.float() 44 | key = key.float() 45 | 46 | attention_scores = torch.baddbmm( 47 | torch.empty( 48 | query.shape[0], 49 | query.shape[1], 50 | key.shape[1], 51 | dtype=query.dtype, 52 | device=query.device, 53 | ), 54 | query, 55 | key.transpose(-1, -2), 56 | beta=0, 57 | alpha=attn.scale, 58 | ) 59 | 60 | if attention_mask is not None: 61 | attention_scores = attention_scores + attention_mask 62 | 63 | if attn.upcast_softmax: 64 | attention_scores = attention_scores.float() 65 | 66 | return attention_scores 67 | 68 | 69 | class CrossAttnProcessor(nn.Module): 70 | def __call__( 71 | self, 72 | attn, 73 | hidden_states, 74 | encoder_hidden_states=None, 75 | attention_mask=None, 76 | ): 77 | batch_size, sequence_length, _ = hidden_states.shape 78 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) 79 | 80 | encoder_states = hidden_states 81 | is_xattn = False 82 | if encoder_hidden_states is not None: 83 | is_xattn = True 84 | img_state = encoder_hidden_states["img_state"] 85 | encoder_states = encoder_hidden_states["states"] 86 | weight_func = encoder_hidden_states["weight_func"] 87 | sigma = encoder_hidden_states["sigma"] 88 | 89 | query = attn.to_q(hidden_states) 90 | key = attn.to_k(encoder_states) 91 | value = attn.to_v(encoder_states) 92 | 93 | query = attn.head_to_batch_dim(query) 94 | key = attn.head_to_batch_dim(key) 95 | value = attn.head_to_batch_dim(value) 96 | 97 | if is_xattn and isinstance(img_state, dict): 98 | # use torch.baddbmm method (slow) 99 | attention_scores = get_attention_scores(attn, query, key, attention_mask) 100 | w = img_state[sequence_length].to(query.device) 101 | cross_attention_weight = weight_func(w, sigma, attention_scores) 102 | attention_scores += torch.repeat_interleave( 103 | cross_attention_weight, repeats=attn.heads, dim=0 104 | ) 105 | 106 | # calc probs 107 | attention_probs = attention_scores.softmax(dim=-1) 108 | attention_probs = attention_probs.to(query.dtype) 109 | hidden_states = torch.bmm(attention_probs, value) 110 | 111 | elif xformers_available: 112 | hidden_states = xformers.ops.memory_efficient_attention( 113 | query.contiguous(), 114 | key.contiguous(), 115 | value.contiguous(), 116 | attn_bias=attention_mask, 117 | ) 118 | hidden_states = hidden_states.to(query.dtype) 119 | 120 | else: 121 | q_bucket_size = 512 122 | k_bucket_size = 1024 123 | 124 | # use flash-attention 125 | hidden_states = FlashAttentionFunction.apply( 126 | query.contiguous(), 127 | key.contiguous(), 128 | value.contiguous(), 129 | attention_mask, 130 | causal=False, 131 | q_bucket_size=q_bucket_size, 132 | k_bucket_size=k_bucket_size, 133 | ) 134 | hidden_states = hidden_states.to(query.dtype) 135 | 136 | hidden_states = attn.batch_to_head_dim(hidden_states) 137 | 138 | # linear proj 139 | hidden_states = attn.to_out[0](hidden_states) 140 | 141 | # dropout 142 | hidden_states = attn.to_out[1](hidden_states) 143 | 144 | return hidden_states 145 | 146 | class ModelWrapper: 147 | def __init__(self, model, alphas_cumprod): 148 | self.model = model 149 | self.alphas_cumprod = alphas_cumprod 150 | 151 | def apply_model(self, *args, **kwargs): 152 | if len(args) == 3: 153 | encoder_hidden_states = args[-1] 154 | args = args[:2] 155 | if kwargs.get("cond", None) is not None: 156 | encoder_hidden_states = kwargs.pop("cond") 157 | return self.model( 158 | *args, encoder_hidden_states=encoder_hidden_states, **kwargs 159 | ).sample 160 | 161 | 162 | class StableDiffusionPipeline(DiffusionPipeline): 163 | 164 | _optional_components = ["safety_checker", "feature_extractor"] 165 | 166 | def __init__( 167 | self, 168 | vae, 169 | text_encoder, 170 | tokenizer, 171 | unet, 172 | scheduler, 173 | ): 174 | super().__init__() 175 | 176 | # get correct sigmas from LMS 177 | self.register_modules( 178 | vae=vae, 179 | text_encoder=text_encoder, 180 | tokenizer=tokenizer, 181 | unet=unet, 182 | scheduler=scheduler, 183 | ) 184 | self.setup_unet(self.unet) 185 | self.setup_text_encoder() 186 | 187 | def setup_text_encoder(self, n=1, new_encoder=None): 188 | if new_encoder is not None: 189 | self.text_encoder = new_encoder 190 | 191 | self.prompt_parser = FrozenCLIPEmbedderWithCustomWords(self.tokenizer, self.text_encoder) 192 | self.prompt_parser.CLIP_stop_at_last_layers = n 193 | 194 | def setup_unet(self, unet): 195 | unet = unet.to(self.device) 196 | model = ModelWrapper(unet, self.scheduler.alphas_cumprod) 197 | if self.scheduler.prediction_type == "v_prediction": 198 | self.k_diffusion_model = CompVisVDenoiser(model) 199 | else: 200 | self.k_diffusion_model = CompVisDenoiser(model) 201 | 202 | def get_scheduler(self, scheduler_type: str): 203 | library = importlib.import_module("k_diffusion") 204 | sampling = getattr(library, "sampling") 205 | return getattr(sampling, scheduler_type) 206 | 207 | def encode_sketchs(self, state, scale_ratio=8, g_strength=1.0, text_ids=None): 208 | uncond, cond = text_ids[0], text_ids[1] 209 | 210 | img_state = [] 211 | if state is None: 212 | return torch.FloatTensor(0) 213 | 214 | for k, v in state.items(): 215 | if v["map"] is None: 216 | continue 217 | 218 | v_input = self.tokenizer( 219 | k, 220 | max_length=self.tokenizer.model_max_length, 221 | truncation=True, 222 | add_special_tokens=False, 223 | ).input_ids 224 | 225 | dotmap = v["map"] < 255 226 | out = dotmap.astype(float) 227 | if v["mask_outsides"]: 228 | out[out==0] = -1 229 | 230 | arr = torch.from_numpy( 231 | out * float(v["weight"]) * g_strength 232 | ) 233 | img_state.append((v_input, arr)) 234 | 235 | if len(img_state) == 0: 236 | return torch.FloatTensor(0) 237 | 238 | w_tensors = dict() 239 | cond = cond.tolist() 240 | uncond = uncond.tolist() 241 | for layer in self.unet.down_blocks: 242 | c = int(len(cond)) 243 | w, h = img_state[0][1].shape 244 | w_r, h_r = w // scale_ratio, h // scale_ratio 245 | 246 | ret_cond_tensor = torch.zeros((1, int(w_r * h_r), c), dtype=torch.float32) 247 | ret_uncond_tensor = torch.zeros((1, int(w_r * h_r), c), dtype=torch.float32) 248 | 249 | for v_as_tokens, img_where_color in img_state: 250 | is_in = 0 251 | 252 | ret = ( 253 | F.interpolate( 254 | img_where_color.unsqueeze(0).unsqueeze(1), 255 | scale_factor=1 / scale_ratio, 256 | mode="bilinear", 257 | align_corners=True, 258 | ) 259 | .squeeze() 260 | .reshape(-1, 1) 261 | .repeat(1, len(v_as_tokens)) 262 | ) 263 | 264 | for idx, tok in enumerate(cond): 265 | if cond[idx : idx + len(v_as_tokens)] == v_as_tokens: 266 | is_in = 1 267 | ret_cond_tensor[0, :, idx : idx + len(v_as_tokens)] += ret 268 | 269 | for idx, tok in enumerate(uncond): 270 | if uncond[idx : idx + len(v_as_tokens)] == v_as_tokens: 271 | is_in = 1 272 | ret_uncond_tensor[0, :, idx : idx + len(v_as_tokens)] += ret 273 | 274 | if not is_in == 1: 275 | print(f"tokens {v_as_tokens} not found in text") 276 | 277 | w_tensors[w_r * h_r] = torch.cat([ret_uncond_tensor, ret_cond_tensor]) 278 | scale_ratio *= 2 279 | 280 | return w_tensors 281 | 282 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 283 | r""" 284 | Enable sliced attention computation. 285 | 286 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention 287 | in several steps. This is useful to save some memory in exchange for a small speed decrease. 288 | 289 | Args: 290 | slice_size (`str` or `int`, *optional*, defaults to `"auto"`): 291 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If 292 | a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, 293 | `attention_head_dim` must be a multiple of `slice_size`. 294 | """ 295 | if slice_size == "auto": 296 | # half the attention head size is usually a good trade-off between 297 | # speed and memory 298 | slice_size = self.unet.config.attention_head_dim // 2 299 | self.unet.set_attention_slice(slice_size) 300 | 301 | def disable_attention_slicing(self): 302 | r""" 303 | Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go 304 | back to computing attention in one step. 305 | """ 306 | # set slice_size = `None` to disable `attention slicing` 307 | self.enable_attention_slicing(None) 308 | 309 | def enable_sequential_cpu_offload(self, gpu_id=0): 310 | r""" 311 | Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, 312 | text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a 313 | `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. 314 | """ 315 | if is_accelerate_available(): 316 | from accelerate import cpu_offload 317 | else: 318 | raise ImportError("Please install accelerate via `pip install accelerate`") 319 | 320 | device = torch.device(f"cuda:{gpu_id}") 321 | 322 | for cpu_offloaded_model in [ 323 | self.unet, 324 | self.text_encoder, 325 | self.vae, 326 | self.safety_checker, 327 | ]: 328 | if cpu_offloaded_model is not None: 329 | cpu_offload(cpu_offloaded_model, device) 330 | 331 | @property 332 | def _execution_device(self): 333 | r""" 334 | Returns the device on which the pipeline's models will be executed. After calling 335 | `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module 336 | hooks. 337 | """ 338 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 339 | return self.device 340 | for module in self.unet.modules(): 341 | if ( 342 | hasattr(module, "_hf_hook") 343 | and hasattr(module._hf_hook, "execution_device") 344 | and module._hf_hook.execution_device is not None 345 | ): 346 | return torch.device(module._hf_hook.execution_device) 347 | return self.device 348 | 349 | def decode_latents(self, latents): 350 | latents = latents.to(self.device, dtype=self.vae.dtype) 351 | latents = 1 / 0.18215 * latents 352 | image = self.vae.decode(latents).sample 353 | image = (image / 2 + 0.5).clamp(0, 1) 354 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 355 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() 356 | return image 357 | 358 | def check_inputs(self, prompt, height, width, callback_steps): 359 | if not isinstance(prompt, str) and not isinstance(prompt, list): 360 | raise ValueError( 361 | f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" 362 | ) 363 | 364 | if height % 8 != 0 or width % 8 != 0: 365 | raise ValueError( 366 | f"`height` and `width` have to be divisible by 8 but are {height} and {width}." 367 | ) 368 | 369 | if (callback_steps is None) or ( 370 | callback_steps is not None 371 | and (not isinstance(callback_steps, int) or callback_steps <= 0) 372 | ): 373 | raise ValueError( 374 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 375 | f" {type(callback_steps)}." 376 | ) 377 | 378 | def prepare_latents( 379 | self, 380 | batch_size, 381 | num_channels_latents, 382 | height, 383 | width, 384 | dtype, 385 | device, 386 | generator, 387 | latents=None, 388 | ): 389 | shape = (batch_size, num_channels_latents, height // 8, width // 8) 390 | if latents is None: 391 | if device.type == "mps": 392 | # randn does not work reproducibly on mps 393 | latents = torch.randn( 394 | shape, generator=generator, device="cpu", dtype=dtype 395 | ).to(device) 396 | else: 397 | latents = torch.randn( 398 | shape, generator=generator, device=device, dtype=dtype 399 | ) 400 | else: 401 | # if latents.shape != shape: 402 | # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") 403 | latents = latents.to(device) 404 | 405 | # scale the initial noise by the standard deviation required by the scheduler 406 | return latents 407 | 408 | def preprocess(self, image): 409 | if isinstance(image, torch.Tensor): 410 | return image 411 | elif isinstance(image, PIL.Image.Image): 412 | image = [image] 413 | 414 | if isinstance(image[0], PIL.Image.Image): 415 | w, h = image[0].size 416 | w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 417 | 418 | image = [ 419 | np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[ 420 | None, : 421 | ] 422 | for i in image 423 | ] 424 | image = np.concatenate(image, axis=0) 425 | image = np.array(image).astype(np.float32) / 255.0 426 | image = image.transpose(0, 3, 1, 2) 427 | image = 2.0 * image - 1.0 428 | image = torch.from_numpy(image) 429 | elif isinstance(image[0], torch.Tensor): 430 | image = torch.cat(image, dim=0) 431 | return image 432 | 433 | @torch.no_grad() 434 | def img2img( 435 | self, 436 | prompt: Union[str, List[str]], 437 | num_inference_steps: int = 50, 438 | guidance_scale: float = 7.5, 439 | negative_prompt: Optional[Union[str, List[str]]] = None, 440 | generator: Optional[torch.Generator] = None, 441 | image: Optional[torch.FloatTensor] = None, 442 | output_type: Optional[str] = "pil", 443 | latents=None, 444 | strength=1.0, 445 | pww_state=None, 446 | pww_attn_weight=1.0, 447 | sampler_name="", 448 | sampler_opt={}, 449 | scale_ratio=8.0, 450 | ): 451 | sampler = self.get_scheduler(sampler_name) 452 | if image is not None: 453 | image = self.preprocess(image) 454 | image = image.to(self.vae.device, dtype=self.vae.dtype) 455 | 456 | init_latents = self.vae.encode(image).latent_dist.sample(generator) 457 | latents = 0.18215 * init_latents 458 | 459 | # 2. Define call parameters 460 | batch_size = 1 if isinstance(prompt, str) else len(prompt) 461 | device = self._execution_device 462 | latents = latents.to(device, dtype=self.unet.dtype) 463 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 464 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 465 | # corresponds to doing no classifier free guidance. 466 | do_classifier_free_guidance = True 467 | if guidance_scale <= 1.0: 468 | raise ValueError("has to use guidance_scale") 469 | 470 | # 3. Encode input prompt 471 | text_ids, text_embeddings = self.prompt_parser([negative_prompt, prompt]) 472 | text_embeddings = text_embeddings.to(self.unet.dtype) 473 | 474 | init_timestep = ( 475 | int(num_inference_steps / min(strength, 0.999)) if strength > 0 else 0 476 | ) 477 | sigmas = self.get_sigmas(init_timestep, sampler_opt).to( 478 | text_embeddings.device, dtype=text_embeddings.dtype 479 | ) 480 | 481 | t_start = max(init_timestep - num_inference_steps, 0) 482 | sigma_sched = sigmas[t_start:] 483 | 484 | noise = randn_tensor( 485 | latents.shape, 486 | generator=generator, 487 | device=device, 488 | dtype=text_embeddings.dtype, 489 | ) 490 | latents = latents.to(device) 491 | latents = latents + noise * sigma_sched[0] 492 | 493 | # 5. Prepare latent variables 494 | self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device) 495 | self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to( 496 | latents.device 497 | ) 498 | 499 | img_state = self.encode_sketchs( 500 | pww_state, 501 | g_strength=pww_attn_weight, 502 | text_ids=text_ids, 503 | ) 504 | 505 | def model_fn(x, sigma): 506 | 507 | latent_model_input = torch.cat([x] * 2) 508 | weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max() 509 | encoder_state = { 510 | "img_state": img_state, 511 | "states": text_embeddings, 512 | "sigma": sigma[0], 513 | "weight_func": weight_func, 514 | } 515 | 516 | noise_pred = self.k_diffusion_model( 517 | latent_model_input, sigma, cond=encoder_state 518 | ) 519 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 520 | noise_pred = noise_pred_uncond + guidance_scale * ( 521 | noise_pred_text - noise_pred_uncond 522 | ) 523 | return noise_pred 524 | 525 | sampler_args = self.get_sampler_extra_args_i2i(sigma_sched, sampler) 526 | latents = sampler(model_fn, latents, **sampler_args) 527 | 528 | # 8. Post-processing 529 | image = self.decode_latents(latents) 530 | 531 | # 10. Convert to PIL 532 | if output_type == "pil": 533 | image = self.numpy_to_pil(image) 534 | 535 | return (image,) 536 | 537 | def get_sigmas(self, steps, params): 538 | discard_next_to_last_sigma = params.get("discard_next_to_last_sigma", False) 539 | steps += 1 if discard_next_to_last_sigma else 0 540 | 541 | if params.get("scheduler", None) == "karras": 542 | sigma_min, sigma_max = ( 543 | self.k_diffusion_model.sigmas[0].item(), 544 | self.k_diffusion_model.sigmas[-1].item(), 545 | ) 546 | sigmas = k_diffusion.sampling.get_sigmas_karras( 547 | n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=self.device 548 | ) 549 | else: 550 | sigmas = self.k_diffusion_model.get_sigmas(steps) 551 | 552 | if discard_next_to_last_sigma: 553 | sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) 554 | 555 | return sigmas 556 | 557 | # https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/48a15821de768fea76e66f26df83df3fddf18f4b/modules/sd_samplers.py#L454 558 | def get_sampler_extra_args_t2i(self, sigmas, eta, steps, func): 559 | extra_params_kwargs = {} 560 | 561 | if "eta" in inspect.signature(func).parameters: 562 | extra_params_kwargs["eta"] = eta 563 | 564 | if "sigma_min" in inspect.signature(func).parameters: 565 | extra_params_kwargs["sigma_min"] = sigmas[0].item() 566 | extra_params_kwargs["sigma_max"] = sigmas[-1].item() 567 | 568 | if "n" in inspect.signature(func).parameters: 569 | extra_params_kwargs["n"] = steps 570 | else: 571 | extra_params_kwargs["sigmas"] = sigmas 572 | 573 | return extra_params_kwargs 574 | 575 | # https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/48a15821de768fea76e66f26df83df3fddf18f4b/modules/sd_samplers.py#L454 576 | def get_sampler_extra_args_i2i(self, sigmas, func): 577 | extra_params_kwargs = {} 578 | 579 | if "sigma_min" in inspect.signature(func).parameters: 580 | ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last 581 | extra_params_kwargs["sigma_min"] = sigmas[-2] 582 | 583 | if "sigma_max" in inspect.signature(func).parameters: 584 | extra_params_kwargs["sigma_max"] = sigmas[0] 585 | 586 | if "n" in inspect.signature(func).parameters: 587 | extra_params_kwargs["n"] = len(sigmas) - 1 588 | 589 | if "sigma_sched" in inspect.signature(func).parameters: 590 | extra_params_kwargs["sigma_sched"] = sigmas 591 | 592 | if "sigmas" in inspect.signature(func).parameters: 593 | extra_params_kwargs["sigmas"] = sigmas 594 | 595 | return extra_params_kwargs 596 | 597 | @torch.no_grad() 598 | def txt2img( 599 | self, 600 | prompt: Union[str, List[str]], 601 | height: int = 512, 602 | width: int = 512, 603 | num_inference_steps: int = 50, 604 | guidance_scale: float = 7.5, 605 | negative_prompt: Optional[Union[str, List[str]]] = None, 606 | eta: float = 0.0, 607 | generator: Optional[torch.Generator] = None, 608 | latents: Optional[torch.FloatTensor] = None, 609 | output_type: Optional[str] = "pil", 610 | callback_steps: Optional[int] = 1, 611 | upscale=False, 612 | upscale_x: float = 2.0, 613 | upscale_method: str = "bicubic", 614 | upscale_antialias: bool = False, 615 | upscale_denoising_strength: int = 0.7, 616 | pww_state=None, 617 | pww_attn_weight=1.0, 618 | sampler_name="", 619 | sampler_opt={}, 620 | ): 621 | sampler = self.get_scheduler(sampler_name) 622 | # 1. Check inputs. Raise error if not correct 623 | self.check_inputs(prompt, height, width, callback_steps) 624 | 625 | # 2. Define call parameters 626 | batch_size = 1 if isinstance(prompt, str) else len(prompt) 627 | device = self._execution_device 628 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 629 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 630 | # corresponds to doing no classifier free guidance. 631 | do_classifier_free_guidance = True 632 | if guidance_scale <= 1.0: 633 | raise ValueError("has to use guidance_scale") 634 | 635 | # 3. Encode input prompt 636 | text_ids, text_embeddings = self.prompt_parser([negative_prompt, prompt]) 637 | text_embeddings = text_embeddings.to(self.unet.dtype) 638 | 639 | # 4. Prepare timesteps 640 | sigmas = self.get_sigmas(num_inference_steps, sampler_opt).to( 641 | text_embeddings.device, dtype=text_embeddings.dtype 642 | ) 643 | 644 | # 5. Prepare latent variables 645 | num_channels_latents = self.unet.in_channels 646 | latents = self.prepare_latents( 647 | batch_size, 648 | num_channels_latents, 649 | height, 650 | width, 651 | text_embeddings.dtype, 652 | device, 653 | generator, 654 | latents, 655 | ) 656 | latents = latents * sigmas[0] 657 | self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device) 658 | self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to( 659 | latents.device 660 | ) 661 | 662 | img_state = self.encode_sketchs( 663 | pww_state, 664 | g_strength=pww_attn_weight, 665 | text_ids=text_ids, 666 | ) 667 | 668 | def model_fn(x, sigma): 669 | 670 | latent_model_input = torch.cat([x] * 2) 671 | weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max() 672 | encoder_state = { 673 | "img_state": img_state, 674 | "states": text_embeddings, 675 | "sigma": sigma[0], 676 | "weight_func": weight_func, 677 | } 678 | 679 | noise_pred = self.k_diffusion_model( 680 | latent_model_input, sigma, cond=encoder_state 681 | ) 682 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 683 | noise_pred = noise_pred_uncond + guidance_scale * ( 684 | noise_pred_text - noise_pred_uncond 685 | ) 686 | return noise_pred 687 | 688 | extra_args = self.get_sampler_extra_args_t2i( 689 | sigmas, eta, num_inference_steps, sampler 690 | ) 691 | latents = sampler(model_fn, latents, **extra_args) 692 | 693 | if upscale: 694 | target_height = height * upscale_x 695 | target_width = width * upscale_x 696 | vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 697 | latents = torch.nn.functional.interpolate( 698 | latents, 699 | size=( 700 | int(target_height // vae_scale_factor), 701 | int(target_width // vae_scale_factor), 702 | ), 703 | mode=upscale_method, 704 | antialias=upscale_antialias, 705 | ) 706 | return self.img2img( 707 | prompt=prompt, 708 | num_inference_steps=num_inference_steps, 709 | guidance_scale=guidance_scale, 710 | negative_prompt=negative_prompt, 711 | generator=generator, 712 | latents=latents, 713 | strength=upscale_denoising_strength, 714 | sampler_name=sampler_name, 715 | sampler_opt=sampler_opt, 716 | pww_state=None, 717 | pww_attn_weight=pww_attn_weight / 2, 718 | ) 719 | 720 | # 8. Post-processing 721 | image = self.decode_latents(latents) 722 | 723 | # 10. Convert to PIL 724 | if output_type == "pil": 725 | image = self.numpy_to_pil(image) 726 | 727 | return (image,) 728 | 729 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | import string 4 | import tempfile 5 | import time 6 | import gradio as gr 7 | import numpy as np 8 | import torch 9 | 10 | from gradio import inputs 11 | from diffusers import ( 12 | AutoencoderKL, 13 | DDIMScheduler, 14 | UNet2DConditionModel, 15 | ) 16 | from modules.model import ( 17 | CrossAttnProcessor, 18 | StableDiffusionPipeline, 19 | ) 20 | from torchvision import transforms 21 | from transformers import CLIPTokenizer, CLIPTextModel 22 | from PIL import Image 23 | from pathlib import Path 24 | from safetensors.torch import load_file 25 | import modules.safechecker as _ 26 | from modules.lora import LoRANetwork 27 | 28 | models = [ 29 | # format: name, model_path, clip_skip 30 | ("AbyssOrangeMix_Base", "/root/workspace/storage/models/orangemix", 2), 31 | ("Stable Diffuison 1.5", "/root/models/stable-diffusion-v1-5", 1), 32 | ("AnimeSFW", "/root/workspace/animesfw", 2), 33 | ] 34 | 35 | base_name, base_model, clip_skip = models[0] 36 | 37 | samplers_k_diffusion = [ 38 | ("Euler a", "sample_euler_ancestral", {}), 39 | ("Euler", "sample_euler", {}), 40 | ("LMS", "sample_lms", {}), 41 | ("Heun", "sample_heun", {}), 42 | ("DPM2", "sample_dpm_2", {"discard_next_to_last_sigma": True}), 43 | ("DPM2 a", "sample_dpm_2_ancestral", {"discard_next_to_last_sigma": True}), 44 | ("DPM++ 2S a", "sample_dpmpp_2s_ancestral", {}), 45 | ("DPM++ 2M", "sample_dpmpp_2m", {}), 46 | ("DPM++ SDE", "sample_dpmpp_sde", {}), 47 | ("LMS Karras", "sample_lms", {"scheduler": "karras"}), 48 | ("DPM2 Karras", "sample_dpm_2", {"scheduler": "karras", "discard_next_to_last_sigma": True}), 49 | ("DPM2 a Karras", "sample_dpm_2_ancestral", {"scheduler": "karras", "discard_next_to_last_sigma": True}), 50 | ("DPM++ 2S a Karras", "sample_dpmpp_2s_ancestral", {"scheduler": "karras"}), 51 | ("DPM++ 2M Karras", "sample_dpmpp_2m", {"scheduler": "karras"}), 52 | ("DPM++ SDE Karras", "sample_dpmpp_sde", {"scheduler": "karras"}), 53 | ] 54 | 55 | # samplers_diffusers = [ 56 | # ("DDIMScheduler", "diffusers.schedulers.DDIMScheduler", {}) 57 | # ("DDPMScheduler", "diffusers.schedulers.DDPMScheduler", {}) 58 | # ("DEISMultistepScheduler", "diffusers.schedulers.DEISMultistepScheduler", {}) 59 | # ] 60 | 61 | start_time = time.time() 62 | 63 | scheduler = DDIMScheduler.from_pretrained( 64 | base_model, 65 | subfolder="scheduler", 66 | ) 67 | vae = AutoencoderKL.from_pretrained( 68 | "stabilityai/sd-vae-ft-ema", 69 | torch_dtype=torch.float16 70 | ) 71 | text_encoder = CLIPTextModel.from_pretrained( 72 | base_model, 73 | subfolder="text_encoder", 74 | torch_dtype=torch.float16, 75 | ) 76 | tokenizer = CLIPTokenizer.from_pretrained( 77 | base_model, 78 | subfolder="tokenizer", 79 | torch_dtype=torch.float16, 80 | ) 81 | unet = UNet2DConditionModel.from_pretrained( 82 | base_model, 83 | subfolder="unet", 84 | torch_dtype=torch.float16, 85 | ) 86 | pipe = StableDiffusionPipeline( 87 | text_encoder=text_encoder, 88 | tokenizer=tokenizer, 89 | unet=unet, 90 | vae=vae, 91 | scheduler=scheduler, 92 | ) 93 | 94 | unet.set_attn_processor(CrossAttnProcessor()) 95 | pipe.setup_text_encoder(clip_skip, text_encoder) 96 | if torch.cuda.is_available(): 97 | pipe = pipe.to("cuda") 98 | 99 | 100 | def get_model_list(): 101 | model_available = [] 102 | for model in models: 103 | if Path(model[1]).is_dir(): 104 | model_available.append(model) 105 | return model_available 106 | 107 | te_cache = { 108 | base_model: text_encoder 109 | } 110 | 111 | unet_cache = { 112 | base_model: unet 113 | } 114 | 115 | lora_cache = { 116 | base_model: LoRANetwork(text_encoder, unet) 117 | } 118 | 119 | te_base_weight_length = text_encoder.get_input_embeddings().weight.data.shape[0] 120 | original_prepare_for_tokenization = tokenizer.prepare_for_tokenization 121 | current_model = base_model 122 | 123 | def setup_model(name, lora_state=None, lora_scale=1.0): 124 | global pipe, current_model 125 | 126 | keys = [k[0] for k in models] 127 | model = models[keys.index(name)][1] 128 | if model not in unet_cache: 129 | unet = UNet2DConditionModel.from_pretrained(model, subfolder="unet", torch_dtype=torch.float16) 130 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder="text_encoder", torch_dtype=torch.float16) 131 | 132 | unet_cache[model] = unet 133 | te_cache[model] = text_encoder 134 | lora_cache[model] = LoRANetwork(text_encoder, unet) 135 | 136 | if current_model != model: 137 | # offload current model 138 | unet_cache[current_model].to("cpu") 139 | te_cache[current_model].to("cpu") 140 | lora_cache[current_model].to("cpu") 141 | current_model = model 142 | 143 | local_te, local_unet, local_lora, = te_cache[model], unet_cache[model], lora_cache[model] 144 | local_unet.set_attn_processor(CrossAttnProcessor()) 145 | local_lora.reset() 146 | clip_skip = models[keys.index(name)][2] 147 | 148 | if torch.cuda.is_available(): 149 | local_unet.to("cuda") 150 | local_te.to("cuda") 151 | 152 | if lora_state is not None and lora_state != "": 153 | local_lora.load(lora_state, lora_scale) 154 | local_lora.to(local_unet.device, dtype=local_unet.dtype) 155 | 156 | pipe.text_encoder, pipe.unet = local_te, local_unet 157 | pipe.setup_unet(local_unet) 158 | pipe.tokenizer.prepare_for_tokenization = original_prepare_for_tokenization 159 | pipe.tokenizer.added_tokens_encoder = {} 160 | pipe.tokenizer.added_tokens_decoder = {} 161 | pipe.setup_text_encoder(clip_skip, local_te) 162 | return pipe 163 | 164 | def error_str(error, title="Error"): 165 | return ( 166 | f"""#### {title} 167 | {error}""" 168 | if error 169 | else "" 170 | ) 171 | 172 | def make_token_names(embs): 173 | all_tokens = [] 174 | for name, vec in embs.items(): 175 | tokens = [f'emb-{name}-{i}' for i in range(len(vec))] 176 | all_tokens.append(tokens) 177 | return all_tokens 178 | 179 | def setup_tokenizer(tokenizer, embs): 180 | reg_match = [re.compile(fr"(?:^|(?<=\s|,)){k}(?=,|\s|$)") for k in embs.keys()] 181 | clip_keywords = [' '.join(s) for s in make_token_names(embs)] 182 | 183 | def parse_prompt(prompt: str): 184 | for m, v in zip(reg_match, clip_keywords): 185 | prompt = m.sub(v, prompt) 186 | return prompt 187 | 188 | def prepare_for_tokenization(self, text: str, is_split_into_words: bool = False, **kwargs): 189 | text = parse_prompt(text) 190 | r = original_prepare_for_tokenization(text, is_split_into_words, **kwargs) 191 | return r 192 | 193 | tokenizer.prepare_for_tokenization = prepare_for_tokenization.__get__(tokenizer, CLIPTokenizer) 194 | return [t for sublist in make_token_names(embs) for t in sublist] 195 | 196 | 197 | def inference( 198 | prompt, 199 | guidance, 200 | steps, 201 | width=512, 202 | height=512, 203 | seed=0, 204 | neg_prompt="", 205 | state=None, 206 | g_strength=0.4, 207 | img_input=None, 208 | i2i_scale=0.5, 209 | hr_enabled=False, 210 | hr_method="Latent", 211 | hr_scale=1.5, 212 | hr_denoise=0.8, 213 | sampler="DPM++ 2M Karras", 214 | embs=None, 215 | model=None, 216 | lora_state=None, 217 | lora_scale=None, 218 | ): 219 | if seed is None or seed == 0: 220 | seed = random.randint(0, 2147483647) 221 | 222 | pipe = setup_model(model, lora_state, lora_scale) 223 | generator = torch.Generator("cuda").manual_seed(int(seed)) 224 | sampler_name, sampler_opt = None, None 225 | for label, funcname, options in samplers_k_diffusion: 226 | if label == sampler: 227 | sampler_name, sampler_opt = funcname, options 228 | 229 | tokenizer, text_encoder = pipe.tokenizer, pipe.text_encoder 230 | if embs is not None and len(embs) > 0: 231 | delta_weight = [] 232 | ti_embs = {} 233 | for name, file in embs.items(): 234 | if str(file).endswith(".pt"): 235 | loaded_learned_embeds = torch.load(file, map_location="cpu") 236 | else: 237 | loaded_learned_embeds = load_file(file, device="cpu") 238 | loaded_learned_embeds = loaded_learned_embeds["string_to_param"]["*"] 239 | ti_embs[name] = loaded_learned_embeds 240 | 241 | if len(ti_embs) > 0: 242 | tokens = setup_tokenizer(tokenizer, ti_embs) 243 | added_tokens = tokenizer.add_tokens(tokens) 244 | delta_weight = torch.cat([val for val in ti_embs.values()], dim=0) 245 | 246 | assert added_tokens == delta_weight.shape[0] 247 | text_encoder.resize_token_embeddings(len(tokenizer)) 248 | token_embeds = text_encoder.get_input_embeddings().weight.data 249 | token_embeds[-delta_weight.shape[0]:] = delta_weight 250 | 251 | config = { 252 | "negative_prompt": neg_prompt, 253 | "num_inference_steps": int(steps), 254 | "guidance_scale": guidance, 255 | "generator": generator, 256 | "sampler_name": sampler_name, 257 | "sampler_opt": sampler_opt, 258 | "pww_state": state, 259 | "pww_attn_weight": g_strength, 260 | } 261 | 262 | if img_input is not None: 263 | ratio = min(height / img_input.height, width / img_input.width) 264 | img_input = img_input.resize( 265 | (int(img_input.width * ratio), int(img_input.height * ratio)), Image.LANCZOS 266 | ) 267 | result = pipe.img2img(prompt, image=img_input, strength=i2i_scale, **config) 268 | elif hr_enabled: 269 | result = pipe.txt2img( 270 | prompt, 271 | width=width, 272 | height=height, 273 | upscale=True, 274 | upscale_x=hr_scale, 275 | upscale_denoising_strength=hr_denoise, 276 | **config, 277 | **latent_upscale_modes[hr_method], 278 | ) 279 | else: 280 | result = pipe.txt2img(prompt, width=width, height=height, **config) 281 | 282 | return gr.Image.update(result[0][0], label=f"Initial Seed: {seed}") 283 | 284 | 285 | color_list = [] 286 | 287 | 288 | def get_color(n): 289 | for _ in range(n - len(color_list)): 290 | color_list.append(tuple(np.random.random(size=3) * 256)) 291 | return color_list 292 | 293 | 294 | def create_mixed_img(current, state, w=512, h=512): 295 | w, h = int(w), int(h) 296 | image_np = np.full([h, w, 4], 255) 297 | if state is None: 298 | state = {} 299 | 300 | colors = get_color(len(state)) 301 | idx = 0 302 | 303 | for key, item in state.items(): 304 | if item["map"] is not None: 305 | m = item["map"] < 255 306 | alpha = 150 307 | if current == key: 308 | alpha = 200 309 | image_np[m] = colors[idx] + (alpha,) 310 | idx += 1 311 | 312 | return image_np 313 | 314 | 315 | # width.change(apply_new_res, inputs=[width, height, global_stats], outputs=[global_stats, sp, rendered]) 316 | def apply_new_res(w, h, state): 317 | w, h = int(w), int(h) 318 | 319 | for key, item in state.items(): 320 | if item["map"] is not None: 321 | item["map"] = resize(item["map"], w, h) 322 | 323 | update_img = gr.Image.update(value=create_mixed_img("", state, w, h)) 324 | return state, update_img 325 | 326 | 327 | def detect_text(text, state, width, height): 328 | 329 | if text is None or text == "": 330 | return None, None, gr.Radio.update(value=None), None 331 | 332 | t = text.split(",") 333 | new_state = {} 334 | 335 | for item in t: 336 | item = item.strip() 337 | if item == "": 338 | continue 339 | if state is not None and item in state: 340 | new_state[item] = { 341 | "map": state[item]["map"], 342 | "weight": state[item]["weight"], 343 | "mask_outsides": state[item]["mask_outsides"], 344 | } 345 | else: 346 | new_state[item] = { 347 | "map": None, 348 | "weight": 0.5, 349 | "mask_outsides": False 350 | } 351 | update = gr.Radio.update(choices=[key for key in new_state.keys()], value=None) 352 | update_img = gr.update(value=create_mixed_img("", new_state, width, height)) 353 | update_sketch = gr.update(value=None, interactive=False) 354 | return new_state, update_sketch, update, update_img 355 | 356 | 357 | def resize(img, w, h): 358 | trs = transforms.Compose( 359 | [ 360 | transforms.ToPILImage(), 361 | transforms.Resize(min(h, w)), 362 | transforms.CenterCrop((h, w)), 363 | ] 364 | ) 365 | result = np.array(trs(img), dtype=np.uint8) 366 | return result 367 | 368 | 369 | def switch_canvas(entry, state, width, height): 370 | if entry == None: 371 | return None, 0.5, False, create_mixed_img("", state, width, height) 372 | 373 | return ( 374 | gr.update(value=None, interactive=True), 375 | gr.update(value=state[entry]["weight"] if entry in state else 0.5), 376 | gr.update(value=state[entry]["mask_outsides"] if entry in state else False), 377 | create_mixed_img(entry, state, width, height), 378 | ) 379 | 380 | 381 | def apply_canvas(selected, draw, state, w, h): 382 | if selected in state: 383 | w, h = int(w), int(h) 384 | state[selected]["map"] = resize(draw, w, h) 385 | return state, gr.Image.update(value=create_mixed_img(selected, state, w, h)) 386 | 387 | 388 | def apply_weight(selected, weight, state): 389 | if selected in state: 390 | state[selected]["weight"] = weight 391 | return state 392 | 393 | 394 | def apply_option(selected, mask, state): 395 | if selected in state: 396 | state[selected]["mask_outsides"] = mask 397 | return state 398 | 399 | 400 | # sp2, radio, width, height, global_stats 401 | def apply_image(image, selected, w, h, strgength, mask, state): 402 | if selected in state: 403 | state[selected] = { 404 | "map": resize(image, w, h), 405 | "weight": strgength, 406 | "mask_outsides": mask 407 | } 408 | 409 | return state, gr.Image.update(value=create_mixed_img(selected, state, w, h)) 410 | 411 | # [ti_state, lora_state, ti_vals, lora_vals, uploads] 412 | def add_net(files, ti_state, lora_state): 413 | if files is None: 414 | return ti_state, "", lora_state, None 415 | 416 | for file in files: 417 | item = Path(file.name) 418 | stripedname = str(item.stem).strip() 419 | if item.suffix == ".pt": 420 | state_dict = torch.load(file.name, map_location="cpu") 421 | else: 422 | state_dict = load_file(file.name, device="cpu") 423 | if any("lora" in k for k in state_dict.keys()): 424 | lora_state = file.name 425 | else: 426 | ti_state[stripedname] = file.name 427 | 428 | return ( 429 | ti_state, 430 | lora_state, 431 | gr.Text.update(f"{[key for key in ti_state.keys()]}"), 432 | gr.Text.update(f"{lora_state}"), 433 | gr.Files.update(value=None), 434 | ) 435 | 436 | 437 | # [ti_state, lora_state, ti_vals, lora_vals, uploads] 438 | def clean_states(ti_state, lora_state): 439 | return ( 440 | dict(), 441 | None, 442 | gr.Text.update(f""), 443 | gr.Text.update(f""), 444 | gr.File.update(value=None), 445 | ) 446 | 447 | 448 | latent_upscale_modes = { 449 | "Latent": {"upscale_method": "bilinear", "upscale_antialias": False}, 450 | "Latent (antialiased)": {"upscale_method": "bilinear", "upscale_antialias": True}, 451 | "Latent (bicubic)": {"upscale_method": "bicubic", "upscale_antialias": False}, 452 | "Latent (bicubic antialiased)": { 453 | "upscale_method": "bicubic", 454 | "upscale_antialias": True, 455 | }, 456 | "Latent (nearest)": {"upscale_method": "nearest", "upscale_antialias": False}, 457 | "Latent (nearest-exact)": { 458 | "upscale_method": "nearest-exact", 459 | "upscale_antialias": False, 460 | }, 461 | } 462 | 463 | css = """ 464 | .finetuned-diffusion-div div{ 465 | display:inline-flex; 466 | align-items:center; 467 | gap:.8rem; 468 | font-size:1.75rem; 469 | padding-top:2rem; 470 | } 471 | .finetuned-diffusion-div div h1{ 472 | font-weight:900; 473 | margin-bottom:7px 474 | } 475 | .finetuned-diffusion-div p{ 476 | margin-bottom:10px; 477 | font-size:94% 478 | } 479 | .box { 480 | float: left; 481 | height: 20px; 482 | width: 20px; 483 | margin-bottom: 15px; 484 | border: 1px solid black; 485 | clear: both; 486 | } 487 | a{ 488 | text-decoration:underline 489 | } 490 | .tabs{ 491 | margin-top:0; 492 | margin-bottom:0 493 | } 494 | #gallery{ 495 | min-height:20rem 496 | } 497 | .no-border { 498 | border: none !important; 499 | } 500 | """ 501 | with gr.Blocks(css=css) as demo: 502 | gr.HTML( 503 | f""" 504 |
505 |
506 |

Demo for diffusion models

507 |
508 |

Hso @ nyanko.sketch2img.gradio

509 |
510 | """ 511 | ) 512 | global_stats = gr.State(value={}) 513 | 514 | with gr.Row(): 515 | 516 | with gr.Column(scale=55): 517 | model = gr.Dropdown( 518 | choices=[k[0] for k in get_model_list()], 519 | label="Model", 520 | value=base_name, 521 | ) 522 | image_out = gr.Image(height=512) 523 | # gallery = gr.Gallery( 524 | # label="Generated images", show_label=False, elem_id="gallery" 525 | # ).style(grid=[1], height="auto") 526 | 527 | with gr.Column(scale=45): 528 | 529 | with gr.Group(): 530 | 531 | with gr.Row(): 532 | with gr.Column(scale=70): 533 | 534 | prompt = gr.Textbox( 535 | label="Prompt", 536 | value="loli cat girl, blue eyes, flat chest, solo, long messy silver hair, blue capelet, cat ears, cat tail, upper body", 537 | show_label=True, 538 | max_lines=4, 539 | placeholder="Enter prompt.", 540 | ) 541 | neg_prompt = gr.Textbox( 542 | label="Negative Prompt", 543 | value="bad quality, low quality, jpeg artifact, cropped", 544 | show_label=True, 545 | max_lines=4, 546 | placeholder="Enter negative prompt.", 547 | ) 548 | 549 | generate = gr.Button(value="Generate").style( 550 | rounded=(False, True, True, False) 551 | ) 552 | 553 | with gr.Tab("Options"): 554 | 555 | with gr.Group(): 556 | 557 | # n_images = gr.Slider(label="Images", value=1, minimum=1, maximum=4, step=1) 558 | with gr.Row(): 559 | guidance = gr.Slider( 560 | label="Guidance scale", value=7.5, maximum=15 561 | ) 562 | steps = gr.Slider( 563 | label="Steps", value=25, minimum=2, maximum=75, step=1 564 | ) 565 | 566 | with gr.Row(): 567 | width = gr.Slider( 568 | label="Width", value=512, minimum=64, maximum=2048, step=64 569 | ) 570 | height = gr.Slider( 571 | label="Height", value=512, minimum=64, maximum=2048, step=64 572 | ) 573 | 574 | sampler = gr.Dropdown( 575 | value="DPM++ 2M Karras", 576 | label="Sampler", 577 | choices=[s[0] for s in samplers_k_diffusion], 578 | ) 579 | seed = gr.Number(label="Seed (0 = random)", value=0) 580 | 581 | with gr.Tab("Image to image"): 582 | with gr.Group(): 583 | 584 | inf_image = gr.Image( 585 | label="Image", height=256, tool="editor", type="pil" 586 | ) 587 | inf_strength = gr.Slider( 588 | label="Transformation strength", 589 | minimum=0, 590 | maximum=1, 591 | step=0.01, 592 | value=0.5, 593 | ) 594 | 595 | def res_cap(g, w, h, x): 596 | if g: 597 | return f"Enable upscaler: {w}x{h} to {int(w*x)}x{int(h*x)}" 598 | else: 599 | return "Enable upscaler" 600 | 601 | with gr.Tab("Hires fix"): 602 | with gr.Group(): 603 | 604 | hr_enabled = gr.Checkbox(label="Enable upscaler", value=False) 605 | hr_method = gr.Dropdown( 606 | [key for key in latent_upscale_modes.keys()], 607 | value="Latent", 608 | label="Upscale method", 609 | ) 610 | hr_scale = gr.Slider( 611 | label="Upscale factor", 612 | minimum=1.0, 613 | maximum=3, 614 | step=0.1, 615 | value=1.5, 616 | ) 617 | hr_denoise = gr.Slider( 618 | label="Denoising strength", 619 | minimum=0.0, 620 | maximum=1.0, 621 | step=0.1, 622 | value=0.8, 623 | ) 624 | 625 | hr_scale.change( 626 | lambda g, x, w, h: gr.Checkbox.update( 627 | label=res_cap(g, w, h, x) 628 | ), 629 | inputs=[hr_enabled, hr_scale, width, height], 630 | outputs=hr_enabled, 631 | ) 632 | hr_enabled.change( 633 | lambda g, x, w, h: gr.Checkbox.update( 634 | label=res_cap(g, w, h, x) 635 | ), 636 | inputs=[hr_enabled, hr_scale, width, height], 637 | outputs=hr_enabled, 638 | ) 639 | 640 | with gr.Tab("Embeddings/Loras"): 641 | 642 | ti_state = gr.State(dict()) 643 | lora_state = gr.State() 644 | 645 | with gr.Group(): 646 | with gr.Row(): 647 | with gr.Column(scale=90): 648 | ti_vals = gr.Text(label="Loaded embeddings") 649 | 650 | with gr.Row(): 651 | with gr.Column(scale=90): 652 | lora_vals = gr.Text(label="Loaded loras") 653 | 654 | with gr.Row(): 655 | 656 | uploads = gr.Files(label="Upload new embeddings/lora") 657 | 658 | with gr.Column(): 659 | lora_scale = gr.Slider( 660 | label="Lora scale", 661 | minimum=0, 662 | maximum=2, 663 | step=0.01, 664 | value=1.0, 665 | ) 666 | btn = gr.Button(value="Upload") 667 | btn_del = gr.Button(value="Reset") 668 | 669 | btn.click( 670 | add_net, 671 | inputs=[uploads, ti_state, lora_state], 672 | outputs=[ti_state, lora_state, ti_vals, lora_vals, uploads], 673 | ) 674 | btn_del.click( 675 | clean_states, 676 | inputs=[ti_state, lora_state], 677 | outputs=[ti_state, lora_state, ti_vals, lora_vals, uploads], 678 | ) 679 | 680 | # error_output = gr.Markdown() 681 | 682 | gr.HTML( 683 | f""" 684 |
685 |
686 |

Paint with words

687 |
688 |

689 | Will use the following formula: w = scale * token_weight_martix * log(1 + sigma) * max(qk). 690 |

691 |
692 | """ 693 | ) 694 | 695 | with gr.Row(): 696 | 697 | with gr.Column(scale=55): 698 | 699 | rendered = gr.Image( 700 | invert_colors=True, 701 | source="canvas", 702 | interactive=False, 703 | image_mode="RGBA", 704 | ) 705 | 706 | with gr.Column(scale=45): 707 | 708 | with gr.Group(): 709 | with gr.Row(): 710 | with gr.Column(scale=70): 711 | g_strength = gr.Slider( 712 | label="Weight scaling", 713 | minimum=0, 714 | maximum=0.8, 715 | step=0.01, 716 | value=0.4, 717 | ) 718 | 719 | text = gr.Textbox( 720 | lines=2, 721 | interactive=True, 722 | label="Token to Draw: (Separate by comma)", 723 | ) 724 | 725 | radio = gr.Radio([], label="Tokens") 726 | 727 | sk_update = gr.Button(value="Update").style( 728 | rounded=(False, True, True, False) 729 | ) 730 | 731 | # g_strength.change(lambda b: gr.update(f"Scaled additional attn: $w = {b} \log (1 + \sigma) \std (Q^T K)$."), inputs=g_strength, outputs=[g_output]) 732 | 733 | with gr.Tab("SketchPad"): 734 | 735 | sp = gr.Image( 736 | image_mode="L", 737 | tool="sketch", 738 | source="canvas", 739 | interactive=False, 740 | ) 741 | 742 | mask_outsides = gr.Checkbox( 743 | label="Mask other areas", 744 | value=False 745 | ) 746 | 747 | strength = gr.Slider( 748 | label="Token strength", 749 | minimum=0, 750 | maximum=0.8, 751 | step=0.01, 752 | value=0.5, 753 | ) 754 | 755 | 756 | sk_update.click( 757 | detect_text, 758 | inputs=[text, global_stats, width, height], 759 | outputs=[global_stats, sp, radio, rendered], 760 | ) 761 | radio.change( 762 | switch_canvas, 763 | inputs=[radio, global_stats, width, height], 764 | outputs=[sp, strength, mask_outsides, rendered], 765 | ) 766 | sp.edit( 767 | apply_canvas, 768 | inputs=[radio, sp, global_stats, width, height], 769 | outputs=[global_stats, rendered], 770 | ) 771 | strength.change( 772 | apply_weight, 773 | inputs=[radio, strength, global_stats], 774 | outputs=[global_stats], 775 | ) 776 | mask_outsides.change( 777 | apply_option, 778 | inputs=[radio, mask_outsides, global_stats], 779 | outputs=[global_stats], 780 | ) 781 | 782 | with gr.Tab("UploadFile"): 783 | 784 | sp2 = gr.Image( 785 | image_mode="L", 786 | source="upload", 787 | shape=(512, 512), 788 | ) 789 | 790 | mask_outsides2 = gr.Checkbox( 791 | label="Mask other areas", 792 | value=False 793 | ) 794 | 795 | strength2 = gr.Slider( 796 | label="Token strength", 797 | minimum=0, 798 | maximum=0.8, 799 | step=0.01, 800 | value=0.5, 801 | ) 802 | 803 | apply_style = gr.Button(value="Apply") 804 | apply_style.click( 805 | apply_image, 806 | inputs=[sp2, radio, width, height, strength2, mask_outsides2, global_stats], 807 | outputs=[global_stats, rendered], 808 | ) 809 | 810 | width.change( 811 | apply_new_res, 812 | inputs=[width, height, global_stats], 813 | outputs=[global_stats, rendered], 814 | ) 815 | height.change( 816 | apply_new_res, 817 | inputs=[width, height, global_stats], 818 | outputs=[global_stats, rendered], 819 | ) 820 | 821 | # color_stats = gr.State(value={}) 822 | # text.change(detect_color, inputs=[sp, text, color_stats], outputs=[color_stats, rendered]) 823 | # sp.change(detect_color, inputs=[sp, text, color_stats], outputs=[color_stats, rendered]) 824 | 825 | inputs = [ 826 | prompt, 827 | guidance, 828 | steps, 829 | width, 830 | height, 831 | seed, 832 | neg_prompt, 833 | global_stats, 834 | g_strength, 835 | inf_image, 836 | inf_strength, 837 | hr_enabled, 838 | hr_method, 839 | hr_scale, 840 | hr_denoise, 841 | sampler, 842 | ti_state, 843 | model, 844 | lora_state, 845 | lora_scale, 846 | ] 847 | outputs = [image_out] 848 | prompt.submit(inference, inputs=inputs, outputs=outputs) 849 | generate.click(inference, inputs=inputs, outputs=outputs) 850 | 851 | print(f"Space built in {time.time() - start_time:.2f} seconds") 852 | # demo.launch(share=True) 853 | demo.launch() 854 | --------------------------------------------------------------------------------