├── .github
└── workflows
│ └── publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── modules
├── rng.py
├── rng_philox.py
├── shared.py
└── text_processing
│ ├── classic_engine.py
│ ├── emphasis.py
│ ├── parsing.py
│ ├── past_classic_engine.py
│ ├── prompt_parser.py
│ ├── t5_engine.py
│ └── textual_inversion.py
├── nodes.py
├── pyproject.toml
├── smZNodes.py
└── web
├── exif.js
├── metadata.js
└── smZdynamicWidgets.js
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish to Comfy registry
2 | on:
3 | workflow_dispatch:
4 | push:
5 | branches:
6 | - main
7 | paths:
8 | - "pyproject.toml"
9 |
10 | jobs:
11 | publish-node:
12 | name: Publish Custom Node to registry
13 | runs-on: ubuntu-latest
14 | steps:
15 | - name: Check out code
16 | uses: actions/checkout@v4
17 | - name: Publish Custom Node
18 | uses: Comfy-Org/publish-node-action@main
19 | with:
20 | ## Add your own personal access token to your Github Repository secrets and reference it here.
21 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | backup*
163 | **/.DS_Store
164 | **/.venv
165 | **/.vscode
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # smZNodes
3 | A selection of custom nodes for [ComfyUI](https://github.com/comfyanonymous/ComfyUI).
4 |
5 | 1. [CLIP Text Encode++](#clip-text-encode)
6 | 2. [Settings](#settings)
7 |
8 | Contents
9 |
10 | * [Tips to get reproducible results on both UIs](#tips-to-get-reproducible-results-on-both-uis)
11 | * [FAQs](#faqs)
12 | * [Installation](#installation)
13 |
14 |
15 | ## CLIP Text Encode++
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | CLIP Text Encode++ can generate identical embeddings from [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) for [ComfyUI](https://github.com/comfyanonymous/ComfyUI).
25 |
26 | This means you can reproduce the same images generated from `stable-diffusion-webui` on `ComfyUI`.
27 |
28 | Simple prompts generate _identical_ images. More complex prompts with complex attention/emphasis/weighting may generate images with slight differences. In that case, you can try using the `Settings` node to match outputs.
29 |
30 | ### Features
31 |
32 | - [Prompt editing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#prompt-editing)
33 | - [Alternating words](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alternating-words)
34 | - [`AND`](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#composable-diffusion) keyword (similar to the ConditioningCombine node)
35 | - [`BREAK`](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#break-keyword) keyword (same as the ConditioningConcat node)
36 | - Weight normalization
37 | - Optional `embedding:` identifier
38 |
39 |
40 | ### Comparisons
41 | These images can be dragged into ComfyUI to load their workflows. Each image is done using the [Silicon29](https://huggingface.co/Xynon/SD-Silicon) (in SD v1.5) checkpoint with 18 steps using the Heun sampler.
42 |
43 | |stable-diffusion-webui|A1111 parser|Comfy parser|
44 | |:---:|:---:|:---:|
45 | |  |  |  |
46 | |  |  |  |
47 |
48 | Image slider links:
49 | - https://imgsli.com/MTkxMjE0
50 | - https://imgsli.com/MTkxMjEy
51 |
52 | ### Options
53 |
54 | |Name|Description|
55 | | --- | --- |
56 | | `parser` | The parser to parse prompts into tokens and then transformed (encoded) into embeddings. Taken from [SD.Next](https://github.com/vladmandic/automatic/discussions/99#discussioncomment-5931014). |
57 | | `mean_normalization` | Whether to take the mean of your prompt weights. It's `true` by default on `stable-diffusion-webui`.
This is implemented according to how it is in `stable-diffusion-webui`. |
58 | | `multi_conditioning` | For each prompt, the list is obtained by splitting the prompt using the `AND` separator.
See: [Compositional Visual Generation with Composable Diffusion Models](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/)
- a way to use multiple prompts at once
- supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2` (for non-comfy parsers). The weights default to 1
- each prompt gets a cfg value of `cfg * weight / N`, where `N` is the number of positive prompts. In `stable-diffusion-webui`, each prompt gets a cfg value of `cfg * weight`. To match their behaviour, you can add a weight of `:N` to every prompt _or_ simply set a cfg value of `cfg * N`
|
59 | |`use_old_emphasis_implementation`| Use old emphasis implementation. Can be useful to reproduce old seeds.
|
60 |
61 | > [!TIP]
62 | > You can right click the node to show/hide some of the widgets. E.g. the `with_SDXL` option.
63 |
64 |
65 |
66 | | Parser | Description |
67 | | ----------------- | -------------------------------------------------------------------------------- |
68 | | `comfy` | The default way `ComfyUI` handles everything |
69 | | `comfy++` | Uses `ComfyUI`'s parser but encodes tokens the way `stable-diffusion-webui` does, allowing to take the mean as they do. |
70 | | `A1111` | The default parser used in `stable-diffusion-webui` |
71 | | `full` | Same as `A1111` but whitespaces, newlines, and special characters are stripped |
72 | | `compel` | Uses [`compel`](https://github.com/damian0815/compel) |
73 | | `fixed attention` | Prompt is untampered with |
74 |
75 | > [!IMPORTANT]
76 | > Every `parser` except `comfy` uses `stable-diffusion-webui`'s encoding pipeline.
77 |
78 | > [!WARNING]
79 | > LoRA syntax (``) is not suppprted.
80 |
81 | ## Settings
82 |
83 |
84 |

85 |
Settings node showcase
86 |
87 |
88 |
89 | The `Settings` node is a dynamic node functioning similar to the Reroute node and is used to fine-tune results during sampling or tokenization. The inputs can be replaced with another input type even after it's been connected. `CLIP` inputs only apply settings to CLIP Text Encode++. Settings apply locally based on its links just like nodes that do model patches. I made this node to explore the various settings found in `stable-diffusion-webui`.
90 |
91 | This node can change whenever it is updated, so you may have to **recreate** it to prevent issues. Settings can be overridden by using another `Settings` node somewhere past a previous one. Right click the node for the `Hide/show all descriptions` menu option.
92 |
93 |
94 | ## Tips to get reproducible results on both UIs
95 | - Use the same seed, sampler settings, RNG (CPU or GPU), clip skip (CLIP Set Last Layer), etc.
96 | - Ancestral and SDE samplers may not be deterministic.
97 | - If you're using `DDIM` as your sampler, use the `ddim_uniform` scheduler.
98 | - There are different `unipc` configurations. Adjust accordingly on both UIs.
99 |
100 | ## FAQs
101 | - How does this differ from [`ComfyUI_ADV_CLIP_emb`](https://github.com/BlenderNeko/ComfyUI_ADV_CLIP_emb)?
102 | - While the weights are normalized in the same manner, the tokenization and encoding pipeline that's taken from stable-diffusion-webui differs from ComfyUI's. These small changes add up and ultimately produces different results.
103 | - Where can I learn more about how ComfyUI interprets weights?
104 | - https://comfyanonymous.github.io/ComfyUI_examples/faq/
105 | - https://blenderneko.github.io/ComfyUI-docs/Interface/Textprompts/
106 | - [comfyui.creamlab.net](https://comfyui.creamlab.net/guides/f_text_prompt)
107 |
108 |
109 | ## Installation
110 |
111 | Three methods are available for installation:
112 |
113 | 1. Load via [ComfyUI Manager](https://github.com/ltdrdata/ComfyUI-Manager)
114 | 2. Clone the repository directly into the extensions directory.
115 | 3. Download the project manually.
116 |
117 |
118 | ### Load via ComfyUI Manager
119 |
120 |
121 |
122 |

123 |
Install via ComfyUI Manager
124 |
125 |
126 | ### Clone Repository
127 |
128 | ```shell
129 | cd path/to/your/ComfyUI/custom_nodes
130 | git clone https://github.com/shiimizu/ComfyUI_smZNodes.git
131 | ```
132 |
133 | ### Download Manually
134 |
135 | 1. Download the project archive from [here](https://github.com/shiimizu/ComfyUI_smZNodes/archive/refs/heads/main.zip).
136 | 2. Extract the downloaded zip file.
137 | 3. Move the extracted files to `path/to/your/ComfyUI/custom_nodes`.
138 | 4. Restart ComfyUI
139 |
140 | The folder structure should resemble: `path/to/your/ComfyUI/custom_nodes/ComfyUI_smZNodes`.
141 |
142 |
143 | ### Update
144 |
145 | To update the extension, update via [ComfyUI Manager](https://github.com/ltdrdata/ComfyUI-Manager) or pull the latest changes from the repository:
146 |
147 | ```shell
148 | cd path/to/your/ComfyUI/custom_nodes/ComfyUI_smZNodes
149 | git pull
150 | ```
151 |
152 | ## Credits
153 |
154 | * [AUTOMATIC1111](https://github.com/AUTOMATIC1111) / [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
155 | * [comfyanonymous](https://github.com/comfyanonymous) / [ComfyUI](https://github.com/comfyanonymous/ComfyUI)
156 | * [vladmandic](https://github.com/vladmandic) / [SD.Next](https://github.com/vladmandic/automatic)
157 | * [lllyasviel](https://github.com/lllyasviel) / [stable-diffusion-webui-forge](https://github.com/lllyasviel/stable-diffusion-webui-forge)
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import importlib
4 | import subprocess
5 | from pathlib import Path
6 |
7 | def get_modules():
8 | from sys import modules
9 | s = set()
10 | for m in modules.values():
11 | try: s.add(m)
12 | except Exception: ...
13 | return s
14 |
15 | def reload_modules(m):
16 | s = get_modules()
17 | for module in s - m :
18 | try: importlib.reload(module)
19 | except Exception: ...
20 |
21 | PRELOADED_MODULES = None
22 |
23 | def install(module, PRELOADED_MODULES):
24 | if importlib.util.find_spec(module) is not None: return
25 | if PRELOADED_MODULES is None:
26 | PRELOADED_MODULES = get_modules()
27 | import sys
28 | try:
29 | print(f"\033[92m[smZNodes] \033[0;31m{module} is not installed. Attempting to install...\033[0m")
30 | subprocess.check_call([sys.executable, "-m", "pip", "install", module])
31 | reload_modules(PRELOADED_MODULES)
32 | print(f"\033[92m[smZNodes] {module} Installed!\033[0m")
33 | except Exception as e:
34 | print(f"\033[92m[smZNodes] \033[0;31mFailed to install {module}.\033[0m")
35 | return PRELOADED_MODULES
36 |
37 | for pkg in ['compel', 'lark']:
38 | PRELOADED_MODULES = install(pkg, PRELOADED_MODULES)
39 |
40 | # ============================
41 | # web
42 |
43 | cwd_path = Path(__file__).parent
44 | comfy_path = cwd_path.parent.parent
45 |
46 | def setup_web_extension():
47 | import nodes
48 | web_extension_path = os.path.join(comfy_path, "web", "extensions", "smZNodes")
49 | if os.path.exists(web_extension_path):
50 | shutil.rmtree(web_extension_path)
51 | if not hasattr(nodes, "EXTENSION_WEB_DIRS"):
52 | if not os.path.exists(web_extension_path):
53 | os.makedirs(web_extension_path)
54 | js_src_path = os.path.join(cwd_path, "web", "smZdynamicWidgets.js")
55 | shutil.copy(js_src_path, web_extension_path)
56 |
57 | setup_web_extension()
58 |
59 | # ============================
60 |
61 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
62 | WEB_DIRECTORY = "./web"
63 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
64 |
65 | from .smZNodes import add_custom_samplers, register_hooks
66 |
67 | add_custom_samplers()
68 | register_hooks()
69 |
--------------------------------------------------------------------------------
/modules/rng.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from comfy.model_patcher import ModelPatcher
4 | from . import shared, rng_philox
5 |
6 | class TorchHijack:
7 | """This is here to replace torch.randn_like of k-diffusion.
8 |
9 | k-diffusion has random_sampler argument for most samplers, but not for all, so
10 | this is needed to properly replace every use of torch.randn_like.
11 |
12 | We need to replace to make images generated in batches to be same as images generated individually."""
13 |
14 | def __init__(self, generator, randn_source, init=True):
15 | self.generator = generator
16 | self.randn_source = randn_source
17 | self.init = init
18 |
19 | def __getattr__(self, item):
20 | if item == 'randn_like':
21 | return self.randn_like
22 |
23 | if hasattr(torch, item):
24 | return getattr(torch, item)
25 |
26 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
27 |
28 | def randn_like(self, x):
29 | return randn_without_seed(x, generator=self.generator, randn_source=self.randn_source)
30 |
31 | def randn_without_seed(x, generator=None, randn_source="cpu"):
32 | """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
33 |
34 | Use either randn() or manual_seed() to initialize the generator."""
35 | if randn_source == "nv":
36 | return torch.asarray(generator.randn(x.size()), device=x.device)
37 | else:
38 | return torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=generator.device, generator=generator).to(device=x.device)
39 |
40 | def prepare_noise(latent_image, seed, noise_inds=None, device='cpu'):
41 | """
42 | creates random noise given a latent image and a seed.
43 | optional arg skip can be used to skip and discard x number of noise generations for a given seed
44 | """
45 | opts = None
46 | opts_found = False
47 | model = _find_outer_instance('model', ModelPatcher)
48 | if (model is not None and (opts:=model.model_options.get(shared.Options.KEY)) is None) or opts is None:
49 | import comfy.samplers
50 | guider = _find_outer_instance('guider', comfy.samplers.CFGGuider)
51 | model = getattr(guider, 'model_patcher', None)
52 | if (model is not None and (opts:=model.model_options.get(shared.Options.KEY)) is None) or opts is None:
53 | pass
54 | opts_found = opts is not None
55 | if not opts_found:
56 | opts = shared.opts_default
57 | device = torch.device("cpu")
58 |
59 | if opts.randn_source == 'gpu':
60 | import comfy.model_management
61 | device = comfy.model_management.get_torch_device()
62 |
63 | device_orig = device
64 | device = torch.device("cpu") if opts.randn_source == "cpu" else device_orig
65 |
66 | def get_generator(seed):
67 | nonlocal device, opts
68 | if opts.randn_source == 'nv':
69 | generator = rng_philox.Generator(seed)
70 | else:
71 | generator = torch.Generator(device=device).manual_seed(seed)
72 | return generator
73 |
74 | def get_generator_obj(seed):
75 | nonlocal opts
76 | generator = torch.manual_seed(seed)
77 | generator = generator_eta = get_generator(seed)
78 | if opts.eta_noise_seed_delta > 0:
79 | seed = min(int(seed + opts.eta_noise_seed_delta), int(0xffffffffffffffff))
80 | generator_eta = get_generator(seed)
81 | return (generator, generator_eta)
82 |
83 | generator, generator_eta = get_generator_obj(seed)
84 | randn_source = opts.randn_source
85 |
86 | # ========== hijack randn_like ===============
87 | import comfy.k_diffusion.sampling
88 | # if not hasattr(comfy.k_diffusion.sampling, 'torch_orig'):
89 | # comfy.k_diffusion.sampling.torch_orig = comfy.k_diffusion.sampling.torch
90 | # comfy.k_diffusion.sampling.torch = TorchHijack(generator_eta, opts.randn_source)
91 |
92 | if not hasattr(comfy.k_diffusion.sampling, 'default_noise_sampler_orig'):
93 | comfy.k_diffusion.sampling.default_noise_sampler_orig = comfy.k_diffusion.sampling.default_noise_sampler
94 | if opts_found:
95 | th = TorchHijack(generator_eta, randn_source)
96 | def default_noise_sampler(x, seed=None, *args, **kwargs):
97 | nonlocal th
98 | return lambda sigma, sigma_next: th.randn_like(x)
99 | default_noise_sampler.init = True
100 | comfy.k_diffusion.sampling.default_noise_sampler = default_noise_sampler
101 | else:
102 | comfy.k_diffusion.sampling.default_noise_sampler = comfy.k_diffusion.sampling.default_noise_sampler_orig
103 | # =============================================
104 |
105 | if noise_inds is None:
106 | shape = latent_image.size()
107 | if opts.randn_source == 'nv':
108 | noise = torch.asarray(generator.randn(shape), dtype=latent_image.dtype, device=device)
109 | else:
110 | noise = torch.randn(shape, dtype=latent_image.dtype, layout=latent_image.layout, device=device, generator=generator)
111 | noise = noise.to(device=device_orig)
112 | return noise
113 |
114 | unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
115 | noises = []
116 | for i in range(unique_inds[-1]+1):
117 | shape = [1] + list(latent_image.size())[1:]
118 | if opts.randn_source == 'nv':
119 | noise = torch.asarray(generator.randn(shape), dtype=latent_image.dtype, device=device)
120 | else:
121 | noise = torch.randn(shape, dtype=latent_image.dtype, layout=latent_image.layout, device=device, generator=generator)
122 | noise = noise.to(device=device_orig)
123 | if i in unique_inds:
124 | noises.append(noise)
125 | noises = [noises[i] for i in inverse]
126 | noises = torch.cat(noises, axis=0)
127 | return noises
128 |
129 | def _find_outer_instance(target:str, target_type=None, callback=None, max_len=10):
130 | import inspect
131 | frame = inspect.currentframe()
132 | i = 0
133 | while frame and i < max_len:
134 | if target in frame.f_locals:
135 | if callback is not None:
136 | return callback(frame)
137 | else:
138 | found = frame.f_locals[target]
139 | if isinstance(found, target_type):
140 | return found
141 | frame = frame.f_back
142 | i += 1
143 | return None
144 |
145 |
--------------------------------------------------------------------------------
/modules/rng_philox.py:
--------------------------------------------------------------------------------
1 | """RNG imitiating torch cuda randn on CPU. You are welcome.
2 |
3 | Usage:
4 |
5 | ```
6 | g = Generator(seed=0)
7 | print(g.randn(shape=(3, 4)))
8 | ```
9 |
10 | Expected output:
11 | ```
12 | [[-0.92466259 -0.42534415 -2.6438457 0.14518388]
13 | [-0.12086647 -0.57972564 -0.62285122 -0.32838709]
14 | [-1.07454231 -0.36314407 -1.67105067 2.26550497]]
15 | ```
16 | """
17 |
18 | import numpy as np
19 |
20 | philox_m = [0xD2511F53, 0xCD9E8D57]
21 | philox_w = [0x9E3779B9, 0xBB67AE85]
22 |
23 | two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32)
24 | two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32)
25 |
26 |
27 | def uint32(x):
28 | """Converts (N,) np.uint64 array into (2, N) np.unit32 array."""
29 | return x.view(np.uint32).reshape(-1, 2).transpose(1, 0)
30 |
31 |
32 | def philox4_round(counter, key):
33 | """A single round of the Philox 4x32 random number generator."""
34 |
35 | v1 = uint32(counter[0].astype(np.uint64) * philox_m[0])
36 | v2 = uint32(counter[2].astype(np.uint64) * philox_m[1])
37 |
38 | counter[0] = v2[1] ^ counter[1] ^ key[0]
39 | counter[1] = v2[0]
40 | counter[2] = v1[1] ^ counter[3] ^ key[1]
41 | counter[3] = v1[0]
42 |
43 |
44 | def philox4_32(counter, key, rounds=10):
45 | """Generates 32-bit random numbers using the Philox 4x32 random number generator.
46 |
47 | Parameters:
48 | counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation).
49 | key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed).
50 | rounds (int): The number of rounds to perform.
51 |
52 | Returns:
53 | numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers.
54 | """
55 |
56 | for _ in range(rounds - 1):
57 | philox4_round(counter, key)
58 |
59 | key[0] = key[0] + philox_w[0]
60 | key[1] = key[1] + philox_w[1]
61 |
62 | philox4_round(counter, key)
63 | return counter
64 |
65 |
66 | def box_muller(x, y):
67 | """Returns just the first out of two numbers generated by Box–Muller transform algorithm."""
68 | u = x * two_pow32_inv + two_pow32_inv / 2
69 | v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2
70 |
71 | s = np.sqrt(-2.0 * np.log(u))
72 |
73 | r1 = s * np.sin(v)
74 | return r1.astype(np.float32)
75 |
76 |
77 | class Generator:
78 | """RNG that produces same outputs as torch.randn(..., device='cuda') on CPU"""
79 |
80 | def __init__(self, seed):
81 | self.seed = seed
82 | self.offset = 0
83 |
84 | def randn(self, shape):
85 | """Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform."""
86 |
87 | n = 1
88 | for x in shape:
89 | n *= x
90 |
91 | counter = np.zeros((4, n), dtype=np.uint32)
92 | counter[0] = self.offset
93 | counter[2] = np.arange(n, dtype=np.uint32) # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3]
94 | self.offset += 1
95 |
96 | key = np.empty(n, dtype=np.uint64)
97 | key.fill(self.seed)
98 | key = uint32(key)
99 |
100 | g = philox4_32(counter, key)
101 |
102 | return box_muller(g[0], g[1]).reshape(shape) # discard g[2] and g[3]
103 |
--------------------------------------------------------------------------------
/modules/shared.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from copy import deepcopy
3 | from comfy.model_management import vram_state, VRAMState
4 | from comfy.cli_args import args
5 | from comfy import model_management
6 |
7 | xformers_available = model_management.XFORMERS_IS_AVAILABLE
8 | logger = logging.getLogger("smZNodes")
9 | level = logging.INFO
10 | logger.propagate = False
11 | logger.setLevel(level)
12 | stdoutHandler = logging.StreamHandler()
13 | fmt = logging.Formatter("[%(name)s] | %(filename)s:%(lineno)s | %(message)s")
14 | stdoutHandler.setFormatter(fmt)
15 | logger.addHandler(stdoutHandler)
16 |
17 | def join_args(*args):
18 | return ' '.join(map(str, args))
19 |
20 | class SimpleNamespaceFast:
21 | def __repr__(self):
22 | keys = sorted(self.__dict__)
23 | items = ("{}={!r}".format(k, self.__dict__[k]) for k in keys)
24 | return "{}({})".format(type(self).__name__, ", ".join(items))
25 |
26 | def __eq__(self, other):
27 | return self.__dict__ == other.__dict__
28 |
29 | class Options(SimpleNamespaceFast):
30 | KEY = 'smZ_opts'
31 |
32 | def clone(self):
33 | return deepcopy(self)
34 |
35 | def update(self, other):
36 | if isinstance(other, dict):
37 | self.__dict__ |= other
38 | else:
39 | self.__dict__ |= other.__dict__
40 | return self
41 |
42 | opts = Options()
43 | opts.prompt_attention = 'A1111 parser'
44 | opts.prompt_mean_norm = True
45 | opts.comma_padding_backtrack = 20
46 | opts.CLIP_stop_at_last_layers = 1
47 | opts.enable_emphasis = True
48 | opts.use_old_emphasis_implementation = False
49 | opts.disable_nan_check = True
50 | opts.pad_cond_uncond = False
51 | opts.s_min_uncond = 0.0
52 | opts.s_min_uncond_all = False
53 | opts.skip_early_cond = 0.0
54 | opts.upcast_sampling = True
55 | opts.upcast_attn = not getattr(args, 'dont_upcast_attention', False)
56 | opts.textual_inversion_add_hashes_to_infotext = False
57 | opts.encode_count = 0
58 | opts.max_chunk_count = 0
59 | opts.return_batch_chunks = False
60 | opts.noise = None
61 | opts.start_step = None
62 | opts.pad_with_repeats = True
63 | opts.randn_source = "cpu"
64 | opts.lora_functional = False
65 | opts.use_old_scheduling = True
66 | opts.eta_noise_seed_delta = 0
67 | opts.multi_conditioning = False
68 | opts.eta = 1.0
69 | opts.s_churn = 0.0
70 | opts.s_tmin = 0.0
71 | opts.s_tmax = 0.0 or float('inf')
72 | opts.s_noise = 1.0
73 |
74 | opts.use_CFGDenoiser = False
75 | opts.sgm_noise_multiplier = True
76 | opts.debug= False
77 |
78 | opts.sdxl_crop_top = 0
79 | opts.sdxl_crop_left = 0
80 | opts.sdxl_refiner_low_aesthetic_score = 2.5
81 | opts.sdxl_refiner_high_aesthetic_score = 6.0
82 |
83 | sd_model = Options()
84 | sd_model.cond_stage_model = Options()
85 |
86 | cmd_opts = Options()
87 |
88 | opts.batch_cond_uncond = False
89 | cmd_opts.lowvram = vram_state == VRAMState.LOW_VRAM
90 | cmd_opts.medvram = vram_state == VRAMState.NORMAL_VRAM
91 | should_batch_cond_uncond = lambda: opts.batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
92 | opts.batch_cond_uncond = True
93 |
94 | opts_default = opts.clone()
95 |
96 | cmd_opts.xformers = xformers_available
97 | cmd_opts.force_enable_xformers = xformers_available
98 |
99 | opts.cross_attention_optimization = "None"
100 | cmd_opts.sub_quad_q_chunk_size = 512
101 | cmd_opts.sub_quad_kv_chunk_size = 512
102 | cmd_opts.sub_quad_chunk_threshold = 80
103 | cmd_opts.token_merging_ratio = 0.0
104 | cmd_opts.token_merging_ratio_img2img = 0.0
105 | cmd_opts.token_merging_ratio_hr = 0.0
106 | cmd_opts.sd_vae_sliced_encode = False
107 | cmd_opts.disable_opt_split_attention = False
--------------------------------------------------------------------------------
/modules/text_processing/classic_engine.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import logging
4 | from collections import namedtuple
5 | from comfy import model_management
6 | from . import emphasis, prompt_parser
7 | from .textual_inversion import EmbeddingDatabase, parse_and_register_embeddings
8 |
9 |
10 | PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
11 | last_extra_generation_params = {}
12 |
13 | def populate_self_variables(self, from_):
14 | attrs_from = vars(from_)
15 | attrs_self = vars(self)
16 | attrs_self.update(attrs_from)
17 |
18 | class PromptChunk:
19 | def __init__(self):
20 | self.tokens = []
21 | self.multipliers = []
22 | self.fixes = []
23 |
24 |
25 | class CLIPEmbeddingForTextualInversion(torch.nn.Module):
26 | def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
27 | super().__init__()
28 | self.wrapped = wrapped
29 | self.embeddings = embeddings
30 | self.textual_inversion_key = textual_inversion_key
31 | self.weight = self.wrapped.weight
32 |
33 | def forward(self, input_ids, out_dtype):
34 | batch_fixes = self.embeddings.fixes
35 | self.embeddings.fixes = None
36 |
37 | inputs_embeds = self.wrapped(input_ids, out_dtype)
38 |
39 | if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
40 | return inputs_embeds
41 |
42 | vecs = []
43 | for fixes, tensor in zip(batch_fixes, inputs_embeds):
44 | for offset, embedding in fixes:
45 | emb = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
46 | emb = emb.to(inputs_embeds)
47 | emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
48 | try:
49 | tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
50 | except Exception:
51 | logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {} {} {} '{}'".format(tensor.shape[0], emb.shape[1], self.current_embeds.weight.shape[1], self.textual_inversion_key, embedding.name))
52 |
53 | vecs.append(tensor)
54 |
55 | return torch.stack(vecs)
56 |
57 |
58 | class ClassicTextProcessingEngine:
59 | def __init__(
60 | self, text_encoder, tokenizer, chunk_length=75,
61 | embedding_dir=None, embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="Original",
62 | text_projection=False, minimal_clip_skip=1, clip_skip=1, return_pooled=True, final_layer_norm=True
63 | ):
64 | super().__init__()
65 | populate_self_variables(self, tokenizer)
66 | self._tokenizer = tokenizer
67 |
68 | self.embeddings = EmbeddingDatabase(self.tokenizer, embedding_expected_shape)
69 |
70 | self.text_encoder = text_encoder
71 | self._try_get_embedding = tokenizer._try_get_embedding
72 |
73 | self.emphasis = emphasis.get_current_option(emphasis_name)()
74 |
75 | self.text_projection = text_projection
76 | self.minimal_clip_skip = minimal_clip_skip
77 | self.clip_skip = clip_skip
78 | self.return_pooled = return_pooled
79 | self.final_layer_norm = final_layer_norm
80 |
81 | self.chunk_length = chunk_length
82 |
83 | self.id_start = self.start_token
84 | self.id_end = self.end_token
85 | self.id_pad = self.pad_token
86 |
87 | model_embeddings = text_encoder.transformer.text_model.embeddings
88 | backup_embeds = self.text_encoder.transformer.get_input_embeddings()
89 | model_embeddings.token_embedding = CLIPEmbeddingForTextualInversion(model_embeddings.token_embedding, self.embeddings, textual_inversion_key=self.embedding_key)
90 | model_embeddings.token_embedding.current_embeds = backup_embeds
91 |
92 | vocab = self.tokenizer.get_vocab()
93 | self.token_mults = {}
94 | tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
95 | for text, ident in tokens_with_parens:
96 | mult = 1.0
97 | for c in text:
98 | if c == '[':
99 | mult /= 1.1
100 | if c == ']':
101 | mult *= 1.1
102 | if c == '(':
103 | mult *= 1.1
104 | if c == ')':
105 | mult /= 1.1
106 | if mult != 1.0:
107 | self.token_mults[ident] = mult
108 |
109 | self.comma_token = vocab.get(',', None)
110 | self.tokenizer._eventual_warn_about_too_long_sequence = lambda *args, **kwargs: None
111 |
112 | def unhook(self):
113 | self.text_encoder.transformer.text_model.embeddings.token_embedding = self.text_encoder.transformer.text_model.embeddings.token_embedding.wrapped
114 | del self._try_get_embedding
115 | w = '_eventual_warn_about_too_long_sequence'
116 | if hasattr(self.tokenizer, w): delattr(self.tokenizer, w)
117 | if hasattr(self._tokenizer, w): delattr(self._tokenizer, w)
118 |
119 | def empty_chunk(self):
120 | chunk = PromptChunk()
121 | chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
122 | chunk.multipliers = [1.0] * (self.chunk_length + 2)
123 | return chunk
124 |
125 | def get_target_prompt_token_count(self, token_count):
126 | return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
127 |
128 | def tokenize(self, texts):
129 | tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
130 | return tokenized
131 |
132 | def tokenize_with_weights(self, texts, return_word_ids=False):
133 | texts = [parse_and_register_embeddings(self, text) for text in texts]
134 | if self.opts.use_old_emphasis_implementation:
135 | return self.process_texts_past(texts)
136 | batch_chunks, token_count = self.process_texts(texts)
137 |
138 | used_embeddings = {}
139 | chunk_count = max([len(x) for x in batch_chunks])
140 |
141 | zs = []
142 | for i in range(chunk_count):
143 | batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
144 |
145 | tokens = [x.tokens for x in batch_chunk]
146 | multipliers = [x.multipliers for x in batch_chunk]
147 | self.embeddings.fixes = [x.fixes for x in batch_chunk]
148 |
149 | for fixes in self.embeddings.fixes:
150 | for _position, embedding in fixes:
151 | used_embeddings[embedding.name] = embedding
152 |
153 | z = (tokens, multipliers)
154 | zs.append(z)
155 |
156 | return zs
157 |
158 | def encode_token_weights(self, token_weight_pairs):
159 | if isinstance(token_weight_pairs[0], str):
160 | token_weight_pairs = self.tokenize_with_weights(token_weight_pairs)
161 | elif isinstance(token_weight_pairs[0], list):
162 | token_weight_pairs = list(map(lambda x: ([list(map(lambda y: y[0], x))], [list(map(lambda y: y[1], x))]), token_weight_pairs))
163 |
164 | target_device = model_management.text_encoder_offload_device()
165 | zs = []
166 | for tokens, multipliers in token_weight_pairs:
167 | z = self.process_tokens(tokens, multipliers)
168 | zs.append(z)
169 | if self.return_pooled:
170 | return torch.hstack(zs).to(target_device), zs[0].pooled.to(target_device) if zs[0].pooled is not None else None
171 | else:
172 | return torch.hstack(zs).to(target_device)
173 |
174 | def encode_with_transformers(self, tokens):
175 | try:
176 | z, pooled = self.text_encoder(tokens)
177 | except Exception:
178 | z, pooled = self.text_encoder(tokens.tolist())
179 | z.pooled = pooled
180 | return z
181 |
182 | def tokenize_line(self, line):
183 | parsed = prompt_parser.parse_prompt_attention(line)
184 |
185 | tokenized = self.tokenize([text for text, _ in parsed])
186 |
187 | chunks = []
188 | chunk = PromptChunk()
189 | token_count = 0
190 | last_comma = -1
191 |
192 | def next_chunk(is_last=False):
193 | nonlocal token_count
194 | nonlocal last_comma
195 | nonlocal chunk
196 |
197 | if is_last:
198 | token_count += len(chunk.tokens)
199 | else:
200 | token_count += self.chunk_length
201 |
202 | to_add = self.chunk_length - len(chunk.tokens)
203 | if to_add > 0:
204 | chunk.tokens += [self.id_end] * to_add
205 | chunk.multipliers += [1.0] * to_add
206 |
207 | chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
208 | chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
209 |
210 | last_comma = -1
211 | chunks.append(chunk)
212 | chunk = PromptChunk()
213 |
214 | for tokens, (text, weight) in zip(tokenized, parsed):
215 | if text == 'BREAK' and weight == -1:
216 | next_chunk()
217 | continue
218 |
219 | position = 0
220 | while position < len(tokens):
221 | token = tokens[position]
222 |
223 | comma_padding_backtrack = 20
224 |
225 | if token == self.comma_token:
226 | last_comma = len(chunk.tokens)
227 |
228 | elif comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= comma_padding_backtrack:
229 | break_location = last_comma + 1
230 |
231 | reloc_tokens = chunk.tokens[break_location:]
232 | reloc_mults = chunk.multipliers[break_location:]
233 |
234 | chunk.tokens = chunk.tokens[:break_location]
235 | chunk.multipliers = chunk.multipliers[:break_location]
236 |
237 | next_chunk()
238 | chunk.tokens = reloc_tokens
239 | chunk.multipliers = reloc_mults
240 |
241 | if len(chunk.tokens) == self.chunk_length:
242 | next_chunk()
243 |
244 | embedding, embedding_length_in_tokens = self.embeddings.find_embedding_at_position(tokens, position)
245 | if embedding is None:
246 | chunk.tokens.append(token)
247 | chunk.multipliers.append(weight)
248 | position += 1
249 | continue
250 |
251 | emb_len = int(embedding.vectors)
252 | if len(chunk.tokens) + emb_len > self.chunk_length:
253 | next_chunk()
254 |
255 | chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
256 |
257 | chunk.tokens += [0] * emb_len
258 | chunk.multipliers += [weight] * emb_len
259 | position += embedding_length_in_tokens
260 |
261 | if chunk.tokens or not chunks:
262 | next_chunk(is_last=True)
263 |
264 | return chunks, token_count
265 |
266 | def process_texts(self, texts):
267 | token_count = 0
268 |
269 | cache = {}
270 | batch_chunks = []
271 | for line in texts:
272 | if line in cache:
273 | chunks = cache[line]
274 | else:
275 | chunks, current_token_count = self.tokenize_line(line)
276 | token_count = max(current_token_count, token_count)
277 |
278 | cache[line] = chunks
279 |
280 | batch_chunks.append(chunks)
281 |
282 | return batch_chunks, token_count
283 |
284 | def __call__(self, texts):
285 | tokens = self.tokenize_with_weights(texts)
286 | return self.encode_token_weights(tokens)
287 |
288 | def process_tokens(self, remade_batch_tokens, batch_multipliers, *args, **kwargs):
289 | try:
290 | tokens = torch.asarray(remade_batch_tokens)
291 |
292 | if self.id_end != self.id_pad:
293 | for batch_pos in range(len(remade_batch_tokens)):
294 | index = remade_batch_tokens[batch_pos].index(self.id_end)
295 | tokens[batch_pos, index + 1:tokens.shape[1]] = self.id_pad
296 |
297 | z = self.encode_with_transformers(tokens)
298 | except ValueError:
299 | # Tokens including textual inversion embeddings in the list.
300 | # i.e. tensors in the list along with tokens.
301 | z = self.encode_with_transformers(remade_batch_tokens)
302 |
303 | pooled = getattr(z, 'pooled', None)
304 |
305 | self.emphasis.tokens = remade_batch_tokens
306 | self.emphasis.multipliers = torch.asarray(batch_multipliers).to(z)
307 | self.emphasis.z = z
308 | self.emphasis.after_transformers()
309 | z = self.emphasis.z
310 |
311 | if pooled is not None:
312 | z.pooled = pooled
313 |
314 | return z
315 |
--------------------------------------------------------------------------------
/modules/text_processing/emphasis.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Emphasis:
5 | name: str = "Base"
6 | description: str = ""
7 | tokens: list[list[int]]
8 | multipliers: torch.Tensor
9 | z: torch.Tensor
10 |
11 | def after_transformers(self):
12 | pass
13 |
14 |
15 | class EmphasisNone(Emphasis):
16 | name = "None"
17 | description = "disable the mechanism entirely and treat (:.1.1) as literal characters"
18 |
19 |
20 | class EmphasisIgnore(Emphasis):
21 | name = "Ignore"
22 | description = "treat all emphasized words as if they have no emphasis"
23 |
24 |
25 | class EmphasisOriginal(Emphasis):
26 | name = "Original"
27 | description = "the original emphasis implementation"
28 |
29 | def after_transformers(self):
30 | original_mean = self.z.mean()
31 | self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape)
32 | new_mean = self.z.mean()
33 | self.z = self.z * (original_mean / new_mean)
34 |
35 |
36 | class EmphasisOriginalNoNorm(EmphasisOriginal):
37 | name = "No norm"
38 | description = "same as original, but without normalization (seems to work better for SDXL)"
39 |
40 | def after_transformers(self):
41 | self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape)
42 |
43 |
44 | def get_current_option(emphasis_option_name):
45 | return next(iter([x for x in options if x.name == emphasis_option_name]), EmphasisOriginal)
46 |
47 |
48 | def get_options_descriptions():
49 | return ", ".join(f"{x.name}: {x.description}" for x in options)
50 |
51 | def get_options_descriptions_nl():
52 | return "\n".join(f"{x.name}: {x.description}" for x in options)
53 |
54 |
55 | options = [
56 | EmphasisNone,
57 | EmphasisIgnore,
58 | EmphasisOriginal,
59 | EmphasisOriginalNoNorm,
60 | ]
61 |
--------------------------------------------------------------------------------
/modules/text_processing/parsing.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 |
4 | re_attention = re.compile(r"""
5 | \\\(|
6 | \\\)|
7 | \\\[|
8 | \\]|
9 | \\\\|
10 | \\|
11 | \(|
12 | \[|
13 | :\s*([+-]?[.\d]+)\s*\)|
14 | \)|
15 | ]|
16 | [^\\()\[\]:]+|
17 | :
18 | """, re.X)
19 |
20 | re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
21 |
22 |
23 | def parse_prompt_attention(text):
24 | res = []
25 | round_brackets = []
26 | square_brackets = []
27 |
28 | round_bracket_multiplier = 1.1
29 | square_bracket_multiplier = 1 / 1.1
30 |
31 | def multiply_range(start_position, multiplier):
32 | for p in range(start_position, len(res)):
33 | res[p][1] *= multiplier
34 |
35 | for m in re_attention.finditer(text):
36 | text = m.group(0)
37 | weight = m.group(1)
38 |
39 | if text.startswith('\\'):
40 | res.append([text[1:], 1.0])
41 | elif text == '(':
42 | round_brackets.append(len(res))
43 | elif text == '[':
44 | square_brackets.append(len(res))
45 | elif weight is not None and round_brackets:
46 | multiply_range(round_brackets.pop(), float(weight))
47 | elif text == ')' and round_brackets:
48 | multiply_range(round_brackets.pop(), round_bracket_multiplier)
49 | elif text == ']' and square_brackets:
50 | multiply_range(square_brackets.pop(), square_bracket_multiplier)
51 | else:
52 | parts = re.split(re_break, text)
53 | for i, part in enumerate(parts):
54 | if i > 0:
55 | res.append(["BREAK", -1])
56 | res.append([part, 1.0])
57 |
58 | for pos in round_brackets:
59 | multiply_range(pos, round_bracket_multiplier)
60 |
61 | for pos in square_brackets:
62 | multiply_range(pos, square_bracket_multiplier)
63 |
64 | if len(res) == 0:
65 | res = [["", 1.0]]
66 |
67 | i = 0
68 | while i + 1 < len(res):
69 | if res[i][1] == res[i + 1][1]:
70 | res[i][0] += res[i + 1][0]
71 | res.pop(i + 1)
72 | else:
73 | i += 1
74 |
75 | return res
76 |
--------------------------------------------------------------------------------
/modules/text_processing/past_classic_engine.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | def process_text_old(self, texts):
4 | id_start = self.id_start
5 | id_end = self.id_end
6 | maxlen = self.max_length # you get to stay at 77
7 | used_custom_terms = []
8 | remade_batch_tokens = []
9 | hijack_comments = []
10 | hijack_fixes = []
11 | token_count = 0
12 |
13 | cache = {}
14 | batch_tokens = self.tokenize(texts)
15 | batch_multipliers = []
16 | batch_multipliers_solo_emb = []
17 | for tokens in batch_tokens:
18 | tuple_tokens = tuple(tokens)
19 |
20 | if tuple_tokens in cache:
21 | remade_tokens, fixes, multipliers, multipliers_solo_emb = cache[tuple_tokens]
22 | else:
23 | fixes = []
24 | remade_tokens = []
25 | multipliers = []
26 | multipliers_solo_emb = []
27 | mult = 1.0
28 |
29 | i = 0
30 | while i < len(tokens):
31 | token = tokens[i]
32 |
33 | embedding, embedding_length_in_tokens = self.embeddings.find_embedding_at_position(tokens, i)
34 | if isinstance(embedding, dict):
35 | if 'open' in self.__class__.__name__.lower():
36 | embedding = embedding.get('g', embedding)
37 | else:
38 | embedding.pop('g', None)
39 | embedding = next(iter(embedding.values()))
40 |
41 | mult_change = self.token_mults.get(token) if self.opts.enable_emphasis else None
42 | if mult_change is not None:
43 | mult *= mult_change
44 | i += 1
45 | elif embedding is None:
46 | remade_tokens.append(token)
47 | multipliers.append(mult)
48 | multipliers_solo_emb.append(mult)
49 | i += 1
50 | else:
51 | emb_len = int(embedding.vec.shape[0])
52 | fixes.append((len(remade_tokens), embedding))
53 | remade_tokens += [0] * emb_len
54 | multipliers += [mult] * emb_len
55 | multipliers_solo_emb += [mult] + ([1.0] * (emb_len-1))
56 | used_custom_terms.append((embedding.name, embedding.checksum()))
57 | i += embedding_length_in_tokens
58 |
59 | if len(remade_tokens) > maxlen - 2:
60 | vocab = {v: k for k, v in self.tokenizer.get_vocab().items()}
61 | ovf = remade_tokens[maxlen - 2:]
62 | overflowing_words = [vocab.get(int(x), "") for x in ovf]
63 | overflowing_text = self.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
64 | logging.warning(f"\033[33mWarning:\033[0m too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
65 |
66 | token_count = len(remade_tokens)
67 | remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
68 | remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
69 | cache[tuple_tokens] = (remade_tokens, fixes, multipliers, multipliers_solo_emb)
70 |
71 | multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
72 | multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
73 |
74 | multipliers_solo_emb = multipliers_solo_emb + [1.0] * (maxlen - 2 - len(multipliers_solo_emb))
75 | multipliers_solo_emb = [1.0] + multipliers_solo_emb[0:maxlen - 2] + [1.0]
76 |
77 | remade_batch_tokens.append(remade_tokens)
78 | hijack_fixes.append(fixes)
79 | batch_multipliers.append(multipliers)
80 | batch_multipliers_solo_emb.append(multipliers_solo_emb)
81 | return batch_multipliers, batch_multipliers_solo_emb, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
82 |
83 |
84 | def forward_old(self, texts):
85 | batch_multipliers, batch_multipliers_solo_emb, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, _token_count = process_text_old(self, texts)
86 |
87 | chunk_count = max([len(x) for x in remade_batch_tokens])
88 |
89 | if self.opts.return_batch_chunks:
90 | return (remade_batch_tokens, chunk_count)
91 |
92 | self.hijack.comments += hijack_comments
93 |
94 | if len(used_custom_terms) > 0:
95 | embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
96 | self.hijack.comments.append(f"Used embeddings: {embedding_names}")
97 |
98 | self.hijack.fixes = hijack_fixes
99 | return self.process_tokens(remade_batch_tokens, batch_multipliers, batch_multipliers_solo_emb)
100 |
101 | def process_texts_past(self, texts):
102 | batch_multipliers, batch_multipliers_solo_emb, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, _token_count = process_text_old(self, texts)
103 | return [(remade_batch_tokens, batch_multipliers)]
--------------------------------------------------------------------------------
/modules/text_processing/prompt_parser.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import re
3 | import lark
4 | from typing import List
5 | from compel import Compel
6 | from collections import namedtuple
7 | if __name__ != "__main__":
8 | import torch
9 | from ..shared import opts, logger
10 | else:
11 | class Opts: ...
12 | opts = Opts()
13 | opts.prompt_attention = 'A1111 parser'
14 | import logging as logger
15 |
16 | # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][: in background:0.25] [shoddy:masterful:0.5]"
17 | # will be represented with prompt_schedule like this (assuming steps=100):
18 | # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
19 | # [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
20 | # [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
21 | # [75, 'fantasy landscape with a lake and an oak in background masterful']
22 | # [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
23 |
24 | schedule_parser = lark.Lark(r"""
25 | !start: (prompt | /[][():]/+)*
26 | prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
27 | !emphasized: "(" prompt ")"
28 | | "(" prompt ":" prompt ")"
29 | | "[" prompt "]"
30 | scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]"
31 | alternate: "[" prompt ("|" [prompt])+ "]"
32 | WHITESPACE: /\s+/
33 | plain: /([^\\\[\]():|]|\\.)+/
34 | %import common.SIGNED_NUMBER -> NUMBER
35 | """)
36 | re_clean = re.compile(r"^\W+", re.S)
37 | re_whitespace = re.compile(r"\s+", re.S)
38 |
39 |
40 | def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=None, use_old_scheduling=False):
41 | """
42 | >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
43 | >>> g("test")
44 | [[10, 'test']]
45 | >>> g("a [b:3]")
46 | [[3, 'a '], [10, 'a b']]
47 | >>> g("a [b: 3]")
48 | [[3, 'a '], [10, 'a b']]
49 | >>> g("a [[[b]]:2]")
50 | [[2, 'a '], [10, 'a [[b]]']]
51 | >>> g("[(a:2):3]")
52 | [[3, ''], [10, '(a:2)']]
53 | >>> g("a [b : c : 1] d")
54 | [[1, 'a b d'], [10, 'a c d']]
55 | >>> g("a[b:[c:d:2]:1]e")
56 | [[1, 'abe'], [2, 'ace'], [10, 'ade']]
57 | >>> g("a [unbalanced")
58 | [[10, 'a [unbalanced']]
59 | >>> g("a [b:.5] c")
60 | [[5, 'a c'], [10, 'a b c']]
61 | >>> g("a [{b|d{:.5] c") # not handling this right now
62 | [[5, 'a c'], [10, 'a {b|d{ c']]
63 | >>> g("((a][:b:c [d:3]")
64 | [[3, '((a][:b:c '], [10, '((a][:b:c d']]
65 | >>> g("[a|(b:1.1)]")
66 | [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
67 | >>> g("[fe|]male")
68 | [[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
69 | >>> g("[fe|||]male")
70 | [[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
71 | >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10, 10)[0]
72 | >>> g("a [b:.5] c")
73 | [[10, 'a b c']]
74 | >>> g("a [b:1.5] c")
75 | [[5, 'a c'], [10, 'a b c']]
76 | """
77 |
78 | if hires_steps is None or use_old_scheduling:
79 | int_offset = 0
80 | flt_offset = 0
81 | steps = base_steps
82 | else:
83 | int_offset = base_steps
84 | flt_offset = 1.0
85 | steps = hires_steps
86 |
87 | def collect_steps(steps, tree):
88 | res = [steps]
89 |
90 | class CollectSteps(lark.Visitor):
91 | def scheduled(self, tree):
92 | s = tree.children[-2]
93 | v = float(s)
94 | if use_old_scheduling:
95 | v = v*steps if v<1 else v
96 | else:
97 | if "." in s:
98 | v = (v - flt_offset) * steps
99 | else:
100 | v = (v - int_offset)
101 | tree.children[-2] = min(steps, int(v))
102 | if tree.children[-2] >= 1:
103 | res.append(tree.children[-2])
104 |
105 | def alternate(self, tree):
106 | res.extend(range(1, steps+1))
107 |
108 | CollectSteps().visit(tree)
109 | return sorted(set(res))
110 |
111 | def at_step(step, tree):
112 | class AtStep(lark.Transformer):
113 | def scheduled(self, args):
114 | before, after, _, when, _ = args
115 | yield before or () if step <= when else after
116 | def alternate(self, args):
117 | args = ["" if not arg else arg for arg in args]
118 | yield args[(step - 1) % len(args)]
119 | def start(self, args):
120 | def flatten(x):
121 | if isinstance(x, str):
122 | yield x
123 | else:
124 | for gen in x:
125 | yield from flatten(gen)
126 | return ''.join(flatten(args))
127 | def plain(self, args):
128 | yield args[0].value
129 | def __default__(self, data, children, meta):
130 | for child in children:
131 | yield child
132 | return AtStep().transform(tree)
133 |
134 | def get_schedule(prompt):
135 | try:
136 | tree = schedule_parser.parse(prompt)
137 | except lark.exceptions.LarkError:
138 | if 0:
139 | import traceback
140 | traceback.print_exc()
141 | return [[steps, prompt]]
142 | return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
143 |
144 | promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
145 | return [promptdict[prompt] for prompt in prompts]
146 |
147 |
148 | ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
149 |
150 |
151 | class SdConditioning(list):
152 | """
153 | A list with prompts for stable diffusion's conditioner model.
154 | Can also specify width and height of created image - SDXL needs it.
155 | """
156 | def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None):
157 | super().__init__()
158 | self.extend(prompts)
159 |
160 | if copy_from is None:
161 | copy_from = prompts
162 |
163 | self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False)
164 | self.width = width or getattr(copy_from, 'width', None)
165 | self.height = height or getattr(copy_from, 'height', None)
166 |
167 |
168 |
169 | def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False):
170 | """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
171 | and the sampling step at which this condition is to be replaced by the next one.
172 |
173 | Input:
174 | (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
175 |
176 | Output:
177 | [
178 | [
179 | ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
180 | ],
181 | [
182 | ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
183 | ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
184 | ]
185 | ]
186 | """
187 | res = []
188 |
189 | prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling)
190 | logger.debug(prompt_schedules)
191 | cache = {}
192 | for prompt, prompt_schedule in zip(prompts, prompt_schedules):
193 |
194 | cached = cache.get(prompt, None)
195 | if cached is not None:
196 | res.append(cached)
197 | continue
198 |
199 | texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts)
200 | # conds = model.get_learned_conditioning(texts)
201 | conds = model(texts)
202 | if isinstance(conds, tuple):
203 | conds, pooleds = conds
204 | conds.pooled = pooleds
205 | cond_schedule = []
206 | for i, (end_at_step, _) in enumerate(prompt_schedule):
207 | if isinstance(conds, dict):
208 | if 'cond' not in conds:
209 | cond = {k: v[i] for k, v in conds.items()}
210 | else:
211 | # cond = {k: (v[i:i+1, :] if isinstance(v, torch.Tensor) else v) for k, v in conds.items()}
212 | cond = {k: (v[i].unsqueeze(0)[0:1] if isinstance(v, torch.Tensor) else v) for k, v in conds.items()}
213 | else:
214 | # cond = conds[i]
215 | cond = conds[i:i+1, :]
216 | if conds.pooled is not None:
217 | cond.pooled = conds.pooled[i:i+1, :]
218 | cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond))
219 |
220 | cache[prompt] = cond_schedule
221 | res.append(cond_schedule)
222 |
223 | return res
224 |
225 |
226 | re_AND = re.compile(r"\bAND\b")
227 | re_weight = re.compile(r"^((?:\s|.)*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
228 |
229 |
230 | def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
231 | res_indexes = []
232 |
233 | prompt_indexes = {}
234 | prompt_flat_list = SdConditioning(prompts)
235 | prompt_flat_list.clear()
236 |
237 | for prompt in prompts:
238 | subprompts = re_AND.split(prompt)
239 |
240 | indexes = []
241 | for subprompt in subprompts:
242 | match = re_weight.search(subprompt)
243 |
244 | text, weight = match.groups() if match is not None else (subprompt, 1.0)
245 |
246 | weight = float(weight) if weight is not None else 1.0
247 |
248 | index = prompt_indexes.get(text, None)
249 | if index is None:
250 | index = len(prompt_flat_list)
251 | prompt_flat_list.append(text)
252 | prompt_indexes[text] = index
253 |
254 | indexes.append((index, weight))
255 |
256 | res_indexes.append(indexes)
257 |
258 | return res_indexes, prompt_flat_list, prompt_indexes
259 |
260 |
261 | class ComposableScheduledPromptConditioning:
262 | def __init__(self, schedules, weight=1.0):
263 | self.schedules: List[ScheduledPromptConditioning] = schedules
264 | self.weight: float = weight
265 |
266 |
267 | class MulticondLearnedConditioning:
268 | def __init__(self, shape, batch):
269 | self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
270 | self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
271 |
272 |
273 | def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
274 | """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
275 | For each prompt, the list is obtained by splitting the prompt using the AND separator.
276 |
277 | https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
278 | """
279 |
280 | res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
281 |
282 | learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling)
283 |
284 | res = []
285 | for indexes in res_indexes:
286 | res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
287 |
288 | return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
289 |
290 |
291 | class DictWithShape(dict):
292 | def __init__(self, x, shape=None):
293 | super().__init__()
294 | self.update(x)
295 |
296 | @property
297 | def shape(self):
298 | return self["crossattn"].shape
299 |
300 |
301 | def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
302 | param = c[0][0].cond
303 | is_dict = isinstance(param, dict)
304 | pooled_outputs = []
305 |
306 | if is_dict:
307 | dict_cond = param
308 | if 'crossattn' in dict_cond:
309 | res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
310 | res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
311 | elif 'cond' in dict_cond:
312 | param = dict_cond['cond']
313 | res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
314 | else:
315 | res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
316 |
317 | for i, cond_schedule in enumerate(c):
318 | target_index = 0
319 | for current, entry in enumerate(cond_schedule):
320 | if current_step <= entry.end_at_step:
321 | target_index = current
322 | break
323 |
324 | cond_target = cond_schedule[target_index].cond
325 | if is_dict:
326 | if 'cond' in cond_target:
327 | res[i] = cond_target['cond']
328 | pooled_outputs.append(cond_target['pooled_output'])
329 | else:
330 | for k, param in cond_target.items():
331 | res[k][i] = param
332 | else:
333 | res[i] = cond_target
334 | pooled_outputs.append(cond_target.pooled)
335 | res.pooled = torch.cat(pooled_outputs).to(param) if pooled_outputs[0] is not None else None
336 | return res
337 |
338 |
339 | def stack_conds(tensors):
340 | # if prompts have wildly different lengths above the limit we'll get tensors of different shapes
341 | # and won't be able to torch.stack them. So this fixes that.
342 | token_count = max([x.shape[0] for x in tensors])
343 | for i in range(len(tensors)):
344 | if tensors[i].shape[0] != token_count:
345 | last_vector = tensors[i][-1:]
346 | last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
347 | tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
348 |
349 | return torch.stack(tensors)
350 |
351 |
352 |
353 | def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
354 | param = c.batch[0][0].schedules[0].cond
355 |
356 | tensors = []
357 | pooled_outputs = []
358 | conds_list = []
359 |
360 | for composable_prompts in c.batch:
361 | conds_for_batch = []
362 |
363 | for composable_prompt in composable_prompts:
364 | target_index = 0
365 | for current, entry in enumerate(composable_prompt.schedules):
366 | if current_step <= entry.end_at_step:
367 | target_index = current
368 | break
369 |
370 | conds_for_batch.append((len(tensors), composable_prompt.weight))
371 | tensors.append(composable_prompt.schedules[target_index].cond)
372 | pooled_outputs.append(composable_prompt.schedules[target_index].cond.pooled)
373 |
374 | conds_list.append(conds_for_batch)
375 |
376 | if isinstance(tensors[0], dict):
377 | keys = list(tensors[0].keys())
378 | stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
379 | stacked = DictWithShape(stacked, stacked['crossattn'].shape)
380 | else:
381 | stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
382 | stacked.pooled = torch.cat(pooled_outputs).to(device=param.device, dtype=param.dtype)
383 | return conds_list, stacked
384 |
385 |
386 | re_attention = re.compile(r"""
387 | \\\(|
388 | \\\)|
389 | \\\[|
390 | \\]|
391 | \\\\|
392 | \\|
393 | \(|
394 | \[|
395 | :\s*([+-]?[.\d]+)\s*\)|
396 | \)|
397 | ]|
398 | [^\\()\[\]:]+|
399 | :
400 | """, re.X)
401 |
402 | re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
403 | re_attention_v1 = re_attention
404 |
405 | def parse_prompt_attention(text):
406 | """
407 | Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
408 | Accepted tokens are:
409 | (abc) - increases attention to abc by a multiplier of 1.1
410 | (abc:3.12) - increases attention to abc by a multiplier of 3.12
411 | [abc] - decreases attention to abc by a multiplier of 1.1
412 | \( - literal character '('
413 | \[ - literal character '['
414 | \) - literal character ')'
415 | \] - literal character ']'
416 | \\ - literal character '\'
417 | anything else - just text
418 |
419 | >>> parse_prompt_attention('normal text')
420 | [['normal text', 1.0]]
421 | >>> parse_prompt_attention('an (important) word')
422 | [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
423 | >>> parse_prompt_attention('(unbalanced')
424 | [['unbalanced', 1.1]]
425 | >>> parse_prompt_attention('\(literal\]')
426 | [['(literal]', 1.0]]
427 | >>> parse_prompt_attention('(unnecessary)(parens)')
428 | [['unnecessaryparens', 1.1]]
429 | >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
430 | [['a ', 1.0],
431 | ['house', 1.5730000000000004],
432 | [' ', 1.1],
433 | ['on', 1.0],
434 | [' a ', 1.1],
435 | ['hill', 0.55],
436 | [', sun, ', 1.1],
437 | ['sky', 1.4641000000000006],
438 | ['.', 1.1]]
439 | """
440 |
441 | res = []
442 | round_brackets = []
443 | square_brackets = []
444 |
445 | round_bracket_multiplier = 1.1
446 | square_bracket_multiplier = 1 / 1.1
447 | if opts.prompt_attention == 'Fixed attention':
448 | res = [[text, 1.0]]
449 | return res
450 | elif opts.prompt_attention == 'Compel parser':
451 | conjunction = Compel.parse_prompt_string(text)
452 | if conjunction is None or conjunction.prompts is None or conjunction.prompts is None:
453 | return [["", 1.0]]
454 | use_and = '.and(' in text
455 | cprompts = conjunction.prompts if 'Blend' in conjunction.prompts[0].__class__.__name__ else [conjunction.prompts]
456 | first_el = conjunction.prompts[0].prompts[0] if hasattr(conjunction.prompts[0], 'prompts') else conjunction.prompts[0]
457 | if len(first_el.children) == 0:
458 | return [["", 1.0]]
459 | res = []
460 | for bprompt in cprompts:
461 | for prompt_idx, prompt in enumerate(bprompt.prompts if hasattr(bprompt, 'prompts') else bprompt):
462 | for k, frag in enumerate(prompt.children):
463 | res.append([f'{" AND " if use_and and k == 0 and prompt_idx > 0 else ""}{frag.text}', frag.weight])
464 | return res
465 | elif opts.prompt_attention == 'A1111 parser':
466 | re_attention = re_attention_v1
467 | whitespace = ''
468 | else:
469 | re_attention = re_attention_v1
470 | text = text.replace('\\n', ' ')
471 | whitespace = ' '
472 |
473 | def multiply_range(start_position, multiplier):
474 | for p in range(start_position, len(res)):
475 | res[p][1] *= multiplier
476 |
477 | for m in re_attention.finditer(text):
478 | text = m.group(0)
479 | weight = m.group(1)
480 |
481 | if text.startswith('\\'):
482 | res.append([text[1:], 1.0])
483 | elif text == '(':
484 | round_brackets.append(len(res))
485 | elif text == '[':
486 | square_brackets.append(len(res))
487 | elif weight is not None and round_brackets:
488 | multiply_range(round_brackets.pop(), float(weight))
489 | elif text == ')' and round_brackets:
490 | multiply_range(round_brackets.pop(), round_bracket_multiplier)
491 | elif text == ']' and square_brackets:
492 | multiply_range(square_brackets.pop(), square_bracket_multiplier)
493 | else:
494 | parts = re.split(re_break, text)
495 | for i, part in enumerate(parts):
496 | if i > 0:
497 | res.append(["BREAK", -1])
498 | if opts.prompt_attention == 'Full parser':
499 | part = re_clean.sub("", part)
500 | part = re_whitespace.sub(" ", part).strip()
501 | if len(part) == 0:
502 | continue
503 | res.append([part, 1.0])
504 |
505 | for pos in round_brackets:
506 | multiply_range(pos, round_bracket_multiplier)
507 |
508 | for pos in square_brackets:
509 | multiply_range(pos, square_bracket_multiplier)
510 |
511 | if len(res) == 0:
512 | res = [["", 1.0]]
513 |
514 | # merge runs of identical weights
515 | i = 0
516 | while i + 1 < len(res):
517 | if res[i][1] == res[i + 1][1]:
518 | res[i][0] += whitespace + res[i + 1][0]
519 | res.pop(i + 1)
520 | else:
521 | i += 1
522 |
523 | return res
524 |
525 | if __name__ == "__main__":
526 | import doctest
527 | doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
528 | import sys
529 | args = sys.argv[1:]
530 | if len(args) > 0:
531 | input_text = " ".join(args)
532 | else:
533 | input_text = '[black] [[grey]] (white) ((gray)) ((orange:1.1) yellow) ((purple) and [dark] red:1.1) [mouse:0.2] [(cat:1.1):0.5]'
534 | print(f'Prompt: {input_text}')
535 | all_schedules = get_learned_conditioning_prompt_schedules([input_text], 100)[0]
536 | print('Schedules', all_schedules)
537 | for schedule in all_schedules:
538 | print('Schedule', schedule[0])
539 | opts.prompt_attention = 'Fixed attention'
540 | output_list = parse_prompt_attention(schedule[1])
541 | print(' Fixed:', output_list)
542 | opts.prompt_attention = 'Compel parser'
543 | output_list = parse_prompt_attention(schedule[1])
544 | print(' Compel:', output_list)
545 | opts.prompt_attention = 'A1111 parser'
546 | output_list = parse_prompt_attention(schedule[1])
547 | print(' A1111:', output_list)
548 | opts.prompt_attention = 'Full parser'
549 | output_list = parse_prompt_attention(schedule[1])
550 | print(' Full :', output_list)
551 | else:
552 | import torch # doctest faster
--------------------------------------------------------------------------------
/modules/text_processing/t5_engine.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import namedtuple
3 | from . import prompt_parser, emphasis
4 | from comfy import model_management
5 |
6 |
7 | PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
8 |
9 | def populate_self_variables(self, from_):
10 | attrs_from = vars(from_)
11 | attrs_self = vars(self)
12 | attrs_self.update(attrs_from)
13 |
14 | class PromptChunk:
15 | def __init__(self):
16 | self.tokens = []
17 | self.multipliers = []
18 |
19 |
20 | class T5TextProcessingEngine:
21 | def __init__(self, text_encoder, tokenizer, emphasis_name="Original", min_length=256):
22 | super().__init__()
23 | populate_self_variables(self, tokenizer)
24 | self._tokenizer = tokenizer
25 |
26 | self.text_encoder = text_encoder
27 |
28 | self.emphasis = emphasis.get_current_option(emphasis_name)()
29 | self.min_length = self.min_length or self.max_length
30 | self.id_end = self.end_token
31 | self.id_pad = self.pad_token
32 | vocab = self.tokenizer.get_vocab()
33 | self.comma_token = vocab.get(',', None)
34 | self.token_mults = {}
35 | tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
36 | for text, ident in tokens_with_parens:
37 | mult = 1.0
38 | for c in text:
39 | if c == '[':
40 | mult /= 1.1
41 | if c == ']':
42 | mult *= 1.1
43 | if c == '(':
44 | mult *= 1.1
45 | if c == ')':
46 | mult /= 1.1
47 |
48 | if mult != 1.0:
49 | self.token_mults[ident] = mult
50 | self.tokenizer._eventual_warn_about_too_long_sequence = lambda *args, **kwargs: None
51 |
52 |
53 | def tokenize(self, texts):
54 | tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
55 | return tokenized
56 |
57 | def encode_with_transformers(self, tokens):
58 | try:
59 | z, pooled = self.text_encoder(tokens)
60 | except Exception:
61 | z, pooled = self.text_encoder(tokens.tolist())
62 | return z
63 |
64 | def tokenize_line(self, line):
65 | parsed = prompt_parser.parse_prompt_attention(line)
66 |
67 | tokenized = self.tokenize([text for text, _ in parsed])
68 |
69 | chunks = []
70 | chunk = PromptChunk()
71 | token_count = 0
72 |
73 | def next_chunk():
74 | nonlocal token_count
75 | nonlocal chunk
76 |
77 | chunk.tokens = chunk.tokens + [self.id_end]
78 | chunk.multipliers = chunk.multipliers + [1.0]
79 | current_chunk_length = len(chunk.tokens)
80 |
81 | token_count += current_chunk_length
82 | remaining_count = self.min_length - current_chunk_length
83 |
84 | if remaining_count > 0:
85 | chunk.tokens += [self.id_pad] * remaining_count
86 | chunk.multipliers += [1.0] * remaining_count
87 |
88 | chunks.append(chunk)
89 | chunk = PromptChunk()
90 |
91 | for tokens, (text, weight) in zip(tokenized, parsed):
92 | if text == 'BREAK' and weight == -1:
93 | next_chunk()
94 | continue
95 |
96 | position = 0
97 | while position < len(tokens):
98 | token = tokens[position]
99 | chunk.tokens.append(token)
100 | chunk.multipliers.append(weight)
101 | position += 1
102 |
103 | if chunk.tokens or not chunks:
104 | next_chunk()
105 |
106 | return chunks, token_count
107 |
108 | def unhook(self):
109 | w = '_eventual_warn_about_too_long_sequence'
110 | if hasattr(self.tokenizer, w): delattr(self.tokenizer, w)
111 | if hasattr(self._tokenizer, w): delattr(self._tokenizer, w)
112 |
113 | def tokenize_with_weights(self, texts, return_word_ids=False):
114 | tokens_and_weights = []
115 | cache = {}
116 | for line in texts:
117 | if line not in cache:
118 | chunks, token_count = self.tokenize_line(line)
119 | line_tokens_and_weights = []
120 |
121 | # Pad all chunks to the length of the longest chunk
122 | max_tokens = 0
123 | for chunk in chunks:
124 | max_tokens = max (len(chunk.tokens), max_tokens)
125 |
126 | for chunk in chunks:
127 | tokens = chunk.tokens
128 | multipliers = chunk.multipliers
129 | remaining_count = max_tokens - len(tokens)
130 | if remaining_count > 0:
131 | tokens += [self.id_pad] * remaining_count
132 | multipliers += [1.0] * remaining_count
133 | line_tokens_and_weights.append((tokens, multipliers))
134 | cache[line] = line_tokens_and_weights
135 |
136 | tokens_and_weights.extend(cache[line])
137 | return tokens_and_weights
138 |
139 | def encode_token_weights(self, token_weight_pairs):
140 | if isinstance(token_weight_pairs[0], str):
141 | token_weight_pairs = self.tokenize_with_weights(token_weight_pairs)
142 | elif isinstance(token_weight_pairs[0], list):
143 | token_weight_pairs = list(map(lambda x: (list(map(lambda y: y[0], x)), list(map(lambda y: y[1], x))), token_weight_pairs))
144 |
145 | target_device = model_management.text_encoder_offload_device()
146 | zs = []
147 | cache = {}
148 | for tokens, multipliers in token_weight_pairs:
149 | token_key = (tuple(tokens), tuple(multipliers))
150 | if token_key not in cache:
151 | z = self.process_tokens([tokens], [multipliers])[0]
152 | cache[token_key] = z
153 | zs.append(cache[token_key])
154 | return torch.stack(zs).to(target_device), None
155 |
156 | def __call__(self, texts):
157 | tokens = self.tokenize_with_weights(texts)
158 | return self.encode_token_weights(tokens)
159 |
160 | def process_tokens(self, batch_tokens, batch_multipliers):
161 | tokens = torch.asarray(batch_tokens)
162 |
163 | z = self.encode_with_transformers(tokens)
164 |
165 | self.emphasis.tokens = batch_tokens
166 | self.emphasis.multipliers = torch.asarray(batch_multipliers).to(z)
167 | self.emphasis.z = z
168 | self.emphasis.after_transformers()
169 | z = self.emphasis.z
170 |
171 | return z
172 |
--------------------------------------------------------------------------------
/modules/text_processing/textual_inversion.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import base64
4 | import json
5 | import zlib
6 | import logging
7 | import numpy as np
8 | import safetensors.torch
9 | from PIL import Image
10 |
11 |
12 | class EmbeddingEncoder(json.JSONEncoder):
13 | def default(self, obj):
14 | if isinstance(obj, torch.Tensor):
15 | return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
16 | return json.JSONEncoder.default(self, obj)
17 |
18 |
19 | class EmbeddingDecoder(json.JSONDecoder):
20 | def __init__(self, *args, **kwargs):
21 | json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
22 |
23 | def object_hook(self, d):
24 | if 'TORCHTENSOR' in d:
25 | return torch.from_numpy(np.array(d['TORCHTENSOR']))
26 | return d
27 |
28 |
29 | def embedding_to_b64(data):
30 | d = json.dumps(data, cls=EmbeddingEncoder)
31 | return base64.b64encode(d.encode())
32 |
33 |
34 | def embedding_from_b64(data):
35 | d = base64.b64decode(data)
36 | return json.loads(d, cls=EmbeddingDecoder)
37 |
38 |
39 | def lcg(m=2 ** 32, a=1664525, c=1013904223, seed=0):
40 | while True:
41 | seed = (a * seed + c) % m
42 | yield seed % 255
43 |
44 |
45 | def xor_block(block):
46 | g = lcg()
47 | randblock = np.array([next(g) for _ in range(np.prod(block.shape))]).astype(np.uint8).reshape(block.shape)
48 | return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
49 |
50 |
51 | def crop_black(img, tol=0):
52 | mask = (img > tol).all(2)
53 | mask0, mask1 = mask.any(0), mask.any(1)
54 | col_start, col_end = mask0.argmax(), mask.shape[1] - mask0[::-1].argmax()
55 | row_start, row_end = mask1.argmax(), mask.shape[0] - mask1[::-1].argmax()
56 | return img[row_start:row_end, col_start:col_end]
57 |
58 |
59 | def extract_image_data_embed(image):
60 | d = 3
61 | outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
62 | black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
63 | if black_cols[0].shape[0] < 2:
64 | print(f'{os.path.basename(getattr(image, "filename", "unknown image file"))}: no embedded information found.')
65 | return None
66 |
67 | data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
68 | data_block_upper = outarr[:, black_cols[0].max() + 1:, :].astype(np.uint8)
69 |
70 | data_block_lower = xor_block(data_block_lower)
71 | data_block_upper = xor_block(data_block_upper)
72 |
73 | data_block = (data_block_upper << 4) | (data_block_lower)
74 | data_block = data_block.flatten().tobytes()
75 |
76 | data = zlib.decompress(data_block)
77 | return json.loads(data, cls=EmbeddingDecoder)
78 |
79 |
80 | class Embedding:
81 | def __init__(self, vec, name, step=None):
82 | self.vec = vec
83 | self.name = name
84 | self.step = step
85 | self.shape = None
86 | self.vectors = 0
87 | self.sd_checkpoint = None
88 | self.sd_checkpoint_name = None
89 |
90 |
91 | class DirWithTextualInversionEmbeddings:
92 | def __init__(self, path):
93 | self.path = path
94 | self.mtime = None
95 |
96 | def has_changed(self):
97 | if not os.path.isdir(self.path):
98 | return False
99 |
100 | mt = os.path.getmtime(self.path)
101 | if self.mtime is None or mt > self.mtime:
102 | return True
103 |
104 | def update(self):
105 | if not os.path.isdir(self.path):
106 | return
107 |
108 | self.mtime = os.path.getmtime(self.path)
109 |
110 |
111 | class EmbeddingDatabase:
112 | def __init__(self, tokenizer, expected_shape=-1):
113 | self.ids_lookup = {}
114 | self.word_embeddings = {}
115 | self.embedding_dirs = {}
116 | self.skipped_embeddings = {}
117 | self.expected_shape = expected_shape
118 | self.tokenizer = tokenizer
119 | self.fixes = []
120 |
121 | def add_embedding_dir(self, path):
122 | self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
123 |
124 | def clear_embedding_dirs(self):
125 | self.embedding_dirs.clear()
126 |
127 | def register_embedding(self, embedding):
128 | return self.register_embedding_by_name(embedding, embedding.name)
129 |
130 | def register_embedding_by_name(self, embedding, name):
131 | ids = self.tokenizer([name], truncation=False, add_special_tokens=False)["input_ids"][0]
132 | first_id = ids[0]
133 | if first_id not in self.ids_lookup:
134 | self.ids_lookup[first_id] = []
135 | if name in self.word_embeddings:
136 | lookup = [x for x in self.ids_lookup[first_id] if x[1].name != name]
137 | else:
138 | lookup = self.ids_lookup[first_id]
139 | if embedding is not None:
140 | lookup += [(ids, embedding)]
141 | self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)
142 | if embedding is None:
143 | if name in self.word_embeddings:
144 | del self.word_embeddings[name]
145 | if len(self.ids_lookup[first_id]) == 0:
146 | del self.ids_lookup[first_id]
147 | return None
148 | self.word_embeddings[name] = embedding
149 | return embedding
150 |
151 | def load_from_file(self, path, filename):
152 | name, ext = os.path.splitext(filename)
153 | ext = ext.upper()
154 |
155 | if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
156 | _, second_ext = os.path.splitext(name)
157 | if second_ext.upper() == '.PREVIEW':
158 | return
159 |
160 | embed_image = Image.open(path)
161 | if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
162 | data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
163 | name = data.get('name', name)
164 | else:
165 | data = extract_image_data_embed(embed_image)
166 | if data:
167 | name = data.get('name', name)
168 | else:
169 | return
170 | elif ext in ['.BIN', '.PT']:
171 | data = torch.load(path, map_location="cpu")
172 | elif ext in ['.SAFETENSORS']:
173 | data = safetensors.torch.load_file(path, device="cpu")
174 | else:
175 | return
176 |
177 | emb_out = None
178 | if data is not None:
179 | embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)
180 |
181 | # if self.expected_shape == -1 or self.expected_shape == embedding.shape:
182 | emb_out = self.register_embedding(embedding)
183 | # else:
184 | # emb_out = self.skipped_embeddings[name] = embedding
185 | else:
186 | print(f"Unable to load Textual inversion embedding due to data issue: '{name}'.")
187 | return emb_out
188 |
189 | def load_from_dir(self, embdir):
190 | if not os.path.isdir(embdir.path):
191 | return
192 |
193 | for root, _, fns in os.walk(embdir.path, followlinks=True):
194 | for fn in fns:
195 | try:
196 | fullfn = os.path.join(root, fn)
197 |
198 | if os.stat(fullfn).st_size == 0:
199 | continue
200 |
201 | self.load_from_file(fullfn, fn)
202 | except Exception:
203 | print(f"Error loading embedding {fn}")
204 | continue
205 |
206 | def load_textual_inversion_embeddings(self):
207 | self.ids_lookup.clear()
208 | self.word_embeddings.clear()
209 | self.skipped_embeddings.clear()
210 |
211 | for embdir in self.embedding_dirs.values():
212 | self.load_from_dir(embdir)
213 | embdir.update()
214 |
215 | return
216 |
217 | def find_embedding_at_position(self, tokens, offset):
218 | token = tokens[offset]
219 | possible_matches = self.ids_lookup.get(token, None)
220 |
221 | if possible_matches is None:
222 | return None, None
223 |
224 | for ids, embedding in possible_matches:
225 | if tokens[offset:offset + len(ids)] == ids:
226 | return embedding, len(ids)
227 |
228 | return None, None
229 |
230 |
231 | def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None):
232 | if 'string_to_param' in data: # textual inversion embeddings
233 | param_dict = data['string_to_param']
234 | param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
235 | assert len(param_dict) == 1, 'embedding file has multiple terms in it'
236 | emb = next(iter(param_dict.items()))[1]
237 | vec = emb.detach().to(dtype=torch.float32)
238 | shape = vec.shape[-1]
239 | vectors = vec.shape[0]
240 | elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
241 | vec = {k: v.detach().to(dtype=torch.float32) for k, v in data.items()}
242 | shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
243 | vectors = data['clip_g'].shape[0]
244 | elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
245 | assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
246 |
247 | emb = next(iter(data.values()))
248 | if len(emb.shape) == 1:
249 | emb = emb.unsqueeze(0)
250 | vec = emb.detach().to(dtype=torch.float32)
251 | shape = vec.shape[-1]
252 | vectors = vec.shape[0]
253 | else:
254 | raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
255 |
256 | embedding = Embedding(vec, name)
257 | embedding.step = data.get('step', None)
258 | embedding.sd_checkpoint = data.get('sd_checkpoint', None)
259 | embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
260 | embedding.vectors = vectors
261 | embedding.shape = shape
262 |
263 | return embedding
264 |
265 | from comfy.sd1_clip import expand_directory_list
266 | def get_embed_file_path(embedding_name, embedding_directory):
267 | if isinstance(embedding_directory, str):
268 | embedding_directory = [embedding_directory]
269 |
270 | embedding_directory = expand_directory_list(embedding_directory)
271 |
272 | valid_file = None
273 | for embed_dir in embedding_directory:
274 | embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name))
275 | embed_dir = os.path.abspath(embed_dir)
276 | try:
277 | if os.path.commonpath((embed_dir, embed_path)) != embed_dir:
278 | continue
279 | except Exception:
280 | continue
281 | if not os.path.isfile(embed_path):
282 | extensions = ['.safetensors', '.pt', '.bin']
283 | for x in extensions:
284 | t = embed_path + x
285 | if os.path.isfile(t):
286 | valid_file = t
287 | break
288 | else:
289 | valid_file = embed_path
290 | if valid_file is not None:
291 | break
292 |
293 | if valid_file is None:
294 | return None
295 |
296 | return valid_file
297 |
298 | import re
299 | from ..shared import logger
300 | emb_re_ = r"(embedding:)?(?:({}[\w\.\-\!\$\/\\]+(\.safetensors|\.pt|\.bin)|(?(1)[\w\.\-\!\$\/\\]+|(?!)))(\.safetensors|\.pt|\.bin)?)(?:(:)(\d+\.?\d*|\d*\.\d+))?"
301 | def get_valid_embeddings(embedding_directories):
302 | from builtins import any as b_any
303 | exts = ['.safetensors', '.pt', '.bin']
304 | if isinstance(embedding_directories, str):
305 | embedding_directories = [embedding_directories]
306 | embedding_directories = expand_directory_list(embedding_directories)
307 | embs = set()
308 | from collections import OrderedDict, namedtuple
309 | EmbedInfo = namedtuple('EmbedInfo', ['basename', 'filename', 'filepath'])
310 | store = OrderedDict()
311 | for embd in embedding_directories:
312 | for root, dirs, files in os.walk(embd, followlinks=True, topdown=False):
313 | for name in files:
314 | if not b_any(x in os.path.splitext(name)[1] for x in exts): continue
315 | basename = os.path.basename(name)
316 | for ext in exts: basename=basename.removesuffix(ext)
317 | relpath_basename = os.path.normpath(os.path.join(os.path.relpath(root, embd), basename))
318 | k = os.path.normpath(os.path.join(os.path.relpath(root, embd), name))
319 | store[k] = EmbedInfo(basename, name, os.path.join(root, name))
320 | # add its counterpart
321 | if '/' in k:
322 | store[k.replace('/', '\\')] = EmbedInfo(basename, name, os.path.join(root, name))
323 | elif '\\' in relpath_basename:
324 | store[k.replace('\\', '/')] = EmbedInfo(basename, name, os.path.join(root, name))
325 |
326 | embs = OrderedDict(sorted(store.items(), key=lambda item: len(item[0]), reverse=True))
327 | return embs
328 |
329 | class EmbbeddingRegex:
330 | STR_PATTERN = r"(embedding:)?(?:({}[\w\.\-\!\$\/\\]+(\.safetensors|\.pt|\.bin)|(?(1)[\w\.\-\!\$\/\\]+|(?!)))(\.safetensors|\.pt|\.bin)?)(?:(:)(\d+\.?\d*|\d*\.\d+))?"
331 | def __init__(self, embedding_directory) -> None:
332 | self.embedding_directory = embedding_directory
333 | self.embeddings = get_valid_embeddings(self.embedding_directory) if self.embedding_directory is not None else {}
334 | joined_keys = '|'.join([re.escape(os.path.splitext(k)[0]) for k in self.embeddings.keys()])
335 | emb_re = self.STR_PATTERN.format(joined_keys + '|' if joined_keys else '')
336 | self.pattern = re.compile(emb_re, flags=re.MULTILINE | re.UNICODE | re.IGNORECASE)
337 |
338 | def parse_and_register_embeddings(self, text: str):
339 | embr = EmbbeddingRegex(self.embedding_directory)
340 | embs = embr.embeddings
341 | matches = embr.pattern.finditer(text)
342 | exts = ['.pt', '.safetensors', '.bin']
343 |
344 | for matchNum, match in enumerate(matches, start=1):
345 | found=False
346 | ext = (match.group(4) or (match.group(3) or ''))
347 | embedding_sname = (match.group(2) or '').removesuffix(ext)
348 | embedding_name = embedding_sname + ext
349 | if embedding_name:
350 | embed = None
351 | if ext:
352 | embed_info = embs.get(embedding_name + ext, None)
353 | else:
354 | for _ext in exts:
355 | embed_info = embs.get(embedding_name + _ext, None)
356 | if embed_info is not None: break
357 | if embed_info is not None:
358 | found=True
359 | try:
360 | embed = self.embeddings.load_from_file(embed_info.filepath, embed_info.filename)
361 | except Exception as e:
362 | logging.warning(f'\033[33mWarning\033[0m loading embedding `{embedding_name + ext}`: {e}')
363 | if embed is not None:
364 | found=True
365 | logger.debug(f'using embedding:{embedding_name}')
366 | if not found:
367 | logging.warning(f"\033[33mwarning\033[0m, embedding:{embedding_name} does not exist, ignoring")
368 | # ComfyUI trims non-existent embedding_names while A1111 doesn't.
369 | # here we get group 2,5,6. (group 2 minus its file extension)
370 | out = embr.pattern.sub(lambda m: (m.group(2) or '').removesuffix(m.group(4) or (m.group(3) or '')) + (m.group(5) or '') + (m.group(6) or ''), text)
371 | return out
--------------------------------------------------------------------------------
/nodes.py:
--------------------------------------------------------------------------------
1 | import re
2 | import logging
3 | from itertools import chain
4 | from nodes import MAX_RESOLUTION
5 | import comfy.model_patcher
6 | import comfy.sd
7 | import comfy.model_management
8 | import comfy.samplers
9 | from .smZNodes import HijackClip, HijackClipComfy, get_learned_conditioning
10 | from comfy_extras.nodes_clip_sdxl import CLIPTextEncodeSDXL
11 |
12 | class smZ_CLIPTextEncode:
13 | @classmethod
14 | def INPUT_TYPES(s):
15 | return {"required": {
16 | "text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
17 | "clip": ("CLIP", ),
18 | "parser": (["comfy", "comfy++", "A1111", "full", "compel", "fixed attention"],{"default": "comfy"}),
19 | "mean_normalization": ("BOOLEAN", {"default": True, "tooltip": "Toggles whether weights are normalized by taking the mean"}),
20 | "multi_conditioning": ("BOOLEAN", {"default": True}),
21 | "use_old_emphasis_implementation": ("BOOLEAN", {"default": False}),
22 | "with_SDXL": ("BOOLEAN", {"default": False}),
23 | "ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
24 | "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
25 | "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
26 | "crop_w": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
27 | "crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
28 | "target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
29 | "target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
30 | "text_g": ("STRING", {"multiline": True, "placeholder": "CLIP_G", "dynamicPrompts": True}),
31 | "text_l": ("STRING", {"multiline": True, "placeholder": "CLIP_L", "dynamicPrompts": True}),
32 | },
33 | "optional": {
34 | "smZ_steps": ("INT", {"default": 1, "min": 1, "max": 0xffffffffffffffff}),
35 | },
36 | }
37 | RETURN_TYPES = ("CONDITIONING",)
38 | FUNCTION = "encode"
39 | CATEGORY = "conditioning"
40 |
41 | def encode(self, clip: comfy.sd.CLIP, text, parser, mean_normalization,
42 | multi_conditioning, use_old_emphasis_implementation,
43 | with_SDXL, ascore, width, height, crop_w,
44 | crop_h, target_width, target_height, text_g, text_l, smZ_steps=1):
45 | from .modules.shared import Options, opts, opts_default
46 | debug=opts.debug # get global opts' debug
47 | if (opts_new := clip.patcher.model_options.get(Options.KEY, None)) is not None:
48 | opts.update(opts_new)
49 | debug = opts_new.debug
50 | else:
51 | opts.update(opts_default)
52 | opts.debug = debug
53 | opts.prompt_mean_norm = mean_normalization
54 | opts.use_old_emphasis_implementation = use_old_emphasis_implementation
55 | opts.multi_conditioning = multi_conditioning
56 | class_name = clip.cond_stage_model.__class__.__name__
57 | is_sdxl = "SDXL" in class_name
58 | on_sdxl = with_SDXL and is_sdxl
59 | parsers = {
60 | "full": "Full parser",
61 | "compel": "Compel parser",
62 | "A1111": "A1111 parser",
63 | "fixed attention": "Fixed attention",
64 | "comfy++": "Comfy++ parser",
65 | }
66 | opts.prompt_attention = parsers.get(parser, "Comfy parser")
67 |
68 | def _comfy_path(clip, text):
69 | nonlocal on_sdxl, class_name, ascore, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l
70 | if on_sdxl and class_name == "SDXLClipModel":
71 | return CLIPTextEncodeSDXL().encode(clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l)
72 | elif on_sdxl and class_name == "SDXLRefinerClipModel":
73 | from comfy_extras.nodes_clip_sdxl import CLIPTextEncodeSDXLRefiner
74 | return CLIPTextEncodeSDXLRefiner().encode(clip, clip, ascore, width, height, text)
75 | else:
76 | from nodes import CLIPTextEncode
77 | return CLIPTextEncode().encode(clip, text)
78 |
79 | def comfy_path(clip):
80 | nonlocal text
81 | if on_sdxl and class_name == "SDXLRefinerClipModel":
82 | return _comfy_path(clip, text)
83 | else:
84 | if multi_conditioning:
85 | prompts = re.compile(r"\bAND\b").split(text)
86 | return (list(chain(*(_comfy_path(clip, prompt)[0] for prompt in prompts))), )
87 | else:
88 | return _comfy_path(clip, text)
89 |
90 | if parser == "comfy":
91 | with HijackClipComfy(clip) as clip:
92 | return comfy_path(clip)
93 | elif parser == "comfy++":
94 | with HijackClip(clip, opts) as clip:
95 | with HijackClipComfy(clip) as clip:
96 | return comfy_path(clip)
97 |
98 | with HijackClip(clip, opts) as clip:
99 | model = lambda txt: clip.encode_from_tokens(clip.tokenize(txt), return_pooled=True, return_dict=True)
100 | steps = max(smZ_steps, 1)
101 | if on_sdxl and class_name == "SDXLClipModel":
102 | # skip prompt-editing
103 | schedules = CLIPTextEncodeSDXL().encode(clip, width, height, crop_w, crop_h, target_width, target_height, [text_g], [text_l])[0]
104 | else:
105 | schedules = get_learned_conditioning(model, [text], steps, multi_conditioning)
106 | if on_sdxl and class_name == "SDXLRefinerClipModel":
107 | for cx in schedules:
108 | cx[1] |= {"aesthetic_score": ascore, "width": width,"height": height}
109 | return (schedules, )
110 |
111 | # Hack: string type that is always equal in not equal comparisons
112 | class AnyType(str):
113 | def __eq__(self, _):
114 | return True
115 |
116 | def __ne__(self, _):
117 | return False
118 |
119 | # Our any instance wants to be a wildcard string
120 | anytype = AnyType("*")
121 |
122 | class smZ_Settings:
123 | @classmethod
124 | def INPUT_TYPES(s):
125 | from .modules.shared import opts_default as opts
126 | from .modules.text_processing.emphasis import get_options_descriptions_nl
127 | i = 0
128 | def create_heading():
129 | nonlocal i
130 | return "ㅤ"*(i:=i+1)
131 | create_heading_value = lambda x: ("STRING", {"multiline": False, "default": x, "placeholder": x})
132 | optional = {
133 | # "show_headings": ("BOOLEAN", {"default": True}),
134 | # "show_descriptions": ("BOOLEAN", {"default":True}),
135 |
136 | create_heading(): create_heading_value("Stable Diffusion"),
137 | "info_comma_padding_backtrack": ("STRING", {"multiline": True, "placeholder": "Prompt word wrap length limit\nin tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"}),
138 | "Prompt word wrap length limit": ("INT", {"default": opts.comma_padding_backtrack, "min": 0, "max": 74, "step": 1, "tooltip": "🚧Prompt word wrap length limit\n\nin tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"}),
139 | "enable_emphasis": ("BOOLEAN", {"default": opts.enable_emphasis, "tooltip": "🚧Emphasis mode\n\nmakes it possible to make model to pay (more:1.1) or (less:0.9) attention to text when you use the syntax in prompt;\n\n" + get_options_descriptions_nl()}),
140 |
141 | "info_RNG": ("STRING", {"multiline": True, "placeholder": "Random number generator source.\nchanges seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"}),
142 | "RNG": (["cpu", "gpu", "nv"],{"default": opts.randn_source, "tooltip": "Random number generator source.\n\nchanges seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"}),
143 |
144 | create_heading(): create_heading_value("Compute Settings"),
145 | "info_disable_nan_check": ("STRING", {"multiline": True, "placeholder": "Disable NaN check in produced images/latent spaces. Only for CFGDenoiser."}),
146 | "disable_nan_check": ("BOOLEAN", {"default": opts.disable_nan_check, "tooltip": "Disable NaN check in produced images/latent spaces. Only for CFGDenoiser."}),
147 |
148 | create_heading(): create_heading_value("Sampler parameters"),
149 | "info_eta_ancestral": ("STRING", {"multiline": True, "placeholder": "Eta for k-diffusion samplers\nnoise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"}),
150 | "eta": ("FLOAT", {"default": opts.eta, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Eta for k-diffusion samplers\n\nnoise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"}),
151 | "info_s_churn": ("STRING", {"multiline": True, "placeholder": "Sigma churn\namount of stochasticity; only applies to Euler, Heun, Heun++2, and DPM2"}),
152 | "s_churn": ("FLOAT", {"default": opts.s_churn, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Sigma churn\n\namount of stochasticity; only applies to Euler, Heun, Heun++2, and DPM2"}),
153 | "info_s_tmin": ("STRING", {"multiline": True, "placeholder": "Sigma tmin\nenable stochasticity; start value of the sigma range; only applies to Euler, Heun, Heun++2, and DPM2'"}),
154 | "s_tmin": ("FLOAT", {"default": opts.s_tmin, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Sigma tmin\n\nenable stochasticity; start value of the sigma range; only applies to Euler, Heun, Heun++2, and DPM2'"}),
155 | "info_s_tmax": ("STRING", {"multiline": True, "placeholder": "Sigma tmax\n0 = inf; end value of the sigma range; only applies to Euler, Heun, Heun++2, and DPM2"}),
156 | "s_tmax": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 999.0, "step": 0.01, "tooltip": "Sigma tmax\n\n0 = inf; end value of the sigma range; only applies to Euler, Heun, Heun++2, and DPM2"}),
157 | "info_s_noise": ("STRING", {"multiline": True, "placeholder": "Sigma noise\namount of additional noise to counteract loss of detail during sampling"}),
158 | "s_noise": ("FLOAT", {"default": opts.s_noise, "min": 0.0, "max": 1.1, "step": 0.001, "tooltip": "Sigma noise\n\namount of additional noise to counteract loss of detail during sampling"}),
159 | "info_eta_noise_seed_delta": ("STRING", {"multiline": True, "placeholder": "Eta noise seed delta\ndoes not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"}),
160 | "ENSD": ("INT", {"default": opts.eta_noise_seed_delta, "min": 0, "max": 0xffffffffffffffff, "step": 1, "tooltip": "Eta noise seed delta\n\ndoes not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"}),
161 | "info_skip_early_cond": ("STRING", {"multiline": True, "placeholder": "Ignore negative prompt during early sampling\ndisables CFG on a proportion of steps at the beginning of generation; 0=skip none; 1=skip all; can both improve sample diversity/quality and speed up sampling"}),
162 | "skip_early_cond": ("FLOAT", {"default": opts.skip_early_cond, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Ignore negative prompt during early sampling\n\ndisables CFG on a proportion of steps at the beginning of generation; 0=skip none; 1=skip all; can both improve sample diversity/quality and speed up sampling"}),
163 | "info_sgm_noise_multiplier": ("STRING", {"multiline": True, "placeholder": "SGM noise multiplier\nmatch initial noise to official SDXL implementation - only useful for reproducing images\nsee https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818"}),
164 | "sgm_noise_multiplier": ("BOOLEAN", {"default": opts.sgm_noise_multiplier, "tooltip": "SGM noise multiplier\n\nmatch initial noise to official SDXL implementation - only useful for reproducing images\nsee https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818"}),
165 | "info_upcast_sampling": ("STRING", {"multiline": True, "placeholder": "upcast sampling.\nNo effect with --force-fp32. Usually produces similar results to --force-fp32 with better performance while using less memory."}),
166 | "upcast_sampling": ("BOOLEAN", {"default": opts.upcast_sampling, "tooltip": "🚧upcast sampling.\n\nNo effect with --force-fp32. Usually produces similar results to --force-fp32 with better performance while using less memory."}),
167 |
168 | create_heading(): create_heading_value("Optimizations"),
169 | "info_NGMS": ("STRING", {"multiline": True, "placeholder": "Negative Guidance minimum sigma\nskip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster. Only for CFGDenoiser.\nsee https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177\nhttps://github.com/lllyasviel/stable-diffusion-webui-forge/pull/1434"}),
170 | "NGMS": ("FLOAT", {"default": opts.s_min_uncond, "min": 0.0, "max": 15.0, "step": 0.01, "tooltip": "Negative Guidance minimum sigma\n\nskip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster.\nsee https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177\nhttps://github.com/lllyasviel/stable-diffusion-webui-forge/pull/1434"}),
171 | "info_NGMS_all_steps": ("STRING", {"multiline": True, "placeholder": "Negative Guidance minimum sigma all steps\nBy default, NGMS above skips every other step; this makes it skip all steps"}),
172 | "NGMS all steps": ("BOOLEAN", {"default": opts.s_min_uncond_all, "tooltip": "Negative Guidance minimum sigma all steps\n\nBy default, NGMS above skips every other step; this makes it skip all steps"}),
173 | "info_pad_cond_uncond": ("STRING", {"multiline": True, "placeholder": "Pad prompt/negative prompt to be same length\nimproves performance when prompt and negative prompt have different lengths; changes seeds. Only for CFGDenoiser."}),
174 | "pad_cond_uncond": ("BOOLEAN", {"default": opts.pad_cond_uncond, "tooltip": "🚧Pad prompt/negative prompt to be same length\n\nimproves performance when prompt and negative prompt have different lengths; changes seeds. Only for CFGDenoiser."}),
175 | "info_batch_cond_uncond": ("STRING", {"multiline": True, "placeholder": "Batch cond/uncond\ndo both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed. Only for CFGDenoiser."}),
176 | "batch_cond_uncond": ("BOOLEAN", {"default": opts.batch_cond_uncond, "tooltip": "🚧Batch cond/uncond\n\ndo both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed. Only for CFGDenoiser."}),
177 |
178 | create_heading(): create_heading_value("Compatibility"),
179 | "info_use_prev_scheduling": ("STRING", {"multiline": True, "placeholder": "Previous prompt editing timelines\nFor [red:green:N]; previous: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"}),
180 | "Use previous prompt editing timelines": ("BOOLEAN", {"default": opts.use_old_scheduling, "tooltip": "🚧Previous prompt editing timelines\n\nFor [red:green:N]; previous: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"}),
181 |
182 | create_heading(): create_heading_value("Experimental"),
183 | "info_use_CFGDenoiser": ("STRING", {"multiline": True, "placeholder": "CFGDenoiser\nAn experimental option to use stable-diffusion-webui's denoiser. It allows you to use the 'Optimizations' settings listed here."}),
184 | "Use CFGDenoiser": ("BOOLEAN", {"default": opts.use_CFGDenoiser, "tooltip": "🚧CFGDenoiser\n\nAn experimental option to use stable-diffusion-webui's denoiser. It allows you to use the 'Optimizations' settings listed here."}),
185 | "info_debug": ("STRING", {"multiline": True, "placeholder": "Debugging messages in the console."}),
186 | "debug": ("BOOLEAN", {"default": opts.debug, "label_on": "on", "label_off": "off", "tooltip": "Debugging messages in the console."}),
187 | }
188 | return {
189 | "required": {
190 | "*": (anytype, {"forceInput": True}),
191 | },
192 | "optional": {
193 | "extra": ("STRING", {"multiline": True, "default": '{"show_headings":true,"show_descriptions":false,"mode":"*"}'}),
194 | **optional,
195 | },
196 | }
197 | RETURN_TYPES = (anytype,)
198 | FUNCTION = "apply"
199 | CATEGORY = "advanced"
200 | OUTPUT_TOOLTIPS = ("The model used for denoising latents.",)
201 |
202 | def apply(self, *args, **kwargs):
203 | first = kwargs.pop('*', None) if '*' in kwargs else args[0]
204 | if not hasattr(first, 'clone') or first is None: return (first,)
205 |
206 | kwargs['s_min_uncond'] = kwargs.pop('NGMS', 0.0)
207 | kwargs['s_min_uncond_all'] = kwargs.pop('NGMS all steps', False)
208 | kwargs['comma_padding_backtrack'] = kwargs.pop('Prompt word wrap length limit')
209 | kwargs['use_old_scheduling']=kwargs.pop("Use previous prompt editing timelines")
210 | kwargs['use_CFGDenoiser'] = kwargs.pop("Use CFGDenoiser")
211 | kwargs['randn_source'] = kwargs.pop('RNG')
212 | kwargs['eta_noise_seed_delta'] = kwargs.pop('ENSD')
213 | kwargs['s_tmax'] = kwargs['s_tmax'] or float('inf')
214 |
215 | from .modules.shared import Options, logger, opts_default, opts as opts_global
216 |
217 | opts_global.update(opts_default)
218 | opts = opts_default.clone()
219 | kwargs_new = {k: v for k, v in kwargs.items() if not ('info' in k or 'heading' in k or 'ㅤ' in k)}
220 | opts.update(kwargs_new)
221 | opts_global.debug = opts.debug
222 |
223 | opts_key = Options.KEY
224 | if isinstance(first, comfy.model_patcher.ModelPatcher):
225 | first = first.clone()
226 | first.model_options[opts_key] = opts
227 | elif isinstance(first, comfy.sd.CLIP):
228 | first = first.clone()
229 | first.patcher.model_options[opts_key] = opts
230 | logger.setLevel(logging.DEBUG if opts_global.debug else logging.INFO)
231 | return (first,)
232 |
233 | # A dictionary that contains all nodes you want to export with their names
234 | # NOTE: names should be globally unique
235 | NODE_CLASS_MAPPINGS = {
236 | "smZ CLIPTextEncode": smZ_CLIPTextEncode,
237 | "smZ Settings": smZ_Settings,
238 | }
239 | # A dictionary that contains the friendly/humanly readable titles for the nodes
240 | NODE_DISPLAY_NAME_MAPPINGS = {
241 | "smZ CLIPTextEncode" : "CLIP Text Encode++",
242 | "smZ Settings" : "Settings (smZ)",
243 | }
244 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "comfyui_smznodes"
3 | description = "Nodes such as CLIP Text Encode++ to achieve identical embeddings from stable-diffusion-webui for ComfyUI."
4 | version = "1.2.19"
5 | license = { file = "LICENSE" }
6 | dependencies = ["lark >= 1.1.9", "compel"]
7 |
8 | [project.urls]
9 | Repository = "https://github.com/shiimizu/ComfyUI_smZNodes"
10 | # Used by Comfy Registry https://comfyregistry.org
11 |
12 | [tool.comfy]
13 | PublisherId = "shiimizu"
14 | DisplayName = "ComfyUI_smZNodes"
15 | Icon = ""
16 |
--------------------------------------------------------------------------------
/smZNodes.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import re
3 | import torch
4 | import inspect
5 | import contextlib
6 | import logging
7 | import comfy
8 | import math
9 | import ctypes
10 | from decimal import Decimal
11 | from functools import partial
12 | from random import getrandbits
13 | import comfy.sdxl_clip
14 | import comfy.sd1_clip
15 | import comfy.sample
16 | import comfy.utils
17 | import comfy.samplers
18 | from comfy.sd1_clip import unescape_important, escape_important, token_weights
19 | from .modules.shared import SimpleNamespaceFast, Options, logger, join_args
20 | from .modules.text_processing import prompt_parser
21 | from .modules.text_processing.past_classic_engine import process_texts_past
22 | from .modules.text_processing.textual_inversion import EmbbeddingRegex
23 | from .modules.text_processing.classic_engine import ClassicTextProcessingEngine
24 | from .modules.text_processing.t5_engine import T5TextProcessingEngine
25 |
26 | class Store(SimpleNamespaceFast): ...
27 |
28 | store = Store()
29 |
30 | def register_hooks():
31 | from .modules.rng import prepare_noise
32 | patches = [
33 | (comfy.samplers, 'get_area_and_mult', get_area_and_mult),
34 | (comfy.samplers.KSampler, 'sample', KSampler_sample),
35 | (comfy.samplers.KSAMPLER, 'sample', KSAMPLER_sample),
36 | (comfy.samplers, 'sample', sample),
37 | (comfy.samplers.Sampler, 'max_denoise', max_denoise),
38 | (comfy.samplers, 'sampling_function', sampling_function),
39 | (comfy.sample, 'prepare_noise', prepare_noise),
40 | ]
41 | for parent, fn_name, fn_patch in patches:
42 | if not hasattr(store, fn_patch.__name__):
43 | setattr(store, fn_patch.__name__, getattr(parent, fn_name))
44 | setattr(parent, fn_name, fn_patch)
45 |
46 | def iter_items(d):
47 | for key, value in d.items():
48 | yield key, value
49 | if isinstance(value, dict):
50 | yield from iter_items(value)
51 |
52 | def find_nearest(a,b):
53 | # Calculate the absolute differences.
54 | diff = (a - b).abs()
55 |
56 | # Find the indices of the nearest elements
57 | nearest_indices = diff.argmin()
58 |
59 | # Get the nearest elements from b
60 | return b[nearest_indices]
61 |
62 | def get_area_and_mult(*args, **kwargs):
63 | conds = args[0]
64 | if 'start_perc' in conds and 'end_perc' in conds and "init_steps" in conds:
65 | timestep_in = args[2]
66 | sigmas = store.sigmas
67 | if conds['init_steps'] == sigmas.shape[0] - 1:
68 | total = Decimal(sigmas.shape[0] - 1)
69 | else:
70 | sigmas_ = store.sigmas.unique(sorted=True).sort(descending=True)[0]
71 | if len(sigmas) == len(sigmas_):
72 | # Sampler Custom with sigmas: no change
73 | total = Decimal(sigmas.shape[0] - 1)
74 | else:
75 | # Sampler with restarts: dedup the sigmas and add one
76 | sigmas = sigmas_
77 | total = Decimal(sigmas.shape[0] + 1)
78 | ts_in = find_nearest(timestep_in, sigmas)
79 | cur_i = ss[0].item() if (ss:=(sigmas == ts_in).nonzero()).shape[0] != 0 else 0
80 | cur = Decimal(cur_i) / total
81 | start = conds['start_perc']
82 | end = conds['end_perc']
83 | if not (cur >= start and cur < end):
84 | return None
85 | return store.get_area_and_mult(*args, **kwargs)
86 |
87 | def KSAMPLER_sample(*args, **kwargs):
88 | orig_fn = store.KSAMPLER_sample
89 | extra_args = None
90 | model_options = None
91 | try:
92 | extra_args = kwargs['extra_args'] if 'extra_args' in kwargs else args[3]
93 | model_options = extra_args['model_options']
94 | except Exception: ...
95 | if model_options is not None and extra_args is not None:
96 | sigmas_ = kwargs['sigmas'] if 'sigmas' in kwargs else args[2]
97 | sigmas_all = model_options.pop('sigmas', None)
98 | sigmas = sigmas_all if sigmas_all is not None else sigmas_
99 | store.sigmas = sigmas
100 |
101 | import comfy.k_diffusion.sampling
102 | if hasattr(comfy.k_diffusion.sampling, 'default_noise_sampler_orig'):
103 | if getattr(comfy.k_diffusion.sampling.default_noise_sampler, 'init', False):
104 | comfy.k_diffusion.sampling.default_noise_sampler.init = False
105 | else:
106 | comfy.k_diffusion.sampling.default_noise_sampler = comfy.k_diffusion.sampling.default_noise_sampler_orig
107 |
108 | if 'Hijack' in comfy.k_diffusion.sampling.torch.__class__.__name__:
109 | if getattr(comfy.k_diffusion.sampling.torch, 'init', False):
110 | comfy.k_diffusion.sampling.torch.init = False
111 | else:
112 | if hasattr(comfy.k_diffusion.sampling, 'torch_orig'):
113 | comfy.k_diffusion.sampling.torch = comfy.k_diffusion.sampling.torch_orig
114 | return orig_fn(*args, **kwargs)
115 |
116 | def KSampler_sample(*args, **kwargs):
117 | orig_fn = store.KSampler_sample
118 | self = args[0]
119 | model_patcher = getattr(self, 'model', None)
120 | model_options = getattr(model_patcher, 'model_options', None)
121 | if model_options is not None:
122 | sigmas = None
123 | try: sigmas = kwargs['sigmas'] if 'sigmas' in kwargs else args[10]
124 | except Exception: ...
125 | if sigmas is None:
126 | sigmas = getattr(self, 'sigmas', None)
127 | if sigmas is not None:
128 | model_options = model_options.copy()
129 | model_options['sigmas'] = sigmas
130 | self.model.model_options = model_options
131 | return orig_fn(*args, **kwargs)
132 |
133 | def sample(*args, **kwargs):
134 | orig_fn = store.sample
135 | model_patcher = args[0]
136 | model_options = getattr(model_patcher, 'model_options', None)
137 | sampler = kwargs['sampler'] if 'sampler' in kwargs else args[6]
138 | if model_options is not None and Options.KEY in model_options:
139 | if hasattr(sampler, 'sampler_function'):
140 | opts = model_options[Options.KEY]
141 | if not hasattr(sampler, f'_sampler_function'):
142 | sampler._sampler_function = sampler.sampler_function
143 | sampler_function_sig_params = inspect.signature(sampler._sampler_function).parameters
144 | params = {x: getattr(opts, x) for x in ['eta', 's_churn', 's_tmin', 's_tmax', 's_noise'] if x in sampler_function_sig_params}
145 | sampler.sampler_function = lambda *a, **kw: sampler._sampler_function(*a, **{**kw, **params})
146 | else:
147 | if hasattr(sampler, '_sampler_function'):
148 | sampler.sampler_function = sampler._sampler_function
149 | return orig_fn(*args, **kwargs)
150 |
151 | def max_denoise(*args, **kwargs):
152 | orig_fn = store.max_denoise
153 | model_wrap = kwargs['model_wrap'] if 'model_wrap' in kwargs else args[1]
154 | base_model = getattr(model_wrap, 'inner_model', None)
155 | model_options = getattr(model_wrap, 'model_options', getattr(base_model, 'model_options', None))
156 | return orig_fn(*args, **kwargs) if getattr(model_options.get(Options.KEY, True), 'sgm_noise_multiplier', True) else False
157 |
158 | def sampling_function(*args, **kwargs):
159 | orig_fn = store.sampling_function
160 | model_options = kwargs['model_options'] if 'model_options' in kwargs else args[6]
161 | model_options=model_options.copy()
162 | kwargs['model_options'] = model_options
163 | if Options.KEY in model_options:
164 | opts = model_options[Options.KEY]
165 | if opts.s_min_uncond_all or opts.s_min_uncond > 0 or opts.skip_early_cond > 0:
166 | cond_scale = _cond_scale = kwargs['cond_scale'] if 'cond_scale' in kwargs else args[5]
167 | sigmas = store.sigmas
168 | sigma = kwargs['timestep'] if 'timestep' in kwargs else args[2]
169 | ts_in = find_nearest(sigma, sigmas)
170 | step = ss[0].item() if (ss:=(sigmas == ts_in).nonzero()).shape[0] != 0 else 0
171 | total_steps = sigmas.shape[0] - 1
172 |
173 | if opts.skip_early_cond > 0 and step / total_steps <= opts.skip_early_cond:
174 | cond_scale = 1.0
175 | elif (step % 2 or opts.s_min_uncond_all) and opts.s_min_uncond > 0 and sigma[0] < opts.s_min_uncond:
176 | cond_scale = 1.0
177 |
178 | if cond_scale != _cond_scale:
179 | if 'cond_scale' not in kwargs:
180 | args = args[:5]
181 | kwargs['cond_scale'] = cond_scale
182 |
183 | cond = kwargs['cond'] if 'cond' in kwargs else args[4]
184 | weights = [x.get('weight', None) for x in cond]
185 | has_some = any(item is not None for item in weights) and len(weights) > 1
186 | if has_some:
187 | out = CFGDenoiser(orig_fn).sampling_function(*args, **kwargs)
188 | else:
189 | out = orig_fn(*args, **kwargs)
190 | return out
191 |
192 |
193 | @contextlib.contextmanager
194 | def HijackClip(clip, opts):
195 | a1 = 'tokenizer', 'tokenize_with_weights'
196 | a2 = 'cond_stage_model', 'encode_token_weights'
197 | ls = [a1, a2]
198 | store = {}
199 | store_orig = {}
200 | try:
201 | for obj, attr in ls:
202 | for clip_name, v in iter_items(getattr(clip, obj).__dict__):
203 | if hasattr(v, attr):
204 | logger.debug(join_args(attr, obj, clip_name, type(v).__qualname__, getattr(v, attr).__qualname__))
205 | if clip_name not in store_orig:
206 | store_orig[clip_name] = {}
207 | store_orig[clip_name][obj] = v
208 | for clip_name, inner_store in store_orig.items():
209 | text_encoder = inner_store['cond_stage_model']
210 | tokenizer = inner_store['tokenizer']
211 | emphasis_name = 'Original' if opts.prompt_mean_norm else "No norm"
212 | if 't5' in clip_name:
213 | text_processing_engine = T5TextProcessingEngine(
214 | text_encoder=text_encoder,
215 | tokenizer=tokenizer,
216 | emphasis_name=emphasis_name,
217 | )
218 | else:
219 | text_processing_engine = ClassicTextProcessingEngine(
220 | text_encoder=text_encoder,
221 | tokenizer=tokenizer,
222 | emphasis_name=emphasis_name,
223 | )
224 | text_processing_engine.opts = opts
225 | text_processing_engine.process_texts_past = partial(process_texts_past, text_processing_engine)
226 | store[clip_name] = text_processing_engine
227 | for obj, attr in ls:
228 | setattr(inner_store[obj], attr, getattr(store[clip_name], attr))
229 | yield clip
230 | finally:
231 | for clip_name, inner_store in store_orig.items():
232 | getattr(inner_store[a2[0]], a2[1]).__self__.unhook()
233 | for obj, attr in ls:
234 | try: delattr(inner_store[obj], attr)
235 | except Exception: ...
236 | del store
237 | del store_orig
238 |
239 | @contextlib.contextmanager
240 | def HijackClipComfy(clip):
241 | a1 = 'tokenizer', 'tokenize_with_weights'
242 | ls = [a1]
243 | store_orig = {}
244 | try:
245 | for obj, attr in ls:
246 | for clip_name, v in iter_items(getattr(clip, obj).__dict__):
247 | if hasattr(v, attr):
248 | logger.debug(join_args(attr, obj, clip_name, type(v).__qualname__, getattr(v, attr).__qualname__))
249 | if clip_name not in store_orig:
250 | store_orig[clip_name] = {}
251 | store_orig[clip_name][obj] = v
252 | setattr(v, attr, partial(tokenize_with_weights_custom, v))
253 | yield clip
254 | finally:
255 | for clip_name, inner_store in store_orig.items():
256 | for obj, attr in ls:
257 | try: delattr(inner_store[obj], attr)
258 | except Exception: ...
259 | del store_orig
260 |
261 | def transform_schedules(steps, schedules, weight=None, with_weight=False):
262 | end_steps = [schedule.end_at_step for schedule in schedules]
263 | start_end_pairs = list(zip([0] + end_steps[:-1], end_steps))
264 | with_prompt_editing = len(schedules) > 1
265 |
266 | def process(schedule, start_step, end_step):
267 | nonlocal with_prompt_editing
268 | d = schedule.cond.copy()
269 | d.pop('cond', None)
270 | if with_prompt_editing:
271 | d |= {"start_perc": Decimal(start_step) / Decimal(steps), "end_perc": Decimal(end_step) / Decimal(steps), "init_steps": steps}
272 | if weight is not None and with_weight:
273 | d['weight'] = weight
274 | return d
275 | return [
276 | [
277 | schedule.cond.get("cond", None),
278 | process(schedule, start_step, end_step)
279 | ]
280 | for schedule, (start_step, end_step) in zip(schedules, start_end_pairs)
281 | ]
282 |
283 | def flatten(nested_list):
284 | return [item for sublist in nested_list for item in sublist]
285 |
286 | def convert_schedules_to_comfy(schedules, steps, multi=False):
287 | if multi:
288 | out = [[transform_schedules(steps, x.schedules, x.weight, len(batch)>1) for x in batch] for batch in schedules.batch]
289 | out = flatten(out)
290 | else:
291 | out = [transform_schedules(steps, sublist) for sublist in schedules]
292 | return flatten(out)
293 |
294 | def get_learned_conditioning(model, prompts, steps, multi=False, *args, **kwargs):
295 | if multi:
296 | schedules = prompt_parser.get_multicond_learned_conditioning(model, prompts, steps, *args, **kwargs)
297 | else:
298 | schedules = prompt_parser.get_learned_conditioning(model, prompts, steps, *args, **kwargs)
299 | schedules_c = convert_schedules_to_comfy(schedules, steps, multi)
300 | return schedules_c
301 |
302 | class CustomList(list):
303 | def __init__(self, *args):
304 | super().__init__(*args)
305 | def __setattr__(self, name: str, value: re.Any):
306 | super().__setattr__(name, value)
307 | return self
308 |
309 | def modify_locals_values(frame, fn):
310 | # https://stackoverflow.com/a/34671307
311 | try: ctypes.pythonapi.PyFrame_LocalsToFast(ctypes.py_object(frame), ctypes.c_int(1))
312 | except Exception: ...
313 | fn(frame)
314 | try: ctypes.pythonapi.PyFrame_LocalsToFast(ctypes.py_object(frame), ctypes.c_int(1))
315 | except Exception: ...
316 |
317 | def update_locals(frame,k,v,list_app=False):
318 | if not list_app:
319 | modify_locals_values(frame, lambda _frame: _frame.f_locals.__setitem__(k, v))
320 | else:
321 | if not isinstance(frame.f_locals[k], CustomList):
322 | out_conds_store = CustomList([])
323 | out_conds_store.outputs=[]
324 | modify_locals_values(frame, lambda _frame: _frame.f_locals.__setitem__(k, out_conds_store))
325 | v.area = frame.f_locals['area']
326 | v.mult = frame.f_locals['mult']
327 | frame.f_locals[k].outputs.append(v)
328 | frame.f_locals[k].out_conds = frame.f_locals['out_conds']
329 | frame.f_locals[k].out_counts = frame.f_locals['out_counts']
330 | modify_locals_values(frame, lambda _frame: _frame.f_locals.__setitem__('batch_chunks', 0))
331 |
332 | def model_function_wrapper_cd(model, args, id, model_options={}):
333 | input_x = args['input']
334 | timestep_ = args['timestep']
335 | c = args['c']
336 | cond_or_uncond = args['cond_or_uncond']
337 | batch_chunks = len(cond_or_uncond)
338 | if f'model_function_wrapper_{id}' in model_options:
339 | output = model_options[f'model_function_wrapper_{id}'](model, args)
340 | else:
341 | output = model(input_x, timestep_, **c)
342 | output.cond_or_uncond = cond_or_uncond
343 | output.batch_chunks = batch_chunks
344 | output.output_chunks = output.chunk(batch_chunks)
345 | output.chunk = lambda *aa, **kw: output
346 | get_parent_variable('out_conds', list, lambda frame: update_locals(frame, 'out_conds', output, list_app=True))
347 | return output
348 |
349 | def get_parent_variable(vname, vtype, fn):
350 | frame = inspect.currentframe().f_back # Get the current frame's parent
351 | while frame:
352 | if vname in frame.f_locals:
353 | val = frame.f_locals[vname]
354 | if isinstance(val, vtype):
355 | if fn is not None:
356 | fn(frame)
357 | return frame.f_locals[vname]
358 | frame = frame.f_back
359 | return None
360 |
361 | def cd_cfg_function(kwargs, id):
362 | model_options = kwargs['model_options']
363 | if f"sampler_cfg_function_{id}" in model_options:
364 | return model_options[f'sampler_cfg_function_{id}'](kwargs)
365 | x = kwargs['input']
366 | cond_pred = kwargs['cond_denoised']
367 | uncond_pred = kwargs['uncond_denoised']
368 | cond_scale = kwargs['cond_scale']
369 | cfg_result = model_options['cfg_result']
370 | cfg_result += (cond_pred - uncond_pred) * cond_scale
371 | return x - cfg_result
372 |
373 | class CFGDenoiser:
374 | def __init__(self, orig_fn) -> None:
375 | self.orig_fn = orig_fn
376 |
377 | def sampling_function(self, model, x, timestep, uncond, cond, cond_scale, model_options, *args0, **kwargs0):
378 | if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
379 | uncond_ = None
380 | else:
381 | uncond_ = uncond
382 |
383 | conds = [cond, uncond_]
384 |
385 | if uncond_ is None:
386 | return self.orig_fn(model, x, timestep, uncond, cond, cond_scale, model_options, *args0, **kwargs0)
387 |
388 | id = getrandbits(7)
389 | if 'model_function_wrapper' in model_options:
390 | model_options[f'model_function_wrapper_{id}'] = model_options.pop('model_function_wrapper')
391 | model_options['model_function_wrapper'] = partial(model_function_wrapper_cd, id=id, model_options=model_options)
392 | out = comfy.samplers.calc_cond_batch(model, conds, x, timestep, model_options)
393 | model_options.pop('model_function_wrapper', None)
394 | if f'model_function_wrapper_{id}' in model_options:
395 | model_options['model_function_wrapper'] = model_options.pop(f'model_function_wrapper_{id}')
396 |
397 | outputs = out.outputs
398 |
399 | out_conds = out.out_conds
400 | out_counts= out.out_counts
401 | if len(out_conds) < len(out_counts):
402 | for _ in out_counts:
403 | out_conds.append(torch.zeros_like(outputs[0].output_chunks[0]))
404 |
405 | oconds=[]
406 | for _output in outputs:
407 | cond_or_uncond=_output.cond_or_uncond
408 | batch_chunks=_output.batch_chunks
409 | output=_output.output_chunks
410 | area=_output.area
411 | mult=_output.mult
412 | for o in range(batch_chunks):
413 | cond_index = cond_or_uncond[o]
414 | a = area[o]
415 | if a is None:
416 | if cond_index == 0:
417 | oconds.append(output[o] * mult[o])
418 | else:
419 | out_conds[cond_index] += output[o] * mult[o]
420 | out_counts[cond_index] += mult[o]
421 | else:
422 | out_c = out_conds[cond_index] if cond_index != 0 else torch.zeros_like(out_conds[cond_index])
423 | out_cts = out_counts[cond_index]
424 | dims = len(a) // 2
425 | for i in range(dims):
426 | out_c = out_c.narrow(i + 2, a[i + dims], a[i])
427 | out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
428 | out_c += output[o] * mult[o]
429 | out_cts += mult[o]
430 | if cond_index == 0:
431 | oconds.append(out_c)
432 |
433 | for i in range(len(out_conds)):
434 | if i != 0:
435 | out_conds[i] /= out_counts[i]
436 |
437 | del out
438 | out = out_conds
439 |
440 | for fn in model_options.get("sampler_pre_cfg_function", []):
441 | out[0] = torch.cat(oconds).to(oconds[0])
442 | args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
443 | "input": x, "sigma": timestep, "model": model, "model_options": model_options}
444 | out = fn(args)
445 |
446 | # ComfyUI: last prompt -> first
447 | # conds were reversed in calc_cond_batch, so do the same for weights
448 | weights = [x.get('weight', None) for x in cond]
449 | weights.reverse()
450 | out_uncond = out[1]
451 | cfg_result = out_uncond.clone()
452 | cond_scale = cond_scale / max(len(oconds), 1)
453 |
454 | if "sampler_cfg_function" in model_options:
455 | model_options[f'sampler_cfg_function_{id}'] = model_options.pop('sampler_cfg_function')
456 | model_options['sampler_cfg_function'] = partial(cd_cfg_function, id=id)
457 | model_options['cfg_result'] = cfg_result
458 |
459 | # ComfyUI: computes the average -> do cfg
460 | # A1111: (cond - uncond) / total_len_of_conds -> in-place addition for each cond -> results in cfg
461 | for ix, ocond in enumerate(oconds):
462 | weight = (weights[ix:ix+1] or [1.0])[0] or 1.0
463 | # cfg_result += (ocond - out_uncond) * (weight * cond_scale) # all this code just to do this
464 | if f"sampler_cfg_function_{id}" in model_options:
465 | # case when there's another cfg_fn. subtract out_uncond and in-place add the result. feed result back in.
466 | cfg_result += comfy.samplers.cfg_function(model, ocond, out_uncond, weight * cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_) - out_uncond
467 | else: # calls cd_cfg_function and does an in-place addition
468 | if model_options.get("sampler_post_cfg_function", []):
469 | # feed the result back in.
470 | cfg_result = comfy.samplers.cfg_function(model, ocond, out_uncond, weight * cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)
471 | else:
472 | # default case. discards the output.
473 | comfy.samplers.cfg_function(model, ocond, out_uncond, weight * cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)
474 | model_options['cfg_result'] = cfg_result
475 | return cfg_result
476 |
477 | def tokenize_with_weights_custom(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
478 | '''
479 | Takes a prompt and converts it to a list of (token, weight, word id) elements.
480 | Tokens can both be integer tokens and pre computed CLIP tensors.
481 | Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
482 | Returned list has the dimensions NxM where M is the input size of CLIP
483 | '''
484 | min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length)
485 | min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
486 |
487 | text = escape_important(text)
488 | parsed_weights = token_weights(text, 1.0)
489 | embr = EmbbeddingRegex(self.embedding_directory)
490 |
491 | # tokenize words
492 | tokens = []
493 | for weighted_segment, weight in parsed_weights:
494 | to_tokenize = unescape_important(weighted_segment)
495 | split = re.split(' {0}|\n{0}'.format(self.embedding_identifier), to_tokenize)
496 | to_tokenize = [split[0]]
497 | for i in range(1, len(split)):
498 | to_tokenize.append("{}{}".format(self.embedding_identifier, split[i]))
499 |
500 | to_tokenize = [x for x in to_tokenize if x != ""]
501 | for word in to_tokenize:
502 | matches = embr.pattern.finditer(word)
503 | last_end = 0
504 | leftovers=[]
505 | for _, match in enumerate(matches, start=1):
506 | start=match.start()
507 | end_match=match.end()
508 | if (fragment:=word[last_end:start]):
509 | leftovers.append(fragment)
510 | ext = (match.group(4) or (match.group(3) or ''))
511 | embedding_sname = (match.group(2) or '').removesuffix(ext)
512 | embedding_name = embedding_sname + ext
513 | if embedding_name:
514 | embed, leftover = self._try_get_embedding(embedding_name)
515 | if embed is None:
516 | logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring")
517 | else:
518 | logger.debug(f'using embedding:{embedding_name}')
519 | if len(embed.shape) == 1:
520 | tokens.append([(embed, weight)])
521 | else:
522 | tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
523 | last_end = end_match
524 | if (fragment:=word[last_end:]):
525 | leftovers.append(fragment)
526 | word_new = ''.join(leftovers)
527 | end = 999999999999
528 | if self.tokenizer_adds_end_token:
529 | end = -1
530 | #parse word
531 | tokens.append([(t, weight) for t in self.tokenizer(word_new)["input_ids"][self.tokens_start:end]])
532 |
533 | #reshape token array to CLIP input size
534 | batched_tokens = []
535 | batch = []
536 | if self.start_token is not None:
537 | batch.append((self.start_token, 1.0, 0))
538 | batched_tokens.append(batch)
539 | for i, t_group in enumerate(tokens):
540 | #determine if we're going to try and keep the tokens in a single batch
541 | is_large = len(t_group) >= self.max_word_length
542 | if self.end_token is not None:
543 | has_end_token = 1
544 | else:
545 | has_end_token = 0
546 |
547 | while len(t_group) > 0:
548 | if len(t_group) + len(batch) > self.max_length - has_end_token:
549 | remaining_length = self.max_length - len(batch) - has_end_token
550 | #break word in two and add end token
551 | if is_large:
552 | batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
553 | if self.end_token is not None:
554 | batch.append((self.end_token, 1.0, 0))
555 | t_group = t_group[remaining_length:]
556 | #add end token and pad
557 | else:
558 | if self.end_token is not None:
559 | batch.append((self.end_token, 1.0, 0))
560 | if self.pad_to_max_length:
561 | batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
562 | #start new batch
563 | batch = []
564 | if self.start_token is not None:
565 | batch.append((self.start_token, 1.0, 0))
566 | batched_tokens.append(batch)
567 | else:
568 | batch.extend([(t,w,i+1) for t,w in t_group])
569 | t_group = []
570 |
571 | #fill last batch
572 | if self.end_token is not None:
573 | batch.append((self.end_token, 1.0, 0))
574 | if min_padding is not None:
575 | batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
576 | if self.pad_to_max_length and len(batch) < self.max_length:
577 | batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
578 | if min_length is not None and len(batch) < min_length:
579 | batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
580 |
581 | if not return_word_ids:
582 | batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
583 |
584 | return batched_tokens
585 |
586 | # ========================================================================
587 |
588 | from server import PromptServer
589 |
590 | def is_prompt_editing(schedules):
591 | if schedules == None: return False
592 | if not isinstance(schedules, dict):
593 | schedules = {'g': schedules}
594 | ret = False
595 | for k,v in schedules.items():
596 | if type(v) is dict and 'schedules' in v:
597 | v=v['schedules']
598 | if type(v) == list:
599 | for vb in v:
600 | if len(vb) != 1: ret = True
601 | else:
602 | if v:
603 | for vb in v.batch:
604 | for cs in vb:
605 | if len(cs.schedules) != 1: ret = True
606 | return ret
607 |
608 | def prompt_handler(json_data):
609 | data=json_data['prompt']
610 | steps_validator = lambda x: isinstance(x, (int, float, str))
611 | text_validator = lambda x: isinstance(x, str)
612 | def find_nearest_ksampler(clip_id):
613 | """Find the nearest KSampler node that references the given CLIPTextEncode id."""
614 | nonlocal data, steps_validator
615 | for ksampler_id, node in data.items():
616 | if "class_type" in node and ("Sampler" in node["class_type"] or "sampler" in node["class_type"]):
617 | # Check if this KSampler node directly or indirectly references the given CLIPTextEncode node
618 | if check_link_to_clip(ksampler_id, clip_id):
619 | return get_val(data, ksampler_id, steps_validator, 'steps')
620 | return None
621 |
622 | def get_val(graph, node_id, validator, val):
623 | node = graph.get(str(node_id), {})
624 | if val == 'steps':
625 | steps_input_value = node.get("inputs", {}).get("steps", None)
626 | if steps_input_value is None:
627 | steps_input_value = node.get("inputs", {}).get("sigmas", None)
628 | else:
629 | steps_input_value = node.get("inputs", {}).get(val, None)
630 |
631 | while(True):
632 | # Base case: it's a direct value
633 | if not isinstance(steps_input_value, list) and validator(steps_input_value):
634 | if val == 'steps':
635 | s = 1
636 | try: s = min(max(1, int(steps_input_value)), 10000)
637 | except Exception as e:
638 | logging.warning(f"\033[33mWarning:\033[0m [smZNodes] Skipping prompt editing. Try recreating the node. {e}")
639 | return s
640 | else:
641 | return steps_input_value
642 | # Loop case: it's a reference to another node
643 | elif isinstance(steps_input_value, list):
644 | ref_node_id, ref_input_index = steps_input_value
645 | ref_node = graph.get(str(ref_node_id), {})
646 | steps_input_value = ref_node.get("inputs", {}).get(val, None)
647 | if steps_input_value is None:
648 | keys = list(ref_node.get("inputs", {}).keys())
649 | ref_input_key = keys[ref_input_index % len(keys)]
650 | steps_input_value = ref_node.get("inputs", {}).get(ref_input_key)
651 | else:
652 | return None
653 |
654 | def check_link_to_clip(node_id, clip_id, visited=None):
655 | """Check if a given node links directly or indirectly to a CLIPTextEncode node."""
656 | nonlocal data
657 | if visited is None:
658 | visited = set()
659 |
660 | node = data[node_id]
661 |
662 | if node_id in visited:
663 | return False
664 | visited.add(node_id)
665 |
666 | for input_value in node["inputs"].values():
667 | if isinstance(input_value, list) and input_value[0] == clip_id:
668 | return True
669 | if isinstance(input_value, list) and check_link_to_clip(input_value[0], clip_id, visited):
670 | return True
671 |
672 | return False
673 |
674 |
675 | # Update each CLIPTextEncode node's steps with the steps from its nearest referencing KSampler node
676 | for clip_id, node in data.items():
677 | if "class_type" in node and node["class_type"] == "smZ CLIPTextEncode":
678 | check_str = prompt_editing = False
679 | if check_str:
680 | if (fast_search:=True):
681 | with_SDXL = get_val(data, clip_id, lambda x: isinstance(x, (bool, int, float)), 'with_SDXL')
682 | if with_SDXL:
683 | ls = is_prompt_editing_str(get_val(data, clip_id, text_validator, 'text_l'))
684 | gs = is_prompt_editing_str(get_val(data, clip_id, text_validator, 'text_g'))
685 | prompt_editing = ls or gs
686 | else:
687 | text = get_val(data, clip_id, text_validator, 'text')
688 | prompt_editing = is_prompt_editing_str(text)
689 | else:
690 | text = get_val(data, clip_id, text_validator, 'text')
691 | prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules([text], steps, None, False)
692 | prompt_editing = sum([len(ps) for ps in prompt_schedules]) != 1
693 | if check_str and not prompt_editing: continue
694 | steps = find_nearest_ksampler(clip_id)
695 | if steps is not None:
696 | node["inputs"]["smZ_steps"] = steps
697 | # logger.debug(f'id: {clip_id} | steps: {steps}')
698 | return json_data
699 |
700 | def is_prompt_editing_str(t: str):
701 | """
702 | Determine if a string includes prompt editing.
703 | This won't cover every case, but it does the job for most.
704 | """
705 | if t is None: return True
706 | if (openb:=t.find('[')) != -1:
707 | if (colon:=t.find(':', openb)) != -1 and t.find(']', colon) != -1:
708 | return True
709 | elif (pipe:=t.find('|', openb)) != -1 and t.find(']', pipe) != -1:
710 | return True
711 | return False
712 |
713 | if hasattr(PromptServer.instance, 'add_on_prompt_handler'):
714 | PromptServer.instance.add_on_prompt_handler(prompt_handler)
715 |
716 | # ========================================================================
717 |
718 | # DPM++ 2M alt
719 |
720 | from tqdm.auto import trange
721 | @torch.no_grad()
722 | def sample_dpmpp_2m_alt(model, x, sigmas, extra_args=None, callback=None, disable=None):
723 | """DPM-Solver++(2M)."""
724 | extra_args = {} if extra_args is None else extra_args
725 | s_in = x.new_ones([x.shape[0]])
726 | sigma_fn = lambda t: t.neg().exp()
727 | t_fn = lambda sigma: sigma.log().neg()
728 | old_denoised = None
729 |
730 | for i in trange(len(sigmas) - 1, disable=disable):
731 | denoised = model(x, sigmas[i] * s_in, **extra_args)
732 | if callback is not None:
733 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
734 | t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
735 | h = t_next - t
736 | if old_denoised is None or sigmas[i + 1] == 0:
737 | x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
738 | else:
739 | h_last = t - t_fn(sigmas[i - 1])
740 | r = h_last / h
741 | denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
742 | x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
743 | sigma_progress = i / len(sigmas)
744 | adjustment_factor = 1 + (0.15 * (sigma_progress * sigma_progress))
745 | old_denoised = denoised * adjustment_factor
746 | return x
747 |
748 |
749 | def add_sample_dpmpp_2m_alt():
750 | from comfy.samplers import KSampler, k_diffusion_sampling
751 | if "dpmpp_2m_alt" not in KSampler.SAMPLERS:
752 | try:
753 | idx = KSampler.SAMPLERS.index("dpmpp_2m")
754 | KSampler.SAMPLERS.insert(idx+1, "dpmpp_2m_alt")
755 | setattr(k_diffusion_sampling, 'sample_dpmpp_2m_alt', sample_dpmpp_2m_alt)
756 | except Exception: ...
757 |
758 | def add_custom_samplers():
759 | samplers = [
760 | add_sample_dpmpp_2m_alt,
761 | ]
762 | for add_sampler in samplers:
763 | add_sampler()
764 |
--------------------------------------------------------------------------------
/web/exif.js:
--------------------------------------------------------------------------------
1 | // Original file: /npm/exif-js@2.3.0/exif.js
2 | try{(function(){var d=!1,l=function(e){return e instanceof l?e:this instanceof l?void(this.EXIFwrapped=e):new l(e)};"undefined"!=typeof exports?("undefined"!=typeof module&&module.exports&&(exports=module.exports=l),exports.EXIF=l):this.EXIF=l;var u=l.Tags={36864:"ExifVersion",40960:"FlashpixVersion",40961:"ColorSpace",40962:"PixelXDimension",40963:"PixelYDimension",37121:"ComponentsConfiguration",37122:"CompressedBitsPerPixel",37500:"MakerNote",37510:"UserComment",40964:"RelatedSoundFile",36867:"DateTimeOriginal",36868:"DateTimeDigitized",37520:"SubsecTime",37521:"SubsecTimeOriginal",37522:"SubsecTimeDigitized",33434:"ExposureTime",33437:"FNumber",34850:"ExposureProgram",34852:"SpectralSensitivity",34855:"ISOSpeedRatings",34856:"OECF",37377:"ShutterSpeedValue",37378:"ApertureValue",37379:"BrightnessValue",37380:"ExposureBias",37381:"MaxApertureValue",37382:"SubjectDistance",37383:"MeteringMode",37384:"LightSource",37385:"Flash",37396:"SubjectArea",37386:"FocalLength",41483:"FlashEnergy",41484:"SpatialFrequencyResponse",41486:"FocalPlaneXResolution",41487:"FocalPlaneYResolution",41488:"FocalPlaneResolutionUnit",41492:"SubjectLocation",41493:"ExposureIndex",41495:"SensingMethod",41728:"FileSource",41729:"SceneType",41730:"CFAPattern",41985:"CustomRendered",41986:"ExposureMode",41987:"WhiteBalance",41988:"DigitalZoomRation",41989:"FocalLengthIn35mmFilm",41990:"SceneCaptureType",41991:"GainControl",41992:"Contrast",41993:"Saturation",41994:"Sharpness",41995:"DeviceSettingDescription",41996:"SubjectDistanceRange",40965:"InteroperabilityIFDPointer",42016:"ImageUniqueID"},c=l.TiffTags={256:"ImageWidth",257:"ImageHeight",34665:"ExifIFDPointer",34853:"GPSInfoIFDPointer",40965:"InteroperabilityIFDPointer",258:"BitsPerSample",259:"Compression",262:"PhotometricInterpretation",274:"Orientation",277:"SamplesPerPixel",284:"PlanarConfiguration",530:"YCbCrSubSampling",531:"YCbCrPositioning",282:"XResolution",283:"YResolution",296:"ResolutionUnit",273:"StripOffsets",278:"RowsPerStrip",279:"StripByteCounts",513:"JPEGInterchangeFormat",514:"JPEGInterchangeFormatLength",301:"TransferFunction",318:"WhitePoint",319:"PrimaryChromaticities",529:"YCbCrCoefficients",532:"ReferenceBlackWhite",306:"DateTime",270:"ImageDescription",271:"Make",272:"Model",305:"Software",315:"Artist",33432:"Copyright"},f=l.GPSTags={0:"GPSVersionID",1:"GPSLatitudeRef",2:"GPSLatitude",3:"GPSLongitudeRef",4:"GPSLongitude",5:"GPSAltitudeRef",6:"GPSAltitude",7:"GPSTimeStamp",8:"GPSSatellites",9:"GPSStatus",10:"GPSMeasureMode",11:"GPSDOP",12:"GPSSpeedRef",13:"GPSSpeed",14:"GPSTrackRef",15:"GPSTrack",16:"GPSImgDirectionRef",17:"GPSImgDirection",18:"GPSMapDatum",19:"GPSDestLatitudeRef",20:"GPSDestLatitude",21:"GPSDestLongitudeRef",22:"GPSDestLongitude",23:"GPSDestBearingRef",24:"GPSDestBearing",25:"GPSDestDistanceRef",26:"GPSDestDistance",27:"GPSProcessingMethod",28:"GPSAreaInformation",29:"GPSDateStamp",30:"GPSDifferential"},g=l.IFD1Tags={256:"ImageWidth",257:"ImageHeight",258:"BitsPerSample",259:"Compression",262:"PhotometricInterpretation",273:"StripOffsets",274:"Orientation",277:"SamplesPerPixel",278:"RowsPerStrip",279:"StripByteCounts",282:"XResolution",283:"YResolution",284:"PlanarConfiguration",296:"ResolutionUnit",513:"JpegIFOffset",514:"JpegIFByteCount",529:"YCbCrCoefficients",530:"YCbCrSubSampling",531:"YCbCrPositioning",532:"ReferenceBlackWhite"},m=l.StringValues={ExposureProgram:{0:"Not defined",1:"Manual",2:"Normal program",3:"Aperture priority",4:"Shutter priority",5:"Creative program",6:"Action program",7:"Portrait mode",8:"Landscape mode"},MeteringMode:{0:"Unknown",1:"Average",2:"CenterWeightedAverage",3:"Spot",4:"MultiSpot",5:"Pattern",6:"Partial",255:"Other"},LightSource:{0:"Unknown",1:"Daylight",2:"Fluorescent",3:"Tungsten (incandescent light)",4:"Flash",9:"Fine weather",10:"Cloudy weather",11:"Shade",12:"Daylight fluorescent (D 5700 - 7100K)",13:"Day white fluorescent (N 4600 - 5400K)",14:"Cool white fluorescent (W 3900 - 4500K)",15:"White fluorescent (WW 3200 - 3700K)",17:"Standard light A",18:"Standard light B",19:"Standard light C",20:"D55",21:"D65",22:"D75",23:"D50",24:"ISO studio tungsten",255:"Other"},Flash:{0:"Flash did not fire",1:"Flash fired",5:"Strobe return light not detected",7:"Strobe return light detected",9:"Flash fired, compulsory flash mode",13:"Flash fired, compulsory flash mode, return light not detected",15:"Flash fired, compulsory flash mode, return light detected",16:"Flash did not fire, compulsory flash mode",24:"Flash did not fire, auto mode",25:"Flash fired, auto mode",29:"Flash fired, auto mode, return light not detected",31:"Flash fired, auto mode, return light detected",32:"No flash function",65:"Flash fired, red-eye reduction mode",69:"Flash fired, red-eye reduction mode, return light not detected",71:"Flash fired, red-eye reduction mode, return light detected",73:"Flash fired, compulsory flash mode, red-eye reduction mode",77:"Flash fired, compulsory flash mode, red-eye reduction mode, return light not detected",79:"Flash fired, compulsory flash mode, red-eye reduction mode, return light detected",89:"Flash fired, auto mode, red-eye reduction mode",93:"Flash fired, auto mode, return light not detected, red-eye reduction mode",95:"Flash fired, auto mode, return light detected, red-eye reduction mode"},SensingMethod:{1:"Not defined",2:"One-chip color area sensor",3:"Two-chip color area sensor",4:"Three-chip color area sensor",5:"Color sequential area sensor",7:"Trilinear sensor",8:"Color sequential linear sensor"},SceneCaptureType:{0:"Standard",1:"Landscape",2:"Portrait",3:"Night scene"},SceneType:{1:"Directly photographed"},CustomRendered:{0:"Normal process",1:"Custom process"},WhiteBalance:{0:"Auto white balance",1:"Manual white balance"},GainControl:{0:"None",1:"Low gain up",2:"High gain up",3:"Low gain down",4:"High gain down"},Contrast:{0:"Normal",1:"Soft",2:"Hard"},Saturation:{0:"Normal",1:"Low saturation",2:"High saturation"},Sharpness:{0:"Normal",1:"Soft",2:"Hard"},SubjectDistanceRange:{0:"Unknown",1:"Macro",2:"Close view",3:"Distant view"},FileSource:{3:"DSC"},Components:{0:"",1:"Y",2:"Cb",3:"Cr",4:"R",5:"G",6:"B"}};function i(e){return!!e.exifdata}function r(i,o){function t(e){var t=p(e);i.exifdata=t||{};var n=function(e){var t=new DataView(e);d&&console.log("Got file of length "+e.byteLength);if(255!=t.getUint8(0)||216!=t.getUint8(1))return d&&console.log("Not a valid JPEG"),!1;var n=2,r=e.byteLength;for(;n")+8,u=(s=s.substring(s.indexOf("e.byteLength)return{};var u=P(e,t,t+l,g,r);if(u.Compression)switch(u.Compression){case 6:if(u.JpegIFOffset&&u.JpegIFByteCount){var c=t+u.JpegIFOffset,d=u.JpegIFByteCount;u.blob=new Blob([new Uint8Array(e.buffer,c,d)],{type:"image/jpeg"})}break;case 1:console.log("Thumbnail image format is TIFF, which is not implemented.");break;default:console.log("Unknown thumbnail image format '%s'",u.Compression)}else 2==u.PhotometricInterpretation&&console.log("Thumbnail image format is RGB, which is not implemented.");return u}(e,s,l,n),r}function b(e){var t={};if(1==e.nodeType){if(0 node.widgets.find((w) => w.name === name);
15 | export const findWidgetsByName = (node, name) => node.widgets.filter((w) => w.name.endsWith(name));
16 |
17 | // export const doesInputWithNameExist = (node, name) => node.inputs ? node.inputs.some((input) => input.name === name) : false;
18 |
19 | // round in increments of n, with an offset
20 | export function round(number, increment = 10, offset = 0) {
21 | return Math.ceil((number - offset) / increment ) * increment + offset;
22 | }
23 |
24 | export function toggleWidget(node, widget, force) {
25 | if (!widget) return;
26 | // if (!widget || doesInputWithNameExist(node, widget.name)) return;
27 | widget.options[HIDDEN_TAG] ??= (widget.options.origType = widget.type, widget.options.origComputeSize = widget.computeSize, HIDDEN_TAG);
28 |
29 | const hide = force ?? (widget.type !== HIDDEN_TAG);
30 |
31 | widget.type = hide ? widget.options[HIDDEN_TAG] : widget.options.origType;
32 |
33 | widget.computeSize = hide ? () => [0, -3.3] : widget.options.origComputeSize;
34 |
35 | widget.linkedWidgets?.forEach(w => toggleWidget(node, w, force));
36 |
37 | for (const el of ["inputEl", "input"])
38 | widget[el]?.classList?.toggle(HIDDEN_TAG, force);
39 |
40 | const height = hide ? node.size[1] : Math.max(node.computeSize()[1], node.size[1]);
41 | node.setSize([node.size[0], height]);
42 |
43 | if (hide)
44 | widget.computedHeight = 0;
45 | else
46 | delete widget.computedHeight;
47 | }
48 |
49 | // Toggles a menu widget for both its canvas and HTML counterpart. Also accounts for group nodes.
50 | // Passing an array with a companion_widget_name means
51 | // the toggle will happen for every name that looks like it.
52 | // Useful for duplicates created in group nodes.
53 | export function toggleMenuOption(node, widget_arr, show) {
54 | const [widget_name, companion_widget_name] = Array.isArray(widget_arr) ? widget_arr : [widget_arr]
55 | let arr = [widget_name];
56 | // companion_widget_name to use the new name assigned in a group node to get the correct widget
57 | if (companion_widget_name) {
58 | for (const gnc of getGroupNodeConfig(node)) {
59 | const omap = Object.values(gnc.oldToNewWidgetMap).find(x => Object.values(x).find(z => z === companion_widget_name));
60 | const n = omap[widget_name];
61 | if(n) arr.push(n);
62 | }
63 | }
64 | const widgets = companion_widget_name ? arr.map(it => findWidgetByName(node, it)) : findWidgetsByName(node, arr[0])
65 | const hide = show !== undefined ? !show : undefined;
66 | widgets.forEach(widget => {
67 | toggleWidget(node, widget, hide)
68 | node.setDirtyCanvas(true);
69 | });
70 | }
71 |
72 | export function getGroupNodeConfig(node) {
73 | let ls = []
74 | let nodeData = node.constructor?.nodeData
75 | if (nodeData) {
76 | for(const sym of Object.getOwnPropertySymbols(nodeData) ) {
77 | const o = nodeData[sym];
78 | if (o) ls.push(o)
79 | }
80 | }
81 | return ls
82 | }
83 |
84 | export function widgetLogic(node, widget) {
85 | const wname = widget.name
86 | const uoei = widgets[widgets.length - 2]
87 | if (wname.endsWith("parser")) {
88 | const in_comfy = widget?.value?.includes?.("comfy")
89 | // toggleMenuOption(node, ['multi_conditioning', wname], !in_comfy)
90 | toggleMenuOption(node, ['mean_normalization', wname], widget.value !== "comfy")
91 | toggleMenuOption(node, [uoei, wname], in_comfy ? false : findWidgetsByName(node, uoei)?.some(x => x?.value))
92 | } else if (wname.endsWith("with_SDXL")) {
93 | toggleMenuOption(node, ['text', wname], !widget.value)
94 | toggleMenuOption(node, ['multi_conditioning', wname], !widget.value)
95 |
96 | // Resize node when widget is set to false
97 | if (!widget.value) {
98 | // Prevents resizing on init/webpage reload
99 | if(widget.init === false) {
100 | // Resize when set to false
101 | node.setSize([node.size[0], Math.max(100, round(node.size[1]/1.5))])
102 | }
103 | } else {
104 | // When enabled, set init to false
105 | widget.init = false
106 | }
107 |
108 | // Toggle sdxl widgets if sdxl widget value is true/false
109 | for (const w of widgets_sdxl) {
110 | toggleMenuOption(node, [w, wname], widget.value)
111 | }
112 |
113 | // Keep showing the widget if it's enabled
114 | if (widget.value && widget.type === HIDDEN_TAG) {
115 | toggleMenuOption(node, [widget.name, wname], true)
116 | }
117 | }
118 | }
119 |
120 | // Specfic to cliptextencode
121 | function applyWidgetLogic(node) {
122 | if (!node.widgets?.length) return;
123 |
124 | const uoei = widgets[widgets.length - 2]
125 | const in_comfy = findWidgetsByName(node, "parser")?.some(it => it?.value?.includes?.("comfy"));
126 | const uoei_w = findWidgetByName(node, uoei);
127 | toggleMenuOption(node, [uoei, uoei], in_comfy ? false : uoei_w.value)
128 |
129 | let gncl = getGroupNodeConfig(node)
130 | for (const w of node.widgets) {
131 | for (const gsw of [...getSetWidgets]) {
132 | if (!w.name.endsWith(gsw)) continue;
133 | // Possibly uneeded
134 | /*
135 | let shouldBreak = false
136 | for (const gnc of gncl) {
137 | const nwmap = gnc.newToOldWidgetMap[w.name]
138 | console.log('=== gnc.newToOldWidgetMap',gnc.newToOldWidgetMap,'w.name',w.name,'nwmap',nwmap)
139 | // nwmap.inputName: resolved, actual widget name.
140 | if (nwmap && !(ids1.has(nwmap.node.type) && nwmap.inputName === gsw))
141 | shouldBreak = true
142 | }
143 | if (shouldBreak) break;
144 | */
145 | widgetLogic(node, w);
146 |
147 | let val = w.value;
148 | Object.defineProperty(w, 'value', {
149 | get() {
150 | return val;
151 | },
152 | set(newVal) {
153 | if (newVal !== val) {
154 | val = newVal
155 | widgetLogic(node, w);
156 | }
157 | }
158 | });
159 |
160 | // Hide SDXL widget on init
161 | // Doing it in nodeCreated fixes its toggling for some reason
162 | if (w.name.endsWith('with_SDXL')) {
163 | toggleMenuOption(node, ['with_SDXL', w.name])
164 | w.init = true
165 |
166 | // Hide steps
167 | toggleMenuOption(node, ['smZ_steps', w.name], false)
168 | }
169 | }
170 | }
171 |
172 | // Reduce initial node size cause of SDXL widgets
173 | // node.setSize([node.size[0], Math.max(node.size[1]/1.5, 220)])
174 | node.setSize([node.size[0], 220])
175 | }
176 |
177 | function create_custom_option(content, _callback) {
178 | return {
179 | content: content,
180 | callback: () => _callback(),
181 | }
182 | };
183 |
184 | function widgetLogicSettings(node) {
185 | const supported_types = {MODEL: 'MODEL', CLIP: 'CLIP'}
186 | const clip_headings = ['Stable Diffusion', 'Compatibility']
187 | const clip_entries =
188 | ["info_comma_padding_backtrack",
189 | "Prompt word wrap length limit",
190 | "enable_emphasis",
191 | "info_use_prev_scheduling",
192 | "Use previous prompt editing timelines"]
193 |
194 | if(!node.widgets?.length) return;
195 |
196 | const index = node.index || 0
197 |
198 | const extra = node.widgets.find(w => w.name.endsWith('extra'))
199 | let extra_data = extra._value
200 | toggleMenuOption(node, extra.name, false)
201 | const condition = (sup) => node.outputs?.[index] && (node.outputs[index].name === sup || node.outputs[index].type === sup) ||
202 | node.inputs?.[index] && node.inputs[index].type === sup;
203 | for (const w of node.widgets) {
204 | if (w.name.endsWith('extra')) continue;
205 | // w._name: to make sure it's from our node. though, it should have a better name for the variable
206 | if(!w._name) continue;
207 |
208 | // heading `values` won't get duplicated names due to group nodes
209 | // So we won't need to do `clip_headings.some()`
210 | // though, it should be read only...
211 | if (condition(supported_types.MODEL)) {
212 | const flag=(clip_entries.some(str => str.includes(w.name)) || (typeof w.value === 'string' && w.heading && w.value.includes('Compatibility')))
213 | if (w.info && !clip_entries.some(str => str.includes(w.name)))
214 | toggleMenuOption(node, [w.name, w.name], extra_data.show_descriptions)
215 | else if (w.heading && typeof w.value === 'string' && !w.value.includes('Compatibility'))
216 | toggleMenuOption(node, [w.name, w.name], extra_data.show_headings)
217 | else
218 | toggleMenuOption(node, [w.name, w.name], !flag)
219 | // toggleMenuOption(node, w.name, flag ? false : true) //doesn't work?
220 | } else if (condition(supported_types.CLIP)) {
221 | // if w.name in list -> enable, else disable
222 | const flag = clip_entries.some(str => str.includes(w.name))
223 | if (w.info && flag)
224 | toggleMenuOption(node, [w.name, w.name], extra_data.show_descriptions)
225 | else if (w.heading && typeof w.value === 'string' && clip_headings.includes(w.value))
226 | toggleMenuOption(node, [w.name, w.name], extra_data.show_headings)
227 | else
228 | toggleMenuOption(node, [w.name, w.name], flag)
229 | } else {
230 | toggleMenuOption(node, w.name, false)
231 | }
232 | }
233 |
234 | let skip = [HEADING_IDENTIFIER, "info", "extra"];
235 | let i = node.inputs.length;
236 | while (i--) {
237 | const ni = node.inputs[i];
238 | const wname = ni?.widget?.name;
239 | if (i !== 0 && skip.some(it => wname?.includes(it))) {
240 | node.inputs.splice(i, 1);
241 | }
242 | }
243 |
244 | node.setSize([node.size[0], node.computeSize()[1]])
245 | node.onResize?.(node.size)
246 | node.setDirtyCanvas(true);
247 | _app.graph.setDirtyCanvas(true, true);
248 | }
249 |
250 | _app.registerExtension({
251 | name: "Comfy.smZ.dynamicWidgets",
252 |
253 | init() {
254 | let style = document.createElement('style');
255 | let placeholder_opacity = 0.75;
256 | const createCssText = (it) => `.smZ-custom-textarea::-${it} { color: inherit; opacity: ${placeholder_opacity}}`;
257 | // WebKit, IE, Firefox
258 | ['webkit-input-placeholder', 'ms-input-placeholder', 'moz-placeholder']
259 | .map(createCssText).forEach(it => style.appendChild(document.createTextNode(it)));
260 | style.appendChild(document.createTextNode(`.${HIDDEN_TAG} { hidden: "hidden"; display: none; }`));
261 | document.head.appendChild(style);
262 | },
263 |
264 | /**
265 | * Called when a node is created. Used to add menu options to nodes.
266 | * @param node The node that was created.
267 | * @param app The app.
268 | */
269 | nodeCreated(node, app) {
270 | const nodeType = node.type || node.constructor?.type
271 | const anyType = "*";
272 | let inGroupNode = false
273 | let inGroupNode2 = false
274 | let nodeData = node.constructor?.nodeData
275 | if (nodeData) {
276 | for(let sym of Object.getOwnPropertySymbols(nodeData) ) {
277 | const nds = nodeData[sym];
278 | const nodes = nds?.nodeData?.nodes
279 | if (nodes) {
280 | for (const _node of nodes) {
281 | const _nodeType = _node.type
282 | if (ids1.has(_nodeType))
283 | inGroupNode = true
284 | if (inGroupNode)
285 | ids1.add(nodeType) // GroupNode's type
286 | if (ids2.has(_nodeType))
287 | inGroupNode2 = true
288 | if (inGroupNode2)
289 | ids2.add(nodeType) // GroupNode's type
290 | }
291 | }
292 | }
293 | }
294 |
295 | // ClipTextEncode++ node
296 | if (ids1.has(nodeType) || inGroupNode) {
297 | applyWidgetLogic(node)
298 | }
299 | // Settings node
300 | if (ids2.has(nodeType) || inGroupNode2) {
301 | if(!node.properties)
302 | node.properties = {}
303 | node.properties.showOutputText = true
304 |
305 |
306 | // allows bypass (in conjunction with below's `allows bypass`)
307 | // by setting node.inputs[0].type to a concrete type, instead of anyType. ComfyUI will complain otherwise.
308 | node.onBeforeConnectInput = function(inputIndex) {
309 | if (inputIndex !== node.index) return inputIndex
310 |
311 | // so we can connect to reroutes
312 | const tp = 'Reroute'
313 | node.type = tp
314 | if (node.constructor) node.constructor.type=tp
315 | this.type = tp
316 | if (this.constructor) this.constructor.type=tp
317 | Object.assign(node.inputs[inputIndex], {...node.inputs[inputIndex], name: anyType, type: anyType});
318 | Object.assign(this.inputs[inputIndex], {...this.inputs[inputIndex], name: anyType, type: anyType});
319 | node.beforeConnectInput = true
320 | this.beforeConnectInput = true
321 | return inputIndex;
322 | }
323 |
324 | // Call once on node creation
325 | node.setupWidgetLogic = function() {
326 | let nt = nodeType // JSON.parse(JSON.stringify(nodeType))
327 | node.type = nodeType
328 | if (node.constructor) node.constructor.type=nodeType
329 | node.applyOrientation = function() {
330 | node.type = nodeType
331 | if (node.constructor) node.constructor.type=nodeType
332 | }
333 | node.index = 0
334 | let i = 0
335 | const innerNode = node.getInnerNodes?.().find(n => {const r = ids2.has(n.type); if (r) node.index = i; ++i; return r } )
336 | const innerNodeWidgets = innerNode?.widgets
337 | i = 0
338 | node.widgets.forEach(function(w) {
339 | if (innerNodeWidgets) {
340 | if(innerNodeWidgets.some(iw => w.name.endsWith(iw.name)))
341 | w._name = w.name
342 | } else
343 | w._name = w.name
344 |
345 | // Styling.
346 | if (w.name.includes(HEADING_IDENTIFIER)) {
347 | w.heading = true
348 | // w.disabled = true
349 | } else if (w.name.includes('info')) {
350 | w.info = true
351 | for (const el of ["input", "inputEl"]) {
352 | if (w[el]) {
353 | w[el].disabled = true;
354 | w[el].readOnly = true;
355 | w[el].style.alignContent = 'center';
356 | w[el].style.textAlign = 'center';
357 | if (!w[el].classList.contains('smZ-custom-textarea'))
358 | w[el].classList.add('smZ-custom-textarea');
359 | }
360 | }
361 | }
362 | })
363 | const extra_widget = node.widgets.find(w => w.name.endsWith('extra'))
364 | if (extra_widget) {
365 | let extra_data = null
366 | try {
367 | extra_data = JSON.parse(extra_widget.value);
368 | } catch (error) {
369 | // when node definitions change due to an update or some other error
370 | extra_data = {show_headings: true, show_descriptions: false, mode: anyType}
371 | }
372 | extra_widget._value = extra_data
373 | Object.defineProperty(extra_widget, '_value', {
374 | get() {
375 | return extra_data;
376 | },
377 | set(newVal) {
378 | extra_data = newVal
379 | extra_widget.value = JSON.stringify(extra_data)
380 | }
381 | });
382 | }
383 |
384 | widgetLogicSettings(node);
385 | // Hijack getting our node type so we can work with Reroutes
386 | Object.defineProperty(node.constructor, 'type', {
387 | get() {
388 | let s = new Error().stack
389 | // const rr = ['rerouteNode.', 'baseCreateRenderer']
390 | const rr = ['rerouteNode.']
391 | // const rr = ['rerouteNode.js', 'reroutePrimitive.js']
392 | // const rr = ['rerouteNode.js', 'groupNode.js', 'reroutePrimitive.js']
393 | if (rr.some(rx => s.includes(rx))) {
394 | return 'Reroute'
395 | }
396 | return nt;
397 | },
398 | set(newVal) {
399 | if (newVal !== nt) {
400 | nt = newVal
401 | }
402 | }
403 | });
404 |
405 | if (node.outputs[node.index]) {
406 | let val = node.outputs[node.index].type;
407 | // Hijacks getting/setting type
408 | Object.defineProperty(node.outputs[node.index], 'type', {
409 | get() {
410 | return val;
411 | },
412 | set(newVal) {
413 | if (newVal !== val) {
414 | val = newVal
415 | // console.log('====== group node test. node', node, 'group', node.getInnerNodes?.())
416 | if (node.inputs && node.inputs[node.index]) node.inputs[node.index].type = newVal
417 | if (node.outputs && node.outputs[node.index]) node.outputs[node.index].name = newVal || anyType;
418 | node.properties.showOutputText = true // Reroute node accesses this
419 | node.type = nodeType
420 | if (node.constructor) node.constructor.type=nodeType
421 | // this.type = nodeType
422 | // if (this.constructor) this.constructor.type=nodeType
423 | // console.log('==== setupWidgetLogic', `val: '${val}' newval: '${newVal}' '${node.outputs[node.index].name}'`)
424 | widgetLogicSettings(node);
425 | }
426 | }
427 | });
428 | }
429 | }
430 | node.setupWidgetLogic()
431 |
432 | const onConfigure = node.onConfigure;
433 | node.onConfigure = function(o) {
434 | // Call again after the node is created since there might be link
435 | // For example: if we reload the graph
436 | node.setupWidgetLogic()
437 | const r = onConfigure ? onConfigure.apply(this, arguments) : undefined;
438 | return r;
439 | }
440 |
441 | // ================= Adapted from rerouteNode.js =================
442 | node.onAfterGraphConfigured = function () {
443 | requestAnimationFrame(() => {
444 | node.onConnectionsChange(LiteGraph.INPUT, null, true, null);
445 | });
446 | };
447 | node.onConnectionsChange = function (type, index, connected, link_info) {
448 | // node.index = index || 0
449 | // index = node.index || 0
450 | if (index !== node.index) return
451 | node.setupWidgetLogic()
452 | // if (index === node.index) node.setupWidgetLogic()
453 |
454 | const type_map = {
455 | [LiteGraph.OUTPUT] : 'OUTPUT',
456 | [LiteGraph.INPUT] : 'INPUT',
457 | }
458 |
459 | // console.log("======== onConnectionsChange type", type, "connected", connected, 'app.graph.links[l]',app.graph.links)
460 | // console.log('=== app.graph.links',app.graph.links, 'node.inputs[index]',node.inputs[index],'node.outputs[index]',node.outputs[index])
461 | // console.log('=== onConnectionsChange type', type_map[type], 'index',index,'connected',connected,'node.inputs', node.inputs,'node.outputs', node.outputs,'link_info',link_info, 'node', node)
462 |
463 | node.type = nodeType
464 | if (node.constructor) node.constructor.type=nodeType
465 | this.type = nodeType
466 | if (this.constructor) this.constructor.type=nodeType
467 |
468 |
469 | // Prevent multiple connections to different types when we have no input
470 | if (connected && type === LiteGraph.OUTPUT) {
471 | // Ignore wildcard nodes as these will be updated to real types
472 | const types = this.outputs?.[index]?.links ? new Set(this.outputs[index].links.map((l) => app.graph.links[l]?.type)?.filter((t) => t !== anyType)) : new Set()
473 | if (types?.size > 1) {
474 | const linksToDisconnect = [];
475 | for (let i = 0; i < this.outputs[index].links.length - 1; i++) {
476 | const linkId = this.outputs[index].links[i];
477 | const link = app.graph.links[linkId];
478 | if (link)
479 | linksToDisconnect.push(link);
480 | }
481 | for (const link of linksToDisconnect) {
482 | if (link) {
483 | const node = app.graph.getNodeById(link.target_id);
484 | node.disconnectInput(link.target_slot);
485 | }
486 | }
487 | }
488 | }
489 |
490 | // Find root input
491 | let currentNode = this;
492 | let updateNodes = [];
493 | let inputType = null;
494 | let inputNode = null;
495 | while (currentNode) {
496 | updateNodes.unshift(currentNode);
497 | const linkId = currentNode?.inputs?.[index]?.link;
498 | if (linkId) {
499 | const link = app.graph.links[linkId];
500 | if (!link) return;
501 | const node = app.graph.getNodeById(link.origin_id);
502 | const type = node.constructor.type || node.type;
503 | if (type === "Reroute" || ids2.has(type)) {
504 | if (node === this) {
505 | // We've found a circle
506 | currentNode.disconnectInput(link.target_slot);
507 | currentNode = null;
508 | } else {
509 | // Move the previous node
510 | currentNode = node;
511 | }
512 | } else {
513 | // We've found the end
514 | inputNode = currentNode;
515 | inputType = node.outputs[link.origin_slot]?.type ?? null;
516 | break;
517 | }
518 | } else {
519 | // This path has no input node
520 | currentNode = null;
521 | break;
522 | }
523 | }
524 |
525 | // Find all outputs
526 | const nodes = [this];
527 | // const nodes = [link_info?.origin_id ? app.graph.getNodeById(link_info.origin_id).outputs ?
528 | // app.graph.getNodeById(link_info.origin_id): this : this,this];
529 | let outputType = null;
530 | while (nodes.length) {
531 | currentNode = nodes.pop();
532 | // if (currentNode.outputs)
533 | const outputs = (currentNode && currentNode.outputs?.[index]?.links ? currentNode.outputs?.[index]?.links : []) || [];
534 | // console.log('=== .outputs',outputs,'currentNode',currentNode)
535 | if (outputs.length) {
536 | for (const linkId of outputs) {
537 | const link = app.graph.links[linkId];
538 |
539 | // When disconnecting sometimes the link is still registered
540 | if (!link) continue;
541 |
542 | const node = app.graph.getNodeById(link.target_id);
543 | const type = node.constructor.type || node.type;
544 |
545 | if (type === "Reroute" || ids2.has(type)) {
546 | // Follow reroute nodes
547 | nodes.push(node);
548 | updateNodes.push(node);
549 | } else {
550 | // We've found an output
551 | const nodeOutType =
552 | node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type
553 | ? node.inputs[link.target_slot].type
554 | : null;
555 | if (inputType && inputType !== anyType && nodeOutType !== inputType) {
556 | // The output doesnt match our input so disconnect it
557 | node.disconnectInput(link.target_slot);
558 | } else {
559 | outputType = nodeOutType;
560 | }
561 | }
562 | }
563 | } else {
564 | // No more outputs for this path
565 | }
566 | }
567 |
568 | const displayType = inputType || outputType || anyType;
569 | const color = LGraphCanvas.link_type_colors[displayType];
570 |
571 | let widgetConfig;
572 | let targetWidget;
573 | let widgetType;
574 | // Update the types of each node
575 | for (const node of updateNodes) {
576 | // If we dont have an input type we are always wildcard but we'll show the output type
577 | // This lets you change the output link to a different type and all nodes will update
578 | if (!(node.outputs && node.outputs[index])) continue
579 | node.outputs[index].type = inputType || anyType;
580 | node.__outputType = displayType;
581 | node.outputs[index].name = node.properties.showOutputText ? displayType : "";
582 | // node.size = node.computeSize();
583 | // node.applyOrientation();
584 |
585 | for (const l of node.outputs[index].links || []) {
586 | const link = app.graph.links[l];
587 | if (link) {
588 | link.color = color;
589 |
590 | if (app.configuringGraph) continue;
591 | const targetNode = app.graph.getNodeById(link.target_id);
592 | const targetInput = targetNode.inputs?.[link.target_slot];
593 | if (targetInput?.widget) {
594 | const config = getWidgetConfig(targetInput);
595 | if (!widgetConfig) {
596 | widgetConfig = config[1] ?? {};
597 | widgetType = config[0];
598 | }
599 | if (!targetWidget) {
600 | targetWidget = targetNode.widgets?.find((w) => w.name === targetInput.widget.name);
601 | }
602 |
603 | const merged = mergeIfValid(targetInput, [config[0], widgetConfig]);
604 | if (merged.customConfig) {
605 | widgetConfig = merged.customConfig;
606 | }
607 | }
608 | }
609 | }
610 | }
611 |
612 | for (const node of updateNodes) {
613 | if (!(node.inputs && node.inputs[index])) continue
614 | if (widgetConfig && outputType) {
615 | node.inputs[index].widget = { name: "value" };
616 | setWidgetConfig(node.inputs[index], [widgetType ?? displayType, widgetConfig], targetWidget);
617 | } else {
618 | setWidgetConfig(node.inputs[index], null);
619 | }
620 | }
621 |
622 | if (inputNode && inputNode.inputs[index]) {
623 | const link = app.graph.links[inputNode.inputs[index].link];
624 | if (link) {
625 | link.color = color;
626 | if (node.outputs?.[index]) {
627 | node.outputs[index].name = inputNode.__outputType || inputNode.outputs[index].type;
628 |
629 | // allows bypass
630 | node.inputs[index] = Object.assign(node.inputs[0], {...node.inputs[index], name: anyType, type: node.outputs[index].name});
631 | }
632 | }
633 | }
634 | }
635 | // ================= Adapted from rerouteNode.js =================
636 | }
637 |
638 | // Add extra MenuOptions for
639 | // ClipTextEncode++ and Settings node
640 | if (ids1.has(nodeType) || inGroupNode || ids2.has(nodeType) || inGroupNode2) {
641 | // Save the original options
642 | const getExtraMenuOptions = node.getExtraMenuOptions;
643 |
644 | node.getExtraMenuOptions = function (_, options) {
645 | // Call the original function for the default menu options
646 | const r = getExtraMenuOptions ? getExtraMenuOptions.apply(this, arguments) : undefined;
647 | let customOptions = []
648 | node.setDirtyCanvas(true, true);
649 | // if (!r) return r;
650 |
651 | if (ids2.has(nodeType) || inGroupNode2) {
652 | // Clean up MenuOption
653 | const hiddenWidgets = node.widgets.filter(w => w.type === HIDDEN_TAG || w.heading || w.info)
654 | let filtered = false
655 | let wo = options.filter(o => o === null || (o && !hiddenWidgets.some(w => {const i = o.content.includes(`Convert ${w.name} to input`); if(i) filtered = i; return i} )))
656 | options.splice(0, options.length, ...wo);
657 |
658 | if(hiddenWidgets.length !== node.widgets.length) {
659 | customOptions.push(null) // seperator
660 | const d=function(_node) {
661 | const extra = _node.widgets.find(w => w.name.endsWith('extra'))
662 | let extra_data = extra._value
663 | // extra_data.show_descriptions = !extra_data.show_descriptions
664 | extra._value = {...extra._value, show_descriptions: !extra_data.show_descriptions}
665 | widgetLogicSettings(_node)
666 | }
667 | const h=function(_node) {
668 | const extra = _node.widgets.find(w => w.name.endsWith('extra'))
669 | let extra_data = extra._value
670 | extra._value = {...extra._value, show_headings: !extra_data.show_headings}
671 | // extra_data.show_headings = !extra_data.show_headings
672 | widgetLogicSettings(_node)
673 | }
674 | customOptions.push(create_custom_option("Hide/show all headings", h.bind(this, node)))
675 | customOptions.push(create_custom_option("Hide/show all descriptions", d.bind(this, node)))
676 | }
677 | }
678 |
679 | if (ids1.has(nodeType) || inGroupNode) {
680 | // Dynamic MenuOption depending on the widgets
681 | const content_hide_show = "Hide/show ";
682 | // const whWidgets = node.widgets.filter(w => w.name === 'width' || w.name === 'height')
683 | const hiddenWidgets = node.widgets.filter(w => w.type === HIDDEN_TAG)
684 | // doesn't take GroupNode into account
685 | // const with_SDXL = node.widgets.find(w => w.name.endsWith('with_SDXL'))
686 | const parser = node.widgets.find(w => w.name.endsWith('parser'))
687 | const in_comfy = parser?.value?.includes?.("comfy")
688 | let ws = widgets.map(widget_name => create_custom_option(content_hide_show + widget_name, toggleMenuOption.bind(this, node, widget_name)))
689 | ws = ws.filter((w) => (in_comfy && parser.value !== 'comfy' && w.content.includes('mean_normalization')) || (in_comfy && w.content.includes('with_SDXL')) || !in_comfy || w.content.includes('multi_conditioning') )
690 | // customOptions.push(null) // seperator
691 | customOptions.push(...ws)
692 |
693 | let wo = options.filter(o => o === null || (o && !hiddenWidgets.some(w => o.content.includes(`Convert ${w.name} to input`))))
694 | const width = node.widgets.find(w => w.name.endsWith('width'))
695 | const height = node.widgets.find(w => w.name.endsWith('height'))
696 | if (width && height) {
697 | const width_type = width.type.toLowerCase()
698 | const height_type = height.type.toLowerCase()
699 | if (!(width_type.includes('number') || width_type.includes('int') || width_type.includes('float') ||
700 | height_type.includes('number') || height_type.includes('int') || height_type.includes('float')))
701 | wo = wo.filter(o => o === null || (o && !o.content.includes('Swap width/height')))
702 | }
703 | options.splice(0, options.length, ...wo);
704 | }
705 | // options.unshift(...customOptions); // top
706 | options.splice(options.length - 1, 0, ...customOptions)
707 | // return r;
708 | }
709 | }
710 | }
711 | });
712 |
--------------------------------------------------------------------------------