├── .github
└── workflows
│ └── publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── adv_control
├── control.py
├── control_ctrlora.py
├── control_lllite.py
├── control_plusplus.py
├── control_reference.py
├── control_sparsectrl.py
├── control_svd.py
├── dinklink.py
├── documentation.py
├── logger.py
├── nodes.py
├── nodes_ctrlora.py
├── nodes_deprecated.py
├── nodes_keyframes.py
├── nodes_loosecontrol.py
├── nodes_main.py
├── nodes_plusplus.py
├── nodes_reference.py
├── nodes_sparsectrl.py
├── nodes_weight.py
├── sampling.py
└── utils.py
├── pyproject.toml
└── web
└── js
├── autosize.js
└── documentation.js
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish to Comfy registry
2 | on:
3 | workflow_dispatch:
4 | push:
5 | branches:
6 | - main
7 | paths:
8 | - "pyproject.toml"
9 |
10 | permissions:
11 | issues: write
12 |
13 | jobs:
14 | publish-node:
15 | name: Publish Custom Node to registry
16 | runs-on: ubuntu-latest
17 | if: ${{ github.repository_owner == 'Kosinkadink' }}
18 | steps:
19 | - name: Check out code
20 | uses: actions/checkout@v4
21 | - name: Publish Custom Node
22 | uses: Comfy-Org/publish-node-action@v1
23 | with:
24 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} ## Add your own personal access token to your Github Repository secrets and reference it here.
25 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ComfyUI-Advanced-ControlNet
2 | Nodes for scheduling ControlNet strength across timesteps and batched latents, as well as applying custom weights and attention masks. The ControlNet nodes here fully support sliding context sampling, like the one used in the [ComfyUI-AnimateDiff-Evolved](https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) nodes. Currently supports ControlNets, T2IAdapters, ControlLoRAs, ControlLLLite, SparseCtrls, SVD-ControlNets, and Reference.
3 |
4 | Custom weights allow replication of the "My prompt is more important" feature of Auto1111's sd-webui ControlNet extension via Soft Weights, and the "ControlNet is more important" feature can be granularly controlled by changing the uncond_multiplier on the same Soft Weights.
5 |
6 | ControlNet preprocessors are available through [comfyui_controlnet_aux](https://github.com/Fannovel16/comfyui_controlnet_aux) nodes.
7 |
8 | ## Features
9 | - Timestep and latent strength scheduling
10 | - Attention masks
11 | - Replicate ***"My prompt is more important"*** feature from sd-webui-controlnet extension via ***Soft Weights***, and allow softness to be tweaked via ***base_multiplier***
12 | - Replicate ***"ControlNet is more important"*** feature from sd-webui-controlnet extension via ***uncond_multiplier*** on ***Soft Weights***
13 | - uncond_multiplier=0.0 gives identical results of auto1111's feature, but values between 0.0 and 1.0 can be used without issue to granularly control the setting.
14 | - ControlNet, T2IAdapter, and ControlLoRA support for sliding context windows
15 | - ControlLLLite support
16 | - ControlNet++ support
17 | - CtrLoRA support
18 | - Relevant models linked on [CtrLoRA github page](https://github.com/xyfJASON/ctrlora)
19 | - SparseCtrl support
20 | - SVD-ControlNet support
21 | - Stable Video Diffusion ControlNets trained by **CiaraRowles**: [Depth](https://huggingface.co/CiaraRowles/temporal-controlnet-depth-svd-v1/tree/main/controlnet), [Lineart](https://huggingface.co/CiaraRowles/temporal-controlnet-lineart-svd-v1/tree/main/controlnet)
22 | - Reference support
23 | - Supports ```reference_attn```, ```reference_adain```, and ```refrence_adain+attn``` modes. ```style_fidelity``` and ```ref_weight``` are equivalent to style_fidelity and control_weight in Auto1111, respectively, and strength of the Apply ControlNet is the balance between ref-influenced result and no-ref result. There is also a Reference ControlNet (Finetune) node that allows adjust the style_fidelity, weight, and strength of attn and adain separately.
24 |
25 | ## Table of Contents:
26 | - [Scheduling Explanation](#scheduling-explanation)
27 | - [Nodes](#nodes)
28 | - [Usage](#usage) (will fill this out soon)
29 |
30 |
31 | # Scheduling Explanation
32 |
33 | The two core concepts for scheduling are ***Timestep Keyframes*** and ***Latent Keyframes***.
34 |
35 | ***Timestep Keyframes*** hold the values that guide the settings for a controlnet, and begin to take effect based on their start_percent, which corresponds to the percentage of the sampling process. They can contain masks for the strengths of each latent, control_net_weights, and latent_keyframes (specific strengths for each latent), all optional.
36 |
37 | ***Latent Keyframes*** determine the strength of the controlnet for specific latents - all they contain is the batch_index of the latent, and the strength the controlnet should apply for that latent. As a concept, latent keyframes achieve the same affect as a uniform mask with the chosen strength value.
38 |
39 | 
40 |
41 | # Nodes
42 |
43 | The ControlNet nodes provided here are the ***Apply Advanced ControlNet*** and ***Load Advanced ControlNet Model*** (or diff) nodes. The vanilla ControlNet nodes are also compatible, and can be used almost interchangeably - the only difference is that **at least one of these nodes must be used** for Advanced versions of ControlNets to be used (important for sliding context sampling, like with AnimateDiff-Evolved).
44 |
45 | Key:
46 | - 🟩 - required inputs
47 | - 🟨 - optional inputs
48 | - 🟦 - start as widgets, can be converted to inputs
49 | - 🟥 - optional input/output, but not recommended to use unless needed
50 | - 🟪 - output
51 |
52 | ## Apply Advanced ControlNet
53 | 
54 |
55 | Same functionality as the vanilla Apply Advanced ControlNet (Advanced) node, except with Advanced ControlNet features added to it. Automatically converts any ControlNet from ControlNet loaders into Advanced versions.
56 |
57 | ### Inputs
58 | - 🟩***positive***: conditioning (positive).
59 | - 🟩***negative***: conditioning (negative).
60 | - 🟩***control_net***: loaded controlnet; will be converted to Advanced version automatically by this node, if it's a supported type.
61 | - 🟩***image***: images to guide controlnets - if the loaded controlnet requires it, they must preprocessed images. If one image provided, will be used for all latents. If more images provided, will use each image separately for each latent. If not enough images to meet latent count, will repeat the images from the beginning to match vanilla ControlNet functionality.
62 | - 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as image input, if you provide more than one mask, each can apply to a different latent.
63 | - 🟨***timestep_kf***: timestep keyframes to guide controlnet effect throughout sampling steps.
64 | - 🟨***latent_kf_override***: override for latent keyframes, useful if no other features from timestep keyframes is needed. *NOTE: this latent keyframe will be applied to ALL timesteps, regardless if there are other latent keyframes attached to connected timestep keyframes.*
65 | - 🟨***weights_override***: override for weights, useful if no other features from timestep keyframes is needed. *NOTE: this weight will be applied to ALL timesteps, regardless if there are other weights attached to connected timestep keyframes.*
66 | - 🟦***strength***: strength of controlnet; 1.0 is full strength, 0.0 is no effect at all.
67 | - 🟦***start_percent***: sampling step percentage at which controlnet should start to be applied - no matter what start_percent is set on timestep keyframes, they won't take effect until this start_percent is reached.
68 | - 🟦***stop_percent***: sampling step percentage at which controlnet should stop being applied - no matter what start_percent is set on timestep keyframes, they won't take effect once this end_percent is reached.
69 |
70 | ### Outputs
71 | - 🟪***positive***: conditioning (positive) with applied controlnets
72 | - 🟪***negative***: conditioning (negative) with applied controlnets
73 |
74 | ## Load Advanced ControlNet Model
75 | 
76 |
77 | Loads a ControlNet model and converts it into an Advanced version that supports all the features in this repo. When used with **Apply Advanced ControlNet** node, there is no reason to use the timestep_keyframe input on this node - use timestep_kf on the Apply node instead.
78 |
79 | ### Inputs
80 | - 🟥***timestep_keyframe***: optional and likely unnecessary input to have ControlNet use selected timestep_keyframes - should not be used unless you need to. Useful if this node is not attached to **Apply Advanced ControlNet** node, but still want to use Timestep Keyframe, or to use TK_SHORTCUT outputs from ControlWeights in the same scenario. Will be overriden by the timestep_kf input on **Apply Advanced ControlNet** node, if one is provided there.
81 | - 🟨***model***: model to plug into the diff version of the node. Some controlnets are designed for receive the model; if you don't know what this does, you probably don't want tot use the diff version of the node.
82 |
83 | ### Outputs
84 | - 🟪***CONTROL_NET***: loaded Advanced ControlNet
85 |
86 | ## Timestep Keyframe
87 | 
88 |
89 | Scheduling node across timesteps (sampling steps) based on the set start_percent. Chaining Timestep Keyframes allows ControlNet scheduling across sampling steps (percentage-wise), through a timestep keyframe schedule.
90 |
91 | ### Inputs
92 | - 🟨***prev_timestep_kf***: used to chain Timestep Keyframes together to create a schedule. The order does not matter - the Timestep Keyframes sort themselves automatically by their start_percent. *Any Timestep Keyframe contained in the prev_timestep_keyframe that contains the same start_percent as the Timestep Keyframe will be overwritten.*
93 | - 🟨***cn_weights***: weights to apply to controlnet while this Timestep Keyframe is in effect. Must be compatible with the loaded controlnet, or will throw an error explaining what weight types are compatible. If inherit_missing is True, if no control_net_weight is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a weight_override, the weight_override will be used during sampling instead of control_net_weight.*
94 | - 🟨***latent_keyframe***: latent keyframes to apply to controlnet while this Timestep Keyframe is in effect. If inherit_missing is True, if no latent_keyframe is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a latent_kf_override, the latent_lf_override will be used during sampling instead of latent_keyframe.*
95 | - 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as mask_optional on the Apply Advanced ControlNet node, can apply either one maks to all latents, or individual masks for each latent. If inherit_missing is True, if no mask_optional is passed in, will attempt to reuse the last-used mask_optional in the timestep keyframe schedule. It is NOT overriden by mask_optional on the Apply Advanced ControlNet node; will be used together.
96 | - 🟦***start_percent***: sampling step percentage at which this Timestep Keyframe qualifies to be used. Acts as the 'key' for the Timestep Keyframe in the timestep keyframe schedule.
97 | - 🟦***strength***: strength of the controlnet; multiplies the controlnet by this value, basically, applied alongside the strength on the Apply ControlNet node. If set to 0.0 will not have any effect during the duration of this Timestep Keyframe's effect, and will increase sampling speed by not doing any work.
98 | - 🟦***null_latent_kf_strength***: strength to assign to latents that are unaccounted for in the passed in latent_keyframes. Has no effect if no latent_keyframes are passed in, or no batch_indeces are unaccounted in the latent_keyframes for during sampling.
99 | - 🟦***inherit_missing***: determines if should reuse values from previous Timestep Keyframes for optional values (control_net_weights, latent_keyframe, and mask_option) that are not included on this TimestepKeyframe. To inherit only specific inputs, use default inputs.
100 | - 🟦***guarantee_steps***: when 1 or greater, even if a Timestep Keyframe's start_percent ahead of this one in the schedule is closer to current sampling percentage, this Timestep Keyframe will still be used for the specified amount of steps before moving on to the next selected Timestep Keyframe in the following step. Whether the Timestep Keyframe is used or not, its inputs will still be accounted for inherit_missing purposes.
101 |
102 | ### Outputs
103 | - 🟪***TIMESTEP_KF***: the created Timestep Keyframe, that can either be linked to another or into a Timestep Keyframe input.
104 |
105 | ## Timestep Keyframe Interpolation
106 | 
107 |
108 | Allows to create Timestep Keyframe with interpolated strength values in a given percent range. (The first generated keyframe will have guarantee_steps=1, rest that follow will have guarantee_steps=0).
109 |
110 | ### Inputs
111 | - 🟨***prev_timestep_kf***: used to chain Timestep Keyframes together to create a schedule. The order does not matter - the Timestep Keyframes sort themselves automatically by their start_percent. *Any Timestep Keyframe contained in the prev_timestep_keyframe that contains the same start_percent as the Timestep Keyframe will be overwritten.*
112 | - 🟨***cn_weights***: weights to apply to controlnet while this Timestep Keyframe is in effect. Must be compatible with the loaded controlnet, or will throw an error explaining what weight types are compatible. If inherit_missing is True, if no control_net_weight is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a weight_override, the weight_override will be used during sampling instead of control_net_weight.*
113 | - 🟨***latent_keyframe***: latent keyframes to apply to controlnet while this Timestep Keyframe is in effect. If inherit_missing is True, if no latent_keyframe is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a latent_kf_override, the latent_lf_override will be used during sampling instead of latent_keyframe.*
114 | - 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as mask_optional on the Apply Advanced ControlNet node, can apply either one maks to all latents, or individual masks for each latent. If inherit_missing is True, if no mask_optional is passed in, will attempt to reuse the last-used mask_optional in the timestep keyframe schedule. It is NOT overriden by mask_optional on the Apply Advanced ControlNet node; will be used together.
115 | - 🟦***start_percent***: sampling step percentage at which the first generated Timestep Keyframe qualifies to be used.
116 | - 🟦***end_percent***: sampling step percentage at which the last generated Timestep Keyframe qualifies to be used.
117 | - 🟦***strength_start***: strength of the Timestep Keyframe at start of range.
118 | - 🟦***strength_end***: strength of the Timestep Keyframe at end of range.
119 | - 🟦***interpolation***: the method of interpolation.
120 | - 🟦***intervals***: the amount of keyframes to generate in total - the first will have its start_percent equal to start_percent, the last will have its start_percent equal to end_percent.
121 | - 🟦***null_latent_kf_strength***: strength to assign to latents that are unaccounted for in the passed in latent_keyframes. Has no effect if no latent_keyframes are passed in, or no batch_indeces are unaccounted in the latent_keyframes for during sampling.
122 | - 🟦***inherit_missing***: determines if should reuse values from previous Timestep Keyframes for optional values (control_net_weights, latent_keyframe, and mask_option) that are not included on this TimestepKeyframe. To inherit only specific inputs, use default inputs.
123 | - 🟦***print_keyframes***: if True, will print the Timestep Keyframes generated by this node for debugging purposes.
124 |
125 | ### Outputs
126 | - 🟪***TIMESTEP_KF***: the created Timestep Keyframe, that can either be linked to another or into a Timestep Keyframe input.
127 |
128 | ## Timestep Keyframe From List
129 | 
130 |
131 | Allows to create Timestep Keyframe via a list of floats, such as with Batch Value Schedule from [ComfyUI_FizzNodes](https://github.com/FizzleDorf/ComfyUI_FizzNodes) nodes. (The first generated keyframe will have guarantee_steps=1, rest that follow will have guarantee_steps=0).
132 |
133 | ### Inputs
134 | - 🟨***prev_timestep_kf***: used to chain Timestep Keyframes together to create a schedule. The order does not matter - the Timestep Keyframes sort themselves automatically by their start_percent. *Any Timestep Keyframe contained in the prev_timestep_keyframe that contains the same start_percent as the Timestep Keyframe will be overwritten.*
135 | - 🟨***cn_weights***: weights to apply to controlnet while this Timestep Keyframe is in effect. Must be compatible with the loaded controlnet, or will throw an error explaining what weight types are compatible. If inherit_missing is True, if no control_net_weight is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a weight_override, the weight_override will be used during sampling instead of control_net_weight.*
136 | - 🟨***latent_keyframe***: latent keyframes to apply to controlnet while this Timestep Keyframe is in effect. If inherit_missing is True, if no latent_keyframe is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a latent_kf_override, the latent_lf_override will be used during sampling instead of latent_keyframe.*
137 | - 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as mask_optional on the Apply Advanced ControlNet node, can apply either one maks to all latents, or individual masks for each latent. If inherit_missing is True, if no mask_optional is passed in, will attempt to reuse the last-used mask_optional in the timestep keyframe schedule. It is NOT overriden by mask_optional on the Apply Advanced ControlNet node; will be used together.
138 | - 🟩***float_strengths***: a list of floats, that will correspond to the strength of each Timestep Keyframe; first will be assigned to start_percent, last will be assigned to end_percent, and the rest spread linearly between.
139 | - 🟦***start_percent***: sampling step percentage at which the first generated Timestep Keyframe qualifies to be used.
140 | - 🟦***end_percent***: sampling step percentage at which the last generated Timestep Keyframe qualifies to be used.
141 | - 🟦***null_latent_kf_strength***: strength to assign to latents that are unaccounted for in the passed in latent_keyframes. Has no effect if no latent_keyframes are passed in, or no batch_indeces are unaccounted in the latent_keyframes for during sampling.
142 | - 🟦***inherit_missing***: determines if should reuse values from previous Timestep Keyframes for optional values (control_net_weights, latent_keyframe, and mask_option) that are not included on this TimestepKeyframe. To inherit only specific inputs, use default inputs.
143 | - 🟦***print_keyframes***: if True, will print the Timestep Keyframes generated by this node for debugging purposes.
144 |
145 | ### Outputs
146 | - 🟪***TIMESTEP_KF***: the created Timestep Keyframe, that can either be linked to another or into a Timestep Keyframe input.
147 |
148 | ## Latent Keyframe
149 | 
150 |
151 | A singular Latent Keyframe, selects the strength for a specific batch_index. If batch_index is not present during sampling, will simply have no effect. Can be chained with any other Latent Keyframe-type node to create a latent keyframe schedule.
152 |
153 | ### Inputs
154 | - 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If a Latent Keyframe contained in prev_latent_keyframes have the same batch_index as this Latent Keyframe, they will take priority over this node's value.*
155 | - 🟦***batch_index***: index of latent in batch to apply controlnet strength to. Acts as the 'key' for the Latent Keyframe in the latent keyframe schedule.
156 | - 🟦***strength***: strength of controlnet to apply to the corresponding latent.
157 |
158 | ### Outputs
159 | - 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input.
160 |
161 | ## Latent Keyframe Group
162 | 
163 |
164 | Allows to create Latent Keyframes via individual indeces or python-style ranges.
165 |
166 | ### Inputs
167 | - 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If any Latent Keyframes contained in prev_latent_keyframes have the same batch_index as a this Latent Keyframe, they will take priority over this node's version.*
168 | - 🟨***latent_optional***: the latents expected to be passed in for sampling; only required if you wish to use negative indeces (will be automatically converted to real values).
169 | - 🟦***index_strengths***: string list of indeces or python-style ranges of indeces to assign strengths to. If latent_optional is passed in, can contain negative indeces or ranges that contain negative numbers, python-style. The different indeces must be comma separated. Individual latents can be specified by ```batch_index=strength```, like ```0=0.9```. Ranges can be specified by ```start_index_inclusive:end_index_exclusive=strength```, like ```0:8=strength```. Negative indeces are possible when latents_optional has an input, with a string such as ```0,-4=0.25```.
170 | - 🟦***print_keyframes***: if True, will print the Latent Keyframes generated by this node for debugging purposes.
171 |
172 | ### Outputs
173 | - 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input.
174 |
175 | ## Latent Keyframe Interpolation
176 | 
177 |
178 | Allows to create Latent Keyframes with interpolated values in a range.
179 |
180 | ### Inputs
181 | - 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If any Latent Keyframes contained in prev_latent_keyframes have the same batch_index as a this Latent Keyframe, they will take priority over this node's version.*
182 | - 🟦***batch_index_from***: starting batch_index of range, included.
183 | - 🟦***batch_index_to***: end batch_index of range, excluded (python-style range).
184 | - 🟦***strength_from***: starting strength of interpolation.
185 | - 🟦***strength_to***: end strength of interpolation.
186 | - 🟦***interpolation***: the method of interpolation.
187 | - 🟦***print_keyframes***: if True, will print the Latent Keyframes generated by this node for debugging purposes.
188 |
189 | ### Outputs
190 | - 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input.
191 |
192 | ## Latent Keyframe From List
193 | 
194 |
195 | Allows to create Latent Keyframes via a list of floats, such as with Batch Value Schedule from [ComfyUI_FizzNodes](https://github.com/FizzleDorf/ComfyUI_FizzNodes) nodes.
196 |
197 | ### Inputs
198 | - 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If any Latent Keyframes contained in prev_latent_keyframes have the same batch_index as a this Latent Keyframe, they will take priority over this node's version.*
199 | - 🟩***float_strengths***: a list of floats, that will correspond to the strength of each Latent Keyframe; the batch_index is the index of each float value in the list.
200 | - 🟦***print_keyframes***: if True, will print the Latent Keyframes generated by this node for debugging purposes.
201 |
202 | ### Outputs
203 | - 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input.
204 |
205 | # There are more nodes to document and show usage - will add this soon! TODO
206 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .adv_control.nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
2 | from .adv_control import documentation
3 | from .adv_control.dinklink import init_dinklink
4 | from .adv_control.sampling import prepare_dinklink_acn_wrapper
5 |
6 | WEB_DIRECTORY = "./web"
7 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS', "WEB_DIRECTORY"]
8 | documentation.format_descriptions(NODE_CLASS_MAPPINGS)
9 |
10 | init_dinklink()
11 | prepare_dinklink_acn_wrapper()
12 |
--------------------------------------------------------------------------------
/adv_control/control_ctrlora.py:
--------------------------------------------------------------------------------
1 | # Core code adapted from CtrLoRA github repo:
2 | # https://github.com/xyfJASON/ctrlora
3 | import torch
4 | from torch import Tensor
5 |
6 | from comfy.cldm.cldm import ControlNet as ControlNetCLDM
7 | import comfy.model_detection
8 | import comfy.model_management
9 | import comfy.ops
10 | import comfy.utils
11 |
12 | from comfy.ldm.modules.diffusionmodules.util import (
13 | zero_module,
14 | timestep_embedding,
15 | )
16 |
17 | from .control import ControlNetAdvanced
18 | from .utils import TimestepKeyframeGroup
19 | from .logger import logger
20 |
21 |
22 | class ControlNetCtrLoRA(ControlNetCLDM):
23 | def __init__(self, *args, **kwargs):
24 | super().__init__(*args, **kwargs)
25 | # delete input hint block
26 | del self.input_hint_block
27 |
28 | def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs):
29 | t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
30 | emb = self.time_embed(t_emb)
31 |
32 | out_output = []
33 | out_middle = []
34 |
35 | if self.num_classes is not None:
36 | assert y.shape[0] == x.shape[0]
37 | emb = emb + self.label_emb(y)
38 |
39 | h = hint.to(dtype=x.dtype)
40 | for module, zero_conv in zip(self.input_blocks, self.zero_convs):
41 | h = module(h, emb, context)
42 | out_output.append(zero_conv(h, emb, context))
43 |
44 | h = self.middle_block(h, emb, context)
45 | out_middle.append(self.middle_block_out(h, emb, context))
46 |
47 | return {"middle": out_middle, "output": out_output}
48 |
49 |
50 | class CtrLoRAAdvanced(ControlNetAdvanced):
51 | def __init__(self, *args, **kwargs):
52 | super().__init__(*args, **kwargs)
53 | self.preprocess_image = lambda a: (a + 1) / 2.0
54 | self.require_vae = True
55 | self.mult_by_ratio_when_vae = False
56 |
57 | def pre_run_advanced(self, model, percent_to_timestep_function):
58 | super().pre_run_advanced(model, percent_to_timestep_function)
59 | self.latent_format = model.latent_format # LatentFormat object, used to process_in latent cond hint
60 |
61 | def cleanup_advanced(self):
62 | super().cleanup_advanced()
63 | if self.latent_format is not None:
64 | del self.latent_format
65 | self.latent_format = None
66 |
67 | def copy(self):
68 | c = CtrLoRAAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
69 | c.control_model = self.control_model
70 | c.control_model_wrapped = self.control_model_wrapped
71 | self.copy_to(c)
72 | self.copy_to_advanced(c)
73 | return c
74 |
75 |
76 | def load_ctrlora(base_path: str, lora_path: str,
77 | base_data: dict[str, Tensor]=None, lora_data: dict[str, Tensor]=None,
78 | timestep_keyframe: TimestepKeyframeGroup=None, model=None, model_options={}):
79 | if base_data is None:
80 | base_data = comfy.utils.load_torch_file(base_path, safe_load=True)
81 | controlnet_data = base_data
82 |
83 | # first, check that base_data contains keys with lora_layer
84 | contains_lora_layers = False
85 | for key in base_data:
86 | if "lora_layer" in key:
87 | contains_lora_layers = True
88 | if not contains_lora_layers:
89 | raise Exception(f"File '{base_path}' is not a valid CtrLoRA base model; does not contain any lora_layer keys.")
90 |
91 | controlnet_config = None
92 | supported_inference_dtypes = None
93 |
94 | pth_key = 'control_model.zero_convs.0.0.weight'
95 | pth = False
96 | key = 'zero_convs.0.0.weight'
97 | if pth_key in controlnet_data:
98 | pth = True
99 | key = pth_key
100 | prefix = "control_model."
101 | elif key in controlnet_data:
102 | prefix = ""
103 | else:
104 | raise Exception("")
105 | net = load_t2i_adapter(controlnet_data, model_options=model_options)
106 | if net is None:
107 | logging.error("error could not detect control model type.")
108 | return net
109 |
110 | if controlnet_config is None:
111 | model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
112 | supported_inference_dtypes = list(model_config.supported_inference_dtypes)
113 | controlnet_config = model_config.unet_config
114 |
115 | unet_dtype = model_options.get("dtype", None)
116 | if unet_dtype is None:
117 | weight_dtype = comfy.utils.weight_dtype(controlnet_data)
118 |
119 | if supported_inference_dtypes is None:
120 | supported_inference_dtypes = [comfy.model_management.unet_dtype()]
121 |
122 | if weight_dtype is not None:
123 | supported_inference_dtypes.append(weight_dtype)
124 |
125 | unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
126 |
127 | load_device = comfy.model_management.get_torch_device()
128 |
129 | manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
130 | operations = model_options.get("custom_operations", None)
131 | if operations is None:
132 | operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
133 |
134 | controlnet_config["operations"] = operations
135 | controlnet_config["dtype"] = unet_dtype
136 | controlnet_config["device"] = comfy.model_management.unet_offload_device()
137 | controlnet_config.pop("out_channels")
138 | controlnet_config["hint_channels"] = 3
139 | #controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
140 | control_model = ControlNetCtrLoRA(**controlnet_config)
141 |
142 | if pth:
143 | if 'difference' in controlnet_data:
144 | if model is not None:
145 | comfy.model_management.load_models_gpu([model])
146 | model_sd = model.model_state_dict()
147 | for x in controlnet_data:
148 | c_m = "control_model."
149 | if x.startswith(c_m):
150 | sd_key = "diffusion_model.{}".format(x[len(c_m):])
151 | if sd_key in model_sd:
152 | cd = controlnet_data[x]
153 | cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
154 | else:
155 | logger.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
156 |
157 | class WeightsLoader(torch.nn.Module):
158 | pass
159 | w = WeightsLoader()
160 | w.control_model = control_model
161 | missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
162 | else:
163 | missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
164 |
165 | if len(missing) > 0:
166 | logger.warning("missing controlnet keys: {}".format(missing))
167 |
168 | if len(unexpected) > 0:
169 | logger.debug("unexpected controlnet keys: {}".format(unexpected))
170 |
171 | global_average_pooling = model_options.get("global_average_pooling", False)
172 | control = CtrLoRAAdvanced(control_model, timestep_keyframe, global_average_pooling=global_average_pooling,
173 | load_device=load_device, manual_cast_dtype=manual_cast_dtype)
174 | # load lora data onto the controlnet
175 | if lora_path is not None:
176 | load_lora_data(control, lora_path)
177 |
178 | return control
179 |
180 |
181 | def load_lora_data(control: CtrLoRAAdvanced, lora_path: str, loaded_data: dict[str, Tensor]=None, lora_strength=1.0):
182 | if loaded_data is None:
183 | loaded_data = comfy.utils.load_torch_file(lora_path, safe_load=True)
184 | # check that lora_data contains keys with lora_layer
185 | contains_lora_layers = False
186 | for key in loaded_data:
187 | if "lora_layer" in key:
188 | contains_lora_layers = True
189 | if not contains_lora_layers:
190 | raise Exception(f"File '{lora_path}' is not a valid CtrLoRA lora model; does not contain any lora_layer keys.")
191 |
192 | # now that we know we have a ctrlora file, separate keys into 'set' and 'lora' keys
193 | data_set: dict[str, Tensor] = {}
194 | data_lora: dict[str, Tensor] = {}
195 |
196 | for key in list(loaded_data.keys()):
197 | if 'lora_layer' in key:
198 | data_lora[key] = loaded_data.pop(key)
199 | else:
200 | data_set[key] = loaded_data.pop(key)
201 | # no keys should be left over
202 | if len(loaded_data) > 0:
203 | logger.warning("Not all keys from CtrlLoRA lora model's loaded data were parsed!")
204 |
205 | # turn set/lora data into corresponding patches;
206 | patches = {}
207 | # set will replace the values
208 | for key, value in data_set.items():
209 | # prase model key from key;
210 | # remove "control_model."
211 | model_key = key.replace("control_model.", "")
212 | patches[model_key] = ("set", (value,))
213 | # lora will do mm of up and down tensors
214 | for down_key in data_lora:
215 | # only process lora down keys; we will process both up+down at the same time
216 | if ".up." in down_key:
217 | continue
218 | # get up version of down key
219 | up_key = down_key.replace(".down.", ".up.")
220 | # get key that will match up with model key;
221 | # remove "lora_layer.down." and "control_model."
222 | model_key = down_key.replace("lora_layer.down.", "").replace("control_model.", "")
223 |
224 | weight_down = data_lora[down_key]
225 | weight_up = data_lora[up_key]
226 | # currently, ComfyUI expects 6 elements in 'lora' type, but for future-proofing add a bunch more with None
227 | patches[model_key] = ("lora", (weight_up, weight_down, None, None, None, None,
228 | None, None, None, None, None, None, None, None))
229 |
230 | # now that patches are made, add them to model
231 | control.control_model_wrapped.add_patches(patches, strength_patch=lora_strength)
232 |
--------------------------------------------------------------------------------
/adv_control/control_lllite.py:
--------------------------------------------------------------------------------
1 | # adapted from https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI
2 | # basically, all the LLLite core code is from there, which I then combined with
3 | # Advanced-ControlNet features and QoL
4 | import math
5 | from typing import Union
6 | from torch import Tensor
7 | import torch
8 | import os
9 |
10 | import comfy.utils
11 | import comfy.ops
12 | import comfy.model_management
13 | from comfy.model_patcher import ModelPatcher
14 | from comfy.controlnet import ControlBase
15 |
16 | from .logger import logger
17 | from .utils import (AdvancedControlBase, TimestepKeyframeGroup, ControlWeights, broadcast_image_to_extend, extend_to_batch_size,
18 | prepare_mask_batch)
19 |
20 |
21 | # based on set_model_patch code in comfy/model_patcher.py
22 | def set_model_patch(transformer_options, patch, name):
23 | to = transformer_options
24 | # check if patch was already added
25 | if "patches" in to:
26 | current_patches = to["patches"].get(name, [])
27 | if patch in current_patches:
28 | return
29 | if "patches" not in to:
30 | to["patches"] = {}
31 | to["patches"][name] = to["patches"].get(name, []) + [patch]
32 |
33 | def set_model_attn1_patch(transformer_options, patch):
34 | set_model_patch(transformer_options, patch, "attn1_patch")
35 |
36 | def set_model_attn2_patch(transformer_options, patch):
37 | set_model_patch(transformer_options, patch, "attn2_patch")
38 |
39 |
40 | def extra_options_to_module_prefix(extra_options):
41 | # extra_options = {'transformer_index': 2, 'block_index': 8, 'original_shape': [2, 4, 128, 128], 'block': ('input', 7), 'n_heads': 20, 'dim_head': 64}
42 |
43 | # block is: [('input', 4), ('input', 5), ('input', 7), ('input', 8), ('middle', 0),
44 | # ('output', 0), ('output', 1), ('output', 2), ('output', 3), ('output', 4), ('output', 5)]
45 | # transformer_index is: [0, 1, 2, 3, 4, 5, 6, 7, 8], for each block
46 | # block_index is: 0-1 or 0-9, depends on the block
47 | # input 7 and 8, middle has 10 blocks
48 |
49 | # make module name from extra_options
50 | block = extra_options["block"]
51 | block_index = extra_options["block_index"]
52 | if block[0] == "input":
53 | module_pfx = f"lllite_unet_input_blocks_{block[1]}_1_transformer_blocks_{block_index}"
54 | elif block[0] == "middle":
55 | module_pfx = f"lllite_unet_middle_block_1_transformer_blocks_{block_index}"
56 | elif block[0] == "output":
57 | module_pfx = f"lllite_unet_output_blocks_{block[1]}_1_transformer_blocks_{block_index}"
58 | else:
59 | raise Exception(f"ControlLLLite: invalid block name '{block[0]}'. Expected 'input', 'middle', or 'output'.")
60 | return module_pfx
61 |
62 |
63 | class LLLitePatch:
64 | ATTN1 = "attn1"
65 | ATTN2 = "attn2"
66 | def __init__(self, modules: dict[str, 'LLLiteModule'], patch_type: str, control: Union[AdvancedControlBase, ControlBase]=None):
67 | self.modules = modules
68 | self.control = control
69 | self.patch_type = patch_type
70 | #logger.error(f"create LLLitePatch: {id(self)},{control}")
71 |
72 | def __call__(self, q, k, v, extra_options):
73 | #logger.error(f"in __call__: {id(self)}")
74 | # determine if have anything to run
75 | if self.control.timestep_range is not None:
76 | # it turns out comparing single-value tensors to floats is extremely slow
77 | # a: Tensor = extra_options["sigmas"][0]
78 | if self.control.t > self.control.timestep_range[0] or self.control.t < self.control.timestep_range[1]:
79 | return q, k, v
80 |
81 | module_pfx = extra_options_to_module_prefix(extra_options)
82 |
83 | is_attn1 = q.shape[-1] == k.shape[-1] # self attention
84 | if is_attn1:
85 | module_pfx = module_pfx + "_attn1"
86 | else:
87 | module_pfx = module_pfx + "_attn2"
88 |
89 | module_pfx_to_q = module_pfx + "_to_q"
90 | module_pfx_to_k = module_pfx + "_to_k"
91 | module_pfx_to_v = module_pfx + "_to_v"
92 |
93 | if module_pfx_to_q in self.modules:
94 | q = q + self.modules[module_pfx_to_q](q, self.control)
95 | if module_pfx_to_k in self.modules:
96 | k = k + self.modules[module_pfx_to_k](k, self.control)
97 | if module_pfx_to_v in self.modules:
98 | v = v + self.modules[module_pfx_to_v](v, self.control)
99 |
100 | return q, k, v
101 |
102 | def to(self, device):
103 | #logger.info(f"to... has control? {self.control}")
104 | for d in self.modules.keys():
105 | self.modules[d] = self.modules[d].to(device)
106 | return self
107 |
108 | def set_control(self, control: Union[AdvancedControlBase, ControlBase]) -> 'LLLitePatch':
109 | self.control = control
110 | return self
111 | #logger.error(f"set control for LLLitePatch: {id(self)}, cn: {id(control)}")
112 |
113 | def clone_with_control(self, control: AdvancedControlBase):
114 | #logger.error(f"clone-set control for LLLitePatch: {id(self)},{id(control)}")
115 | return LLLitePatch(self.modules, self.patch_type, control)
116 |
117 | def cleanup(self):
118 | for module in self.modules.values():
119 | module.cleanup()
120 |
121 |
122 | # TODO: use comfy.ops to support fp8 properly
123 | class LLLiteModule(torch.nn.Module):
124 | def __init__(
125 | self,
126 | name: str,
127 | is_conv2d: bool,
128 | in_dim: int,
129 | depth: int,
130 | cond_emb_dim: int,
131 | mlp_dim: int,
132 | ):
133 | super().__init__()
134 | self.name = name
135 | self.is_conv2d = is_conv2d
136 | self.is_first = False
137 |
138 | modules = []
139 | modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2
140 | if depth == 1:
141 | modules.append(torch.nn.ReLU(inplace=True))
142 | modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
143 | elif depth == 2:
144 | modules.append(torch.nn.ReLU(inplace=True))
145 | modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
146 | elif depth == 3:
147 | # kernel size 8 is too large, so set it to 4
148 | modules.append(torch.nn.ReLU(inplace=True))
149 | modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
150 | modules.append(torch.nn.ReLU(inplace=True))
151 | modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
152 |
153 | self.conditioning1 = torch.nn.Sequential(*modules)
154 |
155 | if self.is_conv2d:
156 | self.down = torch.nn.Sequential(
157 | torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
158 | torch.nn.ReLU(inplace=True),
159 | )
160 | self.mid = torch.nn.Sequential(
161 | torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
162 | torch.nn.ReLU(inplace=True),
163 | )
164 | self.up = torch.nn.Sequential(
165 | torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
166 | )
167 | else:
168 | self.down = torch.nn.Sequential(
169 | torch.nn.Linear(in_dim, mlp_dim),
170 | torch.nn.ReLU(inplace=True),
171 | )
172 | self.mid = torch.nn.Sequential(
173 | torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
174 | torch.nn.ReLU(inplace=True),
175 | )
176 | self.up = torch.nn.Sequential(
177 | torch.nn.Linear(mlp_dim, in_dim),
178 | )
179 |
180 | self.depth = depth
181 | self.cond_emb = None
182 | self.cx_shape = None
183 | self.prev_batch = 0
184 | self.prev_sub_idxs = None
185 |
186 | def cleanup(self):
187 | del self.cond_emb
188 | self.cond_emb = None
189 | self.cx_shape = None
190 | self.prev_batch = 0
191 | self.prev_sub_idxs = None
192 |
193 | def forward(self, x: Tensor, control: Union[AdvancedControlBase, ControlBase]):
194 | mask = None
195 | mask_tk = None
196 | #logger.info(x.shape)
197 | if self.cond_emb is None or control.sub_idxs != self.prev_sub_idxs or x.shape[0] != self.prev_batch:
198 | # print(f"cond_emb is None, {self.name}")
199 | cond_hint = control.cond_hint.to(x.device, dtype=x.dtype)
200 | if control.latent_dims_div2 is not None and x.shape[-1] != 1280:
201 | cond_hint = comfy.utils.common_upscale(cond_hint, control.latent_dims_div2[0] * 8, control.latent_dims_div2[1] * 8, 'nearest-exact', "center").to(x.device, dtype=x.dtype)
202 | elif control.latent_dims_div4 is not None and x.shape[-1] == 1280:
203 | cond_hint = comfy.utils.common_upscale(cond_hint, control.latent_dims_div4[0] * 8, control.latent_dims_div4[1] * 8, 'nearest-exact', "center").to(x.device, dtype=x.dtype)
204 | cx = self.conditioning1(cond_hint)
205 | self.cx_shape = cx.shape
206 | if not self.is_conv2d:
207 | # reshape / b,c,h,w -> b,h*w,c
208 | n, c, h, w = cx.shape
209 | cx = cx.view(n, c, h * w).permute(0, 2, 1)
210 | self.cond_emb = cx
211 | # save prev values
212 | self.prev_batch = x.shape[0]
213 | self.prev_sub_idxs = control.sub_idxs
214 |
215 | cx: torch.Tensor = self.cond_emb
216 | # print(f"forward {self.name}, {cx.shape}, {x.shape}")
217 |
218 | # TODO: make masks work for conv2d (could not find any ControlLLLites at this time that use them)
219 | # create masks
220 | if not self.is_conv2d:
221 | n, c, h, w = self.cx_shape
222 | if control.mask_cond_hint is not None:
223 | mask = prepare_mask_batch(control.mask_cond_hint, (1, 1, h, w)).to(cx.dtype)
224 | mask = mask.view(mask.shape[0], 1, h * w).permute(0, 2, 1)
225 | if control.tk_mask_cond_hint is not None:
226 | mask_tk = prepare_mask_batch(control.mask_cond_hint, (1, 1, h, w)).to(cx.dtype)
227 | mask_tk = mask_tk.view(mask_tk.shape[0], 1, h * w).permute(0, 2, 1)
228 |
229 | # x in uncond/cond doubles batch size
230 | if x.shape[0] != cx.shape[0]:
231 | if self.is_conv2d:
232 | cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1, 1)
233 | else:
234 | # print("x.shape[0] != cx.shape[0]", x.shape[0], cx.shape[0])
235 | cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1)
236 | if mask is not None:
237 | mask = mask.repeat(x.shape[0] // mask.shape[0], 1, 1)
238 | if mask_tk is not None:
239 | mask_tk = mask_tk.repeat(x.shape[0] // mask_tk.shape[0], 1, 1)
240 |
241 | if mask is None:
242 | mask = 1.0
243 | elif mask_tk is not None:
244 | mask = mask * mask_tk
245 |
246 | #logger.info(f"cs: {cx.shape}, x: {x.shape}, is_conv2d: {self.is_conv2d}")
247 | cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2)
248 | cx = self.mid(cx)
249 | cx = self.up(cx)
250 | if control.latent_keyframes is not None:
251 | cx = cx * control.calc_latent_keyframe_mults(x=cx, batched_number=control.batched_number)
252 | if control.weights is not None and control.weights.has_uncond_multiplier:
253 | cond_or_uncond = control.batched_number.cond_or_uncond
254 | actual_length = cx.size(0) // control.batched_number
255 | for idx, cond_type in enumerate(cond_or_uncond):
256 | # if uncond, set to weight's uncond_multiplier
257 | if cond_type == 1:
258 | cx[actual_length*idx:actual_length*(idx+1)] *= control.weights.uncond_multiplier
259 | return cx * mask * control.strength * control._current_timestep_keyframe.strength
260 |
261 |
262 | class ControlLLLiteModules(torch.nn.Module):
263 | def __init__(self, patch_attn1: LLLitePatch, patch_attn2: LLLitePatch):
264 | super().__init__()
265 | self.patch_attn1_modules = torch.nn.Sequential(*list(patch_attn1.modules.values()))
266 | self.patch_attn2_modules = torch.nn.Sequential(*list(patch_attn2.modules.values()))
267 |
268 |
269 | class ControlLLLiteAdvanced(ControlBase, AdvancedControlBase):
270 | # This ControlNet is more of an attention patch than a traditional controlnet
271 | def __init__(self, patch_attn1: LLLitePatch, patch_attn2: LLLitePatch, timestep_keyframes: TimestepKeyframeGroup, device, ops: comfy.ops.disable_weight_init):
272 | super().__init__()
273 | AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite())
274 | self.device = device
275 | self.ops = ops
276 | self.patch_attn1 = patch_attn1.clone_with_control(self)
277 | self.patch_attn2 = patch_attn2.clone_with_control(self)
278 | self.control_model = ControlLLLiteModules(self.patch_attn1, self.patch_attn2)
279 | self.control_model_wrapped = ModelPatcher(self.control_model, load_device=device, offload_device=comfy.model_management.unet_offload_device())
280 | self.latent_dims_div2 = None
281 | self.latent_dims_div4 = None
282 |
283 | def set_cond_hint_inject(self, *args, **kwargs):
284 | to_return = super().set_cond_hint_inject(*args, **kwargs)
285 | # cond hint for LLLite needs to be scaled between (-1, 1) instead of (0, 1)
286 | self.cond_hint_original = self.cond_hint_original * 2.0 - 1.0
287 | return to_return
288 |
289 | def pre_run_advanced(self, *args, **kwargs):
290 | AdvancedControlBase.pre_run_advanced(self, *args, **kwargs)
291 | #logger.error(f"in cn: {id(self.patch_attn1)},{id(self.patch_attn2)}")
292 | self.patch_attn1.set_control(self)
293 | self.patch_attn2.set_control(self)
294 | #logger.warn(f"in pre_run_advanced: {id(self)}")
295 |
296 | def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int, transformer_options: dict):
297 | # normal ControlNet stuff
298 | control_prev = None
299 | if self.previous_controlnet is not None:
300 | control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
301 |
302 | if self.timestep_range is not None:
303 | if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
304 | return control_prev
305 |
306 | dtype = x_noisy.dtype
307 | # prepare cond_hint
308 | if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
309 | if self.cond_hint is not None:
310 | del self.cond_hint
311 | self.cond_hint = None
312 | # if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
313 | if self.sub_idxs is not None:
314 | actual_cond_hint_orig = self.cond_hint_original
315 | if self.cond_hint_original.size(0) < self.full_latent_length:
316 | actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
317 | self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device)
318 | else:
319 | self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device)
320 | if x_noisy.shape[0] != self.cond_hint.shape[0]:
321 | self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number)
322 | # some special logic here compared to other controlnets:
323 | # * The cond_emb in attn patches will divide latent dims by 2 or 4, integer
324 | # * Due to this loss, the cond_emb will become smaller than x input if latent dims are not divisble by 2 or 4
325 | divisible_by_2_h = x_noisy.shape[2]%2==0
326 | divisible_by_2_w = x_noisy.shape[3]%2==0
327 | if not (divisible_by_2_h and divisible_by_2_w):
328 | #logger.warn(f"{x_noisy.shape} not divisible by 2!")
329 | new_h = (x_noisy.shape[2]//2)*2
330 | new_w = (x_noisy.shape[3]//2)*2
331 | if not divisible_by_2_h:
332 | new_h += 2
333 | if not divisible_by_2_w:
334 | new_w += 2
335 | self.latent_dims_div2 = (new_h, new_w)
336 | divisible_by_4_h = x_noisy.shape[2]%4==0
337 | divisible_by_4_w = x_noisy.shape[3]%4==0
338 | if not (divisible_by_4_h and divisible_by_4_w):
339 | #logger.warn(f"{x_noisy.shape} not divisible by 4!")
340 | new_h = (x_noisy.shape[2]//4)*4
341 | new_w = (x_noisy.shape[3]//4)*4
342 | if not divisible_by_4_h:
343 | new_h += 4
344 | if not divisible_by_4_w:
345 | new_w += 4
346 | self.latent_dims_div4 = (new_h, new_w)
347 | # prepare mask
348 | self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number)
349 | # done preparing; model patches will take care of everything now
350 | set_model_attn1_patch(transformer_options, self.patch_attn1.set_control(self))
351 | set_model_attn2_patch(transformer_options, self.patch_attn2.set_control(self))
352 | # return normal controlnet stuff
353 | return control_prev
354 |
355 | def get_models(self):
356 | to_return: list = super().get_models()
357 | to_return.append(self.control_model_wrapped)
358 | return to_return
359 |
360 | def cleanup_advanced(self):
361 | super().cleanup_advanced()
362 | self.patch_attn1.cleanup()
363 | self.patch_attn2.cleanup()
364 | self.latent_dims_div2 = None
365 | self.latent_dims_div4 = None
366 |
367 | def copy(self):
368 | c = ControlLLLiteAdvanced(self.patch_attn1, self.patch_attn2, self.timestep_keyframes, self.device, self.ops)
369 | self.copy_to(c)
370 | self.copy_to_advanced(c)
371 | return c
372 |
373 |
374 | def load_controllllite(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None):
375 | if controlnet_data is None:
376 | controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
377 | # adapted from https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI
378 | # first, split weights for each module
379 | module_weights = {}
380 | for key, value in controlnet_data.items():
381 | fragments = key.split(".")
382 | module_name = fragments[0]
383 | weight_name = ".".join(fragments[1:])
384 |
385 | if module_name not in module_weights:
386 | module_weights[module_name] = {}
387 | module_weights[module_name][weight_name] = value
388 |
389 | unet_dtype = comfy.model_management.unet_dtype()
390 | load_device = comfy.model_management.get_torch_device()
391 | manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
392 | ops = comfy.ops.disable_weight_init
393 | if manual_cast_dtype is not None:
394 | ops = comfy.ops.manual_cast
395 |
396 | # next, load each module
397 | modules = {}
398 | for module_name, weights in module_weights.items():
399 | # kohya planned to do something about how these should be chosen, so I'm not touching this
400 | # since I am not familiar with the logic for this
401 | if "conditioning1.4.weight" in weights:
402 | depth = 3
403 | elif weights["conditioning1.2.weight"].shape[-1] == 4:
404 | depth = 2
405 | else:
406 | depth = 1
407 |
408 | module = LLLiteModule(
409 | name=module_name,
410 | is_conv2d=weights["down.0.weight"].ndim == 4,
411 | in_dim=weights["down.0.weight"].shape[1],
412 | depth=depth,
413 | cond_emb_dim=weights["conditioning1.0.weight"].shape[0] * 2,
414 | mlp_dim=weights["down.0.weight"].shape[0],
415 | )
416 | # load weights into module
417 | module.load_state_dict(weights)
418 | modules[module_name] = module.to(dtype=unet_dtype)
419 | if len(modules) == 1:
420 | module.is_first = True
421 |
422 | #logger.info(f"loaded {ckpt_path} successfully, {len(modules)} modules")
423 |
424 | patch_attn1 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN1)
425 | patch_attn2 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN2)
426 | control = ControlLLLiteAdvanced(patch_attn1=patch_attn1, patch_attn2=patch_attn2, timestep_keyframes=timestep_keyframe, device=load_device, ops=ops)
427 | return control
428 |
--------------------------------------------------------------------------------
/adv_control/control_sparsectrl.py:
--------------------------------------------------------------------------------
1 | #taken from: https://github.com/lllyasviel/ControlNet
2 | #and modified
3 | #and then taken from comfy/cldm/cldm.py and modified again
4 |
5 | from abc import ABC, abstractmethod
6 | import numpy as np
7 | import torch
8 | from torch import Tensor
9 |
10 | from comfy.ldm.modules.diffusionmodules.util import (
11 | zero_module,
12 | timestep_embedding,
13 | )
14 |
15 | from comfy.cldm.cldm import ControlNet as ControlNetCLDM
16 | from comfy.ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential
17 | from comfy.model_patcher import ModelPatcher
18 | from comfy.patcher_extension import PatcherInjection
19 |
20 | from .dinklink import (InterfaceAnimateDiffInfo, InterfaceAnimateDiffModel,
21 | get_CreateMotionModelPatcher, get_AnimateDiffModel, get_AnimateDiffInfo)
22 | from .logger import logger
23 | from .utils import (BIGMAX, AbstractPreprocWrapper, disable_weight_init_clean_groupnorm, WrapperConsts)
24 |
25 |
26 | class SparseMotionModelPatcher(ModelPatcher):
27 | '''Class only used for IDE type hints.'''
28 | def __init__(self, *args, **kwargs):
29 | self.model = InterfaceAnimateDiffModel
30 |
31 |
32 | class SparseConst:
33 | HINT_MULT = "sparse_hint_mult"
34 | NONHINT_MULT = "sparse_nonhint_mult"
35 | MASK_MULT = "sparse_mask_mult"
36 |
37 |
38 | class SparseControlNet(ControlNetCLDM):
39 | def __init__(self, *args,**kwargs):
40 | super().__init__(*args, **kwargs)
41 | hint_channels = kwargs.get("hint_channels")
42 | operations: disable_weight_init_clean_groupnorm = kwargs.get("operations", disable_weight_init_clean_groupnorm)
43 | device = kwargs.get("device", None)
44 | self.use_simplified_conditioning_embedding = kwargs.get("use_simplified_conditioning_embedding", False)
45 | if self.use_simplified_conditioning_embedding:
46 | self.input_hint_block = TimestepEmbedSequential(
47 | zero_module(operations.conv_nd(self.dims, hint_channels, self.model_channels, 3, padding=1, dtype=self.dtype, device=device)),
48 | )
49 |
50 | def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs):
51 | t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
52 | emb = self.time_embed(t_emb)
53 |
54 | # SparseCtrl sets noisy input to zeros
55 | x = torch.zeros_like(x)
56 | guided_hint = self.input_hint_block(hint, emb, context)
57 |
58 | out_output = []
59 | out_middle = []
60 |
61 | hs = []
62 | if self.num_classes is not None:
63 | assert y.shape[0] == x.shape[0]
64 | emb = emb + self.label_emb(y)
65 |
66 | h = x
67 | for module, zero_conv in zip(self.input_blocks, self.zero_convs):
68 | if guided_hint is not None:
69 | h = module(h, emb, context)
70 | h += guided_hint
71 | guided_hint = None
72 | else:
73 | h = module(h, emb, context)
74 | out_output.append(zero_conv(h, emb, context))
75 |
76 | h = self.middle_block(h, emb, context)
77 | out_middle.append(self.middle_block_out(h, emb, context))
78 |
79 | return {"middle": out_middle, "output": out_output}
80 |
81 |
82 | def load_sparsectrl_motionmodel(ckpt_path: str, motion_data: dict[str, Tensor], ops=None) -> InterfaceAnimateDiffModel:
83 | mm_info: InterfaceAnimateDiffInfo = get_AnimateDiffInfo()("SD1.5", "AnimateDiff", "v3", ckpt_path)
84 | init_kwargs = {
85 | "ops": ops,
86 | "get_unet_func": _get_unet_func,
87 | }
88 | motion_model: InterfaceAnimateDiffModel = get_AnimateDiffModel()(mm_state_dict=motion_data, mm_info=mm_info, init_kwargs=init_kwargs)
89 | missing, unexpected = motion_model.load_state_dict(motion_data)
90 | if len(missing) > 0 or len(unexpected) > 0:
91 | logger.info(f"SparseCtrl MotionModel: {missing}, {unexpected}")
92 | return motion_model
93 |
94 |
95 | def create_sparse_modelpatcher(model, motion_model, load_device, offload_device):
96 | patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
97 | if motion_model is not None:
98 | _motionpatcher = _create_sparse_motionmodelpatcher(motion_model, load_device, offload_device)
99 | patcher.set_additional_models(WrapperConsts.ACN, [_motionpatcher])
100 | patcher.set_injections(WrapperConsts.ACN,
101 | [PatcherInjection(inject=_inject_motion_models, eject=_eject_motion_models)])
102 | return patcher
103 |
104 | def _create_sparse_motionmodelpatcher(motion_model, load_device, offload_device) -> SparseMotionModelPatcher:
105 | return get_CreateMotionModelPatcher()(motion_model, load_device, offload_device)
106 |
107 |
108 | def _inject_motion_models(patcher: ModelPatcher):
109 | motion_models: list[SparseMotionModelPatcher] = patcher.get_additional_models_with_key(WrapperConsts.ACN)
110 | for mm in motion_models:
111 | mm.model.inject(patcher)
112 |
113 | def _eject_motion_models(patcher: ModelPatcher):
114 | motion_models: list[SparseMotionModelPatcher] = patcher.get_additional_models_with_key(WrapperConsts.ACN)
115 | for mm in motion_models:
116 | mm.model.eject(patcher)
117 |
118 | def _get_unet_func(wrapper, model: ModelPatcher):
119 | return model.model
120 |
121 |
122 | class PreprocSparseRGBWrapper(AbstractPreprocWrapper):
123 | error_msg = error_msg = "Invalid use of RGB SparseCtrl output. The output of RGB SparseCtrl preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input."
124 | def __init__(self, condhint: Tensor):
125 | super().__init__(condhint)
126 |
127 |
128 | class SparseContextAware:
129 | NEAREST_HINT = "nearest_hint"
130 | OFF = "off"
131 |
132 | LIST = [NEAREST_HINT, OFF]
133 |
134 |
135 | class SparseSettings:
136 | def __init__(self, sparse_method: 'SparseMethod', use_motion: bool=True, motion_strength=1.0, motion_scale=1.0, merged=False,
137 | sparse_mask_mult=1.0, sparse_hint_mult=1.0, sparse_nonhint_mult=1.0, context_aware=SparseContextAware.NEAREST_HINT):
138 | # account for Steerable-Motion workflow incompatibility;
139 | # doing this to for my own peace of mind (not an issue with my code)
140 | if type(sparse_method) == str:
141 | logger.warn("Outdated Steerable-Motion workflow detected; attempting to auto-convert indexes input. If you experience an error here, consult Steerable-Motion github, NOT Advanced-ControlNet.")
142 | sparse_method = SparseIndexMethod(get_idx_list_from_str(sparse_method))
143 | self.sparse_method = sparse_method
144 | self.use_motion = use_motion
145 | self.motion_strength = motion_strength
146 | self.motion_scale = motion_scale
147 | self.merged = merged
148 | self.sparse_mask_mult = float(sparse_mask_mult)
149 | self.sparse_hint_mult = float(sparse_hint_mult)
150 | self.sparse_nonhint_mult = float(sparse_nonhint_mult)
151 | self.context_aware = context_aware
152 |
153 | def is_context_aware(self):
154 | return self.context_aware != SparseContextAware.OFF
155 |
156 | @classmethod
157 | def default(cls):
158 | return SparseSettings(sparse_method=SparseSpreadMethod(), use_motion=True)
159 |
160 |
161 | class SparseMethod(ABC):
162 | SPREAD = "spread"
163 | INDEX = "index"
164 | def __init__(self, method: str):
165 | self.method = method
166 |
167 | @abstractmethod
168 | def _get_indexes(self, hint_length: int, full_length: int) -> list[int]:
169 | pass
170 |
171 | def get_indexes(self, hint_length: int, full_length: int, sub_idxs: list[int]=None) -> tuple[list[int], list[int]]:
172 | returned_idxs = self._get_indexes(hint_length, full_length)
173 | if sub_idxs is None:
174 | return returned_idxs, None
175 | # need to map full indexes to condhint indexes
176 | index_mapping = {}
177 | for i, value in enumerate(returned_idxs):
178 | index_mapping[value] = i
179 | def get_mapped_idxs(idxs: list[int]):
180 | return [index_mapping[idx] for idx in idxs]
181 | # check if returned_idxs fit within subidxs
182 | fitting_idxs = []
183 | for sub_idx in sub_idxs:
184 | if sub_idx in returned_idxs:
185 | fitting_idxs.append(sub_idx)
186 | # if have any fitting_idxs, deal with it
187 | if len(fitting_idxs) > 0:
188 | return fitting_idxs, get_mapped_idxs(fitting_idxs)
189 |
190 | # since no returned_idxs fit in sub_idxs, need to get the next-closest hint images based on strategy
191 | def get_closest_idx(target_idx: int, idxs: list[int]):
192 | min_idx = -1
193 | min_dist = BIGMAX
194 | for idx in idxs:
195 | new_dist = abs(idx-target_idx)
196 | if new_dist < min_dist:
197 | min_idx = idx
198 | min_dist = new_dist
199 | if min_dist == 1:
200 | return min_idx, min_dist
201 | return min_idx, min_dist
202 | start_closest_idx, start_dist = get_closest_idx(sub_idxs[0], returned_idxs)
203 | end_closest_idx, end_dist = get_closest_idx(sub_idxs[-1], returned_idxs)
204 | # if only one cond hint exists, do special behavior
205 | if hint_length == 1:
206 | # if same distance from start and end,
207 | if start_dist == end_dist:
208 | # find center index of sub_idxs
209 | center_idx = sub_idxs[np.linspace(0, len(sub_idxs)-1, 3, endpoint=True, dtype=int)[1]]
210 | return [center_idx], get_mapped_idxs([start_closest_idx])
211 | # otherwise, return closest
212 | if start_dist < end_dist:
213 | return [sub_idxs[0]], get_mapped_idxs([start_closest_idx])
214 | return [sub_idxs[-1]], get_mapped_idxs([end_closest_idx])
215 | # otherwise, select up to two closest images, or just 1, whichever one applies best
216 | # if same distance from start and end, return two images to use
217 | if start_dist == end_dist:
218 | return [sub_idxs[0], sub_idxs[-1]], get_mapped_idxs([start_closest_idx, end_closest_idx])
219 | # else, use just one
220 | if start_dist < end_dist:
221 | return [sub_idxs[0]], get_mapped_idxs([start_closest_idx])
222 | return [sub_idxs[-1]], get_mapped_idxs([end_closest_idx])
223 |
224 |
225 | class SparseSpreadMethod(SparseMethod):
226 | UNIFORM = "uniform"
227 | STARTING = "starting"
228 | ENDING = "ending"
229 | CENTER = "center"
230 |
231 | LIST = [UNIFORM, STARTING, ENDING, CENTER]
232 |
233 | def __init__(self, spread=UNIFORM):
234 | super().__init__(self.SPREAD)
235 | self.spread = spread
236 |
237 | def _get_indexes(self, hint_length: int, full_length: int) -> list[int]:
238 | # if hint_length >= full_length, limit hints to full_length
239 | if hint_length >= full_length:
240 | return list(range(full_length))
241 | # handle special case of 1 hint image
242 | if hint_length == 1:
243 | if self.spread in [self.UNIFORM, self.STARTING]:
244 | return [0]
245 | elif self.spread == self.ENDING:
246 | return [full_length-1]
247 | elif self.spread == self.CENTER:
248 | # return second (of three) values as the center
249 | return [np.linspace(0, full_length-1, 3, endpoint=True, dtype=int)[1]]
250 | else:
251 | raise ValueError(f"Unrecognized spread: {self.spread}")
252 | # otherwise, handle other cases
253 | if self.spread == self.UNIFORM:
254 | return list(np.linspace(0, full_length-1, hint_length, endpoint=True, dtype=int))
255 | elif self.spread == self.STARTING:
256 | # make split 1 larger, remove last element
257 | return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[:-1]
258 | elif self.spread == self.ENDING:
259 | # make split 1 larger, remove first element
260 | return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[1:]
261 | elif self.spread == self.CENTER:
262 | # if hint length is not 3 greater than full length, do STARTING behavior
263 | if full_length-hint_length < 3:
264 | return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[:-1]
265 | # otherwise, get linspace of 2 greater than needed, then cut off first and last
266 | return list(np.linspace(0, full_length-1, hint_length+2, endpoint=True, dtype=int))[1:-1]
267 | return ValueError(f"Unrecognized spread: {self.spread}")
268 |
269 |
270 | class SparseIndexMethod(SparseMethod):
271 | def __init__(self, idxs: list[int]):
272 | super().__init__(self.INDEX)
273 | self.idxs = idxs
274 |
275 | def _get_indexes(self, hint_length: int, full_length: int) -> list[int]:
276 | orig_hint_length = hint_length
277 | if hint_length > full_length:
278 | hint_length = full_length
279 | # if idxs is less than hint_length, throw error
280 | if len(self.idxs) < hint_length:
281 | err_msg = f"There are not enough indexes ({len(self.idxs)}) provided to fit the usable {hint_length} input images."
282 | if orig_hint_length != hint_length:
283 | err_msg = f"{err_msg} (original input images: {orig_hint_length})"
284 | raise ValueError(err_msg)
285 | # cap idxs to hint_length
286 | idxs = self.idxs[:hint_length]
287 | new_idxs = []
288 | real_idxs = set()
289 | for idx in idxs:
290 | if idx < 0:
291 | real_idx = full_length+idx
292 | if real_idx in real_idxs:
293 | raise ValueError(f"Index '{idx}' maps to '{real_idx}' and is duplicate - indexes in Sparse Index Method must be unique.")
294 | else:
295 | real_idx = idx
296 | if real_idx in real_idxs:
297 | raise ValueError(f"Index '{idx}' is duplicate (or a negative index is equivalent) - indexes in Sparse Index Method must be unique.")
298 | real_idxs.add(real_idx)
299 | new_idxs.append(real_idx)
300 | return new_idxs
301 |
302 |
303 | def get_idx_list_from_str(indexes: str) -> list[int]:
304 | idxs = []
305 | unique_idxs = set()
306 | # get indeces from string
307 | str_idxs = [x.strip() for x in indexes.strip().split(",")]
308 | for str_idx in str_idxs:
309 | try:
310 | idx = int(str_idx)
311 | if idx in unique_idxs:
312 | raise ValueError(f"'{idx}' is duplicated; indexes must be unique.")
313 | idxs.append(idx)
314 | unique_idxs.add(idx)
315 | except ValueError:
316 | raise ValueError(f"'{str_idx}' is not a valid integer index.")
317 | if len(idxs) == 0:
318 | raise ValueError(f"No indexes were listed in Sparse Index Method.")
319 | return idxs
320 |
--------------------------------------------------------------------------------
/adv_control/dinklink.py:
--------------------------------------------------------------------------------
1 | ####################################################################################################
2 | # DinkLink is my method of sharing classes/functions between my nodes.
3 | #
4 | # My DinkLink-compatible nodes will inject comfy.hooks with a __DINKLINK attr
5 | # that stores a dictionary, where any of my node packs can store their stuff.
6 | #
7 | # It is not intended to be accessed by node packs that I don't develop, so things may change
8 | # at any time.
9 | #
10 | # DinkLink also serves as a proof-of-concept for a future ComfyUI implementation of
11 | # purposely exposing node pack classes/functions with other node packs.
12 | ####################################################################################################
13 | from __future__ import annotations
14 | from typing import Union
15 | from torch import Tensor, nn
16 |
17 | from comfy.model_patcher import ModelPatcher
18 | import comfy.hooks
19 |
20 | DINKLINK = "__DINKLINK"
21 |
22 |
23 | def init_dinklink():
24 | create_dinklink()
25 | prepare_dinklink()
26 |
27 | def create_dinklink():
28 | if not hasattr(comfy.hooks, DINKLINK):
29 | setattr(comfy.hooks, DINKLINK, {})
30 |
31 | def get_dinklink() -> dict[str, dict[str]]:
32 | create_dinklink()
33 | return getattr(comfy.hooks, DINKLINK)
34 |
35 |
36 | class DinkLinkConst:
37 | VERSION = "version"
38 | # ADE
39 | ADE = "ADE"
40 | ADE_ANIMATEDIFFMODEL = "AnimateDiffModel"
41 | ADE_ANIMATEDIFFINFO = "AnimateDiffInfo"
42 | ADE_CREATE_MOTIONMODELPATCHER = "create_MotionModelPatcher"
43 |
44 | def prepare_dinklink():
45 | pass
46 |
47 |
48 | class InterfaceAnimateDiffInfo:
49 | '''Class only used for IDE type hints; interface of ADE's AnimateDiffInfo'''
50 | def __init__(self, sd_type: str, mm_format: str, mm_version: str, mm_name: str):
51 | self.sd_type = sd_type
52 | self.mm_format = mm_format
53 | self.mm_version = mm_version
54 | self.mm_name = mm_name
55 |
56 |
57 | class InterfaceAnimateDiffModel(nn.Module):
58 | '''Class only used for IDE type hints; interface of ADE's AnimateDiffModel'''
59 | def __init__(self, mm_state_dict: dict[str, Tensor], mm_info: InterfaceAnimateDiffInfo, init_kwargs: dict[str]={}):
60 | pass
61 |
62 | def set_video_length(self, video_length: int, full_length: int) -> None:
63 | raise NotImplemented()
64 |
65 | def set_scale(self, scale: Union[float, Tensor, None], per_block_list: Union[list, None]=None) -> None:
66 | raise NotImplemented()
67 |
68 | def set_effect(self, multival: Union[float, Tensor, None], per_block_list: Union[list, None]=None) -> None:
69 | raise NotImplemented()
70 |
71 | def cleanup(self):
72 | raise NotImplemented()
73 |
74 | def inject(self, model: ModelPatcher):
75 | pass
76 |
77 | def eject(self, model: ModelPatcher):
78 | pass
79 |
80 |
81 | def get_CreateMotionModelPatcher(throw_exception=True):
82 | d = get_dinklink()
83 | try:
84 | link_ade = d[DinkLinkConst.ADE]
85 | return link_ade[DinkLinkConst.ADE_CREATE_MOTIONMODELPATCHER]
86 | except KeyError:
87 | if throw_exception:
88 | raise Exception("Could not get create_MotionModelPatcher function. AnimateDiff-Evolved nodes need to be installed to use SparseCtrl; " + \
89 | "they are either not installed or are of an insufficient version.")
90 | return None
91 |
92 | def get_AnimateDiffModel(throw_exception=True):
93 | d = get_dinklink()
94 | try:
95 | link_ade = d[DinkLinkConst.ADE]
96 | return link_ade[DinkLinkConst.ADE_ANIMATEDIFFMODEL]
97 | except KeyError:
98 | if throw_exception:
99 | raise Exception("Could not get AnimateDiffModel class. AnimateDiff-Evolved nodes need to be installed to use SparseCtrl; " + \
100 | "they are either not installed or are of an insufficient version.")
101 | return None
102 |
103 | def get_AnimateDiffInfo(throw_exception=True) -> InterfaceAnimateDiffInfo:
104 | d = get_dinklink()
105 | try:
106 | link_ade = d[DinkLinkConst.ADE]
107 | return link_ade[DinkLinkConst.ADE_ANIMATEDIFFINFO]
108 | except KeyError:
109 | if throw_exception:
110 | raise Exception("Could not get AnimateDiffInfo class - AnimateDiff-Evolved nodes need to be installed to use SparseCtrl; " + \
111 | "they are either not installed or are of an insufficient version.")
112 | return None
113 |
--------------------------------------------------------------------------------
/adv_control/documentation.py:
--------------------------------------------------------------------------------
1 | from .logger import logger
2 |
3 | def image(src):
4 | return f''
5 | def video(src):
6 | return f'