├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── assets ├── basic_sampling.png └── group_time_example.png ├── docs ├── expression.md └── filter.md └── py ├── __init__.py ├── custom_noise ├── __init__.py ├── base.py ├── nodes.py └── noise_perlin.py ├── expression ├── __init__.py ├── expression.py ├── handler.py ├── parser.py ├── types.py ├── util.py └── validation.py ├── expression_handlers.py ├── external.py ├── filtering.py ├── latent.py ├── model.py ├── nodes.py ├── noise.py ├── res_support.py ├── restart.py ├── sampling.py ├── step_samplers.py ├── substep_merging.py ├── substep_sampling.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 blepping 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .py import nodes 2 | from .py import custom_noise 3 | 4 | NODE_CLASS_MAPPINGS = { 5 | "OCS Sampler": nodes.SamplerNode, 6 | "OCS Substeps": nodes.SubstepsNode, 7 | "OCS Group": nodes.GroupNode, 8 | "OCS Param": nodes.ParamNode, 9 | "OCS MultiParam": nodes.MultiParamNode, 10 | "OCS ModelSetMaxSigma": nodes.ModelSetMaxSigmaNode, 11 | "OCS SimpleRestartSchedule": nodes.SimpleRestartSchedule, 12 | } | custom_noise.NODE_CLASS_MAPPINGS 13 | __all__ = ["NODE_CLASS_MAPPINGS"] 14 | -------------------------------------------------------------------------------- /assets/basic_sampling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blepping/comfyui_overly_complicated_sampling/716b0d749ad64be71d9c535a03c4331e75272d92/assets/basic_sampling.png -------------------------------------------------------------------------------- /assets/group_time_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blepping/comfyui_overly_complicated_sampling/716b0d749ad64be71d9c535a03c4331e75272d92/assets/group_time_example.png -------------------------------------------------------------------------------- /docs/expression.md: -------------------------------------------------------------------------------- 1 | # OCS Expressions 2 | 3 | See [Filters](filter.md) for places where expressions apply. 4 | 5 | ## Expressions 6 | 7 | OCS implements a simple expression language. 8 | 9 | Supported math operators: `+`, `-`, `*`, `/`, `//` (integer division), `**` (power) 10 | 11 | Supported logic operators: `||`, `&&`, `==`, `!=`, `>`, `<`, `>=`, `<=` 12 | 13 | Operator precedence should generally work the way you'd expect. 14 | 15 | You may surround a function name with backticks to turn it into a binary operator (only for functions that take two arguments). 16 | 17 | Functions are called via `name(param1, param2)`. Keyword arguments may be passed using the `:>` operator. Example: 18 | `name(param1, key :> 123, key2 :> otherfunction(10))`. 19 | 20 | Symbols (simple string type) are defined using `'symbol_name` - note the solitary single quote. They may not contain spaces. 21 | 22 | `;` can be used to sequence operations. I.E. `exp1 ; exp2` evaluates `exp1`, then `exp2` and then result of the expression is whatever `exp2` returned. 23 | 24 | `:=` is used to assign to a temporary variable (see `set_var` below). 25 | 26 | The expression language supports a C/JavaScript style ternary operator: `condition ? true_branch : false_branch` is the equivalent of `if(condition, true_branch, false_branch)`. 27 | 28 | Like Python, a parenthesized expression with a trailing comma can be used to create an empty tuple. Example: `(1,)` 29 | 30 | ## Filter Variables 31 | 32 | Indexes like `step` are zero-based: `0` will be the first step. 33 | 34 | ### Basic Variables 35 | 36 | * `default`: Context specific default value. i.e. if used in an `input` expression this would be `x`, if used for `output` this would be the current result. 37 | * `step`: Current step. 38 | * `substep`: Current substep. 39 | * `dt`: `sigma_next - sigma` 40 | * `sigma_idx`: Index of the current sigma. Note that when using restarts this will be based on the restart sigma chunks, not full sigma list. 41 | * `sigma`: The current sigma. 42 | * `sigma_next`: The next sigma. 43 | * `sigma_down`: The down sigma in ancestral sampling. 44 | * `sigma_up`: The up sigma in ancestral sampling. 45 | * `sigma_prev`: The previous sigma (may be `None`). 46 | * `hist_len`: Current available history length. "Now" counts as one. 47 | * `sigma_min`: The minimum sigma (based on the full list). 48 | * `sigma_max`: The maximum sigma (based on the full list). 49 | * `step_pct`: Percentage for the current step (based on total steps). 50 | * `total_steps`: Total steps to be sampled. 51 | 52 | ### Extended Variables 53 | 54 | * `denoised`: From the current step or substep. May not be available in model `input` or group `pre_filter`. 55 | * `cond`: From the current step or substep. May not be available in model `input` or group `pre_filter`. 56 | * `uncond`: From the current step or substep. May not be available in model `input` or group `pre_filter`. 57 | * `denoised_prev`: Only available when model history exists. 58 | * `cond_prev`: Only available when model history exists. 59 | * `cond_prev`: Only available when model history exists. 60 | 61 | ### Model Filter Variables 62 | 63 | * `model_call`: Only applicable to `model` filters, will be the model call index. I.E. if the sampler calls the model three times, the filter would be called with model call indexes `0`, `1` and `2`. 64 | 65 | Available in model filters, with the exception of the `input` filter. 66 | 67 | * `denoised_curr` 68 | * `cond_curr` 69 | * `uncond_curr` 70 | 71 | ## Basic Expression Functions 72 | 73 | | | Name | Input | Output | 74 | | :--- | :--- | :--- | :--- | 75 | |⬤| `all` | `B`\* | `B` | 76 | | Evaluates to true if all its arguments evaluate to true.
**Example:** `all(x > 1, y < 1)` | 77 | |⬤| `any` | `B`\* | `B` | 78 | | Evaluates to true if any of its arguments evaluate to true.
**Example:** `any(x > 1, y < 1)` | 79 | |⬤| `between` | value:`N`, from:`N`, to:`N` | `B` | 80 | | Boolean range checking.
**Example:** `between(value, low, high)` | 81 | |⬤| `comment` | `*` | `null` | 82 | | Ignores any arguments passed to it (they won't be evaluated at all but must parse as a valid expression) and returns `None` | 83 | |⬤| `dict` | `*`* | `dict` | 84 | | Constructs a dictionary from its keyword arguments. _Note_: You may not pass positional arguments.
**Example:** `dict(key1 :> value1, keyN :> valueN)` | 85 | |⬤| `get` | name:`SY`, fallback:`*` | `*` | 86 | | Returns a variable if set, otherwise the fallback.
**Example:** `get('somevar, 123)` | 87 | |⬤| `if` | condition:`B`, then:`*`, else:`*` | `*` | 88 | | Conditional expressions.
**Example:** `if(condition, true_expression, false_expression)` | 89 | |⬤| `index` | index:`IDX`, value:`S \| T` | `*` | 90 | | Index function. | 91 | |⬤| `is_set` | name:`SY` | `B` | 92 | | Tests whether a variable is set. | 93 | |⬤| `max` | values:`SN` | `N` | 94 | | Maximum operation. _Note_: Takes one sequence argument.
**Example:** `min((1, 2, 3))` | 95 | |⬤| `min` | values: `SN` | `N` | 96 | | Minimum operation. _Note_: Takes one sequence argument.
**Example:** `max((1, 2, 3))` | 97 | |⬤| `mod` | lhs:`N`, rhs:`N` | `N` | 98 | | Modulus operation:
**Example:** `mod(5, 2)` | 99 | |⬤| `neg` | `N` | `N` | 100 | | Negation.
**Example:** `neg(2)` | 101 | |⬤| `not` | `B` | `B` | 102 | | Boolean negation | 103 | |⬤| `s_` | start:`I(null)`, end:`I(null)`, step:`I(null)` | `slice` | 104 | | Creates a slice object from the `start`, `end`, `step` values. See Numpy [s_](https://numpy.org/doc/stable/reference/generatednumpy.s_.html) | 105 | |⬤| `set_var` | `SY`, `*` | `*` | 106 | | Sets a temporary variable to the specified value and returns the value. Alias for the `:=` assignment operator.
**Example**: `test1 := 2; set_var('test2, 10); test1 * test2` | 107 | |⬤| `unsafe_call` | `callable`, `*`\* | `*` | 108 | | Allows calling an arbitrary callable.
**Example:** `unsafe_call(some_callable, arg1, arg2, kwarg1 :> 123)` 109 | 110 | **Legend**: `B`=boolean, `N`=numeric, `NS`=scalar numeric, `I`=integer, `F`=float, `T`=tensor, `S`=sequence, `SN`=numeric sequence, `SY`=symbol, `*`=any -- parenthized values indicate argument defaults. `*` following the type indicates variable length arguments. For functions that take keyword arguments, the type will be written like "_name: `TYPE(default_value)`_". 111 | 112 | ## Tensor Expression Functions 113 | 114 | *Tensor dimensions hint*: Most tensors you'll be dealing with are laid out as `batch`, `channels`, `height`, `width`. Negative indexes start from the end, so dimension `-1` would mean _width_ just the same as `3`. 115 | 116 | | | Name | Input | Output | 117 | | :--- | :--- | :--- | :--- | 118 | |⬤| `t_bleh_enhance` | tensor:`T`, mode:`SY`, scale:`N(1.0)` | `T` 119 | | Available if you have the [ComfyUI-bleh](https://github.com/blepping/ComfyUI-bleh) node pack installed. See [Filtering](filter.md#bleh_enhance).
**Example:** `bleh_enhance(some_tensor, 'bandpass, 0.5)` | 120 | |⬤| `t_blend` | tensor1:`T`, tensor2:`T`, scale:`N(0.5)`, mode:`SY(lerp)` | `T` | 121 | | Tensor blend operation.
**Example:** `t_blend(t1, t2, 0.75, 'lerp)` | 122 | |⬤| `t_contrast_adaptive_sharpening` | tensor:`T`, scale:`N(0.5)` | `T` | 123 | | Contrast adaptive sharpening. _Note_: Not recommended to call on noisy tensors (so `denoised` but probably not `x`).
**Example:** `t_contrast_adaptive_sharpening(some_tensor, 0.1)` | 124 | |⬤| `t_flip` | tensor:`T`, dim:`NS`, mirror:`B(false)` | `T` | 125 | | Flips a tensor on the specified dimension. If the third argument is true, it will be mirrored around the center in that dimension.
**Example:** `t_flip(some_tensor, -1, true)` | 126 | |⬤| `t_mean` | tensor:`T`, dim:`SN(-3, -2, -1)` | `T` | 127 | | Tensor mean, second argument is dimensions.
**Example:** `t_mean(some_tensor, (-2, -1))` | 128 | |⬤| `t_noise` | tensor:`T`, type:`SY(gaussian)` | `T` | 129 | | Generates un-normalized noise (use `t_norm` if you want to normalize it). If you have ComfyUI-sonar you can use any noise type that supports, otherwise only `gaussian`. The generated noise will have the same shape as the supplied tensor (hopefully, may not be true for every exotic noise type but at least should be broadcastable to the tensor).
Example: `t_noise(some_tensor, 'pyramid)` | 130 | |⬤| `t_norm` | tensor:`T`, factor:`N(1.0)`, dim:`SN(-3, -2, -1)` | `T` | 131 | | Tensor normalization (subtracts mean, divides by std).
**Example:** `t_norm(some_tensor, 1.0, (-2, -1))` | 132 | |⬤| `t_sonar_power_filter` | tensor:`T`, filter:`dict` | `T` | 133 | | Available if you have [ComfyUI-sonar](https://github.com/blepping/ComfyUI-sonar) installed. See [Filtering](filter.md#sonar_power_filter). Constructs a power filter from a dictionary argument. _Note_: May be slow as the filter is reconstructed on every evaluation.
**Example:** `t_sonar_power_filter(some_tensor, dict(alpha :> 0.1, min_freq :> 0.2, max_freq :> 0.6))` | 134 | |⬤| `t_roll` | tensor:`T`, amount:`NS(0.5)`, dim:`SN((-2,))` | `T` | 135 | | Rolls a tensor along the specified dimensions. If amount is >= -1.0 and < 1.0 this will be interpreted as a percentage.
**Example:** `t_roll(some_tensor, 10, (-2,))` | 136 | |⬤| `t_scale` | tensor:`T`, scale:`SN \| NS`, mode:`SY(bicubic)`, absolute_scale:`B(false)` | `T` | 137 | | Scales a tensor. If scale is a tuple, it will be interpreted as `(height, width)`. When `absolute_scale` is not set, the scales will be interpreted as percentages otherwise absolute values will be used.
Example: `t_scale(some_tensor, (0.75, 0.5), 'bilinear)` | 138 | |⬤| `t_scale_nnlatentupscale` | tensor:`T`, mode:`SY(sd1)`, scale:`SN(2.0)` | `T` | 139 | | Available if you have [ComfyUi_NNLatentUpscale](https://github.com/Ttl/ComfyUi_NNLatentUpscale) installed. `mode` must be one of `sd1` or `sdxl`. `scale` should be between 1.0 and 2.0 (may or may not work out of that range).
**Example:** `t_scale_nnlatentupscale(some_tensor, 'sdxl, 1.5)` | 140 | |⬤| `t_shape` | tensor:`T` | `SN` | 141 | | Returns a tensor's shape as a tuple.
Example: `shp := t_shape(some_tensor); width := shp[-1]; height := shp[-2]` | 142 | |⬤| `t_std` | tensor:`T`, dim:`SN(-3, -2, -1)` | `T` | 143 | | Tensor std, second argument is dimensions.
**Example:** `t_std(some_tensor, (-2, -1))` | 144 | |⬤| `t_taesd_decode` | tensor:`T`, mode:`SY(sd15)` | `T` | 145 | | Decodes a latent tensor used TAESD. Mode must be one of `sd15`, `sdxl`. Only works if the appropriate models are in `vae_approx`
**Example:** `t_taesd_decode(some_tensor, 'sd15)` | 146 | |⬤| `unsafe_tensor_method` | `T`, `SY`, `*`\* | `*` | 147 | | Unsafe tensor method call. See note below.
**Example:** `unsafe_tensor_method(some_tensor, 'mul, 10)` | 148 | |⬤| `unsafe_torch` | path:`SY` | `*` | 149 | | Unsafe Torch module attribute access. See note below.
**Example:** `unsafe_torch('nn.functional.interpolate)` | 150 | 151 | **Note on `unsafe_tensor_method` and `unsafe_torch`**: These functions are disabled by default. If the environment variable `COMFYUI_OCS_ALLOW_UNSAFE_EXPRESSIONS` is set to anything then you can use `unsafe_tensor_method` with a whitelisted set of methods (best effort to avoid anything actually unsafe). If the environment variable `COMFYUI_OCS_ALLOW_ALL_UNSAFE` is set to anything then `unsafe_torch` is enabled and `unsafe_tensor_method` will allow calling any method. ***WARNING***: Allowing _all_ unsafe with workflows you don't trust is _not_ recommended and a malicious workflow will likely have access to anything ComfyUI can access. It is effectively the same as letting the workflow run an arbitrary script on your system. 152 | 153 | ## Tensor Expression Functions 154 | 155 | `IMG` used here to donate the type for functions that take an image. This may actually be an image batch rather than a single image. 156 | 157 | | | Name | Input | Output | 158 | | :--- | :--- | :--- | :--- | 159 | |⬤| `img_pil_resize` | image:`IMG`, size:`SN \| NS`, resample_mode:`SY(bicubic)`, absolute_scale:`B(false)` | `IMG` | 160 | | Scales an image batch using Pillow's [`Image.resize`](https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.resize) function (follow link for information about resample modes, etc). If scale is a tuple, it will be interpreted as `(height, width)`. When `absolute_scale` is not set, the scales will be interpreted as percentages otherwise absolute values will be used.
Example: `img_pil_resize(image_batch, (0.75, 0.5), 'lanczos)` | 161 | |⬤| `img_shape` | image:`IMG` | `SN` | 162 | | Returns an image's shape as a tuple. Will fail if all the images in the batch aren't the same size.
Example: `shp := img_shape(image_batch); width := shp[-1]; height := shp[-2]` | 163 | |⬤| `img_taesd_encode` | image:`IMG`, reference_latent: `T`, mode:`SY(sd15)` | `T` | 164 | | Encodes an image batch into a latent tensor. The reference latent is only used to determine what device and type the output should be. Mode must be one of `sd15`, `sdxl`. Only works if the appropriate models are in `vae_approx`.
**Example:** `img_taesd_encode(image_batch, some_tensor, 'sd15)` | 165 | -------------------------------------------------------------------------------- /docs/filter.md: -------------------------------------------------------------------------------- 1 | # OCS Filters 2 | 3 | Filters allow changing sampler inputs/outputs, model input/outputs, generated noise and so on. They are 4 | configured by the advanced YAML/JSON parameter block in the node. 5 | 6 | ## Filter Support 7 | 8 | ### `OCS Substeps` 9 | 10 | Set via the `pre_filter` and `post_filter` keys. 11 | 12 | ### `OCS Group` 13 | 14 | Set via the `pre_filter` and `post_filter` keys. 15 | 16 | *Note*: Since the group pre-filter may be called before any model calls, variables like `denoised` may not be available 17 | in expressions. 18 | 19 | ### `OCS Sampler` 20 | 21 | **Noise** 22 | 23 | ```yaml 24 | noise: 25 | # Or set to a valid filter definition. 26 | filter: null 27 | ``` 28 | 29 | **Model** 30 | 31 | *Note*: Since the model filters may be called before any other model calls, variables like `denoised` may not be available 32 | in expressions. With the exception of the `input` filter you will have access to `denoised_curr`, `cond_curr`, etc. 33 | See [Expressions](expression.md#model-filter-variables). 34 | 35 | ```yaml 36 | model: 37 | filter: 38 | # Applies to the input passed to the model. 39 | input: null 40 | 41 | # Applies to denoised output. 42 | denoised: null 43 | 44 | # Applies to JVP denoised output (only used by TTM sampler) 45 | jdenoised: null 46 | 47 | # Applies to cond output. 48 | cond: null 49 | 50 | # Applies to uncond output. 51 | uncond: null 52 | 53 | ``` 54 | 55 | ### Immiscible 56 | 57 | `immiscible` is a special type of filter: in this case, you do not set `filter_type`. Normal filter keys 58 | apply in the places where `immiscible` can be set. 59 | 60 | ## Filter Definitions 61 | 62 | For information about expressions, see [Expressions](expression.md). 63 | 64 | A basic filter supports these keys: 65 | 66 | ```yaml 67 | enabled: true 68 | 69 | filter_type: simple 70 | 71 | # Expression that is evaluated to determine whether the filter applies. May be null. 72 | # If set, should evaluate to a boolean. 73 | when: null 74 | 75 | blend_mode: lerp 76 | 77 | # Blend strength applied to output. 78 | strength: 1.0 79 | 80 | # Input expression. input, ref, output and final should evaluate to a tensor. 81 | input: default 82 | 83 | # Reference expression (only used for immiscible noise currently). 84 | ref: default 85 | 86 | # Output expression. 87 | output: default 88 | 89 | # Final expression - occurs *after* blending. 90 | final: default 91 | ``` 92 | 93 | There may be additional keys depending on the filter type. 94 | 95 | If you have [ComfyUI-bleh](https://github.com/blepping/ComfyUI-bleh) available, you can use any blend mode it supports. Otherwise OCS provides these built-in blend modes: `lerp`, `a_only`, `b_only`. _Note_: `a` is considered the original value, `b` the changed value. `a_only` and `b_only` will still scale their output by the `strength`. 96 | 97 | ## Filter Types 98 | 99 | ### `simple` 100 | 101 | Base filter, no special behavior. No additional parameters. 102 | 103 | ### `blend` 104 | 105 | Blends the result of two other filters. 106 | 107 | Keys: 108 | 109 | ```yaml 110 | # No default, value for example purpose only. 111 | filter1: 112 | filter_type: simple 113 | 114 | # No default, value for example purpose only. 115 | filter2: 116 | filter_type: simple 117 | ``` 118 | 119 | ### `list` 120 | 121 | A list of filters. The output of the previous is given to the next as input. 122 | The `list` filter's blend applies to the output from the final filter in the list. 123 | 124 | Keys: 125 | 126 | ```yaml 127 | # Values for example only, default is an empty list. 128 | filters: 129 | - filter_type: simple 130 | strength: 1.0 131 | - filter_type: simple 132 | strength: 1.0 133 | ``` 134 | 135 | ### `bleh_enhance` 136 | 137 | Available if you have the [ComfyUI-bleh](https://github.com/blepping/ComfyUI-bleh) node pack installed. See: 138 | https://github.com/blepping/ComfyUI-bleh#enhancement-types 139 | 140 | Keys: 141 | 142 | ```yaml 143 | enhance_mode: null 144 | enhance_scale: 1.0 145 | ``` 146 | 147 | ### `bleh_ops` 148 | 149 | Available if you have the [ComfyUI-bleh](https://github.com/blepping/ComfyUI-bleh) node pack installed. See: 150 | https://github.com/blepping/ComfyUI-bleh#blehblockops 151 | 152 | Keys: 153 | 154 | ```yaml 155 | # May be specified as a string containing the YAML rule definitions or inline. 156 | # Values for example only, default is an empty list of ops. 157 | ops: 158 | - if: 159 | to_percent: 0.5 160 | ops: # Not recommended to actually do this. 161 | - [flip, { direction: h }] 162 | - [roll, { direction: channels, amount: -2 }] 163 | ``` 164 | 165 | ### `sonar_power_filter` 166 | 167 | Available if you have [ComfyUI-sonar](https://github.com/blepping/ComfyUI-sonar) installed. See: 168 | https://github.com/blepping/ComfyUI-sonar/blob/main/docs/advanced_power_noise.md 169 | 170 | Keys: 171 | 172 | ```yaml 173 | power_filter: 174 | mix: 1.0 175 | normalization_factor: 1.0 176 | common_mode: 0.0 177 | channel_correlation: "1,1,1,1,1,1" 178 | alpha: 0.0 179 | min_freq: 0.0 180 | max_freq: 0.7071 181 | stretch: 1.0 182 | rotate: 0.0 183 | pnorm: 2.0 184 | scale: 1.0 185 | compose_mode: max 186 | 187 | # If specified should be another power filter definition. 188 | compose_with: null 189 | ``` 190 | -------------------------------------------------------------------------------- /py/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blepping/comfyui_overly_complicated_sampling/716b0d749ad64be71d9c535a03c4331e75272d92/py/__init__.py -------------------------------------------------------------------------------- /py/custom_noise/__init__.py: -------------------------------------------------------------------------------- 1 | from . import noise_perlin 2 | from . import nodes 3 | 4 | NODE_CLASS_MAPPINGS = { 5 | "OCSNoise PerlinSimple": noise_perlin.PerlinSimpleNode, 6 | "OCSNoise PerlinAdvanced": noise_perlin.PerlinAdvancedNode, 7 | "OCSNoise to SONAR_CUSTOM_NOISE": nodes.ToSonarNode, 8 | } 9 | -------------------------------------------------------------------------------- /py/custom_noise/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | 4 | from typing import Callable, Any 5 | 6 | from ..noise import scale_noise 7 | 8 | 9 | class CustomNoiseItemBase(abc.ABC): 10 | def __init__(self, factor, **kwargs): 11 | self.factor = factor 12 | self.keys = set(kwargs.keys()) 13 | for k, v in kwargs.items(): 14 | setattr(self, k, v) 15 | 16 | def clone_key(self, k): 17 | return getattr(self, k) 18 | 19 | def clone(self): 20 | return self.__class__(self.factor, **{k: self.clone_key(k) for k in self.keys}) 21 | 22 | def set_factor(self, factor): 23 | self.factor = factor 24 | return self 25 | 26 | def get_normalize(self, k, default=None): 27 | val = getattr(self, k, None) 28 | return default if val is None else val 29 | 30 | @abc.abstractmethod 31 | def make_noise_sampler( 32 | self, 33 | x: torch.Tensor, 34 | sigma_min=None, 35 | sigma_max=None, 36 | seed=None, 37 | cpu=True, 38 | normalized=True, 39 | ): 40 | raise NotImplementedError 41 | 42 | 43 | class CustomNoiseChain: 44 | def __init__(self, items=None): 45 | self.items = items if items is not None else [] 46 | 47 | def clone(self): 48 | return CustomNoiseChain( 49 | [i.clone() for i in self.items], 50 | ) 51 | 52 | def add(self, item): 53 | if item is None: 54 | raise ValueError("Attempt to add nil item") 55 | self.items.append(item) 56 | 57 | @property 58 | def factor(self): 59 | return sum(abs(i.factor) for i in self.items) 60 | 61 | def rescaled(self, scale=1.0): 62 | divisor = self.factor / scale 63 | divisor = divisor if divisor != 0 else 1.0 64 | result = self.clone() 65 | if divisor != 1: 66 | for i in result.items: 67 | i.set_factor(i.factor / divisor) 68 | return result 69 | 70 | @torch.no_grad() 71 | def make_noise_sampler( 72 | self, 73 | x: torch.Tensor, 74 | sigma_min=None, 75 | sigma_max=None, 76 | seed=None, 77 | cpu=True, 78 | normalized=True, 79 | ) -> Callable: 80 | noise_samplers = tuple( 81 | i.make_noise_sampler( 82 | x, 83 | sigma_min, 84 | sigma_max, 85 | seed=seed, 86 | cpu=cpu, 87 | normalized=False, 88 | ) 89 | for i in self.items 90 | ) 91 | if not noise_samplers or not all(noise_samplers): 92 | raise ValueError("Failed to get noise sampler") 93 | factor = self.factor 94 | 95 | def noise_sampler(sigma, sigma_next): 96 | result = None 97 | for ns in noise_samplers: 98 | noise = ns(sigma, sigma_next) 99 | if result is None: 100 | result = noise 101 | else: 102 | result += noise 103 | return scale_noise(result, factor, normalized=normalized) 104 | 105 | return noise_sampler 106 | 107 | 108 | class CustomNoiseNodeBase(abc.ABC): 109 | DESCRIPTION = "An Overly Complicated Sampling custom noise item." 110 | RETURN_TYPES = ("OCS_NOISE",) 111 | OUTPUT_TOOLTIPS = ("A custom noise chain.",) 112 | CATEGORY = "OveryComplicatedSampling/noise" 113 | FUNCTION = "go" 114 | 115 | @abc.abstractmethod 116 | def get_item_class(self): 117 | raise NotImplementedError 118 | 119 | @classmethod 120 | def INPUT_TYPES(cls, *, include_rescale=True, include_chain=True): 121 | result = { 122 | "required": { 123 | "factor": ( 124 | "FLOAT", 125 | { 126 | "default": 1.0, 127 | "min": -100.0, 128 | "max": 100.0, 129 | "step": 0.001, 130 | "round": False, 131 | "tooltip": "Scaling factor for the generated noise of this type.", 132 | }, 133 | ), 134 | }, 135 | "optional": {}, 136 | } 137 | if include_rescale: 138 | result["required"] |= { 139 | "rescale": ( 140 | "FLOAT", 141 | { 142 | "default": 0.0, 143 | "min": 0.0, 144 | "max": 100.0, 145 | "step": 0.001, 146 | "round": False, 147 | "tooltip": "When non-zero, this custom noise item and other custom noise items items connected to it will have their factor scaled to add up to the specified rescale value.", 148 | }, 149 | ), 150 | } 151 | if include_chain: 152 | result["optional"] |= { 153 | "ocs_noise_opt": ( 154 | "OCS_NOISE", 155 | { 156 | "tooltip": "Optional input for more custom noise items.", 157 | }, 158 | ), 159 | } 160 | return result 161 | 162 | def go( 163 | self, 164 | factor=1.0, 165 | rescale=0.0, 166 | ocs_noise_opt=None, 167 | **kwargs: dict[str, Any], 168 | ): 169 | nis = ocs_noise_opt.clone() if ocs_noise_opt else CustomNoiseChain() 170 | if factor != 0: 171 | nis.add(self.get_item_class()(factor, **kwargs)) 172 | return (nis if rescale == 0 else nis.rescaled(rescale),) 173 | 174 | 175 | class NormalizeNoiseNodeMixin: 176 | @staticmethod 177 | def get_normalize(val: str) -> None | bool: 178 | return None if val == "default" else val == "forced" 179 | -------------------------------------------------------------------------------- /py/custom_noise/nodes.py: -------------------------------------------------------------------------------- 1 | class ToSonarNode: 2 | RETURN_TYPES = ("SONAR_CUSTOM_NOISE",) 3 | CATEGORY = "OveryComplicatedSampling/noise" 4 | FUNCTION = "go" 5 | 6 | @classmethod 7 | def INPUT_TYPES(cls): 8 | return { 9 | "required": { 10 | "ocs_noise": ("OCS_NOISE",), 11 | }, 12 | } 13 | 14 | @classmethod 15 | def go(cls, ocs_noise): 16 | return (ocs_noise,) 17 | -------------------------------------------------------------------------------- /py/expression/__init__.py: -------------------------------------------------------------------------------- 1 | from . import types, expression, handler, util, validation 2 | 3 | from .expression import Expression 4 | from .validation import Arg, ValidateArg 5 | from .handler import BASIC_HANDLERS, BaseHandler, HandlerContext 6 | 7 | __all__ = ( 8 | "Arg", 9 | "BaseHandler", 10 | "BASIC_HANDLERS", 11 | "expression", 12 | "Expression", 13 | "handler", 14 | "HandlerContext", 15 | "types", 16 | "util", 17 | "ValidateArg", 18 | "validation", 19 | ) 20 | -------------------------------------------------------------------------------- /py/expression/expression.py: -------------------------------------------------------------------------------- 1 | import re 2 | import operator 3 | 4 | from .parser import Parser, ParserSpec, ParseError 5 | from .types import ( 6 | Empty, 7 | ExpBase, 8 | ExpOp, 9 | ExpBinOp, 10 | ExpSym, 11 | ExpStatements, 12 | ExpFunAp, 13 | ExpTuple, 14 | ExpDict, 15 | ExpKV, 16 | ) 17 | 18 | COMMA_PRECEDENCE = 2 19 | 20 | 21 | class Expression: 22 | EXPR_RE = re.compile( 23 | r""" 24 | \s* 25 | ( 26 | \d+ # Numeric literal 27 | (?: \. \d* )? # Floating point 28 | (?: e [+-] \d+)? # Scientific notation 29 | | (?: \*\* | // ) # Doubled operators 30 | | [<>]=? # Relative comparison 31 | | [!=]= # Equality 32 | | (?: \|\| | && ) # Logic 33 | | [-+*/|!(),] # Operators 34 | | :> # Key value binop 35 | | := # Assignment 36 | | ; # Sequencing 37 | | [?:] # Ternary 38 | | \[ | ] # Index 39 | | \.\.\. # Index ellipsis 40 | | '[-\w.]+ # Symbol 41 | | `?[a-z][\w.]*`? # Function/variable names 42 | ) 43 | \s* 44 | """, 45 | re.I | re.S | re.X | re.A, 46 | ) 47 | 48 | def __init__(self, toks): 49 | if isinstance(toks, str): 50 | toks = tuple(self.tokenize(toks)) 51 | self.expr = Parser(ExprParserSpec(), iter(toks)).go() 52 | 53 | def __repr__(self): 54 | return f"" 55 | 56 | def __call__(self, *args, **kwargs): 57 | return self.eval(*args, **kwargs) 58 | 59 | def eval(self, handlers, *args, **kwargs): 60 | if self.expr != ExpOp("default"): 61 | print("\nEVAL", self.expr) 62 | if not isinstance(self.expr, ExpBase): 63 | return self.expr 64 | return self.expr.eval(handlers, *args, **kwargs) 65 | 66 | def __len__(self): 67 | return len(self.expr) 68 | 69 | def pretty_string(self, depth=0): 70 | sval = ( 71 | repr(self.expr) 72 | if not isinstance(self.expr, ExpBase) 73 | else self.expr.pretty_string(depth=depth + 1) 74 | ) 75 | pad = " " * (depth + 1) * 2 76 | return f"" 77 | 78 | FIXUP = {"true": True, "false": False, "...": Ellipsis, "none": None} 79 | 80 | @classmethod 81 | def fixup_token(cls, t): 82 | if t == "": 83 | return t 84 | t = t.lower() 85 | val = cls.FIXUP.get(t, Empty) 86 | if val is not Empty: 87 | return val 88 | if t[0] == "`": 89 | return ExpBinOp(t.strip("`")) 90 | if t[0] == "'": 91 | return ExpSym(t[1:]) 92 | if (len(t) > 1 and t[0] == "-" and t[1].isdigit()) or t[0].isdigit(): 93 | return float(t) if "." in t else int(t) 94 | return ExpOp(t) 95 | 96 | @classmethod 97 | def tokenize(cls, s): 98 | yield from (cls.fixup_token(m.group(1)) for m in cls.EXPR_RE.finditer(s)) 99 | 100 | 101 | CONST_OP_HANDLERS = { 102 | "+": operator.add, 103 | "-": operator.sub, 104 | "*": operator.mul, 105 | "/": operator.truediv, 106 | "//": operator.floordiv, 107 | "**": operator.pow, 108 | "%": operator.mod, 109 | "add": operator.add, 110 | "sub": operator.sub, 111 | "mul": operator.mul, 112 | "div": operator.truediv, 113 | "idiv": operator.floordiv, 114 | "pow": operator.pow, 115 | "mod": operator.mod, 116 | "neg": operator.neg, 117 | ">": operator.gt, 118 | "<": operator.lt, 119 | ">=": operator.ge, 120 | "<=": operator.le, 121 | "!=": operator.ne, 122 | "==": operator.eq, 123 | } 124 | 125 | 126 | def is_const_value(val): 127 | return val in (None, True, False) or isinstance(val, (int, float, ExpSym)) 128 | 129 | 130 | def make_funap(op, args=(), kwargs=None): 131 | if kwargs is None: 132 | kwargs = ExpDict() 133 | argc = len(args) 134 | if argc > 2 or len(kwargs) or not all(is_const_value(v) for v in args): 135 | return ExpFunAp(op, args, kwargs) 136 | if argc == 1 and op in "-+": 137 | return -args[0] if op == "-" else args[0] 138 | h = CONST_OP_HANDLERS.get(op) 139 | if h is None: 140 | return ExpFunAp(op, args, kwargs) 141 | return h(*args) 142 | 143 | 144 | class ExprParserSpec(ParserSpec): 145 | def __init__(self): 146 | super().__init__() 147 | self.populate() 148 | 149 | @staticmethod 150 | def split_funap_args(toks): 151 | if not isinstance(toks, (list, tuple)): 152 | return ExpTuple((toks,)), ExpDict() 153 | return ExpTuple(t for t in toks if not isinstance(t, ExpKV)), ExpDict({ 154 | str(t.k): t.v for t in toks if isinstance(t, ExpKV) 155 | }) 156 | 157 | @staticmethod 158 | def null_constant(p, token, bp): 159 | return token 160 | 161 | @staticmethod 162 | def null_paren(p, token, bp): 163 | result = p.parse_until(bp) if p.token != ")" else ExpTuple() 164 | if p.token == ",": 165 | p.advance() 166 | p.expect(")") 167 | return result 168 | 169 | @staticmethod 170 | def null_prefixop(p, token, bp): 171 | val = p.parse_until(bp) 172 | return make_funap(token, ExpTuple((val,))) 173 | 174 | @classmethod 175 | def left_binop(cls, p, token, left, bp): 176 | return make_funap(token, *cls.split_funap_args((left, p.parse_until(bp)))) 177 | 178 | @staticmethod 179 | def left_kv(p, token, left, bp): 180 | if not isinstance(left, (ExpOp, ExpSym)): 181 | raise ParseError(f"{left!r} is not a valid key") 182 | return ExpKV(left, p.parse_until(bp)) 183 | 184 | @classmethod 185 | def left_funcall(cls, p, token, left, bp): 186 | if not isinstance(left, ExpOp): 187 | raise ParseError(f"{left!r} is not a valid function/variable name") 188 | args = [] 189 | while p.lexer and p.token != ")": 190 | args.append(p.parse_until(COMMA_PRECEDENCE)) 191 | if p.token == ",": 192 | p.advance() 193 | p.expect(")") 194 | return make_funap(left, *cls.split_funap_args(args)) 195 | 196 | @staticmethod 197 | def left_comma(p, token, left, bp): 198 | if p.token == ")": 199 | return left if isinstance(left, ExpTuple) else ExpTuple((left,)) 200 | r = p.parse_until(bp) 201 | return ExpTuple((*left, r) if isinstance(left, ExpTuple) else (left, r)) 202 | 203 | @staticmethod 204 | def left_semicolon(p, token, left, bp): 205 | r = None if p.token in (None, ")", ";") else p.parse_until(0) 206 | return ExpStatements( 207 | ExpTuple((*left.statements, r)) 208 | if isinstance(left, ExpStatements) 209 | else ExpTuple((left, r)) 210 | ) 211 | 212 | @staticmethod 213 | def left_index(p, token, left, bp): 214 | idx = p.parse_until(0) 215 | p.expect("]") 216 | return make_funap("index", ExpTuple((idx, left))) 217 | 218 | @staticmethod 219 | def left_assign(p, token, left, bp): 220 | if not isinstance(left, (ExpOp, ExpSym)): 221 | raise ParseError(f"bad LHS type for assignment operation {type(left)}") 222 | val = p.parse_until(bp) 223 | return make_funap("set_var", ExpTuple((ExpSym(left), val))) 224 | 225 | @staticmethod 226 | def left_ternary(p, token, left, bp): 227 | true_branch = p.parse_until(0) 228 | p.expect(":") 229 | false_branch = p.parse_until(bp) 230 | return make_funap("if", ExpTuple((left, true_branch, false_branch))) 231 | 232 | @staticmethod 233 | def get_type(token): 234 | if isinstance(token, (int, float)): 235 | return "number" 236 | if isinstance(token, ExpSym): 237 | return "sym" 238 | if isinstance(token, ExpBinOp): 239 | return "binop" 240 | if isinstance(token, ExpOp) and token[0].isalpha(): 241 | return "op" 242 | return token 243 | 244 | def populate(self): 245 | self.add_left(31, self.left_funcall, ("(",)) 246 | self.add_left(31, self.left_index, ("[",)) 247 | self.add_leftright(29, self.left_binop, ("**",)) 248 | self.add_null(27, self.null_prefixop, ("+", "-", "!")) 249 | self.add_left(25, self.left_binop, ("*", "/")) 250 | self.add_left(23, self.left_binop, ("+", "-")) 251 | self.add_left(22, self.left_binop, ("binop",)) 252 | self.add_left(19, self.left_binop, ("<", ">", "<=", ">=")) 253 | self.add_left(19, self.left_binop, ("==", "!=")) 254 | self.add_left(9, self.left_binop, ("&&",)) 255 | self.add_left(7, self.left_binop, ("||",)) 256 | self.add_left(6, self.left_kv, (":>",)) 257 | self.add_leftright(5, self.left_ternary, ("?",)) 258 | self.add_leftright(4, self.left_assign, (":=",)) 259 | self.add_left(COMMA_PRECEDENCE, self.left_comma, (",",)) 260 | self.add_left(1, self.left_semicolon, (";",)) 261 | self.add_null(0, self.null_paren, ("(",)) 262 | self.add_null( 263 | -1, self.null_constant, ("number", "op", "sym", Ellipsis, True, False, None) 264 | ) 265 | self.add_null(-1, ParserSpec.null_error, (")", "]", ":")) 266 | -------------------------------------------------------------------------------- /py/expression/handler.py: -------------------------------------------------------------------------------- 1 | import operator 2 | 3 | from .validation import ValidateArg, Arg, ValidateError 4 | from .types import Empty, ExpDict, ExpOp 5 | from .util import torch 6 | 7 | 8 | class HandlerError(Exception): 9 | pass 10 | 11 | 12 | class HandlerContext: 13 | def __init__(self, handlers=None, constants=None, variables=None): 14 | self.handlers = handlers if handlers is not None else {} 15 | self.constants = constants if constants is not None else {} 16 | self.variables = variables if variables is not None else {} 17 | 18 | def get_handler(self, k, default=Empty): 19 | return self.handlers.get(k, default) 20 | 21 | def get_var(self, k, default=Empty): 22 | result = self.constants.get(k, Empty) 23 | if result is Empty: 24 | result = self.variables.get(k, Empty) 25 | return default if result is Empty else result 26 | 27 | def set_var(self, k, v): 28 | if k in self.constants: 29 | raise KeyError( 30 | f"Cannot set variable with key {k}: already exists as a constant" 31 | ) 32 | self.variables[k] = v 33 | 34 | def unset_var(self, k): 35 | if k in self.variables: 36 | del self.variables[k] 37 | return True 38 | return False 39 | 40 | def __contains__(self, k): 41 | return any( 42 | k in coll for coll in (self.handlers, self.constants, self.variables) 43 | ) 44 | 45 | def clone(self, *, handlers=Empty, constants=Empty, variables=Empty): 46 | return self.__class__( 47 | self.handlers if handlers is Empty else handlers, 48 | self.constants if constants is Empty else constants, 49 | self.variables if variables is Empty else variables, 50 | ) 51 | 52 | 53 | class BaseHandler: 54 | input_validators = () 55 | 56 | def __init__(self): 57 | self.input_validators_by_key = { 58 | v.name: (idx, v) for idx, v in enumerate(self.input_validators) 59 | } 60 | 61 | def __call__(self, obj, *, getter): 62 | try: 63 | val = self.handle(obj, getter) 64 | return self.validate_output(obj, val) 65 | except Exception as exc: 66 | raise HandlerError(f'Error evaluating "{obj.name}":\n {exc!r}') from exc 67 | 68 | def safe_get(self, key, obj, getter=None, *, default=Empty): 69 | str_key = isinstance(key, str) 70 | if str_key: 71 | argidx, validator = self.input_validators_by_key.get(key, (-1, None)) 72 | else: 73 | argidx, validator = ( 74 | key, 75 | ( 76 | self.input_validators[key] 77 | if key < len(self.input_validators) 78 | else None 79 | ), 80 | ) 81 | default = ( 82 | default 83 | if default is not Empty or validator is None 84 | else getattr(validator, "default", Empty) 85 | ) 86 | if argidx >= 0 and argidx < len(obj.args): 87 | eff_key = argidx 88 | str_eff_key = False 89 | elif str_key: 90 | eff_key = key 91 | str_eff_key = True 92 | else: 93 | raise ValidateError( 94 | f"Error validating input argument {key} for {obj.name}, out of range for actual function arguments" 95 | ) 96 | if getter is None: 97 | if str_eff_key: 98 | val = obj.kwargs.get(eff_key) 99 | else: 100 | val = default if eff_key > len(obj.args) else obj.args[eff_key] 101 | else: 102 | val = getter(eff_key, default=default) 103 | if validator is None: 104 | return val 105 | try: 106 | return validator(key, val) 107 | except ValidateError as exc: 108 | raise ValidateError( 109 | f"Error validating input argument {key} for {obj.name}, type {type(val)}: {exc!r}" 110 | ) 111 | 112 | def safe_get_multi(self, keys, obj, getter=None, *, default=Empty): 113 | return (self.safe_get(k, obj, getter, default=default) for k in keys) 114 | 115 | def safe_get_all(self, obj, getter=None, *, default=Empty): 116 | return self.safe_get_multi( 117 | (v.name for v in self.input_validators), obj, getter, default=default 118 | ) 119 | 120 | def handle(self, obj, getter): 121 | raise NotImplementedError 122 | 123 | def validate_output(self, obj, value): 124 | return value 125 | 126 | 127 | class BinopLogicHandler(BaseHandler): 128 | input_validators = ( 129 | Arg.present("lhs"), 130 | Arg.present("rhs"), 131 | ) 132 | 133 | def validate_output(self, obj, value): 134 | return operator.truth(value) 135 | 136 | 137 | class OrHandler(BinopLogicHandler): 138 | def handle(self, obj, getter): 139 | return operator.truth( 140 | self.safe_get("lhs", obj, getter=getter) 141 | ) or operator.truth(self.safe_get("rhs", obj, getter=getter)) 142 | 143 | 144 | class AndHandler(BinopLogicHandler): 145 | def handle(self, obj, getter): 146 | return operator.truth( 147 | self.safe_get("lhs", obj, getter=getter) 148 | ) and operator.truth(self.safe_get("rhs", obj, getter=getter)) 149 | 150 | 151 | class AllHandler(BinopLogicHandler): 152 | input_validators = () 153 | 154 | def handle(self, obj, getter): 155 | return all( 156 | operator.truth(self.safe_get(idx, obj, getter=getter)) 157 | for idx in range(len(obj.args)) 158 | ) and all( 159 | operator.truth(self.safe_get(key, obj, getter=getter)) for key in obj.kwargs 160 | ) 161 | 162 | 163 | class AnyHandler(BinopLogicHandler): 164 | def handle(self, obj, getter): 165 | return any( 166 | operator.truth(self.safe_get(idx, obj, getter=getter)) 167 | for idx in range(len(obj.args)) 168 | ) or any( 169 | operator.truth(self.safe_get(key, obj, getter=getter)) for key in obj.kwargs 170 | ) 171 | 172 | 173 | class EqHandler(BinopLogicHandler): 174 | def handle(self, obj, getter): 175 | a1, a2 = self.safe_get_all(obj, getter) 176 | if isinstance(a1, torch.Tensor) and isinstance(a2, torch.Tensor): 177 | return torch.equal(a1, a2) 178 | return a1 == a2 179 | 180 | 181 | class NeqHandler(EqHandler): 182 | def handle(self, *args, **kwargs): 183 | return not super().handle(*args, **kwargs) 184 | 185 | 186 | class NotHandler(BinopLogicHandler): 187 | input_validators = (Arg.present("value"),) 188 | 189 | def handle(self, obj, getter): 190 | return not operator.truth(self.safe_get("value", obj, getter=getter)) 191 | 192 | 193 | class IfHandler(BaseHandler): 194 | input_validators = ( 195 | Arg.present("condition"), 196 | Arg.present("then"), 197 | Arg.present("else"), 198 | ) 199 | 200 | def handle(self, obj, getter): 201 | if operator.truth(self.safe_get("condition", obj, getter=getter)): 202 | return self.safe_get("then", obj, getter=getter) 203 | return self.safe_get("else", obj, getter=getter) 204 | 205 | 206 | class BetweenHandler(BaseHandler): # Inclusive 207 | input_validators = ( 208 | Arg.numeric("value"), 209 | Arg.numeric("from", 0.0), 210 | Arg.numeric("to"), 211 | ) 212 | 213 | def handle(self, obj, getter): 214 | value, low, high = self.safe_get_all(obj, getter) 215 | return low <= value <= high 216 | 217 | 218 | class SimpleMathHandler(BaseHandler): 219 | input_validators = (Arg.numeric("lhs"), Arg.numeric("rhs")) 220 | 221 | def __init__(self, handler): 222 | super().__init__() 223 | self.handler = handler 224 | 225 | def validate_output(self, obj, value): 226 | return ValidateArg.validate_numeric(-1, value) 227 | 228 | def handle(self, obj, getter): 229 | args = ( 230 | self.safe_get(idx, obj, getter=getter) 231 | for idx in range(len(self.input_validators)) 232 | ) 233 | return self.handler(*args) 234 | 235 | 236 | class MinusHandler(SimpleMathHandler): 237 | input_validators = (Arg.numeric("lhs"), Arg.numeric("rhs", default=Empty)) 238 | 239 | __init__ = BaseHandler.__init__ 240 | 241 | def handle(self, obj, getter): 242 | lhs, rhs = self.safe_get_all(obj, getter) 243 | if rhs is Empty: 244 | return operator.neg(lhs) 245 | return operator.sub(lhs, rhs) 246 | 247 | 248 | class RelComparisonHandler(SimpleMathHandler): 249 | def validate_output(self, obj, value): 250 | return operator.truth(value) 251 | 252 | 253 | class UnarySimpleMathHandler(SimpleMathHandler): 254 | input_validators = (Arg.numeric("lhs"),) 255 | 256 | 257 | class IsSetHandler(BaseHandler): 258 | input_validators = (Arg.string("name"),) 259 | 260 | def handle(self, obj, getter): 261 | key = self.safe_get(0, obj, getter=getter) 262 | return key in getter.ctx 263 | 264 | def validate_output(self, obj, value): 265 | return operator.truth(value) 266 | 267 | 268 | class GetHandler(BaseHandler): 269 | input_validators = ( 270 | Arg.string("name"), 271 | Arg.present("fallback"), 272 | ) 273 | 274 | def handle(self, obj, getter): 275 | key = self.safe_get("name", obj, getter=getter) 276 | result = getter.ctx.get_var(key) 277 | if result is Empty: 278 | return self.safe_get("fallback", obj, getter=getter) 279 | return ExpOp(key).eval(getter.ctx, *getter.args, **getter.kwargs) 280 | 281 | 282 | class S_Handler(BaseHandler): 283 | input_validators = ( 284 | Arg.integer("start", None), 285 | Arg.integer("end", None), 286 | Arg.integer("step", None), 287 | ) 288 | 289 | def handle(self, obj, getter): 290 | return slice(*self.safe_get_all(obj, getter=getter)) 291 | 292 | 293 | class IndexHandler(BaseHandler): 294 | input_validators = ( 295 | Arg.present("index"), 296 | Arg.one_of( 297 | "value", (ValidateArg.validate_sequence, ValidateArg.validate_tensor) 298 | ), 299 | ) 300 | 301 | def handle(self, obj, getter): 302 | idx, value = self.safe_get_all(obj, getter=getter) 303 | return value[idx] 304 | 305 | 306 | class MinHandler(BaseHandler): 307 | input_validators = (Arg.numscalar_sequence("values"),) 308 | 309 | def handle(self, obj, getter): 310 | return min(*self.safe_get("values", obj, getter)) 311 | 312 | def validate_output(self, obj, value): 313 | return ValidateArg.validate_numeric(-1, value) 314 | 315 | 316 | class MaxHandler(MinHandler): 317 | def handle(self, obj, getter): 318 | return max(*self.safe_get("values", obj, getter)) 319 | 320 | 321 | class UnsafeCallHandler(BaseHandler): 322 | input_validators = (Arg.present("__callable"),) 323 | 324 | def handle(self, obj, getter): 325 | if "__callable" in obj.kwargs: 326 | raise ValueError( 327 | "unsafe_call does not support passing the callable via keyword arg" 328 | ) 329 | fun = self.safe_get("__callable", obj, getter) 330 | if not callable(fun): 331 | raise ValueError("Cannot call supplied value: not a callable") 332 | args = (self.safe_get(idx, obj, getter) for idx in range(1, len(obj.args))) 333 | kwargs = {k: self.safe_get(k, obj, getter) for k in obj.kwargs} 334 | return fun(*args, **kwargs) 335 | 336 | 337 | class DictHandler(BaseHandler): 338 | def handle(self, obj, getter): 339 | if len(obj.args): 340 | raise ValueError("Non-KV items passed to dict constructor") 341 | return ExpDict({k: self.safe_get(k, obj, getter) for k in obj.kwargs.keys()}) 342 | 343 | 344 | class CommentHandler(BaseHandler): 345 | def handle(self, obj, getter): 346 | return None 347 | 348 | 349 | class SetVarHandler(BaseHandler): 350 | input_validators = (Arg.string("lhs"), Arg.present("rhs")) 351 | 352 | def handle(self, obj, getter): 353 | key, val = self.safe_get_all(obj, getter) 354 | getter.ctx.set_var(key, val) 355 | return val 356 | 357 | 358 | LOGIC_HANDLERS = { 359 | "||": OrHandler(), 360 | "&&": AndHandler(), 361 | "==": EqHandler(), 362 | "!=": NeqHandler(), 363 | "not": NotHandler(), 364 | "if": IfHandler(), 365 | "all": AllHandler(), 366 | "any": AnyHandler(), 367 | } 368 | for k, alias in ( 369 | ("||", "or"), 370 | ("&&", "and"), 371 | ("==", "eq"), 372 | ("!=", "neq"), 373 | ): 374 | LOGIC_HANDLERS[alias] = LOGIC_HANDLERS[k] 375 | 376 | 377 | MATH_HANDLERS = { 378 | "+": SimpleMathHandler(operator.add), 379 | "-": MinusHandler(), 380 | "*": SimpleMathHandler(operator.mul), 381 | "/": SimpleMathHandler(operator.truediv), 382 | "//": SimpleMathHandler(operator.floordiv), 383 | "**": SimpleMathHandler(operator.pow), 384 | "mod": SimpleMathHandler(operator.mod), 385 | "neg": UnarySimpleMathHandler(operator.neg), 386 | "between": BetweenHandler(), 387 | "<": RelComparisonHandler(operator.lt), 388 | "<=": RelComparisonHandler(operator.le), 389 | ">": RelComparisonHandler(operator.gt), 390 | ">=": RelComparisonHandler(operator.ge), 391 | "min": MinHandler(), 392 | "max": MaxHandler(), 393 | } 394 | for k, alias in ( 395 | ("+", "add"), 396 | ("-", "sub"), 397 | ("*", "mul"), 398 | ("/", "div"), 399 | ("//", "idiv"), 400 | ("**", "pow"), 401 | ): 402 | MATH_HANDLERS[alias] = MATH_HANDLERS[k] 403 | 404 | MISC_HANDLERS = { 405 | "is_set": IsSetHandler(), 406 | "get": GetHandler(), 407 | "index": IndexHandler(), 408 | "s_": S_Handler(), 409 | "unsafe_call": UnsafeCallHandler(), 410 | "dict": DictHandler(), 411 | "comment": CommentHandler(), 412 | "set_var": SetVarHandler(), 413 | } 414 | 415 | BASIC_HANDLERS = LOGIC_HANDLERS | MATH_HANDLERS | MISC_HANDLERS 416 | -------------------------------------------------------------------------------- /py/expression/parser.py: -------------------------------------------------------------------------------- 1 | class ParseError(Exception): 2 | pass 3 | 4 | 5 | # Pratt parsing referenced from https://github.com/andychu/pratt-parsing-demo 6 | class ParserSpec: 7 | @staticmethod 8 | def null_error(p, token, bp): 9 | raise ParseError(f"{token!r} cannot be used in prefix position") 10 | 11 | @staticmethod 12 | def left_error(p, token, bp): 13 | raise ParseError(f"{token!r} cannot be used in infix position") 14 | 15 | class LeftInfo: 16 | def __init__(self, led=None, lbp=0, rbp=0): 17 | self.led, self.lbp, self.rbp = led or ParserSpec.left_error, lbp, rbp 18 | 19 | class NullInfo: 20 | def __init__(self, nud=None, bp=0): 21 | self.nud, self.bp = nud or ParserSpec.null_error, bp 22 | 23 | def __init__(self): 24 | self.null_lookup = {} 25 | self.left_lookup = {} 26 | 27 | def add_null(self, bp, nud, tokens): 28 | for token in tokens: 29 | self.null_lookup[token] = self.NullInfo(nud, bp) 30 | if token not in self.left_lookup: 31 | self.left_lookup[token] = self.LeftInfo() 32 | 33 | def add_led(self, lbp, rbp, led, tokens): 34 | for token in tokens: 35 | self.left_lookup[token] = self.LeftInfo(led, lbp, rbp) 36 | if token not in self.null_lookup: 37 | self.null_lookup[token] = self.NullInfo(self.null_error) 38 | 39 | def add_left(self, bp, led, tokens): 40 | return self.add_led(bp, bp, led, tokens) 41 | 42 | def add_leftright(self, bp, led, tokens): 43 | return self.add_led(bp, bp - 1, led, tokens) 44 | 45 | def lookup(self, token, is_left): 46 | result = (self.left_lookup if is_left else self.null_lookup).get(token) 47 | if result is None: 48 | raise ParseError(f"Unexpected token {token!r}") 49 | return result 50 | 51 | @staticmethod 52 | def get_type(token): 53 | if isinstance(token, (int, float)): 54 | return "number" 55 | if isinstance(token, str) and token.isidentifier(): 56 | return "op" 57 | return token 58 | 59 | 60 | class Parser: 61 | def __init__(self, spec, lexer): 62 | self.spec = spec 63 | self.lexer = lexer 64 | self.token = None 65 | self.token_type = None 66 | self.pos = -1 67 | 68 | def advance(self): 69 | if self.lexer is None: 70 | self.token_type = self.token = None 71 | return None 72 | try: 73 | self.token = next(self.lexer) 74 | self.token_type = self.spec.get_type(self.token) 75 | self.pos += 1 76 | except StopIteration: 77 | self.token = self.token_type = self.lexer = None 78 | return self.token 79 | 80 | def expect(self, val): 81 | if val is not None and (self.lexer is None or self.token != val): 82 | raise ParseError(f"expected {val!r}, got {self.token!r}") 83 | return self.advance() 84 | 85 | def parse_until(self, rbp): 86 | if self.lexer is None: 87 | raise ParseError("unexpected end of input") 88 | spec = self.spec 89 | token, token_type = self.token, self.token_type 90 | self.advance() 91 | ni = spec.lookup(token_type, False) 92 | node = ni.nud(self, token, ni.bp) 93 | while self.lexer: 94 | token, token_type = self.token, self.token_type 95 | li = spec.lookup(token_type, True) 96 | if rbp >= li.lbp: 97 | break 98 | self.advance() 99 | node = li.led(self, token, node, li.rbp) 100 | return node 101 | 102 | def go(self): 103 | self.advance() 104 | try: 105 | result = self.parse_until(0) 106 | except ParseError as exc: 107 | raise ParseError( 108 | f"pos {self.pos} at token {self.token!r}: parse error: {exc}" 109 | ) from None 110 | if self.lexer: 111 | raise ParseError(f"pos {self.pos}: unexpected end of input") 112 | return result 113 | -------------------------------------------------------------------------------- /py/expression/types.py: -------------------------------------------------------------------------------- 1 | class Empty: 2 | def __bool__(self): 3 | return False 4 | 5 | 6 | class ExpBase: 7 | def __bool__(self): 8 | return True 9 | 10 | def pretty_string(self, *, depth=0): 11 | return repr(self) 12 | 13 | def eval(self, *args, **kwargs): 14 | return self 15 | 16 | def clone(self, *, mapper=None): 17 | return self if not mapper else mapper(self) 18 | 19 | 20 | class ExpOp(str, ExpBase): 21 | __slots__ = () 22 | 23 | def eval(self, handlers, *args, **kwargs): 24 | value = handlers.get_var(self) 25 | if value is Empty: 26 | raise KeyError(f"No handler for op/var {self}") 27 | return value 28 | 29 | 30 | class ExpBinOp(ExpOp): 31 | __slots__ = () 32 | 33 | 34 | class ExpSym(str, ExpBase): 35 | __slots__ = () 36 | 37 | def __repr__(self): 38 | return f"'{self}" 39 | 40 | 41 | class ExpTuple(tuple, ExpBase): 42 | __slots__ = () 43 | 44 | def clone(self): 45 | return self.__class__(v.clone() if isinstance(ExpBase) else v for v in self) 46 | 47 | def get_eval(self, k, handlers, *args, default=None, **kwargs): 48 | val = super().__getitem__(k) 49 | if isinstance(val, ExpBase): 50 | return val.eval(handlers, *args, **kwargs) 51 | return val 52 | 53 | def pretty_string(self, depth=0): 54 | vals = ( 55 | repr(v) if not isinstance(v, ExpBase) else v.pretty_string(depth=depth + 1) 56 | for v in self 57 | ) 58 | pad = " " * (depth + 1) * 2 59 | nlpad = f",\n{pad}" 60 | return f"(\n{pad}{nlpad.join(vals)}\n{pad[:-2]})" 61 | 62 | def eval(self, handlers, *args, **kwargs): 63 | return tuple( 64 | v.eval(handlers, *args, **kwargs) if isinstance(v, ExpBase) else v 65 | for v in self 66 | ) 67 | 68 | 69 | class ExpKV(ExpBase): 70 | __slots__ = ("k", "v") 71 | 72 | def __init__(self, k, v): 73 | self.k = k 74 | self.v = v 75 | 76 | 77 | class ExpDict(dict, ExpBase): 78 | __slots__ = () 79 | 80 | def clone(self): 81 | return self.__class__(v.clone() if isinstance(ExpBase) else v for v in self) 82 | 83 | def pop(self, *args, **kwargs): 84 | raise NotImplementedError 85 | 86 | def get_eval(self, k, handlers, *args, default=Empty, **kwargs): 87 | val = super().get(k, default) 88 | if isinstance(val, ExpBase): 89 | return val.eval(handlers, *args, **kwargs) 90 | return val 91 | 92 | def pretty_string(self, depth=0): 93 | vals = ( 94 | f"{k}: {v!r}" 95 | if not isinstance(v, ExpBase) 96 | else f"{k}: {v.pretty_string(depth=depth + 1)}" 97 | for k, v in self.items() 98 | ) 99 | pad = " " * (depth + 1) * 2 100 | nlpad = f",\n{pad}" 101 | return f"{{\n{pad}{nlpad.join(vals)}\n{pad[:-2]}}}" 102 | 103 | def eval(self, handlers, *args, **kwargs): 104 | return { 105 | k: v.eval(handlers, *args, **kwargs) if isinstance(v, ExpBase) else v 106 | for k, v in self.items() 107 | } 108 | 109 | popitem = pop 110 | update = pop 111 | clear = pop 112 | __delitem__ = pop 113 | __setitem__ = pop 114 | __ior__ = pop 115 | 116 | 117 | class ExpStatements(ExpBase): 118 | def __init__(self, statements): 119 | if not isinstance(statements, ExpTuple) or not len(statements): 120 | raise ValueError("Must have at least one statement") 121 | self.statements = statements 122 | 123 | def eval(self, handlers, *args, **kwargs): 124 | result = Empty 125 | for stmt in self.statements: 126 | result = ( 127 | stmt.eval(handlers, *args, **kwargs) 128 | if isinstance(stmt, ExpBase) 129 | else stmt 130 | ) 131 | return result 132 | 133 | def __repr__(self): 134 | return f"@{self.statements}" 135 | 136 | 137 | class ExprGetter: 138 | def __init__(self, obj, ctx, *args, **kwargs): 139 | self.obj = obj 140 | self.ctx = ctx 141 | self.args = args 142 | self.kwargs = kwargs 143 | 144 | def __call__(self, k, *, default=Empty): 145 | obj = self.obj 146 | result = ( 147 | obj.kwargs.get_eval(k, self.ctx, *self.args, default=default, **self.kwargs) 148 | if isinstance(k, str) 149 | else obj.args.get_eval(k, self.ctx, *self.args, **self.kwargs) 150 | ) 151 | if result is Empty: 152 | raise KeyError(f"Unknown key {k!r}") 153 | return result 154 | 155 | 156 | class ExpFunAp(ExpBase): 157 | __slots__ = ("name", "args", "kwargs") 158 | 159 | def __init__(self, name, args=None, kwargs=None): 160 | self.name = name 161 | self.args = args if args is not None else ExpTuple() 162 | self.kwargs = kwargs if kwargs is not None else ExpDict() 163 | 164 | def eval(self, handlers, *args, **kwargs): 165 | handler = handlers.get_handler(self.name) 166 | if handler is Empty: 167 | raise KeyError(f"No handler for op: {self.name!r}") 168 | return handler( 169 | self, getter=ExprGetter(self, handlers, *args, **kwargs), **kwargs 170 | ) 171 | 172 | def clone(self): 173 | return self.__class__(self.name, self.args.clone(), self.kwargs.clone()) 174 | 175 | def pretty_string(self, depth=0): 176 | pad = " " * (depth + 1) * 2 177 | kwargs_str = f", {self.kwargs.pretty_string(depth + 1)}" if self.kwargs else "" 178 | return f"" 179 | 180 | def __repr__(self): 181 | kwargs_str = f", {self.kwargs}" if self.kwargs else "" 182 | return f"" 183 | 184 | 185 | class ExpBoundFunAp(ExpFunAp): 186 | __slots__ = ("fun",) 187 | 188 | def __init__(self, name, fun, args, kwargs): 189 | super().__init__(name, args, kwargs) 190 | self.fun = fun 191 | 192 | def eval(self, handlers, *args, **kwargs): 193 | def get_evaled(k, default=None): 194 | return ( 195 | self.kwargs.get_eval(k, handlers, *args, default=default, **kwargs) 196 | if isinstance(k, str) 197 | else self.args.get_eval(k, handlers, *args, **kwargs) 198 | ) 199 | 200 | return self.fun(self.name, self.args, *args, getter=get_evaled, **kwargs) 201 | 202 | 203 | __all__ = ( 204 | "ExpBase", 205 | "ExpOp", 206 | "ExpBinOp", 207 | "ExpSym", 208 | "ExpTuple", 209 | "ExpKV", 210 | "ExpDict", 211 | "ExpFunAp", 212 | "ExpBoundFunAp", 213 | ) 214 | -------------------------------------------------------------------------------- /py/expression/util.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | # To facilitate testing. 7 | class torch: 8 | class Tensor: 9 | pass 10 | 11 | 12 | class WrapGenerator: 13 | def __init__(self, g): 14 | self.g = g 15 | self._value = None 16 | self.ready = False 17 | 18 | @property 19 | def value(self): 20 | if not self.ready: 21 | raise ValueError("Value not ready") 22 | return self._value 23 | 24 | def __iter__(self): 25 | self._value = yield from self.g 26 | self.ready = True 27 | return self._value 28 | 29 | 30 | def split_iterable(seq, pred): 31 | it = iter(seq) 32 | while True: 33 | toks = tuple(itertools.takewhile(pred, it)) 34 | if toks == (): 35 | break 36 | yield toks 37 | -------------------------------------------------------------------------------- /py/expression/validation.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from ..latent import ImageBatch 4 | from .util import torch 5 | from .types import Empty 6 | 7 | 8 | class Arg: 9 | __slots__ = ("name", "default", "validator") 10 | 11 | def __init__(self, name, default=Empty, *, validator=None): 12 | self.name = name 13 | self.default = default 14 | self.validator = validator 15 | 16 | def __call__(self, _key, value, *args, **kwargs): 17 | return self.validate(value, *args, **kwargs) 18 | 19 | def validate(self, value): 20 | if value is Empty: 21 | if self.default is Empty: 22 | raise ValueError(f"Missing value for argument {self.name}") 23 | return self.default 24 | try: 25 | return self.validator(self.name, value) if self.validator else value 26 | except ValidateError as exc: 27 | raise ValidateError(f"Failed to validate argument {self.name}: {exc}") 28 | 29 | @classmethod 30 | def tensor(cls, name): 31 | return cls(name, validator=ValidateArg.validate_tensor) 32 | 33 | @classmethod 34 | def image(cls, name): 35 | return cls(name, validator=ValidateArg.validate_image) 36 | 37 | @classmethod 38 | def numeric(cls, name, default=Empty): 39 | return cls(name, default=default, validator=ValidateArg.validate_numeric) 40 | 41 | @classmethod 42 | def numeric_scalar(cls, name, default=Empty): 43 | return cls(name, default=default, validator=ValidateArg.validate_numeric_scalar) 44 | 45 | @classmethod 46 | def integer(cls, name, default=Empty): 47 | return cls(name, default=default, validator=ValidateArg.validate_integer) 48 | 49 | @classmethod 50 | def numscalar_sequence(cls, name, default=Empty): 51 | return cls( 52 | name, default=default, validator=ValidateArg.validate_numscalar_sequence 53 | ) 54 | 55 | @classmethod 56 | def sequence(cls, name, default=Empty, *, item_validator=None): 57 | return cls( 58 | name, 59 | default=default, 60 | validator=functools.partial( 61 | ValidateArg.validate_sequence, item_validator=item_validator 62 | ), 63 | ) 64 | 65 | @classmethod 66 | def string(cls, name, default=Empty): 67 | return cls(name, default=default, validator=ValidateArg.validate_string) 68 | 69 | @classmethod 70 | def boolean(cls, name, default=Empty): 71 | return cls(name, default=default, validator=ValidateArg.validate_boolean) 72 | 73 | @classmethod 74 | def present(cls, name): 75 | return cls(name, validator=ValidateArg.validate_passthrough) 76 | 77 | @classmethod 78 | def one_of(cls, name, validators, *, default=Empty): 79 | def validate(idx, val): 80 | for validator in validators: 81 | try: 82 | return validator(idx, val) 83 | except ValidateError: 84 | continue 85 | raise ValidateError( 86 | f"Failed to validate argument at {idx} of type {type(val)}" 87 | ) 88 | 89 | return cls(name, default=default, validator=validate) 90 | 91 | 92 | class ValidateError(Exception): 93 | pass 94 | 95 | 96 | class ValidateArg: 97 | __slots__ = ("valfuns", "groupfun", "kwargs", "kwargslist") 98 | 99 | def __init__(self, name, *args, kwargslist=(), group=all, **kwargs): 100 | if not isinstance(name, (list, tuple)): 101 | return self.__init__((name,), (args,), group=group, kwargslist=kwargs) 102 | self.valfuns = (getattr(self, f"validate_{n}", None) for n in name) 103 | if not all(self.valfuns): 104 | raise ValueError("Unknown validator") 105 | self.groupfun = group 106 | self.kwargs = kwargs 107 | self.kwargslist = kwargslist if kwargslist is not None else {} 108 | 109 | def __call__(self, *args, **kwargs): 110 | kalen = len(self.kwargslist) 111 | return self.groupfun( 112 | vf( 113 | *args, 114 | **(self.kwargslist if idx < kalen else {}), 115 | **self.kwargs, 116 | ) 117 | for idx, vf in enumerate(self.valfuns) 118 | ) 119 | 120 | @staticmethod 121 | def validate_numeric(idx, val): 122 | if not isinstance(val, (int, float, torch.Tensor)): 123 | raise ValidateError( 124 | f"Expected numeric or tensor argument at {idx}, got {type(val)}" 125 | ) 126 | return val 127 | 128 | @classmethod 129 | def validate_numeric_scalar(cls, idx, val): 130 | if not isinstance(val, (int, float)): 131 | raise ValidateError(f"Expected numeric argument at {idx}, got {type(val)}") 132 | return val 133 | 134 | @classmethod 135 | def validate_integer(cls, idx, val): 136 | if not isinstance(val, int): 137 | raise ValidateError(f"Expected integer argument at {idx}, got {type(val)}") 138 | return val 139 | 140 | @staticmethod 141 | def validate_tensor(idx, val): 142 | if not isinstance(val, torch.Tensor): 143 | raise ValidateError(f"Expected tensor argument at {idx}, got {type(val)}") 144 | return val 145 | 146 | @staticmethod 147 | def validate_image(idx, val): 148 | if not isinstance(val, ImageBatch): 149 | raise ValidateError( 150 | f"Expected PIL Image argument at {idx}, got {type(val)}" 151 | ) 152 | return val 153 | 154 | @staticmethod 155 | def validate_sequence(idx, val, *, item_validator=None): 156 | if not isinstance(val, (list, tuple)): 157 | raise ValidateError(f"Expected sequence argument at {idx}, got {type(val)}") 158 | if item_validator is None: 159 | return val 160 | try: 161 | return tuple(item_validator(iidx, v) for iidx, v in enumerate(val)) 162 | except ValidateError as exc: 163 | raise ValidateError(f"Item validation failed for in sequence: {exc}") 164 | 165 | @classmethod 166 | def validate_numscalar_sequence(cls, idx, val): 167 | return cls.validate_sequence( 168 | idx, val, item_validator=cls.validate_numeric_scalar 169 | ) 170 | 171 | # @classmethod 172 | # def validate_numscalar_sequence(cls, idx, val): 173 | # if not isinstance(val, (list, tuple)): 174 | # raise ValidateError(f"Expected sequence argument at {idx}, got {type(val)}") 175 | # try: 176 | # _ = all( 177 | # cls.validate_numeric_scalar(f"{idx}[{i}]", v) is not None 178 | # for i, v in enumerate(val) 179 | # ) 180 | # except ValidateError as exc: 181 | # raise ValidateError( 182 | # f"Expected numeric sequence argument at {idx}, got {type(val)}: {exc}" 183 | # ) 184 | # return val 185 | 186 | @classmethod 187 | def validate_string(cls, idx, val): 188 | if not isinstance(val, str): 189 | raise ValidateError(f"Expected string argument at {idx}, got {type(val)}") 190 | return val 191 | 192 | @classmethod 193 | def validate_boolean(cls, idx, val): 194 | if val is not True and val is not False: 195 | raise ValidateError(f"Expected boolean argument at {idx}, got {type(val)}") 196 | return val 197 | 198 | @classmethod 199 | def validate_passthrough(cls, idx, val): 200 | return val 201 | -------------------------------------------------------------------------------- /py/expression_handlers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import numpy as np 5 | import PIL.Image as PILImage 6 | 7 | from . import expression as expr 8 | from . import latent 9 | 10 | from .external import MODULES as EXT 11 | from .utils import scale_noise, resolve_value 12 | from .latent import OCSTAESD, ImageBatch 13 | 14 | ALLOW_UNSAFE = os.environ.get("COMFYUI_OCS_ALLOW_UNSAFE_EXPRESSIONS") is not None 15 | ALLOW_ALL_UNSAFE = os.environ.get("COMFYUI_OCS_ALLOW_ALL_UNSAFE") is not None 16 | 17 | EXT_BLEH = EXT.get("bleh") 18 | EXT_SONAR = EXT.get("sonar") 19 | EXT_NNLATENTUPSCALE = EXT.get("nnlatentupscale") 20 | 21 | if "bleh" in EXT: 22 | BLENDING_MODES = EXT_BLEH.latent_utils.BLENDING_MODES 23 | else: 24 | BLENDING_MODES = { 25 | "lerp": lambda a, b, t: (1 - t) * a + t * b, 26 | } 27 | 28 | HANDLERS = {} 29 | 30 | 31 | class NormHandler(expr.BaseHandler): 32 | input_validators = ( 33 | expr.Arg.tensor("tensor"), 34 | expr.Arg.numeric("factor", 1.0), 35 | expr.Arg.numscalar_sequence("dim", (-3, -2, -1)), 36 | ) 37 | 38 | def handle(self, obj, getter): 39 | tensor, factor, dim = self.safe_get_all(obj, getter) 40 | return scale_noise(tensor, factor, normalize_dims=dim) 41 | 42 | validate_output = expr.Arg.tensor("output") 43 | 44 | 45 | class MeanHandler(NormHandler): 46 | input_validators = ( 47 | expr.Arg.tensor("tensor"), 48 | expr.Arg.numscalar_sequence("dim", (-3, -2, -1)), 49 | ) 50 | 51 | def handle(self, obj, getter): 52 | tensor, dim = self.safe_get_all(obj, getter) 53 | return tensor.mean(keepdim=True, dim=dim) 54 | 55 | 56 | class StdHandler(NormHandler): 57 | input_validators = ( 58 | expr.Arg.tensor("tensor"), 59 | expr.Arg.numscalar_sequence("dim", (-3, -2, -1)), 60 | ) 61 | 62 | def handle(self, obj, getter): 63 | tensor, dim = self.safe_get_all(obj, getter) 64 | return tensor.std(keepdim=True, dim=dim) 65 | 66 | 67 | class RollHandler(NormHandler): 68 | input_validators = ( 69 | expr.Arg.tensor("tensor"), 70 | expr.Arg.numeric_scalar("amount", 0.5), 71 | expr.Arg.numscalar_sequence("dim", (-2,)), 72 | ) 73 | 74 | def handle(self, obj, getter): 75 | tensor, amount, dim = self.safe_get_all(obj, getter) 76 | if isinstance(amount, float) and amount < 1.0 and amount > -1.0: 77 | if len(dim) > 1: 78 | raise ValueError( 79 | "Cannot use percentage based amount with multiple roll dimensions", 80 | ) 81 | amount = int(tensor.shape[dim[0]] * amount) 82 | amount = (amount,) * len(dim) 83 | return tensor.roll(amount, dims=dim) 84 | 85 | 86 | class FlipHandler(NormHandler): 87 | input_validators = ( 88 | expr.Arg.tensor("tensor"), 89 | expr.Arg.integer("dim"), 90 | expr.Arg.boolean("mirror", False), 91 | ) 92 | 93 | def handle(self, obj, getter): 94 | tensor, dim, mirror = self.safe_get_all(obj, getter) 95 | if dim < 0: 96 | dim += tensor.ndim 97 | if dim < 0 or dim >= tensor.ndim: 98 | raise ValueError( 99 | f"Dimension out of range, wanted {dim}, tensor has {tensor.ndim} dimension(s)" 100 | ) 101 | if not mirror: 102 | return torch.flip(tensor, (dim,)) 103 | result = tensor.detach().clone() 104 | pivot = tensor.shape[dim] // 2 105 | out_slice = tuple( 106 | np.s_[:] if d != dim else np.s_[pivot:] for d in range(tensor.ndim) 107 | ) 108 | in_slice = tuple(np.s_[:] if d != dim else np.s_[:pivot] for d in range(tensor.ndim)) 109 | result[out_slice] = torch.flip(tensor[in_slice], dims=(dim,)) 110 | return result 111 | 112 | 113 | class BlendHandler(NormHandler): 114 | input_validators = ( 115 | expr.Arg.tensor("tensor1"), 116 | expr.Arg.tensor("tensor2"), 117 | expr.Arg.numeric("scale", 0.5), 118 | expr.Arg.string("mode", "lerp"), 119 | ) 120 | 121 | def handle(self, obj, getter): 122 | t1, t2, scale, mode = self.safe_get_all(obj, getter) 123 | blend_handler = BLENDING_MODES.get(mode) 124 | if not blend_handler: 125 | raise KeyError(f"Unknown blend mode {mode!r}") 126 | return blend_handler(t1, t2, scale) 127 | 128 | 129 | class ContrastAdaptiveSharpeningHandler(NormHandler): 130 | input_validators = ( 131 | expr.Arg.tensor("tensor"), 132 | expr.Arg.numeric("scale", 0.5), 133 | ) 134 | 135 | def handle(self, obj, getter): 136 | t, scale = self.safe_get_all(obj, getter) 137 | return latent.contrast_adaptive_sharpening(t, scale) 138 | 139 | 140 | class ScaleHandler(NormHandler): 141 | input_validators = ( 142 | expr.Arg.tensor("tensor"), 143 | expr.Arg.one_of( 144 | "scale", 145 | ( 146 | expr.ValidateArg.validate_numeric_scalar, 147 | expr.ValidateArg.validate_numscalar_sequence, 148 | ), 149 | ), 150 | expr.Arg.string("mode", "bicubic"), 151 | expr.Arg.boolean("absolute_scale", False), 152 | ) 153 | 154 | def handle(self, obj, getter): 155 | t, scale, mode, abs_scale = self.safe_get_all(obj, getter) 156 | if isinstance(scale, (list, tuple)): 157 | if len(scale) != 2: 158 | raise ValueError( 159 | "When passing scale as a tuple, it must be in the form (h, w)" 160 | ) 161 | else: 162 | scale = (scale, scale) 163 | if abs_scale: 164 | scale = tuple(int(v) for v in scale) 165 | else: 166 | scale = (int(t.shape[-2] * scale[0]), int(t.shape[-1] * scale[1])) 167 | print("SCALE", t.shape[-2:], "->", scale) 168 | if not all(v > 0 for v in scale): 169 | raise ValueError(f"Invalid scale: scale values must be > 0, got: {scale!r}") 170 | return latent.scale_samples(t, scale[1], scale[0], mode=mode) 171 | 172 | 173 | class NoiseHandler(NormHandler): 174 | input_validators = ( 175 | expr.Arg.tensor("tensor"), 176 | expr.Arg.string("type", "gaussian"), 177 | ) 178 | 179 | def handle(self, obj, getter): 180 | t, typ = self.safe_get_all(obj, getter) 181 | ctx = getter.ctx 182 | smin, smax, s, sn = ( 183 | ctx.get_var(k) for k in ("sigma_min", "sigma_max", "sigma", "sigma_next") 184 | ) 185 | ns = latent.get_noise_sampler(typ, t, smin, smax, normalized=False) 186 | return ns(s, sn) 187 | 188 | 189 | class ShapeHandler(expr.BaseHandler): 190 | input_validators = (expr.Arg.tensor("tensor"),) 191 | 192 | def handle(self, obj, getter): 193 | t = self.safe_get("tensor", obj, getter) 194 | return expr.types.ExpTuple((*t.shape,)) 195 | 196 | 197 | class TAESDDecodeHandler(expr.BaseHandler): 198 | input_validators = ( 199 | expr.Arg.tensor("tensor"), 200 | expr.Arg.string("mode", "sd15"), 201 | ) 202 | validate_output = expr.Arg.image("output") 203 | 204 | def handle(self, obj, getter): 205 | t, mode = self.safe_get_all(obj, getter) 206 | return OCSTAESD.decode(mode, t) 207 | 208 | 209 | class TAESDEncodeHandler(expr.BaseHandler): 210 | input_validators = ( 211 | expr.Arg.image("image"), 212 | expr.Arg.tensor("reference_latent"), 213 | expr.Arg.string("mode", "sd15"), 214 | ) 215 | validate_output = expr.Arg.tensor("output") 216 | 217 | def handle(self, obj, getter): 218 | imgbatch, ref, mode = self.safe_get_all(obj, getter) 219 | return OCSTAESD.encode(mode, imgbatch, ref) 220 | 221 | 222 | class ImgShapeHandler(expr.BaseHandler): 223 | input_validators = (expr.Arg.tensor("image"),) 224 | 225 | def handle(self, obj, getter): 226 | imgbatch = self.safe_get("image", obj, getter) 227 | if len(imgbatch) == 0: 228 | raise ValueError("Can't get shape of empty image batch") 229 | isz = imgbatch[0].size 230 | return expr.types.ExpTuple((isz[1], isz[0])) 231 | 232 | 233 | class ImgPILResizeHandler(expr.BaseHandler): 234 | input_validators = ( 235 | expr.Arg.image("image"), 236 | expr.Arg.one_of( 237 | "size", 238 | ( 239 | expr.ValidateArg.validate_numeric_scalar, 240 | expr.ValidateArg.validate_numscalar_sequence, 241 | ), 242 | ), 243 | expr.Arg.string("resample_mode", "bicubic"), 244 | expr.Arg.boolean("absolute_scale", False), 245 | ) 246 | validate_output = expr.Arg.image("output") 247 | 248 | def handle(self, obj, getter): 249 | imgbatch, size, resample_mode, abs_scale = self.safe_get_all(obj, getter) 250 | if not isinstance(size, tuple): 251 | size = (size, size) 252 | if len(size) != 2 or not all(n > 0 for n in size): 253 | raise ValueError( 254 | "Image resize size parameter must be a positive non-zero number or tuple of positive non-zero height, width" 255 | ) 256 | try: 257 | resample_mode = PILImage.Resampling[resample_mode.upper()] 258 | except KeyError: 259 | raise ValueError("Bad resample mode") 260 | if len(imgbatch) == 0: 261 | return imgbatch 262 | if abs_scale: 263 | size = tuple(int(v) for v in size) 264 | else: 265 | imgsize = imgbatch[0].size 266 | size = (int(imgsize[1] * size[0]), int(imgsize[0] * size[1])) 267 | new_size = (size[1], size[0]) 268 | return ImageBatch(img.resize(new_size, resample_mode) for img in imgbatch) 269 | 270 | 271 | class UnsafeTorchTensorMethodHandler(NormHandler): 272 | input_validators = ( 273 | expr.Arg.tensor("__tensor"), 274 | expr.Arg.string("__method"), 275 | ) 276 | 277 | if ALLOW_ALL_UNSAFE: 278 | 279 | class AlwaysContains: 280 | def __contains__(self, k): 281 | return True 282 | 283 | whitelist = AlwaysContains() 284 | elif ALLOW_UNSAFE: 285 | whitelist = { 286 | "abs", 287 | "absolute", 288 | "acos", 289 | "acosh", 290 | "add", 291 | "addbmm", 292 | "addcdiv", 293 | "addcmul", 294 | "addmm", 295 | "addmv", 296 | "addr", 297 | "adjoint", 298 | "all", 299 | "allclose", 300 | "amax", 301 | "amin", 302 | "aminmax", 303 | "angle", 304 | "any", 305 | "arccos", 306 | "arccosh", 307 | "arcsin", 308 | "arcsinh", 309 | "arctan", 310 | "arctan2", 311 | "arctanh", 312 | "argmax", 313 | "argmin", 314 | "argsort", 315 | "argwhere", 316 | "as_strided", 317 | "asin", 318 | "asinh", 319 | "atan", 320 | "atan2", 321 | "atanh", 322 | "baddbmm", 323 | "bernoulli", 324 | "bincount", 325 | "bitwise_and", 326 | "bitwise_left_shift", 327 | "bitwise_not", 328 | "bitwise_or", 329 | "bitwise_right_shift", 330 | "bitwise_xor", 331 | "bmm", 332 | "broadcast_to", 333 | "ceil", 334 | "cholesky", 335 | "cholesky_inverse", 336 | "cholesky_solve", 337 | "chunk", 338 | "clamp", 339 | "clip", 340 | "clone", 341 | "conj", 342 | "conj_physical", 343 | "contiguous", 344 | "copysign", 345 | "corrcoef", 346 | "cos", 347 | "cosh", 348 | "count_nonzero", 349 | "cov", 350 | "cross", 351 | "cummax", 352 | "cummin", 353 | "cumprod", 354 | "cumsum", 355 | "deg2rad", 356 | "det", 357 | "detach", 358 | "diag", 359 | "diag_embed", 360 | "diagflat", 361 | "diagonal", 362 | "diagonal_scatter", 363 | "diff", 364 | "digamma", 365 | "dim", 366 | "dist", 367 | "div", 368 | "divide", 369 | "dot", 370 | "dsplit", 371 | "eq", 372 | "equal", 373 | "erf", 374 | "erfc", 375 | "erfinv", 376 | "exp", 377 | "expand", 378 | "expand_as", 379 | "expm1", 380 | "fix", 381 | "flatten", 382 | "flip", 383 | "fliplr", 384 | "flipud", 385 | "float_power", 386 | "floor", 387 | "floor_divide", 388 | "fmax", 389 | "fmin", 390 | "fmod", 391 | "frac", 392 | "frexp", 393 | "gather", 394 | "gcd", 395 | "ge", 396 | "geqrf", 397 | "ger", 398 | "greater", 399 | "greater_equal", 400 | "gt", 401 | "hardshrink", 402 | "heaviside", 403 | "histc", 404 | "hsplit", 405 | "hypot", 406 | "i0", 407 | "igamma", 408 | "igammac", 409 | "index_add", 410 | "index_copy", 411 | "index_fill", 412 | "index_put", 413 | "index_reduce", 414 | "index_select", 415 | "inner", 416 | "inverse", 417 | "isclose", 418 | "isfinite", 419 | "isinf", 420 | "isnan", 421 | "isneginf", 422 | "isposinf", 423 | "kthvalue", 424 | "lcm()", 425 | "ldexp", 426 | "le", 427 | "lerp", 428 | "less", 429 | "less_equal", 430 | "lgamma", 431 | "log", 432 | "log10", 433 | "log1p", 434 | "log2", 435 | "logaddexp", 436 | "logaddexp2", 437 | "logcumsumexp", 438 | "logdet", 439 | "logical_and", 440 | "logical_not", 441 | "logical_or", 442 | "logical_xor", 443 | "logit", 444 | "logsumexp", 445 | "lt", 446 | "lu", 447 | "lu_solve", 448 | "masked_fill", 449 | "masked_scatter", 450 | "masked_select", 451 | "matmul", 452 | "matrix_exp", 453 | "max", 454 | "maximum", 455 | "mean", 456 | "median", 457 | "min", 458 | "minimum", 459 | "mm", 460 | "mode", 461 | "moveaxis", 462 | "movedim", 463 | "msort", 464 | "mul", 465 | "multinomial", 466 | "multiply", 467 | "mv", 468 | "mvlgamma", 469 | "nan_to_num", 470 | "nanmean", 471 | "nanmedian", 472 | "nanquantile", 473 | "nansum", 474 | "narrow", 475 | "narrow_copy", 476 | "ne", 477 | "neg", 478 | "negative", 479 | "new_empty", 480 | "new_full", 481 | "new_ones", 482 | "new_zeros", 483 | "nextafter", 484 | "nonzero", 485 | "norm", 486 | "not_equal", 487 | "numel", 488 | "orgqr", 489 | "ormqr", 490 | "outer", 491 | "permute", 492 | "polygamma", 493 | "positive", 494 | "pow", 495 | "prod", 496 | "qr", 497 | "quantile", 498 | "rad2deg", 499 | "ravel", 500 | "reciprocal", 501 | "remainder", 502 | "renorm", 503 | "repeat", 504 | "repeat_interleave", 505 | "reshape", 506 | "reshape_as", 507 | "resolve_conj", 508 | "resolve_neg", 509 | "roll", 510 | "rot90", 511 | "round", 512 | "rsqrt", 513 | "scatter", 514 | "scatter_add", 515 | "scatter_reduce", 516 | "select", 517 | "select_scatter", 518 | "sgn", 519 | "sigmoid", 520 | "sign", 521 | "signbit", 522 | "sin", 523 | "sinc", 524 | "sinh", 525 | "slice_scatter", 526 | "slogdet", 527 | "smm", 528 | "softmax", 529 | "sort", 530 | "sparse_mask", 531 | "split", 532 | "sqrt", 533 | "square", 534 | "squeeze", 535 | "sspaddmm", 536 | "std", 537 | "stft", 538 | "sub", 539 | "subtract", 540 | "sum", 541 | "sum_to_size", 542 | "svd", 543 | "swapaxes", 544 | "swapdims", 545 | "t", 546 | "take", 547 | "take_along_dim", 548 | "tan", 549 | "tanh", 550 | "tensor_split", 551 | "tile", 552 | "topk", 553 | "transpose", 554 | "triangular_solve", 555 | "tril", 556 | "triu", 557 | "true_divide", 558 | "trunc", 559 | "unflatten", 560 | "unfold", 561 | "unique", 562 | "unique_consecutive", 563 | "unsqueeze", 564 | "var", 565 | "vdot", 566 | "view", 567 | "view_as", 568 | "vsplit", 569 | "where", 570 | "xlogy", 571 | } 572 | else: 573 | whitelist = set() 574 | 575 | def handle(self, obj, getter): 576 | if "__method" in obj.kwargs or "__tensor" in obj.kwargs: 577 | raise ValueError( 578 | "Tensor method call doesn't support passing method or tensor with keyword args" 579 | ) 580 | tensor = self.safe_get("__tensor", obj, getter=getter) 581 | method = self.safe_get("__method", obj, getter=getter) 582 | args = ( 583 | self.safe_get(idx, obj, getter=getter) for idx in range(2, len(obj.args)) 584 | ) 585 | kwargs = {k: self.safe_get(k, obj, getter=getter) for k in obj.kwargs.keys()} 586 | if method not in self.whitelist: 587 | raise ValueError(f"Method {method} not whitelisted: cannot call") 588 | methodfun = getattr(tensor, method, None) 589 | if methodfun is None: 590 | raise KeyError(f"Unknown method {method} for Torch tensor") 591 | return methodfun(*args, **kwargs) 592 | 593 | 594 | class UnsafeTorchHandler(expr.BaseHandler): 595 | input_validators = (expr.Arg.string("path"),) 596 | 597 | if not ALLOW_ALL_UNSAFE: 598 | 599 | def handle(self, obj, getter): 600 | raise ValueError("Unsafe Torch access not allowed") 601 | 602 | else: 603 | 604 | def handle(self, obj, getter): 605 | path = self.safe_get("path", obj, getter) 606 | keys = path.split(".") 607 | if not keys or not all(k for k in keys): 608 | raise ValueError(f"Bad path {path}") 609 | return resolve_value(keys, torch) 610 | 611 | 612 | if EXT_BLEH: 613 | 614 | class BlehEnhanceHandler(expr.BaseHandler): 615 | input_validators = ( 616 | expr.Arg.tensor("tensor"), 617 | expr.Arg.string("mode"), 618 | expr.Arg.numeric_scalar("scale", 1.0), 619 | ) 620 | output_validator = expr.Arg.tensor("output") 621 | 622 | def handle(self, obj, getter): 623 | tensor, mode, scale = self.safe_get_all(obj, getter) 624 | return EXT_BLEH.latent_utils.enhance_tensor( 625 | tensor, mode, scale=scale, adjust_scale=False 626 | ) 627 | 628 | HANDLERS["t_bleh_enhance"] = BlehEnhanceHandler() 629 | 630 | if EXT_SONAR: 631 | 632 | class SonarPowerFilterHandler(expr.BaseHandler): 633 | input_validators = ( 634 | expr.Arg.tensor("tensor"), 635 | expr.Arg.present("filter"), 636 | ) 637 | output_validator = expr.Arg.tensor("output") 638 | 639 | default_power_filter = { 640 | "mix": 1.0, 641 | "normalization_factor": 1.0, 642 | "common_mode": 0.0, 643 | "channel_correlation": "1,1,1,1,1,1", 644 | } 645 | 646 | @classmethod 647 | def make_power_filter(cls, fdict, *, toplevel=True): 648 | fdict = fdict.copy() 649 | compose_with = fdict.pop("compose_with", None) 650 | if compose_with: 651 | if not isinstance(compose_with, dict): 652 | raise TypeError("compose_with must be a dictionary") 653 | fdict["compose_with"] = cls.make_power_filter( 654 | compose_with, toplevel=False 655 | ) 656 | topargs = { 657 | k: fdict.pop(k, dv) for k, dv in cls.default_power_filter.items() 658 | } 659 | power_filter = EXT_SONAR.powernoise.PowerFilter(**fdict) 660 | if not toplevel: 661 | return power_filter 662 | cc = topargs.get("channel_correlation") 663 | if cc is not None: 664 | if not isinstance(cc, (list, tuple)) or not all( 665 | isinstance(v, (int, float)) for v in cc 666 | ): 667 | raise TypeError( 668 | "Bad channel correlation type: must be comma separated string or numeric sequence" 669 | ) 670 | topargs["channel_correlation"] = ",".join(repr(v) for v in cc) 671 | return EXT_SONAR.powernoise.PowerNoiseItem( 672 | 1, power_filter=power_filter, time_brownian=True, **topargs 673 | ) 674 | 675 | def handle(self, obj, getter): 676 | tensor, filter_def = self.safe_get_all(obj, getter) 677 | if not isinstance(filter_def, dict): 678 | raise TypeError("filter argument must be a dictionary") 679 | power_filter = self.make_power_filter(filter_def) 680 | filter_rfft = power_filter.make_filter(tensor.shape).to( 681 | tensor.device, non_blocking=True 682 | ) 683 | ns = power_filter.make_noise_sampler_internal( 684 | tensor, 685 | lambda *_unused, latent=tensor: latent, 686 | filter_rfft, 687 | normalized=False, 688 | ) 689 | return ns(None, None) 690 | 691 | HANDLERS["t_sonar_power_filter"] = SonarPowerFilterHandler() 692 | 693 | if EXT_NNLATENTUPSCALE: 694 | from .latent import scale_nnlatentupscale 695 | 696 | class ScaleNNLatentUpscaleHandler(expr.BaseHandler): 697 | input_validators = ( 698 | expr.Arg.tensor("tensor"), 699 | expr.Arg.string("mode", "sd1"), 700 | expr.Arg.numeric_scalar("scale", 2.0), 701 | ) 702 | output_validator = expr.Arg.tensor("output") 703 | 704 | def handle(self, obj, getter): 705 | tensor, mode, scale = self.safe_get_all(obj, getter) 706 | if mode not in {"sd1", "sdxl"}: 707 | raise ValueError( 708 | "Bad mode for t_scale_nnlatentupscale: must be either sd15 or sdxl" 709 | ) 710 | return scale_nnlatentupscale(mode, tensor, scale) 711 | 712 | HANDLERS["t_scale_nnlatentupscale"] = ScaleNNLatentUpscaleHandler() 713 | 714 | TENSOR_OP_HANDLERS = { 715 | "t_norm": NormHandler(), 716 | "t_mean": MeanHandler(), 717 | "t_std": StdHandler(), 718 | "t_blend": BlendHandler(), 719 | "t_roll": RollHandler(), 720 | "t_flip": FlipHandler(), 721 | "t_contrast_adaptive_sharpening": ContrastAdaptiveSharpeningHandler(), 722 | "t_scale": ScaleHandler(), 723 | "t_noise": NoiseHandler(), 724 | "t_shape": ShapeHandler(), 725 | "t_taesd_decode": TAESDDecodeHandler(), 726 | "unsafe_tensor_method": UnsafeTorchTensorMethodHandler(), 727 | "unsafe_torch": UnsafeTorchHandler(), 728 | } 729 | 730 | IMAGE_OP_HANDLERS = { 731 | "img_taesd_encode": TAESDEncodeHandler(), 732 | "img_shape": ImgShapeHandler(), 733 | "img_pil_resize": ImgPILResizeHandler(), 734 | } 735 | 736 | HANDLERS |= TENSOR_OP_HANDLERS 737 | HANDLERS |= IMAGE_OP_HANDLERS 738 | -------------------------------------------------------------------------------- /py/external.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import importlib 3 | 4 | MODULES = {} 5 | 6 | with contextlib.suppress(ImportError, NotImplementedError): 7 | bleh = importlib.import_module("custom_nodes.ComfyUI-bleh") 8 | bleh_version = getattr(bleh, "BLEH_VERSION", -1) 9 | if bleh_version < 1: 10 | raise NotImplementedError 11 | MODULES["bleh"] = bleh.py 12 | 13 | with contextlib.suppress(ImportError, NotImplementedError): 14 | MODULES["sonar"] = importlib.import_module("custom_nodes.ComfyUI-sonar").py 15 | 16 | with contextlib.suppress(ImportError, NotImplementedError): 17 | MODULES["nnlatentupscale"] = importlib.import_module( 18 | "custom_nodes.ComfyUi_NNLatentUpscale" 19 | ) 20 | 21 | __all__ = ("MODULES",) 22 | -------------------------------------------------------------------------------- /py/filtering.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import torch 4 | 5 | from . import expression as expr 6 | from . import expression_handlers 7 | 8 | from .external import MODULES as EXT 9 | from .utils import fallback 10 | 11 | OD = collections.OrderedDict 12 | 13 | EXT_BLEH = EXT.get("bleh") 14 | EXT_SONAR = EXT.get("sonar") 15 | 16 | if "bleh" in EXT: 17 | BLENDING_MODES = EXT_BLEH.latent_utils.BLENDING_MODES 18 | else: 19 | BLENDING_MODES = { 20 | "lerp": lambda a, b, t: (1 - t) * a + t * b, 21 | } 22 | 23 | BLENDING_MODES = BLENDING_MODES | { 24 | "a_only": lambda a, b, t: a * t, 25 | "b_only": lambda a, b, t: b * t, 26 | } 27 | 28 | FILTER = {} 29 | 30 | 31 | FILTER_HANDLERS = expr.HandlerContext( 32 | expr.BASIC_HANDLERS | expression_handlers.HANDLERS 33 | ) 34 | 35 | 36 | class FilterRefs: 37 | def __init__(self, kvs=None): 38 | self.kvs = fallback(kvs, {}) 39 | 40 | def get(self, k, default=None): 41 | return self.kvs.get(k, default) 42 | 43 | def __getitem__(self, k): 44 | return self.kvs[k] 45 | 46 | def __setitem__(self, k, v): 47 | self.kvs[k] = v 48 | 49 | def clone(self): 50 | return self.__class__(self.kvs.copy()) 51 | 52 | def __or__(self, other): 53 | return self.__class__(self.kvs | other.kvs) 54 | 55 | def __ior__(self, other): 56 | self.kvs |= other.kvs 57 | return self 58 | 59 | def __delitem__(self, k): 60 | del self.kvs[k] 61 | 62 | def __contains__(self, k): 63 | return k in self.kvs 64 | 65 | def __missing__(self, k): 66 | return self.kvs.__missing__(k) 67 | 68 | def __len__(self): 69 | return len(self.kvs) 70 | 71 | def __iter__(self): 72 | return self.kvs.__iter__() 73 | 74 | def items(self): 75 | return self.kvs.items() 76 | 77 | @classmethod 78 | def from_ss(cls, ss, *, have_current=False): 79 | ms = ss.model.model_sampling 80 | fr = cls({ 81 | "step": ss.step, 82 | "substep": ss.substep, 83 | "dt": ss.dt, 84 | "sigma_idx": ss.idx, 85 | "sigma": ss.sigma, 86 | "sigma_next": ss.sigma_next, 87 | "sigma_down": ss.sigma_down, 88 | "sigma_up": ss.sigma_up, 89 | "sigma_prev": ss.sigma_prev, 90 | "hist_len": len(ss.hist), 91 | "sigma_min": ms.sigma_min.item(), 92 | "sigma_max": ms.sigma_max.item(), 93 | "step_pct": float(ss.step / ss.total_steps), 94 | "total_steps": ss.total_steps, 95 | "sampling_pct": (999 - ms.timestep(ss.sigma).item()) / 999, 96 | "is_rectified_flow": ss.model.is_rectified_flow, 97 | "original_cfg_scale": ss.model.inner_cfg_scale, 98 | }) 99 | if have_current and len(ss.hist) > 0: 100 | fr |= cls.from_mr(ss.hcur) 101 | fr["d"] = ss.d 102 | if not have_current and len(ss.hist) > 0: 103 | hist_offs = -1 104 | elif len(ss.hist) > 1: 105 | hist_offs = -2 106 | else: 107 | hist_offs = None 108 | if hist_offs is not None: 109 | hprev = ss.hist[hist_offs] 110 | fr.kvs |= {f"{k}_prev": v for k, v in cls.from_mr(hprev).kvs.items()} 111 | fr["d_prev"] = hprev.d 112 | return fr 113 | 114 | @classmethod 115 | def from_mr(cls, mr): 116 | return cls({ 117 | k: getattr(mr, ak) 118 | for k, ak in ( 119 | ("cond", "denoised_cond"), 120 | ("denoised", "denoised"), 121 | ("model_call", "call_idx"), 122 | ("sigma", "sigma"), 123 | ("uncond", "denoised_uncond"), 124 | ("x", "x"), 125 | ) 126 | if getattr(mr, ak, None) is not None 127 | }) 128 | 129 | @classmethod 130 | def from_sr(cls, sr): 131 | return cls({ 132 | k: getattr(sr, ak) 133 | for k, ak in ( 134 | ("cond", "denoised_cond"), 135 | ("denoised", "denoised"), 136 | ("noise", "noise_pred"), 137 | ("sigma_down", "sigma_down"), 138 | ("sigma_next", "sigma_next"), 139 | ("sigma_up", "sigma_up"), 140 | ("sigma", "sigma"), 141 | ("step", "step"), 142 | ("substep", "substep"), 143 | ("uncond", "denoised_uncond"), 144 | ("x", "x"), 145 | ) 146 | if getattr(sr, ak, None) is not None 147 | }) 148 | 149 | 150 | class Filter: 151 | name = "unknown" 152 | uses_ref = False 153 | default_options = { 154 | "enabled": True, 155 | "when": None, 156 | "input": "default", 157 | "output": "default", 158 | "ref": "default", 159 | "final": "default", 160 | "blend_mode": "lerp", 161 | "strength": 1.0, 162 | } 163 | 164 | def __init__(self, **options): 165 | self.options = options 166 | self.set_options(self.default_options) 167 | if self.when is not None: 168 | self.when = expr.Expression(self.when) 169 | for key in ("input", "output", "ref", "final"): 170 | if not self.uses_ref and key == "ref": 171 | continue 172 | val = getattr(self, key) 173 | val = make_filter(val) if isinstance(val, dict) else expr.Expression(val) 174 | setattr(self, key, val) 175 | if self.blend_mode not in BLENDING_MODES: 176 | raise ValueError("Bad blend mode") 177 | 178 | def set_options(self, defaults): 179 | for k, v in defaults.items(): 180 | setattr(self, k, self.options.pop(k, v)) 181 | 182 | def apply(self, input_latent, default_ref=None, refs=None, **kwargs): 183 | if not self.check_applies(refs): 184 | return input_latent 185 | refs = fallback(refs, FilterRefs()).clone() 186 | latent = self.get_ref("input", input_latent, self.input, refs=refs) 187 | refs["input"] = latent 188 | if not self.uses_ref: 189 | ref_latent = None 190 | else: 191 | ref_latent = ( 192 | self.get_ref("ref", default_ref, self.ref, refs=refs) 193 | if default_ref is not None 194 | else None 195 | ) 196 | refs["ref"] = ref_latent 197 | output_latent = self.get_ref( 198 | "output", 199 | self.filter(latent, ref_latent, refs=refs, **kwargs), 200 | self.output, 201 | refs=refs, 202 | ) 203 | refs["output"] = output_latent 204 | return self.get_ref( 205 | "final", 206 | BLENDING_MODES[self.blend_mode]( 207 | input_latent[: output_latent.shape[0]], output_latent, self.strength 208 | ), 209 | self.final, 210 | refs=refs, 211 | ) 212 | 213 | def filter(self, latent, ref_latent, *, refs, **kwargs): 214 | raise NotImplementedError 215 | 216 | def check_applies(self, refs=None): 217 | if not self.enabled: 218 | return False 219 | if self.when is None: 220 | return True 221 | refs = fallback(refs, FilterRefs()) 222 | matched = self.when.eval(FILTER_HANDLERS.clone(constants=refs, variables={})) 223 | # if matched: 224 | # print("\nMATCH", self.name) 225 | return matched 226 | 227 | def get_ref(self, name, default_ref, ops, *, refs=None): 228 | if isinstance(ops, Filter): 229 | return ops.apply(default_ref, ops, refs=refs) 230 | drefs = FilterRefs({"default": default_ref}) 231 | refs = drefs if refs is None else refs | drefs 232 | return ops.eval(FILTER_HANDLERS.clone(constants=refs, variables={})) 233 | 234 | 235 | class SimpleFilter(Filter): 236 | name = "simple" 237 | 238 | def filter(self, latent, *args, **kwargs): 239 | return latent 240 | 241 | 242 | class BlendFilter(Filter): 243 | name = "blend" 244 | default_options = Filter.default_options | {"filter1": None, "filter2": None} 245 | 246 | def __init__(self, **kwargs): 247 | super().__init__(**kwargs) 248 | if not (isinstance(self.filter1, dict) and isinstance(self.filter2, dict)): 249 | raise ValueError("Must set filter1 and filter2") 250 | self.filter1 = make_filter(self.filter1) 251 | self.filter2 = make_filter(self.filter2) 252 | 253 | def filter(self, latent, ref_latent, *, refs, **kwargs): 254 | if self.blend_mode == "lerp": 255 | if self.strength == 0: 256 | return self.filter1(latent, ref_latent, refs=refs, **kwargs) 257 | if self.strength == 1: 258 | return self.filter2(latent, ref_latent, refs=refs, **kwargs) 259 | return BLENDING_MODES[self.blend_mode]( 260 | self.filter1.apply(latent, ref_latent, refs=refs, **kwargs), 261 | self.filter2.apply(latent, ref_latent, refs=refs, **kwargs), 262 | self.strength, 263 | ) 264 | 265 | 266 | class ListFilter(Filter): 267 | name = "list" 268 | default_options = Filter.default_options | {"filters": ()} 269 | 270 | def __init__(self, **kwargs): 271 | super().__init__(**kwargs) 272 | if not isinstance(self.filters, (list, tuple)): 273 | raise ValueError("filters key must be a sequence") 274 | self.filters = tuple(make_filter(filt) for filt in self.filters) 275 | 276 | def filter(self, latent, ref_latent, *, refs, **kwargs): 277 | if not self.filters: 278 | return latent 279 | for filt in self.filters: 280 | latent = filt.apply(latent, ref_latent, refs=refs, **kwargs) 281 | return latent 282 | 283 | 284 | class NormalizeFilter(Filter): 285 | name = "normalize" 286 | uses_ref = True 287 | default_options = Filter.default_options | { 288 | "adjust_target": 0, 289 | "balance_scale": 1.0, 290 | "adjust_scale": 1.0, 291 | "dims": (-2, -1), 292 | } 293 | 294 | def __init__( 295 | self, 296 | start_step=0, 297 | end_step=9999, 298 | phase="after", 299 | adjust_target=0, 300 | balance_scale=1.0, 301 | adjust_scale=1.0, 302 | dims=(-2, -1), 303 | ): 304 | self.start_step = start_step 305 | self.end_step = end_step 306 | self.phase = phase.lower().strip() # before, after, all 307 | if isinstance(adjust_target, str): 308 | adjust_target = adjust_target.lower().strip() 309 | if adjust_target not in ("x",): 310 | raise ValueError("Bad target mean") 311 | # "x", scalar or array matching mean dims 312 | self.adjust_target = adjust_target 313 | # multiplier on adjustment, scalar or array matching mean dims 314 | self.adjust_scale = adjust_scale 315 | self.balance_scale = balance_scale 316 | self.dims = dims 317 | 318 | def __call__(self, ss, sigma, latent, phase, orig_x=None): 319 | if ss.step < self.start_step or ss.step > self.end_step: 320 | return latent 321 | if self.phase != "all" and phase != self.phase: 322 | return latent 323 | if self.adjust_target == "x" and orig_x is None: 324 | raise ValueError("Can only use source x in after phase") 325 | adjust_scale, balance_scale = ( 326 | torch.tensor(v, dtype=latent.dtype).to(latent) 327 | if isinstance(v, (list, tuple)) 328 | else v 329 | for v in (self.adjust_scale, self.balance_scale) 330 | ) 331 | latent_mean = latent.mean(dim=self.dims, keepdim=True) 332 | # print("MEAN", latent_mean) 333 | latent = latent - latent_mean * balance_scale 334 | if self.adjust_target == "x": 335 | latent += orig_x.mean(dim=self.dims, keepdim=True) * adjust_scale 336 | elif isinstance(self.adjust_target, (list, tuple)): 337 | adjust_target = torch.tensor(self.adjust_target, dtype=latent.dtype).to( 338 | latent 339 | ) 340 | latent += adjust_target * adjust_scale 341 | else: 342 | latent += self.adjust_target * adjust_scale 343 | return latent 344 | 345 | 346 | class NormalizeFilter_: 347 | def __init__( 348 | self, 349 | start_step=0, 350 | end_step=9999, 351 | phase="after", 352 | adjust_target=0, 353 | balance_scale=1.0, 354 | adjust_scale=1.0, 355 | dims=(-2, -1), 356 | ): 357 | self.start_step = start_step 358 | self.end_step = end_step 359 | self.phase = phase.lower().strip() # before, after, all 360 | if isinstance(adjust_target, str): 361 | adjust_target = adjust_target.lower().strip() 362 | if adjust_target not in ("x",): 363 | raise ValueError("Bad target mean") 364 | # "x", scalar or array matching mean dims 365 | self.adjust_target = adjust_target 366 | # multiplier on adjustment, scalar or array matching mean dims 367 | self.adjust_scale = adjust_scale 368 | self.balance_scale = balance_scale 369 | self.dims = dims 370 | 371 | def __call__(self, ss, sigma, latent, phase, orig_x=None): 372 | if ss.step < self.start_step or ss.step > self.end_step: 373 | return latent 374 | if self.phase != "all" and phase != self.phase: 375 | return latent 376 | if self.adjust_target == "x" and orig_x is None: 377 | raise ValueError("Can only use source x in after phase") 378 | adjust_scale, balance_scale = ( 379 | torch.tensor(v, dtype=latent.dtype).to(latent) 380 | if isinstance(v, (list, tuple)) 381 | else v 382 | for v in (self.adjust_scale, self.balance_scale) 383 | ) 384 | latent_mean = latent.mean(dim=self.dims, keepdim=True) 385 | # print("MEAN", latent_mean) 386 | latent = latent - latent_mean * balance_scale 387 | if self.adjust_target == "x": 388 | latent += orig_x.mean(dim=self.dims, keepdim=True) * adjust_scale 389 | elif isinstance(self.adjust_target, (list, tuple)): 390 | adjust_target = torch.tensor(self.adjust_target, dtype=latent.dtype).to( 391 | latent 392 | ) 393 | latent += adjust_target * adjust_scale 394 | else: 395 | latent += self.adjust_target * adjust_scale 396 | return latent 397 | 398 | 399 | Normalize = NormalizeFilter 400 | 401 | if EXT_BLEH: 402 | 403 | class BlehEnhanceFilter(Filter): 404 | name = "bleh_enhance" 405 | default_options = Filter.default_options | { 406 | "enhance_mode": None, 407 | "enhance_scale": 1.0, 408 | } 409 | 410 | def filter(self, latent, *args, **kwargs): 411 | if self.enhance_mode is None or self.enhance_scale == 1: 412 | return latent 413 | return EXT_BLEH.latent_utils.enhance_tensor( 414 | latent, self.enhance_mode, scale=self.enhance_scale, adjust_scale=False 415 | ) 416 | 417 | class BlehOpsFilter(Filter): 418 | name = "bleh_ops" 419 | default_options = Filter.default_options | {"ops": ()} 420 | 421 | def __init__(self, **kwargs): 422 | super().__init__(**kwargs) 423 | if isinstance(self.ops, (tuple, list)): 424 | self.ops = EXT_BLEH.nodes.ops.RuleGroup( 425 | tuple( 426 | r 427 | for rs in self.ops 428 | for r in EXT_BLEH.nodes.ops.Rule.from_dict(rs) 429 | ) 430 | ) 431 | return 432 | if not isinstance(self.ops, str): 433 | raise ValueError("ops key must be a YAML string or list of object") 434 | self.ops = EXT_BLEH.nodes.ops.RuleGroup.from_yaml(self.ops) 435 | 436 | def filter(self, latent, ref_latent, *args, refs=None, **kwargs): 437 | if not self.ops: 438 | return latent 439 | refs = fallback(refs, {}) 440 | bops = EXT_BLEH.nodes.ops 441 | state = { 442 | bops.CondType.TYPE: bops.PatchType.LATENT, 443 | bops.CondType.PERCENT: 0.0, 444 | bops.CondType.BLOCK: -1, 445 | bops.CondType.STAGE: -1, 446 | bops.CondType.STEP: refs.get("step", 0), 447 | bops.CondType.STEP_EXACT: refs.get("step", -1), 448 | "h": latent, 449 | "hsp": ref_latent, 450 | "target": "h", 451 | } 452 | self.ops.eval(state, toplevel=True) 453 | return state["h"] 454 | 455 | FILTER |= { 456 | "bleh_enhance": BlehEnhanceFilter, 457 | "bleh_ops": BlehOpsFilter, 458 | } 459 | 460 | if EXT_SONAR: 461 | 462 | class SonarPowerFilter(Filter): 463 | name = "sonar_power_filter" 464 | default_options = Filter.default_options 465 | 466 | def __init__(self, **kwargs): 467 | super().__init__(**kwargs) 468 | power_filter = self.options.pop("power_filter", None) 469 | if power_filter is None: 470 | self.power_filter = None 471 | return 472 | if not isinstance(power_filter, dict): 473 | raise ValueError("power_filter key must be dict or null") 474 | self.power_filter = ( 475 | expression_handlers.SonarPowerFilterHandler.make_power_filter( 476 | power_filter 477 | ) 478 | ) 479 | 480 | def filter(self, latent, ref_latent, *args, refs=None, **kwargs): 481 | if not self.power_filter: 482 | return latent 483 | filter_rfft = self.power_filter.make_filter(latent.shape).to( 484 | latent.device, non_blocking=True 485 | ) 486 | ns = self.power_filter.make_noise_sampler_internal( 487 | latent, 488 | lambda *_unused, latent=latent: latent, 489 | filter_rfft, 490 | normalized=False, 491 | ) 492 | return ns(None, None) 493 | 494 | FILTER |= {"sonar_power_filter": SonarPowerFilter} 495 | 496 | 497 | def make_filter(args): 498 | if not isinstance(args, dict): 499 | raise TypeError(f"Bad type for filter: {type(args)}") 500 | args = args.copy() 501 | filter_type = args.pop("filter_type", "simple") 502 | if not isinstance(filter_type, str): 503 | raise ValueError("Missing or invalid filter_type") 504 | filter_fun = FILTER.get(filter_type) 505 | if filter_fun is None: 506 | raise ValueError(f"Unknown filter_type: {filter_type}") 507 | return filter_fun(**args) 508 | 509 | 510 | FILTER |= { 511 | "simple": SimpleFilter, 512 | "blend": BlendFilter, 513 | "list": ListFilter, 514 | "normalize": NormalizeFilter, 515 | } 516 | -------------------------------------------------------------------------------- /py/latent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | import folder_paths 6 | import latent_preview 7 | 8 | from comfy.taesd.taesd import TAESD 9 | from comfy.utils import bislerp 10 | from comfy import latent_formats 11 | 12 | from .external import MODULES as EXT 13 | 14 | 15 | def normalize_to_scale(latent, target_min, target_max, *, dim=(-3, -2, -1)): 16 | min_val, max_val = ( 17 | latent.amin(dim=dim, keepdim=True), 18 | latent.amax(dim=dim, keepdim=True), 19 | ) 20 | normalized = (latent - min_val).div_(max_val - min_val) 21 | return ( 22 | normalized.mul_(target_max - target_min) 23 | .add_(target_min) 24 | .clamp_(target_min, target_max) 25 | ) 26 | 27 | 28 | # The following is modified to work with latent images of ~0 mean from https://github.com/Jamy-L/Pytorch-Contrast-Adaptive-Sharpening/tree/main. 29 | def contrast_adaptive_sharpening(x, amount=0.8, *, epsilon=1e-06): 30 | """ 31 | Performs a contrast adaptive sharpening on the batch of images x. 32 | The algorithm is directly implemented from FidelityFX's source code, 33 | that can be found here 34 | https://github.com/GPUOpen-Effects/FidelityFX-CAS/blob/master/ffx-cas/ffx_cas.h 35 | 36 | Parameters 37 | ---------- 38 | x : Tensor 39 | Image or stack of images, of shape [batch, channels, ny, nx]. 40 | Batch and channel dimensions can be ommited. 41 | amount : int [0, 1] 42 | Amount of sharpening to do, 0 being minimum and 1 maximum 43 | 44 | Returns 45 | ------- 46 | Tensor 47 | Processed stack of images. 48 | 49 | """ 50 | 51 | def on_abs_stacked(tensor_list, f, *args, **kwargs): 52 | return f(torch.abs(torch.stack(tensor_list)), *args, **kwargs)[0] 53 | 54 | x_padded = F.pad(x, pad=(1, 1, 1, 1)) 55 | x_padded = torch.complex(x_padded, torch.zeros_like(x_padded)) 56 | # each side gets padded with 1 pixel 57 | # padding = same by default 58 | 59 | # Extracting the 3x3 neighborhood around each pixel 60 | # a b c 61 | # d e f 62 | # g h i 63 | 64 | a = x_padded[..., :-2, :-2] 65 | b = x_padded[..., :-2, 1:-1] 66 | c = x_padded[..., :-2, 2:] 67 | d = x_padded[..., 1:-1, :-2] 68 | e = x_padded[..., 1:-1, 1:-1] 69 | f = x_padded[..., 1:-1, 2:] 70 | g = x_padded[..., 2:, :-2] 71 | h = x_padded[..., 2:, 1:-1] 72 | i = x_padded[..., 2:, 2:] 73 | 74 | # Computing contrast 75 | cross = (b, d, e, f, h) 76 | mn = on_abs_stacked(cross, torch.min, axis=0) 77 | mx = on_abs_stacked(cross, torch.max, axis=0) 78 | 79 | diag = (a, c, g, i) 80 | mn2 = on_abs_stacked(diag, torch.min, axis=0) 81 | mx2 = on_abs_stacked(diag, torch.max, axis=0) 82 | 83 | mx = mx + mx2 84 | mn = mn + mn2 85 | 86 | # Computing local weight 87 | inv_mx = torch.reciprocal(mx + epsilon) # 1/mx 88 | 89 | amp = inv_mx * mn 90 | 91 | # scaling 92 | amp = torch.sqrt(amp) 93 | 94 | w = -amp * (amount * (1 / 5 - 1 / 8) + 1 / 8) 95 | # w scales from 0 when amp=0 to K for amp=1 96 | # K scales from -1/5 when amount=1 to -1/8 for amount=0 97 | 98 | # The local conv filter is 99 | # 0 w 0 100 | # w 1 w 101 | # 0 w 0 102 | div = torch.reciprocal(1 + 4 * w) 103 | output = ((b + d + f + h) * w + e) * div 104 | 105 | return output.real.clamp(x.min(), x.max()) 106 | 107 | 108 | class ImageBatch(tuple): 109 | __slots__ = () 110 | 111 | 112 | class OCSTAESD: 113 | latent_formats = { 114 | "sd15": latent_formats.SD15(), 115 | "sdxl": latent_formats.SDXL(), 116 | } 117 | 118 | @classmethod 119 | def get_decoder_name(cls, fmt): 120 | return cls.latent_formats[fmt].taesd_decoder_name 121 | 122 | @classmethod 123 | def get_encoder_name(cls, fmt): 124 | result = cls.get_decoder_name(fmt) 125 | if not result.endswith("_decoder"): 126 | raise RuntimeError( 127 | f"Could not determine TAESD encoder name from {result!r}" 128 | ) 129 | return f"{result[:-7]}encoder" 130 | 131 | @classmethod 132 | def get_taesd_path(cls, name): 133 | taesd_path = next( 134 | ( 135 | fn 136 | for fn in folder_paths.get_filename_list("vae_approx") 137 | if fn.startswith(name) 138 | ), 139 | "", 140 | ) 141 | if taesd_path == "": 142 | raise RuntimeError(f"Could not get TAESD path for {name!r}") 143 | return folder_paths.get_full_path("vae_approx", taesd_path) 144 | 145 | @classmethod 146 | def decode(cls, fmt, latent): 147 | latent_format = cls.latent_formats[fmt] 148 | # rv = latent_format.process_out(1.0) 149 | filename = cls.get_taesd_path(cls.get_decoder_name(fmt)) 150 | model = TAESD( 151 | decoder_path=filename, latent_channels=latent_format.latent_channels 152 | ).to(latent.device) 153 | # print("DEC INPUT ORIG", latent.min(), latent.max()) 154 | # if torch.any(latent.max() > rv) or torch.any(latent.min() < -rv): 155 | # sv = latent.new((-rv, rv)) 156 | # latent = normalize_to_scale( 157 | # latent, 158 | # latent.amin(dim=(-3, -2, -1), keepdim=True).maximum(sv[0]), 159 | # latent.amax(dim=(-3, -2, -1), keepdim=True).minimum(sv[1]), 160 | # dim=(-3, -2, -1), 161 | # ) 162 | # print("DEC INPUT", latent.min(), latent.max()) 163 | # result = model.decode(latent.clamp(-rv, rv)).movedim(1, 3) 164 | result = model.decode(latent).movedim(1, 3) 165 | # print("DEC RESULT", result.shape, result.isnan().any().item()) 166 | return ImageBatch( 167 | latent_preview.preview_to_image(result[batch_idx]) 168 | for batch_idx in range(result.shape[0]) 169 | ) 170 | 171 | @staticmethod 172 | def img_to_encoder_input(imgbatch): 173 | return torch.stack( 174 | tuple( 175 | torch.tensor(np.array(img), dtype=torch.float32) 176 | .div_(127) 177 | .sub_(1.0) 178 | .clamp_(-1, 1) 179 | for img in imgbatch 180 | ), 181 | dim=0, 182 | ).movedim(-1, 1) 183 | 184 | @classmethod 185 | def encode(cls, fmt, imgbatch, latent, *, normalize_output=False): 186 | latent_format = cls.latent_formats[fmt] 187 | rv = latent_format.process_out(1.0) 188 | filename = cls.get_taesd_path(cls.get_encoder_name(fmt)) 189 | model = TAESD( 190 | encoder_path=filename, latent_channels=latent_format.latent_channels 191 | ).to(device=latent.device) 192 | result = model.encode(cls.img_to_encoder_input(imgbatch).to(latent.device)) 193 | # print( 194 | # "ENC RESULT ORIG", 195 | # result.min(), 196 | # result.max(), 197 | # ) 198 | # if torch.any(result.max() > rv) or torch.any(result.min() < -rv): 199 | # sv = result.new((-rv, rv)) 200 | # result = normalize_to_scale( 201 | # result, 202 | # result.amin(dim=(-3, -2, -1), keepdim=True).maximum(sv[0]), 203 | # result.amax(dim=(-3, -2, -1), keepdim=True).minimum(sv[1]), 204 | # dim=(-3, -2, -1), 205 | # ) 206 | # print( 207 | # "ENC RESULT", 208 | # result.shape, 209 | # result.isnan().any().item(), 210 | # result.min(), 211 | # result.max(), 212 | # ) 213 | return result.to(latent.dtype).clamp(-rv, rv) 214 | 215 | 216 | if "bleh" in EXT: 217 | scale_samples = EXT["bleh"].latent_utils.scale_samples 218 | UPSCALE_METHODS = EXT["bleh"].latent_utils.UPSCALE_METHODS 219 | else: 220 | UPSCALE_METHODS = ("bicubic", "bislerp", "bilinear", "nearest-exact", "area") 221 | 222 | def scale_samples( 223 | samples, 224 | width, 225 | height, 226 | mode="bicubic", 227 | sigma=None, # noqa: ARG001 228 | ): 229 | if mode == "bislerp": 230 | return bislerp(samples, width, height) 231 | return F.interpolate(samples, size=(height, width), mode=mode) 232 | 233 | 234 | if "sonar" in EXT: 235 | get_noise_sampler = EXT["sonar"].noise.get_noise_sampler 236 | else: 237 | 238 | def get_noise_sampler(noise_type, x, *_args: list, **_kwargs: dict): 239 | if noise_type != "gaussian": 240 | raise ValueError("Only gaussian noise supported") 241 | return lambda _s, _sn: torch.randn_like(x) 242 | 243 | 244 | if "nnlatentupscale" in EXT: 245 | 246 | def scale_nnlatentupscale( 247 | mode, 248 | latent, 249 | scale=2.0, 250 | *, 251 | scale_factor=0.13025, 252 | __nlu_module=EXT["nnlatentupscale"], 253 | ): 254 | module = __nlu_module 255 | mode = {"sdxl": "SDXL", "sd1": "SD 1.x"}.get(mode) 256 | if mode is None: 257 | raise ValueError("Bad mode") 258 | node = module.NNLatentUpscale() 259 | model = module.latent_resizer.LatentResizer.load_model( 260 | node.weight_path[mode], latent.device, latent.dtype 261 | ).to(device=latent.device) 262 | result = ( 263 | model(scale_factor * latent, scale=scale).to( 264 | dtype=latent.dtype, device=latent.device 265 | ) 266 | / scale_factor 267 | ) 268 | del model 269 | return result 270 | -------------------------------------------------------------------------------- /py/model.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | 5 | import comfy 6 | from comfy.k_diffusion.sampling import to_d 7 | 8 | from . import filtering 9 | 10 | from .utils import fallback 11 | 12 | 13 | class History: 14 | def __init__(self, size): 15 | self.history = [] 16 | self.size = size 17 | 18 | def __len__(self): 19 | return len(self.history) 20 | 21 | def __getitem__(self, k): 22 | return self.history[k] 23 | 24 | def push(self, val): 25 | if len(self.history) >= self.size: 26 | self.history = self.history[-(self.size - 1) :] 27 | self.history.append(val) 28 | 29 | def reset(self): 30 | self.history = [] 31 | 32 | def clone(self): 33 | obj = self.__new__(self.__class__) 34 | obj.__init__(self.size) 35 | obj.history = self.history.copy() 36 | return obj 37 | 38 | 39 | class ModelResult: 40 | def __init__( 41 | self, 42 | call_idx, 43 | sigma, 44 | x, 45 | denoised, 46 | **kwargs, 47 | ): 48 | self.call_idx = call_idx 49 | self.sigma = sigma 50 | self.x = x 51 | self.denoised = denoised 52 | for k in ("denoised_uncond", "denoised_cond", "tangents", "jdenoised"): 53 | setattr(self, k, kwargs.pop(k, None)) 54 | if len(kwargs) != 0: 55 | raise ValueError(f"Unexpected keyword arguments: {tuple(kwargs.keys())}") 56 | 57 | def to_d( 58 | self, 59 | /, 60 | x=None, 61 | sigma=None, 62 | denoised=None, 63 | denoised_uncond=None, 64 | alt_cfgpp_scale=0, 65 | cfgpp=False, 66 | ): 67 | x = fallback(x, self.x) 68 | sigma = fallback(sigma, self.sigma) 69 | denoised = fallback(denoised, self.denoised) 70 | denoised_uncond = fallback(denoised_uncond, self.denoised_uncond) 71 | if alt_cfgpp_scale != 0: 72 | x = x - denoised * alt_cfgpp_scale + denoised_uncond * alt_cfgpp_scale 73 | return to_d(x, sigma, denoised if not cfgpp else denoised_uncond) 74 | 75 | @property 76 | def d(self): 77 | return self.to_d() 78 | 79 | def clone(self, deep=False): 80 | obj = self.__new__(self.__class__) 81 | for k in ( 82 | "denoised", 83 | "call_idx", 84 | "sigma", 85 | "x", 86 | "denoised_uncond", 87 | "denoised_cond", 88 | "tangents", 89 | "jdenoised", 90 | ): 91 | val = getattr(self, k) 92 | if deep and isinstance(val, torch.Tensor): 93 | val = val.copy() 94 | setattr(obj, k, val) 95 | return obj 96 | 97 | def get_error(self, other, *, override=None, alt_cfgpp_scale=0, cfgpp=False): 98 | slf = fallback(override, self) 99 | first, second = (other, slf) if other.sigma > slf.sigma else (slf, other) 100 | if first.sigma == second.sigma: 101 | return 0.0 102 | d = first.to_d(alt_cfgpp_scale=alt_cfgpp_scale, cfgpp=cfgpp) 103 | d_pred = second.to_d( 104 | x=first.x + d * (second.sigma - first.sigma), 105 | alt_cfgpp_scale=alt_cfgpp_scale, 106 | cfgpp=cfgpp, 107 | ) 108 | return torch.linalg.norm(d_pred.sub_(d)).div_(torch.linalg.norm(d)).item() 109 | 110 | 111 | ModelCallCacheConfig = namedtuple( 112 | "ModelCallCacheConfig", ("size", "max_use", "threshold"), defaults=(0, 1000000, 1) 113 | ) 114 | 115 | 116 | class ModelCallCache: 117 | def __init__( 118 | self, 119 | model, 120 | x: torch.Tensor, 121 | s_in: torch.Tensor, 122 | extra_args: dict, 123 | *, 124 | cache: None | dict = None, 125 | filter: None | dict = None, 126 | cfg1_uncond_optimization: bool = False, 127 | cfg_scale_override: None | int | float = None, 128 | ) -> None: 129 | self.cache = ModelCallCacheConfig(**fallback(cache, {})) 130 | filtargs = fallback(filter, {}).copy() 131 | self.filters = {} 132 | for key in ("input", "denoised", "jdenoised", "cond", "uncond", "x"): 133 | filt = filtargs.pop(key, None) 134 | if filt is None: 135 | continue 136 | self.filters[key] = filtering.make_filter(filt) 137 | self.model = model 138 | self.s_in = s_in 139 | self.extra_args = extra_args 140 | self.cfg1_uncond_optimization = cfg1_uncond_optimization 141 | self.cfg_scale_override = cfg_scale_override 142 | self.is_rectified_flow = x.shape[1] == 16 and isinstance( 143 | model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST 144 | ) 145 | if self.cache.size < 1: 146 | return 147 | self.reset_cache() 148 | 149 | def maybe_filter( 150 | self, name: str, latent: torch.Tensor, *args: list, **kwargs: dict 151 | ) -> torch.Tensor: 152 | filt = self.filters.get(name) 153 | if filt is None: 154 | return latent 155 | return filt.apply(latent, *args, **kwargs) 156 | 157 | def filter_result( 158 | self, result: ModelResult, *args: list, **kwargs: dict 159 | ) -> ModelResult: 160 | if not self.filters: 161 | return result 162 | result = result.clone() 163 | for key in ("denoised", "cond", "uncond", "jdenoised", "x"): 164 | filt = self.filters.get(key) 165 | if filt is None: 166 | continue 167 | attk = f"denoised_{key}" if key in ("cond", "uncond") else key 168 | inpval = getattr(result, attk, None) 169 | if inpval is None: 170 | continue 171 | setattr(result, attk, filt.apply(inpval, *args, **kwargs)) 172 | return result 173 | 174 | @staticmethod 175 | def _fr_add_mr(fr: filtering.FilterRefs, mr: ModelResult) -> filtering.FilterRefs: 176 | frmr = filtering.FilterRefs.from_mr(mr) 177 | fr.kvs |= {f"{k}_curr": v for k, v in frmr.kvs.items()} 178 | return fr 179 | 180 | def reset_cache(self) -> None: 181 | size = self.cache.size 182 | self.slot = [None] * size 183 | self.slot_use = [self.cache.max_use] * size 184 | 185 | def get(self, idx: int, *, jvp: bool = False) -> None | ModelResult: 186 | idx -= self.cache.threshold 187 | if ( 188 | idx >= self.cache.size 189 | or idx < 0 190 | or self.slot[idx] is None 191 | or self.slot_use[idx] < 1 192 | ): 193 | return None 194 | result = self.slot[idx] 195 | if jvp and result.jdenoised is None: 196 | return None 197 | self.slot_use[idx] -= 1 198 | return result 199 | 200 | def set(self, idx: int, mr: ModelResult) -> None: 201 | idx -= self.cache.threshold 202 | if idx < 0 or idx >= self.cache.size: 203 | return 204 | self.slot_use[idx] = self.cache.max_use 205 | self.slot[idx] = mr 206 | 207 | def call_model( 208 | self, x: torch.Tensor, sigma: torch.Tensor, **kwargs: dict 209 | ) -> torch.Tensor: 210 | return self.model(x, sigma * self.s_in, **self.extra_args | kwargs) 211 | 212 | @property 213 | def model_sampling(self): 214 | return self.model.inner_model.inner_model.model_sampling 215 | 216 | @property 217 | def inner_cfg_scale(self) -> None | int | float: 218 | maybe_cfg_scale = getattr(self.model.inner_model, "cfg", None) 219 | return maybe_cfg_scale if isinstance(maybe_cfg_scale, (int, float)) else None 220 | 221 | def set_inner_cfg_scale(self, scale: None | int | float) -> None | int | float: 222 | eff_scale = self.cfg_scale_override 223 | if scale is not None: 224 | eff_scale = None if scale < 0 else scale 225 | if eff_scale is None or eff_scale < 0: 226 | return None 227 | curr_cfg_scale = self.inner_cfg_scale 228 | if curr_cfg_scale is None: 229 | return None 230 | self.model.inner_model.cfg = eff_scale 231 | return curr_cfg_scale 232 | 233 | def __call__( 234 | self, 235 | x: torch.Tensor, 236 | sigma: torch.Tensor, 237 | *, 238 | call_index: int = 0, 239 | ss, 240 | s_in=None, 241 | tangents=None, 242 | require_uncond: bool = False, 243 | cfg_scale_override: None | int = None, 244 | **kwargs, 245 | ) -> ModelResult: 246 | filter_refs = ss.refs | filtering.FilterRefs({ 247 | "model_call": call_index, 248 | "orig_x": x, 249 | }) 250 | result = self.get(call_index, jvp=tangents is not None) 251 | # print( 252 | # f"MODEL: idx={call_index}, size={self.size}, threshold={self.threshold}, cached={result is not None}" 253 | # ) 254 | if result is not None: 255 | self._fr_add_mr(filter_refs, result) 256 | result = self.filter_result(result, default_ref=x, refs=filter_refs) 257 | return result 258 | 259 | comfy.model_management.throw_exception_if_processing_interrupted() 260 | 261 | model_options = self.extra_args.get("model_options", {}).copy() 262 | denoised_cond = denoised_uncond = None 263 | 264 | def postcfg(args): 265 | nonlocal denoised_cond, denoised_uncond 266 | denoised_uncond = args["uncond_denoised"] 267 | denoised_cond = args["cond_denoised"] 268 | if denoised_uncond is None: 269 | denoised_uncond = denoised_cond 270 | return args["denoised"] 271 | 272 | orig_cfg_scale = self.set_inner_cfg_scale(cfg_scale_override) 273 | 274 | model_options = comfy.model_patcher.set_model_options_post_cfg_function( 275 | model_options, 276 | postcfg, 277 | disable_cfg1_optimization=require_uncond 278 | or not self.cfg1_uncond_optimization, 279 | ) 280 | 281 | extra_args = self.extra_args | {"model_options": model_options} 282 | s_in = fallback(s_in, self.s_in) 283 | x = self.maybe_filter("input", x, refs=filter_refs) 284 | 285 | def call_model(x, sigma, **kwargs): 286 | return self.model(x, sigma * s_in, **extra_args | kwargs) 287 | 288 | if tangents is None: 289 | denoised = call_model(x, sigma, **kwargs) 290 | self.set_inner_cfg_scale(orig_cfg_scale) 291 | mr = ModelResult( 292 | call_index, 293 | sigma, 294 | x, 295 | denoised, 296 | denoised_uncond=denoised_uncond, 297 | denoised_cond=denoised_cond, 298 | ) 299 | self.set(call_index, mr) 300 | self._fr_add_mr(filter_refs, mr) 301 | mr = self.filter_result(mr, default_ref=x, refs=filter_refs) 302 | return mr 303 | denoised, denoised_prime = torch.func.jvp(call_model, (x, sigma), tangents) 304 | self.set_inner_cfg_scale(orig_cfg_scale) 305 | mr = ModelResult( 306 | call_index, 307 | sigma, 308 | x, 309 | denoised, 310 | jdenoised=denoised_prime, 311 | denoised_uncond=denoised_uncond, 312 | denoised_cond=denoised_cond, 313 | ) 314 | self.set(call_index, mr) 315 | self._fr_add_mr(filter_refs, mr) 316 | mr = self.filter_result(mr, default_ref=x, refs=filter_refs) 317 | return mr 318 | -------------------------------------------------------------------------------- /py/nodes.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | import comfy 4 | 5 | from .sampling import composable_sampler 6 | from .substep_sampling import StepSamplerChain, StepSamplerGroups, ParamGroup 7 | from .step_samplers import STEP_SAMPLERS 8 | from .substep_merging import MERGE_SUBSTEPS_CLASSES 9 | from .restart import Restart 10 | 11 | DEFAULT_YAML_PARAMS = """\ 12 | # JSON or YAML parameters 13 | s_noise: 1.0 14 | eta: 1.0 15 | """ 16 | 17 | 18 | class SamplerNode: 19 | RETURN_TYPES = ("SAMPLER",) 20 | CATEGORY = "sampling/custom_sampling/OCS" 21 | DESCRIPTION = "Overly Complicated Sampling main sampler node. Can be connected to a SamplerCustom or other sampler node that supports a SAMPLER input." 22 | OUTPUT_TOOLTIPS = ( 23 | "SAMPLER that can be connected to a SamplerCustom or other sampler node that supports a SAMPLER input.", 24 | ) 25 | 26 | FUNCTION = "go" 27 | 28 | @classmethod 29 | def INPUT_TYPES(cls): 30 | return { 31 | "required": { 32 | "groups": ( 33 | "OCS_GROUPS", 34 | { 35 | "tooltip": "Connect OCS substep groups here which are output from the OCS Group node." 36 | }, 37 | ), 38 | }, 39 | "optional": { 40 | "params_opt": ( 41 | "OCS_PARAMS", 42 | { 43 | "tooltip": "Optionally connect parameters like custom noise here. Output from the OCS Param or OCS MultiParam nodes.", 44 | }, 45 | ), 46 | "parameters": ( 47 | "STRING", 48 | { 49 | "default": DEFAULT_YAML_PARAMS, 50 | "multiline": True, 51 | "dynamicPrompts": False, 52 | "tooltip": "The text parameter block allows setting custom parameters using YAML (recommended) or JSON. Optional, may be left blank.", 53 | }, 54 | ), 55 | }, 56 | } 57 | 58 | def go( 59 | self, 60 | *, 61 | groups, 62 | params_opt=None, 63 | parameters="", 64 | ): 65 | options = {} 66 | parameters = parameters.strip() 67 | if parameters: 68 | extra_params = yaml.safe_load(parameters) 69 | if extra_params is not None: 70 | if not isinstance(extra_params, dict): 71 | raise ValueError("Parameters must be a JSON or YAML object") 72 | options |= extra_params 73 | if params_opt is not None: 74 | options |= params_opt.items 75 | options["_groups"] = groups.clone() 76 | return ( 77 | comfy.samplers.KSAMPLER( 78 | composable_sampler, {"overly_complicated_options": options} 79 | ), 80 | ) 81 | 82 | 83 | class GroupNode: 84 | RETURN_TYPES = ("OCS_GROUPS",) 85 | CATEGORY = "sampling/custom_sampling/OCS" 86 | DESCRIPTION = "Over Complicated Sampling group definition node." 87 | OUTPUT_TOOLTIPS = ( 88 | "This output can be connect to another OCS Group node or an OCS Sampler node.", 89 | ) 90 | 91 | FUNCTION = "go" 92 | 93 | @classmethod 94 | def INPUT_TYPES(cls): 95 | return { 96 | "required": { 97 | "merge_method": ( 98 | tuple(MERGE_SUBSTEPS_CLASSES.keys()), 99 | { 100 | "tooltip": "The merge method determines how multiple substeps are combined together during sampling.", 101 | }, 102 | ), 103 | "time_mode": ( 104 | ("step", "step_pct", "sigma"), 105 | { 106 | "tooltip": "The time mode controls how the time_start and time_end parameters are interpreted. The default of step is generally easiest to use.", 107 | }, 108 | ), 109 | "time_start": ( 110 | "FLOAT", 111 | { 112 | "default": 0, 113 | "min": 0.0, 114 | "step": 0.1, 115 | "round'": False, 116 | "tooltip": "The start time this group will be active (inclusive).", 117 | }, 118 | ), 119 | "time_end": ( 120 | "FLOAT", 121 | { 122 | "default": 999, 123 | "min": 0.0, 124 | "step": 0.1, 125 | "round'": False, 126 | "tooltip": "The group will become inactive when the current time is GREATER than the specified end time.", 127 | }, 128 | ), 129 | "substeps": ( 130 | "OCS_SUBSTEPS", 131 | { 132 | "tooltip": "Connect output from an OCS Substeps node here.", 133 | }, 134 | ), 135 | }, 136 | "optional": { 137 | "groups_opt": ( 138 | "OCS_GROUPS", 139 | { 140 | "tooltip": "You may optionally connect the output from another OCS Group node here. Only one group per step is used, matching (based on time or other constraints) starts with the OCS Group node furthest from the OCS Sampler.", 141 | }, 142 | ), 143 | "params_opt": ( 144 | "OCS_PARAMS", 145 | { 146 | "tooltip": "Optionally connect parameters like custom noise here. Output from the OCS Param or OCS MultiParam nodes.", 147 | }, 148 | ), 149 | "parameters": ( 150 | "STRING", 151 | { 152 | "default": DEFAULT_YAML_PARAMS, 153 | "multiline": True, 154 | "dynamicPrompts": False, 155 | "tooltip": "The text parameter block allows setting custom parameters using YAML (recommended) or JSON. Optional, may be left blank.", 156 | }, 157 | ), 158 | }, 159 | } 160 | 161 | def go( 162 | self, 163 | *, 164 | merge_method, 165 | time_mode, 166 | time_start, 167 | time_end, 168 | substeps, 169 | groups_opt=None, 170 | params_opt=None, 171 | parameters="", 172 | ): 173 | group = StepSamplerGroups() if groups_opt is None else groups_opt.clone() 174 | chain = substeps.clone() 175 | chain.merge_method = merge_method 176 | chain.time_mode = time_mode 177 | chain.time_start, chain.time_end = time_start, time_end 178 | options = {} 179 | parameters = parameters.strip() 180 | if parameters: 181 | extra_params = yaml.safe_load(parameters) 182 | if extra_params is not None: 183 | if not isinstance(extra_params, dict): 184 | raise ValueError("Parameters must be a JSON or YAML object") 185 | options |= extra_params 186 | if params_opt is not None: 187 | options |= params_opt.items 188 | chain.options |= options 189 | group.append(chain) 190 | return (group,) 191 | 192 | 193 | class SubstepsNode: 194 | RETURN_TYPES = ("OCS_SUBSTEPS",) 195 | CATEGORY = "sampling/custom_sampling/OCS" 196 | DESCRIPTION = "Overly Complicated Sampling substeps definition node. Used to define a sampler type and other sampler-specific parameters." 197 | OUTPUT_TOOLTIPS = ( 198 | "This output can be connected to another OCS Substeps node or an OCS Group node.", 199 | ) 200 | 201 | FUNCTION = "go" 202 | 203 | @classmethod 204 | def INPUT_TYPES(cls): 205 | return { 206 | "required": { 207 | "substeps": ( 208 | "INT", 209 | { 210 | "default": 1, 211 | "min": 1, 212 | "max": 1000, 213 | "tooltip": "Number of substeps to use for each step, in other words (depending on the OCS Group merge strategy) it may split a step into multiple smaller steps.", 214 | }, 215 | ), 216 | "step_method": ( 217 | tuple(STEP_SAMPLERS.keys()), 218 | { 219 | "tooltip": "In other words, the sampler.", 220 | }, 221 | ), 222 | }, 223 | "optional": { 224 | "substeps_opt": ( 225 | "OCS_SUBSTEPS", 226 | { 227 | "tooltip": "Optionally connect another OCS Substeps node here. Substeps will run in order, starting from the OCS Substeps node FURTHEST from the OCS Group node.", 228 | }, 229 | ), 230 | "params_opt": ( 231 | "OCS_PARAMS", 232 | { 233 | "tooltip": "Optionally connect parameters like custom noise here. Output from the OCS Param or OCS MultiParam nodes.", 234 | }, 235 | ), 236 | "parameters": ( 237 | "STRING", 238 | { 239 | "default": DEFAULT_YAML_PARAMS, 240 | "multiline": True, 241 | "dynamicPrompts": False, 242 | "tooltip": "The text parameter block allows setting custom parameters using YAML (recommended) or JSON. Optional, may be left blank.", 243 | }, 244 | ), 245 | }, 246 | } 247 | 248 | def go( 249 | self, 250 | *, 251 | parameters="", 252 | substeps_opt=None, 253 | params_opt=None, 254 | **kwargs, 255 | ): 256 | if substeps_opt is not None: 257 | chain = substeps_opt.clone() 258 | else: 259 | chain = StepSamplerChain() 260 | parameters = parameters.strip() 261 | if parameters: 262 | extra_params = yaml.safe_load(parameters) 263 | if extra_params is not None: 264 | if not isinstance(extra_params, dict): 265 | raise ValueError("Parameters must be a JSON or YAML object") 266 | kwargs |= extra_params 267 | if params_opt is not None: 268 | kwargs |= params_opt.items 269 | chain.append(kwargs) 270 | return (chain,) 271 | 272 | 273 | class Wildcard(str): 274 | __slots__ = () 275 | 276 | def __ne__(self, _unused): 277 | return False 278 | 279 | 280 | class ParamNode: 281 | RETURN_TYPES = ("OCS_PARAMS",) 282 | CATEGORY = "sampling/custom_sampling/OCS" 283 | DESCRIPTION = "Overly Complicated Sampling parameter definition node. Used to set parameters like custom noise types that require an input." 284 | OUTPUT_TYPES = ( 285 | "Can be connected to another OCS Param or OCS MultiParam node or any other OCS node that takes OCS_PARAMS as an input.", 286 | ) 287 | 288 | FUNCTION = "go" 289 | 290 | WC = Wildcard("*") 291 | 292 | OCS_PARAM_TYPES = { 293 | "custom_noise": lambda v: hasattr(v, "make_noise_sampler"), 294 | "merge_sampler": lambda v: isinstance(v, StepSamplerChain), 295 | "restart_custom_noise": lambda v: hasattr(v, "make_noise_sampler"), 296 | "SAMPLER": lambda _v: True, 297 | } 298 | 299 | @classmethod 300 | def INPUT_TYPES(cls): 301 | return { 302 | "required": { 303 | "key": ( 304 | tuple(cls.OCS_PARAM_TYPES.keys()), 305 | { 306 | "tooltip": "Used to set the type of custom parameter.", 307 | }, 308 | ), 309 | "value": ( 310 | cls.WC, 311 | { 312 | "tooltip": "Connect the type of value expected by the key. Allows connecting output from any type of node HOWEVER if it is the wrong type expected by the key you will get an error when you run the workflow.", 313 | }, 314 | ), 315 | }, 316 | "optional": { 317 | "params_opt": ( 318 | "OCS_PARAMS", 319 | { 320 | "tooltip": "You may optionally connect the output from other OCS Param or OCS MultiParam nodes here to set multiple parameters.", 321 | }, 322 | ), 323 | "parameters": ( 324 | "STRING", 325 | { 326 | "default": "# Additional YAML or JSON parameters\n", 327 | "multiline": True, 328 | "dynamicPrompts": False, 329 | "tooltip": "The text parameter block allows setting custom parameters using YAML (recommended) or JSON. Optional, may be left blank.", 330 | }, 331 | ), 332 | }, 333 | } 334 | 335 | @classmethod 336 | def get_renamed_key(cls, key, params): 337 | rename = params.get("rename") 338 | if rename is None: 339 | return key 340 | if not isinstance(rename, str): 341 | raise ValueError("Param rename key must be a string if set") 342 | rename = rename.strip() 343 | if not rename or not all(c == "_" or c.isalnum() for c in rename): 344 | raise ValueError( 345 | "Param rename keys must consist of one or more alphanumeric or underscore characters" 346 | ) 347 | return f"{key}_{rename}" 348 | 349 | def go(self, *, key, value, params_opt=None, parameters=""): 350 | if not self.OCS_PARAM_TYPES[key](value): 351 | raise ValueError(f"CSamplerParam: Bad value type for key {key}") 352 | if parameters: 353 | extra_params = yaml.safe_load(parameters) 354 | if extra_params is not None: 355 | if not isinstance(extra_params, dict): 356 | raise ValueError("Parameters must be a JSON or YAML object") 357 | key = self.get_renamed_key(key, extra_params) 358 | else: 359 | extra_params = None 360 | params = ParamGroup(items={}) if params_opt is None else params_opt.clone() 361 | params[key] = value 362 | if extra_params is not None: 363 | params[f"{key}.params"] = extra_params 364 | return (params,) 365 | 366 | 367 | class MultiParamNode(ParamNode): 368 | RETURN_TYPES = ("OCS_PARAMS",) 369 | CATEGORY = "sampling/custom_sampling/OCS" 370 | DESCRIPTION = "Overly Complicated Sampling parameter definition node. Used to set parameters like custom noise types that require an input. Like the OCS Param node but allows setting multiple parameters at the same time." 371 | OUTPUT_TYPES = ( 372 | "Can be connected to another OCS Param or OCS MultiParam node or any other OCS node that takes OCS_PARAMS as an input.", 373 | ) 374 | 375 | FUNCTION = "go" 376 | 377 | PARAM_COUNT = 5 378 | 379 | @classmethod 380 | def INPUT_TYPES(cls): 381 | param_keys = ( 382 | ("", *ParamNode.OCS_PARAM_TYPES.keys()), 383 | { 384 | "tooltip": "Used to set the type of custom parameter.", 385 | }, 386 | ) 387 | return { 388 | "required": { 389 | f"key_{idx}": param_keys for idx in range(1, cls.PARAM_COUNT + 1) 390 | }, 391 | "optional": { 392 | "params_opt": ( 393 | "OCS_PARAMS", 394 | { 395 | "tooltip": "You may optionally connect the output from other OCS MultiParam or OCS Param nodes here to set multiple parameters.", 396 | }, 397 | ), 398 | "parameters": ( 399 | "STRING", 400 | { 401 | "default": """\ 402 | # Additional YAML or JSON parameters 403 | # Should be an object with key corresponding to the index of the input 404 | """, 405 | "multiline": True, 406 | "dynamicPrompts": False, 407 | "tooltip": "The text parameter block allows setting custom parameters using YAML (recommended) or JSON. Optional, may be left blank.", 408 | }, 409 | ), 410 | } 411 | | { 412 | f"value_opt_{idx}": ( 413 | ParamNode.WC, 414 | { 415 | "tooltip": "Connect the type of value expected by the corresponding key. Allows connecting output from any type of node HOWEVER if it is the wrong type expected by the corresponding key you will get an error when you run the workflow.", 416 | }, 417 | ) 418 | for idx in range(1, cls.PARAM_COUNT + 1) 419 | }, 420 | } 421 | 422 | def go(self, *, params_opt=None, parameters="", **kwargs): 423 | params = ParamGroup(items={}) if params_opt is None else params_opt.clone() 424 | if parameters: 425 | extra_params = yaml.safe_load(parameters) 426 | if extra_params is not None: 427 | if not isinstance(extra_params, dict): 428 | raise ValueError("Parameters must be a JSON or YAML object") 429 | else: 430 | extra_params = {} 431 | else: 432 | extra_params = {} 433 | for idx in range(1, self.PARAM_COUNT + 1): 434 | key, value = kwargs.get(f"key_{idx}"), kwargs.get(f"value_opt_{idx}") 435 | if not key or value is None: 436 | continue 437 | if not self.OCS_PARAM_TYPES[key](value): 438 | raise ValueError(f"CSamplerParamGroup: Bad value type for key {key}") 439 | extra = extra_params.get(str(idx)) 440 | key = self.get_renamed_key(key, extra) 441 | params[key] = value 442 | if extra is not None: 443 | params[f"{key}.params"] = extra 444 | 445 | return (params,) 446 | 447 | 448 | class SimpleRestartSchedule: 449 | RETURN_TYPES = ("SIGMAS",) 450 | CATEGORY = "sampling/custom_sampling/OCS" 451 | DESCRIPTION = "Overly Complicated Sampling simple Restart schedule node. Allows generating a Restart sampling schedule based on a text definition." 452 | OUTPUT_TYPES = ( 453 | "Can be connected to an OCS Sampler or RestartSampler node. Do not connect directly to a sampler that doesn't have built-in support for Restart schedules.", 454 | ) 455 | 456 | FUNCTION = "go" 457 | 458 | @classmethod 459 | def INPUT_TYPES(cls): 460 | return { 461 | "required": { 462 | "sigmas": ( 463 | "SIGMAS", 464 | { 465 | "tooltip": "Connect the output from another scheduler node (i.e. BasicScheduler) here.", 466 | }, 467 | ), 468 | "start_step": ( 469 | "INT", 470 | { 471 | "min": 0, 472 | "default": 0, 473 | "tooltip": "Step the restart schedule definition starts applying. Zero-based.", 474 | }, 475 | ), 476 | }, 477 | "optional": { 478 | "schedule": ( 479 | "STRING", 480 | { 481 | "default": """\ 482 | # YAML or JSON restart schedule 483 | # Every 5 steps, jump back 3 steps 484 | - [5, -3] 485 | # Jump to schedule item 0 486 | - 0 487 | """, 488 | "multiline": True, 489 | "dynamicPrompts": False, 490 | "tooltip": "Define a schedule here using YAML (recommended) or JSON.", 491 | }, 492 | ), 493 | }, 494 | } 495 | 496 | def go(self, *, sigmas, start_step=0, schedule="[]"): 497 | if schedule: 498 | parsed_schedule = yaml.safe_load(schedule) 499 | if parsed_schedule is not None: 500 | if not isinstance(parsed_schedule, (list, tuple)): 501 | raise ValueError("Schedule must be a JSON or YAML list") 502 | else: 503 | parsed_schedule = [] 504 | else: 505 | parsed_schedule = [] 506 | return (Restart.simple_schedule(sigmas, start_step, parsed_schedule),) 507 | 508 | 509 | class ModelSetMaxSigmaNode: 510 | RETURN_TYPES = ("MODEL",) 511 | CATEGORY = "hacks" 512 | DESCRIPTION = "Allows forcing a model's maximum and minumum sigmas to a specified value. You generally do NOT want to connect this to a sampler node. Connect it to a scheduler node (i.e. BasicScheduler) instead." 513 | OUTPUT_TOOLTIPS = ( 514 | "Patched model. Can be connected to a scheduler node (i.e. BasicScheduler). Generally NOT recommended to connect to an actual sampler.", 515 | ) 516 | 517 | FUNCTION = "go" 518 | 519 | @classmethod 520 | def INPUT_TYPES(cls): 521 | return { 522 | "required": { 523 | "model": ( 524 | "MODEL", 525 | { 526 | "tooltip": "Model to patch with the min/max sigmas.", 527 | }, 528 | ), 529 | "mode": ( 530 | ("recalculate", "simple_multiply"), 531 | { 532 | "tooltip": "Mode use for setting sigmas in the patched model. Recalculate should generally be more accurate.", 533 | }, 534 | ), 535 | "sigma_max": ( 536 | "FLOAT", 537 | { 538 | "default": -1.0, 539 | "min": -10000.0, 540 | "max": 10000.0, 541 | "step": 0.01, 542 | "round": False, 543 | "tooltip": "You can set the maximum sigma here. If you use a negative value, it will be interpreted as the absolute value for the max sigma. If you use a positive value it will be interpreted as a percentage (where 1.0 signified 100%). Schedules generated with the patched model should start from sigma_max (or close to it).", 544 | }, 545 | ), 546 | "fake_sigma_min": ( 547 | "FLOAT", 548 | { 549 | "default": 0.0, 550 | "min": 0.0, 551 | "max": 1000.0, 552 | "step": 0.01, 553 | "round": False, 554 | "tooltip": "You can set the minimum sigma here. Disabled if set to 0. If you use a negative value, it will be interpreted as the absolute value for the max sigma. If you use a positive value it will be interpreted as a percentage (where 1.0 signified 100%). Schedules generated with the patched model should end with [sigma_min, 0]. NOTE: May not work with some schedulers. I recommend leaving this at 0 unless you know you need it (and even then it may not work).", 555 | }, 556 | ), 557 | } 558 | } 559 | 560 | def go(self, model, mode="recalculate", sigma_max=-1.0, fake_sigma_min=0.0): 561 | if sigma_max == 0: 562 | raise ValueError("ModelSetMaxSigma: Invalid sigma_max value") 563 | if mode not in ("recalculate", "simple_multiply"): 564 | raise ValueError("ModelSetMaxSigma: Invalid mode value") 565 | orig_ms = model.get_model_object("model_sampling") 566 | model = model.clone() 567 | orig_max_sigma, orig_min_sigma = ( 568 | orig_ms.sigma_max.item(), 569 | orig_ms.sigma_min.item(), 570 | ) 571 | max_multiplier = abs(sigma_max) if sigma_max < 0 else sigma_max / orig_max_sigma 572 | if max_multiplier == 1: 573 | return (model,) 574 | mcfg = model.get_model_object("model_config") 575 | orig_sigmas = orig_ms.sigmas 576 | fake_sigma_min = orig_sigmas.new_full((1,), fake_sigma_min) 577 | 578 | class NewModelSampling(orig_ms.__class__): 579 | if fake_sigma_min != 0: 580 | 581 | @property 582 | def sigma_min(self): 583 | return fake_sigma_min 584 | 585 | ms = NewModelSampling(mcfg) 586 | if mode == "simple_multiply": 587 | ms.set_sigmas(orig_sigmas * max_multiplier) 588 | else: 589 | ss = getattr(mcfg, "sampling_setting", None) or {} 590 | if ss.get("beta_schedule", "linear") != "linear": 591 | raise NotImplementedError( 592 | "ModelSetMaxSigma: Can only handle linear beta schedules in reschedule mode" 593 | ) 594 | ms.set_sigmas((orig_sigmas**2 * max_multiplier**2) ** 0.5) 595 | new_max_sigma, new_min_sigma = ms.sigma_max.item(), ms.sigma_min.item() 596 | if new_min_sigma >= new_max_sigma: 597 | raise ValueError( 598 | "ModelSetMaxSigma: Invalid fake_min_sigma value, result max <= min" 599 | ) 600 | model.add_object_patch("model_sampling", ms) 601 | print( 602 | f"ModelSetMaxSigma: Set model sigmas({mode}): old_max={orig_max_sigma:.04}, old_min={orig_min_sigma:.03}, new_max={new_max_sigma:.04}, new_min={new_min_sigma:.03}" 603 | ) 604 | return (model,) 605 | 606 | 607 | __all__ = ( 608 | "SamplerNode", 609 | "GroupNode", 610 | "SubstepsNode", 611 | "ParamNode", 612 | "MultiParamNode", 613 | "ModelSetMaxSigmaNode", 614 | ) 615 | -------------------------------------------------------------------------------- /py/noise.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import random 3 | 4 | import scipy 5 | import torch 6 | 7 | from .filtering import Filter, make_filter 8 | from .utils import scale_noise, fallback 9 | 10 | 11 | class ImmiscibleNoise(Filter): 12 | name = "immiscible" 13 | uses_ref = True 14 | default_options = Filter.default_options | { 15 | "size": 0, 16 | "batching": "channel", 17 | "maximize": False, 18 | } 19 | 20 | def __call__(self, noise_sampler, x_ref, *, refs=None): 21 | if not self.check_applies(refs): 22 | return noise_sampler() 23 | return self.apply( 24 | torch.cat(tuple(noise_sampler() for _ in range(self.size))) 25 | if self.size > 0 26 | else noise_sampler(), 27 | default_ref=x_ref, 28 | refs=refs, 29 | output_shape=x_ref.shape, 30 | ) 31 | 32 | def filter(self, latent, ref_latent, *, refs, output_shape): 33 | if self.size == 0: 34 | return latent 35 | return self.unbatch( 36 | self.immiscible(self.batch(latent), self.batch(ref_latent)), output_shape 37 | ) 38 | 39 | def batch(self, latent): 40 | if self.batching == "batch": 41 | return latent 42 | sz = latent.shape 43 | if latent.ndim != 4: 44 | raise ValueError("Both latent and reference must be four-dimensional") 45 | if self.batching == "channel": 46 | return latent.view(sz[0] * sz[1], *sz[2:]) 47 | if self.batching == "row": 48 | return latent.view(sz[0] * sz[1] * sz[2], sz[3]) 49 | if self.batching == "column": 50 | return latent.permute(0, 1, 3, 2).reshape(sz[0] * sz[1] * sz[3], sz[2]) 51 | raise ValueError("Bad Immmiscible noise batching type") 52 | 53 | def unbatch(self, latent, sz): 54 | if self.batching == "column": 55 | return latent.view(*sz[:2], sz[3], sz[2]).permute(0, 1, 3, 2) 56 | return latent.view(*sz) 57 | 58 | # Based on implementation from https://github.com/kohya-ss/sd-scripts/pull/1395 59 | # Idea from https://github.com/Clybius 60 | def immiscible(self, latent, ref_latent): 61 | # "Immiscible Diffusion: Accelerating Diffusion Training with Noise Assignment" (2024) Li et al. arxiv.org/abs/2406.12303 62 | # Minimize latent-noise pairs over a batch 63 | n = latent.shape[0] 64 | ref_latent_expanded = ( 65 | ref_latent.half().unsqueeze(1).expand(-1, n, *ref_latent.shape[1:]) 66 | ) 67 | latent_expanded = ( 68 | latent.half().unsqueeze(0).expand(ref_latent.shape[0], *latent.shape) 69 | ) 70 | dist = (ref_latent_expanded - latent_expanded) ** 2 71 | dist = dist.mean(list(range(2, dist.dim()))).cpu() 72 | try: 73 | assign_mat = scipy.optimize.linear_sum_assignment( 74 | dist, maximize=self.maximize 75 | ) 76 | except ValueError as _exc: 77 | # print("\nImmiscible: Failed optimization, skipping") 78 | return latent[: ref_latent.shape[0]] 79 | # print("IMM IDX", assign_mat[1]) 80 | return latent[assign_mat[1]] 81 | 82 | 83 | class NoiseSamplerCache: 84 | def __init__( 85 | self, 86 | x, 87 | seed, 88 | min_sigma, 89 | max_sigma, 90 | *, 91 | normalize_noise=True, 92 | cpu_noise=True, 93 | batch_size=1, 94 | caching=True, 95 | cache_reset_interval=9999, 96 | set_seed=False, 97 | scale=1.0, 98 | normalize_dims=(-3, -2, -1), 99 | immiscible=None, 100 | filter=None, 101 | **_unused, 102 | ): 103 | self.x = x 104 | self.mega_x = None 105 | self.seed = seed 106 | self.seed_offset = 0 107 | self.min_sigma = min_sigma 108 | self.max_sigma = max_sigma 109 | self.cache = {} 110 | self.batch_size = max(1, batch_size) 111 | self.normalize_noise = normalize_noise 112 | self.cpu_noise = cpu_noise 113 | self.caching = caching 114 | self.cache_reset_interval = max(1, cache_reset_interval) 115 | self.scale = float(scale) 116 | self.normalize_dims = tuple(int(v) for v in normalize_dims) 117 | self.immiscible = ImmiscibleNoise(**fallback(immiscible, {})) 118 | if filter is None: 119 | self.filter = None 120 | else: 121 | self.filter = make_filter(filter) 122 | self.update_x(x) 123 | if set_seed: 124 | random.seed(seed) 125 | torch.manual_seed(seed) 126 | 127 | def reset_cache(self): 128 | self.cache = {} 129 | gc.collect() 130 | 131 | def scale_noise(self, noise, factor=1.0, normalized=None, normalize_dims=None): 132 | normalized = self.normalize_noise if normalized is None else normalized 133 | normalize_dims = ( 134 | self.normalize_dims if normalize_dims is None else normalize_dims 135 | ) 136 | return scale_noise( 137 | noise, factor, normalized=normalized, normalize_dims=normalize_dims 138 | ) 139 | 140 | def update_x(self, x): 141 | if self.x.shape == x.shape and self.mega_x is not None: 142 | self.x = x 143 | return 144 | self.x = x 145 | self.mega_x = None 146 | self.reset_cache() 147 | if self.batch_size == 1: 148 | self.mega_x = x 149 | return 150 | self.mega_x = x.repeat(x.shape[0] * self.batch_size, *((1,) * (x.dim() - 1))) 151 | 152 | def set_cache(self, key, noise_sampler): 153 | if not self.caching: 154 | return 155 | self.cache[key] = noise_sampler 156 | 157 | def make_caching_noise_sampler( 158 | self, 159 | nsobj, 160 | size, 161 | sigma, 162 | sigma_next, 163 | immiscible=None, 164 | ): 165 | size = min(size, self.batch_size) 166 | cache_key = (nsobj, size) 167 | if self.caching: 168 | noise_sampler = self.cache.get(cache_key) 169 | if noise_sampler: 170 | return noise_sampler 171 | curr_seed = self.seed + self.seed_offset 172 | self.seed_offset += 1 173 | curr_x = self.mega_x[: self.x.shape[0] * size, ...] 174 | if nsobj is None: 175 | 176 | def ns(_s, _sn, *_unused, **_unusedkwargs): 177 | return torch.randn_like(curr_x) 178 | 179 | else: 180 | ns = nsobj.make_noise_sampler( 181 | curr_x, 182 | self.min_sigma, 183 | self.max_sigma, 184 | seed=curr_seed, 185 | normalized=False, 186 | cpu=self.cpu_noise, 187 | ) 188 | 189 | orig_h, orig_w = self.x.shape[-2:] 190 | remain = 0 191 | noise = None 192 | if immiscible is None: 193 | immiscible = self.immiscible 194 | 195 | def noise_sampler_( 196 | curr_sigma, 197 | curr_sigma_next, 198 | *_unused, 199 | out_hw=(orig_h, orig_w), 200 | **_unusedkwargs, 201 | ): 202 | nonlocal remain, noise 203 | if out_hw != (orig_h, orig_w): 204 | raise NotImplementedError( 205 | f"Noise size mismatch: {out_hw} vs {(orig_h, orig_w)}" 206 | ) 207 | if remain < 1: 208 | curr_sigma = fallback(curr_sigma, sigma) 209 | curr_sigma_next = fallback(curr_sigma_next, sigma_next) 210 | noise = self.scale_noise(ns(curr_sigma, curr_sigma_next)).view( 211 | size, 212 | *self.x.shape, 213 | ) 214 | remain = size 215 | result = noise[-remain] 216 | remain -= 1 217 | return result 218 | 219 | def noise_sampler(*args, x_ref=None, refs=None, **kwargs): 220 | if immiscible is False: 221 | noise = noise_sampler_(*args, **kwargs) 222 | else: 223 | noise = immiscible( 224 | lambda args=args, kwargs=kwargs: noise_sampler_(*args, **kwargs), 225 | fallback(x_ref, self.x), 226 | refs=refs, 227 | ) 228 | return ( 229 | self.filter.apply(noise, refs=refs) 230 | if self.filter is not None 231 | else noise 232 | ) 233 | 234 | self.set_cache(cache_key, noise_sampler) 235 | return noise_sampler 236 | -------------------------------------------------------------------------------- /py/res_support.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | from torch import FloatTensor 6 | from typing import Optional, NamedTuple 7 | 8 | # Copied from https://github.com/Clybius/ComfyUI-Extra-Samplers 9 | 10 | 11 | def _gamma( 12 | n: int, 13 | ) -> int: 14 | """ 15 | https://en.wikipedia.org/wiki/Gamma_function 16 | for every positive integer n, 17 | Γ(n) = (n-1)! 18 | """ 19 | return math.factorial(n - 1) 20 | 21 | 22 | def _incomplete_gamma(s: int, x: float, gamma_s: Optional[int] = None) -> float: 23 | """ 24 | https://en.wikipedia.org/wiki/Incomplete_gamma_function#Special_values 25 | if s is a positive integer, 26 | Γ(s, x) = (s-1)!*∑{k=0..s-1}(x^k/k!) 27 | """ 28 | if gamma_s is None: 29 | gamma_s = _gamma(s) 30 | 31 | sum_: float = 0 32 | # {k=0..s-1} inclusive 33 | for k in range(s): 34 | numerator: float = x**k 35 | denom: int = math.factorial(k) 36 | quotient: float = numerator / denom 37 | sum_ += quotient 38 | incomplete_gamma_: float = sum_ * math.exp(-x) * gamma_s 39 | return incomplete_gamma_ 40 | 41 | 42 | # by Katherine Crowson 43 | def _phi_1(neg_h: FloatTensor): 44 | return torch.nan_to_num(torch.expm1(neg_h) / neg_h, nan=1.0) 45 | 46 | 47 | # by Katherine Crowson 48 | def _phi_2(neg_h: FloatTensor): 49 | return torch.nan_to_num((torch.expm1(neg_h) - neg_h) / neg_h**2, nan=0.5) 50 | 51 | 52 | # by Katherine Crowson 53 | def _phi_3(neg_h: FloatTensor): 54 | return torch.nan_to_num( 55 | (torch.expm1(neg_h) - neg_h - neg_h**2 / 2) / neg_h**3, nan=1 / 6 56 | ) 57 | 58 | 59 | def _phi( 60 | neg_h: float, 61 | j: int, 62 | ): 63 | """ 64 | For j={1,2,3}: you could alternatively use Kat's phi_1, phi_2, phi_3 which perform fewer steps 65 | 66 | Lemma 1 67 | https://arxiv.org/abs/2308.02157 68 | ϕj(-h) = 1/h^j*∫{0..h}(e^(τ-h)*(τ^(j-1))/((j-1)!)dτ) 69 | 70 | https://www.wolframalpha.com/input?i=integrate+e%5E%28%CF%84-h%29*%28%CF%84%5E%28j-1%29%2F%28j-1%29%21%29d%CF%84 71 | = 1/h^j*[(e^(-h)*(-τ)^(-j)*τ(j))/((j-1)!)]{0..h} 72 | https://www.wolframalpha.com/input?i=integrate+e%5E%28%CF%84-h%29*%28%CF%84%5E%28j-1%29%2F%28j-1%29%21%29d%CF%84+between+0+and+h 73 | = 1/h^j*((e^(-h)*(-h)^(-j)*h^j*(Γ(j)-Γ(j,-h)))/(j-1)!) 74 | = (e^(-h)*(-h)^(-j)*h^j*(Γ(j)-Γ(j,-h))/((j-1)!*h^j) 75 | = (e^(-h)*(-h)^(-j)*(Γ(j)-Γ(j,-h))/(j-1)! 76 | = (e^(-h)*(-h)^(-j)*(Γ(j)-Γ(j,-h))/Γ(j) 77 | = (e^(-h)*(-h)^(-j)*(1-Γ(j,-h)/Γ(j)) 78 | 79 | requires j>0 80 | """ 81 | assert j > 0 82 | gamma_: float = _gamma(j) 83 | incomp_gamma_: float = _incomplete_gamma(j, neg_h, gamma_s=gamma_) 84 | 85 | phi_: float = math.exp(neg_h) * neg_h**-j * (1 - incomp_gamma_ / gamma_) 86 | 87 | return phi_ 88 | 89 | 90 | class RESDECoeffsSecondOrder(NamedTuple): 91 | a2_1: float 92 | b1: float 93 | b2: float 94 | 95 | 96 | def _de_second_order( 97 | h: float, 98 | c2: float, 99 | simple_phi_calc=False, 100 | ) -> RESDECoeffsSecondOrder: 101 | """ 102 | Table 3 103 | https://arxiv.org/abs/2308.02157 104 | ϕi,j := ϕi,j(-h) = ϕi(-cj*h) 105 | a2_1 = c2ϕ1,2 106 | = c2ϕ1(-c2*h) 107 | b1 = ϕ1 - ϕ2/c2 108 | """ 109 | if simple_phi_calc: 110 | # Kat computed simpler expressions for phi for cases j={1,2,3} 111 | a2_1: float = c2 * _phi_1(-c2 * h) 112 | phi1: float = _phi_1(-h) 113 | phi2: float = _phi_2(-h) 114 | else: 115 | # I computed general solution instead. 116 | # they're close, but there are slight differences. not sure which would be more prone to numerical error. 117 | a2_1: float = c2 * _phi(j=1, neg_h=-c2 * h) 118 | phi1: float = _phi(j=1, neg_h=-h) 119 | phi2: float = _phi(j=2, neg_h=-h) 120 | phi2_c2: float = phi2 / c2 121 | b1: float = phi1 - phi2_c2 122 | b2: float = phi2_c2 123 | return RESDECoeffsSecondOrder( 124 | a2_1=a2_1, 125 | b1=b1, 126 | b2=b2, 127 | ) 128 | -------------------------------------------------------------------------------- /py/restart.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Restart: 5 | def __init__(self, *, s_noise=1.0, custom_noise=None, immiscible=False): 6 | from .noise import ImmiscibleNoise 7 | 8 | self.s_noise = s_noise 9 | if immiscible is not False: 10 | immiscible = ImmiscibleNoise(**immiscible) 11 | self.immiscible = immiscible 12 | self.custom_noise = custom_noise 13 | 14 | def get_noise_sampler(self, nsc): 15 | return nsc.make_caching_noise_sampler( 16 | self.custom_noise, 17 | 1, 18 | nsc.max_sigma, 19 | nsc.min_sigma, 20 | immiscible=self.immiscible, 21 | ) 22 | 23 | @staticmethod 24 | def get_segment(sigmas: torch.Tensor) -> torch.Tensor: 25 | last_sigma = sigmas[0] 26 | for idx in range(1, len(sigmas)): 27 | sigma = sigmas[idx] 28 | if sigma > last_sigma: 29 | return sigmas[:idx] 30 | last_sigma = sigma 31 | return sigmas 32 | 33 | def split_sigmas(self, sigmas): 34 | prev_seg = None 35 | while len(sigmas) > 1: 36 | seg = self.get_segment(sigmas) 37 | sigmas = sigmas[len(seg) :] 38 | if prev_seg is not None and seg[0] > prev_seg[-1]: 39 | noise_scale = self.get_noise_scale(prev_seg[-1], seg[0]) 40 | else: 41 | noise_scale = 0.0 42 | prev_seg = seg 43 | yield (noise_scale, seg) 44 | 45 | def get_noise_scale(self, s_min, s_max): 46 | result = (s_max**2 - s_min**2) ** 0.5 47 | if isinstance(result, torch.Tensor): 48 | result = result.item() 49 | return result * self.s_noise 50 | 51 | def __repr__(self): 52 | return f"" 53 | 54 | @classmethod 55 | def simple_schedule(cls, sigmas, start_step, schedule=(), max_iter=1000): 56 | if sigmas.ndim != 1: 57 | raise ValueError("Bad number of dimensions for sigmas") 58 | siglen = len(sigmas) - 1 59 | if siglen <= start_step or not len(schedule): 60 | return sigmas 61 | siglist = sigmas.cpu().tolist() 62 | out = siglist[:start_step] 63 | sched_len = len(schedule) 64 | sched_idx = 0 65 | sig_idx = start_step 66 | iter_count = 0 67 | while 0 <= sched_idx < sched_len: 68 | # print(f"LOOP: sched_idx={sched_idx}, sig_idx={sig_idx}: {out}") 69 | iter_count += 1 70 | if iter_count > max_iter: 71 | raise RuntimeError("Hit max iteration count. Loop in schedule?") 72 | item = schedule[sched_idx] 73 | if not isinstance(item, (list, tuple)): 74 | if item < 0: 75 | item = sched_len + item 76 | if item < 0 or item >= sched_len: 77 | raise ValueError("Schedule jump index out of range") 78 | sched_idx = item 79 | continue 80 | if sig_idx >= siglen or sig_idx < 0: 81 | break 82 | interval, jump = item 83 | chunk = siglist[sig_idx : sig_idx + interval + 1] 84 | # print(f"{out} + {chunk}") 85 | out += chunk 86 | sig_idx += interval + jump 87 | if jump >= 0: 88 | sig_idx += 1 89 | sched_idx += 1 90 | if sig_idx < siglen and sig_idx >= 0: 91 | out += siglist[sig_idx:] 92 | if out[-1] > siglist[-1]: 93 | out.append(siglist[-1]) 94 | return torch.tensor(out).to(sigmas) 95 | -------------------------------------------------------------------------------- /py/sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm.auto import trange 3 | 4 | 5 | from .filtering import FILTER_HANDLERS, FilterRefs 6 | from .model import ModelCallCache 7 | from .noise import NoiseSamplerCache 8 | from .substep_sampling import SamplerState 9 | from .substep_merging import MERGE_SUBSTEPS_CLASSES 10 | from .restart import Restart 11 | 12 | 13 | def find_merge_sampler(merge_samplers, ss) -> object | None: 14 | handlers = None 15 | for merge_sampler in merge_samplers: 16 | if merge_sampler.when is not None and handlers is None: 17 | handlers = FILTER_HANDLERS.clone(constants=ss.refs) 18 | # handlers = FILTER_HANDLERS.clone_with_refs(ss.refs) 19 | if merge_sampler.check_match(handlers, ss=ss): 20 | return merge_sampler 21 | return None 22 | 23 | 24 | def composable_sampler( 25 | model, 26 | x, 27 | sigmas, 28 | *, 29 | s_noise=1.0, 30 | eta=1.0, 31 | overly_complicated_options, 32 | extra_args=None, 33 | callback=None, 34 | disable=None, 35 | noise_sampler=None, 36 | **kwargs, 37 | ): 38 | copts = overly_complicated_options.copy() 39 | if extra_args is None: 40 | extra_args = {} 41 | if noise_sampler is None: 42 | 43 | def noise_sampler(_s, _sn): 44 | return torch.randn_like(x) 45 | 46 | restart_params = copts.get("restart", {}) 47 | restart_custom_noise = copts.get("restart_custom_noise") 48 | if isinstance(restart_custom_noise, str): 49 | restart_custom_noise = copts.get(f"restart_custom_noise_{restart_custom_noise}") 50 | restart = Restart( 51 | s_noise=restart_params.get("s_noise", 1.0), 52 | custom_noise=restart_custom_noise, 53 | immiscible=restart_params.get("immiscible", False), 54 | ) 55 | 56 | ss = SamplerState( 57 | ModelCallCache( 58 | model, 59 | x, 60 | x.new_ones((x.shape[0],)), 61 | extra_args, 62 | **copts.get("model", {}), 63 | ), 64 | sigmas, 65 | 0, 66 | extra_args, 67 | noise_sampler=noise_sampler, 68 | callback=callback, 69 | eta=eta if eta != 1.0 else copts.get("eta", 1.0), 70 | s_noise=s_noise if s_noise != 1.0 else copts.get("s_noise", 1.0), 71 | reta=copts.get("reta", 1.0), 72 | disable_status=disable, 73 | ) 74 | groups = copts["_groups"] 75 | merge_samplers = tuple( 76 | MERGE_SUBSTEPS_CLASSES[g.merge_method](ss, g) for g in groups.items 77 | ) 78 | nsc = NoiseSamplerCache( 79 | x, 80 | extra_args.get("seed", 42), 81 | sigmas[-1], 82 | sigmas[0], 83 | **copts.get("noise", {}), 84 | ) 85 | ss.noise = nsc 86 | sigma_chunks = tuple(restart.split_sigmas(sigmas)) 87 | step_count = sum(len(chunk) - 1 for _noise, chunk in sigma_chunks) 88 | ss.total_steps = step_count 89 | step = 0 90 | with trange(step_count, disable=ss.disable_status) as pbar: 91 | for noise_scale, chunk_sigmas in sigma_chunks: 92 | if step != 0 and noise_scale != 0: 93 | prev_refs = FilterRefs({ 94 | f"pre_restart_{k}": v for k, v in ss.refs.items() 95 | }) 96 | ss.sigmas = chunk_sigmas 97 | ss.update(0, step=step, substep=0) 98 | if step != 0: 99 | nsc.reset_cache() 100 | nsc.update_x(x) 101 | ss.hist.reset() 102 | for ms in merge_samplers: 103 | ms.reset() 104 | nsc.min_sigma, nsc.max_sigma = chunk_sigmas[-1], chunk_sigmas[0] 105 | if step != 0 and noise_scale != 0: 106 | restart_ns = restart.get_noise_sampler(nsc) 107 | x += nsc.scale_noise( 108 | restart_ns(nsc.min_sigma, nsc.max_sigma, refs=prev_refs | ss.refs), 109 | noise_scale, 110 | ) 111 | del restart_ns 112 | del prev_refs 113 | for idx in range(len(chunk_sigmas) - 1): 114 | if idx > 0: 115 | ss.update(idx, step=step, substep=0) 116 | nsc.update_x(x) 117 | # print( 118 | # f"STEP {step + 1:>3}: {ss.sigma.item():.03} -> {ss.sigma_next.item():.03} || up={ss.sigma_up.item():.03}, down={ss.sigma_down.item():.03}" 119 | # ) 120 | ss.model.reset_cache() 121 | nsc.update_x(x) 122 | merge_sampler = find_merge_sampler(merge_samplers, ss) 123 | if merge_sampler is None: 124 | raise RuntimeError(f"No matching sampler group for step {step + 1}") 125 | pbar.set_description( 126 | f"{merge_sampler.name}: {ss.sigma.item():.03} -> {ss.sigma_next.item():.03}" 127 | ) 128 | x = merge_sampler(x) 129 | if (idx + 1) % nsc.cache_reset_interval == 0: 130 | nsc.reset_cache() 131 | step += 1 132 | pbar.update(1) 133 | return x 134 | -------------------------------------------------------------------------------- /py/substep_merging.py: -------------------------------------------------------------------------------- 1 | import operator 2 | 3 | import torch 4 | import tqdm 5 | 6 | from . import expression as expr 7 | from . import utils 8 | 9 | from .filtering import make_filter, FilterRefs, FILTER_HANDLERS 10 | from .noise import ImmiscibleNoise 11 | from .restart import Restart 12 | from .step_samplers import STEP_SAMPLERS, StepSamplerContext 13 | from .substep_sampling import StepSamplerChain 14 | from .utils import check_time, fallback 15 | 16 | 17 | class MergeSubstepsSampler: 18 | name = "unknown" 19 | 20 | def __init__(self, ss, group): 21 | samplers = tuple( 22 | STEP_SAMPLERS[sitem["step_method"]](**sitem) for sitem in group.items 23 | ) 24 | options = group.options.copy() 25 | self.group = group 26 | self.time_mode = group.time_mode 27 | self.time_start = group.time_start 28 | self.time_end = group.time_end 29 | self.ss = ss 30 | self.samplers = samplers 31 | self.substeps = sum(sampler.substeps for sampler in samplers) 32 | when_expr = options.pop("when", None) 33 | self.when = expr.Expression(when_expr) if when_expr else None 34 | pre_filter = options.pop("pre_filter", None) 35 | post_filter = options.pop("post_filter", None) 36 | self.pre_filter = None if pre_filter is None else make_filter(pre_filter) 37 | self.post_filter = None if post_filter is None else make_filter(post_filter) 38 | self.preview_mode = options.pop("preview_mode", "denoised") 39 | self.require_uncond = any(sampler.require_uncond for sampler in samplers) 40 | self.cfg_scale_override = options.pop("cfg_scale_override", None) 41 | self.options = options 42 | 43 | def check_match(self, handlers: None | object, *, ss: None | object = None): 44 | ss = fallback(ss, self.ss) 45 | if not check_time( 46 | self.time_mode, 47 | self.time_start, 48 | self.time_end, 49 | ss.sigma, 50 | ss.step, 51 | ss.total_steps, 52 | ): 53 | return False 54 | if self.when is None: 55 | return True 56 | if handlers is None: 57 | raise ValueError("Group has when expression but handlers not passed") 58 | return operator.truth(self.when.eval(handlers)) 59 | 60 | def step_input(self, x, *, ss=None): 61 | if self.pre_filter is None: 62 | return x 63 | ss = fallback(ss, self.ss) 64 | return self.pre_filter.apply(x, refs=fallback(ss, self.ss).refs) 65 | 66 | def step_output(self, x, *, orig_x=None, ss=None): 67 | if self.post_filter is None: 68 | return x 69 | ss = fallback(ss, self.ss) 70 | refs = ss.refs if orig_x is None else ss.refs | FilterRefs({"orig_x": orig_x}) 71 | return self.post_filter.apply(x, refs=refs) 72 | 73 | def __call__(self, x): 74 | orig_x = x 75 | x = self.step_input(x) 76 | x = self.step(x) 77 | return self.step_output(x, orig_x=orig_x) 78 | 79 | def step(self, x): 80 | raise NotImplementedError 81 | 82 | def substep(self, x, sampler): 83 | sg = sampler(x) 84 | yield from utils.step_generator(sg, get_next=lambda sr: sr.x) 85 | 86 | def simple_substep(self, x, sampler): 87 | for sr in self.substep(x, sampler): 88 | if not sr.final: 89 | raise RuntimeError("Unexpected non-final sampler result in substep!") 90 | return sr 91 | 92 | def merge_steps(self, x, result=None, *, noise=None, ss=None, denoised=True): 93 | ss = ss if ss is not None else self.ss 94 | result = fallback(result, x) 95 | if noise is not None: 96 | result = result + noise 97 | return result 98 | 99 | def step_max_noise_samples(self): 100 | return sum( 101 | (1 + sampler.self_noise) * sampler.substeps for sampler in self.samplers 102 | ) 103 | 104 | def reset(self): 105 | pass 106 | 107 | def callback(self, *, ss=None, mr=None, preview_mode=None): 108 | ss = fallback(ss, self.ss) 109 | preview_mode = fallback(preview_mode, self.preview_mode) 110 | return ss.callback(hi=mr, preview_mode=preview_mode) 111 | 112 | def call_model(self, x, ss=None, sigma=None, **kwargs): 113 | ss = fallback(ss, self.ss) 114 | sigma = fallback(sigma, ss.sigma) 115 | return ss.call_model( 116 | x, 117 | sigma, 118 | ss=ss, 119 | cfg_scale_override=self.cfg_scale_override, 120 | require_uncond=self.require_uncond, 121 | ) 122 | 123 | 124 | class SimpleSubstepsSampler(MergeSubstepsSampler): 125 | name = "simple" 126 | 127 | def __init__(self, *args, **kwargs): 128 | super().__init__(*args, **kwargs) 129 | if not len(self.samplers): 130 | raise ValueError("Missing sampler") 131 | 132 | def step_max_noise_samples(self): 133 | return 1 + self.samplers[0].self_noise 134 | 135 | def step(self, x): 136 | ss, ssampler = self.ss, self.samplers[0] 137 | ss.hist.push(self.call_model(x)) 138 | ss.refs = FilterRefs.from_ss(ss, have_current=True) 139 | self.callback() 140 | with StepSamplerContext(ssampler, ss) as ssampler: 141 | sr = self.simple_substep(x, ssampler) 142 | return self.merge_steps(sr.noise_x(ss=ss)) 143 | 144 | 145 | class SupremeAvgMergeSubstepsSampler(MergeSubstepsSampler): 146 | name = "supreme_avg" 147 | 148 | def step(self, x): 149 | ss = self.ss 150 | substeps = self.substeps 151 | renoise_weight = 1.0 / substeps 152 | z_avg = torch.zeros_like(x) 153 | noise = z_avg.clone() 154 | noise_total = 0.0 155 | substep = 0 156 | pbar = tqdm.tqdm(total=self.substeps, initial=1, disable=ss.disable_status) 157 | ss.hist.push(self.call_model(x)) 158 | ss.refs = FilterRefs.from_ss(ss, have_current=True) 159 | self.callback() 160 | for ssampler_ in self.samplers: 161 | with StepSamplerContext(ssampler_, ss) as ssampler: 162 | for subidx in range(ssampler.substeps): 163 | pbar.set_description(f"{ssampler.name}: {substep + 1}/{substeps}") 164 | sr = self.simple_substep(x, ssampler) 165 | z_avg += renoise_weight * sr.x 166 | if sr.noise_scale != 0 and ss.sigma_next != 0: 167 | noise_total += renoise_weight * sr.noise_scale 168 | noise += renoise_weight * sr.get_noise(ss=ss) 169 | substep += 1 170 | ss.substep = substep 171 | pbar.update(1) 172 | 173 | noise = ss.noise.scale_noise( 174 | noise, 175 | noise_total * self.options.get("s_noise", 1.0), 176 | normalized=True, 177 | ) 178 | return self.merge_steps( 179 | x, z_avg, noise=None if noise_total == 0 else noise, denoised=ss.denoised 180 | ) 181 | 182 | 183 | # class AverageMergeSubstepsSampler(NormalMergeSubstepsSampler): 184 | # name = "average" 185 | 186 | # def __init__(self, ss, sitems, *, avgmerge_stretch=0.4, **kwargs): 187 | # super().__init__(ss, sitems, **kwargs) 188 | # self.stretch = avgmerge_stretch 189 | 190 | # def step_max_noise_samples(self): 191 | # return sum( 192 | # 1 + (2 + sampler.self_noise) * sampler.substeps for sampler in self.samplers 193 | # ) 194 | 195 | # def step(self, x): 196 | # ss = orig_ss = self.ss 197 | # substeps = self.substeps 198 | # renoise_weight = 1.0 / substeps 199 | # z_avg = torch.zeros_like(x) 200 | # noise = torch.zeros_like(x) 201 | # stretch = (ss.sigma - ss.sigma_next) * self.stretch 202 | # sig_adj = ss.sigma + stretch 203 | # ss = self.ss.clone_edit(sigma=sig_adj) 204 | # orig_x = x 205 | # stretch_strength = stretch * ss.s_noise 206 | # if stretch_strength != 0: 207 | # noise_sampler = ss.noise.make_caching_noise_sampler( 208 | # self.options.get("custom_noise"), 1, orig_ss.sigma, ss.sigma_next 209 | # ) 210 | # x = x + ( 211 | # noise_sampler(orig_ss.sigma, ss.sigma_next).mul_(stretch * ss.s_noise) 212 | # ) 213 | # self.ss.denoised = ss.denoised = ss.model(x, sig_adj) 214 | # noise_total = 0.0 215 | # substep = 0 216 | # for idx, ssampler in enumerate(self.samplers): 217 | # print( 218 | # f" SUBSTEP {substep + 1} .. {substep + ssampler.substeps}: {ssampler.name}, stretch={stretch}" 219 | # ) 220 | # custom_noise = ssampler.options.get( 221 | # "custom_noise", self.options.get("custom_noise") 222 | # ) 223 | # noise_sampler = ss.noise.make_caching_noise_sampler( 224 | # custom_noise, 225 | # ssampler.substeps 226 | # + (0 if ss.sigma_next == 0 else ssampler.max_noise_samples), 227 | # ss.sigma, 228 | # ss.sigma_next, 229 | # ) 230 | # ssampler.noise_sampler = noise_sampler 231 | # for sidx in range(ssampler.substeps): 232 | # curr_x = orig_x + noise_sampler(sig_adj, ss.sigma_next).mul_(stretch) 233 | # sr = self.simple_substep(curr_x, ssampler, ss=ss) 234 | # z_avg += renoise_weight * sr.x 235 | # noise_strength = sr.noise_scale 236 | # if ss.sigma_next == 0 or noise_strength == 0: 237 | # continue 238 | # if noise_strength != 0 and ss.sigma_next != 0: 239 | # noise_curr = sr.get_noise() 240 | # noise_total += noise_strength.item() * renoise_weight 241 | # noise += noise_curr 242 | # substep += 1 243 | # substep += ssampler.substeps 244 | # return self.merge_steps( 245 | # x, 246 | # z_avg, 247 | # noise=None 248 | # if not noise_total 249 | # else ss.noise.scale_noise(noise, noise_total * ss.s_noise, normalized=True), 250 | # ss=ss, 251 | # ) 252 | 253 | 254 | # class SampleMergeSubstepsSampler(AverageMergeSubstepsSampler): 255 | # name = "sample" 256 | # cache_model = True 257 | 258 | # def __init__(self, ss, sitems, *, merge_sampler=None, **kwargs): 259 | # super().__init__(ss, sitems, **kwargs) 260 | # if merge_sampler is None: 261 | # merge_sampler = STEP_SAMPLERS["euler"](step_method="euler") 262 | # else: 263 | # msitem = merge_sampler.items[0] 264 | # merge_sampler = STEP_SAMPLERS[msitem["step_method"]](**msitem) 265 | # self.merge_sampler = merge_sampler 266 | # self.merge_ss = None 267 | 268 | # def step(self, x): 269 | # ss = self.ss 270 | # substeps = self.substeps 271 | # renoise_weight = 1.0 / substeps 272 | # z_avg = torch.zeros_like(x) 273 | # curr_x = x 274 | # ss.denoised = None 275 | # stretch = (ss.sigma - ss.sigma_next) * self.stretch 276 | # sig_adj = ss.sigma + stretch 277 | # ss = self.ss.clone_edit(sigma=sig_adj) 278 | # step = 0 279 | # for idx, ssampler in enumerate(self.samplers): 280 | # print( 281 | # f" SUBSTEP {step + 1} .. {step + ssampler.substeps}: {ssampler.name}, stretch={stretch}" 282 | # ) 283 | # custom_noise = ssampler.options.get( 284 | # "custom_noise", self.options.get("custom_noise") 285 | # ) 286 | # noise_sampler = ss.noise.make_caching_noise_sampler( 287 | # custom_noise, 288 | # ssampler.max_noise_samples + ssampler.substeps, 289 | # ss.sigma, 290 | # ss.sigma_next, 291 | # ) 292 | # ssampler.noise_sampler = noise_sampler 293 | # for sidx in range(ssampler.substeps): 294 | # if idx + sidx == 0 or not self.cache_model: 295 | # self.ss.denoised = ss.denoised = ss.model( 296 | # curr_x, 297 | # ss.sigma, 298 | # # + ss.noise_sampler(sig_adj.sigma, ss.sigma_next) * stretch * ss.s_noise, 299 | # # sig_adj, 300 | # ) 301 | # curr_x = x + noise_sampler(sig_adj, ss.sigma_next).mul_( 302 | # ssampler.s_noise * stretch 303 | # ) 304 | # sr = self.simple_substep(curr_x, ssampler, ss=ss) 305 | # z_avg += renoise_weight * sr.x 306 | # curr_x = sr.noise_x(sr.x) 307 | # step += ssampler.substeps 308 | # return self.merge_steps(curr_x, z_avg) 309 | 310 | # def merge_steps(self, x, result): 311 | # ss = self.ss 312 | # ss.dhist.push(ss.denoised) 313 | # ss.denoised = None 314 | # ss.model.reset_cache() 315 | # msampler = self.merge_sampler 316 | # if self.merge_ss is None: 317 | # merge_ss = self.merge_ss = self.ss.clone_edit( 318 | # denoised=result, 319 | # dhist=History(x, 3), 320 | # xhist=History(x, 2), 321 | # s_noise=msampler.s_noise, 322 | # eta=msampler.eta, 323 | # ) 324 | # else: 325 | # merge_ss = self.merge_ss 326 | # merge_ss.denoised = result 327 | # merge_ss.update(self.ss.idx, step=self.ss.step) 328 | # final = merge_ss.sigma_next == 0 329 | # noise_sampler = merge_ss.noise.make_caching_noise_sampler( 330 | # msampler.options.get("custom_noise", self.options.get("custom_noise")), 331 | # msampler.max_noise_samples + int(not final), 332 | # merge_ss.sigma, 333 | # merge_ss.sigma_next, 334 | # ) 335 | # msampler.noise_sampler = noise_sampler 336 | # sr = self.simple_substep(x, msampler, ss=merge_ss) 337 | # self.ss.callback(sr.x) 338 | # sr.noise_x() 339 | # merge_ss.dhist.push(result) 340 | # merge_ss.xhist.push(sr.x) 341 | # merge_ss.denoised = None 342 | # ss.xhist.push(sr.x) 343 | # return sr.x 344 | 345 | # def reset(self): 346 | # if self.merge_ss is None: 347 | # return 348 | # self.merge_ss.reset() 349 | # self.merge_ss.sigmas = self.ss.sigmas 350 | # self.merge_ss.update(self.ss.idx, step=self.ss.step) 351 | 352 | 353 | # class SampleUncachedMergeSubstepsSampler(SampleMergeSubstepsSampler): 354 | # name = "sample_uncached" 355 | # cache_model = False 356 | 357 | 358 | class DivideMergeSubstepsSampler(MergeSubstepsSampler): 359 | name = "divide" 360 | 361 | def __init__(self, ss, group, **kwargs): 362 | super().__init__(ss, group, **kwargs) 363 | self.schedule_multiplier = self.options.pop("schedule_multiplier", 4) 364 | 365 | def make_schedule(self, ss): 366 | max_steps = len(self.ss.sigmas) - 1 367 | sigmas_slice = ss.sigmas[ 368 | ss.idx : min(max_steps + 1, ss.idx + self.schedule_multiplier) 369 | ] 370 | unsorted_idx = utils.find_first_unsorted(sigmas_slice) 371 | if unsorted_idx is not None: 372 | sigmas_slice = sigmas_slice[:unsorted_idx] 373 | chunks = tuple( 374 | torch.linspace( 375 | sigmas_slice[idx], 376 | sigmas_slice[idx + 1], 377 | steps=self.substeps + 1, 378 | device=sigmas_slice.device, 379 | dtype=sigmas_slice.dtype, 380 | )[0 if not idx else 1 :] 381 | for idx in range(len(sigmas_slice) - 1) 382 | ) 383 | return torch.cat(chunks) 384 | 385 | def step(self, x): 386 | ss = self.ss 387 | subss = self.ss.clone_edit(idx=0, sigmas=self.make_schedule(ss)) 388 | subss.main_idx = ss.idx 389 | subss.main_sigmas = ss.sigmas 390 | substep = 0 391 | pbar = tqdm.tqdm(total=self.substeps, initial=0, disable=ss.disable_status) 392 | for ssampler_ in self.samplers: 393 | with StepSamplerContext(ssampler_, subss) as ssampler: 394 | for subidx in range(ssampler.substeps): 395 | subss.update(substep, substep=substep) 396 | pbar.set_description( 397 | f"substep({ssampler.name}): {subss.sigma.item():.03} -> {subss.sigma_next.item():.03}" 398 | ) 399 | subss.hist.push(self.call_model(x, ss=subss)) 400 | subss.refs = FilterRefs.from_ss(subss, have_current=True) 401 | if substep == 0: 402 | self.callback(ss=subss) 403 | sr = self.simple_substep(x, ssampler) 404 | x = sr.x 405 | noise_strength = sr.noise_scale 406 | if noise_strength != 0 and subss.sigma_next != 0: 407 | x = sr.noise_x(ss=subss) 408 | substep += 1 409 | pbar.update(1) 410 | pbar.update(0) 411 | return x 412 | 413 | 414 | class OvershootMergeSubstepsSampler(MergeSubstepsSampler): 415 | name = "overshoot" 416 | 417 | def __init__( 418 | self, 419 | ss, 420 | group, 421 | **kwargs, 422 | ): 423 | super().__init__(ss, group, **kwargs) 424 | self.overshoot_expand_steps = self.options.pop("overshoot_expand_steps", 1) 425 | restart = self.options.pop("restart", {}) 426 | restart_custom_noise = self.options.get("restart_custom_noise") 427 | if isinstance(restart_custom_noise, str): 428 | restart_custom_noise = self.options.get( 429 | f"restart_custom_noise_{restart_custom_noise}" 430 | ) 431 | self.restart = Restart( 432 | s_noise=restart.get("s_noise", 1.0), 433 | custom_noise=restart_custom_noise, 434 | immiscible=restart.get("immiscible", False), 435 | ) 436 | 437 | def make_schedule(self, ss): 438 | expand = self.overshoot_expand_steps 439 | if expand > self.substeps: 440 | raise ValueError( 441 | "overshoot_expand_steps > substeps: can't make it to the end of step 1" 442 | ) 443 | if expand < 2: 444 | return ss.sigmas, ss.idx 445 | sigmas_cpu = ss.sigmas.cpu() 446 | sigmas = torch.cat( 447 | tuple( 448 | torch.linspace(f, t, expand + 1)[:-1] 449 | for f, t in torch.stack((sigmas_cpu[:-1], sigmas_cpu[1:]), dim=1) 450 | ) 451 | + (sigmas_cpu[-1].unsqueeze(0),) 452 | ) 453 | return sigmas.to(ss.sigmas), ss.idx * expand 454 | 455 | def step(self, x): 456 | ss = self.ss 457 | sigmas, sigidx = self.make_schedule(ss) 458 | subss = ss.clone_edit(idx=sigidx, sigmas=sigmas) 459 | subss.hist = subss.hist.clone() 460 | substep = 0 461 | pbar = tqdm.tqdm(total=self.substeps, initial=0, disable=ss.disable_status) 462 | max_idx = len(subss.sigmas) - 2 463 | last_down = None 464 | for ssampler_ in self.samplers: 465 | with StepSamplerContext(ssampler_, subss) as ssampler: 466 | for subidx in range(ssampler.substeps): 467 | subss.update(subss.idx + substep, substep=substep) 468 | pbar.set_description( 469 | f"substep({ssampler.name}): {subss.sigma.item():.03} -> {subss.sigma_next.item():.03}" 470 | ) 471 | subss.hist.push(self.call_model(x, ss=subss)) 472 | subss.refs = FilterRefs.from_ss(subss, have_current=True) 473 | if substep == 0: 474 | ss.hist.push(subss.hcur) 475 | self.callback(ss=subss) 476 | sr = self.simple_substep(x, ssampler) 477 | x = sr.x 478 | noise_strength = sr.noise_scale 479 | if noise_strength != 0 and subss.sigma_next != 0: 480 | x = sr.noise_x(ss=subss) 481 | substep += 1 482 | pbar.update(1) 483 | last_down = subss.sigma_next.item() 484 | if subss.idx + substep >= max_idx: 485 | break 486 | if subss.idx >= max_idx: 487 | break 488 | if last_down is not None and last_down < ss.sigma_next: 489 | restart_ns = self.restart.get_noise_sampler(ss.noise) 490 | x += ss.noise.scale_noise( 491 | restart_ns(last_down, ss.sigma_next, refs=ss.refs), 492 | self.restart.get_noise_scale(last_down, ss.sigma_next), 493 | ) 494 | pbar.update(0) 495 | return x 496 | 497 | 498 | class LookaheadMergeSubstepsSampler(MergeSubstepsSampler): 499 | name = "lookahead" 500 | 501 | def __init__(self, ss, group, **kwargs): 502 | super().__init__(ss, group, **kwargs) 503 | lookahead = self.options.pop("lookahead", {}).copy() 504 | self.lookahead_eta = lookahead.pop("eta", 0.0) 505 | self.lookahead_s_noise = lookahead.pop("s_noise", 1.0) 506 | self.lookahead_dt_factor = lookahead.pop("dt_factor", 1.0) 507 | immiscible = lookahead.get("immiscible", False) 508 | self.immiscible = ( 509 | ImmiscibleNoise(**immiscible) if immiscible is not False else False 510 | ) 511 | 512 | self.custom_noise = self.options.get("custom_noise") 513 | if isinstance(self.custom_noise, str): 514 | self.custom_noise = self.options.get(f"custom_noise_{self.custom_noise}") 515 | 516 | def step(self, x): 517 | orig_x = x.clone() 518 | ss = self.ss 519 | subss = self.ss.clone_edit(idx=ss.idx, sigmas=ss.sigmas) 520 | substep = 0 521 | max_idx = len(ss.sigmas) - 1 522 | eff_substeps = min(max_idx - ss.idx, self.substeps) 523 | pbar = tqdm.tqdm(total=eff_substeps, initial=0, disable=ss.disable_status) 524 | for ssampler_ in self.samplers: 525 | substeps_remain = eff_substeps - substep 526 | if substeps_remain == 0: 527 | break 528 | with StepSamplerContext(ssampler_, subss) as ssampler: 529 | for subidx in range(min(substeps_remain, ssampler.substeps)): 530 | subss.update(ss.idx + substep, substep=substep) 531 | pbar.set_description( 532 | f"substep({ssampler.name}): {subss.sigma.item():.03} -> {subss.sigma_next.item():.03}" 533 | ) 534 | subss.hist.push(self.call_model(x, ss=subss)) 535 | subss.refs = FilterRefs.from_ss(subss, have_current=True) 536 | if substep == 0: 537 | self.callback(ss=subss) 538 | sr = self.simple_substep(x, ssampler) 539 | x = sr.x 540 | noise_strength = sr.noise_scale 541 | if noise_strength != 0 and subss.sigma_next != 0: 542 | x = sr.noise_x(ss=subss) 543 | substep += 1 544 | pbar.update(1) 545 | if substeps_remain == 1: 546 | break 547 | pbar.update(0) 548 | sigma_down, sigma_up = ss.get_ancestral_step( 549 | eta=self.lookahead_eta, sigma=ss.sigma, sigma_next=ss.sigma_next 550 | ) 551 | if sr.sigma_next == sigma_down: 552 | return x 553 | dt = ( 554 | torch.sqrt(1.0 + (ss.sigma - sigma_down) ** 2) * 0.05 555 | + (ss.sigma - sigma_down) * 0.95 556 | ) * self.lookahead_dt_factor 557 | denoised = sr.denoised 558 | d = (orig_x - denoised) / ss.sigma 559 | x = orig_x + d * -dt 560 | if sigma_down == 0 or sigma_up == 0: 561 | return x 562 | noise_sampler = ss.noise.make_caching_noise_sampler( 563 | self.custom_noise, 564 | 1, 565 | ss.sigma, 566 | ss.sigma_next, 567 | immiscible=fallback(self.immiscible, ss.noise.immiscible), 568 | ) 569 | # FIXME: This sigma, sigma_next is probably wrong. 570 | x += ss.noise.scale_noise( 571 | noise_sampler(ss.sigma, ss.sigma_next, refs=ss.refs), 572 | sigma_up * self.lookahead_s_noise, 573 | ) 574 | return x 575 | 576 | 577 | class DynamicMergeSubstepsSampler(MergeSubstepsSampler): 578 | name = "dynamic" 579 | 580 | def __init__(self, ss, group, **kwargs): 581 | super().__init__(ss, group, **kwargs) 582 | dynamic = self.options.get("dynamic") 583 | if dynamic is None: 584 | raise ValueError( 585 | "Dynamic group type requires specifying dynamic block in text parameters" 586 | ) 587 | if isinstance(dynamic, str): 588 | dynamic = ({"expression": dynamic},) 589 | elif not isinstance(dynamic, (tuple, list)): 590 | raise ValueError( 591 | "Bad type for dynamic block: must be string or list of objects" 592 | ) 593 | elif len(dynamic) == 0: 594 | raise ValueError("Dynamic block as a list cannot be empty") 595 | dynresult = [] 596 | for idx, item in enumerate(dynamic): 597 | if not isinstance(item, dict): 598 | raise ValueError( 599 | f"Bad item in dynamic block at index {idx}: must be a dict" 600 | ) 601 | dyn_when = item.get("when") 602 | if isinstance(dyn_when, str): 603 | dyn_when = expr.Expression(dyn_when) 604 | elif dyn_when is not None: 605 | raise ValueError( 606 | f"Unexpected type for when key in dynamic block at index {idx}, must be string or null/unset" 607 | ) 608 | dyn_params = item.get("expression") 609 | if not isinstance(dyn_params, str): 610 | raise ValueError( 611 | f"Missing or incorrectly typed expression key for dynamic block at index {idx}: must be a string" 612 | ) 613 | dynresult.append((dyn_when, expr.Expression(dyn_params))) 614 | self.dynamic = tuple(dynresult) 615 | 616 | def step(self, x): 617 | group_params = None 618 | handlers = FILTER_HANDLERS.clone(constants=self.ss.refs) 619 | for idx, (dyn_when, dyn_params) in enumerate(self.dynamic): 620 | if dyn_when is not None and not bool(dyn_when.eval(handlers)): 621 | continue 622 | group_params = dyn_params.eval(handlers) 623 | if group_params is not None: 624 | break 625 | if group_params is None: 626 | raise RuntimeError( 627 | "Dynamic group could not find matching group: all expressions failed to return a result" 628 | ) 629 | if not isinstance(group_params, dict): 630 | raise TypeError( 631 | f"Dynamic group expression must evaluate to a dict, got type {type(group_params)}" 632 | ) 633 | if bool(group_params.get("dynamic_inherit")): 634 | copy_keys = ("preview_mode",) 635 | opts = {k: getattr(self, k) for k in copy_keys} 636 | else: 637 | opts = {} 638 | opts |= { 639 | k: v 640 | for k, v in self.options.items() 641 | if k.startswith("custom_noise") or k.startswith("restart_custom_noise") 642 | } 643 | opts |= group_params 644 | # print("\n\nDYN GROUP OPTS", opts) 645 | merge_method = opts.pop("merge_method", "simple").strip() 646 | if merge_method == "default": 647 | merge_method = "simple" 648 | group_class = MERGE_SUBSTEPS_CLASSES.get(merge_method) 649 | if group_class is None: 650 | raise ValueError(f"Unknown merge method {merge_method} in dynamic group") 651 | group = StepSamplerChain( 652 | merge_method=merge_method, items=self.group.items, **opts 653 | ) 654 | sampler = group_class(self.ss, group) 655 | return sampler.step(x) 656 | 657 | 658 | MERGE_SUBSTEPS_CLASSES = { 659 | "default (simple)": SimpleSubstepsSampler, 660 | "supreme_avg": SupremeAvgMergeSubstepsSampler, 661 | "divide": DivideMergeSubstepsSampler, 662 | "overshoot": OvershootMergeSubstepsSampler, 663 | "simple": SimpleSubstepsSampler, 664 | "lookahead": LookaheadMergeSubstepsSampler, 665 | "dynamic": DynamicMergeSubstepsSampler, 666 | } 667 | -------------------------------------------------------------------------------- /py/substep_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from comfy.k_diffusion.sampling import get_ancestral_step 4 | 5 | from .filtering import FilterRefs 6 | from .model import History 7 | from .utils import fallback 8 | 9 | 10 | class Items: 11 | def __init__(self, items=None): 12 | self.items = [] if items is None else items 13 | 14 | def clone(self): 15 | return self.__class__(items=self.items.copy()) 16 | 17 | def append(self, item): 18 | self.items.append(item) 19 | return item 20 | 21 | def __getitem__(self, key): 22 | return self.items[key] 23 | 24 | def __setitem__(self, key, value): 25 | self.items[key] = value 26 | 27 | def __len__(self): 28 | return len(self.items) 29 | 30 | def __iter__(self): 31 | return self.items.__iter__() 32 | 33 | 34 | class CommonOptionsItems(Items): 35 | def __init__(self, *, s_noise=1.0, eta=1.0, items=None, **kwargs): 36 | super().__init__(items=items) 37 | self.options = kwargs 38 | self.s_noise = s_noise 39 | self.eta = eta 40 | 41 | def clone(self): 42 | obj = super().clone() 43 | obj.options = self.options.copy() 44 | obj.s_noise = self.s_noise 45 | obj.eta = self.eta 46 | return obj 47 | 48 | 49 | class StepSamplerChain(CommonOptionsItems): 50 | def __init__( 51 | self, 52 | *, 53 | merge_method="divide", 54 | time_mode="step", 55 | time_start=0, 56 | time_end=999, 57 | **kwargs, 58 | ): 59 | super().__init__(**kwargs) 60 | self.merge_method = merge_method 61 | if time_mode not in ("step", "step_pct", "sigma"): 62 | raise ValueError("Bad time mode") 63 | self.time_mode = time_mode 64 | self.time_start, self.time_end = time_start, time_end 65 | 66 | def clone(self): 67 | obj = super().clone() 68 | obj.merge_method = self.merge_method 69 | obj.time_mode = self.time_mode 70 | obj.time_start, obj.time_end = self.time_start, self.time_end 71 | obj.options = self.options.copy() 72 | return obj 73 | 74 | 75 | class ParamGroup(Items): 76 | pass 77 | 78 | 79 | class StepSamplerGroups(CommonOptionsItems): 80 | pass 81 | 82 | 83 | class SamplerState: 84 | CLONE_KEYS = ( 85 | "cfg_scale_override", 86 | "model", 87 | "hist", 88 | "extra_args", 89 | "disable_status", 90 | "eta", 91 | "reta", 92 | "s_noise", 93 | "sigmas", 94 | "callback_", 95 | "noise_sampler", 96 | "noise", 97 | "idx", 98 | "total_steps", 99 | "step", 100 | "substep", 101 | "sigma", 102 | "sigma_next", 103 | "sigma_prev", 104 | "sigma_down", 105 | "sigma_up", 106 | "refs", 107 | ) 108 | 109 | def __init__( 110 | self, 111 | model, 112 | sigmas, 113 | idx, 114 | extra_args, 115 | *, 116 | step=0, 117 | substep=0, 118 | noise_sampler, 119 | callback=None, 120 | denoised=None, 121 | noise=None, 122 | eta=1.0, 123 | reta=1.0, 124 | s_noise=1.0, 125 | disable_status=False, 126 | history_size=4, 127 | cfg_scale_override=None, 128 | ): 129 | self.model = model 130 | self.hist = History(max(1, history_size)) 131 | self.extra_args = extra_args 132 | self.eta = eta 133 | self.reta = reta 134 | self.s_noise = s_noise 135 | self.sigmas = sigmas 136 | self.callback_ = callback 137 | self.noise_sampler = noise_sampler 138 | self.noise = noise 139 | self.disable_status = disable_status 140 | self.step = 0 141 | self.substep = 0 142 | self.total_steps = len(sigmas) - 1 143 | self.cfg_scale_override = cfg_scale_override 144 | self.update(idx) # Sets idx, sigma_prev, sigma, sigma_down, refs 145 | 146 | @property 147 | def hcur(self): 148 | return self.hist[-1] 149 | 150 | @property 151 | def hprev(self): 152 | return self.hist[-2] 153 | 154 | @property 155 | def denoised(self): 156 | return self.hcur.denoised 157 | 158 | @property 159 | def dt(self): 160 | return self.sigma_next - self.sigma 161 | 162 | @property 163 | def d(self): 164 | return self.hcur.d 165 | 166 | def update(self, idx=None, step=None, substep=None): 167 | idx = self.idx if idx is None else idx 168 | self.idx = idx 169 | self.sigma_prev = None if idx < 1 else self.sigmas[idx - 1] 170 | self.sigma, self.sigma_next = self.sigmas[idx], self.sigmas[idx + 1] 171 | self.sigma_down, self.sigma_up = get_ancestral_step( 172 | self.sigma, self.sigma_next, eta=self.eta 173 | ) 174 | if step is not None: 175 | self.step = step 176 | if substep is not None: 177 | self.substep = substep 178 | self.refs = FilterRefs.from_ss(self) 179 | 180 | def get_ancestral_step( 181 | self, eta=1.0, sigma=None, sigma_next=None, retry_increment=0 182 | ): 183 | if self.model.is_rectified_flow: 184 | return self.get_ancestral_step_rf( 185 | eta=eta, 186 | sigma=sigma, 187 | sigma_next=sigma_next, 188 | retry_increment=retry_increment, 189 | ) 190 | sigma = fallback(sigma, self.sigma) 191 | sigma_next = fallback(sigma_next, self.sigma_next) 192 | if eta <= 0 or sigma_next <= 0: 193 | return sigma_next, sigma_next.new_zeros(1) 194 | while eta > 0: 195 | sd, su = ( 196 | v if isinstance(v, torch.Tensor) else sigma.new_full((1,), v) 197 | for v in get_ancestral_step( 198 | sigma, sigma_next, eta=eta if sigma_next != 0 else 0 199 | ) 200 | ) 201 | if sd > 0 and su > 0: 202 | return sd, su 203 | if retry_increment <= 0: 204 | break 205 | # print(f"\nETA {eta} failed, retrying with {eta - retry_increment}") 206 | eta -= retry_increment 207 | return sigma_next, sigma_next.new_zeros(1) 208 | 209 | # Referenced from Comfy dpmpp_2s_ancestral_RF 210 | def get_ancestral_step_rf( 211 | self, eta=1.0, sigma=None, sigma_next=None, retry_increment=0 212 | ): 213 | sigma = fallback(sigma, self.sigma) 214 | sigma_next = fallback(sigma_next, self.sigma_next) 215 | if eta <= 0 or sigma_next <= 0: 216 | return sigma_next, sigma_next.new_zeros(1) 217 | while eta > 0: 218 | sigma_down = sigma_next * (1 + (sigma_next / sigma - 1) * eta) 219 | alpha_ip1, alpha_down = 1 - sigma_next, 1 - sigma_down 220 | sigma_up = ( 221 | sigma_next**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2 222 | ) ** 0.5 223 | if sigma_down > 0 and sigma_up > 0: 224 | return sigma_down, sigma_up 225 | if retry_increment <= 0: 226 | break 227 | eta -= retry_increment 228 | return sigma_next, sigma_next.new_zeros(1) 229 | # print(f"\nRF ancestral: down={sigma_down}, up={sigma_up}") 230 | 231 | def clone_edit(self, **kwargs): 232 | obj = self.__class__.__new__(self.__class__) 233 | for k in self.CLONE_KEYS: 234 | setattr(obj, k, kwargs[k] if k in kwargs else getattr(self, k)) 235 | obj.update() 236 | return obj 237 | 238 | def callback(self, hi=None, *, preview_mode="denoised"): 239 | if not self.callback_: 240 | return None 241 | hi = self.hcur if hi is None else hi 242 | if preview_mode == "cond": 243 | preview = fallback(hi.denoised_cond, hi.denoised) 244 | elif preview_mode == "uncond": 245 | preview = fallback(hi.denoised_uncond, hi.denoised) 246 | elif preview_mode == "raw": 247 | preview = hi.x 248 | elif ( 249 | preview_mode == "diff" 250 | and hi.denoised_uncond is not None 251 | and hi.denoised_cond is not None 252 | ): 253 | preview = ( 254 | hi.denoised_uncond * 0.25 + (hi.denoised_uncond - hi.denoised_cond) * 16 255 | ) 256 | elif preview_mode == "noisy": 257 | preview = (hi.x - hi.denoised) * 0.1 + hi.denoised 258 | else: 259 | preview = hi.denoised 260 | return self.callback_({ 261 | "x": hi.x, 262 | "i": self.step, 263 | "sigma": hi.sigma, 264 | "sigma_hat": hi.sigma, 265 | "denoised": preview, 266 | }) 267 | 268 | def reset(self): 269 | self.hist.reset() 270 | self.denoised = None 271 | 272 | def call_model(self, *args, **kwargs): 273 | cfg_scale_override = kwargs.pop("cfg_scale_override", self.cfg_scale_override) 274 | return self.model(*args, cfg_scale_override=cfg_scale_override, **kwargs) 275 | -------------------------------------------------------------------------------- /py/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | import torch 4 | 5 | from comfy.k_diffusion.sampling import to_d 6 | 7 | # def scale_noise_( 8 | # noise, 9 | # factor=1.0, 10 | # *, 11 | # normalized=True, 12 | # normalize_dims=(-3, -2, -1), 13 | # ): 14 | # if not normalized or noise.numel() == 0: 15 | # return noise.mul_(factor) if factor != 1 else noise 16 | # mean, std = ( 17 | # noise.mean(dim=normalize_dims, keepdim=True), 18 | # noise.std(dim=normalize_dims, keepdim=True), 19 | # ) 20 | # return latent.normalize_to_scale( 21 | # noise.sub_(mean).div_(std).clamp(-1, 1), -1.0, 1.0, dim=normalize_dims 22 | # ).mul_(factor) 23 | 24 | 25 | # def scale_noise( 26 | # noise, 27 | # factor=1.0, 28 | # *, 29 | # normalized=True, 30 | # normalize_dims=(-3, -2, -1), 31 | # ): 32 | # if not normalized or noise.numel() == 0: 33 | # return noise * factor if factor != 1 else noise 34 | # mean, std = ( 35 | # noise.mean(dim=normalize_dims, keepdim=True), 36 | # noise.std(dim=normalize_dims, keepdim=True), 37 | # ) 38 | # return (noise - mean).div_(std).mul_(factor) 39 | 40 | 41 | def scale_noise( 42 | noise, 43 | factor=1.0, 44 | *, 45 | normalized=True, 46 | normalize_dims=(-3, -2, -1), 47 | ): 48 | if not normalized or noise.numel() == 0: 49 | return noise * factor if factor != 1 else noise 50 | noise = noise / noise.std(dim=normalize_dims, keepdim=True) 51 | return noise.sub_(noise.mean(dim=normalize_dims, keepdim=True)).mul_(factor) 52 | 53 | 54 | # def scale_noise( 55 | # noise, 56 | # factor=1.0, 57 | # *, 58 | # normalized=True, 59 | # normalize_dims=(-3, -2, -1), 60 | # ): 61 | # if not normalized or noise.numel() == 0: 62 | # return noise.mul_(factor) if factor != 1 else noise 63 | # n = ( 64 | # torch.nn.LayerNorm(noise.shape[1:]) 65 | # if normalize_dims == (-3, -2, -1) 66 | # else torch.nn.InstanceNorm2d(noise.shape[1]) 67 | # ).to(noise) 68 | # return n(noise) * factor 69 | # return latent.normalize_to_scale( 70 | # n(noise).clamp_(-1, 1), -1, 1, dim=normalize_dims 71 | # ).mul_(factor) 72 | 73 | 74 | def find_first_unsorted(tensor, desc=True): 75 | if not (len(tensor.shape) and tensor.shape[0]): 76 | return None 77 | fun = torch.gt if desc else torch.lt 78 | first_unsorted = fun(tensor[1:], tensor[:-1]).nonzero().flatten()[:1].add_(1) 79 | return None if not len(first_unsorted) else first_unsorted.item() 80 | 81 | 82 | def fallback(val, default, exclude=None): 83 | return val if val is not exclude else default 84 | 85 | 86 | def step_generator(gen, *, get_next, initial=None): 87 | next_val = initial 88 | with contextlib.suppress(StopIteration): 89 | while True: 90 | result = gen.send(next_val) 91 | next_val = get_next(result) 92 | yield result 93 | 94 | 95 | # From Gaeros. Thanks! 96 | def extract_pred(x_before, x_after, sigma_before, sigma_after): 97 | if sigma_after == 0: 98 | return x_after, torch.zeros_like(x_after) 99 | alpha = sigma_after / sigma_before 100 | denoised = (x_after - alpha * x_before) / (1 - alpha) 101 | return denoised, to_d(x_after, sigma_after, denoised) 102 | 103 | 104 | def resolve_value(keys, obj): 105 | if not len(keys): 106 | raise ValueError("Cannot resolve empty key list") 107 | result = obj 108 | 109 | class Empty: 110 | pass 111 | 112 | for idx, key in enumerate(keys): 113 | if not (hasattr(result, "__getattr__") or hasattr(obj, "__getattribute__")): 114 | raise ValueError( 115 | f"Cannot access key {key}: value does not support attribute access" 116 | ) 117 | result = getattr(result, key, Empty) 118 | if result is Empty: 119 | raise AttributeError(f"Key {key} from path {'.'.join(keys)} does not exist") 120 | 121 | 122 | def check_time(time_mode, time_start, time_end, sigma, step, steps): 123 | step_pct = step / steps if steps != 0 else 0.0 124 | if time_mode == "step": 125 | return time_start <= step <= time_end 126 | if time_mode == "step_pct": 127 | return time_start <= step_pct <= time_end 128 | if time_mode == "sigma": 129 | return time_start >= sigma >= time_end 130 | raise ValueError("Bad time mode") 131 | --------------------------------------------------------------------------------