├── .github └── workflows │ └── publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── modules ├── rng.py ├── rng_philox.py ├── shared.py └── text_processing │ ├── classic_engine.py │ ├── emphasis.py │ ├── parsing.py │ ├── past_classic_engine.py │ ├── prompt_parser.py │ ├── t5_engine.py │ └── textual_inversion.py ├── nodes.py ├── pyproject.toml ├── smZNodes.py └── web ├── exif.js ├── metadata.js └── smZdynamicWidgets.js /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | jobs: 11 | publish-node: 12 | name: Publish Custom Node to registry 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Check out code 16 | uses: actions/checkout@v4 17 | - name: Publish Custom Node 18 | uses: Comfy-Org/publish-node-action@main 19 | with: 20 | ## Add your own personal access token to your Github Repository secrets and reference it here. 21 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | backup* 163 | **/.DS_Store 164 | **/.venv 165 | **/.vscode -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # smZNodes 3 | A selection of custom nodes for [ComfyUI](https://github.com/comfyanonymous/ComfyUI). 4 | 5 | 1. [CLIP Text Encode++](#clip-text-encode) 6 | 2. [Settings](#settings) 7 | 8 | Contents 9 | 10 | * [Tips to get reproducible results on both UIs](#tips-to-get-reproducible-results-on-both-uis) 11 | * [FAQs](#faqs) 12 | * [Installation](#installation) 13 | 14 | 15 | ## CLIP Text Encode++ 16 | 17 |

18 |

19 |

20 | Clip Text Encode++ – Default settings on stable-diffusion-webui 21 |

22 |

