├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── diffusers_helper ├── bucket_tools.py ├── dit_common.py ├── k_diffusion │ ├── uni_pc_fm.py │ └── wrapper.py ├── memory.py ├── models │ └── hunyuan_video_packed.py ├── pipelines │ └── k_diffusion_hunyuan.py └── utils.py ├── example_workflows └── framepack_hv_example.json ├── fp8_optimization.py ├── nodes.py ├── requirements.txt ├── transformer_config.json └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | hf_download/ 2 | outputs/ 3 | repo/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # UV 102 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | #uv.lock 106 | 107 | # poetry 108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 109 | # This is especially recommended for binary packages to ensure reproducibility, and is more 110 | # commonly ignored for libraries. 111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 112 | #poetry.lock 113 | 114 | # pdm 115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 116 | #pdm.lock 117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 118 | # in version control. 119 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 120 | .pdm.toml 121 | .pdm-python 122 | .pdm-build/ 123 | 124 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 125 | __pypackages__/ 126 | 127 | # Celery stuff 128 | celerybeat-schedule 129 | celerybeat.pid 130 | 131 | # SageMath parsed files 132 | *.sage.py 133 | 134 | # Environments 135 | .env 136 | .venv 137 | env/ 138 | venv/ 139 | ENV/ 140 | env.bak/ 141 | venv.bak/ 142 | 143 | # Spyder project settings 144 | .spyderproject 145 | .spyproject 146 | 147 | # Rope project settings 148 | .ropeproject 149 | 150 | # mkdocs documentation 151 | /site 152 | 153 | # mypy 154 | .mypy_cache/ 155 | .dmypy.json 156 | dmypy.json 157 | 158 | # Pyre type checker 159 | .pyre/ 160 | 161 | # pytype static type analyzer 162 | .pytype/ 163 | 164 | # Cython debug symbols 165 | cython_debug/ 166 | 167 | # PyCharm 168 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 169 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 170 | # and can be added to the global gitignore or merged into this file. For a more nuclear 171 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 172 | .idea/ 173 | 174 | # Ruff stuff: 175 | .ruff_cache/ 176 | 177 | # PyPI configuration file 178 | .pypirc 179 | demo_gradio.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI Wrapper for [FramePack by lllyasviel](https://lllyasviel.github.io/frame_pack_gitpage/) 2 | 3 | # WORK IN PROGRESS 4 | 5 | Mostly working, took some liberties to make it run faster. 6 | 7 | Uses all the native models for text encoders, VAE and sigclip: 8 | 9 | https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged/tree/main/split_files 10 | 11 | https://huggingface.co/Comfy-Org/sigclip_vision_384/tree/main 12 | 13 | And the transformer model itself is either autodownloaded from here: 14 | 15 | https://huggingface.co/lllyasviel/FramePackI2V_HY/tree/main 16 | 17 | to `ComfyUI\models\diffusers\lllyasviel\FramePackI2V_HY` 18 | 19 | Or from single file, in `ComfyUI\models\diffusion_models`: 20 | 21 | https://huggingface.co/Kijai/HunyuanVideo_comfy/blob/main/FramePackI2V_HY_fp8_e4m3fn.safetensors 22 | https://huggingface.co/Kijai/HunyuanVideo_comfy/blob/main/FramePackI2V_HY_bf16.safetensors 23 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] -------------------------------------------------------------------------------- /diffusers_helper/bucket_tools.py: -------------------------------------------------------------------------------- 1 | bucket_options = { 2 | (416, 960), 3 | (448, 864), 4 | (480, 832), 5 | (512, 768), 6 | (544, 704), 7 | (576, 672), 8 | (608, 640), 9 | (640, 608), 10 | (672, 576), 11 | (704, 544), 12 | (768, 512), 13 | (832, 480), 14 | (864, 448), 15 | (960, 416), 16 | } 17 | 18 | 19 | def find_nearest_bucket(h, w, resolution=640): 20 | min_metric = float('inf') 21 | best_bucket = None 22 | for (bucket_h, bucket_w) in bucket_options: 23 | metric = abs(h * bucket_w - w * bucket_h) 24 | if metric <= min_metric: 25 | min_metric = metric 26 | best_bucket = (bucket_h, bucket_w) 27 | 28 | if resolution != 640: 29 | scale_factor = resolution / 640.0 30 | scaled_height = round(best_bucket[0] * scale_factor / 16) * 16 31 | scaled_width = round(best_bucket[1] * scale_factor / 16) * 16 32 | best_bucket = (scaled_height, scaled_width) 33 | print(f'Resolution: {best_bucket[1]} x {best_bucket[0]}') 34 | 35 | return best_bucket 36 | 37 | -------------------------------------------------------------------------------- /diffusers_helper/dit_common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import accelerate.accelerator 3 | 4 | from diffusers.models.normalization import RMSNorm, LayerNorm, FP32LayerNorm, AdaLayerNormContinuous 5 | 6 | 7 | accelerate.accelerator.convert_outputs_to_fp32 = lambda x: x 8 | 9 | 10 | def LayerNorm_forward(self, x): 11 | return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x) 12 | 13 | 14 | LayerNorm.forward = LayerNorm_forward 15 | torch.nn.LayerNorm.forward = LayerNorm_forward 16 | 17 | 18 | def FP32LayerNorm_forward(self, x): 19 | origin_dtype = x.dtype 20 | return torch.nn.functional.layer_norm( 21 | x.float(), 22 | self.normalized_shape, 23 | self.weight.float() if self.weight is not None else None, 24 | self.bias.float() if self.bias is not None else None, 25 | self.eps, 26 | ).to(origin_dtype) 27 | 28 | 29 | FP32LayerNorm.forward = FP32LayerNorm_forward 30 | 31 | 32 | def RMSNorm_forward(self, hidden_states): 33 | input_dtype = hidden_states.dtype 34 | variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) 35 | hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 36 | 37 | if self.weight is None: 38 | return hidden_states.to(input_dtype) 39 | 40 | return hidden_states.to(input_dtype) * self.weight.to(input_dtype) 41 | 42 | 43 | RMSNorm.forward = RMSNorm_forward 44 | 45 | 46 | def AdaLayerNormContinuous_forward(self, x, conditioning_embedding): 47 | emb = self.linear(self.silu(conditioning_embedding)) 48 | scale, shift = emb.chunk(2, dim=1) 49 | x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] 50 | return x 51 | 52 | 53 | AdaLayerNormContinuous.forward = AdaLayerNormContinuous_forward 54 | -------------------------------------------------------------------------------- /diffusers_helper/k_diffusion/uni_pc_fm.py: -------------------------------------------------------------------------------- 1 | # Better Flow Matching UniPC by Lvmin Zhang 2 | # (c) 2025 3 | # CC BY-SA 4.0 4 | # Attribution-ShareAlike 4.0 International Licence 5 | 6 | 7 | import torch 8 | from comfy.utils import ProgressBar 9 | from tqdm.auto import trange 10 | 11 | 12 | def expand_dims(v, dims): 13 | return v[(...,) + (None,) * (dims - 1)] 14 | 15 | 16 | class FlowMatchUniPC: 17 | def __init__(self, model, extra_args, variant='bh1'): 18 | self.model = model 19 | self.variant = variant 20 | self.extra_args = extra_args 21 | 22 | def model_fn(self, x, t): 23 | return self.model(x, t, **self.extra_args) 24 | 25 | def update_fn(self, x, model_prev_list, t_prev_list, t, order): 26 | assert order <= len(model_prev_list) 27 | dims = x.dim() 28 | 29 | t_prev_0 = t_prev_list[-1] 30 | lambda_prev_0 = - torch.log(t_prev_0) 31 | lambda_t = - torch.log(t) 32 | model_prev_0 = model_prev_list[-1] 33 | 34 | h = lambda_t - lambda_prev_0 35 | 36 | rks = [] 37 | D1s = [] 38 | for i in range(1, order): 39 | t_prev_i = t_prev_list[-(i + 1)] 40 | model_prev_i = model_prev_list[-(i + 1)] 41 | lambda_prev_i = - torch.log(t_prev_i) 42 | rk = ((lambda_prev_i - lambda_prev_0) / h)[0] 43 | rks.append(rk) 44 | D1s.append((model_prev_i - model_prev_0) / rk) 45 | 46 | rks.append(1.) 47 | rks = torch.tensor(rks, device=x.device) 48 | 49 | R = [] 50 | b = [] 51 | 52 | hh = -h[0] 53 | h_phi_1 = torch.expm1(hh) 54 | h_phi_k = h_phi_1 / hh - 1 55 | 56 | factorial_i = 1 57 | 58 | if self.variant == 'bh1': 59 | B_h = hh 60 | elif self.variant == 'bh2': 61 | B_h = torch.expm1(hh) 62 | else: 63 | raise NotImplementedError('Bad variant!') 64 | 65 | for i in range(1, order + 1): 66 | R.append(torch.pow(rks, i - 1)) 67 | b.append(h_phi_k * factorial_i / B_h) 68 | factorial_i *= (i + 1) 69 | h_phi_k = h_phi_k / hh - 1 / factorial_i 70 | 71 | R = torch.stack(R) 72 | b = torch.tensor(b, device=x.device) 73 | 74 | use_predictor = len(D1s) > 0 75 | 76 | if use_predictor: 77 | D1s = torch.stack(D1s, dim=1) 78 | if order == 2: 79 | rhos_p = torch.tensor([0.5], device=b.device) 80 | else: 81 | rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) 82 | else: 83 | D1s = None 84 | rhos_p = None 85 | 86 | if order == 1: 87 | rhos_c = torch.tensor([0.5], device=b.device) 88 | else: 89 | rhos_c = torch.linalg.solve(R, b) 90 | 91 | x_t_ = expand_dims(t / t_prev_0, dims) * x - expand_dims(h_phi_1, dims) * model_prev_0 92 | 93 | if use_predictor: 94 | pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0])) 95 | else: 96 | pred_res = 0 97 | 98 | x_t = x_t_ - expand_dims(B_h, dims) * pred_res 99 | model_t = self.model_fn(x_t, t) 100 | 101 | if D1s is not None: 102 | corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0])) 103 | else: 104 | corr_res = 0 105 | 106 | D1_t = (model_t - model_prev_0) 107 | x_t = x_t_ - expand_dims(B_h, dims) * (corr_res + rhos_c[-1] * D1_t) 108 | 109 | return x_t, model_t 110 | 111 | def sample(self, x, sigmas, callback=None, disable_pbar=False): 112 | order = min(3, len(sigmas) - 2) 113 | model_prev_list, t_prev_list = [], [] 114 | comfy_pbar = ProgressBar(len(sigmas)-1) 115 | for i in trange(len(sigmas) - 1, disable=disable_pbar): 116 | vec_t = sigmas[i].expand(x.shape[0]) 117 | 118 | if i == 0: 119 | model_prev_list = [self.model_fn(x, vec_t)] 120 | t_prev_list = [vec_t] 121 | elif i < order: 122 | init_order = i 123 | x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order) 124 | model_prev_list.append(model_x) 125 | t_prev_list.append(vec_t) 126 | else: 127 | x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order) 128 | model_prev_list.append(model_x) 129 | t_prev_list.append(vec_t) 130 | 131 | model_prev_list = model_prev_list[-order:] 132 | t_prev_list = t_prev_list[-order:] 133 | 134 | if callback is not None: 135 | callback_latent = model_prev_list[-1].detach()[0].permute(1,0,2,3) 136 | callback( 137 | i, 138 | callback_latent, 139 | None, 140 | len(sigmas) - 1 141 | ) 142 | comfy_pbar.update(1) 143 | 144 | return model_prev_list[-1] 145 | 146 | 147 | def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'): 148 | assert variant in ['bh1', 'bh2'] 149 | return FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=disable) 150 | -------------------------------------------------------------------------------- /diffusers_helper/k_diffusion/wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def append_dims(x, target_dims): 5 | return x[(...,) + (None,) * (target_dims - x.ndim)] 6 | 7 | 8 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=1.0): 9 | if guidance_rescale == 0: 10 | return noise_cfg 11 | 12 | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) 13 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) 14 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) 15 | noise_cfg = guidance_rescale * noise_pred_rescaled + (1.0 - guidance_rescale) * noise_cfg 16 | return noise_cfg 17 | 18 | 19 | def fm_wrapper(transformer, t_scale=1000.0): 20 | def k_model(x, sigma, **extra_args): 21 | dtype = extra_args['dtype'] 22 | cfg_scale = extra_args['cfg_scale'] 23 | cfg_rescale = extra_args['cfg_rescale'] 24 | concat_latent = extra_args['concat_latent'] 25 | 26 | original_dtype = x.dtype 27 | sigma = sigma.float() 28 | 29 | x = x.to(dtype) 30 | timestep = (sigma * t_scale).to(dtype) 31 | 32 | if concat_latent is None: 33 | hidden_states = x 34 | else: 35 | hidden_states = torch.cat([x, concat_latent.to(x)], dim=1) 36 | 37 | pred_positive = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['positive'])[0].float() 38 | 39 | if cfg_scale == 1.0: 40 | pred_negative = torch.zeros_like(pred_positive) 41 | else: 42 | pred_negative = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['negative'])[0].float() 43 | 44 | pred_cfg = pred_negative + cfg_scale * (pred_positive - pred_negative) 45 | pred = rescale_noise_cfg(pred_cfg, pred_positive, guidance_rescale=cfg_rescale) 46 | 47 | x0 = x.float() - pred.float() * append_dims(sigma, x.ndim) 48 | 49 | return x0.to(dtype=original_dtype) 50 | 51 | return k_model 52 | -------------------------------------------------------------------------------- /diffusers_helper/memory.py: -------------------------------------------------------------------------------- 1 | # By lllyasviel 2 | 3 | 4 | import torch 5 | 6 | 7 | cpu = torch.device('cpu') 8 | gpu = torch.device(f'cuda:{torch.cuda.current_device()}') 9 | gpu_complete_modules = [] 10 | 11 | 12 | class DynamicSwapInstaller: 13 | @staticmethod 14 | def _install_module(module: torch.nn.Module, **kwargs): 15 | original_class = module.__class__ 16 | module.__dict__['forge_backup_original_class'] = original_class 17 | 18 | def hacked_get_attr(self, name: str): 19 | if '_parameters' in self.__dict__: 20 | _parameters = self.__dict__['_parameters'] 21 | if name in _parameters: 22 | p = _parameters[name] 23 | if p is None: 24 | return None 25 | if p.__class__ == torch.nn.Parameter: 26 | return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad) 27 | else: 28 | return p.to(**kwargs) 29 | if '_buffers' in self.__dict__: 30 | _buffers = self.__dict__['_buffers'] 31 | if name in _buffers: 32 | return _buffers[name].to(**kwargs) 33 | return super(original_class, self).__getattr__(name) 34 | 35 | module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), { 36 | '__getattr__': hacked_get_attr, 37 | }) 38 | 39 | return 40 | 41 | @staticmethod 42 | def _uninstall_module(module: torch.nn.Module): 43 | if 'forge_backup_original_class' in module.__dict__: 44 | module.__class__ = module.__dict__.pop('forge_backup_original_class') 45 | return 46 | 47 | @staticmethod 48 | def install_model(model: torch.nn.Module, **kwargs): 49 | for m in model.modules(): 50 | DynamicSwapInstaller._install_module(m, **kwargs) 51 | return 52 | 53 | @staticmethod 54 | def uninstall_model(model: torch.nn.Module): 55 | for m in model.modules(): 56 | DynamicSwapInstaller._uninstall_module(m) 57 | return 58 | 59 | 60 | def fake_diffusers_current_device(model: torch.nn.Module, target_device: torch.device): 61 | if hasattr(model, 'scale_shift_table'): 62 | model.scale_shift_table.data = model.scale_shift_table.data.to(target_device) 63 | return 64 | 65 | for k, p in model.named_modules(): 66 | if hasattr(p, 'weight'): 67 | p.to(target_device) 68 | return 69 | 70 | 71 | def get_cuda_free_memory_gb(device=None): 72 | if device is None: 73 | device = gpu 74 | 75 | memory_stats = torch.cuda.memory_stats(device) 76 | bytes_active = memory_stats['active_bytes.all.current'] 77 | bytes_reserved = memory_stats['reserved_bytes.all.current'] 78 | bytes_free_cuda, _ = torch.cuda.mem_get_info(device) 79 | bytes_inactive_reserved = bytes_reserved - bytes_active 80 | bytes_total_available = bytes_free_cuda + bytes_inactive_reserved 81 | return bytes_total_available / (1024 ** 3) 82 | 83 | 84 | def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0): 85 | print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB') 86 | 87 | for m in model.modules(): 88 | if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb: 89 | torch.cuda.empty_cache() 90 | return 91 | 92 | if hasattr(m, 'weight'): 93 | m.to(device=target_device) 94 | 95 | model.to(device=target_device) 96 | torch.cuda.empty_cache() 97 | return 98 | 99 | 100 | def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0): 101 | print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB') 102 | 103 | for m in model.modules(): 104 | if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb: 105 | torch.cuda.empty_cache() 106 | return 107 | 108 | if hasattr(m, 'weight'): 109 | m.to(device=cpu) 110 | 111 | model.to(device=cpu) 112 | torch.cuda.empty_cache() 113 | return 114 | 115 | 116 | def unload_complete_models(*args): 117 | for m in gpu_complete_modules + list(args): 118 | m.to(device=cpu) 119 | print(f'Unloaded {m.__class__.__name__} as complete.') 120 | 121 | gpu_complete_modules.clear() 122 | torch.cuda.empty_cache() 123 | return 124 | 125 | 126 | def load_model_as_complete(model, target_device, unload=True): 127 | if unload: 128 | unload_complete_models() 129 | 130 | model.to(device=target_device) 131 | print(f'Loaded {model.__class__.__name__} to {target_device} as complete.') 132 | 133 | gpu_complete_modules.append(model) 134 | return 135 | -------------------------------------------------------------------------------- /diffusers_helper/models/hunyuan_video_packed.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | import torch 4 | import einops 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | from diffusers.loaders import FromOriginalModelMixin 9 | from diffusers.configuration_utils import ConfigMixin, register_to_config 10 | from diffusers.loaders import PeftAdapterMixin 11 | from diffusers.utils import logging 12 | from diffusers.models.attention import FeedForward 13 | from diffusers.models.attention_processor import Attention 14 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection 15 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 16 | from diffusers.models.modeling_utils import ModelMixin 17 | from ...diffusers_helper.dit_common import LayerNorm 18 | 19 | 20 | enabled_backends = [] 21 | 22 | if torch.backends.cuda.flash_sdp_enabled(): 23 | enabled_backends.append("flash") 24 | if torch.backends.cuda.math_sdp_enabled(): 25 | enabled_backends.append("math") 26 | if torch.backends.cuda.mem_efficient_sdp_enabled(): 27 | enabled_backends.append("mem_efficient") 28 | if torch.backends.cuda.cudnn_sdp_enabled(): 29 | enabled_backends.append("cudnn") 30 | 31 | try: 32 | # raise NotImplementedError 33 | from flash_attn import flash_attn_varlen_func, flash_attn_func 34 | except: 35 | flash_attn_varlen_func = None 36 | flash_attn_func = None 37 | 38 | try: 39 | # raise NotImplementedError 40 | from sageattention import sageattn_varlen, sageattn 41 | except: 42 | sageattn_varlen = None 43 | sageattn = None 44 | 45 | 46 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 47 | 48 | 49 | def pad_for_3d_conv(x, kernel_size): 50 | b, c, t, h, w = x.shape 51 | pt, ph, pw = kernel_size 52 | pad_t = (pt - (t % pt)) % pt 53 | pad_h = (ph - (h % ph)) % ph 54 | pad_w = (pw - (w % pw)) % pw 55 | return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode='replicate') 56 | 57 | 58 | def center_down_sample_3d(x, kernel_size): 59 | # pt, ph, pw = kernel_size 60 | # cp = (pt * ph * pw) // 2 61 | # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw) 62 | # xc = xp[cp] 63 | # return xc 64 | return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size) 65 | 66 | 67 | def get_cu_seqlens(text_mask, img_len): 68 | batch_size = text_mask.shape[0] 69 | text_len = text_mask.sum(dim=1) 70 | max_len = text_mask.shape[1] + img_len 71 | 72 | cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda") 73 | 74 | for i in range(batch_size): 75 | s = text_len[i] + img_len 76 | s1 = i * max_len + s 77 | s2 = (i + 1) * max_len 78 | cu_seqlens[2 * i + 1] = s1 79 | cu_seqlens[2 * i + 2] = s2 80 | 81 | return cu_seqlens 82 | 83 | 84 | def apply_rotary_emb_transposed(x, freqs_cis): 85 | cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1) 86 | x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1) 87 | x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) 88 | out = x.float() * cos + x_rotated.float() * sin 89 | out = out.to(x) 90 | return out 91 | 92 | 93 | def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, attention_mode='sdpa'): 94 | if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None: 95 | if attention_mode == "sageattn": 96 | x = sageattn(q, k, v, tensor_layout='NHD') 97 | if attention_mode == "flash_attn": 98 | x = flash_attn_func(q, k, v) 99 | elif attention_mode == "sdpa": 100 | x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2) 101 | return x 102 | 103 | # batch_size = q.shape[0] 104 | # q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) 105 | # k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) 106 | # v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) 107 | # if sageattn_varlen is not None: 108 | # x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) 109 | # elif flash_attn_varlen_func is not None: 110 | # x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) 111 | # else: 112 | # raise NotImplementedError('No Attn Installed!') 113 | # x = x.view(batch_size, max_seqlen_q, *x.shape[2:]) 114 | # return x 115 | 116 | 117 | class HunyuanAttnProcessorFlashAttnDouble: 118 | def __init__(self, attention_mode): 119 | self.attention_mode = attention_mode 120 | def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb): 121 | cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask 122 | 123 | query = attn.to_q(hidden_states) 124 | key = attn.to_k(hidden_states) 125 | value = attn.to_v(hidden_states) 126 | 127 | query = query.unflatten(2, (attn.heads, -1)) 128 | key = key.unflatten(2, (attn.heads, -1)) 129 | value = value.unflatten(2, (attn.heads, -1)) 130 | 131 | query = attn.norm_q(query) 132 | key = attn.norm_k(key) 133 | 134 | query = apply_rotary_emb_transposed(query, image_rotary_emb) 135 | key = apply_rotary_emb_transposed(key, image_rotary_emb) 136 | 137 | encoder_query = attn.add_q_proj(encoder_hidden_states) 138 | encoder_key = attn.add_k_proj(encoder_hidden_states) 139 | encoder_value = attn.add_v_proj(encoder_hidden_states) 140 | 141 | encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) 142 | encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) 143 | encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) 144 | 145 | encoder_query = attn.norm_added_q(encoder_query) 146 | encoder_key = attn.norm_added_k(encoder_key) 147 | 148 | query = torch.cat([query, encoder_query], dim=1) 149 | key = torch.cat([key, encoder_key], dim=1) 150 | value = torch.cat([value, encoder_value], dim=1) 151 | 152 | hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, self.attention_mode) 153 | hidden_states = hidden_states.flatten(-2) 154 | 155 | txt_length = encoder_hidden_states.shape[1] 156 | hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:] 157 | 158 | hidden_states = attn.to_out[0](hidden_states) 159 | hidden_states = attn.to_out[1](hidden_states) 160 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states) 161 | 162 | return hidden_states, encoder_hidden_states 163 | 164 | 165 | class HunyuanAttnProcessorFlashAttnSingle: 166 | def __init__(self, attention_mode): 167 | self.attention_mode = attention_mode 168 | def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb): 169 | cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask 170 | 171 | hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) 172 | 173 | query = attn.to_q(hidden_states) 174 | key = attn.to_k(hidden_states) 175 | value = attn.to_v(hidden_states) 176 | 177 | query = query.unflatten(2, (attn.heads, -1)) 178 | key = key.unflatten(2, (attn.heads, -1)) 179 | value = value.unflatten(2, (attn.heads, -1)) 180 | 181 | query = attn.norm_q(query) 182 | key = attn.norm_k(key) 183 | 184 | txt_length = encoder_hidden_states.shape[1] 185 | 186 | query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1) 187 | key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1) 188 | 189 | hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, self.attention_mode) 190 | hidden_states = hidden_states.flatten(-2) 191 | 192 | hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:] 193 | 194 | return hidden_states, encoder_hidden_states 195 | 196 | 197 | class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): 198 | def __init__(self, embedding_dim, pooled_projection_dim): 199 | super().__init__() 200 | 201 | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) 202 | self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) 203 | self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) 204 | self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") 205 | 206 | def forward(self, timestep, guidance, pooled_projection): 207 | timesteps_proj = self.time_proj(timestep) 208 | timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) 209 | 210 | guidance_proj = self.time_proj(guidance) 211 | guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) 212 | 213 | time_guidance_emb = timesteps_emb + guidance_emb 214 | 215 | pooled_projections = self.text_embedder(pooled_projection) 216 | conditioning = time_guidance_emb + pooled_projections 217 | 218 | return conditioning 219 | 220 | 221 | class CombinedTimestepTextProjEmbeddings(nn.Module): 222 | def __init__(self, embedding_dim, pooled_projection_dim): 223 | super().__init__() 224 | 225 | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) 226 | self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) 227 | self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") 228 | 229 | def forward(self, timestep, pooled_projection): 230 | timesteps_proj = self.time_proj(timestep) 231 | timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) 232 | 233 | pooled_projections = self.text_embedder(pooled_projection) 234 | 235 | conditioning = timesteps_emb + pooled_projections 236 | 237 | return conditioning 238 | 239 | 240 | class HunyuanVideoAdaNorm(nn.Module): 241 | def __init__(self, in_features: int, out_features: Optional[int] = None) -> None: 242 | super().__init__() 243 | 244 | out_features = out_features or 2 * in_features 245 | self.linear = nn.Linear(in_features, out_features) 246 | self.nonlinearity = nn.SiLU() 247 | 248 | def forward( 249 | self, temb: torch.Tensor 250 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 251 | temb = self.linear(self.nonlinearity(temb)) 252 | gate_msa, gate_mlp = temb.chunk(2, dim=-1) 253 | gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1) 254 | return gate_msa, gate_mlp 255 | 256 | 257 | class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): 258 | def __init__( 259 | self, 260 | num_attention_heads: int, 261 | attention_head_dim: int, 262 | mlp_width_ratio: str = 4.0, 263 | mlp_drop_rate: float = 0.0, 264 | attention_bias: bool = True, 265 | ) -> None: 266 | super().__init__() 267 | 268 | hidden_size = num_attention_heads * attention_head_dim 269 | 270 | self.norm1 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) 271 | self.attn = Attention( 272 | query_dim=hidden_size, 273 | cross_attention_dim=None, 274 | heads=num_attention_heads, 275 | dim_head=attention_head_dim, 276 | bias=attention_bias, 277 | ) 278 | 279 | self.norm2 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) 280 | self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate) 281 | 282 | self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size) 283 | 284 | def forward( 285 | self, 286 | hidden_states: torch.Tensor, 287 | temb: torch.Tensor, 288 | attention_mask: Optional[torch.Tensor] = None, 289 | ) -> torch.Tensor: 290 | norm_hidden_states = self.norm1(hidden_states) 291 | 292 | attn_output = self.attn( 293 | hidden_states=norm_hidden_states, 294 | encoder_hidden_states=None, 295 | attention_mask=attention_mask, 296 | ) 297 | 298 | gate_msa, gate_mlp = self.norm_out(temb) 299 | hidden_states = hidden_states + attn_output * gate_msa 300 | 301 | ff_output = self.ff(self.norm2(hidden_states)) 302 | hidden_states = hidden_states + ff_output * gate_mlp 303 | 304 | return hidden_states 305 | 306 | 307 | class HunyuanVideoIndividualTokenRefiner(nn.Module): 308 | def __init__( 309 | self, 310 | num_attention_heads: int, 311 | attention_head_dim: int, 312 | num_layers: int, 313 | mlp_width_ratio: float = 4.0, 314 | mlp_drop_rate: float = 0.0, 315 | attention_bias: bool = True, 316 | ) -> None: 317 | super().__init__() 318 | 319 | self.refiner_blocks = nn.ModuleList( 320 | [ 321 | HunyuanVideoIndividualTokenRefinerBlock( 322 | num_attention_heads=num_attention_heads, 323 | attention_head_dim=attention_head_dim, 324 | mlp_width_ratio=mlp_width_ratio, 325 | mlp_drop_rate=mlp_drop_rate, 326 | attention_bias=attention_bias, 327 | ) 328 | for _ in range(num_layers) 329 | ] 330 | ) 331 | 332 | def forward( 333 | self, 334 | hidden_states: torch.Tensor, 335 | temb: torch.Tensor, 336 | attention_mask: Optional[torch.Tensor] = None, 337 | ) -> None: 338 | self_attn_mask = None 339 | if attention_mask is not None: 340 | batch_size = attention_mask.shape[0] 341 | seq_len = attention_mask.shape[1] 342 | attention_mask = attention_mask.to(hidden_states.device).bool() 343 | self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) 344 | self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) 345 | self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() 346 | self_attn_mask[:, :, :, 0] = True 347 | 348 | for block in self.refiner_blocks: 349 | hidden_states = block(hidden_states, temb, self_attn_mask) 350 | 351 | return hidden_states 352 | 353 | 354 | class HunyuanVideoTokenRefiner(nn.Module): 355 | def __init__( 356 | self, 357 | in_channels: int, 358 | num_attention_heads: int, 359 | attention_head_dim: int, 360 | num_layers: int, 361 | mlp_ratio: float = 4.0, 362 | mlp_drop_rate: float = 0.0, 363 | attention_bias: bool = True, 364 | ) -> None: 365 | super().__init__() 366 | 367 | hidden_size = num_attention_heads * attention_head_dim 368 | 369 | self.time_text_embed = CombinedTimestepTextProjEmbeddings( 370 | embedding_dim=hidden_size, pooled_projection_dim=in_channels 371 | ) 372 | self.proj_in = nn.Linear(in_channels, hidden_size, bias=True) 373 | self.token_refiner = HunyuanVideoIndividualTokenRefiner( 374 | num_attention_heads=num_attention_heads, 375 | attention_head_dim=attention_head_dim, 376 | num_layers=num_layers, 377 | mlp_width_ratio=mlp_ratio, 378 | mlp_drop_rate=mlp_drop_rate, 379 | attention_bias=attention_bias, 380 | ) 381 | 382 | def forward( 383 | self, 384 | hidden_states: torch.Tensor, 385 | timestep: torch.LongTensor, 386 | attention_mask: Optional[torch.LongTensor] = None, 387 | ) -> torch.Tensor: 388 | if attention_mask is None: 389 | pooled_projections = hidden_states.mean(dim=1) 390 | else: 391 | original_dtype = hidden_states.dtype 392 | mask_float = attention_mask.float().unsqueeze(-1) 393 | pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) 394 | pooled_projections = pooled_projections.to(original_dtype) 395 | 396 | temb = self.time_text_embed(timestep, pooled_projections) 397 | hidden_states = self.proj_in(hidden_states) 398 | hidden_states = self.token_refiner(hidden_states, temb, attention_mask) 399 | 400 | return hidden_states 401 | 402 | 403 | class HunyuanVideoRotaryPosEmbed(nn.Module): 404 | def __init__(self, rope_dim, theta): 405 | super().__init__() 406 | self.DT, self.DY, self.DX = rope_dim 407 | self.theta = theta 408 | 409 | @torch.no_grad() 410 | def get_frequency(self, dim, pos): 411 | T, H, W = pos.shape 412 | freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim)) 413 | freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0) 414 | return freqs.cos(), freqs.sin() 415 | 416 | @torch.no_grad() 417 | def forward_inner(self, frame_indices, height, width, device): 418 | GT, GY, GX = torch.meshgrid( 419 | frame_indices.to(device=device, dtype=torch.float32), 420 | torch.arange(0, height, device=device, dtype=torch.float32), 421 | torch.arange(0, width, device=device, dtype=torch.float32), 422 | indexing="ij" 423 | ) 424 | 425 | FCT, FST = self.get_frequency(self.DT, GT) 426 | FCY, FSY = self.get_frequency(self.DY, GY) 427 | FCX, FSX = self.get_frequency(self.DX, GX) 428 | 429 | result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0) 430 | 431 | return result.to(device) 432 | 433 | @torch.no_grad() 434 | def forward(self, frame_indices, height, width, device): 435 | frame_indices = frame_indices.unbind(0) 436 | results = [self.forward_inner(f, height, width, device) for f in frame_indices] 437 | results = torch.stack(results, dim=0) 438 | return results 439 | 440 | 441 | class AdaLayerNormZero(nn.Module): 442 | def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): 443 | super().__init__() 444 | self.silu = nn.SiLU() 445 | self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) 446 | if norm_type == "layer_norm": 447 | self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) 448 | else: 449 | raise ValueError(f"unknown norm_type {norm_type}") 450 | 451 | def forward( 452 | self, 453 | x: torch.Tensor, 454 | emb: Optional[torch.Tensor] = None, 455 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 456 | emb = emb.unsqueeze(-2) 457 | emb = self.linear(self.silu(emb)) 458 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1) 459 | x = self.norm(x) * (1 + scale_msa) + shift_msa 460 | return x, gate_msa, shift_mlp, scale_mlp, gate_mlp 461 | 462 | 463 | class AdaLayerNormZeroSingle(nn.Module): 464 | def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): 465 | super().__init__() 466 | 467 | self.silu = nn.SiLU() 468 | self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias) 469 | if norm_type == "layer_norm": 470 | self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) 471 | else: 472 | raise ValueError(f"unknown norm_type {norm_type}") 473 | 474 | def forward( 475 | self, 476 | x: torch.Tensor, 477 | emb: Optional[torch.Tensor] = None, 478 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 479 | emb = emb.unsqueeze(-2) 480 | emb = self.linear(self.silu(emb)) 481 | shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1) 482 | x = self.norm(x) * (1 + scale_msa) + shift_msa 483 | return x, gate_msa 484 | 485 | 486 | class AdaLayerNormContinuous(nn.Module): 487 | def __init__( 488 | self, 489 | embedding_dim: int, 490 | conditioning_embedding_dim: int, 491 | elementwise_affine=True, 492 | eps=1e-5, 493 | bias=True, 494 | norm_type="layer_norm", 495 | ): 496 | super().__init__() 497 | self.silu = nn.SiLU() 498 | self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) 499 | if norm_type == "layer_norm": 500 | self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) 501 | else: 502 | raise ValueError(f"unknown norm_type {norm_type}") 503 | 504 | def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: 505 | emb = emb.unsqueeze(-2) 506 | emb = self.linear(self.silu(emb)) 507 | scale, shift = emb.chunk(2, dim=-1) 508 | x = self.norm(x) * (1 + scale) + shift 509 | return x 510 | 511 | 512 | class HunyuanVideoSingleTransformerBlock(nn.Module): 513 | def __init__( 514 | self, 515 | num_attention_heads: int, 516 | attention_head_dim: int, 517 | mlp_ratio: float = 4.0, 518 | qk_norm: str = "rms_norm", 519 | attention_mode: str = "sdpa", 520 | ) -> None: 521 | super().__init__() 522 | 523 | hidden_size = num_attention_heads * attention_head_dim 524 | mlp_dim = int(hidden_size * mlp_ratio) 525 | 526 | self.attn = Attention( 527 | query_dim=hidden_size, 528 | cross_attention_dim=None, 529 | dim_head=attention_head_dim, 530 | heads=num_attention_heads, 531 | out_dim=hidden_size, 532 | bias=True, 533 | processor=HunyuanAttnProcessorFlashAttnSingle(attention_mode), 534 | qk_norm=qk_norm, 535 | eps=1e-6, 536 | pre_only=True, 537 | ) 538 | 539 | self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm") 540 | self.proj_mlp = nn.Linear(hidden_size, mlp_dim) 541 | self.act_mlp = nn.GELU(approximate="tanh") 542 | self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size) 543 | 544 | def forward( 545 | self, 546 | hidden_states: torch.Tensor, 547 | encoder_hidden_states: torch.Tensor, 548 | temb: torch.Tensor, 549 | attention_mask: Optional[torch.Tensor] = None, 550 | image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 551 | ) -> torch.Tensor: 552 | text_seq_length = encoder_hidden_states.shape[1] 553 | hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) 554 | 555 | residual = hidden_states 556 | 557 | # 1. Input normalization 558 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb) 559 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) 560 | 561 | norm_hidden_states, norm_encoder_hidden_states = ( 562 | norm_hidden_states[:, :-text_seq_length, :], 563 | norm_hidden_states[:, -text_seq_length:, :], 564 | ) 565 | 566 | # 2. Attention 567 | attn_output, context_attn_output = self.attn( 568 | hidden_states=norm_hidden_states, 569 | encoder_hidden_states=norm_encoder_hidden_states, 570 | attention_mask=attention_mask, 571 | image_rotary_emb=image_rotary_emb, 572 | ) 573 | attn_output = torch.cat([attn_output, context_attn_output], dim=1) 574 | 575 | # 3. Modulation and residual connection 576 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) 577 | hidden_states = gate * self.proj_out(hidden_states) 578 | hidden_states = hidden_states + residual 579 | 580 | hidden_states, encoder_hidden_states = ( 581 | hidden_states[:, :-text_seq_length, :], 582 | hidden_states[:, -text_seq_length:, :], 583 | ) 584 | return hidden_states, encoder_hidden_states 585 | 586 | 587 | class HunyuanVideoTransformerBlock(nn.Module): 588 | def __init__( 589 | self, 590 | num_attention_heads: int, 591 | attention_head_dim: int, 592 | mlp_ratio: float, 593 | qk_norm: str = "rms_norm", 594 | attention_mode: str = "sdpa", 595 | ) -> None: 596 | super().__init__() 597 | 598 | hidden_size = num_attention_heads * attention_head_dim 599 | 600 | self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm") 601 | self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") 602 | 603 | self.attn = Attention( 604 | query_dim=hidden_size, 605 | cross_attention_dim=None, 606 | added_kv_proj_dim=hidden_size, 607 | dim_head=attention_head_dim, 608 | heads=num_attention_heads, 609 | out_dim=hidden_size, 610 | context_pre_only=False, 611 | bias=True, 612 | processor=HunyuanAttnProcessorFlashAttnDouble(attention_mode), 613 | qk_norm=qk_norm, 614 | eps=1e-6, 615 | ) 616 | 617 | self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 618 | self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") 619 | 620 | self.norm2_context = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 621 | self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") 622 | 623 | def forward( 624 | self, 625 | hidden_states: torch.Tensor, 626 | encoder_hidden_states: torch.Tensor, 627 | temb: torch.Tensor, 628 | attention_mask: Optional[torch.Tensor] = None, 629 | freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 630 | ) -> Tuple[torch.Tensor, torch.Tensor]: 631 | # 1. Input normalization 632 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) 633 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb) 634 | 635 | # 2. Joint attention 636 | attn_output, context_attn_output = self.attn( 637 | hidden_states=norm_hidden_states, 638 | encoder_hidden_states=norm_encoder_hidden_states, 639 | attention_mask=attention_mask, 640 | image_rotary_emb=freqs_cis, 641 | ) 642 | 643 | # 3. Modulation and residual connection 644 | hidden_states = hidden_states + attn_output * gate_msa 645 | encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa 646 | 647 | norm_hidden_states = self.norm2(hidden_states) 648 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) 649 | 650 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp 651 | norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp 652 | 653 | # 4. Feed-forward 654 | ff_output = self.ff(norm_hidden_states) 655 | context_ff_output = self.ff_context(norm_encoder_hidden_states) 656 | 657 | hidden_states = hidden_states + gate_mlp * ff_output 658 | encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output 659 | 660 | return hidden_states, encoder_hidden_states 661 | 662 | 663 | class ClipVisionProjection(nn.Module): 664 | def __init__(self, in_channels, out_channels): 665 | super().__init__() 666 | self.up = nn.Linear(in_channels, out_channels * 3) 667 | self.down = nn.Linear(out_channels * 3, out_channels) 668 | 669 | def forward(self, x): 670 | projected_x = self.down(nn.functional.silu(self.up(x))) 671 | return projected_x 672 | 673 | 674 | class HunyuanVideoPatchEmbed(nn.Module): 675 | def __init__(self, patch_size, in_chans, embed_dim): 676 | super().__init__() 677 | self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 678 | 679 | 680 | class HunyuanVideoPatchEmbedForCleanLatents(nn.Module): 681 | def __init__(self, inner_dim): 682 | super().__init__() 683 | self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) 684 | self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) 685 | self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) 686 | 687 | @torch.no_grad() 688 | def initialize_weight_from_another_conv3d(self, another_layer): 689 | weight = another_layer.weight.detach().clone() 690 | bias = another_layer.bias.detach().clone() 691 | 692 | sd = { 693 | 'proj.weight': weight.clone(), 694 | 'proj.bias': bias.clone(), 695 | 'proj_2x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=2, hk=2, wk=2) / 8.0, 696 | 'proj_2x.bias': bias.clone(), 697 | 'proj_4x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=4, hk=4, wk=4) / 64.0, 698 | 'proj_4x.bias': bias.clone(), 699 | } 700 | 701 | sd = {k: v.clone() for k, v in sd.items()} 702 | 703 | self.load_state_dict(sd) 704 | return 705 | 706 | 707 | class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): 708 | @register_to_config 709 | def __init__( 710 | self, 711 | in_channels: int = 16, 712 | out_channels: int = 16, 713 | num_attention_heads: int = 24, 714 | attention_head_dim: int = 128, 715 | num_layers: int = 20, 716 | num_single_layers: int = 40, 717 | num_refiner_layers: int = 2, 718 | mlp_ratio: float = 4.0, 719 | patch_size: int = 2, 720 | patch_size_t: int = 1, 721 | qk_norm: str = "rms_norm", 722 | guidance_embeds: bool = True, 723 | text_embed_dim: int = 4096, 724 | pooled_projection_dim: int = 768, 725 | rope_theta: float = 256.0, 726 | rope_axes_dim: Tuple[int] = (16, 56, 56), 727 | has_image_proj=False, 728 | image_proj_dim=1152, 729 | has_clean_x_embedder=False, 730 | attention_mode="sdpa", 731 | ) -> None: 732 | super().__init__() 733 | 734 | inner_dim = num_attention_heads * attention_head_dim 735 | out_channels = out_channels or in_channels 736 | 737 | # 1. Latent and condition embedders 738 | self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) 739 | self.context_embedder = HunyuanVideoTokenRefiner( 740 | text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers 741 | ) 742 | self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim) 743 | 744 | self.clean_x_embedder = None 745 | self.image_projection = None 746 | 747 | # 2. RoPE 748 | self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta) 749 | 750 | # 3. Dual stream transformer blocks 751 | self.transformer_blocks = nn.ModuleList( 752 | [ 753 | HunyuanVideoTransformerBlock( 754 | num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm, attention_mode=attention_mode 755 | ) 756 | for _ in range(num_layers) 757 | ] 758 | ) 759 | 760 | # 4. Single stream transformer blocks 761 | self.single_transformer_blocks = nn.ModuleList( 762 | [ 763 | HunyuanVideoSingleTransformerBlock( 764 | num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm, attention_mode=attention_mode 765 | ) 766 | for _ in range(num_single_layers) 767 | ] 768 | ) 769 | 770 | # 5. Output projection 771 | self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) 772 | self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) 773 | 774 | self.inner_dim = inner_dim 775 | self.use_gradient_checkpointing = False 776 | self.enable_teacache = False 777 | 778 | if has_image_proj: 779 | self.install_image_projection(image_proj_dim) 780 | 781 | if has_clean_x_embedder: 782 | self.install_clean_x_embedder() 783 | 784 | self.high_quality_fp32_output_for_inference = False 785 | 786 | def install_image_projection(self, in_channels): 787 | self.image_projection = ClipVisionProjection(in_channels=in_channels, out_channels=self.inner_dim) 788 | self.config['has_image_proj'] = True 789 | self.config['image_proj_dim'] = in_channels 790 | 791 | def install_clean_x_embedder(self): 792 | self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim) 793 | self.config['has_clean_x_embedder'] = True 794 | 795 | def enable_gradient_checkpointing(self): 796 | self.use_gradient_checkpointing = True 797 | print('self.use_gradient_checkpointing = True') 798 | 799 | def disable_gradient_checkpointing(self): 800 | self.use_gradient_checkpointing = False 801 | print('self.use_gradient_checkpointing = False') 802 | 803 | def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15): 804 | self.enable_teacache = enable_teacache 805 | self.cnt = 0 806 | self.num_steps = num_steps 807 | self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup 808 | self.accumulated_rel_l1_distance = 0 809 | self.previous_modulated_input = None 810 | self.previous_residual = None 811 | self.teacache_rescale_func = np.poly1d([7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]) 812 | 813 | def gradient_checkpointing_method(self, block, *args): 814 | if self.use_gradient_checkpointing: 815 | result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False) 816 | else: 817 | result = block(*args) 818 | return result 819 | 820 | def process_input_hidden_states( 821 | self, 822 | latents, latent_indices=None, 823 | clean_latents=None, clean_latent_indices=None, 824 | clean_latents_2x=None, clean_latent_2x_indices=None, 825 | clean_latents_4x=None, clean_latent_4x_indices=None 826 | ): 827 | hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents) 828 | B, C, T, H, W = hidden_states.shape 829 | 830 | if latent_indices is None: 831 | latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1) 832 | 833 | hidden_states = hidden_states.flatten(2).transpose(1, 2) 834 | 835 | rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device) 836 | rope_freqs = rope_freqs.flatten(2).transpose(1, 2) 837 | 838 | if clean_latents is not None and clean_latent_indices is not None: 839 | clean_latents = clean_latents.to(hidden_states) 840 | clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents) 841 | clean_latents = clean_latents.flatten(2).transpose(1, 2) 842 | 843 | clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device) 844 | clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2) 845 | 846 | hidden_states = torch.cat([clean_latents, hidden_states], dim=1) 847 | rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1) 848 | 849 | if clean_latents_2x is not None and clean_latent_2x_indices is not None: 850 | clean_latents_2x = clean_latents_2x.to(hidden_states) 851 | clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4)) 852 | clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x) 853 | clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2) 854 | 855 | clean_latent_2x_rope_freqs = self.rope(frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device) 856 | clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2)) 857 | clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2)) 858 | clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2) 859 | 860 | hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1) 861 | rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1) 862 | 863 | if clean_latents_4x is not None and clean_latent_4x_indices is not None: 864 | clean_latents_4x = clean_latents_4x.to(hidden_states) 865 | clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8)) 866 | clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x) 867 | clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2) 868 | 869 | clean_latent_4x_rope_freqs = self.rope(frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device) 870 | clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4)) 871 | clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4)) 872 | clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2) 873 | 874 | hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1) 875 | rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1) 876 | 877 | return hidden_states, rope_freqs 878 | 879 | def forward( 880 | self, 881 | hidden_states, timestep, encoder_hidden_states, encoder_attention_mask, pooled_projections, guidance, 882 | latent_indices=None, 883 | clean_latents=None, clean_latent_indices=None, 884 | clean_latents_2x=None, clean_latent_2x_indices=None, 885 | clean_latents_4x=None, clean_latent_4x_indices=None, 886 | image_embeddings=None, 887 | attention_kwargs=None, return_dict=True 888 | ): 889 | 890 | if attention_kwargs is None: 891 | attention_kwargs = {} 892 | 893 | batch_size, num_channels, num_frames, height, width = hidden_states.shape 894 | p, p_t = self.config['patch_size'], self.config['patch_size_t'] 895 | post_patch_num_frames = num_frames // p_t 896 | post_patch_height = height // p 897 | post_patch_width = width // p 898 | original_context_length = post_patch_num_frames * post_patch_height * post_patch_width 899 | 900 | hidden_states, rope_freqs = self.process_input_hidden_states(hidden_states, latent_indices, clean_latents, clean_latent_indices, clean_latents_2x, clean_latent_2x_indices, clean_latents_4x, clean_latent_4x_indices) 901 | 902 | temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections) 903 | encoder_hidden_states = self.gradient_checkpointing_method(self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask) 904 | 905 | if self.image_projection is not None and image_embeddings is not None: 906 | extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings) 907 | extra_attention_mask = torch.ones((batch_size, extra_encoder_hidden_states.shape[1]), dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device) 908 | 909 | # must cat before (not after) encoder_hidden_states, due to attn masking 910 | encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1) 911 | encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1) 912 | 913 | with torch.no_grad(): 914 | if batch_size == 1: 915 | # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want 916 | # If they are not same, then their impls are wrong. Ours are always the correct one. 917 | text_len = encoder_attention_mask.sum().item() 918 | encoder_hidden_states = encoder_hidden_states[:, :text_len] 919 | attention_mask = None, None, None, None 920 | else: 921 | img_seq_len = hidden_states.shape[1] 922 | txt_seq_len = encoder_hidden_states.shape[1] 923 | 924 | cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len) 925 | cu_seqlens_kv = cu_seqlens_q 926 | max_seqlen_q = img_seq_len + txt_seq_len 927 | max_seqlen_kv = max_seqlen_q 928 | 929 | attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv 930 | 931 | if self.enable_teacache: 932 | modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0] 933 | 934 | if self.cnt == 0 or self.cnt == self.num_steps-1: 935 | should_calc = True 936 | self.accumulated_rel_l1_distance = 0 937 | else: 938 | curr_rel_l1 = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item() 939 | self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1) 940 | should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh 941 | 942 | if should_calc: 943 | self.accumulated_rel_l1_distance = 0 944 | 945 | self.previous_modulated_input = modulated_inp 946 | self.cnt += 1 947 | 948 | if self.cnt == self.num_steps: 949 | self.cnt = 0 950 | 951 | if not should_calc: 952 | hidden_states = hidden_states + self.previous_residual 953 | else: 954 | ori_hidden_states = hidden_states.clone() 955 | 956 | for block_id, block in enumerate(self.transformer_blocks): 957 | hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( 958 | block, 959 | hidden_states, 960 | encoder_hidden_states, 961 | temb, 962 | attention_mask, 963 | rope_freqs 964 | ) 965 | 966 | for block_id, block in enumerate(self.single_transformer_blocks): 967 | hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( 968 | block, 969 | hidden_states, 970 | encoder_hidden_states, 971 | temb, 972 | attention_mask, 973 | rope_freqs 974 | ) 975 | 976 | self.previous_residual = hidden_states - ori_hidden_states 977 | else: 978 | for block_id, block in enumerate(self.transformer_blocks): 979 | hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( 980 | block, 981 | hidden_states, 982 | encoder_hidden_states, 983 | temb, 984 | attention_mask, 985 | rope_freqs 986 | ) 987 | 988 | for block_id, block in enumerate(self.single_transformer_blocks): 989 | hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( 990 | block, 991 | hidden_states, 992 | encoder_hidden_states, 993 | temb, 994 | attention_mask, 995 | rope_freqs 996 | ) 997 | 998 | hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb) 999 | 1000 | hidden_states = hidden_states[:, -original_context_length:, :] 1001 | 1002 | if self.high_quality_fp32_output_for_inference: 1003 | hidden_states = hidden_states.to(dtype=torch.float32) 1004 | if self.proj_out.weight.dtype != torch.float32: 1005 | self.proj_out.to(dtype=torch.float32) 1006 | 1007 | hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states) 1008 | 1009 | hidden_states = einops.rearrange(hidden_states, 'b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)', 1010 | t=post_patch_num_frames, h=post_patch_height, w=post_patch_width, 1011 | pt=p_t, ph=p, pw=p) 1012 | 1013 | if return_dict: 1014 | return Transformer2DModelOutput(sample=hidden_states) 1015 | 1016 | return hidden_states, 1017 | -------------------------------------------------------------------------------- /diffusers_helper/pipelines/k_diffusion_hunyuan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | from ..k_diffusion.uni_pc_fm import sample_unipc 5 | from ..k_diffusion.wrapper import fm_wrapper 6 | from ..utils import repeat_to_batch_size 7 | 8 | 9 | def flux_time_shift(t, mu=1.15, sigma=1.0): 10 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) 11 | 12 | 13 | def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0): 14 | k = (y2 - y1) / (x2 - x1) 15 | b = y1 - k * x1 16 | mu = k * context_length + b 17 | mu = min(mu, math.log(exp_max)) 18 | return mu 19 | 20 | 21 | def get_flux_sigmas_from_mu(n, mu): 22 | sigmas = torch.linspace(1, 0, steps=n + 1) 23 | sigmas = flux_time_shift(sigmas, mu=mu) 24 | return sigmas 25 | 26 | 27 | @torch.inference_mode() 28 | def sample_hunyuan( 29 | transformer, 30 | sampler='unipc', 31 | initial_latent=None, 32 | concat_latent=None, 33 | strength=1.0, 34 | width=512, 35 | height=512, 36 | frames=16, 37 | real_guidance_scale=1.0, 38 | distilled_guidance_scale=6.0, 39 | guidance_rescale=0.0, 40 | shift=None, 41 | num_inference_steps=25, 42 | batch_size=None, 43 | generator=None, 44 | prompt_embeds=None, 45 | prompt_embeds_mask=None, 46 | prompt_poolers=None, 47 | negative_prompt_embeds=None, 48 | negative_prompt_embeds_mask=None, 49 | negative_prompt_poolers=None, 50 | dtype=torch.bfloat16, 51 | device=None, 52 | negative_kwargs=None, 53 | callback=None, 54 | **kwargs, 55 | ): 56 | device = device or transformer.device 57 | 58 | if batch_size is None: 59 | batch_size = int(prompt_embeds.shape[0]) 60 | 61 | latents = torch.randn((batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device).to(device=device, dtype=torch.float32) 62 | 63 | B, C, T, H, W = latents.shape 64 | seq_length = T * H * W // 4 65 | 66 | if shift is None: 67 | mu = calculate_flux_mu(seq_length, exp_max=7.0) 68 | else: 69 | mu = math.log(shift) 70 | 71 | sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device) 72 | 73 | k_model = fm_wrapper(transformer) 74 | 75 | if initial_latent is not None: 76 | sigmas = sigmas * strength 77 | first_sigma = sigmas[0].to(device=device, dtype=torch.float32) 78 | initial_latent = initial_latent.to(device=device, dtype=torch.float32) 79 | latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma 80 | 81 | if concat_latent is not None: 82 | concat_latent = concat_latent.to(latents) 83 | 84 | distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype) 85 | 86 | prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size) 87 | prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size) 88 | prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size) 89 | negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size) 90 | negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size) 91 | negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size) 92 | concat_latent = repeat_to_batch_size(concat_latent, batch_size) 93 | 94 | sampler_kwargs = dict( 95 | dtype=dtype, 96 | cfg_scale=real_guidance_scale, 97 | cfg_rescale=guidance_rescale, 98 | concat_latent=concat_latent, 99 | positive=dict( 100 | pooled_projections=prompt_poolers, 101 | encoder_hidden_states=prompt_embeds, 102 | encoder_attention_mask=prompt_embeds_mask, 103 | guidance=distilled_guidance, 104 | **kwargs, 105 | ), 106 | negative=dict( 107 | pooled_projections=negative_prompt_poolers, 108 | encoder_hidden_states=negative_prompt_embeds, 109 | encoder_attention_mask=negative_prompt_embeds_mask, 110 | guidance=distilled_guidance, 111 | **(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}), 112 | ) 113 | ) 114 | 115 | if sampler == 'unipc_bh1': 116 | variant = 'bh1' 117 | elif sampler == 'unipc_bh2': 118 | variant = 'bh2' 119 | results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, variant=variant, callback=callback) 120 | 121 | return results 122 | -------------------------------------------------------------------------------- /diffusers_helper/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | #import cv2 3 | import json 4 | import random 5 | import glob 6 | import torch 7 | import einops 8 | import numpy as np 9 | import datetime 10 | import torchvision 11 | 12 | import safetensors.torch as sf 13 | from PIL import Image 14 | 15 | 16 | # def min_resize(x, m): 17 | # if x.shape[0] < x.shape[1]: 18 | # s0 = m 19 | # s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1])) 20 | # else: 21 | # s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0])) 22 | # s1 = m 23 | # new_max = max(s1, s0) 24 | # raw_max = max(x.shape[0], x.shape[1]) 25 | # if new_max < raw_max: 26 | # interpolation = cv2.INTER_AREA 27 | # else: 28 | # interpolation = cv2.INTER_LANCZOS4 29 | # y = cv2.resize(x, (s1, s0), interpolation=interpolation) 30 | # return y 31 | 32 | 33 | # def d_resize(x, y): 34 | # H, W, C = y.shape 35 | # new_min = min(H, W) 36 | # raw_min = min(x.shape[0], x.shape[1]) 37 | # if new_min < raw_min: 38 | # interpolation = cv2.INTER_AREA 39 | # else: 40 | # interpolation = cv2.INTER_LANCZOS4 41 | # y = cv2.resize(x, (W, H), interpolation=interpolation) 42 | # return y 43 | 44 | 45 | def resize_and_center_crop(image, target_width, target_height): 46 | if target_height == image.shape[0] and target_width == image.shape[1]: 47 | return image 48 | 49 | pil_image = Image.fromarray(image) 50 | original_width, original_height = pil_image.size 51 | scale_factor = max(target_width / original_width, target_height / original_height) 52 | resized_width = int(round(original_width * scale_factor)) 53 | resized_height = int(round(original_height * scale_factor)) 54 | resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS) 55 | left = (resized_width - target_width) / 2 56 | top = (resized_height - target_height) / 2 57 | right = (resized_width + target_width) / 2 58 | bottom = (resized_height + target_height) / 2 59 | cropped_image = resized_image.crop((left, top, right, bottom)) 60 | return np.array(cropped_image) 61 | 62 | 63 | def resize_and_center_crop_pytorch(image, target_width, target_height): 64 | B, C, H, W = image.shape 65 | 66 | if H == target_height and W == target_width: 67 | return image 68 | 69 | scale_factor = max(target_width / W, target_height / H) 70 | resized_width = int(round(W * scale_factor)) 71 | resized_height = int(round(H * scale_factor)) 72 | 73 | resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False) 74 | 75 | top = (resized_height - target_height) // 2 76 | left = (resized_width - target_width) // 2 77 | cropped = resized[:, :, top:top + target_height, left:left + target_width] 78 | 79 | return cropped 80 | 81 | 82 | def resize_without_crop(image, target_width, target_height): 83 | if target_height == image.shape[0] and target_width == image.shape[1]: 84 | return image 85 | 86 | pil_image = Image.fromarray(image) 87 | resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS) 88 | return np.array(resized_image) 89 | 90 | 91 | def just_crop(image, w, h): 92 | if h == image.shape[0] and w == image.shape[1]: 93 | return image 94 | 95 | original_height, original_width = image.shape[:2] 96 | k = min(original_height / h, original_width / w) 97 | new_width = int(round(w * k)) 98 | new_height = int(round(h * k)) 99 | x_start = (original_width - new_width) // 2 100 | y_start = (original_height - new_height) // 2 101 | cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width] 102 | return cropped_image 103 | 104 | 105 | def write_to_json(data, file_path): 106 | temp_file_path = file_path + ".tmp" 107 | with open(temp_file_path, 'wt', encoding='utf-8') as temp_file: 108 | json.dump(data, temp_file, indent=4) 109 | os.replace(temp_file_path, file_path) 110 | return 111 | 112 | 113 | def read_from_json(file_path): 114 | with open(file_path, 'rt', encoding='utf-8') as file: 115 | data = json.load(file) 116 | return data 117 | 118 | 119 | def get_active_parameters(m): 120 | return {k: v for k, v in m.named_parameters() if v.requires_grad} 121 | 122 | 123 | def cast_training_params(m, dtype=torch.float32): 124 | result = {} 125 | for n, param in m.named_parameters(): 126 | if param.requires_grad: 127 | param.data = param.to(dtype) 128 | result[n] = param 129 | return result 130 | 131 | 132 | def separate_lora_AB(parameters, B_patterns=None): 133 | parameters_normal = {} 134 | parameters_B = {} 135 | 136 | if B_patterns is None: 137 | B_patterns = ['.lora_B.', '__zero__'] 138 | 139 | for k, v in parameters.items(): 140 | if any(B_pattern in k for B_pattern in B_patterns): 141 | parameters_B[k] = v 142 | else: 143 | parameters_normal[k] = v 144 | 145 | return parameters_normal, parameters_B 146 | 147 | 148 | def set_attr_recursive(obj, attr, value): 149 | attrs = attr.split(".") 150 | for name in attrs[:-1]: 151 | obj = getattr(obj, name) 152 | setattr(obj, attrs[-1], value) 153 | return 154 | 155 | 156 | def print_tensor_list_size(tensors): 157 | total_size = 0 158 | total_elements = 0 159 | 160 | if isinstance(tensors, dict): 161 | tensors = tensors.values() 162 | 163 | for tensor in tensors: 164 | total_size += tensor.nelement() * tensor.element_size() 165 | total_elements += tensor.nelement() 166 | 167 | total_size_MB = total_size / (1024 ** 2) 168 | total_elements_B = total_elements / 1e9 169 | 170 | print(f"Total number of tensors: {len(tensors)}") 171 | print(f"Total size of tensors: {total_size_MB:.2f} MB") 172 | print(f"Total number of parameters: {total_elements_B:.3f} billion") 173 | return 174 | 175 | 176 | @torch.no_grad() 177 | def batch_mixture(a, b=None, probability_a=0.5, mask_a=None): 178 | batch_size = a.size(0) 179 | 180 | if b is None: 181 | b = torch.zeros_like(a) 182 | 183 | if mask_a is None: 184 | mask_a = torch.rand(batch_size) < probability_a 185 | 186 | mask_a = mask_a.to(a.device) 187 | mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1)) 188 | result = torch.where(mask_a, a, b) 189 | return result 190 | 191 | 192 | @torch.no_grad() 193 | def zero_module(module): 194 | for p in module.parameters(): 195 | p.detach().zero_() 196 | return module 197 | 198 | 199 | @torch.no_grad() 200 | def supress_lower_channels(m, k, alpha=0.01): 201 | data = m.weight.data.clone() 202 | 203 | assert int(data.shape[1]) >= k 204 | 205 | data[:, :k] = data[:, :k] * alpha 206 | m.weight.data = data.contiguous().clone() 207 | return m 208 | 209 | 210 | def freeze_module(m): 211 | if not hasattr(m, '_forward_inside_frozen_module'): 212 | m._forward_inside_frozen_module = m.forward 213 | m.requires_grad_(False) 214 | m.forward = torch.no_grad()(m.forward) 215 | return m 216 | 217 | 218 | def get_latest_safetensors(folder_path): 219 | safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors')) 220 | 221 | if not safetensors_files: 222 | raise ValueError('No file to resume!') 223 | 224 | latest_file = max(safetensors_files, key=os.path.getmtime) 225 | latest_file = os.path.abspath(os.path.realpath(latest_file)) 226 | return latest_file 227 | 228 | 229 | def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32): 230 | tags = tags_str.split(', ') 231 | tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags))) 232 | prompt = ', '.join(tags) 233 | return prompt 234 | 235 | 236 | def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0): 237 | numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma) 238 | if round_to_int: 239 | numbers = np.round(numbers).astype(int) 240 | return numbers.tolist() 241 | 242 | 243 | def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False): 244 | edges = np.linspace(0, 1, n + 1) 245 | points = np.random.uniform(edges[:-1], edges[1:]) 246 | numbers = inclusive + (exclusive - inclusive) * points 247 | if round_to_int: 248 | numbers = np.round(numbers).astype(int) 249 | return numbers.tolist() 250 | 251 | 252 | def soft_append_bcthw(history, current, overlap=0): 253 | if overlap <= 0: 254 | return torch.cat([history, current], dim=2) 255 | 256 | assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})" 257 | assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})" 258 | 259 | weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1) 260 | blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap] 261 | output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2) 262 | 263 | return output.to(history) 264 | 265 | 266 | def save_bcthw_as_mp4(x, output_filename, fps=10): 267 | b, c, t, h, w = x.shape 268 | 269 | per_row = b 270 | for p in [6, 5, 4, 3, 2]: 271 | if b % p == 0: 272 | per_row = p 273 | break 274 | 275 | os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) 276 | x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 277 | x = x.detach().cpu().to(torch.uint8) 278 | x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row) 279 | torchvision.io.write_video(output_filename, x, fps=fps, video_codec='h264', options={'crf': '0'}) 280 | return x 281 | 282 | 283 | def save_bcthw_as_png(x, output_filename): 284 | os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) 285 | x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 286 | x = x.detach().cpu().to(torch.uint8) 287 | x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)') 288 | torchvision.io.write_png(x, output_filename) 289 | return output_filename 290 | 291 | 292 | def save_bchw_as_png(x, output_filename): 293 | os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) 294 | x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 295 | x = x.detach().cpu().to(torch.uint8) 296 | x = einops.rearrange(x, 'b c h w -> c h (b w)') 297 | torchvision.io.write_png(x, output_filename) 298 | return output_filename 299 | 300 | 301 | def add_tensors_with_padding(tensor1, tensor2): 302 | if tensor1.shape == tensor2.shape: 303 | return tensor1 + tensor2 304 | 305 | shape1 = tensor1.shape 306 | shape2 = tensor2.shape 307 | 308 | new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2)) 309 | 310 | padded_tensor1 = torch.zeros(new_shape) 311 | padded_tensor2 = torch.zeros(new_shape) 312 | 313 | padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1 314 | padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2 315 | 316 | result = padded_tensor1 + padded_tensor2 317 | return result 318 | 319 | 320 | def print_free_mem(): 321 | torch.cuda.empty_cache() 322 | free_mem, total_mem = torch.cuda.mem_get_info(0) 323 | free_mem_mb = free_mem / (1024 ** 2) 324 | total_mem_mb = total_mem / (1024 ** 2) 325 | print(f"Free memory: {free_mem_mb:.2f} MB") 326 | print(f"Total memory: {total_mem_mb:.2f} MB") 327 | return 328 | 329 | 330 | def print_gpu_parameters(device, state_dict, log_count=1): 331 | summary = {"device": device, "keys_count": len(state_dict)} 332 | 333 | logged_params = {} 334 | for i, (key, tensor) in enumerate(state_dict.items()): 335 | if i >= log_count: 336 | break 337 | logged_params[key] = tensor.flatten()[:3].tolist() 338 | 339 | summary["params"] = logged_params 340 | 341 | print(str(summary)) 342 | return 343 | 344 | 345 | def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18): 346 | from PIL import Image, ImageDraw, ImageFont 347 | 348 | txt = Image.new("RGB", (width, height), color="white") 349 | draw = ImageDraw.Draw(txt) 350 | font = ImageFont.truetype(font_path, size=size) 351 | 352 | if text == '': 353 | return np.array(txt) 354 | 355 | # Split text into lines that fit within the image width 356 | lines = [] 357 | words = text.split() 358 | current_line = words[0] 359 | 360 | for word in words[1:]: 361 | line_with_word = f"{current_line} {word}" 362 | if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width: 363 | current_line = line_with_word 364 | else: 365 | lines.append(current_line) 366 | current_line = word 367 | 368 | lines.append(current_line) 369 | 370 | # Draw the text line by line 371 | y = 0 372 | line_height = draw.textbbox((0, 0), "A", font=font)[3] 373 | 374 | for line in lines: 375 | if y + line_height > height: 376 | break # stop drawing if the next line will be outside the image 377 | draw.text((0, y), line, fill="black", font=font) 378 | y += line_height 379 | 380 | return np.array(txt) 381 | 382 | 383 | # def blue_mark(x): 384 | # x = x.copy() 385 | # c = x[:, :, 2] 386 | # b = cv2.blur(c, (9, 9)) 387 | # x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1) 388 | # return x 389 | 390 | 391 | # def green_mark(x): 392 | # x = x.copy() 393 | # x[:, :, 2] = -1 394 | # x[:, :, 0] = -1 395 | # return x 396 | 397 | 398 | # def frame_mark(x): 399 | # x = x.copy() 400 | # x[:64] = -1 401 | # x[-64:] = -1 402 | # x[:, :8] = 1 403 | # x[:, -8:] = 1 404 | # return x 405 | 406 | 407 | @torch.inference_mode() 408 | def pytorch2numpy(imgs): 409 | results = [] 410 | for x in imgs: 411 | y = x.movedim(0, -1) 412 | y = y * 127.5 + 127.5 413 | y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8) 414 | results.append(y) 415 | return results 416 | 417 | 418 | @torch.inference_mode() 419 | def numpy2pytorch(imgs): 420 | h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0 421 | h = h.movedim(-1, 1) 422 | return h 423 | 424 | 425 | @torch.no_grad() 426 | def duplicate_prefix_to_suffix(x, count, zero_out=False): 427 | if zero_out: 428 | return torch.cat([x, torch.zeros_like(x[:count])], dim=0) 429 | else: 430 | return torch.cat([x, x[:count]], dim=0) 431 | 432 | 433 | def weighted_mse(a, b, weight): 434 | return torch.mean(weight.float() * (a.float() - b.float()) ** 2) 435 | 436 | 437 | def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0): 438 | x = (x - x_min) / (x_max - x_min) 439 | x = max(0.0, min(x, 1.0)) 440 | x = x ** sigma 441 | return y_min + x * (y_max - y_min) 442 | 443 | 444 | def expand_to_dims(x, target_dims): 445 | return x.view(*x.shape, *([1] * max(0, target_dims - x.dim()))) 446 | 447 | 448 | def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int): 449 | if tensor is None: 450 | return None 451 | 452 | first_dim = tensor.shape[0] 453 | 454 | if first_dim == batch_size: 455 | return tensor 456 | 457 | if batch_size % first_dim != 0: 458 | raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.") 459 | 460 | repeat_times = batch_size // first_dim 461 | 462 | return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1)) 463 | 464 | 465 | def dim5(x): 466 | return expand_to_dims(x, 5) 467 | 468 | 469 | def dim4(x): 470 | return expand_to_dims(x, 4) 471 | 472 | 473 | def dim3(x): 474 | return expand_to_dims(x, 3) 475 | 476 | 477 | def crop_or_pad_yield_mask(x, length): 478 | B, F, C = x.shape 479 | device = x.device 480 | dtype = x.dtype 481 | 482 | if F < length: 483 | y = torch.zeros((B, length, C), dtype=dtype, device=device) 484 | mask = torch.zeros((B, length), dtype=torch.bool, device=device) 485 | y[:, :F, :] = x 486 | mask[:, :F] = True 487 | return y, mask 488 | 489 | return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device) 490 | 491 | 492 | def extend_dim(x, dim, minimal_length, zero_pad=False): 493 | original_length = int(x.shape[dim]) 494 | 495 | if original_length >= minimal_length: 496 | return x 497 | 498 | if zero_pad: 499 | padding_shape = list(x.shape) 500 | padding_shape[dim] = minimal_length - original_length 501 | padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device) 502 | else: 503 | idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1) 504 | last_element = x[idx] 505 | padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim) 506 | 507 | return torch.cat([x, padding], dim=dim) 508 | 509 | 510 | def lazy_positional_encoding(t, repeats=None): 511 | if not isinstance(t, list): 512 | t = [t] 513 | 514 | from diffusers.models.embeddings import get_timestep_embedding 515 | 516 | te = torch.tensor(t) 517 | te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0) 518 | 519 | if repeats is None: 520 | return te 521 | 522 | te = te[:, None, :].expand(-1, repeats, -1) 523 | 524 | return te 525 | 526 | 527 | def state_dict_offset_merge(A, B, C=None): 528 | result = {} 529 | keys = A.keys() 530 | 531 | for key in keys: 532 | A_value = A[key] 533 | B_value = B[key].to(A_value) 534 | 535 | if C is None: 536 | result[key] = A_value + B_value 537 | else: 538 | C_value = C[key].to(A_value) 539 | result[key] = A_value + B_value - C_value 540 | 541 | return result 542 | 543 | 544 | def state_dict_weighted_merge(state_dicts, weights): 545 | if len(state_dicts) != len(weights): 546 | raise ValueError("Number of state dictionaries must match number of weights") 547 | 548 | if not state_dicts: 549 | return {} 550 | 551 | total_weight = sum(weights) 552 | 553 | if total_weight == 0: 554 | raise ValueError("Sum of weights cannot be zero") 555 | 556 | normalized_weights = [w / total_weight for w in weights] 557 | 558 | keys = state_dicts[0].keys() 559 | result = {} 560 | 561 | for key in keys: 562 | result[key] = state_dicts[0][key] * normalized_weights[0] 563 | 564 | for i in range(1, len(state_dicts)): 565 | state_dict_value = state_dicts[i][key].to(result[key]) 566 | result[key] += state_dict_value * normalized_weights[i] 567 | 568 | return result 569 | 570 | 571 | def group_files_by_folder(all_files): 572 | grouped_files = {} 573 | 574 | for file in all_files: 575 | folder_name = os.path.basename(os.path.dirname(file)) 576 | if folder_name not in grouped_files: 577 | grouped_files[folder_name] = [] 578 | grouped_files[folder_name].append(file) 579 | 580 | list_of_lists = list(grouped_files.values()) 581 | return list_of_lists 582 | 583 | 584 | def generate_timestamp(): 585 | now = datetime.datetime.now() 586 | timestamp = now.strftime('%y%m%d_%H%M%S') 587 | milliseconds = f"{int(now.microsecond / 1000):03d}" 588 | random_number = random.randint(0, 9999) 589 | return f"{timestamp}_{milliseconds}_{random_number}" 590 | 591 | 592 | def write_PIL_image_with_png_info(image, metadata, path): 593 | from PIL.PngImagePlugin import PngInfo 594 | 595 | png_info = PngInfo() 596 | for key, value in metadata.items(): 597 | png_info.add_text(key, value) 598 | 599 | image.save(path, "PNG", pnginfo=png_info) 600 | return image 601 | 602 | 603 | def torch_safe_save(content, path): 604 | torch.save(content, path + '_tmp') 605 | os.replace(path + '_tmp', path) 606 | return path 607 | 608 | 609 | def move_optimizer_to_device(optimizer, device): 610 | for state in optimizer.state.values(): 611 | for k, v in state.items(): 612 | if isinstance(v, torch.Tensor): 613 | state[k] = v.to(device) 614 | -------------------------------------------------------------------------------- /example_workflows/framepack_hv_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "ce2cb810-7775-4564-8928-dd5bed1053cd", 3 | "revision": 0, 4 | "last_node_id": 69, 5 | "last_link_id": 158, 6 | "nodes": [ 7 | { 8 | "id": 15, 9 | "type": "ConditioningZeroOut", 10 | "pos": [ 11 | 1346.0872802734375, 12 | 263.21856689453125 13 | ], 14 | "size": [ 15 | 317.4000244140625, 16 | 26 17 | ], 18 | "flags": { 19 | "collapsed": true 20 | }, 21 | "order": 18, 22 | "mode": 0, 23 | "inputs": [ 24 | { 25 | "name": "conditioning", 26 | "type": "CONDITIONING", 27 | "link": 118 28 | } 29 | ], 30 | "outputs": [ 31 | { 32 | "name": "CONDITIONING", 33 | "type": "CONDITIONING", 34 | "links": [ 35 | 108 36 | ] 37 | } 38 | ], 39 | "properties": { 40 | "cnr_id": "comfy-core", 41 | "ver": "0.3.28", 42 | "Node name for S&R": "ConditioningZeroOut" 43 | }, 44 | "widgets_values": [], 45 | "color": "#332922", 46 | "bgcolor": "#593930" 47 | }, 48 | { 49 | "id": 13, 50 | "type": "DualCLIPLoader", 51 | "pos": [ 52 | 320.9956359863281, 53 | 166.8336181640625 54 | ], 55 | "size": [ 56 | 340.2243957519531, 57 | 130 58 | ], 59 | "flags": {}, 60 | "order": 0, 61 | "mode": 0, 62 | "inputs": [], 63 | "outputs": [ 64 | { 65 | "name": "CLIP", 66 | "type": "CLIP", 67 | "links": [ 68 | 102 69 | ] 70 | } 71 | ], 72 | "properties": { 73 | "cnr_id": "comfy-core", 74 | "ver": "0.3.28", 75 | "Node name for S&R": "DualCLIPLoader" 76 | }, 77 | "widgets_values": [ 78 | "clip_l.safetensors", 79 | "llava_llama3_fp16.safetensors", 80 | "hunyuan_video", 81 | "default" 82 | ], 83 | "color": "#432", 84 | "bgcolor": "#653" 85 | }, 86 | { 87 | "id": 54, 88 | "type": "DownloadAndLoadFramePackModel", 89 | "pos": [ 90 | 1256.5235595703125, 91 | -277.76226806640625 92 | ], 93 | "size": [ 94 | 315, 95 | 130 96 | ], 97 | "flags": {}, 98 | "order": 1, 99 | "mode": 4, 100 | "inputs": [ 101 | { 102 | "name": "compile_args", 103 | "shape": 7, 104 | "type": "FRAMEPACKCOMPILEARGS", 105 | "link": null 106 | } 107 | ], 108 | "outputs": [ 109 | { 110 | "name": "model", 111 | "type": "FramePackMODEL", 112 | "links": null 113 | } 114 | ], 115 | "properties": { 116 | "aux_id": "kijai/ComfyUI-FramePackWrapper", 117 | "ver": "49fe507eca8246cc9d08a8093892f40c1180e88f", 118 | "Node name for S&R": "DownloadAndLoadFramePackModel" 119 | }, 120 | "widgets_values": [ 121 | "lllyasviel/FramePackI2V_HY", 122 | "bf16", 123 | "disabled", 124 | "sdpa" 125 | ] 126 | }, 127 | { 128 | "id": 55, 129 | "type": "MarkdownNote", 130 | "pos": [ 131 | 567.05908203125, 132 | -628.8865966796875 133 | ], 134 | "size": [ 135 | 459.8609619140625, 136 | 285.9714660644531 137 | ], 138 | "flags": {}, 139 | "order": 2, 140 | "mode": 0, 141 | "inputs": [], 142 | "outputs": [], 143 | "properties": {}, 144 | "widgets_values": [ 145 | "Model links:\n\n[https://huggingface.co/Kijai/HunyuanVideo_comfy/blob/main/FramePackI2V_HY_fp8_e4m3fn.safetensors](https://huggingface.co/Kijai/HunyuanVideo_comfy/blob/main/FramePackI2V_HY_fp8_e4m3fn.safetensors)\n\n[https://huggingface.co/Kijai/HunyuanVideo_comfy/blob/main/FramePackI2V_HY_bf16.safetensors](https://huggingface.co/Kijai/HunyuanVideo_comfy/blob/main/FramePackI2V_HY_bf16.safetensors)\n\nsigclip:\n\n[https://huggingface.co/Comfy-Org/sigclip_vision_384/tree/main](https://huggingface.co/Comfy-Org/sigclip_vision_384/tree/main)\n\ntext encoder and VAE:\n\n[https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged/tree/main/split_files](https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged/tree/main/split_files)" 146 | ], 147 | "color": "#432", 148 | "bgcolor": "#653" 149 | }, 150 | { 151 | "id": 17, 152 | "type": "CLIPVisionEncode", 153 | "pos": [ 154 | 1545.9541015625, 155 | 359.1331481933594 156 | ], 157 | "size": [ 158 | 380.4000244140625, 159 | 78 160 | ], 161 | "flags": {}, 162 | "order": 23, 163 | "mode": 0, 164 | "inputs": [ 165 | { 166 | "name": "clip_vision", 167 | "type": "CLIP_VISION", 168 | "link": 149 169 | }, 170 | { 171 | "name": "image", 172 | "type": "IMAGE", 173 | "link": 116 174 | } 175 | ], 176 | "outputs": [ 177 | { 178 | "name": "CLIP_VISION_OUTPUT", 179 | "type": "CLIP_VISION_OUTPUT", 180 | "links": [ 181 | 141 182 | ] 183 | } 184 | ], 185 | "properties": { 186 | "cnr_id": "comfy-core", 187 | "ver": "0.3.28", 188 | "Node name for S&R": "CLIPVisionEncode" 189 | }, 190 | "widgets_values": [ 191 | "center" 192 | ], 193 | "color": "#233", 194 | "bgcolor": "#355" 195 | }, 196 | { 197 | "id": 64, 198 | "type": "GetNode", 199 | "pos": [ 200 | 1554.2071533203125, 201 | 486.79547119140625 202 | ], 203 | "size": [ 204 | 210, 205 | 60 206 | ], 207 | "flags": { 208 | "collapsed": true 209 | }, 210 | "order": 3, 211 | "mode": 0, 212 | "inputs": [], 213 | "outputs": [ 214 | { 215 | "name": "CLIP_VISION", 216 | "type": "CLIP_VISION", 217 | "links": [ 218 | 149 219 | ] 220 | } 221 | ], 222 | "title": "Get_ClipVisionModle", 223 | "properties": {}, 224 | "widgets_values": [ 225 | "ClipVisionModle" 226 | ], 227 | "color": "#233", 228 | "bgcolor": "#355" 229 | }, 230 | { 231 | "id": 48, 232 | "type": "GetImageSizeAndCount", 233 | "pos": [ 234 | 1259.2060546875, 235 | 626.8657836914062 236 | ], 237 | "size": [ 238 | 277.20001220703125, 239 | 86 240 | ], 241 | "flags": {}, 242 | "order": 21, 243 | "mode": 0, 244 | "inputs": [ 245 | { 246 | "name": "image", 247 | "type": "IMAGE", 248 | "link": 125 249 | } 250 | ], 251 | "outputs": [ 252 | { 253 | "name": "image", 254 | "type": "IMAGE", 255 | "links": [ 256 | 116, 257 | 156 258 | ] 259 | }, 260 | { 261 | "label": "704 width", 262 | "name": "width", 263 | "type": "INT", 264 | "links": null 265 | }, 266 | { 267 | "label": "544 height", 268 | "name": "height", 269 | "type": "INT", 270 | "links": null 271 | }, 272 | { 273 | "label": "1 count", 274 | "name": "count", 275 | "type": "INT", 276 | "links": null 277 | } 278 | ], 279 | "properties": { 280 | "cnr_id": "comfyui-kjnodes", 281 | "ver": "8ecf5cd05e0a1012087b0da90eea9a13674668db", 282 | "Node name for S&R": "GetImageSizeAndCount" 283 | }, 284 | "widgets_values": [] 285 | }, 286 | { 287 | "id": 60, 288 | "type": "GetImageSizeAndCount", 289 | "pos": [ 290 | 1279.781494140625, 291 | 1060.245361328125 292 | ], 293 | "size": [ 294 | 277.20001220703125, 295 | 86 296 | ], 297 | "flags": {}, 298 | "order": 22, 299 | "mode": 0, 300 | "inputs": [ 301 | { 302 | "name": "image", 303 | "type": "IMAGE", 304 | "link": 139 305 | } 306 | ], 307 | "outputs": [ 308 | { 309 | "name": "image", 310 | "type": "IMAGE", 311 | "links": [ 312 | 151, 313 | 152 314 | ] 315 | }, 316 | { 317 | "label": "704 width", 318 | "name": "width", 319 | "type": "INT", 320 | "links": null 321 | }, 322 | { 323 | "label": "544 height", 324 | "name": "height", 325 | "type": "INT", 326 | "links": null 327 | }, 328 | { 329 | "label": "1 count", 330 | "name": "count", 331 | "type": "INT", 332 | "links": null 333 | } 334 | ], 335 | "properties": { 336 | "cnr_id": "comfyui-kjnodes", 337 | "ver": "8ecf5cd05e0a1012087b0da90eea9a13674668db", 338 | "Node name for S&R": "GetImageSizeAndCount" 339 | }, 340 | "widgets_values": [] 341 | }, 342 | { 343 | "id": 12, 344 | "type": "VAELoader", 345 | "pos": [ 346 | 570.5363159179688, 347 | -282.70068359375 348 | ], 349 | "size": [ 350 | 469.0488586425781, 351 | 58 352 | ], 353 | "flags": {}, 354 | "order": 4, 355 | "mode": 0, 356 | "inputs": [], 357 | "outputs": [ 358 | { 359 | "name": "VAE", 360 | "type": "VAE", 361 | "links": [ 362 | 153 363 | ] 364 | } 365 | ], 366 | "properties": { 367 | "cnr_id": "comfy-core", 368 | "ver": "0.3.28", 369 | "Node name for S&R": "VAELoader" 370 | }, 371 | "widgets_values": [ 372 | "hyvid\\hunyuan_video_vae_bf16_repack.safetensors" 373 | ], 374 | "color": "#322", 375 | "bgcolor": "#533" 376 | }, 377 | { 378 | "id": 66, 379 | "type": "SetNode", 380 | "pos": [ 381 | 1083.503173828125, 382 | -358.4913330078125 383 | ], 384 | "size": [ 385 | 210, 386 | 60 387 | ], 388 | "flags": { 389 | "collapsed": true 390 | }, 391 | "order": 15, 392 | "mode": 0, 393 | "inputs": [ 394 | { 395 | "name": "VAE", 396 | "type": "VAE", 397 | "link": 153 398 | } 399 | ], 400 | "outputs": [ 401 | { 402 | "name": "*", 403 | "type": "*", 404 | "links": null 405 | } 406 | ], 407 | "title": "Set_VAE", 408 | "properties": { 409 | "previousName": "VAE" 410 | }, 411 | "widgets_values": [ 412 | "VAE" 413 | ], 414 | "color": "#322", 415 | "bgcolor": "#533" 416 | }, 417 | { 418 | "id": 20, 419 | "type": "VAEEncode", 420 | "pos": [ 421 | 1733.111083984375, 422 | 633.30419921875 423 | ], 424 | "size": [ 425 | 210, 426 | 46 427 | ], 428 | "flags": {}, 429 | "order": 24, 430 | "mode": 0, 431 | "inputs": [ 432 | { 433 | "name": "pixels", 434 | "type": "IMAGE", 435 | "link": 156 436 | }, 437 | { 438 | "name": "vae", 439 | "type": "VAE", 440 | "link": 155 441 | } 442 | ], 443 | "outputs": [ 444 | { 445 | "name": "LATENT", 446 | "type": "LATENT", 447 | "links": [ 448 | 86 449 | ] 450 | } 451 | ], 452 | "properties": { 453 | "cnr_id": "comfy-core", 454 | "ver": "0.3.28", 455 | "Node name for S&R": "VAEEncode" 456 | }, 457 | "widgets_values": [], 458 | "color": "#322", 459 | "bgcolor": "#533" 460 | }, 461 | { 462 | "id": 68, 463 | "type": "GetNode", 464 | "pos": [ 465 | 1729.60693359375, 466 | 734.5352172851562 467 | ], 468 | "size": [ 469 | 210, 470 | 34 471 | ], 472 | "flags": { 473 | "collapsed": true 474 | }, 475 | "order": 5, 476 | "mode": 0, 477 | "inputs": [], 478 | "outputs": [ 479 | { 480 | "name": "VAE", 481 | "type": "VAE", 482 | "links": [ 483 | 155 484 | ] 485 | } 486 | ], 487 | "title": "Get_VAE", 488 | "properties": {}, 489 | "widgets_values": [ 490 | "VAE" 491 | ], 492 | "color": "#322", 493 | "bgcolor": "#533" 494 | }, 495 | { 496 | "id": 62, 497 | "type": "VAEEncode", 498 | "pos": [ 499 | 1612.563232421875, 500 | 1048.6236572265625 501 | ], 502 | "size": [ 503 | 210, 504 | 46 505 | ], 506 | "flags": {}, 507 | "order": 26, 508 | "mode": 0, 509 | "inputs": [ 510 | { 511 | "name": "pixels", 512 | "type": "IMAGE", 513 | "link": 152 514 | }, 515 | { 516 | "name": "vae", 517 | "type": "VAE", 518 | "link": 158 519 | } 520 | ], 521 | "outputs": [ 522 | { 523 | "name": "LATENT", 524 | "type": "LATENT", 525 | "links": [ 526 | 147 527 | ] 528 | } 529 | ], 530 | "properties": { 531 | "cnr_id": "comfy-core", 532 | "ver": "0.3.28", 533 | "Node name for S&R": "VAEEncode" 534 | }, 535 | "widgets_values": [], 536 | "color": "#322", 537 | "bgcolor": "#533" 538 | }, 539 | { 540 | "id": 57, 541 | "type": "CLIPVisionEncode", 542 | "pos": [ 543 | 1600.4202880859375, 544 | 1181.36767578125 545 | ], 546 | "size": [ 547 | 380.4000244140625, 548 | 78 549 | ], 550 | "flags": {}, 551 | "order": 25, 552 | "mode": 0, 553 | "inputs": [ 554 | { 555 | "name": "clip_vision", 556 | "type": "CLIP_VISION", 557 | "link": 150 558 | }, 559 | { 560 | "name": "image", 561 | "type": "IMAGE", 562 | "link": 151 563 | } 564 | ], 565 | "outputs": [ 566 | { 567 | "name": "CLIP_VISION_OUTPUT", 568 | "type": "CLIP_VISION_OUTPUT", 569 | "links": [ 570 | 132 571 | ] 572 | } 573 | ], 574 | "properties": { 575 | "cnr_id": "comfy-core", 576 | "ver": "0.3.29", 577 | "Node name for S&R": "CLIPVisionEncode" 578 | }, 579 | "widgets_values": [ 580 | "center" 581 | ], 582 | "color": "#233", 583 | "bgcolor": "#355" 584 | }, 585 | { 586 | "id": 69, 587 | "type": "GetNode", 588 | "pos": [ 589 | 1619.6104736328125, 590 | 1137.854736328125 591 | ], 592 | "size": [ 593 | 210, 594 | 34 595 | ], 596 | "flags": { 597 | "collapsed": true 598 | }, 599 | "order": 6, 600 | "mode": 0, 601 | "inputs": [], 602 | "outputs": [ 603 | { 604 | "name": "VAE", 605 | "type": "VAE", 606 | "links": [ 607 | 158 608 | ] 609 | } 610 | ], 611 | "title": "Get_VAE", 612 | "properties": {}, 613 | "widgets_values": [ 614 | "VAE" 615 | ], 616 | "color": "#322", 617 | "bgcolor": "#533" 618 | }, 619 | { 620 | "id": 65, 621 | "type": "GetNode", 622 | "pos": [ 623 | 1604.746337890625, 624 | 1306.3175048828125 625 | ], 626 | "size": [ 627 | 210, 628 | 34 629 | ], 630 | "flags": { 631 | "collapsed": true 632 | }, 633 | "order": 7, 634 | "mode": 0, 635 | "inputs": [], 636 | "outputs": [ 637 | { 638 | "name": "CLIP_VISION", 639 | "type": "CLIP_VISION", 640 | "links": [ 641 | 150 642 | ] 643 | } 644 | ], 645 | "title": "Get_ClipVisionModle", 646 | "properties": {}, 647 | "widgets_values": [ 648 | "ClipVisionModle" 649 | ], 650 | "color": "#233", 651 | "bgcolor": "#355" 652 | }, 653 | { 654 | "id": 59, 655 | "type": "ImageResize+", 656 | "pos": [ 657 | 908.9832763671875, 658 | 1062.01123046875 659 | ], 660 | "size": [ 661 | 315, 662 | 218 663 | ], 664 | "flags": {}, 665 | "order": 20, 666 | "mode": 0, 667 | "inputs": [ 668 | { 669 | "name": "image", 670 | "type": "IMAGE", 671 | "link": 138 672 | }, 673 | { 674 | "name": "width", 675 | "type": "INT", 676 | "widget": { 677 | "name": "width" 678 | }, 679 | "link": 136 680 | }, 681 | { 682 | "name": "height", 683 | "type": "INT", 684 | "widget": { 685 | "name": "height" 686 | }, 687 | "link": 137 688 | } 689 | ], 690 | "outputs": [ 691 | { 692 | "name": "IMAGE", 693 | "type": "IMAGE", 694 | "links": [ 695 | 139 696 | ] 697 | }, 698 | { 699 | "name": "width", 700 | "type": "INT", 701 | "links": null 702 | }, 703 | { 704 | "name": "height", 705 | "type": "INT", 706 | "links": null 707 | } 708 | ], 709 | "properties": { 710 | "aux_id": "kijai/ComfyUI_essentials", 711 | "ver": "76e9d1e4399bd025ce8b12c290753d58f9f53e93", 712 | "Node name for S&R": "ImageResize+" 713 | }, 714 | "widgets_values": [ 715 | 512, 716 | 512, 717 | "lanczos", 718 | "stretch", 719 | "always", 720 | 0 721 | ] 722 | }, 723 | { 724 | "id": 50, 725 | "type": "ImageResize+", 726 | "pos": [ 727 | 907.2653198242188, 728 | 593.743896484375 729 | ], 730 | "size": [ 731 | 315, 732 | 218 733 | ], 734 | "flags": {}, 735 | "order": 19, 736 | "mode": 0, 737 | "inputs": [ 738 | { 739 | "name": "image", 740 | "type": "IMAGE", 741 | "link": 122 742 | }, 743 | { 744 | "name": "width", 745 | "type": "INT", 746 | "widget": { 747 | "name": "width" 748 | }, 749 | "link": 128 750 | }, 751 | { 752 | "name": "height", 753 | "type": "INT", 754 | "widget": { 755 | "name": "height" 756 | }, 757 | "link": 127 758 | } 759 | ], 760 | "outputs": [ 761 | { 762 | "name": "IMAGE", 763 | "type": "IMAGE", 764 | "links": [ 765 | 125 766 | ] 767 | }, 768 | { 769 | "name": "width", 770 | "type": "INT", 771 | "links": null 772 | }, 773 | { 774 | "name": "height", 775 | "type": "INT", 776 | "links": null 777 | } 778 | ], 779 | "properties": { 780 | "aux_id": "kijai/ComfyUI_essentials", 781 | "ver": "76e9d1e4399bd025ce8b12c290753d58f9f53e93", 782 | "Node name for S&R": "ImageResize+" 783 | }, 784 | "widgets_values": [ 785 | 512, 786 | 512, 787 | "lanczos", 788 | "stretch", 789 | "always", 790 | 0 791 | ] 792 | }, 793 | { 794 | "id": 58, 795 | "type": "LoadImage", 796 | "pos": [ 797 | 190.07057189941406, 798 | 1060.399169921875 799 | ], 800 | "size": [ 801 | 315, 802 | 314 803 | ], 804 | "flags": {}, 805 | "order": 8, 806 | "mode": 0, 807 | "inputs": [], 808 | "outputs": [ 809 | { 810 | "name": "IMAGE", 811 | "type": "IMAGE", 812 | "links": [ 813 | 138 814 | ] 815 | }, 816 | { 817 | "name": "MASK", 818 | "type": "MASK", 819 | "links": null 820 | } 821 | ], 822 | "title": "Load Image: End", 823 | "properties": { 824 | "cnr_id": "comfy-core", 825 | "ver": "0.3.28", 826 | "Node name for S&R": "LoadImage" 827 | }, 828 | "widgets_values": [ 829 | "sd3stag.png", 830 | "image" 831 | ] 832 | }, 833 | { 834 | "id": 51, 835 | "type": "FramePackFindNearestBucket", 836 | "pos": [ 837 | 550.0997314453125, 838 | 887.411376953125 839 | ], 840 | "size": [ 841 | 315, 842 | 78 843 | ], 844 | "flags": {}, 845 | "order": 16, 846 | "mode": 0, 847 | "inputs": [ 848 | { 849 | "name": "image", 850 | "type": "IMAGE", 851 | "link": 126 852 | } 853 | ], 854 | "outputs": [ 855 | { 856 | "name": "width", 857 | "type": "INT", 858 | "links": [ 859 | 128, 860 | 136 861 | ] 862 | }, 863 | { 864 | "name": "height", 865 | "type": "INT", 866 | "links": [ 867 | 127, 868 | 137 869 | ] 870 | } 871 | ], 872 | "properties": { 873 | "aux_id": "kijai/ComfyUI-FramePackWrapper", 874 | "ver": "4f9030a9f4c0bd67d86adf3d3dc07e37118c40bd", 875 | "Node name for S&R": "FramePackFindNearestBucket" 876 | }, 877 | "widgets_values": [ 878 | 640 879 | ] 880 | }, 881 | { 882 | "id": 19, 883 | "type": "LoadImage", 884 | "pos": [ 885 | 184.2612762451172, 886 | 591.6886596679688 887 | ], 888 | "size": [ 889 | 315, 890 | 314 891 | ], 892 | "flags": {}, 893 | "order": 9, 894 | "mode": 0, 895 | "inputs": [], 896 | "outputs": [ 897 | { 898 | "name": "IMAGE", 899 | "type": "IMAGE", 900 | "links": [ 901 | 122, 902 | 126 903 | ] 904 | }, 905 | { 906 | "name": "MASK", 907 | "type": "MASK", 908 | "links": null 909 | } 910 | ], 911 | "title": "Load Image: Start", 912 | "properties": { 913 | "cnr_id": "comfy-core", 914 | "ver": "0.3.28", 915 | "Node name for S&R": "LoadImage" 916 | }, 917 | "widgets_values": [ 918 | "sd3stag.png", 919 | "image" 920 | ] 921 | }, 922 | { 923 | "id": 18, 924 | "type": "CLIPVisionLoader", 925 | "pos": [ 926 | 33.149566650390625, 927 | 23.595293045043945 928 | ], 929 | "size": [ 930 | 388.87139892578125, 931 | 58 932 | ], 933 | "flags": {}, 934 | "order": 10, 935 | "mode": 0, 936 | "inputs": [], 937 | "outputs": [ 938 | { 939 | "name": "CLIP_VISION", 940 | "type": "CLIP_VISION", 941 | "links": [ 942 | 148 943 | ] 944 | } 945 | ], 946 | "properties": { 947 | "cnr_id": "comfy-core", 948 | "ver": "0.3.28", 949 | "Node name for S&R": "CLIPVisionLoader" 950 | }, 951 | "widgets_values": [ 952 | "sigclip_vision_patch14_384.safetensors" 953 | ], 954 | "color": "#2a363b", 955 | "bgcolor": "#3f5159" 956 | }, 957 | { 958 | "id": 63, 959 | "type": "SetNode", 960 | "pos": [ 961 | 247.1346435546875, 962 | -28.502397537231445 963 | ], 964 | "size": [ 965 | 210, 966 | 60 967 | ], 968 | "flags": { 969 | "collapsed": true 970 | }, 971 | "order": 17, 972 | "mode": 0, 973 | "inputs": [ 974 | { 975 | "name": "CLIP_VISION", 976 | "type": "CLIP_VISION", 977 | "link": 148 978 | } 979 | ], 980 | "outputs": [ 981 | { 982 | "name": "*", 983 | "type": "*", 984 | "links": null 985 | } 986 | ], 987 | "title": "Set_ClipVisionModle", 988 | "properties": { 989 | "previousName": "ClipVisionModle" 990 | }, 991 | "widgets_values": [ 992 | "ClipVisionModle" 993 | ], 994 | "color": "#233", 995 | "bgcolor": "#355" 996 | }, 997 | { 998 | "id": 27, 999 | "type": "FramePackTorchCompileSettings", 1000 | "pos": [ 1001 | 623.3660278320312, 1002 | -140.94215393066406 1003 | ], 1004 | "size": [ 1005 | 531.5999755859375, 1006 | 202 1007 | ], 1008 | "flags": {}, 1009 | "order": 11, 1010 | "mode": 0, 1011 | "inputs": [], 1012 | "outputs": [ 1013 | { 1014 | "name": "torch_compile_args", 1015 | "type": "FRAMEPACKCOMPILEARGS", 1016 | "links": [] 1017 | } 1018 | ], 1019 | "properties": { 1020 | "aux_id": "lllyasviel/FramePack", 1021 | "ver": "0e5fe5d7ca13c76fb8e13708f4b92e7c7a34f20c", 1022 | "Node name for S&R": "FramePackTorchCompileSettings" 1023 | }, 1024 | "widgets_values": [ 1025 | "inductor", 1026 | false, 1027 | "default", 1028 | false, 1029 | 64, 1030 | true, 1031 | true 1032 | ] 1033 | }, 1034 | { 1035 | "id": 33, 1036 | "type": "VAEDecodeTiled", 1037 | "pos": [ 1038 | 2328.923828125, 1039 | -22.08228874206543 1040 | ], 1041 | "size": [ 1042 | 315, 1043 | 150 1044 | ], 1045 | "flags": {}, 1046 | "order": 28, 1047 | "mode": 0, 1048 | "inputs": [ 1049 | { 1050 | "name": "samples", 1051 | "type": "LATENT", 1052 | "link": 85 1053 | }, 1054 | { 1055 | "name": "vae", 1056 | "type": "VAE", 1057 | "link": 154 1058 | } 1059 | ], 1060 | "outputs": [ 1061 | { 1062 | "name": "IMAGE", 1063 | "type": "IMAGE", 1064 | "links": [ 1065 | 96 1066 | ] 1067 | } 1068 | ], 1069 | "properties": { 1070 | "cnr_id": "comfy-core", 1071 | "ver": "0.3.28", 1072 | "Node name for S&R": "VAEDecodeTiled" 1073 | }, 1074 | "widgets_values": [ 1075 | 256, 1076 | 64, 1077 | 64, 1078 | 8 1079 | ], 1080 | "color": "#322", 1081 | "bgcolor": "#533" 1082 | }, 1083 | { 1084 | "id": 67, 1085 | "type": "GetNode", 1086 | "pos": [ 1087 | 2342.01806640625, 1088 | -76.06847381591797 1089 | ], 1090 | "size": [ 1091 | 210, 1092 | 60 1093 | ], 1094 | "flags": { 1095 | "collapsed": true 1096 | }, 1097 | "order": 12, 1098 | "mode": 0, 1099 | "inputs": [], 1100 | "outputs": [ 1101 | { 1102 | "name": "VAE", 1103 | "type": "VAE", 1104 | "links": [ 1105 | 154 1106 | ] 1107 | } 1108 | ], 1109 | "title": "Get_VAE", 1110 | "properties": {}, 1111 | "widgets_values": [ 1112 | "VAE" 1113 | ], 1114 | "color": "#322", 1115 | "bgcolor": "#533" 1116 | }, 1117 | { 1118 | "id": 23, 1119 | "type": "VHS_VideoCombine", 1120 | "pos": [ 1121 | 2726.849853515625, 1122 | -29.90264129638672 1123 | ], 1124 | "size": [ 1125 | 908.428955078125, 1126 | 334 1127 | ], 1128 | "flags": {}, 1129 | "order": 30, 1130 | "mode": 0, 1131 | "inputs": [ 1132 | { 1133 | "name": "images", 1134 | "type": "IMAGE", 1135 | "link": 97 1136 | }, 1137 | { 1138 | "name": "audio", 1139 | "shape": 7, 1140 | "type": "AUDIO", 1141 | "link": null 1142 | }, 1143 | { 1144 | "name": "meta_batch", 1145 | "shape": 7, 1146 | "type": "VHS_BatchManager", 1147 | "link": null 1148 | }, 1149 | { 1150 | "name": "vae", 1151 | "shape": 7, 1152 | "type": "VAE", 1153 | "link": null 1154 | } 1155 | ], 1156 | "outputs": [ 1157 | { 1158 | "name": "Filenames", 1159 | "type": "VHS_FILENAMES", 1160 | "links": null 1161 | } 1162 | ], 1163 | "properties": { 1164 | "cnr_id": "comfyui-videohelpersuite", 1165 | "ver": "0a75c7958fe320efcb052f1d9f8451fd20c730a8", 1166 | "Node name for S&R": "VHS_VideoCombine" 1167 | }, 1168 | "widgets_values": { 1169 | "frame_rate": 30, 1170 | "loop_count": 0, 1171 | "filename_prefix": "FramePack", 1172 | "format": "video/h264-mp4", 1173 | "pix_fmt": "yuv420p", 1174 | "crf": 19, 1175 | "save_metadata": true, 1176 | "trim_to_audio": false, 1177 | "pingpong": false, 1178 | "save_output": false, 1179 | "videopreview": { 1180 | "hidden": false, 1181 | "paused": false, 1182 | "params": { 1183 | "filename": "FramePack_00001.mp4", 1184 | "subfolder": "", 1185 | "type": "temp", 1186 | "format": "video/h264-mp4", 1187 | "frame_rate": 30, 1188 | "workflow": "FramePack_00001.png", 1189 | "fullpath": "N:\\AI\\ComfyUI\\temp\\FramePack_00001.mp4" 1190 | } 1191 | } 1192 | } 1193 | }, 1194 | { 1195 | "id": 44, 1196 | "type": "GetImageSizeAndCount", 1197 | "pos": [ 1198 | 2501.023193359375, 1199 | -178.70773315429688 1200 | ], 1201 | "size": [ 1202 | 277.20001220703125, 1203 | 86 1204 | ], 1205 | "flags": {}, 1206 | "order": 29, 1207 | "mode": 0, 1208 | "inputs": [ 1209 | { 1210 | "name": "image", 1211 | "type": "IMAGE", 1212 | "link": 96 1213 | } 1214 | ], 1215 | "outputs": [ 1216 | { 1217 | "name": "image", 1218 | "type": "IMAGE", 1219 | "links": [ 1220 | 97 1221 | ] 1222 | }, 1223 | { 1224 | "label": "704 width", 1225 | "name": "width", 1226 | "type": "INT", 1227 | "links": null 1228 | }, 1229 | { 1230 | "label": "544 height", 1231 | "name": "height", 1232 | "type": "INT", 1233 | "links": null 1234 | }, 1235 | { 1236 | "label": "145 count", 1237 | "name": "count", 1238 | "type": "INT", 1239 | "links": null 1240 | } 1241 | ], 1242 | "properties": { 1243 | "cnr_id": "comfyui-kjnodes", 1244 | "ver": "8ecf5cd05e0a1012087b0da90eea9a13674668db", 1245 | "Node name for S&R": "GetImageSizeAndCount" 1246 | }, 1247 | "widgets_values": [] 1248 | }, 1249 | { 1250 | "id": 47, 1251 | "type": "CLIPTextEncode", 1252 | "pos": [ 1253 | 715.3054809570312, 1254 | 127.73457336425781 1255 | ], 1256 | "size": [ 1257 | 400, 1258 | 200 1259 | ], 1260 | "flags": {}, 1261 | "order": 14, 1262 | "mode": 0, 1263 | "inputs": [ 1264 | { 1265 | "name": "clip", 1266 | "type": "CLIP", 1267 | "link": 102 1268 | } 1269 | ], 1270 | "outputs": [ 1271 | { 1272 | "name": "CONDITIONING", 1273 | "type": "CONDITIONING", 1274 | "links": [ 1275 | 114, 1276 | 118 1277 | ] 1278 | } 1279 | ], 1280 | "properties": { 1281 | "cnr_id": "comfy-core", 1282 | "ver": "0.3.28", 1283 | "Node name for S&R": "CLIPTextEncode" 1284 | }, 1285 | "widgets_values": [ 1286 | "majestig stag in a forest" 1287 | ], 1288 | "color": "#232", 1289 | "bgcolor": "#353" 1290 | }, 1291 | { 1292 | "id": 52, 1293 | "type": "LoadFramePackModel", 1294 | "pos": [ 1295 | 1253.046630859375, 1296 | -82.57657623291016 1297 | ], 1298 | "size": [ 1299 | 480.7601013183594, 1300 | 174 1301 | ], 1302 | "flags": {}, 1303 | "order": 13, 1304 | "mode": 0, 1305 | "inputs": [ 1306 | { 1307 | "name": "compile_args", 1308 | "shape": 7, 1309 | "type": "FRAMEPACKCOMPILEARGS", 1310 | "link": null 1311 | }, 1312 | { 1313 | "name": "lora", 1314 | "shape": 7, 1315 | "type": "FPLORA", 1316 | "link": null 1317 | } 1318 | ], 1319 | "outputs": [ 1320 | { 1321 | "name": "model", 1322 | "type": "FramePackMODEL", 1323 | "links": [ 1324 | 129 1325 | ] 1326 | } 1327 | ], 1328 | "properties": { 1329 | "aux_id": "kijai/ComfyUI-FramePackWrapper", 1330 | "ver": "49fe507eca8246cc9d08a8093892f40c1180e88f", 1331 | "Node name for S&R": "LoadFramePackModel" 1332 | }, 1333 | "widgets_values": [ 1334 | "Hyvid\\FramePackI2V_HY_fp8_e4m3fn.safetensors", 1335 | "bf16", 1336 | "fp8_e4m3fn", 1337 | "offload_device", 1338 | "sdpa" 1339 | ] 1340 | }, 1341 | { 1342 | "id": 39, 1343 | "type": "FramePackSampler", 1344 | "pos": [ 1345 | 2292.58837890625, 1346 | 194.90232849121094 1347 | ], 1348 | "size": [ 1349 | 365.07305908203125, 1350 | 814.6473388671875 1351 | ], 1352 | "flags": {}, 1353 | "order": 27, 1354 | "mode": 0, 1355 | "inputs": [ 1356 | { 1357 | "name": "model", 1358 | "type": "FramePackMODEL", 1359 | "link": 129 1360 | }, 1361 | { 1362 | "name": "positive", 1363 | "type": "CONDITIONING", 1364 | "link": 114 1365 | }, 1366 | { 1367 | "name": "negative", 1368 | "type": "CONDITIONING", 1369 | "link": 108 1370 | }, 1371 | { 1372 | "name": "start_latent", 1373 | "type": "LATENT", 1374 | "link": 86 1375 | }, 1376 | { 1377 | "name": "image_embeds", 1378 | "shape": 7, 1379 | "type": "CLIP_VISION_OUTPUT", 1380 | "link": 141 1381 | }, 1382 | { 1383 | "name": "end_latent", 1384 | "shape": 7, 1385 | "type": "LATENT", 1386 | "link": 147 1387 | }, 1388 | { 1389 | "name": "end_image_embeds", 1390 | "shape": 7, 1391 | "type": "CLIP_VISION_OUTPUT", 1392 | "link": 132 1393 | }, 1394 | { 1395 | "name": "initial_samples", 1396 | "shape": 7, 1397 | "type": "LATENT", 1398 | "link": null 1399 | } 1400 | ], 1401 | "outputs": [ 1402 | { 1403 | "name": "samples", 1404 | "type": "LATENT", 1405 | "links": [ 1406 | 85 1407 | ] 1408 | } 1409 | ], 1410 | "properties": { 1411 | "aux_id": "kijai/ComfyUI-FramePackWrapper", 1412 | "ver": "8e5ec6b7f3acf88255c5d93d062079f18b43aa2b", 1413 | "Node name for S&R": "FramePackSampler" 1414 | }, 1415 | "widgets_values": [ 1416 | 30, 1417 | true, 1418 | 0.15, 1419 | 1, 1420 | 10, 1421 | 0, 1422 | 47, 1423 | "fixed", 1424 | 9, 1425 | 5, 1426 | 6, 1427 | "unipc_bh1", 1428 | "weighted_average", 1429 | 0.5, 1430 | 1 1431 | ] 1432 | } 1433 | ], 1434 | "links": [ 1435 | [ 1436 | 85, 1437 | 39, 1438 | 0, 1439 | 33, 1440 | 0, 1441 | "LATENT" 1442 | ], 1443 | [ 1444 | 86, 1445 | 20, 1446 | 0, 1447 | 39, 1448 | 3, 1449 | "LATENT" 1450 | ], 1451 | [ 1452 | 96, 1453 | 33, 1454 | 0, 1455 | 44, 1456 | 0, 1457 | "IMAGE" 1458 | ], 1459 | [ 1460 | 97, 1461 | 44, 1462 | 0, 1463 | 23, 1464 | 0, 1465 | "IMAGE" 1466 | ], 1467 | [ 1468 | 102, 1469 | 13, 1470 | 0, 1471 | 47, 1472 | 0, 1473 | "CLIP" 1474 | ], 1475 | [ 1476 | 108, 1477 | 15, 1478 | 0, 1479 | 39, 1480 | 2, 1481 | "CONDITIONING" 1482 | ], 1483 | [ 1484 | 114, 1485 | 47, 1486 | 0, 1487 | 39, 1488 | 1, 1489 | "CONDITIONING" 1490 | ], 1491 | [ 1492 | 116, 1493 | 48, 1494 | 0, 1495 | 17, 1496 | 1, 1497 | "IMAGE" 1498 | ], 1499 | [ 1500 | 118, 1501 | 47, 1502 | 0, 1503 | 15, 1504 | 0, 1505 | "CONDITIONING" 1506 | ], 1507 | [ 1508 | 122, 1509 | 19, 1510 | 0, 1511 | 50, 1512 | 0, 1513 | "IMAGE" 1514 | ], 1515 | [ 1516 | 125, 1517 | 50, 1518 | 0, 1519 | 48, 1520 | 0, 1521 | "IMAGE" 1522 | ], 1523 | [ 1524 | 126, 1525 | 19, 1526 | 0, 1527 | 51, 1528 | 0, 1529 | "IMAGE" 1530 | ], 1531 | [ 1532 | 127, 1533 | 51, 1534 | 1, 1535 | 50, 1536 | 2, 1537 | "INT" 1538 | ], 1539 | [ 1540 | 128, 1541 | 51, 1542 | 0, 1543 | 50, 1544 | 1, 1545 | "INT" 1546 | ], 1547 | [ 1548 | 129, 1549 | 52, 1550 | 0, 1551 | 39, 1552 | 0, 1553 | "FramePackMODEL" 1554 | ], 1555 | [ 1556 | 132, 1557 | 57, 1558 | 0, 1559 | 39, 1560 | 6, 1561 | "CLIP_VISION_OUTPUT" 1562 | ], 1563 | [ 1564 | 136, 1565 | 51, 1566 | 0, 1567 | 59, 1568 | 1, 1569 | "INT" 1570 | ], 1571 | [ 1572 | 137, 1573 | 51, 1574 | 1, 1575 | 59, 1576 | 2, 1577 | "INT" 1578 | ], 1579 | [ 1580 | 138, 1581 | 58, 1582 | 0, 1583 | 59, 1584 | 0, 1585 | "IMAGE" 1586 | ], 1587 | [ 1588 | 139, 1589 | 59, 1590 | 0, 1591 | 60, 1592 | 0, 1593 | "IMAGE" 1594 | ], 1595 | [ 1596 | 141, 1597 | 17, 1598 | 0, 1599 | 39, 1600 | 4, 1601 | "CLIP_VISION_OUTPUT" 1602 | ], 1603 | [ 1604 | 147, 1605 | 62, 1606 | 0, 1607 | 39, 1608 | 5, 1609 | "LATENT" 1610 | ], 1611 | [ 1612 | 148, 1613 | 18, 1614 | 0, 1615 | 63, 1616 | 0, 1617 | "*" 1618 | ], 1619 | [ 1620 | 149, 1621 | 64, 1622 | 0, 1623 | 17, 1624 | 0, 1625 | "CLIP_VISION" 1626 | ], 1627 | [ 1628 | 150, 1629 | 65, 1630 | 0, 1631 | 57, 1632 | 0, 1633 | "CLIP_VISION" 1634 | ], 1635 | [ 1636 | 151, 1637 | 60, 1638 | 0, 1639 | 57, 1640 | 1, 1641 | "IMAGE" 1642 | ], 1643 | [ 1644 | 152, 1645 | 60, 1646 | 0, 1647 | 62, 1648 | 0, 1649 | "IMAGE" 1650 | ], 1651 | [ 1652 | 153, 1653 | 12, 1654 | 0, 1655 | 66, 1656 | 0, 1657 | "*" 1658 | ], 1659 | [ 1660 | 154, 1661 | 67, 1662 | 0, 1663 | 33, 1664 | 1, 1665 | "VAE" 1666 | ], 1667 | [ 1668 | 155, 1669 | 68, 1670 | 0, 1671 | 20, 1672 | 1, 1673 | "VAE" 1674 | ], 1675 | [ 1676 | 156, 1677 | 48, 1678 | 0, 1679 | 20, 1680 | 0, 1681 | "IMAGE" 1682 | ], 1683 | [ 1684 | 158, 1685 | 69, 1686 | 0, 1687 | 62, 1688 | 1, 1689 | "VAE" 1690 | ] 1691 | ], 1692 | "groups": [ 1693 | { 1694 | "id": 1, 1695 | "title": "End Image", 1696 | "bounding": [ 1697 | 12.77297592163086, 1698 | 999.1203002929688, 1699 | 2038.674560546875, 1700 | 412.9618225097656 1701 | ], 1702 | "color": "#3f789e", 1703 | "font_size": 24, 1704 | "flags": {} 1705 | }, 1706 | { 1707 | "id": 2, 1708 | "title": "Start Image", 1709 | "bounding": [ 1710 | 11.781991958618164, 1711 | 531.3884887695312, 1712 | 2032.7288818359375, 1713 | 442.6904602050781 1714 | ], 1715 | "color": "#3f789e", 1716 | "font_size": 24, 1717 | "flags": {} 1718 | } 1719 | ], 1720 | "config": {}, 1721 | "extra": { 1722 | "ds": { 1723 | "scale": 0.6115909044841659, 1724 | "offset": [ 1725 | 21.57747102795121, 1726 | 375.7674957811538 1727 | ] 1728 | }, 1729 | "frontendVersion": "1.18.3", 1730 | "VHS_latentpreview": true, 1731 | "VHS_latentpreviewrate": 0, 1732 | "VHS_MetadataImage": true, 1733 | "VHS_KeepIntermediate": true 1734 | }, 1735 | "version": 0.4 1736 | } -------------------------------------------------------------------------------- /fp8_optimization.py: -------------------------------------------------------------------------------- 1 | #based on ComfyUI's and MinusZoneAI's fp8_linear optimization 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | def fp8_linear_forward(cls, original_dtype, input): 7 | weight_dtype = cls.weight.dtype 8 | if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: 9 | if len(input.shape) == 3: 10 | target_dtype = torch.float8_e5m2 if weight_dtype == torch.float8_e4m3fn else torch.float8_e4m3fn 11 | inn = input.reshape(-1, input.shape[2]).to(target_dtype) 12 | w = cls.weight.t() 13 | 14 | scale = torch.ones((1), device=input.device, dtype=torch.float32) 15 | bias = cls.bias.to(original_dtype) if cls.bias is not None else None 16 | 17 | if bias is not None: 18 | o = torch._scaled_mm(inn, w, out_dtype=original_dtype, bias=bias, scale_a=scale, scale_b=scale) 19 | else: 20 | o = torch._scaled_mm(inn, w, out_dtype=original_dtype, scale_a=scale, scale_b=scale) 21 | 22 | if isinstance(o, tuple): 23 | o = o[0] 24 | 25 | return o.reshape((-1, input.shape[1], cls.weight.shape[0])) 26 | else: 27 | return cls.original_forward(input.to(original_dtype)) 28 | else: 29 | return cls.original_forward(input) 30 | 31 | def convert_fp8_linear(module, original_dtype, params_to_keep={}): 32 | setattr(module, "fp8_matmul_enabled", True) 33 | 34 | for name, module in module.named_modules(): 35 | if not any(keyword in name for keyword in params_to_keep): 36 | if isinstance(module, nn.Linear): 37 | original_forward = module.forward 38 | setattr(module, "original_forward", original_forward) 39 | setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input)) 40 | -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import math 4 | from tqdm import tqdm 5 | 6 | from accelerate import init_empty_weights 7 | from accelerate.utils import set_module_tensor_to_device 8 | 9 | import folder_paths 10 | import comfy.model_management as mm 11 | from comfy.utils import load_torch_file, ProgressBar, common_upscale 12 | import comfy.model_base 13 | import comfy.latent_formats 14 | from comfy.cli_args import args, LatentPreviewMethod 15 | 16 | from .utils import log 17 | 18 | script_directory = os.path.dirname(os.path.abspath(__file__)) 19 | vae_scaling_factor = 0.476986 20 | 21 | from .diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModel 22 | from .diffusers_helper.memory import DynamicSwapInstaller, move_model_to_device_with_memory_preservation 23 | from .diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan 24 | from .diffusers_helper.utils import crop_or_pad_yield_mask 25 | from .diffusers_helper.bucket_tools import find_nearest_bucket 26 | 27 | from diffusers.loaders.lora_conversion_utils import _convert_hunyuan_video_lora_to_diffusers 28 | 29 | class HyVideoModel(comfy.model_base.BaseModel): 30 | def __init__(self, *args, **kwargs): 31 | super().__init__(*args, **kwargs) 32 | self.pipeline = {} 33 | self.load_device = mm.get_torch_device() 34 | 35 | def __getitem__(self, k): 36 | return self.pipeline[k] 37 | 38 | def __setitem__(self, k, v): 39 | self.pipeline[k] = v 40 | 41 | 42 | class HyVideoModelConfig: 43 | def __init__(self, dtype): 44 | self.unet_config = {} 45 | self.unet_extra_config = {} 46 | self.latent_format = comfy.latent_formats.HunyuanVideo 47 | self.latent_format.latent_channels = 16 48 | self.manual_cast_dtype = dtype 49 | self.sampling_settings = {"multiplier": 1.0} 50 | self.memory_usage_factor = 2.0 51 | self.unet_config["disable_unet_model_creation"] = True 52 | 53 | class FramePackTorchCompileSettings: 54 | @classmethod 55 | def INPUT_TYPES(s): 56 | return { 57 | "required": { 58 | "backend": (["inductor","cudagraphs"], {"default": "inductor"}), 59 | "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), 60 | "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), 61 | "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), 62 | "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), 63 | "compile_single_blocks": ("BOOLEAN", {"default": True, "tooltip": "Enable single block compilation"}), 64 | "compile_double_blocks": ("BOOLEAN", {"default": True, "tooltip": "Enable double block compilation"}), 65 | }, 66 | } 67 | RETURN_TYPES = ("FRAMEPACKCOMPILEARGS",) 68 | RETURN_NAMES = ("torch_compile_args",) 69 | FUNCTION = "loadmodel" 70 | CATEGORY = "HunyuanVideoWrapper" 71 | DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended" 72 | 73 | def loadmodel(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_single_blocks, compile_double_blocks): 74 | 75 | compile_args = { 76 | "backend": backend, 77 | "fullgraph": fullgraph, 78 | "mode": mode, 79 | "dynamic": dynamic, 80 | "dynamo_cache_size_limit": dynamo_cache_size_limit, 81 | "compile_single_blocks": compile_single_blocks, 82 | "compile_double_blocks": compile_double_blocks 83 | } 84 | 85 | return (compile_args, ) 86 | 87 | #region Model loading 88 | class DownloadAndLoadFramePackModel: 89 | @classmethod 90 | def INPUT_TYPES(s): 91 | return { 92 | "required": { 93 | "model": (["lllyasviel/FramePackI2V_HY"],), 94 | 95 | "base_precision": (["fp32", "bf16", "fp16"], {"default": "bf16"}), 96 | "quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'fp8_e5m2'], {"default": 'disabled', "tooltip": "optional quantization method"}), 97 | }, 98 | "optional": { 99 | "attention_mode": ([ 100 | "sdpa", 101 | "flash_attn", 102 | "sageattn", 103 | ], {"default": "sdpa"}), 104 | "compile_args": ("FRAMEPACKCOMPILEARGS", ), 105 | } 106 | } 107 | 108 | RETURN_TYPES = ("FramePackMODEL",) 109 | RETURN_NAMES = ("model", ) 110 | FUNCTION = "loadmodel" 111 | CATEGORY = "FramePackWrapper" 112 | 113 | def loadmodel(self, model, base_precision, quantization, 114 | compile_args=None, attention_mode="sdpa"): 115 | 116 | base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[base_precision] 117 | 118 | device = mm.get_torch_device() 119 | 120 | model_path = os.path.join(folder_paths.models_dir, "diffusers", "lllyasviel", "FramePackI2V_HY") 121 | if not os.path.exists(model_path): 122 | print(f"Downloading clip model to: {model_path}") 123 | from huggingface_hub import snapshot_download 124 | snapshot_download( 125 | repo_id=model, 126 | local_dir=model_path, 127 | local_dir_use_symlinks=False, 128 | ) 129 | 130 | transformer = HunyuanVideoTransformer3DModel.from_pretrained(model_path, torch_dtype=base_dtype, attention_mode=attention_mode).cpu() 131 | params_to_keep = {"norm", "bias", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"} 132 | if quantization == 'fp8_e4m3fn' or quantization == 'fp8_e4m3fn_fast': 133 | transformer = transformer.to(torch.float8_e4m3fn) 134 | if quantization == "fp8_e4m3fn_fast": 135 | from .fp8_optimization import convert_fp8_linear 136 | convert_fp8_linear(transformer, base_dtype, params_to_keep=params_to_keep) 137 | elif quantization == 'fp8_e5m2': 138 | transformer = transformer.to(torch.float8_e5m2) 139 | else: 140 | transformer = transformer.to(base_dtype) 141 | 142 | DynamicSwapInstaller.install_model(transformer, device=device) 143 | 144 | if compile_args is not None: 145 | if compile_args["compile_single_blocks"]: 146 | for i, block in enumerate(transformer.single_transformer_blocks): 147 | transformer.single_transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) 148 | if compile_args["compile_double_blocks"]: 149 | for i, block in enumerate(transformer.transformer_blocks): 150 | transformer.transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) 151 | 152 | #transformer = torch.compile(transformer, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) 153 | 154 | pipe = { 155 | "transformer": transformer.eval(), 156 | "dtype": base_dtype, 157 | } 158 | return (pipe, ) 159 | 160 | class FramePackLoraSelect: 161 | @classmethod 162 | def INPUT_TYPES(s): 163 | return { 164 | "required": { 165 | "lora": (folder_paths.get_filename_list("loras"), 166 | {"tooltip": "LORA models are expected to be in ComfyUI/models/loras with .safetensors extension"}), 167 | "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), 168 | "fuse_lora": ("BOOLEAN", {"default": True, "tooltip": "Fuse the LORA model with the base model. This is recommended for better performance."}), 169 | }, 170 | "optional": { 171 | "prev_lora":("FPLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}), 172 | } 173 | } 174 | 175 | RETURN_TYPES = ("FPLORA",) 176 | RETURN_NAMES = ("lora", ) 177 | FUNCTION = "getlorapath" 178 | CATEGORY = "FramePackWrapper" 179 | DESCRIPTION = "Select a LoRA model from ComfyUI/models/loras" 180 | 181 | def getlorapath(self, lora, strength, prev_lora=None, fuse_lora=True): 182 | loras_list = [] 183 | 184 | lora = { 185 | "path": folder_paths.get_full_path("loras", lora), 186 | "strength": strength, 187 | "name": lora.split(".")[0], 188 | "fuse_lora": fuse_lora, 189 | } 190 | if prev_lora is not None: 191 | loras_list.extend(prev_lora) 192 | 193 | loras_list.append(lora) 194 | return (loras_list,) 195 | 196 | class LoadFramePackModel: 197 | @classmethod 198 | def INPUT_TYPES(s): 199 | return { 200 | "required": { 201 | "model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}), 202 | 203 | "base_precision": (["fp32", "bf16", "fp16"], {"default": "bf16"}), 204 | "quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'fp8_e5m2'], {"default": 'disabled', "tooltip": "optional quantization method"}), 205 | "load_device": (["main_device", "offload_device"], {"default": "cuda", "tooltip": "Initialize the model on the main device or offload device"}), 206 | }, 207 | "optional": { 208 | "attention_mode": ([ 209 | "sdpa", 210 | "flash_attn", 211 | "sageattn", 212 | ], {"default": "sdpa"}), 213 | "compile_args": ("FRAMEPACKCOMPILEARGS", ), 214 | "lora": ("FPLORA", {"default": None, "tooltip": "LORA model to load"}), 215 | } 216 | } 217 | 218 | RETURN_TYPES = ("FramePackMODEL",) 219 | RETURN_NAMES = ("model", ) 220 | FUNCTION = "loadmodel" 221 | CATEGORY = "FramePackWrapper" 222 | 223 | def loadmodel(self, model, base_precision, quantization, 224 | compile_args=None, attention_mode="sdpa", lora=None, load_device="main_device"): 225 | 226 | base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[base_precision] 227 | 228 | device = mm.get_torch_device() 229 | offload_device = mm.unet_offload_device() 230 | if load_device == "main_device": 231 | transformer_load_device = device 232 | else: 233 | transformer_load_device = offload_device 234 | 235 | model_path = folder_paths.get_full_path_or_raise("diffusion_models", model) 236 | model_config_path = os.path.join(script_directory, "transformer_config.json") 237 | import json 238 | with open(model_config_path, "r") as f: 239 | config = json.load(f) 240 | sd = load_torch_file(model_path, device=offload_device, safe_load=True) 241 | model_weight_dtype = sd['single_transformer_blocks.0.attn.to_k.weight'].dtype 242 | 243 | with init_empty_weights(): 244 | transformer = HunyuanVideoTransformer3DModel(**config, attention_mode=attention_mode) 245 | 246 | params_to_keep = {"norm", "bias", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"} 247 | if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fast" or quantization == "fp8_scaled": 248 | dtype = torch.float8_e4m3fn 249 | elif quantization == "fp8_e5m2": 250 | dtype = torch.float8_e5m2 251 | else: 252 | dtype = base_dtype 253 | 254 | if lora is not None: 255 | after_lora_dtype = dtype 256 | dtype = base_dtype 257 | 258 | print("Using accelerate to load and assign model weights to device...") 259 | param_count = sum(1 for _ in transformer.named_parameters()) 260 | for name, param in tqdm(transformer.named_parameters(), 261 | desc=f"Loading transformer parameters to {transformer_load_device}", 262 | total=param_count, 263 | leave=True): 264 | dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype 265 | 266 | set_module_tensor_to_device(transformer, name, device=transformer_load_device, dtype=dtype_to_use, value=sd[name]) 267 | 268 | if lora is not None: 269 | adapter_list = [] 270 | adapter_weights = [] 271 | 272 | for l in lora: 273 | fuse = True if l["fuse_lora"] else False 274 | lora_sd = load_torch_file(l["path"]) 275 | 276 | if "lora_unet_single_transformer_blocks_0_attn_to_k.lora_up.weight" in lora_sd: 277 | from .utils import convert_to_diffusers 278 | lora_sd = convert_to_diffusers("lora_unet_", lora_sd) 279 | 280 | if not "transformer.single_transformer_blocks.0.attn_to.k.lora_A.weight" in lora_sd: 281 | log.info(f"Converting LoRA weights from {l['path']} to diffusers format...") 282 | lora_sd = _convert_hunyuan_video_lora_to_diffusers(lora_sd) 283 | 284 | lora_rank = None 285 | for key, val in lora_sd.items(): 286 | if "lora_B" in key or "lora_up" in key: 287 | lora_rank = val.shape[1] 288 | break 289 | if lora_rank is not None: 290 | log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}") 291 | adapter_name = l['path'].split("/")[-1].split(".")[0] 292 | adapter_weight = l['strength'] 293 | transformer.load_lora_adapter(lora_sd, weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name) 294 | 295 | adapter_list.append(adapter_name) 296 | adapter_weights.append(adapter_weight) 297 | 298 | del lora_sd 299 | mm.soft_empty_cache() 300 | if adapter_list: 301 | transformer.set_adapters(adapter_list, weights=adapter_weights) 302 | if fuse: 303 | if model_weight_dtype not in [torch.float32, torch.float16, torch.bfloat16]: 304 | raise ValueError("Fusing LoRA doesn't work well with fp8 model weights. Please use a bf16 model file, or disable LoRA fusing.") 305 | lora_scale = 1 306 | transformer.fuse_lora(lora_scale=lora_scale) 307 | transformer.delete_adapters(adapter_list) 308 | 309 | if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fast" or quantization == "fp8_e5m2": 310 | params_to_keep = {"norm", "bias", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"} 311 | for name, param in transformer.named_parameters(): 312 | # Make sure to not cast the LoRA weights to fp8. 313 | if not any(keyword in name for keyword in params_to_keep) and not 'lora' in name: 314 | param.data = param.data.to(after_lora_dtype) 315 | 316 | if quantization == "fp8_e4m3fn_fast": 317 | from .fp8_optimization import convert_fp8_linear 318 | convert_fp8_linear(transformer, base_dtype, params_to_keep=params_to_keep) 319 | 320 | 321 | DynamicSwapInstaller.install_model(transformer, device=device) 322 | 323 | if compile_args is not None: 324 | if compile_args["compile_single_blocks"]: 325 | for i, block in enumerate(transformer.single_transformer_blocks): 326 | transformer.single_transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) 327 | if compile_args["compile_double_blocks"]: 328 | for i, block in enumerate(transformer.transformer_blocks): 329 | transformer.transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) 330 | 331 | #transformer = torch.compile(transformer, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) 332 | 333 | pipe = { 334 | "transformer": transformer.eval(), 335 | "dtype": base_dtype, 336 | } 337 | return (pipe, ) 338 | 339 | class FramePackFindNearestBucket: 340 | @classmethod 341 | def INPUT_TYPES(s): 342 | return {"required": { 343 | "image": ("IMAGE", {"tooltip": "Image to resize"}), 344 | "base_resolution": ("INT", {"default": 640, "min": 64, "max": 2048, "step": 16, "tooltip": "Width of the image to encode"}), 345 | }, 346 | } 347 | 348 | RETURN_TYPES = ("INT", "INT", ) 349 | RETURN_NAMES = ("width","height",) 350 | FUNCTION = "process" 351 | CATEGORY = "FramePackWrapper" 352 | DESCRIPTION = "Finds the closes resolution bucket as defined in the orignal code" 353 | 354 | def process(self, image, base_resolution): 355 | 356 | H, W = image.shape[1], image.shape[2] 357 | 358 | new_height, new_width = find_nearest_bucket(H, W, resolution=base_resolution) 359 | 360 | return (new_width, new_height, ) 361 | 362 | 363 | class FramePackSampler: 364 | @classmethod 365 | def INPUT_TYPES(s): 366 | return { 367 | "required": { 368 | "model": ("FramePackMODEL",), 369 | "positive": ("CONDITIONING",), 370 | "negative": ("CONDITIONING",), 371 | "start_latent": ("LATENT", {"tooltip": "init Latents to use for image2video"} ), 372 | "steps": ("INT", {"default": 30, "min": 1}), 373 | "use_teacache": ("BOOLEAN", {"default": True, "tooltip": "Use teacache for faster sampling."}), 374 | "teacache_rel_l1_thresh": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The threshold for the relative L1 loss."}), 375 | "cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 30.0, "step": 0.01}), 376 | "guidance_scale": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 32.0, "step": 0.01}), 377 | "shift": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01}), 378 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), 379 | "latent_window_size": ("INT", {"default": 9, "min": 1, "max": 33, "step": 1, "tooltip": "The size of the latent window to use for sampling."}), 380 | "total_second_length": ("FLOAT", {"default": 5, "min": 1, "max": 120, "step": 0.1, "tooltip": "The total length of the video in seconds."}), 381 | "gpu_memory_preservation": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 128.0, "step": 0.1, "tooltip": "The amount of GPU memory to preserve."}), 382 | "sampler": (["unipc_bh1", "unipc_bh2"], 383 | { 384 | "default": 'unipc_bh1' 385 | }), 386 | }, 387 | "optional": { 388 | "image_embeds": ("CLIP_VISION_OUTPUT", ), 389 | "end_latent": ("LATENT", {"tooltip": "end Latents to use for image2video"} ), 390 | "end_image_embeds": ("CLIP_VISION_OUTPUT", {"tooltip": "end Image's clip embeds"} ), 391 | "embed_interpolation": (["disabled", "weighted_average", "linear"], {"default": 'disabled', "tooltip": "Image embedding interpolation type. If linear, will smoothly interpolate with time, else it'll be weighted average with the specified weight."}), 392 | "start_embed_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Weighted average constant for image embed interpolation. If end image is not set, the embed's strength won't be affected"}), 393 | "initial_samples": ("LATENT", {"tooltip": "init Latents to use for video2video"} ), 394 | "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), 395 | } 396 | } 397 | 398 | RETURN_TYPES = ("LATENT", ) 399 | RETURN_NAMES = ("samples",) 400 | FUNCTION = "process" 401 | CATEGORY = "FramePackWrapper" 402 | 403 | def process(self, model, shift, positive, negative, latent_window_size, use_teacache, total_second_length, teacache_rel_l1_thresh, steps, cfg, 404 | guidance_scale, seed, sampler, gpu_memory_preservation, start_latent=None, image_embeds=None, end_latent=None, end_image_embeds=None, embed_interpolation="linear", start_embed_strength=1.0, initial_samples=None, denoise_strength=1.0): 405 | total_latent_sections = (total_second_length * 30) / (latent_window_size * 4) 406 | total_latent_sections = int(max(round(total_latent_sections), 1)) 407 | print("total_latent_sections: ", total_latent_sections) 408 | 409 | transformer = model["transformer"] 410 | base_dtype = model["dtype"] 411 | 412 | device = mm.get_torch_device() 413 | offload_device = mm.unet_offload_device() 414 | 415 | mm.unload_all_models() 416 | mm.cleanup_models() 417 | mm.soft_empty_cache() 418 | 419 | if start_latent is not None: 420 | start_latent = start_latent["samples"] * vae_scaling_factor 421 | if initial_samples is not None: 422 | initial_samples = initial_samples["samples"] * vae_scaling_factor 423 | if end_latent is not None: 424 | end_latent = end_latent["samples"] * vae_scaling_factor 425 | has_end_image = end_latent is not None 426 | print("start_latent", start_latent.shape) 427 | B, C, T, H, W = start_latent.shape 428 | 429 | if image_embeds is not None: 430 | start_image_encoder_last_hidden_state = image_embeds["last_hidden_state"].to(device, base_dtype) 431 | 432 | if has_end_image: 433 | assert end_image_embeds is not None 434 | end_image_encoder_last_hidden_state = end_image_embeds["last_hidden_state"].to(device, base_dtype) 435 | else: 436 | if image_embeds is not None: 437 | end_image_encoder_last_hidden_state = torch.zeros_like(start_image_encoder_last_hidden_state) 438 | 439 | llama_vec = positive[0][0].to(device, base_dtype) 440 | clip_l_pooler = positive[0][1]["pooled_output"].to(device, base_dtype) 441 | 442 | if not math.isclose(cfg, 1.0): 443 | llama_vec_n = negative[0][0].to(device, base_dtype) 444 | clip_l_pooler_n = negative[0][1]["pooled_output"].to(device, base_dtype) 445 | else: 446 | llama_vec_n = torch.zeros_like(llama_vec, device=device) 447 | clip_l_pooler_n = torch.zeros_like(clip_l_pooler, device=device) 448 | 449 | llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512) 450 | llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512) 451 | 452 | 453 | # Sampling 454 | 455 | rnd = torch.Generator("cpu").manual_seed(seed) 456 | 457 | num_frames = latent_window_size * 4 - 3 458 | 459 | history_latents = torch.zeros(size=(1, 16, 1 + 2 + 16, H, W), dtype=torch.float32).cpu() 460 | 461 | total_generated_latent_frames = 0 462 | 463 | latent_paddings_list = list(reversed(range(total_latent_sections))) 464 | latent_paddings = latent_paddings_list.copy() # Create a copy for iteration 465 | 466 | comfy_model = HyVideoModel( 467 | HyVideoModelConfig(base_dtype), 468 | model_type=comfy.model_base.ModelType.FLOW, 469 | device=device, 470 | ) 471 | 472 | patcher = comfy.model_patcher.ModelPatcher(comfy_model, device, torch.device("cpu")) 473 | from latent_preview import prepare_callback 474 | callback = prepare_callback(patcher, steps) 475 | 476 | move_model_to_device_with_memory_preservation(transformer, target_device=device, preserved_memory_gb=gpu_memory_preservation) 477 | 478 | if total_latent_sections > 4: 479 | # In theory the latent_paddings should follow the above sequence, but it seems that duplicating some 480 | # items looks better than expanding it when total_latent_sections > 4 481 | # One can try to remove below trick and just 482 | # use `latent_paddings = list(reversed(range(total_latent_sections)))` to compare 483 | latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0] 484 | latent_paddings_list = latent_paddings.copy() 485 | 486 | for i, latent_padding in enumerate(latent_paddings): 487 | print(f"latent_padding: {latent_padding}") 488 | is_last_section = latent_padding == 0 489 | is_first_section = latent_padding == latent_paddings[0] 490 | latent_padding_size = latent_padding * latent_window_size 491 | 492 | if image_embeds is not None: 493 | if embed_interpolation != "disabled": 494 | if embed_interpolation == "linear": 495 | if total_latent_sections <= 1: 496 | frac = 1.0 # Handle case with only one section 497 | else: 498 | frac = 1 - i / (total_latent_sections - 1) # going backwards 499 | else: 500 | frac = start_embed_strength if has_end_image else 1.0 501 | 502 | image_encoder_last_hidden_state = start_image_encoder_last_hidden_state * frac + (1 - frac) * end_image_encoder_last_hidden_state 503 | else: 504 | image_encoder_last_hidden_state = start_image_encoder_last_hidden_state * start_embed_strength 505 | else: 506 | image_encoder_last_hidden_state = None 507 | 508 | print(f'latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}, is_first_section = {is_first_section}') 509 | 510 | start_latent_frames = T # 0 or 1 511 | indices = torch.arange(0, sum([start_latent_frames, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0) 512 | clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split([start_latent_frames, latent_padding_size, latent_window_size, 1, 2, 16], dim=1) 513 | clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1) 514 | 515 | clean_latents_pre = start_latent.to(history_latents) 516 | clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2) 517 | clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) 518 | 519 | # Use end image latent for the first section if provided 520 | if has_end_image and is_first_section: 521 | clean_latents_post = end_latent.to(history_latents) 522 | clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) 523 | 524 | #vid2vid WIP 525 | 526 | if initial_samples is not None: 527 | total_length = initial_samples.shape[2] 528 | 529 | # Get the max padding value for normalization 530 | max_padding = max(latent_paddings_list) 531 | 532 | if is_last_section: 533 | # Last section should capture the end of the sequence 534 | start_idx = max(0, total_length - latent_window_size) 535 | else: 536 | # Calculate windows that distribute more evenly across the sequence 537 | # This normalizes the padding values to create appropriate spacing 538 | if max_padding > 0: # Avoid division by zero 539 | progress = (max_padding - latent_padding) / max_padding 540 | start_idx = int(progress * max(0, total_length - latent_window_size)) 541 | else: 542 | start_idx = 0 543 | 544 | end_idx = min(start_idx + latent_window_size, total_length) 545 | print(f"start_idx: {start_idx}, end_idx: {end_idx}, total_length: {total_length}") 546 | input_init_latents = initial_samples[:, :, start_idx:end_idx, :, :].to(device) 547 | 548 | 549 | if use_teacache: 550 | transformer.initialize_teacache(enable_teacache=True, num_steps=steps, rel_l1_thresh=teacache_rel_l1_thresh) 551 | else: 552 | transformer.initialize_teacache(enable_teacache=False) 553 | 554 | with torch.autocast(device_type=mm.get_autocast_device(device), dtype=base_dtype, enabled=True): 555 | generated_latents = sample_hunyuan( 556 | transformer=transformer, 557 | sampler=sampler, 558 | initial_latent=input_init_latents if initial_samples is not None else None, 559 | strength=denoise_strength, 560 | width=W * 8, 561 | height=H * 8, 562 | frames=num_frames, 563 | real_guidance_scale=cfg, 564 | distilled_guidance_scale=guidance_scale, 565 | guidance_rescale=0, 566 | shift=shift if shift != 0 else None, 567 | num_inference_steps=steps, 568 | generator=rnd, 569 | prompt_embeds=llama_vec, 570 | prompt_embeds_mask=llama_attention_mask, 571 | prompt_poolers=clip_l_pooler, 572 | negative_prompt_embeds=llama_vec_n, 573 | negative_prompt_embeds_mask=llama_attention_mask_n, 574 | negative_prompt_poolers=clip_l_pooler_n, 575 | device=device, 576 | dtype=base_dtype, 577 | image_embeddings=image_encoder_last_hidden_state, 578 | latent_indices=latent_indices, 579 | clean_latents=clean_latents, 580 | clean_latent_indices=clean_latent_indices, 581 | clean_latents_2x=clean_latents_2x, 582 | clean_latent_2x_indices=clean_latent_2x_indices, 583 | clean_latents_4x=clean_latents_4x, 584 | clean_latent_4x_indices=clean_latent_4x_indices, 585 | callback=callback, 586 | ) 587 | 588 | if is_last_section: 589 | generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2) 590 | 591 | total_generated_latent_frames += int(generated_latents.shape[2]) 592 | history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2) 593 | 594 | real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :] 595 | 596 | if is_last_section: 597 | break 598 | 599 | transformer.to(offload_device) 600 | mm.soft_empty_cache() 601 | 602 | return {"samples": real_history_latents / vae_scaling_factor}, 603 | 604 | NODE_CLASS_MAPPINGS = { 605 | "DownloadAndLoadFramePackModel": DownloadAndLoadFramePackModel, 606 | "FramePackSampler": FramePackSampler, 607 | "FramePackTorchCompileSettings": FramePackTorchCompileSettings, 608 | "FramePackFindNearestBucket": FramePackFindNearestBucket, 609 | "LoadFramePackModel": LoadFramePackModel, 610 | "FramePackLoraSelect": FramePackLoraSelect, 611 | } 612 | NODE_DISPLAY_NAME_MAPPINGS = { 613 | "DownloadAndLoadFramePackModel": "(Down)Load FramePackModel", 614 | "FramePackSampler": "FramePackSampler", 615 | "FramePackTorchCompileSettings": "Torch Compile Settings", 616 | "FramePackFindNearestBucket": "Find Nearest Bucket", 617 | "LoadFramePackModel": "Load FramePackModel", 618 | "FramePackLoraSelect": "Select Lora", 619 | } 620 | 621 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate>=1.6.0 2 | diffusers>=0.33.1 3 | transformers>=4.46.2 4 | einops 5 | safetensors 6 | -------------------------------------------------------------------------------- /transformer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "HunyuanVideoTransformer3DModelPacked", 3 | "_diffusers_version": "0.33.0.dev0", 4 | "_name_or_path": "hunyuanvideo-community/HunyuanVideo", 5 | "attention_head_dim": 128, 6 | "guidance_embeds": true, 7 | "has_clean_x_embedder": true, 8 | "has_image_proj": true, 9 | "image_proj_dim": 1152, 10 | "in_channels": 16, 11 | "mlp_ratio": 4.0, 12 | "num_attention_heads": 24, 13 | "num_layers": 20, 14 | "num_refiner_layers": 2, 15 | "num_single_layers": 40, 16 | "out_channels": 16, 17 | "patch_size": 2, 18 | "patch_size_t": 1, 19 | "pooled_projection_dim": 768, 20 | "qk_norm": "rms_norm", 21 | "rope_axes_dim": [ 22 | 16, 23 | 56, 24 | 56 25 | ], 26 | "rope_theta": 256.0, 27 | "text_embed_dim": 4096 28 | } 29 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | import torch 3 | import logging 4 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 5 | log = logging.getLogger(__name__) 6 | 7 | def check_diffusers_version(): 8 | try: 9 | version = importlib.metadata.version('diffusers') 10 | required_version = '0.31.0' 11 | if version < required_version: 12 | raise AssertionError(f"diffusers version {version} is installed, but version {required_version} or higher is required.") 13 | except importlib.metadata.PackageNotFoundError: 14 | raise AssertionError("diffusers is not installed.") 15 | 16 | def print_memory(device): 17 | memory = torch.cuda.memory_allocated(device) / 1024**3 18 | max_memory = torch.cuda.max_memory_allocated(device) / 1024**3 19 | max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 20 | log.info(f"-------------------------------") 21 | log.info(f"Allocated memory: {memory=:.3f} GB") 22 | log.info(f"Max allocated memory: {max_memory=:.3f} GB") 23 | log.info(f"Max reserved memory: {max_reserved=:.3f} GB") 24 | log.info(f"-------------------------------") 25 | #memory_summary = torch.cuda.memory_summary(device=device, abbreviated=False) 26 | #log.info(f"Memory Summary:\n{memory_summary}") 27 | 28 | def convert_to_diffusers(prefix, weights_sd): 29 | # convert from default LoRA to diffusers 30 | # https://github.com/kohya-ss/musubi-tuner/blob/main/convert_lora.py 31 | 32 | # get alphas 33 | lora_alphas = {} 34 | for key, weight in weights_sd.items(): 35 | if key.startswith(prefix): 36 | lora_name = key.split(".", 1)[0] # before first dot 37 | if lora_name not in lora_alphas and "alpha" in key: 38 | lora_alphas[lora_name] = weight 39 | 40 | new_weights_sd = {} 41 | for key, weight in weights_sd.items(): 42 | if key.startswith(prefix): 43 | if "alpha" in key: 44 | continue 45 | 46 | lora_name = key.split(".", 1)[0] # before first dot 47 | 48 | module_name = lora_name[len(prefix) :] # remove "lora_unet_" 49 | module_name = module_name.replace("_", ".") # replace "_" with "." 50 | 51 | # HunyuanVideo lora name to module name: ugly but works 52 | #module_name = module_name.replace("double.blocks.", "double_blocks.") # fix double blocks 53 | module_name = module_name.replace("single.transformer.blocks.", "single_transformer_blocks.") # fix single blocks 54 | module_name = module_name.replace("transformer.blocks.", "transformer_blocks.") # fix double blocks 55 | 56 | module_name = module_name.replace("img.", "img_") # fix img 57 | module_name = module_name.replace("txt.", "txt_") # fix txt 58 | module_name = module_name.replace("to.q", "to_q") # fix attn 59 | module_name = module_name.replace("to.k", "to_k") 60 | module_name = module_name.replace("to.v", "to_v") 61 | module_name = module_name.replace("to.add.out", "to_add_out") 62 | module_name = module_name.replace("add.k.proj", "add_k_proj") 63 | module_name = module_name.replace("add.q.proj", "add_q_proj") 64 | module_name = module_name.replace("add.v.proj", "add_v_proj") 65 | module_name = module_name.replace("add.out.proj", "add_out_proj") 66 | module_name = module_name.replace("proj.", "proj_") # fix proj 67 | module_name = module_name.replace("to.out", "to_out") # fix to_out 68 | module_name = module_name.replace("ff.context", "ff_context") # fix ff context 69 | 70 | diffusers_prefix = "transformer" 71 | if "lora_down" in key: 72 | new_key = f"{diffusers_prefix}.{module_name}.lora_A.weight" 73 | dim = weight.shape[0] 74 | elif "lora_up" in key: 75 | new_key = f"{diffusers_prefix}.{module_name}.lora_B.weight" 76 | dim = weight.shape[1] 77 | else: 78 | log.warning(f"unexpected key: {key} in default LoRA format") 79 | continue 80 | 81 | # scale weight by alpha 82 | if lora_name in lora_alphas: 83 | # we scale both down and up, so scale is sqrt 84 | scale = lora_alphas[lora_name] / dim 85 | scale = scale.sqrt() 86 | weight = weight * scale 87 | else: 88 | log.warning(f"missing alpha for {lora_name}") 89 | 90 | new_weights_sd[new_key] = weight 91 | 92 | return new_weights_sd 93 | 94 | --------------------------------------------------------------------------------