├── .github └── workflows │ └── publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── adv_control ├── control.py ├── control_ctrlora.py ├── control_lllite.py ├── control_plusplus.py ├── control_reference.py ├── control_sparsectrl.py ├── control_svd.py ├── dinklink.py ├── documentation.py ├── logger.py ├── nodes.py ├── nodes_ctrlora.py ├── nodes_deprecated.py ├── nodes_keyframes.py ├── nodes_loosecontrol.py ├── nodes_main.py ├── nodes_plusplus.py ├── nodes_reference.py ├── nodes_sparsectrl.py ├── nodes_weight.py ├── sampling.py └── utils.py ├── pyproject.toml └── web └── js ├── autosize.js └── documentation.js /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | permissions: 11 | issues: write 12 | 13 | jobs: 14 | publish-node: 15 | name: Publish Custom Node to registry 16 | runs-on: ubuntu-latest 17 | if: ${{ github.repository_owner == 'Kosinkadink' }} 18 | steps: 19 | - name: Check out code 20 | uses: actions/checkout@v4 21 | - name: Publish Custom Node 22 | uses: Comfy-Org/publish-node-action@v1 23 | with: 24 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} ## Add your own personal access token to your Github Repository secrets and reference it here. 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-Advanced-ControlNet 2 | Nodes for scheduling ControlNet strength across timesteps and batched latents, as well as applying custom weights and attention masks. The ControlNet nodes here fully support sliding context sampling, like the one used in the [ComfyUI-AnimateDiff-Evolved](https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) nodes. Currently supports ControlNets, T2IAdapters, ControlLoRAs, ControlLLLite, SparseCtrls, SVD-ControlNets, and Reference. 3 | 4 | Custom weights allow replication of the "My prompt is more important" feature of Auto1111's sd-webui ControlNet extension via Soft Weights, and the "ControlNet is more important" feature can be granularly controlled by changing the uncond_multiplier on the same Soft Weights. 5 | 6 | ControlNet preprocessors are available through [comfyui_controlnet_aux](https://github.com/Fannovel16/comfyui_controlnet_aux) nodes. 7 | 8 | ## Features 9 | - Timestep and latent strength scheduling 10 | - Attention masks 11 | - Replicate ***"My prompt is more important"*** feature from sd-webui-controlnet extension via ***Soft Weights***, and allow softness to be tweaked via ***base_multiplier*** 12 | - Replicate ***"ControlNet is more important"*** feature from sd-webui-controlnet extension via ***uncond_multiplier*** on ***Soft Weights*** 13 | - uncond_multiplier=0.0 gives identical results of auto1111's feature, but values between 0.0 and 1.0 can be used without issue to granularly control the setting. 14 | - ControlNet, T2IAdapter, and ControlLoRA support for sliding context windows 15 | - ControlLLLite support 16 | - ControlNet++ support 17 | - CtrLoRA support 18 | - Relevant models linked on [CtrLoRA github page](https://github.com/xyfJASON/ctrlora) 19 | - SparseCtrl support 20 | - SVD-ControlNet support 21 | - Stable Video Diffusion ControlNets trained by **CiaraRowles**: [Depth](https://huggingface.co/CiaraRowles/temporal-controlnet-depth-svd-v1/tree/main/controlnet), [Lineart](https://huggingface.co/CiaraRowles/temporal-controlnet-lineart-svd-v1/tree/main/controlnet) 22 | - Reference support 23 | - Supports ```reference_attn```, ```reference_adain```, and ```refrence_adain+attn``` modes. ```style_fidelity``` and ```ref_weight``` are equivalent to style_fidelity and control_weight in Auto1111, respectively, and strength of the Apply ControlNet is the balance between ref-influenced result and no-ref result. There is also a Reference ControlNet (Finetune) node that allows adjust the style_fidelity, weight, and strength of attn and adain separately. 24 | 25 | ## Table of Contents: 26 | - [Scheduling Explanation](#scheduling-explanation) 27 | - [Nodes](#nodes) 28 | - [Usage](#usage) (will fill this out soon) 29 | 30 | 31 | # Scheduling Explanation 32 | 33 | The two core concepts for scheduling are ***Timestep Keyframes*** and ***Latent Keyframes***. 34 | 35 | ***Timestep Keyframes*** hold the values that guide the settings for a controlnet, and begin to take effect based on their start_percent, which corresponds to the percentage of the sampling process. They can contain masks for the strengths of each latent, control_net_weights, and latent_keyframes (specific strengths for each latent), all optional. 36 | 37 | ***Latent Keyframes*** determine the strength of the controlnet for specific latents - all they contain is the batch_index of the latent, and the strength the controlnet should apply for that latent. As a concept, latent keyframes achieve the same affect as a uniform mask with the chosen strength value. 38 | 39 | ![advcn_image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/e6275264-6c3f-4246-a319-111ee48f4cd9) 40 | 41 | # Nodes 42 | 43 | The ControlNet nodes provided here are the ***Apply Advanced ControlNet*** and ***Load Advanced ControlNet Model*** (or diff) nodes. The vanilla ControlNet nodes are also compatible, and can be used almost interchangeably - the only difference is that **at least one of these nodes must be used** for Advanced versions of ControlNets to be used (important for sliding context sampling, like with AnimateDiff-Evolved). 44 | 45 | Key: 46 | - 🟩 - required inputs 47 | - 🟨 - optional inputs 48 | - 🟦 - start as widgets, can be converted to inputs 49 | - 🟥 - optional input/output, but not recommended to use unless needed 50 | - 🟪 - output 51 | 52 | ## Apply Advanced ControlNet 53 | ![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/dc541d41-70df-4a71-b832-efa65af98f06) 54 | 55 | Same functionality as the vanilla Apply Advanced ControlNet (Advanced) node, except with Advanced ControlNet features added to it. Automatically converts any ControlNet from ControlNet loaders into Advanced versions. 56 | 57 | ### Inputs 58 | - 🟩***positive***: conditioning (positive). 59 | - 🟩***negative***: conditioning (negative). 60 | - 🟩***control_net***: loaded controlnet; will be converted to Advanced version automatically by this node, if it's a supported type. 61 | - 🟩***image***: images to guide controlnets - if the loaded controlnet requires it, they must preprocessed images. If one image provided, will be used for all latents. If more images provided, will use each image separately for each latent. If not enough images to meet latent count, will repeat the images from the beginning to match vanilla ControlNet functionality. 62 | - 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as image input, if you provide more than one mask, each can apply to a different latent. 63 | - 🟨***timestep_kf***: timestep keyframes to guide controlnet effect throughout sampling steps. 64 | - 🟨***latent_kf_override***: override for latent keyframes, useful if no other features from timestep keyframes is needed. *NOTE: this latent keyframe will be applied to ALL timesteps, regardless if there are other latent keyframes attached to connected timestep keyframes.* 65 | - 🟨***weights_override***: override for weights, useful if no other features from timestep keyframes is needed. *NOTE: this weight will be applied to ALL timesteps, regardless if there are other weights attached to connected timestep keyframes.* 66 | - 🟦***strength***: strength of controlnet; 1.0 is full strength, 0.0 is no effect at all. 67 | - 🟦***start_percent***: sampling step percentage at which controlnet should start to be applied - no matter what start_percent is set on timestep keyframes, they won't take effect until this start_percent is reached. 68 | - 🟦***stop_percent***: sampling step percentage at which controlnet should stop being applied - no matter what start_percent is set on timestep keyframes, they won't take effect once this end_percent is reached. 69 | 70 | ### Outputs 71 | - 🟪***positive***: conditioning (positive) with applied controlnets 72 | - 🟪***negative***: conditioning (negative) with applied controlnets 73 | 74 | ## Load Advanced ControlNet Model 75 | ![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/4a7f58a9-783d-4da4-bf82-bc9c167e4722) 76 | 77 | Loads a ControlNet model and converts it into an Advanced version that supports all the features in this repo. When used with **Apply Advanced ControlNet** node, there is no reason to use the timestep_keyframe input on this node - use timestep_kf on the Apply node instead. 78 | 79 | ### Inputs 80 | - 🟥***timestep_keyframe***: optional and likely unnecessary input to have ControlNet use selected timestep_keyframes - should not be used unless you need to. Useful if this node is not attached to **Apply Advanced ControlNet** node, but still want to use Timestep Keyframe, or to use TK_SHORTCUT outputs from ControlWeights in the same scenario. Will be overriden by the timestep_kf input on **Apply Advanced ControlNet** node, if one is provided there. 81 | - 🟨***model***: model to plug into the diff version of the node. Some controlnets are designed for receive the model; if you don't know what this does, you probably don't want tot use the diff version of the node. 82 | 83 | ### Outputs 84 | - 🟪***CONTROL_NET***: loaded Advanced ControlNet 85 | 86 | ## Timestep Keyframe 87 | ![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/404f3cfe-5852-4eed-935b-37e32493d1b5) 88 | 89 | Scheduling node across timesteps (sampling steps) based on the set start_percent. Chaining Timestep Keyframes allows ControlNet scheduling across sampling steps (percentage-wise), through a timestep keyframe schedule. 90 | 91 | ### Inputs 92 | - 🟨***prev_timestep_kf***: used to chain Timestep Keyframes together to create a schedule. The order does not matter - the Timestep Keyframes sort themselves automatically by their start_percent. *Any Timestep Keyframe contained in the prev_timestep_keyframe that contains the same start_percent as the Timestep Keyframe will be overwritten.* 93 | - 🟨***cn_weights***: weights to apply to controlnet while this Timestep Keyframe is in effect. Must be compatible with the loaded controlnet, or will throw an error explaining what weight types are compatible. If inherit_missing is True, if no control_net_weight is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a weight_override, the weight_override will be used during sampling instead of control_net_weight.* 94 | - 🟨***latent_keyframe***: latent keyframes to apply to controlnet while this Timestep Keyframe is in effect. If inherit_missing is True, if no latent_keyframe is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a latent_kf_override, the latent_lf_override will be used during sampling instead of latent_keyframe.* 95 | - 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as mask_optional on the Apply Advanced ControlNet node, can apply either one maks to all latents, or individual masks for each latent. If inherit_missing is True, if no mask_optional is passed in, will attempt to reuse the last-used mask_optional in the timestep keyframe schedule. It is NOT overriden by mask_optional on the Apply Advanced ControlNet node; will be used together. 96 | - 🟦***start_percent***: sampling step percentage at which this Timestep Keyframe qualifies to be used. Acts as the 'key' for the Timestep Keyframe in the timestep keyframe schedule. 97 | - 🟦***strength***: strength of the controlnet; multiplies the controlnet by this value, basically, applied alongside the strength on the Apply ControlNet node. If set to 0.0 will not have any effect during the duration of this Timestep Keyframe's effect, and will increase sampling speed by not doing any work. 98 | - 🟦***null_latent_kf_strength***: strength to assign to latents that are unaccounted for in the passed in latent_keyframes. Has no effect if no latent_keyframes are passed in, or no batch_indeces are unaccounted in the latent_keyframes for during sampling. 99 | - 🟦***inherit_missing***: determines if should reuse values from previous Timestep Keyframes for optional values (control_net_weights, latent_keyframe, and mask_option) that are not included on this TimestepKeyframe. To inherit only specific inputs, use default inputs. 100 | - 🟦***guarantee_steps***: when 1 or greater, even if a Timestep Keyframe's start_percent ahead of this one in the schedule is closer to current sampling percentage, this Timestep Keyframe will still be used for the specified amount of steps before moving on to the next selected Timestep Keyframe in the following step. Whether the Timestep Keyframe is used or not, its inputs will still be accounted for inherit_missing purposes. 101 | 102 | ### Outputs 103 | - 🟪***TIMESTEP_KF***: the created Timestep Keyframe, that can either be linked to another or into a Timestep Keyframe input. 104 | 105 | ## Timestep Keyframe Interpolation 106 | ![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/9789617c-202c-4271-92a2-0909bcf9b108) 107 | 108 | Allows to create Timestep Keyframe with interpolated strength values in a given percent range. (The first generated keyframe will have guarantee_steps=1, rest that follow will have guarantee_steps=0). 109 | 110 | ### Inputs 111 | - 🟨***prev_timestep_kf***: used to chain Timestep Keyframes together to create a schedule. The order does not matter - the Timestep Keyframes sort themselves automatically by their start_percent. *Any Timestep Keyframe contained in the prev_timestep_keyframe that contains the same start_percent as the Timestep Keyframe will be overwritten.* 112 | - 🟨***cn_weights***: weights to apply to controlnet while this Timestep Keyframe is in effect. Must be compatible with the loaded controlnet, or will throw an error explaining what weight types are compatible. If inherit_missing is True, if no control_net_weight is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a weight_override, the weight_override will be used during sampling instead of control_net_weight.* 113 | - 🟨***latent_keyframe***: latent keyframes to apply to controlnet while this Timestep Keyframe is in effect. If inherit_missing is True, if no latent_keyframe is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a latent_kf_override, the latent_lf_override will be used during sampling instead of latent_keyframe.* 114 | - 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as mask_optional on the Apply Advanced ControlNet node, can apply either one maks to all latents, or individual masks for each latent. If inherit_missing is True, if no mask_optional is passed in, will attempt to reuse the last-used mask_optional in the timestep keyframe schedule. It is NOT overriden by mask_optional on the Apply Advanced ControlNet node; will be used together. 115 | - 🟦***start_percent***: sampling step percentage at which the first generated Timestep Keyframe qualifies to be used. 116 | - 🟦***end_percent***: sampling step percentage at which the last generated Timestep Keyframe qualifies to be used. 117 | - 🟦***strength_start***: strength of the Timestep Keyframe at start of range. 118 | - 🟦***strength_end***: strength of the Timestep Keyframe at end of range. 119 | - 🟦***interpolation***: the method of interpolation. 120 | - 🟦***intervals***: the amount of keyframes to generate in total - the first will have its start_percent equal to start_percent, the last will have its start_percent equal to end_percent. 121 | - 🟦***null_latent_kf_strength***: strength to assign to latents that are unaccounted for in the passed in latent_keyframes. Has no effect if no latent_keyframes are passed in, or no batch_indeces are unaccounted in the latent_keyframes for during sampling. 122 | - 🟦***inherit_missing***: determines if should reuse values from previous Timestep Keyframes for optional values (control_net_weights, latent_keyframe, and mask_option) that are not included on this TimestepKeyframe. To inherit only specific inputs, use default inputs. 123 | - 🟦***print_keyframes***: if True, will print the Timestep Keyframes generated by this node for debugging purposes. 124 | 125 | ### Outputs 126 | - 🟪***TIMESTEP_KF***: the created Timestep Keyframe, that can either be linked to another or into a Timestep Keyframe input. 127 | 128 | ## Timestep Keyframe From List 129 | ![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/9e9c23bf-6f82-4ce7-b4d1-3016fd14707d) 130 | 131 | Allows to create Timestep Keyframe via a list of floats, such as with Batch Value Schedule from [ComfyUI_FizzNodes](https://github.com/FizzleDorf/ComfyUI_FizzNodes) nodes. (The first generated keyframe will have guarantee_steps=1, rest that follow will have guarantee_steps=0). 132 | 133 | ### Inputs 134 | - 🟨***prev_timestep_kf***: used to chain Timestep Keyframes together to create a schedule. The order does not matter - the Timestep Keyframes sort themselves automatically by their start_percent. *Any Timestep Keyframe contained in the prev_timestep_keyframe that contains the same start_percent as the Timestep Keyframe will be overwritten.* 135 | - 🟨***cn_weights***: weights to apply to controlnet while this Timestep Keyframe is in effect. Must be compatible with the loaded controlnet, or will throw an error explaining what weight types are compatible. If inherit_missing is True, if no control_net_weight is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a weight_override, the weight_override will be used during sampling instead of control_net_weight.* 136 | - 🟨***latent_keyframe***: latent keyframes to apply to controlnet while this Timestep Keyframe is in effect. If inherit_missing is True, if no latent_keyframe is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a latent_kf_override, the latent_lf_override will be used during sampling instead of latent_keyframe.* 137 | - 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as mask_optional on the Apply Advanced ControlNet node, can apply either one maks to all latents, or individual masks for each latent. If inherit_missing is True, if no mask_optional is passed in, will attempt to reuse the last-used mask_optional in the timestep keyframe schedule. It is NOT overriden by mask_optional on the Apply Advanced ControlNet node; will be used together. 138 | - 🟩***float_strengths***: a list of floats, that will correspond to the strength of each Timestep Keyframe; first will be assigned to start_percent, last will be assigned to end_percent, and the rest spread linearly between. 139 | - 🟦***start_percent***: sampling step percentage at which the first generated Timestep Keyframe qualifies to be used. 140 | - 🟦***end_percent***: sampling step percentage at which the last generated Timestep Keyframe qualifies to be used. 141 | - 🟦***null_latent_kf_strength***: strength to assign to latents that are unaccounted for in the passed in latent_keyframes. Has no effect if no latent_keyframes are passed in, or no batch_indeces are unaccounted in the latent_keyframes for during sampling. 142 | - 🟦***inherit_missing***: determines if should reuse values from previous Timestep Keyframes for optional values (control_net_weights, latent_keyframe, and mask_option) that are not included on this TimestepKeyframe. To inherit only specific inputs, use default inputs. 143 | - 🟦***print_keyframes***: if True, will print the Timestep Keyframes generated by this node for debugging purposes. 144 | 145 | ### Outputs 146 | - 🟪***TIMESTEP_KF***: the created Timestep Keyframe, that can either be linked to another or into a Timestep Keyframe input. 147 | 148 | ## Latent Keyframe 149 | ![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/7eb2cc4c-255c-4f32-b09b-699f713fada3) 150 | 151 | A singular Latent Keyframe, selects the strength for a specific batch_index. If batch_index is not present during sampling, will simply have no effect. Can be chained with any other Latent Keyframe-type node to create a latent keyframe schedule. 152 | 153 | ### Inputs 154 | - 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If a Latent Keyframe contained in prev_latent_keyframes have the same batch_index as this Latent Keyframe, they will take priority over this node's value.* 155 | - 🟦***batch_index***: index of latent in batch to apply controlnet strength to. Acts as the 'key' for the Latent Keyframe in the latent keyframe schedule. 156 | - 🟦***strength***: strength of controlnet to apply to the corresponding latent. 157 | 158 | ### Outputs 159 | - 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input. 160 | 161 | ## Latent Keyframe Group 162 | ![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/5ce3b795-f5fc-4dc3-ae30-a4c7f87e278c) 163 | 164 | Allows to create Latent Keyframes via individual indeces or python-style ranges. 165 | 166 | ### Inputs 167 | - 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If any Latent Keyframes contained in prev_latent_keyframes have the same batch_index as a this Latent Keyframe, they will take priority over this node's version.* 168 | - 🟨***latent_optional***: the latents expected to be passed in for sampling; only required if you wish to use negative indeces (will be automatically converted to real values). 169 | - 🟦***index_strengths***: string list of indeces or python-style ranges of indeces to assign strengths to. If latent_optional is passed in, can contain negative indeces or ranges that contain negative numbers, python-style. The different indeces must be comma separated. Individual latents can be specified by ```batch_index=strength```, like ```0=0.9```. Ranges can be specified by ```start_index_inclusive:end_index_exclusive=strength```, like ```0:8=strength```. Negative indeces are possible when latents_optional has an input, with a string such as ```0,-4=0.25```. 170 | - 🟦***print_keyframes***: if True, will print the Latent Keyframes generated by this node for debugging purposes. 171 | 172 | ### Outputs 173 | - 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input. 174 | 175 | ## Latent Keyframe Interpolation 176 | ![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/7986c737-83b9-46bc-aab0-ae4c368df446) 177 | 178 | Allows to create Latent Keyframes with interpolated values in a range. 179 | 180 | ### Inputs 181 | - 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If any Latent Keyframes contained in prev_latent_keyframes have the same batch_index as a this Latent Keyframe, they will take priority over this node's version.* 182 | - 🟦***batch_index_from***: starting batch_index of range, included. 183 | - 🟦***batch_index_to***: end batch_index of range, excluded (python-style range). 184 | - 🟦***strength_from***: starting strength of interpolation. 185 | - 🟦***strength_to***: end strength of interpolation. 186 | - 🟦***interpolation***: the method of interpolation. 187 | - 🟦***print_keyframes***: if True, will print the Latent Keyframes generated by this node for debugging purposes. 188 | 189 | ### Outputs 190 | - 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input. 191 | 192 | ## Latent Keyframe From List 193 | ![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/6cec701f-6183-4aeb-af5c-cac76f5591b7) 194 | 195 | Allows to create Latent Keyframes via a list of floats, such as with Batch Value Schedule from [ComfyUI_FizzNodes](https://github.com/FizzleDorf/ComfyUI_FizzNodes) nodes. 196 | 197 | ### Inputs 198 | - 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If any Latent Keyframes contained in prev_latent_keyframes have the same batch_index as a this Latent Keyframe, they will take priority over this node's version.* 199 | - 🟩***float_strengths***: a list of floats, that will correspond to the strength of each Latent Keyframe; the batch_index is the index of each float value in the list. 200 | - 🟦***print_keyframes***: if True, will print the Latent Keyframes generated by this node for debugging purposes. 201 | 202 | ### Outputs 203 | - 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input. 204 | 205 | # There are more nodes to document and show usage - will add this soon! TODO 206 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .adv_control.nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | from .adv_control import documentation 3 | from .adv_control.dinklink import init_dinklink 4 | from .adv_control.sampling import prepare_dinklink_acn_wrapper 5 | 6 | WEB_DIRECTORY = "./web" 7 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS', "WEB_DIRECTORY"] 8 | documentation.format_descriptions(NODE_CLASS_MAPPINGS) 9 | 10 | init_dinklink() 11 | prepare_dinklink_acn_wrapper() 12 | -------------------------------------------------------------------------------- /adv_control/control_ctrlora.py: -------------------------------------------------------------------------------- 1 | # Core code adapted from CtrLoRA github repo: 2 | # https://github.com/xyfJASON/ctrlora 3 | import torch 4 | from torch import Tensor 5 | 6 | from comfy.cldm.cldm import ControlNet as ControlNetCLDM 7 | import comfy.model_detection 8 | import comfy.model_management 9 | import comfy.ops 10 | import comfy.utils 11 | 12 | from comfy.ldm.modules.diffusionmodules.util import ( 13 | zero_module, 14 | timestep_embedding, 15 | ) 16 | 17 | from .control import ControlNetAdvanced 18 | from .utils import TimestepKeyframeGroup 19 | from .logger import logger 20 | 21 | 22 | class ControlNetCtrLoRA(ControlNetCLDM): 23 | def __init__(self, *args, **kwargs): 24 | super().__init__(*args, **kwargs) 25 | # delete input hint block 26 | del self.input_hint_block 27 | 28 | def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs): 29 | t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) 30 | emb = self.time_embed(t_emb) 31 | 32 | out_output = [] 33 | out_middle = [] 34 | 35 | if self.num_classes is not None: 36 | assert y.shape[0] == x.shape[0] 37 | emb = emb + self.label_emb(y) 38 | 39 | h = hint.to(dtype=x.dtype) 40 | for module, zero_conv in zip(self.input_blocks, self.zero_convs): 41 | h = module(h, emb, context) 42 | out_output.append(zero_conv(h, emb, context)) 43 | 44 | h = self.middle_block(h, emb, context) 45 | out_middle.append(self.middle_block_out(h, emb, context)) 46 | 47 | return {"middle": out_middle, "output": out_output} 48 | 49 | 50 | class CtrLoRAAdvanced(ControlNetAdvanced): 51 | def __init__(self, *args, **kwargs): 52 | super().__init__(*args, **kwargs) 53 | self.preprocess_image = lambda a: (a + 1) / 2.0 54 | self.require_vae = True 55 | self.mult_by_ratio_when_vae = False 56 | 57 | def pre_run_advanced(self, model, percent_to_timestep_function): 58 | super().pre_run_advanced(model, percent_to_timestep_function) 59 | self.latent_format = model.latent_format # LatentFormat object, used to process_in latent cond hint 60 | 61 | def cleanup_advanced(self): 62 | super().cleanup_advanced() 63 | if self.latent_format is not None: 64 | del self.latent_format 65 | self.latent_format = None 66 | 67 | def copy(self): 68 | c = CtrLoRAAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) 69 | c.control_model = self.control_model 70 | c.control_model_wrapped = self.control_model_wrapped 71 | self.copy_to(c) 72 | self.copy_to_advanced(c) 73 | return c 74 | 75 | 76 | def load_ctrlora(base_path: str, lora_path: str, 77 | base_data: dict[str, Tensor]=None, lora_data: dict[str, Tensor]=None, 78 | timestep_keyframe: TimestepKeyframeGroup=None, model=None, model_options={}): 79 | if base_data is None: 80 | base_data = comfy.utils.load_torch_file(base_path, safe_load=True) 81 | controlnet_data = base_data 82 | 83 | # first, check that base_data contains keys with lora_layer 84 | contains_lora_layers = False 85 | for key in base_data: 86 | if "lora_layer" in key: 87 | contains_lora_layers = True 88 | if not contains_lora_layers: 89 | raise Exception(f"File '{base_path}' is not a valid CtrLoRA base model; does not contain any lora_layer keys.") 90 | 91 | controlnet_config = None 92 | supported_inference_dtypes = None 93 | 94 | pth_key = 'control_model.zero_convs.0.0.weight' 95 | pth = False 96 | key = 'zero_convs.0.0.weight' 97 | if pth_key in controlnet_data: 98 | pth = True 99 | key = pth_key 100 | prefix = "control_model." 101 | elif key in controlnet_data: 102 | prefix = "" 103 | else: 104 | raise Exception("") 105 | net = load_t2i_adapter(controlnet_data, model_options=model_options) 106 | if net is None: 107 | logging.error("error could not detect control model type.") 108 | return net 109 | 110 | if controlnet_config is None: 111 | model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True) 112 | supported_inference_dtypes = list(model_config.supported_inference_dtypes) 113 | controlnet_config = model_config.unet_config 114 | 115 | unet_dtype = model_options.get("dtype", None) 116 | if unet_dtype is None: 117 | weight_dtype = comfy.utils.weight_dtype(controlnet_data) 118 | 119 | if supported_inference_dtypes is None: 120 | supported_inference_dtypes = [comfy.model_management.unet_dtype()] 121 | 122 | if weight_dtype is not None: 123 | supported_inference_dtypes.append(weight_dtype) 124 | 125 | unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes) 126 | 127 | load_device = comfy.model_management.get_torch_device() 128 | 129 | manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) 130 | operations = model_options.get("custom_operations", None) 131 | if operations is None: 132 | operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype) 133 | 134 | controlnet_config["operations"] = operations 135 | controlnet_config["dtype"] = unet_dtype 136 | controlnet_config["device"] = comfy.model_management.unet_offload_device() 137 | controlnet_config.pop("out_channels") 138 | controlnet_config["hint_channels"] = 3 139 | #controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] 140 | control_model = ControlNetCtrLoRA(**controlnet_config) 141 | 142 | if pth: 143 | if 'difference' in controlnet_data: 144 | if model is not None: 145 | comfy.model_management.load_models_gpu([model]) 146 | model_sd = model.model_state_dict() 147 | for x in controlnet_data: 148 | c_m = "control_model." 149 | if x.startswith(c_m): 150 | sd_key = "diffusion_model.{}".format(x[len(c_m):]) 151 | if sd_key in model_sd: 152 | cd = controlnet_data[x] 153 | cd += model_sd[sd_key].type(cd.dtype).to(cd.device) 154 | else: 155 | logger.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") 156 | 157 | class WeightsLoader(torch.nn.Module): 158 | pass 159 | w = WeightsLoader() 160 | w.control_model = control_model 161 | missing, unexpected = w.load_state_dict(controlnet_data, strict=False) 162 | else: 163 | missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) 164 | 165 | if len(missing) > 0: 166 | logger.warning("missing controlnet keys: {}".format(missing)) 167 | 168 | if len(unexpected) > 0: 169 | logger.debug("unexpected controlnet keys: {}".format(unexpected)) 170 | 171 | global_average_pooling = model_options.get("global_average_pooling", False) 172 | control = CtrLoRAAdvanced(control_model, timestep_keyframe, global_average_pooling=global_average_pooling, 173 | load_device=load_device, manual_cast_dtype=manual_cast_dtype) 174 | # load lora data onto the controlnet 175 | if lora_path is not None: 176 | load_lora_data(control, lora_path) 177 | 178 | return control 179 | 180 | 181 | def load_lora_data(control: CtrLoRAAdvanced, lora_path: str, loaded_data: dict[str, Tensor]=None, lora_strength=1.0): 182 | if loaded_data is None: 183 | loaded_data = comfy.utils.load_torch_file(lora_path, safe_load=True) 184 | # check that lora_data contains keys with lora_layer 185 | contains_lora_layers = False 186 | for key in loaded_data: 187 | if "lora_layer" in key: 188 | contains_lora_layers = True 189 | if not contains_lora_layers: 190 | raise Exception(f"File '{lora_path}' is not a valid CtrLoRA lora model; does not contain any lora_layer keys.") 191 | 192 | # now that we know we have a ctrlora file, separate keys into 'set' and 'lora' keys 193 | data_set: dict[str, Tensor] = {} 194 | data_lora: dict[str, Tensor] = {} 195 | 196 | for key in list(loaded_data.keys()): 197 | if 'lora_layer' in key: 198 | data_lora[key] = loaded_data.pop(key) 199 | else: 200 | data_set[key] = loaded_data.pop(key) 201 | # no keys should be left over 202 | if len(loaded_data) > 0: 203 | logger.warning("Not all keys from CtrlLoRA lora model's loaded data were parsed!") 204 | 205 | # turn set/lora data into corresponding patches; 206 | patches = {} 207 | # set will replace the values 208 | for key, value in data_set.items(): 209 | # prase model key from key; 210 | # remove "control_model." 211 | model_key = key.replace("control_model.", "") 212 | patches[model_key] = ("set", (value,)) 213 | # lora will do mm of up and down tensors 214 | for down_key in data_lora: 215 | # only process lora down keys; we will process both up+down at the same time 216 | if ".up." in down_key: 217 | continue 218 | # get up version of down key 219 | up_key = down_key.replace(".down.", ".up.") 220 | # get key that will match up with model key; 221 | # remove "lora_layer.down." and "control_model." 222 | model_key = down_key.replace("lora_layer.down.", "").replace("control_model.", "") 223 | 224 | weight_down = data_lora[down_key] 225 | weight_up = data_lora[up_key] 226 | # currently, ComfyUI expects 6 elements in 'lora' type, but for future-proofing add a bunch more with None 227 | patches[model_key] = ("lora", (weight_up, weight_down, None, None, None, None, 228 | None, None, None, None, None, None, None, None)) 229 | 230 | # now that patches are made, add them to model 231 | control.control_model_wrapped.add_patches(patches, strength_patch=lora_strength) 232 | -------------------------------------------------------------------------------- /adv_control/control_lllite.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI 2 | # basically, all the LLLite core code is from there, which I then combined with 3 | # Advanced-ControlNet features and QoL 4 | import math 5 | from typing import Union 6 | from torch import Tensor 7 | import torch 8 | import os 9 | 10 | import comfy.utils 11 | import comfy.ops 12 | import comfy.model_management 13 | from comfy.model_patcher import ModelPatcher 14 | from comfy.controlnet import ControlBase 15 | 16 | from .logger import logger 17 | from .utils import (AdvancedControlBase, TimestepKeyframeGroup, ControlWeights, broadcast_image_to_extend, extend_to_batch_size, 18 | prepare_mask_batch) 19 | 20 | 21 | # based on set_model_patch code in comfy/model_patcher.py 22 | def set_model_patch(transformer_options, patch, name): 23 | to = transformer_options 24 | # check if patch was already added 25 | if "patches" in to: 26 | current_patches = to["patches"].get(name, []) 27 | if patch in current_patches: 28 | return 29 | if "patches" not in to: 30 | to["patches"] = {} 31 | to["patches"][name] = to["patches"].get(name, []) + [patch] 32 | 33 | def set_model_attn1_patch(transformer_options, patch): 34 | set_model_patch(transformer_options, patch, "attn1_patch") 35 | 36 | def set_model_attn2_patch(transformer_options, patch): 37 | set_model_patch(transformer_options, patch, "attn2_patch") 38 | 39 | 40 | def extra_options_to_module_prefix(extra_options): 41 | # extra_options = {'transformer_index': 2, 'block_index': 8, 'original_shape': [2, 4, 128, 128], 'block': ('input', 7), 'n_heads': 20, 'dim_head': 64} 42 | 43 | # block is: [('input', 4), ('input', 5), ('input', 7), ('input', 8), ('middle', 0), 44 | # ('output', 0), ('output', 1), ('output', 2), ('output', 3), ('output', 4), ('output', 5)] 45 | # transformer_index is: [0, 1, 2, 3, 4, 5, 6, 7, 8], for each block 46 | # block_index is: 0-1 or 0-9, depends on the block 47 | # input 7 and 8, middle has 10 blocks 48 | 49 | # make module name from extra_options 50 | block = extra_options["block"] 51 | block_index = extra_options["block_index"] 52 | if block[0] == "input": 53 | module_pfx = f"lllite_unet_input_blocks_{block[1]}_1_transformer_blocks_{block_index}" 54 | elif block[0] == "middle": 55 | module_pfx = f"lllite_unet_middle_block_1_transformer_blocks_{block_index}" 56 | elif block[0] == "output": 57 | module_pfx = f"lllite_unet_output_blocks_{block[1]}_1_transformer_blocks_{block_index}" 58 | else: 59 | raise Exception(f"ControlLLLite: invalid block name '{block[0]}'. Expected 'input', 'middle', or 'output'.") 60 | return module_pfx 61 | 62 | 63 | class LLLitePatch: 64 | ATTN1 = "attn1" 65 | ATTN2 = "attn2" 66 | def __init__(self, modules: dict[str, 'LLLiteModule'], patch_type: str, control: Union[AdvancedControlBase, ControlBase]=None): 67 | self.modules = modules 68 | self.control = control 69 | self.patch_type = patch_type 70 | #logger.error(f"create LLLitePatch: {id(self)},{control}") 71 | 72 | def __call__(self, q, k, v, extra_options): 73 | #logger.error(f"in __call__: {id(self)}") 74 | # determine if have anything to run 75 | if self.control.timestep_range is not None: 76 | # it turns out comparing single-value tensors to floats is extremely slow 77 | # a: Tensor = extra_options["sigmas"][0] 78 | if self.control.t > self.control.timestep_range[0] or self.control.t < self.control.timestep_range[1]: 79 | return q, k, v 80 | 81 | module_pfx = extra_options_to_module_prefix(extra_options) 82 | 83 | is_attn1 = q.shape[-1] == k.shape[-1] # self attention 84 | if is_attn1: 85 | module_pfx = module_pfx + "_attn1" 86 | else: 87 | module_pfx = module_pfx + "_attn2" 88 | 89 | module_pfx_to_q = module_pfx + "_to_q" 90 | module_pfx_to_k = module_pfx + "_to_k" 91 | module_pfx_to_v = module_pfx + "_to_v" 92 | 93 | if module_pfx_to_q in self.modules: 94 | q = q + self.modules[module_pfx_to_q](q, self.control) 95 | if module_pfx_to_k in self.modules: 96 | k = k + self.modules[module_pfx_to_k](k, self.control) 97 | if module_pfx_to_v in self.modules: 98 | v = v + self.modules[module_pfx_to_v](v, self.control) 99 | 100 | return q, k, v 101 | 102 | def to(self, device): 103 | #logger.info(f"to... has control? {self.control}") 104 | for d in self.modules.keys(): 105 | self.modules[d] = self.modules[d].to(device) 106 | return self 107 | 108 | def set_control(self, control: Union[AdvancedControlBase, ControlBase]) -> 'LLLitePatch': 109 | self.control = control 110 | return self 111 | #logger.error(f"set control for LLLitePatch: {id(self)}, cn: {id(control)}") 112 | 113 | def clone_with_control(self, control: AdvancedControlBase): 114 | #logger.error(f"clone-set control for LLLitePatch: {id(self)},{id(control)}") 115 | return LLLitePatch(self.modules, self.patch_type, control) 116 | 117 | def cleanup(self): 118 | for module in self.modules.values(): 119 | module.cleanup() 120 | 121 | 122 | # TODO: use comfy.ops to support fp8 properly 123 | class LLLiteModule(torch.nn.Module): 124 | def __init__( 125 | self, 126 | name: str, 127 | is_conv2d: bool, 128 | in_dim: int, 129 | depth: int, 130 | cond_emb_dim: int, 131 | mlp_dim: int, 132 | ): 133 | super().__init__() 134 | self.name = name 135 | self.is_conv2d = is_conv2d 136 | self.is_first = False 137 | 138 | modules = [] 139 | modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2 140 | if depth == 1: 141 | modules.append(torch.nn.ReLU(inplace=True)) 142 | modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) 143 | elif depth == 2: 144 | modules.append(torch.nn.ReLU(inplace=True)) 145 | modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0)) 146 | elif depth == 3: 147 | # kernel size 8 is too large, so set it to 4 148 | modules.append(torch.nn.ReLU(inplace=True)) 149 | modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) 150 | modules.append(torch.nn.ReLU(inplace=True)) 151 | modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) 152 | 153 | self.conditioning1 = torch.nn.Sequential(*modules) 154 | 155 | if self.is_conv2d: 156 | self.down = torch.nn.Sequential( 157 | torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0), 158 | torch.nn.ReLU(inplace=True), 159 | ) 160 | self.mid = torch.nn.Sequential( 161 | torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0), 162 | torch.nn.ReLU(inplace=True), 163 | ) 164 | self.up = torch.nn.Sequential( 165 | torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0), 166 | ) 167 | else: 168 | self.down = torch.nn.Sequential( 169 | torch.nn.Linear(in_dim, mlp_dim), 170 | torch.nn.ReLU(inplace=True), 171 | ) 172 | self.mid = torch.nn.Sequential( 173 | torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim), 174 | torch.nn.ReLU(inplace=True), 175 | ) 176 | self.up = torch.nn.Sequential( 177 | torch.nn.Linear(mlp_dim, in_dim), 178 | ) 179 | 180 | self.depth = depth 181 | self.cond_emb = None 182 | self.cx_shape = None 183 | self.prev_batch = 0 184 | self.prev_sub_idxs = None 185 | 186 | def cleanup(self): 187 | del self.cond_emb 188 | self.cond_emb = None 189 | self.cx_shape = None 190 | self.prev_batch = 0 191 | self.prev_sub_idxs = None 192 | 193 | def forward(self, x: Tensor, control: Union[AdvancedControlBase, ControlBase]): 194 | mask = None 195 | mask_tk = None 196 | #logger.info(x.shape) 197 | if self.cond_emb is None or control.sub_idxs != self.prev_sub_idxs or x.shape[0] != self.prev_batch: 198 | # print(f"cond_emb is None, {self.name}") 199 | cond_hint = control.cond_hint.to(x.device, dtype=x.dtype) 200 | if control.latent_dims_div2 is not None and x.shape[-1] != 1280: 201 | cond_hint = comfy.utils.common_upscale(cond_hint, control.latent_dims_div2[0] * 8, control.latent_dims_div2[1] * 8, 'nearest-exact', "center").to(x.device, dtype=x.dtype) 202 | elif control.latent_dims_div4 is not None and x.shape[-1] == 1280: 203 | cond_hint = comfy.utils.common_upscale(cond_hint, control.latent_dims_div4[0] * 8, control.latent_dims_div4[1] * 8, 'nearest-exact', "center").to(x.device, dtype=x.dtype) 204 | cx = self.conditioning1(cond_hint) 205 | self.cx_shape = cx.shape 206 | if not self.is_conv2d: 207 | # reshape / b,c,h,w -> b,h*w,c 208 | n, c, h, w = cx.shape 209 | cx = cx.view(n, c, h * w).permute(0, 2, 1) 210 | self.cond_emb = cx 211 | # save prev values 212 | self.prev_batch = x.shape[0] 213 | self.prev_sub_idxs = control.sub_idxs 214 | 215 | cx: torch.Tensor = self.cond_emb 216 | # print(f"forward {self.name}, {cx.shape}, {x.shape}") 217 | 218 | # TODO: make masks work for conv2d (could not find any ControlLLLites at this time that use them) 219 | # create masks 220 | if not self.is_conv2d: 221 | n, c, h, w = self.cx_shape 222 | if control.mask_cond_hint is not None: 223 | mask = prepare_mask_batch(control.mask_cond_hint, (1, 1, h, w)).to(cx.dtype) 224 | mask = mask.view(mask.shape[0], 1, h * w).permute(0, 2, 1) 225 | if control.tk_mask_cond_hint is not None: 226 | mask_tk = prepare_mask_batch(control.mask_cond_hint, (1, 1, h, w)).to(cx.dtype) 227 | mask_tk = mask_tk.view(mask_tk.shape[0], 1, h * w).permute(0, 2, 1) 228 | 229 | # x in uncond/cond doubles batch size 230 | if x.shape[0] != cx.shape[0]: 231 | if self.is_conv2d: 232 | cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1, 1) 233 | else: 234 | # print("x.shape[0] != cx.shape[0]", x.shape[0], cx.shape[0]) 235 | cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1) 236 | if mask is not None: 237 | mask = mask.repeat(x.shape[0] // mask.shape[0], 1, 1) 238 | if mask_tk is not None: 239 | mask_tk = mask_tk.repeat(x.shape[0] // mask_tk.shape[0], 1, 1) 240 | 241 | if mask is None: 242 | mask = 1.0 243 | elif mask_tk is not None: 244 | mask = mask * mask_tk 245 | 246 | #logger.info(f"cs: {cx.shape}, x: {x.shape}, is_conv2d: {self.is_conv2d}") 247 | cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2) 248 | cx = self.mid(cx) 249 | cx = self.up(cx) 250 | if control.latent_keyframes is not None: 251 | cx = cx * control.calc_latent_keyframe_mults(x=cx, batched_number=control.batched_number) 252 | if control.weights is not None and control.weights.has_uncond_multiplier: 253 | cond_or_uncond = control.batched_number.cond_or_uncond 254 | actual_length = cx.size(0) // control.batched_number 255 | for idx, cond_type in enumerate(cond_or_uncond): 256 | # if uncond, set to weight's uncond_multiplier 257 | if cond_type == 1: 258 | cx[actual_length*idx:actual_length*(idx+1)] *= control.weights.uncond_multiplier 259 | return cx * mask * control.strength * control._current_timestep_keyframe.strength 260 | 261 | 262 | class ControlLLLiteModules(torch.nn.Module): 263 | def __init__(self, patch_attn1: LLLitePatch, patch_attn2: LLLitePatch): 264 | super().__init__() 265 | self.patch_attn1_modules = torch.nn.Sequential(*list(patch_attn1.modules.values())) 266 | self.patch_attn2_modules = torch.nn.Sequential(*list(patch_attn2.modules.values())) 267 | 268 | 269 | class ControlLLLiteAdvanced(ControlBase, AdvancedControlBase): 270 | # This ControlNet is more of an attention patch than a traditional controlnet 271 | def __init__(self, patch_attn1: LLLitePatch, patch_attn2: LLLitePatch, timestep_keyframes: TimestepKeyframeGroup, device, ops: comfy.ops.disable_weight_init): 272 | super().__init__() 273 | AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite()) 274 | self.device = device 275 | self.ops = ops 276 | self.patch_attn1 = patch_attn1.clone_with_control(self) 277 | self.patch_attn2 = patch_attn2.clone_with_control(self) 278 | self.control_model = ControlLLLiteModules(self.patch_attn1, self.patch_attn2) 279 | self.control_model_wrapped = ModelPatcher(self.control_model, load_device=device, offload_device=comfy.model_management.unet_offload_device()) 280 | self.latent_dims_div2 = None 281 | self.latent_dims_div4 = None 282 | 283 | def set_cond_hint_inject(self, *args, **kwargs): 284 | to_return = super().set_cond_hint_inject(*args, **kwargs) 285 | # cond hint for LLLite needs to be scaled between (-1, 1) instead of (0, 1) 286 | self.cond_hint_original = self.cond_hint_original * 2.0 - 1.0 287 | return to_return 288 | 289 | def pre_run_advanced(self, *args, **kwargs): 290 | AdvancedControlBase.pre_run_advanced(self, *args, **kwargs) 291 | #logger.error(f"in cn: {id(self.patch_attn1)},{id(self.patch_attn2)}") 292 | self.patch_attn1.set_control(self) 293 | self.patch_attn2.set_control(self) 294 | #logger.warn(f"in pre_run_advanced: {id(self)}") 295 | 296 | def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int, transformer_options: dict): 297 | # normal ControlNet stuff 298 | control_prev = None 299 | if self.previous_controlnet is not None: 300 | control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options) 301 | 302 | if self.timestep_range is not None: 303 | if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: 304 | return control_prev 305 | 306 | dtype = x_noisy.dtype 307 | # prepare cond_hint 308 | if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: 309 | if self.cond_hint is not None: 310 | del self.cond_hint 311 | self.cond_hint = None 312 | # if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling 313 | if self.sub_idxs is not None: 314 | actual_cond_hint_orig = self.cond_hint_original 315 | if self.cond_hint_original.size(0) < self.full_latent_length: 316 | actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length) 317 | self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device) 318 | else: 319 | self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device) 320 | if x_noisy.shape[0] != self.cond_hint.shape[0]: 321 | self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number) 322 | # some special logic here compared to other controlnets: 323 | # * The cond_emb in attn patches will divide latent dims by 2 or 4, integer 324 | # * Due to this loss, the cond_emb will become smaller than x input if latent dims are not divisble by 2 or 4 325 | divisible_by_2_h = x_noisy.shape[2]%2==0 326 | divisible_by_2_w = x_noisy.shape[3]%2==0 327 | if not (divisible_by_2_h and divisible_by_2_w): 328 | #logger.warn(f"{x_noisy.shape} not divisible by 2!") 329 | new_h = (x_noisy.shape[2]//2)*2 330 | new_w = (x_noisy.shape[3]//2)*2 331 | if not divisible_by_2_h: 332 | new_h += 2 333 | if not divisible_by_2_w: 334 | new_w += 2 335 | self.latent_dims_div2 = (new_h, new_w) 336 | divisible_by_4_h = x_noisy.shape[2]%4==0 337 | divisible_by_4_w = x_noisy.shape[3]%4==0 338 | if not (divisible_by_4_h and divisible_by_4_w): 339 | #logger.warn(f"{x_noisy.shape} not divisible by 4!") 340 | new_h = (x_noisy.shape[2]//4)*4 341 | new_w = (x_noisy.shape[3]//4)*4 342 | if not divisible_by_4_h: 343 | new_h += 4 344 | if not divisible_by_4_w: 345 | new_w += 4 346 | self.latent_dims_div4 = (new_h, new_w) 347 | # prepare mask 348 | self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number) 349 | # done preparing; model patches will take care of everything now 350 | set_model_attn1_patch(transformer_options, self.patch_attn1.set_control(self)) 351 | set_model_attn2_patch(transformer_options, self.patch_attn2.set_control(self)) 352 | # return normal controlnet stuff 353 | return control_prev 354 | 355 | def get_models(self): 356 | to_return: list = super().get_models() 357 | to_return.append(self.control_model_wrapped) 358 | return to_return 359 | 360 | def cleanup_advanced(self): 361 | super().cleanup_advanced() 362 | self.patch_attn1.cleanup() 363 | self.patch_attn2.cleanup() 364 | self.latent_dims_div2 = None 365 | self.latent_dims_div4 = None 366 | 367 | def copy(self): 368 | c = ControlLLLiteAdvanced(self.patch_attn1, self.patch_attn2, self.timestep_keyframes, self.device, self.ops) 369 | self.copy_to(c) 370 | self.copy_to_advanced(c) 371 | return c 372 | 373 | 374 | def load_controllllite(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None): 375 | if controlnet_data is None: 376 | controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) 377 | # adapted from https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI 378 | # first, split weights for each module 379 | module_weights = {} 380 | for key, value in controlnet_data.items(): 381 | fragments = key.split(".") 382 | module_name = fragments[0] 383 | weight_name = ".".join(fragments[1:]) 384 | 385 | if module_name not in module_weights: 386 | module_weights[module_name] = {} 387 | module_weights[module_name][weight_name] = value 388 | 389 | unet_dtype = comfy.model_management.unet_dtype() 390 | load_device = comfy.model_management.get_torch_device() 391 | manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) 392 | ops = comfy.ops.disable_weight_init 393 | if manual_cast_dtype is not None: 394 | ops = comfy.ops.manual_cast 395 | 396 | # next, load each module 397 | modules = {} 398 | for module_name, weights in module_weights.items(): 399 | # kohya planned to do something about how these should be chosen, so I'm not touching this 400 | # since I am not familiar with the logic for this 401 | if "conditioning1.4.weight" in weights: 402 | depth = 3 403 | elif weights["conditioning1.2.weight"].shape[-1] == 4: 404 | depth = 2 405 | else: 406 | depth = 1 407 | 408 | module = LLLiteModule( 409 | name=module_name, 410 | is_conv2d=weights["down.0.weight"].ndim == 4, 411 | in_dim=weights["down.0.weight"].shape[1], 412 | depth=depth, 413 | cond_emb_dim=weights["conditioning1.0.weight"].shape[0] * 2, 414 | mlp_dim=weights["down.0.weight"].shape[0], 415 | ) 416 | # load weights into module 417 | module.load_state_dict(weights) 418 | modules[module_name] = module.to(dtype=unet_dtype) 419 | if len(modules) == 1: 420 | module.is_first = True 421 | 422 | #logger.info(f"loaded {ckpt_path} successfully, {len(modules)} modules") 423 | 424 | patch_attn1 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN1) 425 | patch_attn2 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN2) 426 | control = ControlLLLiteAdvanced(patch_attn1=patch_attn1, patch_attn2=patch_attn2, timestep_keyframes=timestep_keyframe, device=load_device, ops=ops) 427 | return control 428 | -------------------------------------------------------------------------------- /adv_control/control_sparsectrl.py: -------------------------------------------------------------------------------- 1 | #taken from: https://github.com/lllyasviel/ControlNet 2 | #and modified 3 | #and then taken from comfy/cldm/cldm.py and modified again 4 | 5 | from abc import ABC, abstractmethod 6 | import numpy as np 7 | import torch 8 | from torch import Tensor 9 | 10 | from comfy.ldm.modules.diffusionmodules.util import ( 11 | zero_module, 12 | timestep_embedding, 13 | ) 14 | 15 | from comfy.cldm.cldm import ControlNet as ControlNetCLDM 16 | from comfy.ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential 17 | from comfy.model_patcher import ModelPatcher 18 | from comfy.patcher_extension import PatcherInjection 19 | 20 | from .dinklink import (InterfaceAnimateDiffInfo, InterfaceAnimateDiffModel, 21 | get_CreateMotionModelPatcher, get_AnimateDiffModel, get_AnimateDiffInfo) 22 | from .logger import logger 23 | from .utils import (BIGMAX, AbstractPreprocWrapper, disable_weight_init_clean_groupnorm, WrapperConsts) 24 | 25 | 26 | class SparseMotionModelPatcher(ModelPatcher): 27 | '''Class only used for IDE type hints.''' 28 | def __init__(self, *args, **kwargs): 29 | self.model = InterfaceAnimateDiffModel 30 | 31 | 32 | class SparseConst: 33 | HINT_MULT = "sparse_hint_mult" 34 | NONHINT_MULT = "sparse_nonhint_mult" 35 | MASK_MULT = "sparse_mask_mult" 36 | 37 | 38 | class SparseControlNet(ControlNetCLDM): 39 | def __init__(self, *args,**kwargs): 40 | super().__init__(*args, **kwargs) 41 | hint_channels = kwargs.get("hint_channels") 42 | operations: disable_weight_init_clean_groupnorm = kwargs.get("operations", disable_weight_init_clean_groupnorm) 43 | device = kwargs.get("device", None) 44 | self.use_simplified_conditioning_embedding = kwargs.get("use_simplified_conditioning_embedding", False) 45 | if self.use_simplified_conditioning_embedding: 46 | self.input_hint_block = TimestepEmbedSequential( 47 | zero_module(operations.conv_nd(self.dims, hint_channels, self.model_channels, 3, padding=1, dtype=self.dtype, device=device)), 48 | ) 49 | 50 | def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs): 51 | t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) 52 | emb = self.time_embed(t_emb) 53 | 54 | # SparseCtrl sets noisy input to zeros 55 | x = torch.zeros_like(x) 56 | guided_hint = self.input_hint_block(hint, emb, context) 57 | 58 | out_output = [] 59 | out_middle = [] 60 | 61 | hs = [] 62 | if self.num_classes is not None: 63 | assert y.shape[0] == x.shape[0] 64 | emb = emb + self.label_emb(y) 65 | 66 | h = x 67 | for module, zero_conv in zip(self.input_blocks, self.zero_convs): 68 | if guided_hint is not None: 69 | h = module(h, emb, context) 70 | h += guided_hint 71 | guided_hint = None 72 | else: 73 | h = module(h, emb, context) 74 | out_output.append(zero_conv(h, emb, context)) 75 | 76 | h = self.middle_block(h, emb, context) 77 | out_middle.append(self.middle_block_out(h, emb, context)) 78 | 79 | return {"middle": out_middle, "output": out_output} 80 | 81 | 82 | def load_sparsectrl_motionmodel(ckpt_path: str, motion_data: dict[str, Tensor], ops=None) -> InterfaceAnimateDiffModel: 83 | mm_info: InterfaceAnimateDiffInfo = get_AnimateDiffInfo()("SD1.5", "AnimateDiff", "v3", ckpt_path) 84 | init_kwargs = { 85 | "ops": ops, 86 | "get_unet_func": _get_unet_func, 87 | } 88 | motion_model: InterfaceAnimateDiffModel = get_AnimateDiffModel()(mm_state_dict=motion_data, mm_info=mm_info, init_kwargs=init_kwargs) 89 | missing, unexpected = motion_model.load_state_dict(motion_data) 90 | if len(missing) > 0 or len(unexpected) > 0: 91 | logger.info(f"SparseCtrl MotionModel: {missing}, {unexpected}") 92 | return motion_model 93 | 94 | 95 | def create_sparse_modelpatcher(model, motion_model, load_device, offload_device): 96 | patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device) 97 | if motion_model is not None: 98 | _motionpatcher = _create_sparse_motionmodelpatcher(motion_model, load_device, offload_device) 99 | patcher.set_additional_models(WrapperConsts.ACN, [_motionpatcher]) 100 | patcher.set_injections(WrapperConsts.ACN, 101 | [PatcherInjection(inject=_inject_motion_models, eject=_eject_motion_models)]) 102 | return patcher 103 | 104 | def _create_sparse_motionmodelpatcher(motion_model, load_device, offload_device) -> SparseMotionModelPatcher: 105 | return get_CreateMotionModelPatcher()(motion_model, load_device, offload_device) 106 | 107 | 108 | def _inject_motion_models(patcher: ModelPatcher): 109 | motion_models: list[SparseMotionModelPatcher] = patcher.get_additional_models_with_key(WrapperConsts.ACN) 110 | for mm in motion_models: 111 | mm.model.inject(patcher) 112 | 113 | def _eject_motion_models(patcher: ModelPatcher): 114 | motion_models: list[SparseMotionModelPatcher] = patcher.get_additional_models_with_key(WrapperConsts.ACN) 115 | for mm in motion_models: 116 | mm.model.eject(patcher) 117 | 118 | def _get_unet_func(wrapper, model: ModelPatcher): 119 | return model.model 120 | 121 | 122 | class PreprocSparseRGBWrapper(AbstractPreprocWrapper): 123 | error_msg = error_msg = "Invalid use of RGB SparseCtrl output. The output of RGB SparseCtrl preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input." 124 | def __init__(self, condhint: Tensor): 125 | super().__init__(condhint) 126 | 127 | 128 | class SparseContextAware: 129 | NEAREST_HINT = "nearest_hint" 130 | OFF = "off" 131 | 132 | LIST = [NEAREST_HINT, OFF] 133 | 134 | 135 | class SparseSettings: 136 | def __init__(self, sparse_method: 'SparseMethod', use_motion: bool=True, motion_strength=1.0, motion_scale=1.0, merged=False, 137 | sparse_mask_mult=1.0, sparse_hint_mult=1.0, sparse_nonhint_mult=1.0, context_aware=SparseContextAware.NEAREST_HINT): 138 | # account for Steerable-Motion workflow incompatibility; 139 | # doing this to for my own peace of mind (not an issue with my code) 140 | if type(sparse_method) == str: 141 | logger.warn("Outdated Steerable-Motion workflow detected; attempting to auto-convert indexes input. If you experience an error here, consult Steerable-Motion github, NOT Advanced-ControlNet.") 142 | sparse_method = SparseIndexMethod(get_idx_list_from_str(sparse_method)) 143 | self.sparse_method = sparse_method 144 | self.use_motion = use_motion 145 | self.motion_strength = motion_strength 146 | self.motion_scale = motion_scale 147 | self.merged = merged 148 | self.sparse_mask_mult = float(sparse_mask_mult) 149 | self.sparse_hint_mult = float(sparse_hint_mult) 150 | self.sparse_nonhint_mult = float(sparse_nonhint_mult) 151 | self.context_aware = context_aware 152 | 153 | def is_context_aware(self): 154 | return self.context_aware != SparseContextAware.OFF 155 | 156 | @classmethod 157 | def default(cls): 158 | return SparseSettings(sparse_method=SparseSpreadMethod(), use_motion=True) 159 | 160 | 161 | class SparseMethod(ABC): 162 | SPREAD = "spread" 163 | INDEX = "index" 164 | def __init__(self, method: str): 165 | self.method = method 166 | 167 | @abstractmethod 168 | def _get_indexes(self, hint_length: int, full_length: int) -> list[int]: 169 | pass 170 | 171 | def get_indexes(self, hint_length: int, full_length: int, sub_idxs: list[int]=None) -> tuple[list[int], list[int]]: 172 | returned_idxs = self._get_indexes(hint_length, full_length) 173 | if sub_idxs is None: 174 | return returned_idxs, None 175 | # need to map full indexes to condhint indexes 176 | index_mapping = {} 177 | for i, value in enumerate(returned_idxs): 178 | index_mapping[value] = i 179 | def get_mapped_idxs(idxs: list[int]): 180 | return [index_mapping[idx] for idx in idxs] 181 | # check if returned_idxs fit within subidxs 182 | fitting_idxs = [] 183 | for sub_idx in sub_idxs: 184 | if sub_idx in returned_idxs: 185 | fitting_idxs.append(sub_idx) 186 | # if have any fitting_idxs, deal with it 187 | if len(fitting_idxs) > 0: 188 | return fitting_idxs, get_mapped_idxs(fitting_idxs) 189 | 190 | # since no returned_idxs fit in sub_idxs, need to get the next-closest hint images based on strategy 191 | def get_closest_idx(target_idx: int, idxs: list[int]): 192 | min_idx = -1 193 | min_dist = BIGMAX 194 | for idx in idxs: 195 | new_dist = abs(idx-target_idx) 196 | if new_dist < min_dist: 197 | min_idx = idx 198 | min_dist = new_dist 199 | if min_dist == 1: 200 | return min_idx, min_dist 201 | return min_idx, min_dist 202 | start_closest_idx, start_dist = get_closest_idx(sub_idxs[0], returned_idxs) 203 | end_closest_idx, end_dist = get_closest_idx(sub_idxs[-1], returned_idxs) 204 | # if only one cond hint exists, do special behavior 205 | if hint_length == 1: 206 | # if same distance from start and end, 207 | if start_dist == end_dist: 208 | # find center index of sub_idxs 209 | center_idx = sub_idxs[np.linspace(0, len(sub_idxs)-1, 3, endpoint=True, dtype=int)[1]] 210 | return [center_idx], get_mapped_idxs([start_closest_idx]) 211 | # otherwise, return closest 212 | if start_dist < end_dist: 213 | return [sub_idxs[0]], get_mapped_idxs([start_closest_idx]) 214 | return [sub_idxs[-1]], get_mapped_idxs([end_closest_idx]) 215 | # otherwise, select up to two closest images, or just 1, whichever one applies best 216 | # if same distance from start and end, return two images to use 217 | if start_dist == end_dist: 218 | return [sub_idxs[0], sub_idxs[-1]], get_mapped_idxs([start_closest_idx, end_closest_idx]) 219 | # else, use just one 220 | if start_dist < end_dist: 221 | return [sub_idxs[0]], get_mapped_idxs([start_closest_idx]) 222 | return [sub_idxs[-1]], get_mapped_idxs([end_closest_idx]) 223 | 224 | 225 | class SparseSpreadMethod(SparseMethod): 226 | UNIFORM = "uniform" 227 | STARTING = "starting" 228 | ENDING = "ending" 229 | CENTER = "center" 230 | 231 | LIST = [UNIFORM, STARTING, ENDING, CENTER] 232 | 233 | def __init__(self, spread=UNIFORM): 234 | super().__init__(self.SPREAD) 235 | self.spread = spread 236 | 237 | def _get_indexes(self, hint_length: int, full_length: int) -> list[int]: 238 | # if hint_length >= full_length, limit hints to full_length 239 | if hint_length >= full_length: 240 | return list(range(full_length)) 241 | # handle special case of 1 hint image 242 | if hint_length == 1: 243 | if self.spread in [self.UNIFORM, self.STARTING]: 244 | return [0] 245 | elif self.spread == self.ENDING: 246 | return [full_length-1] 247 | elif self.spread == self.CENTER: 248 | # return second (of three) values as the center 249 | return [np.linspace(0, full_length-1, 3, endpoint=True, dtype=int)[1]] 250 | else: 251 | raise ValueError(f"Unrecognized spread: {self.spread}") 252 | # otherwise, handle other cases 253 | if self.spread == self.UNIFORM: 254 | return list(np.linspace(0, full_length-1, hint_length, endpoint=True, dtype=int)) 255 | elif self.spread == self.STARTING: 256 | # make split 1 larger, remove last element 257 | return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[:-1] 258 | elif self.spread == self.ENDING: 259 | # make split 1 larger, remove first element 260 | return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[1:] 261 | elif self.spread == self.CENTER: 262 | # if hint length is not 3 greater than full length, do STARTING behavior 263 | if full_length-hint_length < 3: 264 | return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[:-1] 265 | # otherwise, get linspace of 2 greater than needed, then cut off first and last 266 | return list(np.linspace(0, full_length-1, hint_length+2, endpoint=True, dtype=int))[1:-1] 267 | return ValueError(f"Unrecognized spread: {self.spread}") 268 | 269 | 270 | class SparseIndexMethod(SparseMethod): 271 | def __init__(self, idxs: list[int]): 272 | super().__init__(self.INDEX) 273 | self.idxs = idxs 274 | 275 | def _get_indexes(self, hint_length: int, full_length: int) -> list[int]: 276 | orig_hint_length = hint_length 277 | if hint_length > full_length: 278 | hint_length = full_length 279 | # if idxs is less than hint_length, throw error 280 | if len(self.idxs) < hint_length: 281 | err_msg = f"There are not enough indexes ({len(self.idxs)}) provided to fit the usable {hint_length} input images." 282 | if orig_hint_length != hint_length: 283 | err_msg = f"{err_msg} (original input images: {orig_hint_length})" 284 | raise ValueError(err_msg) 285 | # cap idxs to hint_length 286 | idxs = self.idxs[:hint_length] 287 | new_idxs = [] 288 | real_idxs = set() 289 | for idx in idxs: 290 | if idx < 0: 291 | real_idx = full_length+idx 292 | if real_idx in real_idxs: 293 | raise ValueError(f"Index '{idx}' maps to '{real_idx}' and is duplicate - indexes in Sparse Index Method must be unique.") 294 | else: 295 | real_idx = idx 296 | if real_idx in real_idxs: 297 | raise ValueError(f"Index '{idx}' is duplicate (or a negative index is equivalent) - indexes in Sparse Index Method must be unique.") 298 | real_idxs.add(real_idx) 299 | new_idxs.append(real_idx) 300 | return new_idxs 301 | 302 | 303 | def get_idx_list_from_str(indexes: str) -> list[int]: 304 | idxs = [] 305 | unique_idxs = set() 306 | # get indeces from string 307 | str_idxs = [x.strip() for x in indexes.strip().split(",")] 308 | for str_idx in str_idxs: 309 | try: 310 | idx = int(str_idx) 311 | if idx in unique_idxs: 312 | raise ValueError(f"'{idx}' is duplicated; indexes must be unique.") 313 | idxs.append(idx) 314 | unique_idxs.add(idx) 315 | except ValueError: 316 | raise ValueError(f"'{str_idx}' is not a valid integer index.") 317 | if len(idxs) == 0: 318 | raise ValueError(f"No indexes were listed in Sparse Index Method.") 319 | return idxs 320 | -------------------------------------------------------------------------------- /adv_control/dinklink.py: -------------------------------------------------------------------------------- 1 | #################################################################################################### 2 | # DinkLink is my method of sharing classes/functions between my nodes. 3 | # 4 | # My DinkLink-compatible nodes will inject comfy.hooks with a __DINKLINK attr 5 | # that stores a dictionary, where any of my node packs can store their stuff. 6 | # 7 | # It is not intended to be accessed by node packs that I don't develop, so things may change 8 | # at any time. 9 | # 10 | # DinkLink also serves as a proof-of-concept for a future ComfyUI implementation of 11 | # purposely exposing node pack classes/functions with other node packs. 12 | #################################################################################################### 13 | from __future__ import annotations 14 | from typing import Union 15 | from torch import Tensor, nn 16 | 17 | from comfy.model_patcher import ModelPatcher 18 | import comfy.hooks 19 | 20 | DINKLINK = "__DINKLINK" 21 | 22 | 23 | def init_dinklink(): 24 | create_dinklink() 25 | prepare_dinklink() 26 | 27 | def create_dinklink(): 28 | if not hasattr(comfy.hooks, DINKLINK): 29 | setattr(comfy.hooks, DINKLINK, {}) 30 | 31 | def get_dinklink() -> dict[str, dict[str]]: 32 | create_dinklink() 33 | return getattr(comfy.hooks, DINKLINK) 34 | 35 | 36 | class DinkLinkConst: 37 | VERSION = "version" 38 | # ADE 39 | ADE = "ADE" 40 | ADE_ANIMATEDIFFMODEL = "AnimateDiffModel" 41 | ADE_ANIMATEDIFFINFO = "AnimateDiffInfo" 42 | ADE_CREATE_MOTIONMODELPATCHER = "create_MotionModelPatcher" 43 | 44 | def prepare_dinklink(): 45 | pass 46 | 47 | 48 | class InterfaceAnimateDiffInfo: 49 | '''Class only used for IDE type hints; interface of ADE's AnimateDiffInfo''' 50 | def __init__(self, sd_type: str, mm_format: str, mm_version: str, mm_name: str): 51 | self.sd_type = sd_type 52 | self.mm_format = mm_format 53 | self.mm_version = mm_version 54 | self.mm_name = mm_name 55 | 56 | 57 | class InterfaceAnimateDiffModel(nn.Module): 58 | '''Class only used for IDE type hints; interface of ADE's AnimateDiffModel''' 59 | def __init__(self, mm_state_dict: dict[str, Tensor], mm_info: InterfaceAnimateDiffInfo, init_kwargs: dict[str]={}): 60 | pass 61 | 62 | def set_video_length(self, video_length: int, full_length: int) -> None: 63 | raise NotImplemented() 64 | 65 | def set_scale(self, scale: Union[float, Tensor, None], per_block_list: Union[list, None]=None) -> None: 66 | raise NotImplemented() 67 | 68 | def set_effect(self, multival: Union[float, Tensor, None], per_block_list: Union[list, None]=None) -> None: 69 | raise NotImplemented() 70 | 71 | def cleanup(self): 72 | raise NotImplemented() 73 | 74 | def inject(self, model: ModelPatcher): 75 | pass 76 | 77 | def eject(self, model: ModelPatcher): 78 | pass 79 | 80 | 81 | def get_CreateMotionModelPatcher(throw_exception=True): 82 | d = get_dinklink() 83 | try: 84 | link_ade = d[DinkLinkConst.ADE] 85 | return link_ade[DinkLinkConst.ADE_CREATE_MOTIONMODELPATCHER] 86 | except KeyError: 87 | if throw_exception: 88 | raise Exception("Could not get create_MotionModelPatcher function. AnimateDiff-Evolved nodes need to be installed to use SparseCtrl; " + \ 89 | "they are either not installed or are of an insufficient version.") 90 | return None 91 | 92 | def get_AnimateDiffModel(throw_exception=True): 93 | d = get_dinklink() 94 | try: 95 | link_ade = d[DinkLinkConst.ADE] 96 | return link_ade[DinkLinkConst.ADE_ANIMATEDIFFMODEL] 97 | except KeyError: 98 | if throw_exception: 99 | raise Exception("Could not get AnimateDiffModel class. AnimateDiff-Evolved nodes need to be installed to use SparseCtrl; " + \ 100 | "they are either not installed or are of an insufficient version.") 101 | return None 102 | 103 | def get_AnimateDiffInfo(throw_exception=True) -> InterfaceAnimateDiffInfo: 104 | d = get_dinklink() 105 | try: 106 | link_ade = d[DinkLinkConst.ADE] 107 | return link_ade[DinkLinkConst.ADE_ANIMATEDIFFINFO] 108 | except KeyError: 109 | if throw_exception: 110 | raise Exception("Could not get AnimateDiffInfo class - AnimateDiff-Evolved nodes need to be installed to use SparseCtrl; " + \ 111 | "they are either not installed or are of an insufficient version.") 112 | return None 113 | -------------------------------------------------------------------------------- /adv_control/documentation.py: -------------------------------------------------------------------------------- 1 | from .logger import logger 2 | 3 | def image(src): 4 | return f'' 5 | def video(src): 6 | return f'