23 | 24 | CLIP Text Encode++ can generate identical embeddings from [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) for [ComfyUI](https://github.com/comfyanonymous/ComfyUI). 25 | 26 | This means you can reproduce the same images generated from `stable-diffusion-webui` on `ComfyUI`. 27 | 28 | Simple prompts generate _identical_ images. More complex prompts with complex attention/emphasis/weighting may generate images with slight differences. In that case, you can try using the `Settings` node to match outputs. 29 | 30 | ### Features 31 | 32 | - [Prompt editing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#prompt-editing) 33 | - [Alternating words](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alternating-words) 34 | - [`AND`](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#composable-diffusion) keyword (similar to the ConditioningCombine node) 35 | - [`BREAK`](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#break-keyword) keyword (same as the ConditioningConcat node) 36 | - Weight normalization 37 | - Optional `embedding:` identifier 38 | 39 | 40 | ### Comparisons 41 | These images can be dragged into ComfyUI to load their workflows. Each image is done using the [Silicon29](https://huggingface.co/Xynon/SD-Silicon) (in SD v1.5) checkpoint with 18 steps using the Heun sampler. 42 | 43 | |stable-diffusion-webui|A1111 parser|Comfy parser| 44 | |:---:|:---:|:---:| 45 | | ![00008-0-cinematic wide shot of the ocean, beach, (palmtrees_1 5), at sunset, milkyway](https://github.com/shiimizu/ComfyUI_smZNodes/assets/54494639/719457d8-96fc-495e-aabc-48c4fe4d648d) | ![A1111 parser comparison 1](https://github.com/shiimizu/ComfyUI_smZNodes/assets/54494639/e446b4ab-6f11-4194-b708-f7bdd1cb8fa8) | ![Comfy parser comparison 1](https://github.com/shiimizu/ComfyUI_smZNodes/assets/54494639/e2e04235-02cc-433a-a2f0-7d58be14d6f5) | 46 | | ![00007-0-a photo of an astronaut riding a horse on mars, ((palmtrees_1 2) on water)](https://github.com/shiimizu/ComfyUI_smZNodes/assets/54494639/9ad8b569-8c6d-4a09-bf36-288d81ce4cf9) | ![A1111 parser comparison 2](https://github.com/shiimizu/ComfyUI_smZNodes/assets/54494639/81767441-c286-41db-a59a-4c69603d84d7) | ![Comfy parser comparison 2](https://github.com/shiimizu/ComfyUI_smZNodes/assets/54494639/ed62c23c-c9bd-41cf-9a37-eab4f9d5e12b) | 47 | 48 | Image slider links: 49 | - https://imgsli.com/MTkxMjE0 50 | - https://imgsli.com/MTkxMjEy 51 | 52 | ### Options 53 | 54 | |Name|Description| 55 | | --- | --- | 56 | | `parser` | The parser to parse prompts into tokens and then transformed (encoded) into embeddings. Taken from [SD.Next](https://github.com/vladmandic/automatic/discussions/99#discussioncomment-5931014). | 57 | | `mean_normalization` | Whether to take the mean of your prompt weights. It's `true` by default on `stable-diffusion-webui`.
This is implemented according to how it is in `stable-diffusion-webui`. | 58 | | `multi_conditioning` |
For each prompt, the list is obtained by splitting the prompt using the `AND` separator.
See: [Compositional Visual Generation with Composable Diffusion Models](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/)
| 59 | |`use_old_emphasis_implementation`|
Use old emphasis implementation. Can be useful to reproduce old seeds.
| 60 | 61 | > [!TIP] 62 | > You can right click the node to show/hide some of the widgets. E.g. the `with_SDXL` option. 63 | 64 |
65 | 66 | | Parser | Description | 67 | | ----------------- | -------------------------------------------------------------------------------- | 68 | | `comfy` | The default way `ComfyUI` handles everything | 69 | | `comfy++` | Uses `ComfyUI`'s parser but encodes tokens the way `stable-diffusion-webui` does, allowing to take the mean as they do. | 70 | | `A1111` | The default parser used in `stable-diffusion-webui` | 71 | | `full` | Same as `A1111` but whitespaces, newlines, and special characters are stripped | 72 | | `compel` | Uses [`compel`](https://github.com/damian0815/compel) | 73 | | `fixed attention` | Prompt is untampered with | 74 | 75 | > [!IMPORTANT] 76 | > Every `parser` except `comfy` uses `stable-diffusion-webui`'s encoding pipeline. 77 | 78 | > [!WARNING] 79 | > LoRA syntax (``) is not suppprted. 80 | 81 | ## Settings 82 | 83 |
84 | Settings-node-showcase 85 |

Settings node showcase

86 |
87 | 88 | 89 | The `Settings` node is a dynamic node functioning similar to the Reroute node and is used to fine-tune results during sampling or tokenization. The inputs can be replaced with another input type even after it's been connected. `CLIP` inputs only apply settings to CLIP Text Encode++. Settings apply locally based on its links just like nodes that do model patches. I made this node to explore the various settings found in `stable-diffusion-webui`. 90 | 91 | This node can change whenever it is updated, so you may have to **recreate** it to prevent issues. Settings can be overridden by using another `Settings` node somewhere past a previous one. Right click the node for the `Hide/show all descriptions` menu option. 92 | 93 | 94 | ## Tips to get reproducible results on both UIs 95 | - Use the same seed, sampler settings, RNG (CPU or GPU), clip skip (CLIP Set Last Layer), etc. 96 | - Ancestral and SDE samplers may not be deterministic. 97 | - If you're using `DDIM` as your sampler, use the `ddim_uniform` scheduler. 98 | - There are different `unipc` configurations. Adjust accordingly on both UIs. 99 | 100 | ## FAQs 101 | - How does this differ from [`ComfyUI_ADV_CLIP_emb`](https://github.com/BlenderNeko/ComfyUI_ADV_CLIP_emb)? 102 | - While the weights are normalized in the same manner, the tokenization and encoding pipeline that's taken from stable-diffusion-webui differs from ComfyUI's. These small changes add up and ultimately produces different results. 103 | - Where can I learn more about how ComfyUI interprets weights? 104 | - https://comfyanonymous.github.io/ComfyUI_examples/faq/ 105 | - https://blenderneko.github.io/ComfyUI-docs/Interface/Textprompts/ 106 | - [comfyui.creamlab.net](https://comfyui.creamlab.net/guides/f_text_prompt) 107 | 108 | 109 | ## Installation 110 | 111 | Three methods are available for installation: 112 | 113 | 1. Load via [ComfyUI Manager](https://github.com/ltdrdata/ComfyUI-Manager) 114 | 2. Clone the repository directly into the extensions directory. 115 | 3. Download the project manually. 116 | 117 | 118 | ### Load via ComfyUI Manager 119 | 120 | 121 |
122 | ComfyUI Manager 123 |

Install via ComfyUI Manager

124 |
125 | 126 | ### Clone Repository 127 | 128 | ```shell 129 | cd path/to/your/ComfyUI/custom_nodes 130 | git clone https://github.com/shiimizu/ComfyUI_smZNodes.git 131 | ``` 132 | 133 | ### Download Manually 134 | 135 | 1. Download the project archive from [here](https://github.com/shiimizu/ComfyUI_smZNodes/archive/refs/heads/main.zip). 136 | 2. Extract the downloaded zip file. 137 | 3. Move the extracted files to `path/to/your/ComfyUI/custom_nodes`. 138 | 4. Restart ComfyUI 139 | 140 | The folder structure should resemble: `path/to/your/ComfyUI/custom_nodes/ComfyUI_smZNodes`. 141 | 142 | 143 | ### Update 144 | 145 | To update the extension, update via [ComfyUI Manager](https://github.com/ltdrdata/ComfyUI-Manager) or pull the latest changes from the repository: 146 | 147 | ```shell 148 | cd path/to/your/ComfyUI/custom_nodes/ComfyUI_smZNodes 149 | git pull 150 | ``` 151 | 152 | ## Credits 153 | 154 | * [AUTOMATIC1111](https://github.com/AUTOMATIC1111) / [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) 155 | * [comfyanonymous](https://github.com/comfyanonymous) / [ComfyUI](https://github.com/comfyanonymous/ComfyUI) 156 | * [vladmandic](https://github.com/vladmandic) / [SD.Next](https://github.com/vladmandic/automatic) 157 | * [lllyasviel](https://github.com/lllyasviel) / [stable-diffusion-webui-forge](https://github.com/lllyasviel/stable-diffusion-webui-forge) -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import importlib 4 | import subprocess 5 | from pathlib import Path 6 | 7 | def get_modules(): 8 | from sys import modules 9 | s = set() 10 | for m in modules.values(): 11 | try: s.add(m) 12 | except Exception: ... 13 | return s 14 | 15 | def reload_modules(m): 16 | s = get_modules() 17 | for module in s - m : 18 | try: importlib.reload(module) 19 | except Exception: ... 20 | 21 | PRELOADED_MODULES = None 22 | 23 | def install(module, PRELOADED_MODULES): 24 | if importlib.util.find_spec(module) is not None: return 25 | if PRELOADED_MODULES is None: 26 | PRELOADED_MODULES = get_modules() 27 | import sys 28 | try: 29 | print(f"\033[92m[smZNodes] \033[0;31m{module} is not installed. Attempting to install...\033[0m") 30 | subprocess.check_call([sys.executable, "-m", "pip", "install", module]) 31 | reload_modules(PRELOADED_MODULES) 32 | print(f"\033[92m[smZNodes] {module} Installed!\033[0m") 33 | except Exception as e: 34 | print(f"\033[92m[smZNodes] \033[0;31mFailed to install {module}.\033[0m") 35 | return PRELOADED_MODULES 36 | 37 | for pkg in ['compel', 'lark']: 38 | PRELOADED_MODULES = install(pkg, PRELOADED_MODULES) 39 | 40 | # ============================ 41 | # web 42 | 43 | cwd_path = Path(__file__).parent 44 | comfy_path = cwd_path.parent.parent 45 | 46 | def setup_web_extension(): 47 | import nodes 48 | web_extension_path = os.path.join(comfy_path, "web", "extensions", "smZNodes") 49 | if os.path.exists(web_extension_path): 50 | shutil.rmtree(web_extension_path) 51 | if not hasattr(nodes, "EXTENSION_WEB_DIRS"): 52 | if not os.path.exists(web_extension_path): 53 | os.makedirs(web_extension_path) 54 | js_src_path = os.path.join(cwd_path, "web", "smZdynamicWidgets.js") 55 | shutil.copy(js_src_path, web_extension_path) 56 | 57 | setup_web_extension() 58 | 59 | # ============================ 60 | 61 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 62 | WEB_DIRECTORY = "./web" 63 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"] 64 | 65 | from .smZNodes import add_custom_samplers, register_hooks 66 | 67 | add_custom_samplers() 68 | register_hooks() 69 | -------------------------------------------------------------------------------- /modules/rng.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from comfy.model_patcher import ModelPatcher 4 | from . import shared, rng_philox 5 | 6 | class TorchHijack: 7 | """This is here to replace torch.randn_like of k-diffusion. 8 | 9 | k-diffusion has random_sampler argument for most samplers, but not for all, so 10 | this is needed to properly replace every use of torch.randn_like. 11 | 12 | We need to replace to make images generated in batches to be same as images generated individually.""" 13 | 14 | def __init__(self, generator, randn_source, init=True): 15 | self.generator = generator 16 | self.randn_source = randn_source 17 | self.init = init 18 | 19 | def __getattr__(self, item): 20 | if item == 'randn_like': 21 | return self.randn_like 22 | 23 | if hasattr(torch, item): 24 | return getattr(torch, item) 25 | 26 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") 27 | 28 | def randn_like(self, x): 29 | return randn_without_seed(x, generator=self.generator, randn_source=self.randn_source) 30 | 31 | def randn_without_seed(x, generator=None, randn_source="cpu"): 32 | """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. 33 | 34 | Use either randn() or manual_seed() to initialize the generator.""" 35 | if randn_source == "nv": 36 | return torch.asarray(generator.randn(x.size()), device=x.device) 37 | else: 38 | return torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=generator.device, generator=generator).to(device=x.device) 39 | 40 | def prepare_noise(latent_image, seed, noise_inds=None, device='cpu'): 41 | """ 42 | creates random noise given a latent image and a seed. 43 | optional arg skip can be used to skip and discard x number of noise generations for a given seed 44 | """ 45 | opts = None 46 | opts_found = False 47 | model = _find_outer_instance('model', ModelPatcher) 48 | if (model is not None and (opts:=model.model_options.get(shared.Options.KEY)) is None) or opts is None: 49 | import comfy.samplers 50 | guider = _find_outer_instance('guider', comfy.samplers.CFGGuider) 51 | model = getattr(guider, 'model_patcher', None) 52 | if (model is not None and (opts:=model.model_options.get(shared.Options.KEY)) is None) or opts is None: 53 | pass 54 | opts_found = opts is not None 55 | if not opts_found: 56 | opts = shared.opts_default 57 | device = torch.device("cpu") 58 | 59 | if opts.randn_source == 'gpu': 60 | import comfy.model_management 61 | device = comfy.model_management.get_torch_device() 62 | 63 | device_orig = device 64 | device = torch.device("cpu") if opts.randn_source == "cpu" else device_orig 65 | 66 | def get_generator(seed): 67 | nonlocal device, opts 68 | if opts.randn_source == 'nv': 69 | generator = rng_philox.Generator(seed) 70 | else: 71 | generator = torch.Generator(device=device).manual_seed(seed) 72 | return generator 73 | 74 | def get_generator_obj(seed): 75 | nonlocal opts 76 | generator = torch.manual_seed(seed) 77 | generator = generator_eta = get_generator(seed) 78 | if opts.eta_noise_seed_delta > 0: 79 | seed = min(int(seed + opts.eta_noise_seed_delta), int(0xffffffffffffffff)) 80 | generator_eta = get_generator(seed) 81 | return (generator, generator_eta) 82 | 83 | generator, generator_eta = get_generator_obj(seed) 84 | randn_source = opts.randn_source 85 | 86 | # ========== hijack randn_like =============== 87 | import comfy.k_diffusion.sampling 88 | # if not hasattr(comfy.k_diffusion.sampling, 'torch_orig'): 89 | # comfy.k_diffusion.sampling.torch_orig = comfy.k_diffusion.sampling.torch 90 | # comfy.k_diffusion.sampling.torch = TorchHijack(generator_eta, opts.randn_source) 91 | 92 | if not hasattr(comfy.k_diffusion.sampling, 'default_noise_sampler_orig'): 93 | comfy.k_diffusion.sampling.default_noise_sampler_orig = comfy.k_diffusion.sampling.default_noise_sampler 94 | if opts_found: 95 | th = TorchHijack(generator_eta, randn_source) 96 | def default_noise_sampler(x, seed=None, *args, **kwargs): 97 | nonlocal th 98 | return lambda sigma, sigma_next: th.randn_like(x) 99 | default_noise_sampler.init = True 100 | comfy.k_diffusion.sampling.default_noise_sampler = default_noise_sampler 101 | else: 102 | comfy.k_diffusion.sampling.default_noise_sampler = comfy.k_diffusion.sampling.default_noise_sampler_orig 103 | # ============================================= 104 | 105 | if noise_inds is None: 106 | shape = latent_image.size() 107 | if opts.randn_source == 'nv': 108 | noise = torch.asarray(generator.randn(shape), dtype=latent_image.dtype, device=device) 109 | else: 110 | noise = torch.randn(shape, dtype=latent_image.dtype, layout=latent_image.layout, device=device, generator=generator) 111 | noise = noise.to(device=device_orig) 112 | return noise 113 | 114 | unique_inds, inverse = np.unique(noise_inds, return_inverse=True) 115 | noises = [] 116 | for i in range(unique_inds[-1]+1): 117 | shape = [1] + list(latent_image.size())[1:] 118 | if opts.randn_source == 'nv': 119 | noise = torch.asarray(generator.randn(shape), dtype=latent_image.dtype, device=device) 120 | else: 121 | noise = torch.randn(shape, dtype=latent_image.dtype, layout=latent_image.layout, device=device, generator=generator) 122 | noise = noise.to(device=device_orig) 123 | if i in unique_inds: 124 | noises.append(noise) 125 | noises = [noises[i] for i in inverse] 126 | noises = torch.cat(noises, axis=0) 127 | return noises 128 | 129 | def _find_outer_instance(target:str, target_type=None, callback=None, max_len=10): 130 | import inspect 131 | frame = inspect.currentframe() 132 | i = 0 133 | while frame and i < max_len: 134 | if target in frame.f_locals: 135 | if callback is not None: 136 | return callback(frame) 137 | else: 138 | found = frame.f_locals[target] 139 | if isinstance(found, target_type): 140 | return found 141 | frame = frame.f_back 142 | i += 1 143 | return None 144 | 145 | -------------------------------------------------------------------------------- /modules/rng_philox.py: -------------------------------------------------------------------------------- 1 | """RNG imitiating torch cuda randn on CPU. You are welcome. 2 | 3 | Usage: 4 | 5 | ``` 6 | g = Generator(seed=0) 7 | print(g.randn(shape=(3, 4))) 8 | ``` 9 | 10 | Expected output: 11 | ``` 12 | [[-0.92466259 -0.42534415 -2.6438457 0.14518388] 13 | [-0.12086647 -0.57972564 -0.62285122 -0.32838709] 14 | [-1.07454231 -0.36314407 -1.67105067 2.26550497]] 15 | ``` 16 | """ 17 | 18 | import numpy as np 19 | 20 | philox_m = [0xD2511F53, 0xCD9E8D57] 21 | philox_w = [0x9E3779B9, 0xBB67AE85] 22 | 23 | two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32) 24 | two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32) 25 | 26 | 27 | def uint32(x): 28 | """Converts (N,) np.uint64 array into (2, N) np.unit32 array.""" 29 | return x.view(np.uint32).reshape(-1, 2).transpose(1, 0) 30 | 31 | 32 | def philox4_round(counter, key): 33 | """A single round of the Philox 4x32 random number generator.""" 34 | 35 | v1 = uint32(counter[0].astype(np.uint64) * philox_m[0]) 36 | v2 = uint32(counter[2].astype(np.uint64) * philox_m[1]) 37 | 38 | counter[0] = v2[1] ^ counter[1] ^ key[0] 39 | counter[1] = v2[0] 40 | counter[2] = v1[1] ^ counter[3] ^ key[1] 41 | counter[3] = v1[0] 42 | 43 | 44 | def philox4_32(counter, key, rounds=10): 45 | """Generates 32-bit random numbers using the Philox 4x32 random number generator. 46 | 47 | Parameters: 48 | counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation). 49 | key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed). 50 | rounds (int): The number of rounds to perform. 51 | 52 | Returns: 53 | numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers. 54 | """ 55 | 56 | for _ in range(rounds - 1): 57 | philox4_round(counter, key) 58 | 59 | key[0] = key[0] + philox_w[0] 60 | key[1] = key[1] + philox_w[1] 61 | 62 | philox4_round(counter, key) 63 | return counter 64 | 65 | 66 | def box_muller(x, y): 67 | """Returns just the first out of two numbers generated by Box–Muller transform algorithm.""" 68 | u = x * two_pow32_inv + two_pow32_inv / 2 69 | v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2 70 | 71 | s = np.sqrt(-2.0 * np.log(u)) 72 | 73 | r1 = s * np.sin(v) 74 | return r1.astype(np.float32) 75 | 76 | 77 | class Generator: 78 | """RNG that produces same outputs as torch.randn(..., device='cuda') on CPU""" 79 | 80 | def __init__(self, seed): 81 | self.seed = seed 82 | self.offset = 0 83 | 84 | def randn(self, shape): 85 | """Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform.""" 86 | 87 | n = 1 88 | for x in shape: 89 | n *= x 90 | 91 | counter = np.zeros((4, n), dtype=np.uint32) 92 | counter[0] = self.offset 93 | counter[2] = np.arange(n, dtype=np.uint32) # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3] 94 | self.offset += 1 95 | 96 | key = np.empty(n, dtype=np.uint64) 97 | key.fill(self.seed) 98 | key = uint32(key) 99 | 100 | g = philox4_32(counter, key) 101 | 102 | return box_muller(g[0], g[1]).reshape(shape) # discard g[2] and g[3] 103 | -------------------------------------------------------------------------------- /modules/shared.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from copy import deepcopy 3 | from comfy.model_management import vram_state, VRAMState 4 | from comfy.cli_args import args 5 | from comfy import model_management 6 | 7 | xformers_available = model_management.XFORMERS_IS_AVAILABLE 8 | logger = logging.getLogger("smZNodes") 9 | level = logging.INFO 10 | logger.propagate = False 11 | logger.setLevel(level) 12 | stdoutHandler = logging.StreamHandler() 13 | fmt = logging.Formatter("[%(name)s] | %(filename)s:%(lineno)s | %(message)s") 14 | stdoutHandler.setFormatter(fmt) 15 | logger.addHandler(stdoutHandler) 16 | 17 | def join_args(*args): 18 | return ' '.join(map(str, args)) 19 | 20 | class SimpleNamespaceFast: 21 | def __repr__(self): 22 | keys = sorted(self.__dict__) 23 | items = ("{}={!r}".format(k, self.__dict__[k]) for k in keys) 24 | return "{}({})".format(type(self).__name__, ", ".join(items)) 25 | 26 | def __eq__(self, other): 27 | return self.__dict__ == other.__dict__ 28 | 29 | class Options(SimpleNamespaceFast): 30 | KEY = 'smZ_opts' 31 | 32 | def clone(self): 33 | return deepcopy(self) 34 | 35 | def update(self, other): 36 | if isinstance(other, dict): 37 | self.__dict__ |= other 38 | else: 39 | self.__dict__ |= other.__dict__ 40 | return self 41 | 42 | opts = Options() 43 | opts.prompt_attention = 'A1111 parser' 44 | opts.prompt_mean_norm = True 45 | opts.comma_padding_backtrack = 20 46 | opts.CLIP_stop_at_last_layers = 1 47 | opts.enable_emphasis = True 48 | opts.use_old_emphasis_implementation = False 49 | opts.disable_nan_check = True 50 | opts.pad_cond_uncond = False 51 | opts.s_min_uncond = 0.0 52 | opts.s_min_uncond_all = False 53 | opts.skip_early_cond = 0.0 54 | opts.upcast_sampling = True 55 | opts.upcast_attn = not getattr(args, 'dont_upcast_attention', False) 56 | opts.textual_inversion_add_hashes_to_infotext = False 57 | opts.encode_count = 0 58 | opts.max_chunk_count = 0 59 | opts.return_batch_chunks = False 60 | opts.noise = None 61 | opts.start_step = None 62 | opts.pad_with_repeats = True 63 | opts.randn_source = "cpu" 64 | opts.lora_functional = False 65 | opts.use_old_scheduling = True 66 | opts.eta_noise_seed_delta = 0 67 | opts.multi_conditioning = False 68 | opts.eta = 1.0 69 | opts.s_churn = 0.0 70 | opts.s_tmin = 0.0 71 | opts.s_tmax = 0.0 or float('inf') 72 | opts.s_noise = 1.0 73 | 74 | opts.use_CFGDenoiser = False 75 | opts.sgm_noise_multiplier = True 76 | opts.debug= False 77 | 78 | opts.sdxl_crop_top = 0 79 | opts.sdxl_crop_left = 0 80 | opts.sdxl_refiner_low_aesthetic_score = 2.5 81 | opts.sdxl_refiner_high_aesthetic_score = 6.0 82 | 83 | sd_model = Options() 84 | sd_model.cond_stage_model = Options() 85 | 86 | cmd_opts = Options() 87 | 88 | opts.batch_cond_uncond = False 89 | cmd_opts.lowvram = vram_state == VRAMState.LOW_VRAM 90 | cmd_opts.medvram = vram_state == VRAMState.NORMAL_VRAM 91 | should_batch_cond_uncond = lambda: opts.batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) 92 | opts.batch_cond_uncond = True 93 | 94 | opts_default = opts.clone() 95 | 96 | cmd_opts.xformers = xformers_available 97 | cmd_opts.force_enable_xformers = xformers_available 98 | 99 | opts.cross_attention_optimization = "None" 100 | cmd_opts.sub_quad_q_chunk_size = 512 101 | cmd_opts.sub_quad_kv_chunk_size = 512 102 | cmd_opts.sub_quad_chunk_threshold = 80 103 | cmd_opts.token_merging_ratio = 0.0 104 | cmd_opts.token_merging_ratio_img2img = 0.0 105 | cmd_opts.token_merging_ratio_hr = 0.0 106 | cmd_opts.sd_vae_sliced_encode = False 107 | cmd_opts.disable_opt_split_attention = False -------------------------------------------------------------------------------- /modules/text_processing/classic_engine.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import logging 4 | from collections import namedtuple 5 | from comfy import model_management 6 | from . import emphasis, prompt_parser 7 | from .textual_inversion import EmbeddingDatabase, parse_and_register_embeddings 8 | 9 | 10 | PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) 11 | last_extra_generation_params = {} 12 | 13 | def populate_self_variables(self, from_): 14 | attrs_from = vars(from_) 15 | attrs_self = vars(self) 16 | attrs_self.update(attrs_from) 17 | 18 | class PromptChunk: 19 | def __init__(self): 20 | self.tokens = [] 21 | self.multipliers = [] 22 | self.fixes = [] 23 | 24 | 25 | class CLIPEmbeddingForTextualInversion(torch.nn.Module): 26 | def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'): 27 | super().__init__() 28 | self.wrapped = wrapped 29 | self.embeddings = embeddings 30 | self.textual_inversion_key = textual_inversion_key 31 | self.weight = self.wrapped.weight 32 | 33 | def forward(self, input_ids, out_dtype): 34 | batch_fixes = self.embeddings.fixes 35 | self.embeddings.fixes = None 36 | 37 | inputs_embeds = self.wrapped(input_ids, out_dtype) 38 | 39 | if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0: 40 | return inputs_embeds 41 | 42 | vecs = [] 43 | for fixes, tensor in zip(batch_fixes, inputs_embeds): 44 | for offset, embedding in fixes: 45 | emb = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec 46 | emb = emb.to(inputs_embeds) 47 | emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) 48 | try: 49 | tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype) 50 | except Exception: 51 | logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {} {} {} '{}'".format(tensor.shape[0], emb.shape[1], self.current_embeds.weight.shape[1], self.textual_inversion_key, embedding.name)) 52 | 53 | vecs.append(tensor) 54 | 55 | return torch.stack(vecs) 56 | 57 | 58 | class ClassicTextProcessingEngine: 59 | def __init__( 60 | self, text_encoder, tokenizer, chunk_length=75, 61 | embedding_dir=None, embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="Original", 62 | text_projection=False, minimal_clip_skip=1, clip_skip=1, return_pooled=True, final_layer_norm=True 63 | ): 64 | super().__init__() 65 | populate_self_variables(self, tokenizer) 66 | self._tokenizer = tokenizer 67 | 68 | self.embeddings = EmbeddingDatabase(self.tokenizer, embedding_expected_shape) 69 | 70 | self.text_encoder = text_encoder 71 | self._try_get_embedding = tokenizer._try_get_embedding 72 | 73 | self.emphasis = emphasis.get_current_option(emphasis_name)() 74 | 75 | self.text_projection = text_projection 76 | self.minimal_clip_skip = minimal_clip_skip 77 | self.clip_skip = clip_skip 78 | self.return_pooled = return_pooled 79 | self.final_layer_norm = final_layer_norm 80 | 81 | self.chunk_length = chunk_length 82 | 83 | self.id_start = self.start_token 84 | self.id_end = self.end_token 85 | self.id_pad = self.pad_token 86 | 87 | model_embeddings = text_encoder.transformer.text_model.embeddings 88 | backup_embeds = self.text_encoder.transformer.get_input_embeddings() 89 | model_embeddings.token_embedding = CLIPEmbeddingForTextualInversion(model_embeddings.token_embedding, self.embeddings, textual_inversion_key=self.embedding_key) 90 | model_embeddings.token_embedding.current_embeds = backup_embeds 91 | 92 | vocab = self.tokenizer.get_vocab() 93 | self.token_mults = {} 94 | tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k] 95 | for text, ident in tokens_with_parens: 96 | mult = 1.0 97 | for c in text: 98 | if c == '[': 99 | mult /= 1.1 100 | if c == ']': 101 | mult *= 1.1 102 | if c == '(': 103 | mult *= 1.1 104 | if c == ')': 105 | mult /= 1.1 106 | if mult != 1.0: 107 | self.token_mults[ident] = mult 108 | 109 | self.comma_token = vocab.get(',', None) 110 | self.tokenizer._eventual_warn_about_too_long_sequence = lambda *args, **kwargs: None 111 | 112 | def unhook(self): 113 | self.text_encoder.transformer.text_model.embeddings.token_embedding = self.text_encoder.transformer.text_model.embeddings.token_embedding.wrapped 114 | del self._try_get_embedding 115 | w = '_eventual_warn_about_too_long_sequence' 116 | if hasattr(self.tokenizer, w): delattr(self.tokenizer, w) 117 | if hasattr(self._tokenizer, w): delattr(self._tokenizer, w) 118 | 119 | def empty_chunk(self): 120 | chunk = PromptChunk() 121 | chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1) 122 | chunk.multipliers = [1.0] * (self.chunk_length + 2) 123 | return chunk 124 | 125 | def get_target_prompt_token_count(self, token_count): 126 | return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length 127 | 128 | def tokenize(self, texts): 129 | tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] 130 | return tokenized 131 | 132 | def tokenize_with_weights(self, texts, return_word_ids=False): 133 | texts = [parse_and_register_embeddings(self, text) for text in texts] 134 | if self.opts.use_old_emphasis_implementation: 135 | return self.process_texts_past(texts) 136 | batch_chunks, token_count = self.process_texts(texts) 137 | 138 | used_embeddings = {} 139 | chunk_count = max([len(x) for x in batch_chunks]) 140 | 141 | zs = [] 142 | for i in range(chunk_count): 143 | batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks] 144 | 145 | tokens = [x.tokens for x in batch_chunk] 146 | multipliers = [x.multipliers for x in batch_chunk] 147 | self.embeddings.fixes = [x.fixes for x in batch_chunk] 148 | 149 | for fixes in self.embeddings.fixes: 150 | for _position, embedding in fixes: 151 | used_embeddings[embedding.name] = embedding 152 | 153 | z = (tokens, multipliers) 154 | zs.append(z) 155 | 156 | return zs 157 | 158 | def encode_token_weights(self, token_weight_pairs): 159 | if isinstance(token_weight_pairs[0], str): 160 | token_weight_pairs = self.tokenize_with_weights(token_weight_pairs) 161 | elif isinstance(token_weight_pairs[0], list): 162 | token_weight_pairs = list(map(lambda x: ([list(map(lambda y: y[0], x))], [list(map(lambda y: y[1], x))]), token_weight_pairs)) 163 | 164 | target_device = model_management.text_encoder_offload_device() 165 | zs = [] 166 | for tokens, multipliers in token_weight_pairs: 167 | z = self.process_tokens(tokens, multipliers) 168 | zs.append(z) 169 | if self.return_pooled: 170 | return torch.hstack(zs).to(target_device), zs[0].pooled.to(target_device) if zs[0].pooled is not None else None 171 | else: 172 | return torch.hstack(zs).to(target_device) 173 | 174 | def encode_with_transformers(self, tokens): 175 | try: 176 | z, pooled = self.text_encoder(tokens) 177 | except Exception: 178 | z, pooled = self.text_encoder(tokens.tolist()) 179 | z.pooled = pooled 180 | return z 181 | 182 | def tokenize_line(self, line): 183 | parsed = prompt_parser.parse_prompt_attention(line) 184 | 185 | tokenized = self.tokenize([text for text, _ in parsed]) 186 | 187 | chunks = [] 188 | chunk = PromptChunk() 189 | token_count = 0 190 | last_comma = -1 191 | 192 | def next_chunk(is_last=False): 193 | nonlocal token_count 194 | nonlocal last_comma 195 | nonlocal chunk 196 | 197 | if is_last: 198 | token_count += len(chunk.tokens) 199 | else: 200 | token_count += self.chunk_length 201 | 202 | to_add = self.chunk_length - len(chunk.tokens) 203 | if to_add > 0: 204 | chunk.tokens += [self.id_end] * to_add 205 | chunk.multipliers += [1.0] * to_add 206 | 207 | chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end] 208 | chunk.multipliers = [1.0] + chunk.multipliers + [1.0] 209 | 210 | last_comma = -1 211 | chunks.append(chunk) 212 | chunk = PromptChunk() 213 | 214 | for tokens, (text, weight) in zip(tokenized, parsed): 215 | if text == 'BREAK' and weight == -1: 216 | next_chunk() 217 | continue 218 | 219 | position = 0 220 | while position < len(tokens): 221 | token = tokens[position] 222 | 223 | comma_padding_backtrack = 20 224 | 225 | if token == self.comma_token: 226 | last_comma = len(chunk.tokens) 227 | 228 | elif comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= comma_padding_backtrack: 229 | break_location = last_comma + 1 230 | 231 | reloc_tokens = chunk.tokens[break_location:] 232 | reloc_mults = chunk.multipliers[break_location:] 233 | 234 | chunk.tokens = chunk.tokens[:break_location] 235 | chunk.multipliers = chunk.multipliers[:break_location] 236 | 237 | next_chunk() 238 | chunk.tokens = reloc_tokens 239 | chunk.multipliers = reloc_mults 240 | 241 | if len(chunk.tokens) == self.chunk_length: 242 | next_chunk() 243 | 244 | embedding, embedding_length_in_tokens = self.embeddings.find_embedding_at_position(tokens, position) 245 | if embedding is None: 246 | chunk.tokens.append(token) 247 | chunk.multipliers.append(weight) 248 | position += 1 249 | continue 250 | 251 | emb_len = int(embedding.vectors) 252 | if len(chunk.tokens) + emb_len > self.chunk_length: 253 | next_chunk() 254 | 255 | chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding)) 256 | 257 | chunk.tokens += [0] * emb_len 258 | chunk.multipliers += [weight] * emb_len 259 | position += embedding_length_in_tokens 260 | 261 | if chunk.tokens or not chunks: 262 | next_chunk(is_last=True) 263 | 264 | return chunks, token_count 265 | 266 | def process_texts(self, texts): 267 | token_count = 0 268 | 269 | cache = {} 270 | batch_chunks = [] 271 | for line in texts: 272 | if line in cache: 273 | chunks = cache[line] 274 | else: 275 | chunks, current_token_count = self.tokenize_line(line) 276 | token_count = max(current_token_count, token_count) 277 | 278 | cache[line] = chunks 279 | 280 | batch_chunks.append(chunks) 281 | 282 | return batch_chunks, token_count 283 | 284 | def __call__(self, texts): 285 | tokens = self.tokenize_with_weights(texts) 286 | return self.encode_token_weights(tokens) 287 | 288 | def process_tokens(self, remade_batch_tokens, batch_multipliers, *args, **kwargs): 289 | try: 290 | tokens = torch.asarray(remade_batch_tokens) 291 | 292 | if self.id_end != self.id_pad: 293 | for batch_pos in range(len(remade_batch_tokens)): 294 | index = remade_batch_tokens[batch_pos].index(self.id_end) 295 | tokens[batch_pos, index + 1:tokens.shape[1]] = self.id_pad 296 | 297 | z = self.encode_with_transformers(tokens) 298 | except ValueError: 299 | # Tokens including textual inversion embeddings in the list. 300 | # i.e. tensors in the list along with tokens. 301 | z = self.encode_with_transformers(remade_batch_tokens) 302 | 303 | pooled = getattr(z, 'pooled', None) 304 | 305 | self.emphasis.tokens = remade_batch_tokens 306 | self.emphasis.multipliers = torch.asarray(batch_multipliers).to(z) 307 | self.emphasis.z = z 308 | self.emphasis.after_transformers() 309 | z = self.emphasis.z 310 | 311 | if pooled is not None: 312 | z.pooled = pooled 313 | 314 | return z 315 | -------------------------------------------------------------------------------- /modules/text_processing/emphasis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Emphasis: 5 | name: str = "Base" 6 | description: str = "" 7 | tokens: list[list[int]] 8 | multipliers: torch.Tensor 9 | z: torch.Tensor 10 | 11 | def after_transformers(self): 12 | pass 13 | 14 | 15 | class EmphasisNone(Emphasis): 16 | name = "None" 17 | description = "disable the mechanism entirely and treat (:.1.1) as literal characters" 18 | 19 | 20 | class EmphasisIgnore(Emphasis): 21 | name = "Ignore" 22 | description = "treat all emphasized words as if they have no emphasis" 23 | 24 | 25 | class EmphasisOriginal(Emphasis): 26 | name = "Original" 27 | description = "the original emphasis implementation" 28 | 29 | def after_transformers(self): 30 | original_mean = self.z.mean() 31 | self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape) 32 | new_mean = self.z.mean() 33 | self.z = self.z * (original_mean / new_mean) 34 | 35 | 36 | class EmphasisOriginalNoNorm(EmphasisOriginal): 37 | name = "No norm" 38 | description = "same as original, but without normalization (seems to work better for SDXL)" 39 | 40 | def after_transformers(self): 41 | self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape) 42 | 43 | 44 | def get_current_option(emphasis_option_name): 45 | return next(iter([x for x in options if x.name == emphasis_option_name]), EmphasisOriginal) 46 | 47 | 48 | def get_options_descriptions(): 49 | return ", ".join(f"{x.name}: {x.description}" for x in options) 50 | 51 | def get_options_descriptions_nl(): 52 | return "\n".join(f"{x.name}: {x.description}" for x in options) 53 | 54 | 55 | options = [ 56 | EmphasisNone, 57 | EmphasisIgnore, 58 | EmphasisOriginal, 59 | EmphasisOriginalNoNorm, 60 | ] 61 | -------------------------------------------------------------------------------- /modules/text_processing/parsing.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | re_attention = re.compile(r""" 5 | \\\(| 6 | \\\)| 7 | \\\[| 8 | \\]| 9 | \\\\| 10 | \\| 11 | \(| 12 | \[| 13 | :\s*([+-]?[.\d]+)\s*\)| 14 | \)| 15 | ]| 16 | [^\\()\[\]:]+| 17 | : 18 | """, re.X) 19 | 20 | re_break = re.compile(r"\s*\bBREAK\b\s*", re.S) 21 | 22 | 23 | def parse_prompt_attention(text): 24 | res = [] 25 | round_brackets = [] 26 | square_brackets = [] 27 | 28 | round_bracket_multiplier = 1.1 29 | square_bracket_multiplier = 1 / 1.1 30 | 31 | def multiply_range(start_position, multiplier): 32 | for p in range(start_position, len(res)): 33 | res[p][1] *= multiplier 34 | 35 | for m in re_attention.finditer(text): 36 | text = m.group(0) 37 | weight = m.group(1) 38 | 39 | if text.startswith('\\'): 40 | res.append([text[1:], 1.0]) 41 | elif text == '(': 42 | round_brackets.append(len(res)) 43 | elif text == '[': 44 | square_brackets.append(len(res)) 45 | elif weight is not None and round_brackets: 46 | multiply_range(round_brackets.pop(), float(weight)) 47 | elif text == ')' and round_brackets: 48 | multiply_range(round_brackets.pop(), round_bracket_multiplier) 49 | elif text == ']' and square_brackets: 50 | multiply_range(square_brackets.pop(), square_bracket_multiplier) 51 | else: 52 | parts = re.split(re_break, text) 53 | for i, part in enumerate(parts): 54 | if i > 0: 55 | res.append(["BREAK", -1]) 56 | res.append([part, 1.0]) 57 | 58 | for pos in round_brackets: 59 | multiply_range(pos, round_bracket_multiplier) 60 | 61 | for pos in square_brackets: 62 | multiply_range(pos, square_bracket_multiplier) 63 | 64 | if len(res) == 0: 65 | res = [["", 1.0]] 66 | 67 | i = 0 68 | while i + 1 < len(res): 69 | if res[i][1] == res[i + 1][1]: 70 | res[i][0] += res[i + 1][0] 71 | res.pop(i + 1) 72 | else: 73 | i += 1 74 | 75 | return res 76 | -------------------------------------------------------------------------------- /modules/text_processing/past_classic_engine.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | def process_text_old(self, texts): 4 | id_start = self.id_start 5 | id_end = self.id_end 6 | maxlen = self.max_length # you get to stay at 77 7 | used_custom_terms = [] 8 | remade_batch_tokens = [] 9 | hijack_comments = [] 10 | hijack_fixes = [] 11 | token_count = 0 12 | 13 | cache = {} 14 | batch_tokens = self.tokenize(texts) 15 | batch_multipliers = [] 16 | batch_multipliers_solo_emb = [] 17 | for tokens in batch_tokens: 18 | tuple_tokens = tuple(tokens) 19 | 20 | if tuple_tokens in cache: 21 | remade_tokens, fixes, multipliers, multipliers_solo_emb = cache[tuple_tokens] 22 | else: 23 | fixes = [] 24 | remade_tokens = [] 25 | multipliers = [] 26 | multipliers_solo_emb = [] 27 | mult = 1.0 28 | 29 | i = 0 30 | while i < len(tokens): 31 | token = tokens[i] 32 | 33 | embedding, embedding_length_in_tokens = self.embeddings.find_embedding_at_position(tokens, i) 34 | if isinstance(embedding, dict): 35 | if 'open' in self.__class__.__name__.lower(): 36 | embedding = embedding.get('g', embedding) 37 | else: 38 | embedding.pop('g', None) 39 | embedding = next(iter(embedding.values())) 40 | 41 | mult_change = self.token_mults.get(token) if self.opts.enable_emphasis else None 42 | if mult_change is not None: 43 | mult *= mult_change 44 | i += 1 45 | elif embedding is None: 46 | remade_tokens.append(token) 47 | multipliers.append(mult) 48 | multipliers_solo_emb.append(mult) 49 | i += 1 50 | else: 51 | emb_len = int(embedding.vec.shape[0]) 52 | fixes.append((len(remade_tokens), embedding)) 53 | remade_tokens += [0] * emb_len 54 | multipliers += [mult] * emb_len 55 | multipliers_solo_emb += [mult] + ([1.0] * (emb_len-1)) 56 | used_custom_terms.append((embedding.name, embedding.checksum())) 57 | i += embedding_length_in_tokens 58 | 59 | if len(remade_tokens) > maxlen - 2: 60 | vocab = {v: k for k, v in self.tokenizer.get_vocab().items()} 61 | ovf = remade_tokens[maxlen - 2:] 62 | overflowing_words = [vocab.get(int(x), "") for x in ovf] 63 | overflowing_text = self.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) 64 | logging.warning(f"\033[33mWarning:\033[0m too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") 65 | 66 | token_count = len(remade_tokens) 67 | remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) 68 | remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] 69 | cache[tuple_tokens] = (remade_tokens, fixes, multipliers, multipliers_solo_emb) 70 | 71 | multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) 72 | multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] 73 | 74 | multipliers_solo_emb = multipliers_solo_emb + [1.0] * (maxlen - 2 - len(multipliers_solo_emb)) 75 | multipliers_solo_emb = [1.0] + multipliers_solo_emb[0:maxlen - 2] + [1.0] 76 | 77 | remade_batch_tokens.append(remade_tokens) 78 | hijack_fixes.append(fixes) 79 | batch_multipliers.append(multipliers) 80 | batch_multipliers_solo_emb.append(multipliers_solo_emb) 81 | return batch_multipliers, batch_multipliers_solo_emb, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count 82 | 83 | 84 | def forward_old(self, texts): 85 | batch_multipliers, batch_multipliers_solo_emb, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, _token_count = process_text_old(self, texts) 86 | 87 | chunk_count = max([len(x) for x in remade_batch_tokens]) 88 | 89 | if self.opts.return_batch_chunks: 90 | return (remade_batch_tokens, chunk_count) 91 | 92 | self.hijack.comments += hijack_comments 93 | 94 | if len(used_custom_terms) > 0: 95 | embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms) 96 | self.hijack.comments.append(f"Used embeddings: {embedding_names}") 97 | 98 | self.hijack.fixes = hijack_fixes 99 | return self.process_tokens(remade_batch_tokens, batch_multipliers, batch_multipliers_solo_emb) 100 | 101 | def process_texts_past(self, texts): 102 | batch_multipliers, batch_multipliers_solo_emb, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, _token_count = process_text_old(self, texts) 103 | return [(remade_batch_tokens, batch_multipliers)] -------------------------------------------------------------------------------- /modules/text_processing/prompt_parser.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import re 3 | import lark 4 | from typing import List 5 | from compel import Compel 6 | from collections import namedtuple 7 | if __name__ != "__main__": 8 | import torch 9 | from ..shared import opts, logger 10 | else: 11 | class Opts: ... 12 | opts = Opts() 13 | opts.prompt_attention = 'A1111 parser' 14 | import logging as logger 15 | 16 | # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][: in background:0.25] [shoddy:masterful:0.5]" 17 | # will be represented with prompt_schedule like this (assuming steps=100): 18 | # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] 19 | # [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy'] 20 | # [60, 'fantasy landscape with a lake and an oak in foreground in background masterful'] 21 | # [75, 'fantasy landscape with a lake and an oak in background masterful'] 22 | # [100, 'fantasy landscape with a lake and a christmas tree in background masterful'] 23 | 24 | schedule_parser = lark.Lark(r""" 25 | !start: (prompt | /[][():]/+)* 26 | prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)* 27 | !emphasized: "(" prompt ")" 28 | | "(" prompt ":" prompt ")" 29 | | "[" prompt "]" 30 | scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]" 31 | alternate: "[" prompt ("|" [prompt])+ "]" 32 | WHITESPACE: /\s+/ 33 | plain: /([^\\\[\]():|]|\\.)+/ 34 | %import common.SIGNED_NUMBER -> NUMBER 35 | """) 36 | re_clean = re.compile(r"^\W+", re.S) 37 | re_whitespace = re.compile(r"\s+", re.S) 38 | 39 | 40 | def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=None, use_old_scheduling=False): 41 | """ 42 | >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0] 43 | >>> g("test") 44 | [[10, 'test']] 45 | >>> g("a [b:3]") 46 | [[3, 'a '], [10, 'a b']] 47 | >>> g("a [b: 3]") 48 | [[3, 'a '], [10, 'a b']] 49 | >>> g("a [[[b]]:2]") 50 | [[2, 'a '], [10, 'a [[b]]']] 51 | >>> g("[(a:2):3]") 52 | [[3, ''], [10, '(a:2)']] 53 | >>> g("a [b : c : 1] d") 54 | [[1, 'a b d'], [10, 'a c d']] 55 | >>> g("a[b:[c:d:2]:1]e") 56 | [[1, 'abe'], [2, 'ace'], [10, 'ade']] 57 | >>> g("a [unbalanced") 58 | [[10, 'a [unbalanced']] 59 | >>> g("a [b:.5] c") 60 | [[5, 'a c'], [10, 'a b c']] 61 | >>> g("a [{b|d{:.5] c") # not handling this right now 62 | [[5, 'a c'], [10, 'a {b|d{ c']] 63 | >>> g("((a][:b:c [d:3]") 64 | [[3, '((a][:b:c '], [10, '((a][:b:c d']] 65 | >>> g("[a|(b:1.1)]") 66 | [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']] 67 | >>> g("[fe|]male") 68 | [[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']] 69 | >>> g("[fe|||]male") 70 | [[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']] 71 | >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10, 10)[0] 72 | >>> g("a [b:.5] c") 73 | [[10, 'a b c']] 74 | >>> g("a [b:1.5] c") 75 | [[5, 'a c'], [10, 'a b c']] 76 | """ 77 | 78 | if hires_steps is None or use_old_scheduling: 79 | int_offset = 0 80 | flt_offset = 0 81 | steps = base_steps 82 | else: 83 | int_offset = base_steps 84 | flt_offset = 1.0 85 | steps = hires_steps 86 | 87 | def collect_steps(steps, tree): 88 | res = [steps] 89 | 90 | class CollectSteps(lark.Visitor): 91 | def scheduled(self, tree): 92 | s = tree.children[-2] 93 | v = float(s) 94 | if use_old_scheduling: 95 | v = v*steps if v<1 else v 96 | else: 97 | if "." in s: 98 | v = (v - flt_offset) * steps 99 | else: 100 | v = (v - int_offset) 101 | tree.children[-2] = min(steps, int(v)) 102 | if tree.children[-2] >= 1: 103 | res.append(tree.children[-2]) 104 | 105 | def alternate(self, tree): 106 | res.extend(range(1, steps+1)) 107 | 108 | CollectSteps().visit(tree) 109 | return sorted(set(res)) 110 | 111 | def at_step(step, tree): 112 | class AtStep(lark.Transformer): 113 | def scheduled(self, args): 114 | before, after, _, when, _ = args 115 | yield before or () if step <= when else after 116 | def alternate(self, args): 117 | args = ["" if not arg else arg for arg in args] 118 | yield args[(step - 1) % len(args)] 119 | def start(self, args): 120 | def flatten(x): 121 | if isinstance(x, str): 122 | yield x 123 | else: 124 | for gen in x: 125 | yield from flatten(gen) 126 | return ''.join(flatten(args)) 127 | def plain(self, args): 128 | yield args[0].value 129 | def __default__(self, data, children, meta): 130 | for child in children: 131 | yield child 132 | return AtStep().transform(tree) 133 | 134 | def get_schedule(prompt): 135 | try: 136 | tree = schedule_parser.parse(prompt) 137 | except lark.exceptions.LarkError: 138 | if 0: 139 | import traceback 140 | traceback.print_exc() 141 | return [[steps, prompt]] 142 | return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)] 143 | 144 | promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)} 145 | return [promptdict[prompt] for prompt in prompts] 146 | 147 | 148 | ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) 149 | 150 | 151 | class SdConditioning(list): 152 | """ 153 | A list with prompts for stable diffusion's conditioner model. 154 | Can also specify width and height of created image - SDXL needs it. 155 | """ 156 | def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None): 157 | super().__init__() 158 | self.extend(prompts) 159 | 160 | if copy_from is None: 161 | copy_from = prompts 162 | 163 | self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False) 164 | self.width = width or getattr(copy_from, 'width', None) 165 | self.height = height or getattr(copy_from, 'height', None) 166 | 167 | 168 | 169 | def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False): 170 | """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), 171 | and the sampling step at which this condition is to be replaced by the next one. 172 | 173 | Input: 174 | (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20) 175 | 176 | Output: 177 | [ 178 | [ 179 | ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0')) 180 | ], 181 | [ 182 | ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')), 183 | ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0')) 184 | ] 185 | ] 186 | """ 187 | res = [] 188 | 189 | prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling) 190 | logger.debug(prompt_schedules) 191 | cache = {} 192 | for prompt, prompt_schedule in zip(prompts, prompt_schedules): 193 | 194 | cached = cache.get(prompt, None) 195 | if cached is not None: 196 | res.append(cached) 197 | continue 198 | 199 | texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts) 200 | # conds = model.get_learned_conditioning(texts) 201 | conds = model(texts) 202 | if isinstance(conds, tuple): 203 | conds, pooleds = conds 204 | conds.pooled = pooleds 205 | cond_schedule = [] 206 | for i, (end_at_step, _) in enumerate(prompt_schedule): 207 | if isinstance(conds, dict): 208 | if 'cond' not in conds: 209 | cond = {k: v[i] for k, v in conds.items()} 210 | else: 211 | # cond = {k: (v[i:i+1, :] if isinstance(v, torch.Tensor) else v) for k, v in conds.items()} 212 | cond = {k: (v[i].unsqueeze(0)[0:1] if isinstance(v, torch.Tensor) else v) for k, v in conds.items()} 213 | else: 214 | # cond = conds[i] 215 | cond = conds[i:i+1, :] 216 | if conds.pooled is not None: 217 | cond.pooled = conds.pooled[i:i+1, :] 218 | cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond)) 219 | 220 | cache[prompt] = cond_schedule 221 | res.append(cond_schedule) 222 | 223 | return res 224 | 225 | 226 | re_AND = re.compile(r"\bAND\b") 227 | re_weight = re.compile(r"^((?:\s|.)*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$") 228 | 229 | 230 | def get_multicond_prompt_list(prompts: SdConditioning | list[str]): 231 | res_indexes = [] 232 | 233 | prompt_indexes = {} 234 | prompt_flat_list = SdConditioning(prompts) 235 | prompt_flat_list.clear() 236 | 237 | for prompt in prompts: 238 | subprompts = re_AND.split(prompt) 239 | 240 | indexes = [] 241 | for subprompt in subprompts: 242 | match = re_weight.search(subprompt) 243 | 244 | text, weight = match.groups() if match is not None else (subprompt, 1.0) 245 | 246 | weight = float(weight) if weight is not None else 1.0 247 | 248 | index = prompt_indexes.get(text, None) 249 | if index is None: 250 | index = len(prompt_flat_list) 251 | prompt_flat_list.append(text) 252 | prompt_indexes[text] = index 253 | 254 | indexes.append((index, weight)) 255 | 256 | res_indexes.append(indexes) 257 | 258 | return res_indexes, prompt_flat_list, prompt_indexes 259 | 260 | 261 | class ComposableScheduledPromptConditioning: 262 | def __init__(self, schedules, weight=1.0): 263 | self.schedules: List[ScheduledPromptConditioning] = schedules 264 | self.weight: float = weight 265 | 266 | 267 | class MulticondLearnedConditioning: 268 | def __init__(self, shape, batch): 269 | self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS 270 | self.batch: List[List[ComposableScheduledPromptConditioning]] = batch 271 | 272 | 273 | def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning: 274 | """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt. 275 | For each prompt, the list is obtained by splitting the prompt using the AND separator. 276 | 277 | https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/ 278 | """ 279 | 280 | res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts) 281 | 282 | learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling) 283 | 284 | res = [] 285 | for indexes in res_indexes: 286 | res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes]) 287 | 288 | return MulticondLearnedConditioning(shape=(len(prompts),), batch=res) 289 | 290 | 291 | class DictWithShape(dict): 292 | def __init__(self, x, shape=None): 293 | super().__init__() 294 | self.update(x) 295 | 296 | @property 297 | def shape(self): 298 | return self["crossattn"].shape 299 | 300 | 301 | def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step): 302 | param = c[0][0].cond 303 | is_dict = isinstance(param, dict) 304 | pooled_outputs = [] 305 | 306 | if is_dict: 307 | dict_cond = param 308 | if 'crossattn' in dict_cond: 309 | res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()} 310 | res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape) 311 | elif 'cond' in dict_cond: 312 | param = dict_cond['cond'] 313 | res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) 314 | else: 315 | res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) 316 | 317 | for i, cond_schedule in enumerate(c): 318 | target_index = 0 319 | for current, entry in enumerate(cond_schedule): 320 | if current_step <= entry.end_at_step: 321 | target_index = current 322 | break 323 | 324 | cond_target = cond_schedule[target_index].cond 325 | if is_dict: 326 | if 'cond' in cond_target: 327 | res[i] = cond_target['cond'] 328 | pooled_outputs.append(cond_target['pooled_output']) 329 | else: 330 | for k, param in cond_target.items(): 331 | res[k][i] = param 332 | else: 333 | res[i] = cond_target 334 | pooled_outputs.append(cond_target.pooled) 335 | res.pooled = torch.cat(pooled_outputs).to(param) if pooled_outputs[0] is not None else None 336 | return res 337 | 338 | 339 | def stack_conds(tensors): 340 | # if prompts have wildly different lengths above the limit we'll get tensors of different shapes 341 | # and won't be able to torch.stack them. So this fixes that. 342 | token_count = max([x.shape[0] for x in tensors]) 343 | for i in range(len(tensors)): 344 | if tensors[i].shape[0] != token_count: 345 | last_vector = tensors[i][-1:] 346 | last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1]) 347 | tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) 348 | 349 | return torch.stack(tensors) 350 | 351 | 352 | 353 | def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): 354 | param = c.batch[0][0].schedules[0].cond 355 | 356 | tensors = [] 357 | pooled_outputs = [] 358 | conds_list = [] 359 | 360 | for composable_prompts in c.batch: 361 | conds_for_batch = [] 362 | 363 | for composable_prompt in composable_prompts: 364 | target_index = 0 365 | for current, entry in enumerate(composable_prompt.schedules): 366 | if current_step <= entry.end_at_step: 367 | target_index = current 368 | break 369 | 370 | conds_for_batch.append((len(tensors), composable_prompt.weight)) 371 | tensors.append(composable_prompt.schedules[target_index].cond) 372 | pooled_outputs.append(composable_prompt.schedules[target_index].cond.pooled) 373 | 374 | conds_list.append(conds_for_batch) 375 | 376 | if isinstance(tensors[0], dict): 377 | keys = list(tensors[0].keys()) 378 | stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys} 379 | stacked = DictWithShape(stacked, stacked['crossattn'].shape) 380 | else: 381 | stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype) 382 | stacked.pooled = torch.cat(pooled_outputs).to(device=param.device, dtype=param.dtype) 383 | return conds_list, stacked 384 | 385 | 386 | re_attention = re.compile(r""" 387 | \\\(| 388 | \\\)| 389 | \\\[| 390 | \\]| 391 | \\\\| 392 | \\| 393 | \(| 394 | \[| 395 | :\s*([+-]?[.\d]+)\s*\)| 396 | \)| 397 | ]| 398 | [^\\()\[\]:]+| 399 | : 400 | """, re.X) 401 | 402 | re_break = re.compile(r"\s*\bBREAK\b\s*", re.S) 403 | re_attention_v1 = re_attention 404 | 405 | def parse_prompt_attention(text): 406 | """ 407 | Parses a string with attention tokens and returns a list of pairs: text and its associated weight. 408 | Accepted tokens are: 409 | (abc) - increases attention to abc by a multiplier of 1.1 410 | (abc:3.12) - increases attention to abc by a multiplier of 3.12 411 | [abc] - decreases attention to abc by a multiplier of 1.1 412 | \( - literal character '(' 413 | \[ - literal character '[' 414 | \) - literal character ')' 415 | \] - literal character ']' 416 | \\ - literal character '\' 417 | anything else - just text 418 | 419 | >>> parse_prompt_attention('normal text') 420 | [['normal text', 1.0]] 421 | >>> parse_prompt_attention('an (important) word') 422 | [['an ', 1.0], ['important', 1.1], [' word', 1.0]] 423 | >>> parse_prompt_attention('(unbalanced') 424 | [['unbalanced', 1.1]] 425 | >>> parse_prompt_attention('\(literal\]') 426 | [['(literal]', 1.0]] 427 | >>> parse_prompt_attention('(unnecessary)(parens)') 428 | [['unnecessaryparens', 1.1]] 429 | >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') 430 | [['a ', 1.0], 431 | ['house', 1.5730000000000004], 432 | [' ', 1.1], 433 | ['on', 1.0], 434 | [' a ', 1.1], 435 | ['hill', 0.55], 436 | [', sun, ', 1.1], 437 | ['sky', 1.4641000000000006], 438 | ['.', 1.1]] 439 | """ 440 | 441 | res = [] 442 | round_brackets = [] 443 | square_brackets = [] 444 | 445 | round_bracket_multiplier = 1.1 446 | square_bracket_multiplier = 1 / 1.1 447 | if opts.prompt_attention == 'Fixed attention': 448 | res = [[text, 1.0]] 449 | return res 450 | elif opts.prompt_attention == 'Compel parser': 451 | conjunction = Compel.parse_prompt_string(text) 452 | if conjunction is None or conjunction.prompts is None or conjunction.prompts is None: 453 | return [["", 1.0]] 454 | use_and = '.and(' in text 455 | cprompts = conjunction.prompts if 'Blend' in conjunction.prompts[0].__class__.__name__ else [conjunction.prompts] 456 | first_el = conjunction.prompts[0].prompts[0] if hasattr(conjunction.prompts[0], 'prompts') else conjunction.prompts[0] 457 | if len(first_el.children) == 0: 458 | return [["", 1.0]] 459 | res = [] 460 | for bprompt in cprompts: 461 | for prompt_idx, prompt in enumerate(bprompt.prompts if hasattr(bprompt, 'prompts') else bprompt): 462 | for k, frag in enumerate(prompt.children): 463 | res.append([f'{" AND " if use_and and k == 0 and prompt_idx > 0 else ""}{frag.text}', frag.weight]) 464 | return res 465 | elif opts.prompt_attention == 'A1111 parser': 466 | re_attention = re_attention_v1 467 | whitespace = '' 468 | else: 469 | re_attention = re_attention_v1 470 | text = text.replace('\\n', ' ') 471 | whitespace = ' ' 472 | 473 | def multiply_range(start_position, multiplier): 474 | for p in range(start_position, len(res)): 475 | res[p][1] *= multiplier 476 | 477 | for m in re_attention.finditer(text): 478 | text = m.group(0) 479 | weight = m.group(1) 480 | 481 | if text.startswith('\\'): 482 | res.append([text[1:], 1.0]) 483 | elif text == '(': 484 | round_brackets.append(len(res)) 485 | elif text == '[': 486 | square_brackets.append(len(res)) 487 | elif weight is not None and round_brackets: 488 | multiply_range(round_brackets.pop(), float(weight)) 489 | elif text == ')' and round_brackets: 490 | multiply_range(round_brackets.pop(), round_bracket_multiplier) 491 | elif text == ']' and square_brackets: 492 | multiply_range(square_brackets.pop(), square_bracket_multiplier) 493 | else: 494 | parts = re.split(re_break, text) 495 | for i, part in enumerate(parts): 496 | if i > 0: 497 | res.append(["BREAK", -1]) 498 | if opts.prompt_attention == 'Full parser': 499 | part = re_clean.sub("", part) 500 | part = re_whitespace.sub(" ", part).strip() 501 | if len(part) == 0: 502 | continue 503 | res.append([part, 1.0]) 504 | 505 | for pos in round_brackets: 506 | multiply_range(pos, round_bracket_multiplier) 507 | 508 | for pos in square_brackets: 509 | multiply_range(pos, square_bracket_multiplier) 510 | 511 | if len(res) == 0: 512 | res = [["", 1.0]] 513 | 514 | # merge runs of identical weights 515 | i = 0 516 | while i + 1 < len(res): 517 | if res[i][1] == res[i + 1][1]: 518 | res[i][0] += whitespace + res[i + 1][0] 519 | res.pop(i + 1) 520 | else: 521 | i += 1 522 | 523 | return res 524 | 525 | if __name__ == "__main__": 526 | import doctest 527 | doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE) 528 | import sys 529 | args = sys.argv[1:] 530 | if len(args) > 0: 531 | input_text = " ".join(args) 532 | else: 533 | input_text = '[black] [[grey]] (white) ((gray)) ((orange:1.1) yellow) ((purple) and [dark] red:1.1) [mouse:0.2] [(cat:1.1):0.5]' 534 | print(f'Prompt: {input_text}') 535 | all_schedules = get_learned_conditioning_prompt_schedules([input_text], 100)[0] 536 | print('Schedules', all_schedules) 537 | for schedule in all_schedules: 538 | print('Schedule', schedule[0]) 539 | opts.prompt_attention = 'Fixed attention' 540 | output_list = parse_prompt_attention(schedule[1]) 541 | print(' Fixed:', output_list) 542 | opts.prompt_attention = 'Compel parser' 543 | output_list = parse_prompt_attention(schedule[1]) 544 | print(' Compel:', output_list) 545 | opts.prompt_attention = 'A1111 parser' 546 | output_list = parse_prompt_attention(schedule[1]) 547 | print(' A1111:', output_list) 548 | opts.prompt_attention = 'Full parser' 549 | output_list = parse_prompt_attention(schedule[1]) 550 | print(' Full :', output_list) 551 | else: 552 | import torch # doctest faster -------------------------------------------------------------------------------- /modules/text_processing/t5_engine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import namedtuple 3 | from . import prompt_parser, emphasis 4 | from comfy import model_management 5 | 6 | 7 | PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) 8 | 9 | def populate_self_variables(self, from_): 10 | attrs_from = vars(from_) 11 | attrs_self = vars(self) 12 | attrs_self.update(attrs_from) 13 | 14 | class PromptChunk: 15 | def __init__(self): 16 | self.tokens = [] 17 | self.multipliers = [] 18 | 19 | 20 | class T5TextProcessingEngine: 21 | def __init__(self, text_encoder, tokenizer, emphasis_name="Original", min_length=256): 22 | super().__init__() 23 | populate_self_variables(self, tokenizer) 24 | self._tokenizer = tokenizer 25 | 26 | self.text_encoder = text_encoder 27 | 28 | self.emphasis = emphasis.get_current_option(emphasis_name)() 29 | self.min_length = self.min_length or self.max_length 30 | self.id_end = self.end_token 31 | self.id_pad = self.pad_token 32 | vocab = self.tokenizer.get_vocab() 33 | self.comma_token = vocab.get(',', None) 34 | self.token_mults = {} 35 | tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k] 36 | for text, ident in tokens_with_parens: 37 | mult = 1.0 38 | for c in text: 39 | if c == '[': 40 | mult /= 1.1 41 | if c == ']': 42 | mult *= 1.1 43 | if c == '(': 44 | mult *= 1.1 45 | if c == ')': 46 | mult /= 1.1 47 | 48 | if mult != 1.0: 49 | self.token_mults[ident] = mult 50 | self.tokenizer._eventual_warn_about_too_long_sequence = lambda *args, **kwargs: None 51 | 52 | 53 | def tokenize(self, texts): 54 | tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] 55 | return tokenized 56 | 57 | def encode_with_transformers(self, tokens): 58 | try: 59 | z, pooled = self.text_encoder(tokens) 60 | except Exception: 61 | z, pooled = self.text_encoder(tokens.tolist()) 62 | return z 63 | 64 | def tokenize_line(self, line): 65 | parsed = prompt_parser.parse_prompt_attention(line) 66 | 67 | tokenized = self.tokenize([text for text, _ in parsed]) 68 | 69 | chunks = [] 70 | chunk = PromptChunk() 71 | token_count = 0 72 | 73 | def next_chunk(): 74 | nonlocal token_count 75 | nonlocal chunk 76 | 77 | chunk.tokens = chunk.tokens + [self.id_end] 78 | chunk.multipliers = chunk.multipliers + [1.0] 79 | current_chunk_length = len(chunk.tokens) 80 | 81 | token_count += current_chunk_length 82 | remaining_count = self.min_length - current_chunk_length 83 | 84 | if remaining_count > 0: 85 | chunk.tokens += [self.id_pad] * remaining_count 86 | chunk.multipliers += [1.0] * remaining_count 87 | 88 | chunks.append(chunk) 89 | chunk = PromptChunk() 90 | 91 | for tokens, (text, weight) in zip(tokenized, parsed): 92 | if text == 'BREAK' and weight == -1: 93 | next_chunk() 94 | continue 95 | 96 | position = 0 97 | while position < len(tokens): 98 | token = tokens[position] 99 | chunk.tokens.append(token) 100 | chunk.multipliers.append(weight) 101 | position += 1 102 | 103 | if chunk.tokens or not chunks: 104 | next_chunk() 105 | 106 | return chunks, token_count 107 | 108 | def unhook(self): 109 | w = '_eventual_warn_about_too_long_sequence' 110 | if hasattr(self.tokenizer, w): delattr(self.tokenizer, w) 111 | if hasattr(self._tokenizer, w): delattr(self._tokenizer, w) 112 | 113 | def tokenize_with_weights(self, texts, return_word_ids=False): 114 | tokens_and_weights = [] 115 | cache = {} 116 | for line in texts: 117 | if line not in cache: 118 | chunks, token_count = self.tokenize_line(line) 119 | line_tokens_and_weights = [] 120 | 121 | # Pad all chunks to the length of the longest chunk 122 | max_tokens = 0 123 | for chunk in chunks: 124 | max_tokens = max (len(chunk.tokens), max_tokens) 125 | 126 | for chunk in chunks: 127 | tokens = chunk.tokens 128 | multipliers = chunk.multipliers 129 | remaining_count = max_tokens - len(tokens) 130 | if remaining_count > 0: 131 | tokens += [self.id_pad] * remaining_count 132 | multipliers += [1.0] * remaining_count 133 | line_tokens_and_weights.append((tokens, multipliers)) 134 | cache[line] = line_tokens_and_weights 135 | 136 | tokens_and_weights.extend(cache[line]) 137 | return tokens_and_weights 138 | 139 | def encode_token_weights(self, token_weight_pairs): 140 | if isinstance(token_weight_pairs[0], str): 141 | token_weight_pairs = self.tokenize_with_weights(token_weight_pairs) 142 | elif isinstance(token_weight_pairs[0], list): 143 | token_weight_pairs = list(map(lambda x: (list(map(lambda y: y[0], x)), list(map(lambda y: y[1], x))), token_weight_pairs)) 144 | 145 | target_device = model_management.text_encoder_offload_device() 146 | zs = [] 147 | cache = {} 148 | for tokens, multipliers in token_weight_pairs: 149 | token_key = (tuple(tokens), tuple(multipliers)) 150 | if token_key not in cache: 151 | z = self.process_tokens([tokens], [multipliers])[0] 152 | cache[token_key] = z 153 | zs.append(cache[token_key]) 154 | return torch.stack(zs).to(target_device), None 155 | 156 | def __call__(self, texts): 157 | tokens = self.tokenize_with_weights(texts) 158 | return self.encode_token_weights(tokens) 159 | 160 | def process_tokens(self, batch_tokens, batch_multipliers): 161 | tokens = torch.asarray(batch_tokens) 162 | 163 | z = self.encode_with_transformers(tokens) 164 | 165 | self.emphasis.tokens = batch_tokens 166 | self.emphasis.multipliers = torch.asarray(batch_multipliers).to(z) 167 | self.emphasis.z = z 168 | self.emphasis.after_transformers() 169 | z = self.emphasis.z 170 | 171 | return z 172 | -------------------------------------------------------------------------------- /modules/text_processing/textual_inversion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import base64 4 | import json 5 | import zlib 6 | import logging 7 | import numpy as np 8 | import safetensors.torch 9 | from PIL import Image 10 | 11 | 12 | class EmbeddingEncoder(json.JSONEncoder): 13 | def default(self, obj): 14 | if isinstance(obj, torch.Tensor): 15 | return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()} 16 | return json.JSONEncoder.default(self, obj) 17 | 18 | 19 | class EmbeddingDecoder(json.JSONDecoder): 20 | def __init__(self, *args, **kwargs): 21 | json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs) 22 | 23 | def object_hook(self, d): 24 | if 'TORCHTENSOR' in d: 25 | return torch.from_numpy(np.array(d['TORCHTENSOR'])) 26 | return d 27 | 28 | 29 | def embedding_to_b64(data): 30 | d = json.dumps(data, cls=EmbeddingEncoder) 31 | return base64.b64encode(d.encode()) 32 | 33 | 34 | def embedding_from_b64(data): 35 | d = base64.b64decode(data) 36 | return json.loads(d, cls=EmbeddingDecoder) 37 | 38 | 39 | def lcg(m=2 ** 32, a=1664525, c=1013904223, seed=0): 40 | while True: 41 | seed = (a * seed + c) % m 42 | yield seed % 255 43 | 44 | 45 | def xor_block(block): 46 | g = lcg() 47 | randblock = np.array([next(g) for _ in range(np.prod(block.shape))]).astype(np.uint8).reshape(block.shape) 48 | return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F) 49 | 50 | 51 | def crop_black(img, tol=0): 52 | mask = (img > tol).all(2) 53 | mask0, mask1 = mask.any(0), mask.any(1) 54 | col_start, col_end = mask0.argmax(), mask.shape[1] - mask0[::-1].argmax() 55 | row_start, row_end = mask1.argmax(), mask.shape[0] - mask1[::-1].argmax() 56 | return img[row_start:row_end, col_start:col_end] 57 | 58 | 59 | def extract_image_data_embed(image): 60 | d = 3 61 | outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F 62 | black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0) 63 | if black_cols[0].shape[0] < 2: 64 | print(f'{os.path.basename(getattr(image, "filename", "unknown image file"))}: no embedded information found.') 65 | return None 66 | 67 | data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8) 68 | data_block_upper = outarr[:, black_cols[0].max() + 1:, :].astype(np.uint8) 69 | 70 | data_block_lower = xor_block(data_block_lower) 71 | data_block_upper = xor_block(data_block_upper) 72 | 73 | data_block = (data_block_upper << 4) | (data_block_lower) 74 | data_block = data_block.flatten().tobytes() 75 | 76 | data = zlib.decompress(data_block) 77 | return json.loads(data, cls=EmbeddingDecoder) 78 | 79 | 80 | class Embedding: 81 | def __init__(self, vec, name, step=None): 82 | self.vec = vec 83 | self.name = name 84 | self.step = step 85 | self.shape = None 86 | self.vectors = 0 87 | self.sd_checkpoint = None 88 | self.sd_checkpoint_name = None 89 | 90 | 91 | class DirWithTextualInversionEmbeddings: 92 | def __init__(self, path): 93 | self.path = path 94 | self.mtime = None 95 | 96 | def has_changed(self): 97 | if not os.path.isdir(self.path): 98 | return False 99 | 100 | mt = os.path.getmtime(self.path) 101 | if self.mtime is None or mt > self.mtime: 102 | return True 103 | 104 | def update(self): 105 | if not os.path.isdir(self.path): 106 | return 107 | 108 | self.mtime = os.path.getmtime(self.path) 109 | 110 | 111 | class EmbeddingDatabase: 112 | def __init__(self, tokenizer, expected_shape=-1): 113 | self.ids_lookup = {} 114 | self.word_embeddings = {} 115 | self.embedding_dirs = {} 116 | self.skipped_embeddings = {} 117 | self.expected_shape = expected_shape 118 | self.tokenizer = tokenizer 119 | self.fixes = [] 120 | 121 | def add_embedding_dir(self, path): 122 | self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path) 123 | 124 | def clear_embedding_dirs(self): 125 | self.embedding_dirs.clear() 126 | 127 | def register_embedding(self, embedding): 128 | return self.register_embedding_by_name(embedding, embedding.name) 129 | 130 | def register_embedding_by_name(self, embedding, name): 131 | ids = self.tokenizer([name], truncation=False, add_special_tokens=False)["input_ids"][0] 132 | first_id = ids[0] 133 | if first_id not in self.ids_lookup: 134 | self.ids_lookup[first_id] = [] 135 | if name in self.word_embeddings: 136 | lookup = [x for x in self.ids_lookup[first_id] if x[1].name != name] 137 | else: 138 | lookup = self.ids_lookup[first_id] 139 | if embedding is not None: 140 | lookup += [(ids, embedding)] 141 | self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True) 142 | if embedding is None: 143 | if name in self.word_embeddings: 144 | del self.word_embeddings[name] 145 | if len(self.ids_lookup[first_id]) == 0: 146 | del self.ids_lookup[first_id] 147 | return None 148 | self.word_embeddings[name] = embedding 149 | return embedding 150 | 151 | def load_from_file(self, path, filename): 152 | name, ext = os.path.splitext(filename) 153 | ext = ext.upper() 154 | 155 | if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: 156 | _, second_ext = os.path.splitext(name) 157 | if second_ext.upper() == '.PREVIEW': 158 | return 159 | 160 | embed_image = Image.open(path) 161 | if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: 162 | data = embedding_from_b64(embed_image.text['sd-ti-embedding']) 163 | name = data.get('name', name) 164 | else: 165 | data = extract_image_data_embed(embed_image) 166 | if data: 167 | name = data.get('name', name) 168 | else: 169 | return 170 | elif ext in ['.BIN', '.PT']: 171 | data = torch.load(path, map_location="cpu") 172 | elif ext in ['.SAFETENSORS']: 173 | data = safetensors.torch.load_file(path, device="cpu") 174 | else: 175 | return 176 | 177 | emb_out = None 178 | if data is not None: 179 | embedding = create_embedding_from_data(data, name, filename=filename, filepath=path) 180 | 181 | # if self.expected_shape == -1 or self.expected_shape == embedding.shape: 182 | emb_out = self.register_embedding(embedding) 183 | # else: 184 | # emb_out = self.skipped_embeddings[name] = embedding 185 | else: 186 | print(f"Unable to load Textual inversion embedding due to data issue: '{name}'.") 187 | return emb_out 188 | 189 | def load_from_dir(self, embdir): 190 | if not os.path.isdir(embdir.path): 191 | return 192 | 193 | for root, _, fns in os.walk(embdir.path, followlinks=True): 194 | for fn in fns: 195 | try: 196 | fullfn = os.path.join(root, fn) 197 | 198 | if os.stat(fullfn).st_size == 0: 199 | continue 200 | 201 | self.load_from_file(fullfn, fn) 202 | except Exception: 203 | print(f"Error loading embedding {fn}") 204 | continue 205 | 206 | def load_textual_inversion_embeddings(self): 207 | self.ids_lookup.clear() 208 | self.word_embeddings.clear() 209 | self.skipped_embeddings.clear() 210 | 211 | for embdir in self.embedding_dirs.values(): 212 | self.load_from_dir(embdir) 213 | embdir.update() 214 | 215 | return 216 | 217 | def find_embedding_at_position(self, tokens, offset): 218 | token = tokens[offset] 219 | possible_matches = self.ids_lookup.get(token, None) 220 | 221 | if possible_matches is None: 222 | return None, None 223 | 224 | for ids, embedding in possible_matches: 225 | if tokens[offset:offset + len(ids)] == ids: 226 | return embedding, len(ids) 227 | 228 | return None, None 229 | 230 | 231 | def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None): 232 | if 'string_to_param' in data: # textual inversion embeddings 233 | param_dict = data['string_to_param'] 234 | param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 235 | assert len(param_dict) == 1, 'embedding file has multiple terms in it' 236 | emb = next(iter(param_dict.items()))[1] 237 | vec = emb.detach().to(dtype=torch.float32) 238 | shape = vec.shape[-1] 239 | vectors = vec.shape[0] 240 | elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding 241 | vec = {k: v.detach().to(dtype=torch.float32) for k, v in data.items()} 242 | shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] 243 | vectors = data['clip_g'].shape[0] 244 | elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts 245 | assert len(data.keys()) == 1, 'embedding file has multiple terms in it' 246 | 247 | emb = next(iter(data.values())) 248 | if len(emb.shape) == 1: 249 | emb = emb.unsqueeze(0) 250 | vec = emb.detach().to(dtype=torch.float32) 251 | shape = vec.shape[-1] 252 | vectors = vec.shape[0] 253 | else: 254 | raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") 255 | 256 | embedding = Embedding(vec, name) 257 | embedding.step = data.get('step', None) 258 | embedding.sd_checkpoint = data.get('sd_checkpoint', None) 259 | embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) 260 | embedding.vectors = vectors 261 | embedding.shape = shape 262 | 263 | return embedding 264 | 265 | from comfy.sd1_clip import expand_directory_list 266 | def get_embed_file_path(embedding_name, embedding_directory): 267 | if isinstance(embedding_directory, str): 268 | embedding_directory = [embedding_directory] 269 | 270 | embedding_directory = expand_directory_list(embedding_directory) 271 | 272 | valid_file = None 273 | for embed_dir in embedding_directory: 274 | embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name)) 275 | embed_dir = os.path.abspath(embed_dir) 276 | try: 277 | if os.path.commonpath((embed_dir, embed_path)) != embed_dir: 278 | continue 279 | except Exception: 280 | continue 281 | if not os.path.isfile(embed_path): 282 | extensions = ['.safetensors', '.pt', '.bin'] 283 | for x in extensions: 284 | t = embed_path + x 285 | if os.path.isfile(t): 286 | valid_file = t 287 | break 288 | else: 289 | valid_file = embed_path 290 | if valid_file is not None: 291 | break 292 | 293 | if valid_file is None: 294 | return None 295 | 296 | return valid_file 297 | 298 | import re 299 | from ..shared import logger 300 | emb_re_ = r"(embedding:)?(?:({}[\w\.\-\!\$\/\\]+(\.safetensors|\.pt|\.bin)|(?(1)[\w\.\-\!\$\/\\]+|(?!)))(\.safetensors|\.pt|\.bin)?)(?:(:)(\d+\.?\d*|\d*\.\d+))?" 301 | def get_valid_embeddings(embedding_directories): 302 | from builtins import any as b_any 303 | exts = ['.safetensors', '.pt', '.bin'] 304 | if isinstance(embedding_directories, str): 305 | embedding_directories = [embedding_directories] 306 | embedding_directories = expand_directory_list(embedding_directories) 307 | embs = set() 308 | from collections import OrderedDict, namedtuple 309 | EmbedInfo = namedtuple('EmbedInfo', ['basename', 'filename', 'filepath']) 310 | store = OrderedDict() 311 | for embd in embedding_directories: 312 | for root, dirs, files in os.walk(embd, followlinks=True, topdown=False): 313 | for name in files: 314 | if not b_any(x in os.path.splitext(name)[1] for x in exts): continue 315 | basename = os.path.basename(name) 316 | for ext in exts: basename=basename.removesuffix(ext) 317 | relpath_basename = os.path.normpath(os.path.join(os.path.relpath(root, embd), basename)) 318 | k = os.path.normpath(os.path.join(os.path.relpath(root, embd), name)) 319 | store[k] = EmbedInfo(basename, name, os.path.join(root, name)) 320 | # add its counterpart 321 | if '/' in k: 322 | store[k.replace('/', '\\')] = EmbedInfo(basename, name, os.path.join(root, name)) 323 | elif '\\' in relpath_basename: 324 | store[k.replace('\\', '/')] = EmbedInfo(basename, name, os.path.join(root, name)) 325 | 326 | embs = OrderedDict(sorted(store.items(), key=lambda item: len(item[0]), reverse=True)) 327 | return embs 328 | 329 | class EmbbeddingRegex: 330 | STR_PATTERN = r"(embedding:)?(?:({}[\w\.\-\!\$\/\\]+(\.safetensors|\.pt|\.bin)|(?(1)[\w\.\-\!\$\/\\]+|(?!)))(\.safetensors|\.pt|\.bin)?)(?:(:)(\d+\.?\d*|\d*\.\d+))?" 331 | def __init__(self, embedding_directory) -> None: 332 | self.embedding_directory = embedding_directory 333 | self.embeddings = get_valid_embeddings(self.embedding_directory) if self.embedding_directory is not None else {} 334 | joined_keys = '|'.join([re.escape(os.path.splitext(k)[0]) for k in self.embeddings.keys()]) 335 | emb_re = self.STR_PATTERN.format(joined_keys + '|' if joined_keys else '') 336 | self.pattern = re.compile(emb_re, flags=re.MULTILINE | re.UNICODE | re.IGNORECASE) 337 | 338 | def parse_and_register_embeddings(self, text: str): 339 | embr = EmbbeddingRegex(self.embedding_directory) 340 | embs = embr.embeddings 341 | matches = embr.pattern.finditer(text) 342 | exts = ['.pt', '.safetensors', '.bin'] 343 | 344 | for matchNum, match in enumerate(matches, start=1): 345 | found=False 346 | ext = (match.group(4) or (match.group(3) or '')) 347 | embedding_sname = (match.group(2) or '').removesuffix(ext) 348 | embedding_name = embedding_sname + ext 349 | if embedding_name: 350 | embed = None 351 | if ext: 352 | embed_info = embs.get(embedding_name + ext, None) 353 | else: 354 | for _ext in exts: 355 | embed_info = embs.get(embedding_name + _ext, None) 356 | if embed_info is not None: break 357 | if embed_info is not None: 358 | found=True 359 | try: 360 | embed = self.embeddings.load_from_file(embed_info.filepath, embed_info.filename) 361 | except Exception as e: 362 | logging.warning(f'\033[33mWarning\033[0m loading embedding `{embedding_name + ext}`: {e}') 363 | if embed is not None: 364 | found=True 365 | logger.debug(f'using embedding:{embedding_name}') 366 | if not found: 367 | logging.warning(f"\033[33mwarning\033[0m, embedding:{embedding_name} does not exist, ignoring") 368 | # ComfyUI trims non-existent embedding_names while A1111 doesn't. 369 | # here we get group 2,5,6. (group 2 minus its file extension) 370 | out = embr.pattern.sub(lambda m: (m.group(2) or '').removesuffix(m.group(4) or (m.group(3) or '')) + (m.group(5) or '') + (m.group(6) or ''), text) 371 | return out -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import re 2 | import logging 3 | from itertools import chain 4 | from nodes import MAX_RESOLUTION 5 | import comfy.model_patcher 6 | import comfy.sd 7 | import comfy.model_management 8 | import comfy.samplers 9 | from .smZNodes import HijackClip, HijackClipComfy, get_learned_conditioning 10 | from comfy_extras.nodes_clip_sdxl import CLIPTextEncodeSDXL 11 | 12 | class smZ_CLIPTextEncode: 13 | @classmethod 14 | def INPUT_TYPES(s): 15 | return {"required": { 16 | "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), 17 | "clip": ("CLIP", ), 18 | "parser": (["comfy", "comfy++", "A1111", "full", "compel", "fixed attention"],{"default": "comfy"}), 19 | "mean_normalization": ("BOOLEAN", {"default": True, "tooltip": "Toggles whether weights are normalized by taking the mean"}), 20 | "multi_conditioning": ("BOOLEAN", {"default": True}), 21 | "use_old_emphasis_implementation": ("BOOLEAN", {"default": False}), 22 | "with_SDXL": ("BOOLEAN", {"default": False}), 23 | "ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}), 24 | "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), 25 | "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), 26 | "crop_w": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}), 27 | "crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}), 28 | "target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), 29 | "target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), 30 | "text_g": ("STRING", {"multiline": True, "placeholder": "CLIP_G", "dynamicPrompts": True}), 31 | "text_l": ("STRING", {"multiline": True, "placeholder": "CLIP_L", "dynamicPrompts": True}), 32 | }, 33 | "optional": { 34 | "smZ_steps": ("INT", {"default": 1, "min": 1, "max": 0xffffffffffffffff}), 35 | }, 36 | } 37 | RETURN_TYPES = ("CONDITIONING",) 38 | FUNCTION = "encode" 39 | CATEGORY = "conditioning" 40 | 41 | def encode(self, clip: comfy.sd.CLIP, text, parser, mean_normalization, 42 | multi_conditioning, use_old_emphasis_implementation, 43 | with_SDXL, ascore, width, height, crop_w, 44 | crop_h, target_width, target_height, text_g, text_l, smZ_steps=1): 45 | from .modules.shared import Options, opts, opts_default 46 | debug=opts.debug # get global opts' debug 47 | if (opts_new := clip.patcher.model_options.get(Options.KEY, None)) is not None: 48 | opts.update(opts_new) 49 | debug = opts_new.debug 50 | else: 51 | opts.update(opts_default) 52 | opts.debug = debug 53 | opts.prompt_mean_norm = mean_normalization 54 | opts.use_old_emphasis_implementation = use_old_emphasis_implementation 55 | opts.multi_conditioning = multi_conditioning 56 | class_name = clip.cond_stage_model.__class__.__name__ 57 | is_sdxl = "SDXL" in class_name 58 | on_sdxl = with_SDXL and is_sdxl 59 | parsers = { 60 | "full": "Full parser", 61 | "compel": "Compel parser", 62 | "A1111": "A1111 parser", 63 | "fixed attention": "Fixed attention", 64 | "comfy++": "Comfy++ parser", 65 | } 66 | opts.prompt_attention = parsers.get(parser, "Comfy parser") 67 | 68 | def _comfy_path(clip, text): 69 | nonlocal on_sdxl, class_name, ascore, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l 70 | if on_sdxl and class_name == "SDXLClipModel": 71 | return CLIPTextEncodeSDXL().encode(clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l) 72 | elif on_sdxl and class_name == "SDXLRefinerClipModel": 73 | from comfy_extras.nodes_clip_sdxl import CLIPTextEncodeSDXLRefiner 74 | return CLIPTextEncodeSDXLRefiner().encode(clip, clip, ascore, width, height, text) 75 | else: 76 | from nodes import CLIPTextEncode 77 | return CLIPTextEncode().encode(clip, text) 78 | 79 | def comfy_path(clip): 80 | nonlocal text 81 | if on_sdxl and class_name == "SDXLRefinerClipModel": 82 | return _comfy_path(clip, text) 83 | else: 84 | if multi_conditioning: 85 | prompts = re.compile(r"\bAND\b").split(text) 86 | return (list(chain(*(_comfy_path(clip, prompt)[0] for prompt in prompts))), ) 87 | else: 88 | return _comfy_path(clip, text) 89 | 90 | if parser == "comfy": 91 | with HijackClipComfy(clip) as clip: 92 | return comfy_path(clip) 93 | elif parser == "comfy++": 94 | with HijackClip(clip, opts) as clip: 95 | with HijackClipComfy(clip) as clip: 96 | return comfy_path(clip) 97 | 98 | with HijackClip(clip, opts) as clip: 99 | model = lambda txt: clip.encode_from_tokens(clip.tokenize(txt), return_pooled=True, return_dict=True) 100 | steps = max(smZ_steps, 1) 101 | if on_sdxl and class_name == "SDXLClipModel": 102 | # skip prompt-editing 103 | schedules = CLIPTextEncodeSDXL().encode(clip, width, height, crop_w, crop_h, target_width, target_height, [text_g], [text_l])[0] 104 | else: 105 | schedules = get_learned_conditioning(model, [text], steps, multi_conditioning) 106 | if on_sdxl and class_name == "SDXLRefinerClipModel": 107 | for cx in schedules: 108 | cx[1] |= {"aesthetic_score": ascore, "width": width,"height": height} 109 | return (schedules, ) 110 | 111 | # Hack: string type that is always equal in not equal comparisons 112 | class AnyType(str): 113 | def __eq__(self, _): 114 | return True 115 | 116 | def __ne__(self, _): 117 | return False 118 | 119 | # Our any instance wants to be a wildcard string 120 | anytype = AnyType("*") 121 | 122 | class smZ_Settings: 123 | @classmethod 124 | def INPUT_TYPES(s): 125 | from .modules.shared import opts_default as opts 126 | from .modules.text_processing.emphasis import get_options_descriptions_nl 127 | i = 0 128 | def create_heading(): 129 | nonlocal i 130 | return "ㅤ"*(i:=i+1) 131 | create_heading_value = lambda x: ("STRING", {"multiline": False, "default": x, "placeholder": x}) 132 | optional = { 133 | # "show_headings": ("BOOLEAN", {"default": True}), 134 | # "show_descriptions": ("BOOLEAN", {"default":True}), 135 | 136 | create_heading(): create_heading_value("Stable Diffusion"), 137 | "info_comma_padding_backtrack": ("STRING", {"multiline": True, "placeholder": "Prompt word wrap length limit\nin tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"}), 138 | "Prompt word wrap length limit": ("INT", {"default": opts.comma_padding_backtrack, "min": 0, "max": 74, "step": 1, "tooltip": "🚧Prompt word wrap length limit\n\nin tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"}), 139 | "enable_emphasis": ("BOOLEAN", {"default": opts.enable_emphasis, "tooltip": "🚧Emphasis mode\n\nmakes it possible to make model to pay (more:1.1) or (less:0.9) attention to text when you use the syntax in prompt;\n\n" + get_options_descriptions_nl()}), 140 | 141 | "info_RNG": ("STRING", {"multiline": True, "placeholder": "Random number generator source.\nchanges seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"}), 142 | "RNG": (["cpu", "gpu", "nv"],{"default": opts.randn_source, "tooltip": "Random number generator source.\n\nchanges seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"}), 143 | 144 | create_heading(): create_heading_value("Compute Settings"), 145 | "info_disable_nan_check": ("STRING", {"multiline": True, "placeholder": "Disable NaN check in produced images/latent spaces. Only for CFGDenoiser."}), 146 | "disable_nan_check": ("BOOLEAN", {"default": opts.disable_nan_check, "tooltip": "Disable NaN check in produced images/latent spaces. Only for CFGDenoiser."}), 147 | 148 | create_heading(): create_heading_value("Sampler parameters"), 149 | "info_eta_ancestral": ("STRING", {"multiline": True, "placeholder": "Eta for k-diffusion samplers\nnoise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"}), 150 | "eta": ("FLOAT", {"default": opts.eta, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Eta for k-diffusion samplers\n\nnoise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"}), 151 | "info_s_churn": ("STRING", {"multiline": True, "placeholder": "Sigma churn\namount of stochasticity; only applies to Euler, Heun, Heun++2, and DPM2"}), 152 | "s_churn": ("FLOAT", {"default": opts.s_churn, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Sigma churn\n\namount of stochasticity; only applies to Euler, Heun, Heun++2, and DPM2"}), 153 | "info_s_tmin": ("STRING", {"multiline": True, "placeholder": "Sigma tmin\nenable stochasticity; start value of the sigma range; only applies to Euler, Heun, Heun++2, and DPM2'"}), 154 | "s_tmin": ("FLOAT", {"default": opts.s_tmin, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Sigma tmin\n\nenable stochasticity; start value of the sigma range; only applies to Euler, Heun, Heun++2, and DPM2'"}), 155 | "info_s_tmax": ("STRING", {"multiline": True, "placeholder": "Sigma tmax\n0 = inf; end value of the sigma range; only applies to Euler, Heun, Heun++2, and DPM2"}), 156 | "s_tmax": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 999.0, "step": 0.01, "tooltip": "Sigma tmax\n\n0 = inf; end value of the sigma range; only applies to Euler, Heun, Heun++2, and DPM2"}), 157 | "info_s_noise": ("STRING", {"multiline": True, "placeholder": "Sigma noise\namount of additional noise to counteract loss of detail during sampling"}), 158 | "s_noise": ("FLOAT", {"default": opts.s_noise, "min": 0.0, "max": 1.1, "step": 0.001, "tooltip": "Sigma noise\n\namount of additional noise to counteract loss of detail during sampling"}), 159 | "info_eta_noise_seed_delta": ("STRING", {"multiline": True, "placeholder": "Eta noise seed delta\ndoes not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"}), 160 | "ENSD": ("INT", {"default": opts.eta_noise_seed_delta, "min": 0, "max": 0xffffffffffffffff, "step": 1, "tooltip": "Eta noise seed delta\n\ndoes not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"}), 161 | "info_skip_early_cond": ("STRING", {"multiline": True, "placeholder": "Ignore negative prompt during early sampling\ndisables CFG on a proportion of steps at the beginning of generation; 0=skip none; 1=skip all; can both improve sample diversity/quality and speed up sampling"}), 162 | "skip_early_cond": ("FLOAT", {"default": opts.skip_early_cond, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Ignore negative prompt during early sampling\n\ndisables CFG on a proportion of steps at the beginning of generation; 0=skip none; 1=skip all; can both improve sample diversity/quality and speed up sampling"}), 163 | "info_sgm_noise_multiplier": ("STRING", {"multiline": True, "placeholder": "SGM noise multiplier\nmatch initial noise to official SDXL implementation - only useful for reproducing images\nsee https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818"}), 164 | "sgm_noise_multiplier": ("BOOLEAN", {"default": opts.sgm_noise_multiplier, "tooltip": "SGM noise multiplier\n\nmatch initial noise to official SDXL implementation - only useful for reproducing images\nsee https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818"}), 165 | "info_upcast_sampling": ("STRING", {"multiline": True, "placeholder": "upcast sampling.\nNo effect with --force-fp32. Usually produces similar results to --force-fp32 with better performance while using less memory."}), 166 | "upcast_sampling": ("BOOLEAN", {"default": opts.upcast_sampling, "tooltip": "🚧upcast sampling.\n\nNo effect with --force-fp32. Usually produces similar results to --force-fp32 with better performance while using less memory."}), 167 | 168 | create_heading(): create_heading_value("Optimizations"), 169 | "info_NGMS": ("STRING", {"multiline": True, "placeholder": "Negative Guidance minimum sigma\nskip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster. Only for CFGDenoiser.\nsee https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177\nhttps://github.com/lllyasviel/stable-diffusion-webui-forge/pull/1434"}), 170 | "NGMS": ("FLOAT", {"default": opts.s_min_uncond, "min": 0.0, "max": 15.0, "step": 0.01, "tooltip": "Negative Guidance minimum sigma\n\nskip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster.\nsee https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177\nhttps://github.com/lllyasviel/stable-diffusion-webui-forge/pull/1434"}), 171 | "info_NGMS_all_steps": ("STRING", {"multiline": True, "placeholder": "Negative Guidance minimum sigma all steps\nBy default, NGMS above skips every other step; this makes it skip all steps"}), 172 | "NGMS all steps": ("BOOLEAN", {"default": opts.s_min_uncond_all, "tooltip": "Negative Guidance minimum sigma all steps\n\nBy default, NGMS above skips every other step; this makes it skip all steps"}), 173 | "info_pad_cond_uncond": ("STRING", {"multiline": True, "placeholder": "Pad prompt/negative prompt to be same length\nimproves performance when prompt and negative prompt have different lengths; changes seeds. Only for CFGDenoiser."}), 174 | "pad_cond_uncond": ("BOOLEAN", {"default": opts.pad_cond_uncond, "tooltip": "🚧Pad prompt/negative prompt to be same length\n\nimproves performance when prompt and negative prompt have different lengths; changes seeds. Only for CFGDenoiser."}), 175 | "info_batch_cond_uncond": ("STRING", {"multiline": True, "placeholder": "Batch cond/uncond\ndo both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed. Only for CFGDenoiser."}), 176 | "batch_cond_uncond": ("BOOLEAN", {"default": opts.batch_cond_uncond, "tooltip": "🚧Batch cond/uncond\n\ndo both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed. Only for CFGDenoiser."}), 177 | 178 | create_heading(): create_heading_value("Compatibility"), 179 | "info_use_prev_scheduling": ("STRING", {"multiline": True, "placeholder": "Previous prompt editing timelines\nFor [red:green:N]; previous: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"}), 180 | "Use previous prompt editing timelines": ("BOOLEAN", {"default": opts.use_old_scheduling, "tooltip": "🚧Previous prompt editing timelines\n\nFor [red:green:N]; previous: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"}), 181 | 182 | create_heading(): create_heading_value("Experimental"), 183 | "info_use_CFGDenoiser": ("STRING", {"multiline": True, "placeholder": "CFGDenoiser\nAn experimental option to use stable-diffusion-webui's denoiser. It allows you to use the 'Optimizations' settings listed here."}), 184 | "Use CFGDenoiser": ("BOOLEAN", {"default": opts.use_CFGDenoiser, "tooltip": "🚧CFGDenoiser\n\nAn experimental option to use stable-diffusion-webui's denoiser. It allows you to use the 'Optimizations' settings listed here."}), 185 | "info_debug": ("STRING", {"multiline": True, "placeholder": "Debugging messages in the console."}), 186 | "debug": ("BOOLEAN", {"default": opts.debug, "label_on": "on", "label_off": "off", "tooltip": "Debugging messages in the console."}), 187 | } 188 | return { 189 | "required": { 190 | "*": (anytype, {"forceInput": True}), 191 | }, 192 | "optional": { 193 | "extra": ("STRING", {"multiline": True, "default": '{"show_headings":true,"show_descriptions":false,"mode":"*"}'}), 194 | **optional, 195 | }, 196 | } 197 | RETURN_TYPES = (anytype,) 198 | FUNCTION = "apply" 199 | CATEGORY = "advanced" 200 | OUTPUT_TOOLTIPS = ("The model used for denoising latents.",) 201 | 202 | def apply(self, *args, **kwargs): 203 | first = kwargs.pop('*', None) if '*' in kwargs else args[0] 204 | if not hasattr(first, 'clone') or first is None: return (first,) 205 | 206 | kwargs['s_min_uncond'] = kwargs.pop('NGMS', 0.0) 207 | kwargs['s_min_uncond_all'] = kwargs.pop('NGMS all steps', False) 208 | kwargs['comma_padding_backtrack'] = kwargs.pop('Prompt word wrap length limit') 209 | kwargs['use_old_scheduling']=kwargs.pop("Use previous prompt editing timelines") 210 | kwargs['use_CFGDenoiser'] = kwargs.pop("Use CFGDenoiser") 211 | kwargs['randn_source'] = kwargs.pop('RNG') 212 | kwargs['eta_noise_seed_delta'] = kwargs.pop('ENSD') 213 | kwargs['s_tmax'] = kwargs['s_tmax'] or float('inf') 214 | 215 | from .modules.shared import Options, logger, opts_default, opts as opts_global 216 | 217 | opts_global.update(opts_default) 218 | opts = opts_default.clone() 219 | kwargs_new = {k: v for k, v in kwargs.items() if not ('info' in k or 'heading' in k or 'ㅤ' in k)} 220 | opts.update(kwargs_new) 221 | opts_global.debug = opts.debug 222 | 223 | opts_key = Options.KEY 224 | if isinstance(first, comfy.model_patcher.ModelPatcher): 225 | first = first.clone() 226 | first.model_options[opts_key] = opts 227 | elif isinstance(first, comfy.sd.CLIP): 228 | first = first.clone() 229 | first.patcher.model_options[opts_key] = opts 230 | logger.setLevel(logging.DEBUG if opts_global.debug else logging.INFO) 231 | return (first,) 232 | 233 | # A dictionary that contains all nodes you want to export with their names 234 | # NOTE: names should be globally unique 235 | NODE_CLASS_MAPPINGS = { 236 | "smZ CLIPTextEncode": smZ_CLIPTextEncode, 237 | "smZ Settings": smZ_Settings, 238 | } 239 | # A dictionary that contains the friendly/humanly readable titles for the nodes 240 | NODE_DISPLAY_NAME_MAPPINGS = { 241 | "smZ CLIPTextEncode" : "CLIP Text Encode++", 242 | "smZ Settings" : "Settings (smZ)", 243 | } 244 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui_smznodes" 3 | description = "Nodes such as CLIP Text Encode++ to achieve identical embeddings from stable-diffusion-webui for ComfyUI." 4 | version = "1.2.19" 5 | license = { file = "LICENSE" } 6 | dependencies = ["lark >= 1.1.9", "compel"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/shiimizu/ComfyUI_smZNodes" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "shiimizu" 14 | DisplayName = "ComfyUI_smZNodes" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /smZNodes.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import re 3 | import torch 4 | import inspect 5 | import contextlib 6 | import logging 7 | import comfy 8 | import math 9 | import ctypes 10 | from decimal import Decimal 11 | from functools import partial 12 | from random import getrandbits 13 | import comfy.sdxl_clip 14 | import comfy.sd1_clip 15 | import comfy.sample 16 | import comfy.utils 17 | import comfy.samplers 18 | from comfy.sd1_clip import unescape_important, escape_important, token_weights 19 | from .modules.shared import SimpleNamespaceFast, Options, logger, join_args 20 | from .modules.text_processing import prompt_parser 21 | from .modules.text_processing.past_classic_engine import process_texts_past 22 | from .modules.text_processing.textual_inversion import EmbbeddingRegex 23 | from .modules.text_processing.classic_engine import ClassicTextProcessingEngine 24 | from .modules.text_processing.t5_engine import T5TextProcessingEngine 25 | 26 | class Store(SimpleNamespaceFast): ... 27 | 28 | store = Store() 29 | 30 | def register_hooks(): 31 | from .modules.rng import prepare_noise 32 | patches = [ 33 | (comfy.samplers, 'get_area_and_mult', get_area_and_mult), 34 | (comfy.samplers.KSampler, 'sample', KSampler_sample), 35 | (comfy.samplers.KSAMPLER, 'sample', KSAMPLER_sample), 36 | (comfy.samplers, 'sample', sample), 37 | (comfy.samplers.Sampler, 'max_denoise', max_denoise), 38 | (comfy.samplers, 'sampling_function', sampling_function), 39 | (comfy.sample, 'prepare_noise', prepare_noise), 40 | ] 41 | for parent, fn_name, fn_patch in patches: 42 | if not hasattr(store, fn_patch.__name__): 43 | setattr(store, fn_patch.__name__, getattr(parent, fn_name)) 44 | setattr(parent, fn_name, fn_patch) 45 | 46 | def iter_items(d): 47 | for key, value in d.items(): 48 | yield key, value 49 | if isinstance(value, dict): 50 | yield from iter_items(value) 51 | 52 | def find_nearest(a,b): 53 | # Calculate the absolute differences. 54 | diff = (a - b).abs() 55 | 56 | # Find the indices of the nearest elements 57 | nearest_indices = diff.argmin() 58 | 59 | # Get the nearest elements from b 60 | return b[nearest_indices] 61 | 62 | def get_area_and_mult(*args, **kwargs): 63 | conds = args[0] 64 | if 'start_perc' in conds and 'end_perc' in conds and "init_steps" in conds: 65 | timestep_in = args[2] 66 | sigmas = store.sigmas 67 | if conds['init_steps'] == sigmas.shape[0] - 1: 68 | total = Decimal(sigmas.shape[0] - 1) 69 | else: 70 | sigmas_ = store.sigmas.unique(sorted=True).sort(descending=True)[0] 71 | if len(sigmas) == len(sigmas_): 72 | # Sampler Custom with sigmas: no change 73 | total = Decimal(sigmas.shape[0] - 1) 74 | else: 75 | # Sampler with restarts: dedup the sigmas and add one 76 | sigmas = sigmas_ 77 | total = Decimal(sigmas.shape[0] + 1) 78 | ts_in = find_nearest(timestep_in, sigmas) 79 | cur_i = ss[0].item() if (ss:=(sigmas == ts_in).nonzero()).shape[0] != 0 else 0 80 | cur = Decimal(cur_i) / total 81 | start = conds['start_perc'] 82 | end = conds['end_perc'] 83 | if not (cur >= start and cur < end): 84 | return None 85 | return store.get_area_and_mult(*args, **kwargs) 86 | 87 | def KSAMPLER_sample(*args, **kwargs): 88 | orig_fn = store.KSAMPLER_sample 89 | extra_args = None 90 | model_options = None 91 | try: 92 | extra_args = kwargs['extra_args'] if 'extra_args' in kwargs else args[3] 93 | model_options = extra_args['model_options'] 94 | except Exception: ... 95 | if model_options is not None and extra_args is not None: 96 | sigmas_ = kwargs['sigmas'] if 'sigmas' in kwargs else args[2] 97 | sigmas_all = model_options.pop('sigmas', None) 98 | sigmas = sigmas_all if sigmas_all is not None else sigmas_ 99 | store.sigmas = sigmas 100 | 101 | import comfy.k_diffusion.sampling 102 | if hasattr(comfy.k_diffusion.sampling, 'default_noise_sampler_orig'): 103 | if getattr(comfy.k_diffusion.sampling.default_noise_sampler, 'init', False): 104 | comfy.k_diffusion.sampling.default_noise_sampler.init = False 105 | else: 106 | comfy.k_diffusion.sampling.default_noise_sampler = comfy.k_diffusion.sampling.default_noise_sampler_orig 107 | 108 | if 'Hijack' in comfy.k_diffusion.sampling.torch.__class__.__name__: 109 | if getattr(comfy.k_diffusion.sampling.torch, 'init', False): 110 | comfy.k_diffusion.sampling.torch.init = False 111 | else: 112 | if hasattr(comfy.k_diffusion.sampling, 'torch_orig'): 113 | comfy.k_diffusion.sampling.torch = comfy.k_diffusion.sampling.torch_orig 114 | return orig_fn(*args, **kwargs) 115 | 116 | def KSampler_sample(*args, **kwargs): 117 | orig_fn = store.KSampler_sample 118 | self = args[0] 119 | model_patcher = getattr(self, 'model', None) 120 | model_options = getattr(model_patcher, 'model_options', None) 121 | if model_options is not None: 122 | sigmas = None 123 | try: sigmas = kwargs['sigmas'] if 'sigmas' in kwargs else args[10] 124 | except Exception: ... 125 | if sigmas is None: 126 | sigmas = getattr(self, 'sigmas', None) 127 | if sigmas is not None: 128 | model_options = model_options.copy() 129 | model_options['sigmas'] = sigmas 130 | self.model.model_options = model_options 131 | return orig_fn(*args, **kwargs) 132 | 133 | def sample(*args, **kwargs): 134 | orig_fn = store.sample 135 | model_patcher = args[0] 136 | model_options = getattr(model_patcher, 'model_options', None) 137 | sampler = kwargs['sampler'] if 'sampler' in kwargs else args[6] 138 | if model_options is not None and Options.KEY in model_options: 139 | if hasattr(sampler, 'sampler_function'): 140 | opts = model_options[Options.KEY] 141 | if not hasattr(sampler, f'_sampler_function'): 142 | sampler._sampler_function = sampler.sampler_function 143 | sampler_function_sig_params = inspect.signature(sampler._sampler_function).parameters 144 | params = {x: getattr(opts, x) for x in ['eta', 's_churn', 's_tmin', 's_tmax', 's_noise'] if x in sampler_function_sig_params} 145 | sampler.sampler_function = lambda *a, **kw: sampler._sampler_function(*a, **{**kw, **params}) 146 | else: 147 | if hasattr(sampler, '_sampler_function'): 148 | sampler.sampler_function = sampler._sampler_function 149 | return orig_fn(*args, **kwargs) 150 | 151 | def max_denoise(*args, **kwargs): 152 | orig_fn = store.max_denoise 153 | model_wrap = kwargs['model_wrap'] if 'model_wrap' in kwargs else args[1] 154 | base_model = getattr(model_wrap, 'inner_model', None) 155 | model_options = getattr(model_wrap, 'model_options', getattr(base_model, 'model_options', None)) 156 | return orig_fn(*args, **kwargs) if getattr(model_options.get(Options.KEY, True), 'sgm_noise_multiplier', True) else False 157 | 158 | def sampling_function(*args, **kwargs): 159 | orig_fn = store.sampling_function 160 | model_options = kwargs['model_options'] if 'model_options' in kwargs else args[6] 161 | model_options=model_options.copy() 162 | kwargs['model_options'] = model_options 163 | if Options.KEY in model_options: 164 | opts = model_options[Options.KEY] 165 | if opts.s_min_uncond_all or opts.s_min_uncond > 0 or opts.skip_early_cond > 0: 166 | cond_scale = _cond_scale = kwargs['cond_scale'] if 'cond_scale' in kwargs else args[5] 167 | sigmas = store.sigmas 168 | sigma = kwargs['timestep'] if 'timestep' in kwargs else args[2] 169 | ts_in = find_nearest(sigma, sigmas) 170 | step = ss[0].item() if (ss:=(sigmas == ts_in).nonzero()).shape[0] != 0 else 0 171 | total_steps = sigmas.shape[0] - 1 172 | 173 | if opts.skip_early_cond > 0 and step / total_steps <= opts.skip_early_cond: 174 | cond_scale = 1.0 175 | elif (step % 2 or opts.s_min_uncond_all) and opts.s_min_uncond > 0 and sigma[0] < opts.s_min_uncond: 176 | cond_scale = 1.0 177 | 178 | if cond_scale != _cond_scale: 179 | if 'cond_scale' not in kwargs: 180 | args = args[:5] 181 | kwargs['cond_scale'] = cond_scale 182 | 183 | cond = kwargs['cond'] if 'cond' in kwargs else args[4] 184 | weights = [x.get('weight', None) for x in cond] 185 | has_some = any(item is not None for item in weights) and len(weights) > 1 186 | if has_some: 187 | out = CFGDenoiser(orig_fn).sampling_function(*args, **kwargs) 188 | else: 189 | out = orig_fn(*args, **kwargs) 190 | return out 191 | 192 | 193 | @contextlib.contextmanager 194 | def HijackClip(clip, opts): 195 | a1 = 'tokenizer', 'tokenize_with_weights' 196 | a2 = 'cond_stage_model', 'encode_token_weights' 197 | ls = [a1, a2] 198 | store = {} 199 | store_orig = {} 200 | try: 201 | for obj, attr in ls: 202 | for clip_name, v in iter_items(getattr(clip, obj).__dict__): 203 | if hasattr(v, attr): 204 | logger.debug(join_args(attr, obj, clip_name, type(v).__qualname__, getattr(v, attr).__qualname__)) 205 | if clip_name not in store_orig: 206 | store_orig[clip_name] = {} 207 | store_orig[clip_name][obj] = v 208 | for clip_name, inner_store in store_orig.items(): 209 | text_encoder = inner_store['cond_stage_model'] 210 | tokenizer = inner_store['tokenizer'] 211 | emphasis_name = 'Original' if opts.prompt_mean_norm else "No norm" 212 | if 't5' in clip_name: 213 | text_processing_engine = T5TextProcessingEngine( 214 | text_encoder=text_encoder, 215 | tokenizer=tokenizer, 216 | emphasis_name=emphasis_name, 217 | ) 218 | else: 219 | text_processing_engine = ClassicTextProcessingEngine( 220 | text_encoder=text_encoder, 221 | tokenizer=tokenizer, 222 | emphasis_name=emphasis_name, 223 | ) 224 | text_processing_engine.opts = opts 225 | text_processing_engine.process_texts_past = partial(process_texts_past, text_processing_engine) 226 | store[clip_name] = text_processing_engine 227 | for obj, attr in ls: 228 | setattr(inner_store[obj], attr, getattr(store[clip_name], attr)) 229 | yield clip 230 | finally: 231 | for clip_name, inner_store in store_orig.items(): 232 | getattr(inner_store[a2[0]], a2[1]).__self__.unhook() 233 | for obj, attr in ls: 234 | try: delattr(inner_store[obj], attr) 235 | except Exception: ... 236 | del store 237 | del store_orig 238 | 239 | @contextlib.contextmanager 240 | def HijackClipComfy(clip): 241 | a1 = 'tokenizer', 'tokenize_with_weights' 242 | ls = [a1] 243 | store_orig = {} 244 | try: 245 | for obj, attr in ls: 246 | for clip_name, v in iter_items(getattr(clip, obj).__dict__): 247 | if hasattr(v, attr): 248 | logger.debug(join_args(attr, obj, clip_name, type(v).__qualname__, getattr(v, attr).__qualname__)) 249 | if clip_name not in store_orig: 250 | store_orig[clip_name] = {} 251 | store_orig[clip_name][obj] = v 252 | setattr(v, attr, partial(tokenize_with_weights_custom, v)) 253 | yield clip 254 | finally: 255 | for clip_name, inner_store in store_orig.items(): 256 | for obj, attr in ls: 257 | try: delattr(inner_store[obj], attr) 258 | except Exception: ... 259 | del store_orig 260 | 261 | def transform_schedules(steps, schedules, weight=None, with_weight=False): 262 | end_steps = [schedule.end_at_step for schedule in schedules] 263 | start_end_pairs = list(zip([0] + end_steps[:-1], end_steps)) 264 | with_prompt_editing = len(schedules) > 1 265 | 266 | def process(schedule, start_step, end_step): 267 | nonlocal with_prompt_editing 268 | d = schedule.cond.copy() 269 | d.pop('cond', None) 270 | if with_prompt_editing: 271 | d |= {"start_perc": Decimal(start_step) / Decimal(steps), "end_perc": Decimal(end_step) / Decimal(steps), "init_steps": steps} 272 | if weight is not None and with_weight: 273 | d['weight'] = weight 274 | return d 275 | return [ 276 | [ 277 | schedule.cond.get("cond", None), 278 | process(schedule, start_step, end_step) 279 | ] 280 | for schedule, (start_step, end_step) in zip(schedules, start_end_pairs) 281 | ] 282 | 283 | def flatten(nested_list): 284 | return [item for sublist in nested_list for item in sublist] 285 | 286 | def convert_schedules_to_comfy(schedules, steps, multi=False): 287 | if multi: 288 | out = [[transform_schedules(steps, x.schedules, x.weight, len(batch)>1) for x in batch] for batch in schedules.batch] 289 | out = flatten(out) 290 | else: 291 | out = [transform_schedules(steps, sublist) for sublist in schedules] 292 | return flatten(out) 293 | 294 | def get_learned_conditioning(model, prompts, steps, multi=False, *args, **kwargs): 295 | if multi: 296 | schedules = prompt_parser.get_multicond_learned_conditioning(model, prompts, steps, *args, **kwargs) 297 | else: 298 | schedules = prompt_parser.get_learned_conditioning(model, prompts, steps, *args, **kwargs) 299 | schedules_c = convert_schedules_to_comfy(schedules, steps, multi) 300 | return schedules_c 301 | 302 | class CustomList(list): 303 | def __init__(self, *args): 304 | super().__init__(*args) 305 | def __setattr__(self, name: str, value: re.Any): 306 | super().__setattr__(name, value) 307 | return self 308 | 309 | def modify_locals_values(frame, fn): 310 | # https://stackoverflow.com/a/34671307 311 | try: ctypes.pythonapi.PyFrame_LocalsToFast(ctypes.py_object(frame), ctypes.c_int(1)) 312 | except Exception: ... 313 | fn(frame) 314 | try: ctypes.pythonapi.PyFrame_LocalsToFast(ctypes.py_object(frame), ctypes.c_int(1)) 315 | except Exception: ... 316 | 317 | def update_locals(frame,k,v,list_app=False): 318 | if not list_app: 319 | modify_locals_values(frame, lambda _frame: _frame.f_locals.__setitem__(k, v)) 320 | else: 321 | if not isinstance(frame.f_locals[k], CustomList): 322 | out_conds_store = CustomList([]) 323 | out_conds_store.outputs=[] 324 | modify_locals_values(frame, lambda _frame: _frame.f_locals.__setitem__(k, out_conds_store)) 325 | v.area = frame.f_locals['area'] 326 | v.mult = frame.f_locals['mult'] 327 | frame.f_locals[k].outputs.append(v) 328 | frame.f_locals[k].out_conds = frame.f_locals['out_conds'] 329 | frame.f_locals[k].out_counts = frame.f_locals['out_counts'] 330 | modify_locals_values(frame, lambda _frame: _frame.f_locals.__setitem__('batch_chunks', 0)) 331 | 332 | def model_function_wrapper_cd(model, args, id, model_options={}): 333 | input_x = args['input'] 334 | timestep_ = args['timestep'] 335 | c = args['c'] 336 | cond_or_uncond = args['cond_or_uncond'] 337 | batch_chunks = len(cond_or_uncond) 338 | if f'model_function_wrapper_{id}' in model_options: 339 | output = model_options[f'model_function_wrapper_{id}'](model, args) 340 | else: 341 | output = model(input_x, timestep_, **c) 342 | output.cond_or_uncond = cond_or_uncond 343 | output.batch_chunks = batch_chunks 344 | output.output_chunks = output.chunk(batch_chunks) 345 | output.chunk = lambda *aa, **kw: output 346 | get_parent_variable('out_conds', list, lambda frame: update_locals(frame, 'out_conds', output, list_app=True)) 347 | return output 348 | 349 | def get_parent_variable(vname, vtype, fn): 350 | frame = inspect.currentframe().f_back # Get the current frame's parent 351 | while frame: 352 | if vname in frame.f_locals: 353 | val = frame.f_locals[vname] 354 | if isinstance(val, vtype): 355 | if fn is not None: 356 | fn(frame) 357 | return frame.f_locals[vname] 358 | frame = frame.f_back 359 | return None 360 | 361 | def cd_cfg_function(kwargs, id): 362 | model_options = kwargs['model_options'] 363 | if f"sampler_cfg_function_{id}" in model_options: 364 | return model_options[f'sampler_cfg_function_{id}'](kwargs) 365 | x = kwargs['input'] 366 | cond_pred = kwargs['cond_denoised'] 367 | uncond_pred = kwargs['uncond_denoised'] 368 | cond_scale = kwargs['cond_scale'] 369 | cfg_result = model_options['cfg_result'] 370 | cfg_result += (cond_pred - uncond_pred) * cond_scale 371 | return x - cfg_result 372 | 373 | class CFGDenoiser: 374 | def __init__(self, orig_fn) -> None: 375 | self.orig_fn = orig_fn 376 | 377 | def sampling_function(self, model, x, timestep, uncond, cond, cond_scale, model_options, *args0, **kwargs0): 378 | if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: 379 | uncond_ = None 380 | else: 381 | uncond_ = uncond 382 | 383 | conds = [cond, uncond_] 384 | 385 | if uncond_ is None: 386 | return self.orig_fn(model, x, timestep, uncond, cond, cond_scale, model_options, *args0, **kwargs0) 387 | 388 | id = getrandbits(7) 389 | if 'model_function_wrapper' in model_options: 390 | model_options[f'model_function_wrapper_{id}'] = model_options.pop('model_function_wrapper') 391 | model_options['model_function_wrapper'] = partial(model_function_wrapper_cd, id=id, model_options=model_options) 392 | out = comfy.samplers.calc_cond_batch(model, conds, x, timestep, model_options) 393 | model_options.pop('model_function_wrapper', None) 394 | if f'model_function_wrapper_{id}' in model_options: 395 | model_options['model_function_wrapper'] = model_options.pop(f'model_function_wrapper_{id}') 396 | 397 | outputs = out.outputs 398 | 399 | out_conds = out.out_conds 400 | out_counts= out.out_counts 401 | if len(out_conds) < len(out_counts): 402 | for _ in out_counts: 403 | out_conds.append(torch.zeros_like(outputs[0].output_chunks[0])) 404 | 405 | oconds=[] 406 | for _output in outputs: 407 | cond_or_uncond=_output.cond_or_uncond 408 | batch_chunks=_output.batch_chunks 409 | output=_output.output_chunks 410 | area=_output.area 411 | mult=_output.mult 412 | for o in range(batch_chunks): 413 | cond_index = cond_or_uncond[o] 414 | a = area[o] 415 | if a is None: 416 | if cond_index == 0: 417 | oconds.append(output[o] * mult[o]) 418 | else: 419 | out_conds[cond_index] += output[o] * mult[o] 420 | out_counts[cond_index] += mult[o] 421 | else: 422 | out_c = out_conds[cond_index] if cond_index != 0 else torch.zeros_like(out_conds[cond_index]) 423 | out_cts = out_counts[cond_index] 424 | dims = len(a) // 2 425 | for i in range(dims): 426 | out_c = out_c.narrow(i + 2, a[i + dims], a[i]) 427 | out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) 428 | out_c += output[o] * mult[o] 429 | out_cts += mult[o] 430 | if cond_index == 0: 431 | oconds.append(out_c) 432 | 433 | for i in range(len(out_conds)): 434 | if i != 0: 435 | out_conds[i] /= out_counts[i] 436 | 437 | del out 438 | out = out_conds 439 | 440 | for fn in model_options.get("sampler_pre_cfg_function", []): 441 | out[0] = torch.cat(oconds).to(oconds[0]) 442 | args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep, 443 | "input": x, "sigma": timestep, "model": model, "model_options": model_options} 444 | out = fn(args) 445 | 446 | # ComfyUI: last prompt -> first 447 | # conds were reversed in calc_cond_batch, so do the same for weights 448 | weights = [x.get('weight', None) for x in cond] 449 | weights.reverse() 450 | out_uncond = out[1] 451 | cfg_result = out_uncond.clone() 452 | cond_scale = cond_scale / max(len(oconds), 1) 453 | 454 | if "sampler_cfg_function" in model_options: 455 | model_options[f'sampler_cfg_function_{id}'] = model_options.pop('sampler_cfg_function') 456 | model_options['sampler_cfg_function'] = partial(cd_cfg_function, id=id) 457 | model_options['cfg_result'] = cfg_result 458 | 459 | # ComfyUI: computes the average -> do cfg 460 | # A1111: (cond - uncond) / total_len_of_conds -> in-place addition for each cond -> results in cfg 461 | for ix, ocond in enumerate(oconds): 462 | weight = (weights[ix:ix+1] or [1.0])[0] or 1.0 463 | # cfg_result += (ocond - out_uncond) * (weight * cond_scale) # all this code just to do this 464 | if f"sampler_cfg_function_{id}" in model_options: 465 | # case when there's another cfg_fn. subtract out_uncond and in-place add the result. feed result back in. 466 | cfg_result += comfy.samplers.cfg_function(model, ocond, out_uncond, weight * cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_) - out_uncond 467 | else: # calls cd_cfg_function and does an in-place addition 468 | if model_options.get("sampler_post_cfg_function", []): 469 | # feed the result back in. 470 | cfg_result = comfy.samplers.cfg_function(model, ocond, out_uncond, weight * cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_) 471 | else: 472 | # default case. discards the output. 473 | comfy.samplers.cfg_function(model, ocond, out_uncond, weight * cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_) 474 | model_options['cfg_result'] = cfg_result 475 | return cfg_result 476 | 477 | def tokenize_with_weights_custom(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs): 478 | ''' 479 | Takes a prompt and converts it to a list of (token, weight, word id) elements. 480 | Tokens can both be integer tokens and pre computed CLIP tensors. 481 | Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens. 482 | Returned list has the dimensions NxM where M is the input size of CLIP 483 | ''' 484 | min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length) 485 | min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding) 486 | 487 | text = escape_important(text) 488 | parsed_weights = token_weights(text, 1.0) 489 | embr = EmbbeddingRegex(self.embedding_directory) 490 | 491 | # tokenize words 492 | tokens = [] 493 | for weighted_segment, weight in parsed_weights: 494 | to_tokenize = unescape_important(weighted_segment) 495 | split = re.split(' {0}|\n{0}'.format(self.embedding_identifier), to_tokenize) 496 | to_tokenize = [split[0]] 497 | for i in range(1, len(split)): 498 | to_tokenize.append("{}{}".format(self.embedding_identifier, split[i])) 499 | 500 | to_tokenize = [x for x in to_tokenize if x != ""] 501 | for word in to_tokenize: 502 | matches = embr.pattern.finditer(word) 503 | last_end = 0 504 | leftovers=[] 505 | for _, match in enumerate(matches, start=1): 506 | start=match.start() 507 | end_match=match.end() 508 | if (fragment:=word[last_end:start]): 509 | leftovers.append(fragment) 510 | ext = (match.group(4) or (match.group(3) or '')) 511 | embedding_sname = (match.group(2) or '').removesuffix(ext) 512 | embedding_name = embedding_sname + ext 513 | if embedding_name: 514 | embed, leftover = self._try_get_embedding(embedding_name) 515 | if embed is None: 516 | logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring") 517 | else: 518 | logger.debug(f'using embedding:{embedding_name}') 519 | if len(embed.shape) == 1: 520 | tokens.append([(embed, weight)]) 521 | else: 522 | tokens.append([(embed[x], weight) for x in range(embed.shape[0])]) 523 | last_end = end_match 524 | if (fragment:=word[last_end:]): 525 | leftovers.append(fragment) 526 | word_new = ''.join(leftovers) 527 | end = 999999999999 528 | if self.tokenizer_adds_end_token: 529 | end = -1 530 | #parse word 531 | tokens.append([(t, weight) for t in self.tokenizer(word_new)["input_ids"][self.tokens_start:end]]) 532 | 533 | #reshape token array to CLIP input size 534 | batched_tokens = [] 535 | batch = [] 536 | if self.start_token is not None: 537 | batch.append((self.start_token, 1.0, 0)) 538 | batched_tokens.append(batch) 539 | for i, t_group in enumerate(tokens): 540 | #determine if we're going to try and keep the tokens in a single batch 541 | is_large = len(t_group) >= self.max_word_length 542 | if self.end_token is not None: 543 | has_end_token = 1 544 | else: 545 | has_end_token = 0 546 | 547 | while len(t_group) > 0: 548 | if len(t_group) + len(batch) > self.max_length - has_end_token: 549 | remaining_length = self.max_length - len(batch) - has_end_token 550 | #break word in two and add end token 551 | if is_large: 552 | batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) 553 | if self.end_token is not None: 554 | batch.append((self.end_token, 1.0, 0)) 555 | t_group = t_group[remaining_length:] 556 | #add end token and pad 557 | else: 558 | if self.end_token is not None: 559 | batch.append((self.end_token, 1.0, 0)) 560 | if self.pad_to_max_length: 561 | batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length)) 562 | #start new batch 563 | batch = [] 564 | if self.start_token is not None: 565 | batch.append((self.start_token, 1.0, 0)) 566 | batched_tokens.append(batch) 567 | else: 568 | batch.extend([(t,w,i+1) for t,w in t_group]) 569 | t_group = [] 570 | 571 | #fill last batch 572 | if self.end_token is not None: 573 | batch.append((self.end_token, 1.0, 0)) 574 | if min_padding is not None: 575 | batch.extend([(self.pad_token, 1.0, 0)] * min_padding) 576 | if self.pad_to_max_length and len(batch) < self.max_length: 577 | batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch))) 578 | if min_length is not None and len(batch) < min_length: 579 | batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch))) 580 | 581 | if not return_word_ids: 582 | batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] 583 | 584 | return batched_tokens 585 | 586 | # ======================================================================== 587 | 588 | from server import PromptServer 589 | 590 | def is_prompt_editing(schedules): 591 | if schedules == None: return False 592 | if not isinstance(schedules, dict): 593 | schedules = {'g': schedules} 594 | ret = False 595 | for k,v in schedules.items(): 596 | if type(v) is dict and 'schedules' in v: 597 | v=v['schedules'] 598 | if type(v) == list: 599 | for vb in v: 600 | if len(vb) != 1: ret = True 601 | else: 602 | if v: 603 | for vb in v.batch: 604 | for cs in vb: 605 | if len(cs.schedules) != 1: ret = True 606 | return ret 607 | 608 | def prompt_handler(json_data): 609 | data=json_data['prompt'] 610 | steps_validator = lambda x: isinstance(x, (int, float, str)) 611 | text_validator = lambda x: isinstance(x, str) 612 | def find_nearest_ksampler(clip_id): 613 | """Find the nearest KSampler node that references the given CLIPTextEncode id.""" 614 | nonlocal data, steps_validator 615 | for ksampler_id, node in data.items(): 616 | if "class_type" in node and ("Sampler" in node["class_type"] or "sampler" in node["class_type"]): 617 | # Check if this KSampler node directly or indirectly references the given CLIPTextEncode node 618 | if check_link_to_clip(ksampler_id, clip_id): 619 | return get_val(data, ksampler_id, steps_validator, 'steps') 620 | return None 621 | 622 | def get_val(graph, node_id, validator, val): 623 | node = graph.get(str(node_id), {}) 624 | if val == 'steps': 625 | steps_input_value = node.get("inputs", {}).get("steps", None) 626 | if steps_input_value is None: 627 | steps_input_value = node.get("inputs", {}).get("sigmas", None) 628 | else: 629 | steps_input_value = node.get("inputs", {}).get(val, None) 630 | 631 | while(True): 632 | # Base case: it's a direct value 633 | if not isinstance(steps_input_value, list) and validator(steps_input_value): 634 | if val == 'steps': 635 | s = 1 636 | try: s = min(max(1, int(steps_input_value)), 10000) 637 | except Exception as e: 638 | logging.warning(f"\033[33mWarning:\033[0m [smZNodes] Skipping prompt editing. Try recreating the node. {e}") 639 | return s 640 | else: 641 | return steps_input_value 642 | # Loop case: it's a reference to another node 643 | elif isinstance(steps_input_value, list): 644 | ref_node_id, ref_input_index = steps_input_value 645 | ref_node = graph.get(str(ref_node_id), {}) 646 | steps_input_value = ref_node.get("inputs", {}).get(val, None) 647 | if steps_input_value is None: 648 | keys = list(ref_node.get("inputs", {}).keys()) 649 | ref_input_key = keys[ref_input_index % len(keys)] 650 | steps_input_value = ref_node.get("inputs", {}).get(ref_input_key) 651 | else: 652 | return None 653 | 654 | def check_link_to_clip(node_id, clip_id, visited=None): 655 | """Check if a given node links directly or indirectly to a CLIPTextEncode node.""" 656 | nonlocal data 657 | if visited is None: 658 | visited = set() 659 | 660 | node = data[node_id] 661 | 662 | if node_id in visited: 663 | return False 664 | visited.add(node_id) 665 | 666 | for input_value in node["inputs"].values(): 667 | if isinstance(input_value, list) and input_value[0] == clip_id: 668 | return True 669 | if isinstance(input_value, list) and check_link_to_clip(input_value[0], clip_id, visited): 670 | return True 671 | 672 | return False 673 | 674 | 675 | # Update each CLIPTextEncode node's steps with the steps from its nearest referencing KSampler node 676 | for clip_id, node in data.items(): 677 | if "class_type" in node and node["class_type"] == "smZ CLIPTextEncode": 678 | check_str = prompt_editing = False 679 | if check_str: 680 | if (fast_search:=True): 681 | with_SDXL = get_val(data, clip_id, lambda x: isinstance(x, (bool, int, float)), 'with_SDXL') 682 | if with_SDXL: 683 | ls = is_prompt_editing_str(get_val(data, clip_id, text_validator, 'text_l')) 684 | gs = is_prompt_editing_str(get_val(data, clip_id, text_validator, 'text_g')) 685 | prompt_editing = ls or gs 686 | else: 687 | text = get_val(data, clip_id, text_validator, 'text') 688 | prompt_editing = is_prompt_editing_str(text) 689 | else: 690 | text = get_val(data, clip_id, text_validator, 'text') 691 | prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules([text], steps, None, False) 692 | prompt_editing = sum([len(ps) for ps in prompt_schedules]) != 1 693 | if check_str and not prompt_editing: continue 694 | steps = find_nearest_ksampler(clip_id) 695 | if steps is not None: 696 | node["inputs"]["smZ_steps"] = steps 697 | # logger.debug(f'id: {clip_id} | steps: {steps}') 698 | return json_data 699 | 700 | def is_prompt_editing_str(t: str): 701 | """ 702 | Determine if a string includes prompt editing. 703 | This won't cover every case, but it does the job for most. 704 | """ 705 | if t is None: return True 706 | if (openb:=t.find('[')) != -1: 707 | if (colon:=t.find(':', openb)) != -1 and t.find(']', colon) != -1: 708 | return True 709 | elif (pipe:=t.find('|', openb)) != -1 and t.find(']', pipe) != -1: 710 | return True 711 | return False 712 | 713 | if hasattr(PromptServer.instance, 'add_on_prompt_handler'): 714 | PromptServer.instance.add_on_prompt_handler(prompt_handler) 715 | 716 | # ======================================================================== 717 | 718 | # DPM++ 2M alt 719 | 720 | from tqdm.auto import trange 721 | @torch.no_grad() 722 | def sample_dpmpp_2m_alt(model, x, sigmas, extra_args=None, callback=None, disable=None): 723 | """DPM-Solver++(2M).""" 724 | extra_args = {} if extra_args is None else extra_args 725 | s_in = x.new_ones([x.shape[0]]) 726 | sigma_fn = lambda t: t.neg().exp() 727 | t_fn = lambda sigma: sigma.log().neg() 728 | old_denoised = None 729 | 730 | for i in trange(len(sigmas) - 1, disable=disable): 731 | denoised = model(x, sigmas[i] * s_in, **extra_args) 732 | if callback is not None: 733 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 734 | t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) 735 | h = t_next - t 736 | if old_denoised is None or sigmas[i + 1] == 0: 737 | x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised 738 | else: 739 | h_last = t - t_fn(sigmas[i - 1]) 740 | r = h_last / h 741 | denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised 742 | x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d 743 | sigma_progress = i / len(sigmas) 744 | adjustment_factor = 1 + (0.15 * (sigma_progress * sigma_progress)) 745 | old_denoised = denoised * adjustment_factor 746 | return x 747 | 748 | 749 | def add_sample_dpmpp_2m_alt(): 750 | from comfy.samplers import KSampler, k_diffusion_sampling 751 | if "dpmpp_2m_alt" not in KSampler.SAMPLERS: 752 | try: 753 | idx = KSampler.SAMPLERS.index("dpmpp_2m") 754 | KSampler.SAMPLERS.insert(idx+1, "dpmpp_2m_alt") 755 | setattr(k_diffusion_sampling, 'sample_dpmpp_2m_alt', sample_dpmpp_2m_alt) 756 | except Exception: ... 757 | 758 | def add_custom_samplers(): 759 | samplers = [ 760 | add_sample_dpmpp_2m_alt, 761 | ] 762 | for add_sampler in samplers: 763 | add_sampler() 764 | -------------------------------------------------------------------------------- /web/exif.js: -------------------------------------------------------------------------------- 1 | // Original file: /npm/exif-js@2.3.0/exif.js 2 | try{(function(){var d=!1,l=function(e){return e instanceof l?e:this instanceof l?void(this.EXIFwrapped=e):new l(e)};"undefined"!=typeof exports?("undefined"!=typeof module&&module.exports&&(exports=module.exports=l),exports.EXIF=l):this.EXIF=l;var u=l.Tags={36864:"ExifVersion",40960:"FlashpixVersion",40961:"ColorSpace",40962:"PixelXDimension",40963:"PixelYDimension",37121:"ComponentsConfiguration",37122:"CompressedBitsPerPixel",37500:"MakerNote",37510:"UserComment",40964:"RelatedSoundFile",36867:"DateTimeOriginal",36868:"DateTimeDigitized",37520:"SubsecTime",37521:"SubsecTimeOriginal",37522:"SubsecTimeDigitized",33434:"ExposureTime",33437:"FNumber",34850:"ExposureProgram",34852:"SpectralSensitivity",34855:"ISOSpeedRatings",34856:"OECF",37377:"ShutterSpeedValue",37378:"ApertureValue",37379:"BrightnessValue",37380:"ExposureBias",37381:"MaxApertureValue",37382:"SubjectDistance",37383:"MeteringMode",37384:"LightSource",37385:"Flash",37396:"SubjectArea",37386:"FocalLength",41483:"FlashEnergy",41484:"SpatialFrequencyResponse",41486:"FocalPlaneXResolution",41487:"FocalPlaneYResolution",41488:"FocalPlaneResolutionUnit",41492:"SubjectLocation",41493:"ExposureIndex",41495:"SensingMethod",41728:"FileSource",41729:"SceneType",41730:"CFAPattern",41985:"CustomRendered",41986:"ExposureMode",41987:"WhiteBalance",41988:"DigitalZoomRation",41989:"FocalLengthIn35mmFilm",41990:"SceneCaptureType",41991:"GainControl",41992:"Contrast",41993:"Saturation",41994:"Sharpness",41995:"DeviceSettingDescription",41996:"SubjectDistanceRange",40965:"InteroperabilityIFDPointer",42016:"ImageUniqueID"},c=l.TiffTags={256:"ImageWidth",257:"ImageHeight",34665:"ExifIFDPointer",34853:"GPSInfoIFDPointer",40965:"InteroperabilityIFDPointer",258:"BitsPerSample",259:"Compression",262:"PhotometricInterpretation",274:"Orientation",277:"SamplesPerPixel",284:"PlanarConfiguration",530:"YCbCrSubSampling",531:"YCbCrPositioning",282:"XResolution",283:"YResolution",296:"ResolutionUnit",273:"StripOffsets",278:"RowsPerStrip",279:"StripByteCounts",513:"JPEGInterchangeFormat",514:"JPEGInterchangeFormatLength",301:"TransferFunction",318:"WhitePoint",319:"PrimaryChromaticities",529:"YCbCrCoefficients",532:"ReferenceBlackWhite",306:"DateTime",270:"ImageDescription",271:"Make",272:"Model",305:"Software",315:"Artist",33432:"Copyright"},f=l.GPSTags={0:"GPSVersionID",1:"GPSLatitudeRef",2:"GPSLatitude",3:"GPSLongitudeRef",4:"GPSLongitude",5:"GPSAltitudeRef",6:"GPSAltitude",7:"GPSTimeStamp",8:"GPSSatellites",9:"GPSStatus",10:"GPSMeasureMode",11:"GPSDOP",12:"GPSSpeedRef",13:"GPSSpeed",14:"GPSTrackRef",15:"GPSTrack",16:"GPSImgDirectionRef",17:"GPSImgDirection",18:"GPSMapDatum",19:"GPSDestLatitudeRef",20:"GPSDestLatitude",21:"GPSDestLongitudeRef",22:"GPSDestLongitude",23:"GPSDestBearingRef",24:"GPSDestBearing",25:"GPSDestDistanceRef",26:"GPSDestDistance",27:"GPSProcessingMethod",28:"GPSAreaInformation",29:"GPSDateStamp",30:"GPSDifferential"},g=l.IFD1Tags={256:"ImageWidth",257:"ImageHeight",258:"BitsPerSample",259:"Compression",262:"PhotometricInterpretation",273:"StripOffsets",274:"Orientation",277:"SamplesPerPixel",278:"RowsPerStrip",279:"StripByteCounts",282:"XResolution",283:"YResolution",284:"PlanarConfiguration",296:"ResolutionUnit",513:"JpegIFOffset",514:"JpegIFByteCount",529:"YCbCrCoefficients",530:"YCbCrSubSampling",531:"YCbCrPositioning",532:"ReferenceBlackWhite"},m=l.StringValues={ExposureProgram:{0:"Not defined",1:"Manual",2:"Normal program",3:"Aperture priority",4:"Shutter priority",5:"Creative program",6:"Action program",7:"Portrait mode",8:"Landscape mode"},MeteringMode:{0:"Unknown",1:"Average",2:"CenterWeightedAverage",3:"Spot",4:"MultiSpot",5:"Pattern",6:"Partial",255:"Other"},LightSource:{0:"Unknown",1:"Daylight",2:"Fluorescent",3:"Tungsten (incandescent light)",4:"Flash",9:"Fine weather",10:"Cloudy weather",11:"Shade",12:"Daylight fluorescent (D 5700 - 7100K)",13:"Day white fluorescent (N 4600 - 5400K)",14:"Cool white fluorescent (W 3900 - 4500K)",15:"White fluorescent (WW 3200 - 3700K)",17:"Standard light A",18:"Standard light B",19:"Standard light C",20:"D55",21:"D65",22:"D75",23:"D50",24:"ISO studio tungsten",255:"Other"},Flash:{0:"Flash did not fire",1:"Flash fired",5:"Strobe return light not detected",7:"Strobe return light detected",9:"Flash fired, compulsory flash mode",13:"Flash fired, compulsory flash mode, return light not detected",15:"Flash fired, compulsory flash mode, return light detected",16:"Flash did not fire, compulsory flash mode",24:"Flash did not fire, auto mode",25:"Flash fired, auto mode",29:"Flash fired, auto mode, return light not detected",31:"Flash fired, auto mode, return light detected",32:"No flash function",65:"Flash fired, red-eye reduction mode",69:"Flash fired, red-eye reduction mode, return light not detected",71:"Flash fired, red-eye reduction mode, return light detected",73:"Flash fired, compulsory flash mode, red-eye reduction mode",77:"Flash fired, compulsory flash mode, red-eye reduction mode, return light not detected",79:"Flash fired, compulsory flash mode, red-eye reduction mode, return light detected",89:"Flash fired, auto mode, red-eye reduction mode",93:"Flash fired, auto mode, return light not detected, red-eye reduction mode",95:"Flash fired, auto mode, return light detected, red-eye reduction mode"},SensingMethod:{1:"Not defined",2:"One-chip color area sensor",3:"Two-chip color area sensor",4:"Three-chip color area sensor",5:"Color sequential area sensor",7:"Trilinear sensor",8:"Color sequential linear sensor"},SceneCaptureType:{0:"Standard",1:"Landscape",2:"Portrait",3:"Night scene"},SceneType:{1:"Directly photographed"},CustomRendered:{0:"Normal process",1:"Custom process"},WhiteBalance:{0:"Auto white balance",1:"Manual white balance"},GainControl:{0:"None",1:"Low gain up",2:"High gain up",3:"Low gain down",4:"High gain down"},Contrast:{0:"Normal",1:"Soft",2:"Hard"},Saturation:{0:"Normal",1:"Low saturation",2:"High saturation"},Sharpness:{0:"Normal",1:"Soft",2:"Hard"},SubjectDistanceRange:{0:"Unknown",1:"Macro",2:"Close view",3:"Distant view"},FileSource:{3:"DSC"},Components:{0:"",1:"Y",2:"Cb",3:"Cr",4:"R",5:"G",6:"B"}};function i(e){return!!e.exifdata}function r(i,o){function t(e){var t=p(e);i.exifdata=t||{};var n=function(e){var t=new DataView(e);d&&console.log("Got file of length "+e.byteLength);if(255!=t.getUint8(0)||216!=t.getUint8(1))return d&&console.log("Not a valid JPEG"),!1;var n=2,r=e.byteLength;for(;n")+8,u=(s=s.substring(s.indexOf("e.byteLength)return{};var u=P(e,t,t+l,g,r);if(u.Compression)switch(u.Compression){case 6:if(u.JpegIFOffset&&u.JpegIFByteCount){var c=t+u.JpegIFOffset,d=u.JpegIFByteCount;u.blob=new Blob([new Uint8Array(e.buffer,c,d)],{type:"image/jpeg"})}break;case 1:console.log("Thumbnail image format is TIFF, which is not implemented.");break;default:console.log("Unknown thumbnail image format '%s'",u.Compression)}else 2==u.PhotometricInterpretation&&console.log("Thumbnail image format is RGB, which is not implemented.");return u}(e,s,l,n),r}function b(e){var t={};if(1==e.nodeType){if(0 node.widgets.find((w) => w.name === name); 15 | export const findWidgetsByName = (node, name) => node.widgets.filter((w) => w.name.endsWith(name)); 16 | 17 | // export const doesInputWithNameExist = (node, name) => node.inputs ? node.inputs.some((input) => input.name === name) : false; 18 | 19 | // round in increments of n, with an offset 20 | export function round(number, increment = 10, offset = 0) { 21 | return Math.ceil((number - offset) / increment ) * increment + offset; 22 | } 23 | 24 | export function toggleWidget(node, widget, force) { 25 | if (!widget) return; 26 | // if (!widget || doesInputWithNameExist(node, widget.name)) return; 27 | widget.options[HIDDEN_TAG] ??= (widget.options.origType = widget.type, widget.options.origComputeSize = widget.computeSize, HIDDEN_TAG); 28 | 29 | const hide = force ?? (widget.type !== HIDDEN_TAG); 30 | 31 | widget.type = hide ? widget.options[HIDDEN_TAG] : widget.options.origType; 32 | 33 | widget.computeSize = hide ? () => [0, -3.3] : widget.options.origComputeSize; 34 | 35 | widget.linkedWidgets?.forEach(w => toggleWidget(node, w, force)); 36 | 37 | for (const el of ["inputEl", "input"]) 38 | widget[el]?.classList?.toggle(HIDDEN_TAG, force); 39 | 40 | const height = hide ? node.size[1] : Math.max(node.computeSize()[1], node.size[1]); 41 | node.setSize([node.size[0], height]); 42 | 43 | if (hide) 44 | widget.computedHeight = 0; 45 | else 46 | delete widget.computedHeight; 47 | } 48 | 49 | // Toggles a menu widget for both its canvas and HTML counterpart. Also accounts for group nodes. 50 | // Passing an array with a companion_widget_name means 51 | // the toggle will happen for every name that looks like it. 52 | // Useful for duplicates created in group nodes. 53 | export function toggleMenuOption(node, widget_arr, show) { 54 | const [widget_name, companion_widget_name] = Array.isArray(widget_arr) ? widget_arr : [widget_arr] 55 | let arr = [widget_name]; 56 | // companion_widget_name to use the new name assigned in a group node to get the correct widget 57 | if (companion_widget_name) { 58 | for (const gnc of getGroupNodeConfig(node)) { 59 | const omap = Object.values(gnc.oldToNewWidgetMap).find(x => Object.values(x).find(z => z === companion_widget_name)); 60 | const n = omap[widget_name]; 61 | if(n) arr.push(n); 62 | } 63 | } 64 | const widgets = companion_widget_name ? arr.map(it => findWidgetByName(node, it)) : findWidgetsByName(node, arr[0]) 65 | const hide = show !== undefined ? !show : undefined; 66 | widgets.forEach(widget => { 67 | toggleWidget(node, widget, hide) 68 | node.setDirtyCanvas(true); 69 | }); 70 | } 71 | 72 | export function getGroupNodeConfig(node) { 73 | let ls = [] 74 | let nodeData = node.constructor?.nodeData 75 | if (nodeData) { 76 | for(const sym of Object.getOwnPropertySymbols(nodeData) ) { 77 | const o = nodeData[sym]; 78 | if (o) ls.push(o) 79 | } 80 | } 81 | return ls 82 | } 83 | 84 | export function widgetLogic(node, widget) { 85 | const wname = widget.name 86 | const uoei = widgets[widgets.length - 2] 87 | if (wname.endsWith("parser")) { 88 | const in_comfy = widget?.value?.includes?.("comfy") 89 | // toggleMenuOption(node, ['multi_conditioning', wname], !in_comfy) 90 | toggleMenuOption(node, ['mean_normalization', wname], widget.value !== "comfy") 91 | toggleMenuOption(node, [uoei, wname], in_comfy ? false : findWidgetsByName(node, uoei)?.some(x => x?.value)) 92 | } else if (wname.endsWith("with_SDXL")) { 93 | toggleMenuOption(node, ['text', wname], !widget.value) 94 | toggleMenuOption(node, ['multi_conditioning', wname], !widget.value) 95 | 96 | // Resize node when widget is set to false 97 | if (!widget.value) { 98 | // Prevents resizing on init/webpage reload 99 | if(widget.init === false) { 100 | // Resize when set to false 101 | node.setSize([node.size[0], Math.max(100, round(node.size[1]/1.5))]) 102 | } 103 | } else { 104 | // When enabled, set init to false 105 | widget.init = false 106 | } 107 | 108 | // Toggle sdxl widgets if sdxl widget value is true/false 109 | for (const w of widgets_sdxl) { 110 | toggleMenuOption(node, [w, wname], widget.value) 111 | } 112 | 113 | // Keep showing the widget if it's enabled 114 | if (widget.value && widget.type === HIDDEN_TAG) { 115 | toggleMenuOption(node, [widget.name, wname], true) 116 | } 117 | } 118 | } 119 | 120 | // Specfic to cliptextencode 121 | function applyWidgetLogic(node) { 122 | if (!node.widgets?.length) return; 123 | 124 | const uoei = widgets[widgets.length - 2] 125 | const in_comfy = findWidgetsByName(node, "parser")?.some(it => it?.value?.includes?.("comfy")); 126 | const uoei_w = findWidgetByName(node, uoei); 127 | toggleMenuOption(node, [uoei, uoei], in_comfy ? false : uoei_w.value) 128 | 129 | let gncl = getGroupNodeConfig(node) 130 | for (const w of node.widgets) { 131 | for (const gsw of [...getSetWidgets]) { 132 | if (!w.name.endsWith(gsw)) continue; 133 | // Possibly uneeded 134 | /* 135 | let shouldBreak = false 136 | for (const gnc of gncl) { 137 | const nwmap = gnc.newToOldWidgetMap[w.name] 138 | console.log('=== gnc.newToOldWidgetMap',gnc.newToOldWidgetMap,'w.name',w.name,'nwmap',nwmap) 139 | // nwmap.inputName: resolved, actual widget name. 140 | if (nwmap && !(ids1.has(nwmap.node.type) && nwmap.inputName === gsw)) 141 | shouldBreak = true 142 | } 143 | if (shouldBreak) break; 144 | */ 145 | widgetLogic(node, w); 146 | 147 | let val = w.value; 148 | Object.defineProperty(w, 'value', { 149 | get() { 150 | return val; 151 | }, 152 | set(newVal) { 153 | if (newVal !== val) { 154 | val = newVal 155 | widgetLogic(node, w); 156 | } 157 | } 158 | }); 159 | 160 | // Hide SDXL widget on init 161 | // Doing it in nodeCreated fixes its toggling for some reason 162 | if (w.name.endsWith('with_SDXL')) { 163 | toggleMenuOption(node, ['with_SDXL', w.name]) 164 | w.init = true 165 | 166 | // Hide steps 167 | toggleMenuOption(node, ['smZ_steps', w.name], false) 168 | } 169 | } 170 | } 171 | 172 | // Reduce initial node size cause of SDXL widgets 173 | // node.setSize([node.size[0], Math.max(node.size[1]/1.5, 220)]) 174 | node.setSize([node.size[0], 220]) 175 | } 176 | 177 | function create_custom_option(content, _callback) { 178 | return { 179 | content: content, 180 | callback: () => _callback(), 181 | } 182 | }; 183 | 184 | function widgetLogicSettings(node) { 185 | const supported_types = {MODEL: 'MODEL', CLIP: 'CLIP'} 186 | const clip_headings = ['Stable Diffusion', 'Compatibility'] 187 | const clip_entries = 188 | ["info_comma_padding_backtrack", 189 | "Prompt word wrap length limit", 190 | "enable_emphasis", 191 | "info_use_prev_scheduling", 192 | "Use previous prompt editing timelines"] 193 | 194 | if(!node.widgets?.length) return; 195 | 196 | const index = node.index || 0 197 | 198 | const extra = node.widgets.find(w => w.name.endsWith('extra')) 199 | let extra_data = extra._value 200 | toggleMenuOption(node, extra.name, false) 201 | const condition = (sup) => node.outputs?.[index] && (node.outputs[index].name === sup || node.outputs[index].type === sup) || 202 | node.inputs?.[index] && node.inputs[index].type === sup; 203 | for (const w of node.widgets) { 204 | if (w.name.endsWith('extra')) continue; 205 | // w._name: to make sure it's from our node. though, it should have a better name for the variable 206 | if(!w._name) continue; 207 | 208 | // heading `values` won't get duplicated names due to group nodes 209 | // So we won't need to do `clip_headings.some()` 210 | // though, it should be read only... 211 | if (condition(supported_types.MODEL)) { 212 | const flag=(clip_entries.some(str => str.includes(w.name)) || (typeof w.value === 'string' && w.heading && w.value.includes('Compatibility'))) 213 | if (w.info && !clip_entries.some(str => str.includes(w.name))) 214 | toggleMenuOption(node, [w.name, w.name], extra_data.show_descriptions) 215 | else if (w.heading && typeof w.value === 'string' && !w.value.includes('Compatibility')) 216 | toggleMenuOption(node, [w.name, w.name], extra_data.show_headings) 217 | else 218 | toggleMenuOption(node, [w.name, w.name], !flag) 219 | // toggleMenuOption(node, w.name, flag ? false : true) //doesn't work? 220 | } else if (condition(supported_types.CLIP)) { 221 | // if w.name in list -> enable, else disable 222 | const flag = clip_entries.some(str => str.includes(w.name)) 223 | if (w.info && flag) 224 | toggleMenuOption(node, [w.name, w.name], extra_data.show_descriptions) 225 | else if (w.heading && typeof w.value === 'string' && clip_headings.includes(w.value)) 226 | toggleMenuOption(node, [w.name, w.name], extra_data.show_headings) 227 | else 228 | toggleMenuOption(node, [w.name, w.name], flag) 229 | } else { 230 | toggleMenuOption(node, w.name, false) 231 | } 232 | } 233 | 234 | let skip = [HEADING_IDENTIFIER, "info", "extra"]; 235 | let i = node.inputs.length; 236 | while (i--) { 237 | const ni = node.inputs[i]; 238 | const wname = ni?.widget?.name; 239 | if (i !== 0 && skip.some(it => wname?.includes(it))) { 240 | node.inputs.splice(i, 1); 241 | } 242 | } 243 | 244 | node.setSize([node.size[0], node.computeSize()[1]]) 245 | node.onResize?.(node.size) 246 | node.setDirtyCanvas(true); 247 | _app.graph.setDirtyCanvas(true, true); 248 | } 249 | 250 | _app.registerExtension({ 251 | name: "Comfy.smZ.dynamicWidgets", 252 | 253 | init() { 254 | let style = document.createElement('style'); 255 | let placeholder_opacity = 0.75; 256 | const createCssText = (it) => `.smZ-custom-textarea::-${it} { color: inherit; opacity: ${placeholder_opacity}}`; 257 | // WebKit, IE, Firefox 258 | ['webkit-input-placeholder', 'ms-input-placeholder', 'moz-placeholder'] 259 | .map(createCssText).forEach(it => style.appendChild(document.createTextNode(it))); 260 | style.appendChild(document.createTextNode(`.${HIDDEN_TAG} { hidden: "hidden"; display: none; }`)); 261 | document.head.appendChild(style); 262 | }, 263 | 264 | /** 265 | * Called when a node is created. Used to add menu options to nodes. 266 | * @param node The node that was created. 267 | * @param app The app. 268 | */ 269 | nodeCreated(node, app) { 270 | const nodeType = node.type || node.constructor?.type 271 | const anyType = "*"; 272 | let inGroupNode = false 273 | let inGroupNode2 = false 274 | let nodeData = node.constructor?.nodeData 275 | if (nodeData) { 276 | for(let sym of Object.getOwnPropertySymbols(nodeData) ) { 277 | const nds = nodeData[sym]; 278 | const nodes = nds?.nodeData?.nodes 279 | if (nodes) { 280 | for (const _node of nodes) { 281 | const _nodeType = _node.type 282 | if (ids1.has(_nodeType)) 283 | inGroupNode = true 284 | if (inGroupNode) 285 | ids1.add(nodeType) // GroupNode's type 286 | if (ids2.has(_nodeType)) 287 | inGroupNode2 = true 288 | if (inGroupNode2) 289 | ids2.add(nodeType) // GroupNode's type 290 | } 291 | } 292 | } 293 | } 294 | 295 | // ClipTextEncode++ node 296 | if (ids1.has(nodeType) || inGroupNode) { 297 | applyWidgetLogic(node) 298 | } 299 | // Settings node 300 | if (ids2.has(nodeType) || inGroupNode2) { 301 | if(!node.properties) 302 | node.properties = {} 303 | node.properties.showOutputText = true 304 | 305 | 306 | // allows bypass (in conjunction with below's `allows bypass`) 307 | // by setting node.inputs[0].type to a concrete type, instead of anyType. ComfyUI will complain otherwise. 308 | node.onBeforeConnectInput = function(inputIndex) { 309 | if (inputIndex !== node.index) return inputIndex 310 | 311 | // so we can connect to reroutes 312 | const tp = 'Reroute' 313 | node.type = tp 314 | if (node.constructor) node.constructor.type=tp 315 | this.type = tp 316 | if (this.constructor) this.constructor.type=tp 317 | Object.assign(node.inputs[inputIndex], {...node.inputs[inputIndex], name: anyType, type: anyType}); 318 | Object.assign(this.inputs[inputIndex], {...this.inputs[inputIndex], name: anyType, type: anyType}); 319 | node.beforeConnectInput = true 320 | this.beforeConnectInput = true 321 | return inputIndex; 322 | } 323 | 324 | // Call once on node creation 325 | node.setupWidgetLogic = function() { 326 | let nt = nodeType // JSON.parse(JSON.stringify(nodeType)) 327 | node.type = nodeType 328 | if (node.constructor) node.constructor.type=nodeType 329 | node.applyOrientation = function() { 330 | node.type = nodeType 331 | if (node.constructor) node.constructor.type=nodeType 332 | } 333 | node.index = 0 334 | let i = 0 335 | const innerNode = node.getInnerNodes?.().find(n => {const r = ids2.has(n.type); if (r) node.index = i; ++i; return r } ) 336 | const innerNodeWidgets = innerNode?.widgets 337 | i = 0 338 | node.widgets.forEach(function(w) { 339 | if (innerNodeWidgets) { 340 | if(innerNodeWidgets.some(iw => w.name.endsWith(iw.name))) 341 | w._name = w.name 342 | } else 343 | w._name = w.name 344 | 345 | // Styling. 346 | if (w.name.includes(HEADING_IDENTIFIER)) { 347 | w.heading = true 348 | // w.disabled = true 349 | } else if (w.name.includes('info')) { 350 | w.info = true 351 | for (const el of ["input", "inputEl"]) { 352 | if (w[el]) { 353 | w[el].disabled = true; 354 | w[el].readOnly = true; 355 | w[el].style.alignContent = 'center'; 356 | w[el].style.textAlign = 'center'; 357 | if (!w[el].classList.contains('smZ-custom-textarea')) 358 | w[el].classList.add('smZ-custom-textarea'); 359 | } 360 | } 361 | } 362 | }) 363 | const extra_widget = node.widgets.find(w => w.name.endsWith('extra')) 364 | if (extra_widget) { 365 | let extra_data = null 366 | try { 367 | extra_data = JSON.parse(extra_widget.value); 368 | } catch (error) { 369 | // when node definitions change due to an update or some other error 370 | extra_data = {show_headings: true, show_descriptions: false, mode: anyType} 371 | } 372 | extra_widget._value = extra_data 373 | Object.defineProperty(extra_widget, '_value', { 374 | get() { 375 | return extra_data; 376 | }, 377 | set(newVal) { 378 | extra_data = newVal 379 | extra_widget.value = JSON.stringify(extra_data) 380 | } 381 | }); 382 | } 383 | 384 | widgetLogicSettings(node); 385 | // Hijack getting our node type so we can work with Reroutes 386 | Object.defineProperty(node.constructor, 'type', { 387 | get() { 388 | let s = new Error().stack 389 | // const rr = ['rerouteNode.', 'baseCreateRenderer'] 390 | const rr = ['rerouteNode.'] 391 | // const rr = ['rerouteNode.js', 'reroutePrimitive.js'] 392 | // const rr = ['rerouteNode.js', 'groupNode.js', 'reroutePrimitive.js'] 393 | if (rr.some(rx => s.includes(rx))) { 394 | return 'Reroute' 395 | } 396 | return nt; 397 | }, 398 | set(newVal) { 399 | if (newVal !== nt) { 400 | nt = newVal 401 | } 402 | } 403 | }); 404 | 405 | if (node.outputs[node.index]) { 406 | let val = node.outputs[node.index].type; 407 | // Hijacks getting/setting type 408 | Object.defineProperty(node.outputs[node.index], 'type', { 409 | get() { 410 | return val; 411 | }, 412 | set(newVal) { 413 | if (newVal !== val) { 414 | val = newVal 415 | // console.log('====== group node test. node', node, 'group', node.getInnerNodes?.()) 416 | if (node.inputs && node.inputs[node.index]) node.inputs[node.index].type = newVal 417 | if (node.outputs && node.outputs[node.index]) node.outputs[node.index].name = newVal || anyType; 418 | node.properties.showOutputText = true // Reroute node accesses this 419 | node.type = nodeType 420 | if (node.constructor) node.constructor.type=nodeType 421 | // this.type = nodeType 422 | // if (this.constructor) this.constructor.type=nodeType 423 | // console.log('==== setupWidgetLogic', `val: '${val}' newval: '${newVal}' '${node.outputs[node.index].name}'`) 424 | widgetLogicSettings(node); 425 | } 426 | } 427 | }); 428 | } 429 | } 430 | node.setupWidgetLogic() 431 | 432 | const onConfigure = node.onConfigure; 433 | node.onConfigure = function(o) { 434 | // Call again after the node is created since there might be link 435 | // For example: if we reload the graph 436 | node.setupWidgetLogic() 437 | const r = onConfigure ? onConfigure.apply(this, arguments) : undefined; 438 | return r; 439 | } 440 | 441 | // ================= Adapted from rerouteNode.js ================= 442 | node.onAfterGraphConfigured = function () { 443 | requestAnimationFrame(() => { 444 | node.onConnectionsChange(LiteGraph.INPUT, null, true, null); 445 | }); 446 | }; 447 | node.onConnectionsChange = function (type, index, connected, link_info) { 448 | // node.index = index || 0 449 | // index = node.index || 0 450 | if (index !== node.index) return 451 | node.setupWidgetLogic() 452 | // if (index === node.index) node.setupWidgetLogic() 453 | 454 | const type_map = { 455 | [LiteGraph.OUTPUT] : 'OUTPUT', 456 | [LiteGraph.INPUT] : 'INPUT', 457 | } 458 | 459 | // console.log("======== onConnectionsChange type", type, "connected", connected, 'app.graph.links[l]',app.graph.links) 460 | // console.log('=== app.graph.links',app.graph.links, 'node.inputs[index]',node.inputs[index],'node.outputs[index]',node.outputs[index]) 461 | // console.log('=== onConnectionsChange type', type_map[type], 'index',index,'connected',connected,'node.inputs', node.inputs,'node.outputs', node.outputs,'link_info',link_info, 'node', node) 462 | 463 | node.type = nodeType 464 | if (node.constructor) node.constructor.type=nodeType 465 | this.type = nodeType 466 | if (this.constructor) this.constructor.type=nodeType 467 | 468 | 469 | // Prevent multiple connections to different types when we have no input 470 | if (connected && type === LiteGraph.OUTPUT) { 471 | // Ignore wildcard nodes as these will be updated to real types 472 | const types = this.outputs?.[index]?.links ? new Set(this.outputs[index].links.map((l) => app.graph.links[l]?.type)?.filter((t) => t !== anyType)) : new Set() 473 | if (types?.size > 1) { 474 | const linksToDisconnect = []; 475 | for (let i = 0; i < this.outputs[index].links.length - 1; i++) { 476 | const linkId = this.outputs[index].links[i]; 477 | const link = app.graph.links[linkId]; 478 | if (link) 479 | linksToDisconnect.push(link); 480 | } 481 | for (const link of linksToDisconnect) { 482 | if (link) { 483 | const node = app.graph.getNodeById(link.target_id); 484 | node.disconnectInput(link.target_slot); 485 | } 486 | } 487 | } 488 | } 489 | 490 | // Find root input 491 | let currentNode = this; 492 | let updateNodes = []; 493 | let inputType = null; 494 | let inputNode = null; 495 | while (currentNode) { 496 | updateNodes.unshift(currentNode); 497 | const linkId = currentNode?.inputs?.[index]?.link; 498 | if (linkId) { 499 | const link = app.graph.links[linkId]; 500 | if (!link) return; 501 | const node = app.graph.getNodeById(link.origin_id); 502 | const type = node.constructor.type || node.type; 503 | if (type === "Reroute" || ids2.has(type)) { 504 | if (node === this) { 505 | // We've found a circle 506 | currentNode.disconnectInput(link.target_slot); 507 | currentNode = null; 508 | } else { 509 | // Move the previous node 510 | currentNode = node; 511 | } 512 | } else { 513 | // We've found the end 514 | inputNode = currentNode; 515 | inputType = node.outputs[link.origin_slot]?.type ?? null; 516 | break; 517 | } 518 | } else { 519 | // This path has no input node 520 | currentNode = null; 521 | break; 522 | } 523 | } 524 | 525 | // Find all outputs 526 | const nodes = [this]; 527 | // const nodes = [link_info?.origin_id ? app.graph.getNodeById(link_info.origin_id).outputs ? 528 | // app.graph.getNodeById(link_info.origin_id): this : this,this]; 529 | let outputType = null; 530 | while (nodes.length) { 531 | currentNode = nodes.pop(); 532 | // if (currentNode.outputs) 533 | const outputs = (currentNode && currentNode.outputs?.[index]?.links ? currentNode.outputs?.[index]?.links : []) || []; 534 | // console.log('=== .outputs',outputs,'currentNode',currentNode) 535 | if (outputs.length) { 536 | for (const linkId of outputs) { 537 | const link = app.graph.links[linkId]; 538 | 539 | // When disconnecting sometimes the link is still registered 540 | if (!link) continue; 541 | 542 | const node = app.graph.getNodeById(link.target_id); 543 | const type = node.constructor.type || node.type; 544 | 545 | if (type === "Reroute" || ids2.has(type)) { 546 | // Follow reroute nodes 547 | nodes.push(node); 548 | updateNodes.push(node); 549 | } else { 550 | // We've found an output 551 | const nodeOutType = 552 | node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type 553 | ? node.inputs[link.target_slot].type 554 | : null; 555 | if (inputType && inputType !== anyType && nodeOutType !== inputType) { 556 | // The output doesnt match our input so disconnect it 557 | node.disconnectInput(link.target_slot); 558 | } else { 559 | outputType = nodeOutType; 560 | } 561 | } 562 | } 563 | } else { 564 | // No more outputs for this path 565 | } 566 | } 567 | 568 | const displayType = inputType || outputType || anyType; 569 | const color = LGraphCanvas.link_type_colors[displayType]; 570 | 571 | let widgetConfig; 572 | let targetWidget; 573 | let widgetType; 574 | // Update the types of each node 575 | for (const node of updateNodes) { 576 | // If we dont have an input type we are always wildcard but we'll show the output type 577 | // This lets you change the output link to a different type and all nodes will update 578 | if (!(node.outputs && node.outputs[index])) continue 579 | node.outputs[index].type = inputType || anyType; 580 | node.__outputType = displayType; 581 | node.outputs[index].name = node.properties.showOutputText ? displayType : ""; 582 | // node.size = node.computeSize(); 583 | // node.applyOrientation(); 584 | 585 | for (const l of node.outputs[index].links || []) { 586 | const link = app.graph.links[l]; 587 | if (link) { 588 | link.color = color; 589 | 590 | if (app.configuringGraph) continue; 591 | const targetNode = app.graph.getNodeById(link.target_id); 592 | const targetInput = targetNode.inputs?.[link.target_slot]; 593 | if (targetInput?.widget) { 594 | const config = getWidgetConfig(targetInput); 595 | if (!widgetConfig) { 596 | widgetConfig = config[1] ?? {}; 597 | widgetType = config[0]; 598 | } 599 | if (!targetWidget) { 600 | targetWidget = targetNode.widgets?.find((w) => w.name === targetInput.widget.name); 601 | } 602 | 603 | const merged = mergeIfValid(targetInput, [config[0], widgetConfig]); 604 | if (merged.customConfig) { 605 | widgetConfig = merged.customConfig; 606 | } 607 | } 608 | } 609 | } 610 | } 611 | 612 | for (const node of updateNodes) { 613 | if (!(node.inputs && node.inputs[index])) continue 614 | if (widgetConfig && outputType) { 615 | node.inputs[index].widget = { name: "value" }; 616 | setWidgetConfig(node.inputs[index], [widgetType ?? displayType, widgetConfig], targetWidget); 617 | } else { 618 | setWidgetConfig(node.inputs[index], null); 619 | } 620 | } 621 | 622 | if (inputNode && inputNode.inputs[index]) { 623 | const link = app.graph.links[inputNode.inputs[index].link]; 624 | if (link) { 625 | link.color = color; 626 | if (node.outputs?.[index]) { 627 | node.outputs[index].name = inputNode.__outputType || inputNode.outputs[index].type; 628 | 629 | // allows bypass 630 | node.inputs[index] = Object.assign(node.inputs[0], {...node.inputs[index], name: anyType, type: node.outputs[index].name}); 631 | } 632 | } 633 | } 634 | } 635 | // ================= Adapted from rerouteNode.js ================= 636 | } 637 | 638 | // Add extra MenuOptions for 639 | // ClipTextEncode++ and Settings node 640 | if (ids1.has(nodeType) || inGroupNode || ids2.has(nodeType) || inGroupNode2) { 641 | // Save the original options 642 | const getExtraMenuOptions = node.getExtraMenuOptions; 643 | 644 | node.getExtraMenuOptions = function (_, options) { 645 | // Call the original function for the default menu options 646 | const r = getExtraMenuOptions ? getExtraMenuOptions.apply(this, arguments) : undefined; 647 | let customOptions = [] 648 | node.setDirtyCanvas(true, true); 649 | // if (!r) return r; 650 | 651 | if (ids2.has(nodeType) || inGroupNode2) { 652 | // Clean up MenuOption 653 | const hiddenWidgets = node.widgets.filter(w => w.type === HIDDEN_TAG || w.heading || w.info) 654 | let filtered = false 655 | let wo = options.filter(o => o === null || (o && !hiddenWidgets.some(w => {const i = o.content.includes(`Convert ${w.name} to input`); if(i) filtered = i; return i} ))) 656 | options.splice(0, options.length, ...wo); 657 | 658 | if(hiddenWidgets.length !== node.widgets.length) { 659 | customOptions.push(null) // seperator 660 | const d=function(_node) { 661 | const extra = _node.widgets.find(w => w.name.endsWith('extra')) 662 | let extra_data = extra._value 663 | // extra_data.show_descriptions = !extra_data.show_descriptions 664 | extra._value = {...extra._value, show_descriptions: !extra_data.show_descriptions} 665 | widgetLogicSettings(_node) 666 | } 667 | const h=function(_node) { 668 | const extra = _node.widgets.find(w => w.name.endsWith('extra')) 669 | let extra_data = extra._value 670 | extra._value = {...extra._value, show_headings: !extra_data.show_headings} 671 | // extra_data.show_headings = !extra_data.show_headings 672 | widgetLogicSettings(_node) 673 | } 674 | customOptions.push(create_custom_option("Hide/show all headings", h.bind(this, node))) 675 | customOptions.push(create_custom_option("Hide/show all descriptions", d.bind(this, node))) 676 | } 677 | } 678 | 679 | if (ids1.has(nodeType) || inGroupNode) { 680 | // Dynamic MenuOption depending on the widgets 681 | const content_hide_show = "Hide/show "; 682 | // const whWidgets = node.widgets.filter(w => w.name === 'width' || w.name === 'height') 683 | const hiddenWidgets = node.widgets.filter(w => w.type === HIDDEN_TAG) 684 | // doesn't take GroupNode into account 685 | // const with_SDXL = node.widgets.find(w => w.name.endsWith('with_SDXL')) 686 | const parser = node.widgets.find(w => w.name.endsWith('parser')) 687 | const in_comfy = parser?.value?.includes?.("comfy") 688 | let ws = widgets.map(widget_name => create_custom_option(content_hide_show + widget_name, toggleMenuOption.bind(this, node, widget_name))) 689 | ws = ws.filter((w) => (in_comfy && parser.value !== 'comfy' && w.content.includes('mean_normalization')) || (in_comfy && w.content.includes('with_SDXL')) || !in_comfy || w.content.includes('multi_conditioning') ) 690 | // customOptions.push(null) // seperator 691 | customOptions.push(...ws) 692 | 693 | let wo = options.filter(o => o === null || (o && !hiddenWidgets.some(w => o.content.includes(`Convert ${w.name} to input`)))) 694 | const width = node.widgets.find(w => w.name.endsWith('width')) 695 | const height = node.widgets.find(w => w.name.endsWith('height')) 696 | if (width && height) { 697 | const width_type = width.type.toLowerCase() 698 | const height_type = height.type.toLowerCase() 699 | if (!(width_type.includes('number') || width_type.includes('int') || width_type.includes('float') || 700 | height_type.includes('number') || height_type.includes('int') || height_type.includes('float'))) 701 | wo = wo.filter(o => o === null || (o && !o.content.includes('Swap width/height'))) 702 | } 703 | options.splice(0, options.length, ...wo); 704 | } 705 | // options.unshift(...customOptions); // top 706 | options.splice(options.length - 1, 0, ...customOptions) 707 | // return r; 708 | } 709 | } 710 | } 711 | }); 712 | --------------------------------------------------------------------------------