├── hunyuanvideo_foley ├── models │ ├── __init__.py │ ├── nn │ │ ├── __init__.py │ │ ├── activation_layers.py │ │ ├── modulate_layers.py │ │ ├── norm_layers.py │ │ ├── embed_layers.py │ │ ├── mlp_layers.py │ │ └── posemb_layers.py │ ├── synchformer │ │ ├── __init__.py │ │ ├── divided_224_16x4.yaml │ │ ├── utils.py │ │ ├── compute_desync_score.py │ │ ├── video_model_builder.py │ │ └── synchformer.py │ └── dac_vae │ │ ├── nn │ │ ├── __init__.py │ │ ├── layers.py │ │ ├── vae_utils.py │ │ ├── quantize.py │ │ └── loss.py │ │ ├── model │ │ ├── __init__.py │ │ ├── discriminator.py │ │ ├── base.py │ │ └── dac.py │ │ ├── __init__.py │ │ ├── __main__.py │ │ └── utils │ │ ├── decode.py │ │ ├── encode.py │ │ └── __init__.py ├── utils │ ├── __init__.py │ ├── schedulers │ │ └── __init__.py │ ├── media_utils.py │ ├── config_utils.py │ ├── helper.py │ └── feature_utils.py ├── __init__.py ├── constants.py └── cli.py ├── tests ├── __init__.py ├── test_config_utils.py └── test_media_utils.py ├── assets ├── logo.png ├── pan_chart.png ├── data_pipeline.png ├── model_arch.png └── test.csv ├── .gitattributes ├── pytest.ini ├── download_test_videos.sh ├── requirements.txt ├── MANIFEST.in ├── .pre-commit-config.yaml ├── pyproject.toml ├── configs ├── hunyuanvideo-foley-xl.yaml └── hunyuanvideo-foley-xxl.yaml ├── NOTICE ├── .gitignore ├── INSTALL.md ├── DEVELOPMENT.md ├── setup.py └── infer.py /hunyuanvideo_foley/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Test suite for HunyuanVideo-Foley -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .synchformer import Synchformer 2 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from . import layers 2 | from . import loss 3 | from . import quantize 4 | -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:619860d15d8b8aa34d35dd60d547afe4d96dde78cb5d77ef5217d5507950cc49 3 | size 214408 4 | -------------------------------------------------------------------------------- /assets/pan_chart.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:16019d3355051f5b470532809a0cf9046d22170d30c860dd01929f6921d29ead 3 | size 303974 4 | -------------------------------------------------------------------------------- /assets/data_pipeline.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d5c9e5cd92a7ac24d1e8f39db09e0eaa9ee84bedade8ff08bd1d50141fc7867c 3 | size 384649 4 | -------------------------------------------------------------------------------- /assets/model_arch.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4709a32df5b115e7806e0eb102aaf2e396a0978e12b31fba338730068d6454d7 3 | size 542135 4 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | assets/data_pipeline.png filter=lfs diff=lfs merge=lfs -text 2 | assets/model_arch.png filter=lfs diff=lfs merge=lfs -text 3 | *.png filter=lfs diff=lfs merge=lfs -text 4 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import CodecMixin 2 | from .base import DACFile 3 | from .dac import DAC 4 | from .discriminator import Discriminator 5 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from diffusers.schedulers import DDPMScheduler, EulerDiscreteScheduler 2 | from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | testpaths = tests 3 | python_files = test_*.py 4 | python_functions = test_* 5 | addopts = 6 | --verbose 7 | --tb=short 8 | --strict-markers 9 | --disable-warnings 10 | markers = 11 | slow: marks tests as slow (deselect with '-m "not slow"') -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | 3 | # preserved here for legacy reasons 4 | __model_version__ = "latest" 5 | 6 | import audiotools 7 | 8 | audiotools.ml.BaseModel.INTERN += ["dac.**"] 9 | audiotools.ml.BaseModel.EXTERN += ["einops"] 10 | 11 | 12 | from . import nn 13 | from . import model 14 | from . import utils 15 | from .model import DAC 16 | from .model import DACFile 17 | -------------------------------------------------------------------------------- /download_test_videos.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Download MoviegenAudioBenchSfx 10 videos 4 | curl -O https://texttoaudio-train-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuanvideo-foley_demo/MovieGenAudioBenchSfx.tar.gz 5 | tar -xzvf MovieGenAudioBenchSfx.tar.gz -C ./assets 6 | rm MovieGenAudioBenchSfx.tar.gz 7 | 8 | # Download gradio example video 9 | curl -O https://texttoaudio-train-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuanvideo-foley_demo/examples.tar.gz 10 | tar -xvzf examples.tar.gz 11 | rm examples.tar.gz 12 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/__main__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import argbind 4 | 5 | from .utils import download 6 | from .utils.decode import decode 7 | from .utils.encode import encode 8 | 9 | STAGES = ["encode", "decode", "download"] 10 | 11 | 12 | def run(stage: str): 13 | """Run stages. 14 | 15 | Parameters 16 | ---------- 17 | stage : str 18 | Stage to run 19 | """ 20 | if stage not in STAGES: 21 | raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}") 22 | stage_fn = globals()[stage] 23 | 24 | if stage == "download": 25 | stage_fn() 26 | return 27 | 28 | stage_fn() 29 | 30 | 31 | if __name__ == "__main__": 32 | group = sys.argv.pop(1) 33 | args = argbind.parse_args(group=group) 34 | 35 | with argbind.scope(args): 36 | run(group) 37 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Core ML dependencies 2 | torch>=2.0.0 3 | torchvision>=0.15.0 4 | torchaudio>=2.0.0 5 | numpy==1.26.4 6 | scipy 7 | 8 | # Deep Learning frameworks 9 | diffusers 10 | timm 11 | accelerate 12 | 13 | # Transformers and NLP 14 | git+https://github.com/huggingface/transformers@v4.49.0-SigLIP-2 15 | sentencepiece 16 | 17 | # Audio processing 18 | git+https://github.com/descriptinc/audiotools 19 | 20 | # Video/Image processing 21 | pillow 22 | av 23 | einops 24 | 25 | # Configuration and utilities 26 | pyyaml 27 | omegaconf 28 | easydict 29 | loguru 30 | tqdm 31 | setuptools 32 | 33 | # Data handling 34 | pandas 35 | pyarrow 36 | 37 | # Web interface 38 | gradio==3.50.2 39 | 40 | # Network 41 | urllib3==2.4.0 42 | 43 | # Development dependencies (optional) 44 | black>=23.0.0 45 | isort>=5.12.0 46 | flake8>=6.0.0 47 | mypy>=1.3.0 48 | pre-commit>=3.0.0 49 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Include package metadata and documentation 2 | include README.md 3 | include LICENSE 4 | include NOTICE 5 | include DEVELOPMENT.md 6 | include CLAUDE.md 7 | include requirements.txt 8 | include pyproject.toml 9 | include pytest.ini 10 | 11 | # Include configuration files 12 | include configs/*.yaml 13 | include configs/*.yml 14 | recursive-include hunyuanvideo_foley/configs *.yaml *.yml 15 | 16 | # Include test assets if any 17 | include assets/*.csv 18 | include assets/*.txt 19 | recursive-include assets/test_videos * 20 | 21 | # Include example scripts 22 | include *.py 23 | include *.sh 24 | 25 | # Include test files 26 | recursive-include tests *.py 27 | 28 | # Exclude unnecessary files 29 | global-exclude *.pyc 30 | global-exclude *.pyo 31 | global-exclude *~ 32 | global-exclude .DS_Store 33 | global-exclude __pycache__ 34 | prune .git 35 | prune .github 36 | prune examples/*/outputs 37 | prune **/__pycache__ 38 | prune **/*.pyc 39 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from torch.nn.utils import weight_norm 7 | 8 | 9 | def WNConv1d(*args, **kwargs): 10 | return weight_norm(nn.Conv1d(*args, **kwargs)) 11 | 12 | 13 | def WNConvTranspose1d(*args, **kwargs): 14 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) 15 | 16 | 17 | # Scripting this brings model speed up 1.4x 18 | @torch.jit.script 19 | def snake(x, alpha): 20 | shape = x.shape 21 | x = x.reshape(shape[0], shape[1], -1) 22 | x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) 23 | x = x.reshape(shape) 24 | return x 25 | 26 | 27 | class Snake1d(nn.Module): 28 | def __init__(self, channels): 29 | super().__init__() 30 | self.alpha = nn.Parameter(torch.ones(1, channels, 1)) 31 | 32 | def forward(self, x): 33 | return snake(x, self.alpha) 34 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | HunyuanVideo-Foley: Multimodal Diffusion with Representation Alignment 3 | for High-Fidelity Foley Audio Generation 4 | 5 | This package provides tools for generating high-quality Foley audio effects 6 | from video content using multimodal diffusion models. 7 | """ 8 | 9 | __version__ = "1.0.0" 10 | __author__ = "Tencent Hunyuan Team" 11 | __email__ = "hunyuan@tencent.com" 12 | 13 | # Import main components for easy access 14 | try: 15 | from .utils.model_utils import load_model, denoise_process 16 | from .utils.feature_utils import feature_process 17 | from .utils.media_utils import merge_audio_video 18 | from .utils.config_utils import AttributeDict 19 | 20 | __all__ = [ 21 | "__version__", 22 | "load_model", 23 | "denoise_process", 24 | "feature_process", 25 | "merge_audio_video", 26 | "AttributeDict" 27 | ] 28 | except ImportError: 29 | # Handle missing dependencies gracefully during installation 30 | __all__ = ["__version__"] -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-yaml 8 | - id: check-added-large-files 9 | - id: check-merge-conflict 10 | - id: debug-statements 11 | - id: check-docstring-first 12 | 13 | - repo: https://github.com/psf/black 14 | rev: 23.3.0 15 | hooks: 16 | - id: black 17 | language_version: python3 18 | args: [--line-length=120] 19 | 20 | - repo: https://github.com/pycqa/isort 21 | rev: 5.12.0 22 | hooks: 23 | - id: isort 24 | args: [--profile, black, --line-length=120] 25 | 26 | - repo: https://github.com/pycqa/flake8 27 | rev: 6.0.0 28 | hooks: 29 | - id: flake8 30 | args: [--max-line-length=120] 31 | additional_dependencies: [flake8-docstrings] 32 | 33 | - repo: https://github.com/pre-commit/mirrors-mypy 34 | rev: v1.3.0 35 | hooks: 36 | - id: mypy 37 | additional_dependencies: [types-all] 38 | args: [--ignore-missing-imports] -------------------------------------------------------------------------------- /assets/test.csv: -------------------------------------------------------------------------------- 1 | index,video,prompt 2 | 0,assets/MovieGenAudioBenchSfx/video_with_audio/0.mp4,"juicy crunches of the apple being bitten into and chewed on." 3 | 1,assets/MovieGenAudioBenchSfx/video_with_audio/1.mp4,"mashed potatoes being scooped, high-quality" 4 | 2,assets/MovieGenAudioBenchSfx/video_with_audio/2.mp4,"the slurping of noodles as the man eats with gusto, and the soft clinking of the fork against the plate as the man twirls the noodles." 5 | 3,assets/MovieGenAudioBenchSfx/video_with_audio/3.mp4,"ice cubes clinking against the glass." 6 | 4,assets/MovieGenAudioBenchSfx/video_with_audio/4.mp4,"peoples' footsteps and the gentle hum of the city in the background." 7 | 5,assets/MovieGenAudioBenchSfx/video_with_audio/5.mp4,"the slurping and smacking of the child's tongue on the ice pop." 8 | 6,assets/MovieGenAudioBenchSfx/video_with_audio/6.mp4,"gentle licking the fur, high-quality" 9 | 7,assets/MovieGenAudioBenchSfx/video_with_audio/7.mp4,"dog's tongue lapping against the bowl." 10 | 8,assets/MovieGenAudioBenchSfx/video_with_audio/8.mp4,"boy's tongue darts out and licks the cone with a slight slurping sound." 11 | 9,assets/MovieGenAudioBenchSfx/video_with_audio/9.mp4,"straw sucking sounds with slurping noises." 12 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ['py38', 'py39', 'py310', 'py311'] 4 | include = '\.pyi?$' 5 | exclude = ''' 6 | /( 7 | \.eggs 8 | | \.git 9 | | \.hg 10 | | \.mypy_cache 11 | | \.tox 12 | | \.venv 13 | | _build 14 | | buck-out 15 | | build 16 | | dist 17 | )/ 18 | ''' 19 | 20 | [tool.isort] 21 | profile = "black" 22 | line_length = 120 23 | multi_line_output = 3 24 | include_trailing_comma = true 25 | force_grid_wrap = 0 26 | use_parentheses = true 27 | ensure_newline_before_comments = true 28 | 29 | [tool.flake8] 30 | max-line-length = 120 31 | select = ["E", "W", "F"] 32 | ignore = [ 33 | "E203", # whitespace before ':' 34 | "E501", # line too long 35 | "W503", # line break before binary operator 36 | ] 37 | exclude = [ 38 | ".git", 39 | "__pycache__", 40 | "build", 41 | "dist", 42 | ".eggs", 43 | "*.egg-info", 44 | ".venv", 45 | ".tox", 46 | ] 47 | 48 | [tool.mypy] 49 | python_version = "3.8" 50 | warn_return_any = true 51 | warn_unused_configs = true 52 | disallow_untyped_defs = false 53 | disallow_incomplete_defs = false 54 | check_untyped_defs = true 55 | disallow_untyped_decorators = false 56 | no_implicit_optional = true 57 | warn_redundant_casts = true 58 | warn_unused_ignores = true 59 | warn_no_return = true 60 | warn_unreachable = true 61 | strict_equality = true -------------------------------------------------------------------------------- /configs/hunyuanvideo-foley-xl.yaml: -------------------------------------------------------------------------------- 1 | model_config: 2 | model_name: HunyuanVideo-Foley-XL 3 | model_type: 1d 4 | model_precision: bf16 5 | model_kwargs: 6 | depth_triple_blocks: 12 7 | depth_single_blocks: 24 8 | hidden_size: 1408 9 | num_heads: 11 10 | mlp_ratio: 4 11 | mlp_act_type: "gelu_tanh" 12 | qkv_bias: True 13 | qk_norm: True 14 | qk_norm_type: "rms" 15 | attn_mode: "torch" 16 | embedder_type: "default" 17 | interleaved_audio_visual_rope: True 18 | enable_learnable_empty_visual_feat: True 19 | sync_modulation: False 20 | add_sync_feat_to_audio: True 21 | cross_attention: True 22 | use_attention_mask: False 23 | condition_projection: "linear" 24 | sync_feat_dim: 768 # syncformer 768 dim 25 | condition_dim: 768 # clap 768 text condition dim (clip-text) 26 | clip_dim: 768 # siglip2 visual dim 27 | audio_vae_latent_dim: 128 28 | audio_frame_rate: 50 29 | patch_size: 1 30 | rope_dim_list: null 31 | rope_theta: 10000 32 | text_length: 77 33 | clip_length: 64 34 | sync_length: 192 35 | depth_triple_ssl_encoder: null 36 | depth_single_ssl_encoder: 8 37 | use_repa_with_audiossl: True 38 | 39 | diffusion_config: 40 | denoise_type: "flow" 41 | flow_path_type: "linear" 42 | flow_predict_type: "velocity" 43 | flow_reverse: True 44 | flow_solver: "euler" 45 | sample_flow_shift: 1.0 46 | sample_use_flux_shift: False 47 | flux_base_shift: 0.5 48 | flux_max_shift: 1.15 49 | -------------------------------------------------------------------------------- /configs/hunyuanvideo-foley-xxl.yaml: -------------------------------------------------------------------------------- 1 | model_config: 2 | model_name: HunyuanVideo-Foley-XXL 3 | model_type: 1d 4 | model_precision: bf16 5 | model_kwargs: 6 | depth_triple_blocks: 18 7 | depth_single_blocks: 36 8 | hidden_size: 1536 9 | num_heads: 12 10 | mlp_ratio: 4 11 | mlp_act_type: "gelu_tanh" 12 | qkv_bias: True 13 | qk_norm: True 14 | qk_norm_type: "rms" 15 | attn_mode: "torch" 16 | embedder_type: "default" 17 | interleaved_audio_visual_rope: True 18 | enable_learnable_empty_visual_feat: True 19 | sync_modulation: False 20 | add_sync_feat_to_audio: True 21 | cross_attention: True 22 | use_attention_mask: False 23 | condition_projection: "linear" 24 | sync_feat_dim: 768 # syncformer 768 dim 25 | condition_dim: 768 # clap 768 text condition dim (clip-text) 26 | clip_dim: 768 # siglip2 visual dim 27 | audio_vae_latent_dim: 128 28 | audio_frame_rate: 50 29 | patch_size: 1 30 | rope_dim_list: null 31 | rope_theta: 10000 32 | text_length: 77 33 | clip_length: 64 34 | sync_length: 192 35 | depth_triple_ssl_encoder: null 36 | depth_single_ssl_encoder: 8 37 | use_repa_with_audiossl: True 38 | 39 | diffusion_config: 40 | denoise_type: "flow" 41 | flow_path_type: "linear" 42 | flow_predict_type: "velocity" 43 | flow_reverse: True 44 | flow_solver: "euler" 45 | sample_flow_shift: 1.0 46 | sample_use_flux_shift: False 47 | flux_base_shift: 0.5 48 | flux_max_shift: 1.15 49 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/activation_layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | def get_activation_layer(act_type): 5 | if act_type == "gelu": 6 | return lambda: nn.GELU() 7 | elif act_type == "gelu_tanh": 8 | # Approximate `tanh` requires torch >= 1.13 9 | return lambda: nn.GELU(approximate="tanh") 10 | elif act_type == "relu": 11 | return nn.ReLU 12 | elif act_type == "silu": 13 | return nn.SiLU 14 | else: 15 | raise ValueError(f"Unknown activation type: {act_type}") 16 | 17 | class SwiGLU(nn.Module): 18 | def __init__( 19 | self, 20 | dim: int, 21 | hidden_dim: int, 22 | out_dim: int, 23 | ): 24 | """ 25 | Initialize the SwiGLU FeedForward module. 26 | 27 | Args: 28 | dim (int): Input dimension. 29 | hidden_dim (int): Hidden dimension of the feedforward layer. 30 | 31 | Attributes: 32 | w1: Linear transformation for the first layer. 33 | w2: Linear transformation for the second layer. 34 | w3: Linear transformation for the third layer. 35 | 36 | """ 37 | super().__init__() 38 | 39 | self.w1 = nn.Linear(dim, hidden_dim, bias=False) 40 | self.w2 = nn.Linear(hidden_dim, out_dim, bias=False) 41 | self.w3 = nn.Linear(dim, hidden_dim, bias=False) 42 | 43 | def forward(self, x): 44 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 45 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/modulate_layers.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import torch 3 | import torch.nn as nn 4 | 5 | class ModulateDiT(nn.Module): 6 | def __init__(self, hidden_size: int, factor: int, act_layer: Callable, dtype=None, device=None): 7 | factory_kwargs = {"dtype": dtype, "device": device} 8 | super().__init__() 9 | self.act = act_layer() 10 | self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) 11 | # Zero-initialize the modulation 12 | nn.init.zeros_(self.linear.weight) 13 | nn.init.zeros_(self.linear.bias) 14 | 15 | def forward(self, x: torch.Tensor) -> torch.Tensor: 16 | return self.linear(self.act(x)) 17 | 18 | 19 | def modulate(x, shift=None, scale=None): 20 | if x.ndim == 3: 21 | shift = shift.unsqueeze(1) if shift is not None and shift.ndim == 2 else None 22 | scale = scale.unsqueeze(1) if scale is not None and scale.ndim == 2 else None 23 | if scale is None and shift is None: 24 | return x 25 | elif shift is None: 26 | return x * (1 + scale) 27 | elif scale is None: 28 | return x + shift 29 | else: 30 | return x * (1 + scale) + shift 31 | 32 | 33 | def apply_gate(x, gate=None, tanh=False): 34 | if gate is None: 35 | return x 36 | if gate.ndim == 2 and x.ndim == 3: 37 | gate = gate.unsqueeze(1) 38 | if tanh: 39 | return x * gate.tanh() 40 | else: 41 | return x * gate 42 | 43 | 44 | def ckpt_wrapper(module): 45 | def ckpt_forward(*inputs): 46 | outputs = module(*inputs) 47 | return outputs 48 | 49 | return ckpt_forward 50 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/constants.py: -------------------------------------------------------------------------------- 1 | """Constants used throughout the HunyuanVideo-Foley project.""" 2 | 3 | from typing import Dict, List 4 | 5 | # Model configuration 6 | DEFAULT_AUDIO_SAMPLE_RATE = 48000 7 | DEFAULT_VIDEO_FPS = 25 8 | DEFAULT_AUDIO_CHANNELS = 2 9 | 10 | # Video processing 11 | MAX_VIDEO_DURATION_SECONDS = 15.0 12 | MIN_VIDEO_DURATION_SECONDS = 1.0 13 | 14 | # Audio processing 15 | AUDIO_VAE_LATENT_DIM = 128 16 | AUDIO_FRAME_RATE = 75 # frames per second in latent space 17 | 18 | # Visual features 19 | FPS_VISUAL: Dict[str, int] = { 20 | "siglip2": 8, 21 | "synchformer": 25 22 | } 23 | 24 | # Model paths (can be overridden by environment variables) 25 | DEFAULT_MODEL_PATH = "./pretrained_models/" 26 | DEFAULT_CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml" 27 | 28 | # Inference parameters 29 | DEFAULT_GUIDANCE_SCALE = 4.5 30 | DEFAULT_NUM_INFERENCE_STEPS = 50 31 | MIN_GUIDANCE_SCALE = 1.0 32 | MAX_GUIDANCE_SCALE = 10.0 33 | MIN_INFERENCE_STEPS = 10 34 | MAX_INFERENCE_STEPS = 100 35 | 36 | # Text processing 37 | MAX_TEXT_LENGTH = 100 38 | DEFAULT_NEGATIVE_PROMPT = "noisy, harsh" 39 | 40 | # File extensions 41 | SUPPORTED_VIDEO_EXTENSIONS: List[str] = [".mp4", ".avi", ".mov", ".mkv", ".webm"] 42 | SUPPORTED_AUDIO_EXTENSIONS: List[str] = [".wav", ".mp3", ".flac", ".aac"] 43 | 44 | # Quality settings 45 | AUDIO_QUALITY_SETTINGS: Dict[str, List[str]] = { 46 | "high": ["-b:a", "192k"], 47 | "medium": ["-b:a", "128k"], 48 | "low": ["-b:a", "96k"] 49 | } 50 | 51 | # Error messages 52 | ERROR_MESSAGES: Dict[str, str] = { 53 | "model_not_loaded": "Model is not loaded. Please load the model first.", 54 | "invalid_video_format": "Unsupported video format. Supported formats: {formats}", 55 | "video_too_long": f"Video duration exceeds maximum of {MAX_VIDEO_DURATION_SECONDS} seconds", 56 | "ffmpeg_not_found": "ffmpeg not found. Please install ffmpeg: https://ffmpeg.org/download.html" 57 | } -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/divided_224_16x4.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: Ssv2 4 | BATCH_SIZE: 32 5 | EVAL_PERIOD: 5 6 | CHECKPOINT_PERIOD: 5 7 | AUTO_RESUME: True 8 | CHECKPOINT_EPOCH_RESET: True 9 | CHECKPOINT_FILE_PATH: /checkpoint/fmetze/neurips_sota/40944587/checkpoints/checkpoint_epoch_00035.pyth 10 | DATA: 11 | NUM_FRAMES: 16 12 | SAMPLING_RATE: 4 13 | TRAIN_JITTER_SCALES: [256, 320] 14 | TRAIN_CROP_SIZE: 224 15 | TEST_CROP_SIZE: 224 16 | INPUT_CHANNEL_NUM: [3] 17 | MEAN: [0.5, 0.5, 0.5] 18 | STD: [0.5, 0.5, 0.5] 19 | PATH_TO_DATA_DIR: /private/home/mandelapatrick/slowfast/data/ssv2 20 | PATH_PREFIX: /datasets01/SomethingV2/092720/20bn-something-something-v2-frames 21 | INV_UNIFORM_SAMPLE: True 22 | RANDOM_FLIP: False 23 | REVERSE_INPUT_CHANNEL: True 24 | USE_RAND_AUGMENT: True 25 | RE_PROB: 0.0 26 | USE_REPEATED_AUG: False 27 | USE_RANDOM_RESIZE_CROPS: False 28 | COLORJITTER: False 29 | GRAYSCALE: False 30 | GAUSSIAN: False 31 | SOLVER: 32 | BASE_LR: 1e-4 33 | LR_POLICY: steps_with_relative_lrs 34 | LRS: [1, 0.1, 0.01] 35 | STEPS: [0, 20, 30] 36 | MAX_EPOCH: 35 37 | MOMENTUM: 0.9 38 | WEIGHT_DECAY: 5e-2 39 | WARMUP_EPOCHS: 0.0 40 | OPTIMIZING_METHOD: adamw 41 | USE_MIXED_PRECISION: True 42 | SMOOTHING: 0.2 43 | SLOWFAST: 44 | ALPHA: 8 45 | VIT: 46 | PATCH_SIZE: 16 47 | PATCH_SIZE_TEMP: 2 48 | CHANNELS: 3 49 | EMBED_DIM: 768 50 | DEPTH: 12 51 | NUM_HEADS: 12 52 | MLP_RATIO: 4 53 | QKV_BIAS: True 54 | VIDEO_INPUT: True 55 | TEMPORAL_RESOLUTION: 8 56 | USE_MLP: True 57 | DROP: 0.0 58 | POS_DROPOUT: 0.0 59 | DROP_PATH: 0.2 60 | IM_PRETRAINED: True 61 | HEAD_DROPOUT: 0.0 62 | HEAD_ACT: tanh 63 | PRETRAINED_WEIGHTS: vit_1k 64 | ATTN_LAYER: divided 65 | MODEL: 66 | NUM_CLASSES: 174 67 | ARCH: slow 68 | MODEL_NAME: VisionTransformer 69 | LOSS_FUNC: cross_entropy 70 | TEST: 71 | ENABLE: True 72 | DATASET: Ssv2 73 | BATCH_SIZE: 64 74 | NUM_ENSEMBLE_VIEWS: 1 75 | NUM_SPATIAL_CROPS: 3 76 | DATA_LOADER: 77 | NUM_WORKERS: 4 78 | PIN_MEMORY: True 79 | NUM_GPUS: 8 80 | NUM_SHARDS: 4 81 | RNG_SEED: 0 82 | OUTPUT_DIR: . 83 | TENSORBOARD: 84 | ENABLE: True 85 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/norm_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class RMSNorm(nn.Module): 5 | def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6, 6 | device=None, dtype=None): 7 | """ 8 | Initialize the RMSNorm normalization layer. 9 | 10 | Args: 11 | dim (int): The dimension of the input tensor. 12 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 13 | 14 | Attributes: 15 | eps (float): A small value added to the denominator for numerical stability. 16 | weight (nn.Parameter): Learnable scaling parameter. 17 | 18 | """ 19 | factory_kwargs = {'device': device, 'dtype': dtype} 20 | super().__init__() 21 | self.eps = eps 22 | if elementwise_affine: 23 | self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) 24 | 25 | def _norm(self, x): 26 | """ 27 | Apply the RMSNorm normalization to the input tensor. 28 | 29 | Args: 30 | x (torch.Tensor): The input tensor. 31 | 32 | Returns: 33 | torch.Tensor: The normalized tensor. 34 | 35 | """ 36 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 37 | 38 | def forward(self, x): 39 | """ 40 | Forward pass through the RMSNorm layer. 41 | 42 | Args: 43 | x (torch.Tensor): The input tensor. 44 | 45 | Returns: 46 | torch.Tensor: The output tensor after applying RMSNorm. 47 | 48 | """ 49 | output = self._norm(x.float()).type_as(x) 50 | if hasattr(self, "weight"): 51 | output = output * self.weight 52 | return output 53 | 54 | 55 | def get_norm_layer(norm_layer): 56 | """ 57 | Get the normalization layer. 58 | 59 | Args: 60 | norm_layer (str): The type of normalization layer. 61 | 62 | Returns: 63 | norm_layer (nn.Module): The normalization layer. 64 | """ 65 | if norm_layer == "layer": 66 | return nn.LayerNorm 67 | elif norm_layer == "rms": 68 | return RMSNorm 69 | else: 70 | raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") 71 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Usage and Legal Notices: 2 | 3 | Tencent is pleased to support the open source community by making Tencent HunyuanVideo-Foley available. 4 | 5 | Copyright (C) 2025 Tencent. All rights reserved. 6 | 7 | Tencent HunyuanVideo-Foley is licensed under TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT, which can be found in this repository called "LICENSE", except for the third-party components listed below. Tencent HunyuanVideo-Foley does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations. 8 | 9 | For avoidance of doubts, Tencent HunyuanVideo-Foley means the large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Tencent in accordance with the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. 10 | 11 | 12 | Other dependencies and licenses: 13 | 14 | 15 | Open Source Software Licensed under the MIT License: 16 | -------------------------------------------------------------------- 17 | 1. syncformer 18 | Copyright (c) 2024 Vladimir Iashin 19 | 20 | 21 | Terms of the MIT License: 22 | -------------------------------------------------------------------- 23 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 24 | 25 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 26 | 27 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # ========================================== 132 | # Custom settings 133 | # ========================================== 134 | 135 | # For MacOS 136 | .DS_Store 137 | 138 | # For IDEs 139 | .idea/ 140 | .vscode/ 141 | pyrightconfig.json 142 | .cursorignore 143 | 144 | assets/ 145 | examples/ 146 | 147 | # For global settings 148 | __*/ 149 | **/my_* 150 | tmp*.* 151 | .my* 152 | # Model checkpoints 153 | *.pt 154 | *.ckpt 155 | *.pth 156 | *.safetensors 157 | 158 | 159 | CLAUDE.md 160 | -------------------------------------------------------------------------------- /tests/test_config_utils.py: -------------------------------------------------------------------------------- 1 | """Tests for configuration utilities.""" 2 | 3 | import pytest 4 | import tempfile 5 | import yaml 6 | from pathlib import Path 7 | 8 | from hunyuanvideo_foley.utils.config_utils import AttributeDict, load_yaml 9 | 10 | 11 | class TestAttributeDict: 12 | """Test cases for AttributeDict class.""" 13 | 14 | def test_dict_access(self): 15 | """Test dictionary-style access.""" 16 | data = {"key1": "value1", "key2": {"nested": "value2"}} 17 | attr_dict = AttributeDict(data) 18 | 19 | assert attr_dict["key1"] == "value1" 20 | assert attr_dict["key2"]["nested"] == "value2" 21 | 22 | def test_attribute_access(self): 23 | """Test attribute-style access.""" 24 | data = {"key1": "value1", "key2": {"nested": "value2"}} 25 | attr_dict = AttributeDict(data) 26 | 27 | assert attr_dict.key1 == "value1" 28 | assert attr_dict.key2.nested == "value2" 29 | 30 | def test_list_handling(self): 31 | """Test list data handling.""" 32 | data = [1, 2, {"nested": "value"}] 33 | attr_dict = AttributeDict(data) 34 | 35 | assert attr_dict[0] == 1 36 | assert attr_dict[2].nested == "value" 37 | 38 | def test_keys_method(self): 39 | """Test keys() method.""" 40 | data = {"key1": "value1", "key2": "value2"} 41 | attr_dict = AttributeDict(data) 42 | 43 | keys = list(attr_dict.keys()) 44 | assert "key1" in keys 45 | assert "key2" in keys 46 | 47 | def test_get_method(self): 48 | """Test get() method.""" 49 | data = {"key1": "value1"} 50 | attr_dict = AttributeDict(data) 51 | 52 | assert attr_dict.get("key1") == "value1" 53 | assert attr_dict.get("nonexistent", "default") == "default" 54 | 55 | 56 | class TestLoadYaml: 57 | """Test cases for load_yaml function.""" 58 | 59 | def test_load_valid_yaml(self): 60 | """Test loading valid YAML file.""" 61 | data = {"model": {"name": "test_model", "params": {"lr": 0.001}}} 62 | 63 | with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: 64 | yaml.dump(data, f) 65 | yaml_path = f.name 66 | 67 | try: 68 | result = load_yaml(yaml_path) 69 | assert result.model.name == "test_model" 70 | assert result.model.params.lr == 0.001 71 | finally: 72 | Path(yaml_path).unlink() 73 | 74 | def test_load_nonexistent_file(self): 75 | """Test loading non-existent file.""" 76 | with pytest.raises(FileNotFoundError): 77 | load_yaml("nonexistent.yaml") 78 | 79 | def test_load_invalid_yaml(self): 80 | """Test loading invalid YAML file.""" 81 | with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: 82 | f.write("invalid: yaml: content: [\n") # Invalid YAML 83 | yaml_path = f.name 84 | 85 | try: 86 | with pytest.raises(yaml.YAMLError): 87 | load_yaml(yaml_path) 88 | finally: 89 | Path(yaml_path).unlink() -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.0]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.mean( 45 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 46 | dim=[1, 2], 47 | ) 48 | else: 49 | return 0.5 * torch.mean( 50 | torch.pow(self.mean - other.mean, 2) / other.var 51 | + self.var / other.var 52 | - 1.0 53 | - self.logvar 54 | + other.logvar, 55 | dim=[1, 2], 56 | ) 57 | 58 | def nll(self, sample, dims=[1, 2]): 59 | if self.deterministic: 60 | return torch.Tensor([0.0]) 61 | logtwopi = np.log(2.0 * np.pi) 62 | return 0.5 * torch.sum( 63 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 64 | dim=dims, 65 | ) 66 | 67 | def mode(self): 68 | return self.mean 69 | 70 | 71 | def normal_kl(mean1, logvar1, mean2, logvar2): 72 | """ 73 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 74 | Compute the KL divergence between two gaussians. 75 | Shapes are automatically broadcasted, so batches can be compared to 76 | scalars, among other use cases. 77 | """ 78 | tensor = None 79 | for obj in (mean1, logvar1, mean2, logvar2): 80 | if isinstance(obj, torch.Tensor): 81 | tensor = obj 82 | break 83 | assert tensor is not None, "at least one argument must be a Tensor" 84 | 85 | # Force variances to be Tensors. Broadcasting helps convert scalars to 86 | # Tensors, but it does not work for torch.exp(). 87 | logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] 88 | 89 | return 0.5 * ( 90 | -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 91 | ) 92 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/utils/decode.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | 4 | import argbind 5 | import numpy as np 6 | import torch 7 | from audiotools import AudioSignal 8 | from tqdm import tqdm 9 | 10 | from ..model import DACFile 11 | from . import load_model 12 | 13 | warnings.filterwarnings("ignore", category=UserWarning) 14 | 15 | 16 | @argbind.bind(group="decode", positional=True, without_prefix=True) 17 | @torch.inference_mode() 18 | @torch.no_grad() 19 | def decode( 20 | input: str, 21 | output: str = "", 22 | weights_path: str = "", 23 | model_tag: str = "latest", 24 | model_bitrate: str = "8kbps", 25 | device: str = "cuda", 26 | model_type: str = "44khz", 27 | verbose: bool = False, 28 | ): 29 | """Decode audio from codes. 30 | 31 | Parameters 32 | ---------- 33 | input : str 34 | Path to input directory or file 35 | output : str, optional 36 | Path to output directory, by default "". 37 | If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. 38 | weights_path : str, optional 39 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the 40 | model_tag and model_type. 41 | model_tag : str, optional 42 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. 43 | model_bitrate: str 44 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 45 | device : str, optional 46 | Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU. 47 | model_type : str, optional 48 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. 49 | """ 50 | generator = load_model( 51 | model_type=model_type, 52 | model_bitrate=model_bitrate, 53 | tag=model_tag, 54 | load_path=weights_path, 55 | ) 56 | generator.to(device) 57 | generator.eval() 58 | 59 | # Find all .dac files in input directory 60 | _input = Path(input) 61 | input_files = list(_input.glob("**/*.dac")) 62 | 63 | # If input is a .dac file, add it to the list 64 | if _input.suffix == ".dac": 65 | input_files.append(_input) 66 | 67 | # Create output directory 68 | output = Path(output) 69 | output.mkdir(parents=True, exist_ok=True) 70 | 71 | for i in tqdm(range(len(input_files)), desc=f"Decoding files"): 72 | # Load file 73 | artifact = DACFile.load(input_files[i]) 74 | 75 | # Reconstruct audio from codes 76 | recons = generator.decompress(artifact, verbose=verbose) 77 | 78 | # Compute output path 79 | relative_path = input_files[i].relative_to(input) 80 | output_dir = output / relative_path.parent 81 | if not relative_path.name: 82 | output_dir = output 83 | relative_path = input_files[i] 84 | output_name = relative_path.with_suffix(".wav").name 85 | output_path = output_dir / output_name 86 | output_path.parent.mkdir(parents=True, exist_ok=True) 87 | 88 | # Write to file 89 | recons.write(output_path) 90 | 91 | 92 | if __name__ == "__main__": 93 | args = argbind.parse_args() 94 | with argbind.scope(args): 95 | decode() 96 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/media_utils.py: -------------------------------------------------------------------------------- 1 | """Media utilities for audio/video processing.""" 2 | 3 | import os 4 | import subprocess 5 | from pathlib import Path 6 | from typing import Optional 7 | 8 | from loguru import logger 9 | 10 | 11 | class MediaProcessingError(Exception): 12 | """Exception raised for media processing errors.""" 13 | pass 14 | 15 | 16 | def merge_audio_video( 17 | audio_path: str, 18 | video_path: str, 19 | output_path: str, 20 | overwrite: bool = True, 21 | quality: str = "high" 22 | ) -> str: 23 | """ 24 | Merge audio and video files using ffmpeg. 25 | 26 | Args: 27 | audio_path: Path to input audio file 28 | video_path: Path to input video file 29 | output_path: Path for output video file 30 | overwrite: Whether to overwrite existing output file 31 | quality: Quality setting ('high', 'medium', 'low') 32 | 33 | Returns: 34 | Path to the output file 35 | 36 | Raises: 37 | MediaProcessingError: If input files don't exist or ffmpeg fails 38 | FileNotFoundError: If ffmpeg is not installed 39 | """ 40 | # Validate input files 41 | if not os.path.exists(audio_path): 42 | raise MediaProcessingError(f"Audio file not found: {audio_path}") 43 | if not os.path.exists(video_path): 44 | raise MediaProcessingError(f"Video file not found: {video_path}") 45 | 46 | # Create output directory if needed 47 | output_dir = Path(output_path).parent 48 | output_dir.mkdir(parents=True, exist_ok=True) 49 | 50 | # Quality settings 51 | quality_settings = { 52 | "high": ["-b:a", "192k"], 53 | "medium": ["-b:a", "128k"], 54 | "low": ["-b:a", "96k"] 55 | } 56 | 57 | # Build ffmpeg command 58 | ffmpeg_command = [ 59 | "ffmpeg", 60 | "-i", video_path, 61 | "-i", audio_path, 62 | "-c:v", "copy", 63 | "-c:a", "aac", 64 | "-ac", "2", 65 | "-af", "pan=stereo|c0=c0|c1=c0", 66 | "-map", "0:v:0", 67 | "-map", "1:a:0", 68 | *quality_settings.get(quality, quality_settings["high"]), 69 | ] 70 | 71 | if overwrite: 72 | ffmpeg_command.append("-y") 73 | 74 | ffmpeg_command.append(output_path) 75 | 76 | try: 77 | logger.info(f"Merging audio '{audio_path}' with video '{video_path}'") 78 | process = subprocess.Popen( 79 | ffmpeg_command, 80 | stdout=subprocess.PIPE, 81 | stderr=subprocess.PIPE, 82 | text=True 83 | ) 84 | stdout, stderr = process.communicate() 85 | 86 | if process.returncode != 0: 87 | error_msg = f"FFmpeg failed with return code {process.returncode}: {stderr}" 88 | logger.error(error_msg) 89 | raise MediaProcessingError(error_msg) 90 | else: 91 | logger.info(f"Successfully merged video saved to: {output_path}") 92 | 93 | except FileNotFoundError: 94 | raise FileNotFoundError( 95 | "ffmpeg not found. Please install ffmpeg: " 96 | "https://ffmpeg.org/download.html" 97 | ) 98 | except Exception as e: 99 | raise MediaProcessingError(f"Unexpected error during media processing: {e}") 100 | 101 | return output_path 102 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/utils/encode.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from pathlib import Path 4 | 5 | import argbind 6 | import numpy as np 7 | import torch 8 | from audiotools import AudioSignal 9 | from audiotools.core import util 10 | from tqdm import tqdm 11 | 12 | from . import load_model 13 | 14 | warnings.filterwarnings("ignore", category=UserWarning) 15 | 16 | 17 | @argbind.bind(group="encode", positional=True, without_prefix=True) 18 | @torch.inference_mode() 19 | @torch.no_grad() 20 | def encode( 21 | input: str, 22 | output: str = "", 23 | weights_path: str = "", 24 | model_tag: str = "latest", 25 | model_bitrate: str = "8kbps", 26 | n_quantizers: int = None, 27 | device: str = "cuda", 28 | model_type: str = "44khz", 29 | win_duration: float = 5.0, 30 | verbose: bool = False, 31 | ): 32 | """Encode audio files in input path to .dac format. 33 | 34 | Parameters 35 | ---------- 36 | input : str 37 | Path to input audio file or directory 38 | output : str, optional 39 | Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. 40 | weights_path : str, optional 41 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the 42 | model_tag and model_type. 43 | model_tag : str, optional 44 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. 45 | model_bitrate: str 46 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 47 | n_quantizers : int, optional 48 | Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate. 49 | device : str, optional 50 | Device to use, by default "cuda" 51 | model_type : str, optional 52 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. 53 | """ 54 | generator = load_model( 55 | model_type=model_type, 56 | model_bitrate=model_bitrate, 57 | tag=model_tag, 58 | load_path=weights_path, 59 | ) 60 | generator.to(device) 61 | generator.eval() 62 | kwargs = {"n_quantizers": n_quantizers} 63 | 64 | # Find all audio files in input path 65 | input = Path(input) 66 | audio_files = util.find_audio(input) 67 | 68 | output = Path(output) 69 | output.mkdir(parents=True, exist_ok=True) 70 | 71 | for i in tqdm(range(len(audio_files)), desc="Encoding files"): 72 | # Load file 73 | signal = AudioSignal(audio_files[i]) 74 | 75 | # Encode audio to .dac format 76 | artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs) 77 | 78 | # Compute output path 79 | relative_path = audio_files[i].relative_to(input) 80 | output_dir = output / relative_path.parent 81 | if not relative_path.name: 82 | output_dir = output 83 | relative_path = audio_files[i] 84 | output_name = relative_path.with_suffix(".dac").name 85 | output_path = output_dir / output_name 86 | output_path.parent.mkdir(parents=True, exist_ok=True) 87 | 88 | artifact.save(output_path) 89 | 90 | 91 | if __name__ == "__main__": 92 | args = argbind.parse_args() 93 | with argbind.scope(args): 94 | encode() 95 | -------------------------------------------------------------------------------- /tests/test_media_utils.py: -------------------------------------------------------------------------------- 1 | """Tests for media utilities.""" 2 | 3 | import pytest 4 | import tempfile 5 | import os 6 | from unittest.mock import patch, MagicMock 7 | 8 | from hunyuanvideo_foley.utils.media_utils import merge_audio_video, MediaProcessingError 9 | 10 | 11 | class TestMergeAudioVideo: 12 | """Test cases for merge_audio_video function.""" 13 | 14 | def test_invalid_audio_path(self): 15 | """Test with non-existent audio file.""" 16 | with pytest.raises(MediaProcessingError, match="Audio file not found"): 17 | merge_audio_video("nonexistent.wav", "video.mp4", "output.mp4") 18 | 19 | def test_invalid_video_path(self): 20 | """Test with non-existent video file.""" 21 | with tempfile.NamedTemporaryFile(suffix='.wav') as audio_file: 22 | with pytest.raises(MediaProcessingError, match="Video file not found"): 23 | merge_audio_video(audio_file.name, "nonexistent.mp4", "output.mp4") 24 | 25 | @patch('subprocess.Popen') 26 | def test_successful_merge(self, mock_popen): 27 | """Test successful merge operation.""" 28 | # Create temporary files 29 | with tempfile.NamedTemporaryFile(suffix='.wav') as audio_file, \ 30 | tempfile.NamedTemporaryFile(suffix='.mp4') as video_file, \ 31 | tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as output_file: 32 | 33 | # Mock successful subprocess 34 | mock_process = MagicMock() 35 | mock_process.returncode = 0 36 | mock_process.communicate.return_value = ("", "") 37 | mock_popen.return_value = mock_process 38 | 39 | result = merge_audio_video( 40 | audio_file.name, 41 | video_file.name, 42 | output_file.name 43 | ) 44 | 45 | assert result == output_file.name 46 | mock_popen.assert_called_once() 47 | 48 | # Cleanup 49 | os.unlink(output_file.name) 50 | 51 | @patch('subprocess.Popen') 52 | def test_ffmpeg_failure(self, mock_popen): 53 | """Test ffmpeg failure handling.""" 54 | # Create temporary files 55 | with tempfile.NamedTemporaryFile(suffix='.wav') as audio_file, \ 56 | tempfile.NamedTemporaryFile(suffix='.mp4') as video_file: 57 | 58 | # Mock failed subprocess 59 | mock_process = MagicMock() 60 | mock_process.returncode = 1 61 | mock_process.communicate.return_value = ("", "FFmpeg error") 62 | mock_popen.return_value = mock_process 63 | 64 | with pytest.raises(MediaProcessingError, match="FFmpeg failed"): 65 | merge_audio_video( 66 | audio_file.name, 67 | video_file.name, 68 | "output.mp4" 69 | ) 70 | 71 | @patch('subprocess.Popen', side_effect=FileNotFoundError) 72 | def test_ffmpeg_not_found(self, mock_popen): 73 | """Test ffmpeg not found error.""" 74 | with tempfile.NamedTemporaryFile(suffix='.wav') as audio_file, \ 75 | tempfile.NamedTemporaryFile(suffix='.mp4') as video_file: 76 | 77 | with pytest.raises(FileNotFoundError, match="ffmpeg not found"): 78 | merge_audio_video( 79 | audio_file.name, 80 | video_file.name, 81 | "output.mp4" 82 | ) -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import argbind 4 | from audiotools import ml 5 | 6 | from ..model import DAC 7 | Accelerator = ml.Accelerator 8 | 9 | __MODEL_LATEST_TAGS__ = { 10 | ("44khz", "8kbps"): "0.0.1", 11 | ("24khz", "8kbps"): "0.0.4", 12 | ("16khz", "8kbps"): "0.0.5", 13 | ("44khz", "16kbps"): "1.0.0", 14 | } 15 | 16 | __MODEL_URLS__ = { 17 | ( 18 | "44khz", 19 | "0.0.1", 20 | "8kbps", 21 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth", 22 | ( 23 | "24khz", 24 | "0.0.4", 25 | "8kbps", 26 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth", 27 | ( 28 | "16khz", 29 | "0.0.5", 30 | "8kbps", 31 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth", 32 | ( 33 | "44khz", 34 | "1.0.0", 35 | "16kbps", 36 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth", 37 | } 38 | 39 | 40 | @argbind.bind(group="download", positional=True, without_prefix=True) 41 | def download( 42 | model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest" 43 | ): 44 | """ 45 | Function that downloads the weights file from URL if a local cache is not found. 46 | 47 | Parameters 48 | ---------- 49 | model_type : str 50 | The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". 51 | model_bitrate: str 52 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 53 | Only 44khz model supports 16kbps. 54 | tag : str 55 | The tag of the model to download. Defaults to "latest". 56 | 57 | Returns 58 | ------- 59 | Path 60 | Directory path required to load model via audiotools. 61 | """ 62 | model_type = model_type.lower() 63 | tag = tag.lower() 64 | 65 | assert model_type in [ 66 | "44khz", 67 | "24khz", 68 | "16khz", 69 | ], "model_type must be one of '44khz', '24khz', or '16khz'" 70 | 71 | assert model_bitrate in [ 72 | "8kbps", 73 | "16kbps", 74 | ], "model_bitrate must be one of '8kbps', or '16kbps'" 75 | 76 | if tag == "latest": 77 | tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)] 78 | 79 | download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None) 80 | 81 | if download_link is None: 82 | raise ValueError( 83 | f"Could not find model with tag {tag} and model type {model_type}" 84 | ) 85 | 86 | local_path = ( 87 | Path.home() 88 | / ".cache" 89 | / "descript" 90 | / "dac" 91 | / f"weights_{model_type}_{model_bitrate}_{tag}.pth" 92 | ) 93 | if not local_path.exists(): 94 | local_path.parent.mkdir(parents=True, exist_ok=True) 95 | 96 | # Download the model 97 | import requests 98 | 99 | response = requests.get(download_link) 100 | 101 | if response.status_code != 200: 102 | raise ValueError( 103 | f"Could not download model. Received response code {response.status_code}" 104 | ) 105 | local_path.write_bytes(response.content) 106 | 107 | return local_path 108 | 109 | 110 | def load_model( 111 | model_type: str = "44khz", 112 | model_bitrate: str = "8kbps", 113 | tag: str = "latest", 114 | load_path: str = None, 115 | ): 116 | if not load_path: 117 | load_path = download( 118 | model_type=model_type, model_bitrate=model_bitrate, tag=tag 119 | ) 120 | generator = DAC.load(load_path) 121 | return generator 122 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | """Configuration utilities for the HunyuanVideo-Foley project.""" 2 | 3 | import yaml 4 | from pathlib import Path 5 | from typing import Any, Dict, List, Union 6 | 7 | class AttributeDict: 8 | 9 | def __init__(self, data: Union[Dict, List, Any]): 10 | if isinstance(data, dict): 11 | for key, value in data.items(): 12 | if isinstance(value, (dict, list)): 13 | value = AttributeDict(value) 14 | setattr(self, self._sanitize_key(key), value) 15 | elif isinstance(data, list): 16 | self._list = [AttributeDict(item) if isinstance(item, (dict, list)) else item 17 | for item in data] 18 | else: 19 | self._value = data 20 | 21 | def _sanitize_key(self, key: str) -> str: 22 | import re 23 | sanitized = re.sub(r'[^a-zA-Z0-9_]', '_', str(key)) 24 | if sanitized[0].isdigit(): 25 | sanitized = f'_{sanitized}' 26 | return sanitized 27 | 28 | def __getitem__(self, key): 29 | if hasattr(self, '_list'): 30 | return self._list[key] 31 | return getattr(self, self._sanitize_key(key)) 32 | 33 | def __setitem__(self, key, value): 34 | if hasattr(self, '_list'): 35 | self._list[key] = value 36 | else: 37 | setattr(self, self._sanitize_key(key), value) 38 | 39 | def __iter__(self): 40 | if hasattr(self, '_list'): 41 | return iter(self._list) 42 | return iter(self.__dict__.keys()) 43 | 44 | def __len__(self): 45 | if hasattr(self, '_list'): 46 | return len(self._list) 47 | return len(self.__dict__) 48 | 49 | def get(self, key, default=None): 50 | try: 51 | return self[key] 52 | except (KeyError, AttributeError, IndexError): 53 | return default 54 | 55 | def keys(self): 56 | if hasattr(self, '_list'): 57 | return range(len(self._list)) 58 | elif hasattr(self, '_value'): 59 | return [] 60 | else: 61 | return [key for key in self.__dict__.keys() if not key.startswith('_')] 62 | 63 | def values(self): 64 | if hasattr(self, '_list'): 65 | return self._list 66 | elif hasattr(self, '_value'): 67 | return [self._value] 68 | else: 69 | return [value for key, value in self.__dict__.items() if not key.startswith('_')] 70 | 71 | def items(self): 72 | if hasattr(self, '_list'): 73 | return enumerate(self._list) 74 | elif hasattr(self, '_value'): 75 | return [] 76 | else: 77 | return [(key, value) for key, value in self.__dict__.items() if not key.startswith('_')] 78 | 79 | def __repr__(self): 80 | if hasattr(self, '_list'): 81 | return f"AttributeDict({self._list})" 82 | elif hasattr(self, '_value'): 83 | return f"AttributeDict({self._value})" 84 | return f"AttributeDict({dict(self.__dict__)})" 85 | 86 | def to_dict(self) -> Union[Dict, List, Any]: 87 | if hasattr(self, '_list'): 88 | return [item.to_dict() if isinstance(item, AttributeDict) else item 89 | for item in self._list] 90 | elif hasattr(self, '_value'): 91 | return self._value 92 | else: 93 | result = {} 94 | for key, value in self.__dict__.items(): 95 | if isinstance(value, AttributeDict): 96 | result[key] = value.to_dict() 97 | else: 98 | result[key] = value 99 | return result 100 | 101 | def load_yaml(file_path: str, encoding: str = 'utf-8') -> AttributeDict: 102 | try: 103 | with open(file_path, 'r', encoding=encoding) as file: 104 | data = yaml.safe_load(file) 105 | return AttributeDict(data) 106 | except FileNotFoundError: 107 | raise FileNotFoundError(f"YAML file not found: {file_path}") 108 | except yaml.YAMLError as e: 109 | raise yaml.YAMLError(f"YAML format error: {e}") 110 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/helper.py: -------------------------------------------------------------------------------- 1 | import collections.abc 2 | from itertools import repeat 3 | import importlib 4 | import yaml 5 | import time 6 | 7 | def default(value, default_val): 8 | return default_val if value is None else value 9 | 10 | 11 | def default_dtype(value, default_val): 12 | if value is not None: 13 | assert isinstance(value, type(default_val)), f"Expect {type(default_val)}, got {type(value)}." 14 | return value 15 | return default_val 16 | 17 | 18 | def repeat_interleave(lst, num_repeats): 19 | return [item for item in lst for _ in range(num_repeats)] 20 | 21 | 22 | def _ntuple(n): 23 | def parse(x): 24 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 25 | x = tuple(x) 26 | if len(x) == 1: 27 | x = tuple(repeat(x[0], n)) 28 | return x 29 | return tuple(repeat(x, n)) 30 | 31 | return parse 32 | 33 | 34 | to_1tuple = _ntuple(1) 35 | to_2tuple = _ntuple(2) 36 | to_3tuple = _ntuple(3) 37 | to_4tuple = _ntuple(4) 38 | 39 | 40 | def as_tuple(x): 41 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 42 | return tuple(x) 43 | if x is None or isinstance(x, (int, float, str)): 44 | return (x,) 45 | else: 46 | raise ValueError(f"Unknown type {type(x)}") 47 | 48 | 49 | def as_list_of_2tuple(x): 50 | x = as_tuple(x) 51 | if len(x) == 1: 52 | x = (x[0], x[0]) 53 | assert len(x) % 2 == 0, f"Expect even length, got {len(x)}." 54 | lst = [] 55 | for i in range(0, len(x), 2): 56 | lst.append((x[i], x[i + 1])) 57 | return lst 58 | 59 | 60 | def find_multiple(n: int, k: int) -> int: 61 | assert k > 0 62 | if n % k == 0: 63 | return n 64 | return n - (n % k) + k 65 | 66 | 67 | def merge_dicts(dict1, dict2): 68 | for key, value in dict2.items(): 69 | if key in dict1 and isinstance(dict1[key], dict) and isinstance(value, dict): 70 | merge_dicts(dict1[key], value) 71 | else: 72 | dict1[key] = value 73 | return dict1 74 | 75 | 76 | def merge_yaml_files(file_list): 77 | merged_config = {} 78 | 79 | for file in file_list: 80 | with open(file, "r", encoding="utf-8") as f: 81 | config = yaml.safe_load(f) 82 | if config: 83 | # Remove the first level 84 | for key, value in config.items(): 85 | if isinstance(value, dict): 86 | merged_config = merge_dicts(merged_config, value) 87 | else: 88 | merged_config[key] = value 89 | 90 | return merged_config 91 | 92 | 93 | def merge_dict(file_list): 94 | merged_config = {} 95 | 96 | for file in file_list: 97 | with open(file, "r", encoding="utf-8") as f: 98 | config = yaml.safe_load(f) 99 | if config: 100 | merged_config = merge_dicts(merged_config, config) 101 | 102 | return merged_config 103 | 104 | 105 | def get_obj_from_str(string, reload=False): 106 | module, cls = string.rsplit(".", 1) 107 | if reload: 108 | module_imp = importlib.import_module(module) 109 | importlib.reload(module_imp) 110 | return getattr(importlib.import_module(module, package=None), cls) 111 | 112 | 113 | def readable_time(seconds): 114 | """ Convert time seconds to a readable format: DD Days, HH Hours, MM Minutes, SS Seconds """ 115 | seconds = int(seconds) 116 | days, seconds = divmod(seconds, 86400) 117 | hours, seconds = divmod(seconds, 3600) 118 | minutes, seconds = divmod(seconds, 60) 119 | if days > 0: 120 | return f"{days} Days, {hours} Hours, {minutes} Minutes, {seconds} Seconds" 121 | if hours > 0: 122 | return f"{hours} Hours, {minutes} Minutes, {seconds} Seconds" 123 | if minutes > 0: 124 | return f"{minutes} Minutes, {seconds} Seconds" 125 | return f"{seconds} Seconds" 126 | 127 | 128 | def get_obj_from_cfg(cfg, reload=False): 129 | if isinstance(cfg, str): 130 | return get_obj_from_str(cfg, reload) 131 | elif isinstance(cfg, (list, tuple,)): 132 | return tuple([get_obj_from_str(c, reload) for c in cfg]) 133 | else: 134 | raise NotImplementedError(f"Not implemented for {type(cfg)}.") 135 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # 安装指南 - HunyuanVideo-Foley 2 | 3 | 本文档提供了将 HunyuanVideo-Foley 作为 Python 包安装和使用的详细指南。 4 | 5 | ## 安装方式 6 | 7 | ### 方式1:从源码安装(推荐) 8 | 9 | ```bash 10 | # 克隆仓库 11 | git clone https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley 12 | cd HunyuanVideo-Foley 13 | 14 | # 安装包(开发模式) 15 | pip install -e . 16 | 17 | # 或安装包含所有可选依赖 18 | pip install -e .[all] 19 | ``` 20 | 21 | ### 方式2:直接从GitHub安装 22 | 23 | ```bash 24 | pip install git+https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley.git 25 | ``` 26 | 27 | ### 方式3:构建wheel包安装 28 | 29 | ```bash 30 | # 在项目根目录下 31 | python setup.py bdist_wheel 32 | pip install dist/hunyuanvideo_foley-1.0.0-py3-none-any.whl 33 | ``` 34 | 35 | ## 特殊依赖安装 36 | 37 | 由于某些依赖不在PyPI上,需要单独安装: 38 | 39 | ```bash 40 | # 安装audiotools(必需) 41 | pip install git+https://github.com/descriptinc/audiotools 42 | 43 | # 安装特定版本的transformers(支持SigLIP2) 44 | pip install git+https://github.com/huggingface/transformers@v4.49.0-SigLIP-2 45 | ``` 46 | 47 | ## 可选依赖安装 48 | 49 | ```bash 50 | # 安装开发依赖 51 | pip install hunyuanvideo-foley[dev] 52 | 53 | # 安装测试依赖 54 | pip install hunyuanvideo-foley[test] 55 | 56 | # 安装Gradio界面依赖 57 | pip install hunyuanvideo-foley[gradio] 58 | 59 | # 安装所有可选依赖 60 | pip install hunyuanvideo-foley[all] 61 | ``` 62 | 63 | ## 验证安装 64 | 65 | ```bash 66 | # 检查包是否正确安装 67 | python -c "import hunyuanvideo_foley; print(hunyuanvideo_foley.__version__)" 68 | 69 | # 检查命令行工具 70 | hunyuanvideo-foley --help 71 | ``` 72 | 73 | ## 使用方法 74 | 75 | ### 1. 作为Python包使用 76 | 77 | ```python 78 | import hunyuanvideo_foley as hvf 79 | 80 | # 加载模型 81 | model_dict, cfg = hvf.load_model( 82 | model_path="path/to/model", 83 | config_path="configs/hunyuanvideo-foley-xxl.yaml" 84 | ) 85 | 86 | # 处理特征 87 | visual_feats, text_feats, audio_len = hvf.feature_process( 88 | video_path="video.mp4", 89 | prompt="footsteps on gravel", 90 | model_dict=model_dict, 91 | cfg=cfg 92 | ) 93 | 94 | # 生成音频 95 | audio, sample_rate = hvf.denoise_process( 96 | visual_feats, text_feats, audio_len, 97 | model_dict, cfg 98 | ) 99 | ``` 100 | 101 | ### 2. 使用命令行工具 102 | 103 | ```bash 104 | # 单个视频处理 105 | hunyuanvideo-foley \ 106 | --model_path ./pretrained_models \ 107 | --single_video video.mp4 \ 108 | --single_prompt "footsteps on gravel" \ 109 | --output_dir ./outputs 110 | 111 | # 批量处理 112 | hunyuanvideo-foley \ 113 | --model_path ./pretrained_models \ 114 | --csv_path batch_videos.csv \ 115 | --output_dir ./outputs 116 | 117 | # 启动Gradio界面 118 | hunyuanvideo-foley --gradio --model_path ./pretrained_models 119 | ``` 120 | 121 | ### 3. 使用原始脚本(向后兼容) 122 | 123 | ```bash 124 | # 使用原始infer.py脚本 125 | python infer.py --model_path ./pretrained_models --single_video video.mp4 --single_prompt "audio description" 126 | 127 | # 启动Gradio应用 128 | export HIFI_FOLEY_MODEL_PATH=./pretrained_models 129 | python gradio_app.py 130 | ``` 131 | 132 | ## 开发环境设置 133 | 134 | 如果你想参与开发: 135 | 136 | ```bash 137 | # 克隆项目 138 | git clone https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley 139 | cd HunyuanVideo-Foley 140 | 141 | # 安装开发版本 142 | pip install -e .[dev] 143 | 144 | # 安装pre-commit钩子 145 | pre-commit install 146 | 147 | # 运行测试 148 | python -m pytest 149 | 150 | # 代码格式化 151 | black --line-length 120 . 152 | isort --profile black . 153 | 154 | # 类型检查 155 | mypy --ignore-missing-imports . 156 | ``` 157 | 158 | ## 系统要求 159 | 160 | - **Python**: 3.8+ 161 | - **操作系统**: Linux(主要支持),macOS,Windows 162 | - **GPU内存**: 推荐 ≥24GB VRAM(如RTX 3090/4090) 163 | - **CUDA版本**: 12.4 或 11.8(推荐) 164 | 165 | ## 故障排除 166 | 167 | ### 常见问题 168 | 169 | 1. **ImportError: No module named 'audiotools'** 170 | ```bash 171 | pip install git+https://github.com/descriptinc/audiotools 172 | ``` 173 | 174 | 2. **CUDA内存不足** 175 | - 使用较小的批次大小 176 | - 确保GPU有足够的VRAM(推荐24GB+) 177 | 178 | 3. **transformers版本问题** 179 | ```bash 180 | pip install git+https://github.com/huggingface/transformers@v4.49.0-SigLIP-2 181 | ``` 182 | 183 | ### 获取帮助 184 | 185 | - 查看项目README: [GitHub](https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley) 186 | - 报告问题: [GitHub Issues](https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley/issues) 187 | - 论文: [arXiv:2508.16930](https://arxiv.org/abs/2508.16930) 188 | 189 | ## 模型下载 190 | 191 | ```bash 192 | # 使用HuggingFace Hub 193 | git clone https://huggingface.co/tencent/HunyuanVideo-Foley 194 | 195 | # 或使用huggingface-cli 196 | huggingface-cli download tencent/HunyuanVideo-Foley 197 | ``` 198 | 199 | ## 配置文件 200 | 201 | 包安装后,配置文件位于: 202 | - `hunyuanvideo_foley/configs/` 目录 203 | - 默认配置:`configs/hunyuanvideo-foley-xxl.yaml` -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/utils.py: -------------------------------------------------------------------------------- 1 | from hashlib import md5 2 | from pathlib import Path 3 | import subprocess 4 | 5 | import requests 6 | from tqdm import tqdm 7 | 8 | PARENT_LINK = "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a" 9 | FNAME2LINK = { 10 | # S3: Synchability: AudioSet (run 2) 11 | "24-01-22T20-34-52.pt": f"{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/24-01-22T20-34-52.pt", 12 | "cfg-24-01-22T20-34-52.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/cfg-24-01-22T20-34-52.yaml", 13 | # S2: Synchformer: AudioSet (run 2) 14 | "24-01-04T16-39-21.pt": f"{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/24-01-04T16-39-21.pt", 15 | "cfg-24-01-04T16-39-21.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml", 16 | # S2: Synchformer: AudioSet (run 1) 17 | "23-08-28T11-23-23.pt": f"{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/23-08-28T11-23-23.pt", 18 | "cfg-23-08-28T11-23-23.yaml": f"{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/cfg-23-08-28T11-23-23.yaml", 19 | # S2: Synchformer: LRS3 (run 2) 20 | "23-12-23T18-33-57.pt": f"{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/23-12-23T18-33-57.pt", 21 | "cfg-23-12-23T18-33-57.yaml": f"{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/cfg-23-12-23T18-33-57.yaml", 22 | # S2: Synchformer: VGS (run 2) 23 | "24-01-02T10-00-53.pt": f"{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/24-01-02T10-00-53.pt", 24 | "cfg-24-01-02T10-00-53.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/cfg-24-01-02T10-00-53.yaml", 25 | # SparseSync: ft VGGSound-Full 26 | "22-09-21T21-00-52.pt": f"{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/22-09-21T21-00-52.pt", 27 | "cfg-22-09-21T21-00-52.yaml": f"{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/cfg-22-09-21T21-00-52.yaml", 28 | # SparseSync: ft VGGSound-Sparse 29 | "22-07-28T15-49-45.pt": f"{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/22-07-28T15-49-45.pt", 30 | "cfg-22-07-28T15-49-45.yaml": f"{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/cfg-22-07-28T15-49-45.yaml", 31 | # SparseSync: only pt on LRS3 32 | "22-07-13T22-25-49.pt": f"{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/22-07-13T22-25-49.pt", 33 | "cfg-22-07-13T22-25-49.yaml": f"{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/cfg-22-07-13T22-25-49.yaml", 34 | # SparseSync: feature extractors 35 | "ResNetAudio-22-08-04T09-51-04.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-04T09-51-04.pt", # 2s 36 | "ResNetAudio-22-08-03T23-14-49.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-49.pt", # 3s 37 | "ResNetAudio-22-08-03T23-14-28.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-28.pt", # 4s 38 | "ResNetAudio-22-06-24T08-10-33.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T08-10-33.pt", # 5s 39 | "ResNetAudio-22-06-24T17-31-07.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T17-31-07.pt", # 6s 40 | "ResNetAudio-22-06-24T23-57-11.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T23-57-11.pt", # 7s 41 | "ResNetAudio-22-06-25T04-35-42.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-25T04-35-42.pt", # 8s 42 | } 43 | 44 | 45 | def check_if_file_exists_else_download(path, fname2link=FNAME2LINK, chunk_size=1024): 46 | """Checks if file exists, if not downloads it from the link to the path""" 47 | path = Path(path) 48 | if not path.exists(): 49 | path.parent.mkdir(exist_ok=True, parents=True) 50 | link = fname2link.get(path.name, None) 51 | if link is None: 52 | raise ValueError( 53 | f"Cant find the checkpoint file: {path}.", f"Please download it manually and ensure the path exists." 54 | ) 55 | with requests.get(fname2link[path.name], stream=True) as r: 56 | total_size = int(r.headers.get("content-length", 0)) 57 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 58 | with open(path, "wb") as f: 59 | for data in r.iter_content(chunk_size=chunk_size): 60 | if data: 61 | f.write(data) 62 | pbar.update(chunk_size) 63 | 64 | 65 | def which_ffmpeg() -> str: 66 | """Determines the path to ffmpeg library 67 | Returns: 68 | str -- path to the library 69 | """ 70 | result = subprocess.run(["which", "ffmpeg"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 71 | ffmpeg_path = result.stdout.decode("utf-8").replace("\n", "") 72 | return ffmpeg_path 73 | 74 | 75 | def get_md5sum(path): 76 | hash_md5 = md5() 77 | with open(path, "rb") as f: 78 | for chunk in iter(lambda: f.read(4096 * 8), b""): 79 | hash_md5.update(chunk) 80 | md5sum = hash_md5.hexdigest() 81 | return md5sum 82 | 83 | 84 | class Config: 85 | def __init__(self, **kwargs): 86 | for k, v in kwargs.items(): 87 | setattr(self, k, v) 88 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/embed_layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from ...utils.helper import to_2tuple, to_1tuple 6 | 7 | class PatchEmbed1D(nn.Module): 8 | """1D Audio to Patch Embedding 9 | 10 | A convolution based approach to patchifying a 1D audio w/ embedding projection. 11 | 12 | Based on the impl in https://github.com/google-research/vision_transformer 13 | 14 | Hacked together by / Copyright 2020 Ross Wightman 15 | """ 16 | 17 | def __init__( 18 | self, 19 | patch_size=1, 20 | in_chans=768, 21 | embed_dim=768, 22 | norm_layer=None, 23 | flatten=True, 24 | bias=True, 25 | dtype=None, 26 | device=None, 27 | ): 28 | factory_kwargs = {"dtype": dtype, "device": device} 29 | super().__init__() 30 | patch_size = to_1tuple(patch_size) 31 | self.patch_size = patch_size 32 | self.flatten = flatten 33 | 34 | self.proj = nn.Conv1d( 35 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs 36 | ) 37 | nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1)) 38 | if bias: 39 | nn.init.zeros_(self.proj.bias) 40 | 41 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 42 | 43 | def forward(self, x): 44 | assert ( 45 | x.shape[2] % self.patch_size[0] == 0 46 | ), f"The patch_size of {self.patch_size[0]} must be divisible by the token number ({x.shape[2]}) of x." 47 | 48 | x = self.proj(x) 49 | if self.flatten: 50 | x = x.transpose(1, 2) # BCN -> BNC 51 | x = self.norm(x) 52 | return x 53 | 54 | 55 | class ConditionProjection(nn.Module): 56 | """ 57 | Projects condition embeddings. Also handles dropout for classifier-free guidance. 58 | 59 | Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py 60 | """ 61 | 62 | def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): 63 | factory_kwargs = {'dtype': dtype, 'device': device} 64 | super().__init__() 65 | self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs) 66 | self.act_1 = act_layer() 67 | self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs) 68 | 69 | def forward(self, caption): 70 | hidden_states = self.linear_1(caption) 71 | hidden_states = self.act_1(hidden_states) 72 | hidden_states = self.linear_2(hidden_states) 73 | return hidden_states 74 | 75 | 76 | def timestep_embedding(t, dim, max_period=10000): 77 | """ 78 | Create sinusoidal timestep embeddings. 79 | 80 | Args: 81 | t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. 82 | dim (int): the dimension of the output. 83 | max_period (int): controls the minimum frequency of the embeddings. 84 | 85 | Returns: 86 | embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. 87 | 88 | .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 89 | """ 90 | half = dim // 2 91 | freqs = torch.exp( 92 | -math.log(max_period) 93 | * torch.arange(start=0, end=half, dtype=torch.float32) 94 | / half 95 | ).to(device=t.device) 96 | args = t[:, None].float() * freqs[None] 97 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 98 | if dim % 2: 99 | embedding = torch.cat( 100 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 101 | ) 102 | return embedding 103 | 104 | 105 | class TimestepEmbedder(nn.Module): 106 | """ 107 | Embeds scalar timesteps into vector representations. 108 | """ 109 | def __init__(self, 110 | hidden_size, 111 | act_layer, 112 | frequency_embedding_size=256, 113 | max_period=10000, 114 | out_size=None, 115 | dtype=None, 116 | device=None 117 | ): 118 | factory_kwargs = {'dtype': dtype, 'device': device} 119 | super().__init__() 120 | self.frequency_embedding_size = frequency_embedding_size 121 | self.max_period = max_period 122 | if out_size is None: 123 | out_size = hidden_size 124 | 125 | self.mlp = nn.Sequential( 126 | nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs), 127 | act_layer(), 128 | nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), 129 | ) 130 | nn.init.normal_(self.mlp[0].weight, std=0.02) 131 | nn.init.normal_(self.mlp[2].weight, std=0.02) 132 | 133 | def forward(self, t): 134 | t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) 135 | t_emb = self.mlp(t_freq) 136 | return t_emb 137 | -------------------------------------------------------------------------------- /DEVELOPMENT.md: -------------------------------------------------------------------------------- 1 | # Development Guide 2 | 3 | This document provides guidelines for developing and contributing to the HunyuanVideo-Foley project. 4 | 5 | ## Code Style and Quality 6 | 7 | ### Code Formatting 8 | 9 | We use the following tools to maintain consistent code style: 10 | 11 | - **Black**: Code formatter with 120 character line length 12 | - **isort**: Import sorter compatible with Black 13 | - **flake8**: Linting and style checking 14 | - **mypy**: Static type checking 15 | 16 | ### Pre-commit Hooks 17 | 18 | Install pre-commit hooks to automatically format code before commits: 19 | 20 | ```bash 21 | pip install pre-commit 22 | pre-commit install 23 | ``` 24 | 25 | ### Manual Code Formatting 26 | 27 | Format code manually: 28 | 29 | ```bash 30 | # Format all Python files 31 | black --line-length 120 . 32 | 33 | # Sort imports 34 | isort --profile black --line-length 120 . 35 | 36 | # Check code style 37 | flake8 --max-line-length 120 38 | 39 | # Type checking 40 | mypy --ignore-missing-imports . 41 | ``` 42 | 43 | ## Project Structure 44 | 45 | ``` 46 | hunyuanvideo_foley/ 47 | ├── models/ # Model implementations 48 | │ ├── hifi_foley.py # Main model 49 | │ ├── nn/ # Neural network layers 50 | │ ├── dac_vae/ # Audio VAE 51 | │ └── synchformer/ # Synchronization model 52 | ├── utils/ # Utilities 53 | │ ├── config_utils.py # Configuration handling 54 | │ ├── feature_utils.py # Feature extraction 55 | │ ├── model_utils.py # Model loading/saving 56 | │ └── media_utils.py # Audio/video processing 57 | └── constants.py # Project constants 58 | ``` 59 | 60 | ## Coding Standards 61 | 62 | ### Error Handling 63 | 64 | - Use custom exceptions for domain-specific errors 65 | - Always validate inputs at function boundaries 66 | - Log errors with appropriate levels (ERROR, WARNING, INFO) 67 | - Provide helpful error messages to users 68 | 69 | ### Type Hints 70 | 71 | - Add type hints to all function parameters and return values 72 | - Use `Optional[Type]` for nullable parameters 73 | - Import types from `typing` module 74 | 75 | ### Documentation 76 | 77 | - Add docstrings to all public functions and classes 78 | - Use Google-style docstrings 79 | - Document parameters, return values, and exceptions 80 | 81 | ### Example Function 82 | 83 | ```python 84 | def process_video( 85 | video_path: str, 86 | max_duration: Optional[float] = None 87 | ) -> Tuple[np.ndarray, float]: 88 | """ 89 | Process video file and extract frames. 90 | 91 | Args: 92 | video_path: Path to input video file 93 | max_duration: Maximum duration in seconds (optional) 94 | 95 | Returns: 96 | Tuple of (frames array, duration in seconds) 97 | 98 | Raises: 99 | FileNotFoundError: If video file doesn't exist 100 | VideoProcessingError: If video processing fails 101 | """ 102 | if not os.path.exists(video_path): 103 | raise FileNotFoundError(f"Video file not found: {video_path}") 104 | 105 | # Implementation here... 106 | ``` 107 | 108 | ## Testing 109 | 110 | ### Running Tests 111 | 112 | ```bash 113 | # Run all tests 114 | python -m pytest 115 | 116 | # Run specific test file 117 | python -m pytest tests/test_feature_utils.py 118 | 119 | # Run with coverage 120 | python -m pytest --cov=hunyuanvideo_foley 121 | ``` 122 | 123 | ### Writing Tests 124 | 125 | - Place tests in `tests/` directory 126 | - Name test files as `test_*.py` 127 | - Use descriptive test function names 128 | - Test edge cases and error conditions 129 | 130 | ## Development Workflow 131 | 132 | 1. **Setup Environment** 133 | ```bash 134 | python -m venv venv 135 | source venv/bin/activate # Linux/Mac 136 | # or 137 | venv\Scripts\activate # Windows 138 | 139 | pip install -r requirements.txt 140 | pip install -e . 141 | ``` 142 | 143 | 2. **Install Development Tools** 144 | ```bash 145 | pre-commit install 146 | ``` 147 | 148 | 3. **Make Changes** 149 | - Follow the coding standards above 150 | - Add tests for new functionality 151 | - Update documentation as needed 152 | 153 | 4. **Run Quality Checks** 154 | ```bash 155 | black --check --line-length 120 . 156 | isort --check-only --profile black . 157 | flake8 --max-line-length 120 158 | mypy --ignore-missing-imports . 159 | pytest 160 | ``` 161 | 162 | 5. **Commit Changes** 163 | ```bash 164 | git add . 165 | git commit -m "feat: add new feature" 166 | ``` 167 | 168 | ## Performance Considerations 169 | 170 | - Use `torch.no_grad()` for inference-only code 171 | - Leverage GPU when available 172 | - Implement batch processing where possible 173 | - Profile code to identify bottlenecks 174 | 175 | ## Dependencies 176 | 177 | - Keep dependencies minimal and well-maintained 178 | - Pin versions for reproducibility 179 | - Separate development dependencies from runtime dependencies 180 | - Document any special installation requirements 181 | 182 | ## Configuration 183 | 184 | - Use centralized configuration in `constants.py` 185 | - Support environment variable overrides 186 | - Provide sensible defaults for all parameters 187 | - Validate configuration at startup -------------------------------------------------------------------------------- /hunyuanvideo_foley/cli.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Command Line Interface for HunyuanVideo-Foley 4 | 5 | Provides command-line access to the main inference functionality. 6 | """ 7 | 8 | import sys 9 | import argparse 10 | from pathlib import Path 11 | 12 | def main(): 13 | """Main CLI entry point.""" 14 | parser = argparse.ArgumentParser( 15 | description="HunyuanVideo-Foley: Generate Foley audio from video and text", 16 | formatter_class=argparse.RawDescriptionHelpFormatter, 17 | epilog=""" 18 | Examples: 19 | # Single video generation 20 | hunyuanvideo-foley --model_path ./models --single_video video.mp4 --single_prompt "footsteps on gravel" 21 | 22 | # Batch processing 23 | hunyuanvideo-foley --model_path ./models --csv_path batch.csv --output_dir ./outputs 24 | 25 | # Start Gradio interface 26 | hunyuanvideo-foley --gradio --model_path ./models 27 | """ 28 | ) 29 | 30 | parser.add_argument("--model_path", type=str, required=True, 31 | help="Path to the pretrained model directory") 32 | parser.add_argument("--config_path", type=str, 33 | default="configs/hunyuanvideo-foley-xxl.yaml", 34 | help="Path to the model configuration file") 35 | 36 | # Input options 37 | group_input = parser.add_mutually_exclusive_group(required=True) 38 | group_input.add_argument("--single_video", type=str, 39 | help="Path to single video file for processing") 40 | group_input.add_argument("--csv_path", type=str, 41 | help="Path to CSV file with video paths and prompts") 42 | group_input.add_argument("--gradio", action="store_true", 43 | help="Launch Gradio web interface") 44 | 45 | # Generation options 46 | parser.add_argument("--single_prompt", type=str, 47 | help="Text prompt for single video (required with --single_video)") 48 | parser.add_argument("--output_dir", type=str, default="./outputs", 49 | help="Output directory for generated audio files") 50 | parser.add_argument("--guidance_scale", type=float, default=4.5, 51 | help="Guidance scale for generation (default: 4.5)") 52 | parser.add_argument("--num_inference_steps", type=int, default=50, 53 | help="Number of inference steps (default: 50)") 54 | parser.add_argument("--neg_prompt", type=str, 55 | help="Negative prompt to avoid certain audio characteristics") 56 | 57 | # System options 58 | parser.add_argument("--device", type=str, default="auto", 59 | choices=["auto", "cpu", "cuda"], 60 | help="Device to use for inference") 61 | parser.add_argument("--gpu_id", type=int, default=0, 62 | help="GPU ID to use (default: 0)") 63 | parser.add_argument("--seed", type=int, default=42, 64 | help="Random seed for reproducible generation") 65 | 66 | args = parser.parse_args() 67 | 68 | # Validate arguments 69 | if args.single_video and not args.single_prompt: 70 | parser.error("--single_prompt is required when using --single_video") 71 | 72 | # Import here to avoid import errors if dependencies are missing 73 | try: 74 | if args.gradio: 75 | _launch_gradio(args) 76 | elif args.single_video: 77 | _process_single_video(args) 78 | elif args.csv_path: 79 | _process_batch(args) 80 | except ImportError as e: 81 | print(f"Error: Missing required dependencies. Please install with: pip install hunyuanvideo-foley[all]") 82 | print(f"Import error: {e}") 83 | sys.exit(1) 84 | except Exception as e: 85 | print(f"Error: {e}") 86 | sys.exit(1) 87 | 88 | def _launch_gradio(args): 89 | """Launch Gradio web interface.""" 90 | import os 91 | os.environ["HIFI_FOLEY_MODEL_PATH"] = args.model_path 92 | 93 | # Import and launch gradio app 94 | import subprocess 95 | gradio_script = Path(__file__).parent.parent / "gradio_app.py" 96 | subprocess.run([sys.executable, str(gradio_script)]) 97 | 98 | def _process_single_video(args): 99 | """Process a single video file.""" 100 | from . import infer 101 | 102 | print(f"Processing video: {args.single_video}") 103 | print(f"Prompt: {args.single_prompt}") 104 | 105 | # This would need to be implemented to match the actual infer.py interface 106 | # For now, redirect to the original script 107 | import subprocess 108 | cmd = [ 109 | sys.executable, "infer.py", 110 | "--model_path", args.model_path, 111 | "--config_path", args.config_path, 112 | "--single_video", args.single_video, 113 | "--single_prompt", args.single_prompt, 114 | "--output_dir", args.output_dir, 115 | "--guidance_scale", str(args.guidance_scale), 116 | "--num_inference_steps", str(args.num_inference_steps) 117 | ] 118 | if args.neg_prompt: 119 | cmd.extend(["--neg_prompt", args.neg_prompt]) 120 | 121 | subprocess.run(cmd) 122 | 123 | def _process_batch(args): 124 | """Process a batch of videos from CSV.""" 125 | import subprocess 126 | cmd = [ 127 | sys.executable, "infer.py", 128 | "--model_path", args.model_path, 129 | "--config_path", args.config_path, 130 | "--csv_path", args.csv_path, 131 | "--output_dir", args.output_dir, 132 | "--guidance_scale", str(args.guidance_scale), 133 | "--num_inference_steps", str(args.num_inference_steps) 134 | ] 135 | if args.neg_prompt: 136 | cmd.extend(["--neg_prompt", args.neg_prompt]) 137 | 138 | subprocess.run(cmd) 139 | 140 | if __name__ == "__main__": 141 | main() -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/mlp_layers.py: -------------------------------------------------------------------------------- 1 | # Modified from timm library: 2 | # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13 3 | 4 | from functools import partial 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .modulate_layers import modulate 11 | from ...utils.helper import to_2tuple 12 | 13 | class MLP(nn.Module): 14 | """MLP as used in Vision Transformer, MLP-Mixer and related networks""" 15 | 16 | def __init__( 17 | self, 18 | in_channels, 19 | hidden_channels=None, 20 | out_features=None, 21 | act_layer=nn.GELU, 22 | norm_layer=None, 23 | bias=True, 24 | drop=0.0, 25 | use_conv=False, 26 | device=None, 27 | dtype=None, 28 | ): 29 | factory_kwargs = {"device": device, "dtype": dtype} 30 | super().__init__() 31 | out_features = out_features or in_channels 32 | hidden_channels = hidden_channels or in_channels 33 | bias = to_2tuple(bias) 34 | drop_probs = to_2tuple(drop) 35 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear 36 | 37 | self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs) 38 | self.act = act_layer() 39 | self.drop1 = nn.Dropout(drop_probs[0]) 40 | self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity() 41 | self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs) 42 | self.drop2 = nn.Dropout(drop_probs[1]) 43 | 44 | def forward(self, x): 45 | x = self.fc1(x) 46 | x = self.act(x) 47 | x = self.drop1(x) 48 | x = self.norm(x) 49 | x = self.fc2(x) 50 | x = self.drop2(x) 51 | return x 52 | 53 | 54 | # copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py 55 | # only used when use_vanilla is True 56 | class MLPEmbedder(nn.Module): 57 | def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None): 58 | factory_kwargs = {"device": device, "dtype": dtype} 59 | super().__init__() 60 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs) 61 | self.silu = nn.SiLU() 62 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs) 63 | 64 | def forward(self, x: torch.Tensor) -> torch.Tensor: 65 | return self.out_layer(self.silu(self.in_layer(x))) 66 | 67 | 68 | class LinearWarpforSingle(nn.Module): 69 | def __init__(self, in_dim: int, out_dim: int, bias=True, device=None, dtype=None): 70 | factory_kwargs = {"device": device, "dtype": dtype} 71 | super().__init__() 72 | self.fc = nn.Linear(in_dim, out_dim, bias=bias, **factory_kwargs) 73 | 74 | def forward(self, x, y): 75 | z = torch.cat([x, y], dim=2) 76 | return self.fc(z) 77 | 78 | class FinalLayer1D(nn.Module): 79 | def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None): 80 | factory_kwargs = {"device": device, "dtype": dtype} 81 | super().__init__() 82 | 83 | # Just use LayerNorm for the final layer 84 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) 85 | self.linear = nn.Linear(hidden_size, patch_size * out_channels, bias=True, **factory_kwargs) 86 | nn.init.zeros_(self.linear.weight) 87 | nn.init.zeros_(self.linear.bias) 88 | 89 | # Here we don't distinguish between the modulate types. Just use the simple one. 90 | self.adaLN_modulation = nn.Sequential( 91 | act_layer(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs) 92 | ) 93 | # Zero-initialize the modulation 94 | nn.init.zeros_(self.adaLN_modulation[1].weight) 95 | nn.init.zeros_(self.adaLN_modulation[1].bias) 96 | 97 | def forward(self, x, c): 98 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) 99 | x = modulate(self.norm_final(x), shift=shift, scale=scale) 100 | x = self.linear(x) 101 | return x 102 | 103 | 104 | class ChannelLastConv1d(nn.Conv1d): 105 | 106 | def forward(self, x: torch.Tensor) -> torch.Tensor: 107 | x = x.permute(0, 2, 1) 108 | x = super().forward(x) 109 | x = x.permute(0, 2, 1) 110 | return x 111 | 112 | 113 | class ConvMLP(nn.Module): 114 | 115 | def __init__( 116 | self, 117 | dim: int, 118 | hidden_dim: int, 119 | multiple_of: int = 256, 120 | kernel_size: int = 3, 121 | padding: int = 1, 122 | device=None, 123 | dtype=None, 124 | ): 125 | """ 126 | Convolutional MLP module. 127 | 128 | Args: 129 | dim (int): Input dimension. 130 | hidden_dim (int): Hidden dimension of the feedforward layer. 131 | multiple_of (int): Value to ensure hidden dimension is a multiple of this value. 132 | 133 | Attributes: 134 | w1: Linear transformation for the first layer. 135 | w2: Linear transformation for the second layer. 136 | w3: Linear transformation for the third layer. 137 | 138 | """ 139 | factory_kwargs = {"device": device, "dtype": dtype} 140 | super().__init__() 141 | hidden_dim = int(2 * hidden_dim / 3) 142 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 143 | 144 | self.w1 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs) 145 | self.w2 = ChannelLastConv1d(hidden_dim, dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs) 146 | self.w3 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs) 147 | 148 | def forward(self, x): 149 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 150 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | HunyuanVideo-Foley: Multimodal Diffusion with Representation Alignment 4 | for High-Fidelity Foley Audio Generation 5 | 6 | Setup script for building and installing the HunyuanVideo-Foley package. 7 | """ 8 | 9 | import os 10 | import re 11 | from typing import List 12 | from setuptools import setup, find_packages 13 | 14 | def read_file(filename: str) -> str: 15 | """Read content from a file.""" 16 | here = os.path.abspath(os.path.dirname(__file__)) 17 | with open(os.path.join(here, filename), 'r', encoding='utf-8') as f: 18 | return f.read() 19 | 20 | def get_version() -> str: 21 | """Extract version from constants.py or use default.""" 22 | try: 23 | constants_path = os.path.join('hunyuanvideo_foley', 'constants.py') 24 | content = read_file(constants_path) 25 | version_match = re.search(r"__version__\s*=\s*['\"]([^'\"]*)['\"]", content) 26 | if version_match: 27 | return version_match.group(1) 28 | except FileNotFoundError: 29 | pass 30 | return "1.0.0" 31 | 32 | def parse_requirements(filename: str) -> List[str]: 33 | """Parse requirements from requirements.txt file.""" 34 | try: 35 | content = read_file(filename) 36 | lines = content.splitlines() 37 | requirements = [] 38 | 39 | for line in lines: 40 | line = line.strip() 41 | if not line or line.startswith('#'): 42 | continue 43 | 44 | # Handle git+https dependencies - convert to standard package names 45 | if line.startswith('git+'): 46 | if 'transformers' in line: 47 | requirements.append('transformers>=4.49.0') 48 | elif 'audiotools' in line: 49 | # Use a placeholder for audiotools since it's not on PyPI 50 | # Users will need to install it separately 51 | continue # Skip for now 52 | else: 53 | continue # Skip other git dependencies 54 | else: 55 | requirements.append(line) 56 | 57 | return requirements 58 | except FileNotFoundError: 59 | return [] 60 | 61 | def get_long_description() -> str: 62 | """Get long description from README.md.""" 63 | try: 64 | readme = read_file("README.md") 65 | # Remove HTML tags and excessive styling for PyPI compatibility 66 | readme = re.sub(r'<[^>]+>', '', readme) 67 | return readme 68 | except FileNotFoundError: 69 | return "Multimodal Diffusion with Representation Alignment for High-Fidelity Foley Audio Generation" 70 | 71 | # Read requirements 72 | install_requires = parse_requirements("requirements.txt") 73 | 74 | # Separate development requirements 75 | dev_requirements = [ 76 | "black>=23.0.0", 77 | "isort>=5.12.0", 78 | "flake8>=6.0.0", 79 | "mypy>=1.3.0", 80 | "pre-commit>=3.0.0", 81 | "pytest>=7.0.0", 82 | "pytest-cov>=4.0.0", 83 | ] 84 | 85 | # Optional dependencies for different features 86 | extras_require = { 87 | "dev": dev_requirements, 88 | "test": [ 89 | "pytest>=7.0.0", 90 | "pytest-cov>=4.0.0", 91 | ], 92 | "gradio": [ 93 | "gradio==3.50.2", 94 | ], 95 | "comfyui": [ 96 | # ComfyUI specific dependencies can be added here 97 | ], 98 | "all": dev_requirements + ["gradio==3.50.2"], 99 | } 100 | 101 | setup( 102 | name="hunyuanvideo-foley", 103 | version=get_version(), 104 | 105 | # Package metadata 106 | author="Tencent Hunyuan Team", 107 | author_email="hunyuan@tencent.com", 108 | description="Multimodal Diffusion with Representation Alignment for High-Fidelity Foley Audio Generation", 109 | long_description=get_long_description(), 110 | long_description_content_type="text/markdown", 111 | 112 | # URLs 113 | url="https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley", 114 | project_urls={ 115 | "Homepage": "https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley", 116 | "Repository": "https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley", 117 | "Documentation": "https://szczesnys.github.io/hunyuanvideo-foley", 118 | "Paper": "https://arxiv.org/abs/2508.16930", 119 | "Demo": "https://huggingface.co/spaces/tencent/HunyuanVideo-Foley", 120 | "Models": "https://huggingface.co/tencent/HunyuanVideo-Foley", 121 | }, 122 | 123 | # Package discovery 124 | packages=find_packages( 125 | include=["hunyuanvideo_foley", "hunyuanvideo_foley.*"] 126 | ), 127 | include_package_data=True, 128 | 129 | # Package requirements 130 | python_requires=">=3.8", 131 | install_requires=install_requires, 132 | extras_require=extras_require, 133 | 134 | # Entry points for command line scripts 135 | entry_points={ 136 | "console_scripts": [ 137 | "hunyuanvideo-foley=hunyuanvideo_foley.cli:main", 138 | ], 139 | }, 140 | 141 | # Package data 142 | package_data={ 143 | "hunyuanvideo_foley": [ 144 | "configs/*.yaml", 145 | "configs/*.yml", 146 | "*.yaml", 147 | "*.yml", 148 | ], 149 | }, 150 | 151 | # Classification 152 | classifiers=[ 153 | "Development Status :: 4 - Beta", 154 | "Intended Audience :: Developers", 155 | "Intended Audience :: Science/Research", 156 | "License :: OSI Approved :: Apache Software License", 157 | "Operating System :: OS Independent", 158 | "Programming Language :: Python :: 3", 159 | "Programming Language :: Python :: 3.8", 160 | "Programming Language :: Python :: 3.9", 161 | "Programming Language :: Python :: 3.10", 162 | "Programming Language :: Python :: 3.11", 163 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 164 | "Topic :: Multimedia :: Sound/Audio :: Analysis", 165 | "Topic :: Multimedia :: Video", 166 | ], 167 | 168 | # Keywords for discoverability 169 | keywords=[ 170 | "artificial intelligence", 171 | "machine learning", 172 | "deep learning", 173 | "multimodal", 174 | "diffusion models", 175 | "audio generation", 176 | "foley audio", 177 | "video-to-audio", 178 | "text-to-audio", 179 | "pytorch", 180 | "huggingface", 181 | "tencent", 182 | "hunyuan" 183 | ], 184 | 185 | # Licensing 186 | license="Apache-2.0", 187 | 188 | # Build configuration 189 | zip_safe=False, 190 | 191 | # Additional metadata 192 | platforms=["any"], 193 | ) -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/feature_utils.py: -------------------------------------------------------------------------------- 1 | """Feature extraction utilities for video and text processing.""" 2 | 3 | import os 4 | import numpy as np 5 | import torch 6 | import av 7 | from PIL import Image 8 | from einops import rearrange 9 | from typing import Any, Dict, List, Union, Tuple 10 | from loguru import logger 11 | 12 | from .config_utils import AttributeDict 13 | from ..constants import FPS_VISUAL, MAX_VIDEO_DURATION_SECONDS, DEFAULT_NEGATIVE_PROMPT 14 | 15 | 16 | class FeatureExtractionError(Exception): 17 | """Exception raised for feature extraction errors.""" 18 | pass 19 | 20 | def get_frames_av( 21 | video_path: str, 22 | fps: float, 23 | max_length: float = None, 24 | ) -> Tuple[np.ndarray, float]: 25 | end_sec = max_length if max_length is not None else 15 26 | next_frame_time_for_each_fps = 0.0 27 | time_delta_for_each_fps = 1 / fps 28 | 29 | all_frames = [] 30 | output_frames = [] 31 | 32 | with av.open(video_path) as container: 33 | stream = container.streams.video[0] 34 | ori_fps = stream.guessed_rate 35 | stream.thread_type = "AUTO" 36 | for packet in container.demux(stream): 37 | for frame in packet.decode(): 38 | frame_time = frame.time 39 | if frame_time < 0: 40 | continue 41 | if frame_time > end_sec: 42 | break 43 | 44 | frame_np = None 45 | 46 | this_time = frame_time 47 | while this_time >= next_frame_time_for_each_fps: 48 | if frame_np is None: 49 | frame_np = frame.to_ndarray(format="rgb24") 50 | 51 | output_frames.append(frame_np) 52 | next_frame_time_for_each_fps += time_delta_for_each_fps 53 | 54 | output_frames = np.stack(output_frames) 55 | 56 | vid_len_in_s = len(output_frames) / fps 57 | if max_length is not None and len(output_frames) > int(max_length * fps): 58 | output_frames = output_frames[: int(max_length * fps)] 59 | vid_len_in_s = max_length 60 | 61 | return output_frames, vid_len_in_s 62 | 63 | @torch.inference_mode() 64 | def encode_video_with_siglip2(x: torch.Tensor, model_dict, batch_size: int = -1): 65 | b, t, c, h, w = x.shape 66 | if batch_size < 0: 67 | batch_size = b * t 68 | x = rearrange(x, "b t c h w -> (b t) c h w") 69 | outputs = [] 70 | for i in range(0, b * t, batch_size): 71 | outputs.append(model_dict.siglip2_model.get_image_features(pixel_values=x[i : i + batch_size])) 72 | res = torch.cat(outputs, dim=0) 73 | res = rearrange(res, "(b t) d -> b t d", b=b) 74 | return res 75 | 76 | @torch.inference_mode() 77 | def encode_video_with_sync(x: torch.Tensor, model_dict, batch_size: int = -1): 78 | """ 79 | The input video of x is best to be in fps of 24 of greater than 24. 80 | Input: 81 | x: tensor in shape of [B, T, C, H, W] 82 | batch_size: the batch_size for synchformer inference 83 | """ 84 | b, t, c, h, w = x.shape 85 | assert c == 3 and h == 224 and w == 224 86 | 87 | segment_size = 16 88 | step_size = 8 89 | num_segments = (t - segment_size) // step_size + 1 90 | segments = [] 91 | for i in range(num_segments): 92 | segments.append(x[:, i * step_size : i * step_size + segment_size]) 93 | x = torch.stack(segments, dim=1).cuda() # (B, num_segments, segment_size, 3, 224, 224) 94 | 95 | outputs = [] 96 | if batch_size < 0: 97 | batch_size = b * num_segments 98 | x = rearrange(x, "b s t c h w -> (b s) 1 t c h w") 99 | for i in range(0, b * num_segments, batch_size): 100 | with torch.autocast(device_type="cuda", enabled=True, dtype=torch.half): 101 | outputs.append(model_dict.syncformer_model(x[i : i + batch_size])) 102 | x = torch.cat(outputs, dim=0) # [b * num_segments, 1, 8, 768] 103 | x = rearrange(x, "(b s) 1 t d -> b (s t) d", b=b) 104 | return x 105 | 106 | 107 | @torch.inference_mode() 108 | def encode_video_features(video_path, model_dict): 109 | visual_features = {} 110 | # siglip2 visual features 111 | frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["siglip2"]) 112 | images = [Image.fromarray(frame).convert('RGB') for frame in frames] 113 | images = [model_dict.siglip2_preprocess(image) for image in images] # [T, C, H, W] 114 | clip_frames = torch.stack(images).to(model_dict.device).unsqueeze(0) 115 | visual_features['siglip2_feat'] = encode_video_with_siglip2(clip_frames, model_dict).to(model_dict.device) 116 | 117 | # synchformer visual features 118 | frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["synchformer"]) 119 | images = torch.from_numpy(frames).permute(0, 3, 1, 2) # [T, C, H, W] 120 | sync_frames = model_dict.syncformer_preprocess(images).unsqueeze(0) # [1, T, 3, 224, 224] 121 | # [1, num_segments * 8, channel_dim], e.g. [1, 240, 768] for 10s video 122 | visual_features['syncformer_feat'] = encode_video_with_sync(sync_frames, model_dict) 123 | 124 | vid_len_in_s = sync_frames.shape[1] / FPS_VISUAL["synchformer"] 125 | visual_features = AttributeDict(visual_features) 126 | 127 | return visual_features, vid_len_in_s 128 | 129 | @torch.inference_mode() 130 | def encode_text_feat(text: List[str], model_dict): 131 | # x: (B, L) 132 | inputs = model_dict.clap_tokenizer(text, padding=True, return_tensors="pt").to(model_dict.device) 133 | outputs = model_dict.clap_model(**inputs, output_hidden_states=True, return_dict=True) 134 | return outputs.last_hidden_state, outputs.attentions 135 | 136 | 137 | def feature_process(video_path, prompt, model_dict, cfg, neg_prompt=None): 138 | visual_feats, audio_len_in_s = encode_video_features(video_path, model_dict) 139 | if neg_prompt is None: 140 | neg_prompt = DEFAULT_NEGATIVE_PROMPT # 使用常量中的默认值 141 | prompts = [neg_prompt, prompt] 142 | text_feat_res, text_feat_mask = encode_text_feat(prompts, model_dict) 143 | 144 | text_feat = text_feat_res[1:] 145 | uncond_text_feat = text_feat_res[:1] 146 | 147 | if cfg.model_config.model_kwargs.text_length < text_feat.shape[1]: 148 | text_seq_length = cfg.model_config.model_kwargs.text_length 149 | text_feat = text_feat[:, :text_seq_length] 150 | uncond_text_feat = uncond_text_feat[:, :text_seq_length] 151 | 152 | text_feats = AttributeDict({ 153 | 'text_feat': text_feat, 154 | 'uncond_text_feat': uncond_text_feat, 155 | }) 156 | 157 | if hasattr(model_dict, 'manager') and hasattr(model_dict.manager, 'release_feature_models'): 158 | model_dict.manager.release_feature_models() 159 | 160 | return visual_feats, text_feats, audio_len_in_s 161 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/posemb_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, Tuple 3 | 4 | 5 | def _to_tuple(x, dim=2): 6 | if isinstance(x, int): 7 | return (x,) * dim 8 | elif len(x) == dim: 9 | return x 10 | else: 11 | raise ValueError(f"Expected length {dim} or int, but got {x}") 12 | 13 | 14 | def get_meshgrid_nd(start, *args, dim=2): 15 | """ 16 | Get n-D meshgrid with start, stop and num. 17 | 18 | Args: 19 | start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, 20 | step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num 21 | should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in 22 | n-tuples. 23 | *args: See above. 24 | dim (int): Dimension of the meshgrid. Defaults to 2. 25 | 26 | Returns: 27 | grid (np.ndarray): [dim, ...] 28 | """ 29 | if len(args) == 0: 30 | # start is grid_size 31 | num = _to_tuple(start, dim=dim) 32 | start = (0,) * dim 33 | stop = num 34 | elif len(args) == 1: 35 | # start is start, args[0] is stop, step is 1 36 | start = _to_tuple(start, dim=dim) 37 | stop = _to_tuple(args[0], dim=dim) 38 | num = [stop[i] - start[i] for i in range(dim)] 39 | elif len(args) == 2: 40 | # start is start, args[0] is stop, args[1] is num 41 | start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 42 | stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 43 | num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 44 | else: 45 | raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") 46 | 47 | # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) 48 | axis_grid = [] 49 | for i in range(dim): 50 | a, b, n = start[i], stop[i], num[i] 51 | g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] 52 | axis_grid.append(g) 53 | grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] 54 | grid = torch.stack(grid, dim=0) # [dim, W, H, D] 55 | 56 | return grid 57 | 58 | 59 | ################################################################################# 60 | # Rotary Positional Embedding Functions # 61 | ################################################################################# 62 | # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80 63 | 64 | 65 | def get_nd_rotary_pos_embed( 66 | rope_dim_list, start, *args, theta=10000.0, use_real=False, theta_rescale_factor=1.0, freq_scaling=1.0 67 | ): 68 | """ 69 | This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. 70 | 71 | Args: 72 | rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. 73 | sum(rope_dim_list) should equal to head_dim of attention layer. 74 | start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, 75 | args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. 76 | *args: See above. 77 | theta (float): Scaling factor for frequency computation. Defaults to 10000.0. 78 | use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. 79 | Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real 80 | part and an imaginary part separately. 81 | theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. 82 | freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0. 83 | 84 | Returns: 85 | pos_embed (torch.Tensor): [HW, D/2] 86 | """ 87 | 88 | grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] 89 | 90 | # use 1/ndim of dimensions to encode grid_axis 91 | embs = [] 92 | for i in range(len(rope_dim_list)): 93 | emb = get_1d_rotary_pos_embed( 94 | rope_dim_list[i], 95 | grid[i].reshape(-1), 96 | theta, 97 | use_real=use_real, 98 | theta_rescale_factor=theta_rescale_factor, 99 | freq_scaling=freq_scaling, 100 | ) # 2 x [WHD, rope_dim_list[i]] 101 | embs.append(emb) 102 | 103 | if use_real: 104 | cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) 105 | sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) 106 | return cos, sin 107 | else: 108 | emb = torch.cat(embs, dim=1) # (WHD, D/2) 109 | return emb 110 | 111 | 112 | def get_1d_rotary_pos_embed( 113 | dim: int, 114 | pos: Union[torch.FloatTensor, int], 115 | theta: float = 10000.0, 116 | use_real: bool = False, 117 | theta_rescale_factor: float = 1.0, 118 | freq_scaling: float = 1.0, 119 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 120 | """ 121 | Precompute the frequency tensor for complex exponential (cis) with given dimensions. 122 | (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) 123 | 124 | This function calculates a frequency tensor with complex exponential using the given dimension 'dim' 125 | and the end index 'end'. The 'theta' parameter scales the frequencies. 126 | The returned tensor contains complex values in complex64 data type. 127 | 128 | Args: 129 | dim (int): Dimension of the frequency tensor. 130 | pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar 131 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. 132 | use_real (bool, optional): If True, return real part and imaginary part separately. 133 | Otherwise, return complex numbers. 134 | theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. 135 | freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0. 136 | 137 | Returns: 138 | freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] 139 | freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] 140 | """ 141 | if isinstance(pos, int): 142 | pos = torch.arange(pos).float() 143 | 144 | # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning 145 | # has some connection to NTK literature 146 | # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ 147 | if theta_rescale_factor != 1.0: 148 | theta *= theta_rescale_factor ** (dim / (dim - 1)) 149 | 150 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] 151 | freqs *= freq_scaling 152 | freqs = torch.outer(pos, freqs) # [S, D/2] 153 | if use_real: 154 | freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] 155 | freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] 156 | return freqs_cos, freqs_sin 157 | else: 158 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] 159 | return freqs_cis 160 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/model/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from audiotools import AudioSignal 5 | from audiotools import ml 6 | from audiotools import STFTParams 7 | from einops import rearrange 8 | from torch.nn.utils import weight_norm 9 | 10 | 11 | def WNConv1d(*args, **kwargs): 12 | act = kwargs.pop("act", True) 13 | conv = weight_norm(nn.Conv1d(*args, **kwargs)) 14 | if not act: 15 | return conv 16 | return nn.Sequential(conv, nn.LeakyReLU(0.1)) 17 | 18 | 19 | def WNConv2d(*args, **kwargs): 20 | act = kwargs.pop("act", True) 21 | conv = weight_norm(nn.Conv2d(*args, **kwargs)) 22 | if not act: 23 | return conv 24 | return nn.Sequential(conv, nn.LeakyReLU(0.1)) 25 | 26 | 27 | class MPD(nn.Module): 28 | def __init__(self, period): 29 | super().__init__() 30 | self.period = period 31 | self.convs = nn.ModuleList( 32 | [ 33 | WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), 34 | WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), 35 | WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), 36 | WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), 37 | WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), 38 | ] 39 | ) 40 | self.conv_post = WNConv2d( 41 | 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False 42 | ) 43 | 44 | def pad_to_period(self, x): 45 | t = x.shape[-1] 46 | x = F.pad(x, (0, self.period - t % self.period), mode="reflect") 47 | return x 48 | 49 | def forward(self, x): 50 | fmap = [] 51 | 52 | x = self.pad_to_period(x) 53 | x = rearrange(x, "b c (l p) -> b c l p", p=self.period) 54 | 55 | for layer in self.convs: 56 | x = layer(x) 57 | fmap.append(x) 58 | 59 | x = self.conv_post(x) 60 | fmap.append(x) 61 | 62 | return fmap 63 | 64 | 65 | class MSD(nn.Module): 66 | def __init__(self, rate: int = 1, sample_rate: int = 44100): 67 | super().__init__() 68 | self.convs = nn.ModuleList( 69 | [ 70 | WNConv1d(1, 16, 15, 1, padding=7), 71 | WNConv1d(16, 64, 41, 4, groups=4, padding=20), 72 | WNConv1d(64, 256, 41, 4, groups=16, padding=20), 73 | WNConv1d(256, 1024, 41, 4, groups=64, padding=20), 74 | WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), 75 | WNConv1d(1024, 1024, 5, 1, padding=2), 76 | ] 77 | ) 78 | self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) 79 | self.sample_rate = sample_rate 80 | self.rate = rate 81 | 82 | def forward(self, x): 83 | x = AudioSignal(x, self.sample_rate) 84 | x.resample(self.sample_rate // self.rate) 85 | x = x.audio_data 86 | 87 | fmap = [] 88 | 89 | for l in self.convs: 90 | x = l(x) 91 | fmap.append(x) 92 | x = self.conv_post(x) 93 | fmap.append(x) 94 | 95 | return fmap 96 | 97 | 98 | BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] 99 | 100 | 101 | class MRD(nn.Module): 102 | def __init__( 103 | self, 104 | window_length: int, 105 | hop_factor: float = 0.25, 106 | sample_rate: int = 44100, 107 | bands: list = BANDS, 108 | ): 109 | """Complex multi-band spectrogram discriminator. 110 | Parameters 111 | ---------- 112 | window_length : int 113 | Window length of STFT. 114 | hop_factor : float, optional 115 | Hop factor of the STFT, defaults to ``0.25 * window_length``. 116 | sample_rate : int, optional 117 | Sampling rate of audio in Hz, by default 44100 118 | bands : list, optional 119 | Bands to run discriminator over. 120 | """ 121 | super().__init__() 122 | 123 | self.window_length = window_length 124 | self.hop_factor = hop_factor 125 | self.sample_rate = sample_rate 126 | self.stft_params = STFTParams( 127 | window_length=window_length, 128 | hop_length=int(window_length * hop_factor), 129 | match_stride=True, 130 | ) 131 | 132 | n_fft = window_length // 2 + 1 133 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] 134 | self.bands = bands 135 | 136 | ch = 32 137 | convs = lambda: nn.ModuleList( 138 | [ 139 | WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), 140 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 141 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 142 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 143 | WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), 144 | ] 145 | ) 146 | self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) 147 | self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) 148 | 149 | def spectrogram(self, x): 150 | x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) 151 | x = torch.view_as_real(x.stft()) 152 | x = rearrange(x, "b 1 f t c -> (b 1) c t f") 153 | # Split into bands 154 | x_bands = [x[..., b[0] : b[1]] for b in self.bands] 155 | return x_bands 156 | 157 | def forward(self, x): 158 | x_bands = self.spectrogram(x) 159 | fmap = [] 160 | 161 | x = [] 162 | for band, stack in zip(x_bands, self.band_convs): 163 | for layer in stack: 164 | band = layer(band) 165 | fmap.append(band) 166 | x.append(band) 167 | 168 | x = torch.cat(x, dim=-1) 169 | x = self.conv_post(x) 170 | fmap.append(x) 171 | 172 | return fmap 173 | 174 | 175 | class Discriminator(ml.BaseModel): 176 | def __init__( 177 | self, 178 | rates: list = [], 179 | periods: list = [2, 3, 5, 7, 11], 180 | fft_sizes: list = [2048, 1024, 512], 181 | sample_rate: int = 44100, 182 | bands: list = BANDS, 183 | ): 184 | """Discriminator that combines multiple discriminators. 185 | 186 | Parameters 187 | ---------- 188 | rates : list, optional 189 | sampling rates (in Hz) to run MSD at, by default [] 190 | If empty, MSD is not used. 191 | periods : list, optional 192 | periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] 193 | fft_sizes : list, optional 194 | Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] 195 | sample_rate : int, optional 196 | Sampling rate of audio in Hz, by default 44100 197 | bands : list, optional 198 | Bands to run MRD at, by default `BANDS` 199 | """ 200 | super().__init__() 201 | discs = [] 202 | discs += [MPD(p) for p in periods] 203 | discs += [MSD(r, sample_rate=sample_rate) for r in rates] 204 | discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] 205 | self.discriminators = nn.ModuleList(discs) 206 | 207 | def preprocess(self, y): 208 | # Remove DC offset 209 | y = y - y.mean(dim=-1, keepdims=True) 210 | # Peak normalize the volume of input audio 211 | y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) 212 | return y 213 | 214 | def forward(self, x): 215 | x = self.preprocess(x) 216 | fmaps = [d(x) for d in self.discriminators] 217 | return fmaps 218 | 219 | 220 | if __name__ == "__main__": 221 | disc = Discriminator() 222 | x = torch.zeros(1, 1, 44100) 223 | results = disc(x) 224 | for i, result in enumerate(results): 225 | print(f"disc{i}") 226 | for i, r in enumerate(result): 227 | print(r.shape, r.mean(), r.min(), r.max()) 228 | print() 229 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/compute_desync_score.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | from pathlib import Path 4 | 5 | import torch 6 | import torchaudio 7 | import torchvision 8 | from omegaconf import OmegaConf 9 | 10 | import data_transforms 11 | from .synchformer import Synchformer 12 | from .data_transforms import make_class_grid, quantize_offset 13 | from .utils import check_if_file_exists_else_download, which_ffmpeg 14 | 15 | 16 | def prepare_inputs(batch, device): 17 | aud = batch["audio"].to(device) 18 | vid = batch["video"].to(device) 19 | 20 | return aud, vid 21 | 22 | 23 | def get_test_transforms(): 24 | ts = [ 25 | data_transforms.EqualifyFromRight(), 26 | data_transforms.RGBSpatialCrop(input_size=224, is_random=False), 27 | data_transforms.TemporalCropAndOffset( 28 | crop_len_sec=5, 29 | max_off_sec=2, # https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml 30 | max_wiggle_sec=0.0, 31 | do_offset=True, 32 | offset_type="grid", 33 | prob_oos="null", 34 | grid_size=21, 35 | segment_size_vframes=16, 36 | n_segments=14, 37 | step_size_seg=0.5, 38 | vfps=25, 39 | ), 40 | data_transforms.GenerateMultipleSegments( 41 | segment_size_vframes=16, 42 | n_segments=14, 43 | is_start_random=False, 44 | step_size_seg=0.5, 45 | ), 46 | data_transforms.RGBToHalfToZeroOne(), 47 | data_transforms.RGBNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # motionformer normalization 48 | data_transforms.AudioMelSpectrogram( 49 | sample_rate=16000, 50 | win_length=400, # 25 ms * 16 kHz 51 | hop_length=160, # 10 ms * 16 kHz 52 | n_fft=1024, # 2^(ceil(log2(window_size * sampling_rate))) 53 | n_mels=128, # as in AST 54 | ), 55 | data_transforms.AudioLog(), 56 | data_transforms.PadOrTruncate(max_spec_t=66), 57 | data_transforms.AudioNormalizeAST(mean=-4.2677393, std=4.5689974), # AST, pre-trained on AudioSet 58 | data_transforms.PermuteStreams( 59 | einops_order_audio="S F T -> S 1 F T", einops_order_rgb="S T C H W -> S T C H W" # same 60 | ), 61 | ] 62 | transforms = torchvision.transforms.Compose(ts) 63 | 64 | return transforms 65 | 66 | 67 | def get_video_and_audio(path, get_meta=False, start_sec=0, end_sec=None): 68 | orig_path = path 69 | # (Tv, 3, H, W) [0, 255, uint8]; (Ca, Ta) 70 | rgb, audio, meta = torchvision.io.read_video(str(path), start_sec, end_sec, "sec", output_format="TCHW") 71 | assert meta["video_fps"], f"No video fps for {orig_path}" 72 | # (Ta) <- (Ca, Ta) 73 | audio = audio.mean(dim=0) 74 | # FIXME: this is legacy format of `meta` as it used to be loaded by VideoReader. 75 | meta = { 76 | "video": {"fps": [meta["video_fps"]]}, 77 | "audio": {"framerate": [meta["audio_fps"]]}, 78 | } 79 | return rgb, audio, meta 80 | 81 | 82 | def reencode_video(path, vfps=25, afps=16000, in_size=256): 83 | assert which_ffmpeg() != "", "Is ffmpeg installed? Check if the conda environment is activated." 84 | new_path = Path.cwd() / "vis" / f"{Path(path).stem}_{vfps}fps_{in_size}side_{afps}hz.mp4" 85 | new_path.parent.mkdir(exist_ok=True) 86 | new_path = str(new_path) 87 | cmd = f"{which_ffmpeg()}" 88 | # no info/error printing 89 | cmd += " -hide_banner -loglevel panic" 90 | cmd += f" -y -i {path}" 91 | # 1) change fps, 2) resize: min(H,W)=MIN_SIDE (vertical vids are supported), 3) change audio framerate 92 | cmd += f" -vf fps={vfps},scale=iw*{in_size}/'min(iw,ih)':ih*{in_size}/'min(iw,ih)',crop='trunc(iw/2)'*2:'trunc(ih/2)'*2" 93 | cmd += f" -ar {afps}" 94 | cmd += f" {new_path}" 95 | subprocess.call(cmd.split()) 96 | cmd = f"{which_ffmpeg()}" 97 | cmd += " -hide_banner -loglevel panic" 98 | cmd += f" -y -i {new_path}" 99 | cmd += f" -acodec pcm_s16le -ac 1" 100 | cmd += f' {new_path.replace(".mp4", ".wav")}' 101 | subprocess.call(cmd.split()) 102 | return new_path 103 | 104 | 105 | def decode_single_video_prediction(off_logits, grid, item): 106 | label = item["targets"]["offset_label"].item() 107 | print("Ground Truth offset (sec):", f"{label:.2f} ({quantize_offset(grid, label)[-1].item()})") 108 | print() 109 | print("Prediction Results:") 110 | off_probs = torch.softmax(off_logits, dim=-1) 111 | k = min(off_probs.shape[-1], 5) 112 | topk_logits, topk_preds = torch.topk(off_logits, k) 113 | # remove batch dimension 114 | assert len(topk_logits) == 1, "batch is larger than 1" 115 | topk_logits = topk_logits[0] 116 | topk_preds = topk_preds[0] 117 | off_logits = off_logits[0] 118 | off_probs = off_probs[0] 119 | for target_hat in topk_preds: 120 | print(f'p={off_probs[target_hat]:.4f} ({off_logits[target_hat]:.4f}), "{grid[target_hat]:.2f}" ({target_hat})') 121 | return off_probs 122 | 123 | 124 | def main(args): 125 | vfps = 25 126 | afps = 16000 127 | in_size = 256 128 | # making the offset class grid similar to the one used in transforms, 129 | # refer to the used one: https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml 130 | max_off_sec = 2 131 | num_cls = 21 132 | 133 | # checking if the provided video has the correct frame rates 134 | print(f"Using video: {args.vid_path}") 135 | v, _, info = torchvision.io.read_video(args.vid_path, pts_unit="sec") 136 | _, H, W, _ = v.shape 137 | if info["video_fps"] != vfps or info["audio_fps"] != afps or min(H, W) != in_size: 138 | print(f'Reencoding. vfps: {info["video_fps"]} -> {vfps};', end=" ") 139 | print(f'afps: {info["audio_fps"]} -> {afps};', end=" ") 140 | print(f"{(H, W)} -> min(H, W)={in_size}") 141 | args.vid_path = reencode_video(args.vid_path, vfps, afps, in_size) 142 | else: 143 | print(f'Skipping reencoding. vfps: {info["video_fps"]}; afps: {info["audio_fps"]}; min(H, W)={in_size}') 144 | 145 | device = torch.device(args.device) 146 | 147 | # load visual and audio streams 148 | # rgb: (Tv, 3, H, W) in [0, 225], audio: (Ta,) in [-1, 1] 149 | rgb, audio, meta = get_video_and_audio(args.vid_path, get_meta=True) 150 | 151 | # making an item (dict) to apply transformations 152 | # NOTE: here is how it works: 153 | # For instance, if the model is trained on 5sec clips, the provided video is 9sec, and `v_start_i_sec=1.3` 154 | # the transform will crop out a 5sec-clip from 1.3 to 6.3 seconds and shift the start of the audio 155 | # track by `args.offset_sec` seconds. It means that if `offset_sec` > 0, the audio will 156 | # start by `offset_sec` earlier than the rgb track. 157 | # It is a good idea to use something in [-`max_off_sec`, `max_off_sec`] (-2, +2) seconds (see `grid`) 158 | item = dict( 159 | video=rgb, 160 | audio=audio, 161 | meta=meta, 162 | path=args.vid_path, 163 | split="test", 164 | targets={ 165 | "v_start_i_sec": args.v_start_i_sec, 166 | "offset_sec": args.offset_sec, 167 | }, 168 | ) 169 | 170 | grid = make_class_grid(-max_off_sec, max_off_sec, num_cls) 171 | if not (min(grid) <= item["targets"]["offset_sec"] <= max(grid)): 172 | print(f'WARNING: offset_sec={item["targets"]["offset_sec"]} is outside the trained grid: {grid}') 173 | 174 | # applying the test-time transform 175 | item = get_test_transforms()(item) 176 | 177 | # prepare inputs for inference 178 | batch = torch.utils.data.default_collate([item]) 179 | aud, vid = prepare_inputs(batch, device) 180 | 181 | # TODO: 182 | # sanity check: we will take the input to the `model` and recontruct make a video from it. 183 | # Use this check to make sure the input makes sense (audio should be ok but shifted as you specified) 184 | # reconstruct_video_from_input(aud, vid, batch['meta'], args.vid_path, args.v_start_i_sec, args.offset_sec, 185 | # vfps, afps) 186 | 187 | # forward pass 188 | with torch.set_grad_enabled(False): 189 | with torch.autocast("cuda", enabled=True): 190 | _, logits = synchformer(vid, aud) 191 | 192 | # simply prints the results of the prediction 193 | decode_single_video_prediction(logits, grid, item) 194 | 195 | 196 | if __name__ == "__main__": 197 | parser = argparse.ArgumentParser() 198 | parser.add_argument("--exp_name", required=True, help="In a format: xx-xx-xxTxx-xx-xx") 199 | parser.add_argument("--vid_path", required=True, help="A path to .mp4 video") 200 | parser.add_argument("--offset_sec", type=float, default=0.0) 201 | parser.add_argument("--v_start_i_sec", type=float, default=0.0) 202 | parser.add_argument("--device", default="cuda:0") 203 | args = parser.parse_args() 204 | 205 | synchformer = Synchformer().cuda().eval() 206 | synchformer.load_state_dict( 207 | torch.load( 208 | os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"), 209 | weights_only=True, 210 | map_location="cpu", 211 | ) 212 | ) 213 | 214 | main(args) 215 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/quantize.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | from torch.nn.utils import weight_norm 9 | 10 | from .layers import WNConv1d 11 | 12 | 13 | class VectorQuantize(nn.Module): 14 | """ 15 | Implementation of VQ similar to Karpathy's repo: 16 | https://github.com/karpathy/deep-vector-quantization 17 | Additionally uses following tricks from Improved VQGAN 18 | (https://arxiv.org/pdf/2110.04627.pdf): 19 | 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space 20 | for improved codebook usage 21 | 2. l2-normalized codes: Converts euclidean distance to cosine similarity which 22 | improves training stability 23 | """ 24 | 25 | def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): 26 | super().__init__() 27 | self.codebook_size = codebook_size 28 | self.codebook_dim = codebook_dim 29 | 30 | self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) 31 | self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) 32 | self.codebook = nn.Embedding(codebook_size, codebook_dim) 33 | 34 | def forward(self, z): 35 | """Quantized the input tensor using a fixed codebook and returns 36 | the corresponding codebook vectors 37 | 38 | Parameters 39 | ---------- 40 | z : Tensor[B x D x T] 41 | 42 | Returns 43 | ------- 44 | Tensor[B x D x T] 45 | Quantized continuous representation of input 46 | Tensor[1] 47 | Commitment loss to train encoder to predict vectors closer to codebook 48 | entries 49 | Tensor[1] 50 | Codebook loss to update the codebook 51 | Tensor[B x T] 52 | Codebook indices (quantized discrete representation of input) 53 | Tensor[B x D x T] 54 | Projected latents (continuous representation of input before quantization) 55 | """ 56 | 57 | # Factorized codes (ViT-VQGAN) Project input into low-dimensional space 58 | z_e = self.in_proj(z) # z_e : (B x D x T) 59 | z_q, indices = self.decode_latents(z_e) 60 | 61 | commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) 62 | codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) 63 | 64 | z_q = ( 65 | z_e + (z_q - z_e).detach() 66 | ) # noop in forward pass, straight-through gradient estimator in backward pass 67 | 68 | z_q = self.out_proj(z_q) 69 | 70 | return z_q, commitment_loss, codebook_loss, indices, z_e 71 | 72 | def embed_code(self, embed_id): 73 | return F.embedding(embed_id, self.codebook.weight) 74 | 75 | def decode_code(self, embed_id): 76 | return self.embed_code(embed_id).transpose(1, 2) 77 | 78 | def decode_latents(self, latents): 79 | encodings = rearrange(latents, "b d t -> (b t) d") 80 | codebook = self.codebook.weight # codebook: (N x D) 81 | 82 | # L2 normalize encodings and codebook (ViT-VQGAN) 83 | encodings = F.normalize(encodings) 84 | codebook = F.normalize(codebook) 85 | 86 | # Compute euclidean distance with codebook 87 | dist = ( 88 | encodings.pow(2).sum(1, keepdim=True) 89 | - 2 * encodings @ codebook.t() 90 | + codebook.pow(2).sum(1, keepdim=True).t() 91 | ) 92 | indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) 93 | z_q = self.decode_code(indices) 94 | return z_q, indices 95 | 96 | 97 | class ResidualVectorQuantize(nn.Module): 98 | """ 99 | Introduced in SoundStream: An end2end neural audio codec 100 | https://arxiv.org/abs/2107.03312 101 | """ 102 | 103 | def __init__( 104 | self, 105 | input_dim: int = 512, 106 | n_codebooks: int = 9, 107 | codebook_size: int = 1024, 108 | codebook_dim: Union[int, list] = 8, 109 | quantizer_dropout: float = 0.0, 110 | ): 111 | super().__init__() 112 | if isinstance(codebook_dim, int): 113 | codebook_dim = [codebook_dim for _ in range(n_codebooks)] 114 | 115 | self.n_codebooks = n_codebooks 116 | self.codebook_dim = codebook_dim 117 | self.codebook_size = codebook_size 118 | 119 | self.quantizers = nn.ModuleList( 120 | [ 121 | VectorQuantize(input_dim, codebook_size, codebook_dim[i]) 122 | for i in range(n_codebooks) 123 | ] 124 | ) 125 | self.quantizer_dropout = quantizer_dropout 126 | 127 | def forward(self, z, n_quantizers: int = None): 128 | """Quantized the input tensor using a fixed set of `n` codebooks and returns 129 | the corresponding codebook vectors 130 | Parameters 131 | ---------- 132 | z : Tensor[B x D x T] 133 | n_quantizers : int, optional 134 | No. of quantizers to use 135 | (n_quantizers < self.n_codebooks ex: for quantizer dropout) 136 | Note: if `self.quantizer_dropout` is True, this argument is ignored 137 | when in training mode, and a random number of quantizers is used. 138 | Returns 139 | ------- 140 | dict 141 | A dictionary with the following keys: 142 | 143 | "z" : Tensor[B x D x T] 144 | Quantized continuous representation of input 145 | "codes" : Tensor[B x N x T] 146 | Codebook indices for each codebook 147 | (quantized discrete representation of input) 148 | "latents" : Tensor[B x N*D x T] 149 | Projected latents (continuous representation of input before quantization) 150 | "vq/commitment_loss" : Tensor[1] 151 | Commitment loss to train encoder to predict vectors closer to codebook 152 | entries 153 | "vq/codebook_loss" : Tensor[1] 154 | Codebook loss to update the codebook 155 | """ 156 | z_q = 0 157 | residual = z 158 | commitment_loss = 0 159 | codebook_loss = 0 160 | 161 | codebook_indices = [] 162 | latents = [] 163 | 164 | if n_quantizers is None: 165 | n_quantizers = self.n_codebooks 166 | if self.training: 167 | n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 168 | dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) 169 | n_dropout = int(z.shape[0] * self.quantizer_dropout) 170 | n_quantizers[:n_dropout] = dropout[:n_dropout] 171 | n_quantizers = n_quantizers.to(z.device) 172 | 173 | for i, quantizer in enumerate(self.quantizers): 174 | if self.training is False and i >= n_quantizers: 175 | break 176 | 177 | z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( 178 | residual 179 | ) 180 | 181 | # Create mask to apply quantizer dropout 182 | mask = ( 183 | torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers 184 | ) 185 | z_q = z_q + z_q_i * mask[:, None, None] 186 | residual = residual - z_q_i 187 | 188 | # Sum losses 189 | commitment_loss += (commitment_loss_i * mask).mean() 190 | codebook_loss += (codebook_loss_i * mask).mean() 191 | 192 | codebook_indices.append(indices_i) 193 | latents.append(z_e_i) 194 | 195 | codes = torch.stack(codebook_indices, dim=1) 196 | latents = torch.cat(latents, dim=1) 197 | 198 | return z_q, codes, latents, commitment_loss, codebook_loss 199 | 200 | def from_codes(self, codes: torch.Tensor): 201 | """Given the quantized codes, reconstruct the continuous representation 202 | Parameters 203 | ---------- 204 | codes : Tensor[B x N x T] 205 | Quantized discrete representation of input 206 | Returns 207 | ------- 208 | Tensor[B x D x T] 209 | Quantized continuous representation of input 210 | """ 211 | z_q = 0.0 212 | z_p = [] 213 | n_codebooks = codes.shape[1] 214 | for i in range(n_codebooks): 215 | z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) 216 | z_p.append(z_p_i) 217 | 218 | z_q_i = self.quantizers[i].out_proj(z_p_i) 219 | z_q = z_q + z_q_i 220 | return z_q, torch.cat(z_p, dim=1), codes 221 | 222 | def from_latents(self, latents: torch.Tensor): 223 | """Given the unquantized latents, reconstruct the 224 | continuous representation after quantization. 225 | 226 | Parameters 227 | ---------- 228 | latents : Tensor[B x N x T] 229 | Continuous representation of input after projection 230 | 231 | Returns 232 | ------- 233 | Tensor[B x D x T] 234 | Quantized representation of full-projected space 235 | Tensor[B x D x T] 236 | Quantized representation of latent space 237 | """ 238 | z_q = 0 239 | z_p = [] 240 | codes = [] 241 | dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) 242 | 243 | n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ 244 | 0 245 | ] 246 | for i in range(n_codebooks): 247 | j, k = dims[i], dims[i + 1] 248 | z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) 249 | z_p.append(z_p_i) 250 | codes.append(codes_i) 251 | 252 | z_q_i = self.quantizers[i].out_proj(z_p_i) 253 | z_q = z_q + z_q_i 254 | 255 | return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) 256 | 257 | 258 | if __name__ == "__main__": 259 | rvq = ResidualVectorQuantize(quantizer_dropout=True) 260 | x = torch.randn(16, 512, 80) 261 | y = rvq(x) 262 | print(y["latents"].shape) 263 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/model/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Union 5 | 6 | import numpy as np 7 | import torch 8 | import tqdm 9 | from audiotools import AudioSignal 10 | from torch import nn 11 | 12 | SUPPORTED_VERSIONS = ["1.0.0"] 13 | 14 | 15 | @dataclass 16 | class DACFile: 17 | codes: torch.Tensor 18 | 19 | # Metadata 20 | chunk_length: int 21 | original_length: int 22 | input_db: float 23 | channels: int 24 | sample_rate: int 25 | padding: bool 26 | dac_version: str 27 | 28 | def save(self, path): 29 | artifacts = { 30 | "codes": self.codes.numpy().astype(np.uint16), 31 | "metadata": { 32 | "input_db": self.input_db.numpy().astype(np.float32), 33 | "original_length": self.original_length, 34 | "sample_rate": self.sample_rate, 35 | "chunk_length": self.chunk_length, 36 | "channels": self.channels, 37 | "padding": self.padding, 38 | "dac_version": SUPPORTED_VERSIONS[-1], 39 | }, 40 | } 41 | path = Path(path).with_suffix(".dac") 42 | with open(path, "wb") as f: 43 | np.save(f, artifacts) 44 | return path 45 | 46 | @classmethod 47 | def load(cls, path): 48 | artifacts = np.load(path, allow_pickle=True)[()] 49 | codes = torch.from_numpy(artifacts["codes"].astype(int)) 50 | if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: 51 | raise RuntimeError( 52 | f"Given file {path} can't be loaded with this version of descript-audio-codec." 53 | ) 54 | return cls(codes=codes, **artifacts["metadata"]) 55 | 56 | 57 | class CodecMixin: 58 | @property 59 | def padding(self): 60 | if not hasattr(self, "_padding"): 61 | self._padding = True 62 | return self._padding 63 | 64 | @padding.setter 65 | def padding(self, value): 66 | assert isinstance(value, bool) 67 | 68 | layers = [ 69 | l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d)) 70 | ] 71 | 72 | for layer in layers: 73 | if value: 74 | if hasattr(layer, "original_padding"): 75 | layer.padding = layer.original_padding 76 | else: 77 | layer.original_padding = layer.padding 78 | layer.padding = tuple(0 for _ in range(len(layer.padding))) 79 | 80 | self._padding = value 81 | 82 | def get_delay(self): 83 | # Any number works here, delay is invariant to input length 84 | l_out = self.get_output_length(0) 85 | L = l_out 86 | 87 | layers = [] 88 | for layer in self.modules(): 89 | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): 90 | layers.append(layer) 91 | 92 | for layer in reversed(layers): 93 | d = layer.dilation[0] 94 | k = layer.kernel_size[0] 95 | s = layer.stride[0] 96 | 97 | if isinstance(layer, nn.ConvTranspose1d): 98 | L = ((L - d * (k - 1) - 1) / s) + 1 99 | elif isinstance(layer, nn.Conv1d): 100 | L = (L - 1) * s + d * (k - 1) + 1 101 | 102 | L = math.ceil(L) 103 | 104 | l_in = L 105 | 106 | return (l_in - l_out) // 2 107 | 108 | def get_output_length(self, input_length): 109 | L = input_length 110 | # Calculate output length 111 | for layer in self.modules(): 112 | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): 113 | d = layer.dilation[0] 114 | k = layer.kernel_size[0] 115 | s = layer.stride[0] 116 | 117 | if isinstance(layer, nn.Conv1d): 118 | L = ((L - d * (k - 1) - 1) / s) + 1 119 | elif isinstance(layer, nn.ConvTranspose1d): 120 | L = (L - 1) * s + d * (k - 1) + 1 121 | 122 | L = math.floor(L) 123 | return L 124 | 125 | @torch.no_grad() 126 | def compress( 127 | self, 128 | audio_path_or_signal: Union[str, Path, AudioSignal], 129 | win_duration: float = 1.0, 130 | verbose: bool = False, 131 | normalize_db: float = -16, 132 | n_quantizers: int = None, 133 | ) -> DACFile: 134 | """Processes an audio signal from a file or AudioSignal object into 135 | discrete codes. This function processes the signal in short windows, 136 | using constant GPU memory. 137 | 138 | Parameters 139 | ---------- 140 | audio_path_or_signal : Union[str, Path, AudioSignal] 141 | audio signal to reconstruct 142 | win_duration : float, optional 143 | window duration in seconds, by default 5.0 144 | verbose : bool, optional 145 | by default False 146 | normalize_db : float, optional 147 | normalize db, by default -16 148 | 149 | Returns 150 | ------- 151 | DACFile 152 | Object containing compressed codes and metadata 153 | required for decompression 154 | """ 155 | audio_signal = audio_path_or_signal 156 | if isinstance(audio_signal, (str, Path)): 157 | audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) 158 | 159 | self.eval() 160 | original_padding = self.padding 161 | original_device = audio_signal.device 162 | 163 | audio_signal = audio_signal.clone() 164 | audio_signal = audio_signal.to_mono() 165 | original_sr = audio_signal.sample_rate 166 | 167 | resample_fn = audio_signal.resample 168 | loudness_fn = audio_signal.loudness 169 | 170 | # If audio is > 10 minutes long, use the ffmpeg versions 171 | if audio_signal.signal_duration >= 10 * 60 * 60: 172 | resample_fn = audio_signal.ffmpeg_resample 173 | loudness_fn = audio_signal.ffmpeg_loudness 174 | 175 | original_length = audio_signal.signal_length 176 | resample_fn(self.sample_rate) 177 | input_db = loudness_fn() 178 | 179 | if normalize_db is not None: 180 | audio_signal.normalize(normalize_db) 181 | audio_signal.ensure_max_of_audio() 182 | 183 | nb, nac, nt = audio_signal.audio_data.shape 184 | audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) 185 | win_duration = ( 186 | audio_signal.signal_duration if win_duration is None else win_duration 187 | ) 188 | 189 | if audio_signal.signal_duration <= win_duration: 190 | # Unchunked compression (used if signal length < win duration) 191 | self.padding = True 192 | n_samples = nt 193 | hop = nt 194 | else: 195 | # Chunked inference 196 | self.padding = False 197 | # Zero-pad signal on either side by the delay 198 | audio_signal.zero_pad(self.delay, self.delay) 199 | n_samples = int(win_duration * self.sample_rate) 200 | # Round n_samples to nearest hop length multiple 201 | n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) 202 | hop = self.get_output_length(n_samples) 203 | 204 | codes = [] 205 | range_fn = range if not verbose else tqdm.trange 206 | 207 | for i in range_fn(0, nt, hop): 208 | x = audio_signal[..., i : i + n_samples] 209 | x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) 210 | 211 | audio_data = x.audio_data.to(self.device) 212 | audio_data = self.preprocess(audio_data, self.sample_rate) 213 | _, c, _, _, _ = self.encode(audio_data, n_quantizers) 214 | codes.append(c.to(original_device)) 215 | chunk_length = c.shape[-1] 216 | 217 | codes = torch.cat(codes, dim=-1) 218 | 219 | dac_file = DACFile( 220 | codes=codes, 221 | chunk_length=chunk_length, 222 | original_length=original_length, 223 | input_db=input_db, 224 | channels=nac, 225 | sample_rate=original_sr, 226 | padding=self.padding, 227 | dac_version=SUPPORTED_VERSIONS[-1], 228 | ) 229 | 230 | if n_quantizers is not None: 231 | codes = codes[:, :n_quantizers, :] 232 | 233 | self.padding = original_padding 234 | return dac_file 235 | 236 | @torch.no_grad() 237 | def decompress( 238 | self, 239 | obj: Union[str, Path, DACFile], 240 | verbose: bool = False, 241 | ) -> AudioSignal: 242 | """Reconstruct audio from a given .dac file 243 | 244 | Parameters 245 | ---------- 246 | obj : Union[str, Path, DACFile] 247 | .dac file location or corresponding DACFile object. 248 | verbose : bool, optional 249 | Prints progress if True, by default False 250 | 251 | Returns 252 | ------- 253 | AudioSignal 254 | Object with the reconstructed audio 255 | """ 256 | self.eval() 257 | if isinstance(obj, (str, Path)): 258 | obj = DACFile.load(obj) 259 | 260 | original_padding = self.padding 261 | self.padding = obj.padding 262 | 263 | range_fn = range if not verbose else tqdm.trange 264 | codes = obj.codes 265 | original_device = codes.device 266 | chunk_length = obj.chunk_length 267 | recons = [] 268 | 269 | for i in range_fn(0, codes.shape[-1], chunk_length): 270 | c = codes[..., i : i + chunk_length].to(self.device) 271 | z = self.quantizer.from_codes(c)[0] 272 | r = self.decode(z) 273 | recons.append(r.to(original_device)) 274 | 275 | recons = torch.cat(recons, dim=-1) 276 | recons = AudioSignal(recons, self.sample_rate) 277 | 278 | resample_fn = recons.resample 279 | loudness_fn = recons.loudness 280 | 281 | # If audio is > 10 minutes long, use the ffmpeg versions 282 | if recons.signal_duration >= 10 * 60 * 60: 283 | resample_fn = recons.ffmpeg_resample 284 | loudness_fn = recons.ffmpeg_loudness 285 | 286 | if obj.input_db is not None: 287 | recons.normalize(obj.input_db) 288 | 289 | resample_fn(obj.sample_rate) 290 | 291 | if obj.original_length is not None: 292 | recons = recons[..., : obj.original_length] 293 | loudness_fn() 294 | recons.audio_data = recons.audio_data.reshape( 295 | -1, obj.channels, obj.original_length 296 | ) 297 | else: 298 | loudness_fn() 299 | 300 | self.padding = original_padding 301 | return recons 302 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | import numpy as np 5 | import torch 6 | import pandas as pd 7 | import torchaudio 8 | from loguru import logger 9 | from hunyuanvideo_foley.utils.model_utils import load_model 10 | from hunyuanvideo_foley.utils.feature_utils import feature_process 11 | from hunyuanvideo_foley.utils.model_utils import denoise_process 12 | from hunyuanvideo_foley.utils.media_utils import merge_audio_video 13 | 14 | def set_manual_seed(global_seed): 15 | random.seed(global_seed) 16 | np.random.seed(global_seed) 17 | torch.manual_seed(global_seed) 18 | 19 | def infer(video_path, prompt, model_dict, cfg, guidance_scale=4.5, num_inference_steps=50, neg_prompt=None): 20 | visual_feats, text_feats, audio_len_in_s = feature_process( 21 | video_path, 22 | prompt, 23 | model_dict, 24 | cfg, 25 | neg_prompt=neg_prompt 26 | ) 27 | 28 | audio, sample_rate = denoise_process( 29 | visual_feats, 30 | text_feats, 31 | audio_len_in_s, 32 | model_dict, 33 | cfg, 34 | guidance_scale=guidance_scale, 35 | num_inference_steps=num_inference_steps 36 | ) 37 | return audio[0], sample_rate 38 | 39 | 40 | def generate_audio(model_dict, cfg, csv_path, output_dir, guidance_scale=4.5, num_inference_steps=50, neg_prompt=None): 41 | 42 | os.makedirs(output_dir, exist_ok=True) 43 | test_df = pd.read_csv(csv_path) 44 | 45 | for index, row in test_df.iterrows(): 46 | video_path = row['video'] 47 | prompt = row['prompt'] 48 | 49 | logger.info(f"Processing video: {video_path}") 50 | logger.info(f"Prompt: {prompt}") 51 | 52 | output_audio_path = os.path.join(output_dir, f"{index:04d}.wav") 53 | output_video_path = os.path.join(output_dir, f"{index:04d}.mp4") 54 | 55 | if not os.path.exists(output_audio_path) or not os.path.exists(output_video_path): 56 | audio, sample_rate = infer(video_path, prompt, model_dict, cfg, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, neg_prompt=neg_prompt) 57 | torchaudio.save(output_audio_path, audio, sample_rate) 58 | 59 | merge_audio_video(output_audio_path, video_path, output_video_path) 60 | 61 | logger.info(f"All audio files saved to {output_dir}") 62 | 63 | 64 | def parse_args(): 65 | parser = argparse.ArgumentParser( 66 | description="HunyuanVideo-Foley: Generate audio from video and text prompts", 67 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 68 | ) 69 | 70 | parser.add_argument( 71 | "--model_path", 72 | type=str, 73 | required=True, 74 | help="Path to the pretrained model dir" 75 | ) 76 | parser.add_argument( 77 | "--config_path", 78 | type=str, 79 | help="Path to the configuration file (.yaml file). If not specified, will be inferred from model_size" 80 | ) 81 | parser.add_argument( 82 | "--model_size", 83 | type=str, 84 | choices=["xl", "xxl"], 85 | default="xxl", 86 | help="Model size (xl/xxl). Auto-selects config and model file (default: xxl)" 87 | ) 88 | 89 | input_group = parser.add_mutually_exclusive_group(required=True) 90 | input_group.add_argument( 91 | "--csv_path", 92 | type=str, 93 | help="Path to CSV file containing video paths and text prompts (columns: 'video', 'text')" 94 | ) 95 | input_group.add_argument( 96 | "--single_video", 97 | type=str, 98 | help="Path to a single video file for inference" 99 | ) 100 | parser.add_argument( 101 | "--single_prompt", 102 | type=str, 103 | help="Text prompt for single video (required when using --single_video)" 104 | ) 105 | parser.add_argument( 106 | "--neg_prompt", 107 | type=str, 108 | default=None, 109 | help="Negative prompt to avoid during generation (default: 'noisy, harsh')" 110 | ) 111 | 112 | parser.add_argument( 113 | "--output_dir", 114 | type=str, 115 | required=True, 116 | help="Directory to save generated audio and video files" 117 | ) 118 | 119 | parser.add_argument( 120 | "--guidance_scale", 121 | type=float, 122 | default=4.5, 123 | help="Guidance scale for classifier-free guidance (higher = more text adherence)" 124 | ) 125 | parser.add_argument( 126 | "--num_inference_steps", 127 | type=int, 128 | default=50, 129 | help="Number of denoising steps for diffusion sampling" 130 | ) 131 | parser.add_argument( 132 | "--audio_length", 133 | type=float, 134 | default=None, 135 | help="Maximum audio length in seconds (default: video length)" 136 | ) 137 | 138 | parser.add_argument( 139 | "--device", 140 | type=str, 141 | default="auto", 142 | choices=["auto", "cpu", "cuda", "mps"], 143 | help="Device to use for inference" 144 | ) 145 | parser.add_argument( 146 | "--gpu_id", 147 | type=int, 148 | default=0, 149 | help="GPU ID to use when device is cuda" 150 | ) 151 | 152 | parser.add_argument( 153 | "--batch_size", 154 | type=int, 155 | default=1, 156 | help="Batch size for processing multiple videos" 157 | ) 158 | parser.add_argument( 159 | "--skip_existing", 160 | action="store_true", 161 | help="Skip processing if output files already exist" 162 | ) 163 | parser.add_argument( 164 | "--save_video", 165 | action="store_true", 166 | default=True, 167 | help="Save video with generated audio merged" 168 | ) 169 | parser.add_argument( 170 | "--log_level", 171 | type=str, 172 | default="INFO", 173 | choices=["DEBUG", "INFO", "WARNING", "ERROR"], 174 | help="Logging level" 175 | ) 176 | parser.add_argument( 177 | "--enable_offload", 178 | action="store_true", 179 | help="Enable model offloading to reduce peak memory usage (good for small VRAM GPUs)" 180 | ) 181 | 182 | args = parser.parse_args() 183 | 184 | if args.single_video and not args.single_prompt: 185 | parser.error("--single_prompt is required when using --single_video") 186 | 187 | # 如果指定了model_size,自动推断config_path和model文件 188 | if args.model_size: 189 | config_mapping = { 190 | "xl": "configs/hunyuanvideo-foley-xl.yaml", 191 | "xxl": "configs/hunyuanvideo-foley-xxl.yaml" 192 | } 193 | 194 | if not args.config_path: 195 | args.config_path = config_mapping[args.model_size] 196 | logger.info(f"Auto-selected config for {args.model_size} model: {args.config_path}") 197 | elif not args.config_path: 198 | args.model_size = "xxl" 199 | args.config_path = "configs/hunyuanvideo-foley-xxl.yaml" 200 | logger.info(f"Using default {args.model_size} model: {args.config_path}") 201 | 202 | return args 203 | 204 | 205 | def setup_device(device_str, gpu_id=0): 206 | if device_str == "auto": 207 | if torch.cuda.is_available(): 208 | device = torch.device(f"cuda:{gpu_id}") 209 | logger.info(f"Using CUDA device: {device}") 210 | elif torch.backends.mps.is_available(): 211 | device = torch.device("mps") 212 | logger.info("Using MPS device") 213 | else: 214 | device = torch.device("cpu") 215 | logger.info("Using CPU device") 216 | else: 217 | if device_str == "cuda": 218 | device = torch.device(f"cuda:{gpu_id}") 219 | else: 220 | device = torch.device(device_str) 221 | logger.info(f"Using specified device: {device}") 222 | 223 | return device 224 | 225 | 226 | def process_single_video(video_path, prompt, model_dict, cfg, output_dir, args): 227 | logger.info(f"Processing single video: {video_path}") 228 | logger.info(f"Text prompt: {prompt}") 229 | 230 | video_name = os.path.splitext(os.path.basename(video_path))[0] 231 | output_audio_path = os.path.join(output_dir, f"{video_name}_generated.wav") 232 | output_video_path = os.path.join(output_dir, f"{video_name}_with_audio.mp4") 233 | 234 | if args.skip_existing and os.path.exists(output_audio_path): 235 | logger.info(f"Skipping existing audio file: {output_audio_path}") 236 | if args.save_video and os.path.exists(output_video_path): 237 | logger.info(f"Skipping existing video file: {output_video_path}") 238 | return 239 | 240 | audio, sample_rate = infer( 241 | video_path, prompt, model_dict, cfg, 242 | guidance_scale=args.guidance_scale, 243 | num_inference_steps=args.num_inference_steps, 244 | neg_prompt=args.neg_prompt 245 | ) 246 | 247 | torchaudio.save(output_audio_path, audio, sample_rate) 248 | logger.info(f"Audio saved to: {output_audio_path}") 249 | 250 | if args.save_video: 251 | merge_audio_video(output_audio_path, video_path, output_video_path) 252 | logger.info(f"Video with audio saved to: {output_video_path}") 253 | 254 | def main(): 255 | set_manual_seed(1) 256 | args = parse_args() 257 | 258 | logger.remove() 259 | logger.add(lambda msg: print(msg, end=''), level=args.log_level) 260 | 261 | device = setup_device(args.device, args.gpu_id) 262 | 263 | if not os.path.exists(args.model_path): 264 | logger.error(f"Model file not found: {args.model_path}") 265 | exit(1) 266 | if not os.path.exists(args.config_path): 267 | logger.error(f"Config file not found: {args.config_path}") 268 | exit(1) 269 | 270 | if args.csv_path: 271 | if not os.path.exists(args.csv_path): 272 | logger.error(f"CSV file not found: {args.csv_path}") 273 | exit(1) 274 | elif args.single_video: 275 | if not os.path.exists(args.single_video): 276 | logger.error(f"Video file not found: {args.single_video}") 277 | exit(1) 278 | 279 | os.makedirs(args.output_dir, exist_ok=True) 280 | logger.info(f"Output directory: {args.output_dir}") 281 | 282 | logger.info("Loading models...") 283 | model_dict, cfg = load_model(args.model_path, args.config_path, device, enable_offload=args.enable_offload, model_size=args.model_size) 284 | 285 | if args.single_video: 286 | process_single_video( 287 | args.single_video, args.single_prompt, 288 | model_dict, cfg, args.output_dir, args 289 | ) 290 | else: 291 | generate_audio( 292 | model_dict, cfg, 293 | args.csv_path, args.output_dir, 294 | guidance_scale=args.guidance_scale, 295 | num_inference_steps=args.num_inference_steps, 296 | neg_prompt=args.neg_prompt 297 | ) 298 | 299 | logger.info("Processing completed!") 300 | 301 | 302 | 303 | if __name__ == "__main__": 304 | main() -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/video_model_builder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | # Copyright 2020 Ross Wightman 4 | # Modified Model definition 5 | 6 | from collections import OrderedDict 7 | from functools import partial 8 | 9 | import torch 10 | import torch.nn as nn 11 | from timm.layers import trunc_normal_ 12 | 13 | from .vit_helper import PatchEmbed, PatchEmbed3D, DividedSpaceTimeBlock 14 | 15 | 16 | class VisionTransformer(nn.Module): 17 | """Vision Transformer with support for patch or hybrid CNN input stage""" 18 | 19 | def __init__(self, cfg): 20 | super().__init__() 21 | self.img_size = cfg.DATA.TRAIN_CROP_SIZE 22 | self.patch_size = cfg.VIT.PATCH_SIZE 23 | self.in_chans = cfg.VIT.CHANNELS 24 | if cfg.TRAIN.DATASET == "Epickitchens": 25 | self.num_classes = [97, 300] 26 | else: 27 | self.num_classes = cfg.MODEL.NUM_CLASSES 28 | self.embed_dim = cfg.VIT.EMBED_DIM 29 | self.depth = cfg.VIT.DEPTH 30 | self.num_heads = cfg.VIT.NUM_HEADS 31 | self.mlp_ratio = cfg.VIT.MLP_RATIO 32 | self.qkv_bias = cfg.VIT.QKV_BIAS 33 | self.drop_rate = cfg.VIT.DROP 34 | self.drop_path_rate = cfg.VIT.DROP_PATH 35 | self.head_dropout = cfg.VIT.HEAD_DROPOUT 36 | self.video_input = cfg.VIT.VIDEO_INPUT 37 | self.temporal_resolution = cfg.VIT.TEMPORAL_RESOLUTION 38 | self.use_mlp = cfg.VIT.USE_MLP 39 | self.num_features = self.embed_dim 40 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 41 | self.attn_drop_rate = cfg.VIT.ATTN_DROPOUT 42 | self.head_act = cfg.VIT.HEAD_ACT 43 | self.cfg = cfg 44 | 45 | # Patch Embedding 46 | self.patch_embed = PatchEmbed( 47 | img_size=224, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim 48 | ) 49 | 50 | # 3D Patch Embedding 51 | self.patch_embed_3d = PatchEmbed3D( 52 | img_size=self.img_size, 53 | temporal_resolution=self.temporal_resolution, 54 | patch_size=self.patch_size, 55 | in_chans=self.in_chans, 56 | embed_dim=self.embed_dim, 57 | z_block_size=self.cfg.VIT.PATCH_SIZE_TEMP, 58 | ) 59 | self.patch_embed_3d.proj.weight.data = torch.zeros_like(self.patch_embed_3d.proj.weight.data) 60 | 61 | # Number of patches 62 | if self.video_input: 63 | num_patches = self.patch_embed.num_patches * self.temporal_resolution 64 | else: 65 | num_patches = self.patch_embed.num_patches 66 | self.num_patches = num_patches 67 | 68 | # CLS token 69 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 70 | trunc_normal_(self.cls_token, std=0.02) 71 | 72 | # Positional embedding 73 | self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim)) 74 | self.pos_drop = nn.Dropout(p=cfg.VIT.POS_DROPOUT) 75 | trunc_normal_(self.pos_embed, std=0.02) 76 | 77 | if self.cfg.VIT.POS_EMBED == "joint": 78 | self.st_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim)) 79 | trunc_normal_(self.st_embed, std=0.02) 80 | elif self.cfg.VIT.POS_EMBED == "separate": 81 | self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim)) 82 | 83 | # Layer Blocks 84 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] 85 | if self.cfg.VIT.ATTN_LAYER == "divided": 86 | self.blocks = nn.ModuleList( 87 | [ 88 | DividedSpaceTimeBlock( 89 | attn_type=cfg.VIT.ATTN_LAYER, 90 | dim=self.embed_dim, 91 | num_heads=self.num_heads, 92 | mlp_ratio=self.mlp_ratio, 93 | qkv_bias=self.qkv_bias, 94 | drop=self.drop_rate, 95 | attn_drop=self.attn_drop_rate, 96 | drop_path=dpr[i], 97 | norm_layer=norm_layer, 98 | ) 99 | for i in range(self.depth) 100 | ] 101 | ) 102 | 103 | self.norm = norm_layer(self.embed_dim) 104 | 105 | # MLP head 106 | if self.use_mlp: 107 | hidden_dim = self.embed_dim 108 | if self.head_act == "tanh": 109 | # logging.info("Using TanH activation in MLP") 110 | act = nn.Tanh() 111 | elif self.head_act == "gelu": 112 | # logging.info("Using GELU activation in MLP") 113 | act = nn.GELU() 114 | else: 115 | # logging.info("Using ReLU activation in MLP") 116 | act = nn.ReLU() 117 | self.pre_logits = nn.Sequential( 118 | OrderedDict( 119 | [ 120 | ("fc", nn.Linear(self.embed_dim, hidden_dim)), 121 | ("act", act), 122 | ] 123 | ) 124 | ) 125 | else: 126 | self.pre_logits = nn.Identity() 127 | 128 | # Classifier Head 129 | self.head_drop = nn.Dropout(p=self.head_dropout) 130 | if isinstance(self.num_classes, (list,)) and len(self.num_classes) > 1: 131 | for a, i in enumerate(range(len(self.num_classes))): 132 | setattr(self, "head%d" % a, nn.Linear(self.embed_dim, self.num_classes[i])) 133 | else: 134 | self.head = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() 135 | 136 | # Initialize weights 137 | self.apply(self._init_weights) 138 | 139 | def _init_weights(self, m): 140 | if isinstance(m, nn.Linear): 141 | trunc_normal_(m.weight, std=0.02) 142 | if isinstance(m, nn.Linear) and m.bias is not None: 143 | nn.init.constant_(m.bias, 0) 144 | elif isinstance(m, nn.LayerNorm): 145 | nn.init.constant_(m.bias, 0) 146 | nn.init.constant_(m.weight, 1.0) 147 | 148 | @torch.jit.ignore 149 | def no_weight_decay(self): 150 | if self.cfg.VIT.POS_EMBED == "joint": 151 | return {"pos_embed", "cls_token", "st_embed"} 152 | else: 153 | return {"pos_embed", "cls_token", "temp_embed"} 154 | 155 | def get_classifier(self): 156 | return self.head 157 | 158 | def reset_classifier(self, num_classes, global_pool=""): 159 | self.num_classes = num_classes 160 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 161 | 162 | def forward_features(self, x): 163 | # if self.video_input: 164 | # x = x[0] 165 | B = x.shape[0] 166 | 167 | # Tokenize input 168 | # if self.cfg.VIT.PATCH_SIZE_TEMP > 1: 169 | # for simplicity of mapping between content dimensions (input x) and token dims (after patching) 170 | # we use the same trick as for AST (see modeling_ast.ASTModel.forward for the details): 171 | 172 | # apply patching on input 173 | x = self.patch_embed_3d(x) 174 | tok_mask = None 175 | 176 | # else: 177 | # tok_mask = None 178 | # # 2D tokenization 179 | # if self.video_input: 180 | # x = x.permute(0, 2, 1, 3, 4) 181 | # (B, T, C, H, W) = x.shape 182 | # x = x.reshape(B * T, C, H, W) 183 | 184 | # x = self.patch_embed(x) 185 | 186 | # if self.video_input: 187 | # (B2, T2, D2) = x.shape 188 | # x = x.reshape(B, T * T2, D2) 189 | 190 | # Append CLS token 191 | cls_tokens = self.cls_token.expand(B, -1, -1) 192 | x = torch.cat((cls_tokens, x), dim=1) 193 | # if tok_mask is not None: 194 | # # prepend 1(=keep) to the mask to account for the CLS token as well 195 | # tok_mask = torch.cat((torch.ones_like(tok_mask[:, [0]]), tok_mask), dim=1) 196 | 197 | # Interpolate positinoal embeddings 198 | # if self.cfg.DATA.TRAIN_CROP_SIZE != 224: 199 | # pos_embed = self.pos_embed 200 | # N = pos_embed.shape[1] - 1 201 | # npatch = int((x.size(1) - 1) / self.temporal_resolution) 202 | # class_emb = pos_embed[:, 0] 203 | # pos_embed = pos_embed[:, 1:] 204 | # dim = x.shape[-1] 205 | # pos_embed = torch.nn.functional.interpolate( 206 | # pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 207 | # scale_factor=math.sqrt(npatch / N), 208 | # mode='bicubic', 209 | # ) 210 | # pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 211 | # new_pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1) 212 | # else: 213 | new_pos_embed = self.pos_embed 214 | npatch = self.patch_embed.num_patches 215 | 216 | # Add positional embeddings to input 217 | if self.video_input: 218 | if self.cfg.VIT.POS_EMBED == "separate": 219 | cls_embed = self.pos_embed[:, 0, :].unsqueeze(1) 220 | tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1) 221 | tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1) 222 | total_pos_embed = tile_pos_embed + tile_temporal_embed 223 | total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1) 224 | x = x + total_pos_embed 225 | elif self.cfg.VIT.POS_EMBED == "joint": 226 | x = x + self.st_embed 227 | else: 228 | # image input 229 | x = x + new_pos_embed 230 | 231 | # Apply positional dropout 232 | x = self.pos_drop(x) 233 | 234 | # Encoding using transformer layers 235 | for i, blk in enumerate(self.blocks): 236 | x = blk( 237 | x, 238 | seq_len=npatch, 239 | num_frames=self.temporal_resolution, 240 | approx=self.cfg.VIT.APPROX_ATTN_TYPE, 241 | num_landmarks=self.cfg.VIT.APPROX_ATTN_DIM, 242 | tok_mask=tok_mask, 243 | ) 244 | 245 | ### v-iashin: I moved it to the forward pass 246 | # x = self.norm(x)[:, 0] 247 | # x = self.pre_logits(x) 248 | ### 249 | return x, tok_mask 250 | 251 | # def forward(self, x): 252 | # x = self.forward_features(x) 253 | # ### v-iashin: here. This should leave the same forward output as before 254 | # x = self.norm(x)[:, 0] 255 | # x = self.pre_logits(x) 256 | # ### 257 | # x = self.head_drop(x) 258 | # if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1: 259 | # output = [] 260 | # for head in range(len(self.num_classes)): 261 | # x_out = getattr(self, "head%d" % head)(x) 262 | # if not self.training: 263 | # x_out = torch.nn.functional.softmax(x_out, dim=-1) 264 | # output.append(x_out) 265 | # return output 266 | # else: 267 | # x = self.head(x) 268 | # if not self.training: 269 | # x = torch.nn.functional.softmax(x, dim=-1) 270 | # return x 271 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/loss.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from typing import List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from audiotools import AudioSignal 7 | from audiotools import STFTParams 8 | from torch import nn 9 | 10 | 11 | class L1Loss(nn.L1Loss): 12 | """L1 Loss between AudioSignals. Defaults 13 | to comparing ``audio_data``, but any 14 | attribute of an AudioSignal can be used. 15 | 16 | Parameters 17 | ---------- 18 | attribute : str, optional 19 | Attribute of signal to compare, defaults to ``audio_data``. 20 | weight : float, optional 21 | Weight of this loss, defaults to 1.0. 22 | 23 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py 24 | """ 25 | 26 | def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): 27 | self.attribute = attribute 28 | self.weight = weight 29 | super().__init__(**kwargs) 30 | 31 | def forward(self, x: AudioSignal, y: AudioSignal): 32 | """ 33 | Parameters 34 | ---------- 35 | x : AudioSignal 36 | Estimate AudioSignal 37 | y : AudioSignal 38 | Reference AudioSignal 39 | 40 | Returns 41 | ------- 42 | torch.Tensor 43 | L1 loss between AudioSignal attributes. 44 | """ 45 | if isinstance(x, AudioSignal): 46 | x = getattr(x, self.attribute) 47 | y = getattr(y, self.attribute) 48 | return super().forward(x, y) 49 | 50 | 51 | class SISDRLoss(nn.Module): 52 | """ 53 | Computes the Scale-Invariant Source-to-Distortion Ratio between a batch 54 | of estimated and reference audio signals or aligned features. 55 | 56 | Parameters 57 | ---------- 58 | scaling : int, optional 59 | Whether to use scale-invariant (True) or 60 | signal-to-noise ratio (False), by default True 61 | reduction : str, optional 62 | How to reduce across the batch (either 'mean', 63 | 'sum', or none).], by default ' mean' 64 | zero_mean : int, optional 65 | Zero mean the references and estimates before 66 | computing the loss, by default True 67 | clip_min : int, optional 68 | The minimum possible loss value. Helps network 69 | to not focus on making already good examples better, by default None 70 | weight : float, optional 71 | Weight of this loss, defaults to 1.0. 72 | 73 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py 74 | """ 75 | 76 | def __init__( 77 | self, 78 | scaling: int = True, 79 | reduction: str = "mean", 80 | zero_mean: int = True, 81 | clip_min: int = None, 82 | weight: float = 1.0, 83 | ): 84 | self.scaling = scaling 85 | self.reduction = reduction 86 | self.zero_mean = zero_mean 87 | self.clip_min = clip_min 88 | self.weight = weight 89 | super().__init__() 90 | 91 | def forward(self, x: AudioSignal, y: AudioSignal): 92 | eps = 1e-8 93 | # nb, nc, nt 94 | if isinstance(x, AudioSignal): 95 | references = x.audio_data 96 | estimates = y.audio_data 97 | else: 98 | references = x 99 | estimates = y 100 | 101 | nb = references.shape[0] 102 | references = references.reshape(nb, 1, -1).permute(0, 2, 1) 103 | estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) 104 | 105 | # samples now on axis 1 106 | if self.zero_mean: 107 | mean_reference = references.mean(dim=1, keepdim=True) 108 | mean_estimate = estimates.mean(dim=1, keepdim=True) 109 | else: 110 | mean_reference = 0 111 | mean_estimate = 0 112 | 113 | _references = references - mean_reference 114 | _estimates = estimates - mean_estimate 115 | 116 | references_projection = (_references**2).sum(dim=-2) + eps 117 | references_on_estimates = (_estimates * _references).sum(dim=-2) + eps 118 | 119 | scale = ( 120 | (references_on_estimates / references_projection).unsqueeze(1) 121 | if self.scaling 122 | else 1 123 | ) 124 | 125 | e_true = scale * _references 126 | e_res = _estimates - e_true 127 | 128 | signal = (e_true**2).sum(dim=1) 129 | noise = (e_res**2).sum(dim=1) 130 | sdr = -10 * torch.log10(signal / noise + eps) 131 | 132 | if self.clip_min is not None: 133 | sdr = torch.clamp(sdr, min=self.clip_min) 134 | 135 | if self.reduction == "mean": 136 | sdr = sdr.mean() 137 | elif self.reduction == "sum": 138 | sdr = sdr.sum() 139 | return sdr 140 | 141 | 142 | class MultiScaleSTFTLoss(nn.Module): 143 | """Computes the multi-scale STFT loss from [1]. 144 | 145 | Parameters 146 | ---------- 147 | window_lengths : List[int], optional 148 | Length of each window of each STFT, by default [2048, 512] 149 | loss_fn : typing.Callable, optional 150 | How to compare each loss, by default nn.L1Loss() 151 | clamp_eps : float, optional 152 | Clamp on the log magnitude, below, by default 1e-5 153 | mag_weight : float, optional 154 | Weight of raw magnitude portion of loss, by default 1.0 155 | log_weight : float, optional 156 | Weight of log magnitude portion of loss, by default 1.0 157 | pow : float, optional 158 | Power to raise magnitude to before taking log, by default 2.0 159 | weight : float, optional 160 | Weight of this loss, by default 1.0 161 | match_stride : bool, optional 162 | Whether to match the stride of convolutional layers, by default False 163 | 164 | References 165 | ---------- 166 | 167 | 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. 168 | "DDSP: Differentiable Digital Signal Processing." 169 | International Conference on Learning Representations. 2019. 170 | 171 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py 172 | """ 173 | 174 | def __init__( 175 | self, 176 | window_lengths: List[int] = [2048, 512], 177 | loss_fn: typing.Callable = nn.L1Loss(), 178 | clamp_eps: float = 1e-5, 179 | mag_weight: float = 1.0, 180 | log_weight: float = 1.0, 181 | pow: float = 2.0, 182 | weight: float = 1.0, 183 | match_stride: bool = False, 184 | window_type: str = None, 185 | ): 186 | super().__init__() 187 | self.stft_params = [ 188 | STFTParams( 189 | window_length=w, 190 | hop_length=w // 4, 191 | match_stride=match_stride, 192 | window_type=window_type, 193 | ) 194 | for w in window_lengths 195 | ] 196 | self.loss_fn = loss_fn 197 | self.log_weight = log_weight 198 | self.mag_weight = mag_weight 199 | self.clamp_eps = clamp_eps 200 | self.weight = weight 201 | self.pow = pow 202 | 203 | def forward(self, x: AudioSignal, y: AudioSignal): 204 | """Computes multi-scale STFT between an estimate and a reference 205 | signal. 206 | 207 | Parameters 208 | ---------- 209 | x : AudioSignal 210 | Estimate signal 211 | y : AudioSignal 212 | Reference signal 213 | 214 | Returns 215 | ------- 216 | torch.Tensor 217 | Multi-scale STFT loss. 218 | """ 219 | loss = 0.0 220 | for s in self.stft_params: 221 | x.stft(s.window_length, s.hop_length, s.window_type) 222 | y.stft(s.window_length, s.hop_length, s.window_type) 223 | loss += self.log_weight * self.loss_fn( 224 | x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), 225 | y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), 226 | ) 227 | loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) 228 | return loss 229 | 230 | 231 | class MelSpectrogramLoss(nn.Module): 232 | """Compute distance between mel spectrograms. Can be used 233 | in a multi-scale way. 234 | 235 | Parameters 236 | ---------- 237 | n_mels : List[int] 238 | Number of mels per STFT, by default [150, 80], 239 | window_lengths : List[int], optional 240 | Length of each window of each STFT, by default [2048, 512] 241 | loss_fn : typing.Callable, optional 242 | How to compare each loss, by default nn.L1Loss() 243 | clamp_eps : float, optional 244 | Clamp on the log magnitude, below, by default 1e-5 245 | mag_weight : float, optional 246 | Weight of raw magnitude portion of loss, by default 1.0 247 | log_weight : float, optional 248 | Weight of log magnitude portion of loss, by default 1.0 249 | pow : float, optional 250 | Power to raise magnitude to before taking log, by default 2.0 251 | weight : float, optional 252 | Weight of this loss, by default 1.0 253 | match_stride : bool, optional 254 | Whether to match the stride of convolutional layers, by default False 255 | 256 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py 257 | """ 258 | 259 | def __init__( 260 | self, 261 | n_mels: List[int] = [150, 80], 262 | window_lengths: List[int] = [2048, 512], 263 | loss_fn: typing.Callable = nn.L1Loss(), 264 | clamp_eps: float = 1e-5, 265 | mag_weight: float = 1.0, 266 | log_weight: float = 1.0, 267 | pow: float = 2.0, 268 | weight: float = 1.0, 269 | match_stride: bool = False, 270 | mel_fmin: List[float] = [0.0, 0.0], 271 | mel_fmax: List[float] = [None, None], 272 | window_type: str = None, 273 | ): 274 | super().__init__() 275 | self.stft_params = [ 276 | STFTParams( 277 | window_length=w, 278 | hop_length=w // 4, 279 | match_stride=match_stride, 280 | window_type=window_type, 281 | ) 282 | for w in window_lengths 283 | ] 284 | self.n_mels = n_mels 285 | self.loss_fn = loss_fn 286 | self.clamp_eps = clamp_eps 287 | self.log_weight = log_weight 288 | self.mag_weight = mag_weight 289 | self.weight = weight 290 | self.mel_fmin = mel_fmin 291 | self.mel_fmax = mel_fmax 292 | self.pow = pow 293 | 294 | def forward(self, x: AudioSignal, y: AudioSignal): 295 | """Computes mel loss between an estimate and a reference 296 | signal. 297 | 298 | Parameters 299 | ---------- 300 | x : AudioSignal 301 | Estimate signal 302 | y : AudioSignal 303 | Reference signal 304 | 305 | Returns 306 | ------- 307 | torch.Tensor 308 | Mel loss. 309 | """ 310 | loss = 0.0 311 | for n_mels, fmin, fmax, s in zip( 312 | self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params 313 | ): 314 | kwargs = { 315 | "window_length": s.window_length, 316 | "hop_length": s.hop_length, 317 | "window_type": s.window_type, 318 | } 319 | x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) 320 | y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) 321 | 322 | loss += self.log_weight * self.loss_fn( 323 | x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), 324 | y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), 325 | ) 326 | loss += self.mag_weight * self.loss_fn(x_mels, y_mels) 327 | return loss 328 | 329 | 330 | class GANLoss(nn.Module): 331 | """ 332 | Computes a discriminator loss, given a discriminator on 333 | generated waveforms/spectrograms compared to ground truth 334 | waveforms/spectrograms. Computes the loss for both the 335 | discriminator and the generator in separate functions. 336 | """ 337 | 338 | def __init__(self, discriminator): 339 | super().__init__() 340 | self.discriminator = discriminator 341 | 342 | def forward(self, fake, real): 343 | d_fake = self.discriminator(fake.audio_data) 344 | d_real = self.discriminator(real.audio_data) 345 | return d_fake, d_real 346 | 347 | def discriminator_loss(self, fake, real): 348 | d_fake, d_real = self.forward(fake.clone().detach(), real) 349 | 350 | loss_d = 0 351 | for x_fake, x_real in zip(d_fake, d_real): 352 | loss_d += torch.mean(x_fake[-1] ** 2) 353 | loss_d += torch.mean((1 - x_real[-1]) ** 2) 354 | return loss_d 355 | 356 | def generator_loss(self, fake, real): 357 | d_fake, d_real = self.forward(fake, real) 358 | 359 | loss_g = 0 360 | for x_fake in d_fake: 361 | loss_g += torch.mean((1 - x_fake[-1]) ** 2) 362 | 363 | loss_feature = 0 364 | 365 | for i in range(len(d_fake)): 366 | for j in range(len(d_fake[i]) - 1): 367 | loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) 368 | return loss_g, loss_feature 369 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/synchformer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from typing import Any, Mapping 4 | 5 | import einops 6 | import numpy as np 7 | import torch 8 | import torchaudio 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | from .motionformer import MotionFormer 13 | from .ast_model import AST 14 | from .utils import Config 15 | 16 | 17 | class Synchformer(nn.Module): 18 | 19 | def __init__(self): 20 | super().__init__() 21 | 22 | self.vfeat_extractor = MotionFormer( 23 | extract_features=True, 24 | factorize_space_time=True, 25 | agg_space_module="TransformerEncoderLayer", 26 | agg_time_module="torch.nn.Identity", 27 | add_global_repr=False, 28 | ) 29 | self.afeat_extractor = AST( 30 | extract_features=True, 31 | max_spec_t=66, 32 | factorize_freq_time=True, 33 | agg_freq_module="TransformerEncoderLayer", 34 | agg_time_module="torch.nn.Identity", 35 | add_global_repr=False, 36 | ) 37 | 38 | # # bridging the s3d latent dim (1024) into what is specified in the config 39 | # # to match e.g. the transformer dim 40 | self.vproj = nn.Linear(in_features=768, out_features=768) 41 | self.aproj = nn.Linear(in_features=768, out_features=768) 42 | self.transformer = GlobalTransformer( 43 | tok_pdrop=0.0, embd_pdrop=0.1, resid_pdrop=0.1, attn_pdrop=0.1, n_layer=3, n_head=8, n_embd=768 44 | ) 45 | 46 | def forward(self, vis): 47 | B, S, Tv, C, H, W = vis.shape 48 | vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W) 49 | # feat extractors return a tuple of segment-level and global features (ignored for sync) 50 | # (B, S, tv, D), e.g. (B, 7, 8, 768) 51 | vis = self.vfeat_extractor(vis) 52 | return vis 53 | 54 | def compare_v_a(self, vis: torch.Tensor, aud: torch.Tensor): 55 | vis = self.vproj(vis) 56 | aud = self.aproj(aud) 57 | 58 | B, S, tv, D = vis.shape 59 | B, S, ta, D = aud.shape 60 | vis = vis.view(B, S * tv, D) # (B, S*tv, D) 61 | aud = aud.view(B, S * ta, D) # (B, S*ta, D) 62 | # print(vis.shape, aud.shape) 63 | 64 | # self.transformer will concatenate the vis and aud in one sequence with aux tokens, 65 | # ie `CvvvvMaaaaaa`, and will return the logits for the CLS tokens 66 | logits = self.transformer(vis, aud) # (B, cls); or (B, cls) and (B, 2) if DoubtingTransformer 67 | 68 | return logits 69 | 70 | def extract_vfeats(self, vis): 71 | B, S, Tv, C, H, W = vis.shape 72 | vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W) 73 | # feat extractors return a tuple of segment-level and global features (ignored for sync) 74 | # (B, S, tv, D), e.g. (B, 7, 8, 768) 75 | vis = self.vfeat_extractor(vis) 76 | return vis 77 | 78 | def extract_afeats(self, aud): 79 | B, S, _, Fa, Ta = aud.shape 80 | aud = aud.view(B, S, Fa, Ta).permute(0, 1, 3, 2) # (B, S, Ta, F) 81 | # (B, S, ta, D), e.g. (B, 7, 6, 768) 82 | aud, _ = self.afeat_extractor(aud) 83 | return aud 84 | 85 | def compute_loss(self, logits, targets, loss_fn: str = None): 86 | loss = None 87 | if targets is not None: 88 | if loss_fn is None or loss_fn == "cross_entropy": 89 | # logits: (B, cls) and targets: (B,) 90 | loss = F.cross_entropy(logits, targets) 91 | else: 92 | raise NotImplementedError(f"Loss {loss_fn} not implemented") 93 | return loss 94 | 95 | def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True): 96 | # discard all entries except vfeat_extractor 97 | # sd = {k: v for k, v in sd.items() if k.startswith('vfeat_extractor')} 98 | 99 | return super().load_state_dict(sd, strict) 100 | 101 | 102 | class RandInitPositionalEncoding(nn.Module): 103 | """Random inited trainable pos embedding. It is just applied on the sequence, thus respects no priors.""" 104 | 105 | def __init__(self, block_shape: list, n_embd: int): 106 | super().__init__() 107 | self.block_shape = block_shape 108 | self.n_embd = n_embd 109 | self.pos_emb = nn.Parameter(torch.randn(1, *block_shape, n_embd)) 110 | 111 | def forward(self, token_embeddings): 112 | return token_embeddings + self.pos_emb 113 | 114 | 115 | class GlobalTransformer(torch.nn.Module): 116 | """Same as in SparseSync but without the selector transformers and the head""" 117 | 118 | def __init__( 119 | self, 120 | tok_pdrop=0.0, 121 | embd_pdrop=0.1, 122 | resid_pdrop=0.1, 123 | attn_pdrop=0.1, 124 | n_layer=3, 125 | n_head=8, 126 | n_embd=768, 127 | pos_emb_block_shape=[ 128 | 198, 129 | ], 130 | n_off_head_out=21, 131 | ) -> None: 132 | super().__init__() 133 | self.config = Config( 134 | embd_pdrop=embd_pdrop, 135 | resid_pdrop=resid_pdrop, 136 | attn_pdrop=attn_pdrop, 137 | n_layer=n_layer, 138 | n_head=n_head, 139 | n_embd=n_embd, 140 | ) 141 | # input norm 142 | self.vis_in_lnorm = torch.nn.LayerNorm(n_embd) 143 | self.aud_in_lnorm = torch.nn.LayerNorm(n_embd) 144 | # aux tokens 145 | self.OFF_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd)) 146 | self.MOD_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd)) 147 | # whole token dropout 148 | self.tok_pdrop = tok_pdrop 149 | self.tok_drop_vis = torch.nn.Dropout1d(tok_pdrop) 150 | self.tok_drop_aud = torch.nn.Dropout1d(tok_pdrop) 151 | # maybe add pos emb 152 | self.pos_emb_cfg = RandInitPositionalEncoding( 153 | block_shape=pos_emb_block_shape, 154 | n_embd=n_embd, 155 | ) 156 | # the stem 157 | self.drop = torch.nn.Dropout(embd_pdrop) 158 | self.blocks = torch.nn.Sequential(*[Block(self.config) for _ in range(n_layer)]) 159 | # pre-output norm 160 | self.ln_f = torch.nn.LayerNorm(n_embd) 161 | # maybe add a head 162 | self.off_head = torch.nn.Linear(in_features=n_embd, out_features=n_off_head_out) 163 | 164 | def forward(self, v: torch.Tensor, a: torch.Tensor, targets=None, attempt_to_apply_heads=True): 165 | B, Sv, D = v.shape 166 | B, Sa, D = a.shape 167 | # broadcasting special tokens to the batch size 168 | off_tok = einops.repeat(self.OFF_tok, "1 1 d -> b 1 d", b=B) 169 | mod_tok = einops.repeat(self.MOD_tok, "1 1 d -> b 1 d", b=B) 170 | # norm 171 | v, a = self.vis_in_lnorm(v), self.aud_in_lnorm(a) 172 | # maybe whole token dropout 173 | if self.tok_pdrop > 0: 174 | v, a = self.tok_drop_vis(v), self.tok_drop_aud(a) 175 | # (B, 1+Sv+1+Sa, D) 176 | x = torch.cat((off_tok, v, mod_tok, a), dim=1) 177 | # maybe add pos emb 178 | if hasattr(self, "pos_emb_cfg"): 179 | x = self.pos_emb_cfg(x) 180 | # dropout -> stem -> norm 181 | x = self.drop(x) 182 | x = self.blocks(x) 183 | x = self.ln_f(x) 184 | # maybe add heads 185 | if attempt_to_apply_heads and hasattr(self, "off_head"): 186 | x = self.off_head(x[:, 0, :]) 187 | return x 188 | 189 | 190 | class SelfAttention(nn.Module): 191 | """ 192 | A vanilla multi-head masked self-attention layer with a projection at the end. 193 | It is possible to use torch.nn.MultiheadAttention here but I am including an 194 | explicit implementation here to show that there is nothing too scary here. 195 | """ 196 | 197 | def __init__(self, config): 198 | super().__init__() 199 | assert config.n_embd % config.n_head == 0 200 | # key, query, value projections for all heads 201 | self.key = nn.Linear(config.n_embd, config.n_embd) 202 | self.query = nn.Linear(config.n_embd, config.n_embd) 203 | self.value = nn.Linear(config.n_embd, config.n_embd) 204 | # regularization 205 | self.attn_drop = nn.Dropout(config.attn_pdrop) 206 | self.resid_drop = nn.Dropout(config.resid_pdrop) 207 | # output projection 208 | self.proj = nn.Linear(config.n_embd, config.n_embd) 209 | # # causal mask to ensure that attention is only applied to the left in the input sequence 210 | # mask = torch.tril(torch.ones(config.block_size, 211 | # config.block_size)) 212 | # if hasattr(config, "n_unmasked"): 213 | # mask[:config.n_unmasked, :config.n_unmasked] = 1 214 | # self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size)) 215 | self.n_head = config.n_head 216 | 217 | def forward(self, x): 218 | B, T, C = x.size() 219 | 220 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 221 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 222 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 223 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 224 | 225 | # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 226 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 227 | # att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf')) 228 | att = F.softmax(att, dim=-1) 229 | y = self.attn_drop(att) @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 230 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 231 | 232 | # output projection 233 | y = self.resid_drop(self.proj(y)) 234 | 235 | return y 236 | 237 | 238 | class Block(nn.Module): 239 | """an unassuming Transformer block""" 240 | 241 | def __init__(self, config): 242 | super().__init__() 243 | self.ln1 = nn.LayerNorm(config.n_embd) 244 | self.ln2 = nn.LayerNorm(config.n_embd) 245 | self.attn = SelfAttention(config) 246 | self.mlp = nn.Sequential( 247 | nn.Linear(config.n_embd, 4 * config.n_embd), 248 | nn.GELU(), # nice 249 | nn.Linear(4 * config.n_embd, config.n_embd), 250 | nn.Dropout(config.resid_pdrop), 251 | ) 252 | 253 | def forward(self, x): 254 | x = x + self.attn(self.ln1(x)) 255 | x = x + self.mlp(self.ln2(x)) 256 | return x 257 | 258 | 259 | def make_class_grid( 260 | leftmost_val, 261 | rightmost_val, 262 | grid_size, 263 | add_extreme_offset: bool = False, 264 | seg_size_vframes: int = None, 265 | nseg: int = None, 266 | step_size_seg: float = None, 267 | vfps: float = None, 268 | ): 269 | assert grid_size >= 3, f"grid_size: {grid_size} doesnot make sense. If =2 -> (-1,1); =1 -> (-1); =0 -> ()" 270 | grid = torch.from_numpy(np.linspace(leftmost_val, rightmost_val, grid_size)).float() 271 | if add_extreme_offset: 272 | assert all([seg_size_vframes, nseg, step_size_seg]), f"{seg_size_vframes} {nseg} {step_size_seg}" 273 | seg_size_sec = seg_size_vframes / vfps 274 | trim_size_in_seg = nseg - (1 - step_size_seg) * (nseg - 1) 275 | extreme_value = trim_size_in_seg * seg_size_sec 276 | grid = torch.cat([grid, torch.tensor([extreme_value])]) # adding extreme offset to the class grid 277 | return grid 278 | 279 | 280 | # from synchformer 281 | def pad_or_truncate(audio: torch.Tensor, max_spec_t: int, pad_mode: str = "constant", pad_value: float = 0.0): 282 | difference = max_spec_t - audio.shape[-1] # safe for batched input 283 | # pad or truncate, depending on difference 284 | if difference > 0: 285 | # pad the last dim (time) -> (..., n_mels, 0+time+difference) # safe for batched input 286 | pad_dims = (0, difference) 287 | audio = torch.nn.functional.pad(audio, pad_dims, pad_mode, pad_value) 288 | elif difference < 0: 289 | print(f"Truncating spec ({audio.shape}) to max_spec_t ({max_spec_t}).") 290 | audio = audio[..., :max_spec_t] # safe for batched input 291 | return audio 292 | 293 | 294 | def encode_audio_with_sync( 295 | synchformer: Synchformer, x: torch.Tensor, mel: torchaudio.transforms.MelSpectrogram 296 | ) -> torch.Tensor: 297 | b, t = x.shape 298 | 299 | # partition the video 300 | segment_size = 10240 301 | step_size = 10240 // 2 302 | num_segments = (t - segment_size) // step_size + 1 303 | segments = [] 304 | for i in range(num_segments): 305 | segments.append(x[:, i * step_size : i * step_size + segment_size]) 306 | x = torch.stack(segments, dim=1) # (B, S, T, C, H, W) 307 | 308 | x = mel(x) 309 | x = torch.log(x + 1e-6) 310 | x = pad_or_truncate(x, 66) 311 | 312 | mean = -4.2677393 313 | std = 4.5689974 314 | x = (x - mean) / (2 * std) 315 | # x: B * S * 128 * 66 316 | x = synchformer.extract_afeats(x.unsqueeze(2)) 317 | return x 318 | 319 | 320 | def read_audio(filename, expected_length=int(16000 * 4)): 321 | waveform, sr = torchaudio.load(filename) 322 | waveform = waveform.mean(dim=0) 323 | 324 | if sr != 16000: 325 | resampler = torchaudio.transforms.Resample(sr, 16000) 326 | waveform = resampler[sr](waveform) 327 | 328 | waveform = waveform[:expected_length] 329 | if waveform.shape[0] != expected_length: 330 | raise ValueError(f"Audio {filename} is too short") 331 | 332 | waveform = waveform.squeeze() 333 | 334 | return waveform 335 | 336 | 337 | if __name__ == "__main__": 338 | synchformer = Synchformer().cuda().eval() 339 | 340 | # mmaudio provided synchformer ckpt 341 | synchformer.load_state_dict( 342 | torch.load( 343 | os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"), 344 | weights_only=True, 345 | map_location="cpu", 346 | ) 347 | ) 348 | 349 | sync_mel_spectrogram = torchaudio.transforms.MelSpectrogram( 350 | sample_rate=16000, 351 | win_length=400, 352 | hop_length=160, 353 | n_fft=1024, 354 | n_mels=128, 355 | ) 356 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/model/dac.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List 3 | from typing import Union 4 | 5 | import numpy as np 6 | import torch 7 | from audiotools import AudioSignal 8 | from audiotools.ml import BaseModel 9 | from torch import nn 10 | 11 | from .base import CodecMixin 12 | from ..nn.layers import Snake1d 13 | from ..nn.layers import WNConv1d 14 | from ..nn.layers import WNConvTranspose1d 15 | from ..nn.quantize import ResidualVectorQuantize 16 | from ..nn.vae_utils import DiagonalGaussianDistribution 17 | 18 | 19 | def init_weights(m): 20 | if isinstance(m, nn.Conv1d): 21 | nn.init.trunc_normal_(m.weight, std=0.02) 22 | nn.init.constant_(m.bias, 0) 23 | 24 | 25 | class ResidualUnit(nn.Module): 26 | def __init__(self, dim: int = 16, dilation: int = 1): 27 | super().__init__() 28 | pad = ((7 - 1) * dilation) // 2 29 | self.block = nn.Sequential( 30 | Snake1d(dim), 31 | WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), 32 | Snake1d(dim), 33 | WNConv1d(dim, dim, kernel_size=1), 34 | ) 35 | 36 | def forward(self, x): 37 | y = self.block(x) 38 | pad = (x.shape[-1] - y.shape[-1]) // 2 39 | if pad > 0: 40 | x = x[..., pad:-pad] 41 | return x + y 42 | 43 | 44 | class EncoderBlock(nn.Module): 45 | def __init__(self, dim: int = 16, stride: int = 1): 46 | super().__init__() 47 | self.block = nn.Sequential( 48 | ResidualUnit(dim // 2, dilation=1), 49 | ResidualUnit(dim // 2, dilation=3), 50 | ResidualUnit(dim // 2, dilation=9), 51 | Snake1d(dim // 2), 52 | WNConv1d( 53 | dim // 2, 54 | dim, 55 | kernel_size=2 * stride, 56 | stride=stride, 57 | padding=math.ceil(stride / 2), 58 | ), 59 | ) 60 | 61 | def forward(self, x): 62 | return self.block(x) 63 | 64 | 65 | class Encoder(nn.Module): 66 | def __init__( 67 | self, 68 | d_model: int = 64, 69 | strides: list = [2, 4, 8, 8], 70 | d_latent: int = 64, 71 | ): 72 | super().__init__() 73 | # Create first convolution 74 | self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] 75 | 76 | # Create EncoderBlocks that double channels as they downsample by `stride` 77 | for stride in strides: 78 | d_model *= 2 79 | self.block += [EncoderBlock(d_model, stride=stride)] 80 | 81 | # Create last convolution 82 | self.block += [ 83 | Snake1d(d_model), 84 | WNConv1d(d_model, d_latent, kernel_size=3, padding=1), 85 | ] 86 | 87 | # Wrap black into nn.Sequential 88 | self.block = nn.Sequential(*self.block) 89 | self.enc_dim = d_model 90 | 91 | def forward(self, x): 92 | return self.block(x) 93 | 94 | 95 | class DecoderBlock(nn.Module): 96 | def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): 97 | super().__init__() 98 | self.block = nn.Sequential( 99 | Snake1d(input_dim), 100 | WNConvTranspose1d( 101 | input_dim, 102 | output_dim, 103 | kernel_size=2 * stride, 104 | stride=stride, 105 | padding=math.ceil(stride / 2), 106 | output_padding=stride % 2, 107 | ), 108 | ResidualUnit(output_dim, dilation=1), 109 | ResidualUnit(output_dim, dilation=3), 110 | ResidualUnit(output_dim, dilation=9), 111 | ) 112 | 113 | def forward(self, x): 114 | return self.block(x) 115 | 116 | 117 | class Decoder(nn.Module): 118 | def __init__( 119 | self, 120 | input_channel, 121 | channels, 122 | rates, 123 | d_out: int = 1, 124 | ): 125 | super().__init__() 126 | 127 | # Add first conv layer 128 | layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] 129 | 130 | # Add upsampling + MRF blocks 131 | for i, stride in enumerate(rates): 132 | input_dim = channels // 2**i 133 | output_dim = channels // 2 ** (i + 1) 134 | layers += [DecoderBlock(input_dim, output_dim, stride)] 135 | 136 | # Add final conv layer 137 | layers += [ 138 | Snake1d(output_dim), 139 | WNConv1d(output_dim, d_out, kernel_size=7, padding=3), 140 | nn.Tanh(), 141 | ] 142 | 143 | self.model = nn.Sequential(*layers) 144 | 145 | def forward(self, x): 146 | return self.model(x) 147 | 148 | 149 | class DAC(BaseModel, CodecMixin): 150 | def __init__( 151 | self, 152 | encoder_dim: int = 64, 153 | encoder_rates: List[int] = [2, 4, 8, 8], 154 | latent_dim: int = None, 155 | decoder_dim: int = 1536, 156 | decoder_rates: List[int] = [8, 8, 4, 2], 157 | n_codebooks: int = 9, 158 | codebook_size: int = 1024, 159 | codebook_dim: Union[int, list] = 8, 160 | quantizer_dropout: bool = False, 161 | sample_rate: int = 44100, 162 | continuous: bool = False, 163 | ): 164 | super().__init__() 165 | 166 | self.encoder_dim = encoder_dim 167 | self.encoder_rates = encoder_rates 168 | self.decoder_dim = decoder_dim 169 | self.decoder_rates = decoder_rates 170 | self.sample_rate = sample_rate 171 | self.continuous = continuous 172 | 173 | if latent_dim is None: 174 | latent_dim = encoder_dim * (2 ** len(encoder_rates)) 175 | 176 | self.latent_dim = latent_dim 177 | 178 | self.hop_length = np.prod(encoder_rates) 179 | self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim) 180 | 181 | if not continuous: 182 | self.n_codebooks = n_codebooks 183 | self.codebook_size = codebook_size 184 | self.codebook_dim = codebook_dim 185 | self.quantizer = ResidualVectorQuantize( 186 | input_dim=latent_dim, 187 | n_codebooks=n_codebooks, 188 | codebook_size=codebook_size, 189 | codebook_dim=codebook_dim, 190 | quantizer_dropout=quantizer_dropout, 191 | ) 192 | else: 193 | self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1) 194 | self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1) 195 | 196 | self.decoder = Decoder( 197 | latent_dim, 198 | decoder_dim, 199 | decoder_rates, 200 | ) 201 | self.sample_rate = sample_rate 202 | self.apply(init_weights) 203 | 204 | self.delay = self.get_delay() 205 | 206 | @property 207 | def dtype(self): 208 | """Get the dtype of the model parameters.""" 209 | # Return the dtype of the first parameter found 210 | for param in self.parameters(): 211 | return param.dtype 212 | return torch.float32 # fallback 213 | 214 | @property 215 | def device(self): 216 | """Get the device of the model parameters.""" 217 | # Return the device of the first parameter found 218 | for param in self.parameters(): 219 | return param.device 220 | return torch.device('cpu') # fallback 221 | 222 | def preprocess(self, audio_data, sample_rate): 223 | if sample_rate is None: 224 | sample_rate = self.sample_rate 225 | assert sample_rate == self.sample_rate 226 | 227 | length = audio_data.shape[-1] 228 | right_pad = math.ceil(length / self.hop_length) * self.hop_length - length 229 | audio_data = nn.functional.pad(audio_data, (0, right_pad)) 230 | 231 | return audio_data 232 | 233 | def encode( 234 | self, 235 | audio_data: torch.Tensor, 236 | n_quantizers: int = None, 237 | ): 238 | """Encode given audio data and return quantized latent codes 239 | 240 | Parameters 241 | ---------- 242 | audio_data : Tensor[B x 1 x T] 243 | Audio data to encode 244 | n_quantizers : int, optional 245 | Number of quantizers to use, by default None 246 | If None, all quantizers are used. 247 | 248 | Returns 249 | ------- 250 | dict 251 | A dictionary with the following keys: 252 | "z" : Tensor[B x D x T] 253 | Quantized continuous representation of input 254 | "codes" : Tensor[B x N x T] 255 | Codebook indices for each codebook 256 | (quantized discrete representation of input) 257 | "latents" : Tensor[B x N*D x T] 258 | Projected latents (continuous representation of input before quantization) 259 | "vq/commitment_loss" : Tensor[1] 260 | Commitment loss to train encoder to predict vectors closer to codebook 261 | entries 262 | "vq/codebook_loss" : Tensor[1] 263 | Codebook loss to update the codebook 264 | "length" : int 265 | Number of samples in input audio 266 | """ 267 | z = self.encoder(audio_data) # [B x D x T] 268 | if not self.continuous: 269 | z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers) 270 | else: 271 | z = self.quant_conv(z) # [B x 2D x T] 272 | z = DiagonalGaussianDistribution(z) 273 | codes, latents, commitment_loss, codebook_loss = None, None, 0, 0 274 | 275 | return z, codes, latents, commitment_loss, codebook_loss 276 | 277 | def decode(self, z: torch.Tensor): 278 | """Decode given latent codes and return audio data 279 | 280 | Parameters 281 | ---------- 282 | z : Tensor[B x D x T] 283 | Quantized continuous representation of input 284 | length : int, optional 285 | Number of samples in output audio, by default None 286 | 287 | Returns 288 | ------- 289 | dict 290 | A dictionary with the following keys: 291 | "audio" : Tensor[B x 1 x length] 292 | Decoded audio data. 293 | """ 294 | if not self.continuous: 295 | audio = self.decoder(z) 296 | else: 297 | z = self.post_quant_conv(z) 298 | audio = self.decoder(z) 299 | 300 | return audio 301 | 302 | def forward( 303 | self, 304 | audio_data: torch.Tensor, 305 | sample_rate: int = None, 306 | n_quantizers: int = None, 307 | ): 308 | """Model forward pass 309 | 310 | Parameters 311 | ---------- 312 | audio_data : Tensor[B x 1 x T] 313 | Audio data to encode 314 | sample_rate : int, optional 315 | Sample rate of audio data in Hz, by default None 316 | If None, defaults to `self.sample_rate` 317 | n_quantizers : int, optional 318 | Number of quantizers to use, by default None. 319 | If None, all quantizers are used. 320 | 321 | Returns 322 | ------- 323 | dict 324 | A dictionary with the following keys: 325 | "z" : Tensor[B x D x T] 326 | Quantized continuous representation of input 327 | "codes" : Tensor[B x N x T] 328 | Codebook indices for each codebook 329 | (quantized discrete representation of input) 330 | "latents" : Tensor[B x N*D x T] 331 | Projected latents (continuous representation of input before quantization) 332 | "vq/commitment_loss" : Tensor[1] 333 | Commitment loss to train encoder to predict vectors closer to codebook 334 | entries 335 | "vq/codebook_loss" : Tensor[1] 336 | Codebook loss to update the codebook 337 | "length" : int 338 | Number of samples in input audio 339 | "audio" : Tensor[B x 1 x length] 340 | Decoded audio data. 341 | """ 342 | length = audio_data.shape[-1] 343 | audio_data = self.preprocess(audio_data, sample_rate) 344 | if not self.continuous: 345 | z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers) 346 | 347 | x = self.decode(z) 348 | return { 349 | "audio": x[..., :length], 350 | "z": z, 351 | "codes": codes, 352 | "latents": latents, 353 | "vq/commitment_loss": commitment_loss, 354 | "vq/codebook_loss": codebook_loss, 355 | } 356 | else: 357 | posterior, _, _, _, _ = self.encode(audio_data, n_quantizers) 358 | z = posterior.sample() 359 | x = self.decode(z) 360 | 361 | kl_loss = posterior.kl() 362 | kl_loss = kl_loss.mean() 363 | 364 | return { 365 | "audio": x[..., :length], 366 | "z": z, 367 | "kl_loss": kl_loss, 368 | } 369 | 370 | 371 | if __name__ == "__main__": 372 | import numpy as np 373 | from functools import partial 374 | 375 | model = DAC().to("cpu") 376 | 377 | for n, m in model.named_modules(): 378 | o = m.extra_repr() 379 | p = sum([np.prod(p.size()) for p in m.parameters()]) 380 | fn = lambda o, p: o + f" {p/1e6:<.3f}M params." 381 | setattr(m, "extra_repr", partial(fn, o=o, p=p)) 382 | print(model) 383 | print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) 384 | 385 | length = 88200 * 2 386 | x = torch.randn(1, 1, length).to(model.device) 387 | x.requires_grad_(True) 388 | x.retain_grad() 389 | 390 | # Make a forward pass 391 | out = model(x)["audio"] 392 | print("Input shape:", x.shape) 393 | print("Output shape:", out.shape) 394 | 395 | # Create gradient variable 396 | grad = torch.zeros_like(out) 397 | grad[:, :, grad.shape[-1] // 2] = 1 398 | 399 | # Make a backward pass 400 | out.backward(grad) 401 | 402 | # Check non-zero values 403 | gradmap = x.grad.squeeze(0) 404 | gradmap = (gradmap != 0).sum(0) # sum across features 405 | rf = (gradmap != 0).sum() 406 | 407 | print(f"Receptive field: {rf.item()}") 408 | 409 | x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100) 410 | model.decompress(model.compress(x, verbose=True), verbose=True) 411 | --------------------------------------------------------------------------------