├── hunyuanvideo_foley ├── __init__.py ├── models │ ├── __init__.py │ ├── nn │ │ ├── __init__.py │ │ ├── activation_layers.py │ │ ├── modulate_layers.py │ │ ├── norm_layers.py │ │ ├── embed_layers.py │ │ ├── mlp_layers.py │ │ └── posemb_layers.py │ ├── synchformer │ │ ├── __init__.py │ │ ├── divided_224_16x4.yaml │ │ ├── utils.py │ │ ├── compute_desync_score.py │ │ ├── video_model_builder.py │ │ └── synchformer.py │ └── dac_vae │ │ ├── nn │ │ ├── __init__.py │ │ ├── layers.py │ │ ├── vae_utils.py │ │ ├── quantize.py │ │ └── loss.py │ │ ├── model │ │ ├── __init__.py │ │ ├── discriminator.py │ │ ├── base.py │ │ └── dac.py │ │ ├── __init__.py │ │ ├── __main__.py │ │ └── utils │ │ ├── decode.py │ │ ├── encode.py │ │ └── __init__.py ├── utils │ ├── __init__.py │ ├── schedulers │ │ └── __init__.py │ ├── config_utils.py │ ├── helper.py │ ├── media_utils.py │ ├── feature_utils.py │ └── model_utils.py └── constants.py ├── requirements.txt ├── CONTRIBUTORS.md ├── pyproject.toml ├── .github └── workflows │ └── main.yml ├── __init__.py ├── .gitignore ├── configs └── hunyuanvideo-foley-xxl.yaml ├── model_urls.py ├── model_management.py ├── INSTALLATION_GUIDE.md ├── download_models_manual.py ├── install.py ├── test_node.py └── README.md /hunyuanvideo_foley/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .synchformer import Synchformer 2 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from . import layers 2 | from . import loss 3 | from . import quantize 4 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import CodecMixin 2 | from .base import DACFile 3 | from .dac import DAC 4 | from .discriminator import Discriminator 5 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from diffusers.schedulers import DDPMScheduler, EulerDiscreteScheduler 2 | from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | loguru 3 | tqdm 4 | accelerate 5 | transformers>=4.37.0 6 | safetensors 7 | requests 8 | opencv-python 9 | diffusers 10 | pyyaml 11 | einops 12 | omegaconf 13 | packaging 14 | pytorch-lightning 15 | descript-audio-codec 16 | scipy 17 | soxr 18 | ffmpy 19 | audiocraft 20 | descript-audio-codec 21 | -------------------------------------------------------------------------------- /CONTRIBUTORS.md: -------------------------------------------------------------------------------- 1 | # Contributors 2 | 3 | We are grateful to the following individuals for their contributions to this project: 4 | 5 | - [@dasilva333](https://github.com/dasilva333) - Added `enabled` and `silent_audio` toggles for improved workflow control and error handling. 6 | - [@yichengup](https://github.com/yichengup) - Implemented image frame input/output and added a negative prompt field. 7 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | 3 | # preserved here for legacy reasons 4 | __model_version__ = "latest" 5 | 6 | import audiotools 7 | 8 | audiotools.ml.BaseModel.INTERN += ["dac.**"] 9 | audiotools.ml.BaseModel.EXTERN += ["einops"] 10 | 11 | 12 | from . import nn 13 | from . import model 14 | from . import utils 15 | from .model import DAC 16 | from .model import DACFile 17 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "hunyuanvideo-foley" 7 | description = "ComfyUI custom node for HunyuanVideo-Foley text-video-to-audio synthesis" 8 | version = "1.0.3" 9 | license = { file = "LICENSE.txt" } 10 | dependencies = [] 11 | 12 | [project.urls] 13 | Repository = "https://github.com/if-ai/ComfyUI_HunyuanVideoFoley" 14 | 15 | [tool.comfy] 16 | PublisherId = "impactframes" 17 | DisplayName = "HunyuanVideo-Foley" 18 | Icon = "" 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | jobs: 11 | publish-node: 12 | name: Publish Custom Node to registry 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Check out code 16 | uses: actions/checkout@v4 17 | - name: Publish Custom Node 18 | uses: Comfy-Org/publish-node-action@main 19 | with: 20 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} ## Add your own personal access token to your Github Repository secrets and reference it here. 21 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from loguru import logger 4 | 5 | # Add the current directory to Python path to import hunyuanvideo_foley modules 6 | current_dir = os.path.dirname(os.path.abspath(__file__)) 7 | if current_dir not in sys.path: 8 | sys.path.insert(0, current_dir) 9 | 10 | # Import the individual nodes (with FP8 quantization and torch.compile support) 11 | logger.info("Loading HunyuanVideo-Foley nodes with FP8 quantization and torch.compile support") 12 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 13 | 14 | # Export the mappings 15 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/__main__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import argbind 4 | 5 | from .utils import download 6 | from .utils.decode import decode 7 | from .utils.encode import encode 8 | 9 | STAGES = ["encode", "decode", "download"] 10 | 11 | 12 | def run(stage: str): 13 | """Run stages. 14 | 15 | Parameters 16 | ---------- 17 | stage : str 18 | Stage to run 19 | """ 20 | if stage not in STAGES: 21 | raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}") 22 | stage_fn = globals()[stage] 23 | 24 | if stage == "download": 25 | stage_fn() 26 | return 27 | 28 | stage_fn() 29 | 30 | 31 | if __name__ == "__main__": 32 | group = sys.argv.pop(1) 33 | args = argbind.parse_args(group=group) 34 | 35 | with argbind.scope(args): 36 | run(group) 37 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from torch.nn.utils import weight_norm 7 | 8 | 9 | def WNConv1d(*args, **kwargs): 10 | return weight_norm(nn.Conv1d(*args, **kwargs)) 11 | 12 | 13 | def WNConvTranspose1d(*args, **kwargs): 14 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) 15 | 16 | 17 | # Scripting this brings model speed up 1.4x 18 | @torch.jit.script 19 | def snake(x, alpha): 20 | shape = x.shape 21 | x = x.reshape(shape[0], shape[1], -1) 22 | x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) 23 | x = x.reshape(shape) 24 | return x 25 | 26 | 27 | class Snake1d(nn.Module): 28 | def __init__(self, channels): 29 | super().__init__() 30 | self.alpha = nn.Parameter(torch.ones(1, channels, 1)) 31 | 32 | def forward(self, x): 33 | return snake(x, self.alpha) 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class$ 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # Environments 77 | .env 78 | .venv 79 | env/ 80 | venv/ 81 | ENV/ 82 | env.bak/ 83 | venv.bak/ 84 | 85 | # IDE-specific files 86 | .idea/ 87 | .vscode/ 88 | *.suo 89 | *.ntvs* 90 | *.njsproj 91 | *.sln 92 | *.swp 93 | 94 | # AI tool-specific files 95 | .claude/ 96 | .serena/ 97 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/activation_layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | def get_activation_layer(act_type): 5 | if act_type == "gelu": 6 | return lambda: nn.GELU() 7 | elif act_type == "gelu_tanh": 8 | # Approximate `tanh` requires torch >= 1.13 9 | return lambda: nn.GELU(approximate="tanh") 10 | elif act_type == "relu": 11 | return nn.ReLU 12 | elif act_type == "silu": 13 | return nn.SiLU 14 | else: 15 | raise ValueError(f"Unknown activation type: {act_type}") 16 | 17 | class SwiGLU(nn.Module): 18 | def __init__( 19 | self, 20 | dim: int, 21 | hidden_dim: int, 22 | out_dim: int, 23 | ): 24 | """ 25 | Initialize the SwiGLU FeedForward module. 26 | 27 | Args: 28 | dim (int): Input dimension. 29 | hidden_dim (int): Hidden dimension of the feedforward layer. 30 | 31 | Attributes: 32 | w1: Linear transformation for the first layer. 33 | w2: Linear transformation for the second layer. 34 | w3: Linear transformation for the third layer. 35 | 36 | """ 37 | super().__init__() 38 | 39 | self.w1 = nn.Linear(dim, hidden_dim, bias=False) 40 | self.w2 = nn.Linear(hidden_dim, out_dim, bias=False) 41 | self.w3 = nn.Linear(dim, hidden_dim, bias=False) 42 | 43 | def forward(self, x): 44 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 45 | -------------------------------------------------------------------------------- /configs/hunyuanvideo-foley-xxl.yaml: -------------------------------------------------------------------------------- 1 | model_config: 2 | model_name: HunyuanVideo-Foley-XXL 3 | model_type: 1d 4 | model_precision: bf16 5 | model_kwargs: 6 | depth_triple_blocks: 18 7 | depth_single_blocks: 36 8 | hidden_size: 1536 9 | num_heads: 12 10 | mlp_ratio: 4 11 | mlp_act_type: "gelu_tanh" 12 | qkv_bias: True 13 | qk_norm: True 14 | qk_norm_type: "rms" 15 | attn_mode: "torch" 16 | embedder_type: "default" 17 | interleaved_audio_visual_rope: True 18 | enable_learnable_empty_visual_feat: True 19 | sync_modulation: False 20 | add_sync_feat_to_audio: True 21 | cross_attention: True 22 | use_attention_mask: False 23 | condition_projection: "linear" 24 | sync_feat_dim: 768 # syncformer 768 dim 25 | condition_dim: 768 # clap 768 text condition dim (clip-text) 26 | clip_dim: 768 # siglip2 visual dim 27 | audio_vae_latent_dim: 128 28 | audio_frame_rate: 50 29 | patch_size: 1 30 | rope_dim_list: null 31 | rope_theta: 10000 32 | text_length: 77 33 | clip_length: 64 34 | sync_length: 192 35 | use_mmaudio_singleblock: True 36 | depth_triple_ssl_encoder: null 37 | depth_single_ssl_encoder: 8 38 | use_repa_with_audiossl: True 39 | 40 | diffusion_config: 41 | denoise_type: "flow" 42 | flow_path_type: "linear" 43 | flow_predict_type: "velocity" 44 | flow_reverse: True 45 | flow_solver: "euler" 46 | sample_flow_shift: 1.0 47 | sample_use_flux_shift: False 48 | flux_base_shift: 0.5 49 | flux_max_shift: 1.15 50 | -------------------------------------------------------------------------------- /model_urls.py: -------------------------------------------------------------------------------- 1 | # HunyuanVideo-Foley Model URLs Configuration 2 | # Update these URLs with the actual download links for the models 3 | 4 | MODEL_URLS = { 5 | "hunyuanvideo-foley-xxl": { 6 | "models": [ 7 | { 8 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/hunyuanvideo_foley.pth", 9 | "filename": "hunyuanvideo_foley.pth", 10 | "description": "Main HunyuanVideo-Foley model" 11 | }, 12 | { 13 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/synchformer_state_dict.pth", 14 | "filename": "synchformer_state_dict.pth", 15 | "description": "Synchformer model weights" 16 | }, 17 | { 18 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/vae_128d_48k.pth", 19 | "filename": "vae_128d_48k.pth", 20 | "description": "VAE model weights" 21 | } 22 | ], 23 | "extracted_dir": "hunyuanvideo-foley-xxl", 24 | "description": "HunyuanVideo-Foley XXL model for audio generation" 25 | } 26 | } 27 | 28 | # Alternative mirror URLs (if main URLs fail) 29 | MIRROR_URLS = { 30 | # Add mirror download sources here if needed 31 | } 32 | 33 | def get_model_url(model_name: str, use_mirror: bool = False) -> dict: 34 | """Get model URL configuration""" 35 | urls_dict = MIRROR_URLS if use_mirror else MODEL_URLS 36 | return urls_dict.get(model_name, {}) 37 | 38 | def list_available_models() -> list: 39 | """List all available model names""" 40 | return list(MODEL_URLS.keys()) -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/modulate_layers.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import torch 3 | import torch.nn as nn 4 | 5 | class ModulateDiT(nn.Module): 6 | def __init__(self, hidden_size: int, factor: int, act_layer: Callable, dtype=None, device=None): 7 | factory_kwargs = {"dtype": dtype, "device": device} 8 | super().__init__() 9 | self.act = act_layer() 10 | self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) 11 | # Zero-initialize the modulation 12 | nn.init.zeros_(self.linear.weight) 13 | nn.init.zeros_(self.linear.bias) 14 | 15 | def forward(self, x: torch.Tensor) -> torch.Tensor: 16 | return self.linear(self.act(x)) 17 | 18 | 19 | def modulate(x, shift=None, scale=None): 20 | if x.ndim == 3: 21 | shift = shift.unsqueeze(1) if shift is not None and shift.ndim == 2 else None 22 | scale = scale.unsqueeze(1) if scale is not None and scale.ndim == 2 else None 23 | if scale is None and shift is None: 24 | return x 25 | elif shift is None: 26 | return x * (1 + scale) 27 | elif scale is None: 28 | return x + shift 29 | else: 30 | return x * (1 + scale) + shift 31 | 32 | 33 | def apply_gate(x, gate=None, tanh=False): 34 | if gate is None: 35 | return x 36 | if gate.ndim == 2 and x.ndim == 3: 37 | gate = gate.unsqueeze(1) 38 | if tanh: 39 | return x * gate.tanh() 40 | else: 41 | return x * gate 42 | 43 | 44 | def ckpt_wrapper(module): 45 | def ckpt_forward(*inputs): 46 | outputs = module(*inputs) 47 | return outputs 48 | 49 | return ckpt_forward 50 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/constants.py: -------------------------------------------------------------------------------- 1 | """Constants used throughout the HunyuanVideo-Foley project.""" 2 | 3 | from typing import Dict, List 4 | 5 | # Model configuration 6 | DEFAULT_AUDIO_SAMPLE_RATE = 48000 7 | DEFAULT_VIDEO_FPS = 25 8 | DEFAULT_AUDIO_CHANNELS = 2 9 | 10 | # Video processing 11 | MAX_VIDEO_DURATION_SECONDS = 15.0 12 | MIN_VIDEO_DURATION_SECONDS = 1.0 13 | 14 | # Audio processing 15 | AUDIO_VAE_LATENT_DIM = 128 16 | AUDIO_FRAME_RATE = 75 # frames per second in latent space 17 | 18 | # Visual features 19 | FPS_VISUAL: Dict[str, int] = { 20 | "siglip2": 8, 21 | "synchformer": 25 22 | } 23 | 24 | # Model paths (can be overridden by environment variables) 25 | DEFAULT_MODEL_PATH = "./pretrained_models/" 26 | DEFAULT_CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml" 27 | 28 | # Inference parameters 29 | DEFAULT_GUIDANCE_SCALE = 4.5 30 | DEFAULT_NUM_INFERENCE_STEPS = 50 31 | MIN_GUIDANCE_SCALE = 1.0 32 | MAX_GUIDANCE_SCALE = 10.0 33 | MIN_INFERENCE_STEPS = 10 34 | MAX_INFERENCE_STEPS = 100 35 | 36 | # Text processing 37 | MAX_TEXT_LENGTH = 100 38 | DEFAULT_NEGATIVE_PROMPT = "noisy, harsh" 39 | 40 | # File extensions 41 | SUPPORTED_VIDEO_EXTENSIONS: List[str] = [".mp4", ".avi", ".mov", ".mkv", ".webm"] 42 | SUPPORTED_AUDIO_EXTENSIONS: List[str] = [".wav", ".mp3", ".flac", ".aac"] 43 | 44 | # Quality settings 45 | AUDIO_QUALITY_SETTINGS: Dict[str, List[str]] = { 46 | "high": ["-b:a", "192k"], 47 | "medium": ["-b:a", "128k"], 48 | "low": ["-b:a", "96k"] 49 | } 50 | 51 | # Error messages 52 | ERROR_MESSAGES: Dict[str, str] = { 53 | "model_not_loaded": "Model is not loaded. Please load the model first.", 54 | "invalid_video_format": "Unsupported video format. Supported formats: {formats}", 55 | "video_too_long": f"Video duration exceeds maximum of {MAX_VIDEO_DURATION_SECONDS} seconds", 56 | "ffmpeg_not_found": "ffmpeg not found. Please install ffmpeg: https://ffmpeg.org/download.html" 57 | } -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/divided_224_16x4.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: Ssv2 4 | BATCH_SIZE: 32 5 | EVAL_PERIOD: 5 6 | CHECKPOINT_PERIOD: 5 7 | AUTO_RESUME: True 8 | CHECKPOINT_EPOCH_RESET: True 9 | CHECKPOINT_FILE_PATH: /checkpoint/fmetze/neurips_sota/40944587/checkpoints/checkpoint_epoch_00035.pyth 10 | DATA: 11 | NUM_FRAMES: 16 12 | SAMPLING_RATE: 4 13 | TRAIN_JITTER_SCALES: [256, 320] 14 | TRAIN_CROP_SIZE: 224 15 | TEST_CROP_SIZE: 224 16 | INPUT_CHANNEL_NUM: [3] 17 | MEAN: [0.5, 0.5, 0.5] 18 | STD: [0.5, 0.5, 0.5] 19 | PATH_TO_DATA_DIR: /private/home/mandelapatrick/slowfast/data/ssv2 20 | PATH_PREFIX: /datasets01/SomethingV2/092720/20bn-something-something-v2-frames 21 | INV_UNIFORM_SAMPLE: True 22 | RANDOM_FLIP: False 23 | REVERSE_INPUT_CHANNEL: True 24 | USE_RAND_AUGMENT: True 25 | RE_PROB: 0.0 26 | USE_REPEATED_AUG: False 27 | USE_RANDOM_RESIZE_CROPS: False 28 | COLORJITTER: False 29 | GRAYSCALE: False 30 | GAUSSIAN: False 31 | SOLVER: 32 | BASE_LR: 1e-4 33 | LR_POLICY: steps_with_relative_lrs 34 | LRS: [1, 0.1, 0.01] 35 | STEPS: [0, 20, 30] 36 | MAX_EPOCH: 35 37 | MOMENTUM: 0.9 38 | WEIGHT_DECAY: 5e-2 39 | WARMUP_EPOCHS: 0.0 40 | OPTIMIZING_METHOD: adamw 41 | USE_MIXED_PRECISION: True 42 | SMOOTHING: 0.2 43 | SLOWFAST: 44 | ALPHA: 8 45 | VIT: 46 | PATCH_SIZE: 16 47 | PATCH_SIZE_TEMP: 2 48 | CHANNELS: 3 49 | EMBED_DIM: 768 50 | DEPTH: 12 51 | NUM_HEADS: 12 52 | MLP_RATIO: 4 53 | QKV_BIAS: True 54 | VIDEO_INPUT: True 55 | TEMPORAL_RESOLUTION: 8 56 | USE_MLP: True 57 | DROP: 0.0 58 | POS_DROPOUT: 0.0 59 | DROP_PATH: 0.2 60 | IM_PRETRAINED: True 61 | HEAD_DROPOUT: 0.0 62 | HEAD_ACT: tanh 63 | PRETRAINED_WEIGHTS: vit_1k 64 | ATTN_LAYER: divided 65 | MODEL: 66 | NUM_CLASSES: 174 67 | ARCH: slow 68 | MODEL_NAME: VisionTransformer 69 | LOSS_FUNC: cross_entropy 70 | TEST: 71 | ENABLE: True 72 | DATASET: Ssv2 73 | BATCH_SIZE: 64 74 | NUM_ENSEMBLE_VIEWS: 1 75 | NUM_SPATIAL_CROPS: 3 76 | DATA_LOADER: 77 | NUM_WORKERS: 4 78 | PIN_MEMORY: True 79 | NUM_GPUS: 8 80 | NUM_SHARDS: 4 81 | RNG_SEED: 0 82 | OUTPUT_DIR: . 83 | TENSORBOARD: 84 | ENABLE: True 85 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/norm_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class RMSNorm(nn.Module): 5 | def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6, 6 | device=None, dtype=None): 7 | """ 8 | Initialize the RMSNorm normalization layer. 9 | 10 | Args: 11 | dim (int): The dimension of the input tensor. 12 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 13 | 14 | Attributes: 15 | eps (float): A small value added to the denominator for numerical stability. 16 | weight (nn.Parameter): Learnable scaling parameter. 17 | 18 | """ 19 | factory_kwargs = {'device': device, 'dtype': dtype} 20 | super().__init__() 21 | self.eps = eps 22 | if elementwise_affine: 23 | self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) 24 | 25 | def _norm(self, x): 26 | """ 27 | Apply the RMSNorm normalization to the input tensor. 28 | 29 | Args: 30 | x (torch.Tensor): The input tensor. 31 | 32 | Returns: 33 | torch.Tensor: The normalized tensor. 34 | 35 | """ 36 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 37 | 38 | def forward(self, x): 39 | """ 40 | Forward pass through the RMSNorm layer. 41 | 42 | Args: 43 | x (torch.Tensor): The input tensor. 44 | 45 | Returns: 46 | torch.Tensor: The output tensor after applying RMSNorm. 47 | 48 | """ 49 | output = self._norm(x.float()).type_as(x) 50 | if hasattr(self, "weight"): 51 | output = output * self.weight 52 | return output 53 | 54 | 55 | def get_norm_layer(norm_layer): 56 | """ 57 | Get the normalization layer. 58 | 59 | Args: 60 | norm_layer (str): The type of normalization layer. 61 | 62 | Returns: 63 | norm_layer (nn.Module): The normalization layer. 64 | """ 65 | if norm_layer == "layer": 66 | return nn.LayerNorm 67 | elif norm_layer == "rms": 68 | return RMSNorm 69 | else: 70 | raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") 71 | -------------------------------------------------------------------------------- /model_management.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import comfy.utils 4 | from loguru import logger 5 | import folder_paths 6 | from huggingface_hub import hf_hub_download 7 | 8 | # --- Constants --- 9 | FOLEY_MODEL_NAMES = ["hunyuanvideo_foley.pth", "vae_128d_48k.pth", "synchformer_state_dict.pth"] 10 | SIGLIP_MODEL_REPO = "google/siglip-base-patch16-512" 11 | CLAP_MODEL_REPO = "laion/clap-htsat-unfused" 12 | 13 | # --- Path Management --- 14 | def get_model_dir(subfolder=""): 15 | """Returns the primary Foley models directory.""" 16 | return os.path.join(folder_paths.get_folder_paths("foley")[0], subfolder) 17 | 18 | def get_full_model_path(model_name, subfolder=""): 19 | """Returns the full path for a given model name.""" 20 | return os.path.join(get_model_dir(subfolder), model_name) 21 | 22 | # --- Core Functionality --- 23 | def find_or_download(model_name, repo_id, subfolder="", subfolder_in_repo=""): 24 | """ 25 | Finds a model file, downloading it if it's not found in standard locations. 26 | - Checks the main ComfyUI foley models directory first. 27 | - Falls back to downloading from Hugging Face. 28 | """ 29 | local_path = get_full_model_path(model_name, subfolder) 30 | 31 | if os.path.exists(local_path): 32 | logger.info(f"Found local model: {local_path}") 33 | return local_path 34 | 35 | logger.warning(f"Could not find {model_name} locally. Attempting to download from {repo_id}...") 36 | 37 | try: 38 | downloaded_path = hf_hub_download( 39 | repo_id=repo_id, 40 | filename=model_name, 41 | subfolder=subfolder_in_repo, 42 | local_dir=get_model_dir(subfolder), 43 | local_dir_use_symlinks=False 44 | ) 45 | logger.info(f"Successfully downloaded model to: {downloaded_path}") 46 | return downloaded_path 47 | except Exception as e: 48 | logger.error(f"Failed to download {model_name} from {repo_id}: {e}") 49 | raise FileNotFoundError(f"Could not find or download {model_name}. Please check your connection or download it manually.") 50 | 51 | def get_siglip_path(): 52 | """Special handling for the SigLIP model which is a directory.""" 53 | return find_or_download_directory(repo_id=SIGLIP_MODEL_REPO, local_dir_name="siglip-base-patch16-512") 54 | 55 | def get_clap_path(): 56 | """Special handling for the CLAP model which is a directory.""" 57 | return find_or_download_directory(repo_id=CLAP_MODEL_REPO, local_dir_name="clap-htsat-unfused") 58 | 59 | def find_or_download_directory(repo_id, local_dir_name): 60 | """ 61 | Finds a model directory, downloading it if it's not found. 62 | This is for models like SigLIP that are not single files. 63 | """ 64 | local_path = get_model_dir(local_dir_name) 65 | 66 | if os.path.exists(local_path) and os.listdir(local_path): 67 | logger.info(f"Found local model directory: {local_path}") 68 | return local_path 69 | 70 | logger.warning(f"Could not find {local_dir_name} directory locally. Attempting to download from {repo_id}...") 71 | 72 | # We can't use hf_hub_download for a whole directory in the same way, 73 | # but the transformers library will handle this caching for us automatically 74 | # when `from_pretrained` is called. We just need to return the repo_id. 75 | # The actual "download" is implicit. 76 | return repo_id 77 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.0]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.mean( 45 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 46 | dim=[1, 2], 47 | ) 48 | else: 49 | return 0.5 * torch.mean( 50 | torch.pow(self.mean - other.mean, 2) / other.var 51 | + self.var / other.var 52 | - 1.0 53 | - self.logvar 54 | + other.logvar, 55 | dim=[1, 2], 56 | ) 57 | 58 | def nll(self, sample, dims=[1, 2]): 59 | if self.deterministic: 60 | return torch.Tensor([0.0]) 61 | logtwopi = np.log(2.0 * np.pi) 62 | return 0.5 * torch.sum( 63 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 64 | dim=dims, 65 | ) 66 | 67 | def mode(self): 68 | return self.mean 69 | 70 | 71 | def normal_kl(mean1, logvar1, mean2, logvar2): 72 | """ 73 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 74 | Compute the KL divergence between two gaussians. 75 | Shapes are automatically broadcasted, so batches can be compared to 76 | scalars, among other use cases. 77 | """ 78 | tensor = None 79 | for obj in (mean1, logvar1, mean2, logvar2): 80 | if isinstance(obj, torch.Tensor): 81 | tensor = obj 82 | break 83 | assert tensor is not None, "at least one argument must be a Tensor" 84 | 85 | # Force variances to be Tensors. Broadcasting helps convert scalars to 86 | # Tensors, but it does not work for torch.exp(). 87 | logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] 88 | 89 | return 0.5 * ( 90 | -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 91 | ) 92 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/utils/decode.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | 4 | import argbind 5 | import numpy as np 6 | import torch 7 | from audiotools import AudioSignal 8 | from tqdm import tqdm 9 | 10 | from ..model import DACFile 11 | from . import load_model 12 | 13 | warnings.filterwarnings("ignore", category=UserWarning) 14 | 15 | 16 | @argbind.bind(group="decode", positional=True, without_prefix=True) 17 | @torch.inference_mode() 18 | @torch.no_grad() 19 | def decode( 20 | input: str, 21 | output: str = "", 22 | weights_path: str = "", 23 | model_tag: str = "latest", 24 | model_bitrate: str = "8kbps", 25 | device: str = "cuda", 26 | model_type: str = "44khz", 27 | verbose: bool = False, 28 | ): 29 | """Decode audio from codes. 30 | 31 | Parameters 32 | ---------- 33 | input : str 34 | Path to input directory or file 35 | output : str, optional 36 | Path to output directory, by default "". 37 | If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. 38 | weights_path : str, optional 39 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the 40 | model_tag and model_type. 41 | model_tag : str, optional 42 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. 43 | model_bitrate: str 44 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 45 | device : str, optional 46 | Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU. 47 | model_type : str, optional 48 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. 49 | """ 50 | generator = load_model( 51 | model_type=model_type, 52 | model_bitrate=model_bitrate, 53 | tag=model_tag, 54 | load_path=weights_path, 55 | ) 56 | generator.to(device) 57 | generator.eval() 58 | 59 | # Find all .dac files in input directory 60 | _input = Path(input) 61 | input_files = list(_input.glob("**/*.dac")) 62 | 63 | # If input is a .dac file, add it to the list 64 | if _input.suffix == ".dac": 65 | input_files.append(_input) 66 | 67 | # Create output directory 68 | output = Path(output) 69 | output.mkdir(parents=True, exist_ok=True) 70 | 71 | for i in tqdm(range(len(input_files)), desc=f"Decoding files"): 72 | # Load file 73 | artifact = DACFile.load(input_files[i]) 74 | 75 | # Reconstruct audio from codes 76 | recons = generator.decompress(artifact, verbose=verbose) 77 | 78 | # Compute output path 79 | relative_path = input_files[i].relative_to(input) 80 | output_dir = output / relative_path.parent 81 | if not relative_path.name: 82 | output_dir = output 83 | relative_path = input_files[i] 84 | output_name = relative_path.with_suffix(".wav").name 85 | output_path = output_dir / output_name 86 | output_path.parent.mkdir(parents=True, exist_ok=True) 87 | 88 | # Write to file 89 | recons.write(output_path) 90 | 91 | 92 | if __name__ == "__main__": 93 | args = argbind.parse_args() 94 | with argbind.scope(args): 95 | decode() 96 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/utils/encode.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from pathlib import Path 4 | 5 | import argbind 6 | import numpy as np 7 | import torch 8 | from audiotools import AudioSignal 9 | from audiotools.core import util 10 | from tqdm import tqdm 11 | 12 | from . import load_model 13 | 14 | warnings.filterwarnings("ignore", category=UserWarning) 15 | 16 | 17 | @argbind.bind(group="encode", positional=True, without_prefix=True) 18 | @torch.inference_mode() 19 | @torch.no_grad() 20 | def encode( 21 | input: str, 22 | output: str = "", 23 | weights_path: str = "", 24 | model_tag: str = "latest", 25 | model_bitrate: str = "8kbps", 26 | n_quantizers: int = None, 27 | device: str = "cuda", 28 | model_type: str = "44khz", 29 | win_duration: float = 5.0, 30 | verbose: bool = False, 31 | ): 32 | """Encode audio files in input path to .dac format. 33 | 34 | Parameters 35 | ---------- 36 | input : str 37 | Path to input audio file or directory 38 | output : str, optional 39 | Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. 40 | weights_path : str, optional 41 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the 42 | model_tag and model_type. 43 | model_tag : str, optional 44 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. 45 | model_bitrate: str 46 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 47 | n_quantizers : int, optional 48 | Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate. 49 | device : str, optional 50 | Device to use, by default "cuda" 51 | model_type : str, optional 52 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. 53 | """ 54 | generator = load_model( 55 | model_type=model_type, 56 | model_bitrate=model_bitrate, 57 | tag=model_tag, 58 | load_path=weights_path, 59 | ) 60 | generator.to(device) 61 | generator.eval() 62 | kwargs = {"n_quantizers": n_quantizers} 63 | 64 | # Find all audio files in input path 65 | input = Path(input) 66 | audio_files = util.find_audio(input) 67 | 68 | output = Path(output) 69 | output.mkdir(parents=True, exist_ok=True) 70 | 71 | for i in tqdm(range(len(audio_files)), desc="Encoding files"): 72 | # Load file 73 | signal = AudioSignal(audio_files[i]) 74 | 75 | # Encode audio to .dac format 76 | artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs) 77 | 78 | # Compute output path 79 | relative_path = audio_files[i].relative_to(input) 80 | output_dir = output / relative_path.parent 81 | if not relative_path.name: 82 | output_dir = output 83 | relative_path = audio_files[i] 84 | output_name = relative_path.with_suffix(".dac").name 85 | output_path = output_dir / output_name 86 | output_path.parent.mkdir(parents=True, exist_ok=True) 87 | 88 | artifact.save(output_path) 89 | 90 | 91 | if __name__ == "__main__": 92 | args = argbind.parse_args() 93 | with argbind.scope(args): 94 | encode() 95 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import argbind 4 | from audiotools import ml 5 | 6 | from ..model import DAC 7 | Accelerator = ml.Accelerator 8 | 9 | __MODEL_LATEST_TAGS__ = { 10 | ("44khz", "8kbps"): "0.0.1", 11 | ("24khz", "8kbps"): "0.0.4", 12 | ("16khz", "8kbps"): "0.0.5", 13 | ("44khz", "16kbps"): "1.0.0", 14 | } 15 | 16 | __MODEL_URLS__ = { 17 | ( 18 | "44khz", 19 | "0.0.1", 20 | "8kbps", 21 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth", 22 | ( 23 | "24khz", 24 | "0.0.4", 25 | "8kbps", 26 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth", 27 | ( 28 | "16khz", 29 | "0.0.5", 30 | "8kbps", 31 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth", 32 | ( 33 | "44khz", 34 | "1.0.0", 35 | "16kbps", 36 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth", 37 | } 38 | 39 | 40 | @argbind.bind(group="download", positional=True, without_prefix=True) 41 | def download( 42 | model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest" 43 | ): 44 | """ 45 | Function that downloads the weights file from URL if a local cache is not found. 46 | 47 | Parameters 48 | ---------- 49 | model_type : str 50 | The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". 51 | model_bitrate: str 52 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 53 | Only 44khz model supports 16kbps. 54 | tag : str 55 | The tag of the model to download. Defaults to "latest". 56 | 57 | Returns 58 | ------- 59 | Path 60 | Directory path required to load model via audiotools. 61 | """ 62 | model_type = model_type.lower() 63 | tag = tag.lower() 64 | 65 | assert model_type in [ 66 | "44khz", 67 | "24khz", 68 | "16khz", 69 | ], "model_type must be one of '44khz', '24khz', or '16khz'" 70 | 71 | assert model_bitrate in [ 72 | "8kbps", 73 | "16kbps", 74 | ], "model_bitrate must be one of '8kbps', or '16kbps'" 75 | 76 | if tag == "latest": 77 | tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)] 78 | 79 | download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None) 80 | 81 | if download_link is None: 82 | raise ValueError( 83 | f"Could not find model with tag {tag} and model type {model_type}" 84 | ) 85 | 86 | local_path = ( 87 | Path.home() 88 | / ".cache" 89 | / "descript" 90 | / "dac" 91 | / f"weights_{model_type}_{model_bitrate}_{tag}.pth" 92 | ) 93 | if not local_path.exists(): 94 | local_path.parent.mkdir(parents=True, exist_ok=True) 95 | 96 | # Download the model 97 | import requests 98 | 99 | response = requests.get(download_link) 100 | 101 | if response.status_code != 200: 102 | raise ValueError( 103 | f"Could not download model. Received response code {response.status_code}" 104 | ) 105 | local_path.write_bytes(response.content) 106 | 107 | return local_path 108 | 109 | 110 | def load_model( 111 | model_type: str = "44khz", 112 | model_bitrate: str = "8kbps", 113 | tag: str = "latest", 114 | load_path: str = None, 115 | ): 116 | if not load_path: 117 | load_path = download( 118 | model_type=model_type, model_bitrate=model_bitrate, tag=tag 119 | ) 120 | generator = DAC.load(load_path) 121 | return generator 122 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | """Configuration utilities for the HunyuanVideo-Foley project.""" 2 | 3 | import yaml 4 | from pathlib import Path 5 | from typing import Any, Dict, List, Union 6 | 7 | class AttributeDict: 8 | 9 | def __init__(self, data: Union[Dict, List, Any]): 10 | if isinstance(data, dict): 11 | for key, value in data.items(): 12 | if isinstance(value, (dict, list)): 13 | value = AttributeDict(value) 14 | setattr(self, self._sanitize_key(key), value) 15 | elif isinstance(data, list): 16 | self._list = [AttributeDict(item) if isinstance(item, (dict, list)) else item 17 | for item in data] 18 | else: 19 | self._value = data 20 | 21 | def _sanitize_key(self, key: str) -> str: 22 | import re 23 | sanitized = re.sub(r'[^a-zA-Z0-9_]', '_', str(key)) 24 | if sanitized[0].isdigit(): 25 | sanitized = f'_{sanitized}' 26 | return sanitized 27 | 28 | def __getitem__(self, key): 29 | if hasattr(self, '_list'): 30 | return self._list[key] 31 | return getattr(self, self._sanitize_key(key)) 32 | 33 | def __setitem__(self, key, value): 34 | if hasattr(self, '_list'): 35 | self._list[key] = value 36 | else: 37 | setattr(self, self._sanitize_key(key), value) 38 | 39 | def __iter__(self): 40 | if hasattr(self, '_list'): 41 | return iter(self._list) 42 | return iter(self.__dict__.keys()) 43 | 44 | def __len__(self): 45 | if hasattr(self, '_list'): 46 | return len(self._list) 47 | return len(self.__dict__) 48 | 49 | def get(self, key, default=None): 50 | try: 51 | return self[key] 52 | except (KeyError, AttributeError, IndexError): 53 | return default 54 | 55 | def keys(self): 56 | if hasattr(self, '_list'): 57 | return range(len(self._list)) 58 | elif hasattr(self, '_value'): 59 | return [] 60 | else: 61 | return [key for key in self.__dict__.keys() if not key.startswith('_')] 62 | 63 | def values(self): 64 | if hasattr(self, '_list'): 65 | return self._list 66 | elif hasattr(self, '_value'): 67 | return [self._value] 68 | else: 69 | return [value for key, value in self.__dict__.items() if not key.startswith('_')] 70 | 71 | def items(self): 72 | if hasattr(self, '_list'): 73 | return enumerate(self._list) 74 | elif hasattr(self, '_value'): 75 | return [] 76 | else: 77 | return [(key, value) for key, value in self.__dict__.items() if not key.startswith('_')] 78 | 79 | def __repr__(self): 80 | if hasattr(self, '_list'): 81 | return f"AttributeDict({self._list})" 82 | elif hasattr(self, '_value'): 83 | return f"AttributeDict({self._value})" 84 | return f"AttributeDict({dict(self.__dict__)})" 85 | 86 | def to_dict(self) -> Union[Dict, List, Any]: 87 | if hasattr(self, '_list'): 88 | return [item.to_dict() if isinstance(item, AttributeDict) else item 89 | for item in self._list] 90 | elif hasattr(self, '_value'): 91 | return self._value 92 | else: 93 | result = {} 94 | for key, value in self.__dict__.items(): 95 | if isinstance(value, AttributeDict): 96 | result[key] = value.to_dict() 97 | else: 98 | result[key] = value 99 | return result 100 | 101 | def load_yaml(file_path: str, encoding: str = 'utf-8') -> AttributeDict: 102 | try: 103 | with open(file_path, 'r', encoding=encoding) as file: 104 | data = yaml.safe_load(file) 105 | return AttributeDict(data) 106 | except FileNotFoundError: 107 | raise FileNotFoundError(f"YAML file not found: {file_path}") 108 | except yaml.YAMLError as e: 109 | raise yaml.YAMLError(f"YAML format error: {e}") 110 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/helper.py: -------------------------------------------------------------------------------- 1 | import collections.abc 2 | from itertools import repeat 3 | import importlib 4 | import yaml 5 | import time 6 | 7 | def default(value, default_val): 8 | return default_val if value is None else value 9 | 10 | 11 | def default_dtype(value, default_val): 12 | if value is not None: 13 | assert isinstance(value, type(default_val)), f"Expect {type(default_val)}, got {type(value)}." 14 | return value 15 | return default_val 16 | 17 | 18 | def repeat_interleave(lst, num_repeats): 19 | return [item for item in lst for _ in range(num_repeats)] 20 | 21 | 22 | def _ntuple(n): 23 | def parse(x): 24 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 25 | x = tuple(x) 26 | if len(x) == 1: 27 | x = tuple(repeat(x[0], n)) 28 | return x 29 | return tuple(repeat(x, n)) 30 | 31 | return parse 32 | 33 | 34 | to_1tuple = _ntuple(1) 35 | to_2tuple = _ntuple(2) 36 | to_3tuple = _ntuple(3) 37 | to_4tuple = _ntuple(4) 38 | 39 | 40 | def as_tuple(x): 41 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 42 | return tuple(x) 43 | if x is None or isinstance(x, (int, float, str)): 44 | return (x,) 45 | else: 46 | raise ValueError(f"Unknown type {type(x)}") 47 | 48 | 49 | def as_list_of_2tuple(x): 50 | x = as_tuple(x) 51 | if len(x) == 1: 52 | x = (x[0], x[0]) 53 | assert len(x) % 2 == 0, f"Expect even length, got {len(x)}." 54 | lst = [] 55 | for i in range(0, len(x), 2): 56 | lst.append((x[i], x[i + 1])) 57 | return lst 58 | 59 | 60 | def find_multiple(n: int, k: int) -> int: 61 | assert k > 0 62 | if n % k == 0: 63 | return n 64 | return n - (n % k) + k 65 | 66 | 67 | def merge_dicts(dict1, dict2): 68 | for key, value in dict2.items(): 69 | if key in dict1 and isinstance(dict1[key], dict) and isinstance(value, dict): 70 | merge_dicts(dict1[key], value) 71 | else: 72 | dict1[key] = value 73 | return dict1 74 | 75 | 76 | def merge_yaml_files(file_list): 77 | merged_config = {} 78 | 79 | for file in file_list: 80 | with open(file, "r", encoding="utf-8") as f: 81 | config = yaml.safe_load(f) 82 | if config: 83 | # Remove the first level 84 | for key, value in config.items(): 85 | if isinstance(value, dict): 86 | merged_config = merge_dicts(merged_config, value) 87 | else: 88 | merged_config[key] = value 89 | 90 | return merged_config 91 | 92 | 93 | def merge_dict(file_list): 94 | merged_config = {} 95 | 96 | for file in file_list: 97 | with open(file, "r", encoding="utf-8") as f: 98 | config = yaml.safe_load(f) 99 | if config: 100 | merged_config = merge_dicts(merged_config, config) 101 | 102 | return merged_config 103 | 104 | 105 | def get_obj_from_str(string, reload=False): 106 | module, cls = string.rsplit(".", 1) 107 | if reload: 108 | module_imp = importlib.import_module(module) 109 | importlib.reload(module_imp) 110 | return getattr(importlib.import_module(module, package=None), cls) 111 | 112 | 113 | def readable_time(seconds): 114 | """ Convert time seconds to a readable format: DD Days, HH Hours, MM Minutes, SS Seconds """ 115 | seconds = int(seconds) 116 | days, seconds = divmod(seconds, 86400) 117 | hours, seconds = divmod(seconds, 3600) 118 | minutes, seconds = divmod(seconds, 60) 119 | if days > 0: 120 | return f"{days} Days, {hours} Hours, {minutes} Minutes, {seconds} Seconds" 121 | if hours > 0: 122 | return f"{hours} Hours, {minutes} Minutes, {seconds} Seconds" 123 | if minutes > 0: 124 | return f"{minutes} Minutes, {seconds} Seconds" 125 | return f"{seconds} Seconds" 126 | 127 | 128 | def get_obj_from_cfg(cfg, reload=False): 129 | if isinstance(cfg, str): 130 | return get_obj_from_str(cfg, reload) 131 | elif isinstance(cfg, (list, tuple,)): 132 | return tuple([get_obj_from_str(c, reload) for c in cfg]) 133 | else: 134 | raise NotImplementedError(f"Not implemented for {type(cfg)}.") 135 | -------------------------------------------------------------------------------- /INSTALLATION_GUIDE.md: -------------------------------------------------------------------------------- 1 | # Installation Guide for ComfyUI HunyuanVideo-Foley Custom Node 2 | 3 | ## Overview 4 | 5 | This custom node wraps the HunyuanVideo-Foley model for use in ComfyUI, enabling text-video-to-audio synthesis directly within ComfyUI workflows. 6 | 7 | ## Prerequisites 8 | 9 | - ComfyUI installation 10 | - Python 3.8+ 11 | - CUDA-capable GPU (8GB+ VRAM recommended, can run on less with memory optimization) 12 | - At least 16GB system RAM 13 | 14 | ## Step-by-Step Installation 15 | 16 | ### 1. Clone the Custom Node 17 | 18 | Navigate to your ComfyUI `custom_nodes` directory and clone the repository: 19 | ```bash 20 | cd /path/to/ComfyUI/custom_nodes 21 | git clone https://github.com/if-ai/ComfyUI_HunyuanVideoFoley.git 22 | cd ComfyUI_HunyuanVideoFoley 23 | ``` 24 | 25 | ### 2. Install Dependencies 26 | 27 | Run the included installation script. This will check for and install any missing Python packages. 28 | ```bash 29 | python install.py 30 | ``` 31 | 32 | ### 3. Model Handling (Automatic) 33 | 34 | **No manual download is required.** 35 | 36 | The first time you use a generator node, the necessary models will be automatically downloaded and placed in the correct directory: `ComfyUI/models/foley/`. 37 | 38 | The script will create this directory for you if it doesn't exist. 39 | 40 | ### 4. Restart ComfyUI 41 | 42 | After the installation is complete, restart ComfyUI to load the new custom nodes. 43 | 44 | ## Expected Directory Structure 45 | 46 | The installer will create a `foley` directory inside your main ComfyUI `models` folder for storing the downloaded models. The custom node directory will look like this: 47 | 48 | ``` 49 | ComfyUI/ 50 | ├── models/ 51 | │ └── foley/ 52 | │ └── hunyuanvideo-foley-xxl/ 53 | │ ├── hunyuanvideo_foley.pth 54 | │ ├── vae_128d_48k.pth 55 | │ └── synchformer_state_dict.pth 56 | └── custom_nodes/ 57 | └── ComfyUI_HunyuanVideoFoley/ 58 | ├── __init__.py 59 | ├── nodes.py 60 | ├── install.py 61 | └── ... (other node files) 62 | ``` 63 | 64 | ## Usage 65 | 66 | ### Nodes Available 67 | 68 | 1. **HunyuanVideo-Foley Generator**: The main, simplified node for audio generation. 69 | 2. **HunyuanVideo-Foley Generator (Advanced)**: An advanced version that can accept pre-loaded models from loader nodes for optimized workflows. 70 | 3. **HunyuanVideo-Foley Model Loader (FP8)**: Loads the model with optional memory-saving FP8 quantization. 71 | 4. **HunyuanVideo-Foley Dependencies**: Pre-loads model dependencies like text encoders. 72 | 5. **HunyuanVideo-Foley Torch Compile**: Optimizes the model with `torch.compile` for faster inference on compatible GPUs. 73 | 74 | ## Performance & Memory Optimization 75 | 76 | The model includes several features to manage VRAM usage, allowing it to run on a wider range of hardware. 77 | 78 | - **VRAM Usage**: While 8GB of VRAM is recommended for a smooth experience, you can run the model on GPUs with less memory by enabling the following options in the generator node: 79 | - **`memory_efficient`**: This checkbox aggressively unloads models from VRAM after each generation. This is the most effective way to save VRAM. 80 | - **`cpu_offload`**: This option keeps the models on the CPU and only moves them to the GPU when needed. It is slower but significantly reduces VRAM usage. 81 | 82 | - **Generation Time**: Audio generation can take time depending on video length, settings, and hardware. Use the `HunyuanVideo-Foley Torch Compile` node for a potential speedup on subsequent runs. 83 | 84 | ## Troubleshooting 85 | 86 | ### Common Issues 87 | 88 | 1. **"Failed to import..." errors**: 89 | Ensure the installation script completed successfully. You can run it again to be sure: 90 | ```bash 91 | python install.py 92 | ``` 93 | 94 | 2. **Model download issues**: 95 | If the automatic download fails, check your internet connection and the ComfyUI console for error messages. You can also manually download the models from [HuggingFace](https://huggingface.co/tencent/HunyuanVideo-Foley) and place them in `ComfyUI/models/foley/hunyuanvideo-foley-xxl/`. 96 | 97 | 3. **CUDA out of memory**: 98 | - Enable the `memory_efficient` checkbox in the node. 99 | - Enable `cpu_offload` if you still have issues (at the cost of speed). 100 | - Reduce `sample_nums` to 1. 101 | - Use shorter videos for testing. -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/media_utils.py: -------------------------------------------------------------------------------- 1 | """Media utilities for audio/video processing.""" 2 | 3 | import os 4 | import subprocess 5 | from pathlib import Path 6 | from typing import Optional 7 | 8 | from loguru import logger 9 | 10 | 11 | class MediaProcessingError(Exception): 12 | """Exception raised for media processing errors.""" 13 | pass 14 | 15 | 16 | def merge_audio_video( 17 | audio_path: str, 18 | video_path: str, 19 | output_path: str, 20 | overwrite: bool = True, 21 | quality: str = "high" 22 | ) -> str: 23 | """ 24 | Merge audio and video files using ffmpeg. 25 | 26 | Args: 27 | audio_path: Path to input audio file 28 | video_path: Path to input video file 29 | output_path: Path for output video file 30 | overwrite: Whether to overwrite existing output file 31 | quality: Quality setting ('high', 'medium', 'low') 32 | 33 | Returns: 34 | Path to the output file 35 | 36 | Raises: 37 | MediaProcessingError: If input files don't exist or ffmpeg fails 38 | FileNotFoundError: If ffmpeg is not installed 39 | """ 40 | # Validate input files 41 | if not os.path.exists(audio_path): 42 | raise MediaProcessingError(f"Audio file not found: {audio_path}") 43 | if not os.path.exists(video_path): 44 | raise MediaProcessingError(f"Video file not found: {video_path}") 45 | 46 | # Create output directory if needed 47 | output_dir = Path(output_path).parent 48 | output_dir.mkdir(parents=True, exist_ok=True) 49 | 50 | # Quality settings 51 | quality_settings = { 52 | "high": ["-b:a", "192k"], 53 | "medium": ["-b:a", "128k"], 54 | "low": ["-b:a", "96k"] 55 | } 56 | 57 | # Build ffmpeg command with more flexible stream handling 58 | ffmpeg_command = [ 59 | "ffmpeg", 60 | "-i", video_path, 61 | "-i", audio_path, 62 | "-c:v", "copy", 63 | "-c:a", "aac", 64 | "-ac", "2", 65 | "-shortest", # Use shortest stream to avoid hanging 66 | *quality_settings.get(quality, quality_settings["high"]), 67 | ] 68 | 69 | if overwrite: 70 | ffmpeg_command.append("-y") 71 | 72 | ffmpeg_command.append(output_path) 73 | 74 | try: 75 | logger.info(f"Merging audio '{audio_path}' with video '{video_path}'") 76 | logger.info(f"FFmpeg command: {' '.join(ffmpeg_command)}") 77 | 78 | process = subprocess.Popen( 79 | ffmpeg_command, 80 | stdout=subprocess.PIPE, 81 | stderr=subprocess.PIPE, 82 | text=True 83 | ) 84 | stdout, stderr = process.communicate() 85 | 86 | if process.returncode != 0: 87 | logger.error(f"Primary merge failed, trying fallback method...") 88 | logger.error(f"FFmpeg stderr: {stderr}") 89 | 90 | # Try a more compatible fallback approach 91 | fallback_command = [ 92 | "ffmpeg", "-y", 93 | "-i", video_path, 94 | "-i", audio_path, 95 | "-c:v", "libx264", # Re-encode video for compatibility 96 | "-c:a", "aac", 97 | "-b:a", "128k", 98 | "-preset", "fast", # Faster encoding 99 | "-shortest", 100 | output_path 101 | ] 102 | 103 | logger.info(f"Fallback FFmpeg command: {' '.join(fallback_command)}") 104 | 105 | fallback_process = subprocess.Popen( 106 | fallback_command, 107 | stdout=subprocess.PIPE, 108 | stderr=subprocess.PIPE, 109 | text=True 110 | ) 111 | fallback_stdout, fallback_stderr = fallback_process.communicate() 112 | 113 | if fallback_process.returncode != 0: 114 | error_msg = f"Both primary and fallback FFmpeg failed. Primary: {stderr}, Fallback: {fallback_stderr}" 115 | logger.error(error_msg) 116 | raise MediaProcessingError(error_msg) 117 | else: 118 | logger.info(f"Successfully merged video with fallback method: {output_path}") 119 | else: 120 | logger.info(f"Successfully merged video saved to: {output_path}") 121 | 122 | except FileNotFoundError: 123 | raise FileNotFoundError( 124 | "ffmpeg not found. Please install ffmpeg: " 125 | "https://ffmpeg.org/download.html" 126 | ) 127 | except Exception as e: 128 | raise MediaProcessingError(f"Unexpected error during media processing: {e}") 129 | 130 | return output_path 131 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/utils.py: -------------------------------------------------------------------------------- 1 | from hashlib import md5 2 | from pathlib import Path 3 | import subprocess 4 | 5 | import requests 6 | from tqdm import tqdm 7 | 8 | PARENT_LINK = "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a" 9 | FNAME2LINK = { 10 | # S3: Synchability: AudioSet (run 2) 11 | "24-01-22T20-34-52.pt": f"{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/24-01-22T20-34-52.pt", 12 | "cfg-24-01-22T20-34-52.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/cfg-24-01-22T20-34-52.yaml", 13 | # S2: Synchformer: AudioSet (run 2) 14 | "24-01-04T16-39-21.pt": f"{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/24-01-04T16-39-21.pt", 15 | "cfg-24-01-04T16-39-21.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml", 16 | # S2: Synchformer: AudioSet (run 1) 17 | "23-08-28T11-23-23.pt": f"{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/23-08-28T11-23-23.pt", 18 | "cfg-23-08-28T11-23-23.yaml": f"{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/cfg-23-08-28T11-23-23.yaml", 19 | # S2: Synchformer: LRS3 (run 2) 20 | "23-12-23T18-33-57.pt": f"{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/23-12-23T18-33-57.pt", 21 | "cfg-23-12-23T18-33-57.yaml": f"{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/cfg-23-12-23T18-33-57.yaml", 22 | # S2: Synchformer: VGS (run 2) 23 | "24-01-02T10-00-53.pt": f"{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/24-01-02T10-00-53.pt", 24 | "cfg-24-01-02T10-00-53.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/cfg-24-01-02T10-00-53.yaml", 25 | # SparseSync: ft VGGSound-Full 26 | "22-09-21T21-00-52.pt": f"{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/22-09-21T21-00-52.pt", 27 | "cfg-22-09-21T21-00-52.yaml": f"{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/cfg-22-09-21T21-00-52.yaml", 28 | # SparseSync: ft VGGSound-Sparse 29 | "22-07-28T15-49-45.pt": f"{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/22-07-28T15-49-45.pt", 30 | "cfg-22-07-28T15-49-45.yaml": f"{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/cfg-22-07-28T15-49-45.yaml", 31 | # SparseSync: only pt on LRS3 32 | "22-07-13T22-25-49.pt": f"{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/22-07-13T22-25-49.pt", 33 | "cfg-22-07-13T22-25-49.yaml": f"{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/cfg-22-07-13T22-25-49.yaml", 34 | # SparseSync: feature extractors 35 | "ResNetAudio-22-08-04T09-51-04.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-04T09-51-04.pt", # 2s 36 | "ResNetAudio-22-08-03T23-14-49.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-49.pt", # 3s 37 | "ResNetAudio-22-08-03T23-14-28.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-28.pt", # 4s 38 | "ResNetAudio-22-06-24T08-10-33.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T08-10-33.pt", # 5s 39 | "ResNetAudio-22-06-24T17-31-07.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T17-31-07.pt", # 6s 40 | "ResNetAudio-22-06-24T23-57-11.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T23-57-11.pt", # 7s 41 | "ResNetAudio-22-06-25T04-35-42.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-25T04-35-42.pt", # 8s 42 | } 43 | 44 | 45 | def check_if_file_exists_else_download(path, fname2link=FNAME2LINK, chunk_size=1024): 46 | """Checks if file exists, if not downloads it from the link to the path""" 47 | path = Path(path) 48 | if not path.exists(): 49 | path.parent.mkdir(exist_ok=True, parents=True) 50 | link = fname2link.get(path.name, None) 51 | if link is None: 52 | raise ValueError( 53 | f"Cant find the checkpoint file: {path}.", f"Please download it manually and ensure the path exists." 54 | ) 55 | with requests.get(fname2link[path.name], stream=True) as r: 56 | total_size = int(r.headers.get("content-length", 0)) 57 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 58 | with open(path, "wb") as f: 59 | for data in r.iter_content(chunk_size=chunk_size): 60 | if data: 61 | f.write(data) 62 | pbar.update(chunk_size) 63 | 64 | 65 | def which_ffmpeg() -> str: 66 | """Determines the path to ffmpeg library 67 | Returns: 68 | str -- path to the library 69 | """ 70 | result = subprocess.run(["which", "ffmpeg"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 71 | ffmpeg_path = result.stdout.decode("utf-8").replace("\n", "") 72 | return ffmpeg_path 73 | 74 | 75 | def get_md5sum(path): 76 | hash_md5 = md5() 77 | with open(path, "rb") as f: 78 | for chunk in iter(lambda: f.read(4096 * 8), b""): 79 | hash_md5.update(chunk) 80 | md5sum = hash_md5.hexdigest() 81 | return md5sum 82 | 83 | 84 | class Config: 85 | def __init__(self, **kwargs): 86 | for k, v in kwargs.items(): 87 | setattr(self, k, v) 88 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/embed_layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from ...utils.helper import to_2tuple, to_1tuple 6 | 7 | class PatchEmbed1D(nn.Module): 8 | """1D Audio to Patch Embedding 9 | 10 | A convolution based approach to patchifying a 1D audio w/ embedding projection. 11 | 12 | Based on the impl in https://github.com/google-research/vision_transformer 13 | 14 | Hacked together by / Copyright 2020 Ross Wightman 15 | """ 16 | 17 | def __init__( 18 | self, 19 | patch_size=1, 20 | in_chans=768, 21 | embed_dim=768, 22 | norm_layer=None, 23 | flatten=True, 24 | bias=True, 25 | dtype=None, 26 | device=None, 27 | ): 28 | factory_kwargs = {"dtype": dtype, "device": device} 29 | super().__init__() 30 | patch_size = to_1tuple(patch_size) 31 | self.patch_size = patch_size 32 | self.flatten = flatten 33 | 34 | self.proj = nn.Conv1d( 35 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs 36 | ) 37 | nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1)) 38 | if bias: 39 | nn.init.zeros_(self.proj.bias) 40 | 41 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 42 | 43 | def forward(self, x): 44 | assert ( 45 | x.shape[2] % self.patch_size[0] == 0 46 | ), f"The patch_size of {self.patch_size[0]} must be divisible by the token number ({x.shape[2]}) of x." 47 | 48 | x = self.proj(x) 49 | if self.flatten: 50 | x = x.transpose(1, 2) # BCN -> BNC 51 | x = self.norm(x) 52 | return x 53 | 54 | 55 | class ConditionProjection(nn.Module): 56 | """ 57 | Projects condition embeddings. Also handles dropout for classifier-free guidance. 58 | 59 | Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py 60 | """ 61 | 62 | def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): 63 | factory_kwargs = {'dtype': dtype, 'device': device} 64 | super().__init__() 65 | self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs) 66 | self.act_1 = act_layer() 67 | self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs) 68 | 69 | def forward(self, caption): 70 | hidden_states = self.linear_1(caption) 71 | hidden_states = self.act_1(hidden_states) 72 | hidden_states = self.linear_2(hidden_states) 73 | return hidden_states 74 | 75 | 76 | def timestep_embedding(t, dim, max_period=10000): 77 | """ 78 | Create sinusoidal timestep embeddings. 79 | 80 | Args: 81 | t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. 82 | dim (int): the dimension of the output. 83 | max_period (int): controls the minimum frequency of the embeddings. 84 | 85 | Returns: 86 | embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. 87 | 88 | .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 89 | """ 90 | half = dim // 2 91 | freqs = torch.exp( 92 | -math.log(max_period) 93 | * torch.arange(start=0, end=half, dtype=torch.float32) 94 | / half 95 | ).to(device=t.device) 96 | args = t[:, None].float() * freqs[None] 97 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 98 | if dim % 2: 99 | embedding = torch.cat( 100 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 101 | ) 102 | return embedding 103 | 104 | 105 | class TimestepEmbedder(nn.Module): 106 | """ 107 | Embeds scalar timesteps into vector representations. 108 | """ 109 | def __init__(self, 110 | hidden_size, 111 | act_layer, 112 | frequency_embedding_size=256, 113 | max_period=10000, 114 | out_size=None, 115 | dtype=None, 116 | device=None 117 | ): 118 | factory_kwargs = {'dtype': dtype, 'device': device} 119 | super().__init__() 120 | self.frequency_embedding_size = frequency_embedding_size 121 | self.max_period = max_period 122 | if out_size is None: 123 | out_size = hidden_size 124 | 125 | self.mlp = nn.Sequential( 126 | nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs), 127 | act_layer(), 128 | nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), 129 | ) 130 | nn.init.normal_(self.mlp[0].weight, std=0.02) 131 | nn.init.normal_(self.mlp[2].weight, std=0.02) 132 | 133 | def forward(self, t): 134 | t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) 135 | t_emb = self.mlp(t_freq) 136 | return t_emb 137 | -------------------------------------------------------------------------------- /download_models_manual.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Manual download helper for HunyuanVideo-Foley models 4 | Run this script directly if the automatic download fails in ComfyUI 5 | """ 6 | 7 | import os 8 | import sys 9 | from pathlib import Path 10 | import urllib.request 11 | import time 12 | 13 | # Model download URLs 14 | MODELS = [ 15 | { 16 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/hunyuanvideo_foley.pth", 17 | "filename": "hunyuanvideo_foley.pth", 18 | "size_gb": 10.3, 19 | "description": "Main HunyuanVideo-Foley model" 20 | }, 21 | { 22 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/synchformer_state_dict.pth", 23 | "filename": "synchformer_state_dict.pth", 24 | "size_gb": 0.95, 25 | "description": "Synchformer model weights" 26 | }, 27 | { 28 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/vae_128d_48k.pth", 29 | "filename": "vae_128d_48k.pth", 30 | "size_gb": 1.49, 31 | "description": "VAE model weights" 32 | } 33 | ] 34 | 35 | def download_with_progress(url, dest_path): 36 | """Download file with progress display""" 37 | def progress_hook(block_num, block_size, total_size): 38 | if total_size > 0: 39 | downloaded = block_num * block_size 40 | percent = min(100, (downloaded * 100) // total_size) 41 | size_mb = downloaded / (1024 * 1024) 42 | total_mb = total_size / (1024 * 1024) 43 | 44 | # Print progress 45 | bar_len = 40 46 | filled_len = int(bar_len * percent // 100) 47 | bar = '█' * filled_len + '-' * (bar_len - filled_len) 48 | 49 | sys.stdout.write(f'\r[{bar}] {percent}% ({size_mb:.1f}/{total_mb:.1f} MB)') 50 | sys.stdout.flush() 51 | 52 | try: 53 | urllib.request.urlretrieve(url, dest_path, progress_hook) 54 | print() # New line after progress 55 | return True 56 | except Exception as e: 57 | print(f"\nError: {e}") 58 | return False 59 | 60 | def main(): 61 | # Determine ComfyUI models directory 62 | comfyui_root = Path(__file__).parent.parent.parent # Go up to ComfyUI root 63 | models_dir = comfyui_root / "models" / "foley" / "hunyuanvideo-foley-xxl" 64 | 65 | print("=" * 60) 66 | print("HunyuanVideo-Foley Model Downloader") 67 | print("=" * 60) 68 | print(f"\nModels will be downloaded to:") 69 | print(f" {models_dir}") 70 | 71 | # Create directory if it doesn't exist 72 | models_dir.mkdir(parents=True, exist_ok=True) 73 | 74 | # Check disk space 75 | import shutil 76 | stat = shutil.disk_usage(models_dir) 77 | available_gb = stat.free / (1024 ** 3) 78 | required_gb = sum(m["size_gb"] for m in MODELS) + 1 # Add 1GB buffer 79 | 80 | print(f"\nDisk space available: {available_gb:.1f} GB") 81 | print(f"Space required: {required_gb:.1f} GB") 82 | 83 | if available_gb < required_gb: 84 | print(f"\n⚠️ WARNING: Insufficient disk space!") 85 | print(f"Please free up at least {required_gb - available_gb:.1f} GB before continuing.") 86 | response = input("\nContinue anyway? (y/n): ") 87 | if response.lower() != 'y': 88 | return 89 | 90 | print("\n" + "=" * 60) 91 | 92 | # Download each model 93 | for i, model_info in enumerate(MODELS, 1): 94 | model_path = models_dir / model_info["filename"] 95 | 96 | print(f"\n[{i}/{len(MODELS)}] {model_info['description']}") 97 | print(f" File: {model_info['filename']} ({model_info['size_gb']} GB)") 98 | 99 | # Check if already downloaded 100 | if model_path.exists() and model_path.stat().st_size > 100 * 1024 * 1024: 101 | size_gb = model_path.stat().st_size / (1024 ** 3) 102 | print(f" ✓ Already downloaded ({size_gb:.2f} GB)") 103 | continue 104 | 105 | print(f" Downloading from: {model_info['url']}") 106 | 107 | # Try downloading with retries 108 | max_retries = 3 109 | for attempt in range(max_retries): 110 | if attempt > 0: 111 | print(f" Retry {attempt}/{max_retries - 1}...") 112 | time.sleep(5) 113 | 114 | success = download_with_progress(model_info["url"], model_path) 115 | if success: 116 | size_gb = model_path.stat().st_size / (1024 ** 3) 117 | print(f" ✓ Downloaded successfully ({size_gb:.2f} GB)") 118 | break 119 | else: 120 | if attempt == max_retries - 1: 121 | print(f" ✗ Failed to download after {max_retries} attempts") 122 | print(f"\n You can manually download from:") 123 | print(f" {model_info['url']}") 124 | print(f" And place it at:") 125 | print(f" {model_path}") 126 | 127 | print("\n" + "=" * 60) 128 | print("Download process completed!") 129 | print("\nYou can now use the HunyuanVideo-Foley node in ComfyUI.") 130 | print("=" * 60) 131 | 132 | if __name__ == "__main__": 133 | main() -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | """ 2 | Installation script for ComfyUI HunyuanVideo-Foley Custom Node 3 | """ 4 | 5 | import os 6 | import sys 7 | import subprocess 8 | import pkg_resources 9 | from pathlib import Path 10 | 11 | def parse_requirements(file_path): 12 | """Parse requirements file and handle git dependencies.""" 13 | requirements = [] 14 | with open(file_path, 'r') as f: 15 | for line in f: 16 | line = line.strip() 17 | if line and not line.startswith('#'): 18 | if line.startswith('git+'): 19 | # For git repos, find the package name from the egg fragment 20 | egg_name = None 21 | if '#egg=' in line: 22 | egg_name = line.split('#egg=')[-1] 23 | 24 | if egg_name: 25 | requirements.append((egg_name, line)) 26 | else: 27 | print(f"⚠️ Git requirement '{line}' is missing the '#egg=' part and cannot be checked. It will be installed regardless.") 28 | # Fallback: We can't check it, so we'll just try to install it. 29 | # The package name is passed as None to signal an install attempt. 30 | requirements.append((None, line)) 31 | else: 32 | # Standard package 33 | req = pkg_resources.Requirement.parse(line) 34 | requirements.append((req.project_name, str(req))) 35 | return requirements 36 | 37 | def check_and_install_requirements(): 38 | """Check and install required packages without overriding existing ones.""" 39 | requirements_file = Path(__file__).parent / "requirements.txt" 40 | 41 | if not requirements_file.exists(): 42 | print("❌ Requirements file not found!") 43 | return False 44 | 45 | try: 46 | print("🚀 Checking and installing requirements...") 47 | 48 | # Get list of (package_name, requirement_string) 49 | requirements = parse_requirements(requirements_file) 50 | 51 | for pkg_name, requirement_str in requirements: 52 | # If pkg_name is None, it's a git URL we couldn't parse. Try installing. 53 | if pkg_name is None: 54 | print(f"Attempting to install from git: {requirement_str}") 55 | try: 56 | subprocess.check_call([sys.executable, '-m', 'pip', 'install', requirement_str]) 57 | print(f"✅ Successfully installed {requirement_str}") 58 | except subprocess.CalledProcessError as e: 59 | print(f"❌ Failed to install {requirement_str}: {e}") 60 | continue 61 | 62 | # Check if the package is already installed 63 | try: 64 | pkg_resources.require(requirement_str) 65 | print(f"✅ {pkg_name} is already installed and meets version requirements.") 66 | except pkg_resources.DistributionNotFound: 67 | print(f"Installing {pkg_name}...") 68 | try: 69 | subprocess.check_call([sys.executable, '-m', 'pip', 'install', requirement_str]) 70 | print(f"✅ Successfully installed {pkg_name}") 71 | except subprocess.CalledProcessError as e: 72 | print(f"❌ Failed to install {pkg_name}: {e}") 73 | except pkg_resources.VersionConflict as e: 74 | print(f"⚠️ Version conflict for {pkg_name}: {e.req} is required, but you have {e.dist}.") 75 | print(" Skipping upgrade to avoid conflicts with other nodes. If you encounter issues, please update this package manually.") 76 | except Exception as e: 77 | print(f"An unexpected error occurred while checking {pkg_name}: {e}") 78 | 79 | print("✅ All dependencies checked.") 80 | return True 81 | 82 | except Exception as e: 83 | print(f"❌ Error installing requirements: {e}") 84 | return False 85 | 86 | def setup_model_directories(): 87 | """Create necessary model directories""" 88 | base_dir = Path(__file__).parent.parent.parent # Go up to ComfyUI root 89 | 90 | # Create ComfyUI/models/foley directory for automatic downloads 91 | foley_models_dir = base_dir / "models" / "foley" 92 | foley_models_dir.mkdir(parents=True, exist_ok=True) 93 | print(f"✓ Created ComfyUI models directory: {foley_models_dir}") 94 | 95 | # Also create local fallback directories 96 | node_dir = Path(__file__).parent 97 | local_dirs = [ 98 | node_dir / "pretrained_models", 99 | node_dir / "configs" 100 | ] 101 | 102 | for dir_path in local_dirs: 103 | dir_path.mkdir(exist_ok=True) 104 | print(f"✓ Created local directory: {dir_path}") 105 | 106 | def main(): 107 | """Main installation function""" 108 | print("🚀 Installing ComfyUI HunyuanVideo-Foley Custom Node...") 109 | 110 | # Install requirements 111 | if not check_and_install_requirements(): 112 | print("❌ Failed to install requirements") 113 | return False 114 | 115 | # Setup directories 116 | setup_model_directories() 117 | 118 | print("📋 Installation completed!") 119 | print() 120 | print("📌 Next steps:") 121 | print("1. Restart ComfyUI to load the custom nodes") 122 | print("2. Models will be automatically downloaded when you first use the node") 123 | print("3. Alternatively, manually download models and place them in ComfyUI/models/foley/") 124 | print("4. Model URLs are configured in model_urls.py (can be updated as needed)") 125 | print() 126 | 127 | return True 128 | 129 | if __name__ == "__main__": 130 | main() -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/mlp_layers.py: -------------------------------------------------------------------------------- 1 | # Modified from timm library: 2 | # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13 3 | 4 | from functools import partial 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .modulate_layers import modulate 11 | from ...utils.helper import to_2tuple 12 | 13 | class MLP(nn.Module): 14 | """MLP as used in Vision Transformer, MLP-Mixer and related networks""" 15 | 16 | def __init__( 17 | self, 18 | in_channels, 19 | hidden_channels=None, 20 | out_features=None, 21 | act_layer=nn.GELU, 22 | norm_layer=None, 23 | bias=True, 24 | drop=0.0, 25 | use_conv=False, 26 | device=None, 27 | dtype=None, 28 | ): 29 | factory_kwargs = {"device": device, "dtype": dtype} 30 | super().__init__() 31 | out_features = out_features or in_channels 32 | hidden_channels = hidden_channels or in_channels 33 | bias = to_2tuple(bias) 34 | drop_probs = to_2tuple(drop) 35 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear 36 | 37 | self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs) 38 | self.act = act_layer() 39 | self.drop1 = nn.Dropout(drop_probs[0]) 40 | self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity() 41 | self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs) 42 | self.drop2 = nn.Dropout(drop_probs[1]) 43 | 44 | def forward(self, x): 45 | x = self.fc1(x) 46 | x = self.act(x) 47 | x = self.drop1(x) 48 | x = self.norm(x) 49 | x = self.fc2(x) 50 | x = self.drop2(x) 51 | return x 52 | 53 | 54 | # copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py 55 | # only used when use_vanilla is True 56 | class MLPEmbedder(nn.Module): 57 | def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None): 58 | factory_kwargs = {"device": device, "dtype": dtype} 59 | super().__init__() 60 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs) 61 | self.silu = nn.SiLU() 62 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs) 63 | 64 | def forward(self, x: torch.Tensor) -> torch.Tensor: 65 | return self.out_layer(self.silu(self.in_layer(x))) 66 | 67 | 68 | class LinearWarpforSingle(nn.Module): 69 | def __init__(self, in_dim: int, out_dim: int, bias=True, device=None, dtype=None): 70 | factory_kwargs = {"device": device, "dtype": dtype} 71 | super().__init__() 72 | self.fc = nn.Linear(in_dim, out_dim, bias=bias, **factory_kwargs) 73 | 74 | def forward(self, x, y): 75 | z = torch.cat([x, y], dim=2) 76 | return self.fc(z) 77 | 78 | class FinalLayer1D(nn.Module): 79 | def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None): 80 | factory_kwargs = {"device": device, "dtype": dtype} 81 | super().__init__() 82 | 83 | # Just use LayerNorm for the final layer 84 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) 85 | self.linear = nn.Linear(hidden_size, patch_size * out_channels, bias=True, **factory_kwargs) 86 | nn.init.zeros_(self.linear.weight) 87 | nn.init.zeros_(self.linear.bias) 88 | 89 | # Here we don't distinguish between the modulate types. Just use the simple one. 90 | self.adaLN_modulation = nn.Sequential( 91 | act_layer(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs) 92 | ) 93 | # Zero-initialize the modulation 94 | nn.init.zeros_(self.adaLN_modulation[1].weight) 95 | nn.init.zeros_(self.adaLN_modulation[1].bias) 96 | 97 | def forward(self, x, c): 98 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) 99 | x = modulate(self.norm_final(x), shift=shift, scale=scale) 100 | x = self.linear(x) 101 | return x 102 | 103 | 104 | class ChannelLastConv1d(nn.Conv1d): 105 | 106 | def forward(self, x: torch.Tensor) -> torch.Tensor: 107 | x = x.permute(0, 2, 1) 108 | x = super().forward(x) 109 | x = x.permute(0, 2, 1) 110 | return x 111 | 112 | 113 | class ConvMLP(nn.Module): 114 | 115 | def __init__( 116 | self, 117 | dim: int, 118 | hidden_dim: int, 119 | multiple_of: int = 256, 120 | kernel_size: int = 3, 121 | padding: int = 1, 122 | device=None, 123 | dtype=None, 124 | ): 125 | """ 126 | Convolutional MLP module. 127 | 128 | Args: 129 | dim (int): Input dimension. 130 | hidden_dim (int): Hidden dimension of the feedforward layer. 131 | multiple_of (int): Value to ensure hidden dimension is a multiple of this value. 132 | 133 | Attributes: 134 | w1: Linear transformation for the first layer. 135 | w2: Linear transformation for the second layer. 136 | w3: Linear transformation for the third layer. 137 | 138 | """ 139 | factory_kwargs = {"device": device, "dtype": dtype} 140 | super().__init__() 141 | hidden_dim = int(2 * hidden_dim / 3) 142 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 143 | 144 | self.w1 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs) 145 | self.w2 = ChannelLastConv1d(hidden_dim, dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs) 146 | self.w3 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs) 147 | 148 | def forward(self, x): 149 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 150 | -------------------------------------------------------------------------------- /test_node.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Test script for ComfyUI HunyuanVideo-Foley custom node 4 | """ 5 | 6 | import sys 7 | import os 8 | import tempfile 9 | from pathlib import Path 10 | 11 | # Add the parent directory to path for imports 12 | current_dir = Path(__file__).parent 13 | parent_dir = current_dir.parent 14 | sys.path.insert(0, str(parent_dir)) 15 | 16 | def test_imports(): 17 | """Test that all required modules can be imported""" 18 | print("Testing imports...") 19 | 20 | try: 21 | from ComfyUI_HunyuanVideoFoley import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 22 | print("✅ Successfully imported node mappings") 23 | 24 | print(f"Available nodes: {list(NODE_CLASS_MAPPINGS.keys())}") 25 | print(f"Display names: {NODE_DISPLAY_NAME_MAPPINGS}") 26 | 27 | return True 28 | 29 | except ImportError as e: 30 | print(f"❌ Import failed: {e}") 31 | return False 32 | 33 | def test_node_structure(): 34 | """Test node class structure""" 35 | print("\nTesting node structure...") 36 | 37 | try: 38 | from ComfyUI_HunyuanVideoFoley.nodes import HunyuanVideoFoleyNode, HunyuanVideoFoleyModelLoader 39 | 40 | # Test HunyuanVideoFoleyNode 41 | node = HunyuanVideoFoleyNode() 42 | input_types = node.INPUT_TYPES() 43 | 44 | print("✅ HunyuanVideoFoleyNode structure:") 45 | print(f" - Required inputs: {list(input_types['required'].keys())}") 46 | print(f" - Optional inputs: {list(input_types.get('optional', {}).keys())}") 47 | print(f" - Return types: {node.RETURN_TYPES}") 48 | print(f" - Function: {node.FUNCTION}") 49 | print(f" - Category: {node.CATEGORY}") 50 | 51 | # Test HunyuanVideoFoleyModelLoader 52 | loader = HunyuanVideoFoleyModelLoader() 53 | loader_input_types = loader.INPUT_TYPES() 54 | 55 | print("✅ HunyuanVideoFoleyModelLoader structure:") 56 | print(f" - Required inputs: {list(loader_input_types['required'].keys())}") 57 | print(f" - Return types: {loader.RETURN_TYPES}") 58 | print(f" - Function: {loader.FUNCTION}") 59 | 60 | return True 61 | 62 | except Exception as e: 63 | print(f"❌ Node structure test failed: {e}") 64 | return False 65 | 66 | def test_device_setup(): 67 | """Test device setup functionality""" 68 | print("\nTesting device setup...") 69 | 70 | try: 71 | from ComfyUI_HunyuanVideoFoley.nodes import HunyuanVideoFoleyNode 72 | 73 | device = HunyuanVideoFoleyNode.setup_device("auto") 74 | print(f"✅ Device setup successful: {device}") 75 | 76 | return True 77 | 78 | except Exception as e: 79 | print(f"❌ Device setup failed: {e}") 80 | return False 81 | 82 | def test_utils(): 83 | """Test utility functions""" 84 | print("\nTesting utility functions...") 85 | 86 | try: 87 | from ComfyUI_HunyuanVideoFoley.utils import ( 88 | get_optimal_device, 89 | check_memory_requirements, 90 | format_duration, 91 | validate_model_files 92 | ) 93 | 94 | # Test device detection 95 | device = get_optimal_device() 96 | print(f"✅ Optimal device: {device}") 97 | 98 | # Test memory check 99 | has_memory, msg = check_memory_requirements(device) 100 | print(f"✅ Memory check: {msg}") 101 | 102 | # Test duration formatting 103 | duration = format_duration(125.5) 104 | print(f"✅ Duration formatting: 125.5s -> {duration}") 105 | 106 | # Test model validation (will fail without models, but that's expected) 107 | is_valid, msg = validate_model_files("./pretrained_models/") 108 | print(f"✅ Model validation: {msg}") 109 | 110 | return True 111 | 112 | except Exception as e: 113 | print(f"❌ Utils test failed: {e}") 114 | return False 115 | 116 | def test_requirements(): 117 | """Test if key requirements are available""" 118 | print("\nTesting requirements...") 119 | 120 | required_packages = [ 121 | 'torch', 122 | 'torchaudio', 123 | 'numpy', 124 | 'loguru', 125 | 'diffusers', 126 | 'transformers' 127 | ] 128 | 129 | missing = [] 130 | for package in required_packages: 131 | try: 132 | __import__(package) 133 | print(f"✅ {package}") 134 | except ImportError: 135 | print(f"❌ {package} - not installed") 136 | missing.append(package) 137 | 138 | if missing: 139 | print(f"\nMissing packages: {', '.join(missing)}") 140 | print("Run: pip install -r requirements.txt") 141 | return False 142 | 143 | return True 144 | 145 | def main(): 146 | """Main test function""" 147 | print("🧪 Testing ComfyUI HunyuanVideo-Foley Custom Node") 148 | print("=" * 50) 149 | 150 | tests = [ 151 | ("Requirements", test_requirements), 152 | ("Imports", test_imports), 153 | ("Node Structure", test_node_structure), 154 | ("Device Setup", test_device_setup), 155 | ("Utils", test_utils), 156 | ] 157 | 158 | passed = 0 159 | failed = 0 160 | 161 | for test_name, test_func in tests: 162 | print(f"\n🔍 Running test: {test_name}") 163 | try: 164 | if test_func(): 165 | passed += 1 166 | print(f"✅ {test_name} PASSED") 167 | else: 168 | failed += 1 169 | print(f"❌ {test_name} FAILED") 170 | except Exception as e: 171 | failed += 1 172 | print(f"❌ {test_name} FAILED with exception: {e}") 173 | 174 | print("\n" + "=" * 50) 175 | print(f"📊 Test Results: {passed} passed, {failed} failed") 176 | 177 | if failed == 0: 178 | print("🎉 All tests passed! The custom node is ready for use.") 179 | else: 180 | print("⚠️ Some tests failed. Please check the issues above.") 181 | 182 | return failed == 0 183 | 184 | if __name__ == "__main__": 185 | success = main() 186 | sys.exit(0 if success else 1) -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/feature_utils.py: -------------------------------------------------------------------------------- 1 | """Feature extraction utilities for video and text processing.""" 2 | 3 | import os 4 | import numpy as np 5 | import torch 6 | import av 7 | from PIL import Image 8 | from einops import rearrange 9 | from typing import Any, Dict, List, Union, Tuple 10 | from loguru import logger 11 | 12 | from .config_utils import AttributeDict 13 | from ..constants import FPS_VISUAL, MAX_VIDEO_DURATION_SECONDS 14 | 15 | 16 | class FeatureExtractionError(Exception): 17 | """Exception raised for feature extraction errors.""" 18 | pass 19 | 20 | def get_frames_av( 21 | video_path: str, 22 | fps: float, 23 | max_length: float = None, 24 | ) -> Tuple[np.ndarray, float]: 25 | end_sec = max_length if max_length is not None else 15 26 | next_frame_time_for_each_fps = 0.0 27 | time_delta_for_each_fps = 1 / fps 28 | 29 | all_frames = [] 30 | output_frames = [] 31 | 32 | with av.open(video_path) as container: 33 | stream = container.streams.video[0] 34 | ori_fps = stream.guessed_rate 35 | stream.thread_type = "AUTO" 36 | for packet in container.demux(stream): 37 | for frame in packet.decode(): 38 | frame_time = frame.time 39 | if frame_time < 0: 40 | continue 41 | if frame_time > end_sec: 42 | break 43 | 44 | frame_np = None 45 | 46 | this_time = frame_time 47 | while this_time >= next_frame_time_for_each_fps: 48 | if frame_np is None: 49 | frame_np = frame.to_ndarray(format="rgb24") 50 | 51 | output_frames.append(frame_np) 52 | next_frame_time_for_each_fps += time_delta_for_each_fps 53 | 54 | output_frames = np.stack(output_frames) 55 | 56 | vid_len_in_s = len(output_frames) / fps 57 | if max_length is not None and len(output_frames) > int(max_length * fps): 58 | output_frames = output_frames[: int(max_length * fps)] 59 | vid_len_in_s = max_length 60 | 61 | return output_frames, vid_len_in_s 62 | 63 | @torch.inference_mode() 64 | def encode_video_with_siglip2(x: torch.Tensor, model_dict, batch_size: int = -1): 65 | b, t, c, h, w = x.shape 66 | if batch_size < 0: 67 | batch_size = b * t 68 | x = rearrange(x, "b t c h w -> (b t) c h w") 69 | outputs = [] 70 | for i in range(0, b * t, batch_size): 71 | pixel_values = x[i : i + batch_size] 72 | # --- Transformers Compatibility Fix --- 73 | if hasattr(model_dict.siglip2_model, 'get_image_features'): 74 | # Older transformers versions 75 | features = model_dict.siglip2_model.get_image_features(pixel_values=pixel_values) 76 | else: 77 | # Newer transformers versions 78 | features = model_dict.siglip2_model(pixel_values=pixel_values).image_embeds 79 | outputs.append(features) 80 | # --- End of Fix --- 81 | res = torch.cat(outputs, dim=0) 82 | res = rearrange(res, "(b t) d -> b t d", b=b) 83 | return res 84 | 85 | @torch.inference_mode() 86 | def encode_video_with_sync(x: torch.Tensor, model_dict, batch_size: int = -1): 87 | """ 88 | The input video of x is best to be in fps of 24 of greater than 24. 89 | Input: 90 | x: tensor in shape of [B, T, C, H, W] 91 | batch_size: the batch_size for synchformer inference 92 | """ 93 | b, t, c, h, w = x.shape 94 | assert c == 3 and h == 224 and w == 224 95 | 96 | segment_size = 16 97 | step_size = 8 98 | num_segments = (t - segment_size) // step_size + 1 99 | segments = [] 100 | for i in range(num_segments): 101 | segments.append(x[:, i * step_size : i * step_size + segment_size]) 102 | x = torch.stack(segments, dim=1).cuda() # (B, num_segments, segment_size, 3, 224, 224) 103 | 104 | outputs = [] 105 | if batch_size < 0: 106 | batch_size = b * num_segments 107 | x = rearrange(x, "b s t c h w -> (b s) 1 t c h w") 108 | for i in range(0, b * num_segments, batch_size): 109 | with torch.autocast(device_type="cuda", enabled=True, dtype=torch.half): 110 | outputs.append(model_dict.syncformer_model(x[i : i + batch_size])) 111 | x = torch.cat(outputs, dim=0) # [b * num_segments, 1, 8, 768] 112 | x = rearrange(x, "(b s) 1 t d -> b (s t) d", b=b) 113 | return x 114 | 115 | 116 | @torch.inference_mode() 117 | def encode_video_features(video_path, model_dict): 118 | visual_features = {} 119 | # siglip2 visual features 120 | frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["siglip2"]) 121 | images = [Image.fromarray(frame).convert('RGB') for frame in frames] 122 | images = [model_dict.siglip2_preprocess(image) for image in images] # [T, C, H, W] 123 | clip_frames = torch.stack(images).to(model_dict.device).unsqueeze(0) 124 | visual_features['siglip2_feat'] = encode_video_with_siglip2(clip_frames, model_dict).to(model_dict.device) 125 | 126 | # synchformer visual features 127 | frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["synchformer"]) 128 | images = torch.from_numpy(frames).permute(0, 3, 1, 2) # [T, C, H, W] 129 | sync_frames = model_dict.syncformer_preprocess(images).unsqueeze(0) # [1, T, 3, 224, 224] 130 | # [1, num_segments * 8, channel_dim], e.g. [1, 240, 768] for 10s video 131 | visual_features['syncformer_feat'] = encode_video_with_sync(sync_frames, model_dict) 132 | 133 | vid_len_in_s = sync_frames.shape[1] / FPS_VISUAL["synchformer"] 134 | visual_features = AttributeDict(visual_features) 135 | 136 | return visual_features, vid_len_in_s 137 | 138 | @torch.inference_mode() 139 | def encode_text_feat(text: List[str], model_dict): 140 | # x: (B, L) 141 | inputs = model_dict.clap_tokenizer(text, padding=True, return_tensors="pt").to(model_dict.device) 142 | outputs = model_dict.clap_model(**inputs, output_hidden_states=True, return_dict=True) 143 | return outputs.last_hidden_state, outputs.attentions 144 | 145 | 146 | def feature_process(video_path, prompt, model_dict, cfg): 147 | visual_feats, audio_len_in_s = encode_video_features(video_path, model_dict) 148 | neg_prompt = "noisy, harsh" 149 | prompts = [neg_prompt, prompt] 150 | text_feat_res, text_feat_mask = encode_text_feat(prompts, model_dict) 151 | 152 | text_feat = text_feat_res[1:] 153 | uncond_text_feat = text_feat_res[:1] 154 | 155 | if cfg.model_config.model_kwargs.text_length < text_feat.shape[1]: 156 | text_seq_length = cfg.model_config.model_kwargs.text_length 157 | text_feat = text_feat[:, :text_seq_length] 158 | uncond_text_feat = uncond_text_feat[:, :text_seq_length] 159 | 160 | text_feats = AttributeDict({ 161 | 'text_feat': text_feat, 162 | 'uncond_text_feat': uncond_text_feat, 163 | }) 164 | 165 | return visual_feats, text_feats, audio_len_in_s 166 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/posemb_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, Tuple 3 | 4 | 5 | def _to_tuple(x, dim=2): 6 | if isinstance(x, int): 7 | return (x,) * dim 8 | elif len(x) == dim: 9 | return x 10 | else: 11 | raise ValueError(f"Expected length {dim} or int, but got {x}") 12 | 13 | 14 | def get_meshgrid_nd(start, *args, dim=2): 15 | """ 16 | Get n-D meshgrid with start, stop and num. 17 | 18 | Args: 19 | start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, 20 | step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num 21 | should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in 22 | n-tuples. 23 | *args: See above. 24 | dim (int): Dimension of the meshgrid. Defaults to 2. 25 | 26 | Returns: 27 | grid (np.ndarray): [dim, ...] 28 | """ 29 | if len(args) == 0: 30 | # start is grid_size 31 | num = _to_tuple(start, dim=dim) 32 | start = (0,) * dim 33 | stop = num 34 | elif len(args) == 1: 35 | # start is start, args[0] is stop, step is 1 36 | start = _to_tuple(start, dim=dim) 37 | stop = _to_tuple(args[0], dim=dim) 38 | num = [stop[i] - start[i] for i in range(dim)] 39 | elif len(args) == 2: 40 | # start is start, args[0] is stop, args[1] is num 41 | start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 42 | stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 43 | num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 44 | else: 45 | raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") 46 | 47 | # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) 48 | axis_grid = [] 49 | for i in range(dim): 50 | a, b, n = start[i], stop[i], num[i] 51 | g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] 52 | axis_grid.append(g) 53 | grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] 54 | grid = torch.stack(grid, dim=0) # [dim, W, H, D] 55 | 56 | return grid 57 | 58 | 59 | ################################################################################# 60 | # Rotary Positional Embedding Functions # 61 | ################################################################################# 62 | # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80 63 | 64 | 65 | def get_nd_rotary_pos_embed( 66 | rope_dim_list, start, *args, theta=10000.0, use_real=False, theta_rescale_factor=1.0, freq_scaling=1.0 67 | ): 68 | """ 69 | This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. 70 | 71 | Args: 72 | rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. 73 | sum(rope_dim_list) should equal to head_dim of attention layer. 74 | start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, 75 | args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. 76 | *args: See above. 77 | theta (float): Scaling factor for frequency computation. Defaults to 10000.0. 78 | use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. 79 | Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real 80 | part and an imaginary part separately. 81 | theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. 82 | freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0. 83 | 84 | Returns: 85 | pos_embed (torch.Tensor): [HW, D/2] 86 | """ 87 | 88 | grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] 89 | 90 | # use 1/ndim of dimensions to encode grid_axis 91 | embs = [] 92 | for i in range(len(rope_dim_list)): 93 | emb = get_1d_rotary_pos_embed( 94 | rope_dim_list[i], 95 | grid[i].reshape(-1), 96 | theta, 97 | use_real=use_real, 98 | theta_rescale_factor=theta_rescale_factor, 99 | freq_scaling=freq_scaling, 100 | ) # 2 x [WHD, rope_dim_list[i]] 101 | embs.append(emb) 102 | 103 | if use_real: 104 | cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) 105 | sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) 106 | return cos, sin 107 | else: 108 | emb = torch.cat(embs, dim=1) # (WHD, D/2) 109 | return emb 110 | 111 | 112 | def get_1d_rotary_pos_embed( 113 | dim: int, 114 | pos: Union[torch.FloatTensor, int], 115 | theta: float = 10000.0, 116 | use_real: bool = False, 117 | theta_rescale_factor: float = 1.0, 118 | freq_scaling: float = 1.0, 119 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 120 | """ 121 | Precompute the frequency tensor for complex exponential (cis) with given dimensions. 122 | (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) 123 | 124 | This function calculates a frequency tensor with complex exponential using the given dimension 'dim' 125 | and the end index 'end'. The 'theta' parameter scales the frequencies. 126 | The returned tensor contains complex values in complex64 data type. 127 | 128 | Args: 129 | dim (int): Dimension of the frequency tensor. 130 | pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar 131 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. 132 | use_real (bool, optional): If True, return real part and imaginary part separately. 133 | Otherwise, return complex numbers. 134 | theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. 135 | freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0. 136 | 137 | Returns: 138 | freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] 139 | freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] 140 | """ 141 | if isinstance(pos, int): 142 | pos = torch.arange(pos).float() 143 | 144 | # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning 145 | # has some connection to NTK literature 146 | # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ 147 | if theta_rescale_factor != 1.0: 148 | theta *= theta_rescale_factor ** (dim / (dim - 1)) 149 | 150 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] 151 | freqs *= freq_scaling 152 | freqs = torch.outer(pos, freqs) # [S, D/2] 153 | if use_real: 154 | freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] 155 | freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] 156 | return freqs_cos, freqs_sin 157 | else: 158 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] 159 | return freqs_cis 160 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/model/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from audiotools import AudioSignal 5 | from audiotools import ml 6 | from audiotools import STFTParams 7 | from einops import rearrange 8 | from torch.nn.utils import weight_norm 9 | 10 | 11 | def WNConv1d(*args, **kwargs): 12 | act = kwargs.pop("act", True) 13 | conv = weight_norm(nn.Conv1d(*args, **kwargs)) 14 | if not act: 15 | return conv 16 | return nn.Sequential(conv, nn.LeakyReLU(0.1)) 17 | 18 | 19 | def WNConv2d(*args, **kwargs): 20 | act = kwargs.pop("act", True) 21 | conv = weight_norm(nn.Conv2d(*args, **kwargs)) 22 | if not act: 23 | return conv 24 | return nn.Sequential(conv, nn.LeakyReLU(0.1)) 25 | 26 | 27 | class MPD(nn.Module): 28 | def __init__(self, period): 29 | super().__init__() 30 | self.period = period 31 | self.convs = nn.ModuleList( 32 | [ 33 | WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), 34 | WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), 35 | WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), 36 | WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), 37 | WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), 38 | ] 39 | ) 40 | self.conv_post = WNConv2d( 41 | 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False 42 | ) 43 | 44 | def pad_to_period(self, x): 45 | t = x.shape[-1] 46 | x = F.pad(x, (0, self.period - t % self.period), mode="reflect") 47 | return x 48 | 49 | def forward(self, x): 50 | fmap = [] 51 | 52 | x = self.pad_to_period(x) 53 | x = rearrange(x, "b c (l p) -> b c l p", p=self.period) 54 | 55 | for layer in self.convs: 56 | x = layer(x) 57 | fmap.append(x) 58 | 59 | x = self.conv_post(x) 60 | fmap.append(x) 61 | 62 | return fmap 63 | 64 | 65 | class MSD(nn.Module): 66 | def __init__(self, rate: int = 1, sample_rate: int = 44100): 67 | super().__init__() 68 | self.convs = nn.ModuleList( 69 | [ 70 | WNConv1d(1, 16, 15, 1, padding=7), 71 | WNConv1d(16, 64, 41, 4, groups=4, padding=20), 72 | WNConv1d(64, 256, 41, 4, groups=16, padding=20), 73 | WNConv1d(256, 1024, 41, 4, groups=64, padding=20), 74 | WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), 75 | WNConv1d(1024, 1024, 5, 1, padding=2), 76 | ] 77 | ) 78 | self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) 79 | self.sample_rate = sample_rate 80 | self.rate = rate 81 | 82 | def forward(self, x): 83 | x = AudioSignal(x, self.sample_rate) 84 | x.resample(self.sample_rate // self.rate) 85 | x = x.audio_data 86 | 87 | fmap = [] 88 | 89 | for l in self.convs: 90 | x = l(x) 91 | fmap.append(x) 92 | x = self.conv_post(x) 93 | fmap.append(x) 94 | 95 | return fmap 96 | 97 | 98 | BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] 99 | 100 | 101 | class MRD(nn.Module): 102 | def __init__( 103 | self, 104 | window_length: int, 105 | hop_factor: float = 0.25, 106 | sample_rate: int = 44100, 107 | bands: list = BANDS, 108 | ): 109 | """Complex multi-band spectrogram discriminator. 110 | Parameters 111 | ---------- 112 | window_length : int 113 | Window length of STFT. 114 | hop_factor : float, optional 115 | Hop factor of the STFT, defaults to ``0.25 * window_length``. 116 | sample_rate : int, optional 117 | Sampling rate of audio in Hz, by default 44100 118 | bands : list, optional 119 | Bands to run discriminator over. 120 | """ 121 | super().__init__() 122 | 123 | self.window_length = window_length 124 | self.hop_factor = hop_factor 125 | self.sample_rate = sample_rate 126 | self.stft_params = STFTParams( 127 | window_length=window_length, 128 | hop_length=int(window_length * hop_factor), 129 | match_stride=True, 130 | ) 131 | 132 | n_fft = window_length // 2 + 1 133 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] 134 | self.bands = bands 135 | 136 | ch = 32 137 | convs = lambda: nn.ModuleList( 138 | [ 139 | WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), 140 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 141 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 142 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 143 | WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), 144 | ] 145 | ) 146 | self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) 147 | self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) 148 | 149 | def spectrogram(self, x): 150 | x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) 151 | x = torch.view_as_real(x.stft()) 152 | x = rearrange(x, "b 1 f t c -> (b 1) c t f") 153 | # Split into bands 154 | x_bands = [x[..., b[0] : b[1]] for b in self.bands] 155 | return x_bands 156 | 157 | def forward(self, x): 158 | x_bands = self.spectrogram(x) 159 | fmap = [] 160 | 161 | x = [] 162 | for band, stack in zip(x_bands, self.band_convs): 163 | for layer in stack: 164 | band = layer(band) 165 | fmap.append(band) 166 | x.append(band) 167 | 168 | x = torch.cat(x, dim=-1) 169 | x = self.conv_post(x) 170 | fmap.append(x) 171 | 172 | return fmap 173 | 174 | 175 | class Discriminator(ml.BaseModel): 176 | def __init__( 177 | self, 178 | rates: list = [], 179 | periods: list = [2, 3, 5, 7, 11], 180 | fft_sizes: list = [2048, 1024, 512], 181 | sample_rate: int = 44100, 182 | bands: list = BANDS, 183 | ): 184 | """Discriminator that combines multiple discriminators. 185 | 186 | Parameters 187 | ---------- 188 | rates : list, optional 189 | sampling rates (in Hz) to run MSD at, by default [] 190 | If empty, MSD is not used. 191 | periods : list, optional 192 | periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] 193 | fft_sizes : list, optional 194 | Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] 195 | sample_rate : int, optional 196 | Sampling rate of audio in Hz, by default 44100 197 | bands : list, optional 198 | Bands to run MRD at, by default `BANDS` 199 | """ 200 | super().__init__() 201 | discs = [] 202 | discs += [MPD(p) for p in periods] 203 | discs += [MSD(r, sample_rate=sample_rate) for r in rates] 204 | discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] 205 | self.discriminators = nn.ModuleList(discs) 206 | 207 | def preprocess(self, y): 208 | # Remove DC offset 209 | y = y - y.mean(dim=-1, keepdims=True) 210 | # Peak normalize the volume of input audio 211 | y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) 212 | return y 213 | 214 | def forward(self, x): 215 | x = self.preprocess(x) 216 | fmaps = [d(x) for d in self.discriminators] 217 | return fmaps 218 | 219 | 220 | if __name__ == "__main__": 221 | disc = Discriminator() 222 | x = torch.zeros(1, 1, 44100) 223 | results = disc(x) 224 | for i, result in enumerate(results): 225 | print(f"disc{i}") 226 | for i, r in enumerate(result): 227 | print(r.shape, r.mean(), r.min(), r.max()) 228 | print() 229 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/compute_desync_score.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | from pathlib import Path 4 | 5 | import torch 6 | import torchaudio 7 | import torchvision 8 | from omegaconf import OmegaConf 9 | 10 | import data_transforms 11 | from .synchformer import Synchformer 12 | from .data_transforms import make_class_grid, quantize_offset 13 | from .utils import check_if_file_exists_else_download, which_ffmpeg 14 | 15 | 16 | def prepare_inputs(batch, device): 17 | aud = batch["audio"].to(device) 18 | vid = batch["video"].to(device) 19 | 20 | return aud, vid 21 | 22 | 23 | def get_test_transforms(): 24 | ts = [ 25 | data_transforms.EqualifyFromRight(), 26 | data_transforms.RGBSpatialCrop(input_size=224, is_random=False), 27 | data_transforms.TemporalCropAndOffset( 28 | crop_len_sec=5, 29 | max_off_sec=2, # https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml 30 | max_wiggle_sec=0.0, 31 | do_offset=True, 32 | offset_type="grid", 33 | prob_oos="null", 34 | grid_size=21, 35 | segment_size_vframes=16, 36 | n_segments=14, 37 | step_size_seg=0.5, 38 | vfps=25, 39 | ), 40 | data_transforms.GenerateMultipleSegments( 41 | segment_size_vframes=16, 42 | n_segments=14, 43 | is_start_random=False, 44 | step_size_seg=0.5, 45 | ), 46 | data_transforms.RGBToHalfToZeroOne(), 47 | data_transforms.RGBNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # motionformer normalization 48 | data_transforms.AudioMelSpectrogram( 49 | sample_rate=16000, 50 | win_length=400, # 25 ms * 16 kHz 51 | hop_length=160, # 10 ms * 16 kHz 52 | n_fft=1024, # 2^(ceil(log2(window_size * sampling_rate))) 53 | n_mels=128, # as in AST 54 | ), 55 | data_transforms.AudioLog(), 56 | data_transforms.PadOrTruncate(max_spec_t=66), 57 | data_transforms.AudioNormalizeAST(mean=-4.2677393, std=4.5689974), # AST, pre-trained on AudioSet 58 | data_transforms.PermuteStreams( 59 | einops_order_audio="S F T -> S 1 F T", einops_order_rgb="S T C H W -> S T C H W" # same 60 | ), 61 | ] 62 | transforms = torchvision.transforms.Compose(ts) 63 | 64 | return transforms 65 | 66 | 67 | def get_video_and_audio(path, get_meta=False, start_sec=0, end_sec=None): 68 | orig_path = path 69 | # (Tv, 3, H, W) [0, 255, uint8]; (Ca, Ta) 70 | rgb, audio, meta = torchvision.io.read_video(str(path), start_sec, end_sec, "sec", output_format="TCHW") 71 | assert meta["video_fps"], f"No video fps for {orig_path}" 72 | # (Ta) <- (Ca, Ta) 73 | audio = audio.mean(dim=0) 74 | # FIXME: this is legacy format of `meta` as it used to be loaded by VideoReader. 75 | meta = { 76 | "video": {"fps": [meta["video_fps"]]}, 77 | "audio": {"framerate": [meta["audio_fps"]]}, 78 | } 79 | return rgb, audio, meta 80 | 81 | 82 | def reencode_video(path, vfps=25, afps=16000, in_size=256): 83 | assert which_ffmpeg() != "", "Is ffmpeg installed? Check if the conda environment is activated." 84 | new_path = Path.cwd() / "vis" / f"{Path(path).stem}_{vfps}fps_{in_size}side_{afps}hz.mp4" 85 | new_path.parent.mkdir(exist_ok=True) 86 | new_path = str(new_path) 87 | cmd = f"{which_ffmpeg()}" 88 | # no info/error printing 89 | cmd += " -hide_banner -loglevel panic" 90 | cmd += f" -y -i {path}" 91 | # 1) change fps, 2) resize: min(H,W)=MIN_SIDE (vertical vids are supported), 3) change audio framerate 92 | cmd += f" -vf fps={vfps},scale=iw*{in_size}/'min(iw,ih)':ih*{in_size}/'min(iw,ih)',crop='trunc(iw/2)'*2:'trunc(ih/2)'*2" 93 | cmd += f" -ar {afps}" 94 | cmd += f" {new_path}" 95 | subprocess.call(cmd.split()) 96 | cmd = f"{which_ffmpeg()}" 97 | cmd += " -hide_banner -loglevel panic" 98 | cmd += f" -y -i {new_path}" 99 | cmd += f" -acodec pcm_s16le -ac 1" 100 | cmd += f' {new_path.replace(".mp4", ".wav")}' 101 | subprocess.call(cmd.split()) 102 | return new_path 103 | 104 | 105 | def decode_single_video_prediction(off_logits, grid, item): 106 | label = item["targets"]["offset_label"].item() 107 | print("Ground Truth offset (sec):", f"{label:.2f} ({quantize_offset(grid, label)[-1].item()})") 108 | print() 109 | print("Prediction Results:") 110 | off_probs = torch.softmax(off_logits, dim=-1) 111 | k = min(off_probs.shape[-1], 5) 112 | topk_logits, topk_preds = torch.topk(off_logits, k) 113 | # remove batch dimension 114 | assert len(topk_logits) == 1, "batch is larger than 1" 115 | topk_logits = topk_logits[0] 116 | topk_preds = topk_preds[0] 117 | off_logits = off_logits[0] 118 | off_probs = off_probs[0] 119 | for target_hat in topk_preds: 120 | print(f'p={off_probs[target_hat]:.4f} ({off_logits[target_hat]:.4f}), "{grid[target_hat]:.2f}" ({target_hat})') 121 | return off_probs 122 | 123 | 124 | def main(args): 125 | vfps = 25 126 | afps = 16000 127 | in_size = 256 128 | # making the offset class grid similar to the one used in transforms, 129 | # refer to the used one: https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml 130 | max_off_sec = 2 131 | num_cls = 21 132 | 133 | # checking if the provided video has the correct frame rates 134 | print(f"Using video: {args.vid_path}") 135 | v, _, info = torchvision.io.read_video(args.vid_path, pts_unit="sec") 136 | _, H, W, _ = v.shape 137 | if info["video_fps"] != vfps or info["audio_fps"] != afps or min(H, W) != in_size: 138 | print(f'Reencoding. vfps: {info["video_fps"]} -> {vfps};', end=" ") 139 | print(f'afps: {info["audio_fps"]} -> {afps};', end=" ") 140 | print(f"{(H, W)} -> min(H, W)={in_size}") 141 | args.vid_path = reencode_video(args.vid_path, vfps, afps, in_size) 142 | else: 143 | print(f'Skipping reencoding. vfps: {info["video_fps"]}; afps: {info["audio_fps"]}; min(H, W)={in_size}') 144 | 145 | device = torch.device(args.device) 146 | 147 | # load visual and audio streams 148 | # rgb: (Tv, 3, H, W) in [0, 225], audio: (Ta,) in [-1, 1] 149 | rgb, audio, meta = get_video_and_audio(args.vid_path, get_meta=True) 150 | 151 | # making an item (dict) to apply transformations 152 | # NOTE: here is how it works: 153 | # For instance, if the model is trained on 5sec clips, the provided video is 9sec, and `v_start_i_sec=1.3` 154 | # the transform will crop out a 5sec-clip from 1.3 to 6.3 seconds and shift the start of the audio 155 | # track by `args.offset_sec` seconds. It means that if `offset_sec` > 0, the audio will 156 | # start by `offset_sec` earlier than the rgb track. 157 | # It is a good idea to use something in [-`max_off_sec`, `max_off_sec`] (-2, +2) seconds (see `grid`) 158 | item = dict( 159 | video=rgb, 160 | audio=audio, 161 | meta=meta, 162 | path=args.vid_path, 163 | split="test", 164 | targets={ 165 | "v_start_i_sec": args.v_start_i_sec, 166 | "offset_sec": args.offset_sec, 167 | }, 168 | ) 169 | 170 | grid = make_class_grid(-max_off_sec, max_off_sec, num_cls) 171 | if not (min(grid) <= item["targets"]["offset_sec"] <= max(grid)): 172 | print(f'WARNING: offset_sec={item["targets"]["offset_sec"]} is outside the trained grid: {grid}') 173 | 174 | # applying the test-time transform 175 | item = get_test_transforms()(item) 176 | 177 | # prepare inputs for inference 178 | batch = torch.utils.data.default_collate([item]) 179 | aud, vid = prepare_inputs(batch, device) 180 | 181 | # TODO: 182 | # sanity check: we will take the input to the `model` and recontruct make a video from it. 183 | # Use this check to make sure the input makes sense (audio should be ok but shifted as you specified) 184 | # reconstruct_video_from_input(aud, vid, batch['meta'], args.vid_path, args.v_start_i_sec, args.offset_sec, 185 | # vfps, afps) 186 | 187 | # forward pass 188 | with torch.set_grad_enabled(False): 189 | with torch.autocast("cuda", enabled=True): 190 | _, logits = synchformer(vid, aud) 191 | 192 | # simply prints the results of the prediction 193 | decode_single_video_prediction(logits, grid, item) 194 | 195 | 196 | if __name__ == "__main__": 197 | parser = argparse.ArgumentParser() 198 | parser.add_argument("--exp_name", required=True, help="In a format: xx-xx-xxTxx-xx-xx") 199 | parser.add_argument("--vid_path", required=True, help="A path to .mp4 video") 200 | parser.add_argument("--offset_sec", type=float, default=0.0) 201 | parser.add_argument("--v_start_i_sec", type=float, default=0.0) 202 | parser.add_argument("--device", default="cuda:0") 203 | args = parser.parse_args() 204 | 205 | synchformer = Synchformer().cuda().eval() 206 | synchformer.load_state_dict( 207 | torch.load( 208 | os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"), 209 | weights_only=True, 210 | map_location="cpu", 211 | ) 212 | ) 213 | 214 | main(args) 215 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/quantize.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | from torch.nn.utils import weight_norm 9 | 10 | from .layers import WNConv1d 11 | 12 | 13 | class VectorQuantize(nn.Module): 14 | """ 15 | Implementation of VQ similar to Karpathy's repo: 16 | https://github.com/karpathy/deep-vector-quantization 17 | Additionally uses following tricks from Improved VQGAN 18 | (https://arxiv.org/pdf/2110.04627.pdf): 19 | 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space 20 | for improved codebook usage 21 | 2. l2-normalized codes: Converts euclidean distance to cosine similarity which 22 | improves training stability 23 | """ 24 | 25 | def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): 26 | super().__init__() 27 | self.codebook_size = codebook_size 28 | self.codebook_dim = codebook_dim 29 | 30 | self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) 31 | self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) 32 | self.codebook = nn.Embedding(codebook_size, codebook_dim) 33 | 34 | def forward(self, z): 35 | """Quantized the input tensor using a fixed codebook and returns 36 | the corresponding codebook vectors 37 | 38 | Parameters 39 | ---------- 40 | z : Tensor[B x D x T] 41 | 42 | Returns 43 | ------- 44 | Tensor[B x D x T] 45 | Quantized continuous representation of input 46 | Tensor[1] 47 | Commitment loss to train encoder to predict vectors closer to codebook 48 | entries 49 | Tensor[1] 50 | Codebook loss to update the codebook 51 | Tensor[B x T] 52 | Codebook indices (quantized discrete representation of input) 53 | Tensor[B x D x T] 54 | Projected latents (continuous representation of input before quantization) 55 | """ 56 | 57 | # Factorized codes (ViT-VQGAN) Project input into low-dimensional space 58 | z_e = self.in_proj(z) # z_e : (B x D x T) 59 | z_q, indices = self.decode_latents(z_e) 60 | 61 | commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) 62 | codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) 63 | 64 | z_q = ( 65 | z_e + (z_q - z_e).detach() 66 | ) # noop in forward pass, straight-through gradient estimator in backward pass 67 | 68 | z_q = self.out_proj(z_q) 69 | 70 | return z_q, commitment_loss, codebook_loss, indices, z_e 71 | 72 | def embed_code(self, embed_id): 73 | return F.embedding(embed_id, self.codebook.weight) 74 | 75 | def decode_code(self, embed_id): 76 | return self.embed_code(embed_id).transpose(1, 2) 77 | 78 | def decode_latents(self, latents): 79 | encodings = rearrange(latents, "b d t -> (b t) d") 80 | codebook = self.codebook.weight # codebook: (N x D) 81 | 82 | # L2 normalize encodings and codebook (ViT-VQGAN) 83 | encodings = F.normalize(encodings) 84 | codebook = F.normalize(codebook) 85 | 86 | # Compute euclidean distance with codebook 87 | dist = ( 88 | encodings.pow(2).sum(1, keepdim=True) 89 | - 2 * encodings @ codebook.t() 90 | + codebook.pow(2).sum(1, keepdim=True).t() 91 | ) 92 | indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) 93 | z_q = self.decode_code(indices) 94 | return z_q, indices 95 | 96 | 97 | class ResidualVectorQuantize(nn.Module): 98 | """ 99 | Introduced in SoundStream: An end2end neural audio codec 100 | https://arxiv.org/abs/2107.03312 101 | """ 102 | 103 | def __init__( 104 | self, 105 | input_dim: int = 512, 106 | n_codebooks: int = 9, 107 | codebook_size: int = 1024, 108 | codebook_dim: Union[int, list] = 8, 109 | quantizer_dropout: float = 0.0, 110 | ): 111 | super().__init__() 112 | if isinstance(codebook_dim, int): 113 | codebook_dim = [codebook_dim for _ in range(n_codebooks)] 114 | 115 | self.n_codebooks = n_codebooks 116 | self.codebook_dim = codebook_dim 117 | self.codebook_size = codebook_size 118 | 119 | self.quantizers = nn.ModuleList( 120 | [ 121 | VectorQuantize(input_dim, codebook_size, codebook_dim[i]) 122 | for i in range(n_codebooks) 123 | ] 124 | ) 125 | self.quantizer_dropout = quantizer_dropout 126 | 127 | def forward(self, z, n_quantizers: int = None): 128 | """Quantized the input tensor using a fixed set of `n` codebooks and returns 129 | the corresponding codebook vectors 130 | Parameters 131 | ---------- 132 | z : Tensor[B x D x T] 133 | n_quantizers : int, optional 134 | No. of quantizers to use 135 | (n_quantizers < self.n_codebooks ex: for quantizer dropout) 136 | Note: if `self.quantizer_dropout` is True, this argument is ignored 137 | when in training mode, and a random number of quantizers is used. 138 | Returns 139 | ------- 140 | dict 141 | A dictionary with the following keys: 142 | 143 | "z" : Tensor[B x D x T] 144 | Quantized continuous representation of input 145 | "codes" : Tensor[B x N x T] 146 | Codebook indices for each codebook 147 | (quantized discrete representation of input) 148 | "latents" : Tensor[B x N*D x T] 149 | Projected latents (continuous representation of input before quantization) 150 | "vq/commitment_loss" : Tensor[1] 151 | Commitment loss to train encoder to predict vectors closer to codebook 152 | entries 153 | "vq/codebook_loss" : Tensor[1] 154 | Codebook loss to update the codebook 155 | """ 156 | z_q = 0 157 | residual = z 158 | commitment_loss = 0 159 | codebook_loss = 0 160 | 161 | codebook_indices = [] 162 | latents = [] 163 | 164 | if n_quantizers is None: 165 | n_quantizers = self.n_codebooks 166 | if self.training: 167 | n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 168 | dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) 169 | n_dropout = int(z.shape[0] * self.quantizer_dropout) 170 | n_quantizers[:n_dropout] = dropout[:n_dropout] 171 | n_quantizers = n_quantizers.to(z.device) 172 | 173 | for i, quantizer in enumerate(self.quantizers): 174 | if self.training is False and i >= n_quantizers: 175 | break 176 | 177 | z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( 178 | residual 179 | ) 180 | 181 | # Create mask to apply quantizer dropout 182 | mask = ( 183 | torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers 184 | ) 185 | z_q = z_q + z_q_i * mask[:, None, None] 186 | residual = residual - z_q_i 187 | 188 | # Sum losses 189 | commitment_loss += (commitment_loss_i * mask).mean() 190 | codebook_loss += (codebook_loss_i * mask).mean() 191 | 192 | codebook_indices.append(indices_i) 193 | latents.append(z_e_i) 194 | 195 | codes = torch.stack(codebook_indices, dim=1) 196 | latents = torch.cat(latents, dim=1) 197 | 198 | return z_q, codes, latents, commitment_loss, codebook_loss 199 | 200 | def from_codes(self, codes: torch.Tensor): 201 | """Given the quantized codes, reconstruct the continuous representation 202 | Parameters 203 | ---------- 204 | codes : Tensor[B x N x T] 205 | Quantized discrete representation of input 206 | Returns 207 | ------- 208 | Tensor[B x D x T] 209 | Quantized continuous representation of input 210 | """ 211 | z_q = 0.0 212 | z_p = [] 213 | n_codebooks = codes.shape[1] 214 | for i in range(n_codebooks): 215 | z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) 216 | z_p.append(z_p_i) 217 | 218 | z_q_i = self.quantizers[i].out_proj(z_p_i) 219 | z_q = z_q + z_q_i 220 | return z_q, torch.cat(z_p, dim=1), codes 221 | 222 | def from_latents(self, latents: torch.Tensor): 223 | """Given the unquantized latents, reconstruct the 224 | continuous representation after quantization. 225 | 226 | Parameters 227 | ---------- 228 | latents : Tensor[B x N x T] 229 | Continuous representation of input after projection 230 | 231 | Returns 232 | ------- 233 | Tensor[B x D x T] 234 | Quantized representation of full-projected space 235 | Tensor[B x D x T] 236 | Quantized representation of latent space 237 | """ 238 | z_q = 0 239 | z_p = [] 240 | codes = [] 241 | dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) 242 | 243 | n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ 244 | 0 245 | ] 246 | for i in range(n_codebooks): 247 | j, k = dims[i], dims[i + 1] 248 | z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) 249 | z_p.append(z_p_i) 250 | codes.append(codes_i) 251 | 252 | z_q_i = self.quantizers[i].out_proj(z_p_i) 253 | z_q = z_q + z_q_i 254 | 255 | return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) 256 | 257 | 258 | if __name__ == "__main__": 259 | rvq = ResidualVectorQuantize(quantizer_dropout=True) 260 | x = torch.randn(16, 512, 80) 261 | y = rvq(x) 262 | print(y["latents"].shape) 263 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/model/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Union 5 | 6 | import numpy as np 7 | import torch 8 | import tqdm 9 | from audiotools import AudioSignal 10 | from torch import nn 11 | 12 | SUPPORTED_VERSIONS = ["1.0.0"] 13 | 14 | 15 | @dataclass 16 | class DACFile: 17 | codes: torch.Tensor 18 | 19 | # Metadata 20 | chunk_length: int 21 | original_length: int 22 | input_db: float 23 | channels: int 24 | sample_rate: int 25 | padding: bool 26 | dac_version: str 27 | 28 | def save(self, path): 29 | artifacts = { 30 | "codes": self.codes.numpy().astype(np.uint16), 31 | "metadata": { 32 | "input_db": self.input_db.numpy().astype(np.float32), 33 | "original_length": self.original_length, 34 | "sample_rate": self.sample_rate, 35 | "chunk_length": self.chunk_length, 36 | "channels": self.channels, 37 | "padding": self.padding, 38 | "dac_version": SUPPORTED_VERSIONS[-1], 39 | }, 40 | } 41 | path = Path(path).with_suffix(".dac") 42 | with open(path, "wb") as f: 43 | np.save(f, artifacts) 44 | return path 45 | 46 | @classmethod 47 | def load(cls, path): 48 | artifacts = np.load(path, allow_pickle=True)[()] 49 | codes = torch.from_numpy(artifacts["codes"].astype(int)) 50 | if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: 51 | raise RuntimeError( 52 | f"Given file {path} can't be loaded with this version of descript-audio-codec." 53 | ) 54 | return cls(codes=codes, **artifacts["metadata"]) 55 | 56 | 57 | class CodecMixin: 58 | @property 59 | def padding(self): 60 | if not hasattr(self, "_padding"): 61 | self._padding = True 62 | return self._padding 63 | 64 | @padding.setter 65 | def padding(self, value): 66 | assert isinstance(value, bool) 67 | 68 | layers = [ 69 | l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d)) 70 | ] 71 | 72 | for layer in layers: 73 | if value: 74 | if hasattr(layer, "original_padding"): 75 | layer.padding = layer.original_padding 76 | else: 77 | layer.original_padding = layer.padding 78 | layer.padding = tuple(0 for _ in range(len(layer.padding))) 79 | 80 | self._padding = value 81 | 82 | def get_delay(self): 83 | # Any number works here, delay is invariant to input length 84 | l_out = self.get_output_length(0) 85 | L = l_out 86 | 87 | layers = [] 88 | for layer in self.modules(): 89 | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): 90 | layers.append(layer) 91 | 92 | for layer in reversed(layers): 93 | d = layer.dilation[0] 94 | k = layer.kernel_size[0] 95 | s = layer.stride[0] 96 | 97 | if isinstance(layer, nn.ConvTranspose1d): 98 | L = ((L - d * (k - 1) - 1) / s) + 1 99 | elif isinstance(layer, nn.Conv1d): 100 | L = (L - 1) * s + d * (k - 1) + 1 101 | 102 | L = math.ceil(L) 103 | 104 | l_in = L 105 | 106 | return (l_in - l_out) // 2 107 | 108 | def get_output_length(self, input_length): 109 | L = input_length 110 | # Calculate output length 111 | for layer in self.modules(): 112 | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): 113 | d = layer.dilation[0] 114 | k = layer.kernel_size[0] 115 | s = layer.stride[0] 116 | 117 | if isinstance(layer, nn.Conv1d): 118 | L = ((L - d * (k - 1) - 1) / s) + 1 119 | elif isinstance(layer, nn.ConvTranspose1d): 120 | L = (L - 1) * s + d * (k - 1) + 1 121 | 122 | L = math.floor(L) 123 | return L 124 | 125 | @torch.no_grad() 126 | def compress( 127 | self, 128 | audio_path_or_signal: Union[str, Path, AudioSignal], 129 | win_duration: float = 1.0, 130 | verbose: bool = False, 131 | normalize_db: float = -16, 132 | n_quantizers: int = None, 133 | ) -> DACFile: 134 | """Processes an audio signal from a file or AudioSignal object into 135 | discrete codes. This function processes the signal in short windows, 136 | using constant GPU memory. 137 | 138 | Parameters 139 | ---------- 140 | audio_path_or_signal : Union[str, Path, AudioSignal] 141 | audio signal to reconstruct 142 | win_duration : float, optional 143 | window duration in seconds, by default 5.0 144 | verbose : bool, optional 145 | by default False 146 | normalize_db : float, optional 147 | normalize db, by default -16 148 | 149 | Returns 150 | ------- 151 | DACFile 152 | Object containing compressed codes and metadata 153 | required for decompression 154 | """ 155 | audio_signal = audio_path_or_signal 156 | if isinstance(audio_signal, (str, Path)): 157 | audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) 158 | 159 | self.eval() 160 | original_padding = self.padding 161 | original_device = audio_signal.device 162 | 163 | audio_signal = audio_signal.clone() 164 | audio_signal = audio_signal.to_mono() 165 | original_sr = audio_signal.sample_rate 166 | 167 | resample_fn = audio_signal.resample 168 | loudness_fn = audio_signal.loudness 169 | 170 | # If audio is > 10 minutes long, use the ffmpeg versions 171 | if audio_signal.signal_duration >= 10 * 60 * 60: 172 | resample_fn = audio_signal.ffmpeg_resample 173 | loudness_fn = audio_signal.ffmpeg_loudness 174 | 175 | original_length = audio_signal.signal_length 176 | resample_fn(self.sample_rate) 177 | input_db = loudness_fn() 178 | 179 | if normalize_db is not None: 180 | audio_signal.normalize(normalize_db) 181 | audio_signal.ensure_max_of_audio() 182 | 183 | nb, nac, nt = audio_signal.audio_data.shape 184 | audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) 185 | win_duration = ( 186 | audio_signal.signal_duration if win_duration is None else win_duration 187 | ) 188 | 189 | if audio_signal.signal_duration <= win_duration: 190 | # Unchunked compression (used if signal length < win duration) 191 | self.padding = True 192 | n_samples = nt 193 | hop = nt 194 | else: 195 | # Chunked inference 196 | self.padding = False 197 | # Zero-pad signal on either side by the delay 198 | audio_signal.zero_pad(self.delay, self.delay) 199 | n_samples = int(win_duration * self.sample_rate) 200 | # Round n_samples to nearest hop length multiple 201 | n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) 202 | hop = self.get_output_length(n_samples) 203 | 204 | codes = [] 205 | range_fn = range if not verbose else tqdm.trange 206 | 207 | for i in range_fn(0, nt, hop): 208 | x = audio_signal[..., i : i + n_samples] 209 | x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) 210 | 211 | audio_data = x.audio_data.to(self.device) 212 | audio_data = self.preprocess(audio_data, self.sample_rate) 213 | _, c, _, _, _ = self.encode(audio_data, n_quantizers) 214 | codes.append(c.to(original_device)) 215 | chunk_length = c.shape[-1] 216 | 217 | codes = torch.cat(codes, dim=-1) 218 | 219 | dac_file = DACFile( 220 | codes=codes, 221 | chunk_length=chunk_length, 222 | original_length=original_length, 223 | input_db=input_db, 224 | channels=nac, 225 | sample_rate=original_sr, 226 | padding=self.padding, 227 | dac_version=SUPPORTED_VERSIONS[-1], 228 | ) 229 | 230 | if n_quantizers is not None: 231 | codes = codes[:, :n_quantizers, :] 232 | 233 | self.padding = original_padding 234 | return dac_file 235 | 236 | @torch.no_grad() 237 | def decompress( 238 | self, 239 | obj: Union[str, Path, DACFile], 240 | verbose: bool = False, 241 | ) -> AudioSignal: 242 | """Reconstruct audio from a given .dac file 243 | 244 | Parameters 245 | ---------- 246 | obj : Union[str, Path, DACFile] 247 | .dac file location or corresponding DACFile object. 248 | verbose : bool, optional 249 | Prints progress if True, by default False 250 | 251 | Returns 252 | ------- 253 | AudioSignal 254 | Object with the reconstructed audio 255 | """ 256 | self.eval() 257 | if isinstance(obj, (str, Path)): 258 | obj = DACFile.load(obj) 259 | 260 | original_padding = self.padding 261 | self.padding = obj.padding 262 | 263 | range_fn = range if not verbose else tqdm.trange 264 | codes = obj.codes 265 | original_device = codes.device 266 | chunk_length = obj.chunk_length 267 | recons = [] 268 | 269 | for i in range_fn(0, codes.shape[-1], chunk_length): 270 | c = codes[..., i : i + chunk_length].to(self.device) 271 | z = self.quantizer.from_codes(c)[0] 272 | r = self.decode(z) 273 | recons.append(r.to(original_device)) 274 | 275 | recons = torch.cat(recons, dim=-1) 276 | recons = AudioSignal(recons, self.sample_rate) 277 | 278 | resample_fn = recons.resample 279 | loudness_fn = recons.loudness 280 | 281 | # If audio is > 10 minutes long, use the ffmpeg versions 282 | if recons.signal_duration >= 10 * 60 * 60: 283 | resample_fn = recons.ffmpeg_resample 284 | loudness_fn = recons.ffmpeg_loudness 285 | 286 | if obj.input_db is not None: 287 | recons.normalize(obj.input_db) 288 | 289 | resample_fn(obj.sample_rate) 290 | 291 | if obj.original_length is not None: 292 | recons = recons[..., : obj.original_length] 293 | loudness_fn() 294 | recons.audio_data = recons.audio_data.reshape( 295 | -1, obj.channels, obj.original_length 296 | ) 297 | else: 298 | loudness_fn() 299 | 300 | self.padding = original_padding 301 | return recons 302 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/video_model_builder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | # Copyright 2020 Ross Wightman 4 | # Modified Model definition 5 | 6 | from collections import OrderedDict 7 | from functools import partial 8 | 9 | import torch 10 | import torch.nn as nn 11 | from timm.layers import trunc_normal_ 12 | 13 | from .vit_helper import PatchEmbed, PatchEmbed3D, DividedSpaceTimeBlock 14 | 15 | 16 | class VisionTransformer(nn.Module): 17 | """Vision Transformer with support for patch or hybrid CNN input stage""" 18 | 19 | def __init__(self, cfg): 20 | super().__init__() 21 | self.img_size = cfg.DATA.TRAIN_CROP_SIZE 22 | self.patch_size = cfg.VIT.PATCH_SIZE 23 | self.in_chans = cfg.VIT.CHANNELS 24 | if cfg.TRAIN.DATASET == "Epickitchens": 25 | self.num_classes = [97, 300] 26 | else: 27 | self.num_classes = cfg.MODEL.NUM_CLASSES 28 | self.embed_dim = cfg.VIT.EMBED_DIM 29 | self.depth = cfg.VIT.DEPTH 30 | self.num_heads = cfg.VIT.NUM_HEADS 31 | self.mlp_ratio = cfg.VIT.MLP_RATIO 32 | self.qkv_bias = cfg.VIT.QKV_BIAS 33 | self.drop_rate = cfg.VIT.DROP 34 | self.drop_path_rate = cfg.VIT.DROP_PATH 35 | self.head_dropout = cfg.VIT.HEAD_DROPOUT 36 | self.video_input = cfg.VIT.VIDEO_INPUT 37 | self.temporal_resolution = cfg.VIT.TEMPORAL_RESOLUTION 38 | self.use_mlp = cfg.VIT.USE_MLP 39 | self.num_features = self.embed_dim 40 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 41 | self.attn_drop_rate = cfg.VIT.ATTN_DROPOUT 42 | self.head_act = cfg.VIT.HEAD_ACT 43 | self.cfg = cfg 44 | 45 | # Patch Embedding 46 | self.patch_embed = PatchEmbed( 47 | img_size=224, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim 48 | ) 49 | 50 | # 3D Patch Embedding 51 | self.patch_embed_3d = PatchEmbed3D( 52 | img_size=self.img_size, 53 | temporal_resolution=self.temporal_resolution, 54 | patch_size=self.patch_size, 55 | in_chans=self.in_chans, 56 | embed_dim=self.embed_dim, 57 | z_block_size=self.cfg.VIT.PATCH_SIZE_TEMP, 58 | ) 59 | self.patch_embed_3d.proj.weight.data = torch.zeros_like(self.patch_embed_3d.proj.weight.data) 60 | 61 | # Number of patches 62 | if self.video_input: 63 | num_patches = self.patch_embed.num_patches * self.temporal_resolution 64 | else: 65 | num_patches = self.patch_embed.num_patches 66 | self.num_patches = num_patches 67 | 68 | # CLS token 69 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 70 | trunc_normal_(self.cls_token, std=0.02) 71 | 72 | # Positional embedding 73 | self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim)) 74 | self.pos_drop = nn.Dropout(p=cfg.VIT.POS_DROPOUT) 75 | trunc_normal_(self.pos_embed, std=0.02) 76 | 77 | if self.cfg.VIT.POS_EMBED == "joint": 78 | self.st_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim)) 79 | trunc_normal_(self.st_embed, std=0.02) 80 | elif self.cfg.VIT.POS_EMBED == "separate": 81 | self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim)) 82 | 83 | # Layer Blocks 84 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] 85 | if self.cfg.VIT.ATTN_LAYER == "divided": 86 | self.blocks = nn.ModuleList( 87 | [ 88 | DividedSpaceTimeBlock( 89 | attn_type=cfg.VIT.ATTN_LAYER, 90 | dim=self.embed_dim, 91 | num_heads=self.num_heads, 92 | mlp_ratio=self.mlp_ratio, 93 | qkv_bias=self.qkv_bias, 94 | drop=self.drop_rate, 95 | attn_drop=self.attn_drop_rate, 96 | drop_path=dpr[i], 97 | norm_layer=norm_layer, 98 | ) 99 | for i in range(self.depth) 100 | ] 101 | ) 102 | 103 | self.norm = norm_layer(self.embed_dim) 104 | 105 | # MLP head 106 | if self.use_mlp: 107 | hidden_dim = self.embed_dim 108 | if self.head_act == "tanh": 109 | # logging.info("Using TanH activation in MLP") 110 | act = nn.Tanh() 111 | elif self.head_act == "gelu": 112 | # logging.info("Using GELU activation in MLP") 113 | act = nn.GELU() 114 | else: 115 | # logging.info("Using ReLU activation in MLP") 116 | act = nn.ReLU() 117 | self.pre_logits = nn.Sequential( 118 | OrderedDict( 119 | [ 120 | ("fc", nn.Linear(self.embed_dim, hidden_dim)), 121 | ("act", act), 122 | ] 123 | ) 124 | ) 125 | else: 126 | self.pre_logits = nn.Identity() 127 | 128 | # Classifier Head 129 | self.head_drop = nn.Dropout(p=self.head_dropout) 130 | if isinstance(self.num_classes, (list,)) and len(self.num_classes) > 1: 131 | for a, i in enumerate(range(len(self.num_classes))): 132 | setattr(self, "head%d" % a, nn.Linear(self.embed_dim, self.num_classes[i])) 133 | else: 134 | self.head = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() 135 | 136 | # Initialize weights 137 | self.apply(self._init_weights) 138 | 139 | def _init_weights(self, m): 140 | if isinstance(m, nn.Linear): 141 | trunc_normal_(m.weight, std=0.02) 142 | if isinstance(m, nn.Linear) and m.bias is not None: 143 | nn.init.constant_(m.bias, 0) 144 | elif isinstance(m, nn.LayerNorm): 145 | nn.init.constant_(m.bias, 0) 146 | nn.init.constant_(m.weight, 1.0) 147 | 148 | @torch.jit.ignore 149 | def no_weight_decay(self): 150 | if self.cfg.VIT.POS_EMBED == "joint": 151 | return {"pos_embed", "cls_token", "st_embed"} 152 | else: 153 | return {"pos_embed", "cls_token", "temp_embed"} 154 | 155 | def get_classifier(self): 156 | return self.head 157 | 158 | def reset_classifier(self, num_classes, global_pool=""): 159 | self.num_classes = num_classes 160 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 161 | 162 | def forward_features(self, x): 163 | # if self.video_input: 164 | # x = x[0] 165 | B = x.shape[0] 166 | 167 | # Tokenize input 168 | # if self.cfg.VIT.PATCH_SIZE_TEMP > 1: 169 | # for simplicity of mapping between content dimensions (input x) and token dims (after patching) 170 | # we use the same trick as for AST (see modeling_ast.ASTModel.forward for the details): 171 | 172 | # apply patching on input 173 | x = self.patch_embed_3d(x) 174 | tok_mask = None 175 | 176 | # else: 177 | # tok_mask = None 178 | # # 2D tokenization 179 | # if self.video_input: 180 | # x = x.permute(0, 2, 1, 3, 4) 181 | # (B, T, C, H, W) = x.shape 182 | # x = x.reshape(B * T, C, H, W) 183 | 184 | # x = self.patch_embed(x) 185 | 186 | # if self.video_input: 187 | # (B2, T2, D2) = x.shape 188 | # x = x.reshape(B, T * T2, D2) 189 | 190 | # Append CLS token 191 | cls_tokens = self.cls_token.expand(B, -1, -1) 192 | x = torch.cat((cls_tokens, x), dim=1) 193 | # if tok_mask is not None: 194 | # # prepend 1(=keep) to the mask to account for the CLS token as well 195 | # tok_mask = torch.cat((torch.ones_like(tok_mask[:, [0]]), tok_mask), dim=1) 196 | 197 | # Interpolate positinoal embeddings 198 | # if self.cfg.DATA.TRAIN_CROP_SIZE != 224: 199 | # pos_embed = self.pos_embed 200 | # N = pos_embed.shape[1] - 1 201 | # npatch = int((x.size(1) - 1) / self.temporal_resolution) 202 | # class_emb = pos_embed[:, 0] 203 | # pos_embed = pos_embed[:, 1:] 204 | # dim = x.shape[-1] 205 | # pos_embed = torch.nn.functional.interpolate( 206 | # pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 207 | # scale_factor=math.sqrt(npatch / N), 208 | # mode='bicubic', 209 | # ) 210 | # pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 211 | # new_pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1) 212 | # else: 213 | new_pos_embed = self.pos_embed 214 | npatch = self.patch_embed.num_patches 215 | 216 | # Add positional embeddings to input 217 | if self.video_input: 218 | if self.cfg.VIT.POS_EMBED == "separate": 219 | cls_embed = self.pos_embed[:, 0, :].unsqueeze(1) 220 | tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1) 221 | tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1) 222 | total_pos_embed = tile_pos_embed + tile_temporal_embed 223 | total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1) 224 | x = x + total_pos_embed 225 | elif self.cfg.VIT.POS_EMBED == "joint": 226 | x = x + self.st_embed 227 | else: 228 | # image input 229 | x = x + new_pos_embed 230 | 231 | # Apply positional dropout 232 | x = self.pos_drop(x) 233 | 234 | # Encoding using transformer layers 235 | for i, blk in enumerate(self.blocks): 236 | x = blk( 237 | x, 238 | seq_len=npatch, 239 | num_frames=self.temporal_resolution, 240 | approx=self.cfg.VIT.APPROX_ATTN_TYPE, 241 | num_landmarks=self.cfg.VIT.APPROX_ATTN_DIM, 242 | tok_mask=tok_mask, 243 | ) 244 | 245 | ### v-iashin: I moved it to the forward pass 246 | # x = self.norm(x)[:, 0] 247 | # x = self.pre_logits(x) 248 | ### 249 | return x, tok_mask 250 | 251 | # def forward(self, x): 252 | # x = self.forward_features(x) 253 | # ### v-iashin: here. This should leave the same forward output as before 254 | # x = self.norm(x)[:, 0] 255 | # x = self.pre_logits(x) 256 | # ### 257 | # x = self.head_drop(x) 258 | # if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1: 259 | # output = [] 260 | # for head in range(len(self.num_classes)): 261 | # x_out = getattr(self, "head%d" % head)(x) 262 | # if not self.training: 263 | # x_out = torch.nn.functional.softmax(x_out, dim=-1) 264 | # output.append(x_out) 265 | # return output 266 | # else: 267 | # x = self.head(x) 268 | # if not self.training: 269 | # x = torch.nn.functional.softmax(x, dim=-1) 270 | # return x 271 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from loguru import logger 4 | from torchvision import transforms 5 | from torchvision.transforms import v2 6 | from diffusers.utils.torch_utils import randn_tensor 7 | from transformers import AutoTokenizer, AutoModel, ClapTextModelWithProjection 8 | from ..models.dac_vae.model.dac import DAC 9 | from ..models.synchformer import Synchformer 10 | from ..models.hifi_foley import HunyuanVideoFoley 11 | from .config_utils import load_yaml, AttributeDict 12 | from .schedulers import FlowMatchDiscreteScheduler 13 | from tqdm import tqdm 14 | 15 | def load_state_dict(model, model_path): 16 | logger.info(f"Loading model state dict from: {model_path}") 17 | state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False) 18 | 19 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 20 | 21 | if missing_keys: 22 | logger.warning(f"Missing keys in state dict ({len(missing_keys)} keys):") 23 | for key in missing_keys: 24 | logger.warning(f" - {key}") 25 | else: 26 | logger.info("No missing keys found") 27 | 28 | if unexpected_keys: 29 | logger.warning(f"Unexpected keys in state dict ({len(unexpected_keys)} keys):") 30 | for key in unexpected_keys: 31 | logger.warning(f" - {key}") 32 | else: 33 | logger.info("No unexpected keys found") 34 | 35 | logger.info("Model state dict loaded successfully") 36 | return model 37 | 38 | def load_model(model_path, config_path, device): 39 | logger.info("Starting model loading process...") 40 | logger.info(f"Configuration file: {config_path}") 41 | logger.info(f"Model weights dir: {model_path}") 42 | logger.info(f"Target device: {device}") 43 | 44 | cfg = load_yaml(config_path) 45 | logger.info("Configuration loaded successfully") 46 | 47 | # HunyuanVideoFoley 48 | logger.info("Loading HunyuanVideoFoley main model...") 49 | foley_model = HunyuanVideoFoley(cfg, dtype=torch.bfloat16, device=device).to(device=device, dtype=torch.bfloat16) 50 | foley_model = load_state_dict(foley_model, os.path.join(model_path, "hunyuanvideo_foley.pth")) 51 | foley_model.eval() 52 | logger.info("HunyuanVideoFoley model loaded and set to evaluation mode") 53 | 54 | # DAC-VAE 55 | dac_path = os.path.join(model_path, "vae_128d_48k.pth") 56 | logger.info(f"Loading DAC VAE model from: {dac_path}") 57 | try: 58 | # Try loading with the standard DAC.load method 59 | dac_model = DAC.load(dac_path) 60 | except TypeError as e: 61 | if "map_location" in str(e): 62 | # Handle the map_location conflict by manually loading the state dict 63 | logger.warning(f"DAC.load() failed with map_location conflict: {e}") 64 | logger.info("Attempting manual DAC model loading...") 65 | 66 | # Create DAC model instance with appropriate parameters for vae_128d_48k 67 | # Based on filename, this appears to be 128-dimensional latent space, 48kHz sample rate 68 | dac_model = DAC( 69 | encoder_dim=64, 70 | latent_dim=128, # 128d as indicated by filename 71 | decoder_dim=1536, 72 | sample_rate=48000, # 48k as indicated by filename 73 | continuous=False 74 | ) 75 | state_dict = torch.load(dac_path, map_location="cpu", weights_only=False) 76 | dac_model.load_state_dict(state_dict, strict=False) 77 | else: 78 | raise e 79 | 80 | dac_model = dac_model.to(device) 81 | dac_model.requires_grad_(False) 82 | dac_model.eval() 83 | logger.info("DAC VAE model loaded successfully") 84 | 85 | # Siglip2 visual-encoder 86 | logger.info("Loading SigLIP2 visual encoder...") 87 | siglip2_preprocess = transforms.Compose([ 88 | transforms.Resize((512, 512)), 89 | transforms.ToTensor(), 90 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 91 | ]) 92 | 93 | # Try multiple approaches to load SigLIP2 94 | siglip2_model = None 95 | 96 | # Method 1: Try with standard transformers AutoModel 97 | try: 98 | siglip2_model = AutoModel.from_pretrained("google/siglip2-base-patch16-512", trust_remote_code=True).to(device).eval() 99 | logger.info("SigLIP2 loaded using standard transformers") 100 | except Exception as e1: 101 | logger.warning(f"Standard transformers loading failed: {e1}") 102 | 103 | # Method 2: Try loading from local cache or downloaded weights 104 | try: 105 | from transformers import SiglipVisionModel 106 | siglip2_model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-512").to(device).eval() 107 | logger.info("SigLIP2 loaded using SiglipVisionModel (base variant)") 108 | except Exception as e2: 109 | logger.warning(f"SiglipVisionModel loading failed: {e2}") 110 | 111 | # Method 3: Try using a compatible CLIP model as fallback 112 | try: 113 | from transformers import CLIPVisionModel 114 | logger.warning("Falling back to CLIP vision model as SigLIP2 is not available") 115 | siglip2_model = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14-336").to(device).eval() 116 | logger.info("Using CLIP vision model as fallback") 117 | except Exception as e3: 118 | logger.error(f"All vision model loading attempts failed: {e3}") 119 | raise RuntimeError( 120 | "Could not load SigLIP2 vision encoder. Please ensure you have a compatible " 121 | "transformers version installed. You can try:\n" 122 | "1. pip install transformers>=4.37.0\n" 123 | "2. Or manually download the model weights" 124 | ) 125 | 126 | logger.info("SigLIP2 model and preprocessing pipeline loaded successfully") 127 | 128 | # clap text-encoder 129 | logger.info("Loading CLAP text encoder...") 130 | clap_tokenizer = AutoTokenizer.from_pretrained("laion/larger_clap_general") 131 | clap_model = ClapTextModelWithProjection.from_pretrained("laion/larger_clap_general").to(device) 132 | logger.info("CLAP tokenizer and model loaded successfully") 133 | 134 | # syncformer 135 | syncformer_path = os.path.join(model_path, "synchformer_state_dict.pth") 136 | logger.info(f"Loading Synchformer model from: {syncformer_path}") 137 | syncformer_preprocess = v2.Compose( 138 | [ 139 | v2.Resize(224, interpolation=v2.InterpolationMode.BICUBIC), 140 | v2.CenterCrop(224), 141 | v2.ToImage(), 142 | v2.ToDtype(torch.float32, scale=True), 143 | v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 144 | ] 145 | ) 146 | 147 | syncformer_model = Synchformer() 148 | syncformer_model.load_state_dict(torch.load(syncformer_path, weights_only=False, map_location="cpu")) 149 | syncformer_model = syncformer_model.to(device).eval() 150 | logger.info("Synchformer model and preprocessing pipeline loaded successfully") 151 | 152 | 153 | logger.info("Creating model dictionary with attribute access...") 154 | model_dict = AttributeDict({ 155 | 'foley_model': foley_model, 156 | 'dac_model': dac_model, 157 | 'siglip2_preprocess': siglip2_preprocess, 158 | 'siglip2_model': siglip2_model, 159 | 'clap_tokenizer': clap_tokenizer, 160 | 'clap_model': clap_model, 161 | 'syncformer_preprocess': syncformer_preprocess, 162 | 'syncformer_model': syncformer_model, 163 | 'device': device, 164 | }) 165 | 166 | logger.info("All models loaded successfully!") 167 | logger.info("Available model components:") 168 | for key in model_dict.keys(): 169 | logger.info(f" - {key}") 170 | logger.info("Models can be accessed via attribute notation (e.g., models.foley_model)") 171 | 172 | return model_dict, cfg 173 | 174 | def retrieve_timesteps( 175 | scheduler, 176 | num_inference_steps, 177 | device, 178 | **kwargs, 179 | ): 180 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 181 | timesteps = scheduler.timesteps 182 | return timesteps, num_inference_steps 183 | 184 | 185 | def prepare_latents(scheduler, batch_size, num_channels_latents, length, dtype, device): 186 | shape = (batch_size, num_channels_latents, int(length)) 187 | latents = randn_tensor(shape, device=device, dtype=dtype) 188 | 189 | # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler 190 | if hasattr(scheduler, "init_noise_sigma"): 191 | # scale the initial noise by the standard deviation required by the scheduler 192 | latents = latents * scheduler.init_noise_sigma 193 | 194 | return latents 195 | 196 | 197 | @torch.no_grad() 198 | def denoise_process(visual_feats, text_feats, audio_len_in_s, model_dict, cfg, guidance_scale=4.5, num_inference_steps=50, batch_size=1): 199 | 200 | target_dtype = model_dict.foley_model.dtype 201 | autocast_enabled = target_dtype != torch.float32 202 | device = model_dict.device 203 | 204 | scheduler = FlowMatchDiscreteScheduler( 205 | shift=cfg.diffusion_config.sample_flow_shift, 206 | reverse=cfg.diffusion_config.flow_reverse, 207 | solver=cfg.diffusion_config.flow_solver, 208 | use_flux_shift=cfg.diffusion_config.sample_use_flux_shift, 209 | flux_base_shift=cfg.diffusion_config.flux_base_shift, 210 | flux_max_shift=cfg.diffusion_config.flux_max_shift, 211 | ) 212 | 213 | timesteps, num_inference_steps = retrieve_timesteps( 214 | scheduler, 215 | num_inference_steps, 216 | device, 217 | ) 218 | 219 | latents = prepare_latents( 220 | scheduler, 221 | batch_size=batch_size, 222 | num_channels_latents=cfg.model_config.model_kwargs.audio_vae_latent_dim, 223 | length=audio_len_in_s * cfg.model_config.model_kwargs.audio_frame_rate, 224 | dtype=target_dtype, 225 | device=device, 226 | ) 227 | 228 | # Denoise loop 229 | for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Denoising steps"): 230 | # noise latents 231 | latent_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents 232 | latent_input = scheduler.scale_model_input(latent_input, t) 233 | 234 | t_expand = t.repeat(latent_input.shape[0]) 235 | 236 | # siglip2 features 237 | siglip2_feat = visual_feats.siglip2_feat.repeat(batch_size, 1, 1) # Repeat for batch_size 238 | uncond_siglip2_feat = model_dict.foley_model.get_empty_clip_sequence( 239 | bs=batch_size, len=siglip2_feat.shape[1] 240 | ).to(device) 241 | 242 | if guidance_scale is not None and guidance_scale > 1.0: 243 | siglip2_feat_input = torch.cat([uncond_siglip2_feat, siglip2_feat], dim=0) 244 | else: 245 | siglip2_feat_input = siglip2_feat 246 | 247 | # syncformer features 248 | syncformer_feat = visual_feats.syncformer_feat.repeat(batch_size, 1, 1) # Repeat for batch_size 249 | uncond_syncformer_feat = model_dict.foley_model.get_empty_sync_sequence( 250 | bs=batch_size, len=syncformer_feat.shape[1] 251 | ).to(device) 252 | if guidance_scale is not None and guidance_scale > 1.0: 253 | syncformer_feat_input = torch.cat([uncond_syncformer_feat, syncformer_feat], dim=0) 254 | else: 255 | syncformer_feat_input = syncformer_feat 256 | 257 | # text features 258 | text_feat_repeated = text_feats.text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size 259 | uncond_text_feat_repeated = text_feats.uncond_text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size 260 | if guidance_scale is not None and guidance_scale > 1.0: 261 | text_feat_input = torch.cat([uncond_text_feat_repeated, text_feat_repeated], dim=0) 262 | else: 263 | text_feat_input = text_feat_repeated 264 | 265 | with torch.autocast(device_type=device.type, enabled=autocast_enabled, dtype=target_dtype): 266 | # Predict the noise residual 267 | noise_pred = model_dict.foley_model( 268 | x=latent_input, 269 | t=t_expand, 270 | cond=text_feat_input, 271 | clip_feat=siglip2_feat_input, 272 | sync_feat=syncformer_feat_input, 273 | return_dict=True, 274 | )["x"] 275 | 276 | noise_pred = noise_pred.to(dtype=torch.float32) 277 | 278 | if guidance_scale is not None and guidance_scale > 1.0: 279 | # Perform classifier-free guidance 280 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 281 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 282 | 283 | # Compute the previous noisy sample x_t -> x_t-1 284 | latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] 285 | 286 | # Post-process the latents to audio 287 | 288 | with torch.no_grad(): 289 | audio = model_dict.dac_model.decode(latents) 290 | audio = audio.float().cpu() 291 | 292 | audio = audio[:, :int(audio_len_in_s*model_dict.dac_model.sample_rate)] 293 | 294 | return audio, model_dict.dac_model.sample_rate 295 | 296 | 297 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI HunyuanVideo-Foley Custom Node 2 | 3 | This is a ComfyUI custom node wrapper for the HunyuanVideo-Foley model, which generates realistic audio from video and text descriptions. 4 | 5 | ## Features 6 | 7 | - **Text-Video-to-Audio Synthesis**: Generate realistic audio that matches your video content 8 | - **Flexible Text Prompts**: Use optional text descriptions to guide audio generation 9 | - **Multiple Samples**: Generate up to 6 different audio variations per inference 10 | - **Configurable Parameters**: Control guidance scale, inference steps, and sampling 11 | - **Seed Control**: Reproducible results with seed parameter 12 | - **Model Caching**: Efficient model loading and reuse across generations 13 | - **Automatic Model Downloads**: Models are automatically downloaded to `ComfyUI/models/foley/` when needed 14 | image 15 | 16 | 17 | ## Features 18 | 19 | - **Text-Video-to-Audio Synthesis**: Generate realistic audio that matches your video content 20 | - **Flexible Text Prompts**: Use optional text descriptions to guide audio generation 21 | - **Multiple Samples**: Generate up to 6 different audio variations per inference 22 | - **Configurable Parameters**: Control guidance scale, inference steps, and sampling 23 | - **Seed Control**: Reproducible results with seed parameter 24 | - **Model Caching**: Efficient model loading and reuse across generations 25 | - **Automatic Model Downloads**: Models are automatically downloaded to `ComfyUI/models/foley/` when needed 26 | 27 | ## Installation 28 | 29 | 1. **Clone this repository** into your ComfyUI custom_nodes directory: 30 | ```bash 31 | cd ComfyUI/custom_nodes 32 | git clone https://github.com/if-ai/ComfyUI_HunyuanVideoFoley.git 33 | ``` 34 | 35 | 2. **Install dependencies**: 36 | ```bash 37 | cd ComfyUI_HunyuanVideoFoley 38 | pip install -r requirements.txt 39 | ``` 40 | 41 | 3. **Run the installation script** (recommended): 42 | ```bash 43 | python install.py 44 | ``` 45 | 46 | 4. **Restart ComfyUI** to load the new nodes. 47 | 48 | ### Model Setup 49 | 50 | The models can be obtained in two ways: 51 | 52 | #### Option 1: Automatic Download (Recommended) 53 | - Models will be automatically downloaded to `ComfyUI/models/foley/` when you first run the node 54 | - No manual setup required 55 | - Progress will be shown in the ComfyUI console 56 | 57 | #### Option 2: Manual Download 58 | - Download models from [HuggingFace](https://huggingface.co/tencent/HunyuanVideo-Foley) 59 | - Place models in `ComfyUI/models/foley/` (recommended) or `./pretrained_models/` directory 60 | - Ensure the config file is at `configs/hunyuanvideo-foley-xxl.yaml` 61 | 62 | ## Operation Guide: How to Use the Nodes 63 | 64 | This custom node package is designed in a modular way for maximum flexibility and efficiency. Here is the recommended workflow and an explanation of what each node does. 65 | 66 | ### Recommended Workflow 67 | 68 | The most powerful and efficient way to use these nodes is to chain them together in the following order: 69 | 70 | `Model Loader` → `Dependencies Loader` → `Torch Compile` → `Generator (Advanced)` 71 | 72 | This setup allows you to load the models only once, apply performance optimizations, and then run the generator multiple times without reloading, saving significant time and VRAM. 73 | 74 | ### Node Details 75 | 76 | #### 1. HunyuanVideo-Foley Model Loader (FP8) 77 | This is the starting point. It loads the main (and very large) audio generation model into memory. 78 | 79 | - **quantization**: This is the most important setting for saving VRAM. 80 | - `none`: Loads the model in its original format (highest VRAM usage). 81 | - `fp8_e5m2` / `fp8_e4m3fn`: These options use **FP8 quantization**, a technique that stores the model's weights in a much smaller format. This can save several gigabytes of VRAM with a minimal impact on audio quality, making it possible to run on GPUs with less memory. 82 | - **cpu_offload**: If `True`, the model will be kept in your regular RAM instead of VRAM. This is not the same as the generator's offload setting; use this if you are loading multiple different models in your workflow and need to conserve VRAM. 83 | 84 | #### 2. HunyuanVideo-Foley Dependencies 85 | This node takes the main model from the loader and then loads all the smaller, auxiliary models required for the process (the VAE, text encoder, and visual feature extractors). 86 | 87 | #### 3. HunyuanVideo-Foley Torch Compile 88 | This is an optional but highly recommended performance-enhancing node. It uses `torch.compile` to optimize the model's code for your specific hardware. 89 | - **Note**: The very first time you run a workflow with this node, it will take a minute or two to perform the compilation. However, every subsequent run will be significantly faster (often 20-30%). 90 | 91 | - **`compile_mode`**: This controls the trade-off between compilation time and the amount of performance gain. 92 | - `default`: The best balance. It provides a good speedup with a reasonable initial compile time. 93 | - `reduce-overhead`: Compiles more slowly but can reduce the overhead of running the model, which might be faster for very small audio generations. 94 | - `max-autotune`: Takes the longest to compile initially, but it tries many different optimizations to find the absolute fastest option for your specific hardware. 95 | 96 | - **`backend`**: This is an advanced setting that changes the underlying compiler used by PyTorch. For most users, the default `inductor` is the best choice. 97 | 98 | #### 4. HunyuanVideo-Foley Generator (Advanced) 99 | This is the main workhorse node where the audio generation happens. 100 | 101 | - **video / images**: Your visual input. You can provide either a video file or a batch of images from another node. 102 | - **compiled_model**: The input for the model prepared by the upstream nodes. 103 | - **text_prompt / negative_prompt**: Your descriptions of the sound you want (and don't want). 104 | - **guidance_scale / num_inference_steps / seed**: Standard diffusion model controls for creativity vs. prompt adherence, quality vs. speed, and reproducibility. 105 | - **enabled**: A simple switch. If `False`, the node does nothing and passes through an empty/silent output. This is useful for disabling parts of a complex workflow without having to disconnect them. 106 | - **silent_audio**: Controls what happens when the node is disabled or fails. If `True`, it outputs a valid, silent audio clip, which prevents downstream nodes (like video combiners) from failing. If `False`, it outputs `None`. 107 | 108 | ### Understanding the Memory Options 109 | 110 | The two memory-related checkboxes on the Generator node are crucial for managing your GPU's resources. Here is exactly what they do: 111 | 112 | - **`cpu_offload`**: 113 | - **What it does:** If this is `True`, the node will always move the models to your regular RAM (CPU) after the generation is complete. This is the best option for freeing up VRAM for other nodes in your workflow while still keeping the models ready for the next run without having to reload them from disk. 114 | - **Use this when:** You want to run other VRAM-intensive nodes after this one and plan to come back to the Foley generator later. 115 | 116 | - **`memory_efficient`**: 117 | - **What it does:** This is a more aggressive option. If `True`, the node will completely unload the models from memory (both VRAM and RAM) after the generation is finished. 118 | - **Important Distinction:** This process is smart. It will **only** unload the model if it was loaded by the generator node itself (the simple workflow). If the model was passed in from the `HunyuanVideoFoleyModelLoader` (the advanced workflow), it will **not** unload it, respecting the fact that you may want to reuse the pre-loaded model for another generation. 119 | - **Use this when:** You are finished with audio generation and want to free up as much memory as possible for completely different tasks. 120 | 121 | ### Performance Tuning & VRAM Usage 122 | 123 | The most memory-intensive part of the process is visual feature extraction. We've implemented batched processing to prevent out-of-memory errors with longer videos or on GPUs with less VRAM. You can control this with two settings on the **Generator (Advanced)** node: 124 | 125 | - **`feature_extraction_batch_size`**: This determines how many video frames are processed by the feature extractor models at once. 126 | - **Lower values** significantly reduce peak VRAM usage at the cost of slightly slower processing. 127 | - **Higher values** speed up processing but require more VRAM. 128 | 129 | - **`enable_profiling`**: If you check this box, the node will print detailed performance timings and peak VRAM usage for the feature extraction step to the console. This is highly recommended for finding the optimal batch size for your specific hardware. 130 | 131 | #### Recommended Batch Sizes 132 | 133 | These are general starting points. The optimal value can vary based on your exact GPU, driver version, and other running processes. 134 | 135 | | VRAM Tier | Video Resolution | Recommended Batch Size | Notes | 136 | | :--- | :--- | :--- | :--- | 137 | | **≤ 8 GB** | 480p | 4 - 8 | Start with 4. If successful, you can try increasing it. | 138 | | | 720p | 2 - 4 | Start with 2. 720p videos are demanding on low VRAM cards. | 139 | | **12-16 GB** | 480p | 16 - 32 | The default of 16 should work well. Can be increased for more speed. | 140 | | | 720p | 8 - 16 | Start with 8 or 16. | 141 | | **≥ 24 GB**| 480p | 32 - 64 | You can safely increase the batch size for maximum performance. | 142 | | | 720p | 16 - 32 | A batch size of 32 should be easily achievable. | 143 | 144 | ## Usage 145 | 146 | ### Node Types 147 | 148 | #### 1. HunyuanVideo-Foley Generator 149 | Main node for generating audio from video and text. 150 | 151 | **Inputs:** 152 | - **video**: Video input (VIDEO type) 153 | - **text_prompt**: Text description of desired audio (STRING) 154 | - **guidance_scale**: CFG scale for generation control (1.0-10.0, default: 4.5) 155 | - **num_inference_steps**: Number of denoising steps (10-100, default: 50) 156 | - **sample_nums**: Number of audio samples to generate (1-6, default: 1) 157 | - **seed**: Random seed for reproducibility (INT) 158 | - **model_path**: Path to pretrained models (optional, leave empty for auto-download) 159 | - **enabled**: Enable or disable the entire node. If disabled, it will pass through a silent or null audio output without processing. (BOOLEAN, default: True) 160 | - **silent_audio**: Controls the output when the node is disabled or fails. If true, it outputs a silent audio clip. If false, it outputs `None`. (BOOLEAN, default: True) 161 | 162 | **Outputs:** 163 | - **video_with_audio**: Video with generated audio merged (VIDEO) 164 | - **audio_only**: Generated audio file (AUDIO) 165 | - **status_message**: Generation status and info (STRING) 166 | 167 | ## ⚠ Important Limitations 168 | 169 | ### **Frame Count & Duration Limits** 170 | - **Maximum Frames**: 450 frames (hard limit) 171 | - **Maximum Duration**: 15 seconds at 30fps 172 | - **Recommended**: Keep videos ≤15 seconds for best results 173 | 174 | ### **FPS Recommendations** 175 | - **30fps**: Max 15 seconds (450 frames) 176 | - **24fps**: Max 18.75 seconds (450 frames)   177 | - **15fps**: Max 30 seconds (450 frames) 178 | 179 | ### **Long Video Solutions** 180 | For videos longer than 15 seconds: 181 | 1. **Reduce FPS**: Lower FPS allows longer duration within frame limit 182 | 2. **Segment Processing**: Split long videos into 15s segments 183 | 3. **Audio Merging**: Combine generated audio segments in post-processing 184 | 185 | 186 | ## Example Workflow 187 | 188 | 1. **Load Video**: Use a "Load Video" node to input your video file 189 | 2. **Add Generator**: Add the "HunyuanVideo-Foley Generator" node 190 | 3. **Connect Video**: Connect the video output to the generator's video input 191 | 4. **Set Prompt**: Enter a text description (e.g., "A person walks on frozen ice") 192 | 5. **Adjust Settings**: Configure guidance scale, steps, and sample count as needed 193 | 6. **Generate**: Run the workflow to generate audio 194 | 195 | ## Model Requirements 196 | 197 | The node expects the following model structure: 198 | ``` 199 | ComfyUI\models\foley\hunyuanvideo-foley-xxl 200 | ├── hunyuanvideo_foley.pth # Main Foley model 201 | ├── vae_128d_48k.pth # DAC VAE model 202 | └── synchformer_state_dict.pth # Synchformer model 203 | 204 | configs/ 205 | └── hunyuanvideo-foley-xxl.yaml # Configuration file 206 | ``` 207 | 208 | ## TODO 209 | - [x] ADD VHS INPUT/OUTPUTS (Thanks to YC) 210 | - [x] NEGATIVE PROMPT (Thanks to YC) 211 | - [x] MODEL OFFLOADING OPS 212 | - [x] TORCH COMPILE 213 | - [ ] QUANTISE MODEL 214 | 215 | 216 | ## Support 217 | 218 | If you find this tool useful, please consider supporting my work by: 219 | 220 | - Starring this repository on GitHub 221 | - Subscribing to my YouTube channel: [Impact Frames](https://youtube.com/@impactframes?si=DrBu3tOAC2-YbEvc) 222 | - Following on X: [@ImpactFrames](https://x.com/ImpactFramesX) 223 | 224 | You can also support by reporting issues or suggesting features. Your contributions help me bring updates and improvements to the project. 225 | 226 | 227 | 228 | ## License 229 | 230 | This custom node is based on the HunyuanVideo-Foley project. Please check the original project's license terms. 231 | 232 | ## Credits 233 | 234 | Based on the HunyuanVideo-Foley project by Tencent. Original paper and code available at: 235 | - Paper: [HunyuanVideo-Foley: Text-Video-to-Audio Synthesis] 236 | 237 | - Code: [https://github.com/tencent/HunyuanVideo-Foley] 238 | 239 | :IFAIloadImages_comfy 240 | 241 | 242 | 243 | 244 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/loss.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from typing import List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from audiotools import AudioSignal 7 | from audiotools import STFTParams 8 | from torch import nn 9 | 10 | 11 | class L1Loss(nn.L1Loss): 12 | """L1 Loss between AudioSignals. Defaults 13 | to comparing ``audio_data``, but any 14 | attribute of an AudioSignal can be used. 15 | 16 | Parameters 17 | ---------- 18 | attribute : str, optional 19 | Attribute of signal to compare, defaults to ``audio_data``. 20 | weight : float, optional 21 | Weight of this loss, defaults to 1.0. 22 | 23 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py 24 | """ 25 | 26 | def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): 27 | self.attribute = attribute 28 | self.weight = weight 29 | super().__init__(**kwargs) 30 | 31 | def forward(self, x: AudioSignal, y: AudioSignal): 32 | """ 33 | Parameters 34 | ---------- 35 | x : AudioSignal 36 | Estimate AudioSignal 37 | y : AudioSignal 38 | Reference AudioSignal 39 | 40 | Returns 41 | ------- 42 | torch.Tensor 43 | L1 loss between AudioSignal attributes. 44 | """ 45 | if isinstance(x, AudioSignal): 46 | x = getattr(x, self.attribute) 47 | y = getattr(y, self.attribute) 48 | return super().forward(x, y) 49 | 50 | 51 | class SISDRLoss(nn.Module): 52 | """ 53 | Computes the Scale-Invariant Source-to-Distortion Ratio between a batch 54 | of estimated and reference audio signals or aligned features. 55 | 56 | Parameters 57 | ---------- 58 | scaling : int, optional 59 | Whether to use scale-invariant (True) or 60 | signal-to-noise ratio (False), by default True 61 | reduction : str, optional 62 | How to reduce across the batch (either 'mean', 63 | 'sum', or none).], by default ' mean' 64 | zero_mean : int, optional 65 | Zero mean the references and estimates before 66 | computing the loss, by default True 67 | clip_min : int, optional 68 | The minimum possible loss value. Helps network 69 | to not focus on making already good examples better, by default None 70 | weight : float, optional 71 | Weight of this loss, defaults to 1.0. 72 | 73 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py 74 | """ 75 | 76 | def __init__( 77 | self, 78 | scaling: int = True, 79 | reduction: str = "mean", 80 | zero_mean: int = True, 81 | clip_min: int = None, 82 | weight: float = 1.0, 83 | ): 84 | self.scaling = scaling 85 | self.reduction = reduction 86 | self.zero_mean = zero_mean 87 | self.clip_min = clip_min 88 | self.weight = weight 89 | super().__init__() 90 | 91 | def forward(self, x: AudioSignal, y: AudioSignal): 92 | eps = 1e-8 93 | # nb, nc, nt 94 | if isinstance(x, AudioSignal): 95 | references = x.audio_data 96 | estimates = y.audio_data 97 | else: 98 | references = x 99 | estimates = y 100 | 101 | nb = references.shape[0] 102 | references = references.reshape(nb, 1, -1).permute(0, 2, 1) 103 | estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) 104 | 105 | # samples now on axis 1 106 | if self.zero_mean: 107 | mean_reference = references.mean(dim=1, keepdim=True) 108 | mean_estimate = estimates.mean(dim=1, keepdim=True) 109 | else: 110 | mean_reference = 0 111 | mean_estimate = 0 112 | 113 | _references = references - mean_reference 114 | _estimates = estimates - mean_estimate 115 | 116 | references_projection = (_references**2).sum(dim=-2) + eps 117 | references_on_estimates = (_estimates * _references).sum(dim=-2) + eps 118 | 119 | scale = ( 120 | (references_on_estimates / references_projection).unsqueeze(1) 121 | if self.scaling 122 | else 1 123 | ) 124 | 125 | e_true = scale * _references 126 | e_res = _estimates - e_true 127 | 128 | signal = (e_true**2).sum(dim=1) 129 | noise = (e_res**2).sum(dim=1) 130 | sdr = -10 * torch.log10(signal / noise + eps) 131 | 132 | if self.clip_min is not None: 133 | sdr = torch.clamp(sdr, min=self.clip_min) 134 | 135 | if self.reduction == "mean": 136 | sdr = sdr.mean() 137 | elif self.reduction == "sum": 138 | sdr = sdr.sum() 139 | return sdr 140 | 141 | 142 | class MultiScaleSTFTLoss(nn.Module): 143 | """Computes the multi-scale STFT loss from [1]. 144 | 145 | Parameters 146 | ---------- 147 | window_lengths : List[int], optional 148 | Length of each window of each STFT, by default [2048, 512] 149 | loss_fn : typing.Callable, optional 150 | How to compare each loss, by default nn.L1Loss() 151 | clamp_eps : float, optional 152 | Clamp on the log magnitude, below, by default 1e-5 153 | mag_weight : float, optional 154 | Weight of raw magnitude portion of loss, by default 1.0 155 | log_weight : float, optional 156 | Weight of log magnitude portion of loss, by default 1.0 157 | pow : float, optional 158 | Power to raise magnitude to before taking log, by default 2.0 159 | weight : float, optional 160 | Weight of this loss, by default 1.0 161 | match_stride : bool, optional 162 | Whether to match the stride of convolutional layers, by default False 163 | 164 | References 165 | ---------- 166 | 167 | 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. 168 | "DDSP: Differentiable Digital Signal Processing." 169 | International Conference on Learning Representations. 2019. 170 | 171 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py 172 | """ 173 | 174 | def __init__( 175 | self, 176 | window_lengths: List[int] = [2048, 512], 177 | loss_fn: typing.Callable = nn.L1Loss(), 178 | clamp_eps: float = 1e-5, 179 | mag_weight: float = 1.0, 180 | log_weight: float = 1.0, 181 | pow: float = 2.0, 182 | weight: float = 1.0, 183 | match_stride: bool = False, 184 | window_type: str = None, 185 | ): 186 | super().__init__() 187 | self.stft_params = [ 188 | STFTParams( 189 | window_length=w, 190 | hop_length=w // 4, 191 | match_stride=match_stride, 192 | window_type=window_type, 193 | ) 194 | for w in window_lengths 195 | ] 196 | self.loss_fn = loss_fn 197 | self.log_weight = log_weight 198 | self.mag_weight = mag_weight 199 | self.clamp_eps = clamp_eps 200 | self.weight = weight 201 | self.pow = pow 202 | 203 | def forward(self, x: AudioSignal, y: AudioSignal): 204 | """Computes multi-scale STFT between an estimate and a reference 205 | signal. 206 | 207 | Parameters 208 | ---------- 209 | x : AudioSignal 210 | Estimate signal 211 | y : AudioSignal 212 | Reference signal 213 | 214 | Returns 215 | ------- 216 | torch.Tensor 217 | Multi-scale STFT loss. 218 | """ 219 | loss = 0.0 220 | for s in self.stft_params: 221 | x.stft(s.window_length, s.hop_length, s.window_type) 222 | y.stft(s.window_length, s.hop_length, s.window_type) 223 | loss += self.log_weight * self.loss_fn( 224 | x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), 225 | y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), 226 | ) 227 | loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) 228 | return loss 229 | 230 | 231 | class MelSpectrogramLoss(nn.Module): 232 | """Compute distance between mel spectrograms. Can be used 233 | in a multi-scale way. 234 | 235 | Parameters 236 | ---------- 237 | n_mels : List[int] 238 | Number of mels per STFT, by default [150, 80], 239 | window_lengths : List[int], optional 240 | Length of each window of each STFT, by default [2048, 512] 241 | loss_fn : typing.Callable, optional 242 | How to compare each loss, by default nn.L1Loss() 243 | clamp_eps : float, optional 244 | Clamp on the log magnitude, below, by default 1e-5 245 | mag_weight : float, optional 246 | Weight of raw magnitude portion of loss, by default 1.0 247 | log_weight : float, optional 248 | Weight of log magnitude portion of loss, by default 1.0 249 | pow : float, optional 250 | Power to raise magnitude to before taking log, by default 2.0 251 | weight : float, optional 252 | Weight of this loss, by default 1.0 253 | match_stride : bool, optional 254 | Whether to match the stride of convolutional layers, by default False 255 | 256 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py 257 | """ 258 | 259 | def __init__( 260 | self, 261 | n_mels: List[int] = [150, 80], 262 | window_lengths: List[int] = [2048, 512], 263 | loss_fn: typing.Callable = nn.L1Loss(), 264 | clamp_eps: float = 1e-5, 265 | mag_weight: float = 1.0, 266 | log_weight: float = 1.0, 267 | pow: float = 2.0, 268 | weight: float = 1.0, 269 | match_stride: bool = False, 270 | mel_fmin: List[float] = [0.0, 0.0], 271 | mel_fmax: List[float] = [None, None], 272 | window_type: str = None, 273 | ): 274 | super().__init__() 275 | self.stft_params = [ 276 | STFTParams( 277 | window_length=w, 278 | hop_length=w // 4, 279 | match_stride=match_stride, 280 | window_type=window_type, 281 | ) 282 | for w in window_lengths 283 | ] 284 | self.n_mels = n_mels 285 | self.loss_fn = loss_fn 286 | self.clamp_eps = clamp_eps 287 | self.log_weight = log_weight 288 | self.mag_weight = mag_weight 289 | self.weight = weight 290 | self.mel_fmin = mel_fmin 291 | self.mel_fmax = mel_fmax 292 | self.pow = pow 293 | 294 | def forward(self, x: AudioSignal, y: AudioSignal): 295 | """Computes mel loss between an estimate and a reference 296 | signal. 297 | 298 | Parameters 299 | ---------- 300 | x : AudioSignal 301 | Estimate signal 302 | y : AudioSignal 303 | Reference signal 304 | 305 | Returns 306 | ------- 307 | torch.Tensor 308 | Mel loss. 309 | """ 310 | loss = 0.0 311 | for n_mels, fmin, fmax, s in zip( 312 | self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params 313 | ): 314 | kwargs = { 315 | "window_length": s.window_length, 316 | "hop_length": s.hop_length, 317 | "window_type": s.window_type, 318 | } 319 | x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) 320 | y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) 321 | 322 | loss += self.log_weight * self.loss_fn( 323 | x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), 324 | y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), 325 | ) 326 | loss += self.mag_weight * self.loss_fn(x_mels, y_mels) 327 | return loss 328 | 329 | 330 | class GANLoss(nn.Module): 331 | """ 332 | Computes a discriminator loss, given a discriminator on 333 | generated waveforms/spectrograms compared to ground truth 334 | waveforms/spectrograms. Computes the loss for both the 335 | discriminator and the generator in separate functions. 336 | """ 337 | 338 | def __init__(self, discriminator): 339 | super().__init__() 340 | self.discriminator = discriminator 341 | 342 | def forward(self, fake, real): 343 | d_fake = self.discriminator(fake.audio_data) 344 | d_real = self.discriminator(real.audio_data) 345 | return d_fake, d_real 346 | 347 | def discriminator_loss(self, fake, real): 348 | d_fake, d_real = self.forward(fake.clone().detach(), real) 349 | 350 | loss_d = 0 351 | for x_fake, x_real in zip(d_fake, d_real): 352 | loss_d += torch.mean(x_fake[-1] ** 2) 353 | loss_d += torch.mean((1 - x_real[-1]) ** 2) 354 | return loss_d 355 | 356 | def generator_loss(self, fake, real): 357 | d_fake, d_real = self.forward(fake, real) 358 | 359 | loss_g = 0 360 | for x_fake in d_fake: 361 | loss_g += torch.mean((1 - x_fake[-1]) ** 2) 362 | 363 | loss_feature = 0 364 | 365 | for i in range(len(d_fake)): 366 | for j in range(len(d_fake[i]) - 1): 367 | loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) 368 | return loss_g, loss_feature 369 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/synchformer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from typing import Any, Mapping 4 | 5 | import einops 6 | import numpy as np 7 | import torch 8 | import torchaudio 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | from .motionformer import MotionFormer 13 | from .ast_model import AST 14 | from .utils import Config 15 | 16 | 17 | class Synchformer(nn.Module): 18 | 19 | def __init__(self): 20 | super().__init__() 21 | 22 | self.vfeat_extractor = MotionFormer( 23 | extract_features=True, 24 | factorize_space_time=True, 25 | agg_space_module="TransformerEncoderLayer", 26 | agg_time_module="torch.nn.Identity", 27 | add_global_repr=False, 28 | ) 29 | self.afeat_extractor = AST( 30 | extract_features=True, 31 | max_spec_t=66, 32 | factorize_freq_time=True, 33 | agg_freq_module="TransformerEncoderLayer", 34 | agg_time_module="torch.nn.Identity", 35 | add_global_repr=False, 36 | ) 37 | 38 | # # bridging the s3d latent dim (1024) into what is specified in the config 39 | # # to match e.g. the transformer dim 40 | self.vproj = nn.Linear(in_features=768, out_features=768) 41 | self.aproj = nn.Linear(in_features=768, out_features=768) 42 | self.transformer = GlobalTransformer( 43 | tok_pdrop=0.0, embd_pdrop=0.1, resid_pdrop=0.1, attn_pdrop=0.1, n_layer=3, n_head=8, n_embd=768 44 | ) 45 | 46 | def forward(self, vis): 47 | B, S, Tv, C, H, W = vis.shape 48 | vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W) 49 | # feat extractors return a tuple of segment-level and global features (ignored for sync) 50 | # (B, S, tv, D), e.g. (B, 7, 8, 768) 51 | vis = self.vfeat_extractor(vis) 52 | return vis 53 | 54 | def compare_v_a(self, vis: torch.Tensor, aud: torch.Tensor): 55 | vis = self.vproj(vis) 56 | aud = self.aproj(aud) 57 | 58 | B, S, tv, D = vis.shape 59 | B, S, ta, D = aud.shape 60 | vis = vis.view(B, S * tv, D) # (B, S*tv, D) 61 | aud = aud.view(B, S * ta, D) # (B, S*ta, D) 62 | # print(vis.shape, aud.shape) 63 | 64 | # self.transformer will concatenate the vis and aud in one sequence with aux tokens, 65 | # ie `CvvvvMaaaaaa`, and will return the logits for the CLS tokens 66 | logits = self.transformer(vis, aud) # (B, cls); or (B, cls) and (B, 2) if DoubtingTransformer 67 | 68 | return logits 69 | 70 | def extract_vfeats(self, vis): 71 | B, S, Tv, C, H, W = vis.shape 72 | vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W) 73 | # feat extractors return a tuple of segment-level and global features (ignored for sync) 74 | # (B, S, tv, D), e.g. (B, 7, 8, 768) 75 | vis = self.vfeat_extractor(vis) 76 | return vis 77 | 78 | def extract_afeats(self, aud): 79 | B, S, _, Fa, Ta = aud.shape 80 | aud = aud.view(B, S, Fa, Ta).permute(0, 1, 3, 2) # (B, S, Ta, F) 81 | # (B, S, ta, D), e.g. (B, 7, 6, 768) 82 | aud, _ = self.afeat_extractor(aud) 83 | return aud 84 | 85 | def compute_loss(self, logits, targets, loss_fn: str = None): 86 | loss = None 87 | if targets is not None: 88 | if loss_fn is None or loss_fn == "cross_entropy": 89 | # logits: (B, cls) and targets: (B,) 90 | loss = F.cross_entropy(logits, targets) 91 | else: 92 | raise NotImplementedError(f"Loss {loss_fn} not implemented") 93 | return loss 94 | 95 | def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True): 96 | # discard all entries except vfeat_extractor 97 | # sd = {k: v for k, v in sd.items() if k.startswith('vfeat_extractor')} 98 | 99 | return super().load_state_dict(sd, strict) 100 | 101 | 102 | class RandInitPositionalEncoding(nn.Module): 103 | """Random inited trainable pos embedding. It is just applied on the sequence, thus respects no priors.""" 104 | 105 | def __init__(self, block_shape: list, n_embd: int): 106 | super().__init__() 107 | self.block_shape = block_shape 108 | self.n_embd = n_embd 109 | self.pos_emb = nn.Parameter(torch.randn(1, *block_shape, n_embd)) 110 | 111 | def forward(self, token_embeddings): 112 | return token_embeddings + self.pos_emb 113 | 114 | 115 | class GlobalTransformer(torch.nn.Module): 116 | """Same as in SparseSync but without the selector transformers and the head""" 117 | 118 | def __init__( 119 | self, 120 | tok_pdrop=0.0, 121 | embd_pdrop=0.1, 122 | resid_pdrop=0.1, 123 | attn_pdrop=0.1, 124 | n_layer=3, 125 | n_head=8, 126 | n_embd=768, 127 | pos_emb_block_shape=[ 128 | 198, 129 | ], 130 | n_off_head_out=21, 131 | ) -> None: 132 | super().__init__() 133 | self.config = Config( 134 | embd_pdrop=embd_pdrop, 135 | resid_pdrop=resid_pdrop, 136 | attn_pdrop=attn_pdrop, 137 | n_layer=n_layer, 138 | n_head=n_head, 139 | n_embd=n_embd, 140 | ) 141 | # input norm 142 | self.vis_in_lnorm = torch.nn.LayerNorm(n_embd) 143 | self.aud_in_lnorm = torch.nn.LayerNorm(n_embd) 144 | # aux tokens 145 | self.OFF_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd)) 146 | self.MOD_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd)) 147 | # whole token dropout 148 | self.tok_pdrop = tok_pdrop 149 | self.tok_drop_vis = torch.nn.Dropout1d(tok_pdrop) 150 | self.tok_drop_aud = torch.nn.Dropout1d(tok_pdrop) 151 | # maybe add pos emb 152 | self.pos_emb_cfg = RandInitPositionalEncoding( 153 | block_shape=pos_emb_block_shape, 154 | n_embd=n_embd, 155 | ) 156 | # the stem 157 | self.drop = torch.nn.Dropout(embd_pdrop) 158 | self.blocks = torch.nn.Sequential(*[Block(self.config) for _ in range(n_layer)]) 159 | # pre-output norm 160 | self.ln_f = torch.nn.LayerNorm(n_embd) 161 | # maybe add a head 162 | self.off_head = torch.nn.Linear(in_features=n_embd, out_features=n_off_head_out) 163 | 164 | def forward(self, v: torch.Tensor, a: torch.Tensor, targets=None, attempt_to_apply_heads=True): 165 | B, Sv, D = v.shape 166 | B, Sa, D = a.shape 167 | # broadcasting special tokens to the batch size 168 | off_tok = einops.repeat(self.OFF_tok, "1 1 d -> b 1 d", b=B) 169 | mod_tok = einops.repeat(self.MOD_tok, "1 1 d -> b 1 d", b=B) 170 | # norm 171 | v, a = self.vis_in_lnorm(v), self.aud_in_lnorm(a) 172 | # maybe whole token dropout 173 | if self.tok_pdrop > 0: 174 | v, a = self.tok_drop_vis(v), self.tok_drop_aud(a) 175 | # (B, 1+Sv+1+Sa, D) 176 | x = torch.cat((off_tok, v, mod_tok, a), dim=1) 177 | # maybe add pos emb 178 | if hasattr(self, "pos_emb_cfg"): 179 | x = self.pos_emb_cfg(x) 180 | # dropout -> stem -> norm 181 | x = self.drop(x) 182 | x = self.blocks(x) 183 | x = self.ln_f(x) 184 | # maybe add heads 185 | if attempt_to_apply_heads and hasattr(self, "off_head"): 186 | x = self.off_head(x[:, 0, :]) 187 | return x 188 | 189 | 190 | class SelfAttention(nn.Module): 191 | """ 192 | A vanilla multi-head masked self-attention layer with a projection at the end. 193 | It is possible to use torch.nn.MultiheadAttention here but I am including an 194 | explicit implementation here to show that there is nothing too scary here. 195 | """ 196 | 197 | def __init__(self, config): 198 | super().__init__() 199 | assert config.n_embd % config.n_head == 0 200 | # key, query, value projections for all heads 201 | self.key = nn.Linear(config.n_embd, config.n_embd) 202 | self.query = nn.Linear(config.n_embd, config.n_embd) 203 | self.value = nn.Linear(config.n_embd, config.n_embd) 204 | # regularization 205 | self.attn_drop = nn.Dropout(config.attn_pdrop) 206 | self.resid_drop = nn.Dropout(config.resid_pdrop) 207 | # output projection 208 | self.proj = nn.Linear(config.n_embd, config.n_embd) 209 | # # causal mask to ensure that attention is only applied to the left in the input sequence 210 | # mask = torch.tril(torch.ones(config.block_size, 211 | # config.block_size)) 212 | # if hasattr(config, "n_unmasked"): 213 | # mask[:config.n_unmasked, :config.n_unmasked] = 1 214 | # self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size)) 215 | self.n_head = config.n_head 216 | 217 | def forward(self, x): 218 | B, T, C = x.size() 219 | 220 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 221 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 222 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 223 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 224 | 225 | # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 226 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 227 | # att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf')) 228 | att = F.softmax(att, dim=-1) 229 | y = self.attn_drop(att) @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 230 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 231 | 232 | # output projection 233 | y = self.resid_drop(self.proj(y)) 234 | 235 | return y 236 | 237 | 238 | class Block(nn.Module): 239 | """an unassuming Transformer block""" 240 | 241 | def __init__(self, config): 242 | super().__init__() 243 | self.ln1 = nn.LayerNorm(config.n_embd) 244 | self.ln2 = nn.LayerNorm(config.n_embd) 245 | self.attn = SelfAttention(config) 246 | self.mlp = nn.Sequential( 247 | nn.Linear(config.n_embd, 4 * config.n_embd), 248 | nn.GELU(), # nice 249 | nn.Linear(4 * config.n_embd, config.n_embd), 250 | nn.Dropout(config.resid_pdrop), 251 | ) 252 | 253 | def forward(self, x): 254 | x = x + self.attn(self.ln1(x)) 255 | x = x + self.mlp(self.ln2(x)) 256 | return x 257 | 258 | 259 | def make_class_grid( 260 | leftmost_val, 261 | rightmost_val, 262 | grid_size, 263 | add_extreme_offset: bool = False, 264 | seg_size_vframes: int = None, 265 | nseg: int = None, 266 | step_size_seg: float = None, 267 | vfps: float = None, 268 | ): 269 | assert grid_size >= 3, f"grid_size: {grid_size} doesnot make sense. If =2 -> (-1,1); =1 -> (-1); =0 -> ()" 270 | grid = torch.from_numpy(np.linspace(leftmost_val, rightmost_val, grid_size)).float() 271 | if add_extreme_offset: 272 | assert all([seg_size_vframes, nseg, step_size_seg]), f"{seg_size_vframes} {nseg} {step_size_seg}" 273 | seg_size_sec = seg_size_vframes / vfps 274 | trim_size_in_seg = nseg - (1 - step_size_seg) * (nseg - 1) 275 | extreme_value = trim_size_in_seg * seg_size_sec 276 | grid = torch.cat([grid, torch.tensor([extreme_value])]) # adding extreme offset to the class grid 277 | return grid 278 | 279 | 280 | # from synchformer 281 | def pad_or_truncate(audio: torch.Tensor, max_spec_t: int, pad_mode: str = "constant", pad_value: float = 0.0): 282 | difference = max_spec_t - audio.shape[-1] # safe for batched input 283 | # pad or truncate, depending on difference 284 | if difference > 0: 285 | # pad the last dim (time) -> (..., n_mels, 0+time+difference) # safe for batched input 286 | pad_dims = (0, difference) 287 | audio = torch.nn.functional.pad(audio, pad_dims, pad_mode, pad_value) 288 | elif difference < 0: 289 | print(f"Truncating spec ({audio.shape}) to max_spec_t ({max_spec_t}).") 290 | audio = audio[..., :max_spec_t] # safe for batched input 291 | return audio 292 | 293 | 294 | def encode_audio_with_sync( 295 | synchformer: Synchformer, x: torch.Tensor, mel: torchaudio.transforms.MelSpectrogram 296 | ) -> torch.Tensor: 297 | b, t = x.shape 298 | 299 | # partition the video 300 | segment_size = 10240 301 | step_size = 10240 // 2 302 | num_segments = (t - segment_size) // step_size + 1 303 | segments = [] 304 | for i in range(num_segments): 305 | segments.append(x[:, i * step_size : i * step_size + segment_size]) 306 | x = torch.stack(segments, dim=1) # (B, S, T, C, H, W) 307 | 308 | x = mel(x) 309 | x = torch.log(x + 1e-6) 310 | x = pad_or_truncate(x, 66) 311 | 312 | mean = -4.2677393 313 | std = 4.5689974 314 | x = (x - mean) / (2 * std) 315 | # x: B * S * 128 * 66 316 | x = synchformer.extract_afeats(x.unsqueeze(2)) 317 | return x 318 | 319 | 320 | def read_audio(filename, expected_length=int(16000 * 4)): 321 | waveform, sr = torchaudio.load(filename) 322 | waveform = waveform.mean(dim=0) 323 | 324 | if sr != 16000: 325 | resampler = torchaudio.transforms.Resample(sr, 16000) 326 | waveform = resampler[sr](waveform) 327 | 328 | waveform = waveform[:expected_length] 329 | if waveform.shape[0] != expected_length: 330 | raise ValueError(f"Audio {filename} is too short") 331 | 332 | waveform = waveform.squeeze() 333 | 334 | return waveform 335 | 336 | 337 | if __name__ == "__main__": 338 | synchformer = Synchformer().cuda().eval() 339 | 340 | # mmaudio provided synchformer ckpt 341 | synchformer.load_state_dict( 342 | torch.load( 343 | os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"), 344 | weights_only=True, 345 | map_location="cpu", 346 | ) 347 | ) 348 | 349 | sync_mel_spectrogram = torchaudio.transforms.MelSpectrogram( 350 | sample_rate=16000, 351 | win_length=400, 352 | hop_length=160, 353 | n_fft=1024, 354 | n_mels=128, 355 | ) 356 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/model/dac.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List 3 | from typing import Union 4 | 5 | import numpy as np 6 | import torch 7 | from audiotools import AudioSignal 8 | from audiotools.ml import BaseModel 9 | from torch import nn 10 | 11 | from .base import CodecMixin 12 | from ..nn.layers import Snake1d 13 | from ..nn.layers import WNConv1d 14 | from ..nn.layers import WNConvTranspose1d 15 | from ..nn.quantize import ResidualVectorQuantize 16 | from ..nn.vae_utils import DiagonalGaussianDistribution 17 | 18 | 19 | def init_weights(m): 20 | if isinstance(m, nn.Conv1d): 21 | nn.init.trunc_normal_(m.weight, std=0.02) 22 | nn.init.constant_(m.bias, 0) 23 | 24 | 25 | class ResidualUnit(nn.Module): 26 | def __init__(self, dim: int = 16, dilation: int = 1): 27 | super().__init__() 28 | pad = ((7 - 1) * dilation) // 2 29 | self.block = nn.Sequential( 30 | Snake1d(dim), 31 | WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), 32 | Snake1d(dim), 33 | WNConv1d(dim, dim, kernel_size=1), 34 | ) 35 | 36 | def forward(self, x): 37 | y = self.block(x) 38 | pad = (x.shape[-1] - y.shape[-1]) // 2 39 | if pad > 0: 40 | x = x[..., pad:-pad] 41 | return x + y 42 | 43 | 44 | class EncoderBlock(nn.Module): 45 | def __init__(self, dim: int = 16, stride: int = 1): 46 | super().__init__() 47 | self.block = nn.Sequential( 48 | ResidualUnit(dim // 2, dilation=1), 49 | ResidualUnit(dim // 2, dilation=3), 50 | ResidualUnit(dim // 2, dilation=9), 51 | Snake1d(dim // 2), 52 | WNConv1d( 53 | dim // 2, 54 | dim, 55 | kernel_size=2 * stride, 56 | stride=stride, 57 | padding=math.ceil(stride / 2), 58 | ), 59 | ) 60 | 61 | def forward(self, x): 62 | return self.block(x) 63 | 64 | 65 | class Encoder(nn.Module): 66 | def __init__( 67 | self, 68 | d_model: int = 64, 69 | strides: list = [2, 4, 8, 8], 70 | d_latent: int = 64, 71 | ): 72 | super().__init__() 73 | # Create first convolution 74 | self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] 75 | 76 | # Create EncoderBlocks that double channels as they downsample by `stride` 77 | for stride in strides: 78 | d_model *= 2 79 | self.block += [EncoderBlock(d_model, stride=stride)] 80 | 81 | # Create last convolution 82 | self.block += [ 83 | Snake1d(d_model), 84 | WNConv1d(d_model, d_latent, kernel_size=3, padding=1), 85 | ] 86 | 87 | # Wrap black into nn.Sequential 88 | self.block = nn.Sequential(*self.block) 89 | self.enc_dim = d_model 90 | 91 | def forward(self, x): 92 | return self.block(x) 93 | 94 | 95 | class DecoderBlock(nn.Module): 96 | def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): 97 | super().__init__() 98 | self.block = nn.Sequential( 99 | Snake1d(input_dim), 100 | WNConvTranspose1d( 101 | input_dim, 102 | output_dim, 103 | kernel_size=2 * stride, 104 | stride=stride, 105 | padding=math.ceil(stride / 2), 106 | output_padding=stride % 2, 107 | ), 108 | ResidualUnit(output_dim, dilation=1), 109 | ResidualUnit(output_dim, dilation=3), 110 | ResidualUnit(output_dim, dilation=9), 111 | ) 112 | 113 | def forward(self, x): 114 | return self.block(x) 115 | 116 | 117 | class Decoder(nn.Module): 118 | def __init__( 119 | self, 120 | input_channel, 121 | channels, 122 | rates, 123 | d_out: int = 1, 124 | ): 125 | super().__init__() 126 | 127 | # Add first conv layer 128 | layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] 129 | 130 | # Add upsampling + MRF blocks 131 | for i, stride in enumerate(rates): 132 | input_dim = channels // 2**i 133 | output_dim = channels // 2 ** (i + 1) 134 | layers += [DecoderBlock(input_dim, output_dim, stride)] 135 | 136 | # Add final conv layer 137 | layers += [ 138 | Snake1d(output_dim), 139 | WNConv1d(output_dim, d_out, kernel_size=7, padding=3), 140 | nn.Tanh(), 141 | ] 142 | 143 | self.model = nn.Sequential(*layers) 144 | 145 | def forward(self, x): 146 | return self.model(x) 147 | 148 | 149 | class DAC(BaseModel, CodecMixin): 150 | def __init__( 151 | self, 152 | encoder_dim: int = 64, 153 | encoder_rates: List[int] = [2, 4, 8, 8], 154 | latent_dim: int = None, 155 | decoder_dim: int = 1536, 156 | decoder_rates: List[int] = [8, 8, 4, 2], 157 | n_codebooks: int = 9, 158 | codebook_size: int = 1024, 159 | codebook_dim: Union[int, list] = 8, 160 | quantizer_dropout: bool = False, 161 | sample_rate: int = 44100, 162 | continuous: bool = False, 163 | ): 164 | super().__init__() 165 | 166 | self.encoder_dim = encoder_dim 167 | self.encoder_rates = encoder_rates 168 | self.decoder_dim = decoder_dim 169 | self.decoder_rates = decoder_rates 170 | self.sample_rate = sample_rate 171 | self.continuous = continuous 172 | 173 | if latent_dim is None: 174 | latent_dim = encoder_dim * (2 ** len(encoder_rates)) 175 | 176 | self.latent_dim = latent_dim 177 | 178 | self.hop_length = np.prod(encoder_rates) 179 | self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim) 180 | 181 | if not continuous: 182 | self.n_codebooks = n_codebooks 183 | self.codebook_size = codebook_size 184 | self.codebook_dim = codebook_dim 185 | self.quantizer = ResidualVectorQuantize( 186 | input_dim=latent_dim, 187 | n_codebooks=n_codebooks, 188 | codebook_size=codebook_size, 189 | codebook_dim=codebook_dim, 190 | quantizer_dropout=quantizer_dropout, 191 | ) 192 | else: 193 | self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1) 194 | self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1) 195 | 196 | self.decoder = Decoder( 197 | latent_dim, 198 | decoder_dim, 199 | decoder_rates, 200 | ) 201 | self.sample_rate = sample_rate 202 | self.apply(init_weights) 203 | 204 | self.delay = self.get_delay() 205 | 206 | @property 207 | def dtype(self): 208 | """Get the dtype of the model parameters.""" 209 | # Return the dtype of the first parameter found 210 | for param in self.parameters(): 211 | return param.dtype 212 | return torch.float32 # fallback 213 | 214 | @property 215 | def device(self): 216 | """Get the device of the model parameters.""" 217 | # Return the device of the first parameter found 218 | for param in self.parameters(): 219 | return param.device 220 | return torch.device('cpu') # fallback 221 | 222 | def preprocess(self, audio_data, sample_rate): 223 | if sample_rate is None: 224 | sample_rate = self.sample_rate 225 | assert sample_rate == self.sample_rate 226 | 227 | length = audio_data.shape[-1] 228 | right_pad = math.ceil(length / self.hop_length) * self.hop_length - length 229 | audio_data = nn.functional.pad(audio_data, (0, right_pad)) 230 | 231 | return audio_data 232 | 233 | def encode( 234 | self, 235 | audio_data: torch.Tensor, 236 | n_quantizers: int = None, 237 | ): 238 | """Encode given audio data and return quantized latent codes 239 | 240 | Parameters 241 | ---------- 242 | audio_data : Tensor[B x 1 x T] 243 | Audio data to encode 244 | n_quantizers : int, optional 245 | Number of quantizers to use, by default None 246 | If None, all quantizers are used. 247 | 248 | Returns 249 | ------- 250 | dict 251 | A dictionary with the following keys: 252 | "z" : Tensor[B x D x T] 253 | Quantized continuous representation of input 254 | "codes" : Tensor[B x N x T] 255 | Codebook indices for each codebook 256 | (quantized discrete representation of input) 257 | "latents" : Tensor[B x N*D x T] 258 | Projected latents (continuous representation of input before quantization) 259 | "vq/commitment_loss" : Tensor[1] 260 | Commitment loss to train encoder to predict vectors closer to codebook 261 | entries 262 | "vq/codebook_loss" : Tensor[1] 263 | Codebook loss to update the codebook 264 | "length" : int 265 | Number of samples in input audio 266 | """ 267 | z = self.encoder(audio_data) # [B x D x T] 268 | if not self.continuous: 269 | z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers) 270 | else: 271 | z = self.quant_conv(z) # [B x 2D x T] 272 | z = DiagonalGaussianDistribution(z) 273 | codes, latents, commitment_loss, codebook_loss = None, None, 0, 0 274 | 275 | return z, codes, latents, commitment_loss, codebook_loss 276 | 277 | def decode(self, z: torch.Tensor): 278 | """Decode given latent codes and return audio data 279 | 280 | Parameters 281 | ---------- 282 | z : Tensor[B x D x T] 283 | Quantized continuous representation of input 284 | length : int, optional 285 | Number of samples in output audio, by default None 286 | 287 | Returns 288 | ------- 289 | dict 290 | A dictionary with the following keys: 291 | "audio" : Tensor[B x 1 x length] 292 | Decoded audio data. 293 | """ 294 | if not self.continuous: 295 | audio = self.decoder(z) 296 | else: 297 | z = self.post_quant_conv(z) 298 | audio = self.decoder(z) 299 | 300 | return audio 301 | 302 | def forward( 303 | self, 304 | audio_data: torch.Tensor, 305 | sample_rate: int = None, 306 | n_quantizers: int = None, 307 | ): 308 | """Model forward pass 309 | 310 | Parameters 311 | ---------- 312 | audio_data : Tensor[B x 1 x T] 313 | Audio data to encode 314 | sample_rate : int, optional 315 | Sample rate of audio data in Hz, by default None 316 | If None, defaults to `self.sample_rate` 317 | n_quantizers : int, optional 318 | Number of quantizers to use, by default None. 319 | If None, all quantizers are used. 320 | 321 | Returns 322 | ------- 323 | dict 324 | A dictionary with the following keys: 325 | "z" : Tensor[B x D x T] 326 | Quantized continuous representation of input 327 | "codes" : Tensor[B x N x T] 328 | Codebook indices for each codebook 329 | (quantized discrete representation of input) 330 | "latents" : Tensor[B x N*D x T] 331 | Projected latents (continuous representation of input before quantization) 332 | "vq/commitment_loss" : Tensor[1] 333 | Commitment loss to train encoder to predict vectors closer to codebook 334 | entries 335 | "vq/codebook_loss" : Tensor[1] 336 | Codebook loss to update the codebook 337 | "length" : int 338 | Number of samples in input audio 339 | "audio" : Tensor[B x 1 x length] 340 | Decoded audio data. 341 | """ 342 | length = audio_data.shape[-1] 343 | audio_data = self.preprocess(audio_data, sample_rate) 344 | if not self.continuous: 345 | z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers) 346 | 347 | x = self.decode(z) 348 | return { 349 | "audio": x[..., :length], 350 | "z": z, 351 | "codes": codes, 352 | "latents": latents, 353 | "vq/commitment_loss": commitment_loss, 354 | "vq/codebook_loss": codebook_loss, 355 | } 356 | else: 357 | posterior, _, _, _, _ = self.encode(audio_data, n_quantizers) 358 | z = posterior.sample() 359 | x = self.decode(z) 360 | 361 | kl_loss = posterior.kl() 362 | kl_loss = kl_loss.mean() 363 | 364 | return { 365 | "audio": x[..., :length], 366 | "z": z, 367 | "kl_loss": kl_loss, 368 | } 369 | 370 | 371 | if __name__ == "__main__": 372 | import numpy as np 373 | from functools import partial 374 | 375 | model = DAC().to("cpu") 376 | 377 | for n, m in model.named_modules(): 378 | o = m.extra_repr() 379 | p = sum([np.prod(p.size()) for p in m.parameters()]) 380 | fn = lambda o, p: o + f" {p/1e6:<.3f}M params." 381 | setattr(m, "extra_repr", partial(fn, o=o, p=p)) 382 | print(model) 383 | print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) 384 | 385 | length = 88200 * 2 386 | x = torch.randn(1, 1, length).to(model.device) 387 | x.requires_grad_(True) 388 | x.retain_grad() 389 | 390 | # Make a forward pass 391 | out = model(x)["audio"] 392 | print("Input shape:", x.shape) 393 | print("Output shape:", out.shape) 394 | 395 | # Create gradient variable 396 | grad = torch.zeros_like(out) 397 | grad[:, :, grad.shape[-1] // 2] = 1 398 | 399 | # Make a backward pass 400 | out.backward(grad) 401 | 402 | # Check non-zero values 403 | gradmap = x.grad.squeeze(0) 404 | gradmap = (gradmap != 0).sum(0) # sum across features 405 | rf = (gradmap != 0).sum() 406 | 407 | print(f"Receptive field: {rf.item()}") 408 | 409 | x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100) 410 | model.decompress(model.compress(x, verbose=True), verbose=True) 411 | --------------------------------------------------------------------------------