├── hunyuanvideo_foley ├── __init__.py ├── models │ ├── __init__.py │ ├── nn │ │ ├── __init__.py │ │ ├── activation_layers.py │ │ ├── modulate_layers.py │ │ ├── norm_layers.py │ │ ├── embed_layers.py │ │ ├── mlp_layers.py │ │ └── posemb_layers.py │ ├── synchformer │ │ ├── __init__.py │ │ ├── divided_224_16x4.yaml │ │ ├── utils.py │ │ ├── compute_desync_score.py │ │ ├── video_model_builder.py │ │ └── synchformer.py │ └── dac_vae │ │ ├── nn │ │ ├── __init__.py │ │ ├── layers.py │ │ ├── vae_utils.py │ │ ├── quantize.py │ │ └── loss.py │ │ ├── model │ │ ├── __init__.py │ │ ├── discriminator.py │ │ ├── base.py │ │ └── dac.py │ │ ├── __init__.py │ │ ├── __main__.py │ │ └── utils │ │ ├── decode.py │ │ ├── encode.py │ │ └── __init__.py ├── utils │ ├── __init__.py │ ├── schedulers │ │ └── __init__.py │ ├── media_utils.py │ ├── config_utils.py │ ├── helper.py │ ├── feature_utils.py │ └── model_utils.py └── constants.py ├── __init__.py ├── requirements.txt ├── configs └── hunyuanvideo-foley-xxl.yaml ├── model_urls.py ├── install.py ├── download_models_manual.py ├── README.md ├── INSTALLATION_GUIDE.md ├── test_node.py ├── utils.py └── example_workflows └── HunyuanVideo-Foley.json /hunyuanvideo_foley/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .synchformer import Synchformer 2 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from . import layers 2 | from . import loss 3 | from . import quantize 4 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import CodecMixin 2 | from .base import DACFile 3 | from .dac import DAC 4 | from .discriminator import Discriminator 5 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from diffusers.schedulers import DDPMScheduler, EulerDiscreteScheduler 2 | from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | 3 | # preserved here for legacy reasons 4 | __model_version__ = "latest" 5 | 6 | import audiotools 7 | 8 | audiotools.ml.BaseModel.INTERN += ["dac.**"] 9 | audiotools.ml.BaseModel.EXTERN += ["einops"] 10 | 11 | 12 | from . import nn 13 | from . import model 14 | from . import utils 15 | from .model import DAC 16 | from .model import DACFile 17 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # Add the current directory to Python path to import hunyuanvideo_foley modules 5 | current_dir = os.path.dirname(os.path.abspath(__file__)) 6 | if current_dir not in sys.path: 7 | sys.path.insert(0, current_dir) 8 | 9 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 10 | 11 | # Export the mappings 12 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Core ML dependencies 2 | #torch>=2.0.0 3 | #torchvision>=0.15.0 4 | #torchaudio>=2.0.0 5 | numpy>=1.26.4 6 | scipy 7 | 8 | # Deep Learning frameworks 9 | diffusers 10 | timm 11 | accelerate 12 | 13 | # Transformers and NLP 14 | transformers>=4.37.0 15 | sentencepiece 16 | huggingface_hub 17 | 18 | # Audio processing 19 | git+https://github.com/descriptinc/audiotools 20 | 21 | # Video/Image processing 22 | pillow 23 | av 24 | einops 25 | 26 | # Configuration and utilities 27 | pyyaml 28 | omegaconf 29 | easydict 30 | loguru 31 | tqdm 32 | setuptools 33 | 34 | # Data handling 35 | pandas 36 | pyarrow 37 | 38 | # Network 39 | urllib3==2.4.0 40 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/__main__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import argbind 4 | 5 | from .utils import download 6 | from .utils.decode import decode 7 | from .utils.encode import encode 8 | 9 | STAGES = ["encode", "decode", "download"] 10 | 11 | 12 | def run(stage: str): 13 | """Run stages. 14 | 15 | Parameters 16 | ---------- 17 | stage : str 18 | Stage to run 19 | """ 20 | if stage not in STAGES: 21 | raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}") 22 | stage_fn = globals()[stage] 23 | 24 | if stage == "download": 25 | stage_fn() 26 | return 27 | 28 | stage_fn() 29 | 30 | 31 | if __name__ == "__main__": 32 | group = sys.argv.pop(1) 33 | args = argbind.parse_args(group=group) 34 | 35 | with argbind.scope(args): 36 | run(group) 37 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from torch.nn.utils import weight_norm 7 | 8 | 9 | def WNConv1d(*args, **kwargs): 10 | return weight_norm(nn.Conv1d(*args, **kwargs)) 11 | 12 | 13 | def WNConvTranspose1d(*args, **kwargs): 14 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) 15 | 16 | 17 | # Scripting this brings model speed up 1.4x 18 | @torch.jit.script 19 | def snake(x, alpha): 20 | shape = x.shape 21 | x = x.reshape(shape[0], shape[1], -1) 22 | x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) 23 | x = x.reshape(shape) 24 | return x 25 | 26 | 27 | class Snake1d(nn.Module): 28 | def __init__(self, channels): 29 | super().__init__() 30 | self.alpha = nn.Parameter(torch.ones(1, channels, 1)) 31 | 32 | def forward(self, x): 33 | return snake(x, self.alpha) 34 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/activation_layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | def get_activation_layer(act_type): 5 | if act_type == "gelu": 6 | return lambda: nn.GELU() 7 | elif act_type == "gelu_tanh": 8 | # Approximate `tanh` requires torch >= 1.13 9 | return lambda: nn.GELU(approximate="tanh") 10 | elif act_type == "relu": 11 | return nn.ReLU 12 | elif act_type == "silu": 13 | return nn.SiLU 14 | else: 15 | raise ValueError(f"Unknown activation type: {act_type}") 16 | 17 | class SwiGLU(nn.Module): 18 | def __init__( 19 | self, 20 | dim: int, 21 | hidden_dim: int, 22 | out_dim: int, 23 | ): 24 | """ 25 | Initialize the SwiGLU FeedForward module. 26 | 27 | Args: 28 | dim (int): Input dimension. 29 | hidden_dim (int): Hidden dimension of the feedforward layer. 30 | 31 | Attributes: 32 | w1: Linear transformation for the first layer. 33 | w2: Linear transformation for the second layer. 34 | w3: Linear transformation for the third layer. 35 | 36 | """ 37 | super().__init__() 38 | 39 | self.w1 = nn.Linear(dim, hidden_dim, bias=False) 40 | self.w2 = nn.Linear(hidden_dim, out_dim, bias=False) 41 | self.w3 = nn.Linear(dim, hidden_dim, bias=False) 42 | 43 | def forward(self, x): 44 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 45 | -------------------------------------------------------------------------------- /configs/hunyuanvideo-foley-xxl.yaml: -------------------------------------------------------------------------------- 1 | model_config: 2 | model_name: HunyuanVideo-Foley-XXL 3 | model_type: 1d 4 | model_precision: bf16 5 | model_kwargs: 6 | depth_triple_blocks: 18 7 | depth_single_blocks: 36 8 | hidden_size: 1536 9 | num_heads: 12 10 | mlp_ratio: 4 11 | mlp_act_type: "gelu_tanh" 12 | qkv_bias: True 13 | qk_norm: True 14 | qk_norm_type: "rms" 15 | attn_mode: "torch" 16 | embedder_type: "default" 17 | interleaved_audio_visual_rope: True 18 | enable_learnable_empty_visual_feat: True 19 | sync_modulation: False 20 | add_sync_feat_to_audio: True 21 | cross_attention: True 22 | use_attention_mask: False 23 | condition_projection: "linear" 24 | sync_feat_dim: 768 # syncformer 768 dim 25 | condition_dim: 768 # clap 768 text condition dim (clip-text) 26 | clip_dim: 768 # siglip2 visual dim 27 | audio_vae_latent_dim: 128 28 | audio_frame_rate: 50 29 | patch_size: 1 30 | rope_dim_list: null 31 | rope_theta: 10000 32 | text_length: 77 33 | clip_length: 64 34 | sync_length: 192 35 | use_mmaudio_singleblock: True 36 | depth_triple_ssl_encoder: null 37 | depth_single_ssl_encoder: 8 38 | use_repa_with_audiossl: True 39 | 40 | diffusion_config: 41 | denoise_type: "flow" 42 | flow_path_type: "linear" 43 | flow_predict_type: "velocity" 44 | flow_reverse: True 45 | flow_solver: "euler" 46 | sample_flow_shift: 1.0 47 | sample_use_flux_shift: False 48 | flux_base_shift: 0.5 49 | flux_max_shift: 1.15 50 | -------------------------------------------------------------------------------- /model_urls.py: -------------------------------------------------------------------------------- 1 | # HunyuanVideo-Foley Model URLs Configuration 2 | # Update these URLs with the actual download links for the models 3 | 4 | MODEL_URLS = { 5 | "hunyuanvideo-foley-xxl": { 6 | "models": [ 7 | { 8 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/hunyuanvideo_foley.pth", 9 | "filename": "hunyuanvideo_foley.pth", 10 | "description": "Main HunyuanVideo-Foley model" 11 | }, 12 | { 13 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/synchformer_state_dict.pth", 14 | "filename": "synchformer_state_dict.pth", 15 | "description": "Synchformer model weights" 16 | }, 17 | { 18 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/vae_128d_48k.pth", 19 | "filename": "vae_128d_48k.pth", 20 | "description": "VAE model weights" 21 | } 22 | ], 23 | "extracted_dir": "hunyuanvideo-foley-xxl", 24 | "description": "HunyuanVideo-Foley XXL model for audio generation" 25 | } 26 | } 27 | 28 | # Alternative mirror URLs (if main URLs fail) 29 | MIRROR_URLS = { 30 | # Add mirror download sources here if needed 31 | } 32 | 33 | def get_model_url(model_name: str, use_mirror: bool = False) -> dict: 34 | """Get model URL configuration""" 35 | urls_dict = MIRROR_URLS if use_mirror else MODEL_URLS 36 | return urls_dict.get(model_name, {}) 37 | 38 | def list_available_models() -> list: 39 | """List all available model names""" 40 | return list(MODEL_URLS.keys()) -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/modulate_layers.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import torch 3 | import torch.nn as nn 4 | 5 | class ModulateDiT(nn.Module): 6 | def __init__(self, hidden_size: int, factor: int, act_layer: Callable, dtype=None, device=None): 7 | factory_kwargs = {"dtype": dtype, "device": device} 8 | super().__init__() 9 | self.act = act_layer() 10 | self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) 11 | # Zero-initialize the modulation 12 | nn.init.zeros_(self.linear.weight) 13 | nn.init.zeros_(self.linear.bias) 14 | 15 | def forward(self, x: torch.Tensor) -> torch.Tensor: 16 | return self.linear(self.act(x)) 17 | 18 | 19 | def modulate(x, shift=None, scale=None): 20 | if x.ndim == 3: 21 | shift = shift.unsqueeze(1) if shift is not None and shift.ndim == 2 else None 22 | scale = scale.unsqueeze(1) if scale is not None and scale.ndim == 2 else None 23 | if scale is None and shift is None: 24 | return x 25 | elif shift is None: 26 | return x * (1 + scale) 27 | elif scale is None: 28 | return x + shift 29 | else: 30 | return x * (1 + scale) + shift 31 | 32 | 33 | def apply_gate(x, gate=None, tanh=False): 34 | if gate is None: 35 | return x 36 | if gate.ndim == 2 and x.ndim == 3: 37 | gate = gate.unsqueeze(1) 38 | if tanh: 39 | return x * gate.tanh() 40 | else: 41 | return x * gate 42 | 43 | 44 | def ckpt_wrapper(module): 45 | def ckpt_forward(*inputs): 46 | outputs = module(*inputs) 47 | return outputs 48 | 49 | return ckpt_forward 50 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/constants.py: -------------------------------------------------------------------------------- 1 | """Constants used throughout the HunyuanVideo-Foley project.""" 2 | 3 | from typing import Dict, List 4 | 5 | # Model configuration 6 | DEFAULT_AUDIO_SAMPLE_RATE = 48000 7 | DEFAULT_VIDEO_FPS = 25 8 | DEFAULT_AUDIO_CHANNELS = 2 9 | 10 | # Video processing 11 | MAX_VIDEO_DURATION_SECONDS = 15.0 12 | MIN_VIDEO_DURATION_SECONDS = 1.0 13 | 14 | # Audio processing 15 | AUDIO_VAE_LATENT_DIM = 128 16 | AUDIO_FRAME_RATE = 75 # frames per second in latent space 17 | 18 | # Visual features 19 | FPS_VISUAL: Dict[str, int] = { 20 | "siglip2": 8, 21 | "synchformer": 25 22 | } 23 | 24 | # Model paths (can be overridden by environment variables) 25 | DEFAULT_MODEL_PATH = "./pretrained_models/" 26 | DEFAULT_CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml" 27 | 28 | # Inference parameters 29 | DEFAULT_GUIDANCE_SCALE = 4.5 30 | DEFAULT_NUM_INFERENCE_STEPS = 50 31 | MIN_GUIDANCE_SCALE = 1.0 32 | MAX_GUIDANCE_SCALE = 10.0 33 | MIN_INFERENCE_STEPS = 10 34 | MAX_INFERENCE_STEPS = 100 35 | 36 | # Text processing 37 | MAX_TEXT_LENGTH = 100 38 | DEFAULT_NEGATIVE_PROMPT = "noisy, harsh" 39 | 40 | # File extensions 41 | SUPPORTED_VIDEO_EXTENSIONS: List[str] = [".mp4", ".avi", ".mov", ".mkv", ".webm"] 42 | SUPPORTED_AUDIO_EXTENSIONS: List[str] = [".wav", ".mp3", ".flac", ".aac"] 43 | 44 | # Quality settings 45 | AUDIO_QUALITY_SETTINGS: Dict[str, List[str]] = { 46 | "high": ["-b:a", "192k"], 47 | "medium": ["-b:a", "128k"], 48 | "low": ["-b:a", "96k"] 49 | } 50 | 51 | # Error messages 52 | ERROR_MESSAGES: Dict[str, str] = { 53 | "model_not_loaded": "Model is not loaded. Please load the model first.", 54 | "invalid_video_format": "Unsupported video format. Supported formats: {formats}", 55 | "video_too_long": f"Video duration exceeds maximum of {MAX_VIDEO_DURATION_SECONDS} seconds", 56 | "ffmpeg_not_found": "ffmpeg not found. Please install ffmpeg: https://ffmpeg.org/download.html" 57 | } -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/divided_224_16x4.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: Ssv2 4 | BATCH_SIZE: 32 5 | EVAL_PERIOD: 5 6 | CHECKPOINT_PERIOD: 5 7 | AUTO_RESUME: True 8 | CHECKPOINT_EPOCH_RESET: True 9 | CHECKPOINT_FILE_PATH: /checkpoint/fmetze/neurips_sota/40944587/checkpoints/checkpoint_epoch_00035.pyth 10 | DATA: 11 | NUM_FRAMES: 16 12 | SAMPLING_RATE: 4 13 | TRAIN_JITTER_SCALES: [256, 320] 14 | TRAIN_CROP_SIZE: 224 15 | TEST_CROP_SIZE: 224 16 | INPUT_CHANNEL_NUM: [3] 17 | MEAN: [0.5, 0.5, 0.5] 18 | STD: [0.5, 0.5, 0.5] 19 | PATH_TO_DATA_DIR: /private/home/mandelapatrick/slowfast/data/ssv2 20 | PATH_PREFIX: /datasets01/SomethingV2/092720/20bn-something-something-v2-frames 21 | INV_UNIFORM_SAMPLE: True 22 | RANDOM_FLIP: False 23 | REVERSE_INPUT_CHANNEL: True 24 | USE_RAND_AUGMENT: True 25 | RE_PROB: 0.0 26 | USE_REPEATED_AUG: False 27 | USE_RANDOM_RESIZE_CROPS: False 28 | COLORJITTER: False 29 | GRAYSCALE: False 30 | GAUSSIAN: False 31 | SOLVER: 32 | BASE_LR: 1e-4 33 | LR_POLICY: steps_with_relative_lrs 34 | LRS: [1, 0.1, 0.01] 35 | STEPS: [0, 20, 30] 36 | MAX_EPOCH: 35 37 | MOMENTUM: 0.9 38 | WEIGHT_DECAY: 5e-2 39 | WARMUP_EPOCHS: 0.0 40 | OPTIMIZING_METHOD: adamw 41 | USE_MIXED_PRECISION: True 42 | SMOOTHING: 0.2 43 | SLOWFAST: 44 | ALPHA: 8 45 | VIT: 46 | PATCH_SIZE: 16 47 | PATCH_SIZE_TEMP: 2 48 | CHANNELS: 3 49 | EMBED_DIM: 768 50 | DEPTH: 12 51 | NUM_HEADS: 12 52 | MLP_RATIO: 4 53 | QKV_BIAS: True 54 | VIDEO_INPUT: True 55 | TEMPORAL_RESOLUTION: 8 56 | USE_MLP: True 57 | DROP: 0.0 58 | POS_DROPOUT: 0.0 59 | DROP_PATH: 0.2 60 | IM_PRETRAINED: True 61 | HEAD_DROPOUT: 0.0 62 | HEAD_ACT: tanh 63 | PRETRAINED_WEIGHTS: vit_1k 64 | ATTN_LAYER: divided 65 | MODEL: 66 | NUM_CLASSES: 174 67 | ARCH: slow 68 | MODEL_NAME: VisionTransformer 69 | LOSS_FUNC: cross_entropy 70 | TEST: 71 | ENABLE: True 72 | DATASET: Ssv2 73 | BATCH_SIZE: 64 74 | NUM_ENSEMBLE_VIEWS: 1 75 | NUM_SPATIAL_CROPS: 3 76 | DATA_LOADER: 77 | NUM_WORKERS: 4 78 | PIN_MEMORY: True 79 | NUM_GPUS: 8 80 | NUM_SHARDS: 4 81 | RNG_SEED: 0 82 | OUTPUT_DIR: . 83 | TENSORBOARD: 84 | ENABLE: True 85 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/norm_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class RMSNorm(nn.Module): 5 | def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6, 6 | device=None, dtype=None): 7 | """ 8 | Initialize the RMSNorm normalization layer. 9 | 10 | Args: 11 | dim (int): The dimension of the input tensor. 12 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 13 | 14 | Attributes: 15 | eps (float): A small value added to the denominator for numerical stability. 16 | weight (nn.Parameter): Learnable scaling parameter. 17 | 18 | """ 19 | factory_kwargs = {'device': device, 'dtype': dtype} 20 | super().__init__() 21 | self.eps = eps 22 | if elementwise_affine: 23 | self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) 24 | 25 | def _norm(self, x): 26 | """ 27 | Apply the RMSNorm normalization to the input tensor. 28 | 29 | Args: 30 | x (torch.Tensor): The input tensor. 31 | 32 | Returns: 33 | torch.Tensor: The normalized tensor. 34 | 35 | """ 36 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 37 | 38 | def forward(self, x): 39 | """ 40 | Forward pass through the RMSNorm layer. 41 | 42 | Args: 43 | x (torch.Tensor): The input tensor. 44 | 45 | Returns: 46 | torch.Tensor: The output tensor after applying RMSNorm. 47 | 48 | """ 49 | output = self._norm(x.float()).type_as(x) 50 | if hasattr(self, "weight"): 51 | output = output * self.weight 52 | return output 53 | 54 | 55 | def get_norm_layer(norm_layer): 56 | """ 57 | Get the normalization layer. 58 | 59 | Args: 60 | norm_layer (str): The type of normalization layer. 61 | 62 | Returns: 63 | norm_layer (nn.Module): The normalization layer. 64 | """ 65 | if norm_layer == "layer": 66 | return nn.LayerNorm 67 | elif norm_layer == "rms": 68 | return RMSNorm 69 | else: 70 | raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") 71 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.0]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.mean( 45 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 46 | dim=[1, 2], 47 | ) 48 | else: 49 | return 0.5 * torch.mean( 50 | torch.pow(self.mean - other.mean, 2) / other.var 51 | + self.var / other.var 52 | - 1.0 53 | - self.logvar 54 | + other.logvar, 55 | dim=[1, 2], 56 | ) 57 | 58 | def nll(self, sample, dims=[1, 2]): 59 | if self.deterministic: 60 | return torch.Tensor([0.0]) 61 | logtwopi = np.log(2.0 * np.pi) 62 | return 0.5 * torch.sum( 63 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 64 | dim=dims, 65 | ) 66 | 67 | def mode(self): 68 | return self.mean 69 | 70 | 71 | def normal_kl(mean1, logvar1, mean2, logvar2): 72 | """ 73 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 74 | Compute the KL divergence between two gaussians. 75 | Shapes are automatically broadcasted, so batches can be compared to 76 | scalars, among other use cases. 77 | """ 78 | tensor = None 79 | for obj in (mean1, logvar1, mean2, logvar2): 80 | if isinstance(obj, torch.Tensor): 81 | tensor = obj 82 | break 83 | assert tensor is not None, "at least one argument must be a Tensor" 84 | 85 | # Force variances to be Tensors. Broadcasting helps convert scalars to 86 | # Tensors, but it does not work for torch.exp(). 87 | logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] 88 | 89 | return 0.5 * ( 90 | -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 91 | ) 92 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/utils/decode.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | 4 | import argbind 5 | import numpy as np 6 | import torch 7 | from audiotools import AudioSignal 8 | from tqdm import tqdm 9 | 10 | from ..model import DACFile 11 | from . import load_model 12 | 13 | warnings.filterwarnings("ignore", category=UserWarning) 14 | 15 | 16 | @argbind.bind(group="decode", positional=True, without_prefix=True) 17 | @torch.inference_mode() 18 | @torch.no_grad() 19 | def decode( 20 | input: str, 21 | output: str = "", 22 | weights_path: str = "", 23 | model_tag: str = "latest", 24 | model_bitrate: str = "8kbps", 25 | device: str = "cuda", 26 | model_type: str = "44khz", 27 | verbose: bool = False, 28 | ): 29 | """Decode audio from codes. 30 | 31 | Parameters 32 | ---------- 33 | input : str 34 | Path to input directory or file 35 | output : str, optional 36 | Path to output directory, by default "". 37 | If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. 38 | weights_path : str, optional 39 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the 40 | model_tag and model_type. 41 | model_tag : str, optional 42 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. 43 | model_bitrate: str 44 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 45 | device : str, optional 46 | Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU. 47 | model_type : str, optional 48 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. 49 | """ 50 | generator = load_model( 51 | model_type=model_type, 52 | model_bitrate=model_bitrate, 53 | tag=model_tag, 54 | load_path=weights_path, 55 | ) 56 | generator.to(device) 57 | generator.eval() 58 | 59 | # Find all .dac files in input directory 60 | _input = Path(input) 61 | input_files = list(_input.glob("**/*.dac")) 62 | 63 | # If input is a .dac file, add it to the list 64 | if _input.suffix == ".dac": 65 | input_files.append(_input) 66 | 67 | # Create output directory 68 | output = Path(output) 69 | output.mkdir(parents=True, exist_ok=True) 70 | 71 | for i in tqdm(range(len(input_files)), desc=f"Decoding files"): 72 | # Load file 73 | artifact = DACFile.load(input_files[i]) 74 | 75 | # Reconstruct audio from codes 76 | recons = generator.decompress(artifact, verbose=verbose) 77 | 78 | # Compute output path 79 | relative_path = input_files[i].relative_to(input) 80 | output_dir = output / relative_path.parent 81 | if not relative_path.name: 82 | output_dir = output 83 | relative_path = input_files[i] 84 | output_name = relative_path.with_suffix(".wav").name 85 | output_path = output_dir / output_name 86 | output_path.parent.mkdir(parents=True, exist_ok=True) 87 | 88 | # Write to file 89 | recons.write(output_path) 90 | 91 | 92 | if __name__ == "__main__": 93 | args = argbind.parse_args() 94 | with argbind.scope(args): 95 | decode() 96 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/media_utils.py: -------------------------------------------------------------------------------- 1 | """Media utilities for audio/video processing.""" 2 | 3 | import os 4 | import subprocess 5 | from pathlib import Path 6 | from typing import Optional 7 | 8 | from loguru import logger 9 | 10 | 11 | class MediaProcessingError(Exception): 12 | """Exception raised for media processing errors.""" 13 | pass 14 | 15 | 16 | def merge_audio_video( 17 | audio_path: str, 18 | video_path: str, 19 | output_path: str, 20 | overwrite: bool = True, 21 | quality: str = "high" 22 | ) -> str: 23 | """ 24 | Merge audio and video files using ffmpeg. 25 | 26 | Args: 27 | audio_path: Path to input audio file 28 | video_path: Path to input video file 29 | output_path: Path for output video file 30 | overwrite: Whether to overwrite existing output file 31 | quality: Quality setting ('high', 'medium', 'low') 32 | 33 | Returns: 34 | Path to the output file 35 | 36 | Raises: 37 | MediaProcessingError: If input files don't exist or ffmpeg fails 38 | FileNotFoundError: If ffmpeg is not installed 39 | """ 40 | # Validate input files 41 | if not os.path.exists(audio_path): 42 | raise MediaProcessingError(f"Audio file not found: {audio_path}") 43 | if not os.path.exists(video_path): 44 | raise MediaProcessingError(f"Video file not found: {video_path}") 45 | 46 | # Create output directory if needed 47 | output_dir = Path(output_path).parent 48 | output_dir.mkdir(parents=True, exist_ok=True) 49 | 50 | # Quality settings 51 | quality_settings = { 52 | "high": ["-b:a", "192k"], 53 | "medium": ["-b:a", "128k"], 54 | "low": ["-b:a", "96k"] 55 | } 56 | 57 | # Build ffmpeg command 58 | ffmpeg_command = [ 59 | "ffmpeg", 60 | "-i", video_path, 61 | "-i", audio_path, 62 | "-c:v", "copy", 63 | "-c:a", "aac", 64 | "-ac", "2", 65 | "-af", "pan=stereo|c0=c0|c1=c0", 66 | "-map", "0:v:0", 67 | "-map", "1:a:0", 68 | *quality_settings.get(quality, quality_settings["high"]), 69 | ] 70 | 71 | if overwrite: 72 | ffmpeg_command.append("-y") 73 | 74 | ffmpeg_command.append(output_path) 75 | 76 | try: 77 | logger.info(f"Merging audio '{audio_path}' with video '{video_path}'") 78 | process = subprocess.Popen( 79 | ffmpeg_command, 80 | stdout=subprocess.PIPE, 81 | stderr=subprocess.PIPE, 82 | text=True 83 | ) 84 | stdout, stderr = process.communicate() 85 | 86 | if process.returncode != 0: 87 | error_msg = f"FFmpeg failed with return code {process.returncode}: {stderr}" 88 | logger.error(error_msg) 89 | raise MediaProcessingError(error_msg) 90 | else: 91 | logger.info(f"Successfully merged video saved to: {output_path}") 92 | 93 | except FileNotFoundError: 94 | raise FileNotFoundError( 95 | "ffmpeg not found. Please install ffmpeg: " 96 | "https://ffmpeg.org/download.html" 97 | ) 98 | except Exception as e: 99 | raise MediaProcessingError(f"Unexpected error during media processing: {e}") 100 | 101 | return output_path 102 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/utils/encode.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from pathlib import Path 4 | 5 | import argbind 6 | import numpy as np 7 | import torch 8 | from audiotools import AudioSignal 9 | from audiotools.core import util 10 | from tqdm import tqdm 11 | 12 | from . import load_model 13 | 14 | warnings.filterwarnings("ignore", category=UserWarning) 15 | 16 | 17 | @argbind.bind(group="encode", positional=True, without_prefix=True) 18 | @torch.inference_mode() 19 | @torch.no_grad() 20 | def encode( 21 | input: str, 22 | output: str = "", 23 | weights_path: str = "", 24 | model_tag: str = "latest", 25 | model_bitrate: str = "8kbps", 26 | n_quantizers: int = None, 27 | device: str = "cuda", 28 | model_type: str = "44khz", 29 | win_duration: float = 5.0, 30 | verbose: bool = False, 31 | ): 32 | """Encode audio files in input path to .dac format. 33 | 34 | Parameters 35 | ---------- 36 | input : str 37 | Path to input audio file or directory 38 | output : str, optional 39 | Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. 40 | weights_path : str, optional 41 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the 42 | model_tag and model_type. 43 | model_tag : str, optional 44 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. 45 | model_bitrate: str 46 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 47 | n_quantizers : int, optional 48 | Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate. 49 | device : str, optional 50 | Device to use, by default "cuda" 51 | model_type : str, optional 52 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. 53 | """ 54 | generator = load_model( 55 | model_type=model_type, 56 | model_bitrate=model_bitrate, 57 | tag=model_tag, 58 | load_path=weights_path, 59 | ) 60 | generator.to(device) 61 | generator.eval() 62 | kwargs = {"n_quantizers": n_quantizers} 63 | 64 | # Find all audio files in input path 65 | input = Path(input) 66 | audio_files = util.find_audio(input) 67 | 68 | output = Path(output) 69 | output.mkdir(parents=True, exist_ok=True) 70 | 71 | for i in tqdm(range(len(audio_files)), desc="Encoding files"): 72 | # Load file 73 | signal = AudioSignal(audio_files[i]) 74 | 75 | # Encode audio to .dac format 76 | artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs) 77 | 78 | # Compute output path 79 | relative_path = audio_files[i].relative_to(input) 80 | output_dir = output / relative_path.parent 81 | if not relative_path.name: 82 | output_dir = output 83 | relative_path = audio_files[i] 84 | output_name = relative_path.with_suffix(".dac").name 85 | output_path = output_dir / output_name 86 | output_path.parent.mkdir(parents=True, exist_ok=True) 87 | 88 | artifact.save(output_path) 89 | 90 | 91 | if __name__ == "__main__": 92 | args = argbind.parse_args() 93 | with argbind.scope(args): 94 | encode() 95 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import argbind 4 | from audiotools import ml 5 | 6 | from ..model import DAC 7 | Accelerator = ml.Accelerator 8 | 9 | __MODEL_LATEST_TAGS__ = { 10 | ("44khz", "8kbps"): "0.0.1", 11 | ("24khz", "8kbps"): "0.0.4", 12 | ("16khz", "8kbps"): "0.0.5", 13 | ("44khz", "16kbps"): "1.0.0", 14 | } 15 | 16 | __MODEL_URLS__ = { 17 | ( 18 | "44khz", 19 | "0.0.1", 20 | "8kbps", 21 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth", 22 | ( 23 | "24khz", 24 | "0.0.4", 25 | "8kbps", 26 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth", 27 | ( 28 | "16khz", 29 | "0.0.5", 30 | "8kbps", 31 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth", 32 | ( 33 | "44khz", 34 | "1.0.0", 35 | "16kbps", 36 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth", 37 | } 38 | 39 | 40 | @argbind.bind(group="download", positional=True, without_prefix=True) 41 | def download( 42 | model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest" 43 | ): 44 | """ 45 | Function that downloads the weights file from URL if a local cache is not found. 46 | 47 | Parameters 48 | ---------- 49 | model_type : str 50 | The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". 51 | model_bitrate: str 52 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 53 | Only 44khz model supports 16kbps. 54 | tag : str 55 | The tag of the model to download. Defaults to "latest". 56 | 57 | Returns 58 | ------- 59 | Path 60 | Directory path required to load model via audiotools. 61 | """ 62 | model_type = model_type.lower() 63 | tag = tag.lower() 64 | 65 | assert model_type in [ 66 | "44khz", 67 | "24khz", 68 | "16khz", 69 | ], "model_type must be one of '44khz', '24khz', or '16khz'" 70 | 71 | assert model_bitrate in [ 72 | "8kbps", 73 | "16kbps", 74 | ], "model_bitrate must be one of '8kbps', or '16kbps'" 75 | 76 | if tag == "latest": 77 | tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)] 78 | 79 | download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None) 80 | 81 | if download_link is None: 82 | raise ValueError( 83 | f"Could not find model with tag {tag} and model type {model_type}" 84 | ) 85 | 86 | local_path = ( 87 | Path.home() 88 | / ".cache" 89 | / "descript" 90 | / "dac" 91 | / f"weights_{model_type}_{model_bitrate}_{tag}.pth" 92 | ) 93 | if not local_path.exists(): 94 | local_path.parent.mkdir(parents=True, exist_ok=True) 95 | 96 | # Download the model 97 | import requests 98 | 99 | response = requests.get(download_link) 100 | 101 | if response.status_code != 200: 102 | raise ValueError( 103 | f"Could not download model. Received response code {response.status_code}" 104 | ) 105 | local_path.write_bytes(response.content) 106 | 107 | return local_path 108 | 109 | 110 | def load_model( 111 | model_type: str = "44khz", 112 | model_bitrate: str = "8kbps", 113 | tag: str = "latest", 114 | load_path: str = None, 115 | ): 116 | if not load_path: 117 | load_path = download( 118 | model_type=model_type, model_bitrate=model_bitrate, tag=tag 119 | ) 120 | generator = DAC.load(load_path) 121 | return generator 122 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | """Configuration utilities for the HunyuanVideo-Foley project.""" 2 | 3 | import yaml 4 | from pathlib import Path 5 | from typing import Any, Dict, List, Union 6 | 7 | class AttributeDict: 8 | 9 | def __init__(self, data: Union[Dict, List, Any]): 10 | if isinstance(data, dict): 11 | for key, value in data.items(): 12 | if isinstance(value, (dict, list)): 13 | value = AttributeDict(value) 14 | setattr(self, self._sanitize_key(key), value) 15 | elif isinstance(data, list): 16 | self._list = [AttributeDict(item) if isinstance(item, (dict, list)) else item 17 | for item in data] 18 | else: 19 | self._value = data 20 | 21 | def _sanitize_key(self, key: str) -> str: 22 | import re 23 | sanitized = re.sub(r'[^a-zA-Z0-9_]', '_', str(key)) 24 | if sanitized[0].isdigit(): 25 | sanitized = f'_{sanitized}' 26 | return sanitized 27 | 28 | def __getitem__(self, key): 29 | if hasattr(self, '_list'): 30 | return self._list[key] 31 | return getattr(self, self._sanitize_key(key)) 32 | 33 | def __setitem__(self, key, value): 34 | if hasattr(self, '_list'): 35 | self._list[key] = value 36 | else: 37 | setattr(self, self._sanitize_key(key), value) 38 | 39 | def __iter__(self): 40 | if hasattr(self, '_list'): 41 | return iter(self._list) 42 | return iter(self.__dict__.keys()) 43 | 44 | def __len__(self): 45 | if hasattr(self, '_list'): 46 | return len(self._list) 47 | return len(self.__dict__) 48 | 49 | def get(self, key, default=None): 50 | try: 51 | return self[key] 52 | except (KeyError, AttributeError, IndexError): 53 | return default 54 | 55 | def keys(self): 56 | if hasattr(self, '_list'): 57 | return range(len(self._list)) 58 | elif hasattr(self, '_value'): 59 | return [] 60 | else: 61 | return [key for key in self.__dict__.keys() if not key.startswith('_')] 62 | 63 | def values(self): 64 | if hasattr(self, '_list'): 65 | return self._list 66 | elif hasattr(self, '_value'): 67 | return [self._value] 68 | else: 69 | return [value for key, value in self.__dict__.items() if not key.startswith('_')] 70 | 71 | def items(self): 72 | if hasattr(self, '_list'): 73 | return enumerate(self._list) 74 | elif hasattr(self, '_value'): 75 | return [] 76 | else: 77 | return [(key, value) for key, value in self.__dict__.items() if not key.startswith('_')] 78 | 79 | def __repr__(self): 80 | if hasattr(self, '_list'): 81 | return f"AttributeDict({self._list})" 82 | elif hasattr(self, '_value'): 83 | return f"AttributeDict({self._value})" 84 | return f"AttributeDict({dict(self.__dict__)})" 85 | 86 | def to_dict(self) -> Union[Dict, List, Any]: 87 | if hasattr(self, '_list'): 88 | return [item.to_dict() if isinstance(item, AttributeDict) else item 89 | for item in self._list] 90 | elif hasattr(self, '_value'): 91 | return self._value 92 | else: 93 | result = {} 94 | for key, value in self.__dict__.items(): 95 | if isinstance(value, AttributeDict): 96 | result[key] = value.to_dict() 97 | else: 98 | result[key] = value 99 | return result 100 | 101 | def load_yaml(file_path: str, encoding: str = 'utf-8') -> AttributeDict: 102 | try: 103 | with open(file_path, 'r', encoding=encoding) as file: 104 | data = yaml.safe_load(file) 105 | return AttributeDict(data) 106 | except FileNotFoundError: 107 | raise FileNotFoundError(f"YAML file not found: {file_path}") 108 | except yaml.YAMLError as e: 109 | raise yaml.YAMLError(f"YAML format error: {e}") 110 | -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | """ 2 | Installation script for ComfyUI HunyuanVideo-Foley Custom Node 3 | """ 4 | 5 | import os 6 | import sys 7 | import subprocess 8 | import pkg_resources 9 | from pathlib import Path 10 | 11 | def check_and_install_requirements(): 12 | """Check and install required packages""" 13 | requirements_file = Path(__file__).parent / "requirements.txt" 14 | 15 | if not requirements_file.exists(): 16 | print("Requirements file not found!") 17 | return False 18 | 19 | try: 20 | print("Checking and installing requirements...") 21 | 22 | # Read requirements 23 | with open(requirements_file, 'r') as f: 24 | requirements = f.read().splitlines() 25 | 26 | # Filter out comments and empty lines 27 | requirements = [line.strip() for line in requirements 28 | if line.strip() and not line.strip().startswith('#')] 29 | 30 | # Install packages 31 | for requirement in requirements: 32 | try: 33 | # Skip git+ requirements for now (they need special handling) 34 | if requirement.startswith('git+'): 35 | print(f"Installing git requirement: {requirement}") 36 | subprocess.check_call([sys.executable, '-m', 'pip', 'install', requirement]) 37 | else: 38 | # Check if package is already installed 39 | try: 40 | pkg_resources.require([requirement]) 41 | print(f"✓ {requirement} already installed") 42 | except pkg_resources.DistributionNotFound: 43 | print(f"Installing {requirement}...") 44 | subprocess.check_call([sys.executable, '-m', 'pip', 'install', requirement]) 45 | except pkg_resources.VersionConflict: 46 | print(f"Updating {requirement}...") 47 | subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--upgrade', requirement]) 48 | 49 | except subprocess.CalledProcessError as e: 50 | print(f"Failed to install {requirement}: {e}") 51 | return False 52 | 53 | print("✅ All requirements installed successfully!") 54 | return True 55 | 56 | except Exception as e: 57 | print(f"Error installing requirements: {e}") 58 | return False 59 | 60 | def setup_model_directories(): 61 | """Create necessary model directories""" 62 | base_dir = Path(__file__).parent.parent.parent # Go up to ComfyUI root 63 | 64 | # Create ComfyUI/models/foley directory for automatic downloads 65 | foley_models_dir = base_dir / "models" / "foley" 66 | foley_models_dir.mkdir(parents=True, exist_ok=True) 67 | print(f"✓ Created ComfyUI models directory: {foley_models_dir}") 68 | 69 | # Also create local fallback directories 70 | node_dir = Path(__file__).parent 71 | local_dirs = [ 72 | node_dir / "pretrained_models", 73 | node_dir / "configs" 74 | ] 75 | 76 | for dir_path in local_dirs: 77 | dir_path.mkdir(exist_ok=True) 78 | print(f"✓ Created local directory: {dir_path}") 79 | 80 | def main(): 81 | """Main installation function""" 82 | print("🚀 Installing ComfyUI HunyuanVideo-Foley Custom Node...") 83 | 84 | # Install requirements 85 | if not check_and_install_requirements(): 86 | print("❌ Failed to install requirements") 87 | return False 88 | 89 | # Setup directories 90 | setup_model_directories() 91 | 92 | print("📁 Directory structure created") 93 | print("📋 Installation completed!") 94 | print() 95 | print("📌 Next steps:") 96 | print("1. Restart ComfyUI to load the custom nodes") 97 | print("2. Models will be automatically downloaded when you first use the node") 98 | print("3. Alternatively, manually download models and place them in ComfyUI/models/foley/") 99 | print("4. Model URLs are configured in model_urls.py (can be updated as needed)") 100 | print() 101 | 102 | return True 103 | 104 | if __name__ == "__main__": 105 | main() -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/helper.py: -------------------------------------------------------------------------------- 1 | import collections.abc 2 | from itertools import repeat 3 | import importlib 4 | import yaml 5 | import time 6 | 7 | def default(value, default_val): 8 | return default_val if value is None else value 9 | 10 | 11 | def default_dtype(value, default_val): 12 | if value is not None: 13 | assert isinstance(value, type(default_val)), f"Expect {type(default_val)}, got {type(value)}." 14 | return value 15 | return default_val 16 | 17 | 18 | def repeat_interleave(lst, num_repeats): 19 | return [item for item in lst for _ in range(num_repeats)] 20 | 21 | 22 | def _ntuple(n): 23 | def parse(x): 24 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 25 | x = tuple(x) 26 | if len(x) == 1: 27 | x = tuple(repeat(x[0], n)) 28 | return x 29 | return tuple(repeat(x, n)) 30 | 31 | return parse 32 | 33 | 34 | to_1tuple = _ntuple(1) 35 | to_2tuple = _ntuple(2) 36 | to_3tuple = _ntuple(3) 37 | to_4tuple = _ntuple(4) 38 | 39 | 40 | def as_tuple(x): 41 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 42 | return tuple(x) 43 | if x is None or isinstance(x, (int, float, str)): 44 | return (x,) 45 | else: 46 | raise ValueError(f"Unknown type {type(x)}") 47 | 48 | 49 | def as_list_of_2tuple(x): 50 | x = as_tuple(x) 51 | if len(x) == 1: 52 | x = (x[0], x[0]) 53 | assert len(x) % 2 == 0, f"Expect even length, got {len(x)}." 54 | lst = [] 55 | for i in range(0, len(x), 2): 56 | lst.append((x[i], x[i + 1])) 57 | return lst 58 | 59 | 60 | def find_multiple(n: int, k: int) -> int: 61 | assert k > 0 62 | if n % k == 0: 63 | return n 64 | return n - (n % k) + k 65 | 66 | 67 | def merge_dicts(dict1, dict2): 68 | for key, value in dict2.items(): 69 | if key in dict1 and isinstance(dict1[key], dict) and isinstance(value, dict): 70 | merge_dicts(dict1[key], value) 71 | else: 72 | dict1[key] = value 73 | return dict1 74 | 75 | 76 | def merge_yaml_files(file_list): 77 | merged_config = {} 78 | 79 | for file in file_list: 80 | with open(file, "r", encoding="utf-8") as f: 81 | config = yaml.safe_load(f) 82 | if config: 83 | # Remove the first level 84 | for key, value in config.items(): 85 | if isinstance(value, dict): 86 | merged_config = merge_dicts(merged_config, value) 87 | else: 88 | merged_config[key] = value 89 | 90 | return merged_config 91 | 92 | 93 | def merge_dict(file_list): 94 | merged_config = {} 95 | 96 | for file in file_list: 97 | with open(file, "r", encoding="utf-8") as f: 98 | config = yaml.safe_load(f) 99 | if config: 100 | merged_config = merge_dicts(merged_config, config) 101 | 102 | return merged_config 103 | 104 | 105 | def get_obj_from_str(string, reload=False): 106 | module, cls = string.rsplit(".", 1) 107 | if reload: 108 | module_imp = importlib.import_module(module) 109 | importlib.reload(module_imp) 110 | return getattr(importlib.import_module(module, package=None), cls) 111 | 112 | 113 | def readable_time(seconds): 114 | """ Convert time seconds to a readable format: DD Days, HH Hours, MM Minutes, SS Seconds """ 115 | seconds = int(seconds) 116 | days, seconds = divmod(seconds, 86400) 117 | hours, seconds = divmod(seconds, 3600) 118 | minutes, seconds = divmod(seconds, 60) 119 | if days > 0: 120 | return f"{days} Days, {hours} Hours, {minutes} Minutes, {seconds} Seconds" 121 | if hours > 0: 122 | return f"{hours} Hours, {minutes} Minutes, {seconds} Seconds" 123 | if minutes > 0: 124 | return f"{minutes} Minutes, {seconds} Seconds" 125 | return f"{seconds} Seconds" 126 | 127 | 128 | def get_obj_from_cfg(cfg, reload=False): 129 | if isinstance(cfg, str): 130 | return get_obj_from_str(cfg, reload) 131 | elif isinstance(cfg, (list, tuple,)): 132 | return tuple([get_obj_from_str(c, reload) for c in cfg]) 133 | else: 134 | raise NotImplementedError(f"Not implemented for {type(cfg)}.") 135 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/utils.py: -------------------------------------------------------------------------------- 1 | from hashlib import md5 2 | from pathlib import Path 3 | import subprocess 4 | 5 | import requests 6 | from tqdm import tqdm 7 | 8 | PARENT_LINK = "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a" 9 | FNAME2LINK = { 10 | # S3: Synchability: AudioSet (run 2) 11 | "24-01-22T20-34-52.pt": f"{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/24-01-22T20-34-52.pt", 12 | "cfg-24-01-22T20-34-52.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/cfg-24-01-22T20-34-52.yaml", 13 | # S2: Synchformer: AudioSet (run 2) 14 | "24-01-04T16-39-21.pt": f"{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/24-01-04T16-39-21.pt", 15 | "cfg-24-01-04T16-39-21.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml", 16 | # S2: Synchformer: AudioSet (run 1) 17 | "23-08-28T11-23-23.pt": f"{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/23-08-28T11-23-23.pt", 18 | "cfg-23-08-28T11-23-23.yaml": f"{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/cfg-23-08-28T11-23-23.yaml", 19 | # S2: Synchformer: LRS3 (run 2) 20 | "23-12-23T18-33-57.pt": f"{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/23-12-23T18-33-57.pt", 21 | "cfg-23-12-23T18-33-57.yaml": f"{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/cfg-23-12-23T18-33-57.yaml", 22 | # S2: Synchformer: VGS (run 2) 23 | "24-01-02T10-00-53.pt": f"{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/24-01-02T10-00-53.pt", 24 | "cfg-24-01-02T10-00-53.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/cfg-24-01-02T10-00-53.yaml", 25 | # SparseSync: ft VGGSound-Full 26 | "22-09-21T21-00-52.pt": f"{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/22-09-21T21-00-52.pt", 27 | "cfg-22-09-21T21-00-52.yaml": f"{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/cfg-22-09-21T21-00-52.yaml", 28 | # SparseSync: ft VGGSound-Sparse 29 | "22-07-28T15-49-45.pt": f"{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/22-07-28T15-49-45.pt", 30 | "cfg-22-07-28T15-49-45.yaml": f"{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/cfg-22-07-28T15-49-45.yaml", 31 | # SparseSync: only pt on LRS3 32 | "22-07-13T22-25-49.pt": f"{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/22-07-13T22-25-49.pt", 33 | "cfg-22-07-13T22-25-49.yaml": f"{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/cfg-22-07-13T22-25-49.yaml", 34 | # SparseSync: feature extractors 35 | "ResNetAudio-22-08-04T09-51-04.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-04T09-51-04.pt", # 2s 36 | "ResNetAudio-22-08-03T23-14-49.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-49.pt", # 3s 37 | "ResNetAudio-22-08-03T23-14-28.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-28.pt", # 4s 38 | "ResNetAudio-22-06-24T08-10-33.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T08-10-33.pt", # 5s 39 | "ResNetAudio-22-06-24T17-31-07.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T17-31-07.pt", # 6s 40 | "ResNetAudio-22-06-24T23-57-11.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T23-57-11.pt", # 7s 41 | "ResNetAudio-22-06-25T04-35-42.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-25T04-35-42.pt", # 8s 42 | } 43 | 44 | 45 | def check_if_file_exists_else_download(path, fname2link=FNAME2LINK, chunk_size=1024): 46 | """Checks if file exists, if not downloads it from the link to the path""" 47 | path = Path(path) 48 | if not path.exists(): 49 | path.parent.mkdir(exist_ok=True, parents=True) 50 | link = fname2link.get(path.name, None) 51 | if link is None: 52 | raise ValueError( 53 | f"Cant find the checkpoint file: {path}.", f"Please download it manually and ensure the path exists." 54 | ) 55 | with requests.get(fname2link[path.name], stream=True) as r: 56 | total_size = int(r.headers.get("content-length", 0)) 57 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 58 | with open(path, "wb") as f: 59 | for data in r.iter_content(chunk_size=chunk_size): 60 | if data: 61 | f.write(data) 62 | pbar.update(chunk_size) 63 | 64 | 65 | def which_ffmpeg() -> str: 66 | """Determines the path to ffmpeg library 67 | Returns: 68 | str -- path to the library 69 | """ 70 | result = subprocess.run(["which", "ffmpeg"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 71 | ffmpeg_path = result.stdout.decode("utf-8").replace("\n", "") 72 | return ffmpeg_path 73 | 74 | 75 | def get_md5sum(path): 76 | hash_md5 = md5() 77 | with open(path, "rb") as f: 78 | for chunk in iter(lambda: f.read(4096 * 8), b""): 79 | hash_md5.update(chunk) 80 | md5sum = hash_md5.hexdigest() 81 | return md5sum 82 | 83 | 84 | class Config: 85 | def __init__(self, **kwargs): 86 | for k, v in kwargs.items(): 87 | setattr(self, k, v) 88 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/embed_layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from ...utils.helper import to_2tuple, to_1tuple 6 | 7 | class PatchEmbed1D(nn.Module): 8 | """1D Audio to Patch Embedding 9 | 10 | A convolution based approach to patchifying a 1D audio w/ embedding projection. 11 | 12 | Based on the impl in https://github.com/google-research/vision_transformer 13 | 14 | Hacked together by / Copyright 2020 Ross Wightman 15 | """ 16 | 17 | def __init__( 18 | self, 19 | patch_size=1, 20 | in_chans=768, 21 | embed_dim=768, 22 | norm_layer=None, 23 | flatten=True, 24 | bias=True, 25 | dtype=None, 26 | device=None, 27 | ): 28 | factory_kwargs = {"dtype": dtype, "device": device} 29 | super().__init__() 30 | patch_size = to_1tuple(patch_size) 31 | self.patch_size = patch_size 32 | self.flatten = flatten 33 | 34 | self.proj = nn.Conv1d( 35 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs 36 | ) 37 | nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1)) 38 | if bias: 39 | nn.init.zeros_(self.proj.bias) 40 | 41 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 42 | 43 | def forward(self, x): 44 | assert ( 45 | x.shape[2] % self.patch_size[0] == 0 46 | ), f"The patch_size of {self.patch_size[0]} must be divisible by the token number ({x.shape[2]}) of x." 47 | 48 | x = self.proj(x) 49 | if self.flatten: 50 | x = x.transpose(1, 2) # BCN -> BNC 51 | x = self.norm(x) 52 | return x 53 | 54 | 55 | class ConditionProjection(nn.Module): 56 | """ 57 | Projects condition embeddings. Also handles dropout for classifier-free guidance. 58 | 59 | Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py 60 | """ 61 | 62 | def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): 63 | factory_kwargs = {'dtype': dtype, 'device': device} 64 | super().__init__() 65 | self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs) 66 | self.act_1 = act_layer() 67 | self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs) 68 | 69 | def forward(self, caption): 70 | hidden_states = self.linear_1(caption) 71 | hidden_states = self.act_1(hidden_states) 72 | hidden_states = self.linear_2(hidden_states) 73 | return hidden_states 74 | 75 | 76 | def timestep_embedding(t, dim, max_period=10000): 77 | """ 78 | Create sinusoidal timestep embeddings. 79 | 80 | Args: 81 | t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. 82 | dim (int): the dimension of the output. 83 | max_period (int): controls the minimum frequency of the embeddings. 84 | 85 | Returns: 86 | embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. 87 | 88 | .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 89 | """ 90 | half = dim // 2 91 | freqs = torch.exp( 92 | -math.log(max_period) 93 | * torch.arange(start=0, end=half, dtype=torch.float32) 94 | / half 95 | ).to(device=t.device) 96 | args = t[:, None].float() * freqs[None] 97 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 98 | if dim % 2: 99 | embedding = torch.cat( 100 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 101 | ) 102 | return embedding 103 | 104 | 105 | class TimestepEmbedder(nn.Module): 106 | """ 107 | Embeds scalar timesteps into vector representations. 108 | """ 109 | def __init__(self, 110 | hidden_size, 111 | act_layer, 112 | frequency_embedding_size=256, 113 | max_period=10000, 114 | out_size=None, 115 | dtype=None, 116 | device=None 117 | ): 118 | factory_kwargs = {'dtype': dtype, 'device': device} 119 | super().__init__() 120 | self.frequency_embedding_size = frequency_embedding_size 121 | self.max_period = max_period 122 | if out_size is None: 123 | out_size = hidden_size 124 | 125 | self.mlp = nn.Sequential( 126 | nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs), 127 | act_layer(), 128 | nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), 129 | ) 130 | nn.init.normal_(self.mlp[0].weight, std=0.02) 131 | nn.init.normal_(self.mlp[2].weight, std=0.02) 132 | 133 | def forward(self, t): 134 | t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) 135 | t_emb = self.mlp(t_freq) 136 | return t_emb 137 | -------------------------------------------------------------------------------- /download_models_manual.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Manual download helper for HunyuanVideo-Foley models 4 | Run this script directly if the automatic download fails in ComfyUI 5 | """ 6 | 7 | import os 8 | import sys 9 | from pathlib import Path 10 | import urllib.request 11 | import time 12 | 13 | # Model download URLs 14 | MODELS = [ 15 | { 16 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/hunyuanvideo_foley.pth", 17 | "filename": "hunyuanvideo_foley.pth", 18 | "size_gb": 10.3, 19 | "description": "Main HunyuanVideo-Foley model" 20 | }, 21 | { 22 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/synchformer_state_dict.pth", 23 | "filename": "synchformer_state_dict.pth", 24 | "size_gb": 0.95, 25 | "description": "Synchformer model weights" 26 | }, 27 | { 28 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/vae_128d_48k.pth", 29 | "filename": "vae_128d_48k.pth", 30 | "size_gb": 1.49, 31 | "description": "VAE model weights" 32 | } 33 | ] 34 | 35 | def download_with_progress(url, dest_path): 36 | """Download file with progress display""" 37 | def progress_hook(block_num, block_size, total_size): 38 | if total_size > 0: 39 | downloaded = block_num * block_size 40 | percent = min(100, (downloaded * 100) // total_size) 41 | size_mb = downloaded / (1024 * 1024) 42 | total_mb = total_size / (1024 * 1024) 43 | 44 | # Print progress 45 | bar_len = 40 46 | filled_len = int(bar_len * percent // 100) 47 | bar = '█' * filled_len + '-' * (bar_len - filled_len) 48 | 49 | sys.stdout.write(f'\r[{bar}] {percent}% ({size_mb:.1f}/{total_mb:.1f} MB)') 50 | sys.stdout.flush() 51 | 52 | try: 53 | urllib.request.urlretrieve(url, dest_path, progress_hook) 54 | print() # New line after progress 55 | return True 56 | except Exception as e: 57 | print(f"\nError: {e}") 58 | return False 59 | 60 | def main(): 61 | # Determine ComfyUI models directory 62 | comfyui_root = Path(__file__).parent.parent.parent # Go up to ComfyUI root 63 | models_dir = comfyui_root / "models" / "foley" / "hunyuanvideo-foley-xxl" 64 | 65 | print("=" * 60) 66 | print("HunyuanVideo-Foley Model Downloader") 67 | print("=" * 60) 68 | print(f"\nModels will be downloaded to:") 69 | print(f" {models_dir}") 70 | 71 | # Create directory if it doesn't exist 72 | models_dir.mkdir(parents=True, exist_ok=True) 73 | 74 | # Check disk space 75 | import shutil 76 | stat = shutil.disk_usage(models_dir) 77 | available_gb = stat.free / (1024 ** 3) 78 | required_gb = sum(m["size_gb"] for m in MODELS) + 1 # Add 1GB buffer 79 | 80 | print(f"\nDisk space available: {available_gb:.1f} GB") 81 | print(f"Space required: {required_gb:.1f} GB") 82 | 83 | if available_gb < required_gb: 84 | print(f"\n⚠️ WARNING: Insufficient disk space!") 85 | print(f"Please free up at least {required_gb - available_gb:.1f} GB before continuing.") 86 | response = input("\nContinue anyway? (y/n): ") 87 | if response.lower() != 'y': 88 | return 89 | 90 | print("\n" + "=" * 60) 91 | 92 | # Download each model 93 | for i, model_info in enumerate(MODELS, 1): 94 | model_path = models_dir / model_info["filename"] 95 | 96 | print(f"\n[{i}/{len(MODELS)}] {model_info['description']}") 97 | print(f" File: {model_info['filename']} ({model_info['size_gb']} GB)") 98 | 99 | # Check if already downloaded 100 | if model_path.exists() and model_path.stat().st_size > 100 * 1024 * 1024: 101 | size_gb = model_path.stat().st_size / (1024 ** 3) 102 | print(f" ✓ Already downloaded ({size_gb:.2f} GB)") 103 | continue 104 | 105 | print(f" Downloading from: {model_info['url']}") 106 | 107 | # Try downloading with retries 108 | max_retries = 3 109 | for attempt in range(max_retries): 110 | if attempt > 0: 111 | print(f" Retry {attempt}/{max_retries - 1}...") 112 | time.sleep(5) 113 | 114 | success = download_with_progress(model_info["url"], model_path) 115 | if success: 116 | size_gb = model_path.stat().st_size / (1024 ** 3) 117 | print(f" ✓ Downloaded successfully ({size_gb:.2f} GB)") 118 | break 119 | else: 120 | if attempt == max_retries - 1: 121 | print(f" ✗ Failed to download after {max_retries} attempts") 122 | print(f"\n You can manually download from:") 123 | print(f" {model_info['url']}") 124 | print(f" And place it at:") 125 | print(f" {model_path}") 126 | 127 | print("\n" + "=" * 60) 128 | print("Download process completed!") 129 | print("\nYou can now use the HunyuanVideo-Foley node in ComfyUI.") 130 | print("=" * 60) 131 | 132 | if __name__ == "__main__": 133 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI HunyuanVideo-Foley Custom Node 2 | 3 | This is a ComfyUI custom node wrapper for the HunyuanVideo-Foley model, which generates realistic audio from video and text descriptions. 4 | 5 | 6 | ## ✨大佬的原插件已更新,安装原插件就行 (https://github.com/if-ai/ComfyUI_HunyuanVideoFoley) 7 | ## The original plug-in of the boss has been updated, just install the original plug-in 8 | image 9 | ## 模型卸载 Model Unloading 10 | bd36754d74ffe001890d76cfd0b4211 11 | 12 | 13 | 14 | 15 | ## Features 16 | 17 | - **Text-Video-to-Audio Synthesis**: Generate realistic audio that matches your video content 18 | - **Flexible Text Prompts**: Use optional text descriptions to guide audio generation 19 | - **Multiple Samples**: Generate up to 6 different audio variations per inference 20 | - **Configurable Parameters**: Control guidance scale, inference steps, and sampling 21 | - **Seed Control**: Reproducible results with seed parameter 22 | - **Model Caching**: Efficient model loading and reuse across generations 23 | - **Automatic Model Downloads**: Models are automatically downloaded to `ComfyUI/models/foley/` when needed 24 | 25 | ## Installation 26 | 27 | 1. **Clone this repository** into your ComfyUI custom_nodes directory: 28 | ```bash 29 | cd ComfyUI/custom_nodes 30 | git clone https://github.com/yichengup/ComfyUI_ycHunyuanVideoFoley.git 31 | ``` 32 | 33 | 2. **Install dependencies**: 34 | ```bash 35 | cd ComfyUI_ycHunyuanVideoFoley 36 | pip install -r requirements.txt 37 | ``` 38 | 39 | 3. **Run the installation script** (recommended): 40 | ```bash 41 | python install.py 42 | ``` 43 | 44 | 4. **Restart ComfyUI** to load the new nodes. 45 | 46 | ### Model Setup 47 | 48 | The models can be obtained in two ways: 49 | 50 | #### Option 1: Automatic Download (Recommended) 51 | - Models will be automatically downloaded to `ComfyUI/models/foley/` when you first run the node 52 | - No manual setup required 53 | - Progress will be shown in the ComfyUI console 54 | 55 | #### Option 2: Manual Download 56 | - Download models from [HuggingFace](https://huggingface.co/tencent/HunyuanVideo-Foley) 57 | - Place models in `ComfyUI/models/foley/` (recommended) or `./pretrained_models/` directory 58 | - Ensure the config file is at `configs/hunyuanvideo-foley-xxl.yaml` 59 | 60 | ## Usage 61 | 62 | ### Node Types 63 | 64 | #### 1. HunyuanVideo-Foley Generator 65 | Main node for generating audio from video and text. 66 | 67 | **Inputs:** 68 | - **video**: Video input (VIDEO type) 69 | - **text_prompt**: Text description of desired audio (STRING) 70 | - **guidance_scale**: CFG scale for generation control (1.0-10.0, default: 4.5) 71 | - **num_inference_steps**: Number of denoising steps (10-100, default: 50) 72 | - **sample_nums**: Number of audio samples to generate (1-6, default: 1) 73 | - **seed**: Random seed for reproducibility (INT) 74 | - **model_path**: Path to pretrained models (optional, leave empty for auto-download) 75 | 76 | **Outputs:** 77 | - **video_with_audio**: Video with generated audio merged (VIDEO) 78 | - **audio_only**: Generated audio file (AUDIO) 79 | - **status_message**: Generation status and info (STRING) 80 | 81 | ## ⚠ Important Limitations 82 | 83 | ### **Frame Count & Duration Limits** 84 | - **Maximum Frames**: 450 frames (hard limit) 85 | - **Maximum Duration**: 15 seconds at 30fps 86 | - **Recommended**: Keep videos ≤15 seconds for best results 87 | 88 | ### **FPS Recommendations** 89 | - **30fps**: Max 15 seconds (450 frames) 90 | - **24fps**: Max 18.75 seconds (450 frames)   91 | - **15fps**: Max 30 seconds (450 frames) 92 | 93 | ### **Long Video Solutions** 94 | For videos longer than 15 seconds: 95 | 1. **Reduce FPS**: Lower FPS allows longer duration within frame limit 96 | 2. **Segment Processing**: Split long videos into 15s segments 97 | 3. **Audio Merging**: Combine generated audio segments in post-processing 98 | 99 | 100 | ## Example Workflow 101 | 102 | 1. **Load Video**: Use a "Load Video" node to input your video file 103 | 2. **Add Generator**: Add the "HunyuanVideo-Foley Generator" node 104 | 3. **Connect Video**: Connect the video output to the generator's video input 105 | 4. **Set Prompt**: Enter a text description (e.g., "A person walks on frozen ice") 106 | 5. **Adjust Settings**: Configure guidance scale, steps, and sample count as needed 107 | 6. **Generate**: Run the workflow to generate audio 108 | 109 | ## Model Requirements 110 | 111 | The node expects the following model structure: 112 | ``` 113 | pretrained_models/ 114 | ├── hunyuanvideo_foley.pth # Main Foley model 115 | ├── vae_128d_48k.pth # DAC VAE model 116 | └── synchformer_state_dict.pth # Synchformer model 117 | 118 | configs/ 119 | └── hunyuanvideo-foley-xxl.yaml # Configuration file 120 | ``` 121 | 122 | 123 | ## License 124 | 125 | This custom node is based on the HunyuanVideo-Foley project. Please check the original project's license terms. 126 | 127 | ## Credits 128 | 129 | Based on the HunyuanVideo-Foley project by Tencent. Original paper and code available at: 130 | - Paper: [HunyuanVideo-Foley: Text-Video-to-Audio Synthesis] 131 | 132 | - Code: [https://github.com/tencent/HunyuanVideo-Foley] 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | -------------------------------------------------------------------------------- /INSTALLATION_GUIDE.md: -------------------------------------------------------------------------------- 1 | # Installation Guide for ComfyUI HunyuanVideo-Foley Custom Node 2 | 3 | ## Overview 4 | 5 | This custom node wraps the HunyuanVideo-Foley model for use in ComfyUI, enabling text-video-to-audio synthesis directly within ComfyUI workflows. 6 | 7 | ## Prerequisites 8 | 9 | - ComfyUI installation 10 | - Python 3.8+ 11 | - CUDA-capable GPU (recommended 16GB+ VRAM) 12 | - At least 32GB system RAM 13 | 14 | ## Step-by-Step Installation 15 | 16 | ### 1. Clone the Custom Node 17 | 18 | ```bash 19 | cd /path/to/ComfyUI/custom_nodes 20 | git clone ComfyUI_HunyuanVideoFoley 21 | cd ComfyUI_HunyuanVideoFoley 22 | ``` 23 | 24 | ### 2. Install Dependencies 25 | 26 | ```bash 27 | # Install Python dependencies 28 | pip install -r requirements.txt 29 | 30 | # Or run the installation script 31 | python install.py 32 | ``` 33 | 34 | ### 3. Download Models 35 | 36 | You need to download the HunyuanVideo-Foley models: 37 | 38 | ```bash 39 | # Option 1: Manual download from HuggingFace 40 | # Visit: https://huggingface.co/tencent/HunyuanVideo-Foley 41 | # Download and place files in the following structure: 42 | 43 | mkdir -p pretrained_models 44 | mkdir -p configs 45 | 46 | # Download these files to pretrained_models/: 47 | # - hunyuanvideo_foley.pth 48 | # - vae_128d_48k.pth 49 | # - synchformer_state_dict.pth 50 | 51 | # Download config file to configs/: 52 | # - hunyuanvideo-foley-xxl.yaml 53 | ``` 54 | 55 | Or use the Gradio app's auto-download feature by setting the environment variable: 56 | 57 | ```bash 58 | export HIFI_FOLEY_MODEL_PATH="/path/to/ComfyUI/custom_nodes/ComfyUI_HunyuanVideoFoley/pretrained_models" 59 | ``` 60 | 61 | ### 4. Verify Installation 62 | 63 | ```bash 64 | # Run the test script to verify everything is working 65 | python test_node.py 66 | ``` 67 | 68 | ### 5. Restart ComfyUI 69 | 70 | After installation, restart ComfyUI to load the new custom nodes. 71 | 72 | ## Expected Directory Structure 73 | 74 | After installation, your directory should look like this: 75 | 76 | ``` 77 | ComfyUI_HunyuanVideoFoley/ 78 | ├── __init__.py 79 | ├── nodes.py 80 | ├── utils.py 81 | ├── requirements.txt 82 | ├── install.py 83 | ├── test_node.py 84 | ├── example_workflow.json 85 | ├── README.md 86 | ├── INSTALLATION_GUIDE.md 87 | └── pyproject.toml 88 | 89 | pretrained_models/ 90 | ├── hunyuanvideo_foley.pth 91 | ├── vae_128d_48k.pth 92 | └── synchformer_state_dict.pth 93 | 94 | configs/ 95 | └── hunyuanvideo-foley-xxl.yaml 96 | ``` 97 | 98 | ## Usage 99 | 100 | ### Nodes Available 101 | 102 | 1. **HunyuanVideo-Foley Generator** 103 | - Main node for audio generation 104 | - Inputs: video, text prompt, generation parameters 105 | - Outputs: video with audio, audio only, status message 106 | 107 | 2. **HunyuanVideo-Foley Model Loader** 108 | - Separate model loading node 109 | - Useful for sharing models between multiple generator nodes 110 | - Inputs: model path, config path 111 | - Outputs: model handle, status message 112 | 113 | ### Basic Workflow 114 | 115 | 1. Load a video using ComfyUI's video input nodes 116 | 2. Add the "HunyuanVideo-Foley Generator" node 117 | 3. Connect the video to the generator 118 | 4. Set your text prompt (e.g., "A person walks on frozen ice") 119 | 5. Adjust parameters as needed 120 | 6. Run the workflow 121 | 122 | ### Example Workflow 123 | 124 | An example workflow is provided in `example_workflow.json`. Load this file in ComfyUI to see a basic setup. 125 | 126 | ## Performance Tips 127 | 128 | - **VRAM Usage**: The model requires significant GPU memory. If you encounter CUDA out of memory errors: 129 | - Reduce `sample_nums` parameter 130 | - Lower `num_inference_steps` 131 | - Use CPU mode (slower but works with less memory) 132 | 133 | - **Generation Time**: Audio generation can take several minutes depending on: 134 | - Video length 135 | - Number of inference steps 136 | - Number of samples generated 137 | - Hardware specifications 138 | 139 | ## Troubleshooting 140 | 141 | ### Common Issues 142 | 143 | 1. **"Failed to import HunyuanVideo-Foley modules"** 144 | ```bash 145 | # Make sure you're in the correct directory and have all dependencies 146 | pip install -r requirements.txt 147 | ``` 148 | 149 | 2. **"Model path does not exist"** 150 | ```bash 151 | # Download models from HuggingFace and verify directory structure 152 | ls pretrained_models/ 153 | # Should show: hunyuanvideo_foley.pth, vae_128d_48k.pth, synchformer_state_dict.pth 154 | ``` 155 | 156 | 3. **CUDA out of memory** 157 | ```bash 158 | # Reduce memory usage by adjusting parameters: 159 | # - Lower sample_nums to 1 160 | # - Reduce num_inference_steps to 25 161 | # - Use shorter videos for testing 162 | ``` 163 | 164 | 4. **Slow generation** 165 | ```bash 166 | # Normal for first run (model loading) 167 | # Subsequent runs should be faster 168 | # Consider using fewer inference steps for faster results 169 | ``` 170 | 171 | ### Getting Help 172 | 173 | - Check the test script output: `python test_node.py` 174 | - Review ComfyUI console output for detailed error messages 175 | - Ensure all model files are downloaded correctly 176 | - Verify GPU memory availability 177 | 178 | ## Model Information 179 | 180 | The HunyuanVideo-Foley model consists of several components: 181 | 182 | - **Main Foley Model**: Core text-video-to-audio generation 183 | - **DAC VAE**: Audio encoding/decoding 184 | - **SigLIP2**: Visual feature extraction 185 | - **CLAP**: Text feature extraction 186 | - **Synchformer**: Video-audio synchronization 187 | 188 | All components are automatically loaded when using the custom node. 189 | 190 | ## License & Credits 191 | 192 | This custom node is based on the HunyuanVideo-Foley project by Tencent. Please respect the original project's license terms when using this implementation. 193 | 194 | Original project: https://github.com/tencent/HunyuanVideo-Foley -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/mlp_layers.py: -------------------------------------------------------------------------------- 1 | # Modified from timm library: 2 | # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13 3 | 4 | from functools import partial 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .modulate_layers import modulate 11 | from ...utils.helper import to_2tuple 12 | 13 | class MLP(nn.Module): 14 | """MLP as used in Vision Transformer, MLP-Mixer and related networks""" 15 | 16 | def __init__( 17 | self, 18 | in_channels, 19 | hidden_channels=None, 20 | out_features=None, 21 | act_layer=nn.GELU, 22 | norm_layer=None, 23 | bias=True, 24 | drop=0.0, 25 | use_conv=False, 26 | device=None, 27 | dtype=None, 28 | ): 29 | factory_kwargs = {"device": device, "dtype": dtype} 30 | super().__init__() 31 | out_features = out_features or in_channels 32 | hidden_channels = hidden_channels or in_channels 33 | bias = to_2tuple(bias) 34 | drop_probs = to_2tuple(drop) 35 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear 36 | 37 | self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs) 38 | self.act = act_layer() 39 | self.drop1 = nn.Dropout(drop_probs[0]) 40 | self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity() 41 | self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs) 42 | self.drop2 = nn.Dropout(drop_probs[1]) 43 | 44 | def forward(self, x): 45 | x = self.fc1(x) 46 | x = self.act(x) 47 | x = self.drop1(x) 48 | x = self.norm(x) 49 | x = self.fc2(x) 50 | x = self.drop2(x) 51 | return x 52 | 53 | 54 | # copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py 55 | # only used when use_vanilla is True 56 | class MLPEmbedder(nn.Module): 57 | def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None): 58 | factory_kwargs = {"device": device, "dtype": dtype} 59 | super().__init__() 60 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs) 61 | self.silu = nn.SiLU() 62 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs) 63 | 64 | def forward(self, x: torch.Tensor) -> torch.Tensor: 65 | return self.out_layer(self.silu(self.in_layer(x))) 66 | 67 | 68 | class LinearWarpforSingle(nn.Module): 69 | def __init__(self, in_dim: int, out_dim: int, bias=True, device=None, dtype=None): 70 | factory_kwargs = {"device": device, "dtype": dtype} 71 | super().__init__() 72 | self.fc = nn.Linear(in_dim, out_dim, bias=bias, **factory_kwargs) 73 | 74 | def forward(self, x, y): 75 | z = torch.cat([x, y], dim=2) 76 | return self.fc(z) 77 | 78 | class FinalLayer1D(nn.Module): 79 | def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None): 80 | factory_kwargs = {"device": device, "dtype": dtype} 81 | super().__init__() 82 | 83 | # Just use LayerNorm for the final layer 84 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) 85 | self.linear = nn.Linear(hidden_size, patch_size * out_channels, bias=True, **factory_kwargs) 86 | nn.init.zeros_(self.linear.weight) 87 | nn.init.zeros_(self.linear.bias) 88 | 89 | # Here we don't distinguish between the modulate types. Just use the simple one. 90 | self.adaLN_modulation = nn.Sequential( 91 | act_layer(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs) 92 | ) 93 | # Zero-initialize the modulation 94 | nn.init.zeros_(self.adaLN_modulation[1].weight) 95 | nn.init.zeros_(self.adaLN_modulation[1].bias) 96 | 97 | def forward(self, x, c): 98 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) 99 | x = modulate(self.norm_final(x), shift=shift, scale=scale) 100 | x = self.linear(x) 101 | return x 102 | 103 | 104 | class ChannelLastConv1d(nn.Conv1d): 105 | 106 | def forward(self, x: torch.Tensor) -> torch.Tensor: 107 | x = x.permute(0, 2, 1) 108 | x = super().forward(x) 109 | x = x.permute(0, 2, 1) 110 | return x 111 | 112 | 113 | class ConvMLP(nn.Module): 114 | 115 | def __init__( 116 | self, 117 | dim: int, 118 | hidden_dim: int, 119 | multiple_of: int = 256, 120 | kernel_size: int = 3, 121 | padding: int = 1, 122 | device=None, 123 | dtype=None, 124 | ): 125 | """ 126 | Convolutional MLP module. 127 | 128 | Args: 129 | dim (int): Input dimension. 130 | hidden_dim (int): Hidden dimension of the feedforward layer. 131 | multiple_of (int): Value to ensure hidden dimension is a multiple of this value. 132 | 133 | Attributes: 134 | w1: Linear transformation for the first layer. 135 | w2: Linear transformation for the second layer. 136 | w3: Linear transformation for the third layer. 137 | 138 | """ 139 | factory_kwargs = {"device": device, "dtype": dtype} 140 | super().__init__() 141 | hidden_dim = int(2 * hidden_dim / 3) 142 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 143 | 144 | self.w1 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs) 145 | self.w2 = ChannelLastConv1d(hidden_dim, dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs) 146 | self.w3 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs) 147 | 148 | def forward(self, x): 149 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 150 | -------------------------------------------------------------------------------- /test_node.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Test script for ComfyUI HunyuanVideo-Foley custom node 4 | """ 5 | 6 | import sys 7 | import os 8 | import tempfile 9 | from pathlib import Path 10 | 11 | # Add the parent directory to path for imports 12 | current_dir = Path(__file__).parent 13 | parent_dir = current_dir.parent 14 | sys.path.insert(0, str(parent_dir)) 15 | 16 | def test_imports(): 17 | """Test that all required modules can be imported""" 18 | print("Testing imports...") 19 | 20 | try: 21 | from ComfyUI_HunyuanVideoFoley import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 22 | print("✅ Successfully imported node mappings") 23 | 24 | print(f"Available nodes: {list(NODE_CLASS_MAPPINGS.keys())}") 25 | print(f"Display names: {NODE_DISPLAY_NAME_MAPPINGS}") 26 | 27 | return True 28 | 29 | except ImportError as e: 30 | print(f"❌ Import failed: {e}") 31 | return False 32 | 33 | def test_node_structure(): 34 | """Test node class structure""" 35 | print("\nTesting node structure...") 36 | 37 | try: 38 | from ComfyUI_HunyuanVideoFoley.nodes import HunyuanVideoFoleyNode, HunyuanVideoFoleyModelLoader 39 | 40 | # Test HunyuanVideoFoleyNode 41 | node = HunyuanVideoFoleyNode() 42 | input_types = node.INPUT_TYPES() 43 | 44 | print("✅ HunyuanVideoFoleyNode structure:") 45 | print(f" - Required inputs: {list(input_types['required'].keys())}") 46 | print(f" - Optional inputs: {list(input_types.get('optional', {}).keys())}") 47 | print(f" - Return types: {node.RETURN_TYPES}") 48 | print(f" - Function: {node.FUNCTION}") 49 | print(f" - Category: {node.CATEGORY}") 50 | 51 | # Test HunyuanVideoFoleyModelLoader 52 | loader = HunyuanVideoFoleyModelLoader() 53 | loader_input_types = loader.INPUT_TYPES() 54 | 55 | print("✅ HunyuanVideoFoleyModelLoader structure:") 56 | print(f" - Required inputs: {list(loader_input_types['required'].keys())}") 57 | print(f" - Return types: {loader.RETURN_TYPES}") 58 | print(f" - Function: {loader.FUNCTION}") 59 | 60 | return True 61 | 62 | except Exception as e: 63 | print(f"❌ Node structure test failed: {e}") 64 | return False 65 | 66 | def test_device_setup(): 67 | """Test device setup functionality""" 68 | print("\nTesting device setup...") 69 | 70 | try: 71 | from ComfyUI_HunyuanVideoFoley.nodes import HunyuanVideoFoleyNode 72 | 73 | device = HunyuanVideoFoleyNode.setup_device("auto") 74 | print(f"✅ Device setup successful: {device}") 75 | 76 | return True 77 | 78 | except Exception as e: 79 | print(f"❌ Device setup failed: {e}") 80 | return False 81 | 82 | def test_utils(): 83 | """Test utility functions""" 84 | print("\nTesting utility functions...") 85 | 86 | try: 87 | from ComfyUI_HunyuanVideoFoley.utils import ( 88 | get_optimal_device, 89 | check_memory_requirements, 90 | format_duration, 91 | validate_model_files 92 | ) 93 | 94 | # Test device detection 95 | device = get_optimal_device() 96 | print(f"✅ Optimal device: {device}") 97 | 98 | # Test memory check 99 | has_memory, msg = check_memory_requirements(device) 100 | print(f"✅ Memory check: {msg}") 101 | 102 | # Test duration formatting 103 | duration = format_duration(125.5) 104 | print(f"✅ Duration formatting: 125.5s -> {duration}") 105 | 106 | # Test model validation (will fail without models, but that's expected) 107 | is_valid, msg = validate_model_files("./pretrained_models/") 108 | print(f"✅ Model validation: {msg}") 109 | 110 | return True 111 | 112 | except Exception as e: 113 | print(f"❌ Utils test failed: {e}") 114 | return False 115 | 116 | def test_requirements(): 117 | """Test if key requirements are available""" 118 | print("\nTesting requirements...") 119 | 120 | required_packages = [ 121 | 'torch', 122 | 'torchaudio', 123 | 'numpy', 124 | 'loguru', 125 | 'diffusers', 126 | 'transformers' 127 | ] 128 | 129 | missing = [] 130 | for package in required_packages: 131 | try: 132 | __import__(package) 133 | print(f"✅ {package}") 134 | except ImportError: 135 | print(f"❌ {package} - not installed") 136 | missing.append(package) 137 | 138 | if missing: 139 | print(f"\nMissing packages: {', '.join(missing)}") 140 | print("Run: pip install -r requirements.txt") 141 | return False 142 | 143 | return True 144 | 145 | def main(): 146 | """Main test function""" 147 | print("🧪 Testing ComfyUI HunyuanVideo-Foley Custom Node") 148 | print("=" * 50) 149 | 150 | tests = [ 151 | ("Requirements", test_requirements), 152 | ("Imports", test_imports), 153 | ("Node Structure", test_node_structure), 154 | ("Device Setup", test_device_setup), 155 | ("Utils", test_utils), 156 | ] 157 | 158 | passed = 0 159 | failed = 0 160 | 161 | for test_name, test_func in tests: 162 | print(f"\n🔍 Running test: {test_name}") 163 | try: 164 | if test_func(): 165 | passed += 1 166 | print(f"✅ {test_name} PASSED") 167 | else: 168 | failed += 1 169 | print(f"❌ {test_name} FAILED") 170 | except Exception as e: 171 | failed += 1 172 | print(f"❌ {test_name} FAILED with exception: {e}") 173 | 174 | print("\n" + "=" * 50) 175 | print(f"📊 Test Results: {passed} passed, {failed} failed") 176 | 177 | if failed == 0: 178 | print("🎉 All tests passed! The custom node is ready for use.") 179 | else: 180 | print("⚠️ Some tests failed. Please check the issues above.") 181 | 182 | return failed == 0 183 | 184 | if __name__ == "__main__": 185 | success = main() 186 | sys.exit(0 if success else 1) -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/feature_utils.py: -------------------------------------------------------------------------------- 1 | """Feature extraction utilities for video and text processing.""" 2 | 3 | import os 4 | import numpy as np 5 | import torch 6 | import av 7 | from PIL import Image 8 | from einops import rearrange 9 | from typing import Any, Dict, List, Union, Tuple 10 | from loguru import logger 11 | 12 | from .config_utils import AttributeDict 13 | from ..constants import FPS_VISUAL, MAX_VIDEO_DURATION_SECONDS 14 | 15 | 16 | class FeatureExtractionError(Exception): 17 | """Exception raised for feature extraction errors.""" 18 | pass 19 | 20 | def get_frames_av( 21 | video_path: str, 22 | fps: float, 23 | max_length: float = None, 24 | ) -> Tuple[np.ndarray, float]: 25 | end_sec = max_length if max_length is not None else 15 26 | next_frame_time_for_each_fps = 0.0 27 | time_delta_for_each_fps = 1 / fps 28 | 29 | all_frames = [] 30 | output_frames = [] 31 | 32 | with av.open(video_path) as container: 33 | stream = container.streams.video[0] 34 | ori_fps = stream.guessed_rate 35 | stream.thread_type = "AUTO" 36 | for packet in container.demux(stream): 37 | for frame in packet.decode(): 38 | frame_time = frame.time 39 | if frame_time < 0: 40 | continue 41 | if frame_time > end_sec: 42 | break 43 | 44 | frame_np = None 45 | 46 | this_time = frame_time 47 | while this_time >= next_frame_time_for_each_fps: 48 | if frame_np is None: 49 | frame_np = frame.to_ndarray(format="rgb24") 50 | 51 | output_frames.append(frame_np) 52 | next_frame_time_for_each_fps += time_delta_for_each_fps 53 | 54 | output_frames = np.stack(output_frames) 55 | 56 | vid_len_in_s = len(output_frames) / fps 57 | if max_length is not None and len(output_frames) > int(max_length * fps): 58 | output_frames = output_frames[: int(max_length * fps)] 59 | vid_len_in_s = max_length 60 | 61 | return output_frames, vid_len_in_s 62 | 63 | @torch.inference_mode() 64 | def encode_video_with_siglip2(x: torch.Tensor, model_dict, batch_size: int = -1): 65 | b, t, c, h, w = x.shape 66 | if batch_size < 0: 67 | batch_size = b * t 68 | x = rearrange(x, "b t c h w -> (b t) c h w") 69 | outputs = [] 70 | for i in range(0, b * t, batch_size): 71 | outputs.append(model_dict.siglip2_model.get_image_features(pixel_values=x[i : i + batch_size])) 72 | res = torch.cat(outputs, dim=0) 73 | res = rearrange(res, "(b t) d -> b t d", b=b) 74 | return res 75 | 76 | @torch.inference_mode() 77 | def encode_video_with_sync(x: torch.Tensor, model_dict, batch_size: int = -1): 78 | """ 79 | The input video of x is best to be in fps of 24 of greater than 24. 80 | Input: 81 | x: tensor in shape of [B, T, C, H, W] 82 | batch_size: the batch_size for synchformer inference 83 | """ 84 | b, t, c, h, w = x.shape 85 | assert c == 3 and h == 224 and w == 224 86 | 87 | segment_size = 16 88 | step_size = 8 89 | num_segments = (t - segment_size) // step_size + 1 90 | segments = [] 91 | for i in range(num_segments): 92 | segments.append(x[:, i * step_size : i * step_size + segment_size]) 93 | x = torch.stack(segments, dim=1).cuda() # (B, num_segments, segment_size, 3, 224, 224) 94 | 95 | outputs = [] 96 | if batch_size < 0: 97 | batch_size = b * num_segments 98 | x = rearrange(x, "b s t c h w -> (b s) 1 t c h w") 99 | for i in range(0, b * num_segments, batch_size): 100 | with torch.autocast(device_type="cuda", enabled=True, dtype=torch.half): 101 | outputs.append(model_dict.syncformer_model(x[i : i + batch_size])) 102 | x = torch.cat(outputs, dim=0) # [b * num_segments, 1, 8, 768] 103 | x = rearrange(x, "(b s) 1 t d -> b (s t) d", b=b) 104 | return x 105 | 106 | 107 | @torch.inference_mode() 108 | def encode_video_features(video_path, model_dict): 109 | visual_features = {} 110 | # siglip2 visual features 111 | frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["siglip2"]) 112 | images = [Image.fromarray(frame).convert('RGB') for frame in frames] 113 | images = [model_dict.siglip2_preprocess(image) for image in images] # [T, C, H, W] 114 | clip_frames = torch.stack(images).to(model_dict.device).unsqueeze(0) 115 | visual_features['siglip2_feat'] = encode_video_with_siglip2(clip_frames, model_dict).to(model_dict.device) 116 | 117 | # synchformer visual features 118 | frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["synchformer"]) 119 | images = torch.from_numpy(frames).permute(0, 3, 1, 2) # [T, C, H, W] 120 | sync_frames = model_dict.syncformer_preprocess(images).unsqueeze(0) # [1, T, 3, 224, 224] 121 | # [1, num_segments * 8, channel_dim], e.g. [1, 240, 768] for 10s video 122 | visual_features['syncformer_feat'] = encode_video_with_sync(sync_frames, model_dict) 123 | 124 | vid_len_in_s = sync_frames.shape[1] / FPS_VISUAL["synchformer"] 125 | visual_features = AttributeDict(visual_features) 126 | 127 | return visual_features, vid_len_in_s 128 | 129 | @torch.inference_mode() 130 | def encode_text_feat(text: List[str], model_dict): 131 | # x: (B, L) 132 | inputs = model_dict.clap_tokenizer(text, padding=True, return_tensors="pt").to(model_dict.device) 133 | outputs = model_dict.clap_model(**inputs, output_hidden_states=True, return_dict=True) 134 | return outputs.last_hidden_state, outputs.attentions 135 | 136 | 137 | def feature_process(video_path, prompt, model_dict, cfg): 138 | visual_feats, audio_len_in_s = encode_video_features(video_path, model_dict) 139 | neg_prompt = "noisy, harsh" 140 | prompts = [neg_prompt, prompt] 141 | text_feat_res, text_feat_mask = encode_text_feat(prompts, model_dict) 142 | 143 | text_feat = text_feat_res[1:] 144 | uncond_text_feat = text_feat_res[:1] 145 | 146 | if cfg.model_config.model_kwargs.text_length < text_feat.shape[1]: 147 | text_seq_length = cfg.model_config.model_kwargs.text_length 148 | text_feat = text_feat[:, :text_seq_length] 149 | uncond_text_feat = uncond_text_feat[:, :text_seq_length] 150 | 151 | text_feats = AttributeDict({ 152 | 'text_feat': text_feat, 153 | 'uncond_text_feat': uncond_text_feat, 154 | }) 155 | 156 | return visual_feats, text_feats, audio_len_in_s 157 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/posemb_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, Tuple 3 | 4 | 5 | def _to_tuple(x, dim=2): 6 | if isinstance(x, int): 7 | return (x,) * dim 8 | elif len(x) == dim: 9 | return x 10 | else: 11 | raise ValueError(f"Expected length {dim} or int, but got {x}") 12 | 13 | 14 | def get_meshgrid_nd(start, *args, dim=2): 15 | """ 16 | Get n-D meshgrid with start, stop and num. 17 | 18 | Args: 19 | start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, 20 | step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num 21 | should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in 22 | n-tuples. 23 | *args: See above. 24 | dim (int): Dimension of the meshgrid. Defaults to 2. 25 | 26 | Returns: 27 | grid (np.ndarray): [dim, ...] 28 | """ 29 | if len(args) == 0: 30 | # start is grid_size 31 | num = _to_tuple(start, dim=dim) 32 | start = (0,) * dim 33 | stop = num 34 | elif len(args) == 1: 35 | # start is start, args[0] is stop, step is 1 36 | start = _to_tuple(start, dim=dim) 37 | stop = _to_tuple(args[0], dim=dim) 38 | num = [stop[i] - start[i] for i in range(dim)] 39 | elif len(args) == 2: 40 | # start is start, args[0] is stop, args[1] is num 41 | start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 42 | stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 43 | num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 44 | else: 45 | raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") 46 | 47 | # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) 48 | axis_grid = [] 49 | for i in range(dim): 50 | a, b, n = start[i], stop[i], num[i] 51 | g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] 52 | axis_grid.append(g) 53 | grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] 54 | grid = torch.stack(grid, dim=0) # [dim, W, H, D] 55 | 56 | return grid 57 | 58 | 59 | ################################################################################# 60 | # Rotary Positional Embedding Functions # 61 | ################################################################################# 62 | # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80 63 | 64 | 65 | def get_nd_rotary_pos_embed( 66 | rope_dim_list, start, *args, theta=10000.0, use_real=False, theta_rescale_factor=1.0, freq_scaling=1.0 67 | ): 68 | """ 69 | This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. 70 | 71 | Args: 72 | rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. 73 | sum(rope_dim_list) should equal to head_dim of attention layer. 74 | start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, 75 | args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. 76 | *args: See above. 77 | theta (float): Scaling factor for frequency computation. Defaults to 10000.0. 78 | use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. 79 | Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real 80 | part and an imaginary part separately. 81 | theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. 82 | freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0. 83 | 84 | Returns: 85 | pos_embed (torch.Tensor): [HW, D/2] 86 | """ 87 | 88 | grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] 89 | 90 | # use 1/ndim of dimensions to encode grid_axis 91 | embs = [] 92 | for i in range(len(rope_dim_list)): 93 | emb = get_1d_rotary_pos_embed( 94 | rope_dim_list[i], 95 | grid[i].reshape(-1), 96 | theta, 97 | use_real=use_real, 98 | theta_rescale_factor=theta_rescale_factor, 99 | freq_scaling=freq_scaling, 100 | ) # 2 x [WHD, rope_dim_list[i]] 101 | embs.append(emb) 102 | 103 | if use_real: 104 | cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) 105 | sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) 106 | return cos, sin 107 | else: 108 | emb = torch.cat(embs, dim=1) # (WHD, D/2) 109 | return emb 110 | 111 | 112 | def get_1d_rotary_pos_embed( 113 | dim: int, 114 | pos: Union[torch.FloatTensor, int], 115 | theta: float = 10000.0, 116 | use_real: bool = False, 117 | theta_rescale_factor: float = 1.0, 118 | freq_scaling: float = 1.0, 119 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 120 | """ 121 | Precompute the frequency tensor for complex exponential (cis) with given dimensions. 122 | (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) 123 | 124 | This function calculates a frequency tensor with complex exponential using the given dimension 'dim' 125 | and the end index 'end'. The 'theta' parameter scales the frequencies. 126 | The returned tensor contains complex values in complex64 data type. 127 | 128 | Args: 129 | dim (int): Dimension of the frequency tensor. 130 | pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar 131 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. 132 | use_real (bool, optional): If True, return real part and imaginary part separately. 133 | Otherwise, return complex numbers. 134 | theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. 135 | freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0. 136 | 137 | Returns: 138 | freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] 139 | freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] 140 | """ 141 | if isinstance(pos, int): 142 | pos = torch.arange(pos).float() 143 | 144 | # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning 145 | # has some connection to NTK literature 146 | # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ 147 | if theta_rescale_factor != 1.0: 148 | theta *= theta_rescale_factor ** (dim / (dim - 1)) 149 | 150 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] 151 | freqs *= freq_scaling 152 | freqs = torch.outer(pos, freqs) # [S, D/2] 153 | if use_real: 154 | freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] 155 | freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] 156 | return freqs_cos, freqs_sin 157 | else: 158 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] 159 | return freqs_cis 160 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/model/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from audiotools import AudioSignal 5 | from audiotools import ml 6 | from audiotools import STFTParams 7 | from einops import rearrange 8 | from torch.nn.utils import weight_norm 9 | 10 | 11 | def WNConv1d(*args, **kwargs): 12 | act = kwargs.pop("act", True) 13 | conv = weight_norm(nn.Conv1d(*args, **kwargs)) 14 | if not act: 15 | return conv 16 | return nn.Sequential(conv, nn.LeakyReLU(0.1)) 17 | 18 | 19 | def WNConv2d(*args, **kwargs): 20 | act = kwargs.pop("act", True) 21 | conv = weight_norm(nn.Conv2d(*args, **kwargs)) 22 | if not act: 23 | return conv 24 | return nn.Sequential(conv, nn.LeakyReLU(0.1)) 25 | 26 | 27 | class MPD(nn.Module): 28 | def __init__(self, period): 29 | super().__init__() 30 | self.period = period 31 | self.convs = nn.ModuleList( 32 | [ 33 | WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), 34 | WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), 35 | WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), 36 | WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), 37 | WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), 38 | ] 39 | ) 40 | self.conv_post = WNConv2d( 41 | 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False 42 | ) 43 | 44 | def pad_to_period(self, x): 45 | t = x.shape[-1] 46 | x = F.pad(x, (0, self.period - t % self.period), mode="reflect") 47 | return x 48 | 49 | def forward(self, x): 50 | fmap = [] 51 | 52 | x = self.pad_to_period(x) 53 | x = rearrange(x, "b c (l p) -> b c l p", p=self.period) 54 | 55 | for layer in self.convs: 56 | x = layer(x) 57 | fmap.append(x) 58 | 59 | x = self.conv_post(x) 60 | fmap.append(x) 61 | 62 | return fmap 63 | 64 | 65 | class MSD(nn.Module): 66 | def __init__(self, rate: int = 1, sample_rate: int = 44100): 67 | super().__init__() 68 | self.convs = nn.ModuleList( 69 | [ 70 | WNConv1d(1, 16, 15, 1, padding=7), 71 | WNConv1d(16, 64, 41, 4, groups=4, padding=20), 72 | WNConv1d(64, 256, 41, 4, groups=16, padding=20), 73 | WNConv1d(256, 1024, 41, 4, groups=64, padding=20), 74 | WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), 75 | WNConv1d(1024, 1024, 5, 1, padding=2), 76 | ] 77 | ) 78 | self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) 79 | self.sample_rate = sample_rate 80 | self.rate = rate 81 | 82 | def forward(self, x): 83 | x = AudioSignal(x, self.sample_rate) 84 | x.resample(self.sample_rate // self.rate) 85 | x = x.audio_data 86 | 87 | fmap = [] 88 | 89 | for l in self.convs: 90 | x = l(x) 91 | fmap.append(x) 92 | x = self.conv_post(x) 93 | fmap.append(x) 94 | 95 | return fmap 96 | 97 | 98 | BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] 99 | 100 | 101 | class MRD(nn.Module): 102 | def __init__( 103 | self, 104 | window_length: int, 105 | hop_factor: float = 0.25, 106 | sample_rate: int = 44100, 107 | bands: list = BANDS, 108 | ): 109 | """Complex multi-band spectrogram discriminator. 110 | Parameters 111 | ---------- 112 | window_length : int 113 | Window length of STFT. 114 | hop_factor : float, optional 115 | Hop factor of the STFT, defaults to ``0.25 * window_length``. 116 | sample_rate : int, optional 117 | Sampling rate of audio in Hz, by default 44100 118 | bands : list, optional 119 | Bands to run discriminator over. 120 | """ 121 | super().__init__() 122 | 123 | self.window_length = window_length 124 | self.hop_factor = hop_factor 125 | self.sample_rate = sample_rate 126 | self.stft_params = STFTParams( 127 | window_length=window_length, 128 | hop_length=int(window_length * hop_factor), 129 | match_stride=True, 130 | ) 131 | 132 | n_fft = window_length // 2 + 1 133 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] 134 | self.bands = bands 135 | 136 | ch = 32 137 | convs = lambda: nn.ModuleList( 138 | [ 139 | WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), 140 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 141 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 142 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 143 | WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), 144 | ] 145 | ) 146 | self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) 147 | self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) 148 | 149 | def spectrogram(self, x): 150 | x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) 151 | x = torch.view_as_real(x.stft()) 152 | x = rearrange(x, "b 1 f t c -> (b 1) c t f") 153 | # Split into bands 154 | x_bands = [x[..., b[0] : b[1]] for b in self.bands] 155 | return x_bands 156 | 157 | def forward(self, x): 158 | x_bands = self.spectrogram(x) 159 | fmap = [] 160 | 161 | x = [] 162 | for band, stack in zip(x_bands, self.band_convs): 163 | for layer in stack: 164 | band = layer(band) 165 | fmap.append(band) 166 | x.append(band) 167 | 168 | x = torch.cat(x, dim=-1) 169 | x = self.conv_post(x) 170 | fmap.append(x) 171 | 172 | return fmap 173 | 174 | 175 | class Discriminator(ml.BaseModel): 176 | def __init__( 177 | self, 178 | rates: list = [], 179 | periods: list = [2, 3, 5, 7, 11], 180 | fft_sizes: list = [2048, 1024, 512], 181 | sample_rate: int = 44100, 182 | bands: list = BANDS, 183 | ): 184 | """Discriminator that combines multiple discriminators. 185 | 186 | Parameters 187 | ---------- 188 | rates : list, optional 189 | sampling rates (in Hz) to run MSD at, by default [] 190 | If empty, MSD is not used. 191 | periods : list, optional 192 | periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] 193 | fft_sizes : list, optional 194 | Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] 195 | sample_rate : int, optional 196 | Sampling rate of audio in Hz, by default 44100 197 | bands : list, optional 198 | Bands to run MRD at, by default `BANDS` 199 | """ 200 | super().__init__() 201 | discs = [] 202 | discs += [MPD(p) for p in periods] 203 | discs += [MSD(r, sample_rate=sample_rate) for r in rates] 204 | discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] 205 | self.discriminators = nn.ModuleList(discs) 206 | 207 | def preprocess(self, y): 208 | # Remove DC offset 209 | y = y - y.mean(dim=-1, keepdims=True) 210 | # Peak normalize the volume of input audio 211 | y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) 212 | return y 213 | 214 | def forward(self, x): 215 | x = self.preprocess(x) 216 | fmaps = [d(x) for d in self.discriminators] 217 | return fmaps 218 | 219 | 220 | if __name__ == "__main__": 221 | disc = Discriminator() 222 | x = torch.zeros(1, 1, 44100) 223 | results = disc(x) 224 | for i, result in enumerate(results): 225 | print(f"disc{i}") 226 | for i, r in enumerate(result): 227 | print(r.shape, r.mean(), r.min(), r.max()) 228 | print() 229 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for ComfyUI HunyuanVideo-Foley custom node 3 | """ 4 | 5 | import os 6 | import tempfile 7 | import torch 8 | import numpy as np 9 | from typing import Union, Optional, Tuple 10 | from loguru import logger 11 | 12 | def tensor_to_video(video_tensor: torch.Tensor, output_path: str, fps: int = 30) -> str: 13 | """ 14 | Convert a video tensor to a video file 15 | 16 | Args: 17 | video_tensor: Video tensor with shape (frames, channels, height, width) 18 | output_path: Output video file path 19 | fps: Frame rate for the output video 20 | 21 | Returns: 22 | Path to the saved video file 23 | """ 24 | try: 25 | import cv2 26 | 27 | # Convert tensor to numpy and handle different formats 28 | if isinstance(video_tensor, torch.Tensor): 29 | video_np = video_tensor.detach().cpu().numpy() 30 | else: 31 | video_np = np.array(video_tensor) 32 | 33 | # Handle different tensor formats 34 | if video_np.ndim == 4: # (frames, channels, height, width) 35 | if video_np.shape[1] == 3: # RGB 36 | video_np = np.transpose(video_np, (0, 2, 3, 1)) # (frames, height, width, channels) 37 | elif video_np.shape[1] == 1: # Grayscale 38 | video_np = np.transpose(video_np, (0, 2, 3, 1)) 39 | video_np = np.repeat(video_np, 3, axis=3) # Convert to RGB 40 | elif video_np.ndim == 5: # (batch, frames, channels, height, width) 41 | video_np = video_np[0] # Take first batch 42 | if video_np.shape[1] == 3: 43 | video_np = np.transpose(video_np, (0, 2, 3, 1)) 44 | 45 | # Normalize values to 0-255 range 46 | if video_np.max() <= 1.0: 47 | video_np = (video_np * 255).astype(np.uint8) 48 | else: 49 | video_np = video_np.astype(np.uint8) 50 | 51 | # Get video dimensions 52 | frames, height, width, channels = video_np.shape 53 | 54 | # Create video writer 55 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 56 | out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) 57 | 58 | # Write frames 59 | for i in range(frames): 60 | frame = video_np[i] 61 | if channels == 3: 62 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # Convert RGB to BGR for OpenCV 63 | out.write(frame) 64 | 65 | out.release() 66 | logger.info(f"Video saved to: {output_path}") 67 | return output_path 68 | 69 | except Exception as e: 70 | logger.error(f"Failed to convert tensor to video: {e}") 71 | raise 72 | 73 | 74 | def get_video_info(video_path: str) -> dict: 75 | """ 76 | Get information about a video file 77 | 78 | Args: 79 | video_path: Path to video file 80 | 81 | Returns: 82 | Dictionary with video information 83 | """ 84 | try: 85 | import cv2 86 | 87 | cap = cv2.VideoCapture(video_path) 88 | 89 | info = { 90 | 'fps': cap.get(cv2.CAP_PROP_FPS), 91 | 'frame_count': int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 92 | 'width': int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), 93 | 'height': int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), 94 | 'duration': int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) / cap.get(cv2.CAP_PROP_FPS) 95 | } 96 | 97 | cap.release() 98 | return info 99 | 100 | except Exception as e: 101 | logger.error(f"Failed to get video info: {e}") 102 | return {} 103 | 104 | 105 | def ensure_video_file(video_input: Union[str, torch.Tensor, np.ndarray]) -> str: 106 | """ 107 | Ensure the video input is converted to a file path 108 | 109 | Args: 110 | video_input: Video input (path, tensor, or array) 111 | 112 | Returns: 113 | Path to video file 114 | """ 115 | if isinstance(video_input, str): 116 | # Already a file path 117 | if os.path.exists(video_input): 118 | return video_input 119 | else: 120 | raise FileNotFoundError(f"Video file not found: {video_input}") 121 | 122 | elif isinstance(video_input, (torch.Tensor, np.ndarray)): 123 | # Convert tensor/array to video file 124 | temp_dir = tempfile.mkdtemp() 125 | output_path = os.path.join(temp_dir, "input_video.mp4") 126 | return tensor_to_video(video_input, output_path) 127 | 128 | else: 129 | raise ValueError(f"Unsupported video input type: {type(video_input)}") 130 | 131 | 132 | def validate_model_files(model_path: str) -> Tuple[bool, str]: 133 | """ 134 | Validate that all required model files exist 135 | 136 | Args: 137 | model_path: Path to model directory 138 | 139 | Returns: 140 | Tuple of (is_valid, error_message) 141 | """ 142 | required_files = [ 143 | "hunyuanvideo_foley.pth", 144 | "vae_128d_48k.pth", 145 | "synchformer_state_dict.pth" 146 | ] 147 | 148 | missing_files = [] 149 | 150 | for file_name in required_files: 151 | file_path = os.path.join(model_path, file_name) 152 | if not os.path.exists(file_path): 153 | missing_files.append(file_name) 154 | 155 | if missing_files: 156 | return False, f"Missing model files: {', '.join(missing_files)}" 157 | 158 | return True, "All required model files found" 159 | 160 | 161 | def get_optimal_device() -> torch.device: 162 | """ 163 | Get the optimal device for model execution 164 | 165 | Returns: 166 | PyTorch device 167 | """ 168 | if torch.cuda.is_available(): 169 | # Get the device with most free memory 170 | max_memory = 0 171 | best_device = 0 172 | 173 | for i in range(torch.cuda.device_count()): 174 | memory_free = torch.cuda.get_device_properties(i).total_memory 175 | if memory_free > max_memory: 176 | max_memory = memory_free 177 | best_device = i 178 | 179 | device = torch.device(f"cuda:{best_device}") 180 | logger.info(f"Using CUDA device: {device} with {max_memory / 1e9:.1f}GB memory") 181 | return device 182 | 183 | elif torch.backends.mps.is_available(): 184 | device = torch.device("mps") 185 | logger.info("Using MPS device (Apple Silicon)") 186 | return device 187 | 188 | else: 189 | device = torch.device("cpu") 190 | logger.info("Using CPU device") 191 | return device 192 | 193 | 194 | def check_memory_requirements(device: torch.device, required_gb: float = 16.0) -> Tuple[bool, str]: 195 | """ 196 | Check if the device has enough memory for model execution 197 | 198 | Args: 199 | device: PyTorch device 200 | required_gb: Required memory in GB 201 | 202 | Returns: 203 | Tuple of (has_enough_memory, message) 204 | """ 205 | if device.type == "cuda": 206 | properties = torch.cuda.get_device_properties(device) 207 | total_memory = properties.total_memory / 1e9 # Convert to GB 208 | 209 | if total_memory < required_gb: 210 | return False, f"GPU has {total_memory:.1f}GB memory, but {required_gb}GB is recommended" 211 | else: 212 | return True, f"GPU has {total_memory:.1f}GB memory (sufficient)" 213 | 214 | elif device.type == "mps": 215 | # MPS doesn't have a direct way to check memory, assume it's sufficient 216 | return True, "Using MPS device (memory check not available)" 217 | 218 | else: 219 | # CPU - assume it has enough memory 220 | return True, "Using CPU (no memory limit)" 221 | 222 | 223 | def format_duration(seconds: float) -> str: 224 | """ 225 | Format duration in seconds to human readable format 226 | 227 | Args: 228 | seconds: Duration in seconds 229 | 230 | Returns: 231 | Formatted duration string 232 | """ 233 | if seconds < 60: 234 | return f"{seconds:.1f}s" 235 | elif seconds < 3600: 236 | minutes = int(seconds // 60) 237 | remaining_seconds = seconds % 60 238 | return f"{minutes}m {remaining_seconds:.1f}s" 239 | else: 240 | hours = int(seconds // 3600) 241 | remaining_minutes = int((seconds % 3600) // 60) 242 | return f"{hours}h {remaining_minutes}m" -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/compute_desync_score.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | from pathlib import Path 4 | 5 | import torch 6 | import torchaudio 7 | import torchvision 8 | from omegaconf import OmegaConf 9 | 10 | import data_transforms 11 | from .synchformer import Synchformer 12 | from .data_transforms import make_class_grid, quantize_offset 13 | from .utils import check_if_file_exists_else_download, which_ffmpeg 14 | 15 | 16 | def prepare_inputs(batch, device): 17 | aud = batch["audio"].to(device) 18 | vid = batch["video"].to(device) 19 | 20 | return aud, vid 21 | 22 | 23 | def get_test_transforms(): 24 | ts = [ 25 | data_transforms.EqualifyFromRight(), 26 | data_transforms.RGBSpatialCrop(input_size=224, is_random=False), 27 | data_transforms.TemporalCropAndOffset( 28 | crop_len_sec=5, 29 | max_off_sec=2, # https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml 30 | max_wiggle_sec=0.0, 31 | do_offset=True, 32 | offset_type="grid", 33 | prob_oos="null", 34 | grid_size=21, 35 | segment_size_vframes=16, 36 | n_segments=14, 37 | step_size_seg=0.5, 38 | vfps=25, 39 | ), 40 | data_transforms.GenerateMultipleSegments( 41 | segment_size_vframes=16, 42 | n_segments=14, 43 | is_start_random=False, 44 | step_size_seg=0.5, 45 | ), 46 | data_transforms.RGBToHalfToZeroOne(), 47 | data_transforms.RGBNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # motionformer normalization 48 | data_transforms.AudioMelSpectrogram( 49 | sample_rate=16000, 50 | win_length=400, # 25 ms * 16 kHz 51 | hop_length=160, # 10 ms * 16 kHz 52 | n_fft=1024, # 2^(ceil(log2(window_size * sampling_rate))) 53 | n_mels=128, # as in AST 54 | ), 55 | data_transforms.AudioLog(), 56 | data_transforms.PadOrTruncate(max_spec_t=66), 57 | data_transforms.AudioNormalizeAST(mean=-4.2677393, std=4.5689974), # AST, pre-trained on AudioSet 58 | data_transforms.PermuteStreams( 59 | einops_order_audio="S F T -> S 1 F T", einops_order_rgb="S T C H W -> S T C H W" # same 60 | ), 61 | ] 62 | transforms = torchvision.transforms.Compose(ts) 63 | 64 | return transforms 65 | 66 | 67 | def get_video_and_audio(path, get_meta=False, start_sec=0, end_sec=None): 68 | orig_path = path 69 | # (Tv, 3, H, W) [0, 255, uint8]; (Ca, Ta) 70 | rgb, audio, meta = torchvision.io.read_video(str(path), start_sec, end_sec, "sec", output_format="TCHW") 71 | assert meta["video_fps"], f"No video fps for {orig_path}" 72 | # (Ta) <- (Ca, Ta) 73 | audio = audio.mean(dim=0) 74 | # FIXME: this is legacy format of `meta` as it used to be loaded by VideoReader. 75 | meta = { 76 | "video": {"fps": [meta["video_fps"]]}, 77 | "audio": {"framerate": [meta["audio_fps"]]}, 78 | } 79 | return rgb, audio, meta 80 | 81 | 82 | def reencode_video(path, vfps=25, afps=16000, in_size=256): 83 | assert which_ffmpeg() != "", "Is ffmpeg installed? Check if the conda environment is activated." 84 | new_path = Path.cwd() / "vis" / f"{Path(path).stem}_{vfps}fps_{in_size}side_{afps}hz.mp4" 85 | new_path.parent.mkdir(exist_ok=True) 86 | new_path = str(new_path) 87 | cmd = f"{which_ffmpeg()}" 88 | # no info/error printing 89 | cmd += " -hide_banner -loglevel panic" 90 | cmd += f" -y -i {path}" 91 | # 1) change fps, 2) resize: min(H,W)=MIN_SIDE (vertical vids are supported), 3) change audio framerate 92 | cmd += f" -vf fps={vfps},scale=iw*{in_size}/'min(iw,ih)':ih*{in_size}/'min(iw,ih)',crop='trunc(iw/2)'*2:'trunc(ih/2)'*2" 93 | cmd += f" -ar {afps}" 94 | cmd += f" {new_path}" 95 | subprocess.call(cmd.split()) 96 | cmd = f"{which_ffmpeg()}" 97 | cmd += " -hide_banner -loglevel panic" 98 | cmd += f" -y -i {new_path}" 99 | cmd += f" -acodec pcm_s16le -ac 1" 100 | cmd += f' {new_path.replace(".mp4", ".wav")}' 101 | subprocess.call(cmd.split()) 102 | return new_path 103 | 104 | 105 | def decode_single_video_prediction(off_logits, grid, item): 106 | label = item["targets"]["offset_label"].item() 107 | print("Ground Truth offset (sec):", f"{label:.2f} ({quantize_offset(grid, label)[-1].item()})") 108 | print() 109 | print("Prediction Results:") 110 | off_probs = torch.softmax(off_logits, dim=-1) 111 | k = min(off_probs.shape[-1], 5) 112 | topk_logits, topk_preds = torch.topk(off_logits, k) 113 | # remove batch dimension 114 | assert len(topk_logits) == 1, "batch is larger than 1" 115 | topk_logits = topk_logits[0] 116 | topk_preds = topk_preds[0] 117 | off_logits = off_logits[0] 118 | off_probs = off_probs[0] 119 | for target_hat in topk_preds: 120 | print(f'p={off_probs[target_hat]:.4f} ({off_logits[target_hat]:.4f}), "{grid[target_hat]:.2f}" ({target_hat})') 121 | return off_probs 122 | 123 | 124 | def main(args): 125 | vfps = 25 126 | afps = 16000 127 | in_size = 256 128 | # making the offset class grid similar to the one used in transforms, 129 | # refer to the used one: https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml 130 | max_off_sec = 2 131 | num_cls = 21 132 | 133 | # checking if the provided video has the correct frame rates 134 | print(f"Using video: {args.vid_path}") 135 | v, _, info = torchvision.io.read_video(args.vid_path, pts_unit="sec") 136 | _, H, W, _ = v.shape 137 | if info["video_fps"] != vfps or info["audio_fps"] != afps or min(H, W) != in_size: 138 | print(f'Reencoding. vfps: {info["video_fps"]} -> {vfps};', end=" ") 139 | print(f'afps: {info["audio_fps"]} -> {afps};', end=" ") 140 | print(f"{(H, W)} -> min(H, W)={in_size}") 141 | args.vid_path = reencode_video(args.vid_path, vfps, afps, in_size) 142 | else: 143 | print(f'Skipping reencoding. vfps: {info["video_fps"]}; afps: {info["audio_fps"]}; min(H, W)={in_size}') 144 | 145 | device = torch.device(args.device) 146 | 147 | # load visual and audio streams 148 | # rgb: (Tv, 3, H, W) in [0, 225], audio: (Ta,) in [-1, 1] 149 | rgb, audio, meta = get_video_and_audio(args.vid_path, get_meta=True) 150 | 151 | # making an item (dict) to apply transformations 152 | # NOTE: here is how it works: 153 | # For instance, if the model is trained on 5sec clips, the provided video is 9sec, and `v_start_i_sec=1.3` 154 | # the transform will crop out a 5sec-clip from 1.3 to 6.3 seconds and shift the start of the audio 155 | # track by `args.offset_sec` seconds. It means that if `offset_sec` > 0, the audio will 156 | # start by `offset_sec` earlier than the rgb track. 157 | # It is a good idea to use something in [-`max_off_sec`, `max_off_sec`] (-2, +2) seconds (see `grid`) 158 | item = dict( 159 | video=rgb, 160 | audio=audio, 161 | meta=meta, 162 | path=args.vid_path, 163 | split="test", 164 | targets={ 165 | "v_start_i_sec": args.v_start_i_sec, 166 | "offset_sec": args.offset_sec, 167 | }, 168 | ) 169 | 170 | grid = make_class_grid(-max_off_sec, max_off_sec, num_cls) 171 | if not (min(grid) <= item["targets"]["offset_sec"] <= max(grid)): 172 | print(f'WARNING: offset_sec={item["targets"]["offset_sec"]} is outside the trained grid: {grid}') 173 | 174 | # applying the test-time transform 175 | item = get_test_transforms()(item) 176 | 177 | # prepare inputs for inference 178 | batch = torch.utils.data.default_collate([item]) 179 | aud, vid = prepare_inputs(batch, device) 180 | 181 | # TODO: 182 | # sanity check: we will take the input to the `model` and recontruct make a video from it. 183 | # Use this check to make sure the input makes sense (audio should be ok but shifted as you specified) 184 | # reconstruct_video_from_input(aud, vid, batch['meta'], args.vid_path, args.v_start_i_sec, args.offset_sec, 185 | # vfps, afps) 186 | 187 | # forward pass 188 | with torch.set_grad_enabled(False): 189 | with torch.autocast("cuda", enabled=True): 190 | _, logits = synchformer(vid, aud) 191 | 192 | # simply prints the results of the prediction 193 | decode_single_video_prediction(logits, grid, item) 194 | 195 | 196 | if __name__ == "__main__": 197 | parser = argparse.ArgumentParser() 198 | parser.add_argument("--exp_name", required=True, help="In a format: xx-xx-xxTxx-xx-xx") 199 | parser.add_argument("--vid_path", required=True, help="A path to .mp4 video") 200 | parser.add_argument("--offset_sec", type=float, default=0.0) 201 | parser.add_argument("--v_start_i_sec", type=float, default=0.0) 202 | parser.add_argument("--device", default="cuda:0") 203 | args = parser.parse_args() 204 | 205 | synchformer = Synchformer().cuda().eval() 206 | synchformer.load_state_dict( 207 | torch.load( 208 | os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"), 209 | weights_only=True, 210 | map_location="cpu", 211 | ) 212 | ) 213 | 214 | main(args) 215 | -------------------------------------------------------------------------------- /example_workflows/HunyuanVideo-Foley.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "b97ba031-d866-4f4e-baf5-3017bd4ef4b5", 3 | "revision": 0, 4 | "last_node_id": 36, 5 | "last_link_id": 82, 6 | "nodes": [ 7 | { 8 | "id": 19, 9 | "type": "VHS_VideoCombine", 10 | "pos": [ 11 | 768.0491943359375, 12 | -51.64681625366211 13 | ], 14 | "size": [ 15 | 214.7587890625, 16 | 477.5454406738281 17 | ], 18 | "flags": {}, 19 | "order": 3, 20 | "mode": 0, 21 | "inputs": [ 22 | { 23 | "name": "images", 24 | "type": "IMAGE", 25 | "link": 79 26 | }, 27 | { 28 | "name": "audio", 29 | "shape": 7, 30 | "type": "AUDIO", 31 | "link": 80 32 | }, 33 | { 34 | "name": "meta_batch", 35 | "shape": 7, 36 | "type": "VHS_BatchManager", 37 | "link": null 38 | }, 39 | { 40 | "name": "vae", 41 | "shape": 7, 42 | "type": "VAE", 43 | "link": null 44 | }, 45 | { 46 | "name": "frame_rate", 47 | "type": "FLOAT", 48 | "widget": { 49 | "name": "frame_rate" 50 | }, 51 | "link": 59 52 | } 53 | ], 54 | "outputs": [ 55 | { 56 | "name": "Filenames", 57 | "type": "VHS_FILENAMES", 58 | "links": null 59 | } 60 | ], 61 | "properties": { 62 | "Node name for S&R": "VHS_VideoCombine" 63 | }, 64 | "widgets_values": { 65 | "frame_rate": 24, 66 | "loop_count": 0, 67 | "filename_prefix": "AnimateDiff", 68 | "format": "video/h264-mp4", 69 | "pix_fmt": "yuv420p", 70 | "crf": 19, 71 | "save_metadata": true, 72 | "trim_to_audio": false, 73 | "pingpong": false, 74 | "save_output": true, 75 | "videopreview": { 76 | "hidden": false, 77 | "paused": false, 78 | "params": { 79 | "filename": "AnimateDiff_00044-audio.mp4", 80 | "subfolder": "", 81 | "type": "output", 82 | "format": "video/h264-mp4", 83 | "frame_rate": 24, 84 | "workflow": "AnimateDiff_00044.png", 85 | "fullpath": "/root/ComfyUI/output/AnimateDiff_00044-audio.mp4" 86 | } 87 | } 88 | } 89 | }, 90 | { 91 | "id": 36, 92 | "type": "HunyuanVideoFoley", 93 | "pos": [ 94 | 280.91168212890625, 95 | -55.041656494140625 96 | ], 97 | "size": [ 98 | 400, 99 | 398 100 | ], 101 | "flags": {}, 102 | "order": 2, 103 | "mode": 0, 104 | "inputs": [ 105 | { 106 | "name": "images", 107 | "type": "IMAGE", 108 | "link": 81 109 | }, 110 | { 111 | "name": "fps", 112 | "shape": 7, 113 | "type": "FLOAT", 114 | "widget": { 115 | "name": "fps" 116 | }, 117 | "link": 82 118 | } 119 | ], 120 | "outputs": [ 121 | { 122 | "name": "video_frames", 123 | "type": "IMAGE", 124 | "links": [ 125 | 79 126 | ] 127 | }, 128 | { 129 | "name": "audio", 130 | "type": "AUDIO", 131 | "links": [ 132 | 80 133 | ] 134 | }, 135 | { 136 | "name": "status_message", 137 | "type": "STRING", 138 | "links": null 139 | } 140 | ], 141 | "properties": { 142 | "Node name for S&R": "HunyuanVideoFoley" 143 | }, 144 | "widgets_values": [ 145 | "Woman holding a gun and fighting monster", 146 | 4.5, 147 | 50, 148 | 1, 149 | 1183713675, 150 | "randomize", 151 | "Sharp sounds, noise, messy sounds, and weak frame-related sounds", 152 | 24, 153 | "hunyuan_foley", 154 | "foley_", 155 | false 156 | ] 157 | }, 158 | { 159 | "id": 16, 160 | "type": "VHS_LoadVideo", 161 | "pos": [ 162 | -87.1226577758789, 163 | -49.90601348876953 164 | ], 165 | "size": [ 166 | 252.2359619140625, 167 | 479.15765380859375 168 | ], 169 | "flags": {}, 170 | "order": 0, 171 | "mode": 0, 172 | "inputs": [ 173 | { 174 | "name": "meta_batch", 175 | "shape": 7, 176 | "type": "VHS_BatchManager", 177 | "link": null 178 | }, 179 | { 180 | "name": "vae", 181 | "shape": 7, 182 | "type": "VAE", 183 | "link": null 184 | } 185 | ], 186 | "outputs": [ 187 | { 188 | "name": "IMAGE", 189 | "type": "IMAGE", 190 | "links": [ 191 | 81 192 | ] 193 | }, 194 | { 195 | "name": "frame_count", 196 | "type": "INT", 197 | "links": [] 198 | }, 199 | { 200 | "name": "audio", 201 | "type": "AUDIO", 202 | "links": null 203 | }, 204 | { 205 | "name": "video_info", 206 | "type": "VHS_VIDEOINFO", 207 | "links": [ 208 | 51 209 | ] 210 | } 211 | ], 212 | "properties": { 213 | "Node name for S&R": "VHS_LoadVideo" 214 | }, 215 | "widgets_values": { 216 | "video": "70973466-d29f-4789-a957-99b44aa36f0b.mp4", 217 | "force_rate": 0, 218 | "custom_width": 0, 219 | "custom_height": 0, 220 | "frame_load_cap": 0, 221 | "skip_first_frames": 0, 222 | "select_every_nth": 1, 223 | "format": "AnimateDiff", 224 | "choose video to upload": "image", 225 | "videopreview": { 226 | "hidden": false, 227 | "paused": false, 228 | "params": { 229 | "force_rate": 0, 230 | "custom_width": 0, 231 | "custom_height": 0, 232 | "frame_load_cap": 0, 233 | "skip_first_frames": 0, 234 | "select_every_nth": 1, 235 | "filename": "70973466-d29f-4789-a957-99b44aa36f0b.mp4", 236 | "type": "input", 237 | "format": "video/mp4" 238 | } 239 | } 240 | } 241 | }, 242 | { 243 | "id": 26, 244 | "type": "VHS_VideoInfo", 245 | "pos": [ 246 | 310.0038146972656, 247 | 443.6332092285156 248 | ], 249 | "size": [ 250 | 218.12109375, 251 | 206 252 | ], 253 | "flags": {}, 254 | "order": 1, 255 | "mode": 0, 256 | "inputs": [ 257 | { 258 | "name": "video_info", 259 | "type": "VHS_VIDEOINFO", 260 | "link": 51 261 | } 262 | ], 263 | "outputs": [ 264 | { 265 | "name": "source_fps🟨", 266 | "type": "FLOAT", 267 | "links": [ 268 | 59, 269 | 82 270 | ] 271 | }, 272 | { 273 | "name": "source_frame_count🟨", 274 | "type": "INT", 275 | "links": null 276 | }, 277 | { 278 | "name": "source_duration🟨", 279 | "type": "FLOAT", 280 | "links": null 281 | }, 282 | { 283 | "name": "source_width🟨", 284 | "type": "INT", 285 | "links": null 286 | }, 287 | { 288 | "name": "source_height🟨", 289 | "type": "INT", 290 | "links": null 291 | }, 292 | { 293 | "name": "loaded_fps🟦", 294 | "type": "FLOAT", 295 | "links": null 296 | }, 297 | { 298 | "name": "loaded_frame_count🟦", 299 | "type": "INT", 300 | "links": null 301 | }, 302 | { 303 | "name": "loaded_duration🟦", 304 | "type": "FLOAT", 305 | "links": null 306 | }, 307 | { 308 | "name": "loaded_width🟦", 309 | "type": "INT", 310 | "links": null 311 | }, 312 | { 313 | "name": "loaded_height🟦", 314 | "type": "INT", 315 | "links": null 316 | } 317 | ], 318 | "properties": { 319 | "Node name for S&R": "VHS_VideoInfo" 320 | }, 321 | "widgets_values": {} 322 | } 323 | ], 324 | "links": [ 325 | [ 326 | 51, 327 | 16, 328 | 3, 329 | 26, 330 | 0, 331 | "VHS_VIDEOINFO" 332 | ], 333 | [ 334 | 59, 335 | 26, 336 | 0, 337 | 19, 338 | 4, 339 | "FLOAT" 340 | ], 341 | [ 342 | 79, 343 | 36, 344 | 0, 345 | 19, 346 | 0, 347 | "IMAGE" 348 | ], 349 | [ 350 | 80, 351 | 36, 352 | 1, 353 | 19, 354 | 1, 355 | "AUDIO" 356 | ], 357 | [ 358 | 81, 359 | 16, 360 | 0, 361 | 36, 362 | 0, 363 | "IMAGE" 364 | ], 365 | [ 366 | 82, 367 | 26, 368 | 0, 369 | 36, 370 | 1, 371 | "FLOAT" 372 | ] 373 | ], 374 | "groups": [], 375 | "config": {}, 376 | "extra": { 377 | "ds": { 378 | "scale": 0.9507547298682385, 379 | "offset": [ 380 | 188.57921583361022, 381 | 197.62529662436822 382 | ] 383 | }, 384 | "frontendVersion": "1.23.4", 385 | "VHS_latentpreview": false, 386 | "VHS_latentpreviewrate": 0, 387 | "VHS_MetadataImage": true, 388 | "VHS_KeepIntermediate": true 389 | }, 390 | "version": 0.4 391 | } -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/quantize.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | from torch.nn.utils import weight_norm 9 | 10 | from .layers import WNConv1d 11 | 12 | 13 | class VectorQuantize(nn.Module): 14 | """ 15 | Implementation of VQ similar to Karpathy's repo: 16 | https://github.com/karpathy/deep-vector-quantization 17 | Additionally uses following tricks from Improved VQGAN 18 | (https://arxiv.org/pdf/2110.04627.pdf): 19 | 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space 20 | for improved codebook usage 21 | 2. l2-normalized codes: Converts euclidean distance to cosine similarity which 22 | improves training stability 23 | """ 24 | 25 | def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): 26 | super().__init__() 27 | self.codebook_size = codebook_size 28 | self.codebook_dim = codebook_dim 29 | 30 | self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) 31 | self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) 32 | self.codebook = nn.Embedding(codebook_size, codebook_dim) 33 | 34 | def forward(self, z): 35 | """Quantized the input tensor using a fixed codebook and returns 36 | the corresponding codebook vectors 37 | 38 | Parameters 39 | ---------- 40 | z : Tensor[B x D x T] 41 | 42 | Returns 43 | ------- 44 | Tensor[B x D x T] 45 | Quantized continuous representation of input 46 | Tensor[1] 47 | Commitment loss to train encoder to predict vectors closer to codebook 48 | entries 49 | Tensor[1] 50 | Codebook loss to update the codebook 51 | Tensor[B x T] 52 | Codebook indices (quantized discrete representation of input) 53 | Tensor[B x D x T] 54 | Projected latents (continuous representation of input before quantization) 55 | """ 56 | 57 | # Factorized codes (ViT-VQGAN) Project input into low-dimensional space 58 | z_e = self.in_proj(z) # z_e : (B x D x T) 59 | z_q, indices = self.decode_latents(z_e) 60 | 61 | commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) 62 | codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) 63 | 64 | z_q = ( 65 | z_e + (z_q - z_e).detach() 66 | ) # noop in forward pass, straight-through gradient estimator in backward pass 67 | 68 | z_q = self.out_proj(z_q) 69 | 70 | return z_q, commitment_loss, codebook_loss, indices, z_e 71 | 72 | def embed_code(self, embed_id): 73 | return F.embedding(embed_id, self.codebook.weight) 74 | 75 | def decode_code(self, embed_id): 76 | return self.embed_code(embed_id).transpose(1, 2) 77 | 78 | def decode_latents(self, latents): 79 | encodings = rearrange(latents, "b d t -> (b t) d") 80 | codebook = self.codebook.weight # codebook: (N x D) 81 | 82 | # L2 normalize encodings and codebook (ViT-VQGAN) 83 | encodings = F.normalize(encodings) 84 | codebook = F.normalize(codebook) 85 | 86 | # Compute euclidean distance with codebook 87 | dist = ( 88 | encodings.pow(2).sum(1, keepdim=True) 89 | - 2 * encodings @ codebook.t() 90 | + codebook.pow(2).sum(1, keepdim=True).t() 91 | ) 92 | indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) 93 | z_q = self.decode_code(indices) 94 | return z_q, indices 95 | 96 | 97 | class ResidualVectorQuantize(nn.Module): 98 | """ 99 | Introduced in SoundStream: An end2end neural audio codec 100 | https://arxiv.org/abs/2107.03312 101 | """ 102 | 103 | def __init__( 104 | self, 105 | input_dim: int = 512, 106 | n_codebooks: int = 9, 107 | codebook_size: int = 1024, 108 | codebook_dim: Union[int, list] = 8, 109 | quantizer_dropout: float = 0.0, 110 | ): 111 | super().__init__() 112 | if isinstance(codebook_dim, int): 113 | codebook_dim = [codebook_dim for _ in range(n_codebooks)] 114 | 115 | self.n_codebooks = n_codebooks 116 | self.codebook_dim = codebook_dim 117 | self.codebook_size = codebook_size 118 | 119 | self.quantizers = nn.ModuleList( 120 | [ 121 | VectorQuantize(input_dim, codebook_size, codebook_dim[i]) 122 | for i in range(n_codebooks) 123 | ] 124 | ) 125 | self.quantizer_dropout = quantizer_dropout 126 | 127 | def forward(self, z, n_quantizers: int = None): 128 | """Quantized the input tensor using a fixed set of `n` codebooks and returns 129 | the corresponding codebook vectors 130 | Parameters 131 | ---------- 132 | z : Tensor[B x D x T] 133 | n_quantizers : int, optional 134 | No. of quantizers to use 135 | (n_quantizers < self.n_codebooks ex: for quantizer dropout) 136 | Note: if `self.quantizer_dropout` is True, this argument is ignored 137 | when in training mode, and a random number of quantizers is used. 138 | Returns 139 | ------- 140 | dict 141 | A dictionary with the following keys: 142 | 143 | "z" : Tensor[B x D x T] 144 | Quantized continuous representation of input 145 | "codes" : Tensor[B x N x T] 146 | Codebook indices for each codebook 147 | (quantized discrete representation of input) 148 | "latents" : Tensor[B x N*D x T] 149 | Projected latents (continuous representation of input before quantization) 150 | "vq/commitment_loss" : Tensor[1] 151 | Commitment loss to train encoder to predict vectors closer to codebook 152 | entries 153 | "vq/codebook_loss" : Tensor[1] 154 | Codebook loss to update the codebook 155 | """ 156 | z_q = 0 157 | residual = z 158 | commitment_loss = 0 159 | codebook_loss = 0 160 | 161 | codebook_indices = [] 162 | latents = [] 163 | 164 | if n_quantizers is None: 165 | n_quantizers = self.n_codebooks 166 | if self.training: 167 | n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 168 | dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) 169 | n_dropout = int(z.shape[0] * self.quantizer_dropout) 170 | n_quantizers[:n_dropout] = dropout[:n_dropout] 171 | n_quantizers = n_quantizers.to(z.device) 172 | 173 | for i, quantizer in enumerate(self.quantizers): 174 | if self.training is False and i >= n_quantizers: 175 | break 176 | 177 | z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( 178 | residual 179 | ) 180 | 181 | # Create mask to apply quantizer dropout 182 | mask = ( 183 | torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers 184 | ) 185 | z_q = z_q + z_q_i * mask[:, None, None] 186 | residual = residual - z_q_i 187 | 188 | # Sum losses 189 | commitment_loss += (commitment_loss_i * mask).mean() 190 | codebook_loss += (codebook_loss_i * mask).mean() 191 | 192 | codebook_indices.append(indices_i) 193 | latents.append(z_e_i) 194 | 195 | codes = torch.stack(codebook_indices, dim=1) 196 | latents = torch.cat(latents, dim=1) 197 | 198 | return z_q, codes, latents, commitment_loss, codebook_loss 199 | 200 | def from_codes(self, codes: torch.Tensor): 201 | """Given the quantized codes, reconstruct the continuous representation 202 | Parameters 203 | ---------- 204 | codes : Tensor[B x N x T] 205 | Quantized discrete representation of input 206 | Returns 207 | ------- 208 | Tensor[B x D x T] 209 | Quantized continuous representation of input 210 | """ 211 | z_q = 0.0 212 | z_p = [] 213 | n_codebooks = codes.shape[1] 214 | for i in range(n_codebooks): 215 | z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) 216 | z_p.append(z_p_i) 217 | 218 | z_q_i = self.quantizers[i].out_proj(z_p_i) 219 | z_q = z_q + z_q_i 220 | return z_q, torch.cat(z_p, dim=1), codes 221 | 222 | def from_latents(self, latents: torch.Tensor): 223 | """Given the unquantized latents, reconstruct the 224 | continuous representation after quantization. 225 | 226 | Parameters 227 | ---------- 228 | latents : Tensor[B x N x T] 229 | Continuous representation of input after projection 230 | 231 | Returns 232 | ------- 233 | Tensor[B x D x T] 234 | Quantized representation of full-projected space 235 | Tensor[B x D x T] 236 | Quantized representation of latent space 237 | """ 238 | z_q = 0 239 | z_p = [] 240 | codes = [] 241 | dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) 242 | 243 | n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ 244 | 0 245 | ] 246 | for i in range(n_codebooks): 247 | j, k = dims[i], dims[i + 1] 248 | z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) 249 | z_p.append(z_p_i) 250 | codes.append(codes_i) 251 | 252 | z_q_i = self.quantizers[i].out_proj(z_p_i) 253 | z_q = z_q + z_q_i 254 | 255 | return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) 256 | 257 | 258 | if __name__ == "__main__": 259 | rvq = ResidualVectorQuantize(quantizer_dropout=True) 260 | x = torch.randn(16, 512, 80) 261 | y = rvq(x) 262 | print(y["latents"].shape) 263 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from loguru import logger 4 | from torchvision import transforms 5 | from torchvision.transforms import v2 6 | from diffusers.utils.torch_utils import randn_tensor 7 | from transformers import AutoTokenizer, AutoModel, ClapTextModelWithProjection 8 | from ..models.dac_vae.model.dac import DAC 9 | from ..models.synchformer import Synchformer 10 | from ..models.hifi_foley import HunyuanVideoFoley 11 | from .config_utils import load_yaml, AttributeDict 12 | from .schedulers import FlowMatchDiscreteScheduler 13 | from tqdm import tqdm 14 | 15 | def load_state_dict(model, model_path): 16 | logger.info(f"Loading model state dict from: {model_path}") 17 | state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False) 18 | 19 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 20 | 21 | if missing_keys: 22 | logger.warning(f"Missing keys in state dict ({len(missing_keys)} keys):") 23 | for key in missing_keys: 24 | logger.warning(f" - {key}") 25 | else: 26 | logger.info("No missing keys found") 27 | 28 | if unexpected_keys: 29 | logger.warning(f"Unexpected keys in state dict ({len(unexpected_keys)} keys):") 30 | for key in unexpected_keys: 31 | logger.warning(f" - {key}") 32 | else: 33 | logger.info("No unexpected keys found") 34 | 35 | logger.info("Model state dict loaded successfully") 36 | return model 37 | 38 | def load_model(model_path, config_path, device): 39 | logger.info("Starting model loading process...") 40 | logger.info(f"Configuration file: {config_path}") 41 | logger.info(f"Model weights dir: {model_path}") 42 | logger.info(f"Target device: {device}") 43 | 44 | cfg = load_yaml(config_path) 45 | logger.info("Configuration loaded successfully") 46 | 47 | # HunyuanVideoFoley 48 | logger.info("Loading HunyuanVideoFoley main model...") 49 | foley_model = HunyuanVideoFoley(cfg, dtype=torch.bfloat16, device=device).to(device=device, dtype=torch.bfloat16) 50 | foley_model = load_state_dict(foley_model, os.path.join(model_path, "hunyuanvideo_foley.pth")) 51 | foley_model.eval() 52 | logger.info("HunyuanVideoFoley model loaded and set to evaluation mode") 53 | 54 | # DAC-VAE 55 | dac_path = os.path.join(model_path, "vae_128d_48k.pth") 56 | logger.info(f"Loading DAC VAE model from: {dac_path}") 57 | dac_model = DAC.load(dac_path) 58 | dac_model = dac_model.to(device) 59 | dac_model.requires_grad_(False) 60 | dac_model.eval() 61 | logger.info("DAC VAE model loaded successfully") 62 | 63 | # Siglip2 visual-encoder 64 | logger.info("Loading SigLIP2 visual encoder...") 65 | siglip2_preprocess = transforms.Compose([ 66 | transforms.Resize((512, 512)), 67 | transforms.ToTensor(), 68 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 69 | ]) 70 | siglip2_model = AutoModel.from_pretrained("google/siglip2-base-patch16-512").to(device).eval() 71 | logger.info("SigLIP2 model and preprocessing pipeline loaded successfully") 72 | 73 | # clap text-encoder 74 | logger.info("Loading CLAP text encoder...") 75 | clap_tokenizer = AutoTokenizer.from_pretrained("laion/larger_clap_general") 76 | clap_model = ClapTextModelWithProjection.from_pretrained("laion/larger_clap_general").to(device) 77 | logger.info("CLAP tokenizer and model loaded successfully") 78 | 79 | # syncformer 80 | syncformer_path = os.path.join(model_path, "synchformer_state_dict.pth") 81 | logger.info(f"Loading Synchformer model from: {syncformer_path}") 82 | syncformer_preprocess = v2.Compose( 83 | [ 84 | v2.Resize(224, interpolation=v2.InterpolationMode.BICUBIC), 85 | v2.CenterCrop(224), 86 | v2.ToImage(), 87 | v2.ToDtype(torch.float32, scale=True), 88 | v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 89 | ] 90 | ) 91 | 92 | syncformer_model = Synchformer() 93 | syncformer_model.load_state_dict(torch.load(syncformer_path, weights_only=False, map_location="cpu")) 94 | syncformer_model = syncformer_model.to(device).eval() 95 | logger.info("Synchformer model and preprocessing pipeline loaded successfully") 96 | 97 | 98 | logger.info("Creating model dictionary with attribute access...") 99 | model_dict = AttributeDict({ 100 | 'foley_model': foley_model, 101 | 'dac_model': dac_model, 102 | 'siglip2_preprocess': siglip2_preprocess, 103 | 'siglip2_model': siglip2_model, 104 | 'clap_tokenizer': clap_tokenizer, 105 | 'clap_model': clap_model, 106 | 'syncformer_preprocess': syncformer_preprocess, 107 | 'syncformer_model': syncformer_model, 108 | 'device': device, 109 | }) 110 | 111 | logger.info("All models loaded successfully!") 112 | logger.info("Available model components:") 113 | for key in model_dict.keys(): 114 | logger.info(f" - {key}") 115 | logger.info("Models can be accessed via attribute notation (e.g., models.foley_model)") 116 | 117 | return model_dict, cfg 118 | 119 | def retrieve_timesteps( 120 | scheduler, 121 | num_inference_steps, 122 | device, 123 | **kwargs, 124 | ): 125 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 126 | timesteps = scheduler.timesteps 127 | return timesteps, num_inference_steps 128 | 129 | 130 | def prepare_latents(scheduler, batch_size, num_channels_latents, length, dtype, device): 131 | shape = (batch_size, num_channels_latents, int(length)) 132 | latents = randn_tensor(shape, device=device, dtype=dtype) 133 | 134 | # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler 135 | if hasattr(scheduler, "init_noise_sigma"): 136 | # scale the initial noise by the standard deviation required by the scheduler 137 | latents = latents * scheduler.init_noise_sigma 138 | 139 | return latents 140 | 141 | 142 | @torch.no_grad() 143 | def denoise_process(visual_feats, text_feats, audio_len_in_s, model_dict, cfg, guidance_scale=4.5, num_inference_steps=50, batch_size=1): 144 | 145 | target_dtype = model_dict.foley_model.dtype 146 | autocast_enabled = target_dtype != torch.float32 147 | device = model_dict.device 148 | 149 | scheduler = FlowMatchDiscreteScheduler( 150 | shift=cfg.diffusion_config.sample_flow_shift, 151 | reverse=cfg.diffusion_config.flow_reverse, 152 | solver=cfg.diffusion_config.flow_solver, 153 | use_flux_shift=cfg.diffusion_config.sample_use_flux_shift, 154 | flux_base_shift=cfg.diffusion_config.flux_base_shift, 155 | flux_max_shift=cfg.diffusion_config.flux_max_shift, 156 | ) 157 | 158 | timesteps, num_inference_steps = retrieve_timesteps( 159 | scheduler, 160 | num_inference_steps, 161 | device, 162 | ) 163 | 164 | latents = prepare_latents( 165 | scheduler, 166 | batch_size=batch_size, 167 | num_channels_latents=cfg.model_config.model_kwargs.audio_vae_latent_dim, 168 | length=audio_len_in_s * cfg.model_config.model_kwargs.audio_frame_rate, 169 | dtype=target_dtype, 170 | device=device, 171 | ) 172 | 173 | # Denoise loop 174 | for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Denoising steps"): 175 | # noise latents 176 | latent_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents 177 | latent_input = scheduler.scale_model_input(latent_input, t) 178 | 179 | t_expand = t.repeat(latent_input.shape[0]) 180 | 181 | # siglip2 features 182 | siglip2_feat = visual_feats.siglip2_feat.repeat(batch_size, 1, 1) # Repeat for batch_size 183 | uncond_siglip2_feat = model_dict.foley_model.get_empty_clip_sequence( 184 | bs=batch_size, len=siglip2_feat.shape[1] 185 | ).to(device) 186 | 187 | if guidance_scale is not None and guidance_scale > 1.0: 188 | siglip2_feat_input = torch.cat([uncond_siglip2_feat, siglip2_feat], dim=0) 189 | else: 190 | siglip2_feat_input = siglip2_feat 191 | 192 | # syncformer features 193 | syncformer_feat = visual_feats.syncformer_feat.repeat(batch_size, 1, 1) # Repeat for batch_size 194 | uncond_syncformer_feat = model_dict.foley_model.get_empty_sync_sequence( 195 | bs=batch_size, len=syncformer_feat.shape[1] 196 | ).to(device) 197 | if guidance_scale is not None and guidance_scale > 1.0: 198 | syncformer_feat_input = torch.cat([uncond_syncformer_feat, syncformer_feat], dim=0) 199 | else: 200 | syncformer_feat_input = syncformer_feat 201 | 202 | # text features 203 | text_feat_repeated = text_feats.text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size 204 | uncond_text_feat_repeated = text_feats.uncond_text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size 205 | if guidance_scale is not None and guidance_scale > 1.0: 206 | text_feat_input = torch.cat([uncond_text_feat_repeated, text_feat_repeated], dim=0) 207 | else: 208 | text_feat_input = text_feat_repeated 209 | 210 | with torch.autocast(device_type=device.type, enabled=autocast_enabled, dtype=target_dtype): 211 | # Predict the noise residual 212 | noise_pred = model_dict.foley_model( 213 | x=latent_input, 214 | t=t_expand, 215 | cond=text_feat_input, 216 | clip_feat=siglip2_feat_input, 217 | sync_feat=syncformer_feat_input, 218 | return_dict=True, 219 | )["x"] 220 | 221 | noise_pred = noise_pred.to(dtype=torch.float32) 222 | 223 | if guidance_scale is not None and guidance_scale > 1.0: 224 | # Perform classifier-free guidance 225 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 226 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 227 | 228 | # Compute the previous noisy sample x_t -> x_t-1 229 | latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] 230 | 231 | # Post-process the latents to audio 232 | 233 | with torch.no_grad(): 234 | audio = model_dict.dac_model.decode(latents) 235 | audio = audio.float().cpu() 236 | 237 | audio = audio[:, :int(audio_len_in_s*model_dict.dac_model.sample_rate)] 238 | 239 | return audio, model_dict.dac_model.sample_rate 240 | 241 | 242 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/model/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Union 5 | 6 | import numpy as np 7 | import torch 8 | import tqdm 9 | from audiotools import AudioSignal 10 | from torch import nn 11 | 12 | SUPPORTED_VERSIONS = ["1.0.0"] 13 | 14 | 15 | @dataclass 16 | class DACFile: 17 | codes: torch.Tensor 18 | 19 | # Metadata 20 | chunk_length: int 21 | original_length: int 22 | input_db: float 23 | channels: int 24 | sample_rate: int 25 | padding: bool 26 | dac_version: str 27 | 28 | def save(self, path): 29 | artifacts = { 30 | "codes": self.codes.numpy().astype(np.uint16), 31 | "metadata": { 32 | "input_db": self.input_db.numpy().astype(np.float32), 33 | "original_length": self.original_length, 34 | "sample_rate": self.sample_rate, 35 | "chunk_length": self.chunk_length, 36 | "channels": self.channels, 37 | "padding": self.padding, 38 | "dac_version": SUPPORTED_VERSIONS[-1], 39 | }, 40 | } 41 | path = Path(path).with_suffix(".dac") 42 | with open(path, "wb") as f: 43 | np.save(f, artifacts) 44 | return path 45 | 46 | @classmethod 47 | def load(cls, path): 48 | artifacts = np.load(path, allow_pickle=True)[()] 49 | codes = torch.from_numpy(artifacts["codes"].astype(int)) 50 | if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: 51 | raise RuntimeError( 52 | f"Given file {path} can't be loaded with this version of descript-audio-codec." 53 | ) 54 | return cls(codes=codes, **artifacts["metadata"]) 55 | 56 | 57 | class CodecMixin: 58 | @property 59 | def padding(self): 60 | if not hasattr(self, "_padding"): 61 | self._padding = True 62 | return self._padding 63 | 64 | @padding.setter 65 | def padding(self, value): 66 | assert isinstance(value, bool) 67 | 68 | layers = [ 69 | l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d)) 70 | ] 71 | 72 | for layer in layers: 73 | if value: 74 | if hasattr(layer, "original_padding"): 75 | layer.padding = layer.original_padding 76 | else: 77 | layer.original_padding = layer.padding 78 | layer.padding = tuple(0 for _ in range(len(layer.padding))) 79 | 80 | self._padding = value 81 | 82 | def get_delay(self): 83 | # Any number works here, delay is invariant to input length 84 | l_out = self.get_output_length(0) 85 | L = l_out 86 | 87 | layers = [] 88 | for layer in self.modules(): 89 | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): 90 | layers.append(layer) 91 | 92 | for layer in reversed(layers): 93 | d = layer.dilation[0] 94 | k = layer.kernel_size[0] 95 | s = layer.stride[0] 96 | 97 | if isinstance(layer, nn.ConvTranspose1d): 98 | L = ((L - d * (k - 1) - 1) / s) + 1 99 | elif isinstance(layer, nn.Conv1d): 100 | L = (L - 1) * s + d * (k - 1) + 1 101 | 102 | L = math.ceil(L) 103 | 104 | l_in = L 105 | 106 | return (l_in - l_out) // 2 107 | 108 | def get_output_length(self, input_length): 109 | L = input_length 110 | # Calculate output length 111 | for layer in self.modules(): 112 | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): 113 | d = layer.dilation[0] 114 | k = layer.kernel_size[0] 115 | s = layer.stride[0] 116 | 117 | if isinstance(layer, nn.Conv1d): 118 | L = ((L - d * (k - 1) - 1) / s) + 1 119 | elif isinstance(layer, nn.ConvTranspose1d): 120 | L = (L - 1) * s + d * (k - 1) + 1 121 | 122 | L = math.floor(L) 123 | return L 124 | 125 | @torch.no_grad() 126 | def compress( 127 | self, 128 | audio_path_or_signal: Union[str, Path, AudioSignal], 129 | win_duration: float = 1.0, 130 | verbose: bool = False, 131 | normalize_db: float = -16, 132 | n_quantizers: int = None, 133 | ) -> DACFile: 134 | """Processes an audio signal from a file or AudioSignal object into 135 | discrete codes. This function processes the signal in short windows, 136 | using constant GPU memory. 137 | 138 | Parameters 139 | ---------- 140 | audio_path_or_signal : Union[str, Path, AudioSignal] 141 | audio signal to reconstruct 142 | win_duration : float, optional 143 | window duration in seconds, by default 5.0 144 | verbose : bool, optional 145 | by default False 146 | normalize_db : float, optional 147 | normalize db, by default -16 148 | 149 | Returns 150 | ------- 151 | DACFile 152 | Object containing compressed codes and metadata 153 | required for decompression 154 | """ 155 | audio_signal = audio_path_or_signal 156 | if isinstance(audio_signal, (str, Path)): 157 | audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) 158 | 159 | self.eval() 160 | original_padding = self.padding 161 | original_device = audio_signal.device 162 | 163 | audio_signal = audio_signal.clone() 164 | audio_signal = audio_signal.to_mono() 165 | original_sr = audio_signal.sample_rate 166 | 167 | resample_fn = audio_signal.resample 168 | loudness_fn = audio_signal.loudness 169 | 170 | # If audio is > 10 minutes long, use the ffmpeg versions 171 | if audio_signal.signal_duration >= 10 * 60 * 60: 172 | resample_fn = audio_signal.ffmpeg_resample 173 | loudness_fn = audio_signal.ffmpeg_loudness 174 | 175 | original_length = audio_signal.signal_length 176 | resample_fn(self.sample_rate) 177 | input_db = loudness_fn() 178 | 179 | if normalize_db is not None: 180 | audio_signal.normalize(normalize_db) 181 | audio_signal.ensure_max_of_audio() 182 | 183 | nb, nac, nt = audio_signal.audio_data.shape 184 | audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) 185 | win_duration = ( 186 | audio_signal.signal_duration if win_duration is None else win_duration 187 | ) 188 | 189 | if audio_signal.signal_duration <= win_duration: 190 | # Unchunked compression (used if signal length < win duration) 191 | self.padding = True 192 | n_samples = nt 193 | hop = nt 194 | else: 195 | # Chunked inference 196 | self.padding = False 197 | # Zero-pad signal on either side by the delay 198 | audio_signal.zero_pad(self.delay, self.delay) 199 | n_samples = int(win_duration * self.sample_rate) 200 | # Round n_samples to nearest hop length multiple 201 | n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) 202 | hop = self.get_output_length(n_samples) 203 | 204 | codes = [] 205 | range_fn = range if not verbose else tqdm.trange 206 | 207 | for i in range_fn(0, nt, hop): 208 | x = audio_signal[..., i : i + n_samples] 209 | x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) 210 | 211 | audio_data = x.audio_data.to(self.device) 212 | audio_data = self.preprocess(audio_data, self.sample_rate) 213 | _, c, _, _, _ = self.encode(audio_data, n_quantizers) 214 | codes.append(c.to(original_device)) 215 | chunk_length = c.shape[-1] 216 | 217 | codes = torch.cat(codes, dim=-1) 218 | 219 | dac_file = DACFile( 220 | codes=codes, 221 | chunk_length=chunk_length, 222 | original_length=original_length, 223 | input_db=input_db, 224 | channels=nac, 225 | sample_rate=original_sr, 226 | padding=self.padding, 227 | dac_version=SUPPORTED_VERSIONS[-1], 228 | ) 229 | 230 | if n_quantizers is not None: 231 | codes = codes[:, :n_quantizers, :] 232 | 233 | self.padding = original_padding 234 | return dac_file 235 | 236 | @torch.no_grad() 237 | def decompress( 238 | self, 239 | obj: Union[str, Path, DACFile], 240 | verbose: bool = False, 241 | ) -> AudioSignal: 242 | """Reconstruct audio from a given .dac file 243 | 244 | Parameters 245 | ---------- 246 | obj : Union[str, Path, DACFile] 247 | .dac file location or corresponding DACFile object. 248 | verbose : bool, optional 249 | Prints progress if True, by default False 250 | 251 | Returns 252 | ------- 253 | AudioSignal 254 | Object with the reconstructed audio 255 | """ 256 | self.eval() 257 | if isinstance(obj, (str, Path)): 258 | obj = DACFile.load(obj) 259 | 260 | original_padding = self.padding 261 | self.padding = obj.padding 262 | 263 | range_fn = range if not verbose else tqdm.trange 264 | codes = obj.codes 265 | original_device = codes.device 266 | chunk_length = obj.chunk_length 267 | recons = [] 268 | 269 | for i in range_fn(0, codes.shape[-1], chunk_length): 270 | c = codes[..., i : i + chunk_length].to(self.device) 271 | z = self.quantizer.from_codes(c)[0] 272 | r = self.decode(z) 273 | recons.append(r.to(original_device)) 274 | 275 | recons = torch.cat(recons, dim=-1) 276 | recons = AudioSignal(recons, self.sample_rate) 277 | 278 | resample_fn = recons.resample 279 | loudness_fn = recons.loudness 280 | 281 | # If audio is > 10 minutes long, use the ffmpeg versions 282 | if recons.signal_duration >= 10 * 60 * 60: 283 | resample_fn = recons.ffmpeg_resample 284 | loudness_fn = recons.ffmpeg_loudness 285 | 286 | if obj.input_db is not None: 287 | recons.normalize(obj.input_db) 288 | 289 | resample_fn(obj.sample_rate) 290 | 291 | if obj.original_length is not None: 292 | recons = recons[..., : obj.original_length] 293 | loudness_fn() 294 | recons.audio_data = recons.audio_data.reshape( 295 | -1, obj.channels, obj.original_length 296 | ) 297 | else: 298 | loudness_fn() 299 | 300 | self.padding = original_padding 301 | return recons 302 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/video_model_builder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | # Copyright 2020 Ross Wightman 4 | # Modified Model definition 5 | 6 | from collections import OrderedDict 7 | from functools import partial 8 | 9 | import torch 10 | import torch.nn as nn 11 | from timm.layers import trunc_normal_ 12 | 13 | from .vit_helper import PatchEmbed, PatchEmbed3D, DividedSpaceTimeBlock 14 | 15 | 16 | class VisionTransformer(nn.Module): 17 | """Vision Transformer with support for patch or hybrid CNN input stage""" 18 | 19 | def __init__(self, cfg): 20 | super().__init__() 21 | self.img_size = cfg.DATA.TRAIN_CROP_SIZE 22 | self.patch_size = cfg.VIT.PATCH_SIZE 23 | self.in_chans = cfg.VIT.CHANNELS 24 | if cfg.TRAIN.DATASET == "Epickitchens": 25 | self.num_classes = [97, 300] 26 | else: 27 | self.num_classes = cfg.MODEL.NUM_CLASSES 28 | self.embed_dim = cfg.VIT.EMBED_DIM 29 | self.depth = cfg.VIT.DEPTH 30 | self.num_heads = cfg.VIT.NUM_HEADS 31 | self.mlp_ratio = cfg.VIT.MLP_RATIO 32 | self.qkv_bias = cfg.VIT.QKV_BIAS 33 | self.drop_rate = cfg.VIT.DROP 34 | self.drop_path_rate = cfg.VIT.DROP_PATH 35 | self.head_dropout = cfg.VIT.HEAD_DROPOUT 36 | self.video_input = cfg.VIT.VIDEO_INPUT 37 | self.temporal_resolution = cfg.VIT.TEMPORAL_RESOLUTION 38 | self.use_mlp = cfg.VIT.USE_MLP 39 | self.num_features = self.embed_dim 40 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 41 | self.attn_drop_rate = cfg.VIT.ATTN_DROPOUT 42 | self.head_act = cfg.VIT.HEAD_ACT 43 | self.cfg = cfg 44 | 45 | # Patch Embedding 46 | self.patch_embed = PatchEmbed( 47 | img_size=224, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim 48 | ) 49 | 50 | # 3D Patch Embedding 51 | self.patch_embed_3d = PatchEmbed3D( 52 | img_size=self.img_size, 53 | temporal_resolution=self.temporal_resolution, 54 | patch_size=self.patch_size, 55 | in_chans=self.in_chans, 56 | embed_dim=self.embed_dim, 57 | z_block_size=self.cfg.VIT.PATCH_SIZE_TEMP, 58 | ) 59 | self.patch_embed_3d.proj.weight.data = torch.zeros_like(self.patch_embed_3d.proj.weight.data) 60 | 61 | # Number of patches 62 | if self.video_input: 63 | num_patches = self.patch_embed.num_patches * self.temporal_resolution 64 | else: 65 | num_patches = self.patch_embed.num_patches 66 | self.num_patches = num_patches 67 | 68 | # CLS token 69 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 70 | trunc_normal_(self.cls_token, std=0.02) 71 | 72 | # Positional embedding 73 | self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim)) 74 | self.pos_drop = nn.Dropout(p=cfg.VIT.POS_DROPOUT) 75 | trunc_normal_(self.pos_embed, std=0.02) 76 | 77 | if self.cfg.VIT.POS_EMBED == "joint": 78 | self.st_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim)) 79 | trunc_normal_(self.st_embed, std=0.02) 80 | elif self.cfg.VIT.POS_EMBED == "separate": 81 | self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim)) 82 | 83 | # Layer Blocks 84 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] 85 | if self.cfg.VIT.ATTN_LAYER == "divided": 86 | self.blocks = nn.ModuleList( 87 | [ 88 | DividedSpaceTimeBlock( 89 | attn_type=cfg.VIT.ATTN_LAYER, 90 | dim=self.embed_dim, 91 | num_heads=self.num_heads, 92 | mlp_ratio=self.mlp_ratio, 93 | qkv_bias=self.qkv_bias, 94 | drop=self.drop_rate, 95 | attn_drop=self.attn_drop_rate, 96 | drop_path=dpr[i], 97 | norm_layer=norm_layer, 98 | ) 99 | for i in range(self.depth) 100 | ] 101 | ) 102 | 103 | self.norm = norm_layer(self.embed_dim) 104 | 105 | # MLP head 106 | if self.use_mlp: 107 | hidden_dim = self.embed_dim 108 | if self.head_act == "tanh": 109 | # logging.info("Using TanH activation in MLP") 110 | act = nn.Tanh() 111 | elif self.head_act == "gelu": 112 | # logging.info("Using GELU activation in MLP") 113 | act = nn.GELU() 114 | else: 115 | # logging.info("Using ReLU activation in MLP") 116 | act = nn.ReLU() 117 | self.pre_logits = nn.Sequential( 118 | OrderedDict( 119 | [ 120 | ("fc", nn.Linear(self.embed_dim, hidden_dim)), 121 | ("act", act), 122 | ] 123 | ) 124 | ) 125 | else: 126 | self.pre_logits = nn.Identity() 127 | 128 | # Classifier Head 129 | self.head_drop = nn.Dropout(p=self.head_dropout) 130 | if isinstance(self.num_classes, (list,)) and len(self.num_classes) > 1: 131 | for a, i in enumerate(range(len(self.num_classes))): 132 | setattr(self, "head%d" % a, nn.Linear(self.embed_dim, self.num_classes[i])) 133 | else: 134 | self.head = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() 135 | 136 | # Initialize weights 137 | self.apply(self._init_weights) 138 | 139 | def _init_weights(self, m): 140 | if isinstance(m, nn.Linear): 141 | trunc_normal_(m.weight, std=0.02) 142 | if isinstance(m, nn.Linear) and m.bias is not None: 143 | nn.init.constant_(m.bias, 0) 144 | elif isinstance(m, nn.LayerNorm): 145 | nn.init.constant_(m.bias, 0) 146 | nn.init.constant_(m.weight, 1.0) 147 | 148 | @torch.jit.ignore 149 | def no_weight_decay(self): 150 | if self.cfg.VIT.POS_EMBED == "joint": 151 | return {"pos_embed", "cls_token", "st_embed"} 152 | else: 153 | return {"pos_embed", "cls_token", "temp_embed"} 154 | 155 | def get_classifier(self): 156 | return self.head 157 | 158 | def reset_classifier(self, num_classes, global_pool=""): 159 | self.num_classes = num_classes 160 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 161 | 162 | def forward_features(self, x): 163 | # if self.video_input: 164 | # x = x[0] 165 | B = x.shape[0] 166 | 167 | # Tokenize input 168 | # if self.cfg.VIT.PATCH_SIZE_TEMP > 1: 169 | # for simplicity of mapping between content dimensions (input x) and token dims (after patching) 170 | # we use the same trick as for AST (see modeling_ast.ASTModel.forward for the details): 171 | 172 | # apply patching on input 173 | x = self.patch_embed_3d(x) 174 | tok_mask = None 175 | 176 | # else: 177 | # tok_mask = None 178 | # # 2D tokenization 179 | # if self.video_input: 180 | # x = x.permute(0, 2, 1, 3, 4) 181 | # (B, T, C, H, W) = x.shape 182 | # x = x.reshape(B * T, C, H, W) 183 | 184 | # x = self.patch_embed(x) 185 | 186 | # if self.video_input: 187 | # (B2, T2, D2) = x.shape 188 | # x = x.reshape(B, T * T2, D2) 189 | 190 | # Append CLS token 191 | cls_tokens = self.cls_token.expand(B, -1, -1) 192 | x = torch.cat((cls_tokens, x), dim=1) 193 | # if tok_mask is not None: 194 | # # prepend 1(=keep) to the mask to account for the CLS token as well 195 | # tok_mask = torch.cat((torch.ones_like(tok_mask[:, [0]]), tok_mask), dim=1) 196 | 197 | # Interpolate positinoal embeddings 198 | # if self.cfg.DATA.TRAIN_CROP_SIZE != 224: 199 | # pos_embed = self.pos_embed 200 | # N = pos_embed.shape[1] - 1 201 | # npatch = int((x.size(1) - 1) / self.temporal_resolution) 202 | # class_emb = pos_embed[:, 0] 203 | # pos_embed = pos_embed[:, 1:] 204 | # dim = x.shape[-1] 205 | # pos_embed = torch.nn.functional.interpolate( 206 | # pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 207 | # scale_factor=math.sqrt(npatch / N), 208 | # mode='bicubic', 209 | # ) 210 | # pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 211 | # new_pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1) 212 | # else: 213 | new_pos_embed = self.pos_embed 214 | npatch = self.patch_embed.num_patches 215 | 216 | # Add positional embeddings to input 217 | if self.video_input: 218 | if self.cfg.VIT.POS_EMBED == "separate": 219 | cls_embed = self.pos_embed[:, 0, :].unsqueeze(1) 220 | tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1) 221 | tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1) 222 | total_pos_embed = tile_pos_embed + tile_temporal_embed 223 | total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1) 224 | x = x + total_pos_embed 225 | elif self.cfg.VIT.POS_EMBED == "joint": 226 | x = x + self.st_embed 227 | else: 228 | # image input 229 | x = x + new_pos_embed 230 | 231 | # Apply positional dropout 232 | x = self.pos_drop(x) 233 | 234 | # Encoding using transformer layers 235 | for i, blk in enumerate(self.blocks): 236 | x = blk( 237 | x, 238 | seq_len=npatch, 239 | num_frames=self.temporal_resolution, 240 | approx=self.cfg.VIT.APPROX_ATTN_TYPE, 241 | num_landmarks=self.cfg.VIT.APPROX_ATTN_DIM, 242 | tok_mask=tok_mask, 243 | ) 244 | 245 | ### v-iashin: I moved it to the forward pass 246 | # x = self.norm(x)[:, 0] 247 | # x = self.pre_logits(x) 248 | ### 249 | return x, tok_mask 250 | 251 | # def forward(self, x): 252 | # x = self.forward_features(x) 253 | # ### v-iashin: here. This should leave the same forward output as before 254 | # x = self.norm(x)[:, 0] 255 | # x = self.pre_logits(x) 256 | # ### 257 | # x = self.head_drop(x) 258 | # if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1: 259 | # output = [] 260 | # for head in range(len(self.num_classes)): 261 | # x_out = getattr(self, "head%d" % head)(x) 262 | # if not self.training: 263 | # x_out = torch.nn.functional.softmax(x_out, dim=-1) 264 | # output.append(x_out) 265 | # return output 266 | # else: 267 | # x = self.head(x) 268 | # if not self.training: 269 | # x = torch.nn.functional.softmax(x, dim=-1) 270 | # return x 271 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/loss.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from typing import List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from audiotools import AudioSignal 7 | from audiotools import STFTParams 8 | from torch import nn 9 | 10 | 11 | class L1Loss(nn.L1Loss): 12 | """L1 Loss between AudioSignals. Defaults 13 | to comparing ``audio_data``, but any 14 | attribute of an AudioSignal can be used. 15 | 16 | Parameters 17 | ---------- 18 | attribute : str, optional 19 | Attribute of signal to compare, defaults to ``audio_data``. 20 | weight : float, optional 21 | Weight of this loss, defaults to 1.0. 22 | 23 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py 24 | """ 25 | 26 | def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): 27 | self.attribute = attribute 28 | self.weight = weight 29 | super().__init__(**kwargs) 30 | 31 | def forward(self, x: AudioSignal, y: AudioSignal): 32 | """ 33 | Parameters 34 | ---------- 35 | x : AudioSignal 36 | Estimate AudioSignal 37 | y : AudioSignal 38 | Reference AudioSignal 39 | 40 | Returns 41 | ------- 42 | torch.Tensor 43 | L1 loss between AudioSignal attributes. 44 | """ 45 | if isinstance(x, AudioSignal): 46 | x = getattr(x, self.attribute) 47 | y = getattr(y, self.attribute) 48 | return super().forward(x, y) 49 | 50 | 51 | class SISDRLoss(nn.Module): 52 | """ 53 | Computes the Scale-Invariant Source-to-Distortion Ratio between a batch 54 | of estimated and reference audio signals or aligned features. 55 | 56 | Parameters 57 | ---------- 58 | scaling : int, optional 59 | Whether to use scale-invariant (True) or 60 | signal-to-noise ratio (False), by default True 61 | reduction : str, optional 62 | How to reduce across the batch (either 'mean', 63 | 'sum', or none).], by default ' mean' 64 | zero_mean : int, optional 65 | Zero mean the references and estimates before 66 | computing the loss, by default True 67 | clip_min : int, optional 68 | The minimum possible loss value. Helps network 69 | to not focus on making already good examples better, by default None 70 | weight : float, optional 71 | Weight of this loss, defaults to 1.0. 72 | 73 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py 74 | """ 75 | 76 | def __init__( 77 | self, 78 | scaling: int = True, 79 | reduction: str = "mean", 80 | zero_mean: int = True, 81 | clip_min: int = None, 82 | weight: float = 1.0, 83 | ): 84 | self.scaling = scaling 85 | self.reduction = reduction 86 | self.zero_mean = zero_mean 87 | self.clip_min = clip_min 88 | self.weight = weight 89 | super().__init__() 90 | 91 | def forward(self, x: AudioSignal, y: AudioSignal): 92 | eps = 1e-8 93 | # nb, nc, nt 94 | if isinstance(x, AudioSignal): 95 | references = x.audio_data 96 | estimates = y.audio_data 97 | else: 98 | references = x 99 | estimates = y 100 | 101 | nb = references.shape[0] 102 | references = references.reshape(nb, 1, -1).permute(0, 2, 1) 103 | estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) 104 | 105 | # samples now on axis 1 106 | if self.zero_mean: 107 | mean_reference = references.mean(dim=1, keepdim=True) 108 | mean_estimate = estimates.mean(dim=1, keepdim=True) 109 | else: 110 | mean_reference = 0 111 | mean_estimate = 0 112 | 113 | _references = references - mean_reference 114 | _estimates = estimates - mean_estimate 115 | 116 | references_projection = (_references**2).sum(dim=-2) + eps 117 | references_on_estimates = (_estimates * _references).sum(dim=-2) + eps 118 | 119 | scale = ( 120 | (references_on_estimates / references_projection).unsqueeze(1) 121 | if self.scaling 122 | else 1 123 | ) 124 | 125 | e_true = scale * _references 126 | e_res = _estimates - e_true 127 | 128 | signal = (e_true**2).sum(dim=1) 129 | noise = (e_res**2).sum(dim=1) 130 | sdr = -10 * torch.log10(signal / noise + eps) 131 | 132 | if self.clip_min is not None: 133 | sdr = torch.clamp(sdr, min=self.clip_min) 134 | 135 | if self.reduction == "mean": 136 | sdr = sdr.mean() 137 | elif self.reduction == "sum": 138 | sdr = sdr.sum() 139 | return sdr 140 | 141 | 142 | class MultiScaleSTFTLoss(nn.Module): 143 | """Computes the multi-scale STFT loss from [1]. 144 | 145 | Parameters 146 | ---------- 147 | window_lengths : List[int], optional 148 | Length of each window of each STFT, by default [2048, 512] 149 | loss_fn : typing.Callable, optional 150 | How to compare each loss, by default nn.L1Loss() 151 | clamp_eps : float, optional 152 | Clamp on the log magnitude, below, by default 1e-5 153 | mag_weight : float, optional 154 | Weight of raw magnitude portion of loss, by default 1.0 155 | log_weight : float, optional 156 | Weight of log magnitude portion of loss, by default 1.0 157 | pow : float, optional 158 | Power to raise magnitude to before taking log, by default 2.0 159 | weight : float, optional 160 | Weight of this loss, by default 1.0 161 | match_stride : bool, optional 162 | Whether to match the stride of convolutional layers, by default False 163 | 164 | References 165 | ---------- 166 | 167 | 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. 168 | "DDSP: Differentiable Digital Signal Processing." 169 | International Conference on Learning Representations. 2019. 170 | 171 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py 172 | """ 173 | 174 | def __init__( 175 | self, 176 | window_lengths: List[int] = [2048, 512], 177 | loss_fn: typing.Callable = nn.L1Loss(), 178 | clamp_eps: float = 1e-5, 179 | mag_weight: float = 1.0, 180 | log_weight: float = 1.0, 181 | pow: float = 2.0, 182 | weight: float = 1.0, 183 | match_stride: bool = False, 184 | window_type: str = None, 185 | ): 186 | super().__init__() 187 | self.stft_params = [ 188 | STFTParams( 189 | window_length=w, 190 | hop_length=w // 4, 191 | match_stride=match_stride, 192 | window_type=window_type, 193 | ) 194 | for w in window_lengths 195 | ] 196 | self.loss_fn = loss_fn 197 | self.log_weight = log_weight 198 | self.mag_weight = mag_weight 199 | self.clamp_eps = clamp_eps 200 | self.weight = weight 201 | self.pow = pow 202 | 203 | def forward(self, x: AudioSignal, y: AudioSignal): 204 | """Computes multi-scale STFT between an estimate and a reference 205 | signal. 206 | 207 | Parameters 208 | ---------- 209 | x : AudioSignal 210 | Estimate signal 211 | y : AudioSignal 212 | Reference signal 213 | 214 | Returns 215 | ------- 216 | torch.Tensor 217 | Multi-scale STFT loss. 218 | """ 219 | loss = 0.0 220 | for s in self.stft_params: 221 | x.stft(s.window_length, s.hop_length, s.window_type) 222 | y.stft(s.window_length, s.hop_length, s.window_type) 223 | loss += self.log_weight * self.loss_fn( 224 | x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), 225 | y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), 226 | ) 227 | loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) 228 | return loss 229 | 230 | 231 | class MelSpectrogramLoss(nn.Module): 232 | """Compute distance between mel spectrograms. Can be used 233 | in a multi-scale way. 234 | 235 | Parameters 236 | ---------- 237 | n_mels : List[int] 238 | Number of mels per STFT, by default [150, 80], 239 | window_lengths : List[int], optional 240 | Length of each window of each STFT, by default [2048, 512] 241 | loss_fn : typing.Callable, optional 242 | How to compare each loss, by default nn.L1Loss() 243 | clamp_eps : float, optional 244 | Clamp on the log magnitude, below, by default 1e-5 245 | mag_weight : float, optional 246 | Weight of raw magnitude portion of loss, by default 1.0 247 | log_weight : float, optional 248 | Weight of log magnitude portion of loss, by default 1.0 249 | pow : float, optional 250 | Power to raise magnitude to before taking log, by default 2.0 251 | weight : float, optional 252 | Weight of this loss, by default 1.0 253 | match_stride : bool, optional 254 | Whether to match the stride of convolutional layers, by default False 255 | 256 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py 257 | """ 258 | 259 | def __init__( 260 | self, 261 | n_mels: List[int] = [150, 80], 262 | window_lengths: List[int] = [2048, 512], 263 | loss_fn: typing.Callable = nn.L1Loss(), 264 | clamp_eps: float = 1e-5, 265 | mag_weight: float = 1.0, 266 | log_weight: float = 1.0, 267 | pow: float = 2.0, 268 | weight: float = 1.0, 269 | match_stride: bool = False, 270 | mel_fmin: List[float] = [0.0, 0.0], 271 | mel_fmax: List[float] = [None, None], 272 | window_type: str = None, 273 | ): 274 | super().__init__() 275 | self.stft_params = [ 276 | STFTParams( 277 | window_length=w, 278 | hop_length=w // 4, 279 | match_stride=match_stride, 280 | window_type=window_type, 281 | ) 282 | for w in window_lengths 283 | ] 284 | self.n_mels = n_mels 285 | self.loss_fn = loss_fn 286 | self.clamp_eps = clamp_eps 287 | self.log_weight = log_weight 288 | self.mag_weight = mag_weight 289 | self.weight = weight 290 | self.mel_fmin = mel_fmin 291 | self.mel_fmax = mel_fmax 292 | self.pow = pow 293 | 294 | def forward(self, x: AudioSignal, y: AudioSignal): 295 | """Computes mel loss between an estimate and a reference 296 | signal. 297 | 298 | Parameters 299 | ---------- 300 | x : AudioSignal 301 | Estimate signal 302 | y : AudioSignal 303 | Reference signal 304 | 305 | Returns 306 | ------- 307 | torch.Tensor 308 | Mel loss. 309 | """ 310 | loss = 0.0 311 | for n_mels, fmin, fmax, s in zip( 312 | self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params 313 | ): 314 | kwargs = { 315 | "window_length": s.window_length, 316 | "hop_length": s.hop_length, 317 | "window_type": s.window_type, 318 | } 319 | x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) 320 | y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) 321 | 322 | loss += self.log_weight * self.loss_fn( 323 | x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), 324 | y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), 325 | ) 326 | loss += self.mag_weight * self.loss_fn(x_mels, y_mels) 327 | return loss 328 | 329 | 330 | class GANLoss(nn.Module): 331 | """ 332 | Computes a discriminator loss, given a discriminator on 333 | generated waveforms/spectrograms compared to ground truth 334 | waveforms/spectrograms. Computes the loss for both the 335 | discriminator and the generator in separate functions. 336 | """ 337 | 338 | def __init__(self, discriminator): 339 | super().__init__() 340 | self.discriminator = discriminator 341 | 342 | def forward(self, fake, real): 343 | d_fake = self.discriminator(fake.audio_data) 344 | d_real = self.discriminator(real.audio_data) 345 | return d_fake, d_real 346 | 347 | def discriminator_loss(self, fake, real): 348 | d_fake, d_real = self.forward(fake.clone().detach(), real) 349 | 350 | loss_d = 0 351 | for x_fake, x_real in zip(d_fake, d_real): 352 | loss_d += torch.mean(x_fake[-1] ** 2) 353 | loss_d += torch.mean((1 - x_real[-1]) ** 2) 354 | return loss_d 355 | 356 | def generator_loss(self, fake, real): 357 | d_fake, d_real = self.forward(fake, real) 358 | 359 | loss_g = 0 360 | for x_fake in d_fake: 361 | loss_g += torch.mean((1 - x_fake[-1]) ** 2) 362 | 363 | loss_feature = 0 364 | 365 | for i in range(len(d_fake)): 366 | for j in range(len(d_fake[i]) - 1): 367 | loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) 368 | return loss_g, loss_feature 369 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/synchformer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from typing import Any, Mapping 4 | 5 | import einops 6 | import numpy as np 7 | import torch 8 | import torchaudio 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | from .motionformer import MotionFormer 13 | from .ast_model import AST 14 | from .utils import Config 15 | 16 | 17 | class Synchformer(nn.Module): 18 | 19 | def __init__(self): 20 | super().__init__() 21 | 22 | self.vfeat_extractor = MotionFormer( 23 | extract_features=True, 24 | factorize_space_time=True, 25 | agg_space_module="TransformerEncoderLayer", 26 | agg_time_module="torch.nn.Identity", 27 | add_global_repr=False, 28 | ) 29 | self.afeat_extractor = AST( 30 | extract_features=True, 31 | max_spec_t=66, 32 | factorize_freq_time=True, 33 | agg_freq_module="TransformerEncoderLayer", 34 | agg_time_module="torch.nn.Identity", 35 | add_global_repr=False, 36 | ) 37 | 38 | # # bridging the s3d latent dim (1024) into what is specified in the config 39 | # # to match e.g. the transformer dim 40 | self.vproj = nn.Linear(in_features=768, out_features=768) 41 | self.aproj = nn.Linear(in_features=768, out_features=768) 42 | self.transformer = GlobalTransformer( 43 | tok_pdrop=0.0, embd_pdrop=0.1, resid_pdrop=0.1, attn_pdrop=0.1, n_layer=3, n_head=8, n_embd=768 44 | ) 45 | 46 | def forward(self, vis): 47 | B, S, Tv, C, H, W = vis.shape 48 | vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W) 49 | # feat extractors return a tuple of segment-level and global features (ignored for sync) 50 | # (B, S, tv, D), e.g. (B, 7, 8, 768) 51 | vis = self.vfeat_extractor(vis) 52 | return vis 53 | 54 | def compare_v_a(self, vis: torch.Tensor, aud: torch.Tensor): 55 | vis = self.vproj(vis) 56 | aud = self.aproj(aud) 57 | 58 | B, S, tv, D = vis.shape 59 | B, S, ta, D = aud.shape 60 | vis = vis.view(B, S * tv, D) # (B, S*tv, D) 61 | aud = aud.view(B, S * ta, D) # (B, S*ta, D) 62 | # print(vis.shape, aud.shape) 63 | 64 | # self.transformer will concatenate the vis and aud in one sequence with aux tokens, 65 | # ie `CvvvvMaaaaaa`, and will return the logits for the CLS tokens 66 | logits = self.transformer(vis, aud) # (B, cls); or (B, cls) and (B, 2) if DoubtingTransformer 67 | 68 | return logits 69 | 70 | def extract_vfeats(self, vis): 71 | B, S, Tv, C, H, W = vis.shape 72 | vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W) 73 | # feat extractors return a tuple of segment-level and global features (ignored for sync) 74 | # (B, S, tv, D), e.g. (B, 7, 8, 768) 75 | vis = self.vfeat_extractor(vis) 76 | return vis 77 | 78 | def extract_afeats(self, aud): 79 | B, S, _, Fa, Ta = aud.shape 80 | aud = aud.view(B, S, Fa, Ta).permute(0, 1, 3, 2) # (B, S, Ta, F) 81 | # (B, S, ta, D), e.g. (B, 7, 6, 768) 82 | aud, _ = self.afeat_extractor(aud) 83 | return aud 84 | 85 | def compute_loss(self, logits, targets, loss_fn: str = None): 86 | loss = None 87 | if targets is not None: 88 | if loss_fn is None or loss_fn == "cross_entropy": 89 | # logits: (B, cls) and targets: (B,) 90 | loss = F.cross_entropy(logits, targets) 91 | else: 92 | raise NotImplementedError(f"Loss {loss_fn} not implemented") 93 | return loss 94 | 95 | def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True): 96 | # discard all entries except vfeat_extractor 97 | # sd = {k: v for k, v in sd.items() if k.startswith('vfeat_extractor')} 98 | 99 | return super().load_state_dict(sd, strict) 100 | 101 | 102 | class RandInitPositionalEncoding(nn.Module): 103 | """Random inited trainable pos embedding. It is just applied on the sequence, thus respects no priors.""" 104 | 105 | def __init__(self, block_shape: list, n_embd: int): 106 | super().__init__() 107 | self.block_shape = block_shape 108 | self.n_embd = n_embd 109 | self.pos_emb = nn.Parameter(torch.randn(1, *block_shape, n_embd)) 110 | 111 | def forward(self, token_embeddings): 112 | return token_embeddings + self.pos_emb 113 | 114 | 115 | class GlobalTransformer(torch.nn.Module): 116 | """Same as in SparseSync but without the selector transformers and the head""" 117 | 118 | def __init__( 119 | self, 120 | tok_pdrop=0.0, 121 | embd_pdrop=0.1, 122 | resid_pdrop=0.1, 123 | attn_pdrop=0.1, 124 | n_layer=3, 125 | n_head=8, 126 | n_embd=768, 127 | pos_emb_block_shape=[ 128 | 198, 129 | ], 130 | n_off_head_out=21, 131 | ) -> None: 132 | super().__init__() 133 | self.config = Config( 134 | embd_pdrop=embd_pdrop, 135 | resid_pdrop=resid_pdrop, 136 | attn_pdrop=attn_pdrop, 137 | n_layer=n_layer, 138 | n_head=n_head, 139 | n_embd=n_embd, 140 | ) 141 | # input norm 142 | self.vis_in_lnorm = torch.nn.LayerNorm(n_embd) 143 | self.aud_in_lnorm = torch.nn.LayerNorm(n_embd) 144 | # aux tokens 145 | self.OFF_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd)) 146 | self.MOD_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd)) 147 | # whole token dropout 148 | self.tok_pdrop = tok_pdrop 149 | self.tok_drop_vis = torch.nn.Dropout1d(tok_pdrop) 150 | self.tok_drop_aud = torch.nn.Dropout1d(tok_pdrop) 151 | # maybe add pos emb 152 | self.pos_emb_cfg = RandInitPositionalEncoding( 153 | block_shape=pos_emb_block_shape, 154 | n_embd=n_embd, 155 | ) 156 | # the stem 157 | self.drop = torch.nn.Dropout(embd_pdrop) 158 | self.blocks = torch.nn.Sequential(*[Block(self.config) for _ in range(n_layer)]) 159 | # pre-output norm 160 | self.ln_f = torch.nn.LayerNorm(n_embd) 161 | # maybe add a head 162 | self.off_head = torch.nn.Linear(in_features=n_embd, out_features=n_off_head_out) 163 | 164 | def forward(self, v: torch.Tensor, a: torch.Tensor, targets=None, attempt_to_apply_heads=True): 165 | B, Sv, D = v.shape 166 | B, Sa, D = a.shape 167 | # broadcasting special tokens to the batch size 168 | off_tok = einops.repeat(self.OFF_tok, "1 1 d -> b 1 d", b=B) 169 | mod_tok = einops.repeat(self.MOD_tok, "1 1 d -> b 1 d", b=B) 170 | # norm 171 | v, a = self.vis_in_lnorm(v), self.aud_in_lnorm(a) 172 | # maybe whole token dropout 173 | if self.tok_pdrop > 0: 174 | v, a = self.tok_drop_vis(v), self.tok_drop_aud(a) 175 | # (B, 1+Sv+1+Sa, D) 176 | x = torch.cat((off_tok, v, mod_tok, a), dim=1) 177 | # maybe add pos emb 178 | if hasattr(self, "pos_emb_cfg"): 179 | x = self.pos_emb_cfg(x) 180 | # dropout -> stem -> norm 181 | x = self.drop(x) 182 | x = self.blocks(x) 183 | x = self.ln_f(x) 184 | # maybe add heads 185 | if attempt_to_apply_heads and hasattr(self, "off_head"): 186 | x = self.off_head(x[:, 0, :]) 187 | return x 188 | 189 | 190 | class SelfAttention(nn.Module): 191 | """ 192 | A vanilla multi-head masked self-attention layer with a projection at the end. 193 | It is possible to use torch.nn.MultiheadAttention here but I am including an 194 | explicit implementation here to show that there is nothing too scary here. 195 | """ 196 | 197 | def __init__(self, config): 198 | super().__init__() 199 | assert config.n_embd % config.n_head == 0 200 | # key, query, value projections for all heads 201 | self.key = nn.Linear(config.n_embd, config.n_embd) 202 | self.query = nn.Linear(config.n_embd, config.n_embd) 203 | self.value = nn.Linear(config.n_embd, config.n_embd) 204 | # regularization 205 | self.attn_drop = nn.Dropout(config.attn_pdrop) 206 | self.resid_drop = nn.Dropout(config.resid_pdrop) 207 | # output projection 208 | self.proj = nn.Linear(config.n_embd, config.n_embd) 209 | # # causal mask to ensure that attention is only applied to the left in the input sequence 210 | # mask = torch.tril(torch.ones(config.block_size, 211 | # config.block_size)) 212 | # if hasattr(config, "n_unmasked"): 213 | # mask[:config.n_unmasked, :config.n_unmasked] = 1 214 | # self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size)) 215 | self.n_head = config.n_head 216 | 217 | def forward(self, x): 218 | B, T, C = x.size() 219 | 220 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 221 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 222 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 223 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 224 | 225 | # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 226 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 227 | # att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf')) 228 | att = F.softmax(att, dim=-1) 229 | y = self.attn_drop(att) @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 230 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 231 | 232 | # output projection 233 | y = self.resid_drop(self.proj(y)) 234 | 235 | return y 236 | 237 | 238 | class Block(nn.Module): 239 | """an unassuming Transformer block""" 240 | 241 | def __init__(self, config): 242 | super().__init__() 243 | self.ln1 = nn.LayerNorm(config.n_embd) 244 | self.ln2 = nn.LayerNorm(config.n_embd) 245 | self.attn = SelfAttention(config) 246 | self.mlp = nn.Sequential( 247 | nn.Linear(config.n_embd, 4 * config.n_embd), 248 | nn.GELU(), # nice 249 | nn.Linear(4 * config.n_embd, config.n_embd), 250 | nn.Dropout(config.resid_pdrop), 251 | ) 252 | 253 | def forward(self, x): 254 | x = x + self.attn(self.ln1(x)) 255 | x = x + self.mlp(self.ln2(x)) 256 | return x 257 | 258 | 259 | def make_class_grid( 260 | leftmost_val, 261 | rightmost_val, 262 | grid_size, 263 | add_extreme_offset: bool = False, 264 | seg_size_vframes: int = None, 265 | nseg: int = None, 266 | step_size_seg: float = None, 267 | vfps: float = None, 268 | ): 269 | assert grid_size >= 3, f"grid_size: {grid_size} doesnot make sense. If =2 -> (-1,1); =1 -> (-1); =0 -> ()" 270 | grid = torch.from_numpy(np.linspace(leftmost_val, rightmost_val, grid_size)).float() 271 | if add_extreme_offset: 272 | assert all([seg_size_vframes, nseg, step_size_seg]), f"{seg_size_vframes} {nseg} {step_size_seg}" 273 | seg_size_sec = seg_size_vframes / vfps 274 | trim_size_in_seg = nseg - (1 - step_size_seg) * (nseg - 1) 275 | extreme_value = trim_size_in_seg * seg_size_sec 276 | grid = torch.cat([grid, torch.tensor([extreme_value])]) # adding extreme offset to the class grid 277 | return grid 278 | 279 | 280 | # from synchformer 281 | def pad_or_truncate(audio: torch.Tensor, max_spec_t: int, pad_mode: str = "constant", pad_value: float = 0.0): 282 | difference = max_spec_t - audio.shape[-1] # safe for batched input 283 | # pad or truncate, depending on difference 284 | if difference > 0: 285 | # pad the last dim (time) -> (..., n_mels, 0+time+difference) # safe for batched input 286 | pad_dims = (0, difference) 287 | audio = torch.nn.functional.pad(audio, pad_dims, pad_mode, pad_value) 288 | elif difference < 0: 289 | print(f"Truncating spec ({audio.shape}) to max_spec_t ({max_spec_t}).") 290 | audio = audio[..., :max_spec_t] # safe for batched input 291 | return audio 292 | 293 | 294 | def encode_audio_with_sync( 295 | synchformer: Synchformer, x: torch.Tensor, mel: torchaudio.transforms.MelSpectrogram 296 | ) -> torch.Tensor: 297 | b, t = x.shape 298 | 299 | # partition the video 300 | segment_size = 10240 301 | step_size = 10240 // 2 302 | num_segments = (t - segment_size) // step_size + 1 303 | segments = [] 304 | for i in range(num_segments): 305 | segments.append(x[:, i * step_size : i * step_size + segment_size]) 306 | x = torch.stack(segments, dim=1) # (B, S, T, C, H, W) 307 | 308 | x = mel(x) 309 | x = torch.log(x + 1e-6) 310 | x = pad_or_truncate(x, 66) 311 | 312 | mean = -4.2677393 313 | std = 4.5689974 314 | x = (x - mean) / (2 * std) 315 | # x: B * S * 128 * 66 316 | x = synchformer.extract_afeats(x.unsqueeze(2)) 317 | return x 318 | 319 | 320 | def read_audio(filename, expected_length=int(16000 * 4)): 321 | waveform, sr = torchaudio.load(filename) 322 | waveform = waveform.mean(dim=0) 323 | 324 | if sr != 16000: 325 | resampler = torchaudio.transforms.Resample(sr, 16000) 326 | waveform = resampler[sr](waveform) 327 | 328 | waveform = waveform[:expected_length] 329 | if waveform.shape[0] != expected_length: 330 | raise ValueError(f"Audio {filename} is too short") 331 | 332 | waveform = waveform.squeeze() 333 | 334 | return waveform 335 | 336 | 337 | if __name__ == "__main__": 338 | synchformer = Synchformer().cuda().eval() 339 | 340 | # mmaudio provided synchformer ckpt 341 | synchformer.load_state_dict( 342 | torch.load( 343 | os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"), 344 | weights_only=True, 345 | map_location="cpu", 346 | ) 347 | ) 348 | 349 | sync_mel_spectrogram = torchaudio.transforms.MelSpectrogram( 350 | sample_rate=16000, 351 | win_length=400, 352 | hop_length=160, 353 | n_fft=1024, 354 | n_mels=128, 355 | ) 356 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/model/dac.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List 3 | from typing import Union 4 | 5 | import numpy as np 6 | import torch 7 | from audiotools import AudioSignal 8 | from audiotools.ml import BaseModel 9 | from torch import nn 10 | 11 | from .base import CodecMixin 12 | from ..nn.layers import Snake1d 13 | from ..nn.layers import WNConv1d 14 | from ..nn.layers import WNConvTranspose1d 15 | from ..nn.quantize import ResidualVectorQuantize 16 | from ..nn.vae_utils import DiagonalGaussianDistribution 17 | 18 | 19 | def init_weights(m): 20 | if isinstance(m, nn.Conv1d): 21 | nn.init.trunc_normal_(m.weight, std=0.02) 22 | nn.init.constant_(m.bias, 0) 23 | 24 | 25 | class ResidualUnit(nn.Module): 26 | def __init__(self, dim: int = 16, dilation: int = 1): 27 | super().__init__() 28 | pad = ((7 - 1) * dilation) // 2 29 | self.block = nn.Sequential( 30 | Snake1d(dim), 31 | WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), 32 | Snake1d(dim), 33 | WNConv1d(dim, dim, kernel_size=1), 34 | ) 35 | 36 | def forward(self, x): 37 | y = self.block(x) 38 | pad = (x.shape[-1] - y.shape[-1]) // 2 39 | if pad > 0: 40 | x = x[..., pad:-pad] 41 | return x + y 42 | 43 | 44 | class EncoderBlock(nn.Module): 45 | def __init__(self, dim: int = 16, stride: int = 1): 46 | super().__init__() 47 | self.block = nn.Sequential( 48 | ResidualUnit(dim // 2, dilation=1), 49 | ResidualUnit(dim // 2, dilation=3), 50 | ResidualUnit(dim // 2, dilation=9), 51 | Snake1d(dim // 2), 52 | WNConv1d( 53 | dim // 2, 54 | dim, 55 | kernel_size=2 * stride, 56 | stride=stride, 57 | padding=math.ceil(stride / 2), 58 | ), 59 | ) 60 | 61 | def forward(self, x): 62 | return self.block(x) 63 | 64 | 65 | class Encoder(nn.Module): 66 | def __init__( 67 | self, 68 | d_model: int = 64, 69 | strides: list = [2, 4, 8, 8], 70 | d_latent: int = 64, 71 | ): 72 | super().__init__() 73 | # Create first convolution 74 | self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] 75 | 76 | # Create EncoderBlocks that double channels as they downsample by `stride` 77 | for stride in strides: 78 | d_model *= 2 79 | self.block += [EncoderBlock(d_model, stride=stride)] 80 | 81 | # Create last convolution 82 | self.block += [ 83 | Snake1d(d_model), 84 | WNConv1d(d_model, d_latent, kernel_size=3, padding=1), 85 | ] 86 | 87 | # Wrap black into nn.Sequential 88 | self.block = nn.Sequential(*self.block) 89 | self.enc_dim = d_model 90 | 91 | def forward(self, x): 92 | return self.block(x) 93 | 94 | 95 | class DecoderBlock(nn.Module): 96 | def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): 97 | super().__init__() 98 | self.block = nn.Sequential( 99 | Snake1d(input_dim), 100 | WNConvTranspose1d( 101 | input_dim, 102 | output_dim, 103 | kernel_size=2 * stride, 104 | stride=stride, 105 | padding=math.ceil(stride / 2), 106 | output_padding=stride % 2, 107 | ), 108 | ResidualUnit(output_dim, dilation=1), 109 | ResidualUnit(output_dim, dilation=3), 110 | ResidualUnit(output_dim, dilation=9), 111 | ) 112 | 113 | def forward(self, x): 114 | return self.block(x) 115 | 116 | 117 | class Decoder(nn.Module): 118 | def __init__( 119 | self, 120 | input_channel, 121 | channels, 122 | rates, 123 | d_out: int = 1, 124 | ): 125 | super().__init__() 126 | 127 | # Add first conv layer 128 | layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] 129 | 130 | # Add upsampling + MRF blocks 131 | for i, stride in enumerate(rates): 132 | input_dim = channels // 2**i 133 | output_dim = channels // 2 ** (i + 1) 134 | layers += [DecoderBlock(input_dim, output_dim, stride)] 135 | 136 | # Add final conv layer 137 | layers += [ 138 | Snake1d(output_dim), 139 | WNConv1d(output_dim, d_out, kernel_size=7, padding=3), 140 | nn.Tanh(), 141 | ] 142 | 143 | self.model = nn.Sequential(*layers) 144 | 145 | def forward(self, x): 146 | return self.model(x) 147 | 148 | 149 | class DAC(BaseModel, CodecMixin): 150 | def __init__( 151 | self, 152 | encoder_dim: int = 64, 153 | encoder_rates: List[int] = [2, 4, 8, 8], 154 | latent_dim: int = None, 155 | decoder_dim: int = 1536, 156 | decoder_rates: List[int] = [8, 8, 4, 2], 157 | n_codebooks: int = 9, 158 | codebook_size: int = 1024, 159 | codebook_dim: Union[int, list] = 8, 160 | quantizer_dropout: bool = False, 161 | sample_rate: int = 44100, 162 | continuous: bool = False, 163 | ): 164 | super().__init__() 165 | 166 | self.encoder_dim = encoder_dim 167 | self.encoder_rates = encoder_rates 168 | self.decoder_dim = decoder_dim 169 | self.decoder_rates = decoder_rates 170 | self.sample_rate = sample_rate 171 | self.continuous = continuous 172 | 173 | if latent_dim is None: 174 | latent_dim = encoder_dim * (2 ** len(encoder_rates)) 175 | 176 | self.latent_dim = latent_dim 177 | 178 | self.hop_length = np.prod(encoder_rates) 179 | self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim) 180 | 181 | if not continuous: 182 | self.n_codebooks = n_codebooks 183 | self.codebook_size = codebook_size 184 | self.codebook_dim = codebook_dim 185 | self.quantizer = ResidualVectorQuantize( 186 | input_dim=latent_dim, 187 | n_codebooks=n_codebooks, 188 | codebook_size=codebook_size, 189 | codebook_dim=codebook_dim, 190 | quantizer_dropout=quantizer_dropout, 191 | ) 192 | else: 193 | self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1) 194 | self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1) 195 | 196 | self.decoder = Decoder( 197 | latent_dim, 198 | decoder_dim, 199 | decoder_rates, 200 | ) 201 | self.sample_rate = sample_rate 202 | self.apply(init_weights) 203 | 204 | self.delay = self.get_delay() 205 | 206 | @property 207 | def dtype(self): 208 | """Get the dtype of the model parameters.""" 209 | # Return the dtype of the first parameter found 210 | for param in self.parameters(): 211 | return param.dtype 212 | return torch.float32 # fallback 213 | 214 | @property 215 | def device(self): 216 | """Get the device of the model parameters.""" 217 | # Return the device of the first parameter found 218 | for param in self.parameters(): 219 | return param.device 220 | return torch.device('cpu') # fallback 221 | 222 | def preprocess(self, audio_data, sample_rate): 223 | if sample_rate is None: 224 | sample_rate = self.sample_rate 225 | assert sample_rate == self.sample_rate 226 | 227 | length = audio_data.shape[-1] 228 | right_pad = math.ceil(length / self.hop_length) * self.hop_length - length 229 | audio_data = nn.functional.pad(audio_data, (0, right_pad)) 230 | 231 | return audio_data 232 | 233 | def encode( 234 | self, 235 | audio_data: torch.Tensor, 236 | n_quantizers: int = None, 237 | ): 238 | """Encode given audio data and return quantized latent codes 239 | 240 | Parameters 241 | ---------- 242 | audio_data : Tensor[B x 1 x T] 243 | Audio data to encode 244 | n_quantizers : int, optional 245 | Number of quantizers to use, by default None 246 | If None, all quantizers are used. 247 | 248 | Returns 249 | ------- 250 | dict 251 | A dictionary with the following keys: 252 | "z" : Tensor[B x D x T] 253 | Quantized continuous representation of input 254 | "codes" : Tensor[B x N x T] 255 | Codebook indices for each codebook 256 | (quantized discrete representation of input) 257 | "latents" : Tensor[B x N*D x T] 258 | Projected latents (continuous representation of input before quantization) 259 | "vq/commitment_loss" : Tensor[1] 260 | Commitment loss to train encoder to predict vectors closer to codebook 261 | entries 262 | "vq/codebook_loss" : Tensor[1] 263 | Codebook loss to update the codebook 264 | "length" : int 265 | Number of samples in input audio 266 | """ 267 | z = self.encoder(audio_data) # [B x D x T] 268 | if not self.continuous: 269 | z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers) 270 | else: 271 | z = self.quant_conv(z) # [B x 2D x T] 272 | z = DiagonalGaussianDistribution(z) 273 | codes, latents, commitment_loss, codebook_loss = None, None, 0, 0 274 | 275 | return z, codes, latents, commitment_loss, codebook_loss 276 | 277 | def decode(self, z: torch.Tensor): 278 | """Decode given latent codes and return audio data 279 | 280 | Parameters 281 | ---------- 282 | z : Tensor[B x D x T] 283 | Quantized continuous representation of input 284 | length : int, optional 285 | Number of samples in output audio, by default None 286 | 287 | Returns 288 | ------- 289 | dict 290 | A dictionary with the following keys: 291 | "audio" : Tensor[B x 1 x length] 292 | Decoded audio data. 293 | """ 294 | if not self.continuous: 295 | audio = self.decoder(z) 296 | else: 297 | z = self.post_quant_conv(z) 298 | audio = self.decoder(z) 299 | 300 | return audio 301 | 302 | def forward( 303 | self, 304 | audio_data: torch.Tensor, 305 | sample_rate: int = None, 306 | n_quantizers: int = None, 307 | ): 308 | """Model forward pass 309 | 310 | Parameters 311 | ---------- 312 | audio_data : Tensor[B x 1 x T] 313 | Audio data to encode 314 | sample_rate : int, optional 315 | Sample rate of audio data in Hz, by default None 316 | If None, defaults to `self.sample_rate` 317 | n_quantizers : int, optional 318 | Number of quantizers to use, by default None. 319 | If None, all quantizers are used. 320 | 321 | Returns 322 | ------- 323 | dict 324 | A dictionary with the following keys: 325 | "z" : Tensor[B x D x T] 326 | Quantized continuous representation of input 327 | "codes" : Tensor[B x N x T] 328 | Codebook indices for each codebook 329 | (quantized discrete representation of input) 330 | "latents" : Tensor[B x N*D x T] 331 | Projected latents (continuous representation of input before quantization) 332 | "vq/commitment_loss" : Tensor[1] 333 | Commitment loss to train encoder to predict vectors closer to codebook 334 | entries 335 | "vq/codebook_loss" : Tensor[1] 336 | Codebook loss to update the codebook 337 | "length" : int 338 | Number of samples in input audio 339 | "audio" : Tensor[B x 1 x length] 340 | Decoded audio data. 341 | """ 342 | length = audio_data.shape[-1] 343 | audio_data = self.preprocess(audio_data, sample_rate) 344 | if not self.continuous: 345 | z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers) 346 | 347 | x = self.decode(z) 348 | return { 349 | "audio": x[..., :length], 350 | "z": z, 351 | "codes": codes, 352 | "latents": latents, 353 | "vq/commitment_loss": commitment_loss, 354 | "vq/codebook_loss": codebook_loss, 355 | } 356 | else: 357 | posterior, _, _, _, _ = self.encode(audio_data, n_quantizers) 358 | z = posterior.sample() 359 | x = self.decode(z) 360 | 361 | kl_loss = posterior.kl() 362 | kl_loss = kl_loss.mean() 363 | 364 | return { 365 | "audio": x[..., :length], 366 | "z": z, 367 | "kl_loss": kl_loss, 368 | } 369 | 370 | 371 | if __name__ == "__main__": 372 | import numpy as np 373 | from functools import partial 374 | 375 | model = DAC().to("cpu") 376 | 377 | for n, m in model.named_modules(): 378 | o = m.extra_repr() 379 | p = sum([np.prod(p.size()) for p in m.parameters()]) 380 | fn = lambda o, p: o + f" {p/1e6:<.3f}M params." 381 | setattr(m, "extra_repr", partial(fn, o=o, p=p)) 382 | print(model) 383 | print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) 384 | 385 | length = 88200 * 2 386 | x = torch.randn(1, 1, length).to(model.device) 387 | x.requires_grad_(True) 388 | x.retain_grad() 389 | 390 | # Make a forward pass 391 | out = model(x)["audio"] 392 | print("Input shape:", x.shape) 393 | print("Output shape:", out.shape) 394 | 395 | # Create gradient variable 396 | grad = torch.zeros_like(out) 397 | grad[:, :, grad.shape[-1] // 2] = 1 398 | 399 | # Make a backward pass 400 | out.backward(grad) 401 | 402 | # Check non-zero values 403 | gradmap = x.grad.squeeze(0) 404 | gradmap = (gradmap != 0).sum(0) # sum across features 405 | rf = (gradmap != 0).sum() 406 | 407 | print(f"Receptive field: {rf.item()}") 408 | 409 | x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100) 410 | model.decompress(model.compress(x, verbose=True), verbose=True) 411 | --------------------------------------------------------------------------------