├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── assets └── insane_chain_example.png ├── blehconfig.example.json ├── blehconfig.example.yaml ├── changelog.md ├── docs └── blockops.md ├── py ├── __init__.py ├── better_previews │ ├── __init__.py │ ├── previewer.py │ └── tae_vid.py ├── latent_utils.py ├── nodes │ ├── __init__.py │ ├── blockCFG.py │ ├── deepShrink.py │ ├── hyperTile.py │ ├── misc.py │ ├── modelPatchConditional.py │ ├── ops.py │ ├── refinerAfter.py │ ├── sageAttention.py │ ├── samplers.py │ └── taevid.py └── settings.py └── ruff.toml /.gitignore: -------------------------------------------------------------------------------- 1 | blehconfig.json 2 | blehconfig.yaml 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BLEH 2 | 3 | A ComfyUI nodes collection of utility and model patching functions. Also includes improved previewer that allows previewing batches during generation. 4 | 5 | For recent user-visible changes, please see the [ChangeLog](changelog.md). 6 | 7 | ## Features 8 | 9 | 1. Better TAESD previews (see below). 10 | 2. Allow setting seed, timestep range and step interval for HyperTile (look for the [`BlehHyperTile`](#blehhypertile) node). 11 | 3. Allow applying Kohya Deep Shrink to multiple blocks, also allow gradually fading out the downscale factor (look for the [`BlehDeepShrink`](#blehdeepshrink) node). 12 | 4. Allow discarding penultimate sigma (look for the `BlehDiscardPenultimateSigma` node). This can be useful if you find certain samplers are ruining your image by spewing a bunch of noise into it at the very end (usually only an issue with `dpm2 a` or SDE samplers). 13 | 5. Allow more conveniently switching between samplers during sampling (look for the [BlehInsaneChainSampler](#blehinsanechainsampler) node). 14 | 6. Apply arbitrary model patches at an interval and/or for a percentage of sampling (look for the [BlehModelPatchConditional](#blehmodelpatchconditional) node). 15 | 7. Ensure a seed is set even when `add_noise` is turned off in a sampler. Yes, that's right: if you don't have `add_noise` enabled _no_ seed gets set for samplers like `euler_a` and it's not possible to reproduce generations. (look for the [BlehForceSeedSampler](#blehforceseedsampler) node). For `SamplerCustomAdvanced` you can use `BlehDisableNoise` to accomplish the same thing. 16 | 8. Allows swapping to a refiner model at a predefined time (look for the [BlehRefinerAfter](#blehrefinerafter) node). 17 | 9. Allow defining arbitrary model patches (look for the [BlehBlockOps](#blehblockops) node). 18 | 10. Experimental blockwise CFG type effect (look for the [BlehBlockCFG](#blehblockcfg) node). 19 | 11. [SageAttention](https://github.com/thu-ml/SageAttention/) support either globally or as a sampler wrapper. Look for the [BlehSageAttentionSampler](#blehsageattentionsampler) and `BlehGlobalSageAttention` nodes. 20 | 21 | ## Configuration 22 | 23 | Copy either `blehconfig.yaml.example` or `blehconfig.json.example` to `blehconfig.yaml` or `blehconfig.json` respectively and edit the copy. When loading configuration, the YAML file will be prioritized if it exists and Python has YAML support. 24 | 25 | Restart ComfyUI to apply any new changes. 26 | 27 | ### Better Previews 28 | 29 | * Supports setting max preview size (ComfyUI default is hardcoded to 512 max). 30 | * Supports showing previews for more than the first latent in the batch. 31 | * Supports throttling previews. Do you really need your expensive high quality preview to get updated 3 times a second? 32 | 33 | **General settings defaults:** 34 | 35 | |Key|Default|Description| 36 | |-|-|-| 37 | |`enabled`|`true`|Toggles whether enhanced TAESD previews are enabled| 38 | |`max_size`|`768`|Max width or height for previews. Note this does not affect TAESD decoding, just the preview image| 39 | |`max_width`|`max_size`|Same as `max_size` except allows setting the width independently. Previews may not work well with non-square max dimensions.| 40 | |`max_height`|`max_size`|Same as `max_size` except allows setting the height independently. Previews may not work well with non-square max dimensions.| 41 | |`max_batch`|`4`|Max number of latents in a batch to preview| 42 | |`max_batch_cols`|`2`|Max number of columns to use when previewing batches| 43 | |`throttle_secs`|`2`|Max frequency to decode the latents for previewing. `0.25` would be every quarter second, `2` would be once every two seconds| 44 | |`maxed_batch_step_mode`|`false`|When `false`, you will see the first `max_batch` previews, when `true` you will see previews spread across the batch. Also applies to video frames.| 45 | |`preview_device`|`null`|`null` (use the default device) or a string with a PyTorch device name like `"cpu"`, `"cuda:0"`, etc. Can be used to run TAESD previews on CPU or other available devices. Not recommended to change this unless you really need to, using the CPU device may prevent out of memory errors but will likely significantly slow down generation.| 46 | |`compile_previewer`|`false`|Controls whether the previewer gets compiled with `torch.compile`. May be a boolean or an object in which case the object will be used as argument to `torch.compile`. Note: May cause a delay/memory spike on the first preview.| 47 | |`oom_fallback`|`latent2rgb`|May be set to `none` or `latent2rgb`. Controls what happens if trying to decode the preview runs out of memory.| 48 | |`oom_retry`|`true`|If set to `false`, we will give up and use the `oom_fallback` behavior after hitting the first OOM. Otherwise, we'll attempt to decode with the normal previewer each time a preview is requested, even if that previously ran out of memory.| 49 | |`whitelist_formats`|(empty list)|List of latent formats to whitelist. See [example YAML config](blehconfig.example.yaml) for more information.| 50 | |`whitelist_formats`|(empty list)|List of latent formats to blacklist. See [example YAML config](blehconfig.example.yaml) for more information.| 51 | 52 | **Note**: Most options here that refer to batches will also apply to video models and in that case frames will be treated like batch items. Batches aren't supported when generating videos. 53 | 54 | **Image model settings defaults:** 55 | 56 | |Key|Default|Description| 57 | |-|-|-| 58 | |`skip_upscale_layers`|`0`|The TAESD model has three upscale layers, each doubles the size of the result. Skipping some of them will significantly speed up TAESD previews at the cost of smaller preview image results. You can set this to `-1` to automatically pop layers until at least one dimension is within the max width/height or `-2` to aggressively pop until _both_ dimensions are within the limit.| 59 | 60 | More detailed explanation for skipping upscale layers: Latents (the thing you're running the TAESD preview on) are 8 times smaller than the image you get decoding by normal VAE or TAESD. The TAESD decoder has three upscale layers, each doubling the size: `1 * 2 * 2 * 2 = 8`. So for example if normal decoding would get you a `1280x1280` image, skipping one TAESD upscale layer will get you a `640x640` result, skipping two will get you `320x320` and so on. I did some testing running TAESD decode on CPU for a `1280x1280` image: the base speed is about `1.95` sec base, `1.15` sec with one upscale layer skipped, `0.44` sec with two upscale layers skipped and `0.16` sec with all three upscale layers popped (of course you only get a `160x160` preview at that point). The upshot is if you are using TAESD to preview large images or batches or you want to run TAESD on CPU (normally pretty slow) you would probably benefit from setting `skip_upscale_layers` to `1` or `2`. Also if your max preview size is `768` and you are decoding a `1280x1280` image, it's just going to get scaled down to `768x768` anyway. 61 | 62 | **Video model settings defaults:** 63 | 64 | |Key|Default|Description| 65 | |-|-|-| 66 | |`video_parallel`|`false`|Use parallel mode when decoding video latents. May actually use more memory than a full VAE decode.| 67 | |`video_max_frames`|`-1`|Maximum frames to include in a preview. Frame limiting is treated like batch limiting. `-1` means unlimited.| 68 | |`video_temporal_upscale_level`|`0`|Number of temporal upscale blocks to use, 0 will not do any temporal upscaling, 2 means full temporal upscaling.| 69 | 70 | These defaults are conservative. I would recommend setting `throttle_secs` to something relatively high (like 5-10) especially if you are generating batches at high resolution. 71 | 72 | Slightly more detailed explanation for `maxed_batch_step_mode`: If max previews is set to `3` and the batch size is `15` you will see previews for indexes `0, 5, 10`. Or to put it a different way, it steps through the batch by `batch_size / max_previews` rounded up. This behavior may be useful for previewing generations with a high batch count like when using AnimateDiff. 73 | 74 | 75 | **Note**: Other node packs that patch ComfyUI's previewer behavior may interfere with this feature. One I am aware of is [ComfyUI-VideoHelperSuite](https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite) - if you have displaying animated previews turned on, it will overwrite Bleh's patched previewer. Or possibly, depending on the load order, Bleh will prevent it from working correctly. 76 | 77 | ### BlehModelPatchConditional 78 | 79 | **Note**: Very experimental. 80 | 81 | This node takes a `default` model and a `matched` model. When the interval or start/end percentage match, the `matched` model will apply, otherwise the `default` one will. This can be used to apply something like HyperTile, Self Attention Guidance or other arbitrary model patches conditionally. 82 | 83 | The first sampling step that matches the timestep range always applies `matched`, after that the following behavior applies: If the interval is positive then you just get `matched` every `interval` steps. It is also possible to set interval to a negative value, for example `-3` would mean out of every three steps, the first two use `matched` and the third doesn't. 84 | 85 | _Notes and limitations_: Not all types of model modifications/patches can be intercepted with a node like this. You also almost certainly can't use this to mix different models: both inputs should be instances of the same loaded model. It's also probably a bad idea to apply further patches on top of the `BlehModelPatchConditional` node output: it should most likely be the last thing before a sampler or something that actually uses the model. 86 | 87 | ### BlehHyperTile 88 | 89 | Adds the ability to set a seed and timestep range that HyperTile gets applied for. *Not* well tested, and I just assumed the Inspire version works which may or may not be the case. 90 | 91 | It is also possible to set an interval for HyperTile steps, this time it is just normal sampling steps that match the timestep range. The first sampling step that matches the timestep range always applies HyperTile, after that the following behavior applies: If the interval is positive then you just get HyperTile every `interval` steps. It is also possible to set interval to a negative value, for example `-3` would mean out of every three steps, the first two have HyperTile and the third doesn't. 92 | 93 | **Note**: Timesteps start from 999 and count down to 0 and also are not necessarily linear. Figuring out exactly which sampling step a timestep applies 94 | to is left as an exercise for you, dear node user. As an example, Karras and exponential samplers essentially rush to low timesteps and spend quite a bit of time there. 95 | 96 | HyperTile credits: 97 | 98 | The node was originally taken by Comfy from taken from: https://github.com/tfernd/HyperTile/ 99 | 100 | Then the Inspire node pack took it from the base ComfyUI node: https://github.com/ltdrdata/ComfyUI-Inspire-Pack 101 | 102 | Then I took it from the Inspire node pack. The original license was MIT so I assume yoinking it into this repo is probably okay. 103 | 104 | ### BlehDeepShrink 105 | 106 | AKA `PatchModelAddDownScale` AKA Kohya Deep Shrink. Compared to the built-in Deep Shrink node this version has the following differences: 107 | 108 | 1. Instead of choosing a block to apply the downscale effect to, you can enter a comma-separated list of blocks. This may or not actually be useful but it seems like you can get interesting effects applying it to multiple blocks. Try `2,3` or `1,2,3`. 109 | 2. Adds a `start_fadeout_percent` input. When this is less than `end_percent` the downscale will be scaled to end at `end_percent`. For example, if `downscale_factor=2.0`, `start_percent=0.0`, `end_percent=0.5` and `start_fadeout_percent=0.0` then at 25% you could expect `downscale_factor` to be around `1.5`. This is because we are deep shrinking between 0 and 50% and we are halfway through the effect range. (`downscale_factor=1.0` would of course be a no-op and values below 1 don't seem to work.) 110 | 3. Expands the options for upscale and downscale types, you can also turn on antialiasing for `bicubic` and `bilinear` modes. 111 | 112 | *Notes*: It seems like when shrinking multiple blocks, blocks downstream are also affected. So if you do x2 downscaling on 3 blocks, you are going to be applying `x2 * 3` downscaling to the lowest block (and maybe downstream ones?). I am not 100% sure how it works, but the takeway is you want to reduce the downscale amount when you are downscaling multiple blocks. For example, using blocks `2,3,4` and a downscale factor of `2.0` or `2.5` generating at 3072x3072 seems to work pretty well. Another note is schedulers that move at a steady pace seem to produce better results when fading out the deep shrink effect. In other words, exponential or Karras schedulers don't work well (and may produce complete nonsense). `ddim_uniform` and `sgm_uniform` seem to work pretty well and `normal` appears to be decent. 113 | 114 | Deep Shrink credits: 115 | 116 | Adapted from the ComfyUI source which I presume was adapted from the version Kohya initially published. 117 | 118 | ### BlehInsaneChainSampler 119 | 120 | **Note**: I'd recommend using my [Overly Complicated Sampling](https://github.com/blepping/comfyui_overly_complicated_sampling) node pack over this. It generally has better tools for scheduling samplers. 121 | 122 | A picture is worth a thousand words, so: 123 | 124 | ![Insane chain example](assets/insane_chain_example.png) 125 | 126 | This will use `heunpp2` for the first five steps, `euler_ancestral` for the next five, and `dpmpp_2m` for however many remain. 127 | 128 | This is basically the same as chaining a bunch of samplers together and manually setting the start/end steps. 129 | 130 | **Note**: Even though the `dpmpp_2m` insane chain sampler node has `steps=1` it will run for five steps. This is because the requirement of fifteen total steps must be fulfilled and... you can't really sample stuff without a sampler. Also note progress might be a little weird splitting sampling up like this. 131 | 132 | ### BlehForceSeedSampler 133 | 134 | Currently, the way ComfyUI's advanced and custom samplers work is if you turn off `add_noise` _no_ global RNG seed gets set. Samplers like `euler_a` use this (SDE samplers use a different RNG method and aren't subject to this issue). Anyway, the upshot is you will get a different generation every time regardless of what the seed is set to. This node simply wraps another sampler and ensures that the seed gets set. 135 | 136 | ### BlehDisableNoise 137 | 138 | Basically the same idea as `BlehForceSeedSampler`, however it is usable with `SamplerCustomAdvanced`. 139 | 140 | 141 | ### BlehPlug 142 | 143 | You can connect this node to any input and it will be the same as if the input had no connection. Why is this useful? It's mainly for [Use Everywhere](https://github.com/chrisgoringe/cg-use-everywhere) — sometimes it's desirable to leave an input unconnected, but if you have Use Everywhere broadcasting an output it can be inconvenient. Just shove a plug in those inputs. 144 | 145 | ### BlehSetSamplerPreset 146 | 147 | Allows associating a `SAMPLER` with a name in list of samplers (`bleh_preset_0`, etc) so you can use a custom sampler in places that do not allow custom sampling - FaceDetailer for example. You can adjust the number of presets by setting the environment variable `COMFYUI_BLEH_SAMPLER_PRESET_COUNT` - it defaults to 1 if unset. If set to 0, no sampler presets will be added to the list. 148 | 149 | This node needs to run before sampling with the preset begins - it takes a wildcard input with can be used to pass through something like the model or latent to make sure the node runs before sampling. **Note**: Since the input and outputs are wildcards, ComfyUI's normal type checking does not apply here - be sure you connect the output to something that supports the input type. For example, if you connect a `MODEL` to `any_input`, ComfyUI will let you connect that to something expecting `LATENT` which won't work very well. 150 | 151 | It's also possible to override the sigmas used for sampling - possibly to do something like Restart sampling in nodes that don't currently allow you to pass in sigmas. This is an advanced option, if you don't know that you need it then I suggest not connecting anything here. *Note*: If the sampler is adding noise then you likely will get unexpected results if the two sets of sigmas start at different values. (This could also be considered a feature since it effectively lets you apply a multiplier to the initial noise.) 152 | 153 | The `dummy_opt` input can be attached to anything and isn't used by the node. However, you can connect something like a string or integer and change it to ensure the node runs again and sets your preset. See the note below. 154 | 155 | *Note*: One thing to be aware of is that this node assigns the preset within the ComfyUI server when it runs, so if you are changing and using the same preset ID between samplers, you need to make sure the `BlehSetSamplerPreset` node runs before the corresponding sampler. For example, suppose you have a workflow that looks like `Set Preset 0 -> KSampler1 (with preset 0) -> Set Preset 0 -> KSampler2 (with preset 0)`. The `Set Preset` nodes will run before each KSampler as expected the first time you execute the workflow. However, if you go back and change a setting in `KSampler1` and run the workflow, this *won't* cause the first `Set Preset` node to run again so you'll be sampling with whatever got assigned to the preset with the second `Set Preset` node. You can change a value connected to the `dummy_opt` input to force the node to run again. 156 | 157 | ### BlehRefinerAfter 158 | 159 | Allows switching to a refiner model at a predefined time. There are three time modes: 160 | 161 | * `timestep`: Note that this is not a sampling step but a value between `0` and `999` where `999` is the beginning of sampling 162 | and `0` is the end. It is basically equivalent to the percentage of sampling remaining - `999` = ~99.9% sampling remaining. 163 | * `percent`: Value between `0.0` and `1.0` where `0.0` is the start of sampling and 1.0 is the end. Note that this is not 164 | based on sampling _steps_. 165 | * `sigma`: Advanced option. If you don't know what this is you probably don't need to worry about it. 166 | 167 | **Note**: This only patches the unet apply function, most other stuff including conditioning comes from the base model so 168 | you likely can only use this to swap between models that are closely related. For example, switching from SD 1.5 to 169 | SDXL is not going to work at all. 170 | 171 | ### BlehBlockCFG 172 | 173 | Experimental model patch that attempts to guide either `cond` (positive prompt) or `uncond` (negative prompt) away from its opposite. 174 | In other words, when applied to `cond` it will try to push it further away from what `uncond` is doing and vice versa. Stronger effect when 175 | applied to `cond` or output blocks. The defaults are reasonable for SD 1.5 (or as reasonable as weird stuff like this can be). 176 | 177 | Enter comma separated blocks numbers starting with one of **I**input, **O**utput or **M**iddle like `i4,m0,o4`. You may also use `*` rather than a block 178 | number to select all blocks in the category, for example `i*, o*` matches all input and all output blocks. 179 | 180 | The patch can be applied to the same model multiple times. 181 | 182 | Is it good, or even doing what I think? Who knows! Both positive and negative scales seem to have positive effect on the generation. Low negative scales applied to `cond` seem to make the generation bright and colorful. 183 | 184 | _Note_: Probably only works with SD 1.x and SDXL. Middle block patching will probably only work if you have [FreeU_Advanced](https://github.com/WASasquatch/FreeU_Advanced) installed. 185 | 186 | **Note**: Doesn't work correctly with Tiled Diffusion when using tile batch sizes over 1. 187 | 188 | ### BlehSageAttentionSampler 189 | 190 | Allows using the SageAttention attention optimization as a sampler wrapper. SageAttention 2.0.1 supports head sizes up to 128 and should have some effect for most models. Earlier SageAttention versions had more limited support and, for example, didn't support any of SD1.5's head sizes. You will probably notice the biggest difference for high resolution generations. 191 | 192 | If you run into custom nodes that don't seem to be honoring SageAttention (you can verify this with `sageattn_verbose: true` in the YAML options), feel free to let me know and I can probably add support. At this point the Bleh SageAttention stuff should work for most custom nodes. 193 | 194 | **Note:** Requires manually installing SageAttention into your Python environment. Should work with SageAttention 1.0 and 2.0.x (2.0.x currently requires CUDA 8+). Link: https://github.com/thu-ml/SageAttention 195 | 196 | 197 | ### BlehGlobalSageAttention 198 | 199 | Enables SageAttention (see description above) globally. Prefer using the sampler wrapper when possible as it has less sharp edges. 200 | 201 | **Note**: This isn't a real model patch. The settings are applied when the node runs, so, for example, if you enable it and then bypass the node that won't actually disable SageAttention. The node needs to actually run each time you want your settings applied. 202 | 203 | ### BlehBlockOps 204 | 205 | Very experimental advanced node that allows defining model patches using YAML. This node is still under development and may be changed. 206 | 207 | * [Extended BlockOps documentation](docs/blockops.md) 208 | 209 | #### Examples 210 | 211 | Just to show what's possible, you can implement model patches like FreeU or Deep Shrink using BlockOps. 212 | 213 | **FreeU V2** 214 | 215 | ```yaml 216 | # FreeU V2 b1=1.1, b2=1.2, s1=0.9, s2=0.2 217 | - if: 218 | type: output 219 | stage: 1 220 | ops: 221 | - [slice, 0.75, 1.1, 1, null, true] 222 | - [target_skip, true] 223 | - [ffilter, 0.9, none, 1.0, 1] 224 | - if: 225 | type: output 226 | stage: 2 227 | ops: 228 | - [slice, 0.75, 1.2, 1, null, true] 229 | - [target_skip, true] 230 | - [ffilter, 0.2, none, 1.0, 1] 231 | ``` 232 | 233 | **Kohya Deep Shrink** 234 | 235 | ```yaml 236 | # Deep Shrink, downscale 2, apply up to 35%. 237 | - if: 238 | type: input_after_skip 239 | block: 3 240 | to_percent: 0.35 241 | ops: [[scale, bicubic, bicubic, 0.5, 0.5, 0]] 242 | - if: 243 | type: output 244 | ops: [[unscale, bicubic, bicubic, 0]] 245 | ``` 246 | 247 | 248 | ### BlehLatentOps 249 | 250 | Basically the same as BlehBlockOps, except the condition `type` will be `latent`. Obviously stuff involving steps, percentages, etc do not apply. 251 | This node allows you to apply the blending/filtering/scaling operations to a latent. 252 | 253 | ### BlehLatentScaleBy 254 | 255 | Like the builtin `LatentScaleBy` node, however it allows setting the horizontal and vertical scaling types and scales independently 256 | as well as allowing providing an extended list of scaling options. Can also be useful for testing what different types of scaling or 257 | enhancement effects look like. 258 | 259 | ### BlehLatentBlend 260 | 261 | Allows blending latents using any of the blending modes available. 262 | 263 | ### BlehCast 264 | 265 | Advanced node: Allows tricking ComfyUI into thinking a value of one type is a different type. This does not actually convert anything, just lets you connect things that otherwise couldn't be connected. In other words, don't do it unless you know the actual object is compatible with the input. 266 | 267 | ### BlehSetSigmas 268 | 269 | Advanced sigma manipulation node which can be used to insert sigmas into other sigmas, adjust them, replace them or 270 | just manually enter a list of sigmas. Note: Experimental, not well tested. 271 | 272 | ### BlehEnsurePreviewer 273 | 274 | Ensures that Bleh's previewer is used. Generally not necessary unless some other custom node pack is overriding the default previewer. The node acts as a bridge for any input type. 275 | 276 | ### BlehTAEVideoEncode and BlehTAEVideoDecode 277 | 278 | Fast video latent encoding/decoding with models from madebyollin (same person that made TAESD). Supports WAN 2.1, Hunyuan and Mochi. The node has a toggle for parallel mode which is faster but may use a lot of memory. 279 | 280 | You will need to download the models and put them in `models/vae_approx`. Don't change the names. 281 | 282 | * **WAN 2.1**: https://github.com/madebyollin/taehv/blob/main/taew2_1.pth 283 | * **Hunyean**: https://github.com/madebyollin/taehv/blob/main/taehv.pth 284 | * **Mochi**: https://github.com/madebyollin/taem1/blob/main/taem1.pth 285 | 286 | *Note*: If you run into issues it's probably a problem with my implementation and not the TAE video models or original inference code. 287 | 288 | *** 289 | 290 | ## Scaling Types 291 | 292 | * bicubic: Generally the safe option. 293 | * bilinear: Like bicubic but slightly not as good? 294 | * nearest-exact 295 | * area 296 | * bislerp: Interpolates between tensors a and b using normalized linear interpolation. 297 | * colorize: Supposedly transfers color. May or may not work that way. 298 | * hslerp: Hybrid Spherical Linear Interporation, supposedly smooths transitions between orientations and colors. 299 | * bibislerp: Uses bislerp as the slerp function in bislerp. When slerping once just isn't enough. 300 | * cosinterp: Cosine interpolation. 301 | * cuberp: Cubic interpolation. 302 | * inject: Adds the value scaled by the ratio. Probably not the best for scaling. 303 | * lineardodge: Supposedly simulates a brightning effect. 304 | * random: Chooses a random relatively normal scaling function each time. My thought is this will avoid artifacts from 305 | a specific scaling type from getting reinforced each step. Generally only useful for Deep Shrink or 306 | [jankhdiffusion](https://github.com/blepping/comfyui_jankhidiffusion). 307 | * randomaa: Like `random`, however it will also choose a random antialias size. 308 | 309 | Scaling types like `bicubic+something` will apply the `something` enhancement after scaling. See below. 310 | 311 | Scaling types that start with `rev` like `revinject` reverse the arguments to the scaling function. 312 | For example, `inject` does `a + b * scale`, `revinject` does `b + a * scale`. When is this desirable? 313 | I really don't know! Just stuff to experiment with. It may or may not be useful. (`revcosinterp` looks better than `cosinterp` though.) 314 | 315 | **Note**: Scaling types like `random` are very experimental and may be modified or removed. 316 | 317 | ## Enhancement Types 318 | 319 | * randmultihighlowpass: Randomly uses multihighpass or multilowpass filter. Effect is generally quite strong. 320 | * randhilowpass: Randomly uses a highpass or lowpass filter. When you filter both high and low frequencies you are left with... 321 | nothing! The effect is very strong. May not be useful. 322 | * randlowbandpass: Randomly uses a bandpass or lowpass filter. 323 | * randhibandpass: Randomly uses a bandpass or highpass filter. 324 | * renoise1: Adds some gaussian noise. Starts off relatively weak and increases based on sigma. 325 | * renoise2: Adds some guassian noise. Starts relatively strong and decreases based on sigma. 326 | * korniabilateralblur: Applies a bilateral (edge preserving) blur effect. 327 | * korniagaussianblur: Applies a guassian blur effect. 328 | * korniasharpen: Applies a sharpen effect. 329 | * korniaedge: Applies an edge enhancement effect. 330 | * korniarevedge: Applies an edge softening effect - may not work correctly. 331 | * korniarandblursharp: Randomly chooses between blurring and sharpening. 332 | 333 | Also may be an item from [Filters](#filters). 334 | 335 | **Note**: These enhancements are very experimental and may be modified or removed. 336 | 337 | ## Credits 338 | 339 | Many latent blending and scaling and filter functions based on implementation from https://github.com/WASasquatch/FreeU_Advanced - thanks! 340 | 341 | TAE video model support based on code from https://github.com/madebyollin/taehv/. 342 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import nodes 4 | 5 | from . import py 6 | from .py import settings 7 | from .py.nodes import samplers 8 | 9 | BLEH_VERSION = 2 10 | 11 | 12 | settings.load_settings() 13 | 14 | from .py.nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 15 | 16 | 17 | def blep_init(): 18 | bi = getattr(nodes, "_blepping_integrations", {}) 19 | if "bleh" in bi: 20 | return 21 | bi["bleh"] = sys.modules[__name__] 22 | nodes._blepping_integrations = bi # noqa: SLF001 23 | samplers.add_sampler_presets() 24 | 25 | 26 | blep_init() 27 | 28 | __all__ = ("BLEH_VERSION", "NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "py") 29 | -------------------------------------------------------------------------------- /assets/insane_chain_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blepping/ComfyUI-bleh/9606ca236c05746e2b196361cdce411b7759f787/assets/insane_chain_example.png -------------------------------------------------------------------------------- /blehconfig.example.json: -------------------------------------------------------------------------------- 1 | { 2 | "betterTaesdPreviews": { 3 | "enabled": true, 4 | "max_size": 768, 5 | "max_batch": 4, 6 | "max_batch_cols": 2, 7 | "throttle_secs": 1, 8 | "maxed_batch_step_mode": false, 9 | "preview_device": null, 10 | "skip_upscale_layers": 0, 11 | "compile_previewer": false, 12 | "oom_fallback": "latent2rgb", 13 | "oom_retry": true 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /blehconfig.example.yaml: -------------------------------------------------------------------------------- 1 | # Copy this file to blehconfig.yaml 2 | betterTaesdPreviews: 3 | # If disabled, will use the old ComfyUI previewer. 4 | enabled: true 5 | 6 | # Maximum preview size (applies to both height and width). 7 | max_size: 768 8 | 9 | # Maximum preview width. If set, will override max_size. 10 | max_width: 768 11 | 12 | # Maximum preview height. If set, will override max_size. 13 | max_height: 768 14 | 15 | # Maximum batch items to preview. 16 | max_batch: 4 17 | 18 | # Maximum columns to use when previewing batches. 19 | max_batch_cols: 2 20 | 21 | # Minimum time between updating previews. The default will update the preview at most once per second. 22 | throttle_secs: 1 23 | 24 | # When enabled and previewing batches, you will see previews spread across the batch. Otherwise it will be the first max_batch items. 25 | maxed_batch_step_mode: false 26 | 27 | # Allows overriding the preview device, for example you could set it to "cpu". Note: Generally should be left 28 | # alone unless you know you need to change it. Previewing on CPU will likely be quite slow. 29 | preview_device: null 30 | 31 | # Allows skipping upscale layers in the TAESD model, may increase performance when previewing large images or batches. 32 | # May be set to -1 (conservative) or -2 (aggressive) to automatically calculate how many to skip. See README.md for details. 33 | skip_upscale_layers: 0 34 | 35 | # Controls whether the previewer model is compiled (using torch.compile). Only works if your 36 | # Torch version and GPU support compiling. This also may cause a delay/memory spike on decoding the first preview. 37 | # This may be a boolean or object with arguments to pass to torch.compile. For example: 38 | # compile_previewer: 39 | # mode: max-autotune 40 | # backend: inductor 41 | compile_previewer: false 42 | 43 | # Controls behavior if we run out of memory trying to decode the preview. 44 | # Possible values: none, latent2rgb 45 | oom_fallback: "latent2rgb" 46 | 47 | # When enabled, we will try to use the normal previewer on each call 48 | # and only use the fallback if the normal previewer fails. 49 | # When disabled, we use the fallback starting from the first OOM. 50 | oom_retry: true 51 | 52 | # List of lowercase latent format names from https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py 53 | # If the list is empty, this disables the whitelist. Otherwise, Bleh will 54 | # only handle previewing for formats in the list. 55 | whitelist_formats: [] 56 | 57 | # List of lowercase latent format names (see above). 58 | # Bleh will delegate to the normal previewer for any latent formats in the blacklist. 59 | blacklist_formats: [] 60 | 61 | # Controls whether video previewing uses parallel mode (faster, requires much more memory). 62 | video_parallel: false 63 | 64 | # Maximum frames to include in a preview. -1 means no limit. 65 | # Frame limiting is treated like batch limiting so maxed_batch_step mode, etc will apply here. 66 | video_max_frames: -1 67 | 68 | # Number of temporal upscale blocks to run. 2 is the maximum and will fully decode 69 | # the latent into image frames. 70 | video_temporal_upscale_level: 0 71 | -------------------------------------------------------------------------------- /changelog.md: -------------------------------------------------------------------------------- 1 | # Changes 2 | 3 | Note, only relatively significant changes to user-visible functionality will be included here. Most recent changes at the top. 4 | 5 | ## 20250504 6 | 7 | This is a fairly large set of changes. Please create an issue if you experience problems. 8 | 9 | * Support for TAE video models/video previewing. 10 | * Added `BlehTAEVideoEncode` and `BlehTAEVideoDecode` nodes for fast video latent encoding/decoding. 11 | * Added many more blend modes. 12 | * Added `BlehEnsurePreviewer` node (for use when other custom node packs overwrite Bleh's previewer). 13 | 14 | ## 20250313 15 | 16 | * Added OOM fallback to the previewer. 17 | * Added ability to compile the previewer (and a few other related options). 18 | * Added `BlehEnsurePreviewer` node that can be used to ensure Bleh's previewer is used if some other custom node overrides it. 19 | 20 | ## 20250119 21 | 22 | * Fixed SageAttention 1.x support and setting `tensor_layout` should work properly for SageAttention 2.x now. Please create an issue if you experience problems. 23 | 24 | ## 20250114 25 | 26 | * Added `BlehLatentBlend` node. 27 | * Added `BlehCast` node that lets crazy people connect things that shouldn't be connected. 28 | * Added `BlehSetSigmas` node. 29 | * Some BlockOps functions have expanded capabilities now. 30 | 31 | ## 20250109 32 | 33 | * The strategy the SageAttention nodes use to patch ComfyUI's attention should work for third-party custom nodes more reliably now. Please create an issue if you experience problems. 34 | 35 | ## 20241228 36 | 37 | * The SageAttention nodes can now take advantage of SageAttention 2.0.1 which supports head sizes up to 128. 38 | * Fixed an issue which prevented SageAttention from applying to models like Flux. 39 | 40 | ## 20241122 41 | 42 | * Added the `BlehGlobalSageAttention` and `BlehSageAttentionSampler` nodes. See the README for details. 43 | 44 | ## 20241103 45 | 46 | * Added the `BlehSetSamplerPreset` node and sampler presets feature. 47 | 48 | ## 20241021 49 | 50 | * Added `seed_offset` parameter to `BlehDisableNoise` and `BlehForceSeedSampler` nodes. This is to avoid a case where the same noise would be used during sampling as the initial noise. **Note**: Changes seeds. You can set `seed_offset` to 0 to get the same behavior as before. 51 | 52 | ## 20240830 53 | 54 | * Added the `BlehBlockCFG` node (see README for usage and details). 55 | * More scaling/blending types. Some of them don't work well with scaling and will be filtered, you can set the environment variable `COMFYUI_BLEH_OVERRIDE_NO_SCALE` if you want the full list to be available (but you might just get garbage if you try to use them for scaling). 56 | * Possibly better normalization function (may change seeds). Set the environment variable `COMFYUI_BLEH_ORIG_NORMALIZE` to disable. 57 | * TAESD previews should be faster. Also now can dynamically set the number of upscale layers to skip based on the preview size limits. Additionally it's possible to set the max preview width/height seperately - see the YAML example config. 58 | 59 | ## 20240506 60 | 61 | * Add many new scaling types. 62 | * Add enhancements that can be combined with scaling, also `apply_enhancement` blockops function. 63 | 64 | ## 20240423 65 | 66 | * Added `BlehPlug` and `BlehDisableNoise` (see README for usage and description). 67 | * Increased the available upscale/downscale types for `BlehDeepShrink`. 68 | 69 | ## 20240412 70 | 71 | * Added `BlehBlockOps` and `BlehLatentOps` nodes. 72 | 73 | ## 20240403 74 | 75 | * Added `BlehRefinerAfter` node. 76 | 77 | ## 20240218 78 | 79 | * Added `BlehForceSeedSampler` node. 80 | 81 | ## 20240216 82 | 83 | * Added `BlehModelPatchConditional` node (see README for usage and description). 84 | 85 | ## 20240208 86 | 87 | * Added `BlehInsaneChainSampler` node. 88 | * Added ability to run previews on a specific device and skip TAESD upscale layers for increased TAESD decoding performance (see README). 89 | 90 | ## 20240202 91 | 92 | * Added `BlehDiscardPenultimateSigma` node. 93 | 94 | ## 20240201 95 | 96 | * Added `BlehDeepShrink` node (see README for usage and description) 97 | * Add more upscale/downscale methods to the Deep Shrink node, allow setting a higher downscale factor, allow enabling antialiasing for `bilinear` and `bicubic` modes. 98 | 99 | ## 20240128 100 | 101 | * Removed CUDA-specific stuff from TAESD previewer as the performance gains were marginal and it had a major effect on VRAM usage. 102 | * (Hopefully) improved heuristics for batch preview layout. 103 | * Added `maxed_batch_step_mode` setting for TAESD previewer. 104 | * Fixed reversed HyperTile default start/end step values. 105 | * Allow only applying HyperTile at a step interval. 106 | -------------------------------------------------------------------------------- /docs/blockops.md: -------------------------------------------------------------------------------- 1 | 2 | # BlehBlockOps 3 | 4 | The top level YAML should consist of a list of objects with a condition `if`, a list of `ops` that run if the condition succeeds. 5 | Objects `then` and `else` also take the same form as the top level object and apply when the `if` condition matches (or not in the case of `else`). 6 | 7 | All object fields (`if`, `then`, `else`, `ops`) are optional. An empty object is valid, it just doesn't do anything. 8 | 9 | ```yaml 10 | - if: 11 | cond1: [value1, value2] 12 | cond2: value # Values may be specified as a list or single item. 13 | ops: [[opname1, oparg1, oparg2], [opname2, oparg1, oparg2]] 14 | then: 15 | if: [[opname1, oparg1, oparg2]] # Conditions may also be specified as a list. 16 | ops: [] # and so on 17 | else: 18 | ops: [] 19 | # then and else may also be nested to an arbitrary depth. 20 | ``` 21 | 22 | *Note*: Blocks match by default, conditions restrict them. So a block with no `if` matches everything. 23 | 24 | #### Blend Modes 25 | 26 | 1. bislerp: Interpolates between tensors a and b using normalized linear interpolation. 27 | 2. colorize: Supposedly transfers color. May or may not work that way. 28 | 3. cosinterp: Cosine interpolation. 29 | 4. cuberp 30 | 5. hslerp: Hybrid Spherical Linear Interporation, supposedly smooths transitions between orientations and colors. 31 | 6. inject: Inject just adds the value scaled by the ratio, so if ratio is `1.0` this simply adds it. 32 | 7. lerp: Linear interpolation. 33 | 8. lineardodge: Supposedly simulates a brightning effect. 34 | 35 | #### Filters 36 | 37 | 1. none 38 | 2. bandpass 39 | 3. lowpass: Allows low frequencies and suppresses high frequencies. 40 | 4. highpass: Allows high frequencies and suppresses low frequencies. 41 | 5. passthrough: Maybe doesn't do anything? 42 | 6. gaussianblur: Blur. 43 | 7. edge: Edge enhance. 44 | 8. sharpen: Sharpens the target. 45 | 9. multilowpass: The multi versions apply to multiple bands. 46 | 10. multihighpass 47 | 11. multipassthrough 48 | 12. multigaussianblur 49 | 13. multiedge 50 | 14. multisharpen 51 | 52 | Custom filters may also be defined. For example, `gaussianblur` in the YAML filter definition would be `[[10,0.5]]`, 53 | `sharpen` would be `[[10, 1.5]]`. 54 | 55 | #### Scaling Functions 56 | 57 | See [Scaling Types](#scaling-types) below. 58 | 59 | #### Conditions 60 | 61 | **`type`**: One of `input`, `input_after_skip`, `middle`, `output` (preceding are block patches), `latent`, `post_cfg`. 62 | **Note**: ComfyUI doesn't allow patching the middle blocks by default, this feature is only available if you have 63 | [FreeU Advanced](https://github.com/WASasquatch/FreeU_Advanced) installed and enabled. (It patches ComfyUI to support patching 64 | the middle blocks.) 65 | 66 | **`block`**: The block number. Only applies when type is `input`, `input_after_skip`, `middle` or `output`. 67 | 68 | **`stage`**: The model stage. Applies to the same types as `block`. You can think of this in terms of FreeU's `b1`, `b2` - the number is the stage. 69 | 70 | **`percent`**: Percentage of sampling completed as a number between `0.0` and `1.0`. Note that this is sampling percentage, not percentage of steps. 71 | Does not apply to type `latent`. 72 | 73 | **`from_percent`**: Matches when sampling is greater or equal to the percent. Same restrictions as `percent`. 74 | 75 | **`to_percent`**: Matches when sampling is less or equal to the percent. Same restrictions as `from_percent`. 76 | 77 | **`step`**: Only applies when sigmas are connected to the `BlehBlockOps` node. A step will be determined as the index of the closest 78 | matching sigma. In other words, if you don't connect sigmas that exactly match the sigmas used for sampling you won't get accurate steps. 79 | Does not apply to type `latent`. 80 | 81 | **`step_exact`**: Same restrictions as `step`, however will only be set if the current sigma _exactly_ matches a step. Otherwise the 82 | value will be `-1`. 83 | 84 | **`from_step`**: As above, but matches when the step is greater or equal to the value. 85 | 86 | **`from_step`**: As above, but matches when the step is less or equal to the value. 87 | 88 | **`step_interval`**: Same restrictions as the other step condition types. Matches when the step modulus interval is 0. In other words, 89 | every other step starting from the first step you'd use an interval of `2` and the `then` branch (since `1 % 2 == 1` which is not 0). 90 | 91 | **`cond`**: Generic condition, has two forms: 92 | 93 | *Comparison*: Takes three arguments: comparison type (`eq`, `ne`, `gt`, `lt`, `ge`, `le`), a condition type with 94 | a numeric value (`block`, `stage`, `percent`, `step`, `step_exact`) and a value or list of values to compare with. 95 | 96 | Example: 97 | ```yaml 98 | - if: [cond, [lt, percent, 0.35]] 99 | ``` 100 | 101 | *Logic*: Takes a logic operation type (`not`, `and`, `or`) and a list of condition blocks. **Note**: The logic operation is applied 102 | to the result of the condition block and not the fields within it. 103 | 104 | Example: 105 | ```yaml 106 | - if: 107 | cond: [not, 108 | [cond, [or, 109 | [cond, [lt, step, 1]], 110 | [cond, [gt, step, 5]], 111 | ]] 112 | ] # A verbose way of expressing step >= 1 and step <= 5 113 | - if: 114 | - [cond, [ge, step, 1]] 115 | - [cond, [le, step, 5]] # Same as above 116 | - if: [[from_step, 1], [to_step, 5]] # Also same as above 117 | ``` 118 | 119 | #### Operations 120 | 121 | Operations mostly modify a target which can be `h` or `hsp`. `hsp` is only a valid target when `type` is `output`. I think it has something 122 | to do with skip connections but I don't know the specifics. It's important for FreeU. 123 | 124 | Default values are show in parenthesis next to the operation argument name. You may supply an incomplete argument list, 125 | in which case default values will be used for the remaining arguments. Ex: `[flip]` is the same as `[flip, h]`. You may 126 | also specify the arguments as a map, keys that aren't included will use the default values. Ex: `[flip, {direction: h}]` 127 | 128 | **`slice`**: Applies a filtering operation on a slice of the target. 129 | 130 | 1. `scale`(`1.0`): Slice scale, `1.0` would mean apply to 100% of the target, `0.5` would mean 50% of it. 131 | 2. `strength`(`1.0`): Scales the target. `1.0` would mean 100%. 132 | 3. `blend`(`1.0`): Ratio of the transformed value to blend in. `1.0` means replace it with no blending. 133 | 4. `blend_mode`(`bislerp`): See the blend mode section. 134 | 5. `use_hidden_mean`(`true`): No idea what this does really, but FreeU V2 uses it when slicing and V1 doesn't. 135 | 136 | **`ffilter`**: Applies a Fourier filter operation to the target. 137 | 138 | 1. `scale`(`1.0`): Scales the target. `1.0` would mean 100%. 139 | 2. `filter`(`none`): May be a string with a predefined filter name or a list of lists defining filters. See the filter section. 140 | 3. `filter_strength`(`0.5`): Strength of the filter. `1.0` would mean to apply it at 100%. 141 | 4. `threshold`(`1`): Threshold for the Fourier filter. This generally should be 1. 142 | 143 | **`scale_torch`**: Scales the target up or down, using PyTorch's `interpolate` function. 144 | 145 | 1. `type`(`bicubic`): One of `bicubic`, `nearest`, `bilinear` or `area`. 146 | 2. `scale_width`(`1.0`): Ratio to scale the width. `2.0` would mean double it, `0.5` would mean half of it. 147 | 3. `scale_height`(`1.0`): As above. 148 | 4. `antialias`(`false`): `true` to apply antialiasing after scaling or `false`. 149 | 150 | **`unscale_torch`**: Scale the target to be the same size as `hsp`. Only can be used when the target isn't `hsp` and condition `type` is `output`. 151 | Can be used to reverse a `scale` or `scale_torch` operation without having to worry about calculating the ratios to get the original size back. 152 | 153 | 1. `type`(`bicubic`): Same as `scale_torch`. 154 | 2. `antialias`(`false`): Same as `scale_torch`. 155 | 156 | **`scale`**: Scales the target up or down using various functions. See the scaling functions section. 157 | 158 | 1. `type_width`(`bicubic`): Scaling function to use for width. Note if the type is one of the ones from `scale_torch` it cannot be combined with other scaling functions. 159 | 2. `type_height`(`bicubic`): As above. 160 | 3. `scale_width`(`1.0`): Ratio to scale the width. `2.0` would mean double it, `0.5` would mean half of it. 161 | 4. `scale_height`(`1.0`): As above. 162 | 5. `antialias_size`(`0`): Size of the antialias kernel. Between 1 and 7 inclusive. Higher numbers seem to increase blurriness. 163 | 164 | **`unscale`**: Like `unscale_torch` except it supports more scale functions and can specify width/height scale function independently. 165 | Same restriction as `scale`. 166 | 167 | 1. `type_width`(`bicubic`): Scaling function to use for width. Note if the type is one of the ones from `scale_torch` it cannot be combined with other scaling functions. 168 | 2. `type_height`(`bicubic`): As above. 169 | 3. `antialias_size`(`0`): Size of the antialias kernel. Between 1 and 7 inclusive. Higher numbers seem to increase blurriness. 170 | 171 | **`flip`**: Flips the target. 172 | 173 | 1. `direction`(`h`): `h` for horizontal flip, `v` for vertical. Note that latents generally don't tolerate being flipped very well. 174 | 175 | **`rot90`**: Does a 90 degree rotation of the target. 176 | 177 | 1. `count`(`1`): Number of times to rotate (can also be negative). Note that if you rotate in a way that makes the tensors not match then stuff will probably break. 178 | also as with `flip` it generally is pretty destructive to latents. 179 | 180 | **`roll`**: Rotates the values in a dimension of the target. 181 | 182 | 1. `direction`(`c`): `horizontal`, `vertical`, `channels`. Note that when `type` is `input`, `input_after_skip`, `middle` or `output` you aren't actually dealing 183 | with a latent. The second dimension ("channels") is actually the features in the layer. Rotating them can produce some pretty weird effects. 184 | 2. `amount`(`1`): If it's a number greater than `-1.0` and less than `1.0` this will rotate forward or backward by a percentage of the size. Otherwise it is 185 | interpreted as the number of items to rotate forward or backward. 186 | 187 | **`roll_channels`**: Same as `roll` but you only specify the count, it always targets channels and you can't use percentages. 188 | 189 | 1. `count`(`1`): Number of channels to rotate. May be negative. 190 | 191 | **`target_skip`**: Changes the target. 192 | 193 | 1. `active`(`true`): If `true` will target `hsp`, otherwise will target `h`. Targeting `hsp` is only allowed when `type` is `output`, no effect otherwise. 194 | 195 | **`multiply`**: Multiply the target by the value. 196 | 197 | 1. `factor`(`1.0`): Multiplier. `2.0` would double all values in the target. 198 | 199 | **`antialias`**: Applies an antialias effect to the target. Works the same ase with `scale`. 200 | 201 | 1. `size`(`7`): The antialias kernel size as a number between 1 and 7. 202 | 203 | **`noise`**: Adds noise to the target. Can only be used when sigmas are connected. Noise will be scaled by `sigma - sigma_next`. 204 | 205 | 1. `scale`(`0.5`): Additionally scale the noise by the supplied factor. `1.0` would mean no scaling, `2.0` would double it, etc. 206 | 2. `type`(`gaussian`): Only `gaussian` unless [ComfyUI-sonar](https://github.com/blepping/ComfyUI-sonar) is installed and active, otherwise 207 | you may use the additional noise types Sonar provides. 208 | 3. `scale_mode`(`sigdiff`): `sigdiff` scales the noise by the current sigma minus the next (requires sigmas connected), 209 | `sigma` scales by the current sigma, `none` or an invalid type uses no scaling (you get exactly `noise * scale`). 210 | 211 | **`debug`**: Outputs some debug information about the state. 212 | 213 | **`blend_op`**: Allows applying a blend function to the result of another operation. 214 | 215 | 1. `blend`(`1.0`): Ratio of the transformed value to blend in. 216 | 2. `blend_mode`(`bislerp`): See the blend mode section. 217 | 3. `ops`(empty): The operation as a list, with the name first. i.e. `[blend_op, 0.5, inject, [multiply, 0.5]]`. May also be a list of operations. 218 | 219 | **`pad`**: Pads the target. 220 | 221 | 1. `mode`(`reflect`): One of `constant`, `reflect`, `replicate`, `circular` - see https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html#torch.nn.functional.pad 222 | 2. `top`(`0`): Amount of top padding. If this is a floating point value, it will be treated as a percentage of the dimension. 223 | 3. `bottom`(`0`): " " " 224 | 4. `left`(`0`): " " " 225 | 5. `right`(`0`): " " " 226 | 6. `constant`(`0`): Constant value to use, only applies when mode is `constant`. 227 | 228 | _Note_: If you pad `input` (rather than `input_after_skip`) then you will need to crop the corresponding block in `output` 229 | for both `h` and `hsp` (i.e. with `target_skip`). 230 | 231 | **`crop`**: Crops the target. 232 | 233 | 1. `top`(`0`): Items to crop from the top. If this is a floating point value, it will be treated as a percentage of the dimension. 234 | 2. `bottom`(`0`): " " " 235 | 3. `left`(`0`): " " " 236 | 4. `right`(`0`): " " " 237 | 238 | **`mask_example_op`**: Applies providing a mask by example and masks the result of an operation or list of operations. 239 | 240 | 1. `scale_mode`(`bicubic`) type: Same as with `scale`. 241 | 2. `antialias`(`7`) size: Same as with `scale`. 242 | 3. `mask`(mask targeting corners): A two dimensional list of mask values. See below. 243 | 4. `ops`(empty): Same as with `blend_op`. 244 | 245 | Simple example of a mask: 246 | 247 | ```plaintext 248 | [ [1.0, 0.0, 0.0, 1.0], 249 | [0.0, 0.0, 0.0, 0.0], 250 | [1.0, 0.0, 0.0, 1.0], 251 | ] 252 | ``` 253 | 254 | With this mask, the result of the mask ops will be applied at full strength to the corners. The mask is scaled up to 255 | the size of the target tensor, so with this example the masked corners will be proportionately quite large if the 256 | latent or tensor is much bigger than the mask. There are two convenience tricks for defining larger masks without 257 | having to specify each value: 258 | 259 | * If the first element in a row is `"rep"` then the second element is interpreted as a row repeat count and the 260 | rest of the items in the row constitute the row. Ex: `["rep", 2, 1, 0, 1]` expands to two rows of `1, 0, 1`. 261 | * If a column item is a list, the first element is interpreted as the repeat count and the remaining items are repeated 262 | however many times. Ex: `[2, 1.2, 0.5]` as a column would expand to `1.2, 0.5, 1.2, 0.5`. 263 | 264 | These two shortcuts can be combined. A mask of `[["rep", 2, 1, [3, 0], 2]]` expands to: 265 | 266 | ```plaintext 267 | [ 268 | [1, 0, 0, 0, 2], 269 | [1, 0, 0, 0, 2], 270 | ] 271 | ``` 272 | 273 | **`apply_enhancement`**: Applies an [enhancement](#enhancement-types) to the target. 274 | 275 | 1. `scale`: 1.0 276 | 2. `type`: korniabilateralblur 277 | 278 | -------------------------------------------------------------------------------- /py/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | better_previews, 3 | latent_utils, 4 | nodes, 5 | settings, 6 | ) 7 | 8 | __all__ = ("better_previews", "latent_utils", "nodes", "settings") 9 | -------------------------------------------------------------------------------- /py/better_previews/__init__.py: -------------------------------------------------------------------------------- 1 | from . import previewer, tae_vid 2 | 3 | __all__ = ("previewer", "tae_vid") 4 | -------------------------------------------------------------------------------- /py/better_previews/previewer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | from time import time 5 | from typing import TYPE_CHECKING, NamedTuple 6 | 7 | import folder_paths 8 | import latent_preview 9 | import torch 10 | from comfy import latent_formats 11 | from comfy.cli_args import LatentPreviewMethod 12 | from comfy.cli_args import args as comfy_args 13 | from comfy.model_management import device_supports_non_blocking 14 | from comfy.taesd.taesd import TAESD 15 | from PIL import Image 16 | from tqdm import tqdm 17 | 18 | from ..settings import SETTINGS # noqa: TID252 19 | from .tae_vid import TAEVid 20 | 21 | if TYPE_CHECKING: 22 | from pathlib import Path 23 | 24 | import numpy as np 25 | 26 | _ORIG_PREVIEWER = latent_preview.TAESDPreviewerImpl 27 | _ORIG_GET_PREVIEWER = latent_preview.get_previewer 28 | 29 | LAST_LATENT_FORMAT = None 30 | 31 | 32 | class VideoModelInfo(NamedTuple): 33 | latent_format: latent_formats.LatentFormat 34 | fps: int = 24 35 | temporal_compression: int = 8 36 | tae_model: str | Path | None = None 37 | 38 | 39 | VIDEO_FORMATS = { 40 | "mochi": VideoModelInfo( 41 | latent_formats.Mochi, 42 | temporal_compression=6, 43 | tae_model="taem1.pth", 44 | ), 45 | "hunyuanvideo": VideoModelInfo( 46 | latent_formats.HunyuanVideo, 47 | temporal_compression=4, 48 | tae_model="taehv.pth", 49 | ), 50 | "cosmos1cv8x8x8": VideoModelInfo(latent_formats.Cosmos1CV8x8x8), 51 | "wan21": VideoModelInfo( 52 | latent_formats.Wan21, 53 | fps=16, 54 | temporal_compression=4, 55 | tae_model="taew2_1.pth", 56 | ), 57 | } 58 | 59 | 60 | class ImageWrapper: 61 | def __init__(self, frames: tuple, frame_duration: int): 62 | self._frames = frames 63 | self._frame_duration = frame_duration 64 | 65 | def save(self, fp, format: str | None, **kwargs: dict): # noqa: A002 66 | if len(self._frames) == 1: 67 | return self._frames[0].save(fp, format, **kwargs) 68 | kwargs |= { 69 | "loop": 0, 70 | "save_all": True, 71 | "append_images": self._frames[1:], 72 | "duration": self._frame_duration, 73 | } 74 | return self._frames[0].save(fp, "webp", **kwargs) 75 | 76 | def resize(self, *args: list, **kwargs: dict) -> ImageWrapper: 77 | return ImageWrapper( 78 | tuple(frame.resize(*args, **kwargs) for frame in self._frames), 79 | frame_duration=self._frame_duration, 80 | ) 81 | 82 | def __getattr__(self, key): 83 | return getattr(self._frames[0], key) 84 | 85 | 86 | class FallbackPreviewerModel(torch.nn.Module): 87 | @torch.no_grad() 88 | def __init__( 89 | self, 90 | latent_format: latent_formats.LatentFormat, 91 | *, 92 | dtype: torch.dtype, 93 | device: torch.device, 94 | scale_factor: float = 8.0, 95 | upscale_mode: str = "bilinear", 96 | ): 97 | super().__init__() 98 | 99 | raw_factors = latent_format.latent_rgb_factors 100 | raw_bias = latent_format.latent_rgb_factors_bias 101 | factors = torch.tensor(raw_factors, device=device, dtype=dtype).transpose(0, 1) 102 | bias = ( 103 | torch.tensor(raw_bias, device=device, dtype=dtype) 104 | if raw_bias is not None 105 | else None 106 | ) 107 | self.lin = torch.nn.Linear( 108 | factors.shape[1], 109 | factors.shape[0], 110 | device=device, 111 | dtype=dtype, 112 | bias=bias is not None, 113 | ) 114 | self.upsample = torch.nn.Upsample(scale_factor=scale_factor, mode=upscale_mode) 115 | self.requires_grad_(False) # noqa: FBT003 116 | self.lin.weight.copy_(factors) 117 | if bias is not None: 118 | self.lin.bias.copy_(bias) 119 | 120 | @torch.no_grad() 121 | def forward(self, x: torch.Tensor) -> torch.Tensor: 122 | x = self.lin(x.movedim(1, -1)).movedim(-1, 1) 123 | x = self.upsample(x).movedim(1, -1) 124 | return x.add_(1.0).mul_(127.5).clamp_(0.0, 255.0) 125 | 126 | 127 | class BetterPreviewer(_ORIG_PREVIEWER): 128 | def __init__( 129 | self, 130 | *, 131 | taesd: torch.nn.Module | None = None, 132 | latent_format: latent_formats.LatentFormat, 133 | vid_info: VideoModelInfo | None = None, 134 | ): 135 | self.latent_format = latent_format 136 | self.vid_info = vid_info 137 | self.fallback_previewer_model = None 138 | self.device = ( 139 | None 140 | if SETTINGS.btp_preview_device is None 141 | else torch.device(SETTINGS.btp_preview_device) 142 | ) 143 | if taesd is not None: 144 | if hasattr(taesd, "taesd_encoder"): 145 | del taesd.taesd_encoder 146 | if hasattr(taesd, "encoder"): 147 | del taesd.encoder 148 | if self.device and self.device != next(taesd.parameters()).device: 149 | taesd = taesd.to(self.device) 150 | self.taesd = taesd 151 | self.stamp = None 152 | self.cached = None 153 | self.blank = Image.new("RGB", size=(1, 1)) 154 | self.oom_fallback = SETTINGS.btp_oom_fallback == "latent2rgb" 155 | self.oom_retry = SETTINGS.btp_oom_retry 156 | self.oom_count = 0 157 | self.skip_upscale_layers = SETTINGS.btp_skip_upscale_layers 158 | self.preview_max_width = SETTINGS.btp_max_width 159 | self.preview_max_height = SETTINGS.btp_max_height 160 | self.throttle_secs = SETTINGS.btp_throttle_secs 161 | self.max_batch_preview = SETTINGS.btp_max_batch 162 | self.maxed_batch_step_mode = SETTINGS.btp_maxed_batch_step_mode 163 | self.max_batch_cols = SETTINGS.btp_max_batch_cols 164 | self.compile_previewer = SETTINGS.btp_compile_previewer 165 | self.maybe_pop_upscale_layers() 166 | if self.compile_previewer: 167 | compile_kwargs = ( 168 | {} 169 | if not isinstance(self.compile_previewer, dict) 170 | else self.compile_previewer 171 | ) 172 | self.taesd = torch.compile(self.taesd, **compile_kwargs) 173 | 174 | # Popping upscale layers trick from https://github.com/madebyollin/ 175 | def maybe_pop_upscale_layers(self, *, width=None, height=None) -> None: 176 | skip = self.skip_upscale_layers 177 | if skip == 0 or not isinstance(self.taesd, TAESD): 178 | return 179 | upscale_layers = tuple( 180 | idx 181 | for idx, layer in enumerate(self.taesd.taesd_decoder) 182 | if isinstance(layer, torch.nn.Upsample) 183 | ) 184 | num_upscale_layers = len(upscale_layers) 185 | if skip < 0: 186 | if width is None or height is None: 187 | return 188 | aggressive = skip == -2 189 | skip = 0 190 | max_width, max_height = self.preview_max_width, self.preview_max_height 191 | while skip < num_upscale_layers and ( 192 | width > max_width or height > max_height 193 | ): 194 | width //= 2 195 | height //= 2 196 | if not aggressive and width < max_width and height < max_height: 197 | # Popping another would overshoot the size requirement. 198 | break 199 | skip += 1 200 | if not aggressive and (width <= max_width or height <= max_height): 201 | # At least one dimension is within the size requirement. 202 | break 203 | if skip > 0: 204 | skip = min(skip, num_upscale_layers) 205 | for idx in range(1, skip + 1): 206 | self.taesd.taesd_decoder.pop(upscale_layers[-idx]) 207 | self.skip_upscale_layers = 0 208 | 209 | def decode_latent_to_preview_image( 210 | self, 211 | preview_format: str, 212 | x0: torch.Tensor, 213 | ) -> tuple[str, Image, int]: 214 | preview_image = self.decode_latent_to_preview(x0) 215 | return ( 216 | preview_format if not isinstance(preview_image, ImageWrapper) else "WEBP", 217 | preview_image, 218 | min( 219 | max(*preview_image.size), 220 | max(self.preview_max_width, self.preview_max_height), 221 | ), 222 | ) 223 | 224 | def check_use_cached(self) -> bool: 225 | now = time() 226 | if ( 227 | self.cached is not None and self.stamp is not None 228 | ) and now - self.stamp < self.throttle_secs: 229 | return True 230 | self.stamp = now 231 | return False 232 | 233 | def calculate_indexes(self, batch_size: int, *, is_video=False) -> range: 234 | max_batch = ( 235 | SETTINGS.btp_video_max_frames if is_video else self.max_batch_preview 236 | ) 237 | if max_batch < 0: 238 | return range(batch_size) 239 | if not self.maxed_batch_step_mode: 240 | return range(min(max_batch, batch_size)) 241 | return range( 242 | 0, 243 | batch_size, 244 | math.ceil(batch_size / max_batch), 245 | )[:max_batch] 246 | 247 | def prepare_decode_latent( 248 | self, 249 | x0: torch.Tensor, 250 | *, 251 | frames_to_batch=True, 252 | ) -> tuple[torch.Tensor, int, int]: 253 | is_video = x0.ndim == 5 254 | if frames_to_batch and is_video: 255 | x0 = x0.transpose(2, 1).reshape(-1, x0.shape[1], *x0.shape[-2:]) 256 | batch = x0.shape[0] 257 | x0 = x0[self.calculate_indexes(batch, is_video=is_video), :] 258 | batch = x0.shape[0] 259 | height, width = x0.shape[-2:] 260 | if self.device and x0.device != self.device: 261 | x0 = x0.to( 262 | device=self.device, 263 | non_blocking=device_supports_non_blocking(x0.device), 264 | ) 265 | cols, rows = self.calc_cols_rows( 266 | min(batch, self.max_batch_preview), 267 | width, 268 | height, 269 | ) 270 | return x0, cols, rows 271 | 272 | def _decode_latent_taevid(self, x0: torch.Tensor) -> tuple[torch.Tensor, int, int]: 273 | height, width = x0.shape[-2:] 274 | if self.device and x0.device != self.device: 275 | x0 = x0.to( 276 | device=self.device, 277 | non_blocking=device_supports_non_blocking(x0.device), 278 | ) 279 | decoded = self.taesd.decode( 280 | x0.transpose(1, 2), 281 | parallel=SETTINGS.btp_video_parallel, 282 | ).movedim(2, -1) 283 | del x0 284 | decoded = decoded.reshape(-1, *decoded.shape[2:]) 285 | batch = decoded.shape[0] 286 | decoded = decoded[self.calculate_indexes(batch, is_video=True), :] 287 | cols, rows = self.calc_cols_rows( 288 | min( 289 | batch, 290 | SETTINGS.btp_video_max_frames 291 | if SETTINGS.btp_video_max_frames >= 0 292 | else batch, 293 | ), 294 | width, 295 | height, 296 | ) 297 | return ( 298 | decoded.mul_(255.0).round_().clamp_(min=0, max=255.0).detach(), 299 | cols, 300 | rows, 301 | ) 302 | 303 | def _decode_latent_taesd(self, x0: torch.Tensor) -> tuple[torch.Tensor, int, int]: 304 | x0, cols, rows = self.prepare_decode_latent( 305 | x0, 306 | frames_to_batch=not isinstance(self.taesd, TAEVid), 307 | ) 308 | height, width = x0.shape[-2:] 309 | if self.skip_upscale_layers < 0: 310 | self.maybe_pop_upscale_layers( 311 | width=width * 8 * cols, 312 | height=height * 8 * rows, 313 | ) 314 | return ( 315 | ( 316 | self.taesd.decode(x0) 317 | .movedim(1, -1) 318 | .add_(1.0) 319 | .mul_(127.5) 320 | .clamp_(min=0, max=255.0) 321 | .detach() 322 | ), 323 | cols, 324 | rows, 325 | ) 326 | 327 | def calc_cols_rows( 328 | self, 329 | batch_size: int, 330 | width: int, 331 | height: int, 332 | ) -> tuple[int, int]: 333 | max_cols = self.max_batch_cols 334 | ratio = height / width 335 | cols = max(1, min(round((batch_size * ratio) ** 0.5), max_cols, batch_size)) 336 | rows = math.ceil(batch_size / cols) 337 | return cols, rows 338 | 339 | @classmethod 340 | def decoded_to_animation(cls, samples: np.ndarray) -> ImageWrapper: 341 | batch = samples.shape[0] 342 | return ImageWrapper( 343 | tuple(Image.fromarray(samples[idx]) for idx in range(batch)), 344 | frame_duration=250, 345 | ) 346 | 347 | def decoded_to_image( 348 | self, 349 | samples: torch.Tensor, 350 | cols: int, 351 | rows: int, 352 | *, 353 | is_video=False, 354 | ) -> Image | ImageWrapper: 355 | batch, (height, width) = samples.shape[0], samples.shape[-3:-1] 356 | samples = samples.to( 357 | device="cpu", 358 | dtype=torch.uint8, 359 | non_blocking=device_supports_non_blocking(samples.device), 360 | ).numpy() 361 | if batch == 1: 362 | self.cached = Image.fromarray(samples[0]) 363 | return self.cached 364 | if SETTINGS.btp_animate_preview == "both" or ( 365 | is_video, 366 | SETTINGS.btp_animate_preview, 367 | ) in {(True, "video"), (False, "batch")}: 368 | return self.decoded_to_animation(samples) 369 | cols, rows = self.calc_cols_rows(batch, width, height) 370 | img_size = (width * cols, height * rows) 371 | if self.cached is not None and self.cached.size == img_size: 372 | result = self.cached 373 | else: 374 | self.cached = result = Image.new("RGB", size=(width * cols, height * rows)) 375 | for idx in range(batch): 376 | result.paste( 377 | Image.fromarray(samples[idx]), 378 | box=((idx % cols) * width, ((idx // cols) % rows) * height), 379 | ) 380 | return result 381 | 382 | @torch.no_grad() 383 | def init_fallback_previewer(self, device: torch.device, dtype: torch.dtype) -> bool: 384 | if self.fallback_previewer_model is not None: 385 | return True 386 | if self.latent_format is None: 387 | return False 388 | self.fallback_previewer_model = FallbackPreviewerModel( 389 | self.latent_format, 390 | device=device, 391 | dtype=dtype, 392 | ) 393 | return True 394 | 395 | def fallback_previewer(self, x0: torch.Tensor, *, quiet=False) -> Image: 396 | if not quiet: 397 | fallback_mode = "using fallback" if self.oom_fallback else "skipping" 398 | tqdm.write( 399 | f"*** BlehBetterPreviews: Got out of memory error while decoding preview - {fallback_mode}.", 400 | ) 401 | if not self.oom_fallback: 402 | return self.blank 403 | if not self.init_fallback_previewer(x0.device, x0.dtype): 404 | self.oom_fallback = False 405 | tqdm.write( 406 | "*** BlehBetterPreviews: Couldn't initialize fallback previewer, giving up on previews.", 407 | ) 408 | return self.blank 409 | x0, cols, rows = self.prepare_decode_latent(x0) 410 | try: 411 | return self.decoded_to_image( 412 | self.fallback_previewer_model(x0), 413 | cols, 414 | rows, 415 | ) 416 | except torch.OutOfMemoryError: 417 | return self.blank 418 | 419 | def decode_latent_to_preview(self, x0: torch.Tensor) -> Image: 420 | if self.check_use_cached(): 421 | return self.cached 422 | if x0.shape[0] == 0: 423 | return self.blank # Shouldn't actually be possible. 424 | if (self.oom_count and not self.oom_retry) or self.taesd is None: 425 | return self.fallback_previewer(x0, quiet=True) 426 | is_video = x0.ndim == 5 427 | used_fallback = False 428 | start_time = time() 429 | try: 430 | dargs = ( 431 | self._decode_latent_taevid(x0) 432 | if is_video 433 | else self._decode_latent_taesd(x0) 434 | ) 435 | result = self.decoded_to_image(*dargs, is_video=is_video) 436 | except torch.OutOfMemoryError: 437 | used_fallback = True 438 | result = self.fallback_previewer(x0) 439 | if SETTINGS.btp_verbose: 440 | tqdm.write( 441 | f"BlehPreview: used fallback: {used_fallback}, decode time: {time() - start_time:0.2f}", 442 | ) 443 | return result 444 | 445 | 446 | def bleh_get_previewer( 447 | device, 448 | latent_format: latent_formats.LatentFormat, 449 | *args: list, 450 | **kwargs: dict, 451 | ) -> object | None: 452 | preview_method = comfy_args.preview_method 453 | format_name = latent_format.__class__.__name__.lower() 454 | if ( 455 | not SETTINGS.btp_enabled 456 | or format_name in SETTINGS.btp_blacklist 457 | or (SETTINGS.btp_whitelist and format_name not in SETTINGS.btp_whitelist) 458 | ): 459 | return _ORIG_GET_PREVIEWER(device, latent_format, *args, **kwargs) 460 | tae_model = None 461 | if preview_method in {LatentPreviewMethod.TAESD, LatentPreviewMethod.Auto}: 462 | vid_info = VIDEO_FORMATS.get(format_name) 463 | if vid_info is not None and vid_info.tae_model is not None: 464 | tae_model_path = folder_paths.get_full_path( 465 | "vae_approx", 466 | vid_info.tae_model, 467 | ) 468 | tupscale_limit = SETTINGS.btp_video_temporal_upscale_level 469 | decoder_time_upscale = tuple( 470 | i < tupscale_limit for i in range(TAEVid.temporal_upscale_blocks) 471 | ) 472 | tae_model = ( 473 | TAEVid( 474 | checkpoint_path=tae_model_path, 475 | latent_channels=latent_format.latent_channels, 476 | device=device, 477 | decoder_time_upscale=decoder_time_upscale, 478 | ).to(device) 479 | if tae_model_path is not None 480 | else None 481 | ) 482 | if tae_model is None and latent_format.taesd_decoder_name is not None: 483 | taesd_path = folder_paths.get_full_path( 484 | "vae_approx", 485 | f"{latent_format.taesd_decoder_name}.pth", 486 | ) 487 | tae_model = ( 488 | TAESD( 489 | None, 490 | taesd_path, 491 | latent_channels=latent_format.latent_channels, 492 | ).to(device) 493 | if taesd_path is not None 494 | else None 495 | ) 496 | return BetterPreviewer( 497 | taesd=tae_model, 498 | latent_format=latent_format, 499 | vid_info=vid_info, 500 | ) 501 | if ( 502 | preview_method == LatentPreviewMethod.NoPreviews 503 | or latent_format.latent_rgb_factors is None 504 | ): 505 | return None 506 | if preview_method == LatentPreviewMethod.Latent2RGB: 507 | return BetterPreviewer(latent_format=latent_format) 508 | return _ORIG_GET_PREVIEWER(device, latent_format, *args, **kwargs) 509 | 510 | 511 | def ensure_previewer(): 512 | if latent_preview.get_previewer != bleh_get_previewer: 513 | latent_preview.BLEH_ORIG_get_previewer = _ORIG_GET_PREVIEWER 514 | latent_preview.get_previewer = bleh_get_previewer 515 | 516 | 517 | ensure_previewer() 518 | -------------------------------------------------------------------------------- /py/better_previews/tae_vid.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/madebyollin/taehv/blob/main/taehv.py 2 | 3 | # ruff: noqa: N806 4 | 5 | from __future__ import annotations 6 | 7 | from typing import TYPE_CHECKING, NamedTuple 8 | 9 | import torch 10 | from torch import nn 11 | from tqdm.auto import tqdm 12 | 13 | if TYPE_CHECKING: 14 | from collections.abc import Iterable 15 | from pathlib import Path 16 | 17 | F = torch.nn.functional 18 | 19 | 20 | class TWorkItem(NamedTuple): 21 | input_tensor: torch.Tensor 22 | block_index: int 23 | 24 | 25 | def conv(n_in: int, n_out: int, **kwargs: dict) -> nn.Conv2d: 26 | return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) 27 | 28 | 29 | class Clamp(nn.Module): 30 | @classmethod 31 | def forward(cls, x: torch.Tensor) -> torch.Tensor: 32 | return torch.tanh(x / 3) * 3 33 | 34 | 35 | class MemBlock(nn.Module): 36 | def __init__(self, n_in, n_out): 37 | super().__init__() 38 | self.conv = nn.Sequential( 39 | conv(n_in * 2, n_out), 40 | nn.ReLU(inplace=True), 41 | conv(n_out, n_out), 42 | nn.ReLU(inplace=True), 43 | conv(n_out, n_out), 44 | ) 45 | self.skip = ( 46 | nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() 47 | ) 48 | self.act = nn.ReLU(inplace=True) 49 | 50 | def forward(self, x: torch.Tensor, past: torch.Tensor) -> torch.Tensor: 51 | return self.act(self.conv(torch.cat((x, past), 1)) + self.skip(x)) 52 | 53 | 54 | class TPool(nn.Module): 55 | def __init__(self, n_f, stride): 56 | super().__init__() 57 | self.stride = stride 58 | self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False) 59 | 60 | def forward(self, x: torch.Tensor) -> torch.Tensor: 61 | c, h, w = x.shape[-3:] 62 | return self.conv(x.reshape(-1, self.stride * c, h, w)) 63 | 64 | 65 | class TGrow(nn.Module): 66 | def __init__(self, n_f, stride): 67 | super().__init__() 68 | self.stride = stride 69 | self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False) 70 | 71 | def forward(self, x: torch.Tensor) -> torch.Tensor: 72 | orig_shape = x.shape 73 | return self.conv(x).reshape(-1, *orig_shape[-3:]) 74 | 75 | 76 | class TAEVidContext: 77 | def __init__(self, model): 78 | self.model = model 79 | self.HANDLERS = { 80 | MemBlock: self.handle_memblock, 81 | TPool: self.handle_tpool, 82 | TGrow: self.handle_tgrow, 83 | } 84 | 85 | def reset(self, x: torch.Tensor) -> None: 86 | N, T, C, H, W = x.shape 87 | self.N, self.T = N, T 88 | self.work_queue = [ 89 | TWorkItem(xt, 0) 90 | for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1)) 91 | ] 92 | self.mem = [None] * len(self.model) 93 | 94 | def handle_memblock( 95 | self, 96 | i: int, 97 | xt: torch.Tensor, 98 | b: nn.Module, 99 | ) -> Iterable[torch.Tensor]: 100 | mem = self.mem 101 | # mem blocks are simple since we're visiting the graph in causal order 102 | if mem[i] is None: 103 | xt_new = b(xt, torch.zeros_like(xt)) 104 | mem[i] = xt 105 | else: 106 | xt_new = b(xt, mem[i]) 107 | # inplace might reduce mysterious pytorch memory allocations? doesn't help though 108 | mem[i].copy_(xt) 109 | return (xt_new,) 110 | 111 | def handle_tpool( 112 | self, 113 | i: int, 114 | xt: torch.Tensor, 115 | b: nn.Module, 116 | ) -> Iterable[torch.Tensor]: 117 | mem = self.mem 118 | # pool blocks are miserable 119 | if mem[i] is None: 120 | mem[i] = [] # pool memory is itself a queue of inputs to pool 121 | mem[i].append(xt) 122 | if len(mem[i]) > b.stride: 123 | # pool mem is in invalid state, we should have pooled before this 124 | raise RuntimeError("Internal error: Invalid mem state") 125 | if len(mem[i]) < b.stride: 126 | # pool mem is not yet full, go back to processing the work queue 127 | return () 128 | # pool mem is ready, run the pool block 129 | N, C, H, W = xt.shape 130 | xt = b(torch.cat(mem[i], 1).view(N * b.stride, C, H, W)) 131 | # reset the pool mem 132 | mem[i] = [] 133 | return (xt,) 134 | 135 | def handle_tgrow( 136 | self, 137 | _i: int, 138 | xt: torch.Tensor, 139 | b: nn.Module, 140 | ) -> Iterable[torch.Tensor]: 141 | xt = b(xt) 142 | C, H, W = xt.shape[1:] 143 | return reversed( 144 | xt.view(self.N, b.stride * C, H, W).chunk(b.stride, 1), 145 | ) 146 | 147 | @classmethod 148 | def handle_default( 149 | cls, 150 | _i: int, 151 | xt: torch.Tensor, 152 | b: nn.Module, 153 | ) -> Iterable[torch.Tensor]: 154 | return (b(xt),) 155 | 156 | def handle_block(self, i: int, xt: torch.Tensor, b: nn.Module) -> None: 157 | handler = self.HANDLERS.get(b.__class__, self.handle_default) 158 | for xt_new in handler(i, xt, b): 159 | self.work_queue.insert(0, TWorkItem(xt_new, i + 1)) 160 | 161 | def apply(self, x: torch.Tensor, *, show_progress=False) -> torch.Tensor: 162 | if x.ndim != 5: 163 | raise ValueError("Expected 5 dimensional tensor") 164 | self.reset(x) 165 | out = [] 166 | work_queue = self.work_queue 167 | model = self.model 168 | model_len = len(model) 169 | 170 | with tqdm(range(self.T), disable=not show_progress) as pbar: 171 | while work_queue: 172 | xt, i = work_queue.pop(0) 173 | if i == model_len: 174 | # reached end of the graph, append result to output list 175 | out.append(xt) 176 | continue 177 | if i == 0: 178 | # new source node consumed 179 | pbar.update(1) 180 | self.handle_block(i, xt, model[i]) 181 | return torch.stack(out, 1) 182 | 183 | 184 | class TAEVid(nn.Module): 185 | temporal_upscale_blocks = 2 186 | spatial_upscale_blocks = 3 187 | 188 | def __init__( 189 | self, 190 | *, 191 | checkpoint_path: str | Path, 192 | latent_channels: int, 193 | image_channels: int = 3, 194 | device="cpu", 195 | decoder_time_upscale=(True, True), 196 | decoder_space_upscale=(True, True, True), 197 | ): 198 | super().__init__() 199 | self.latent_channels = latent_channels 200 | self.image_channels = image_channels 201 | self.encoder = nn.Sequential( 202 | conv(image_channels, 64), 203 | nn.ReLU(inplace=True), 204 | TPool(64, 2), 205 | conv(64, 64, stride=2, bias=False), 206 | MemBlock(64, 64), 207 | MemBlock(64, 64), 208 | MemBlock(64, 64), 209 | TPool(64, 2), 210 | conv(64, 64, stride=2, bias=False), 211 | MemBlock(64, 64), 212 | MemBlock(64, 64), 213 | MemBlock(64, 64), 214 | TPool(64, 1), 215 | conv(64, 64, stride=2, bias=False), 216 | MemBlock(64, 64), 217 | MemBlock(64, 64), 218 | MemBlock(64, 64), 219 | conv(64, latent_channels), 220 | ) 221 | n_f = (256, 128, 64, 64) 222 | self.frames_to_trim = 2 ** sum(decoder_time_upscale) - 1 223 | self.decoder = nn.Sequential( 224 | Clamp(), 225 | conv(latent_channels, n_f[0]), 226 | nn.ReLU(inplace=True), 227 | MemBlock(n_f[0], n_f[0]), 228 | MemBlock(n_f[0], n_f[0]), 229 | MemBlock(n_f[0], n_f[0]), 230 | nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), 231 | TGrow(n_f[0], 1), 232 | conv(n_f[0], n_f[1], bias=False), 233 | MemBlock(n_f[1], n_f[1]), 234 | MemBlock(n_f[1], n_f[1]), 235 | MemBlock(n_f[1], n_f[1]), 236 | nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), 237 | TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), 238 | conv(n_f[1], n_f[2], bias=False), 239 | MemBlock(n_f[2], n_f[2]), 240 | MemBlock(n_f[2], n_f[2]), 241 | MemBlock(n_f[2], n_f[2]), 242 | nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), 243 | TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), 244 | conv(n_f[2], n_f[3], bias=False), 245 | nn.ReLU(inplace=True), 246 | conv(n_f[3], image_channels), 247 | ) 248 | if checkpoint_path is None: 249 | return 250 | self.load_state_dict( 251 | self.patch_tgrow_layers( 252 | torch.load(checkpoint_path, map_location=device, weights_only=True), 253 | ), 254 | ) 255 | 256 | def patch_tgrow_layers(self, sd: dict) -> dict: 257 | new_sd = self.state_dict() 258 | for i, layer in enumerate(self.decoder): 259 | if isinstance(layer, TGrow): 260 | key = f"decoder.{i}.conv.weight" 261 | if sd[key].shape[0] > new_sd[key].shape[0]: 262 | # take the last-timestep output channels 263 | sd[key] = sd[key][-new_sd[key].shape[0] :] 264 | return sd 265 | 266 | @classmethod 267 | def apply_parallel( 268 | cls, 269 | x: torch.Tensor, 270 | model: nn.Module, 271 | *, 272 | show_progress=False, 273 | ) -> torch.Tensor: 274 | padding = (0, 0, 0, 0, 0, 0, 1, 0) 275 | n, t, c, h, w = x.shape 276 | x = x.reshape(n * t, c, h, w) 277 | # parallel over input timesteps, iterate over blocks 278 | for b in tqdm(model, disable=not show_progress): 279 | if not isinstance(b, MemBlock): 280 | x = b(x) 281 | continue 282 | nt, c, h, w = x.shape 283 | t = nt // n 284 | mem = F.pad(x.reshape(n, t, c, h, w), padding, value=0)[:, :t].reshape( 285 | x.shape, 286 | ) 287 | x = b(x, mem) 288 | del mem 289 | nt, c, h, w = x.shape 290 | t = nt // n 291 | return x.view(n, t, c, h, w) 292 | 293 | def apply( 294 | self, 295 | x: torch.Tensor, 296 | *, 297 | decode=True, 298 | parallel=True, 299 | show_progress=False, 300 | ) -> torch.Tensor: 301 | model = self.decoder if decode else self.encoder 302 | if parallel: 303 | return self.apply_parallel(x, model, show_progress=show_progress) 304 | return TAEVidContext(model).apply(x, show_progress=show_progress) 305 | 306 | def decode(self, *args: list, **kwargs: dict) -> torch.Tensor: 307 | return self.apply(*args, decode=True, **kwargs)[:, self.frames_to_trim :] 308 | 309 | def encode(self, *args: list, **kwargs: dict) -> torch.Tensor: 310 | return self.apply(*args, decode=False, **kwargs) 311 | 312 | def forward(self, x: torch.Tensor) -> torch.Tensor: 313 | return self.c(x) 314 | -------------------------------------------------------------------------------- /py/nodes/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | blockCFG, 3 | deepShrink, 4 | hyperTile, 5 | misc, 6 | modelPatchConditional, 7 | ops, 8 | refinerAfter, 9 | sageAttention, 10 | samplers, 11 | taevid, 12 | ) 13 | 14 | NODE_CLASS_MAPPINGS = { 15 | "BlehBlockCFG": blockCFG.BlockCFGBleh, 16 | "BlehBlockOps": ops.BlehBlockOps, 17 | "BlehDeepShrink": deepShrink.DeepShrinkBleh, 18 | "BlehDisableNoise": misc.BlehDisableNoise, 19 | "BlehDiscardPenultimateSigma": misc.DiscardPenultimateSigma, 20 | "BlehForceSeedSampler": samplers.BlehForceSeedSampler, 21 | "BlehGlobalSageAttention": sageAttention.BlehGlobalSageAttention, 22 | "BlehHyperTile": hyperTile.HyperTileBleh, 23 | "BlehInsaneChainSampler": samplers.BlehInsaneChainSampler, 24 | "BlehLatentOps": ops.BlehLatentOps, 25 | "BlehLatentScaleBy": ops.BlehLatentScaleBy, 26 | "BlehLatentBlend": ops.BlehLatentBlend, 27 | "BlehModelPatchConditional": modelPatchConditional.ModelPatchConditionalNode, 28 | "BlehPlug": misc.BlehPlug, 29 | "BlehRefinerAfter": refinerAfter.BlehRefinerAfter, 30 | "BlehSageAttentionSampler": sageAttention.BlehSageAttentionSampler, 31 | "BlehSetSamplerPreset": samplers.BlehSetSamplerPreset, 32 | "BlehCast": misc.BlehCast, 33 | "BlehSetSigmas": misc.BlehSetSigmas, 34 | "BlehEnsurePreviewer": misc.BlehEnsurePreviewer, 35 | "BlehTAEVideoDecode": taevid.TAEVideoDecode, 36 | "BlehTAEVideoEncode": taevid.TAEVideoEncode, 37 | } 38 | 39 | NODE_DISPLAY_NAME_MAPPINGS = { 40 | "BlehHyperTile": "HyperTile (bleh)", 41 | "BlehDeepShrink": "Kohya Deep Shrink (bleh)", 42 | } 43 | 44 | __all__ = ("NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS") 45 | -------------------------------------------------------------------------------- /py/nodes/blockCFG.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | 4 | class BlockCFGBleh: 5 | RETURN_TYPES = ("MODEL",) 6 | FUNCTION = "patch" 7 | CATEGORY = "bleh/model_patches" 8 | DESCRIPTION = ( 9 | "Applies a CFG type effect to the model blocks themselves during evaluation." 10 | ) 11 | 12 | @classmethod 13 | def INPUT_TYPES(cls): 14 | return { 15 | "required": { 16 | "model": ( 17 | "MODEL", 18 | { 19 | "tooltip": "Model to patch", 20 | }, 21 | ), 22 | "commasep_block_numbers": ( 23 | "STRING", 24 | { 25 | "default": "i4,m0,o4", 26 | "tooltip": "Comma separated list of block numbers, each should start with one of i(input), m(iddle), o(utput). You may also use * instead of a block number to select all blocks in the category.", 27 | }, 28 | ), 29 | "scale": ( 30 | "FLOAT", 31 | { 32 | "default": 0.25, 33 | "min": -100.0, 34 | "max": 100.0, 35 | "step": 0.001, 36 | "round": False, 37 | "tooltip": "Effect strength", 38 | }, 39 | ), 40 | "start_percent": ( 41 | "FLOAT", 42 | { 43 | "default": 0.2, 44 | "min": 0.0, 45 | "max": 1.0, 46 | "step": 0.001, 47 | "tooltip": "Start time as sampling percentage (not percentage of steps). Percentages are inclusive.", 48 | }, 49 | ), 50 | "end_percent": ( 51 | "FLOAT", 52 | { 53 | "default": 0.8, 54 | "min": 0.0, 55 | "max": 1.0, 56 | "step": 0.001, 57 | "tooltip": "End time as sampling percentage (not percentage of steps). Percentages are inclusive.", 58 | }, 59 | ), 60 | "skip_mode": ( 61 | "BOOLEAN", 62 | { 63 | "default": False, 64 | "tooltip": "For output blocks, this causes the effect to apply to the skip connection. For input blocks it patches after the skip connection. No effect for middle blocks.", 65 | }, 66 | ), 67 | "apply_to": ( 68 | ("cond", "uncond"), 69 | { 70 | "default": "uncond", 71 | "tooltip": "Guides the specified target away from its opposite. cond=positive prompt, uncond=negative prompt.", 72 | }, 73 | ), 74 | }, 75 | } 76 | 77 | @classmethod 78 | def patch( 79 | cls, 80 | *, 81 | model, 82 | commasep_block_numbers, 83 | scale, 84 | start_percent, 85 | end_percent, 86 | skip_mode, 87 | apply_to, 88 | ): 89 | input_blocks = {} 90 | middle_blocks = {} 91 | output_blocks = {} 92 | for idx, item_ in enumerate(commasep_block_numbers.split(",")): 93 | item = item_.strip().lower() 94 | if not item: 95 | continue 96 | block_type = item[0] 97 | if block_type not in "imo" or len(item) < 2: 98 | errstr = f"Bad block definition at item {idx}" 99 | raise ValueError(errstr) 100 | if item[1] == "*": 101 | block = tidx = -1 102 | else: 103 | block, *tidx = item[1:].split(".", 1) 104 | block = int(block) 105 | tidx = int(tidx) if tidx else -1 106 | if block_type == "i": 107 | bd = input_blocks 108 | else: 109 | bd = output_blocks if block_type == "o" else middle_blocks 110 | bd[block] = tidx 111 | 112 | if ( 113 | scale == 0 114 | or end_percent <= 0 115 | or start_percent >= 1 116 | or not (input_blocks or middle_blocks or output_blocks) 117 | ): 118 | return (model,) 119 | 120 | ms = model.get_model_object("model_sampling") 121 | sigma_start = ms.percent_to_sigma(start_percent) 122 | sigma_end = ms.percent_to_sigma(end_percent) 123 | reverse = apply_to != "cond" 124 | 125 | def check_applies(block_list, transformer_options): 126 | cond_or_uncond = transformer_options["cond_or_uncond"] 127 | if ( 128 | not (0 in cond_or_uncond and 1 in cond_or_uncond) 129 | or len(cond_or_uncond) != 2 130 | ): 131 | return False 132 | block_num = transformer_options["block"][1] 133 | sigma_tensor = transformer_options["sigmas"].max() 134 | sigma = sigma_tensor.detach().cpu().item() 135 | block_def = block_list.get(block_num) 136 | ok_time = sigma_end <= sigma <= sigma_start 137 | if not ok_time: 138 | return False 139 | if block_def is None: 140 | return -1 in block_list 141 | return block_def in {-1, transformer_options.get("transformer_index")} 142 | 143 | def apply_cfg_fun(tensor, primary_offset): 144 | secondary_offset = 0 if primary_offset == 1 else 1 145 | if reverse: 146 | primary_offset, secondary_offset = secondary_offset, primary_offset 147 | result = tensor.clone() 148 | batch = tensor.shape[0] // 2 149 | primary_idxs, secondary_idxs = ( 150 | tuple(range(batch * offs, batch + batch * offs)) 151 | for offs in (primary_offset, secondary_offset) 152 | ) 153 | # print(f"\nIDXS: cond={primary_idxs}, uncond={secondary_idxs}") 154 | result[primary_idxs, ...] -= ( 155 | tensor[primary_idxs, ...] - tensor[secondary_idxs, ...] 156 | ).mul_(scale) 157 | return result 158 | 159 | def non_output_block_patch(h, transformer_options, *, block_list): 160 | cond_or_uncond = transformer_options["cond_or_uncond"] 161 | if not check_applies( 162 | block_list, 163 | transformer_options, 164 | ): 165 | return h 166 | return apply_cfg_fun(h, cond_or_uncond.index(0)) 167 | 168 | def output_block_patch(h, hsp, transformer_options, *, block_list): 169 | cond_or_uncond = transformer_options["cond_or_uncond"] 170 | if not check_applies( 171 | block_list, 172 | transformer_options, 173 | ): 174 | return h, hsp 175 | cond_idx = cond_or_uncond.index(0) 176 | return ( 177 | (apply_cfg_fun(h, cond_idx), hsp) 178 | if not skip_mode 179 | else (h, apply_cfg_fun(hsp, cond_idx)) 180 | ) 181 | 182 | m = model.clone() 183 | if input_blocks: 184 | ( 185 | m.set_model_input_block_patch_after_skip 186 | if skip_mode 187 | else m.set_model_input_block_patch 188 | )( 189 | partial(non_output_block_patch, block_list=input_blocks), 190 | ) 191 | if middle_blocks: 192 | m.set_model_patch( 193 | partial(non_output_block_patch, block_list=middle_blocks), 194 | "middle_block_patch", 195 | ) 196 | if output_blocks: 197 | m.set_model_output_block_patch( 198 | partial(output_block_patch, block_list=output_blocks), 199 | ) 200 | return (m,) 201 | -------------------------------------------------------------------------------- /py/nodes/deepShrink.py: -------------------------------------------------------------------------------- 1 | # Adapted from the ComfyUI built-in node 2 | 3 | from .. import latent_utils # noqa: TID252 4 | 5 | 6 | class DeepShrinkBleh: 7 | RETURN_TYPES = ("MODEL",) 8 | FUNCTION = "patch" 9 | CATEGORY = "bleh/model_patches" 10 | DESCRIPTION = "Model patch that enables generating at higher resolution than the model was trained for by downscaling the image near the start of generation." 11 | 12 | upscale_methods = ( 13 | "bicubic", 14 | "nearest-exact", 15 | "bilinear", 16 | "area", 17 | "bislerp", 18 | ) 19 | 20 | @classmethod 21 | def INPUT_TYPES(cls): 22 | return { 23 | "required": { 24 | "model": ( 25 | "MODEL", 26 | { 27 | "tooltip": "Model to patch", 28 | }, 29 | ), 30 | "commasep_block_numbers": ( 31 | "STRING", 32 | { 33 | "default": "3", 34 | "tooltip": "A comma separated list of input block numbers, the default should work for SD 1.5 and SDXL.", 35 | }, 36 | ), 37 | "downscale_factor": ( 38 | "FLOAT", 39 | { 40 | "default": 2.0, 41 | "min": 1.0, 42 | "max": 32.0, 43 | "step": 0.1, 44 | "tooltip": "Controls how much the block will get downscaled while the effect is active.", 45 | }, 46 | ), 47 | "start_percent": ( 48 | "FLOAT", 49 | { 50 | "default": 0.0, 51 | "min": 0.0, 52 | "max": 1.0, 53 | "step": 0.001, 54 | "tooltip": "Start time as sampling percentage (not percentage of steps). Percentages are inclusive.", 55 | }, 56 | ), 57 | "start_fadeout_percent": ( 58 | "FLOAT", 59 | { 60 | "default": 1.0, 61 | "min": 0.0, 62 | "max": 1.0, 63 | "step": 0.001, 64 | "tooltip": "When enabled, the downscale_factor will fade out such that at end_percent it will be around 1.0 (no downscaling). May reduce artifacts... or cause them!", 65 | }, 66 | ), 67 | "end_percent": ( 68 | "FLOAT", 69 | { 70 | "default": 0.35, 71 | "min": 0.0, 72 | "max": 1.0, 73 | "step": 0.001, 74 | "tooltip": "End time as sampling percentage (not percentage of steps). Percentages are inclusive.", 75 | }, 76 | ), 77 | "downscale_after_skip": ( 78 | "BOOLEAN", 79 | { 80 | "default": True, 81 | "tooltip": "Controls whether the downscale effect occurs after the skip conection. Generally should be left enabled.", 82 | }, 83 | ), 84 | "downscale_method": ( 85 | latent_utils.UPSCALE_METHODS, 86 | { 87 | "default": "bicubic", 88 | "tooltip": "Mode used for downscaling. Bicubic is generally a safe choice.", 89 | }, 90 | ), 91 | "upscale_method": ( 92 | latent_utils.UPSCALE_METHODS, 93 | { 94 | "default": "bicubic", 95 | "tooltip": "Mode used for upscaling. Bicubic is generally a safe choice.", 96 | }, 97 | ), 98 | "antialias_downscale": ( 99 | "BOOLEAN", 100 | { 101 | "default": False, 102 | "tooltip": "Experimental option to anti-alias (smooth) the latent after downscaling.", 103 | }, 104 | ), 105 | "antialias_upscale": ( 106 | "BOOLEAN", 107 | { 108 | "default": False, 109 | "tooltip": "Experimental option to anti-alias (smooth) the latent after upscaling.", 110 | }, 111 | ), 112 | }, 113 | } 114 | 115 | @classmethod 116 | def patch( 117 | cls, 118 | *, 119 | model, 120 | commasep_block_numbers, 121 | downscale_factor, 122 | start_percent, 123 | start_fadeout_percent, 124 | end_percent, 125 | downscale_after_skip, 126 | downscale_method, 127 | upscale_method, 128 | antialias_downscale, 129 | antialias_upscale, 130 | ): 131 | block_numbers = tuple( 132 | int(x) for x in commasep_block_numbers.split(",") if x.strip() 133 | ) 134 | downscale_factor = 1.0 / downscale_factor 135 | if not (block_numbers and all(val > 0 and val <= 32 for val in block_numbers)): 136 | raise ValueError( 137 | "BlehDeepShrink: Bad value for block numbers: must be comma-separated list of numbers between 1-32", 138 | ) 139 | antialias_downscale = antialias_downscale and downscale_method in { 140 | "bicubic", 141 | "bilinear", 142 | } 143 | antialias_upscale = antialias_upscale and upscale_method in { 144 | "bicubic", 145 | "bilinear", 146 | } 147 | if start_fadeout_percent < start_percent: 148 | start_fadeout_percent = start_percent 149 | elif start_fadeout_percent > end_percent: 150 | # No fadeout. 151 | start_fadeout_percent = 1000.0 152 | 153 | ms = model.get_model_object("model_sampling") 154 | sigma_start = ms.percent_to_sigma(start_percent) 155 | sigma_end = ms.percent_to_sigma(end_percent) 156 | 157 | def input_block_patch(h, transformer_options): 158 | block_num = transformer_options["block"][1] 159 | sigma_tensor = transformer_options["sigmas"].max() 160 | sigma = sigma_tensor.detach().cpu().item() 161 | if ( 162 | sigma > sigma_start 163 | or sigma < sigma_end 164 | or block_num not in block_numbers 165 | ): 166 | return h 167 | pct = 1.0 - (ms.timestep(sigma_tensor).detach().cpu().item() / 999) 168 | if ( 169 | pct < start_fadeout_percent 170 | or start_fadeout_percent > end_percent 171 | or pct > end_percent 172 | ): 173 | scaled_scale = downscale_factor 174 | else: 175 | # May or not be accurate but the idea is to scale the downscale factor by the percentage 176 | # of the start fade to end deep shrink we have currently traversed. It at least sort of works. 177 | downscale_pct = 1.0 - ( 178 | (pct - start_fadeout_percent) 179 | / (end_percent - start_fadeout_percent) 180 | ) 181 | scaled_scale = 1.0 - ((1.0 - downscale_factor) * downscale_pct) 182 | orig_width, orig_height = h.shape[-1], h.shape[-2] 183 | width, height = ( 184 | round(orig_width * scaled_scale), 185 | round(orig_height * scaled_scale), 186 | ) 187 | if scaled_scale >= 0.98 or width >= orig_width or height >= orig_height: 188 | return h 189 | return latent_utils.scale_samples( 190 | h, 191 | width, 192 | height, 193 | mode=downscale_method, 194 | antialias_size=3 if antialias_downscale else 0, 195 | sigma=sigma, 196 | ) 197 | 198 | def output_block_patch(h, hsp, transformer_options): 199 | sigma = transformer_options["sigmas"][0].cpu().item() 200 | if ( 201 | h.shape[-2:] == hsp.shape[-2:] 202 | or sigma > sigma_start 203 | or sigma < sigma_end 204 | ): 205 | return h, hsp 206 | return latent_utils.scale_samples( 207 | h, 208 | hsp.shape[-1], 209 | hsp.shape[-2], 210 | mode=upscale_method, 211 | antialias_size=3 if antialias_upscale else 0, 212 | sigma=sigma, 213 | ), hsp 214 | 215 | m = model.clone() 216 | if downscale_factor == 0.0 or start_percent >= 1.0: 217 | return (m,) 218 | if downscale_after_skip: 219 | m.set_model_input_block_patch_after_skip(input_block_patch) 220 | else: 221 | m.set_model_input_block_patch(input_block_patch) 222 | m.set_model_output_block_patch(output_block_patch) 223 | return (m,) 224 | -------------------------------------------------------------------------------- /py/nodes/hyperTile.py: -------------------------------------------------------------------------------- 1 | # The chain of yoinks grows ever longer. 2 | # Originally taken from: https://github.com/tfernd/HyperTile/ 3 | # Modified version of ComfyUI main code 4 | # https://github.com/comfyanonymous/ComfyUI/blob/master/comfy_extras/nodes_hypertile.py 5 | from __future__ import annotations 6 | 7 | import math 8 | 9 | import torch 10 | from einops import rearrange 11 | 12 | 13 | class HyperTile: 14 | def __init__( # noqa: PLR0917 15 | self, 16 | model, 17 | seed, 18 | tile_size, 19 | swap_size, 20 | max_depth, 21 | scale_depth, 22 | interval, 23 | start_step, 24 | end_step, 25 | ): 26 | self.rng = torch.Generator() 27 | self.rng.manual_seed(seed) 28 | self.model = model 29 | self.latent_tile_size = max(32, tile_size) // 8 30 | self.swap_size = swap_size 31 | self.max_depth = max_depth 32 | self.scale_depth = scale_depth 33 | self.start_step = start_step 34 | self.end_step = end_step 35 | self.interval = interval 36 | self.last_timestep = -1 37 | self.counter = -1 38 | # Temporary storage for rearranged tensors in the output part 39 | self.temp = None 40 | 41 | def patch(self): 42 | model = self.model 43 | model.set_model_attn1_patch(self.attn1_in) 44 | model.set_model_attn1_output_patch(self.attn1_out) 45 | return model 46 | 47 | def random_divisor( 48 | self, 49 | value: int, 50 | min_value: int, 51 | /, 52 | max_options: int = 1, 53 | ) -> int: 54 | min_value = min(min_value, value) 55 | # All big divisors of value (inclusive) 56 | divisors = tuple(i for i in range(min_value, value + 1) if value % i == 0) 57 | ns = tuple(value // i for i in divisors[:max_options]) # has at least 1 element 58 | if len(ns) < 2: 59 | return ns[0] 60 | return ns[ 61 | torch.randint( 62 | generator=self.rng, 63 | low=0, 64 | high=len(ns) - 1, 65 | size=(1,), 66 | ).item() 67 | ] 68 | 69 | def check_timestep(self, extra_options): 70 | current_timestep = self.model.model.model_sampling.timestep( 71 | extra_options["sigmas"][0], 72 | ).item() 73 | matched = ( 74 | current_timestep <= self.start_step and current_timestep >= self.end_step 75 | ) 76 | if current_timestep > self.last_timestep: 77 | # Detecting if the model got reused to sample again... maybe? 78 | self.counter = 0 if matched else -1 79 | self.temp = None 80 | elif matched and current_timestep != self.last_timestep: 81 | self.counter += 1 82 | self.last_timestep = current_timestep 83 | return matched 84 | 85 | def check_interval(self): 86 | if self.interval > 0: 87 | return (self.counter % self.interval) == 0 88 | if self.interval < 0: 89 | return ((self.counter + 1) % abs(self.interval)) > 0 90 | return False 91 | 92 | def attn1_in(self, q, k, v, extra_options): 93 | if not (self.check_timestep(extra_options) and self.check_interval()): 94 | return q, k, v 95 | model_chans = q.shape[-2] 96 | orig_shape = extra_options["original_shape"] 97 | 98 | apply_to = tuple( 99 | (orig_shape[-2] / (2**i)) * (orig_shape[-1] / (2**i)) 100 | for i in range(self.max_depth + 1) 101 | ) 102 | if model_chans not in apply_to: 103 | return q, k, v 104 | 105 | aspect_ratio = orig_shape[-1] / orig_shape[-2] 106 | 107 | hw = q.size(1) 108 | h, w = ( 109 | round(math.sqrt(hw * aspect_ratio)), 110 | round(math.sqrt(hw / aspect_ratio)), 111 | ) 112 | 113 | factor = (2 ** apply_to.index(model_chans)) if self.scale_depth else 1 114 | 115 | nh = self.random_divisor(h, self.latent_tile_size * factor, self.swap_size) 116 | nw = self.random_divisor(w, self.latent_tile_size * factor, self.swap_size) 117 | 118 | if nh * nw <= 1: 119 | return q, k, v 120 | 121 | q = rearrange( 122 | q, 123 | "b (nh h nw w) c -> (b nh nw) (h w) c", 124 | h=h // nh, 125 | w=w // nw, 126 | nh=nh, 127 | nw=nw, 128 | ) 129 | self.temp = (nh, nw, h, w) 130 | return q, k, v 131 | 132 | def attn1_out(self, out, _extra_options): 133 | if self.temp is None: 134 | return out 135 | nh, nw, h, w = self.temp 136 | self.temp = None 137 | out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) 138 | return rearrange( 139 | out, 140 | "b nh nw (h w) c -> b (nh h nw w) c", 141 | h=h // nh, 142 | w=w // nw, 143 | ) 144 | 145 | 146 | class HyperTileBleh: 147 | RETURN_TYPES = ("MODEL",) 148 | FUNCTION = "patch" 149 | CATEGORY = "bleh/model_patches" 150 | DESCRIPTION = "Model patch that speeds up generation at some cost of quality." 151 | 152 | @classmethod 153 | def INPUT_TYPES(cls): 154 | return { 155 | "required": { 156 | "model": ("MODEL",), 157 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}), 158 | "tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}), 159 | "swap_size": ("INT", {"default": 2, "min": 1, "max": 128}), 160 | "max_depth": ("INT", {"default": 0, "min": 0, "max": 10}), 161 | "scale_depth": ("BOOLEAN", {"default": False}), 162 | "interval": ("INT", {"default": 1, "min": -999, "max": 999}), 163 | "start_step": ( 164 | "INT", 165 | { 166 | "default": 1000, 167 | "min": 0, 168 | "max": 1000, 169 | "step": 1, 170 | "display": "number", 171 | }, 172 | ), 173 | "end_step": ( 174 | "INT", 175 | { 176 | "default": 0, 177 | "min": 0, 178 | "max": 1000, 179 | "step": 1, 180 | }, 181 | ), 182 | }, 183 | } 184 | 185 | @classmethod 186 | def patch( 187 | cls, 188 | *, 189 | model, 190 | seed, 191 | tile_size, 192 | swap_size, 193 | max_depth, 194 | scale_depth, 195 | interval, 196 | start_step, 197 | end_step, 198 | ): 199 | return ( 200 | HyperTile( 201 | model.clone(), 202 | seed, 203 | tile_size, 204 | swap_size, 205 | max_depth, 206 | scale_depth, 207 | interval, 208 | start_step, 209 | end_step, 210 | ).patch(), 211 | ) 212 | -------------------------------------------------------------------------------- /py/nodes/misc.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import operator 4 | import random 5 | from decimal import Decimal 6 | 7 | import torch 8 | from comfy import model_management 9 | 10 | from ..better_previews.previewer import ensure_previewer # noqa: TID252 11 | 12 | 13 | class DiscardPenultimateSigma: 14 | @classmethod 15 | def INPUT_TYPES(cls): 16 | return { 17 | "required": { 18 | "enabled": ("BOOLEAN", {"default": True}), 19 | "sigmas": ("SIGMAS", {"forceInput": True}), 20 | }, 21 | } 22 | 23 | FUNCTION = "go" 24 | RETURN_TYPES = ("SIGMAS",) 25 | CATEGORY = "sampling/custom_sampling/sigmas" 26 | DESCRIPTION = "Discards the next to last sigma in the list." 27 | 28 | @classmethod 29 | def go(cls, enabled: bool, sigmas: torch.Tensor) -> tuple[torch.Tensor]: 30 | if not enabled or len(sigmas) < 2: 31 | return (sigmas,) 32 | return (torch.cat((sigmas[:-2], sigmas[-1:])),) 33 | 34 | 35 | class SeededDisableNoise: 36 | def __init__(self, seed: int, seed_offset: int = 1): 37 | self.seed = seed 38 | self.seed_offset = seed_offset 39 | 40 | def generate_noise(self, latent): 41 | samples = latent["samples"] 42 | torch.manual_seed(self.seed) 43 | random.seed(self.seed) # For good measure. 44 | shape = samples.shape 45 | device = model_management.get_torch_device() 46 | for _ in range(self.seed_offset): 47 | _ = random.random() # noqa: S311 48 | _ = torch.randn( 49 | *shape, 50 | dtype=samples.dtype, 51 | layout=samples.layout, 52 | device=device, 53 | ) 54 | return torch.zeros( 55 | samples.shape, 56 | dtype=samples.dtype, 57 | layout=samples.layout, 58 | device="cpu", 59 | ) 60 | 61 | 62 | class BlehDisableNoise: 63 | DESCRIPTION = "Allows setting a seed even when disabling noise. Used for SamplerCustomAdvanced or other nodes that take a NOISE input." 64 | RETURN_TYPES = ("NOISE",) 65 | FUNCTION = "go" 66 | CATEGORY = "sampling/custom_sampling/noise" 67 | 68 | @classmethod 69 | def INPUT_TYPES(cls): 70 | return { 71 | "required": { 72 | "noise_seed": ( 73 | "INT", 74 | {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}, 75 | ), 76 | }, 77 | "optional": { 78 | "seed_offset": ( 79 | "INT", 80 | { 81 | "default": 1, 82 | "min": 0, 83 | "max": 200, 84 | "tooltip": "Advances the RNG this many times to avoid the mistake of using the same noise for sampling as the initial noise. I recommend leaving this at 1 (or higher) but you can set it to 0. to disable", 85 | }, 86 | ), 87 | }, 88 | } 89 | 90 | @classmethod 91 | def go( 92 | cls, 93 | noise_seed: int, 94 | seed_offset: int | None = 1, 95 | ) -> tuple[SeededDisableNoise]: 96 | return ( 97 | SeededDisableNoise( 98 | noise_seed, 99 | seed_offset if seed_offset is not None else 0, 100 | ), 101 | ) 102 | 103 | 104 | class Wildcard(str): # noqa: FURB189 105 | __slots__ = () 106 | 107 | def __ne__(self, _unused): 108 | return False 109 | 110 | 111 | class BlehPlug: 112 | DESCRIPTION = "This node can be used to plug up an input but act like the input was not actually connected. Can be used to prevent something like Use Everywhere nodes from supplying an input without having to set up blacklists or other configuration." 113 | FUNCTION = "go" 114 | OUTPUT_NODE = False 115 | CATEGORY = "hacks" 116 | 117 | WILDCARD = Wildcard("*") 118 | RETURN_TYPES = (WILDCARD,) 119 | 120 | @classmethod 121 | def INPUT_TYPES(cls): 122 | return {} 123 | 124 | @classmethod 125 | def go(cls): 126 | return (None,) 127 | 128 | 129 | class BlehCast: 130 | DESCRIPTION = "UNSAFE: This node allows casting its input to any type. NOTE: This does not actually change the data in any way, it just allows you to connect its output to any input. Only use if you know for sure the data is compatible." 131 | FUNCTION = "go" 132 | CATEGORY = "hacks" 133 | 134 | WILDCARD = Wildcard("*") 135 | RETURN_TYPES = (WILDCARD,) 136 | 137 | @classmethod 138 | def INPUT_TYPES(cls): 139 | return { 140 | "required": { 141 | "any_input": ( 142 | cls.WILDCARD, 143 | { 144 | "forceInput": True, 145 | "description": "You can connect any type of input here, but take to ensure that you connect the output from this node to an input that is compatible.", 146 | }, 147 | ), 148 | }, 149 | } 150 | 151 | @classmethod 152 | def go(cls, *, any_input): 153 | return (any_input,) 154 | 155 | 156 | class BlehSetSigmas: 157 | DESCRIPTION = "Advanced node that allows manipulating SIGMAS. For example, you can manually enter a list of sigmas, insert some new sigmas into existing SIGMAS, etc." 158 | FUNCTION = "go" 159 | CATEGORY = "sampling/custom_sampling/sigmas" 160 | RETURN_TYPES = ("SIGMAS",) 161 | 162 | @classmethod 163 | def INPUT_TYPES(cls): 164 | return { 165 | "required": { 166 | "start_index": ( 167 | "INT", 168 | { 169 | "default": 0, 170 | "tooltip": "Start index for modifying sigmas, zero-based. May be set to a negative value to index from the end, i.e. -1 is the last item, -2 is the penultimate item.", 171 | }, 172 | ), 173 | "mode": ( 174 | ("replace", "insert", "multiply", "add", "subtract", "divide"), 175 | { 176 | "default": "replace", 177 | "tooltip": "", 178 | }, 179 | ), 180 | "order": ( 181 | ("AB", "BA"), 182 | { 183 | "default": "AB", 184 | "tooltip": "Only applies to add, subtract, multiply and divide operations. Controls the order of operations. For example if order AB then add means A*B, if order BA then add means B*A.", 185 | }, 186 | ), 187 | "commasep_sigmas_b": ( 188 | "STRING", 189 | { 190 | "default": "", 191 | "tooltip": "Exclusive with sigmas_b. Enter a comma-separated list of sigma values here. For non-insert mode, the input sigmas will be padded with zeros if necessary. Example: start_index=2 (3rd item), mode=replace, input sigmas 4,3,2,1 and you used replace mode with 0.3,0.2,0.1 the output would be 4,3,0.3,0.2,0.1", 192 | }, 193 | ), 194 | }, 195 | "optional": { 196 | "sigmas_a": ( 197 | "SIGMAS", 198 | { 199 | "forceInput": True, 200 | "tooltip": "Optional input as long as commasep_sigmas is not also empty. If not supplied, an initial sigmas list of the appropriate size will be generated filled with zeros.", 201 | }, 202 | ), 203 | "sigmas_b": ( 204 | "SIGMAS", 205 | { 206 | "forceInput": True, 207 | "tooltip": "Optionally populate this or commasep_sigmas_b but not both.", 208 | }, 209 | ), 210 | }, 211 | } 212 | 213 | OP_MAP = { # noqa: RUF012 214 | "add": operator.add, 215 | "subtract": operator.sub, 216 | "multiply": operator.mul, 217 | "divide": operator.truediv, 218 | } 219 | 220 | @classmethod 221 | def go( 222 | cls, 223 | *, 224 | start_index: int, 225 | mode: str, 226 | order: str, 227 | commasep_sigmas_b: str, 228 | sigmas_a: torch.Tensor | None = None, 229 | sigmas_b: torch.Tensor | None = None, 230 | ) -> tuple: 231 | new_sigmas_list = tuple( 232 | Decimal(val) for val in commasep_sigmas_b.strip().split(",") if val.strip() 233 | ) 234 | if new_sigmas_list and sigmas_b is not None: 235 | raise ValueError( 236 | "Must populate one of sigmas_b or commasep_sigmas_b but not both.", 237 | ) 238 | if sigmas_b is not None: 239 | sigmas_b = sigmas_b.to(dtype=torch.float64, device="cpu", copy=True) 240 | else: 241 | sigmas_b = torch.tensor(new_sigmas_list, device="cpu", dtype=torch.float64) 242 | newlen = sigmas_b.numel() 243 | if sigmas_a is None or sigmas_a.numel() == 0: 244 | sigmas_a = None 245 | if not newlen: 246 | raise ValueError( 247 | "sigmas_a, commasep_sigmas_b and sigmas_b can't all be empty.", 248 | ) 249 | if start_index < 0: 250 | raise ValueError( 251 | "Negative start_index doesn't make sense when input sigmas are empty.", 252 | ) 253 | if newlen == 0: 254 | return (sigmas_a.to(dtype=torch.float, copy=True),) 255 | oldlen = 0 if sigmas_a is None else sigmas_a.numel() 256 | if start_index < 0: 257 | start_index = oldlen + start_index 258 | if start_index < 0: 259 | raise ValueError( 260 | "Negative start index points past the beginning of sigmas_a", 261 | ) 262 | past_end = 0 if start_index < oldlen else start_index + 1 - oldlen 263 | if past_end and mode == "insert": 264 | mode = "replace" 265 | if past_end: 266 | outlen = oldlen + newlen + past_end - 1 267 | elif mode == "insert": 268 | outlen = oldlen + newlen 269 | else: 270 | outlen = oldlen + max(0, newlen - (oldlen - start_index)) 271 | sigmas_out = torch.zeros(outlen, device="cpu", dtype=torch.float64) 272 | if mode == "insert": 273 | sigmas_out[:start_index] = sigmas_a[:start_index] 274 | sigmas_out[start_index : start_index + newlen] = sigmas_b 275 | sigmas_out[start_index + newlen :] = sigmas_a[start_index:] 276 | else: 277 | if oldlen: 278 | sigmas_out[:oldlen] = sigmas_a 279 | if mode == "replace": 280 | sigmas_out[start_index : start_index + newlen] = sigmas_b 281 | else: 282 | opfun = cls.OP_MAP.get(mode) 283 | if opfun is None: 284 | raise ValueError("Bad mode") 285 | arga = sigmas_out[start_index : start_index + newlen] 286 | if order == "BA": 287 | arga, argb = sigmas_b, arga 288 | else: 289 | argb = sigmas_b 290 | sigmas_out[start_index : start_index + newlen] = opfun(arga, argb) 291 | return (sigmas_out.to(torch.float),) 292 | 293 | 294 | class BlehEnsurePreviewer: 295 | DESCRIPTION = "This node ensures Bleh is used for previews. Can be used if other custom nodes overwrite the Bleh previewer. It will pass through any value unchanged." 296 | FUNCTION = "go" 297 | OUTPUT_NODE = False 298 | CATEGORY = "hacks" 299 | 300 | WILDCARD = Wildcard("*") 301 | RETURN_TYPES = (WILDCARD,) 302 | 303 | @classmethod 304 | def INPUT_TYPES(cls): 305 | return { 306 | "required": { 307 | "any_input": ( 308 | cls.WILDCARD, 309 | { 310 | "forceInput": True, 311 | "description": "You can connect any type of input here, but take to ensure that you connect the output from this node to an input that is compatible.", 312 | }, 313 | ), 314 | }, 315 | } 316 | 317 | @classmethod 318 | def go(cls, *, any_input): 319 | ensure_previewer() 320 | return (any_input,) 321 | -------------------------------------------------------------------------------- /py/nodes/modelPatchConditional.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | import torch 6 | from comfy.ldm.modules.attention import optimized_attention 7 | 8 | # ** transformer_options 9 | # input_block_patch* : p(h, transformer_options) -> h 10 | # input_block_patch_after_skip*: p(h, transformer_options) -> h 11 | # output_block_patch* : p(h, hsp, transformer_options) -> h, hsp 12 | # attn1_patch* : p(n, context_attn1, value_attn1, extra_options) -> n, context_attn1, value_attn1 13 | # attn1_output_patch* : p(n, extra_options) -> n 14 | # attn2_patch* : p(n, context_attn2, value_attn2, extra_options) -> n, context_attn2, value_attn2 15 | # attn2_output_patch* : p(n, extra_options) -> n 16 | # middle_patch* : p(x, extra_options) -> x 17 | # attn1_replace* : p(n, context_attn1, value_attn1, extra_options) -> n 18 | # attn2_replace* : p(n, context_attn2, value_attn2, extra_options) -> n 19 | 20 | # ** model_options 21 | # model_function_wrapper : p(model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}) -> output 22 | # sampler_cfg_function : p({"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}) -> output 23 | # sampler_post_cfg_function* : p({"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, "sigma": timestep, "model_options": model_options, "input": x}) -> output 24 | 25 | 26 | class PatchTypeTransformer: 27 | def __init__( 28 | self, 29 | name, 30 | nresult=1, 31 | ): 32 | self.name = name 33 | self.nresult = nresult 34 | 35 | def get_patches(self, model_options): 36 | return ( 37 | model_options.get("transformer_options", {}) 38 | .get("patches", {}) 39 | .get(self.name, []) 40 | ) 41 | 42 | def set_patches(self, model_options, val): 43 | to = model_options.get("transformer_options", {}) 44 | model_options["transformer_options"] = to 45 | patches = to.get("patches", {}) 46 | to["patches"] = patches 47 | patches[self.name] = val 48 | 49 | def exists(self, model_options): 50 | return len(self.get_patches(model_options)) > 0 51 | 52 | def _call(self, patches, *args: list[Any]): 53 | result_part, arg_part = args[: self.nresult], args[self.nresult :] 54 | for p in patches: 55 | result_part = p(*result_part, *arg_part) 56 | return result_part 57 | 58 | @torch.no_grad() 59 | def __call__(self, model_options, *args: list[Any]): 60 | result = self._call(self.get_patches(model_options), *args) 61 | return ( 62 | result[0] 63 | if isinstance(result, (tuple, list)) and len(result) == 1 64 | else result 65 | ) 66 | 67 | 68 | class PatchTypeTransformerReplace(PatchTypeTransformer): 69 | def get_patches(self, model_options): 70 | return ( 71 | model_options.get("transformer_options", {}) 72 | .get("patches_replace", {}) 73 | .get(self.name, {}) 74 | ) 75 | 76 | def set_patches(self, model_options, val): 77 | to = model_options.get("transformer_options", {}) 78 | model_options["transformer_options"] = to 79 | patches = to.get("patches_replace", {}) 80 | to["patches_replace"] = patches 81 | patches[self.name] = val 82 | 83 | def __call__(self, key, model_options, *args: list[Any]): 84 | return self._call(key, self.get_patches(model_options), *args) 85 | 86 | @classmethod 87 | def _call(cls, key, patches, *args: list[Any]): 88 | p = patches.get(key) 89 | if p: 90 | return p(*args) 91 | return optimized_attention(*args[:-1], heads=args[-1]["n_heads"]) 92 | 93 | 94 | class PatchTypeModel(PatchTypeTransformer): 95 | def set_patches(self, model_options, val): 96 | model_options[self.name] = val[0] 97 | 98 | def get_patches(self, model_options): 99 | return () if self.name not in model_options else (model_options[self.name],) 100 | 101 | 102 | class PatchTypeModelWrapper(PatchTypeModel): 103 | @classmethod 104 | def _call(cls, patches, apply_model, opts): 105 | if not patches: 106 | return apply_model(opts["input"], opts["timestep"], **opts["c"]) 107 | return patches[0](apply_model, opts) 108 | 109 | 110 | class PatchTypeSamplerPostCfgFunction(PatchTypeModel): 111 | def get_patches(self, model_options): 112 | return model_options.get(self.name, ()) 113 | 114 | def set_patches(self, model_options, val): 115 | model_options[self.name] = val 116 | 117 | _call_result_key = "denoised" 118 | 119 | @classmethod 120 | def _call(cls, patches, opts): 121 | if not patches: 122 | return opts["denoised"] 123 | curr_opts = opts.copy() 124 | key = cls._call_result_key 125 | for p in patches: 126 | result = p(curr_opts) 127 | curr_opts[key] = result 128 | return result 129 | 130 | 131 | class PatchTypeSamplerPreCfgFunction(PatchTypeSamplerPostCfgFunction): 132 | _call_result_key = "conds_out" 133 | 134 | 135 | class PatchTypeSamplerCfgFunction(PatchTypeModel): 136 | @classmethod 137 | def _call(cls, patches, opts): 138 | if not patches: 139 | cond_pred, uncond_pred = opts["cond_denoised"], opts["uncond_denoised"] 140 | return uncond_pred + (cond_pred - uncond_pred) * opts["cond_scale"] 141 | return patches[0](opts) 142 | 143 | 144 | PATCH_TYPES = { 145 | "input_block_patch": PatchTypeTransformer("input_block_patch"), 146 | "input_block_patch_after_skip": PatchTypeTransformer( 147 | "input_block_patch_after_skip", 148 | ), 149 | "output_block_patch": PatchTypeTransformer("output_block_patch", nresult=2), 150 | "attn1_patch": PatchTypeTransformer("attn1_patch", nresult=3), 151 | "attn1_output_patch": PatchTypeTransformer("attn1_output_patch"), 152 | "attn2_patch": PatchTypeTransformer("attn2_patch", nresult=3), 153 | "attn2_output_patch": PatchTypeTransformer("attn2_output_patch"), 154 | "middle_patch": PatchTypeTransformer("middle_patch"), 155 | "attn1": PatchTypeTransformerReplace("attn1"), 156 | "attn2": PatchTypeTransformerReplace("attn2"), 157 | "model_function_wrapper": PatchTypeModelWrapper("model_function_wrapper"), 158 | "sampler_cfg_function": PatchTypeSamplerCfgFunction("sampler_cfg_function"), 159 | "sampler_post_cfg_function": PatchTypeSamplerPostCfgFunction( 160 | "sampler_post_cfg_function", 161 | ), 162 | "sampler_pre_cfg_function": PatchTypeSamplerPostCfgFunction( 163 | "sampler_pre_cfg_function", 164 | ), 165 | } 166 | 167 | 168 | class ModelConditionalState: 169 | def __init__(self): 170 | self.last_sigma = None 171 | self.step = None 172 | 173 | def update(self, sigma): 174 | if self.last_sigma is None or sigma > self.last_sigma: 175 | self.step = 0 176 | elif sigma != self.last_sigma: 177 | self.step += 1 178 | self.last_sigma = sigma 179 | return self.step 180 | 181 | 182 | class ModelPatchConditional: 183 | def __init__( # noqa: PLR0917 184 | self, 185 | model_default, 186 | model_matched, 187 | start_percent: float = 0.0, 188 | end_percent: float = 1.0, 189 | interval: int = 1, 190 | base_on_default: bool = True, 191 | ): 192 | self.options_default = dict(**model_default.model_options) 193 | self.options_matched = dict(**model_matched.model_options) 194 | self.start_percent = start_percent 195 | self.end_percent = end_percent 196 | self.interval = interval 197 | self.base = model_default if base_on_default else model_matched 198 | self.base_on_default = base_on_default 199 | self.sigma_end = self.sigma_start = None 200 | 201 | def lazy_calc_steps(self): 202 | if self.sigma_start is not None: 203 | return 204 | model = self.base 205 | count = 0 206 | while hasattr(model, "model"): 207 | count += 1 208 | if count > 128: 209 | raise ValueError("I can't handle these insane levels of modelception!") 210 | model = model.model 211 | self.sigma_start = model.model_sampling.percent_to_sigma( 212 | self.start_percent, 213 | ) 214 | self.sigma_end = model.model_sampling.percent_to_sigma( 215 | self.end_percent, 216 | ) 217 | 218 | def should_use_patched(self, state, opts): 219 | self.lazy_calc_steps() 220 | if "sigmas" in opts: 221 | sigmas = opts["sigmas"] 222 | elif "transformer_options" in opts: 223 | sigmas = opts["transformer_options"]["sigmas"] 224 | elif "c" in opts: 225 | sigmas = opts["c"]["transformer_options"]["sigmas"] 226 | elif "sigma" in opts: 227 | sigmas = opts["sigma"] 228 | else: 229 | raise ValueError("Cannot determine sigma") 230 | iv = self.interval 231 | sigma = sigmas[0].item() 232 | step = state.update(sigma) 233 | matched = sigma <= self.sigma_start and sigma >= self.sigma_end 234 | matched &= (step % iv) == 0 if iv > 0 else ((step + 1) % abs(iv)) > 0 235 | return matched 236 | 237 | def mk_patch_handler(self, pt, state, key=None): 238 | def handler(*args: list[Any]): 239 | matched = self.should_use_patched(state, args[-1]) 240 | # print(f">> {pt.name}, active: {matched}, step: {state.step}") 241 | opts = self.options_matched if matched else self.options_default 242 | return pt(key, opts, *args) if key else pt(opts, *args) 243 | 244 | return handler 245 | 246 | def patch(self): 247 | state = ModelConditionalState() 248 | base_patched = self.base.clone() 249 | for pt in PATCH_TYPES.values(): 250 | if not (pt.exists(self.options_default) or pt.exists(self.options_matched)): 251 | continue 252 | # print(f"set patch {pt.name}") 253 | if not isinstance(pt, PatchTypeTransformerReplace): 254 | pt.set_patches( 255 | base_patched.model_options, 256 | [self.mk_patch_handler(pt, state)], 257 | ) 258 | continue 259 | pt.set_patches( 260 | base_patched.model_options, 261 | { 262 | k: self.mk_patch_handler(pt, state, key=k) 263 | for k in ( 264 | pt.get_patches(self.options_default).keys() 265 | | pt.get_patches(self.options_matched).keys() 266 | ) 267 | }, 268 | ) 269 | base_patched.model_options["disable_cfg1_optimization"] = ( 270 | self.options_default.get("disable_cfg1_optimization", False) 271 | or self.options_matched.get("disable_cfg1_optimization", False) 272 | ) 273 | return base_patched 274 | 275 | 276 | class ModelPatchConditionalNode: 277 | RETURN_TYPES = ("MODEL",) 278 | FUNCTION = "patch" 279 | CATEGORY = "bleh/model_patches" 280 | DESCRIPTION = "Experimental model patch that lets you control when other model patches are active." 281 | 282 | @classmethod 283 | def INPUT_TYPES(cls): 284 | return { 285 | "required": { 286 | "model_default": ( 287 | "MODEL", 288 | { 289 | "tooltip": "Fallback model patches, used when start/end/interval do not match.", 290 | }, 291 | ), 292 | "model_matched": ( 293 | "MODEL", 294 | {"tooltip": "Model patches used when start/end/interval match."}, 295 | ), 296 | "start_percent": ( 297 | "FLOAT", 298 | { 299 | "default": 0.0, 300 | "min": 0.0, 301 | "max": 1.0, 302 | "step": 0.001, 303 | "tooltip": "Start time as sampling percentage (not percentage of steps). Percentages are inclusive.", 304 | }, 305 | ), 306 | "end_percent": ( 307 | "FLOAT", 308 | { 309 | "default": 1.0, 310 | "min": 0.0, 311 | "max": 1.0, 312 | "step": 0.001, 313 | "tooltip": "End time as sampling percentage (not percentage of steps). Percentages are inclusive.", 314 | }, 315 | ), 316 | "interval": ( 317 | "INT", 318 | { 319 | "default": 1, 320 | "min": -999, 321 | "max": 999, 322 | "tooltip": "Step interval to use model_matched. If positive 3 would mean activate every third step, if negative -3 would mean skip every third step.", 323 | }, 324 | ), 325 | "base_on_default": ( 326 | "BOOLEAN", 327 | { 328 | "default": True, 329 | "tooltip": "When true, the active set of patches will be applied to model_default, otherwise they will be applied to model_matched.", 330 | }, 331 | ), 332 | }, 333 | } 334 | 335 | @classmethod 336 | def patch( 337 | cls, 338 | *, 339 | model_default, 340 | model_matched=None, 341 | start_percent: float = 0.0, 342 | end_percent: float = 1.0, 343 | interval: int = 1, 344 | base_on_default: bool = True, 345 | ): 346 | if not model_matched or start_percent >= 1.0 or interval == 0: 347 | return (model_default.clone(),) 348 | mopts = getattr(model_default, "model_options", None) 349 | if mopts is None or not isinstance(mopts, dict): 350 | # Not an instance of ModelPatcher, apparently so we can't do anything here. 351 | return (model_default.clone(),) 352 | return ( 353 | ModelPatchConditional( 354 | model_default, 355 | model_matched, 356 | start_percent, 357 | end_percent, 358 | interval, 359 | base_on_default=base_on_default, 360 | ).patch(), 361 | ) 362 | -------------------------------------------------------------------------------- /py/nodes/ops.py: -------------------------------------------------------------------------------- 1 | # Adapted from the ComfyUI built-in node 2 | from __future__ import annotations 3 | 4 | import bisect 5 | import importlib 6 | import operator as pyop 7 | from collections import OrderedDict 8 | from enum import Enum, auto 9 | from itertools import starmap 10 | 11 | import torch 12 | import yaml 13 | 14 | from ..latent_utils import * # noqa: TID252 15 | 16 | try: 17 | sonar_noise = importlib.import_module("custom_nodes.ComfyUI-sonar.py.noise") 18 | get_noise_sampler = sonar_noise.get_noise_sampler 19 | except ImportError: 20 | 21 | def get_noise_sampler(noise_type, x, *_args: list, **_kwargs: dict): 22 | if noise_type != "gaussian": 23 | raise ValueError("Only gaussian noise supported") 24 | return lambda _s, _sn: torch.randn_like(x) 25 | 26 | 27 | class CondType(Enum): 28 | TYPE = auto() 29 | BLOCK = auto() 30 | STAGE = auto() 31 | FROM_PERCENT = auto() 32 | TO_PERCENT = auto() 33 | PERCENT = auto() 34 | STEP = auto() # Calculated from closest sigma. 35 | STEP_EXACT = auto() # Only exact matching sigma or -1. 36 | FROM_STEP = auto() 37 | TO_STEP = auto() 38 | STEP_INTERVAL = auto() 39 | COND = auto() 40 | 41 | 42 | class PatchType(Enum): 43 | LATENT = auto() 44 | INPUT = auto() 45 | INPUT_AFTER_SKIP = auto() 46 | MIDDLE = auto() 47 | OUTPUT = auto() 48 | POST_CFG = auto() 49 | PRE_APPLY_MODEL = auto() 50 | POST_APPLY_MODEL = auto() 51 | 52 | 53 | class CompareType(Enum): 54 | EQ = auto() 55 | NE = auto() 56 | GT = auto() 57 | LT = auto() 58 | GE = auto() 59 | LE = auto() 60 | NOT = auto() 61 | OR = auto() 62 | AND = auto() 63 | 64 | 65 | class OpType(Enum): 66 | # scale, strength, blend, blend mode, use hidden mean, dim, scale offset 67 | SLICE = auto() 68 | 69 | # scale, filter, filter strength, threshold 70 | FFILTER = auto() 71 | 72 | # type (bicubic, nearest, bilinear, area), scale width, scale height, antialias 73 | SCALE_TORCH = auto() 74 | 75 | # type (bicubic, nearest, bilinear, area), antialias 76 | UNSCALE_TORCH = auto() 77 | 78 | # type width (slerp, slerp_alt, hslerp, colorize), type height, scale width, scale height, antialias size 79 | SCALE = auto() 80 | 81 | # type width (slerp, slerp_alt, hslerp, colorize), type height, antialias size 82 | UNSCALE = auto() 83 | 84 | # direction (h, v) 85 | FLIP = auto() 86 | 87 | # count 88 | ROT90 = auto() 89 | 90 | # count 91 | ROLL_CHANNELS = auto() 92 | 93 | # direction (horizontal, vertical, channels) or list of dims, amount (integer or percentage >-1.0 < 1.0) 94 | ROLL = auto() 95 | 96 | # true/false - only makes sense with output block 97 | TARGET_SKIP = auto() 98 | 99 | # factor 100 | MULTIPLY = auto() 101 | 102 | # blend strength, blend_mode, [op] 103 | BLEND_OP = auto() 104 | 105 | # scale mode, antialias size, mask example, [op], blend_mode 106 | MASK_EXAMPLE_OP = auto() 107 | 108 | # size 109 | ANTIALIAS = auto() 110 | 111 | # scale, type, scale_mode (none, sigma, sigdiff) 112 | NOISE = auto() 113 | 114 | # none 115 | DEBUG = auto() 116 | 117 | # mode (constant, reflect, replicate, circular), top, bottom, left, right, constant 118 | PAD = auto() 119 | 120 | # top, bottom, left, right 121 | CROP = auto() 122 | 123 | # count, [ops] 124 | REPEAT = auto() 125 | 126 | # scale, type 127 | APPLY_ENHANCEMENT = auto() 128 | 129 | 130 | OP_DEFAULTS = { 131 | OpType.SLICE: OrderedDict( 132 | scale=1.0, 133 | strength=1.0, 134 | blend=1.0, 135 | blend_mode="bislerp", 136 | use_hidden_mean=True, 137 | dim=1, 138 | scale_offset=0, 139 | ), 140 | OpType.FFILTER: OrderedDict( 141 | scale=1.0, 142 | filter="none", 143 | filter_strength=0.5, 144 | threshold=1, 145 | ), 146 | OpType.SCALE_TORCH: OrderedDict( 147 | type="bicubic", 148 | scale_width=1.0, 149 | scale_height=None, 150 | antialias=False, 151 | ), 152 | OpType.SCALE: OrderedDict( 153 | type_width="bicubic", 154 | type_height="bicubic", 155 | scale_width=1.0, 156 | scale_height=None, 157 | antialias_size=0, 158 | ), 159 | OpType.UNSCALE_TORCH: OrderedDict( 160 | type="bicubic", 161 | antialias=False, 162 | ), 163 | OpType.UNSCALE: OrderedDict( 164 | type_width="bicubic", 165 | type_height="bicubic", 166 | antialias_size=0, 167 | ), 168 | OpType.FLIP: OrderedDict(direction="h"), 169 | OpType.ROT90: OrderedDict(count=1), 170 | OpType.ROLL_CHANNELS: OrderedDict(count=1), 171 | OpType.ROLL: OrderedDict(direction="c", amount=1), 172 | OpType.TARGET_SKIP: OrderedDict(active=True), 173 | OpType.MULTIPLY: OrderedDict(factor=1.0), 174 | OpType.BLEND_OP: OrderedDict(blend=1.0, blend_mode="bislerp", ops=()), 175 | OpType.MASK_EXAMPLE_OP: OrderedDict( 176 | scale_mode="bicubic", 177 | antialias_size=7, 178 | mask=( 179 | (0.5, 0.25, (16, 0.0), 0.25, 0.5), 180 | ("rep", 18, (20, 0.0)), 181 | (0.5, 0.25, (16, 0.0), 0.25, 0.5), 182 | ), 183 | ops=(), 184 | blend_mode="lerp", 185 | ), 186 | OpType.ANTIALIAS: OrderedDict(size=7), 187 | OpType.NOISE: OrderedDict(scale=0.5, type="gaussian", scale_mode="sigdiff"), 188 | OpType.DEBUG: OrderedDict(), 189 | OpType.PAD: OrderedDict( 190 | mode="reflect", 191 | top=0, 192 | bottom=0, 193 | left=0, 194 | right=0, 195 | constant=None, 196 | ), 197 | OpType.CROP: OrderedDict(top=0, bottom=0, left=0, right=0), 198 | OpType.REPEAT: OrderedDict(count=2, ops=()), 199 | OpType.APPLY_ENHANCEMENT: OrderedDict(scale=1.0, type="korniabilateralblur"), 200 | } 201 | 202 | 203 | class Compare: 204 | VALID_TYPES = { # noqa: RUF012 205 | CondType.BLOCK, 206 | CondType.STAGE, 207 | CondType.PERCENT, 208 | CondType.STEP, 209 | CondType.STEP_EXACT, 210 | } 211 | 212 | def __init__(self, typ: str, value): 213 | self.typ = getattr(CompareType, typ.upper().strip()) 214 | if self.typ in {CompareType.OR, CompareType.AND, CompareType.NOT}: 215 | self.value = tuple(ConditionGroup(v) for v in value) 216 | self.field = None 217 | return 218 | self.field = getattr(CondType, value[0].upper().strip()) 219 | if self.field not in self.VALID_TYPES: 220 | raise TypeError("Invalid type compare operation") 221 | self.opfn = getattr(pyop, self.typ.name.lower()) 222 | self.value = value[1:] 223 | if not isinstance(self.value, (list, tuple)): 224 | self.value = (self.value,) 225 | 226 | def test(self, state: dict) -> bool: 227 | if self.typ == CompareType.NOT: 228 | return all(not v.test(state) for v in self.value) 229 | if self.typ == CompareType.AND: 230 | return all(v.test(state) for v in self.value) 231 | if self.typ == CompareType.OR: 232 | return any(v.test(state) for v in self.value) 233 | opfn, fieldval = self.opfn, state[self.field] 234 | return all(opfn(fieldval, val) for val in self.value) 235 | 236 | def __repr__(self) -> str: 237 | return f"" 238 | 239 | 240 | class Condition: 241 | def __init__(self, typ: str, value): 242 | self.typ = getattr(CondType, typ.upper().strip()) 243 | if self.typ == CondType.TYPE: 244 | if not isinstance(value, (list, tuple)): 245 | value = (value,) 246 | self.value = {getattr(PatchType, pt.strip().upper()) for pt in value} 247 | elif self.typ is not CondType.COND: 248 | self.value = set(value if isinstance(value, (list, tuple)) else (value,)) 249 | else: 250 | self.value = Compare(value[0], value[1:]) 251 | 252 | def test(self, state: dict) -> bool: 253 | if self.typ == CondType.FROM_PERCENT: 254 | pct = state[CondType.PERCENT] 255 | result = all(pct >= v for v in self.value) 256 | elif self.typ == CondType.TO_PERCENT: 257 | pct = state[CondType.PERCENT] 258 | result = all(pct <= v for v in self.value) 259 | elif self.typ == CondType.FROM_STEP: 260 | step = state[CondType.STEP] 261 | result = step > 0 and all(step >= v for v in self.value) 262 | elif self.typ == CondType.TO_STEP: 263 | step = state[CondType.STEP] 264 | result = step > 0 and all(step <= v for v in self.value) 265 | elif self.typ == CondType.STEP_INTERVAL: 266 | step = state[CondType.STEP] 267 | result = step > 0 and all(step % v == 0 for v in self.value) 268 | elif self.typ == CondType.COND: 269 | result = self.value.test(state) 270 | else: 271 | result = state[self.typ] in self.value 272 | return result 273 | 274 | def __repr__(self) -> str: 275 | return f"" 276 | 277 | 278 | class ConditionGroup: 279 | def __init__(self, conds): 280 | if not conds: 281 | self.conds = () 282 | return 283 | if isinstance(conds, dict): 284 | conds = tuple(conds.items()) 285 | if isinstance(conds[0], str): 286 | conds = (conds,) 287 | self.conds = tuple(starmap(Condition, conds)) 288 | 289 | def test(self, state: dict) -> bool: 290 | return all(c.test(state) for c in self.conds) 291 | 292 | def get_all_types(self) -> set[str]: 293 | pass 294 | 295 | def __repr__(self) -> str: 296 | return f"" 297 | 298 | 299 | # Copied from https://github.com/WASasquatch/FreeU_Advanced 300 | def hidden_mean(h): 301 | hidden_mean = h.mean(1).unsqueeze(1) 302 | b = hidden_mean.shape[0] 303 | hidden_max, _ = torch.max(hidden_mean.view(b, -1), dim=-1, keepdim=True) 304 | hidden_min, _ = torch.min(hidden_mean.view(b, -1), dim=-1, keepdim=True) 305 | return (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / ( 306 | hidden_max - hidden_min 307 | ).unsqueeze(2).unsqueeze(3) 308 | 309 | 310 | class Operation: 311 | IDX = 0 312 | 313 | def __init__(self, typ: str | OpType, *args: list): 314 | if isinstance(typ, str): 315 | typ = getattr(OpType, typ.upper().strip()) 316 | self.typ = typ 317 | defaults = OP_DEFAULTS[self.typ] 318 | if len(args) == 1 and isinstance(args[0], dict): 319 | args = args[0] 320 | extra = set(args.keys()) - set(defaults.keys()) 321 | if extra: 322 | errstr = f"Unexpected argument keys for operation {typ}: {extra}" 323 | raise ValueError(errstr) 324 | self.args = tuple(starmap(args.get, defaults.items())) 325 | else: 326 | if len(args) > len(defaults): 327 | raise ValueError("Too many arguments for operation") 328 | self.args = (*args, *tuple(defaults.values())[len(args) :]) 329 | 330 | @staticmethod 331 | def build(typ: str | OpType, *args: list) -> object: 332 | if isinstance(typ, str): 333 | typ = getattr(OpType, typ.upper().strip()) 334 | return OP_TO_OPCLASS[typ](typ, *args) 335 | 336 | def eval(self, state: dict): 337 | out = self.op(state[state["target"]], state) 338 | state[state["target"]] = out 339 | 340 | def __repr__(self) -> str: 341 | return f"" 342 | 343 | 344 | class SubOpsOperation(Operation): 345 | SUBOPS_IDXS = () 346 | 347 | def __init__(self, *args: list, **kwargs: dict): 348 | super().__init__(*args, **kwargs) 349 | for argidx in self.SUBOPS_IDXS: 350 | subops = self.args[argidx] 351 | if subops and isinstance(subops[0], str): 352 | # Simple single subop. 353 | subops = (subops,) 354 | compiled_subops = [] 355 | for idx in range(len(subops)): 356 | subop = subops[idx] 357 | if isinstance(subop, dict): 358 | # Compile to rule. 359 | subop = subops[idx] = Rule.from_dict(subops[idx]) 360 | elif isinstance(subop, (list, tuple)): 361 | # Compile to op. 362 | subop = Operation.build(subop[0], *subop[1:]) 363 | compiled_subops.append(subop) 364 | temp = list(self.args) 365 | temp[argidx] = compiled_subops 366 | self.args = tuple(temp) 367 | 368 | 369 | class OpSlice(Operation): 370 | def op(self, t, _state): 371 | out = t 372 | scale, strength, blend, mode, use_hm, dim, scale_offset = self.args 373 | if dim < 0: 374 | dim = t.ndim + dim 375 | dim_size = t.shape[dim] 376 | slice_size = max(1, round(dim_size * scale)) 377 | slice_offset = int(dim_size * scale_offset) 378 | slice_def = tuple( 379 | slice(None, None) 380 | if idx != dim 381 | else slice(slice_offset, slice_offset + slice_size) 382 | for idx in range(dim + 1) 383 | ) 384 | sliced = t[slice_def] 385 | if use_hm: 386 | result = sliced * ((strength - 1) * hidden_mean(t)[slice_def] + 1) 387 | else: 388 | result = sliced * strength 389 | if blend != 1: 390 | result = BLENDING_MODES[mode](sliced, result, blend) 391 | out[slice_def] = result 392 | return out 393 | 394 | 395 | class OpFFilter(Operation): 396 | def op(self, t, _state): 397 | scale, filt, strength, threshold = self.args 398 | if isinstance(filt, str): 399 | filt = FILTER_PRESETS[filt] 400 | elif filt is None: 401 | filt = () 402 | return ffilter(t, threshold, scale, filt, strength) 403 | 404 | 405 | class OpScaleTorch(Operation): 406 | def op(self, t, state): 407 | if self.typ == OpType.SCALE_TORCH: 408 | mode, scale_w, scale_h, antialias = self.args 409 | width, height = ( 410 | round(t.shape[-1] * scale_w), 411 | round(t.shape[-2] * scale_h), 412 | ) 413 | else: 414 | hsp = state["hsp"] 415 | if hsp is None: 416 | raise ValueError( 417 | "Can only use unscale_torch when HSP is set (output)", 418 | ) 419 | if t.shape[-1] == hsp.shape[-1] and t.shape[-2] == hsp.shape[-2]: 420 | return t 421 | mode, antialias = self.args 422 | width, height = hsp.shape[-1], hsp.shape[-2] 423 | return scale_samples( 424 | t, 425 | width, 426 | height, 427 | mode, 428 | antialias_size=8 if antialias else 0, 429 | sigma=state.get("sigma"), 430 | ) 431 | 432 | 433 | class OpUnscaleTorch(OpScaleTorch): 434 | pass 435 | 436 | 437 | class OpScale(Operation): 438 | def op(self, t, state): 439 | if self.typ == OpType.SCALE: 440 | mode_w, mode_h, scale_w, scale_h, antialias_size = self.args 441 | width, height = ( 442 | round(t.shape[-1] * scale_w), 443 | round(t.shape[-2] * scale_h), 444 | ) 445 | else: 446 | hsp = state["hsp"] 447 | if hsp is None: 448 | raise ValueError( 449 | "Can only use unscale when HSP is set (output)", 450 | ) 451 | if t.shape[-1] == hsp.shape[-1] and t.shape[-2] == hsp.shape[-2]: 452 | return t 453 | mode_w, mode_h, antialias_size = self.args 454 | width, height = hsp.shape[-1], hsp.shape[-2] 455 | return scale_samples( 456 | t, 457 | width, 458 | height, 459 | mode=mode_w, 460 | mode_h=mode_h, 461 | antialias_size=antialias_size, 462 | ) 463 | 464 | 465 | class OpUnscale(OpScale): 466 | pass 467 | 468 | 469 | class OpFlip(Operation): 470 | def op(self, t, _state): 471 | dimarg = self.args[0] 472 | if isinstance(dimarg, str): 473 | dim = dimarg[:1] == "v" 474 | elif isinstance(dimarg, int): 475 | dim = (dimarg,) 476 | else: 477 | dim = dimarg 478 | return torch.flip(t, dims=dim) 479 | 480 | 481 | class OpRot90(Operation): 482 | def op(self, t, _state): 483 | return torch.rot90(t, self.args[0], dims=(3, 2)) 484 | 485 | 486 | class OpRollChannels(Operation): 487 | def op(self, t, _state): 488 | return torch.roll(t, self.args[0], dims=(1,)) 489 | 490 | 491 | class OpRoll(Operation): 492 | def op(self, t, _state): 493 | dims, amount = self.args 494 | if isinstance(dims, str): 495 | if dims in {"h", "horizontal"}: 496 | dims = (3,) 497 | elif dims in {"v", "vertical"}: 498 | dims = (2,) 499 | elif dims in {"c", "channels"}: 500 | dims = (1,) 501 | else: 502 | raise ValueError("Bad roll direction") 503 | elif isinstance(dims, int): 504 | dims = (dims,) 505 | if isinstance(amount, float) and amount < 1.0 and amount > -1.0: 506 | if len(dims) > 1: 507 | raise ValueError( 508 | "Cannot use percentage based amount with multiple roll dimensions", 509 | ) 510 | amount = int(t.shape[dims[0]] * amount) 511 | return torch.roll(t, amount, dims=dims) 512 | 513 | 514 | class OpTargetSkip(Operation): 515 | def op(self, t, state): 516 | if state.get("hsp") is None: 517 | if state["target"] == "hsp": 518 | state["target"] = "h" 519 | return t 520 | state["target"] = "hsp" if self.args[0] is True else "h" 521 | return state[state["target"]] 522 | 523 | 524 | class OpMultiply(Operation): 525 | def op(self, t, _state): 526 | return t.mul_(self.args[0]) 527 | 528 | 529 | class OpBlendOp(SubOpsOperation): 530 | SUBOPS_IDXS = (2,) 531 | 532 | def op(self, t, state): 533 | blend, mode, subops = self.args 534 | tempname = f"temp{Operation.IDX}" 535 | Operation.IDX += 1 536 | old_target = state["target"] 537 | state[tempname] = t.clone() 538 | for subop in subops: 539 | state["target"] = tempname 540 | subop.eval(state) 541 | state["target"] = old_target 542 | out = BLENDING_MODES[mode](t, state[tempname], blend) 543 | del state[tempname] 544 | return out 545 | 546 | 547 | class OpMaskExampleOp(SubOpsOperation): 548 | SUBOPS_IDXS = (3,) 549 | 550 | def __init__(self, *args: list, **kwargs: dict): 551 | super().__init__(*args, **kwargs) 552 | scale_mode, antialias_size, maskdef, subops, blend_mode = self.args 553 | blend_function = BLENDING_MODES.get(blend_mode) 554 | if blend_function is None: 555 | raise ValueError("Bad blend mode") 556 | mask = [] 557 | for rowidx in range(len(maskdef)): 558 | repeats = 1 559 | rowdef = maskdef[rowidx] 560 | if rowdef and rowdef[0] == "rep": 561 | repeats = int(rowdef[1]) 562 | rowdef = rowdef[2:] 563 | row = [] 564 | for col in rowdef: 565 | if isinstance(col, (list, tuple)): 566 | row += col[1:] * col[0] 567 | else: 568 | row.append(col) 569 | mask += (row,) * repeats 570 | mask = torch.tensor(mask, dtype=torch.float32, device="cpu") 571 | self.args = ( 572 | scale_mode, 573 | antialias_size, 574 | mask, 575 | subops, 576 | blend_function, 577 | ) 578 | 579 | def op(self, t, state): 580 | scale_mode, antialias_size, mask, subops, blend_function = self.args 581 | mask = scale_samples( 582 | mask.view(1, 1, *mask.shape).to(t.device, dtype=t.dtype), 583 | t.shape[-1], 584 | t.shape[-2], 585 | mode=scale_mode, 586 | antialias_size=antialias_size, 587 | ).broadcast_to(t.shape) 588 | tempname = f"temp{Operation.IDX}" 589 | Operation.IDX += 1 590 | old_target = state["target"] 591 | state[tempname] = t.clone() 592 | for subop in subops: 593 | state["target"] = tempname 594 | subop.eval(state) 595 | state["target"] = old_target 596 | out = blend_function(t, state[tempname], mask) 597 | del state[tempname] 598 | return out 599 | 600 | 601 | class OpAntialias(Operation): 602 | def op(self, t, _state): 603 | return antialias_tensor(t, self.args[0]) 604 | 605 | 606 | class OpNoise(Operation): 607 | def op(self, t, state): 608 | scale, noise_type, scale_mode = self.args 609 | if scale_mode == "sigma": 610 | step_scale = state.get("sigma", 1.0) 611 | elif scale_mode == "sigdiff": 612 | if "sigma" in state and "sigma_next" in state: 613 | step_scale = state["sigma"] - state["sigma_next"] 614 | else: 615 | step_scale = state.get("sigma", 1.0) 616 | else: 617 | step_scale = 1.0 618 | noise_sampler = get_noise_sampler( 619 | noise_type, 620 | t, 621 | state["sigma_min"], 622 | state["sigma_max"], 623 | ) 624 | noise = noise_sampler(state.get("sigma"), state.get("sigma_next")) 625 | t += noise * step_scale * scale 626 | return t 627 | 628 | 629 | class OpDebug(Operation): 630 | @classmethod 631 | def op(cls, t, state): 632 | stcopy = { 633 | k: v 634 | if not isinstance(v, torch.Tensor) 635 | else f"" 636 | for k, v in state.items() 637 | } 638 | stcopy["target_shape"] = t.shape 639 | print(f">> BlehOps debug: {stcopy!r}") 640 | return t 641 | 642 | 643 | class OpPad(Operation): 644 | def op(self, t, _state): 645 | mode, top, bottom, left, right, constant_value = self.args 646 | if mode != "constant": 647 | constant_value = None 648 | shp = t.shape 649 | top, bottom = tuple( 650 | val if isinstance(val, int) else int(shp[-2] * val) for val in (top, bottom) 651 | ) 652 | left, right = tuple( 653 | val if isinstance(val, int) else int(shp[-1] * val) for val in (left, right) 654 | ) 655 | return torch.nn.functional.pad( 656 | t, 657 | (left, right, top, bottom), 658 | mode=mode, 659 | value=constant_value, 660 | ) 661 | 662 | 663 | class OpCrop(Operation): 664 | def op(self, t, _state): 665 | top, bottom, left, right = self.args 666 | shp = t.shape 667 | top, bottom = tuple( 668 | val if isinstance(val, int) else int(shp[-2] * val) for val in (top, bottom) 669 | ) 670 | left, right = tuple( 671 | val if isinstance(val, int) else int(shp[-1] * val) for val in (left, right) 672 | ) 673 | bottom, right = shp[-2] - bottom, shp[-1] - right 674 | return t[:, :, top:bottom, left:right] 675 | 676 | 677 | class OpRepeat(SubOpsOperation): 678 | SUBOPS_IDXS = (1,) 679 | 680 | def op(self, _t, state): 681 | count, subops = self.args 682 | for _ in range(count): 683 | for subop in subops: 684 | subop.eval(state) 685 | return state[state["target"]] 686 | 687 | 688 | class OpApplyEnhancement(Operation): 689 | def op(self, t, state): 690 | scale, typ = self.args 691 | return enhance_tensor(t, typ, scale=scale, sigma=state.get("sigma")) 692 | 693 | 694 | OP_TO_OPCLASS = { 695 | OpType.SLICE: OpSlice, 696 | OpType.FFILTER: OpFFilter, 697 | OpType.SCALE_TORCH: OpScaleTorch, 698 | OpType.UNSCALE_TORCH: OpUnscaleTorch, 699 | OpType.SCALE: OpScale, 700 | OpType.UNSCALE: OpUnscale, 701 | OpType.FLIP: OpFlip, 702 | OpType.ROT90: OpRot90, 703 | OpType.ROLL_CHANNELS: OpRollChannels, 704 | OpType.ROLL: OpRoll, 705 | OpType.TARGET_SKIP: OpTargetSkip, 706 | OpType.MULTIPLY: OpMultiply, 707 | OpType.BLEND_OP: OpBlendOp, 708 | OpType.MASK_EXAMPLE_OP: OpMaskExampleOp, 709 | OpType.ANTIALIAS: OpAntialias, 710 | OpType.NOISE: OpNoise, 711 | OpType.DEBUG: OpDebug, 712 | OpType.PAD: OpPad, 713 | OpType.CROP: OpCrop, 714 | OpType.REPEAT: OpRepeat, 715 | OpType.APPLY_ENHANCEMENT: OpApplyEnhancement, 716 | } 717 | 718 | 719 | class Rule: 720 | @classmethod 721 | def from_dict(cls, val) -> object: 722 | if not isinstance(val, (list, tuple)): 723 | val = (val,) 724 | 725 | return tuple( 726 | cls( 727 | conds=d.get("if", ()), 728 | ops=d.get("ops", ()), 729 | matched=d.get("then", ()), 730 | nomatched=d.get("else", ()), 731 | ) 732 | for d in val 733 | if not d.get("disable") 734 | ) 735 | 736 | def __init__(self, conds=(), ops=(), matched=(), nomatched=()): 737 | self.conds = ConditionGroup(conds) 738 | if ops and isinstance(ops[0], str): 739 | ops = (ops,) 740 | self.ops = tuple(Operation.build(o[0], *o[1:]) for o in ops) 741 | self.matched = Rule.from_dict(matched) 742 | self.nomatched = Rule.from_dict(nomatched) 743 | 744 | def get_all_types(self) -> set: 745 | result = {c.value for c in self.conds if c.typ == CondType.TYPE} 746 | for r in self.matched: 747 | result |= r.get_all_types() 748 | for r in self.nomatched: 749 | result |= r.get_all_types() 750 | return result 751 | 752 | def eval(self, state: dict) -> None: 753 | # print("EVAL", state | {"h": None, "hsp": None}) 754 | 755 | if not self.conds.test(state): 756 | for r in self.nomatched: 757 | r.eval(state) 758 | return 759 | for o in self.ops: 760 | o.eval(state) 761 | for r in self.matched: 762 | r.eval(state) 763 | 764 | def __repr__(self): 765 | return f"" 766 | 767 | 768 | class RuleGroup: 769 | @classmethod 770 | def from_yaml(cls, s: str) -> object: 771 | parsed_rules = yaml.safe_load(s) 772 | if parsed_rules is None: 773 | return cls(()) 774 | return cls(tuple(r for rs in parsed_rules for r in Rule.from_dict(rs))) 775 | 776 | def __init__(self, rules): 777 | self.rules = rules 778 | 779 | def eval(self, state, toplevel=False): 780 | for rule in self.rules: 781 | if toplevel: 782 | state["target"] = "h" 783 | rule.eval(state) 784 | return state 785 | 786 | def __repr__(self) -> str: 787 | return f"" 788 | 789 | 790 | class BlehBlockOps: 791 | RETURN_TYPES = ("MODEL",) 792 | FUNCTION = "patch" 793 | CATEGORY = "bleh/model_patches" 794 | 795 | @classmethod 796 | def INPUT_TYPES(cls): 797 | return { 798 | "required": { 799 | "model": ("MODEL",), 800 | "rules": ("STRING", {"multiline": True, "dynamicPrompts": False}), 801 | }, 802 | "optional": { 803 | "sigmas_opt": ("SIGMAS",), 804 | }, 805 | } 806 | 807 | @classmethod 808 | def patch( 809 | cls, 810 | model, 811 | rules: str, 812 | sigmas_opt: torch.Tensor | None = None, 813 | ): 814 | rules = rules.strip() 815 | if len(rules) == 0: 816 | return (model.clone(),) 817 | rules = RuleGroup.from_yaml(rules) 818 | # print("RULES", rules) 819 | 820 | # Arbitrary number that should have good enough precision 821 | pct_steps = 400 822 | pct_incr = 1.0 / pct_steps 823 | model_sampling = model.get_model_object("model_sampling") 824 | sig2pct = tuple( 825 | model_sampling.percent_to_sigma(x / pct_steps) 826 | for x in range(pct_steps, -1, -1) 827 | ) 828 | 829 | def get_pct(topts): 830 | sigma = topts["sigmas"][0].item() 831 | # This is obviously terrible but I couldn't find a better way to get the percentage from the current sigma. 832 | idx = bisect.bisect_right(sig2pct, sigma) 833 | if idx >= len(sig2pct): 834 | # Sigma out of range somehow? 835 | return None 836 | return pct_incr * (pct_steps - idx) 837 | 838 | def set_state_step(state, sigma): 839 | sdict = { 840 | CondType.STEP: -1, 841 | CondType.STEP_EXACT: -1, 842 | "sigma": sigma, 843 | "sigma_min": model_sampling.sigma_min, 844 | "sigma_max": model_sampling.sigma_max, 845 | } 846 | if sigmas_opt is None: 847 | state |= sdict 848 | return state 849 | sigmadiff, idx = torch.min(torch.abs(sigmas_opt[:-1] - sigma), 0) 850 | idx = idx.item() 851 | state |= sdict | { 852 | CondType.STEP: idx + 1, 853 | CondType.STEP_EXACT: -1 if sigmadiff.item() > 1.5e-06 else idx + 1, 854 | "sigma_next": sigmas_opt[idx + 1].item(), 855 | } 856 | return state 857 | 858 | stages = (1280, 640, 320) 859 | 860 | def make_state(typ: PatchType, topts: dict, h, hsp=None): 861 | pct = get_pct(topts) 862 | if pct is None: 863 | return None 864 | 865 | stage = stages.index(h.shape[1]) + 1 if h.shape[1] in stages else -1 866 | # print(">>", typ, topts["original_shape"], h.shape, stage, topts["block"]) 867 | result = { 868 | CondType.TYPE: typ, 869 | CondType.PERCENT: pct, 870 | CondType.BLOCK: topts["block"][1], 871 | CondType.STAGE: stage, 872 | "h": h, 873 | "hsp": hsp, 874 | "target": "h", 875 | } 876 | set_state_step(result, topts["sigmas"].max().item()) 877 | return result 878 | 879 | def block_patch(typ, h, topts: dict): 880 | state = make_state(typ, topts, h) 881 | if state is None: 882 | return h 883 | return rules.eval(state, toplevel=True)["h"] 884 | 885 | def output_block_patch(h, hsp, transformer_options: dict): 886 | state = make_state(PatchType.OUTPUT, transformer_options, h, hsp) 887 | if state is None: 888 | return h 889 | rules.eval(state, toplevel=True) 890 | return state["h"], state["hsp"] 891 | 892 | def post_cfg_patch(args: dict): 893 | pct = get_pct({"sigmas": args["sigma"]}) 894 | if pct is None: 895 | return None 896 | state = { 897 | CondType.TYPE: PatchType.POST_CFG, 898 | CondType.PERCENT: pct, 899 | CondType.BLOCK: -1, 900 | CondType.STAGE: -1, 901 | "h": args["denoised"], 902 | "hsp": None, 903 | "target": "h", 904 | } 905 | set_state_step(state, args["sigma"].max().item()) 906 | return rules.eval(state)["h"] 907 | 908 | m = model.clone() 909 | m.set_model_input_block_patch_after_skip( 910 | lambda *args: block_patch(PatchType.INPUT_AFTER_SKIP, *args), 911 | ) 912 | m.set_model_input_block_patch(lambda *args: block_patch(PatchType.INPUT, *args)) 913 | m.set_model_patch( 914 | lambda *args: block_patch(PatchType.MIDDLE, *args), 915 | "middle_block_patch", 916 | ) 917 | m.set_model_output_block_patch(output_block_patch) 918 | m.set_model_sampler_post_cfg_function( 919 | post_cfg_patch, 920 | disable_cfg1_optimization=True, 921 | ) 922 | orig_model_function_wrapper = model.model_options.get("model_function_wrapper") 923 | 924 | def pre_model(state): 925 | state[CondType.TYPE] = PatchType.PRE_APPLY_MODEL 926 | return rules.eval(state, toplevel=True)["h"] 927 | 928 | def post_model(state, result): 929 | state[CondType.TYPE] = PatchType.POST_APPLY_MODEL 930 | state["target"] = "h" 931 | state["h"] = result 932 | return rules.eval(state, toplevel=True)["h"] 933 | 934 | def model_unet_function_wrapper(apply_model, args): 935 | pct = get_pct({"sigmas": args["timestep"]}) 936 | if pct is None: 937 | return None 938 | state = { 939 | CondType.PERCENT: pct, 940 | CondType.BLOCK: -1, 941 | CondType.STAGE: -1, 942 | "h": args["input"], 943 | "hsp": None, 944 | "target": "h", 945 | } 946 | set_state_step(state, args["timestep"].max().item()) 947 | x = pre_model(state) 948 | args = args | {"input": x} # noqa: PLR6104 949 | if orig_model_function_wrapper is not None: 950 | result = orig_model_function_wrapper(apply_model, args) 951 | else: 952 | result = apply_model(args["input"], args["timestep"], **args["c"]) 953 | return post_model(state, result) 954 | 955 | m.set_model_unet_function_wrapper(model_unet_function_wrapper) 956 | 957 | return (m,) 958 | 959 | 960 | class BlehLatentScaleBy: 961 | @classmethod 962 | def INPUT_TYPES(cls): 963 | return { 964 | "required": { 965 | "samples": ("LATENT",), 966 | "method_horizontal": (UPSCALE_METHODS,), 967 | "method_vertical": (("same", *UPSCALE_METHODS),), 968 | "scale_width": ( 969 | "FLOAT", 970 | {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}, 971 | ), 972 | "scale_height": ( 973 | "FLOAT", 974 | {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}, 975 | ), 976 | "antialias_size": ("INT", {"default": 0}), 977 | }, 978 | } 979 | 980 | RETURN_TYPES = ("LATENT",) 981 | FUNCTION = "upscale" 982 | 983 | CATEGORY = "latent" 984 | 985 | @classmethod 986 | def upscale( 987 | cls, 988 | *, 989 | samples: dict, 990 | method_horizontal: str, 991 | method_vertical: str, 992 | scale_width: float, 993 | scale_height: float, 994 | antialias_size: int, 995 | ): 996 | if method_vertical == "same": 997 | method_vertical = method_horizontal 998 | samples = samples.copy() 999 | stensor = samples["samples"] 1000 | width = round(stensor.shape[3] * scale_width) 1001 | height = round(stensor.shape[2] * scale_height) 1002 | samples["samples"] = scale_samples( 1003 | stensor, 1004 | width, 1005 | height, 1006 | mode=method_horizontal, 1007 | mode_h=method_vertical, 1008 | antialias_size=antialias_size, 1009 | ) 1010 | return (samples,) 1011 | 1012 | 1013 | class BlehLatentOps: 1014 | @classmethod 1015 | def INPUT_TYPES(cls): 1016 | return { 1017 | "required": { 1018 | "samples": ("LATENT",), 1019 | "rules": ("STRING", {"multiline": True, "dynamicPrompts": False}), 1020 | }, 1021 | "optional": { 1022 | "samples_hsp": ("LATENT",), 1023 | }, 1024 | } 1025 | 1026 | RETURN_TYPES = ("LATENT",) 1027 | FUNCTION = "go" 1028 | 1029 | CATEGORY = "latent" 1030 | 1031 | @classmethod 1032 | def go( 1033 | cls, 1034 | *, 1035 | samples: dict, 1036 | rules: str, 1037 | samples_hsp: dict | None = None, 1038 | ): 1039 | samples = samples.copy() 1040 | rules = rules.strip() 1041 | if len(rules) == 0: 1042 | return (samples,) 1043 | rules = RuleGroup.from_yaml(rules) 1044 | stensor = samples["samples"] 1045 | state = { 1046 | CondType.TYPE: PatchType.LATENT, 1047 | CondType.PERCENT: 0.0, 1048 | CondType.BLOCK: -1, 1049 | CondType.STAGE: -1, 1050 | "h": stensor, 1051 | "hsp": None if samples_hsp is None else samples_hsp["samples"], 1052 | "target": "h", 1053 | } 1054 | rules.eval(state, toplevel=True) 1055 | return ({"samples": state["h"]},) 1056 | 1057 | 1058 | class BlehLatentBlend: 1059 | @classmethod 1060 | def INPUT_TYPES(cls): 1061 | return { 1062 | "required": { 1063 | "samples1": ("LATENT",), 1064 | "samples2": ("LATENT",), 1065 | "samples2_percent": ("FLOAT", {"default": 0.5}), 1066 | "blend_mode": (tuple(BLENDING_MODES.keys()),), 1067 | }, 1068 | } 1069 | 1070 | RETURN_TYPES = ("LATENT",) 1071 | FUNCTION = "go" 1072 | 1073 | CATEGORY = "latent" 1074 | 1075 | @classmethod 1076 | def go( 1077 | cls, 1078 | *, 1079 | samples1: dict, 1080 | samples2: dict, 1081 | samples2_percent=0.5, 1082 | blend_mode="lerp", 1083 | ): 1084 | a, b = samples1["samples"], samples2["samples"] 1085 | blend_function = BLENDING_MODES[blend_mode] 1086 | return ({"samples": blend_function(a, b, samples2_percent)},) 1087 | -------------------------------------------------------------------------------- /py/nodes/refinerAfter.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import comfy.model_management as mm 4 | 5 | 6 | class BlehRefinerAfter: 7 | DESCRIPTION = "Allows switching to another model at a certain point in sampling. Only works with models that are closely related as the sampling type and conditioning must match. Can be used to switch to a refiner model near the end of sampling." 8 | RETURN_TYPES = ("MODEL",) 9 | CATEGORY = "bleh/model_patches" 10 | 11 | @classmethod 12 | def INPUT_TYPES(cls): 13 | return { 14 | "required": { 15 | "time_mode": ( 16 | ( 17 | "timestep", 18 | "percent", 19 | "sigma", 20 | ), 21 | { 22 | "tooltip": "Controls how start_time is interpreted. Timestep will be 999 at the start of sampling and 0 at the end - it is basically the inverse of the sampling percentage with a multiplier. Percent is the percent of sampling (not steps) and will be 0.0 at the start of sampling and 1.0 at the end. Sigma is an advanced option - if you don't know what it is, you don't need to use it.", 23 | }, 24 | ), 25 | "start_time": ( 26 | "FLOAT", 27 | { 28 | "default": 199.0, 29 | "min": 0.0, 30 | "max": 999.0, 31 | "tooltip": "Time the refiner_model will become active. The type of value you enter here will depend on what time_mode is set to.", 32 | }, 33 | ), 34 | "model": ( 35 | "MODEL", 36 | { 37 | "tooltip": "Model to patch. This will also be the active model until the start_time condition is met.", 38 | }, 39 | ), 40 | "refiner_model": ( 41 | "MODEL", 42 | { 43 | "tooltip": "Model to switch to after the start_time condition is met.", 44 | }, 45 | ), 46 | }, 47 | } 48 | 49 | FUNCTION = "patch" 50 | 51 | @staticmethod 52 | def get_real_model(model: object) -> object: 53 | while hasattr(model, "model"): 54 | model = model.model 55 | return model 56 | 57 | @staticmethod 58 | def load_if_needed(model: object) -> bool: 59 | if mm.LoadedModel(model) in mm.current_loaded_models: 60 | return False 61 | mm.load_models_gpu([model]) 62 | return True 63 | 64 | def patch( # noqa: PLR0911 65 | self, 66 | start_time: float, 67 | model: object, 68 | refiner_model: object, 69 | time_mode: str = "timestep", 70 | ) -> tuple[object]: 71 | model = model.clone() 72 | refiner_model = refiner_model.clone() 73 | ms = self.get_real_model(model).model_sampling 74 | real_refiner_model = None 75 | 76 | if time_mode == "sigma": 77 | if start_time <= ms.sigma_min: 78 | return (model,) 79 | if start_time >= ms.sigma_max: 80 | return (refiner_model,) 81 | 82 | def check_time(sigma): 83 | return sigma.item() <= start_time 84 | 85 | elif time_mode == "percent": 86 | if start_time > 1.0 or start_time < 0.0: 87 | raise ValueError( 88 | "BlehRefinerAfter: invalid value for percent start time", 89 | ) 90 | if start_time >= 1.0: 91 | return (model,) 92 | if start_time <= 0.0: 93 | return (refiner_model,) 94 | 95 | def check_time(sigma): 96 | return sigma.item() <= ms.percent_to_sigma(start_time) 97 | 98 | elif time_mode == "timestep": 99 | if start_time <= 0.0: 100 | return (model,) 101 | if start_time >= 999.0: 102 | return (refiner_model,) 103 | 104 | def check_time(sigma): 105 | return ms.timestep(sigma) <= start_time 106 | 107 | else: 108 | raise ValueError("BlehRefinerAfter: invalid time mode") 109 | 110 | def unet_wrapper(apply_model, args): 111 | nonlocal real_refiner_model 112 | 113 | inp, timestep, c = args["input"], args["timestep"], args["c"] 114 | if not check_time(timestep.max()): 115 | real_refiner_model = None 116 | self.load_if_needed(model) 117 | return apply_model(inp, timestep, **c) 118 | if self.load_if_needed(refiner_model) or not real_refiner_model: 119 | real_refiner_model = self.get_real_model(refiner_model) 120 | return real_refiner_model.apply_model(inp, timestep, **c) 121 | 122 | model.set_model_unet_function_wrapper(unet_wrapper) 123 | return (model,) 124 | -------------------------------------------------------------------------------- /py/nodes/sageAttention.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import contextlib 4 | import importlib 5 | from typing import TYPE_CHECKING 6 | 7 | import comfy.ldm.modules.attention as comfyattn 8 | import yaml 9 | from comfy.samplers import KSAMPLER 10 | 11 | try: 12 | import sageattention 13 | 14 | sageattn_default_function = sageattention.sageattn 15 | except ImportError: 16 | sageattention = None 17 | sageattn_default_head_sizes = None 18 | sageattn_default_function = None 19 | 20 | 21 | if TYPE_CHECKING: 22 | import collections 23 | from collections.abc import Callable 24 | 25 | import torch 26 | 27 | 28 | if sageattention is not None: 29 | try: 30 | sageattn_version = importlib.metadata.version("sageattention") 31 | except Exception: # noqa: BLE001 32 | sageattn_version = "unknown" 33 | if ( 34 | sageattn_version == "unknown" 35 | or sageattn_version.startswith("1.") 36 | or sageattn_version == "2.0.0" 37 | ): 38 | sageattn_default_head_sizes = {64, 96, 128} 39 | else: 40 | # SageAttention 2.0.1 (and later one would assume) supports up to 128. 41 | sageattn_default_head_sizes = set(range(1, 129)) 42 | else: 43 | sageattn_version = "unknown" 44 | 45 | 46 | def attention_sage( 47 | q: torch.Tensor, 48 | k: torch.Tensor, 49 | v: torch.Tensor, 50 | heads: int, 51 | *, 52 | orig_attention: Callable, 53 | sageattn_allow_head_sizes: collections.abc.Collection 54 | | None = sageattn_default_head_sizes, 55 | sageattn_function: collections.abc.Callable = sageattn_default_function, 56 | sageattn_version: str = sageattn_version, 57 | sageattn_verbose: bool = False, 58 | **kwargs: dict[str], 59 | ) -> torch.Tensor: 60 | old_sageattn = sageattn_version[:2] in {"1.", "un"} 61 | mask = kwargs.get("mask") 62 | skip_reshape = kwargs.get("skip_reshape", False) 63 | skip_output_reshape = kwargs.get("skip_output_reshape", False) 64 | batch = q.shape[0] 65 | dim_head = q.shape[-1] // (1 if skip_reshape else heads) 66 | enabled = sageattn_allow_head_sizes is None or dim_head in sageattn_allow_head_sizes 67 | if enabled and old_sageattn: 68 | enabled = all(t.shape == q.shape for t in (k, v)) 69 | if sageattn_verbose: 70 | print( 71 | f"\n>> SAGE({enabled}): reshape={not skip_reshape}, output_reshape={not skip_output_reshape}, dim_head={q.shape[-1]}, heads={heads}, adj_heads={dim_head}, q={q.shape}, k={k.shape}, v={v.shape}, args: {kwargs}\n", 72 | ) 73 | if not enabled: 74 | filtered_kwargs = { 75 | k: v 76 | for k, v in kwargs.items() 77 | if k in {"mask", "skip_reshape", "skip_output_reshape", "attn_precision"} 78 | } 79 | return orig_attention(q, k, v, heads, **filtered_kwargs) 80 | tensor_layout = kwargs.pop("tensor_layout", None) 81 | if old_sageattn: 82 | tensor_layout = "HND" 83 | elif tensor_layout is None: 84 | tensor_layout = "HND" if skip_reshape else "NHD" 85 | tensor_layout = tensor_layout.strip().upper() 86 | if tensor_layout not in {"NHD", "HND"}: 87 | raise ValueError("Bad tensor_layout, must be one of NHD, HND") 88 | if mask is not None: 89 | if mask.ndim == 2: 90 | mask = mask[None, None, ...] 91 | elif mask.ndim == 3: 92 | mask = mask.unsqueeze(1) 93 | if not skip_reshape: 94 | if tensor_layout == "HND": 95 | q, k, v = ( 96 | t.view(batch, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v) 97 | ) 98 | do_transpose = True 99 | else: 100 | q, k, v = (t.view(batch, -1, heads, dim_head) for t in (q, k, v)) 101 | do_transpose = False 102 | else: 103 | do_transpose = not skip_output_reshape 104 | if not old_sageattn and tensor_layout == "NHD": 105 | q, k, v = (t.transpose(1, 2) for t in (q, k, v)) 106 | do_transpose = skip_output_reshape 107 | if not old_sageattn: 108 | kwargs["tensor_layout"] = tensor_layout 109 | sm_scale_hd = kwargs.pop(f"sm_scale_{dim_head}", None) 110 | if sm_scale_hd is not None: 111 | kwargs["sm_scale"] = sm_scale_hd 112 | result = sageattn_function( 113 | q, 114 | k, 115 | v, 116 | is_causal=False, 117 | attn_mask=mask, 118 | dropout_p=0.0, 119 | **kwargs, 120 | ) 121 | if do_transpose: 122 | result = result.transpose(1, 2) 123 | if not skip_output_reshape: 124 | result = result.reshape(batch, -1, heads * dim_head) 125 | return result 126 | 127 | 128 | def copy_funattrs(fun, dest=None): 129 | if dest is None: 130 | dest = fun.__class__(fun.__code__, fun.__globals__) 131 | for k in ( 132 | "__code__", 133 | "__defaults__", 134 | "__kwdefaults__", 135 | "__module__", 136 | ): 137 | setattr(dest, k, getattr(fun, k)) 138 | return dest 139 | 140 | 141 | def make_sageattn_wrapper( 142 | *, 143 | orig_attn, 144 | sageattn_function: str = "sageattn", 145 | **kwargs: dict, 146 | ): 147 | outer_kwargs = kwargs 148 | sageattn_function = getattr(sageattention, sageattn_function) 149 | 150 | def attn( 151 | *args: list, 152 | _sage_outer_kwargs=outer_kwargs, 153 | _sage_orig_attention=orig_attn, 154 | _sage_sageattn_function=sageattn_function, 155 | _sage_attn=attention_sage, 156 | **kwargs: dict, 157 | ) -> torch.Tensor: 158 | return _sage_attn( 159 | *args, 160 | orig_attention=_sage_orig_attention, 161 | sageattn_function=_sage_sageattn_function, 162 | **_sage_outer_kwargs, 163 | **kwargs, 164 | ) 165 | 166 | return attn 167 | 168 | 169 | @contextlib.contextmanager 170 | def sageattn_context( 171 | enabled: bool, 172 | **kwargs: dict, 173 | ): 174 | if not enabled: 175 | yield None 176 | return 177 | orig_attn = copy_funattrs(comfyattn.optimized_attention) 178 | attn = make_sageattn_wrapper(orig_attn=orig_attn, **kwargs) 179 | try: 180 | copy_funattrs(attn, comfyattn.optimized_attention) 181 | yield None 182 | finally: 183 | copy_funattrs(orig_attn, comfyattn.optimized_attention) 184 | 185 | 186 | def get_yaml_parameters(yaml_parameters: str | None = None) -> dict: 187 | if not yaml_parameters: 188 | return {} 189 | extra_params = yaml.safe_load(yaml_parameters) 190 | if extra_params is None: 191 | return {} 192 | if not isinstance(extra_params, dict): 193 | raise ValueError( # noqa: TRY004 194 | "BlehSageAttention: yaml_parameters must either be null or an object", 195 | ) 196 | return extra_params 197 | 198 | 199 | class BlehGlobalSageAttention: 200 | DESCRIPTION = "Deprecated: Prefer using BlehSageAttentionSampler if possible. This node allows globally replacing ComfyUI's attention with SageAtteniton (performance enhancement). Requires SageAttention to be installed into the ComfyUI Python environment. IMPORTANT: This is not a normal model patch. For settings to apply (including toggling on or off) the node must actually be run. If you toggle it on, run your workflow and then bypass or mute the node this will not actually disable SageAttention." 201 | RETURN_TYPES = ("MODEL",) 202 | FUNCTION = "go" 203 | CATEGORY = "hacks" 204 | 205 | @classmethod 206 | def INPUT_TYPES(cls) -> dict: 207 | return { 208 | "required": { 209 | "model": ("MODEL",), 210 | "enabled": ( 211 | "BOOLEAN", 212 | {"default": True}, 213 | ), 214 | }, 215 | "optional": { 216 | "yaml_parameters": ( 217 | "STRING", 218 | { 219 | "tooltip": "Allows specifying custom parameters via YAML. These are mostly passed directly to the SageAttention function with no error checking. Must be empty or a YAML object.", 220 | "dynamicPrompts": False, 221 | "multiline": True, 222 | "defaultInput": True, 223 | }, 224 | ), 225 | }, 226 | } 227 | 228 | orig_attn = None 229 | 230 | @classmethod 231 | def go( 232 | cls, 233 | *, 234 | model: object, 235 | enabled: bool, 236 | yaml_parameters: str | None = None, 237 | ) -> tuple: 238 | if not enabled: 239 | if cls.orig_attn is not None: 240 | copy_funattrs(cls.orig_attn, comfyattn.optimized_attention) 241 | cls.orig_attn = None 242 | return (model,) 243 | if sageattention is None: 244 | raise RuntimeError( 245 | "sageattention not installed to Python environment: SageAttention feature unavailable", 246 | ) 247 | if not cls.orig_attn: 248 | cls.orig_attn = copy_funattrs(comfyattn.optimized_attention) 249 | attn = make_sageattn_wrapper( 250 | orig_attn=cls.orig_attn, 251 | **get_yaml_parameters(yaml_parameters), 252 | ) 253 | copy_funattrs(attn, comfyattn.optimized_attention) 254 | return (model,) 255 | 256 | 257 | def sageattn_sampler( 258 | model: object, 259 | x: torch.Tensor, 260 | sigmas: torch.Tensor, 261 | *, 262 | sageattn_sampler_options: tuple, 263 | **kwargs: dict, 264 | ) -> torch.Tensor: 265 | sampler, start_percent, end_percent, sageattn_kwargs = sageattn_sampler_options 266 | ms = model.inner_model.inner_model.model_sampling 267 | start_sigma, end_sigma = ( 268 | round(ms.percent_to_sigma(start_percent), 4), 269 | round(ms.percent_to_sigma(end_percent), 4), 270 | ) 271 | del ms 272 | 273 | def model_wrapper( 274 | x: torch.Tensor, 275 | sigma: torch.Tensor, 276 | **extra_args: dict[str], 277 | ) -> torch.Tensor: 278 | sigma_float = float(sigma.max().detach().cpu()) 279 | enabled = end_sigma <= sigma_float <= start_sigma 280 | with sageattn_context( 281 | enabled=enabled, 282 | **sageattn_kwargs, 283 | ): 284 | return model(x, sigma, **extra_args) 285 | 286 | for k in ( 287 | "inner_model", 288 | "sigmas", 289 | ): 290 | if hasattr(model, k): 291 | setattr(model_wrapper, k, getattr(model, k)) 292 | return sampler.sampler_function( 293 | model_wrapper, 294 | x, 295 | sigmas, 296 | **kwargs, 297 | **sampler.extra_options, 298 | ) 299 | 300 | 301 | class BlehSageAttentionSampler: 302 | DESCRIPTION = "Sampler wrapper that enables using SageAttention (performance enhancement) while sampling is in progress. Requires SageAttention to be installed into the ComfyUI Python environment." 303 | CATEGORY = "sampling/custom_sampling/samplers" 304 | RETURN_TYPES = ("SAMPLER",) 305 | FUNCTION = "go" 306 | 307 | @classmethod 308 | def INPUT_TYPES(cls) -> dict: 309 | return { 310 | "required": { 311 | "sampler": ("SAMPLER",), 312 | }, 313 | "optional": { 314 | "start_percent": ( 315 | "FLOAT", 316 | { 317 | "default": 0.0, 318 | "min": 0.0, 319 | "max": 1.0, 320 | "step": 0.001, 321 | "tooltip": "Time the effect becomes active as a percentage of sampling, not steps.", 322 | }, 323 | ), 324 | "end_percent": ( 325 | "FLOAT", 326 | { 327 | "default": 1.0, 328 | "min": 0.0, 329 | "max": 1.0, 330 | "step": 0.001, 331 | "tooltip": "Time the effect ends (inclusive) as a percentage of sampling, not steps.", 332 | }, 333 | ), 334 | "yaml_parameters": ( 335 | "STRING", 336 | { 337 | "tooltip": "Allows specifying custom parameters via YAML. These are mostly passed directly to the SageAttention function with no error checking. Must be empty or a YAML object.", 338 | "dynamicPrompts": False, 339 | "multiline": True, 340 | "defaultInput": True, 341 | }, 342 | ), 343 | }, 344 | } 345 | 346 | @classmethod 347 | def go( 348 | cls, 349 | sampler: object, 350 | *, 351 | start_percent: float = 0.0, 352 | end_percent: float = 1.0, 353 | yaml_parameters: str | None = None, 354 | ) -> tuple: 355 | if sageattention is None: 356 | raise RuntimeError( 357 | "sageattention not installed to Python environment: SageAttention feature unavailable", 358 | ) 359 | return ( 360 | KSAMPLER( 361 | sageattn_sampler, 362 | extra_options={ 363 | "sageattn_sampler_options": ( 364 | sampler, 365 | start_percent, 366 | end_percent, 367 | get_yaml_parameters(yaml_parameters), 368 | ), 369 | }, 370 | ), 371 | ) 372 | -------------------------------------------------------------------------------- /py/nodes/samplers.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import contextlib 4 | import importlib 5 | import random 6 | from copy import deepcopy 7 | from functools import partial 8 | from os import environ 9 | from typing import Any, Callable, NamedTuple 10 | 11 | import torch 12 | from comfy.samplers import KSAMPLER, KSampler, k_diffusion_sampling 13 | from tqdm import tqdm 14 | 15 | from .misc import Wildcard 16 | 17 | BLEH_PRESET_LIMIT = 16 18 | BLEH_PRESET_COUNT = 1 19 | with contextlib.suppress(Exception): 20 | BLEH_PRESET_COUNT = min( 21 | BLEH_PRESET_LIMIT, 22 | max( 23 | 0, 24 | int(environ.get("COMFYUI_BLEH_SAMPLER_PRESET_COUNT", "1")), 25 | ), 26 | ) 27 | 28 | 29 | class SamplerChain(NamedTuple): 30 | prev: SamplerChain | None = None 31 | steps: int = 0 32 | sampler: object | None = None 33 | chain_sampler: Callable | None = None 34 | 35 | 36 | class BlehInsaneChainSampler: 37 | RETURN_TYPES = ("SAMPLER", "BLEH_SAMPLER_CHAIN") 38 | CATEGORY = "sampling/custom_sampling/samplers" 39 | FUNCTION = "build" 40 | 41 | @classmethod 42 | def INPUT_TYPES(cls): 43 | return { 44 | "required": { 45 | "sampler": ("SAMPLER",), 46 | "steps": ("INT", {"default": 0, "min": 0, "max": 9999}), 47 | }, 48 | "optional": { 49 | "sampler_chain_opt": ("BLEH_SAMPLER_CHAIN",), 50 | }, 51 | } 52 | 53 | def build( 54 | self, 55 | sampler: object | None = None, 56 | steps: int = 0, 57 | sampler_chain_opt: SamplerChain | None = None, 58 | ) -> tuple[KSAMPLER, SamplerChain]: 59 | if sampler is None: 60 | raise ValueError("BlehInsaneChainSampler: sampler missing") 61 | if steps > 0: 62 | chain = SamplerChain(steps=steps, sampler=sampler, prev=sampler_chain_opt) 63 | else: 64 | chain = sampler_chain_opt or SamplerChain() 65 | return (KSAMPLER(self.sampler, {"sampler_chain": chain}), chain) 66 | 67 | @classmethod 68 | @torch.no_grad() 69 | def sampler( 70 | cls, 71 | model, 72 | x, 73 | sigmas, 74 | *args: list[Any], 75 | disable=None, 76 | sampler_chain=None, 77 | **kwargs: dict[str, Any], 78 | ): 79 | if not sampler_chain: 80 | return x 81 | chain = sampler_chain 82 | remaining_steps = len(sigmas) - 1 83 | i = 0 84 | progress = tqdm(total=remaining_steps, disable=disable) 85 | while remaining_steps > 0 and chain: 86 | while chain and (chain.steps == 0 or chain.sampler is None): 87 | chain = chain.prev 88 | if chain is None or chain.sampler is None: 89 | raise ValueError("Sampler chain didn't provide a sampler for sampling!") 90 | steps = min(chain.steps, remaining_steps) 91 | real_next = chain.prev 92 | while real_next and (real_next.steps == 0 or real_next.sampler is None): 93 | real_next = real_next.prev 94 | if real_next and (real_next.steps == 0 or real_next.sampler is None): 95 | real_next = None 96 | if real_next is None: 97 | steps = remaining_steps + 1 98 | start_idx = max(i - 1, i) 99 | end_idx = start_idx + steps 100 | curr_sigmas = sigmas[start_idx : end_idx + 1] 101 | x = chain.sampler.sampler_function( 102 | model, 103 | x, 104 | curr_sigmas, 105 | *args, 106 | disable=disable, 107 | **chain.sampler.extra_options, 108 | **kwargs, 109 | ) 110 | i += steps 111 | progress.update(n=min(steps, remaining_steps)) 112 | remaining_steps -= steps 113 | chain = real_next 114 | progress.close() 115 | return x 116 | 117 | 118 | class BlehForceSeedSampler: 119 | DESCRIPTION = "ComfyUI has a bug where it will not set any seed if you have add_noise disabled in the sampler. This node is a workaround for that which ensures a seed alway gets set." 120 | RETURN_TYPES = ("SAMPLER",) 121 | CATEGORY = "sampling/custom_sampling/samplers" 122 | 123 | @classmethod 124 | def INPUT_TYPES(cls): 125 | return { 126 | "required": { 127 | "sampler": ("SAMPLER",), 128 | }, 129 | "optional": { 130 | "seed_offset": ( 131 | "INT", 132 | { 133 | "default": 1, 134 | "min": 0, 135 | "max": 200, 136 | "tooltip": "Advances the RNG this many times to avoid the mistake of using the same noise for sampling as the initial noise. I recommend leaving this at 1 (or higher) but you can set it to 0. to disable", 137 | }, 138 | ), 139 | }, 140 | } 141 | 142 | FUNCTION = "go" 143 | 144 | def go( 145 | self, 146 | sampler: object, 147 | seed_offset: int | None = 1, 148 | ) -> tuple[KSAMPLER, SamplerChain]: 149 | return ( 150 | KSAMPLER( 151 | self.sampler_function, 152 | extra_options=sampler.extra_options 153 | | { 154 | "bleh_wrapped_sampler": sampler, 155 | "bleh_seed_offset": seed_offset, 156 | }, 157 | inpaint_options=sampler.inpaint_options | {}, 158 | ), 159 | ) 160 | 161 | @classmethod 162 | @torch.no_grad() 163 | def sampler_function( 164 | cls, 165 | model: object, 166 | x: torch.Tensor, 167 | *args: list[Any], 168 | extra_args: dict[str, Any] | None = None, 169 | bleh_wrapped_sampler: object | None = None, 170 | bleh_seed_offset: int | None = 1, 171 | **kwargs: dict[str, Any], 172 | ): 173 | if not bleh_wrapped_sampler: 174 | raise ValueError("wrapped sampler missing!") 175 | seed = (extra_args or {}).get("seed") 176 | if seed is not None: 177 | random.seed(seed) 178 | torch.manual_seed(seed) 179 | for _ in range(bleh_seed_offset if bleh_seed_offset is not None else 0): 180 | _ = random.random() # noqa: S311 181 | _ = torch.randn_like(x) 182 | return bleh_wrapped_sampler.sampler_function( 183 | model, 184 | x, 185 | *args, 186 | extra_args=extra_args, 187 | **kwargs, 188 | ) 189 | 190 | 191 | BLEH_PRESET = [None] * BLEH_PRESET_COUNT 192 | 193 | 194 | def bleh_sampler_preset_wrapper( 195 | preset_idx, 196 | model, 197 | x, 198 | sigmas, 199 | *args: list, 200 | **kwargs: dict, 201 | ) -> torch.Tensor: 202 | if not (0 <= preset_idx < BLEH_PRESET_COUNT): 203 | raise ValueError("Bleh sampler preset out of range") 204 | preset = BLEH_PRESET[preset_idx] 205 | if preset is None: 206 | errstr = f"Cannot use bleh_preset_{preset_idx} - present not defined. Ensure BlehSetSamplerPreset runs before sampling." 207 | raise RuntimeError(errstr) 208 | sampler, override_sigmas = preset 209 | if override_sigmas is not None: 210 | sigmas = override_sigmas.detach().clone().to(sigmas) 211 | return sampler.sampler_function( 212 | model, 213 | x, 214 | sigmas, 215 | *args, 216 | **sampler.extra_options, 217 | **kwargs, 218 | ) 219 | 220 | 221 | def add_sampler_presets(): 222 | if BLEH_PRESET_COUNT < 1: 223 | return 224 | for idx in range(BLEH_PRESET_COUNT): 225 | key = f"bleh_preset_{idx}" 226 | if key in KSampler.SAMPLERS: 227 | print( 228 | f"\n** ComfyUI-bleh: Warning: {key} already exists in sampler list, skipping adding preset samplers.", 229 | ) 230 | if idx == 0: 231 | return 232 | break 233 | KSampler.SAMPLERS.append(key) 234 | setattr( 235 | k_diffusion_sampling, 236 | f"sample_{key}", 237 | partial(bleh_sampler_preset_wrapper, idx), 238 | ) 239 | importlib.reload(k_diffusion_sampling) 240 | 241 | 242 | class BlehSetSamplerPreset: 243 | WILDCARD = Wildcard("*") 244 | DESCRIPTION = "This node allows setting a custom sampler as a preset that can be selected in nodes that don't support custom sampling (FaceDetailer for example). This node needs to run at least once with any preset changes before actual sampling begins. The any_input input acts as a passthrough so you can do something like pass your model or latent through before you start sampling to ensure the node runs. You can also connect something like an integer or string to the dummy_opt input and change it to force the node to run again. The number of presets can be adjusted (and the whole feature disabled if desired) by setting the environment variable COMFYUI_BLEH_SAMPLER_PRESET_COUNT. WARNING: Since the input and output are wildcards, this bypasses ComfyUI's normal type checking. Make sure you connect the output to something that actually accepts the input type." 245 | RETURN_TYPES = (WILDCARD,) 246 | OUTPUT_TOOLTIPS = ( 247 | "This just returns the value of any_input unchanged. WARNING: ComfyUI's normal typechecking is disabled here, make sure you connect this output to something that allows the input type.", 248 | ) 249 | CATEGORY = "hacks" 250 | NOT_IDEMPOTENT = True 251 | FUNCTION = "go" 252 | 253 | @classmethod 254 | def INPUT_TYPES(cls): 255 | return { 256 | "required": { 257 | "sampler": ("SAMPLER", {"tooltip": "Sampler to use for this preset."}), 258 | "any_input": ( 259 | cls.WILDCARD, 260 | { 261 | "tooltip": "This input is simply returned as the output. Note: Make sure you connect this node's output to something that supports input connected here.", 262 | }, 263 | ), 264 | "preset": ( 265 | "INT", 266 | { 267 | "min": -1, 268 | "max": BLEH_PRESET_COUNT - 1, 269 | "default": 0 if BLEH_PRESET_COUNT > 0 else -1, 270 | "tooltip": "Preset index to set. If set to -1, no preset assignment will be done. The number of presets can be adjusted, see the README.", 271 | }, 272 | ), 273 | "discard_penultimate_sigma": ( 274 | "BOOLEAN", 275 | { 276 | "default": False, 277 | "tooltip": "Advanced option to allow discarding the penultimate sigma. May be needed for some samplers like dpmpp_3m_sde - if it seems like the generation has a bunch of noise added at the very last step then you can try enabling this. Note: Cannot be used when override sigmas are attached.", 278 | }, 279 | ), 280 | }, 281 | "optional": { 282 | "override_sigmas_opt": ( 283 | "SIGMAS", 284 | { 285 | "tooltip": "Advanced option that allows overriding the sigmas used for sampling. Note: Cannot be used with discard_penultimate_sigma. Also this cannot control the noise added by the sampler, so if the schedule used by the sampler starts on a different sigma than the override you will likely run into issues.", 286 | }, 287 | ), 288 | "dummy_opt": ( 289 | cls.WILDCARD, 290 | { 291 | "tooltip": "This input can optionally be connected to any value as a way to force the node to run again on demand. See the README.", 292 | }, 293 | ), 294 | }, 295 | } 296 | 297 | @classmethod 298 | def go( 299 | cls, 300 | *, 301 | sampler, 302 | any_input, 303 | preset, 304 | discard_penultimate_sigma, 305 | override_sigmas_opt: torch.Tensor | None = None, 306 | dummy_opt=None, # noqa: ARG003 307 | ): 308 | if not (0 <= preset < BLEH_PRESET_COUNT): 309 | return (any_input,) 310 | if discard_penultimate_sigma and override_sigmas_opt is not None: 311 | raise ValueError( 312 | "BlehSetSamplerPreset: Cannot override sigmas and also enable discard penultimate sigma", 313 | ) 314 | dps_samplers = getattr(KSampler, "DISCARD_PENULTIMATE_SIGMA_SAMPLERS", None) 315 | if dps_samplers is not None: 316 | key = f"bleh_preset_{preset}" 317 | if discard_penultimate_sigma: 318 | dps_samplers.update(key) 319 | else: 320 | dps_samplers -= {key} 321 | sigmas = ( 322 | None 323 | if override_sigmas_opt is None 324 | else override_sigmas_opt.detach().clone().cpu() 325 | ) 326 | BLEH_PRESET[preset] = (deepcopy(sampler), sigmas) 327 | return (any_input,) 328 | -------------------------------------------------------------------------------- /py/nodes/taevid.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: TID252 2 | 3 | import torch # noqa: I001 4 | 5 | import folder_paths 6 | from comfy import model_management 7 | 8 | from ..better_previews.previewer import VIDEO_FORMATS, VideoModelInfo 9 | from ..better_previews.tae_vid import TAEVid 10 | 11 | 12 | class TAEVideoNodeBase: 13 | FUNCTION = "go" 14 | CATEGORY = "latent" 15 | 16 | @classmethod 17 | def INPUT_TYPES(cls) -> dict: 18 | return { 19 | "required": { 20 | "latent_type": (("wan21", "hunyuanvideo", "mochi"),), 21 | "parallel_mode": ( 22 | "BOOLEAN", 23 | { 24 | "default": False, 25 | "tooltip": "Parallel mode is faster but requires more memory.", 26 | }, 27 | ), 28 | }, 29 | } 30 | 31 | @classmethod 32 | def get_taevid_model( 33 | cls, 34 | latent_type: str, 35 | ) -> tuple[TAEVid, torch.device, torch.dtype, VideoModelInfo]: 36 | vmi = VIDEO_FORMATS.get(latent_type) 37 | if vmi is None or vmi.tae_model is None: 38 | raise ValueError("Bad latent type") 39 | tae_model_path = folder_paths.get_full_path("vae_approx", vmi.tae_model) 40 | if tae_model_path is None: 41 | if latent_type == "wan21": 42 | model_src = "taew2_1.pth from https://github.com/madebyollin/taehv" 43 | elif latent_type == "hunyuanvideo": 44 | model_src = "taehv.pth from https://github.com/madebyollin/taehv" 45 | else: 46 | model_src = "taem1.pth from https://github.com/madebyollin/taem1" 47 | err_string = f"Missing TAE video model. Download {model_src} and place it in the models/vae_approx directory" 48 | raise RuntimeError(err_string) 49 | device = model_management.vae_device() 50 | dtype = model_management.vae_dtype(device=device) 51 | return ( 52 | TAEVid( 53 | checkpoint_path=tae_model_path, 54 | latent_channels=vmi.latent_format.latent_channels, 55 | device=device, 56 | ).to(device), 57 | device, 58 | dtype, 59 | vmi, 60 | ) 61 | 62 | @classmethod 63 | def go(cls, *, latent, latent_type: str, parallel_mode: bool) -> tuple: 64 | raise NotImplementedError 65 | 66 | 67 | class TAEVideoDecode(TAEVideoNodeBase): 68 | RETURN_TYPES = ("IMAGE",) 69 | CATEGORY = "latent" 70 | DESCRIPTION = "Fast decoding of Wan, Hunyuan and Mochi video latents with the video equivalent of TAESD." 71 | 72 | @classmethod 73 | def INPUT_TYPES(cls) -> dict: 74 | result = super().INPUT_TYPES() 75 | result["required"] |= { 76 | "latent": ("LATENT",), 77 | } 78 | return result 79 | 80 | @classmethod 81 | def go(cls, *, latent: dict, latent_type: str, parallel_mode: bool) -> tuple: 82 | model, device, dtype, vmi = cls.get_taevid_model(latent_type) 83 | samples = latent["samples"].detach().to(device=device, dtype=dtype, copy=True) 84 | samples = vmi.latent_format().process_in(samples) 85 | img = ( 86 | model.decode( 87 | samples.transpose(1, 2), 88 | parallel=parallel_mode, 89 | show_progress=True, 90 | ) 91 | .movedim(2, -1) 92 | .to( 93 | dtype=torch.float, 94 | device="cpu", 95 | ) 96 | ) 97 | img = img.reshape(-1, *img.shape[-3:]) 98 | return (img,) 99 | 100 | 101 | class TAEVideoEncode(TAEVideoNodeBase): 102 | RETURN_TYPES = ("LATENT",) 103 | CATEGORY = "latent" 104 | DESCRIPTION = "Fast encoding of Wan, Hunyuan and Mochi video latents with the video equivalent of TAESD." 105 | 106 | @classmethod 107 | def INPUT_TYPES(cls) -> dict: 108 | result = super().INPUT_TYPES() 109 | result["required"] |= { 110 | "image": ("IMAGE",), 111 | } 112 | return result 113 | 114 | @classmethod 115 | def go(cls, *, image: torch.Tensor, latent_type: str, parallel_mode: bool) -> tuple: 116 | model, device, dtype, vmi = cls.get_taevid_model(latent_type) 117 | image = image.detach().to(device=device, dtype=dtype, copy=True) 118 | if image.ndim == 4: 119 | image = image.unsqueeze(0) 120 | latent = model.encode( 121 | image.movedim(-1, 2), 122 | parallel=parallel_mode, 123 | show_progress=True, 124 | ).transpose(1, 2) 125 | latent = ( 126 | vmi.latent_format() 127 | .process_out(latent) 128 | .to( 129 | dtype=torch.float, 130 | device="cpu", 131 | ) 132 | ) 133 | return ({"samples": latent},) 134 | -------------------------------------------------------------------------------- /py/settings.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | 4 | class Settings: 5 | def __init__(self): 6 | self.btp_enabled = False 7 | 8 | def update(self, obj): 9 | btp = obj.get("betterTaesdPreviews", None) 10 | self.btp_enabled = btp is not None and btp.get("enabled", True) is True 11 | if not self.btp_enabled: 12 | return 13 | max_size = max(8, btp.get("max_size", 768)) 14 | self.btp_max_width = max(8, btp.get("max_width", max_size)) 15 | self.btp_max_height = max(8, btp.get("max_height", max_size)) 16 | self.btp_max_batch = max(1, btp.get("max_batch", 4)) 17 | self.btp_max_batch_cols = max(1, btp.get("max_batch_cols", 2)) 18 | self.btp_throttle_secs = btp.get("throttle_secs", 1) 19 | self.btp_skip_upscale_layers = btp.get("skip_upscale_layers", 0) 20 | self.btp_preview_device = btp.get("preview_device") 21 | self.btp_maxed_batch_step_mode = btp.get("maxed_batch_step_mode", False) 22 | self.btp_compile_previewer = btp.get("compile_previewer", False) 23 | self.btp_oom_fallback = btp.get("oom_fallback", "latent2rgb") 24 | self.btp_oom_retry = btp.get("oom_retry", True) 25 | self.btp_whitelist = frozenset(btp.get("whitelist_formats", frozenset())) 26 | self.btp_blacklist = frozenset(btp.get("blacklist_formats", frozenset())) 27 | self.btp_video_parallel = btp.get("video_parallel", False) 28 | self.btp_video_max_frames = btp.get("video_max_frames", -1) 29 | self.btp_video_temporal_upscale_level = btp.get( 30 | "video_temporal_upscale_level", 31 | 0, 32 | ) 33 | self.btp_animate_preview = btp.get("animate_preview", "none") 34 | self.btp_verbose = btp.get("verbose", False) 35 | 36 | @staticmethod 37 | def get_cfg_path(filename) -> Path: 38 | my_path = Path.resolve(Path(__file__).parent) 39 | return my_path.parent / filename 40 | 41 | def try_update_from_json(self, filename): 42 | import json # noqa: PLC0415 43 | 44 | try: 45 | with Path.open(self.get_cfg_path(filename)) as fp: 46 | self.update(json.load(fp)) 47 | return True 48 | except OSError: 49 | return False 50 | 51 | def try_update_from_yaml(self, filename): 52 | try: 53 | import yaml # noqa: PLC0415 54 | 55 | with Path.open(self.get_cfg_path(filename)) as fp: 56 | self.update(yaml.safe_load(fp)) 57 | return True 58 | except (OSError, ImportError): 59 | return False 60 | 61 | 62 | SETTINGS = Settings() 63 | 64 | 65 | def load_settings(): 66 | if not SETTINGS.try_update_from_yaml("blehconfig.yaml"): 67 | SETTINGS.try_update_from_json("blehconfig.json") 68 | return SETTINGS 69 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | [lint] 2 | ignore = [ 3 | "ANN001", 4 | "ANN101", 5 | "ANN102", 6 | "ANN201", 7 | "ANN202", 8 | "ANN204", 9 | "ANN206", 10 | "C901", 11 | "CPY001", 12 | "D100", 13 | "D101", 14 | "D102", 15 | "D103", 16 | "D104", 17 | "D105", 18 | "D107", 19 | "D211", 20 | "D213", 21 | "E402", 22 | "E501", 23 | "EM101", 24 | "ERA001", 25 | "F403", 26 | "F405", 27 | "FBT001", 28 | "FBT002", 29 | "PLR0912", 30 | "PLR0913", 31 | "PLR0915", 32 | "PLR2004", 33 | "T201", 34 | "TD001", 35 | "TD002", 36 | "TD003", 37 | "TRY003", 38 | "N802", 39 | "N999", 40 | ] 41 | select = ["ALL"] 42 | --------------------------------------------------------------------------------