├── .gitattributes ├── .github └── workflows │ └── publish.yml ├── .gitignore ├── __init__.py ├── configs └── hy_vae_config.json ├── context.py ├── enhance_a_video ├── __init__.py ├── enhance.py └── globals.py ├── example_workflows ├── example_input.png ├── example_output.mp4 ├── hunhyuan_rf_inversion_testing_01.json ├── hyvideo_custom_testing_01.json ├── hyvideo_dashtoon_keyframe_example_01.json ├── hyvideo_i2v_example_01.json ├── hyvideo_i2v_example_fixed_model_02.json ├── hyvideo_ip2v_experimental_dango.json ├── hyvideo_leapfusion_img2vid_example_01.json ├── hyvideo_lowvram_blockswap_test.json ├── hyvideo_prompt_mix_experimental.json ├── hyvideo_skyreel_img2vid_example_01.json ├── hyvideo_t2v_example_01.json └── hyvideo_v2v_example_01.json ├── fp8_optimization.py ├── hunyuan_empty_prompt_embeds_dict.pt ├── hyvideo ├── __init__.py ├── config.py ├── constants.py ├── diffusion │ ├── __init__.py │ ├── pipelines │ │ ├── __init__.py │ │ └── pipeline_hunyuan_video.py │ └── schedulers │ │ ├── __init__.py │ │ ├── scheduling_dpmsolver_multistep.py │ │ ├── scheduling_flow_match_discrete.py │ │ ├── scheduling_sasolver.py │ │ └── scheduling_unipc_multistep.py ├── modules │ ├── __init__.py │ ├── activation_layers.py │ ├── attention.py │ ├── embed_layers.py │ ├── fp8_map.safetensors │ ├── fp8_optimization.py │ ├── mlp_layers.py │ ├── models.py │ ├── modulate_layers.py │ ├── norm_layers.py │ ├── posemb_layers.py │ └── token_refiner.py ├── prompt_rewrite.py ├── text_encoder │ ├── __init__.py │ ├── configuration_llava.py │ ├── modeling_llava.py │ └── processing_llava.py ├── utils │ ├── __init__.py │ ├── data_utils.py │ ├── file_utils.py │ ├── helpers.py │ ├── preprocess_text_encoder_tokenizer_utils.py │ └── token_helper.py └── vae │ ├── __init__.py │ ├── autoencoder_kl_causal_3d.py │ ├── unet_causal_3d_blocks.py │ └── vae.py ├── nodes.py ├── nodes_rf_inversion.py ├── pyproject.toml ├── readme.md ├── requirements.txt └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | jobs: 12 | publish-node: 13 | name: Publish Custom Node to registry 14 | runs-on: ubuntu-latest 15 | # if this is a forked repository. Skipping the workflow. 16 | if: github.event.repository.fork == false 17 | steps: 18 | - name: Check out code 19 | uses: actions/checkout@v4 20 | - name: Publish Custom Node 21 | uses: Comfy-Org/publish-node-action@main 22 | with: 23 | ## Add your own personal access token to your Github Repository secrets and reference it here. 24 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | output/ 2 | *__pycache__/ 3 | samples*/ 4 | runs/ 5 | checkpoints/ 6 | master_ip 7 | logs/ 8 | *.DS_Store 9 | .idea 10 | tools/ 11 | .vscode/ 12 | convert_* -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS as NODES_CLASS, NODE_DISPLAY_NAME_MAPPINGS as NODES_DISPLAY 2 | from .nodes_rf_inversion import NODE_CLASS_MAPPINGS as NODE_CLASS_MAPPINGS_RF_INVERSION, NODE_DISPLAY_NAME_MAPPINGS as NODE_DISPLAY_NAME_MAPPINGS_RF_INVERSION 3 | 4 | NODE_CLASS_MAPPINGS = {**NODES_CLASS, **NODE_CLASS_MAPPINGS_RF_INVERSION} 5 | NODE_DISPLAY_NAME_MAPPINGS = {**NODES_DISPLAY, **NODE_DISPLAY_NAME_MAPPINGS_RF_INVERSION} 6 | 7 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] -------------------------------------------------------------------------------- /configs/hy_vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "AutoencoderKLCausal3D", 3 | "_diffusers_version": "0.4.2", 4 | "act_fn": "silu", 5 | "block_out_channels": [ 6 | 128, 7 | 256, 8 | 512, 9 | 512 10 | ], 11 | "down_block_types": [ 12 | "DownEncoderBlockCausal3D", 13 | "DownEncoderBlockCausal3D", 14 | "DownEncoderBlockCausal3D", 15 | "DownEncoderBlockCausal3D" 16 | ], 17 | "in_channels": 3, 18 | "latent_channels": 16, 19 | "layers_per_block": 2, 20 | "norm_num_groups": 32, 21 | "out_channels": 3, 22 | "tile_sample_min_size": 256, 23 | "sample_tsize": 64, 24 | "up_block_types": [ 25 | "UpDecoderBlockCausal3D", 26 | "UpDecoderBlockCausal3D", 27 | "UpDecoderBlockCausal3D", 28 | "UpDecoderBlockCausal3D" 29 | ], 30 | "scaling_factor": 0.476986, 31 | "time_compression_ratio": 4, 32 | "mid_block_add_attention": true 33 | } 34 | -------------------------------------------------------------------------------- /context.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Callable, Optional, List 3 | 4 | 5 | def ordered_halving(val): 6 | bin_str = f"{val:064b}" 7 | bin_flip = bin_str[::-1] 8 | as_int = int(bin_flip, 2) 9 | 10 | return as_int / (1 << 64) 11 | 12 | def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]: 13 | prev_val = -1 14 | for i, val in enumerate(window): 15 | val = val % num_frames 16 | if val < prev_val: 17 | return True, i 18 | prev_val = val 19 | return False, -1 20 | 21 | def shift_window_to_start(window: list[int], num_frames: int): 22 | start_val = window[0] 23 | for i in range(len(window)): 24 | # 1) subtract each element by start_val to move vals relative to the start of all frames 25 | # 2) add num_frames and take modulus to get adjusted vals 26 | window[i] = ((window[i] - start_val) + num_frames) % num_frames 27 | 28 | def shift_window_to_end(window: list[int], num_frames: int): 29 | # 1) shift window to start 30 | shift_window_to_start(window, num_frames) 31 | end_val = window[-1] 32 | end_delta = num_frames - end_val - 1 33 | for i in range(len(window)): 34 | # 2) add end_delta to each val to slide windows to end 35 | window[i] = window[i] + end_delta 36 | 37 | def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]: 38 | all_indexes = list(range(num_frames)) 39 | for w in windows: 40 | for val in w: 41 | try: 42 | all_indexes.remove(val) 43 | except ValueError: 44 | pass 45 | return all_indexes 46 | 47 | def uniform_looped( 48 | step: int = ..., 49 | num_steps: Optional[int] = None, 50 | num_frames: int = ..., 51 | context_size: Optional[int] = None, 52 | context_stride: int = 3, 53 | context_overlap: int = 4, 54 | closed_loop: bool = True, 55 | ): 56 | if num_frames <= context_size: 57 | yield list(range(num_frames)) 58 | return 59 | 60 | context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) 61 | 62 | for context_step in 1 << np.arange(context_stride): 63 | pad = int(round(num_frames * ordered_halving(step))) 64 | for j in range( 65 | int(ordered_halving(step) * context_step) + pad, 66 | num_frames + pad + (0 if closed_loop else -context_overlap), 67 | (context_size * context_step - context_overlap), 68 | ): 69 | yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)] 70 | 71 | #from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) 72 | def uniform_standard( 73 | step: int = ..., 74 | num_steps: Optional[int] = None, 75 | num_frames: int = ..., 76 | context_size: Optional[int] = None, 77 | context_stride: int = 3, 78 | context_overlap: int = 4, 79 | closed_loop: bool = True, 80 | ): 81 | windows = [] 82 | if num_frames <= context_size: 83 | windows.append(list(range(num_frames))) 84 | return windows 85 | 86 | context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) 87 | 88 | for context_step in 1 << np.arange(context_stride): 89 | pad = int(round(num_frames * ordered_halving(step))) 90 | for j in range( 91 | int(ordered_halving(step) * context_step) + pad, 92 | num_frames + pad + (0 if closed_loop else -context_overlap), 93 | (context_size * context_step - context_overlap), 94 | ): 95 | windows.append([e % num_frames for e in range(j, j + context_size * context_step, context_step)]) 96 | 97 | # now that windows are created, shift any windows that loop, and delete duplicate windows 98 | delete_idxs = [] 99 | win_i = 0 100 | while win_i < len(windows): 101 | # if window is rolls over itself, need to shift it 102 | is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames) 103 | if is_roll: 104 | roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides 105 | shift_window_to_end(windows[win_i], num_frames=num_frames) 106 | # check if next window (cyclical) is missing roll_val 107 | if roll_val not in windows[(win_i+1) % len(windows)]: 108 | # need to insert new window here - just insert window starting at roll_val 109 | windows.insert(win_i+1, list(range(roll_val, roll_val + context_size))) 110 | # delete window if it's not unique 111 | for pre_i in range(0, win_i): 112 | if windows[win_i] == windows[pre_i]: 113 | delete_idxs.append(win_i) 114 | break 115 | win_i += 1 116 | 117 | # reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation 118 | delete_idxs.reverse() 119 | for i in delete_idxs: 120 | windows.pop(i) 121 | return windows 122 | 123 | def static_standard( 124 | step: int = ..., 125 | num_steps: Optional[int] = None, 126 | num_frames: int = ..., 127 | context_size: Optional[int] = None, 128 | context_stride: int = 3, 129 | context_overlap: int = 4, 130 | closed_loop: bool = True, 131 | ): 132 | windows = [] 133 | if num_frames <= context_size: 134 | windows.append(list(range(num_frames))) 135 | return windows 136 | # always return the same set of windows 137 | delta = context_size - context_overlap 138 | for start_idx in range(0, num_frames, delta): 139 | # if past the end of frames, move start_idx back to allow same context_length 140 | ending = start_idx + context_size 141 | if ending >= num_frames: 142 | final_delta = ending - num_frames 143 | final_start_idx = start_idx - final_delta 144 | windows.append(list(range(final_start_idx, final_start_idx + context_size))) 145 | break 146 | windows.append(list(range(start_idx, start_idx + context_size))) 147 | return windows 148 | 149 | def get_context_scheduler(name: str) -> Callable: 150 | if name == "uniform_looped": 151 | return uniform_looped 152 | elif name == "uniform_standard": 153 | return uniform_standard 154 | elif name == "static_standard": 155 | return static_standard 156 | else: 157 | raise ValueError(f"Unknown context_overlap policy {name}") 158 | 159 | 160 | def get_total_steps( 161 | scheduler, 162 | timesteps: List[int], 163 | num_steps: Optional[int] = None, 164 | num_frames: int = ..., 165 | context_size: Optional[int] = None, 166 | context_stride: int = 3, 167 | context_overlap: int = 4, 168 | closed_loop: bool = True, 169 | ): 170 | return sum( 171 | len( 172 | list( 173 | scheduler( 174 | i, 175 | num_steps, 176 | num_frames, 177 | context_size, 178 | context_stride, 179 | context_overlap, 180 | ) 181 | ) 182 | ) 183 | for i in range(len(timesteps)) 184 | ) 185 | -------------------------------------------------------------------------------- /enhance_a_video/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-HunyuanVideoWrapper/3ce9640497139b14462910b0dbf2be1df855c1d6/enhance_a_video/__init__.py -------------------------------------------------------------------------------- /enhance_a_video/enhance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from .globals import get_enhance_weight, get_num_frames 4 | 5 | def get_feta_scores(query, key): 6 | img_q, img_k = query, key 7 | 8 | num_frames = get_num_frames() 9 | 10 | B, S, N, C = img_q.shape 11 | 12 | # Calculate spatial dimension 13 | spatial_dim = S // num_frames 14 | 15 | # Add time dimension between spatial and head dims 16 | query_image = img_q.reshape(B, spatial_dim, num_frames, N, C) 17 | key_image = img_k.reshape(B, spatial_dim, num_frames, N, C) 18 | 19 | # Expand time dimension 20 | query_image = query_image.expand(-1, -1, num_frames, -1, -1) # [B, S, T, N, C] 21 | key_image = key_image.expand(-1, -1, num_frames, -1, -1) # [B, S, T, N, C] 22 | 23 | # Reshape to match feta_score input format: [(B S) N T C] 24 | query_image = rearrange(query_image, "b s t n c -> (b s) n t c") #torch.Size([3200, 24, 5, 128]) 25 | key_image = rearrange(key_image, "b s t n c -> (b s) n t c") 26 | 27 | return feta_score(query_image, key_image, C, num_frames) 28 | 29 | def feta_score(query_image, key_image, head_dim, num_frames): 30 | scale = head_dim**-0.5 31 | query_image = query_image * scale 32 | attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32 33 | attn_temp = attn_temp.to(torch.float32) 34 | attn_temp = attn_temp.softmax(dim=-1) 35 | 36 | # Reshape to [batch_size * num_tokens, num_frames, num_frames] 37 | attn_temp = attn_temp.reshape(-1, num_frames, num_frames) 38 | 39 | # Create a mask for diagonal elements 40 | diag_mask = torch.eye(num_frames, device=attn_temp.device).bool() 41 | diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1) 42 | 43 | # Zero out diagonal elements 44 | attn_wo_diag = attn_temp.masked_fill(diag_mask, 0) 45 | 46 | # Calculate mean for each token's attention matrix 47 | # Number of off-diagonal elements per matrix is n*n - n 48 | num_off_diag = num_frames * num_frames - num_frames 49 | mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag 50 | 51 | enhance_scores = mean_scores.mean() * (num_frames + get_enhance_weight()) 52 | enhance_scores = enhance_scores.clamp(min=1) 53 | return enhance_scores 54 | -------------------------------------------------------------------------------- /enhance_a_video/globals.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | NUM_FRAMES = None 4 | FETA_WEIGHT = None 5 | ENABLE_FETA_SINGLE = False 6 | ENABLE_FETA_DOUBLE = False 7 | 8 | @torch.compiler.disable() 9 | def set_num_frames(num_frames: int): 10 | global NUM_FRAMES 11 | NUM_FRAMES = num_frames 12 | 13 | @torch.compiler.disable() 14 | def get_num_frames() -> int: 15 | return NUM_FRAMES 16 | 17 | 18 | def enable_enhance(single, double): 19 | global ENABLE_FETA_SINGLE, ENABLE_FETA_DOUBLE 20 | ENABLE_FETA_SINGLE = single 21 | ENABLE_FETA_DOUBLE = double 22 | 23 | def disable_enhance(): 24 | global ENABLE_FETA_SINGLE, ENABLE_FETA_DOUBLE 25 | ENABLE_FETA_SINGLE = False 26 | ENABLE_FETA_DOUBLE = False 27 | 28 | @torch.compiler.disable() 29 | def is_enhance_enabled_single() -> bool: 30 | return ENABLE_FETA_SINGLE 31 | 32 | @torch.compiler.disable() 33 | def is_enhance_enabled_double() -> bool: 34 | return ENABLE_FETA_DOUBLE 35 | 36 | @torch.compiler.disable() 37 | def set_enhance_weight(feta_weight: float): 38 | global FETA_WEIGHT 39 | FETA_WEIGHT = feta_weight 40 | 41 | @torch.compiler.disable() 42 | def get_enhance_weight() -> float: 43 | return FETA_WEIGHT 44 | -------------------------------------------------------------------------------- /example_workflows/example_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-HunyuanVideoWrapper/3ce9640497139b14462910b0dbf2be1df855c1d6/example_workflows/example_input.png -------------------------------------------------------------------------------- /example_workflows/example_output.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-HunyuanVideoWrapper/3ce9640497139b14462910b0dbf2be1df855c1d6/example_workflows/example_output.mp4 -------------------------------------------------------------------------------- /example_workflows/hyvideo_ip2v_experimental_dango.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 73, 3 | "last_link_id": 82, 4 | "nodes": [ 5 | { 6 | "id": 3, 7 | "type": "HyVideoSampler", 8 | "pos": [ 9 | 260, 10 | -230 11 | ], 12 | "size": [ 13 | 315, 14 | 546 15 | ], 16 | "flags": {}, 17 | "order": 5, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "model", 22 | "type": "HYVIDEOMODEL", 23 | "link": 2 24 | }, 25 | { 26 | "name": "hyvid_embeds", 27 | "type": "HYVIDEMBEDS", 28 | "link": 82 29 | }, 30 | { 31 | "name": "samples", 32 | "type": "LATENT", 33 | "link": null, 34 | "shape": 7 35 | }, 36 | { 37 | "name": "stg_args", 38 | "type": "STGARGS", 39 | "link": null, 40 | "shape": 7 41 | } 42 | ], 43 | "outputs": [ 44 | { 45 | "name": "samples", 46 | "type": "LATENT", 47 | "links": [ 48 | 4 49 | ], 50 | "slot_index": 0 51 | } 52 | ], 53 | "properties": { 54 | "Node name for S&R": "HyVideoSampler" 55 | }, 56 | "widgets_values": [ 57 | 720, 58 | 480, 59 | 61, 60 | 30, 61 | 7.5, 62 | 7.5, 63 | 233, 64 | "fixed", 65 | true, 66 | 1 67 | ] 68 | }, 69 | { 70 | "id": 7, 71 | "type": "HyVideoVAELoader", 72 | "pos": [ 73 | -277, 74 | -284 75 | ], 76 | "size": [ 77 | 379.166748046875, 78 | 82 79 | ], 80 | "flags": {}, 81 | "order": 0, 82 | "mode": 0, 83 | "inputs": [ 84 | { 85 | "name": "compile_args", 86 | "type": "COMPILEARGS", 87 | "link": null, 88 | "shape": 7 89 | } 90 | ], 91 | "outputs": [ 92 | { 93 | "name": "vae", 94 | "type": "VAE", 95 | "links": [ 96 | 6 97 | ], 98 | "slot_index": 0 99 | } 100 | ], 101 | "properties": { 102 | "Node name for S&R": "HyVideoVAELoader" 103 | }, 104 | "widgets_values": [ 105 | "hyvid\\hunyuan_video_vae_bf16.safetensors", 106 | "bf16" 107 | ] 108 | }, 109 | { 110 | "id": 1, 111 | "type": "HyVideoModelLoader", 112 | "pos": [ 113 | -285, 114 | -94 115 | ], 116 | "size": [ 117 | 426.1773986816406, 118 | 194 119 | ], 120 | "flags": {}, 121 | "order": 1, 122 | "mode": 0, 123 | "inputs": [ 124 | { 125 | "name": "compile_args", 126 | "type": "COMPILEARGS", 127 | "link": null, 128 | "shape": 7 129 | }, 130 | { 131 | "name": "block_swap_args", 132 | "type": "BLOCKSWAPARGS", 133 | "link": null, 134 | "shape": 7 135 | }, 136 | { 137 | "name": "lora", 138 | "type": "HYVIDLORA", 139 | "link": null, 140 | "shape": 7 141 | } 142 | ], 143 | "outputs": [ 144 | { 145 | "name": "model", 146 | "type": "HYVIDEOMODEL", 147 | "links": [ 148 | 2 149 | ], 150 | "slot_index": 0 151 | } 152 | ], 153 | "properties": { 154 | "Node name for S&R": "HyVideoModelLoader" 155 | }, 156 | "widgets_values": [ 157 | "hyvideo\\hunyuan_video_720_fp8_e4m3fn.safetensors", 158 | "bf16", 159 | "fp8_e4m3fn", 160 | "offload_device", 161 | "sdpa" 162 | ] 163 | }, 164 | { 165 | "id": 71, 166 | "type": "DownloadAndLoadHyVideoTextEncoder", 167 | "pos": [ 168 | -637.5891723632812, 169 | 201.5082244873047 170 | ], 171 | "size": [ 172 | 441, 173 | 178 174 | ], 175 | "flags": {}, 176 | "order": 2, 177 | "mode": 0, 178 | "inputs": [], 179 | "outputs": [ 180 | { 181 | "name": "hyvid_text_encoder", 182 | "type": "HYVIDTEXTENCODER", 183 | "links": [ 184 | 80 185 | ], 186 | "slot_index": 0 187 | } 188 | ], 189 | "properties": { 190 | "Node name for S&R": "DownloadAndLoadHyVideoTextEncoder" 191 | }, 192 | "widgets_values": [ 193 | "xtuner/llava-llama-3-8b-v1_1-transformers", 194 | "openai/clip-vit-large-patch14", 195 | "bf16", 196 | false, 197 | 2, 198 | "disabled" 199 | ] 200 | }, 201 | { 202 | "id": 73, 203 | "type": "HyVideoTextImageEncode", 204 | "pos": [ 205 | -38.233642578125, 206 | 414.9195556640625 207 | ], 208 | "size": [ 209 | 493.3573303222656, 210 | 382.35430908203125 211 | ], 212 | "flags": {}, 213 | "order": 4, 214 | "mode": 0, 215 | "inputs": [ 216 | { 217 | "name": "text_encoders", 218 | "type": "HYVIDTEXTENCODER", 219 | "link": 80 220 | }, 221 | { 222 | "name": "custom_prompt_template", 223 | "type": "PROMPT_TEMPLATE", 224 | "link": null, 225 | "shape": 7 226 | }, 227 | { 228 | "name": "clip_l", 229 | "type": "CLIP", 230 | "link": null, 231 | "shape": 7 232 | }, 233 | { 234 | "name": "image1", 235 | "type": "IMAGE", 236 | "link": 81, 237 | "shape": 7 238 | }, 239 | { 240 | "name": "image2", 241 | "type": "IMAGE", 242 | "link": null, 243 | "shape": 7 244 | }, 245 | { 246 | "name": "hyvid_cfg", 247 | "type": "HYVID_CFG", 248 | "link": null, 249 | "shape": 7 250 | } 251 | ], 252 | "outputs": [ 253 | { 254 | "name": "hyvid_embeds", 255 | "type": "HYVIDEMBEDS", 256 | "links": [ 257 | 82 258 | ], 259 | "slot_index": 0 260 | } 261 | ], 262 | "properties": { 263 | "Node name for S&R": "HyVideoTextImageEncode" 264 | }, 265 | "widgets_values": [ 266 | "Astonishing promotion video of a toy movie, high quality video 4k A fluffy plushie stuffed animal of , furry fox ears, dancing on grass land with blue sky. cinematic realistic rendering ", 267 | "::3", 268 | true, 269 | "video", 270 | "" 271 | ] 272 | }, 273 | { 274 | "id": 65, 275 | "type": "LoadImage", 276 | "pos": [ 277 | -540, 278 | 530 279 | ], 280 | "size": [ 281 | 315, 282 | 314 283 | ], 284 | "flags": {}, 285 | "order": 3, 286 | "mode": 0, 287 | "inputs": [], 288 | "outputs": [ 289 | { 290 | "name": "IMAGE", 291 | "type": "IMAGE", 292 | "links": [ 293 | 81 294 | ], 295 | "slot_index": 0 296 | }, 297 | { 298 | "name": "MASK", 299 | "type": "MASK", 300 | "links": null 301 | } 302 | ], 303 | "properties": { 304 | "Node name for S&R": "LoadImage" 305 | }, 306 | "widgets_values": [ 307 | "example.png", 308 | "image" 309 | ] 310 | }, 311 | { 312 | "id": 5, 313 | "type": "HyVideoDecode", 314 | "pos": [ 315 | 690, 316 | -230 317 | ], 318 | "size": [ 319 | 345.4285888671875, 320 | 150 321 | ], 322 | "flags": {}, 323 | "order": 6, 324 | "mode": 0, 325 | "inputs": [ 326 | { 327 | "name": "vae", 328 | "type": "VAE", 329 | "link": 6 330 | }, 331 | { 332 | "name": "samples", 333 | "type": "LATENT", 334 | "link": 4 335 | } 336 | ], 337 | "outputs": [ 338 | { 339 | "name": "images", 340 | "type": "IMAGE", 341 | "links": [ 342 | 42 343 | ], 344 | "slot_index": 0 345 | } 346 | ], 347 | "properties": { 348 | "Node name for S&R": "HyVideoDecode" 349 | }, 350 | "widgets_values": [ 351 | true, 352 | 64, 353 | 128, 354 | false 355 | ] 356 | }, 357 | { 358 | "id": 34, 359 | "type": "VHS_VideoCombine", 360 | "pos": [ 361 | 660, 362 | 30 363 | ], 364 | "size": [ 365 | 580.7774658203125, 366 | 697.8516235351562 367 | ], 368 | "flags": {}, 369 | "order": 7, 370 | "mode": 0, 371 | "inputs": [ 372 | { 373 | "name": "images", 374 | "type": "IMAGE", 375 | "link": 42 376 | }, 377 | { 378 | "name": "audio", 379 | "type": "AUDIO", 380 | "link": null, 381 | "shape": 7 382 | }, 383 | { 384 | "name": "meta_batch", 385 | "type": "VHS_BatchManager", 386 | "link": null, 387 | "shape": 7 388 | }, 389 | { 390 | "name": "vae", 391 | "type": "VAE", 392 | "link": null, 393 | "shape": 7 394 | } 395 | ], 396 | "outputs": [ 397 | { 398 | "name": "Filenames", 399 | "type": "VHS_FILENAMES", 400 | "links": null 401 | } 402 | ], 403 | "properties": { 404 | "Node name for S&R": "VHS_VideoCombine" 405 | }, 406 | "widgets_values": { 407 | "frame_rate": 24, 408 | "loop_count": 0, 409 | "filename_prefix": "HunyuanVideo", 410 | "format": "video/h264-mp4", 411 | "pix_fmt": "yuv420p", 412 | "crf": 20, 413 | "save_metadata": true, 414 | "pingpong": false, 415 | "save_output": true, 416 | "videopreview": { 417 | "hidden": false, 418 | "paused": false, 419 | "params": { 420 | "filename": "HunyuanVideo_00204.mp4", 421 | "subfolder": "", 422 | "type": "output", 423 | "format": "video/h264-mp4", 424 | "frame_rate": 24 425 | }, 426 | "muted": false 427 | } 428 | } 429 | } 430 | ], 431 | "links": [ 432 | [ 433 | 2, 434 | 1, 435 | 0, 436 | 3, 437 | 0, 438 | "HYVIDEOMODEL" 439 | ], 440 | [ 441 | 4, 442 | 3, 443 | 0, 444 | 5, 445 | 1, 446 | "LATENT" 447 | ], 448 | [ 449 | 6, 450 | 7, 451 | 0, 452 | 5, 453 | 0, 454 | "VAE" 455 | ], 456 | [ 457 | 42, 458 | 5, 459 | 0, 460 | 34, 461 | 0, 462 | "IMAGE" 463 | ], 464 | [ 465 | 80, 466 | 71, 467 | 0, 468 | 73, 469 | 0, 470 | "HYVIDTEXTENCODER" 471 | ], 472 | [ 473 | 81, 474 | 65, 475 | 0, 476 | 73, 477 | 3, 478 | "IMAGE" 479 | ], 480 | [ 481 | 82, 482 | 73, 483 | 0, 484 | 3, 485 | 1, 486 | "HYVIDEMBEDS" 487 | ] 488 | ], 489 | "groups": [], 490 | "config": {}, 491 | "extra": { 492 | "ds": { 493 | "scale": 0.895430243255319, 494 | "offset": [ 495 | 869.3809790113617, 496 | 317.12236180775994 497 | ] 498 | }, 499 | "workspace_info": { 500 | "id": "kZ4q7BpZY-s3NIJ0k8OPz" 501 | } 502 | }, 503 | "version": 0.4 504 | } -------------------------------------------------------------------------------- /example_workflows/hyvideo_lowvram_blockswap_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 40, 3 | "last_link_id": 44, 4 | "nodes": [ 5 | { 6 | "id": 7, 7 | "type": "HyVideoVAELoader", 8 | "pos": [ 9 | 442, 10 | -282 11 | ], 12 | "size": [ 13 | 379.166748046875, 14 | 82 15 | ], 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "compile_args", 22 | "type": "COMPILEARGS", 23 | "link": null, 24 | "shape": 7 25 | } 26 | ], 27 | "outputs": [ 28 | { 29 | "name": "vae", 30 | "type": "VAE", 31 | "links": [ 32 | 6 33 | ], 34 | "slot_index": 0 35 | } 36 | ], 37 | "properties": { 38 | "Node name for S&R": "HyVideoVAELoader" 39 | }, 40 | "widgets_values": [ 41 | "hyvid\\hunyuan_video_vae_bf16.safetensors", 42 | "bf16" 43 | ] 44 | }, 45 | { 46 | "id": 3, 47 | "type": "HyVideoSampler", 48 | "pos": [ 49 | 668, 50 | -62 51 | ], 52 | "size": [ 53 | 315, 54 | 334 55 | ], 56 | "flags": {}, 57 | "order": 10, 58 | "mode": 0, 59 | "inputs": [ 60 | { 61 | "name": "model", 62 | "type": "HYVIDEOMODEL", 63 | "link": 2 64 | }, 65 | { 66 | "name": "hyvid_embeds", 67 | "type": "HYVIDEMBEDS", 68 | "link": 36 69 | }, 70 | { 71 | "name": "samples", 72 | "type": "LATENT", 73 | "link": null, 74 | "shape": 7 75 | }, 76 | { 77 | "name": "stg_args", 78 | "type": "STGARGS", 79 | "link": null, 80 | "shape": 7 81 | } 82 | ], 83 | "outputs": [ 84 | { 85 | "name": "samples", 86 | "type": "LATENT", 87 | "links": [ 88 | 4 89 | ], 90 | "slot_index": 0 91 | } 92 | ], 93 | "properties": { 94 | "Node name for S&R": "HyVideoSampler" 95 | }, 96 | "widgets_values": [ 97 | 512, 98 | 512, 99 | 33, 100 | 20, 101 | 6, 102 | 9, 103 | 2, 104 | "fixed", 105 | 1, 106 | 1 107 | ] 108 | }, 109 | { 110 | "id": 35, 111 | "type": "HyVideoBlockSwap", 112 | "pos": [ 113 | -351, 114 | -44 115 | ], 116 | "size": [ 117 | 315, 118 | 130 119 | ], 120 | "flags": {}, 121 | "order": 1, 122 | "mode": 0, 123 | "inputs": [], 124 | "outputs": [ 125 | { 126 | "name": "block_swap_args", 127 | "type": "BLOCKSWAPARGS", 128 | "links": [ 129 | 43 130 | ] 131 | } 132 | ], 133 | "properties": { 134 | "Node name for S&R": "HyVideoBlockSwap" 135 | }, 136 | "widgets_values": [ 137 | 20, 138 | 0, 139 | true, 140 | true 141 | ] 142 | }, 143 | { 144 | "id": 16, 145 | "type": "DownloadAndLoadHyVideoTextEncoder", 146 | "pos": [ 147 | -300, 148 | 380 149 | ], 150 | "size": [ 151 | 441, 152 | 178 153 | ], 154 | "flags": {}, 155 | "order": 2, 156 | "mode": 0, 157 | "inputs": [], 158 | "outputs": [ 159 | { 160 | "name": "hyvid_text_encoder", 161 | "type": "HYVIDTEXTENCODER", 162 | "links": [ 163 | 35 164 | ] 165 | } 166 | ], 167 | "properties": { 168 | "Node name for S&R": "DownloadAndLoadHyVideoTextEncoder" 169 | }, 170 | "widgets_values": [ 171 | "Kijai/llava-llama-3-8b-text-encoder-tokenizer", 172 | "openai/clip-vit-large-patch14", 173 | "fp16", 174 | false, 175 | 2, 176 | "bnb_nf4" 177 | ] 178 | }, 179 | { 180 | "id": 30, 181 | "type": "HyVideoTextEncode", 182 | "pos": [ 183 | 210, 184 | 380 185 | ], 186 | "size": [ 187 | 400, 188 | 200 189 | ], 190 | "flags": {}, 191 | "order": 9, 192 | "mode": 0, 193 | "inputs": [ 194 | { 195 | "name": "text_encoders", 196 | "type": "HYVIDTEXTENCODER", 197 | "link": 35 198 | }, 199 | { 200 | "name": "custom_prompt_template", 201 | "type": "PROMPT_TEMPLATE", 202 | "link": null, 203 | "shape": 7 204 | }, 205 | { 206 | "name": "clip_l", 207 | "type": "CLIP", 208 | "link": null, 209 | "shape": 7 210 | } 211 | ], 212 | "outputs": [ 213 | { 214 | "name": "hyvid_embeds", 215 | "type": "HYVIDEMBEDS", 216 | "links": [ 217 | 36 218 | ] 219 | } 220 | ], 221 | "properties": { 222 | "Node name for S&R": "HyVideoTextEncode" 223 | }, 224 | "widgets_values": [ 225 | "high quality anime style movie featuring a wolf in a forest", 226 | "bad quality video", 227 | "video" 228 | ] 229 | }, 230 | { 231 | "id": 36, 232 | "type": "Note", 233 | "pos": [ 234 | -140, 235 | 600 236 | ], 237 | "size": [ 238 | 262.6133117675781, 239 | 105.17330932617188 240 | ], 241 | "flags": {}, 242 | "order": 3, 243 | "mode": 0, 244 | "inputs": [], 245 | "outputs": [], 246 | "properties": {}, 247 | "widgets_values": [ 248 | "bnb_nf4 requires bitsandbytes installed, it is necessary if you don't have VRAM to run the text encoding, it will not affect VRAM usage during video sampling" 249 | ], 250 | "color": "#432", 251 | "bgcolor": "#653" 252 | }, 253 | { 254 | "id": 38, 255 | "type": "Note", 256 | "pos": [ 257 | 241.49647521972656, 258 | 163.09042358398438 259 | ], 260 | "size": [ 261 | 313.94659423828125, 262 | 76.57331085205078 263 | ], 264 | "flags": {}, 265 | "order": 4, 266 | "mode": 0, 267 | "inputs": [], 268 | "outputs": [], 269 | "properties": {}, 270 | "widgets_values": [ 271 | "If you have sageattention working, use that for fastest sampling." 272 | ], 273 | "color": "#432", 274 | "bgcolor": "#653" 275 | }, 276 | { 277 | "id": 1, 278 | "type": "HyVideoModelLoader", 279 | "pos": [ 280 | 24, 281 | -63 282 | ], 283 | "size": [ 284 | 509.7506103515625, 285 | 178 286 | ], 287 | "flags": {}, 288 | "order": 8, 289 | "mode": 0, 290 | "inputs": [ 291 | { 292 | "name": "compile_args", 293 | "type": "COMPILEARGS", 294 | "link": null, 295 | "shape": 7 296 | }, 297 | { 298 | "name": "block_swap_args", 299 | "type": "BLOCKSWAPARGS", 300 | "link": 43, 301 | "shape": 7 302 | } 303 | ], 304 | "outputs": [ 305 | { 306 | "name": "model", 307 | "type": "HYVIDEOMODEL", 308 | "links": [ 309 | 2 310 | ], 311 | "slot_index": 0 312 | } 313 | ], 314 | "properties": { 315 | "Node name for S&R": "HyVideoModelLoader" 316 | }, 317 | "widgets_values": [ 318 | "hyvideo\\hunyuan_video_720_fp8_e4m3fn.safetensors", 319 | "bf16", 320 | "fp8_e4m3fn", 321 | "offload_device", 322 | "sdpa" 323 | ] 324 | }, 325 | { 326 | "id": 37, 327 | "type": "Note", 328 | "pos": [ 329 | -321.4975280761719, 330 | -211.55799865722656 331 | ], 332 | "size": [ 333 | 262.6133117675781, 334 | 105.17330932617188 335 | ], 336 | "flags": {}, 337 | "order": 5, 338 | "mode": 0, 339 | "inputs": [], 340 | "outputs": [], 341 | "properties": {}, 342 | "widgets_values": [ 343 | "To further reduce VRAM usage, you can also enable swapping for up to 40 single blocks. More you swap, slower it gets though." 344 | ], 345 | "color": "#432", 346 | "bgcolor": "#653" 347 | }, 348 | { 349 | "id": 39, 350 | "type": "HyVideoTorchCompileSettings", 351 | "pos": [ 352 | -353.23699951171875, 353 | -576.843017578125 354 | ], 355 | "size": [ 356 | 441, 357 | 274 358 | ], 359 | "flags": {}, 360 | "order": 6, 361 | "mode": 0, 362 | "inputs": [], 363 | "outputs": [ 364 | { 365 | "name": "torch_compile_args", 366 | "type": "COMPILEARGS", 367 | "links": [] 368 | } 369 | ], 370 | "properties": { 371 | "Node name for S&R": "HyVideoTorchCompileSettings" 372 | }, 373 | "widgets_values": [ 374 | "inductor", 375 | false, 376 | "default", 377 | false, 378 | 64, 379 | true, 380 | true, 381 | true, 382 | true, 383 | true 384 | ] 385 | }, 386 | { 387 | "id": 40, 388 | "type": "Note", 389 | "pos": [ 390 | 113.16300201416016, 391 | -577.575927734375 392 | ], 393 | "size": [ 394 | 262.6133117675781, 395 | 105.17330932617188 396 | ], 397 | "flags": {}, 398 | "order": 7, 399 | "mode": 0, 400 | "inputs": [], 401 | "outputs": [], 402 | "properties": {}, 403 | "widgets_values": [ 404 | "If you have working Triton installation, compilation can further reduce VRAM use, granted you can run the actual compilation process. To enable compilation, connect this to the model loader" 405 | ], 406 | "color": "#432", 407 | "bgcolor": "#653" 408 | }, 409 | { 410 | "id": 34, 411 | "type": "VHS_VideoCombine", 412 | "pos": [ 413 | 1367, 414 | -275 415 | ], 416 | "size": [ 417 | 371.7926940917969, 418 | 310 419 | ], 420 | "flags": {}, 421 | "order": 12, 422 | "mode": 0, 423 | "inputs": [ 424 | { 425 | "name": "images", 426 | "type": "IMAGE", 427 | "link": 42 428 | }, 429 | { 430 | "name": "audio", 431 | "type": "AUDIO", 432 | "link": null, 433 | "shape": 7 434 | }, 435 | { 436 | "name": "meta_batch", 437 | "type": "VHS_BatchManager", 438 | "link": null, 439 | "shape": 7 440 | }, 441 | { 442 | "name": "vae", 443 | "type": "VAE", 444 | "link": null, 445 | "shape": 7 446 | } 447 | ], 448 | "outputs": [ 449 | { 450 | "name": "Filenames", 451 | "type": "VHS_FILENAMES", 452 | "links": null 453 | } 454 | ], 455 | "properties": { 456 | "Node name for S&R": "VHS_VideoCombine" 457 | }, 458 | "widgets_values": { 459 | "frame_rate": 24, 460 | "loop_count": 0, 461 | "filename_prefix": "HunyuanVideo", 462 | "format": "video/h264-mp4", 463 | "pix_fmt": "yuv420p", 464 | "crf": 19, 465 | "save_metadata": true, 466 | "pingpong": false, 467 | "save_output": true, 468 | "videopreview": { 469 | "hidden": false, 470 | "paused": false, 471 | "params": { 472 | "filename": "HunyuanVideo_00059.mp4", 473 | "subfolder": "", 474 | "type": "temp", 475 | "format": "video/h264-mp4", 476 | "frame_rate": 16 477 | }, 478 | "muted": false 479 | } 480 | } 481 | }, 482 | { 483 | "id": 5, 484 | "type": "HyVideoDecode", 485 | "pos": [ 486 | 920, 487 | -279 488 | ], 489 | "size": [ 490 | 345.4285888671875, 491 | 150 492 | ], 493 | "flags": {}, 494 | "order": 11, 495 | "mode": 0, 496 | "inputs": [ 497 | { 498 | "name": "vae", 499 | "type": "VAE", 500 | "link": 6 501 | }, 502 | { 503 | "name": "samples", 504 | "type": "LATENT", 505 | "link": 4 506 | } 507 | ], 508 | "outputs": [ 509 | { 510 | "name": "images", 511 | "type": "IMAGE", 512 | "links": [ 513 | 42 514 | ], 515 | "slot_index": 0 516 | } 517 | ], 518 | "properties": { 519 | "Node name for S&R": "HyVideoDecode" 520 | }, 521 | "widgets_values": [ 522 | true, 523 | 64, 524 | 128, 525 | false 526 | ] 527 | } 528 | ], 529 | "links": [ 530 | [ 531 | 2, 532 | 1, 533 | 0, 534 | 3, 535 | 0, 536 | "HYVIDEOMODEL" 537 | ], 538 | [ 539 | 4, 540 | 3, 541 | 0, 542 | 5, 543 | 1, 544 | "LATENT" 545 | ], 546 | [ 547 | 6, 548 | 7, 549 | 0, 550 | 5, 551 | 0, 552 | "VAE" 553 | ], 554 | [ 555 | 35, 556 | 16, 557 | 0, 558 | 30, 559 | 0, 560 | "HYVIDTEXTENCODER" 561 | ], 562 | [ 563 | 36, 564 | 30, 565 | 0, 566 | 3, 567 | 1, 568 | "HYVIDEMBEDS" 569 | ], 570 | [ 571 | 42, 572 | 5, 573 | 0, 574 | 34, 575 | 0, 576 | "IMAGE" 577 | ], 578 | [ 579 | 43, 580 | 35, 581 | 0, 582 | 1, 583 | 1, 584 | "BLOCKSWAPARGS" 585 | ] 586 | ], 587 | "groups": [], 588 | "config": {}, 589 | "extra": { 590 | "ds": { 591 | "scale": 1, 592 | "offset": [ 593 | 499.8829252621208, 594 | 382.71199061362694 595 | ] 596 | } 597 | }, 598 | "version": 0.4 599 | } -------------------------------------------------------------------------------- /example_workflows/hyvideo_prompt_mix_experimental.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 69, 3 | "last_link_id": 105, 4 | "nodes": [ 5 | { 6 | "id": 7, 7 | "type": "HyVideoVAELoader", 8 | "pos": [ 9 | -275.3620910644531, 10 | -360.9834899902344 11 | ], 12 | "size": [ 13 | 379.166748046875, 14 | 82 15 | ], 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "compile_args", 22 | "type": "COMPILEARGS", 23 | "link": null, 24 | "shape": 7 25 | } 26 | ], 27 | "outputs": [ 28 | { 29 | "name": "vae", 30 | "type": "VAE", 31 | "links": [ 32 | 6 33 | ], 34 | "slot_index": 0 35 | } 36 | ], 37 | "properties": { 38 | "Node name for S&R": "HyVideoVAELoader" 39 | }, 40 | "widgets_values": [ 41 | "hyvid/hunyuan_video_vae_bf16.safetensors", 42 | "bf16" 43 | ] 44 | }, 45 | { 46 | "id": 5, 47 | "type": "HyVideoDecode", 48 | "pos": [ 49 | 651, 50 | -285 51 | ], 52 | "size": [ 53 | 345.4285888671875, 54 | 150 55 | ], 56 | "flags": {}, 57 | "order": 8, 58 | "mode": 0, 59 | "inputs": [ 60 | { 61 | "name": "vae", 62 | "type": "VAE", 63 | "link": 6 64 | }, 65 | { 66 | "name": "samples", 67 | "type": "LATENT", 68 | "link": 105 69 | } 70 | ], 71 | "outputs": [ 72 | { 73 | "name": "images", 74 | "type": "IMAGE", 75 | "links": [ 76 | 42 77 | ], 78 | "slot_index": 0 79 | } 80 | ], 81 | "properties": { 82 | "Node name for S&R": "HyVideoDecode" 83 | }, 84 | "widgets_values": [ 85 | true, 86 | 64, 87 | 128, 88 | false 89 | ] 90 | }, 91 | { 92 | "id": 1, 93 | "type": "HyVideoModelLoader", 94 | "pos": [ 95 | -263.7066345214844, 96 | -193.9147491455078 97 | ], 98 | "size": [ 99 | 426.1773986816406, 100 | 194 101 | ], 102 | "flags": {}, 103 | "order": 4, 104 | "mode": 0, 105 | "inputs": [ 106 | { 107 | "name": "compile_args", 108 | "type": "COMPILEARGS", 109 | "link": 72, 110 | "shape": 7 111 | }, 112 | { 113 | "name": "block_swap_args", 114 | "type": "BLOCKSWAPARGS", 115 | "link": null, 116 | "shape": 7 117 | }, 118 | { 119 | "name": "lora", 120 | "type": "HYVIDLORA", 121 | "link": null, 122 | "shape": 7 123 | } 124 | ], 125 | "outputs": [ 126 | { 127 | "name": "model", 128 | "type": "HYVIDEOMODEL", 129 | "links": [ 130 | 101 131 | ], 132 | "slot_index": 0 133 | } 134 | ], 135 | "properties": { 136 | "Node name for S&R": "HyVideoModelLoader" 137 | }, 138 | "widgets_values": [ 139 | "hyvideo/hunyuan_video_720_fp8_e4m3fn.safetensors", 140 | "bf16", 141 | "fp8_e4m3fn_fast", 142 | "offload_device", 143 | "sageattn_varlen" 144 | ] 145 | }, 146 | { 147 | "id": 55, 148 | "type": "HyVideoTorchCompileSettings", 149 | "pos": [ 150 | -856.2830200195312, 151 | -202.7902069091797 152 | ], 153 | "size": [ 154 | 441, 155 | 274 156 | ], 157 | "flags": {}, 158 | "order": 1, 159 | "mode": 0, 160 | "inputs": [], 161 | "outputs": [ 162 | { 163 | "name": "torch_compile_args", 164 | "type": "COMPILEARGS", 165 | "links": [ 166 | 72 167 | ] 168 | } 169 | ], 170 | "properties": { 171 | "Node name for S&R": "HyVideoTorchCompileSettings" 172 | }, 173 | "widgets_values": [ 174 | "inductor", 175 | false, 176 | "default", 177 | false, 178 | 64, 179 | true, 180 | true, 181 | false, 182 | false, 183 | false 184 | ] 185 | }, 186 | { 187 | "id": 34, 188 | "type": "VHS_VideoCombine", 189 | "pos": [ 190 | 1090.357421875, 191 | -361.1742858886719 192 | ], 193 | "size": [ 194 | 704.4984741210938, 195 | 780.3323364257812 196 | ], 197 | "flags": {}, 198 | "order": 9, 199 | "mode": 0, 200 | "inputs": [ 201 | { 202 | "name": "images", 203 | "type": "IMAGE", 204 | "link": 42 205 | }, 206 | { 207 | "name": "audio", 208 | "type": "AUDIO", 209 | "link": null, 210 | "shape": 7 211 | }, 212 | { 213 | "name": "meta_batch", 214 | "type": "VHS_BatchManager", 215 | "link": null, 216 | "shape": 7 217 | }, 218 | { 219 | "name": "vae", 220 | "type": "VAE", 221 | "link": null, 222 | "shape": 7 223 | } 224 | ], 225 | "outputs": [ 226 | { 227 | "name": "Filenames", 228 | "type": "VHS_FILENAMES", 229 | "links": null 230 | } 231 | ], 232 | "properties": { 233 | "Node name for S&R": "VHS_VideoCombine" 234 | }, 235 | "widgets_values": { 236 | "frame_rate": 24, 237 | "loop_count": 0, 238 | "filename_prefix": "HunyuanVideo", 239 | "format": "video/h264-mp4", 240 | "pix_fmt": "yuv420p", 241 | "crf": 19, 242 | "save_metadata": true, 243 | "pingpong": false, 244 | "save_output": true, 245 | "videopreview": { 246 | "hidden": false, 247 | "paused": false, 248 | "params": { 249 | "filename": "HunyuanVideo_00144.mp4", 250 | "subfolder": "", 251 | "type": "output", 252 | "format": "video/h264-mp4", 253 | "frame_rate": 24 254 | }, 255 | "muted": false 256 | } 257 | } 258 | }, 259 | { 260 | "id": 54, 261 | "type": "SplineEditor", 262 | "pos": [ 263 | 256.6975402832031, 264 | 415.37872314453125 265 | ], 266 | "size": [ 267 | 557, 268 | 942 269 | ], 270 | "flags": {}, 271 | "order": 2, 272 | "mode": 0, 273 | "inputs": [ 274 | { 275 | "name": "bg_image", 276 | "type": "IMAGE", 277 | "link": null, 278 | "shape": 7 279 | } 280 | ], 281 | "outputs": [ 282 | { 283 | "name": "mask", 284 | "type": "MASK", 285 | "links": null 286 | }, 287 | { 288 | "name": "coord_str", 289 | "type": "STRING", 290 | "links": null 291 | }, 292 | { 293 | "name": "float", 294 | "type": "FLOAT", 295 | "links": [ 296 | 104 297 | ], 298 | "slot_index": 2 299 | }, 300 | { 301 | "name": "count", 302 | "type": "INT", 303 | "links": null, 304 | "slot_index": 3 305 | }, 306 | { 307 | "name": "normalized_str", 308 | "type": "STRING", 309 | "links": null 310 | } 311 | ], 312 | "properties": { 313 | "Node name for S&R": "SplineEditor", 314 | "points": "SplineEditor" 315 | }, 316 | "widgets_values": [ 317 | "[{\"x\":0,\"y\":512},{\"x\":227.27610724891395,\"y\":89.92939495460624},{\"x\":512,\"y\":0}]", 318 | "[{\"x\":0.8097066879272461,\"y\":510.08172607421875},{\"x\":31.270395278930664,\"y\":439.3046569824219},{\"x\":64.21856689453125,\"y\":366.2564392089844},{\"x\":96.14781951904297,\"y\":299.5324401855469},{\"x\":127.18185424804688,\"y\":239.24691772460938},{\"x\":160.6988525390625,\"y\":180.31076049804688},{\"x\":191.04666137695312,\"y\":133.8830108642578},{\"x\":224.03964233398438,\"y\":93.23545837402344},{\"x\":256.7614440917969,\"y\":65.20995330810547},{\"x\":288.4270935058594,\"y\":46.15709686279297},{\"x\":319.38470458984375,\"y\":32.372684478759766},{\"x\":351.5347900390625,\"y\":21.644088745117188},{\"x\":384.4419250488281,\"y\":13.514848709106445},{\"x\":416.2929992675781,\"y\":7.808127403259277},{\"x\":448.40399169921875,\"y\":3.8050320148468018},{\"x\":479.1258850097656,\"y\":1.3530210256576538},{\"x\":511.4590759277344,\"y\":0.012910770252346992}]", 319 | 512, 320 | 512, 321 | 17, 322 | "time", 323 | "cardinal", 324 | 0.5, 325 | 1, 326 | "list", 327 | 0, 328 | 1, 329 | null, 330 | null, 331 | null 332 | ] 333 | }, 334 | { 335 | "id": 69, 336 | "type": "HyVideoPromptMixSampler", 337 | "pos": [ 338 | 266.9952392578125, 339 | -224.5370330810547 340 | ], 341 | "size": [ 342 | 315, 343 | 566 344 | ], 345 | "flags": {}, 346 | "order": 7, 347 | "mode": 0, 348 | "inputs": [ 349 | { 350 | "name": "model", 351 | "type": "HYVIDEOMODEL", 352 | "link": 101 353 | }, 354 | { 355 | "name": "hyvid_embeds", 356 | "type": "HYVIDEMBEDS", 357 | "link": 102 358 | }, 359 | { 360 | "name": "hyvid_embeds_2", 361 | "type": "HYVIDEMBEDS", 362 | "link": 103 363 | }, 364 | { 365 | "name": "interpolation_curve", 366 | "type": "FLOAT", 367 | "link": 104, 368 | "widget": { 369 | "name": "interpolation_curve" 370 | }, 371 | "shape": 7 372 | } 373 | ], 374 | "outputs": [ 375 | { 376 | "name": "samples", 377 | "type": "LATENT", 378 | "links": [ 379 | 105 380 | ] 381 | } 382 | ], 383 | "properties": { 384 | "Node name for S&R": "HyVideoPromptMixSampler" 385 | }, 386 | "widgets_values": [ 387 | 720, 388 | 480, 389 | 65, 390 | 20, 391 | 6, 392 | 9, 393 | true, 394 | 5, 395 | "fixed", 396 | 0.7000000000000001, 397 | 0 398 | ] 399 | }, 400 | { 401 | "id": 16, 402 | "type": "DownloadAndLoadHyVideoTextEncoder", 403 | "pos": [ 404 | -1210, 405 | 300 406 | ], 407 | "size": [ 408 | 441, 409 | 178 410 | ], 411 | "flags": {}, 412 | "order": 3, 413 | "mode": 0, 414 | "inputs": [], 415 | "outputs": [ 416 | { 417 | "name": "hyvid_text_encoder", 418 | "type": "HYVIDTEXTENCODER", 419 | "links": [ 420 | 35, 421 | 54 422 | ], 423 | "slot_index": 0 424 | } 425 | ], 426 | "properties": { 427 | "Node name for S&R": "DownloadAndLoadHyVideoTextEncoder" 428 | }, 429 | "widgets_values": [ 430 | "Kijai/llava-llama-3-8b-text-encoder-tokenizer", 431 | "openai/clip-vit-large-patch14", 432 | "fp16", 433 | false, 434 | 2, 435 | "disabled" 436 | ] 437 | }, 438 | { 439 | "id": 30, 440 | "type": "HyVideoTextEncode", 441 | "pos": [ 442 | -640, 443 | 300 444 | ], 445 | "size": [ 446 | 415.4671630859375, 447 | 232.31947326660156 448 | ], 449 | "flags": {}, 450 | "order": 5, 451 | "mode": 0, 452 | "inputs": [ 453 | { 454 | "name": "text_encoders", 455 | "type": "HYVIDTEXTENCODER", 456 | "link": 35 457 | }, 458 | { 459 | "name": "custom_prompt_template", 460 | "type": "PROMPT_TEMPLATE", 461 | "link": null, 462 | "shape": 7 463 | }, 464 | { 465 | "name": "clip_l", 466 | "type": "CLIP", 467 | "link": null, 468 | "shape": 7 469 | }, 470 | { 471 | "name": "hyvid_cfg", 472 | "type": "HYVID_CFG", 473 | "link": null, 474 | "shape": 7 475 | } 476 | ], 477 | "outputs": [ 478 | { 479 | "name": "hyvid_embeds", 480 | "type": "HYVIDEMBEDS", 481 | "links": [ 482 | 102 483 | ], 484 | "slot_index": 0 485 | } 486 | ], 487 | "properties": { 488 | "Node name for S&R": "HyVideoTextEncode" 489 | }, 490 | "widgets_values": [ 491 | "high quality nature video of a red panda balancing on a bamboo stick while a bird lands on the panda's head, there's a waterfall in the background", 492 | true, 493 | "video" 494 | ] 495 | }, 496 | { 497 | "id": 43, 498 | "type": "HyVideoTextEncode", 499 | "pos": [ 500 | -640, 501 | 590 502 | ], 503 | "size": [ 504 | 435.94158935546875, 505 | 286.3716125488281 506 | ], 507 | "flags": {}, 508 | "order": 6, 509 | "mode": 0, 510 | "inputs": [ 511 | { 512 | "name": "text_encoders", 513 | "type": "HYVIDTEXTENCODER", 514 | "link": 54 515 | }, 516 | { 517 | "name": "custom_prompt_template", 518 | "type": "PROMPT_TEMPLATE", 519 | "link": null, 520 | "shape": 7 521 | }, 522 | { 523 | "name": "clip_l", 524 | "type": "CLIP", 525 | "link": null, 526 | "shape": 7 527 | }, 528 | { 529 | "name": "hyvid_cfg", 530 | "type": "HYVID_CFG", 531 | "link": null, 532 | "shape": 7 533 | } 534 | ], 535 | "outputs": [ 536 | { 537 | "name": "hyvid_embeds", 538 | "type": "HYVIDEMBEDS", 539 | "links": [ 540 | 103 541 | ], 542 | "slot_index": 0 543 | } 544 | ], 545 | "properties": { 546 | "Node name for S&R": "HyVideoTextEncode" 547 | }, 548 | "widgets_values": [ 549 | "high quality nature video of a cat balancing on a bamboo stick while a bird lands on the cat's head, there's a waterfall in the background", 550 | "bad quality video", 551 | "video" 552 | ] 553 | } 554 | ], 555 | "links": [ 556 | [ 557 | 6, 558 | 7, 559 | 0, 560 | 5, 561 | 0, 562 | "VAE" 563 | ], 564 | [ 565 | 35, 566 | 16, 567 | 0, 568 | 30, 569 | 0, 570 | "HYVIDTEXTENCODER" 571 | ], 572 | [ 573 | 42, 574 | 5, 575 | 0, 576 | 34, 577 | 0, 578 | "IMAGE" 579 | ], 580 | [ 581 | 54, 582 | 16, 583 | 0, 584 | 43, 585 | 0, 586 | "HYVIDTEXTENCODER" 587 | ], 588 | [ 589 | 72, 590 | 55, 591 | 0, 592 | 1, 593 | 0, 594 | "COMPILEARGS" 595 | ], 596 | [ 597 | 101, 598 | 1, 599 | 0, 600 | 69, 601 | 0, 602 | "HYVIDEOMODEL" 603 | ], 604 | [ 605 | 102, 606 | 30, 607 | 0, 608 | 69, 609 | 1, 610 | "HYVIDEMBEDS" 611 | ], 612 | [ 613 | 103, 614 | 43, 615 | 0, 616 | 69, 617 | 2, 618 | "HYVIDEMBEDS" 619 | ], 620 | [ 621 | 104, 622 | 54, 623 | 2, 624 | 69, 625 | 3, 626 | "FLOAT" 627 | ], 628 | [ 629 | 105, 630 | 69, 631 | 0, 632 | 5, 633 | 1, 634 | "LATENT" 635 | ] 636 | ], 637 | "groups": [], 638 | "config": {}, 639 | "extra": { 640 | "ds": { 641 | "scale": 0.6727499949325655, 642 | "offset": [ 643 | 1357.072821515262, 644 | 431.32537108482603 645 | ] 646 | } 647 | }, 648 | "version": 0.4 649 | } -------------------------------------------------------------------------------- /example_workflows/hyvideo_t2v_example_01.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 40, 3 | "last_link_id": 50, 4 | "nodes": [ 5 | { 6 | "id": 1, 7 | "type": "HyVideoModelLoader", 8 | "pos": [ 9 | -285, 10 | -94 11 | ], 12 | "size": [ 13 | 426.1773986816406, 14 | 174 15 | ], 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "compile_args", 22 | "type": "COMPILEARGS", 23 | "link": null, 24 | "shape": 7 25 | }, 26 | { 27 | "name": "block_swap_args", 28 | "type": "BLOCKSWAPARGS", 29 | "link": null, 30 | "shape": 7 31 | } 32 | ], 33 | "outputs": [ 34 | { 35 | "name": "model", 36 | "type": "HYVIDEOMODEL", 37 | "links": [ 38 | 2 39 | ], 40 | "slot_index": 0 41 | } 42 | ], 43 | "properties": { 44 | "Node name for S&R": "HyVideoModelLoader" 45 | }, 46 | "widgets_values": [ 47 | "hyvideo\\hunyuan_video_720_fp8_e4m3fn.safetensors", 48 | "bf16", 49 | "fp8_e4m3fn", 50 | "offload_device", 51 | "sdpa" 52 | ] 53 | }, 54 | { 55 | "id": 3, 56 | "type": "HyVideoSampler", 57 | "pos": [ 58 | 266, 59 | -141 60 | ], 61 | "size": [ 62 | 315, 63 | 334 64 | ], 65 | "flags": {}, 66 | "order": 4, 67 | "mode": 0, 68 | "inputs": [ 69 | { 70 | "name": "model", 71 | "type": "HYVIDEOMODEL", 72 | "link": 2 73 | }, 74 | { 75 | "name": "hyvid_embeds", 76 | "type": "HYVIDEMBEDS", 77 | "link": 36 78 | }, 79 | { 80 | "name": "samples", 81 | "type": "LATENT", 82 | "link": null, 83 | "shape": 7 84 | }, 85 | { 86 | "name": "stg_args", 87 | "type": "STGARGS", 88 | "link": null, 89 | "shape": 7 90 | } 91 | ], 92 | "outputs": [ 93 | { 94 | "name": "samples", 95 | "type": "LATENT", 96 | "links": [ 97 | 4 98 | ], 99 | "slot_index": 0 100 | } 101 | ], 102 | "properties": { 103 | "Node name for S&R": "HyVideoSampler" 104 | }, 105 | "widgets_values": [ 106 | 512, 107 | 320, 108 | 85, 109 | 30, 110 | 6, 111 | 9, 112 | 6, 113 | "fixed", 114 | 1, 115 | 1 116 | ] 117 | }, 118 | { 119 | "id": 7, 120 | "type": "HyVideoVAELoader", 121 | "pos": [ 122 | -277, 123 | -284 124 | ], 125 | "size": [ 126 | 379.166748046875, 127 | 82 128 | ], 129 | "flags": {}, 130 | "order": 1, 131 | "mode": 0, 132 | "inputs": [ 133 | { 134 | "name": "compile_args", 135 | "type": "COMPILEARGS", 136 | "link": null, 137 | "shape": 7 138 | } 139 | ], 140 | "outputs": [ 141 | { 142 | "name": "vae", 143 | "type": "VAE", 144 | "links": [ 145 | 6 146 | ], 147 | "slot_index": 0 148 | } 149 | ], 150 | "properties": { 151 | "Node name for S&R": "HyVideoVAELoader" 152 | }, 153 | "widgets_values": [ 154 | "hyvid\\hunyuan_video_vae_bf16.safetensors", 155 | "bf16" 156 | ] 157 | }, 158 | { 159 | "id": 16, 160 | "type": "DownloadAndLoadHyVideoTextEncoder", 161 | "pos": [ 162 | -312, 163 | 243 164 | ], 165 | "size": [ 166 | 441, 167 | 178 168 | ], 169 | "flags": {}, 170 | "order": 2, 171 | "mode": 0, 172 | "inputs": [], 173 | "outputs": [ 174 | { 175 | "name": "hyvid_text_encoder", 176 | "type": "HYVIDTEXTENCODER", 177 | "links": [ 178 | 35 179 | ] 180 | } 181 | ], 182 | "properties": { 183 | "Node name for S&R": "DownloadAndLoadHyVideoTextEncoder" 184 | }, 185 | "widgets_values": [ 186 | "Kijai/llava-llama-3-8b-text-encoder-tokenizer", 187 | "openai/clip-vit-large-patch14", 188 | "fp16", 189 | false, 190 | 2, 191 | "disabled" 192 | ] 193 | }, 194 | { 195 | "id": 30, 196 | "type": "HyVideoTextEncode", 197 | "pos": [ 198 | 179, 199 | 242 200 | ], 201 | "size": [ 202 | 408.91546630859375, 203 | 172.84060668945312 204 | ], 205 | "flags": {}, 206 | "order": 3, 207 | "mode": 0, 208 | "inputs": [ 209 | { 210 | "name": "text_encoders", 211 | "type": "HYVIDTEXTENCODER", 212 | "link": 35 213 | }, 214 | { 215 | "name": "custom_prompt_template", 216 | "type": "PROMPT_TEMPLATE", 217 | "link": null, 218 | "shape": 7 219 | }, 220 | { 221 | "name": "clip_l", 222 | "type": "CLIP", 223 | "link": null, 224 | "shape": 7 225 | } 226 | ], 227 | "outputs": [ 228 | { 229 | "name": "hyvid_embeds", 230 | "type": "HYVIDEMBEDS", 231 | "links": [ 232 | 36 233 | ] 234 | } 235 | ], 236 | "properties": { 237 | "Node name for S&R": "HyVideoTextEncode" 238 | }, 239 | "widgets_values": [ 240 | "high quality nature video of a red panda balancing on a bamboo stick while a bird lands on the panda's head, there's a waterfall in the background", 241 | "bad quality video", 242 | "video" 243 | ] 244 | }, 245 | { 246 | "id": 34, 247 | "type": "VHS_VideoCombine", 248 | "pos": [ 249 | 673.133544921875, 250 | -37.19999694824219 251 | ], 252 | "size": [ 253 | 580.7774658203125, 254 | 310 255 | ], 256 | "flags": {}, 257 | "order": 6, 258 | "mode": 0, 259 | "inputs": [ 260 | { 261 | "name": "images", 262 | "type": "IMAGE", 263 | "link": 42 264 | }, 265 | { 266 | "name": "audio", 267 | "type": "AUDIO", 268 | "link": null, 269 | "shape": 7 270 | }, 271 | { 272 | "name": "meta_batch", 273 | "type": "VHS_BatchManager", 274 | "link": null, 275 | "shape": 7 276 | }, 277 | { 278 | "name": "vae", 279 | "type": "VAE", 280 | "link": null, 281 | "shape": 7 282 | } 283 | ], 284 | "outputs": [ 285 | { 286 | "name": "Filenames", 287 | "type": "VHS_FILENAMES", 288 | "links": null 289 | } 290 | ], 291 | "properties": { 292 | "Node name for S&R": "VHS_VideoCombine" 293 | }, 294 | "widgets_values": { 295 | "frame_rate": 24, 296 | "loop_count": 0, 297 | "filename_prefix": "HunyuanVideo", 298 | "format": "video/h264-mp4", 299 | "pix_fmt": "yuv420p", 300 | "crf": 19, 301 | "save_metadata": true, 302 | "pingpong": false, 303 | "save_output": true, 304 | "videopreview": { 305 | "hidden": false, 306 | "paused": false, 307 | "params": { 308 | "filename": "HunyuanVideo_00009.mp4", 309 | "subfolder": "", 310 | "type": "temp", 311 | "format": "video/h264-mp4", 312 | "frame_rate": 24 313 | }, 314 | "muted": false 315 | } 316 | } 317 | }, 318 | { 319 | "id": 5, 320 | "type": "HyVideoDecode", 321 | "pos": [ 322 | 651, 323 | -285 324 | ], 325 | "size": [ 326 | 345.4285888671875, 327 | 150 328 | ], 329 | "flags": {}, 330 | "order": 5, 331 | "mode": 0, 332 | "inputs": [ 333 | { 334 | "name": "vae", 335 | "type": "VAE", 336 | "link": 6 337 | }, 338 | { 339 | "name": "samples", 340 | "type": "LATENT", 341 | "link": 4 342 | } 343 | ], 344 | "outputs": [ 345 | { 346 | "name": "images", 347 | "type": "IMAGE", 348 | "links": [ 349 | 42 350 | ], 351 | "slot_index": 0 352 | } 353 | ], 354 | "properties": { 355 | "Node name for S&R": "HyVideoDecode" 356 | }, 357 | "widgets_values": [ 358 | true, 359 | 64, 360 | 256, 361 | true 362 | ] 363 | } 364 | ], 365 | "links": [ 366 | [ 367 | 2, 368 | 1, 369 | 0, 370 | 3, 371 | 0, 372 | "HYVIDEOMODEL" 373 | ], 374 | [ 375 | 4, 376 | 3, 377 | 0, 378 | 5, 379 | 1, 380 | "LATENT" 381 | ], 382 | [ 383 | 6, 384 | 7, 385 | 0, 386 | 5, 387 | 0, 388 | "VAE" 389 | ], 390 | [ 391 | 35, 392 | 16, 393 | 0, 394 | 30, 395 | 0, 396 | "HYVIDTEXTENCODER" 397 | ], 398 | [ 399 | 36, 400 | 30, 401 | 0, 402 | 3, 403 | 1, 404 | "HYVIDEMBEDS" 405 | ], 406 | [ 407 | 42, 408 | 5, 409 | 0, 410 | 34, 411 | 0, 412 | "IMAGE" 413 | ] 414 | ], 415 | "groups": [], 416 | "config": {}, 417 | "extra": { 418 | "ds": { 419 | "scale": 0.9090909090909091, 420 | "offset": [ 421 | 740.8495512386833, 422 | 610.811990613627 423 | ] 424 | } 425 | }, 426 | "version": 0.4 427 | } -------------------------------------------------------------------------------- /fp8_optimization.py: -------------------------------------------------------------------------------- 1 | #based on ComfyUI's and MinusZoneAI's fp8_linear optimization 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | def fp8_linear_forward(cls, original_dtype, input): 7 | weight_dtype = cls.weight.dtype 8 | if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: 9 | if len(input.shape) == 3: 10 | #target_dtype = torch.float8_e5m2 if weight_dtype == torch.float8_e4m3fn else torch.float8_e4m3fn 11 | #inn = input.reshape(-1, input.shape[2]).to(target_dtype) 12 | inn = input.reshape(-1, input.shape[2]).to(weight_dtype) 13 | w = cls.weight.t() 14 | 15 | scale = torch.ones((1), device=input.device, dtype=torch.float32) 16 | bias = cls.bias.to(original_dtype) if cls.bias is not None else None 17 | 18 | if bias is not None: 19 | o = torch._scaled_mm(inn, w, out_dtype=original_dtype, bias=bias, scale_a=scale, scale_b=scale) 20 | else: 21 | o = torch._scaled_mm(inn, w, out_dtype=original_dtype, scale_a=scale, scale_b=scale) 22 | 23 | if isinstance(o, tuple): 24 | o = o[0] 25 | 26 | return o.reshape((-1, input.shape[1], cls.weight.shape[0])) 27 | else: 28 | return cls.original_forward(input.to(original_dtype)) 29 | else: 30 | return cls.original_forward(input) 31 | 32 | def convert_fp8_linear(module, original_dtype, params_to_keep={}): 33 | setattr(module, "fp8_matmul_enabled", True) 34 | 35 | for name, module in module.named_modules(): 36 | if not any(keyword in name for keyword in params_to_keep): 37 | if isinstance(module, nn.Linear): 38 | original_forward = module.forward 39 | setattr(module, "original_forward", original_forward) 40 | setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input)) 41 | -------------------------------------------------------------------------------- /hunyuan_empty_prompt_embeds_dict.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-HunyuanVideoWrapper/3ce9640497139b14462910b0dbf2be1df855c1d6/hunyuan_empty_prompt_embeds_dict.pt -------------------------------------------------------------------------------- /hyvideo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-HunyuanVideoWrapper/3ce9640497139b14462910b0dbf2be1df855c1d6/hyvideo/__init__.py -------------------------------------------------------------------------------- /hyvideo/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from .constants import * 3 | import re 4 | from .modules.models import HUNYUAN_VIDEO_CONFIG 5 | 6 | 7 | def parse_args(namespace=None): 8 | parser = argparse.ArgumentParser(description="HunyuanVideo inference script") 9 | 10 | parser = add_network_args(parser) 11 | parser = add_extra_models_args(parser) 12 | parser = add_denoise_schedule_args(parser) 13 | parser = add_inference_args(parser) 14 | 15 | args = parser.parse_args(namespace=namespace) 16 | args = sanity_check_args(args) 17 | 18 | return args 19 | 20 | 21 | def add_network_args(parser: argparse.ArgumentParser): 22 | group = parser.add_argument_group(title="HunyuanVideo network args") 23 | 24 | # Main model 25 | group.add_argument( 26 | "--model", 27 | type=str, 28 | choices=list(HUNYUAN_VIDEO_CONFIG.keys()), 29 | default="HYVideo-T/2-cfgdistill", 30 | ) 31 | group.add_argument( 32 | "--latent-channels", 33 | type=str, 34 | default=16, 35 | help="Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, " 36 | "it still needs to match the latent channels of the VAE model.", 37 | ) 38 | group.add_argument( 39 | "--precision", 40 | type=str, 41 | default="bf16", 42 | choices=PRECISIONS, 43 | help="Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.", 44 | ) 45 | 46 | # RoPE 47 | group.add_argument( 48 | "--rope-theta", type=int, default=256, help="Theta used in RoPE." 49 | ) 50 | return parser 51 | 52 | 53 | def add_extra_models_args(parser: argparse.ArgumentParser): 54 | group = parser.add_argument_group( 55 | title="Extra models args, including vae, text encoders and tokenizers)" 56 | ) 57 | 58 | # - VAE 59 | group.add_argument( 60 | "--vae", 61 | type=str, 62 | default="884-16c-hy", 63 | choices=list(VAE_PATH), 64 | help="Name of the VAE model.", 65 | ) 66 | group.add_argument( 67 | "--vae-precision", 68 | type=str, 69 | default="fp16", 70 | choices=PRECISIONS, 71 | help="Precision mode for the VAE model.", 72 | ) 73 | group.add_argument( 74 | "--vae-tiling", 75 | action="store_true", 76 | help="Enable tiling for the VAE model to save GPU memory.", 77 | ) 78 | group.set_defaults(vae_tiling=True) 79 | 80 | group.add_argument( 81 | "--text-encoder", 82 | type=str, 83 | default="llm", 84 | choices=list(TEXT_ENCODER_PATH), 85 | help="Name of the text encoder model.", 86 | ) 87 | group.add_argument( 88 | "--text-encoder-precision", 89 | type=str, 90 | default="fp16", 91 | choices=PRECISIONS, 92 | help="Precision mode for the text encoder model.", 93 | ) 94 | group.add_argument( 95 | "--text-states-dim", 96 | type=int, 97 | default=4096, 98 | help="Dimension of the text encoder hidden states.", 99 | ) 100 | group.add_argument( 101 | "--text-len", type=int, default=256, help="Maximum length of the text input." 102 | ) 103 | group.add_argument( 104 | "--tokenizer", 105 | type=str, 106 | default="llm", 107 | choices=list(TOKENIZER_PATH), 108 | help="Name of the tokenizer model.", 109 | ) 110 | group.add_argument( 111 | "--prompt-template", 112 | type=str, 113 | default="dit-llm-encode", 114 | choices=PROMPT_TEMPLATE, 115 | help="Image prompt template for the decoder-only text encoder model.", 116 | ) 117 | group.add_argument( 118 | "--prompt-template-video", 119 | type=str, 120 | default="dit-llm-encode-video", 121 | choices=PROMPT_TEMPLATE, 122 | help="Video prompt template for the decoder-only text encoder model.", 123 | ) 124 | group.add_argument( 125 | "--hidden-state-skip-layer", 126 | type=int, 127 | default=2, 128 | help="Skip layer for hidden states.", 129 | ) 130 | group.add_argument( 131 | "--apply-final-norm", 132 | action="store_true", 133 | help="Apply final normalization to the used text encoder hidden states.", 134 | ) 135 | 136 | # - CLIP 137 | group.add_argument( 138 | "--text-encoder-2", 139 | type=str, 140 | default="clipL", 141 | choices=list(TEXT_ENCODER_PATH), 142 | help="Name of the second text encoder model.", 143 | ) 144 | group.add_argument( 145 | "--text-encoder-precision-2", 146 | type=str, 147 | default="fp16", 148 | choices=PRECISIONS, 149 | help="Precision mode for the second text encoder model.", 150 | ) 151 | group.add_argument( 152 | "--text-states-dim-2", 153 | type=int, 154 | default=768, 155 | help="Dimension of the second text encoder hidden states.", 156 | ) 157 | group.add_argument( 158 | "--tokenizer-2", 159 | type=str, 160 | default="clipL", 161 | choices=list(TOKENIZER_PATH), 162 | help="Name of the second tokenizer model.", 163 | ) 164 | group.add_argument( 165 | "--text-len-2", 166 | type=int, 167 | default=77, 168 | help="Maximum length of the second text input.", 169 | ) 170 | 171 | return parser 172 | 173 | 174 | def add_denoise_schedule_args(parser: argparse.ArgumentParser): 175 | group = parser.add_argument_group(title="Denoise schedule args") 176 | 177 | group.add_argument( 178 | "--denoise-type", 179 | type=str, 180 | default="flow", 181 | help="Denoise type for noised inputs.", 182 | ) 183 | 184 | # Flow Matching 185 | group.add_argument( 186 | "--flow-shift", 187 | type=float, 188 | default=9.0, 189 | help="Shift factor for flow matching schedulers.", 190 | ) 191 | group.add_argument( 192 | "--flow-reverse", 193 | action="store_true", 194 | help="If reverse, learning/sampling from t=1 -> t=0.", 195 | ) 196 | group.add_argument( 197 | "--flow-solver", 198 | type=str, 199 | default="euler", 200 | help="Solver for flow matching.", 201 | ) 202 | group.add_argument( 203 | "--use-linear-quadratic-schedule", 204 | action="store_true", 205 | help="Use linear quadratic schedule for flow matching." 206 | "Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)", 207 | ) 208 | group.add_argument( 209 | "--linear-schedule-end", 210 | type=int, 211 | default=25, 212 | help="End step for linear quadratic schedule for flow matching.", 213 | ) 214 | 215 | return parser 216 | 217 | 218 | def add_inference_args(parser: argparse.ArgumentParser): 219 | group = parser.add_argument_group(title="Inference args") 220 | 221 | # ======================== Model loads ======================== 222 | group.add_argument( 223 | "--model-base", 224 | type=str, 225 | default="ckpts", 226 | help="Root path of all the models, including t2v models and extra models.", 227 | ) 228 | group.add_argument( 229 | "--dit-weight", 230 | type=str, 231 | default="ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", 232 | help="Path to the HunyuanVideo model. If None, search the model in the args.model_root." 233 | "1. If it is a file, load the model directly." 234 | "2. If it is a directory, search the model in the directory. Support two types of models: " 235 | "1) named `pytorch_model_*.pt`" 236 | "2) named `*_model_states.pt`, where * can be `mp_rank_00`.", 237 | ) 238 | group.add_argument( 239 | "--model-resolution", 240 | type=str, 241 | default="540p", 242 | choices=["540p", "720p"], 243 | help="Root path of all the models, including t2v models and extra models.", 244 | ) 245 | group.add_argument( 246 | "--load-key", 247 | type=str, 248 | default="module", 249 | help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.", 250 | ) 251 | group.add_argument( 252 | "--use-cpu-offload", 253 | action="store_true", 254 | help="Use CPU offload for the model load.", 255 | ) 256 | 257 | # ======================== Inference general setting ======================== 258 | group.add_argument( 259 | "--batch-size", 260 | type=int, 261 | default=1, 262 | help="Batch size for inference and evaluation.", 263 | ) 264 | group.add_argument( 265 | "--infer-steps", 266 | type=int, 267 | default=30, 268 | help="Number of denoising steps for inference.", 269 | ) 270 | group.add_argument( 271 | "--disable-autocast", 272 | action="store_true", 273 | help="Disable autocast for denoising loop and vae decoding in pipeline sampling.", 274 | ) 275 | group.add_argument( 276 | "--save-path", 277 | type=str, 278 | default="./results", 279 | help="Path to save the generated samples.", 280 | ) 281 | group.add_argument( 282 | "--save-path-suffix", 283 | type=str, 284 | default="", 285 | help="Suffix for the directory of saved samples.", 286 | ) 287 | group.add_argument( 288 | "--name-suffix", 289 | type=str, 290 | default="", 291 | help="Suffix for the names of saved samples.", 292 | ) 293 | group.add_argument( 294 | "--num-videos", 295 | type=int, 296 | default=1, 297 | help="Number of videos to generate for each prompt.", 298 | ) 299 | # ---sample size--- 300 | group.add_argument( 301 | "--video-size", 302 | type=int, 303 | nargs="+", 304 | default=(720, 1280), 305 | help="Video size for training. If a single value is provided, it will be used for both height " 306 | "and width. If two values are provided, they will be used for height and width " 307 | "respectively.", 308 | ) 309 | group.add_argument( 310 | "--video-length", 311 | type=int, 312 | default=129, 313 | help="How many frames to sample from a video. if using 3d vae, the number should be 4n+1", 314 | ) 315 | # --- prompt --- 316 | group.add_argument( 317 | "--prompt", 318 | type=str, 319 | default=None, 320 | help="Prompt for sampling during evaluation.", 321 | ) 322 | group.add_argument( 323 | "--seed-type", 324 | type=str, 325 | default="auto", 326 | choices=["file", "random", "fixed", "auto"], 327 | help="Seed type for evaluation. If file, use the seed from the CSV file. If random, generate a " 328 | "random seed. If fixed, use the fixed seed given by `--seed`. If auto, `csv` will use the " 329 | "seed column if available, otherwise use the fixed `seed` value. `prompt` will use the " 330 | "fixed `seed` value.", 331 | ) 332 | group.add_argument("--seed", type=int, default=0, help="Seed for evaluation.") 333 | 334 | # Classifier-Free Guidance 335 | group.add_argument( 336 | "--neg-prompt", type=str, default=None, help="Negative prompt for sampling." 337 | ) 338 | group.add_argument( 339 | "--cfg-scale", type=float, default=1.0, help="Classifier free guidance scale." 340 | ) 341 | group.add_argument( 342 | "--embedded-cfg-scale", 343 | type=float, 344 | default=6.0, 345 | help="Embeded classifier free guidance scale.", 346 | ) 347 | 348 | group.add_argument( 349 | "--reproduce", 350 | action="store_true", 351 | help="Enable reproducibility by setting random seeds and deterministic algorithms.", 352 | ) 353 | 354 | return parser 355 | 356 | 357 | def sanity_check_args(args): 358 | # VAE channels 359 | vae_pattern = r"\d{2,3}-\d{1,2}c-\w+" 360 | if not re.match(vae_pattern, args.vae): 361 | raise ValueError( 362 | f"Invalid VAE model: {args.vae}. Must be in the format of '{vae_pattern}'." 363 | ) 364 | vae_channels = int(args.vae.split("-")[1][:-1]) 365 | if args.latent_channels is None: 366 | args.latent_channels = vae_channels 367 | if vae_channels != args.latent_channels: 368 | raise ValueError( 369 | f"Latent channels ({args.latent_channels}) must match the VAE channels ({vae_channels})." 370 | ) 371 | return args 372 | -------------------------------------------------------------------------------- /hyvideo/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | __all__ = [ 5 | "C_SCALE", 6 | "PROMPT_TEMPLATE", 7 | "MODEL_BASE", 8 | "PRECISIONS", 9 | "NORMALIZATION_TYPE", 10 | "ACTIVATION_TYPE", 11 | "VAE_PATH", 12 | "TEXT_ENCODER_PATH", 13 | "TOKENIZER_PATH", 14 | "TEXT_PROJECTION", 15 | "DATA_TYPE", 16 | "NEGATIVE_PROMPT", 17 | "NEGATIVE_PROMPT_I2V", 18 | "FLOW_PATH_TYPE", 19 | "FLOW_PREDICT_TYPE", 20 | "FLOW_LOSS_WEIGHT", 21 | "FLOW_SNR_TYPE", 22 | "FLOW_SOLVER", 23 | ] 24 | 25 | PRECISION_TO_TYPE = { 26 | 'fp32': torch.float32, 27 | 'fp16': torch.float16, 28 | 'bf16': torch.bfloat16, 29 | 'fp8_e4m3fn': torch.float8_e4m3fn, 30 | } 31 | 32 | # =================== Constant Values ===================== 33 | # Computation scale factor, 1P = 1_000_000_000_000_000. Tensorboard will display the value in PetaFLOPS to avoid 34 | # overflow error when tensorboard logging values. 35 | C_SCALE = 1_000_000_000_000_000 36 | 37 | # When using decoder-only models, we must provide a prompt template to instruct the text encoder 38 | # on how to generate the text. 39 | # -------------------------------------------------------------------- 40 | PROMPT_TEMPLATE_ENCODE = ( 41 | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, " 42 | "quantity, text, spatial relationships of the objects and background:<|eot_id|>" 43 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" 44 | ) 45 | PROMPT_TEMPLATE_ENCODE_VIDEO = ( 46 | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " 47 | "1. The main content and theme of the video." 48 | "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." 49 | "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." 50 | "4. background environment, light, style and atmosphere." 51 | "5. camera angles, movements, and transitions used in the video:<|eot_id|>" 52 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" 53 | ) 54 | 55 | PROMPT_TEMPLATE_ENCODE_I2V = ( 56 | "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the image by detailing the color, shape, size, texture, " 57 | "quantity, text, spatial relationships of the objects and background:<|eot_id|>" 58 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" 59 | "<|start_header_id|>assistant<|end_header_id|>\n\n" 60 | ) 61 | 62 | PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = ( 63 | "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " 64 | "1. The main content and theme of the video." 65 | "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." 66 | "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." 67 | "4. background environment, light, style and atmosphere." 68 | "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n" 69 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" 70 | "<|start_header_id|>assistant<|end_header_id|>\n\n" 71 | ) 72 | 73 | NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion" 74 | NEGATIVE_PROMPT_I2V = "deformation, a poor composition and deformed video, bad teeth, bad eyes, bad limbs" 75 | 76 | PROMPT_TEMPLATE = { 77 | "dit-llm-encode": { 78 | "template": PROMPT_TEMPLATE_ENCODE, 79 | "crop_start": 36, 80 | }, 81 | "dit-llm-encode-video": { 82 | "template": PROMPT_TEMPLATE_ENCODE_VIDEO, 83 | "crop_start": 95, 84 | }, 85 | "dit-llm-encode-i2v": { 86 | "template": PROMPT_TEMPLATE_ENCODE_I2V, 87 | "crop_start": 36, 88 | "image_emb_start": 5, 89 | "image_emb_end": 581, 90 | "image_emb_len": 576, 91 | "double_return_token_id": 271 92 | }, 93 | "dit-llm-encode-video-i2v": { 94 | "template": PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, 95 | "crop_start": 103, 96 | "image_emb_start": 5, 97 | "image_emb_end": 581, 98 | "image_emb_len": 576, 99 | "double_return_token_id": 271 100 | }, 101 | } 102 | 103 | # ======================= Model ====================== 104 | PRECISIONS = {"fp32", "fp16", "bf16"} 105 | NORMALIZATION_TYPE = {"layer", "rms"} 106 | ACTIVATION_TYPE = {"relu", "silu", "gelu", "gelu_tanh"} 107 | 108 | # =================== Model Path ===================== 109 | MODEL_BASE = os.getenv("MODEL_BASE", "./ckpts") 110 | 111 | # =================== Data ======================= 112 | DATA_TYPE = {"image", "video", "image_video"} 113 | 114 | # 3D VAE 115 | VAE_PATH = {"884-16c-hy": f"{MODEL_BASE}/hunyuan-video-t2v-720p/vae"} 116 | 117 | # Text Encoder 118 | TEXT_ENCODER_PATH = { 119 | "clipL": f"{MODEL_BASE}/text_encoder_2", 120 | "llm": f"{MODEL_BASE}/text_encoder", 121 | "llm-i2v": f"{MODEL_BASE}/text_encoder_i2v", 122 | } 123 | 124 | # Tokenizer 125 | TOKENIZER_PATH = { 126 | "clipL": f"{MODEL_BASE}/text_encoder_2", 127 | "llm": f"{MODEL_BASE}/text_encoder", 128 | "llm-i2v": f"{MODEL_BASE}/text_encoder_i2v", 129 | } 130 | 131 | TEXT_PROJECTION = { 132 | "linear", # Default, an nn.Linear() layer 133 | "single_refiner", # Single TokenRefiner. Refer to LI-DiT 134 | } 135 | 136 | # Flow Matching path type 137 | FLOW_PATH_TYPE = { 138 | "linear", # Linear trajectory between noise and data 139 | "gvp", # Generalized variance-preserving SDE 140 | "vp", # Variance-preserving SDE 141 | } 142 | 143 | # Flow Matching predict type 144 | FLOW_PREDICT_TYPE = { 145 | "velocity", # Predict velocity 146 | "score", # Predict score 147 | "noise", # Predict noise 148 | } 149 | 150 | # Flow Matching loss weight 151 | FLOW_LOSS_WEIGHT = { 152 | "velocity", # Weight loss by velocity 153 | "likelihood", # Weight loss by likelihood 154 | } 155 | 156 | # Flow Matching SNR type 157 | FLOW_SNR_TYPE = { 158 | "lognorm", # Log-normal SNR 159 | "uniform", # Uniform SNR 160 | } 161 | 162 | # Flow Matching solvers 163 | FLOW_SOLVER = { 164 | "euler", # Euler solver 165 | } -------------------------------------------------------------------------------- /hyvideo/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipelines import HunyuanVideoPipeline 2 | from .schedulers import FlowMatchDiscreteScheduler 3 | -------------------------------------------------------------------------------- /hyvideo/diffusion/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_hunyuan_video import HunyuanVideoPipeline 2 | -------------------------------------------------------------------------------- /hyvideo/diffusion/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler 2 | -------------------------------------------------------------------------------- /hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # 16 | # Modified from diffusers==0.29.2 17 | # 18 | # ============================================================================== 19 | 20 | from dataclasses import dataclass 21 | from typing import Optional, Tuple, Union 22 | 23 | import numpy as np 24 | import torch 25 | 26 | from diffusers.configuration_utils import ConfigMixin, register_to_config 27 | from diffusers.utils import BaseOutput, logging 28 | from diffusers.schedulers.scheduling_utils import SchedulerMixin 29 | 30 | 31 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 32 | 33 | 34 | @dataclass 35 | class FlowMatchDiscreteSchedulerOutput(BaseOutput): 36 | """ 37 | Output class for the scheduler's `step` function output. 38 | 39 | Args: 40 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 41 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 42 | denoising loop. 43 | """ 44 | 45 | prev_sample: torch.FloatTensor 46 | 47 | 48 | class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin): 49 | """ 50 | Euler scheduler. 51 | 52 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic 53 | methods the library implements for all schedulers such as loading and saving. 54 | 55 | Args: 56 | num_train_timesteps (`int`, defaults to 1000): 57 | The number of diffusion steps to train the model. 58 | timestep_spacing (`str`, defaults to `"linspace"`): 59 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 60 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 61 | shift (`float`, defaults to 1.0): 62 | The shift value for the timestep schedule. 63 | reverse (`bool`, defaults to `True`): 64 | Whether to reverse the timestep schedule. 65 | """ 66 | 67 | _compatibles = [] 68 | order = 1 69 | 70 | @register_to_config 71 | def __init__( 72 | self, 73 | num_train_timesteps: int = 1000, 74 | flow_shift: float = 1.0, 75 | reverse: bool = True, 76 | solver: str = "euler", 77 | n_tokens: Optional[int] = None, 78 | ): 79 | sigmas = torch.linspace(1, 0, num_train_timesteps + 1) 80 | print("Scheduler config:", self.config) 81 | if not reverse: 82 | sigmas = sigmas.flip(0) 83 | self.flow_shift = flow_shift 84 | 85 | self.sigmas = sigmas 86 | # the value fed to model 87 | self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32) 88 | 89 | self._step_index = None 90 | self._begin_index = None 91 | 92 | self.supported_solver = ["euler"] 93 | if solver not in self.supported_solver: 94 | raise ValueError( 95 | f"Solver {solver} not supported. Supported solvers: {self.supported_solver}" 96 | ) 97 | 98 | @property 99 | def step_index(self): 100 | """ 101 | The index counter for current timestep. It will increase 1 after each scheduler step. 102 | """ 103 | return self._step_index 104 | 105 | @property 106 | def begin_index(self): 107 | """ 108 | The index for the first timestep. It should be set from pipeline with `set_begin_index` method. 109 | """ 110 | return self._begin_index 111 | 112 | # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index 113 | def set_begin_index(self, begin_index: int = 0): 114 | """ 115 | Sets the begin index for the scheduler. This function should be run from pipeline before the inference. 116 | 117 | Args: 118 | begin_index (`int`): 119 | The begin index for the scheduler. 120 | """ 121 | self._begin_index = begin_index 122 | 123 | def _sigma_to_t(self, sigma): 124 | return sigma * self.config.num_train_timesteps 125 | 126 | def set_timesteps( 127 | self, 128 | num_inference_steps: int, 129 | device: Union[str, torch.device] = None, 130 | n_tokens: int = None, 131 | ): 132 | """ 133 | Sets the discrete timesteps used for the diffusion chain (to be run before inference). 134 | 135 | Args: 136 | num_inference_steps (`int`): 137 | The number of diffusion steps used when generating samples with a pre-trained model. 138 | device (`str` or `torch.device`, *optional*): 139 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 140 | n_tokens (`int`, *optional*): 141 | Number of tokens in the input sequence. 142 | """ 143 | self.num_inference_steps = num_inference_steps 144 | 145 | sigmas = torch.linspace(1, 0, num_inference_steps + 1) 146 | sigmas = self.sd3_time_shift(sigmas) 147 | 148 | if not self.config.reverse: 149 | sigmas = 1 - sigmas 150 | 151 | self.sigmas = sigmas 152 | self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to( 153 | dtype=torch.float32, device=device 154 | ) 155 | 156 | # Reset step index 157 | self._step_index = None 158 | 159 | def index_for_timestep(self, timestep, schedule_timesteps=None): 160 | if schedule_timesteps is None: 161 | schedule_timesteps = self.timesteps 162 | 163 | indices = (schedule_timesteps == timestep).nonzero() 164 | 165 | # The sigma index that is taken for the **very** first `step` 166 | # is always the second index (or the last index if there is only 1) 167 | # This way we can ensure we don't accidentally skip a sigma in 168 | # case we start in the middle of the denoising schedule (e.g. for image-to-image) 169 | pos = 1 if len(indices) > 1 else 0 170 | 171 | return indices[pos].item() 172 | 173 | def _init_step_index(self, timestep): 174 | if self.begin_index is None: 175 | if isinstance(timestep, torch.Tensor): 176 | timestep = timestep.to(self.timesteps.device) 177 | self._step_index = self.index_for_timestep(timestep) 178 | else: 179 | self._step_index = self._begin_index 180 | 181 | def scale_model_input( 182 | self, sample: torch.Tensor, timestep: Optional[int] = None 183 | ) -> torch.Tensor: 184 | return sample 185 | 186 | def sd3_time_shift(self, t: torch.Tensor): 187 | return (self.flow_shift * t) / (1 + (self.flow_shift - 1) * t) 188 | 189 | def step( 190 | self, 191 | model_output: torch.FloatTensor, 192 | timestep: Union[float, torch.FloatTensor], 193 | sample: torch.FloatTensor, 194 | return_dict: bool = True, 195 | ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]: 196 | """ 197 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 198 | process from the learned model outputs (most often the predicted noise). 199 | 200 | Args: 201 | model_output (`torch.FloatTensor`): 202 | The direct output from learned diffusion model. 203 | timestep (`float`): 204 | The current discrete timestep in the diffusion chain. 205 | sample (`torch.FloatTensor`): 206 | A current instance of a sample created by the diffusion process. 207 | generator (`torch.Generator`, *optional*): 208 | A random number generator. 209 | n_tokens (`int`, *optional*): 210 | Number of tokens in the input sequence. 211 | return_dict (`bool`): 212 | Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or 213 | tuple. 214 | 215 | Returns: 216 | [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: 217 | If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is 218 | returned, otherwise a tuple is returned where the first element is the sample tensor. 219 | """ 220 | 221 | if ( 222 | isinstance(timestep, int) 223 | or isinstance(timestep, torch.IntTensor) 224 | or isinstance(timestep, torch.LongTensor) 225 | ): 226 | raise ValueError( 227 | ( 228 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 229 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 230 | " one of the `scheduler.timesteps` as a timestep." 231 | ), 232 | ) 233 | 234 | if self.step_index is None: 235 | self._init_step_index(timestep) 236 | 237 | # Upcast to avoid precision issues when computing prev_sample 238 | sample = sample.to(torch.float32) 239 | 240 | dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index] 241 | 242 | if self.config.solver == "euler": 243 | prev_sample = sample + model_output.to(torch.float32) * dt 244 | else: 245 | raise ValueError( 246 | f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}" 247 | ) 248 | 249 | # upon completion increase step index by one 250 | self._step_index += 1 251 | 252 | if not return_dict: 253 | return (prev_sample,) 254 | 255 | return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample) 256 | 257 | def __len__(self): 258 | return self.config.num_train_timesteps 259 | -------------------------------------------------------------------------------- /hyvideo/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import HYVideoDiffusionTransformer 2 | -------------------------------------------------------------------------------- /hyvideo/modules/activation_layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def get_activation_layer(act_type): 5 | """get activation layer 6 | 7 | Args: 8 | act_type (str): the activation type 9 | 10 | Returns: 11 | torch.nn.functional: the activation layer 12 | """ 13 | if act_type == "gelu": 14 | return lambda: nn.GELU() 15 | elif act_type == "gelu_tanh": 16 | # Approximate `tanh` requires torch >= 1.13 17 | return lambda: nn.GELU(approximate="tanh") 18 | elif act_type == "relu": 19 | return nn.ReLU 20 | elif act_type == "silu": 21 | return nn.SiLU 22 | else: 23 | raise ValueError(f"Unknown activation type: {act_type}") 24 | -------------------------------------------------------------------------------- /hyvideo/modules/attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | try: 7 | from flash_attn.flash_attn_interface import flash_attn_varlen_func 8 | except ImportError: 9 | flash_attn_varlen_func = None 10 | 11 | try: 12 | from sageattention import sageattn_varlen, sageattn 13 | @torch.compiler.disable() 14 | def sageattn_varlen_func( 15 | q, 16 | k, 17 | v, 18 | cu_seqlens_q, 19 | cu_seqlens_kv, 20 | max_seqlen_q, 21 | max_seqlen_kv, 22 | ): 23 | return sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) 24 | @torch.compiler.disable() 25 | def sageattn_func(q, k, v, attn_mask=None, dropout_p=0, is_causal=False): 26 | return sageattn(q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) 27 | except ImportError: 28 | sageattn_varlen_func = None 29 | 30 | from comfy.ldm.modules.attention import optimized_attention 31 | 32 | MEMORY_LAYOUT = { 33 | "flash_attn_varlen": ( 34 | lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), 35 | lambda x: x, 36 | ), 37 | "sageattn_varlen": ( 38 | lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), 39 | lambda x: x, 40 | ), 41 | "sdpa": ( 42 | lambda x: x.transpose(1, 2), 43 | lambda x: x.transpose(1, 2), 44 | ), 45 | "sageattn": ( 46 | lambda x: x.transpose(1, 2), 47 | lambda x: x.transpose(1, 2), 48 | ), 49 | "comfy": ( 50 | lambda x: x.transpose(1, 2), 51 | lambda x: x.transpose(1, 2), 52 | ), 53 | "vanilla": ( 54 | lambda x: x.transpose(1, 2), 55 | lambda x: x.transpose(1, 2), 56 | ), 57 | } 58 | 59 | 60 | def get_cu_seqlens(text_mask, img_len): 61 | """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len 62 | 63 | Args: 64 | text_mask (torch.Tensor): the mask of text 65 | img_len (int): the length of image 66 | 67 | Returns: 68 | torch.Tensor: the calculated cu_seqlens for flash attention 69 | """ 70 | batch_size = text_mask.shape[0] 71 | text_len = text_mask.sum(dim=1) 72 | max_len = text_mask.shape[1] + img_len 73 | 74 | cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda") 75 | 76 | for i in range(batch_size): 77 | s = text_len[i] + img_len 78 | s1 = i * max_len + s 79 | s2 = (i + 1) * max_len 80 | cu_seqlens[2 * i + 1] = s1 81 | cu_seqlens[2 * i + 2] = s2 82 | 83 | return cu_seqlens 84 | 85 | 86 | def attention( 87 | q, 88 | k, 89 | v, 90 | heads, 91 | mode="sdpa", 92 | drop_rate=0, 93 | attn_mask=None, 94 | causal=False, 95 | cu_seqlens_q=None, 96 | cu_seqlens_kv=None, 97 | max_seqlen_q=None, 98 | max_seqlen_kv=None, 99 | batch_size=1, 100 | do_stg=False, 101 | txt_len=-1, 102 | ): 103 | """ 104 | Perform QKV self attention. 105 | 106 | Args: 107 | q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. 108 | k (torch.Tensor): Key tensor with shape [b, s1, a, d] 109 | v (torch.Tensor): Value tensor with shape [b, s1, a, d] 110 | mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. 111 | drop_rate (float): Dropout rate in attention map. (default: 0) 112 | attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). 113 | (default: None) 114 | causal (bool): Whether to use causal attention. (default: False) 115 | cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, 116 | used to index into q. 117 | cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, 118 | used to index into kv. 119 | max_seqlen_q (int): The maximum sequence length in the batch of q. 120 | max_seqlen_kv (int): The maximum sequence length in the batch of k and v. 121 | 122 | Returns: 123 | torch.Tensor: Output tensor after self attention with shape [b, s, ad] 124 | """ 125 | pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] 126 | q = pre_attn_layout(q) 127 | k = pre_attn_layout(k) 128 | v = pre_attn_layout(v) 129 | 130 | if mode == "sdpa": 131 | if attn_mask is not None and attn_mask.dtype != torch.bool: 132 | attn_mask = attn_mask.to(q.dtype) 133 | 134 | if do_stg: 135 | batch_size = q.shape[0] 136 | q, q_perturb = q[:batch_size-1], q[batch_size-1:] 137 | k, k_perturb = k[:batch_size-1], k[batch_size-1:] 138 | v, v_perturb = v[:batch_size-1], v[batch_size-1:] 139 | if attn_mask is not None: 140 | attn_mask = attn_mask[:batch_size-1] 141 | #print(f"q: {q.shape}") 142 | #print(f"q_perturb: {q_perturb.shape}") 143 | #print(f"txt_len: {txt_len}") 144 | x = F.scaled_dot_product_attention( 145 | q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal 146 | ) 147 | #print(f"x: {x.shape}") 148 | batch_size = q_perturb.shape[0] 149 | seq_len = q_perturb.shape[2] 150 | num_heads = q_perturb.shape[1] 151 | identity_block_size = seq_len - txt_len 152 | full_mask = torch.zeros((seq_len, seq_len), dtype=q_perturb.dtype, device=q_perturb.device) 153 | full_mask[:identity_block_size, :identity_block_size] = float("-inf") 154 | full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0) 155 | 156 | full_mask = full_mask.unsqueeze(0).unsqueeze(0) 157 | full_mask = full_mask.expand(batch_size, num_heads, seq_len, seq_len) 158 | #print(f"full_mask: {full_mask.shape} is_causal: {causal}") 159 | x_perturb = F.scaled_dot_product_attention( 160 | q_perturb, k_perturb, v_perturb, attn_mask=full_mask, dropout_p=drop_rate, is_causal=causal, 161 | ) 162 | #print(f"x_perturb: {x_perturb.shape}") 163 | x = torch.cat([x, x_perturb], dim=0) 164 | else: 165 | x = F.scaled_dot_product_attention( 166 | q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal 167 | ) 168 | elif mode == "sageattn_varlen": 169 | x = sageattn_varlen_func( 170 | q, 171 | k, 172 | v, 173 | cu_seqlens_q, 174 | cu_seqlens_kv, 175 | max_seqlen_q, 176 | max_seqlen_kv, 177 | ) 178 | # x with shape [(bxs), a, d] 179 | x = x.view( 180 | batch_size, max_seqlen_q, x.shape[-2], x.shape[-1] 181 | ) # reshape x to [b, s, a, d] 182 | elif mode == "comfy": 183 | x = optimized_attention(q, k, v, mask=attn_mask, heads=heads, skip_reshape=True) 184 | elif mode == "sageattn": 185 | if attn_mask is not None and attn_mask.dtype != torch.bool: 186 | attn_mask = attn_mask.to(q.dtype) 187 | x = sageattn_func(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) 188 | elif mode == "flash_attn_varlen": 189 | x = flash_attn_varlen_func( 190 | q, 191 | k, 192 | v, 193 | cu_seqlens_q, 194 | cu_seqlens_kv, 195 | max_seqlen_q, 196 | max_seqlen_kv, 197 | ) 198 | # x with shape [(bxs), a, d] 199 | x = x.view( 200 | batch_size, max_seqlen_q, x.shape[-2], x.shape[-1] 201 | ) # reshape x to [b, s, a, d] 202 | elif mode == "vanilla": 203 | scale_factor = 1 / math.sqrt(q.size(-1)) 204 | 205 | b, a, s, _ = q.shape 206 | s1 = k.size(2) 207 | attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) 208 | if causal: 209 | # Only applied to self attention 210 | assert ( 211 | attn_mask is None 212 | ), "Causal mask and attn_mask cannot be used together" 213 | temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril( 214 | diagonal=0 215 | ) 216 | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) 217 | attn_bias.to(q.dtype) 218 | 219 | if attn_mask is not None: 220 | if attn_mask.dtype == torch.bool: 221 | attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) 222 | else: 223 | attn_bias += attn_mask 224 | 225 | # TODO: Maybe force q and k to be float32 to avoid numerical overflow 226 | attn = (q @ k.transpose(-2, -1)) * scale_factor 227 | attn += attn_bias 228 | attn = attn.softmax(dim=-1) 229 | attn = torch.dropout(attn, p=drop_rate, train=True) 230 | x = attn @ v 231 | else: 232 | raise NotImplementedError(f"Unsupported attention mode: {mode}") 233 | 234 | if mode != "comfy": 235 | x = post_attn_layout(x) 236 | b, s, a, d = x.shape 237 | return x.reshape(b, s, -1) 238 | return x 239 | -------------------------------------------------------------------------------- /hyvideo/modules/embed_layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from einops import rearrange, repeat 5 | 6 | from ..utils.helpers import to_2tuple 7 | 8 | 9 | class PatchEmbed(nn.Module): 10 | """2D Image to Patch Embedding 11 | 12 | Image to Patch Embedding using Conv2d 13 | 14 | A convolution based approach to patchifying a 2D image w/ embedding projection. 15 | 16 | Based on the impl in https://github.com/google-research/vision_transformer 17 | 18 | Hacked together by / Copyright 2020 Ross Wightman 19 | 20 | Remove the _assert function in forward function to be compatible with multi-resolution images. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | patch_size=16, 26 | in_chans=3, 27 | embed_dim=768, 28 | norm_layer=None, 29 | flatten=True, 30 | bias=True, 31 | dtype=None, 32 | device=None, 33 | ): 34 | factory_kwargs = {"dtype": dtype, "device": device} 35 | super().__init__() 36 | patch_size = to_2tuple(patch_size) 37 | self.patch_size = patch_size 38 | self.flatten = flatten 39 | 40 | self.proj = nn.Conv3d( 41 | in_chans, 42 | embed_dim, 43 | kernel_size=patch_size, 44 | stride=patch_size, 45 | bias=bias, 46 | **factory_kwargs 47 | ) 48 | nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1)) 49 | if bias: 50 | nn.init.zeros_(self.proj.bias) 51 | 52 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 53 | 54 | def forward(self, x): 55 | x = self.proj(x) 56 | if self.flatten: 57 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 58 | x = self.norm(x) 59 | return x 60 | 61 | 62 | class TextProjection(nn.Module): 63 | """ 64 | Projects text embeddings. Also handles dropout for classifier-free guidance. 65 | 66 | Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py 67 | """ 68 | 69 | def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): 70 | factory_kwargs = {"dtype": dtype, "device": device} 71 | super().__init__() 72 | self.linear_1 = nn.Linear( 73 | in_features=in_channels, 74 | out_features=hidden_size, 75 | bias=True, 76 | **factory_kwargs 77 | ) 78 | self.act_1 = act_layer() 79 | self.linear_2 = nn.Linear( 80 | in_features=hidden_size, 81 | out_features=hidden_size, 82 | bias=True, 83 | **factory_kwargs 84 | ) 85 | 86 | def forward(self, caption): 87 | hidden_states = self.linear_1(caption) 88 | hidden_states = self.act_1(hidden_states) 89 | hidden_states = self.linear_2(hidden_states) 90 | return hidden_states 91 | 92 | 93 | def timestep_embedding(t, dim, max_period=10000): 94 | """ 95 | Create sinusoidal timestep embeddings. 96 | 97 | Args: 98 | t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. 99 | dim (int): the dimension of the output. 100 | max_period (int): controls the minimum frequency of the embeddings. 101 | 102 | Returns: 103 | embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. 104 | 105 | .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 106 | """ 107 | half = dim // 2 108 | freqs = torch.exp( 109 | -math.log(max_period) 110 | * torch.arange(start=0, end=half, dtype=torch.float32) 111 | / half 112 | ).to(device=t.device) 113 | args = t[:, None].float() * freqs[None] 114 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 115 | if dim % 2: 116 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 117 | return embedding 118 | 119 | 120 | class TimestepEmbedder(nn.Module): 121 | """ 122 | Embeds scalar timesteps into vector representations. 123 | """ 124 | 125 | def __init__( 126 | self, 127 | hidden_size, 128 | act_layer, 129 | frequency_embedding_size=256, 130 | max_period=10000, 131 | out_size=None, 132 | dtype=None, 133 | device=None, 134 | ): 135 | factory_kwargs = {"dtype": dtype, "device": device} 136 | super().__init__() 137 | self.frequency_embedding_size = frequency_embedding_size 138 | self.max_period = max_period 139 | if out_size is None: 140 | out_size = hidden_size 141 | 142 | self.mlp = nn.Sequential( 143 | nn.Linear( 144 | frequency_embedding_size, hidden_size, bias=True, **factory_kwargs 145 | ), 146 | act_layer(), 147 | nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), 148 | ) 149 | nn.init.normal_(self.mlp[0].weight, std=0.02) 150 | nn.init.normal_(self.mlp[2].weight, std=0.02) 151 | 152 | def forward(self, t): 153 | t_freq = timestep_embedding( 154 | t, self.frequency_embedding_size, self.max_period 155 | ).type(self.mlp[0].weight.dtype) 156 | t_emb = self.mlp(t_freq) 157 | return t_emb 158 | -------------------------------------------------------------------------------- /hyvideo/modules/fp8_map.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-HunyuanVideoWrapper/3ce9640497139b14462910b0dbf2be1df855c1d6/hyvideo/modules/fp8_map.safetensors -------------------------------------------------------------------------------- /hyvideo/modules/fp8_optimization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | from comfy.utils import load_torch_file 6 | 7 | @torch.compiler.disable() 8 | def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1): 9 | _bits = torch.tensor(bits) 10 | _mantissa_bit = torch.tensor(mantissa_bit) 11 | _sign_bits = torch.tensor(sign_bits) 12 | M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits) 13 | E = _bits - _sign_bits - M 14 | bias = 2 ** (E - 1) - 1 15 | mantissa = 1 16 | for i in range(mantissa_bit - 1): 17 | mantissa += 1 / (2 ** (i+1)) 18 | maxval = mantissa * 2 ** (2**E - 1 - bias) 19 | return maxval 20 | 21 | @torch.compiler.disable() 22 | def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1): 23 | """ 24 | Default is E4M3. 25 | """ 26 | bits = torch.tensor(bits) 27 | mantissa_bit = torch.tensor(mantissa_bit) 28 | sign_bits = torch.tensor(sign_bits) 29 | M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits) 30 | E = bits - sign_bits - M 31 | bias = 2 ** (E - 1) - 1 32 | mantissa = 1 33 | for i in range(mantissa_bit - 1): 34 | mantissa += 1 / (2 ** (i+1)) 35 | maxval = mantissa * 2 ** (2**E - 1 - bias) 36 | minval = - maxval 37 | minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval) 38 | input_clamp = torch.min(torch.max(x, minval), maxval) 39 | log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0) 40 | log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype)) 41 | # dequant 42 | qdq_out = torch.round(input_clamp / log_scales) * log_scales 43 | return qdq_out, log_scales 44 | 45 | @torch.compiler.disable() 46 | def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1): 47 | for i in range(len(x.shape) - 1): 48 | scale = scale.unsqueeze(-1) 49 | new_x = x / scale 50 | quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits) 51 | return quant_dequant_x, scale, log_scales 52 | 53 | @torch.compiler.disable() 54 | def fp8_activation_dequant(qdq_out, dtype): 55 | qdq_out = qdq_out.type(dtype) 56 | return qdq_out 57 | 58 | def fp8_linear_forward(cls, original_dtype, input): 59 | weight_dtype = cls.weight.dtype 60 | ##### 61 | if cls.weight.dtype != torch.float8_e4m3fn: 62 | maxval = get_fp_maxval() 63 | scale = torch.max(torch.abs(cls.weight.flatten())) / maxval 64 | linear_weight, scale, log_scales = fp8_tensor_quant(cls.weight, scale) 65 | linear_weight = linear_weight.to(torch.float8_e4m3fn) 66 | weight_dtype = linear_weight.dtype 67 | else: 68 | scale = cls.fp8_scale#.to(cls.weight.device) 69 | linear_weight = cls.weight 70 | ##### 71 | 72 | #if weight_dtype == torch.float8_e4m3fn and cls.weight.sum() != 0: 73 | if weight_dtype == torch.float8_e4m3fn: 74 | qdq_out = fp8_activation_dequant(linear_weight, original_dtype) 75 | cls_dequant = qdq_out * scale 76 | if cls.bias != None: 77 | output = F.linear(input, cls_dequant, cls.bias) 78 | else: 79 | output = F.linear(input, cls_dequant) 80 | return output 81 | else: 82 | return cls.original_forward(input) 83 | 84 | def convert_fp8_linear(module, original_dtype, device, fp8_scale_map={}): 85 | setattr(module, "fp8_matmul_enabled", True) 86 | script_directory = os.path.dirname(os.path.abspath(__file__)) 87 | 88 | # loading fp8 mapping file 89 | if not fp8_scale_map: 90 | fp8_map_path = os.path.join(script_directory,"fp8_map.safetensors") 91 | if os.path.exists(fp8_map_path): 92 | fp8_map = load_torch_file(fp8_map_path, safe_load=True) 93 | else: 94 | raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.") 95 | 96 | #fp8_layers = [] 97 | for key, layer in module.named_modules(): 98 | if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key): 99 | #fp8_layers.append(key) 100 | original_forward = layer.forward 101 | #layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn)) 102 | setattr(layer, "fp8_scale", fp8_map[key].to(device=device, dtype=original_dtype)) 103 | setattr(layer, "original_forward", original_forward) 104 | setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input)) 105 | -------------------------------------------------------------------------------- /hyvideo/modules/mlp_layers.py: -------------------------------------------------------------------------------- 1 | # Modified from timm library: 2 | # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13 3 | 4 | from functools import partial 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from .modulate_layers import modulate 10 | from ..utils.helpers import to_2tuple 11 | 12 | 13 | class MLP(nn.Module): 14 | """MLP as used in Vision Transformer, MLP-Mixer and related networks""" 15 | 16 | def __init__( 17 | self, 18 | in_channels, 19 | hidden_channels=None, 20 | out_features=None, 21 | act_layer=nn.GELU, 22 | norm_layer=None, 23 | bias=True, 24 | drop=0.0, 25 | use_conv=False, 26 | device=None, 27 | dtype=None, 28 | ): 29 | factory_kwargs = {"device": device, "dtype": dtype} 30 | super().__init__() 31 | out_features = out_features or in_channels 32 | hidden_channels = hidden_channels or in_channels 33 | bias = to_2tuple(bias) 34 | drop_probs = to_2tuple(drop) 35 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear 36 | 37 | self.fc1 = linear_layer( 38 | in_channels, hidden_channels, bias=bias[0], **factory_kwargs 39 | ) 40 | self.act = act_layer() 41 | self.drop1 = nn.Dropout(drop_probs[0]) 42 | self.norm = ( 43 | norm_layer(hidden_channels, **factory_kwargs) 44 | if norm_layer is not None 45 | else nn.Identity() 46 | ) 47 | self.fc2 = linear_layer( 48 | hidden_channels, out_features, bias=bias[1], **factory_kwargs 49 | ) 50 | self.drop2 = nn.Dropout(drop_probs[1]) 51 | 52 | def forward(self, x): 53 | x = self.fc1(x) 54 | x = self.act(x) 55 | x = self.drop1(x) 56 | x = self.norm(x) 57 | x = self.fc2(x) 58 | x = self.drop2(x) 59 | return x 60 | 61 | 62 | # 63 | class MLPEmbedder(nn.Module): 64 | """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py""" 65 | def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None): 66 | factory_kwargs = {"device": device, "dtype": dtype} 67 | super().__init__() 68 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs) 69 | self.silu = nn.SiLU() 70 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs) 71 | 72 | def forward(self, x: torch.Tensor) -> torch.Tensor: 73 | return self.out_layer(self.silu(self.in_layer(x))) 74 | 75 | 76 | class FinalLayer(nn.Module): 77 | """The final layer of DiT.""" 78 | 79 | def __init__( 80 | self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None 81 | ): 82 | factory_kwargs = {"device": device, "dtype": dtype} 83 | super().__init__() 84 | 85 | # Just use LayerNorm for the final layer 86 | self.norm_final = nn.LayerNorm( 87 | hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs 88 | ) 89 | if isinstance(patch_size, int): 90 | self.linear = nn.Linear( 91 | hidden_size, 92 | patch_size * patch_size * out_channels, 93 | bias=True, 94 | **factory_kwargs 95 | ) 96 | else: 97 | self.linear = nn.Linear( 98 | hidden_size, 99 | patch_size[0] * patch_size[1] * patch_size[2] * out_channels, 100 | bias=True, 101 | ) 102 | nn.init.zeros_(self.linear.weight) 103 | nn.init.zeros_(self.linear.bias) 104 | 105 | # Here we don't distinguish between the modulate types. Just use the simple one. 106 | self.adaLN_modulation = nn.Sequential( 107 | act_layer(), 108 | nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), 109 | ) 110 | # Zero-initialize the modulation 111 | nn.init.zeros_(self.adaLN_modulation[1].weight) 112 | nn.init.zeros_(self.adaLN_modulation[1].bias) 113 | 114 | def forward(self, x, c): 115 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 116 | x = modulate(self.norm_final(x), shift=shift, scale=scale) 117 | x = self.linear(x) 118 | return x 119 | -------------------------------------------------------------------------------- /hyvideo/modules/modulate_layers.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | class ModulateDiT(nn.Module): 7 | """Modulation layer for DiT.""" 8 | def __init__( 9 | self, 10 | hidden_size: int, 11 | factor: int, 12 | act_layer: Callable, 13 | dtype=None, 14 | device=None, 15 | ): 16 | factory_kwargs = {"dtype": dtype, "device": device} 17 | super().__init__() 18 | self.act = act_layer() 19 | self.linear = nn.Linear( 20 | hidden_size, factor * hidden_size, bias=True, **factory_kwargs 21 | ) 22 | # Zero-initialize the modulation 23 | nn.init.zeros_(self.linear.weight) 24 | nn.init.zeros_(self.linear.bias) 25 | 26 | def forward(self, x: torch.Tensor, condition_type=None, token_replace_vec=None) -> torch.Tensor: 27 | 28 | x_out = self.linear(self.act(x)) 29 | 30 | if condition_type == "token_replace": 31 | x_token_replace_out = self.linear(self.act(token_replace_vec)) 32 | return x_out, x_token_replace_out 33 | else: 34 | return x_out 35 | 36 | def modulate(x, shift=None, scale=None, condition_type=None, 37 | tr_shift=None, tr_scale=None, 38 | first_frame_token_num=None): 39 | """modulate by shift and scale 40 | 41 | Args: 42 | x (torch.Tensor): input tensor. 43 | shift (torch.Tensor, optional): shift tensor. Defaults to None. 44 | scale (torch.Tensor, optional): scale tensor. Defaults to None. 45 | 46 | Returns: 47 | torch.Tensor: the output tensor after modulate. 48 | """ 49 | if condition_type == "token_replace": 50 | x_zero = x[:, :first_frame_token_num] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1) 51 | x_orig = x[:, first_frame_token_num:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 52 | x = torch.concat((x_zero, x_orig), dim=1) 53 | return x 54 | else: 55 | if scale is None and shift is None: 56 | return x 57 | elif shift is None: 58 | return x * (1 + scale.unsqueeze(1)) 59 | elif scale is None: 60 | return x + shift.unsqueeze(1) 61 | else: 62 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 63 | 64 | 65 | def apply_gate(x, gate=None, tanh=False, condition_type=None, tr_gate=None, first_frame_token_num=None): 66 | """AI is creating summary for apply_gate 67 | 68 | Args: 69 | x (torch.Tensor): input tensor. 70 | gate (torch.Tensor, optional): gate tensor. Defaults to None. 71 | tanh (bool, optional): whether to use tanh function. Defaults to False. 72 | 73 | Returns: 74 | torch.Tensor: the output tensor after apply gate. 75 | """ 76 | if condition_type == "token_replace": 77 | if gate is None: 78 | return x 79 | if tanh: 80 | x_zero = x[:, :first_frame_token_num] * tr_gate.unsqueeze(1).tanh() 81 | x_orig = x[:, first_frame_token_num:] * gate.unsqueeze(1).tanh() 82 | x = torch.concat((x_zero, x_orig), dim=1) 83 | return x 84 | else: 85 | x_zero = x[:, :first_frame_token_num] * tr_gate.unsqueeze(1) 86 | x_orig = x[:, first_frame_token_num:] * gate.unsqueeze(1) 87 | x = torch.concat((x_zero, x_orig), dim=1) 88 | return x 89 | else: 90 | if gate is None: 91 | return x 92 | if tanh: 93 | return x * gate.unsqueeze(1).tanh() 94 | else: 95 | return x * gate.unsqueeze(1) 96 | 97 | 98 | def ckpt_wrapper(module): 99 | def ckpt_forward(*inputs): 100 | outputs = module(*inputs) 101 | return outputs 102 | 103 | return ckpt_forward -------------------------------------------------------------------------------- /hyvideo/modules/norm_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class RMSNorm(nn.Module): 6 | def __init__( 7 | self, 8 | dim: int, 9 | elementwise_affine=True, 10 | eps: float = 1e-6, 11 | device=None, 12 | dtype=None, 13 | ): 14 | """ 15 | Initialize the RMSNorm normalization layer. 16 | 17 | Args: 18 | dim (int): The dimension of the input tensor. 19 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 20 | 21 | Attributes: 22 | eps (float): A small value added to the denominator for numerical stability. 23 | weight (nn.Parameter): Learnable scaling parameter. 24 | 25 | """ 26 | factory_kwargs = {"device": device, "dtype": dtype} 27 | super().__init__() 28 | self.eps = eps 29 | if elementwise_affine: 30 | self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) 31 | 32 | def _norm(self, x): 33 | """ 34 | Apply the RMSNorm normalization to the input tensor. 35 | 36 | Args: 37 | x (torch.Tensor): The input tensor. 38 | 39 | Returns: 40 | torch.Tensor: The normalized tensor. 41 | 42 | """ 43 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 44 | 45 | def forward(self, x): 46 | """ 47 | Forward pass through the RMSNorm layer. 48 | 49 | Args: 50 | x (torch.Tensor): The input tensor. 51 | 52 | Returns: 53 | torch.Tensor: The output tensor after applying RMSNorm. 54 | 55 | """ 56 | output = self._norm(x.float()).type_as(x) 57 | if hasattr(self, "weight"): 58 | output = output * self.weight 59 | return output 60 | 61 | 62 | def get_norm_layer(norm_layer): 63 | """ 64 | Get the normalization layer. 65 | 66 | Args: 67 | norm_layer (str): The type of normalization layer. 68 | 69 | Returns: 70 | norm_layer (nn.Module): The normalization layer. 71 | """ 72 | if norm_layer == "layer": 73 | return nn.LayerNorm 74 | elif norm_layer == "rms": 75 | return RMSNorm 76 | else: 77 | raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") 78 | -------------------------------------------------------------------------------- /hyvideo/modules/posemb_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, Tuple, List 3 | 4 | 5 | def _to_tuple(x, dim=2): 6 | if isinstance(x, int): 7 | return (x,) * dim 8 | elif len(x) == dim: 9 | return x 10 | else: 11 | raise ValueError(f"Expected length {dim} or int, but got {x}") 12 | 13 | 14 | def get_meshgrid_nd(start, *args, dim=2): 15 | """ 16 | Get n-D meshgrid with start, stop and num. 17 | 18 | Args: 19 | start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, 20 | step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num 21 | should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in 22 | n-tuples. 23 | *args: See above. 24 | dim (int): Dimension of the meshgrid. Defaults to 2. 25 | 26 | Returns: 27 | grid (np.ndarray): [dim, ...] 28 | """ 29 | if len(args) == 0: 30 | # start is grid_size 31 | num = _to_tuple(start, dim=dim) 32 | start = (0,) * dim 33 | stop = num 34 | elif len(args) == 1: 35 | # start is start, args[0] is stop, step is 1 36 | start = _to_tuple(start, dim=dim) 37 | stop = _to_tuple(args[0], dim=dim) 38 | num = [stop[i] - start[i] for i in range(dim)] 39 | elif len(args) == 2: 40 | # start is start, args[0] is stop, args[1] is num 41 | start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 42 | stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 43 | num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 44 | else: 45 | raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") 46 | 47 | # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) 48 | axis_grid = [] 49 | for i in range(dim): 50 | a, b, n = start[i], stop[i], num[i] 51 | g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] 52 | axis_grid.append(g) 53 | grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] 54 | grid = torch.stack(grid, dim=0) # [dim, W, H, D] 55 | 56 | return grid 57 | 58 | 59 | ################################################################################# 60 | # Rotary Positional Embedding Functions # 61 | ################################################################################# 62 | # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80 63 | 64 | 65 | def apply_rotary(x, cos, sin): 66 | x_reshaped = x.view(*x.shape[:-1], -1, 2) 67 | x1, x2 = x_reshaped.unbind(-1) 68 | x_rotated = torch.stack([-x2, x1], dim=-1).flatten(3) 69 | return (x * cos) + (x_rotated * sin) 70 | 71 | def apply_rotary_emb( 72 | xq: torch.Tensor, 73 | xk: torch.Tensor, 74 | freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], 75 | upcast: bool = False, 76 | ) -> Tuple[torch.Tensor, torch.Tensor]: 77 | """ 78 | Apply rotary embeddings to input tensors using the given frequency tensor. 79 | 80 | This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided 81 | frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor 82 | is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are 83 | returned as real tensors. 84 | 85 | Args: 86 | xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] 87 | xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] 88 | freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential. 89 | 90 | Returns: 91 | Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. 92 | 93 | """ 94 | 95 | shape = [d if i == 1 or i == xq.ndim - 1 else 1 for i, d in enumerate(xq.shape)] 96 | cos, sin = freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) 97 | 98 | if upcast: 99 | xq_out = apply_rotary(xq.float(), cos, sin).to(xq.dtype) 100 | xk_out = apply_rotary(xk.float(), cos, sin).to(xk.dtype) 101 | else: 102 | xq_out = apply_rotary(xq, cos, sin) 103 | xk_out = apply_rotary(xk, cos, sin) 104 | 105 | return xq_out, xk_out 106 | 107 | 108 | def get_nd_rotary_pos_embed( 109 | rope_dim_list, 110 | start, 111 | *args, 112 | theta=10000.0, 113 | use_real=False, 114 | theta_rescale_factor: Union[float, List[float]] = 1.0, 115 | interpolation_factor: Union[float, List[float]] = 1.0, 116 | num_frames: int = 129, 117 | k: int = 0, 118 | ): 119 | """ 120 | This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. 121 | 122 | Args: 123 | rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. 124 | sum(rope_dim_list) should equal to head_dim of attention layer. 125 | start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, 126 | args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. 127 | *args: See above. 128 | theta (float): Scaling factor for frequency computation. Defaults to 10000.0. 129 | use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. 130 | Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real 131 | part and an imaginary part separately. 132 | theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. 133 | 134 | Returns: 135 | pos_embed (torch.Tensor): [HW, D/2] 136 | """ 137 | 138 | grid = get_meshgrid_nd( 139 | start, *args, dim=len(rope_dim_list) 140 | ) # [3, W, H, D] / [2, W, H] 141 | 142 | if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): 143 | theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) 144 | elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: 145 | theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) 146 | assert len(theta_rescale_factor) == len( 147 | rope_dim_list 148 | ), "len(theta_rescale_factor) should equal to len(rope_dim_list)" 149 | 150 | if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): 151 | interpolation_factor = [interpolation_factor] * len(rope_dim_list) 152 | elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: 153 | interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) 154 | assert len(interpolation_factor) == len( 155 | rope_dim_list 156 | ), "len(interpolation_factor) should equal to len(rope_dim_list)" 157 | 158 | # use 1/ndim of dimensions to encode grid_axis 159 | embs = [] 160 | for i in range(len(rope_dim_list)): 161 | if i == 0: 162 | emb = get_1d_rotary_pos_embed_riflex( 163 | rope_dim_list[i], 164 | grid[i].reshape(-1), 165 | theta, 166 | use_real=use_real, 167 | theta_rescale_factor=theta_rescale_factor[i], 168 | interpolation_factor=interpolation_factor[i], 169 | L_test=num_frames, 170 | k=k, 171 | ) # 2 x [WHD, rope_dim_list[i]] 172 | else: 173 | emb = get_1d_rotary_pos_embed( 174 | rope_dim_list[i], 175 | grid[i].reshape(-1), 176 | theta, 177 | use_real=use_real, 178 | theta_rescale_factor=theta_rescale_factor[i], 179 | interpolation_factor=interpolation_factor[i], 180 | ) 181 | embs.append(emb) 182 | 183 | if use_real: 184 | cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) 185 | sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) 186 | return cos, sin 187 | else: 188 | emb = torch.cat(embs, dim=1) # (WHD, D/2) 189 | return emb 190 | 191 | 192 | def get_1d_rotary_pos_embed( 193 | dim: int, 194 | pos: Union[torch.FloatTensor, int], 195 | theta: float = 10000.0, 196 | use_real: bool = False, 197 | theta_rescale_factor: float = 1.0, 198 | interpolation_factor: float = 1.0, 199 | L_test: int = 100, 200 | k: int = 0, 201 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 202 | """ 203 | Precompute the frequency tensor for complex exponential (cis) with given dimensions. 204 | (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) 205 | 206 | This function calculates a frequency tensor with complex exponential using the given dimension 'dim' 207 | and the end index 'end'. The 'theta' parameter scales the frequencies. 208 | The returned tensor contains complex values in complex64 data type. 209 | 210 | Args: 211 | dim (int): Dimension of the frequency tensor. 212 | pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar 213 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. 214 | use_real (bool, optional): If True, return real part and imaginary part separately. 215 | Otherwise, return complex numbers. 216 | theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. 217 | 218 | Returns: 219 | freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] 220 | freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] 221 | """ 222 | if isinstance(pos, int): 223 | pos = torch.arange(pos).float() 224 | 225 | # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning 226 | # has some connection to NTK literature 227 | if theta_rescale_factor != 1.0: 228 | theta *= theta_rescale_factor ** (dim / (dim - 2)) 229 | 230 | freqs = 1.0 / ( 231 | theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) 232 | ) # [D/2] 233 | # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" 234 | 235 | freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] 236 | if use_real: 237 | freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] 238 | freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] 239 | return freqs_cos, freqs_sin 240 | else: 241 | freqs_cis = torch.polar( 242 | torch.ones_like(freqs), freqs 243 | ) # complex64 # [S, D/2] 244 | return freqs_cis 245 | 246 | def get_1d_rotary_pos_embed_riflex( 247 | dim: int, 248 | pos: Union[torch.FloatTensor, int], 249 | theta: float = 10000.0, 250 | use_real: bool = False, 251 | theta_rescale_factor: float = 1.0, 252 | interpolation_factor: float = 1.0, 253 | L_test: int = 66, 254 | k: int = 0, 255 | N_k: int=50 256 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 257 | """ 258 | Precompute the frequency tensor for complex exponential (cis) with given dimensions. 259 | (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) 260 | 261 | This function calculates a frequency tensor with complex exponential using the given dimension 'dim' 262 | and the end index 'end'. The 'theta' parameter scales the frequencies. 263 | The returned tensor contains complex values in complex64 data type. 264 | 265 | Args: 266 | dim (int): Dimension of the frequency tensor. 267 | pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar 268 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. 269 | use_real (bool, optional): If True, return real part and imaginary part separately. 270 | Otherwise, return complex numbers. 271 | theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. 272 | 273 | Returns: 274 | freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] 275 | freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] 276 | """ 277 | if isinstance(pos, int): 278 | pos = torch.arange(pos).float() 279 | 280 | # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning 281 | # has some connection to NTK literature 282 | if theta_rescale_factor != 1.0: 283 | theta *= theta_rescale_factor ** (dim / (dim - 2)) 284 | 285 | freqs = 1.0 / ( 286 | theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) 287 | ) # [D/2] 288 | # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" 289 | 290 | #RIFLEx https://github.com/thu-ml/RIFLEx 291 | if k > 0 and L_test > N_k: 292 | freqs[k-1] = 0.9 * 2 * torch.pi / L_test 293 | 294 | 295 | freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] 296 | if use_real: 297 | freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] 298 | freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] 299 | return freqs_cos, freqs_sin 300 | else: 301 | freqs_cis = torch.polar( 302 | torch.ones_like(freqs), freqs 303 | ) # complex64 # [S, D/2] 304 | return freqs_cis 305 | 306 | def get_nd_rotary_pos_embed_new(rope_dim_list, start, *args, theta=10000., use_real=False, 307 | theta_rescale_factor: Union[float, List[float]]=1.0, 308 | interpolation_factor: Union[float, List[float]]=1.0, 309 | concat_dict = {'mode': 'timecat-w', 'bias': -1}, num_frames: int = 129, k: int = 0, 310 | ): 311 | 312 | grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] 313 | if len(concat_dict)<1: 314 | pass 315 | else: 316 | if concat_dict['mode']=='timecat': 317 | bias = grid[:,:1].clone() 318 | bias[0] = concat_dict['bias']*torch.ones_like(bias[0]) 319 | grid = torch.cat([bias, grid], dim=1) 320 | 321 | elif concat_dict['mode']=='timecat-w': 322 | bias = grid[:,:1].clone() 323 | bias[0] = concat_dict['bias']*torch.ones_like(bias[0]) 324 | bias[2] += start[-1] ## ref https://github.com/Yuanshi9815/OminiControl/blob/main/src/generate.py#L178 325 | grid = torch.cat([bias, grid], dim=1) 326 | if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): 327 | theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) 328 | elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: 329 | theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) 330 | assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)" 331 | 332 | if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): 333 | interpolation_factor = [interpolation_factor] * len(rope_dim_list) 334 | elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: 335 | interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) 336 | assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)" 337 | 338 | # use 1/ndim of dimensions to encode grid_axis 339 | embs = [] 340 | for i in range(len(rope_dim_list)): 341 | emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=use_real, 342 | theta_rescale_factor=theta_rescale_factor[i], 343 | interpolation_factor=interpolation_factor[i]) # 2 x [WHD, rope_dim_list[i]] 344 | 345 | embs.append(emb) 346 | 347 | if use_real: 348 | cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) 349 | sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) 350 | return cos, sin 351 | else: 352 | emb = torch.cat(embs, dim=1) # (WHD, D/2) 353 | return emb -------------------------------------------------------------------------------- /hyvideo/modules/token_refiner.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from einops import rearrange 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .activation_layers import get_activation_layer 8 | from .attention import attention 9 | from .norm_layers import get_norm_layer 10 | from .embed_layers import TimestepEmbedder, TextProjection 11 | from .attention import attention 12 | from .mlp_layers import MLP 13 | from .modulate_layers import modulate, apply_gate 14 | 15 | 16 | class IndividualTokenRefinerBlock(nn.Module): 17 | def __init__( 18 | self, 19 | hidden_size, 20 | heads_num, 21 | mlp_width_ratio: str = 4.0, 22 | mlp_drop_rate: float = 0.0, 23 | act_type: str = "silu", 24 | qk_norm: bool = False, 25 | qk_norm_type: str = "layer", 26 | qkv_bias: bool = True, 27 | dtype: Optional[torch.dtype] = None, 28 | device: Optional[torch.device] = None, 29 | ): 30 | factory_kwargs = {"device": device, "dtype": dtype} 31 | super().__init__() 32 | self.heads_num = heads_num 33 | head_dim = hidden_size // heads_num 34 | mlp_hidden_dim = int(hidden_size * mlp_width_ratio) 35 | 36 | self.norm1 = nn.LayerNorm( 37 | hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs 38 | ) 39 | self.self_attn_qkv = nn.Linear( 40 | hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs 41 | ) 42 | qk_norm_layer = get_norm_layer(qk_norm_type) 43 | self.self_attn_q_norm = ( 44 | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) 45 | if qk_norm 46 | else nn.Identity() 47 | ) 48 | self.self_attn_k_norm = ( 49 | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) 50 | if qk_norm 51 | else nn.Identity() 52 | ) 53 | self.self_attn_proj = nn.Linear( 54 | hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs 55 | ) 56 | 57 | self.norm2 = nn.LayerNorm( 58 | hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs 59 | ) 60 | act_layer = get_activation_layer(act_type) 61 | self.mlp = MLP( 62 | in_channels=hidden_size, 63 | hidden_channels=mlp_hidden_dim, 64 | act_layer=act_layer, 65 | drop=mlp_drop_rate, 66 | **factory_kwargs, 67 | ) 68 | 69 | self.adaLN_modulation = nn.Sequential( 70 | act_layer(), 71 | nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), 72 | ) 73 | # Zero-initialize the modulation 74 | nn.init.zeros_(self.adaLN_modulation[1].weight) 75 | nn.init.zeros_(self.adaLN_modulation[1].bias) 76 | 77 | def forward( 78 | self, 79 | x: torch.Tensor, 80 | c: torch.Tensor, # timestep_aware_representations + context_aware_representations 81 | attn_mask: torch.Tensor = None, 82 | ): 83 | gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) 84 | 85 | norm_x = self.norm1(x) 86 | qkv = self.self_attn_qkv(norm_x) 87 | q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) 88 | # Apply QK-Norm if needed 89 | q = self.self_attn_q_norm(q).to(v) 90 | k = self.self_attn_k_norm(k).to(v) 91 | 92 | # Self-Attention 93 | attn = attention(q, k, v, heads = self.heads_num,mode="sdpa", attn_mask=attn_mask) 94 | 95 | x = x + apply_gate(self.self_attn_proj(attn), gate_msa) 96 | 97 | # FFN Layer 98 | x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp) 99 | 100 | return x 101 | 102 | 103 | class IndividualTokenRefiner(nn.Module): 104 | def __init__( 105 | self, 106 | hidden_size, 107 | heads_num, 108 | depth, 109 | mlp_width_ratio: float = 4.0, 110 | mlp_drop_rate: float = 0.0, 111 | act_type: str = "silu", 112 | qk_norm: bool = False, 113 | qk_norm_type: str = "layer", 114 | qkv_bias: bool = True, 115 | dtype: Optional[torch.dtype] = None, 116 | device: Optional[torch.device] = None, 117 | ): 118 | factory_kwargs = {"device": device, "dtype": dtype} 119 | super().__init__() 120 | self.blocks = nn.ModuleList( 121 | [ 122 | IndividualTokenRefinerBlock( 123 | hidden_size=hidden_size, 124 | heads_num=heads_num, 125 | mlp_width_ratio=mlp_width_ratio, 126 | mlp_drop_rate=mlp_drop_rate, 127 | act_type=act_type, 128 | qk_norm=qk_norm, 129 | qk_norm_type=qk_norm_type, 130 | qkv_bias=qkv_bias, 131 | **factory_kwargs, 132 | ) 133 | for _ in range(depth) 134 | ] 135 | ) 136 | 137 | def forward( 138 | self, 139 | x: torch.Tensor, 140 | c: torch.LongTensor, 141 | mask: Optional[torch.Tensor] = None, 142 | ): 143 | self_attn_mask = None 144 | if mask is not None: 145 | batch_size = mask.shape[0] 146 | seq_len = mask.shape[1] 147 | mask = mask.to(x.device) 148 | # batch_size x 1 x seq_len x seq_len 149 | self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat( 150 | 1, 1, seq_len, 1 151 | ) 152 | # batch_size x 1 x seq_len x seq_len 153 | self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) 154 | # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num 155 | self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() 156 | # avoids self-attention weight being NaN for padding tokens 157 | self_attn_mask[:, :, :, 0] = True 158 | 159 | for block in self.blocks: 160 | x = block(x, c, self_attn_mask) 161 | return x 162 | 163 | 164 | class SingleTokenRefiner(nn.Module): 165 | """ 166 | A single token refiner block for llm text embedding refine. 167 | """ 168 | def __init__( 169 | self, 170 | in_channels, 171 | hidden_size, 172 | heads_num, 173 | depth, 174 | mlp_width_ratio: float = 4.0, 175 | mlp_drop_rate: float = 0.0, 176 | act_type: str = "silu", 177 | qk_norm: bool = False, 178 | qk_norm_type: str = "layer", 179 | qkv_bias: bool = True, 180 | attn_mode: str = "sdpa", 181 | dtype: Optional[torch.dtype] = None, 182 | device: Optional[torch.device] = None, 183 | ): 184 | factory_kwargs = {"device": device, "dtype": dtype} 185 | super().__init__() 186 | self.attn_mode = attn_mode 187 | assert self.attn_mode == "sdpa", "Only support 'torch sdpa' mode for token refiner." 188 | 189 | self.input_embedder = nn.Linear( 190 | in_channels, hidden_size, bias=True, **factory_kwargs 191 | ) 192 | 193 | act_layer = get_activation_layer(act_type) 194 | # Build timestep embedding layer 195 | self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs) 196 | # Build context embedding layer 197 | self.c_embedder = TextProjection( 198 | in_channels, hidden_size, act_layer, **factory_kwargs 199 | ) 200 | 201 | self.individual_token_refiner = IndividualTokenRefiner( 202 | hidden_size=hidden_size, 203 | heads_num=heads_num, 204 | depth=depth, 205 | mlp_width_ratio=mlp_width_ratio, 206 | mlp_drop_rate=mlp_drop_rate, 207 | act_type=act_type, 208 | qk_norm=qk_norm, 209 | qk_norm_type=qk_norm_type, 210 | qkv_bias=qkv_bias, 211 | **factory_kwargs, 212 | ) 213 | 214 | def forward( 215 | self, 216 | x: torch.Tensor, 217 | t: torch.LongTensor, 218 | mask: Optional[torch.LongTensor] = None, 219 | ): 220 | in_dtype = x.dtype 221 | timestep_aware_representations = self.t_embedder(t) 222 | 223 | if mask is None: 224 | context_aware_representations = x.mean(dim=1) 225 | else: 226 | mask_float = mask.float().unsqueeze(-1) # [b, s1, 1] 227 | context_aware_representations = (x * mask_float).sum( 228 | dim=1 229 | ) / mask_float.sum(dim=1) 230 | context_aware_representations = self.c_embedder(context_aware_representations.to(in_dtype)) 231 | c = timestep_aware_representations + context_aware_representations 232 | 233 | x = self.input_embedder(x) 234 | 235 | x = self.individual_token_refiner(x, c, mask) 236 | 237 | return x 238 | -------------------------------------------------------------------------------- /hyvideo/prompt_rewrite.py: -------------------------------------------------------------------------------- 1 | normal_mode_prompt = """Normal mode - Video Recaption Task: 2 | 3 | You are a large language model specialized in rewriting video descriptions. Your task is to modify the input description. 4 | 5 | 0. Preserve ALL information, including style words and technical terms. 6 | 7 | 1. If the input is in Chinese, translate the entire description to English. 8 | 9 | 2. If the input is just one or two words describing an object or person, provide a brief, simple description focusing on basic visual characteristics. Limit the description to 1-2 short sentences. 10 | 11 | 3. If the input does not include style, lighting, atmosphere, you can make reasonable associations. 12 | 13 | 4. Output ALL must be in English. 14 | 15 | Given Input: 16 | input: "{input}" 17 | """ 18 | 19 | 20 | master_mode_prompt = """Master mode - Video Recaption Task: 21 | 22 | You are a large language model specialized in rewriting video descriptions. Your task is to modify the input description. 23 | 24 | 0. Preserve ALL information, including style words and technical terms. 25 | 26 | 1. If the input is in Chinese, translate the entire description to English. 27 | 28 | 2. If the input is just one or two words describing an object or person, provide a brief, simple description focusing on basic visual characteristics. Limit the description to 1-2 short sentences. 29 | 30 | 3. If the input does not include style, lighting, atmosphere, you can make reasonable associations. 31 | 32 | 4. Output ALL must be in English. 33 | 34 | Given Input: 35 | input: "{input}" 36 | """ 37 | 38 | def get_rewrite_prompt(ori_prompt, mode="Normal"): 39 | if mode == "Normal": 40 | prompt = normal_mode_prompt.format(input=ori_prompt) 41 | elif mode == "Master": 42 | prompt = master_mode_prompt.format(input=ori_prompt) 43 | else: 44 | raise Exception("Only supports Normal and Normal", mode) 45 | return prompt 46 | 47 | ori_prompt = "一只小狗在草地上奔跑。" 48 | normal_prompt = get_rewrite_prompt(ori_prompt, mode="Normal") 49 | master_prompt = get_rewrite_prompt(ori_prompt, mode="Master") 50 | 51 | # Then you can use the normal_prompt or master_prompt to access the hunyuan-large rewrite model to get the final prompt. -------------------------------------------------------------------------------- /hyvideo/text_encoder/configuration_llava.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Llava model configuration""" 15 | 16 | from transformers.configuration_utils import PretrainedConfig 17 | from transformers.utils import logging 18 | from transformers.models.auto import CONFIG_MAPPING, AutoConfig 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | 24 | class LlavaConfig(PretrainedConfig): 25 | r""" 26 | This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an 27 | Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration 28 | with the defaults will yield a similar configuration to that of the Llava-9B. 29 | 30 | e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b) 31 | 32 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 33 | documentation from [`PretrainedConfig`] for more information. 34 | 35 | Args: 36 | vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): 37 | The config object or dictionary of the vision backbone. 38 | text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): 39 | The config object or dictionary of the text backbone. 40 | ignore_index (`int`, *optional*, defaults to -100): 41 | The ignore index for the loss function. 42 | image_token_index (`int`, *optional*, defaults to 32000): 43 | The image token index to encode the image prompt. 44 | projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): 45 | The activation function used by the multimodal projector. 46 | vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): 47 | The feature selection strategy used to select the vision feature from the vision backbone. 48 | Can be one of `"default"` or `"full"`. 49 | vision_feature_layer (`int`, *optional*, defaults to -2): 50 | The index of the layer to select the vision feature. 51 | image_seq_length (`int`, *optional*, defaults to 576): 52 | Sequence length of one image embedding. 53 | 54 | Example: 55 | 56 | ```python 57 | >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig 58 | 59 | >>> # Initializing a CLIP-vision config 60 | >>> vision_config = CLIPVisionConfig() 61 | 62 | >>> # Initializing a Llama config 63 | >>> text_config = LlamaConfig() 64 | 65 | >>> # Initializing a Llava llava-1.5-7b style configuration 66 | >>> configuration = LlavaConfig(vision_config, text_config) 67 | 68 | >>> # Initializing a model from the llava-1.5-7b style configuration 69 | >>> model = LlavaForConditionalGeneration(configuration) 70 | 71 | >>> # Accessing the model configuration 72 | >>> configuration = model.config 73 | ```""" 74 | 75 | model_type = "llava" 76 | sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} 77 | 78 | def __init__( 79 | self, 80 | vision_config=None, 81 | text_config=None, 82 | ignore_index=-100, 83 | image_token_index=32000, 84 | projector_hidden_act="gelu", 85 | vision_feature_select_strategy="default", 86 | vision_feature_layer=-2, 87 | image_seq_length=576, 88 | **kwargs, 89 | ): 90 | self.ignore_index = ignore_index 91 | self.image_token_index = image_token_index 92 | self.projector_hidden_act = projector_hidden_act 93 | self.image_seq_length = image_seq_length 94 | 95 | if vision_feature_select_strategy not in ["default", "full"]: 96 | raise ValueError( 97 | "vision_feature_select_strategy should be one of 'default', 'full'." 98 | f"Got: {vision_feature_select_strategy}" 99 | ) 100 | 101 | self.vision_feature_select_strategy = vision_feature_select_strategy 102 | self.vision_feature_layer = vision_feature_layer 103 | 104 | if isinstance(vision_config, dict): 105 | vision_config["model_type"] = ( 106 | vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" 107 | ) 108 | vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) 109 | elif vision_config is None: 110 | vision_config = CONFIG_MAPPING["clip_vision_model"]( 111 | intermediate_size=4096, 112 | hidden_size=1024, 113 | patch_size=14, 114 | image_size=336, 115 | num_hidden_layers=24, 116 | num_attention_heads=16, 117 | vocab_size=32000, 118 | projection_dim=768, 119 | ) 120 | 121 | self.vision_config = vision_config 122 | 123 | if isinstance(text_config, dict): 124 | text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" 125 | text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) 126 | elif text_config is None: 127 | text_config = CONFIG_MAPPING["llama"]() 128 | 129 | self.text_config = text_config 130 | 131 | super().__init__(**kwargs) 132 | -------------------------------------------------------------------------------- /hyvideo/text_encoder/processing_llava.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Processor class for Llava. 17 | """ 18 | 19 | from typing import List, Union 20 | 21 | from transformers.feature_extraction_utils import BatchFeature 22 | from transformers.image_utils import ImageInput, get_image_size, to_numpy_array 23 | from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order 24 | from transformers.tokenization_utils_base import PreTokenizedInput, TextInput 25 | from transformers.utils import logging 26 | 27 | 28 | logger = logging.get_logger(__name__) 29 | 30 | 31 | class LlavaProcessorKwargs(ProcessingKwargs, total=False): 32 | _defaults = { 33 | "text_kwargs": { 34 | "padding": False, 35 | }, 36 | "images_kwargs": {}, 37 | } 38 | 39 | 40 | class LlavaProcessor(ProcessorMixin): 41 | r""" 42 | Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor. 43 | 44 | [`LlavaProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the 45 | [`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information. 46 | 47 | Args: 48 | image_processor ([`CLIPImageProcessor`], *optional*): 49 | The image processor is a required input. 50 | tokenizer ([`LlamaTokenizerFast`], *optional*): 51 | The tokenizer is a required input. 52 | patch_size (`int`, *optional*): 53 | Patch size from the vision tower. 54 | vision_feature_select_strategy (`str`, *optional*): 55 | The feature selection strategy used to select the vision feature from the vision backbone. 56 | Shoudl be same as in model's config 57 | chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages 58 | in a chat into a tokenizable string. 59 | image_token (`str`, *optional*, defaults to `""`): 60 | Special token used to denote image location. 61 | num_additional_image_tokens (`int`, *optional*, defaults to 0): 62 | Number of additional tokens added to the image embeddings, such as CLS (+1). If the backbone has no CLS or other 63 | extra tokens appended, no need to set this arg. 64 | """ 65 | 66 | attributes = ["image_processor", "tokenizer"] 67 | valid_kwargs = [ 68 | "chat_template", 69 | "patch_size", 70 | "vision_feature_select_strategy", 71 | "image_token", 72 | "num_additional_image_tokens", 73 | ] 74 | image_processor_class = "AutoImageProcessor" 75 | tokenizer_class = "AutoTokenizer" 76 | 77 | def __init__( 78 | self, 79 | image_processor=None, 80 | tokenizer=None, 81 | patch_size=None, 82 | vision_feature_select_strategy=None, 83 | chat_template=None, 84 | image_token="", # set the default and let users change if they have peculiar special tokens in rare cases 85 | num_additional_image_tokens=0, 86 | **kwargs, 87 | ): 88 | self.patch_size = patch_size 89 | self.num_additional_image_tokens = num_additional_image_tokens 90 | self.vision_feature_select_strategy = vision_feature_select_strategy 91 | self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token 92 | super().__init__(image_processor, tokenizer, chat_template=chat_template) 93 | 94 | def __call__( 95 | self, 96 | images: ImageInput = None, 97 | text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, 98 | audio=None, 99 | videos=None, 100 | **kwargs: Unpack[LlavaProcessorKwargs], 101 | ) -> BatchFeature: 102 | """ 103 | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` 104 | and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode 105 | the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to 106 | CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring 107 | of the above two methods for more information. 108 | 109 | Args: 110 | images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): 111 | The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch 112 | tensor. Both channels-first and channels-last formats are supported. 113 | text (`str`, `List[str]`, `List[List[str]]`): 114 | The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings 115 | (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set 116 | `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). 117 | return_tensors (`str` or [`~utils.TensorType`], *optional*): 118 | If set, will return tensors of a particular framework. Acceptable values are: 119 | - `'tf'`: Return TensorFlow `tf.constant` objects. 120 | - `'pt'`: Return PyTorch `torch.Tensor` objects. 121 | - `'np'`: Return NumPy `np.ndarray` objects. 122 | - `'jax'`: Return JAX `jnp.ndarray` objects. 123 | 124 | Returns: 125 | [`BatchFeature`]: A [`BatchFeature`] with the following fields: 126 | 127 | - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. 128 | - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when 129 | `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not 130 | `None`). 131 | - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. 132 | """ 133 | if images is None and text is None: 134 | raise ValueError("You have to specify at least one of `images` or `text`.") 135 | 136 | # check if images and text inputs are reversed for BC 137 | images, text = _validate_images_text_input_order(images, text) 138 | 139 | output_kwargs = self._merge_kwargs( 140 | LlavaProcessorKwargs, 141 | tokenizer_init_kwargs=self.tokenizer.init_kwargs, 142 | **kwargs, 143 | ) 144 | if images is not None: 145 | image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) 146 | else: 147 | image_inputs = {} 148 | 149 | if isinstance(text, str): 150 | text = [text] 151 | elif not isinstance(text, list) and not isinstance(text[0], str): 152 | raise ValueError("Invalid input text. Please provide a string, or a list of strings") 153 | 154 | # try to expand inputs in processing if we have the necessary parts 155 | prompt_strings = text 156 | if image_inputs.get("pixel_values") is not None: 157 | if self.patch_size is not None and self.vision_feature_select_strategy is not None: 158 | # Replace the image token with the expanded image token sequence 159 | pixel_values = image_inputs["pixel_values"] 160 | height, width = get_image_size(to_numpy_array(pixel_values[0])) 161 | num_image_tokens = (height // self.patch_size) * ( 162 | width // self.patch_size 163 | ) + self.num_additional_image_tokens 164 | if self.vision_feature_select_strategy == "default": 165 | num_image_tokens -= self.num_additional_image_tokens 166 | 167 | prompt_strings = [] 168 | for sample in text: 169 | sample = sample.replace(self.image_token, self.image_token * num_image_tokens) 170 | prompt_strings.append(sample) 171 | else: 172 | logger.warning_once( 173 | "Expanding inputs for image tokens in LLaVa should be done in processing. " 174 | "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " 175 | "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " 176 | "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." 177 | ) 178 | 179 | text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) 180 | return BatchFeature(data={**text_inputs, **image_inputs}) 181 | 182 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama 183 | def batch_decode(self, *args, **kwargs): 184 | """ 185 | This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please 186 | refer to the docstring of this method for more information. 187 | """ 188 | return self.tokenizer.batch_decode(*args, **kwargs) 189 | 190 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama 191 | def decode(self, *args, **kwargs): 192 | """ 193 | This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to 194 | the docstring of this method for more information. 195 | """ 196 | return self.tokenizer.decode(*args, **kwargs) 197 | 198 | @property 199 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names 200 | def model_input_names(self): 201 | tokenizer_input_names = self.tokenizer.model_input_names 202 | image_processor_input_names = self.image_processor.model_input_names 203 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) 204 | -------------------------------------------------------------------------------- /hyvideo/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-HunyuanVideoWrapper/3ce9640497139b14462910b0dbf2be1df855c1d6/hyvideo/utils/__init__.py -------------------------------------------------------------------------------- /hyvideo/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | 5 | def align_to(value, alignment): 6 | """align hight, width according to alignment 7 | 8 | Args: 9 | value (int): height or width 10 | alignment (int): target alignment factor 11 | 12 | Returns: 13 | int: the aligned value 14 | """ 15 | return int(math.ceil(value / alignment) * alignment) 16 | -------------------------------------------------------------------------------- /hyvideo/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from einops import rearrange 4 | 5 | import torch 6 | import torchvision 7 | import numpy as np 8 | import imageio 9 | 10 | CODE_SUFFIXES = { 11 | ".py", # Python codes 12 | ".sh", # Shell scripts 13 | ".yaml", 14 | ".yml", # Configuration files 15 | } 16 | 17 | 18 | def safe_dir(path): 19 | """ 20 | Create a directory (or the parent directory of a file) if it does not exist. 21 | 22 | Args: 23 | path (str or Path): Path to the directory. 24 | 25 | Returns: 26 | path (Path): Path object of the directory. 27 | """ 28 | path = Path(path) 29 | path.mkdir(exist_ok=True, parents=True) 30 | return path 31 | 32 | 33 | def safe_file(path): 34 | """ 35 | Create the parent directory of a file if it does not exist. 36 | 37 | Args: 38 | path (str or Path): Path to the file. 39 | 40 | Returns: 41 | path (Path): Path object of the file. 42 | """ 43 | path = Path(path) 44 | path.parent.mkdir(exist_ok=True, parents=True) 45 | return path 46 | 47 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24): 48 | """save videos by video tensor 49 | copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61 50 | 51 | Args: 52 | videos (torch.Tensor): video tensor predicted by the model 53 | path (str): path to save video 54 | rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False. 55 | n_rows (int, optional): Defaults to 1. 56 | fps (int, optional): video save fps. Defaults to 8. 57 | """ 58 | videos = rearrange(videos, "b c t h w -> t b c h w") 59 | outputs = [] 60 | for x in videos: 61 | x = torchvision.utils.make_grid(x, nrow=n_rows) 62 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 63 | if rescale: 64 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 65 | x = torch.clamp(x, 0, 1) 66 | x = (x * 255).numpy().astype(np.uint8) 67 | outputs.append(x) 68 | 69 | os.makedirs(os.path.dirname(path), exist_ok=True) 70 | imageio.mimsave(path, outputs, fps=fps) 71 | -------------------------------------------------------------------------------- /hyvideo/utils/helpers.py: -------------------------------------------------------------------------------- 1 | import collections.abc 2 | 3 | from itertools import repeat 4 | 5 | 6 | def _ntuple(n): 7 | def parse(x): 8 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 9 | x = tuple(x) 10 | if len(x) == 1: 11 | x = tuple(repeat(x[0], n)) 12 | return x 13 | return tuple(repeat(x, n)) 14 | return parse 15 | 16 | 17 | to_1tuple = _ntuple(1) 18 | to_2tuple = _ntuple(2) 19 | to_3tuple = _ntuple(3) 20 | to_4tuple = _ntuple(4) 21 | 22 | 23 | def as_tuple(x): 24 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 25 | return tuple(x) 26 | if x is None or isinstance(x, (int, float, str)): 27 | return (x,) 28 | else: 29 | raise ValueError(f"Unknown type {type(x)}") 30 | 31 | 32 | def as_list_of_2tuple(x): 33 | x = as_tuple(x) 34 | if len(x) == 1: 35 | x = (x[0], x[0]) 36 | assert len(x) % 2 == 0, f"Expect even length, got {len(x)}." 37 | lst = [] 38 | for i in range(0, len(x), 2): 39 | lst.append((x[i], x[i + 1])) 40 | return lst 41 | -------------------------------------------------------------------------------- /hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from transformers import ( 4 | AutoProcessor, 5 | LlavaForConditionalGeneration, 6 | ) 7 | 8 | 9 | def preprocess_text_encoder_tokenizer(args): 10 | 11 | processor = AutoProcessor.from_pretrained(args.input_dir) 12 | model = LlavaForConditionalGeneration.from_pretrained( 13 | args.input_dir, 14 | torch_dtype=torch.float16, 15 | low_cpu_mem_usage=True, 16 | ).to(0) 17 | 18 | model.language_model.save_pretrained( 19 | f"{args.output_dir}" 20 | ) 21 | processor.tokenizer.save_pretrained( 22 | f"{args.output_dir}" 23 | ) 24 | 25 | if __name__ == "__main__": 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument( 29 | "--input_dir", 30 | type=str, 31 | required=True, 32 | help="The path to the llava-llama-3-8b-v1_1-transformers.", 33 | ) 34 | parser.add_argument( 35 | "--output_dir", 36 | type=str, 37 | default="", 38 | help="The output path of the llava-llama-3-8b-text-encoder-tokenizer." 39 | "if '', the parent dir of output will be the same as input dir.", 40 | ) 41 | args = parser.parse_args() 42 | 43 | if len(args.output_dir) == 0: 44 | args.output_dir = "/".join(args.input_dir.split("/")[:-1]) 45 | 46 | preprocess_text_encoder_tokenizer(args) 47 | -------------------------------------------------------------------------------- /hyvideo/utils/token_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def find_subsequence(sequence, sub_sequence): 5 | 6 | assert sequence.shape[0]==1 7 | sequence = sequence[0] 8 | sub_sequence = sub_sequence[0] 9 | 10 | sub_len = len(sub_sequence) 11 | indices = [] 12 | 13 | windows = sequence.unfold(0, sub_len, 1) 14 | matches = (windows == sub_sequence).all(dim=1) 15 | indices = matches.nonzero().flatten().tolist() 16 | 17 | return indices, len(indices), sub_len 18 | 19 | import ast 20 | import torch 21 | 22 | def multi_slice_to_mask(expr, length): 23 | def process_single_slice(s): 24 | s = s.replace(':', ',').replace(' ', '') 25 | while ',,' in s: 26 | s = s.replace(',,', ',None,') 27 | if s.startswith(','): 28 | s = 'None' + s 29 | if s.endswith(','): 30 | s = s + 'None' 31 | return s 32 | 33 | try: 34 | slices = expr.split(',') 35 | mask = torch.zeros(length, dtype=torch.bool) 36 | if expr == "": 37 | return mask 38 | i = 0 39 | while i < len(slices): 40 | if ':' in slices[i]: 41 | slice_expr = process_single_slice(slices[i]) 42 | slice_args = ast.literal_eval(f"({slice_expr})") 43 | s = slice(*slice_args) 44 | mask[s] = True 45 | i += 1 46 | else: 47 | idx = ast.literal_eval(slices[i]) 48 | if idx < 0: 49 | idx = length + idx 50 | if 0 <= idx < length: 51 | mask[idx] = True 52 | i += 1 53 | 54 | return mask 55 | except Exception as e: 56 | raise ValueError(f"Invalid slice expression: {e}") 57 | -------------------------------------------------------------------------------- /hyvideo/vae/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | 5 | from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D 6 | from ..constants import VAE_PATH, PRECISION_TO_TYPE 7 | 8 | def load_vae(vae_type: str="884-16c-hy", 9 | vae_precision: str=None, 10 | sample_size: tuple=None, 11 | vae_path: str=None, 12 | logger=None, 13 | device=None 14 | ): 15 | """the fucntion to load the 3D VAE model 16 | 17 | Args: 18 | vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy". 19 | vae_precision (str, optional): the precision to load vae. Defaults to None. 20 | sample_size (tuple, optional): the tiling size. Defaults to None. 21 | vae_path (str, optional): the path to vae. Defaults to None. 22 | logger (_type_, optional): logger. Defaults to None. 23 | device (_type_, optional): device to load vae. Defaults to None. 24 | """ 25 | if vae_path is None: 26 | vae_path = VAE_PATH[vae_type] 27 | 28 | if logger is not None: 29 | logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}") 30 | config = AutoencoderKLCausal3D.load_config(vae_path) 31 | if sample_size: 32 | vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size) 33 | else: 34 | vae = AutoencoderKLCausal3D.from_config(config) 35 | 36 | vae_ckpt = Path(vae_path) / "pytorch_model.pt" 37 | assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}" 38 | 39 | ckpt = torch.load(vae_ckpt, map_location=vae.device) 40 | if "state_dict" in ckpt: 41 | ckpt = ckpt["state_dict"] 42 | vae_ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")} 43 | vae.load_state_dict(vae_ckpt) 44 | 45 | spatial_compression_ratio = vae.config.spatial_compression_ratio 46 | time_compression_ratio = vae.config.time_compression_ratio 47 | 48 | if vae_precision is not None: 49 | vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision]) 50 | 51 | vae.requires_grad_(False) 52 | 53 | if logger is not None: 54 | logger.info(f"VAE to dtype: {vae.dtype}") 55 | 56 | if device is not None: 57 | vae = vae.to(device) 58 | 59 | vae.eval() 60 | 61 | return vae, vae_path, spatial_compression_ratio, time_compression_ratio 62 | -------------------------------------------------------------------------------- /hyvideo/vae/vae.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | from diffusers.utils import BaseOutput, is_torch_version 9 | from diffusers.utils.torch_utils import randn_tensor 10 | from diffusers.models.attention_processor import SpatialNorm 11 | from .unet_causal_3d_blocks import ( 12 | CausalConv3d, 13 | UNetMidBlockCausal3D, 14 | get_down_block3d, 15 | get_up_block3d, 16 | ) 17 | 18 | 19 | @dataclass 20 | class DecoderOutput(BaseOutput): 21 | r""" 22 | Output of decoding method. 23 | 24 | Args: 25 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 26 | The decoded output sample from the last layer of the model. 27 | """ 28 | 29 | sample: torch.FloatTensor 30 | 31 | 32 | class EncoderCausal3D(nn.Module): 33 | r""" 34 | The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | in_channels: int = 3, 40 | out_channels: int = 3, 41 | down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",), 42 | block_out_channels: Tuple[int, ...] = (64,), 43 | layers_per_block: int = 2, 44 | norm_num_groups: int = 32, 45 | act_fn: str = "silu", 46 | double_z: bool = True, 47 | mid_block_add_attention=True, 48 | time_compression_ratio: int = 4, 49 | spatial_compression_ratio: int = 8, 50 | ): 51 | super().__init__() 52 | self.layers_per_block = layers_per_block 53 | 54 | self.conv_in = CausalConv3d( 55 | in_channels, block_out_channels[0], kernel_size=3, stride=1) 56 | self.mid_block = None 57 | self.down_blocks = nn.ModuleList([]) 58 | 59 | # down 60 | output_channel = block_out_channels[0] 61 | for i, down_block_type in enumerate(down_block_types): 62 | input_channel = output_channel 63 | output_channel = block_out_channels[i] 64 | is_final_block = i == len(block_out_channels) - 1 65 | num_spatial_downsample_layers = int( 66 | np.log2(spatial_compression_ratio)) 67 | num_time_downsample_layers = int(np.log2(time_compression_ratio)) 68 | 69 | if time_compression_ratio == 4: 70 | add_spatial_downsample = bool( 71 | i < num_spatial_downsample_layers) 72 | add_time_downsample = bool(i >= ( 73 | len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block) 74 | elif time_compression_ratio == 8: 75 | add_spatial_downsample = bool( 76 | i < num_spatial_downsample_layers) 77 | add_time_downsample = bool(i < num_time_downsample_layers) 78 | else: 79 | raise ValueError( 80 | f"Unsupported time_compression_ratio: {time_compression_ratio}") 81 | 82 | downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1) 83 | downsample_stride_T = (2, ) if add_time_downsample else (1, ) 84 | downsample_stride = tuple( 85 | downsample_stride_T + downsample_stride_HW) 86 | down_block = get_down_block3d( 87 | down_block_type, 88 | num_layers=self.layers_per_block, 89 | in_channels=input_channel, 90 | out_channels=output_channel, 91 | add_downsample=bool( 92 | add_spatial_downsample or add_time_downsample), 93 | downsample_stride=downsample_stride, 94 | resnet_eps=1e-6, 95 | downsample_padding=0, 96 | resnet_act_fn=act_fn, 97 | resnet_groups=norm_num_groups, 98 | attention_head_dim=output_channel, 99 | temb_channels=None, 100 | ) 101 | self.down_blocks.append(down_block) 102 | 103 | # mid 104 | self.mid_block = UNetMidBlockCausal3D( 105 | in_channels=block_out_channels[-1], 106 | resnet_eps=1e-6, 107 | resnet_act_fn=act_fn, 108 | output_scale_factor=1, 109 | resnet_time_scale_shift="default", 110 | attention_head_dim=block_out_channels[-1], 111 | resnet_groups=norm_num_groups, 112 | temb_channels=None, 113 | add_attention=mid_block_add_attention, 114 | ) 115 | 116 | # out 117 | self.conv_norm_out = nn.GroupNorm( 118 | num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) 119 | self.conv_act = nn.SiLU() 120 | 121 | conv_out_channels = 2 * out_channels if double_z else out_channels 122 | self.conv_out = CausalConv3d( 123 | block_out_channels[-1], conv_out_channels, kernel_size=3) 124 | 125 | def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: 126 | r"""The forward method of the `EncoderCausal3D` class.""" 127 | assert len(sample.shape) == 5, "The input tensor should have 5 dimensions" 128 | 129 | sample = self.conv_in(sample) 130 | 131 | # down 132 | for down_block in self.down_blocks: 133 | sample = down_block(sample) 134 | 135 | # middle 136 | sample = self.mid_block(sample) 137 | 138 | # post-process 139 | sample = self.conv_norm_out(sample) 140 | sample = self.conv_act(sample) 141 | sample = self.conv_out(sample) 142 | 143 | return sample 144 | 145 | 146 | class DecoderCausal3D(nn.Module): 147 | r""" 148 | The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample. 149 | """ 150 | 151 | def __init__( 152 | self, 153 | in_channels: int = 3, 154 | out_channels: int = 3, 155 | up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",), 156 | block_out_channels: Tuple[int, ...] = (64,), 157 | layers_per_block: int = 2, 158 | norm_num_groups: int = 32, 159 | act_fn: str = "silu", 160 | norm_type: str = "group", # group, spatial 161 | mid_block_add_attention=True, 162 | time_compression_ratio: int = 4, 163 | spatial_compression_ratio: int = 8, 164 | ): 165 | super().__init__() 166 | self.layers_per_block = layers_per_block 167 | 168 | self.conv_in = CausalConv3d( 169 | in_channels, block_out_channels[-1], kernel_size=3, stride=1) 170 | self.mid_block = None 171 | self.up_blocks = nn.ModuleList([]) 172 | 173 | temb_channels = in_channels if norm_type == "spatial" else None 174 | 175 | # mid 176 | self.mid_block = UNetMidBlockCausal3D( 177 | in_channels=block_out_channels[-1], 178 | resnet_eps=1e-6, 179 | resnet_act_fn=act_fn, 180 | output_scale_factor=1, 181 | resnet_time_scale_shift="default" if norm_type == "group" else norm_type, 182 | attention_head_dim=block_out_channels[-1], 183 | resnet_groups=norm_num_groups, 184 | temb_channels=temb_channels, 185 | add_attention=mid_block_add_attention, 186 | ) 187 | 188 | # up 189 | reversed_block_out_channels = list(reversed(block_out_channels)) 190 | output_channel = reversed_block_out_channels[0] 191 | for i, up_block_type in enumerate(up_block_types): 192 | prev_output_channel = output_channel 193 | output_channel = reversed_block_out_channels[i] 194 | is_final_block = i == len(block_out_channels) - 1 195 | num_spatial_upsample_layers = int( 196 | np.log2(spatial_compression_ratio)) 197 | num_time_upsample_layers = int(np.log2(time_compression_ratio)) 198 | 199 | if time_compression_ratio == 4: 200 | add_spatial_upsample = bool(i < num_spatial_upsample_layers) 201 | add_time_upsample = bool(i >= len( 202 | block_out_channels) - 1 - num_time_upsample_layers and not is_final_block) 203 | else: 204 | raise ValueError( 205 | f"Unsupported time_compression_ratio: {time_compression_ratio}") 206 | 207 | upsample_scale_factor_HW = ( 208 | 2, 2) if add_spatial_upsample else (1, 1) 209 | upsample_scale_factor_T = (2, ) if add_time_upsample else (1, ) 210 | upsample_scale_factor = tuple( 211 | upsample_scale_factor_T + upsample_scale_factor_HW) 212 | up_block = get_up_block3d( 213 | up_block_type, 214 | num_layers=self.layers_per_block + 1, 215 | in_channels=prev_output_channel, 216 | out_channels=output_channel, 217 | prev_output_channel=None, 218 | add_upsample=bool(add_spatial_upsample or add_time_upsample), 219 | upsample_scale_factor=upsample_scale_factor, 220 | resnet_eps=1e-6, 221 | resnet_act_fn=act_fn, 222 | resnet_groups=norm_num_groups, 223 | attention_head_dim=output_channel, 224 | temb_channels=temb_channels, 225 | resnet_time_scale_shift=norm_type, 226 | ) 227 | self.up_blocks.append(up_block) 228 | prev_output_channel = output_channel 229 | 230 | # out 231 | if norm_type == "spatial": 232 | self.conv_norm_out = SpatialNorm( 233 | block_out_channels[0], temb_channels) 234 | else: 235 | self.conv_norm_out = nn.GroupNorm( 236 | num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) 237 | self.conv_act = nn.SiLU() 238 | self.conv_out = CausalConv3d( 239 | block_out_channels[0], out_channels, kernel_size=3) 240 | 241 | self.gradient_checkpointing = False 242 | 243 | def forward( 244 | self, 245 | sample: torch.FloatTensor, 246 | latent_embeds: Optional[torch.FloatTensor] = None, 247 | ) -> torch.FloatTensor: 248 | r"""The forward method of the `DecoderCausal3D` class.""" 249 | assert len(sample.shape) == 5, "The input tensor should have 5 dimensions" 250 | 251 | sample = self.conv_in(sample) 252 | 253 | upscale_dtype = next(iter(self.up_blocks.parameters())).dtype 254 | if self.training and self.gradient_checkpointing: 255 | 256 | def create_custom_forward(module): 257 | def custom_forward(*inputs): 258 | return module(*inputs) 259 | 260 | return custom_forward 261 | 262 | if is_torch_version(">=", "1.11.0"): 263 | # middle 264 | sample = torch.utils.checkpoint.checkpoint( 265 | create_custom_forward(self.mid_block), 266 | sample, 267 | latent_embeds, 268 | use_reentrant=False, 269 | ) 270 | sample = sample.to(upscale_dtype) 271 | 272 | # up 273 | for up_block in self.up_blocks: 274 | sample = torch.utils.checkpoint.checkpoint( 275 | create_custom_forward(up_block), 276 | sample, 277 | latent_embeds, 278 | use_reentrant=False, 279 | ) 280 | else: 281 | # middle 282 | sample = torch.utils.checkpoint.checkpoint( 283 | create_custom_forward( 284 | self.mid_block), sample, latent_embeds 285 | ) 286 | sample = sample.to(upscale_dtype) 287 | 288 | # up 289 | for up_block in self.up_blocks: 290 | sample = torch.utils.checkpoint.checkpoint( 291 | create_custom_forward(up_block), sample, latent_embeds) 292 | else: 293 | # middle 294 | sample = self.mid_block(sample, latent_embeds) 295 | sample = sample.to(upscale_dtype) 296 | 297 | # up 298 | for up_block in self.up_blocks: 299 | sample = up_block(sample, latent_embeds) 300 | 301 | # post-process 302 | if latent_embeds is None: 303 | sample = self.conv_norm_out(sample) 304 | else: 305 | sample = self.conv_norm_out(sample, latent_embeds) 306 | sample = self.conv_act(sample) 307 | sample = self.conv_out(sample) 308 | 309 | return sample 310 | 311 | 312 | class DiagonalGaussianDistribution(object): 313 | def __init__(self, parameters: torch.Tensor, deterministic: bool = False): 314 | if parameters.ndim == 3: 315 | dim = 2 # (B, L, C) 316 | elif parameters.ndim == 5 or parameters.ndim == 4: 317 | dim = 1 # (B, C, T, H ,W) / (B, C, H, W) 318 | else: 319 | raise NotImplementedError 320 | self.parameters = parameters 321 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim) 322 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 323 | self.deterministic = deterministic 324 | self.std = torch.exp(0.5 * self.logvar) 325 | self.var = torch.exp(self.logvar) 326 | if self.deterministic: 327 | self.var = self.std = torch.zeros_like( 328 | self.mean, device=self.parameters.device, dtype=self.parameters.dtype 329 | ) 330 | 331 | def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: 332 | # make sure sample is on the same device as the parameters and has same dtype 333 | sample = randn_tensor( 334 | self.mean.shape, 335 | generator=generator, 336 | device=self.parameters.device, 337 | dtype=self.parameters.dtype, 338 | ) 339 | x = self.mean + self.std * sample 340 | return x 341 | 342 | def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: 343 | if self.deterministic: 344 | return torch.Tensor([0.0]) 345 | else: 346 | reduce_dim = list(range(1, self.mean.ndim)) 347 | if other is None: 348 | return 0.5 * torch.sum( 349 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 350 | dim=reduce_dim, 351 | ) 352 | else: 353 | return 0.5 * torch.sum( 354 | torch.pow(self.mean - other.mean, 2) / other.var 355 | + self.var / other.var 356 | - 1.0 357 | - self.logvar 358 | + other.logvar, 359 | dim=reduce_dim, 360 | ) 361 | 362 | def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: 363 | if self.deterministic: 364 | return torch.Tensor([0.0]) 365 | logtwopi = np.log(2.0 * np.pi) 366 | return 0.5 * torch.sum( 367 | logtwopi + self.logvar + 368 | torch.pow(sample - self.mean, 2) / self.var, 369 | dim=dims, 370 | ) 371 | 372 | def mode(self) -> torch.Tensor: 373 | return self.mean 374 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-hunyuanvideowrapper" 3 | description = "ComfyUI diffusers wrapper nodes for [a/HunyuanVideo](https://github.com/Tencent/HunyuanVideo)" 4 | version = "1.0.8" 5 | license = {file = "LICENSE"} 6 | dependencies = ["accelerate >= 1.2.1", "diffusers >= 0.31.0", "transformers >= 4.49.0", "jax >= 0.4.28", "timm >= 1.0.15"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/kijai/ComfyUI-HunyuanVideoWrapper" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "kijai" 14 | DisplayName = "ComfyUI-HunyuanVideoWrapper" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # ComfyUI wrapper nodes for [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) 2 | 3 | # Update 5 4 | 5 | So I know I said I'd stop working on this, but with all the new stuff out I wanted to work on those and have included the official I2V, it's "fixed" version 2 and the [LoRAs](https://huggingface.co/Kijai/HunyuanVideo_comfy/blob/main/hyvid_I2V_lora_embrace.safetensors) they included in the release 6 | 7 | https://github.com/user-attachments/assets/8ce4b1ee-fb63-49a2-83b4-ba8ef1a8b842 8 | 9 | 10 | 11 | 12 | and the [dashtoon keyframe LoRA](https://github.com/dashtoon/hunyuan-video-keyframe-control-lora). 13 | 14 | https://github.com/user-attachments/assets/2b6e32e4-470f-4feb-b299-5a453e2b4fa1 15 | 16 | Also because there's been so much trouble in using the transformer model for text encoding, I figured a way to use the text embeds from native ComfyUI text encoding, like this: 17 | 18 | ![image](https://github.com/user-attachments/assets/80b23087-a66d-4937-bb2c-d15d5a20304b) 19 | 20 | Not that it does give somewhat different results and using these nodes like that can't be considered as original implementation wrapper anymore. 21 | 22 | # Update 4, the non-update: 23 | 24 | 25 | As the native implementation exists, and has support for most features by now, I will mostly stop working on these nodes for anything but it's main purpose: early access and testing of potential new features that are difficult (at least for me) to implement natively. 26 | 27 | ## Some resources for native workflows: 28 | 29 | Flowedit and enhance-a-video can be found from these nodes: https://github.com/logtd/ComfyUI-HunyuanLoom 30 | 31 | TeaCache equilevant FirstBlockCache, as well as torch.compile with LoRA support: https://github.com/chengzeyi/Comfy-WaveSpeed 32 | 33 | Sageattention can be enabled by `--use-sage-attention` startup argument for ComfyUI, or with a patcher node found in [KJNodes](https://github.com/kijai/ComfyUI-KJNodes) as well as some other node packs. 34 | 35 | Leapfusion I2V can be used with my patcher node found in the KJNodes as well, example workflow: https://github.com/kijai/ComfyUI-KJNodes/blob/main/example_workflows/leapfusion_hunyuuanvideo_i2v_native_testing.json 36 | 37 | What remains missing from native implementation currently: 38 | - context windowing 39 | - direct image embed support through IP2V 40 | - manual memory management 41 | 42 | # Update 3: 43 | 44 | It's been hectic couple of weeks with this model, I've lost track of what has happened since the start, but I'll try to present some of the more important updates: 45 | 46 | ## Official scaled fp8 weights were released: 47 | 48 | https://huggingface.co/tencent/HunyuanVideo/blob/main/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt 49 | 50 | Even if this file is .pt it's completely safe and it is loaded with weights_only, the scale map is included with the nodes. To use this model you have to use the `fp8_scaled` -quantization option in the model loader. 51 | The quality of these weights is much closer to the original bf16, downside is that they do not currently support fp8 fast mode, or LoRAs. 52 | 53 | ## Almost free quality increase with [Enhance-A-Video](https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video): 54 | 55 | This has a very slight hit on inference speed and zero hit on memory use, initial tests indicate it's absolutely worth using. 56 | 57 | ![image](https://github.com/user-attachments/assets/68f0b5eb-aa23-49e1-a48f-fd3c4b1108ed) 58 | 59 | https://github.com/user-attachments/assets/e19b30e1-5f67-4e75-9c73-716d4569c319 60 | 61 | https://github.com/user-attachments/assets/083353a2-e9aa-43e9-a916-ff3af1d581c1 62 | 63 | 64 | 65 | # Update 2: Experimental IP2V - Image Prompting to Video via VLM by @Dango233 66 | ## WORK IN PROGRESS - But it should work now! 67 | 68 | Now you can feed image to the VLM as condition of generations! This is different from image2video where the image become the first frame of the video. IP2V uses image as a part of the prompt, to extract the concept and style of the image. 69 | So - very much like IPAdapter - but VLM will do the heavy lifting for you! 70 | 71 | Now this is a tuning free approach but with further task specific tuning we can expand the use scenarios. 72 | 73 | ## Guide to Using `xtuner/llava-llama-3-8b-v1_1-transformers` for Image-Text Tasks 74 | 75 | ## Step 1: Model Selection 76 | Use the original `xtuner/llava-llama-3-8b-v1_1-transformers` model which includes the vision tower. You have two options: 77 | - Download the model and place it in the `models/LLM` folder. 78 | - Rely on the auto-download mechanism. 79 | 80 | **Note:** It's recommended to offload the text encoder since the vision tower requires additional VRAM. 81 | 82 | ## Step 2: Load and Connect Image 83 | - Use the comfy native node to load the image. 84 | - Connect the loaded image to the `Hunyuan TextImageEncode` node. 85 | - You can connect up to 2 images to this node. 86 | 87 | ## Step 3: Prompting with Images 88 | - Reference the image in your prompt by including ``. 89 | - The number of `` tags should match the number of images provided to the sampler. 90 | - Example prompt: `Describe this in great detail.` 91 | 92 | You can also choose to give CLIP a prompt that does not reference the image separately. 93 | 94 | ## Step 4: Advanced Configuration - `image_token_selection_expression` 95 | This expression is for advanced users and serves as a boolean mask to select which part of the image hidden state will be used for conditioning. Here are some details and recommendations: 96 | 97 | - The hidden state sequence length (or number of tokens) per image in llava-llama-3 is 576. 98 | - The default setting is `::4`, meaning every four tokens, one token goes into conditioning, interleaved, resulting in 144 tokens per image. 99 | - Generally, more tokens lean more towards the conditional image. 100 | - However, too many tokens (especially if the overall token count exceeds 256) will degrade generation quality. It's recommended not to use more than half the tokens (`::2`). 101 | - Interleaved tokens generally perform better, but you might also want to try the following expressions: 102 | - `:128` - First 128 tokens. 103 | - `-128:` - Last 128 tokens. 104 | - `:128, -128:` - First 128 tokens and last 128 tokens. 105 | - With a proper prompting strategy, even not passing in any image tokens (leaving the expression blank) can yield decent effects. 106 | 107 | # Update 108 | 109 | Scaled dot product attention (sdpa) should now be working (only tested on Windows, torch 2.5.1+cu124 on 4090), sageattention is still recommended for speed, but should not be necessary anymore making installation much easier. 110 | 111 | Vid2vid test: 112 | [source video](https://www.pexels.com/video/a-4x4-vehicle-speeding-on-a-dirt-road-during-a-competition-15604814/) 113 | 114 | https://github.com/user-attachments/assets/12940721-4168-4e2b-8a71-31b4b0432314 115 | 116 | 117 | text2vid (old test): 118 | 119 | https://github.com/user-attachments/assets/3750da65-9753-4bd2-aae2-a688d2b86115 120 | 121 | 122 | Transformer and VAE (single files, no autodownload): 123 | 124 | https://huggingface.co/Kijai/HunyuanVideo_comfy/tree/main 125 | 126 | Go to the usual ComfyUI folders (diffusion_models and vae) 127 | 128 | LLM text encoder (has autodownload): 129 | 130 | https://huggingface.co/Kijai/llava-llama-3-8b-text-encoder-tokenizer 131 | 132 | Files go to `ComfyUI/models/LLM/llava-llama-3-8b-text-encoder-tokenizer` 133 | 134 | Clip text encoder (has autodownload) 135 | 136 | Either use any Clip_L model supported by ComfyUI by disabling the clip_model in the text encoder loader and plugging in ClipLoader to the text encoder node, or 137 | allow the autodownloader to fetch the original clip model from: 138 | 139 | https://huggingface.co/openai/clip-vit-large-patch14, (only need the .safetensor from the weights, and all the config files) to: 140 | 141 | `ComfyUI/models/clip/clip-vit-large-patch14` 142 | 143 | Memory use is entirely dependant on resolution and frame count, don't expect to be able to go very high even on 24GB. 144 | 145 | Good news is that the model can do functional videos even at really low resolutions. 146 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate >= 1.2.1 2 | diffusers >= 0.31.0 3 | transformers >= 4.49.0 4 | jax >= 0.4.28 5 | timm >= 1.0.15 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | import torch 3 | import logging 4 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 5 | log = logging.getLogger(__name__) 6 | 7 | def check_diffusers_version(): 8 | try: 9 | version = importlib.metadata.version('diffusers') 10 | required_version = '0.31.0' 11 | if version < required_version: 12 | raise AssertionError(f"diffusers version {version} is installed, but version {required_version} or higher is required.") 13 | except importlib.metadata.PackageNotFoundError: 14 | raise AssertionError("diffusers is not installed.") 15 | 16 | def print_memory(device): 17 | memory = torch.cuda.memory_allocated(device) / 1024**3 18 | max_memory = torch.cuda.max_memory_allocated(device) / 1024**3 19 | max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 20 | log.info(f"-------------------------------") 21 | log.info(f"Allocated memory: {memory=:.3f} GB") 22 | log.info(f"Max allocated memory: {max_memory=:.3f} GB") 23 | log.info(f"Max reserved memory: {max_reserved=:.3f} GB") 24 | log.info(f"-------------------------------") 25 | #memory_summary = torch.cuda.memory_summary(device=device, abbreviated=False) 26 | #log.info(f"Memory Summary:\n{memory_summary}") 27 | 28 | def optimized_scale(positive_flat, negative_flat): 29 | 30 | # Calculate dot production 31 | dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) 32 | 33 | # Squared norm of uncondition 34 | squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 35 | 36 | # st_star = v_cond^T * v_uncond / ||v_uncond||^2 37 | st_star = dot_product / squared_norm 38 | 39 | return st_star 40 | 41 | # Code based on https://github.com/WikiChao/FreSca (MIT License) 42 | import torch 43 | import torch.fft as fft 44 | 45 | def fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20): 46 | """ 47 | Apply frequency-dependent scaling to an image tensor using Fourier transforms. 48 | 49 | Parameters: 50 | x: Input tensor of shape (B, C, H, W) 51 | scale_low: Scaling factor for low-frequency components (default: 1.0) 52 | scale_high: Scaling factor for high-frequency components (default: 1.5) 53 | freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20) 54 | 55 | Returns: 56 | x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied. 57 | """ 58 | # Preserve input dtype and device 59 | dtype, device = x.dtype, x.device 60 | 61 | # Convert to float32 for FFT computations 62 | x = x.to(torch.float32) 63 | 64 | # 1) Apply FFT and shift low frequencies to center 65 | x_freq = fft.fftn(x, dim=(-2, -1)) 66 | x_freq = fft.fftshift(x_freq, dim=(-2, -1)) 67 | 68 | # 2) Create a mask to scale frequencies differently 69 | B, C, T, H, W = x_freq.shape 70 | crow, ccol = H // 2, W // 2 71 | 72 | # Initialize mask with high-frequency scaling factor 73 | mask = torch.ones((B, C, T, H, W), device=device) * scale_high 74 | 75 | # Apply low-frequency scaling factor to center region 76 | mask[ 77 | ..., 78 | crow - freq_cutoff : crow + freq_cutoff, 79 | ccol - freq_cutoff : ccol + freq_cutoff, 80 | ] = scale_low 81 | 82 | # 3) Apply frequency-specific scaling 83 | x_freq = x_freq * mask 84 | 85 | # 4) Convert back to spatial domain 86 | x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) 87 | x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real 88 | 89 | # 5) Restore original dtype 90 | x_filtered = x_filtered.to(dtype) 91 | 92 | return x_filtered --------------------------------------------------------------------------------