├── hunyuanvideo_foley
├── __init__.py
├── models
│ ├── __init__.py
│ ├── nn
│ │ ├── __init__.py
│ │ ├── activation_layers.py
│ │ ├── modulate_layers.py
│ │ ├── norm_layers.py
│ │ ├── embed_layers.py
│ │ ├── mlp_layers.py
│ │ └── posemb_layers.py
│ ├── synchformer
│ │ ├── __init__.py
│ │ ├── divided_224_16x4.yaml
│ │ ├── utils.py
│ │ ├── compute_desync_score.py
│ │ ├── video_model_builder.py
│ │ └── synchformer.py
│ └── dac_vae
│ │ ├── nn
│ │ ├── __init__.py
│ │ ├── layers.py
│ │ ├── vae_utils.py
│ │ ├── quantize.py
│ │ └── loss.py
│ │ ├── model
│ │ ├── __init__.py
│ │ ├── discriminator.py
│ │ ├── base.py
│ │ └── dac.py
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ └── utils
│ │ ├── decode.py
│ │ ├── encode.py
│ │ └── __init__.py
├── utils
│ ├── __init__.py
│ ├── schedulers
│ │ └── __init__.py
│ ├── config_utils.py
│ ├── helper.py
│ ├── media_utils.py
│ ├── feature_utils.py
│ └── model_utils.py
└── constants.py
├── requirements.txt
├── CONTRIBUTORS.md
├── pyproject.toml
├── .github
└── workflows
│ └── main.yml
├── __init__.py
├── .gitignore
├── configs
└── hunyuanvideo-foley-xxl.yaml
├── model_urls.py
├── model_management.py
├── INSTALLATION_GUIDE.md
├── download_models_manual.py
├── install.py
├── test_node.py
└── README.md
/hunyuanvideo_foley/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/nn/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/synchformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .synchformer import Synchformer
2 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from . import layers
2 | from . import loss
3 | from . import quantize
4 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import CodecMixin
2 | from .base import DACFile
3 | from .dac import DAC
4 | from .discriminator import Discriminator
5 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/utils/schedulers/__init__.py:
--------------------------------------------------------------------------------
1 | from diffusers.schedulers import DDPMScheduler, EulerDiscreteScheduler
2 | from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | loguru
3 | tqdm
4 | accelerate
5 | transformers>=4.37.0
6 | safetensors
7 | requests
8 | opencv-python
9 | diffusers
10 | pyyaml
11 | einops
12 | omegaconf
13 | packaging
14 | pytorch-lightning
15 | descript-audio-codec
16 | scipy
17 | soxr
18 | ffmpy
19 | audiocraft
20 | descript-audio-codec
21 |
--------------------------------------------------------------------------------
/CONTRIBUTORS.md:
--------------------------------------------------------------------------------
1 | # Contributors
2 |
3 | We are grateful to the following individuals for their contributions to this project:
4 |
5 | - [@dasilva333](https://github.com/dasilva333) - Added `enabled` and `silent_audio` toggles for improved workflow control and error handling.
6 | - [@yichengup](https://github.com/yichengup) - Implemented image frame input/output and added a negative prompt field.
7 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.0.0"
2 |
3 | # preserved here for legacy reasons
4 | __model_version__ = "latest"
5 |
6 | import audiotools
7 |
8 | audiotools.ml.BaseModel.INTERN += ["dac.**"]
9 | audiotools.ml.BaseModel.EXTERN += ["einops"]
10 |
11 |
12 | from . import nn
13 | from . import model
14 | from . import utils
15 | from .model import DAC
16 | from .model import DACFile
17 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=45", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "hunyuanvideo-foley"
7 | description = "ComfyUI custom node for HunyuanVideo-Foley text-video-to-audio synthesis"
8 | version = "1.0.3"
9 | license = { file = "LICENSE.txt" }
10 | dependencies = []
11 |
12 | [project.urls]
13 | Repository = "https://github.com/if-ai/ComfyUI_HunyuanVideoFoley"
14 |
15 | [tool.comfy]
16 | PublisherId = "impactframes"
17 | DisplayName = "HunyuanVideo-Foley"
18 | Icon = ""
19 |
20 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/.github/workflows/main.yml:
--------------------------------------------------------------------------------
1 | name: Publish to Comfy registry
2 | on:
3 | workflow_dispatch:
4 | push:
5 | branches:
6 | - main
7 | paths:
8 | - "pyproject.toml"
9 |
10 | jobs:
11 | publish-node:
12 | name: Publish Custom Node to registry
13 | runs-on: ubuntu-latest
14 | steps:
15 | - name: Check out code
16 | uses: actions/checkout@v4
17 | - name: Publish Custom Node
18 | uses: Comfy-Org/publish-node-action@main
19 | with:
20 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} ## Add your own personal access token to your Github Repository secrets and reference it here.
21 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from loguru import logger
4 |
5 | # Add the current directory to Python path to import hunyuanvideo_foley modules
6 | current_dir = os.path.dirname(os.path.abspath(__file__))
7 | if current_dir not in sys.path:
8 | sys.path.insert(0, current_dir)
9 |
10 | # Import the individual nodes (with FP8 quantization and torch.compile support)
11 | logger.info("Loading HunyuanVideo-Foley nodes with FP8 quantization and torch.compile support")
12 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
13 |
14 | # Export the mappings
15 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/__main__.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import argbind
4 |
5 | from .utils import download
6 | from .utils.decode import decode
7 | from .utils.encode import encode
8 |
9 | STAGES = ["encode", "decode", "download"]
10 |
11 |
12 | def run(stage: str):
13 | """Run stages.
14 |
15 | Parameters
16 | ----------
17 | stage : str
18 | Stage to run
19 | """
20 | if stage not in STAGES:
21 | raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
22 | stage_fn = globals()[stage]
23 |
24 | if stage == "download":
25 | stage_fn()
26 | return
27 |
28 | stage_fn()
29 |
30 |
31 | if __name__ == "__main__":
32 | group = sys.argv.pop(1)
33 | args = argbind.parse_args(group=group)
34 |
35 | with argbind.scope(args):
36 | run(group)
37 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/nn/layers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from einops import rearrange
6 | from torch.nn.utils import weight_norm
7 |
8 |
9 | def WNConv1d(*args, **kwargs):
10 | return weight_norm(nn.Conv1d(*args, **kwargs))
11 |
12 |
13 | def WNConvTranspose1d(*args, **kwargs):
14 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15 |
16 |
17 | # Scripting this brings model speed up 1.4x
18 | @torch.jit.script
19 | def snake(x, alpha):
20 | shape = x.shape
21 | x = x.reshape(shape[0], shape[1], -1)
22 | x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
23 | x = x.reshape(shape)
24 | return x
25 |
26 |
27 | class Snake1d(nn.Module):
28 | def __init__(self, channels):
29 | super().__init__()
30 | self.alpha = nn.Parameter(torch.ones(1, channels, 1))
31 |
32 | def forward(self, x):
33 | return snake(x, self.alpha)
34 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class$
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .nox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | .hypothesis/
49 | .pytest_cache/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 | db.sqlite3
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # Environments
77 | .env
78 | .venv
79 | env/
80 | venv/
81 | ENV/
82 | env.bak/
83 | venv.bak/
84 |
85 | # IDE-specific files
86 | .idea/
87 | .vscode/
88 | *.suo
89 | *.ntvs*
90 | *.njsproj
91 | *.sln
92 | *.swp
93 |
94 | # AI tool-specific files
95 | .claude/
96 | .serena/
97 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/nn/activation_layers.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | def get_activation_layer(act_type):
5 | if act_type == "gelu":
6 | return lambda: nn.GELU()
7 | elif act_type == "gelu_tanh":
8 | # Approximate `tanh` requires torch >= 1.13
9 | return lambda: nn.GELU(approximate="tanh")
10 | elif act_type == "relu":
11 | return nn.ReLU
12 | elif act_type == "silu":
13 | return nn.SiLU
14 | else:
15 | raise ValueError(f"Unknown activation type: {act_type}")
16 |
17 | class SwiGLU(nn.Module):
18 | def __init__(
19 | self,
20 | dim: int,
21 | hidden_dim: int,
22 | out_dim: int,
23 | ):
24 | """
25 | Initialize the SwiGLU FeedForward module.
26 |
27 | Args:
28 | dim (int): Input dimension.
29 | hidden_dim (int): Hidden dimension of the feedforward layer.
30 |
31 | Attributes:
32 | w1: Linear transformation for the first layer.
33 | w2: Linear transformation for the second layer.
34 | w3: Linear transformation for the third layer.
35 |
36 | """
37 | super().__init__()
38 |
39 | self.w1 = nn.Linear(dim, hidden_dim, bias=False)
40 | self.w2 = nn.Linear(hidden_dim, out_dim, bias=False)
41 | self.w3 = nn.Linear(dim, hidden_dim, bias=False)
42 |
43 | def forward(self, x):
44 | return self.w2(F.silu(self.w1(x)) * self.w3(x))
45 |
--------------------------------------------------------------------------------
/configs/hunyuanvideo-foley-xxl.yaml:
--------------------------------------------------------------------------------
1 | model_config:
2 | model_name: HunyuanVideo-Foley-XXL
3 | model_type: 1d
4 | model_precision: bf16
5 | model_kwargs:
6 | depth_triple_blocks: 18
7 | depth_single_blocks: 36
8 | hidden_size: 1536
9 | num_heads: 12
10 | mlp_ratio: 4
11 | mlp_act_type: "gelu_tanh"
12 | qkv_bias: True
13 | qk_norm: True
14 | qk_norm_type: "rms"
15 | attn_mode: "torch"
16 | embedder_type: "default"
17 | interleaved_audio_visual_rope: True
18 | enable_learnable_empty_visual_feat: True
19 | sync_modulation: False
20 | add_sync_feat_to_audio: True
21 | cross_attention: True
22 | use_attention_mask: False
23 | condition_projection: "linear"
24 | sync_feat_dim: 768 # syncformer 768 dim
25 | condition_dim: 768 # clap 768 text condition dim (clip-text)
26 | clip_dim: 768 # siglip2 visual dim
27 | audio_vae_latent_dim: 128
28 | audio_frame_rate: 50
29 | patch_size: 1
30 | rope_dim_list: null
31 | rope_theta: 10000
32 | text_length: 77
33 | clip_length: 64
34 | sync_length: 192
35 | use_mmaudio_singleblock: True
36 | depth_triple_ssl_encoder: null
37 | depth_single_ssl_encoder: 8
38 | use_repa_with_audiossl: True
39 |
40 | diffusion_config:
41 | denoise_type: "flow"
42 | flow_path_type: "linear"
43 | flow_predict_type: "velocity"
44 | flow_reverse: True
45 | flow_solver: "euler"
46 | sample_flow_shift: 1.0
47 | sample_use_flux_shift: False
48 | flux_base_shift: 0.5
49 | flux_max_shift: 1.15
50 |
--------------------------------------------------------------------------------
/model_urls.py:
--------------------------------------------------------------------------------
1 | # HunyuanVideo-Foley Model URLs Configuration
2 | # Update these URLs with the actual download links for the models
3 |
4 | MODEL_URLS = {
5 | "hunyuanvideo-foley-xxl": {
6 | "models": [
7 | {
8 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/hunyuanvideo_foley.pth",
9 | "filename": "hunyuanvideo_foley.pth",
10 | "description": "Main HunyuanVideo-Foley model"
11 | },
12 | {
13 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/synchformer_state_dict.pth",
14 | "filename": "synchformer_state_dict.pth",
15 | "description": "Synchformer model weights"
16 | },
17 | {
18 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/vae_128d_48k.pth",
19 | "filename": "vae_128d_48k.pth",
20 | "description": "VAE model weights"
21 | }
22 | ],
23 | "extracted_dir": "hunyuanvideo-foley-xxl",
24 | "description": "HunyuanVideo-Foley XXL model for audio generation"
25 | }
26 | }
27 |
28 | # Alternative mirror URLs (if main URLs fail)
29 | MIRROR_URLS = {
30 | # Add mirror download sources here if needed
31 | }
32 |
33 | def get_model_url(model_name: str, use_mirror: bool = False) -> dict:
34 | """Get model URL configuration"""
35 | urls_dict = MIRROR_URLS if use_mirror else MODEL_URLS
36 | return urls_dict.get(model_name, {})
37 |
38 | def list_available_models() -> list:
39 | """List all available model names"""
40 | return list(MODEL_URLS.keys())
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/nn/modulate_layers.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 | import torch
3 | import torch.nn as nn
4 |
5 | class ModulateDiT(nn.Module):
6 | def __init__(self, hidden_size: int, factor: int, act_layer: Callable, dtype=None, device=None):
7 | factory_kwargs = {"dtype": dtype, "device": device}
8 | super().__init__()
9 | self.act = act_layer()
10 | self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
11 | # Zero-initialize the modulation
12 | nn.init.zeros_(self.linear.weight)
13 | nn.init.zeros_(self.linear.bias)
14 |
15 | def forward(self, x: torch.Tensor) -> torch.Tensor:
16 | return self.linear(self.act(x))
17 |
18 |
19 | def modulate(x, shift=None, scale=None):
20 | if x.ndim == 3:
21 | shift = shift.unsqueeze(1) if shift is not None and shift.ndim == 2 else None
22 | scale = scale.unsqueeze(1) if scale is not None and scale.ndim == 2 else None
23 | if scale is None and shift is None:
24 | return x
25 | elif shift is None:
26 | return x * (1 + scale)
27 | elif scale is None:
28 | return x + shift
29 | else:
30 | return x * (1 + scale) + shift
31 |
32 |
33 | def apply_gate(x, gate=None, tanh=False):
34 | if gate is None:
35 | return x
36 | if gate.ndim == 2 and x.ndim == 3:
37 | gate = gate.unsqueeze(1)
38 | if tanh:
39 | return x * gate.tanh()
40 | else:
41 | return x * gate
42 |
43 |
44 | def ckpt_wrapper(module):
45 | def ckpt_forward(*inputs):
46 | outputs = module(*inputs)
47 | return outputs
48 |
49 | return ckpt_forward
50 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/constants.py:
--------------------------------------------------------------------------------
1 | """Constants used throughout the HunyuanVideo-Foley project."""
2 |
3 | from typing import Dict, List
4 |
5 | # Model configuration
6 | DEFAULT_AUDIO_SAMPLE_RATE = 48000
7 | DEFAULT_VIDEO_FPS = 25
8 | DEFAULT_AUDIO_CHANNELS = 2
9 |
10 | # Video processing
11 | MAX_VIDEO_DURATION_SECONDS = 15.0
12 | MIN_VIDEO_DURATION_SECONDS = 1.0
13 |
14 | # Audio processing
15 | AUDIO_VAE_LATENT_DIM = 128
16 | AUDIO_FRAME_RATE = 75 # frames per second in latent space
17 |
18 | # Visual features
19 | FPS_VISUAL: Dict[str, int] = {
20 | "siglip2": 8,
21 | "synchformer": 25
22 | }
23 |
24 | # Model paths (can be overridden by environment variables)
25 | DEFAULT_MODEL_PATH = "./pretrained_models/"
26 | DEFAULT_CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml"
27 |
28 | # Inference parameters
29 | DEFAULT_GUIDANCE_SCALE = 4.5
30 | DEFAULT_NUM_INFERENCE_STEPS = 50
31 | MIN_GUIDANCE_SCALE = 1.0
32 | MAX_GUIDANCE_SCALE = 10.0
33 | MIN_INFERENCE_STEPS = 10
34 | MAX_INFERENCE_STEPS = 100
35 |
36 | # Text processing
37 | MAX_TEXT_LENGTH = 100
38 | DEFAULT_NEGATIVE_PROMPT = "noisy, harsh"
39 |
40 | # File extensions
41 | SUPPORTED_VIDEO_EXTENSIONS: List[str] = [".mp4", ".avi", ".mov", ".mkv", ".webm"]
42 | SUPPORTED_AUDIO_EXTENSIONS: List[str] = [".wav", ".mp3", ".flac", ".aac"]
43 |
44 | # Quality settings
45 | AUDIO_QUALITY_SETTINGS: Dict[str, List[str]] = {
46 | "high": ["-b:a", "192k"],
47 | "medium": ["-b:a", "128k"],
48 | "low": ["-b:a", "96k"]
49 | }
50 |
51 | # Error messages
52 | ERROR_MESSAGES: Dict[str, str] = {
53 | "model_not_loaded": "Model is not loaded. Please load the model first.",
54 | "invalid_video_format": "Unsupported video format. Supported formats: {formats}",
55 | "video_too_long": f"Video duration exceeds maximum of {MAX_VIDEO_DURATION_SECONDS} seconds",
56 | "ffmpeg_not_found": "ffmpeg not found. Please install ffmpeg: https://ffmpeg.org/download.html"
57 | }
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/synchformer/divided_224_16x4.yaml:
--------------------------------------------------------------------------------
1 | TRAIN:
2 | ENABLE: True
3 | DATASET: Ssv2
4 | BATCH_SIZE: 32
5 | EVAL_PERIOD: 5
6 | CHECKPOINT_PERIOD: 5
7 | AUTO_RESUME: True
8 | CHECKPOINT_EPOCH_RESET: True
9 | CHECKPOINT_FILE_PATH: /checkpoint/fmetze/neurips_sota/40944587/checkpoints/checkpoint_epoch_00035.pyth
10 | DATA:
11 | NUM_FRAMES: 16
12 | SAMPLING_RATE: 4
13 | TRAIN_JITTER_SCALES: [256, 320]
14 | TRAIN_CROP_SIZE: 224
15 | TEST_CROP_SIZE: 224
16 | INPUT_CHANNEL_NUM: [3]
17 | MEAN: [0.5, 0.5, 0.5]
18 | STD: [0.5, 0.5, 0.5]
19 | PATH_TO_DATA_DIR: /private/home/mandelapatrick/slowfast/data/ssv2
20 | PATH_PREFIX: /datasets01/SomethingV2/092720/20bn-something-something-v2-frames
21 | INV_UNIFORM_SAMPLE: True
22 | RANDOM_FLIP: False
23 | REVERSE_INPUT_CHANNEL: True
24 | USE_RAND_AUGMENT: True
25 | RE_PROB: 0.0
26 | USE_REPEATED_AUG: False
27 | USE_RANDOM_RESIZE_CROPS: False
28 | COLORJITTER: False
29 | GRAYSCALE: False
30 | GAUSSIAN: False
31 | SOLVER:
32 | BASE_LR: 1e-4
33 | LR_POLICY: steps_with_relative_lrs
34 | LRS: [1, 0.1, 0.01]
35 | STEPS: [0, 20, 30]
36 | MAX_EPOCH: 35
37 | MOMENTUM: 0.9
38 | WEIGHT_DECAY: 5e-2
39 | WARMUP_EPOCHS: 0.0
40 | OPTIMIZING_METHOD: adamw
41 | USE_MIXED_PRECISION: True
42 | SMOOTHING: 0.2
43 | SLOWFAST:
44 | ALPHA: 8
45 | VIT:
46 | PATCH_SIZE: 16
47 | PATCH_SIZE_TEMP: 2
48 | CHANNELS: 3
49 | EMBED_DIM: 768
50 | DEPTH: 12
51 | NUM_HEADS: 12
52 | MLP_RATIO: 4
53 | QKV_BIAS: True
54 | VIDEO_INPUT: True
55 | TEMPORAL_RESOLUTION: 8
56 | USE_MLP: True
57 | DROP: 0.0
58 | POS_DROPOUT: 0.0
59 | DROP_PATH: 0.2
60 | IM_PRETRAINED: True
61 | HEAD_DROPOUT: 0.0
62 | HEAD_ACT: tanh
63 | PRETRAINED_WEIGHTS: vit_1k
64 | ATTN_LAYER: divided
65 | MODEL:
66 | NUM_CLASSES: 174
67 | ARCH: slow
68 | MODEL_NAME: VisionTransformer
69 | LOSS_FUNC: cross_entropy
70 | TEST:
71 | ENABLE: True
72 | DATASET: Ssv2
73 | BATCH_SIZE: 64
74 | NUM_ENSEMBLE_VIEWS: 1
75 | NUM_SPATIAL_CROPS: 3
76 | DATA_LOADER:
77 | NUM_WORKERS: 4
78 | PIN_MEMORY: True
79 | NUM_GPUS: 8
80 | NUM_SHARDS: 4
81 | RNG_SEED: 0
82 | OUTPUT_DIR: .
83 | TENSORBOARD:
84 | ENABLE: True
85 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/nn/norm_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class RMSNorm(nn.Module):
5 | def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6,
6 | device=None, dtype=None):
7 | """
8 | Initialize the RMSNorm normalization layer.
9 |
10 | Args:
11 | dim (int): The dimension of the input tensor.
12 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
13 |
14 | Attributes:
15 | eps (float): A small value added to the denominator for numerical stability.
16 | weight (nn.Parameter): Learnable scaling parameter.
17 |
18 | """
19 | factory_kwargs = {'device': device, 'dtype': dtype}
20 | super().__init__()
21 | self.eps = eps
22 | if elementwise_affine:
23 | self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
24 |
25 | def _norm(self, x):
26 | """
27 | Apply the RMSNorm normalization to the input tensor.
28 |
29 | Args:
30 | x (torch.Tensor): The input tensor.
31 |
32 | Returns:
33 | torch.Tensor: The normalized tensor.
34 |
35 | """
36 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
37 |
38 | def forward(self, x):
39 | """
40 | Forward pass through the RMSNorm layer.
41 |
42 | Args:
43 | x (torch.Tensor): The input tensor.
44 |
45 | Returns:
46 | torch.Tensor: The output tensor after applying RMSNorm.
47 |
48 | """
49 | output = self._norm(x.float()).type_as(x)
50 | if hasattr(self, "weight"):
51 | output = output * self.weight
52 | return output
53 |
54 |
55 | def get_norm_layer(norm_layer):
56 | """
57 | Get the normalization layer.
58 |
59 | Args:
60 | norm_layer (str): The type of normalization layer.
61 |
62 | Returns:
63 | norm_layer (nn.Module): The normalization layer.
64 | """
65 | if norm_layer == "layer":
66 | return nn.LayerNorm
67 | elif norm_layer == "rms":
68 | return RMSNorm
69 | else:
70 | raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
71 |
--------------------------------------------------------------------------------
/model_management.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import comfy.utils
4 | from loguru import logger
5 | import folder_paths
6 | from huggingface_hub import hf_hub_download
7 |
8 | # --- Constants ---
9 | FOLEY_MODEL_NAMES = ["hunyuanvideo_foley.pth", "vae_128d_48k.pth", "synchformer_state_dict.pth"]
10 | SIGLIP_MODEL_REPO = "google/siglip-base-patch16-512"
11 | CLAP_MODEL_REPO = "laion/clap-htsat-unfused"
12 |
13 | # --- Path Management ---
14 | def get_model_dir(subfolder=""):
15 | """Returns the primary Foley models directory."""
16 | return os.path.join(folder_paths.get_folder_paths("foley")[0], subfolder)
17 |
18 | def get_full_model_path(model_name, subfolder=""):
19 | """Returns the full path for a given model name."""
20 | return os.path.join(get_model_dir(subfolder), model_name)
21 |
22 | # --- Core Functionality ---
23 | def find_or_download(model_name, repo_id, subfolder="", subfolder_in_repo=""):
24 | """
25 | Finds a model file, downloading it if it's not found in standard locations.
26 | - Checks the main ComfyUI foley models directory first.
27 | - Falls back to downloading from Hugging Face.
28 | """
29 | local_path = get_full_model_path(model_name, subfolder)
30 |
31 | if os.path.exists(local_path):
32 | logger.info(f"Found local model: {local_path}")
33 | return local_path
34 |
35 | logger.warning(f"Could not find {model_name} locally. Attempting to download from {repo_id}...")
36 |
37 | try:
38 | downloaded_path = hf_hub_download(
39 | repo_id=repo_id,
40 | filename=model_name,
41 | subfolder=subfolder_in_repo,
42 | local_dir=get_model_dir(subfolder),
43 | local_dir_use_symlinks=False
44 | )
45 | logger.info(f"Successfully downloaded model to: {downloaded_path}")
46 | return downloaded_path
47 | except Exception as e:
48 | logger.error(f"Failed to download {model_name} from {repo_id}: {e}")
49 | raise FileNotFoundError(f"Could not find or download {model_name}. Please check your connection or download it manually.")
50 |
51 | def get_siglip_path():
52 | """Special handling for the SigLIP model which is a directory."""
53 | return find_or_download_directory(repo_id=SIGLIP_MODEL_REPO, local_dir_name="siglip-base-patch16-512")
54 |
55 | def get_clap_path():
56 | """Special handling for the CLAP model which is a directory."""
57 | return find_or_download_directory(repo_id=CLAP_MODEL_REPO, local_dir_name="clap-htsat-unfused")
58 |
59 | def find_or_download_directory(repo_id, local_dir_name):
60 | """
61 | Finds a model directory, downloading it if it's not found.
62 | This is for models like SigLIP that are not single files.
63 | """
64 | local_path = get_model_dir(local_dir_name)
65 |
66 | if os.path.exists(local_path) and os.listdir(local_path):
67 | logger.info(f"Found local model directory: {local_path}")
68 | return local_path
69 |
70 | logger.warning(f"Could not find {local_dir_name} directory locally. Attempting to download from {repo_id}...")
71 |
72 | # We can't use hf_hub_download for a whole directory in the same way,
73 | # but the transformers library will handle this caching for us automatically
74 | # when `from_pretrained` is called. We just need to return the repo_id.
75 | # The actual "download" is implicit.
76 | return repo_id
77 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.0])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.mean(
45 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
46 | dim=[1, 2],
47 | )
48 | else:
49 | return 0.5 * torch.mean(
50 | torch.pow(self.mean - other.mean, 2) / other.var
51 | + self.var / other.var
52 | - 1.0
53 | - self.logvar
54 | + other.logvar,
55 | dim=[1, 2],
56 | )
57 |
58 | def nll(self, sample, dims=[1, 2]):
59 | if self.deterministic:
60 | return torch.Tensor([0.0])
61 | logtwopi = np.log(2.0 * np.pi)
62 | return 0.5 * torch.sum(
63 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
64 | dim=dims,
65 | )
66 |
67 | def mode(self):
68 | return self.mean
69 |
70 |
71 | def normal_kl(mean1, logvar1, mean2, logvar2):
72 | """
73 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
74 | Compute the KL divergence between two gaussians.
75 | Shapes are automatically broadcasted, so batches can be compared to
76 | scalars, among other use cases.
77 | """
78 | tensor = None
79 | for obj in (mean1, logvar1, mean2, logvar2):
80 | if isinstance(obj, torch.Tensor):
81 | tensor = obj
82 | break
83 | assert tensor is not None, "at least one argument must be a Tensor"
84 |
85 | # Force variances to be Tensors. Broadcasting helps convert scalars to
86 | # Tensors, but it does not work for torch.exp().
87 | logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
88 |
89 | return 0.5 * (
90 | -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
91 | )
92 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/utils/decode.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from pathlib import Path
3 |
4 | import argbind
5 | import numpy as np
6 | import torch
7 | from audiotools import AudioSignal
8 | from tqdm import tqdm
9 |
10 | from ..model import DACFile
11 | from . import load_model
12 |
13 | warnings.filterwarnings("ignore", category=UserWarning)
14 |
15 |
16 | @argbind.bind(group="decode", positional=True, without_prefix=True)
17 | @torch.inference_mode()
18 | @torch.no_grad()
19 | def decode(
20 | input: str,
21 | output: str = "",
22 | weights_path: str = "",
23 | model_tag: str = "latest",
24 | model_bitrate: str = "8kbps",
25 | device: str = "cuda",
26 | model_type: str = "44khz",
27 | verbose: bool = False,
28 | ):
29 | """Decode audio from codes.
30 |
31 | Parameters
32 | ----------
33 | input : str
34 | Path to input directory or file
35 | output : str, optional
36 | Path to output directory, by default "".
37 | If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
38 | weights_path : str, optional
39 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
40 | model_tag and model_type.
41 | model_tag : str, optional
42 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
43 | model_bitrate: str
44 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
45 | device : str, optional
46 | Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
47 | model_type : str, optional
48 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
49 | """
50 | generator = load_model(
51 | model_type=model_type,
52 | model_bitrate=model_bitrate,
53 | tag=model_tag,
54 | load_path=weights_path,
55 | )
56 | generator.to(device)
57 | generator.eval()
58 |
59 | # Find all .dac files in input directory
60 | _input = Path(input)
61 | input_files = list(_input.glob("**/*.dac"))
62 |
63 | # If input is a .dac file, add it to the list
64 | if _input.suffix == ".dac":
65 | input_files.append(_input)
66 |
67 | # Create output directory
68 | output = Path(output)
69 | output.mkdir(parents=True, exist_ok=True)
70 |
71 | for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
72 | # Load file
73 | artifact = DACFile.load(input_files[i])
74 |
75 | # Reconstruct audio from codes
76 | recons = generator.decompress(artifact, verbose=verbose)
77 |
78 | # Compute output path
79 | relative_path = input_files[i].relative_to(input)
80 | output_dir = output / relative_path.parent
81 | if not relative_path.name:
82 | output_dir = output
83 | relative_path = input_files[i]
84 | output_name = relative_path.with_suffix(".wav").name
85 | output_path = output_dir / output_name
86 | output_path.parent.mkdir(parents=True, exist_ok=True)
87 |
88 | # Write to file
89 | recons.write(output_path)
90 |
91 |
92 | if __name__ == "__main__":
93 | args = argbind.parse_args()
94 | with argbind.scope(args):
95 | decode()
96 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/utils/encode.py:
--------------------------------------------------------------------------------
1 | import math
2 | import warnings
3 | from pathlib import Path
4 |
5 | import argbind
6 | import numpy as np
7 | import torch
8 | from audiotools import AudioSignal
9 | from audiotools.core import util
10 | from tqdm import tqdm
11 |
12 | from . import load_model
13 |
14 | warnings.filterwarnings("ignore", category=UserWarning)
15 |
16 |
17 | @argbind.bind(group="encode", positional=True, without_prefix=True)
18 | @torch.inference_mode()
19 | @torch.no_grad()
20 | def encode(
21 | input: str,
22 | output: str = "",
23 | weights_path: str = "",
24 | model_tag: str = "latest",
25 | model_bitrate: str = "8kbps",
26 | n_quantizers: int = None,
27 | device: str = "cuda",
28 | model_type: str = "44khz",
29 | win_duration: float = 5.0,
30 | verbose: bool = False,
31 | ):
32 | """Encode audio files in input path to .dac format.
33 |
34 | Parameters
35 | ----------
36 | input : str
37 | Path to input audio file or directory
38 | output : str, optional
39 | Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
40 | weights_path : str, optional
41 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
42 | model_tag and model_type.
43 | model_tag : str, optional
44 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
45 | model_bitrate: str
46 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
47 | n_quantizers : int, optional
48 | Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
49 | device : str, optional
50 | Device to use, by default "cuda"
51 | model_type : str, optional
52 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
53 | """
54 | generator = load_model(
55 | model_type=model_type,
56 | model_bitrate=model_bitrate,
57 | tag=model_tag,
58 | load_path=weights_path,
59 | )
60 | generator.to(device)
61 | generator.eval()
62 | kwargs = {"n_quantizers": n_quantizers}
63 |
64 | # Find all audio files in input path
65 | input = Path(input)
66 | audio_files = util.find_audio(input)
67 |
68 | output = Path(output)
69 | output.mkdir(parents=True, exist_ok=True)
70 |
71 | for i in tqdm(range(len(audio_files)), desc="Encoding files"):
72 | # Load file
73 | signal = AudioSignal(audio_files[i])
74 |
75 | # Encode audio to .dac format
76 | artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs)
77 |
78 | # Compute output path
79 | relative_path = audio_files[i].relative_to(input)
80 | output_dir = output / relative_path.parent
81 | if not relative_path.name:
82 | output_dir = output
83 | relative_path = audio_files[i]
84 | output_name = relative_path.with_suffix(".dac").name
85 | output_path = output_dir / output_name
86 | output_path.parent.mkdir(parents=True, exist_ok=True)
87 |
88 | artifact.save(output_path)
89 |
90 |
91 | if __name__ == "__main__":
92 | args = argbind.parse_args()
93 | with argbind.scope(args):
94 | encode()
95 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import argbind
4 | from audiotools import ml
5 |
6 | from ..model import DAC
7 | Accelerator = ml.Accelerator
8 |
9 | __MODEL_LATEST_TAGS__ = {
10 | ("44khz", "8kbps"): "0.0.1",
11 | ("24khz", "8kbps"): "0.0.4",
12 | ("16khz", "8kbps"): "0.0.5",
13 | ("44khz", "16kbps"): "1.0.0",
14 | }
15 |
16 | __MODEL_URLS__ = {
17 | (
18 | "44khz",
19 | "0.0.1",
20 | "8kbps",
21 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
22 | (
23 | "24khz",
24 | "0.0.4",
25 | "8kbps",
26 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
27 | (
28 | "16khz",
29 | "0.0.5",
30 | "8kbps",
31 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
32 | (
33 | "44khz",
34 | "1.0.0",
35 | "16kbps",
36 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
37 | }
38 |
39 |
40 | @argbind.bind(group="download", positional=True, without_prefix=True)
41 | def download(
42 | model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
43 | ):
44 | """
45 | Function that downloads the weights file from URL if a local cache is not found.
46 |
47 | Parameters
48 | ----------
49 | model_type : str
50 | The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
51 | model_bitrate: str
52 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
53 | Only 44khz model supports 16kbps.
54 | tag : str
55 | The tag of the model to download. Defaults to "latest".
56 |
57 | Returns
58 | -------
59 | Path
60 | Directory path required to load model via audiotools.
61 | """
62 | model_type = model_type.lower()
63 | tag = tag.lower()
64 |
65 | assert model_type in [
66 | "44khz",
67 | "24khz",
68 | "16khz",
69 | ], "model_type must be one of '44khz', '24khz', or '16khz'"
70 |
71 | assert model_bitrate in [
72 | "8kbps",
73 | "16kbps",
74 | ], "model_bitrate must be one of '8kbps', or '16kbps'"
75 |
76 | if tag == "latest":
77 | tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
78 |
79 | download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
80 |
81 | if download_link is None:
82 | raise ValueError(
83 | f"Could not find model with tag {tag} and model type {model_type}"
84 | )
85 |
86 | local_path = (
87 | Path.home()
88 | / ".cache"
89 | / "descript"
90 | / "dac"
91 | / f"weights_{model_type}_{model_bitrate}_{tag}.pth"
92 | )
93 | if not local_path.exists():
94 | local_path.parent.mkdir(parents=True, exist_ok=True)
95 |
96 | # Download the model
97 | import requests
98 |
99 | response = requests.get(download_link)
100 |
101 | if response.status_code != 200:
102 | raise ValueError(
103 | f"Could not download model. Received response code {response.status_code}"
104 | )
105 | local_path.write_bytes(response.content)
106 |
107 | return local_path
108 |
109 |
110 | def load_model(
111 | model_type: str = "44khz",
112 | model_bitrate: str = "8kbps",
113 | tag: str = "latest",
114 | load_path: str = None,
115 | ):
116 | if not load_path:
117 | load_path = download(
118 | model_type=model_type, model_bitrate=model_bitrate, tag=tag
119 | )
120 | generator = DAC.load(load_path)
121 | return generator
122 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/utils/config_utils.py:
--------------------------------------------------------------------------------
1 | """Configuration utilities for the HunyuanVideo-Foley project."""
2 |
3 | import yaml
4 | from pathlib import Path
5 | from typing import Any, Dict, List, Union
6 |
7 | class AttributeDict:
8 |
9 | def __init__(self, data: Union[Dict, List, Any]):
10 | if isinstance(data, dict):
11 | for key, value in data.items():
12 | if isinstance(value, (dict, list)):
13 | value = AttributeDict(value)
14 | setattr(self, self._sanitize_key(key), value)
15 | elif isinstance(data, list):
16 | self._list = [AttributeDict(item) if isinstance(item, (dict, list)) else item
17 | for item in data]
18 | else:
19 | self._value = data
20 |
21 | def _sanitize_key(self, key: str) -> str:
22 | import re
23 | sanitized = re.sub(r'[^a-zA-Z0-9_]', '_', str(key))
24 | if sanitized[0].isdigit():
25 | sanitized = f'_{sanitized}'
26 | return sanitized
27 |
28 | def __getitem__(self, key):
29 | if hasattr(self, '_list'):
30 | return self._list[key]
31 | return getattr(self, self._sanitize_key(key))
32 |
33 | def __setitem__(self, key, value):
34 | if hasattr(self, '_list'):
35 | self._list[key] = value
36 | else:
37 | setattr(self, self._sanitize_key(key), value)
38 |
39 | def __iter__(self):
40 | if hasattr(self, '_list'):
41 | return iter(self._list)
42 | return iter(self.__dict__.keys())
43 |
44 | def __len__(self):
45 | if hasattr(self, '_list'):
46 | return len(self._list)
47 | return len(self.__dict__)
48 |
49 | def get(self, key, default=None):
50 | try:
51 | return self[key]
52 | except (KeyError, AttributeError, IndexError):
53 | return default
54 |
55 | def keys(self):
56 | if hasattr(self, '_list'):
57 | return range(len(self._list))
58 | elif hasattr(self, '_value'):
59 | return []
60 | else:
61 | return [key for key in self.__dict__.keys() if not key.startswith('_')]
62 |
63 | def values(self):
64 | if hasattr(self, '_list'):
65 | return self._list
66 | elif hasattr(self, '_value'):
67 | return [self._value]
68 | else:
69 | return [value for key, value in self.__dict__.items() if not key.startswith('_')]
70 |
71 | def items(self):
72 | if hasattr(self, '_list'):
73 | return enumerate(self._list)
74 | elif hasattr(self, '_value'):
75 | return []
76 | else:
77 | return [(key, value) for key, value in self.__dict__.items() if not key.startswith('_')]
78 |
79 | def __repr__(self):
80 | if hasattr(self, '_list'):
81 | return f"AttributeDict({self._list})"
82 | elif hasattr(self, '_value'):
83 | return f"AttributeDict({self._value})"
84 | return f"AttributeDict({dict(self.__dict__)})"
85 |
86 | def to_dict(self) -> Union[Dict, List, Any]:
87 | if hasattr(self, '_list'):
88 | return [item.to_dict() if isinstance(item, AttributeDict) else item
89 | for item in self._list]
90 | elif hasattr(self, '_value'):
91 | return self._value
92 | else:
93 | result = {}
94 | for key, value in self.__dict__.items():
95 | if isinstance(value, AttributeDict):
96 | result[key] = value.to_dict()
97 | else:
98 | result[key] = value
99 | return result
100 |
101 | def load_yaml(file_path: str, encoding: str = 'utf-8') -> AttributeDict:
102 | try:
103 | with open(file_path, 'r', encoding=encoding) as file:
104 | data = yaml.safe_load(file)
105 | return AttributeDict(data)
106 | except FileNotFoundError:
107 | raise FileNotFoundError(f"YAML file not found: {file_path}")
108 | except yaml.YAMLError as e:
109 | raise yaml.YAMLError(f"YAML format error: {e}")
110 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/utils/helper.py:
--------------------------------------------------------------------------------
1 | import collections.abc
2 | from itertools import repeat
3 | import importlib
4 | import yaml
5 | import time
6 |
7 | def default(value, default_val):
8 | return default_val if value is None else value
9 |
10 |
11 | def default_dtype(value, default_val):
12 | if value is not None:
13 | assert isinstance(value, type(default_val)), f"Expect {type(default_val)}, got {type(value)}."
14 | return value
15 | return default_val
16 |
17 |
18 | def repeat_interleave(lst, num_repeats):
19 | return [item for item in lst for _ in range(num_repeats)]
20 |
21 |
22 | def _ntuple(n):
23 | def parse(x):
24 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
25 | x = tuple(x)
26 | if len(x) == 1:
27 | x = tuple(repeat(x[0], n))
28 | return x
29 | return tuple(repeat(x, n))
30 |
31 | return parse
32 |
33 |
34 | to_1tuple = _ntuple(1)
35 | to_2tuple = _ntuple(2)
36 | to_3tuple = _ntuple(3)
37 | to_4tuple = _ntuple(4)
38 |
39 |
40 | def as_tuple(x):
41 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
42 | return tuple(x)
43 | if x is None or isinstance(x, (int, float, str)):
44 | return (x,)
45 | else:
46 | raise ValueError(f"Unknown type {type(x)}")
47 |
48 |
49 | def as_list_of_2tuple(x):
50 | x = as_tuple(x)
51 | if len(x) == 1:
52 | x = (x[0], x[0])
53 | assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
54 | lst = []
55 | for i in range(0, len(x), 2):
56 | lst.append((x[i], x[i + 1]))
57 | return lst
58 |
59 |
60 | def find_multiple(n: int, k: int) -> int:
61 | assert k > 0
62 | if n % k == 0:
63 | return n
64 | return n - (n % k) + k
65 |
66 |
67 | def merge_dicts(dict1, dict2):
68 | for key, value in dict2.items():
69 | if key in dict1 and isinstance(dict1[key], dict) and isinstance(value, dict):
70 | merge_dicts(dict1[key], value)
71 | else:
72 | dict1[key] = value
73 | return dict1
74 |
75 |
76 | def merge_yaml_files(file_list):
77 | merged_config = {}
78 |
79 | for file in file_list:
80 | with open(file, "r", encoding="utf-8") as f:
81 | config = yaml.safe_load(f)
82 | if config:
83 | # Remove the first level
84 | for key, value in config.items():
85 | if isinstance(value, dict):
86 | merged_config = merge_dicts(merged_config, value)
87 | else:
88 | merged_config[key] = value
89 |
90 | return merged_config
91 |
92 |
93 | def merge_dict(file_list):
94 | merged_config = {}
95 |
96 | for file in file_list:
97 | with open(file, "r", encoding="utf-8") as f:
98 | config = yaml.safe_load(f)
99 | if config:
100 | merged_config = merge_dicts(merged_config, config)
101 |
102 | return merged_config
103 |
104 |
105 | def get_obj_from_str(string, reload=False):
106 | module, cls = string.rsplit(".", 1)
107 | if reload:
108 | module_imp = importlib.import_module(module)
109 | importlib.reload(module_imp)
110 | return getattr(importlib.import_module(module, package=None), cls)
111 |
112 |
113 | def readable_time(seconds):
114 | """ Convert time seconds to a readable format: DD Days, HH Hours, MM Minutes, SS Seconds """
115 | seconds = int(seconds)
116 | days, seconds = divmod(seconds, 86400)
117 | hours, seconds = divmod(seconds, 3600)
118 | minutes, seconds = divmod(seconds, 60)
119 | if days > 0:
120 | return f"{days} Days, {hours} Hours, {minutes} Minutes, {seconds} Seconds"
121 | if hours > 0:
122 | return f"{hours} Hours, {minutes} Minutes, {seconds} Seconds"
123 | if minutes > 0:
124 | return f"{minutes} Minutes, {seconds} Seconds"
125 | return f"{seconds} Seconds"
126 |
127 |
128 | def get_obj_from_cfg(cfg, reload=False):
129 | if isinstance(cfg, str):
130 | return get_obj_from_str(cfg, reload)
131 | elif isinstance(cfg, (list, tuple,)):
132 | return tuple([get_obj_from_str(c, reload) for c in cfg])
133 | else:
134 | raise NotImplementedError(f"Not implemented for {type(cfg)}.")
135 |
--------------------------------------------------------------------------------
/INSTALLATION_GUIDE.md:
--------------------------------------------------------------------------------
1 | # Installation Guide for ComfyUI HunyuanVideo-Foley Custom Node
2 |
3 | ## Overview
4 |
5 | This custom node wraps the HunyuanVideo-Foley model for use in ComfyUI, enabling text-video-to-audio synthesis directly within ComfyUI workflows.
6 |
7 | ## Prerequisites
8 |
9 | - ComfyUI installation
10 | - Python 3.8+
11 | - CUDA-capable GPU (8GB+ VRAM recommended, can run on less with memory optimization)
12 | - At least 16GB system RAM
13 |
14 | ## Step-by-Step Installation
15 |
16 | ### 1. Clone the Custom Node
17 |
18 | Navigate to your ComfyUI `custom_nodes` directory and clone the repository:
19 | ```bash
20 | cd /path/to/ComfyUI/custom_nodes
21 | git clone https://github.com/if-ai/ComfyUI_HunyuanVideoFoley.git
22 | cd ComfyUI_HunyuanVideoFoley
23 | ```
24 |
25 | ### 2. Install Dependencies
26 |
27 | Run the included installation script. This will check for and install any missing Python packages.
28 | ```bash
29 | python install.py
30 | ```
31 |
32 | ### 3. Model Handling (Automatic)
33 |
34 | **No manual download is required.**
35 |
36 | The first time you use a generator node, the necessary models will be automatically downloaded and placed in the correct directory: `ComfyUI/models/foley/`.
37 |
38 | The script will create this directory for you if it doesn't exist.
39 |
40 | ### 4. Restart ComfyUI
41 |
42 | After the installation is complete, restart ComfyUI to load the new custom nodes.
43 |
44 | ## Expected Directory Structure
45 |
46 | The installer will create a `foley` directory inside your main ComfyUI `models` folder for storing the downloaded models. The custom node directory will look like this:
47 |
48 | ```
49 | ComfyUI/
50 | ├── models/
51 | │ └── foley/
52 | │ └── hunyuanvideo-foley-xxl/
53 | │ ├── hunyuanvideo_foley.pth
54 | │ ├── vae_128d_48k.pth
55 | │ └── synchformer_state_dict.pth
56 | └── custom_nodes/
57 | └── ComfyUI_HunyuanVideoFoley/
58 | ├── __init__.py
59 | ├── nodes.py
60 | ├── install.py
61 | └── ... (other node files)
62 | ```
63 |
64 | ## Usage
65 |
66 | ### Nodes Available
67 |
68 | 1. **HunyuanVideo-Foley Generator**: The main, simplified node for audio generation.
69 | 2. **HunyuanVideo-Foley Generator (Advanced)**: An advanced version that can accept pre-loaded models from loader nodes for optimized workflows.
70 | 3. **HunyuanVideo-Foley Model Loader (FP8)**: Loads the model with optional memory-saving FP8 quantization.
71 | 4. **HunyuanVideo-Foley Dependencies**: Pre-loads model dependencies like text encoders.
72 | 5. **HunyuanVideo-Foley Torch Compile**: Optimizes the model with `torch.compile` for faster inference on compatible GPUs.
73 |
74 | ## Performance & Memory Optimization
75 |
76 | The model includes several features to manage VRAM usage, allowing it to run on a wider range of hardware.
77 |
78 | - **VRAM Usage**: While 8GB of VRAM is recommended for a smooth experience, you can run the model on GPUs with less memory by enabling the following options in the generator node:
79 | - **`memory_efficient`**: This checkbox aggressively unloads models from VRAM after each generation. This is the most effective way to save VRAM.
80 | - **`cpu_offload`**: This option keeps the models on the CPU and only moves them to the GPU when needed. It is slower but significantly reduces VRAM usage.
81 |
82 | - **Generation Time**: Audio generation can take time depending on video length, settings, and hardware. Use the `HunyuanVideo-Foley Torch Compile` node for a potential speedup on subsequent runs.
83 |
84 | ## Troubleshooting
85 |
86 | ### Common Issues
87 |
88 | 1. **"Failed to import..." errors**:
89 | Ensure the installation script completed successfully. You can run it again to be sure:
90 | ```bash
91 | python install.py
92 | ```
93 |
94 | 2. **Model download issues**:
95 | If the automatic download fails, check your internet connection and the ComfyUI console for error messages. You can also manually download the models from [HuggingFace](https://huggingface.co/tencent/HunyuanVideo-Foley) and place them in `ComfyUI/models/foley/hunyuanvideo-foley-xxl/`.
96 |
97 | 3. **CUDA out of memory**:
98 | - Enable the `memory_efficient` checkbox in the node.
99 | - Enable `cpu_offload` if you still have issues (at the cost of speed).
100 | - Reduce `sample_nums` to 1.
101 | - Use shorter videos for testing.
--------------------------------------------------------------------------------
/hunyuanvideo_foley/utils/media_utils.py:
--------------------------------------------------------------------------------
1 | """Media utilities for audio/video processing."""
2 |
3 | import os
4 | import subprocess
5 | from pathlib import Path
6 | from typing import Optional
7 |
8 | from loguru import logger
9 |
10 |
11 | class MediaProcessingError(Exception):
12 | """Exception raised for media processing errors."""
13 | pass
14 |
15 |
16 | def merge_audio_video(
17 | audio_path: str,
18 | video_path: str,
19 | output_path: str,
20 | overwrite: bool = True,
21 | quality: str = "high"
22 | ) -> str:
23 | """
24 | Merge audio and video files using ffmpeg.
25 |
26 | Args:
27 | audio_path: Path to input audio file
28 | video_path: Path to input video file
29 | output_path: Path for output video file
30 | overwrite: Whether to overwrite existing output file
31 | quality: Quality setting ('high', 'medium', 'low')
32 |
33 | Returns:
34 | Path to the output file
35 |
36 | Raises:
37 | MediaProcessingError: If input files don't exist or ffmpeg fails
38 | FileNotFoundError: If ffmpeg is not installed
39 | """
40 | # Validate input files
41 | if not os.path.exists(audio_path):
42 | raise MediaProcessingError(f"Audio file not found: {audio_path}")
43 | if not os.path.exists(video_path):
44 | raise MediaProcessingError(f"Video file not found: {video_path}")
45 |
46 | # Create output directory if needed
47 | output_dir = Path(output_path).parent
48 | output_dir.mkdir(parents=True, exist_ok=True)
49 |
50 | # Quality settings
51 | quality_settings = {
52 | "high": ["-b:a", "192k"],
53 | "medium": ["-b:a", "128k"],
54 | "low": ["-b:a", "96k"]
55 | }
56 |
57 | # Build ffmpeg command with more flexible stream handling
58 | ffmpeg_command = [
59 | "ffmpeg",
60 | "-i", video_path,
61 | "-i", audio_path,
62 | "-c:v", "copy",
63 | "-c:a", "aac",
64 | "-ac", "2",
65 | "-shortest", # Use shortest stream to avoid hanging
66 | *quality_settings.get(quality, quality_settings["high"]),
67 | ]
68 |
69 | if overwrite:
70 | ffmpeg_command.append("-y")
71 |
72 | ffmpeg_command.append(output_path)
73 |
74 | try:
75 | logger.info(f"Merging audio '{audio_path}' with video '{video_path}'")
76 | logger.info(f"FFmpeg command: {' '.join(ffmpeg_command)}")
77 |
78 | process = subprocess.Popen(
79 | ffmpeg_command,
80 | stdout=subprocess.PIPE,
81 | stderr=subprocess.PIPE,
82 | text=True
83 | )
84 | stdout, stderr = process.communicate()
85 |
86 | if process.returncode != 0:
87 | logger.error(f"Primary merge failed, trying fallback method...")
88 | logger.error(f"FFmpeg stderr: {stderr}")
89 |
90 | # Try a more compatible fallback approach
91 | fallback_command = [
92 | "ffmpeg", "-y",
93 | "-i", video_path,
94 | "-i", audio_path,
95 | "-c:v", "libx264", # Re-encode video for compatibility
96 | "-c:a", "aac",
97 | "-b:a", "128k",
98 | "-preset", "fast", # Faster encoding
99 | "-shortest",
100 | output_path
101 | ]
102 |
103 | logger.info(f"Fallback FFmpeg command: {' '.join(fallback_command)}")
104 |
105 | fallback_process = subprocess.Popen(
106 | fallback_command,
107 | stdout=subprocess.PIPE,
108 | stderr=subprocess.PIPE,
109 | text=True
110 | )
111 | fallback_stdout, fallback_stderr = fallback_process.communicate()
112 |
113 | if fallback_process.returncode != 0:
114 | error_msg = f"Both primary and fallback FFmpeg failed. Primary: {stderr}, Fallback: {fallback_stderr}"
115 | logger.error(error_msg)
116 | raise MediaProcessingError(error_msg)
117 | else:
118 | logger.info(f"Successfully merged video with fallback method: {output_path}")
119 | else:
120 | logger.info(f"Successfully merged video saved to: {output_path}")
121 |
122 | except FileNotFoundError:
123 | raise FileNotFoundError(
124 | "ffmpeg not found. Please install ffmpeg: "
125 | "https://ffmpeg.org/download.html"
126 | )
127 | except Exception as e:
128 | raise MediaProcessingError(f"Unexpected error during media processing: {e}")
129 |
130 | return output_path
131 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/synchformer/utils.py:
--------------------------------------------------------------------------------
1 | from hashlib import md5
2 | from pathlib import Path
3 | import subprocess
4 |
5 | import requests
6 | from tqdm import tqdm
7 |
8 | PARENT_LINK = "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a"
9 | FNAME2LINK = {
10 | # S3: Synchability: AudioSet (run 2)
11 | "24-01-22T20-34-52.pt": f"{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/24-01-22T20-34-52.pt",
12 | "cfg-24-01-22T20-34-52.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/cfg-24-01-22T20-34-52.yaml",
13 | # S2: Synchformer: AudioSet (run 2)
14 | "24-01-04T16-39-21.pt": f"{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/24-01-04T16-39-21.pt",
15 | "cfg-24-01-04T16-39-21.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml",
16 | # S2: Synchformer: AudioSet (run 1)
17 | "23-08-28T11-23-23.pt": f"{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/23-08-28T11-23-23.pt",
18 | "cfg-23-08-28T11-23-23.yaml": f"{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/cfg-23-08-28T11-23-23.yaml",
19 | # S2: Synchformer: LRS3 (run 2)
20 | "23-12-23T18-33-57.pt": f"{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/23-12-23T18-33-57.pt",
21 | "cfg-23-12-23T18-33-57.yaml": f"{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/cfg-23-12-23T18-33-57.yaml",
22 | # S2: Synchformer: VGS (run 2)
23 | "24-01-02T10-00-53.pt": f"{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/24-01-02T10-00-53.pt",
24 | "cfg-24-01-02T10-00-53.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/cfg-24-01-02T10-00-53.yaml",
25 | # SparseSync: ft VGGSound-Full
26 | "22-09-21T21-00-52.pt": f"{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/22-09-21T21-00-52.pt",
27 | "cfg-22-09-21T21-00-52.yaml": f"{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/cfg-22-09-21T21-00-52.yaml",
28 | # SparseSync: ft VGGSound-Sparse
29 | "22-07-28T15-49-45.pt": f"{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/22-07-28T15-49-45.pt",
30 | "cfg-22-07-28T15-49-45.yaml": f"{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/cfg-22-07-28T15-49-45.yaml",
31 | # SparseSync: only pt on LRS3
32 | "22-07-13T22-25-49.pt": f"{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/22-07-13T22-25-49.pt",
33 | "cfg-22-07-13T22-25-49.yaml": f"{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/cfg-22-07-13T22-25-49.yaml",
34 | # SparseSync: feature extractors
35 | "ResNetAudio-22-08-04T09-51-04.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-04T09-51-04.pt", # 2s
36 | "ResNetAudio-22-08-03T23-14-49.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-49.pt", # 3s
37 | "ResNetAudio-22-08-03T23-14-28.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-28.pt", # 4s
38 | "ResNetAudio-22-06-24T08-10-33.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T08-10-33.pt", # 5s
39 | "ResNetAudio-22-06-24T17-31-07.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T17-31-07.pt", # 6s
40 | "ResNetAudio-22-06-24T23-57-11.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T23-57-11.pt", # 7s
41 | "ResNetAudio-22-06-25T04-35-42.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-25T04-35-42.pt", # 8s
42 | }
43 |
44 |
45 | def check_if_file_exists_else_download(path, fname2link=FNAME2LINK, chunk_size=1024):
46 | """Checks if file exists, if not downloads it from the link to the path"""
47 | path = Path(path)
48 | if not path.exists():
49 | path.parent.mkdir(exist_ok=True, parents=True)
50 | link = fname2link.get(path.name, None)
51 | if link is None:
52 | raise ValueError(
53 | f"Cant find the checkpoint file: {path}.", f"Please download it manually and ensure the path exists."
54 | )
55 | with requests.get(fname2link[path.name], stream=True) as r:
56 | total_size = int(r.headers.get("content-length", 0))
57 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
58 | with open(path, "wb") as f:
59 | for data in r.iter_content(chunk_size=chunk_size):
60 | if data:
61 | f.write(data)
62 | pbar.update(chunk_size)
63 |
64 |
65 | def which_ffmpeg() -> str:
66 | """Determines the path to ffmpeg library
67 | Returns:
68 | str -- path to the library
69 | """
70 | result = subprocess.run(["which", "ffmpeg"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
71 | ffmpeg_path = result.stdout.decode("utf-8").replace("\n", "")
72 | return ffmpeg_path
73 |
74 |
75 | def get_md5sum(path):
76 | hash_md5 = md5()
77 | with open(path, "rb") as f:
78 | for chunk in iter(lambda: f.read(4096 * 8), b""):
79 | hash_md5.update(chunk)
80 | md5sum = hash_md5.hexdigest()
81 | return md5sum
82 |
83 |
84 | class Config:
85 | def __init__(self, **kwargs):
86 | for k, v in kwargs.items():
87 | setattr(self, k, v)
88 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/nn/embed_layers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 | from ...utils.helper import to_2tuple, to_1tuple
6 |
7 | class PatchEmbed1D(nn.Module):
8 | """1D Audio to Patch Embedding
9 |
10 | A convolution based approach to patchifying a 1D audio w/ embedding projection.
11 |
12 | Based on the impl in https://github.com/google-research/vision_transformer
13 |
14 | Hacked together by / Copyright 2020 Ross Wightman
15 | """
16 |
17 | def __init__(
18 | self,
19 | patch_size=1,
20 | in_chans=768,
21 | embed_dim=768,
22 | norm_layer=None,
23 | flatten=True,
24 | bias=True,
25 | dtype=None,
26 | device=None,
27 | ):
28 | factory_kwargs = {"dtype": dtype, "device": device}
29 | super().__init__()
30 | patch_size = to_1tuple(patch_size)
31 | self.patch_size = patch_size
32 | self.flatten = flatten
33 |
34 | self.proj = nn.Conv1d(
35 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs
36 | )
37 | nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
38 | if bias:
39 | nn.init.zeros_(self.proj.bias)
40 |
41 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
42 |
43 | def forward(self, x):
44 | assert (
45 | x.shape[2] % self.patch_size[0] == 0
46 | ), f"The patch_size of {self.patch_size[0]} must be divisible by the token number ({x.shape[2]}) of x."
47 |
48 | x = self.proj(x)
49 | if self.flatten:
50 | x = x.transpose(1, 2) # BCN -> BNC
51 | x = self.norm(x)
52 | return x
53 |
54 |
55 | class ConditionProjection(nn.Module):
56 | """
57 | Projects condition embeddings. Also handles dropout for classifier-free guidance.
58 |
59 | Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
60 | """
61 |
62 | def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
63 | factory_kwargs = {'dtype': dtype, 'device': device}
64 | super().__init__()
65 | self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
66 | self.act_1 = act_layer()
67 | self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
68 |
69 | def forward(self, caption):
70 | hidden_states = self.linear_1(caption)
71 | hidden_states = self.act_1(hidden_states)
72 | hidden_states = self.linear_2(hidden_states)
73 | return hidden_states
74 |
75 |
76 | def timestep_embedding(t, dim, max_period=10000):
77 | """
78 | Create sinusoidal timestep embeddings.
79 |
80 | Args:
81 | t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
82 | dim (int): the dimension of the output.
83 | max_period (int): controls the minimum frequency of the embeddings.
84 |
85 | Returns:
86 | embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
87 |
88 | .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
89 | """
90 | half = dim // 2
91 | freqs = torch.exp(
92 | -math.log(max_period)
93 | * torch.arange(start=0, end=half, dtype=torch.float32)
94 | / half
95 | ).to(device=t.device)
96 | args = t[:, None].float() * freqs[None]
97 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
98 | if dim % 2:
99 | embedding = torch.cat(
100 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
101 | )
102 | return embedding
103 |
104 |
105 | class TimestepEmbedder(nn.Module):
106 | """
107 | Embeds scalar timesteps into vector representations.
108 | """
109 | def __init__(self,
110 | hidden_size,
111 | act_layer,
112 | frequency_embedding_size=256,
113 | max_period=10000,
114 | out_size=None,
115 | dtype=None,
116 | device=None
117 | ):
118 | factory_kwargs = {'dtype': dtype, 'device': device}
119 | super().__init__()
120 | self.frequency_embedding_size = frequency_embedding_size
121 | self.max_period = max_period
122 | if out_size is None:
123 | out_size = hidden_size
124 |
125 | self.mlp = nn.Sequential(
126 | nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
127 | act_layer(),
128 | nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
129 | )
130 | nn.init.normal_(self.mlp[0].weight, std=0.02)
131 | nn.init.normal_(self.mlp[2].weight, std=0.02)
132 |
133 | def forward(self, t):
134 | t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
135 | t_emb = self.mlp(t_freq)
136 | return t_emb
137 |
--------------------------------------------------------------------------------
/download_models_manual.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Manual download helper for HunyuanVideo-Foley models
4 | Run this script directly if the automatic download fails in ComfyUI
5 | """
6 |
7 | import os
8 | import sys
9 | from pathlib import Path
10 | import urllib.request
11 | import time
12 |
13 | # Model download URLs
14 | MODELS = [
15 | {
16 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/hunyuanvideo_foley.pth",
17 | "filename": "hunyuanvideo_foley.pth",
18 | "size_gb": 10.3,
19 | "description": "Main HunyuanVideo-Foley model"
20 | },
21 | {
22 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/synchformer_state_dict.pth",
23 | "filename": "synchformer_state_dict.pth",
24 | "size_gb": 0.95,
25 | "description": "Synchformer model weights"
26 | },
27 | {
28 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/vae_128d_48k.pth",
29 | "filename": "vae_128d_48k.pth",
30 | "size_gb": 1.49,
31 | "description": "VAE model weights"
32 | }
33 | ]
34 |
35 | def download_with_progress(url, dest_path):
36 | """Download file with progress display"""
37 | def progress_hook(block_num, block_size, total_size):
38 | if total_size > 0:
39 | downloaded = block_num * block_size
40 | percent = min(100, (downloaded * 100) // total_size)
41 | size_mb = downloaded / (1024 * 1024)
42 | total_mb = total_size / (1024 * 1024)
43 |
44 | # Print progress
45 | bar_len = 40
46 | filled_len = int(bar_len * percent // 100)
47 | bar = '█' * filled_len + '-' * (bar_len - filled_len)
48 |
49 | sys.stdout.write(f'\r[{bar}] {percent}% ({size_mb:.1f}/{total_mb:.1f} MB)')
50 | sys.stdout.flush()
51 |
52 | try:
53 | urllib.request.urlretrieve(url, dest_path, progress_hook)
54 | print() # New line after progress
55 | return True
56 | except Exception as e:
57 | print(f"\nError: {e}")
58 | return False
59 |
60 | def main():
61 | # Determine ComfyUI models directory
62 | comfyui_root = Path(__file__).parent.parent.parent # Go up to ComfyUI root
63 | models_dir = comfyui_root / "models" / "foley" / "hunyuanvideo-foley-xxl"
64 |
65 | print("=" * 60)
66 | print("HunyuanVideo-Foley Model Downloader")
67 | print("=" * 60)
68 | print(f"\nModels will be downloaded to:")
69 | print(f" {models_dir}")
70 |
71 | # Create directory if it doesn't exist
72 | models_dir.mkdir(parents=True, exist_ok=True)
73 |
74 | # Check disk space
75 | import shutil
76 | stat = shutil.disk_usage(models_dir)
77 | available_gb = stat.free / (1024 ** 3)
78 | required_gb = sum(m["size_gb"] for m in MODELS) + 1 # Add 1GB buffer
79 |
80 | print(f"\nDisk space available: {available_gb:.1f} GB")
81 | print(f"Space required: {required_gb:.1f} GB")
82 |
83 | if available_gb < required_gb:
84 | print(f"\n⚠️ WARNING: Insufficient disk space!")
85 | print(f"Please free up at least {required_gb - available_gb:.1f} GB before continuing.")
86 | response = input("\nContinue anyway? (y/n): ")
87 | if response.lower() != 'y':
88 | return
89 |
90 | print("\n" + "=" * 60)
91 |
92 | # Download each model
93 | for i, model_info in enumerate(MODELS, 1):
94 | model_path = models_dir / model_info["filename"]
95 |
96 | print(f"\n[{i}/{len(MODELS)}] {model_info['description']}")
97 | print(f" File: {model_info['filename']} ({model_info['size_gb']} GB)")
98 |
99 | # Check if already downloaded
100 | if model_path.exists() and model_path.stat().st_size > 100 * 1024 * 1024:
101 | size_gb = model_path.stat().st_size / (1024 ** 3)
102 | print(f" ✓ Already downloaded ({size_gb:.2f} GB)")
103 | continue
104 |
105 | print(f" Downloading from: {model_info['url']}")
106 |
107 | # Try downloading with retries
108 | max_retries = 3
109 | for attempt in range(max_retries):
110 | if attempt > 0:
111 | print(f" Retry {attempt}/{max_retries - 1}...")
112 | time.sleep(5)
113 |
114 | success = download_with_progress(model_info["url"], model_path)
115 | if success:
116 | size_gb = model_path.stat().st_size / (1024 ** 3)
117 | print(f" ✓ Downloaded successfully ({size_gb:.2f} GB)")
118 | break
119 | else:
120 | if attempt == max_retries - 1:
121 | print(f" ✗ Failed to download after {max_retries} attempts")
122 | print(f"\n You can manually download from:")
123 | print(f" {model_info['url']}")
124 | print(f" And place it at:")
125 | print(f" {model_path}")
126 |
127 | print("\n" + "=" * 60)
128 | print("Download process completed!")
129 | print("\nYou can now use the HunyuanVideo-Foley node in ComfyUI.")
130 | print("=" * 60)
131 |
132 | if __name__ == "__main__":
133 | main()
--------------------------------------------------------------------------------
/install.py:
--------------------------------------------------------------------------------
1 | """
2 | Installation script for ComfyUI HunyuanVideo-Foley Custom Node
3 | """
4 |
5 | import os
6 | import sys
7 | import subprocess
8 | import pkg_resources
9 | from pathlib import Path
10 |
11 | def parse_requirements(file_path):
12 | """Parse requirements file and handle git dependencies."""
13 | requirements = []
14 | with open(file_path, 'r') as f:
15 | for line in f:
16 | line = line.strip()
17 | if line and not line.startswith('#'):
18 | if line.startswith('git+'):
19 | # For git repos, find the package name from the egg fragment
20 | egg_name = None
21 | if '#egg=' in line:
22 | egg_name = line.split('#egg=')[-1]
23 |
24 | if egg_name:
25 | requirements.append((egg_name, line))
26 | else:
27 | print(f"⚠️ Git requirement '{line}' is missing the '#egg=' part and cannot be checked. It will be installed regardless.")
28 | # Fallback: We can't check it, so we'll just try to install it.
29 | # The package name is passed as None to signal an install attempt.
30 | requirements.append((None, line))
31 | else:
32 | # Standard package
33 | req = pkg_resources.Requirement.parse(line)
34 | requirements.append((req.project_name, str(req)))
35 | return requirements
36 |
37 | def check_and_install_requirements():
38 | """Check and install required packages without overriding existing ones."""
39 | requirements_file = Path(__file__).parent / "requirements.txt"
40 |
41 | if not requirements_file.exists():
42 | print("❌ Requirements file not found!")
43 | return False
44 |
45 | try:
46 | print("🚀 Checking and installing requirements...")
47 |
48 | # Get list of (package_name, requirement_string)
49 | requirements = parse_requirements(requirements_file)
50 |
51 | for pkg_name, requirement_str in requirements:
52 | # If pkg_name is None, it's a git URL we couldn't parse. Try installing.
53 | if pkg_name is None:
54 | print(f"Attempting to install from git: {requirement_str}")
55 | try:
56 | subprocess.check_call([sys.executable, '-m', 'pip', 'install', requirement_str])
57 | print(f"✅ Successfully installed {requirement_str}")
58 | except subprocess.CalledProcessError as e:
59 | print(f"❌ Failed to install {requirement_str}: {e}")
60 | continue
61 |
62 | # Check if the package is already installed
63 | try:
64 | pkg_resources.require(requirement_str)
65 | print(f"✅ {pkg_name} is already installed and meets version requirements.")
66 | except pkg_resources.DistributionNotFound:
67 | print(f"Installing {pkg_name}...")
68 | try:
69 | subprocess.check_call([sys.executable, '-m', 'pip', 'install', requirement_str])
70 | print(f"✅ Successfully installed {pkg_name}")
71 | except subprocess.CalledProcessError as e:
72 | print(f"❌ Failed to install {pkg_name}: {e}")
73 | except pkg_resources.VersionConflict as e:
74 | print(f"⚠️ Version conflict for {pkg_name}: {e.req} is required, but you have {e.dist}.")
75 | print(" Skipping upgrade to avoid conflicts with other nodes. If you encounter issues, please update this package manually.")
76 | except Exception as e:
77 | print(f"An unexpected error occurred while checking {pkg_name}: {e}")
78 |
79 | print("✅ All dependencies checked.")
80 | return True
81 |
82 | except Exception as e:
83 | print(f"❌ Error installing requirements: {e}")
84 | return False
85 |
86 | def setup_model_directories():
87 | """Create necessary model directories"""
88 | base_dir = Path(__file__).parent.parent.parent # Go up to ComfyUI root
89 |
90 | # Create ComfyUI/models/foley directory for automatic downloads
91 | foley_models_dir = base_dir / "models" / "foley"
92 | foley_models_dir.mkdir(parents=True, exist_ok=True)
93 | print(f"✓ Created ComfyUI models directory: {foley_models_dir}")
94 |
95 | # Also create local fallback directories
96 | node_dir = Path(__file__).parent
97 | local_dirs = [
98 | node_dir / "pretrained_models",
99 | node_dir / "configs"
100 | ]
101 |
102 | for dir_path in local_dirs:
103 | dir_path.mkdir(exist_ok=True)
104 | print(f"✓ Created local directory: {dir_path}")
105 |
106 | def main():
107 | """Main installation function"""
108 | print("🚀 Installing ComfyUI HunyuanVideo-Foley Custom Node...")
109 |
110 | # Install requirements
111 | if not check_and_install_requirements():
112 | print("❌ Failed to install requirements")
113 | return False
114 |
115 | # Setup directories
116 | setup_model_directories()
117 |
118 | print("📋 Installation completed!")
119 | print()
120 | print("📌 Next steps:")
121 | print("1. Restart ComfyUI to load the custom nodes")
122 | print("2. Models will be automatically downloaded when you first use the node")
123 | print("3. Alternatively, manually download models and place them in ComfyUI/models/foley/")
124 | print("4. Model URLs are configured in model_urls.py (can be updated as needed)")
125 | print()
126 |
127 | return True
128 |
129 | if __name__ == "__main__":
130 | main()
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/nn/mlp_layers.py:
--------------------------------------------------------------------------------
1 | # Modified from timm library:
2 | # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
3 |
4 | from functools import partial
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | from .modulate_layers import modulate
11 | from ...utils.helper import to_2tuple
12 |
13 | class MLP(nn.Module):
14 | """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
15 |
16 | def __init__(
17 | self,
18 | in_channels,
19 | hidden_channels=None,
20 | out_features=None,
21 | act_layer=nn.GELU,
22 | norm_layer=None,
23 | bias=True,
24 | drop=0.0,
25 | use_conv=False,
26 | device=None,
27 | dtype=None,
28 | ):
29 | factory_kwargs = {"device": device, "dtype": dtype}
30 | super().__init__()
31 | out_features = out_features or in_channels
32 | hidden_channels = hidden_channels or in_channels
33 | bias = to_2tuple(bias)
34 | drop_probs = to_2tuple(drop)
35 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
36 |
37 | self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs)
38 | self.act = act_layer()
39 | self.drop1 = nn.Dropout(drop_probs[0])
40 | self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity()
41 | self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs)
42 | self.drop2 = nn.Dropout(drop_probs[1])
43 |
44 | def forward(self, x):
45 | x = self.fc1(x)
46 | x = self.act(x)
47 | x = self.drop1(x)
48 | x = self.norm(x)
49 | x = self.fc2(x)
50 | x = self.drop2(x)
51 | return x
52 |
53 |
54 | # copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
55 | # only used when use_vanilla is True
56 | class MLPEmbedder(nn.Module):
57 | def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
58 | factory_kwargs = {"device": device, "dtype": dtype}
59 | super().__init__()
60 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
61 | self.silu = nn.SiLU()
62 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
63 |
64 | def forward(self, x: torch.Tensor) -> torch.Tensor:
65 | return self.out_layer(self.silu(self.in_layer(x)))
66 |
67 |
68 | class LinearWarpforSingle(nn.Module):
69 | def __init__(self, in_dim: int, out_dim: int, bias=True, device=None, dtype=None):
70 | factory_kwargs = {"device": device, "dtype": dtype}
71 | super().__init__()
72 | self.fc = nn.Linear(in_dim, out_dim, bias=bias, **factory_kwargs)
73 |
74 | def forward(self, x, y):
75 | z = torch.cat([x, y], dim=2)
76 | return self.fc(z)
77 |
78 | class FinalLayer1D(nn.Module):
79 | def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None):
80 | factory_kwargs = {"device": device, "dtype": dtype}
81 | super().__init__()
82 |
83 | # Just use LayerNorm for the final layer
84 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
85 | self.linear = nn.Linear(hidden_size, patch_size * out_channels, bias=True, **factory_kwargs)
86 | nn.init.zeros_(self.linear.weight)
87 | nn.init.zeros_(self.linear.bias)
88 |
89 | # Here we don't distinguish between the modulate types. Just use the simple one.
90 | self.adaLN_modulation = nn.Sequential(
91 | act_layer(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs)
92 | )
93 | # Zero-initialize the modulation
94 | nn.init.zeros_(self.adaLN_modulation[1].weight)
95 | nn.init.zeros_(self.adaLN_modulation[1].bias)
96 |
97 | def forward(self, x, c):
98 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
99 | x = modulate(self.norm_final(x), shift=shift, scale=scale)
100 | x = self.linear(x)
101 | return x
102 |
103 |
104 | class ChannelLastConv1d(nn.Conv1d):
105 |
106 | def forward(self, x: torch.Tensor) -> torch.Tensor:
107 | x = x.permute(0, 2, 1)
108 | x = super().forward(x)
109 | x = x.permute(0, 2, 1)
110 | return x
111 |
112 |
113 | class ConvMLP(nn.Module):
114 |
115 | def __init__(
116 | self,
117 | dim: int,
118 | hidden_dim: int,
119 | multiple_of: int = 256,
120 | kernel_size: int = 3,
121 | padding: int = 1,
122 | device=None,
123 | dtype=None,
124 | ):
125 | """
126 | Convolutional MLP module.
127 |
128 | Args:
129 | dim (int): Input dimension.
130 | hidden_dim (int): Hidden dimension of the feedforward layer.
131 | multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
132 |
133 | Attributes:
134 | w1: Linear transformation for the first layer.
135 | w2: Linear transformation for the second layer.
136 | w3: Linear transformation for the third layer.
137 |
138 | """
139 | factory_kwargs = {"device": device, "dtype": dtype}
140 | super().__init__()
141 | hidden_dim = int(2 * hidden_dim / 3)
142 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
143 |
144 | self.w1 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
145 | self.w2 = ChannelLastConv1d(hidden_dim, dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
146 | self.w3 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
147 |
148 | def forward(self, x):
149 | return self.w2(F.silu(self.w1(x)) * self.w3(x))
150 |
--------------------------------------------------------------------------------
/test_node.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Test script for ComfyUI HunyuanVideo-Foley custom node
4 | """
5 |
6 | import sys
7 | import os
8 | import tempfile
9 | from pathlib import Path
10 |
11 | # Add the parent directory to path for imports
12 | current_dir = Path(__file__).parent
13 | parent_dir = current_dir.parent
14 | sys.path.insert(0, str(parent_dir))
15 |
16 | def test_imports():
17 | """Test that all required modules can be imported"""
18 | print("Testing imports...")
19 |
20 | try:
21 | from ComfyUI_HunyuanVideoFoley import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
22 | print("✅ Successfully imported node mappings")
23 |
24 | print(f"Available nodes: {list(NODE_CLASS_MAPPINGS.keys())}")
25 | print(f"Display names: {NODE_DISPLAY_NAME_MAPPINGS}")
26 |
27 | return True
28 |
29 | except ImportError as e:
30 | print(f"❌ Import failed: {e}")
31 | return False
32 |
33 | def test_node_structure():
34 | """Test node class structure"""
35 | print("\nTesting node structure...")
36 |
37 | try:
38 | from ComfyUI_HunyuanVideoFoley.nodes import HunyuanVideoFoleyNode, HunyuanVideoFoleyModelLoader
39 |
40 | # Test HunyuanVideoFoleyNode
41 | node = HunyuanVideoFoleyNode()
42 | input_types = node.INPUT_TYPES()
43 |
44 | print("✅ HunyuanVideoFoleyNode structure:")
45 | print(f" - Required inputs: {list(input_types['required'].keys())}")
46 | print(f" - Optional inputs: {list(input_types.get('optional', {}).keys())}")
47 | print(f" - Return types: {node.RETURN_TYPES}")
48 | print(f" - Function: {node.FUNCTION}")
49 | print(f" - Category: {node.CATEGORY}")
50 |
51 | # Test HunyuanVideoFoleyModelLoader
52 | loader = HunyuanVideoFoleyModelLoader()
53 | loader_input_types = loader.INPUT_TYPES()
54 |
55 | print("✅ HunyuanVideoFoleyModelLoader structure:")
56 | print(f" - Required inputs: {list(loader_input_types['required'].keys())}")
57 | print(f" - Return types: {loader.RETURN_TYPES}")
58 | print(f" - Function: {loader.FUNCTION}")
59 |
60 | return True
61 |
62 | except Exception as e:
63 | print(f"❌ Node structure test failed: {e}")
64 | return False
65 |
66 | def test_device_setup():
67 | """Test device setup functionality"""
68 | print("\nTesting device setup...")
69 |
70 | try:
71 | from ComfyUI_HunyuanVideoFoley.nodes import HunyuanVideoFoleyNode
72 |
73 | device = HunyuanVideoFoleyNode.setup_device("auto")
74 | print(f"✅ Device setup successful: {device}")
75 |
76 | return True
77 |
78 | except Exception as e:
79 | print(f"❌ Device setup failed: {e}")
80 | return False
81 |
82 | def test_utils():
83 | """Test utility functions"""
84 | print("\nTesting utility functions...")
85 |
86 | try:
87 | from ComfyUI_HunyuanVideoFoley.utils import (
88 | get_optimal_device,
89 | check_memory_requirements,
90 | format_duration,
91 | validate_model_files
92 | )
93 |
94 | # Test device detection
95 | device = get_optimal_device()
96 | print(f"✅ Optimal device: {device}")
97 |
98 | # Test memory check
99 | has_memory, msg = check_memory_requirements(device)
100 | print(f"✅ Memory check: {msg}")
101 |
102 | # Test duration formatting
103 | duration = format_duration(125.5)
104 | print(f"✅ Duration formatting: 125.5s -> {duration}")
105 |
106 | # Test model validation (will fail without models, but that's expected)
107 | is_valid, msg = validate_model_files("./pretrained_models/")
108 | print(f"✅ Model validation: {msg}")
109 |
110 | return True
111 |
112 | except Exception as e:
113 | print(f"❌ Utils test failed: {e}")
114 | return False
115 |
116 | def test_requirements():
117 | """Test if key requirements are available"""
118 | print("\nTesting requirements...")
119 |
120 | required_packages = [
121 | 'torch',
122 | 'torchaudio',
123 | 'numpy',
124 | 'loguru',
125 | 'diffusers',
126 | 'transformers'
127 | ]
128 |
129 | missing = []
130 | for package in required_packages:
131 | try:
132 | __import__(package)
133 | print(f"✅ {package}")
134 | except ImportError:
135 | print(f"❌ {package} - not installed")
136 | missing.append(package)
137 |
138 | if missing:
139 | print(f"\nMissing packages: {', '.join(missing)}")
140 | print("Run: pip install -r requirements.txt")
141 | return False
142 |
143 | return True
144 |
145 | def main():
146 | """Main test function"""
147 | print("🧪 Testing ComfyUI HunyuanVideo-Foley Custom Node")
148 | print("=" * 50)
149 |
150 | tests = [
151 | ("Requirements", test_requirements),
152 | ("Imports", test_imports),
153 | ("Node Structure", test_node_structure),
154 | ("Device Setup", test_device_setup),
155 | ("Utils", test_utils),
156 | ]
157 |
158 | passed = 0
159 | failed = 0
160 |
161 | for test_name, test_func in tests:
162 | print(f"\n🔍 Running test: {test_name}")
163 | try:
164 | if test_func():
165 | passed += 1
166 | print(f"✅ {test_name} PASSED")
167 | else:
168 | failed += 1
169 | print(f"❌ {test_name} FAILED")
170 | except Exception as e:
171 | failed += 1
172 | print(f"❌ {test_name} FAILED with exception: {e}")
173 |
174 | print("\n" + "=" * 50)
175 | print(f"📊 Test Results: {passed} passed, {failed} failed")
176 |
177 | if failed == 0:
178 | print("🎉 All tests passed! The custom node is ready for use.")
179 | else:
180 | print("⚠️ Some tests failed. Please check the issues above.")
181 |
182 | return failed == 0
183 |
184 | if __name__ == "__main__":
185 | success = main()
186 | sys.exit(0 if success else 1)
--------------------------------------------------------------------------------
/hunyuanvideo_foley/utils/feature_utils.py:
--------------------------------------------------------------------------------
1 | """Feature extraction utilities for video and text processing."""
2 |
3 | import os
4 | import numpy as np
5 | import torch
6 | import av
7 | from PIL import Image
8 | from einops import rearrange
9 | from typing import Any, Dict, List, Union, Tuple
10 | from loguru import logger
11 |
12 | from .config_utils import AttributeDict
13 | from ..constants import FPS_VISUAL, MAX_VIDEO_DURATION_SECONDS
14 |
15 |
16 | class FeatureExtractionError(Exception):
17 | """Exception raised for feature extraction errors."""
18 | pass
19 |
20 | def get_frames_av(
21 | video_path: str,
22 | fps: float,
23 | max_length: float = None,
24 | ) -> Tuple[np.ndarray, float]:
25 | end_sec = max_length if max_length is not None else 15
26 | next_frame_time_for_each_fps = 0.0
27 | time_delta_for_each_fps = 1 / fps
28 |
29 | all_frames = []
30 | output_frames = []
31 |
32 | with av.open(video_path) as container:
33 | stream = container.streams.video[0]
34 | ori_fps = stream.guessed_rate
35 | stream.thread_type = "AUTO"
36 | for packet in container.demux(stream):
37 | for frame in packet.decode():
38 | frame_time = frame.time
39 | if frame_time < 0:
40 | continue
41 | if frame_time > end_sec:
42 | break
43 |
44 | frame_np = None
45 |
46 | this_time = frame_time
47 | while this_time >= next_frame_time_for_each_fps:
48 | if frame_np is None:
49 | frame_np = frame.to_ndarray(format="rgb24")
50 |
51 | output_frames.append(frame_np)
52 | next_frame_time_for_each_fps += time_delta_for_each_fps
53 |
54 | output_frames = np.stack(output_frames)
55 |
56 | vid_len_in_s = len(output_frames) / fps
57 | if max_length is not None and len(output_frames) > int(max_length * fps):
58 | output_frames = output_frames[: int(max_length * fps)]
59 | vid_len_in_s = max_length
60 |
61 | return output_frames, vid_len_in_s
62 |
63 | @torch.inference_mode()
64 | def encode_video_with_siglip2(x: torch.Tensor, model_dict, batch_size: int = -1):
65 | b, t, c, h, w = x.shape
66 | if batch_size < 0:
67 | batch_size = b * t
68 | x = rearrange(x, "b t c h w -> (b t) c h w")
69 | outputs = []
70 | for i in range(0, b * t, batch_size):
71 | pixel_values = x[i : i + batch_size]
72 | # --- Transformers Compatibility Fix ---
73 | if hasattr(model_dict.siglip2_model, 'get_image_features'):
74 | # Older transformers versions
75 | features = model_dict.siglip2_model.get_image_features(pixel_values=pixel_values)
76 | else:
77 | # Newer transformers versions
78 | features = model_dict.siglip2_model(pixel_values=pixel_values).image_embeds
79 | outputs.append(features)
80 | # --- End of Fix ---
81 | res = torch.cat(outputs, dim=0)
82 | res = rearrange(res, "(b t) d -> b t d", b=b)
83 | return res
84 |
85 | @torch.inference_mode()
86 | def encode_video_with_sync(x: torch.Tensor, model_dict, batch_size: int = -1):
87 | """
88 | The input video of x is best to be in fps of 24 of greater than 24.
89 | Input:
90 | x: tensor in shape of [B, T, C, H, W]
91 | batch_size: the batch_size for synchformer inference
92 | """
93 | b, t, c, h, w = x.shape
94 | assert c == 3 and h == 224 and w == 224
95 |
96 | segment_size = 16
97 | step_size = 8
98 | num_segments = (t - segment_size) // step_size + 1
99 | segments = []
100 | for i in range(num_segments):
101 | segments.append(x[:, i * step_size : i * step_size + segment_size])
102 | x = torch.stack(segments, dim=1).cuda() # (B, num_segments, segment_size, 3, 224, 224)
103 |
104 | outputs = []
105 | if batch_size < 0:
106 | batch_size = b * num_segments
107 | x = rearrange(x, "b s t c h w -> (b s) 1 t c h w")
108 | for i in range(0, b * num_segments, batch_size):
109 | with torch.autocast(device_type="cuda", enabled=True, dtype=torch.half):
110 | outputs.append(model_dict.syncformer_model(x[i : i + batch_size]))
111 | x = torch.cat(outputs, dim=0) # [b * num_segments, 1, 8, 768]
112 | x = rearrange(x, "(b s) 1 t d -> b (s t) d", b=b)
113 | return x
114 |
115 |
116 | @torch.inference_mode()
117 | def encode_video_features(video_path, model_dict):
118 | visual_features = {}
119 | # siglip2 visual features
120 | frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["siglip2"])
121 | images = [Image.fromarray(frame).convert('RGB') for frame in frames]
122 | images = [model_dict.siglip2_preprocess(image) for image in images] # [T, C, H, W]
123 | clip_frames = torch.stack(images).to(model_dict.device).unsqueeze(0)
124 | visual_features['siglip2_feat'] = encode_video_with_siglip2(clip_frames, model_dict).to(model_dict.device)
125 |
126 | # synchformer visual features
127 | frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["synchformer"])
128 | images = torch.from_numpy(frames).permute(0, 3, 1, 2) # [T, C, H, W]
129 | sync_frames = model_dict.syncformer_preprocess(images).unsqueeze(0) # [1, T, 3, 224, 224]
130 | # [1, num_segments * 8, channel_dim], e.g. [1, 240, 768] for 10s video
131 | visual_features['syncformer_feat'] = encode_video_with_sync(sync_frames, model_dict)
132 |
133 | vid_len_in_s = sync_frames.shape[1] / FPS_VISUAL["synchformer"]
134 | visual_features = AttributeDict(visual_features)
135 |
136 | return visual_features, vid_len_in_s
137 |
138 | @torch.inference_mode()
139 | def encode_text_feat(text: List[str], model_dict):
140 | # x: (B, L)
141 | inputs = model_dict.clap_tokenizer(text, padding=True, return_tensors="pt").to(model_dict.device)
142 | outputs = model_dict.clap_model(**inputs, output_hidden_states=True, return_dict=True)
143 | return outputs.last_hidden_state, outputs.attentions
144 |
145 |
146 | def feature_process(video_path, prompt, model_dict, cfg):
147 | visual_feats, audio_len_in_s = encode_video_features(video_path, model_dict)
148 | neg_prompt = "noisy, harsh"
149 | prompts = [neg_prompt, prompt]
150 | text_feat_res, text_feat_mask = encode_text_feat(prompts, model_dict)
151 |
152 | text_feat = text_feat_res[1:]
153 | uncond_text_feat = text_feat_res[:1]
154 |
155 | if cfg.model_config.model_kwargs.text_length < text_feat.shape[1]:
156 | text_seq_length = cfg.model_config.model_kwargs.text_length
157 | text_feat = text_feat[:, :text_seq_length]
158 | uncond_text_feat = uncond_text_feat[:, :text_seq_length]
159 |
160 | text_feats = AttributeDict({
161 | 'text_feat': text_feat,
162 | 'uncond_text_feat': uncond_text_feat,
163 | })
164 |
165 | return visual_feats, text_feats, audio_len_in_s
166 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/nn/posemb_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Union, Tuple
3 |
4 |
5 | def _to_tuple(x, dim=2):
6 | if isinstance(x, int):
7 | return (x,) * dim
8 | elif len(x) == dim:
9 | return x
10 | else:
11 | raise ValueError(f"Expected length {dim} or int, but got {x}")
12 |
13 |
14 | def get_meshgrid_nd(start, *args, dim=2):
15 | """
16 | Get n-D meshgrid with start, stop and num.
17 |
18 | Args:
19 | start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
20 | step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
21 | should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
22 | n-tuples.
23 | *args: See above.
24 | dim (int): Dimension of the meshgrid. Defaults to 2.
25 |
26 | Returns:
27 | grid (np.ndarray): [dim, ...]
28 | """
29 | if len(args) == 0:
30 | # start is grid_size
31 | num = _to_tuple(start, dim=dim)
32 | start = (0,) * dim
33 | stop = num
34 | elif len(args) == 1:
35 | # start is start, args[0] is stop, step is 1
36 | start = _to_tuple(start, dim=dim)
37 | stop = _to_tuple(args[0], dim=dim)
38 | num = [stop[i] - start[i] for i in range(dim)]
39 | elif len(args) == 2:
40 | # start is start, args[0] is stop, args[1] is num
41 | start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
42 | stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
43 | num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
44 | else:
45 | raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
46 |
47 | # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
48 | axis_grid = []
49 | for i in range(dim):
50 | a, b, n = start[i], stop[i], num[i]
51 | g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
52 | axis_grid.append(g)
53 | grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
54 | grid = torch.stack(grid, dim=0) # [dim, W, H, D]
55 |
56 | return grid
57 |
58 |
59 | #################################################################################
60 | # Rotary Positional Embedding Functions #
61 | #################################################################################
62 | # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
63 |
64 |
65 | def get_nd_rotary_pos_embed(
66 | rope_dim_list, start, *args, theta=10000.0, use_real=False, theta_rescale_factor=1.0, freq_scaling=1.0
67 | ):
68 | """
69 | This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
70 |
71 | Args:
72 | rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
73 | sum(rope_dim_list) should equal to head_dim of attention layer.
74 | start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
75 | args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
76 | *args: See above.
77 | theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
78 | use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
79 | Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
80 | part and an imaginary part separately.
81 | theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
82 | freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0.
83 |
84 | Returns:
85 | pos_embed (torch.Tensor): [HW, D/2]
86 | """
87 |
88 | grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
89 |
90 | # use 1/ndim of dimensions to encode grid_axis
91 | embs = []
92 | for i in range(len(rope_dim_list)):
93 | emb = get_1d_rotary_pos_embed(
94 | rope_dim_list[i],
95 | grid[i].reshape(-1),
96 | theta,
97 | use_real=use_real,
98 | theta_rescale_factor=theta_rescale_factor,
99 | freq_scaling=freq_scaling,
100 | ) # 2 x [WHD, rope_dim_list[i]]
101 | embs.append(emb)
102 |
103 | if use_real:
104 | cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
105 | sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
106 | return cos, sin
107 | else:
108 | emb = torch.cat(embs, dim=1) # (WHD, D/2)
109 | return emb
110 |
111 |
112 | def get_1d_rotary_pos_embed(
113 | dim: int,
114 | pos: Union[torch.FloatTensor, int],
115 | theta: float = 10000.0,
116 | use_real: bool = False,
117 | theta_rescale_factor: float = 1.0,
118 | freq_scaling: float = 1.0,
119 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
120 | """
121 | Precompute the frequency tensor for complex exponential (cis) with given dimensions.
122 | (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
123 |
124 | This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
125 | and the end index 'end'. The 'theta' parameter scales the frequencies.
126 | The returned tensor contains complex values in complex64 data type.
127 |
128 | Args:
129 | dim (int): Dimension of the frequency tensor.
130 | pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
131 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
132 | use_real (bool, optional): If True, return real part and imaginary part separately.
133 | Otherwise, return complex numbers.
134 | theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
135 | freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0.
136 |
137 | Returns:
138 | freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
139 | freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
140 | """
141 | if isinstance(pos, int):
142 | pos = torch.arange(pos).float()
143 |
144 | # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
145 | # has some connection to NTK literature
146 | # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
147 | if theta_rescale_factor != 1.0:
148 | theta *= theta_rescale_factor ** (dim / (dim - 1))
149 |
150 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
151 | freqs *= freq_scaling
152 | freqs = torch.outer(pos, freqs) # [S, D/2]
153 | if use_real:
154 | freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
155 | freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
156 | return freqs_cos, freqs_sin
157 | else:
158 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
159 | return freqs_cis
160 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/model/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from audiotools import AudioSignal
5 | from audiotools import ml
6 | from audiotools import STFTParams
7 | from einops import rearrange
8 | from torch.nn.utils import weight_norm
9 |
10 |
11 | def WNConv1d(*args, **kwargs):
12 | act = kwargs.pop("act", True)
13 | conv = weight_norm(nn.Conv1d(*args, **kwargs))
14 | if not act:
15 | return conv
16 | return nn.Sequential(conv, nn.LeakyReLU(0.1))
17 |
18 |
19 | def WNConv2d(*args, **kwargs):
20 | act = kwargs.pop("act", True)
21 | conv = weight_norm(nn.Conv2d(*args, **kwargs))
22 | if not act:
23 | return conv
24 | return nn.Sequential(conv, nn.LeakyReLU(0.1))
25 |
26 |
27 | class MPD(nn.Module):
28 | def __init__(self, period):
29 | super().__init__()
30 | self.period = period
31 | self.convs = nn.ModuleList(
32 | [
33 | WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
34 | WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
35 | WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
36 | WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
37 | WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
38 | ]
39 | )
40 | self.conv_post = WNConv2d(
41 | 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
42 | )
43 |
44 | def pad_to_period(self, x):
45 | t = x.shape[-1]
46 | x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
47 | return x
48 |
49 | def forward(self, x):
50 | fmap = []
51 |
52 | x = self.pad_to_period(x)
53 | x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
54 |
55 | for layer in self.convs:
56 | x = layer(x)
57 | fmap.append(x)
58 |
59 | x = self.conv_post(x)
60 | fmap.append(x)
61 |
62 | return fmap
63 |
64 |
65 | class MSD(nn.Module):
66 | def __init__(self, rate: int = 1, sample_rate: int = 44100):
67 | super().__init__()
68 | self.convs = nn.ModuleList(
69 | [
70 | WNConv1d(1, 16, 15, 1, padding=7),
71 | WNConv1d(16, 64, 41, 4, groups=4, padding=20),
72 | WNConv1d(64, 256, 41, 4, groups=16, padding=20),
73 | WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
74 | WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
75 | WNConv1d(1024, 1024, 5, 1, padding=2),
76 | ]
77 | )
78 | self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
79 | self.sample_rate = sample_rate
80 | self.rate = rate
81 |
82 | def forward(self, x):
83 | x = AudioSignal(x, self.sample_rate)
84 | x.resample(self.sample_rate // self.rate)
85 | x = x.audio_data
86 |
87 | fmap = []
88 |
89 | for l in self.convs:
90 | x = l(x)
91 | fmap.append(x)
92 | x = self.conv_post(x)
93 | fmap.append(x)
94 |
95 | return fmap
96 |
97 |
98 | BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
99 |
100 |
101 | class MRD(nn.Module):
102 | def __init__(
103 | self,
104 | window_length: int,
105 | hop_factor: float = 0.25,
106 | sample_rate: int = 44100,
107 | bands: list = BANDS,
108 | ):
109 | """Complex multi-band spectrogram discriminator.
110 | Parameters
111 | ----------
112 | window_length : int
113 | Window length of STFT.
114 | hop_factor : float, optional
115 | Hop factor of the STFT, defaults to ``0.25 * window_length``.
116 | sample_rate : int, optional
117 | Sampling rate of audio in Hz, by default 44100
118 | bands : list, optional
119 | Bands to run discriminator over.
120 | """
121 | super().__init__()
122 |
123 | self.window_length = window_length
124 | self.hop_factor = hop_factor
125 | self.sample_rate = sample_rate
126 | self.stft_params = STFTParams(
127 | window_length=window_length,
128 | hop_length=int(window_length * hop_factor),
129 | match_stride=True,
130 | )
131 |
132 | n_fft = window_length // 2 + 1
133 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
134 | self.bands = bands
135 |
136 | ch = 32
137 | convs = lambda: nn.ModuleList(
138 | [
139 | WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
140 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
141 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
142 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
143 | WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
144 | ]
145 | )
146 | self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
147 | self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
148 |
149 | def spectrogram(self, x):
150 | x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
151 | x = torch.view_as_real(x.stft())
152 | x = rearrange(x, "b 1 f t c -> (b 1) c t f")
153 | # Split into bands
154 | x_bands = [x[..., b[0] : b[1]] for b in self.bands]
155 | return x_bands
156 |
157 | def forward(self, x):
158 | x_bands = self.spectrogram(x)
159 | fmap = []
160 |
161 | x = []
162 | for band, stack in zip(x_bands, self.band_convs):
163 | for layer in stack:
164 | band = layer(band)
165 | fmap.append(band)
166 | x.append(band)
167 |
168 | x = torch.cat(x, dim=-1)
169 | x = self.conv_post(x)
170 | fmap.append(x)
171 |
172 | return fmap
173 |
174 |
175 | class Discriminator(ml.BaseModel):
176 | def __init__(
177 | self,
178 | rates: list = [],
179 | periods: list = [2, 3, 5, 7, 11],
180 | fft_sizes: list = [2048, 1024, 512],
181 | sample_rate: int = 44100,
182 | bands: list = BANDS,
183 | ):
184 | """Discriminator that combines multiple discriminators.
185 |
186 | Parameters
187 | ----------
188 | rates : list, optional
189 | sampling rates (in Hz) to run MSD at, by default []
190 | If empty, MSD is not used.
191 | periods : list, optional
192 | periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
193 | fft_sizes : list, optional
194 | Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
195 | sample_rate : int, optional
196 | Sampling rate of audio in Hz, by default 44100
197 | bands : list, optional
198 | Bands to run MRD at, by default `BANDS`
199 | """
200 | super().__init__()
201 | discs = []
202 | discs += [MPD(p) for p in periods]
203 | discs += [MSD(r, sample_rate=sample_rate) for r in rates]
204 | discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
205 | self.discriminators = nn.ModuleList(discs)
206 |
207 | def preprocess(self, y):
208 | # Remove DC offset
209 | y = y - y.mean(dim=-1, keepdims=True)
210 | # Peak normalize the volume of input audio
211 | y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
212 | return y
213 |
214 | def forward(self, x):
215 | x = self.preprocess(x)
216 | fmaps = [d(x) for d in self.discriminators]
217 | return fmaps
218 |
219 |
220 | if __name__ == "__main__":
221 | disc = Discriminator()
222 | x = torch.zeros(1, 1, 44100)
223 | results = disc(x)
224 | for i, result in enumerate(results):
225 | print(f"disc{i}")
226 | for i, r in enumerate(result):
227 | print(r.shape, r.mean(), r.min(), r.max())
228 | print()
229 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/synchformer/compute_desync_score.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import subprocess
3 | from pathlib import Path
4 |
5 | import torch
6 | import torchaudio
7 | import torchvision
8 | from omegaconf import OmegaConf
9 |
10 | import data_transforms
11 | from .synchformer import Synchformer
12 | from .data_transforms import make_class_grid, quantize_offset
13 | from .utils import check_if_file_exists_else_download, which_ffmpeg
14 |
15 |
16 | def prepare_inputs(batch, device):
17 | aud = batch["audio"].to(device)
18 | vid = batch["video"].to(device)
19 |
20 | return aud, vid
21 |
22 |
23 | def get_test_transforms():
24 | ts = [
25 | data_transforms.EqualifyFromRight(),
26 | data_transforms.RGBSpatialCrop(input_size=224, is_random=False),
27 | data_transforms.TemporalCropAndOffset(
28 | crop_len_sec=5,
29 | max_off_sec=2, # https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml
30 | max_wiggle_sec=0.0,
31 | do_offset=True,
32 | offset_type="grid",
33 | prob_oos="null",
34 | grid_size=21,
35 | segment_size_vframes=16,
36 | n_segments=14,
37 | step_size_seg=0.5,
38 | vfps=25,
39 | ),
40 | data_transforms.GenerateMultipleSegments(
41 | segment_size_vframes=16,
42 | n_segments=14,
43 | is_start_random=False,
44 | step_size_seg=0.5,
45 | ),
46 | data_transforms.RGBToHalfToZeroOne(),
47 | data_transforms.RGBNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # motionformer normalization
48 | data_transforms.AudioMelSpectrogram(
49 | sample_rate=16000,
50 | win_length=400, # 25 ms * 16 kHz
51 | hop_length=160, # 10 ms * 16 kHz
52 | n_fft=1024, # 2^(ceil(log2(window_size * sampling_rate)))
53 | n_mels=128, # as in AST
54 | ),
55 | data_transforms.AudioLog(),
56 | data_transforms.PadOrTruncate(max_spec_t=66),
57 | data_transforms.AudioNormalizeAST(mean=-4.2677393, std=4.5689974), # AST, pre-trained on AudioSet
58 | data_transforms.PermuteStreams(
59 | einops_order_audio="S F T -> S 1 F T", einops_order_rgb="S T C H W -> S T C H W" # same
60 | ),
61 | ]
62 | transforms = torchvision.transforms.Compose(ts)
63 |
64 | return transforms
65 |
66 |
67 | def get_video_and_audio(path, get_meta=False, start_sec=0, end_sec=None):
68 | orig_path = path
69 | # (Tv, 3, H, W) [0, 255, uint8]; (Ca, Ta)
70 | rgb, audio, meta = torchvision.io.read_video(str(path), start_sec, end_sec, "sec", output_format="TCHW")
71 | assert meta["video_fps"], f"No video fps for {orig_path}"
72 | # (Ta) <- (Ca, Ta)
73 | audio = audio.mean(dim=0)
74 | # FIXME: this is legacy format of `meta` as it used to be loaded by VideoReader.
75 | meta = {
76 | "video": {"fps": [meta["video_fps"]]},
77 | "audio": {"framerate": [meta["audio_fps"]]},
78 | }
79 | return rgb, audio, meta
80 |
81 |
82 | def reencode_video(path, vfps=25, afps=16000, in_size=256):
83 | assert which_ffmpeg() != "", "Is ffmpeg installed? Check if the conda environment is activated."
84 | new_path = Path.cwd() / "vis" / f"{Path(path).stem}_{vfps}fps_{in_size}side_{afps}hz.mp4"
85 | new_path.parent.mkdir(exist_ok=True)
86 | new_path = str(new_path)
87 | cmd = f"{which_ffmpeg()}"
88 | # no info/error printing
89 | cmd += " -hide_banner -loglevel panic"
90 | cmd += f" -y -i {path}"
91 | # 1) change fps, 2) resize: min(H,W)=MIN_SIDE (vertical vids are supported), 3) change audio framerate
92 | cmd += f" -vf fps={vfps},scale=iw*{in_size}/'min(iw,ih)':ih*{in_size}/'min(iw,ih)',crop='trunc(iw/2)'*2:'trunc(ih/2)'*2"
93 | cmd += f" -ar {afps}"
94 | cmd += f" {new_path}"
95 | subprocess.call(cmd.split())
96 | cmd = f"{which_ffmpeg()}"
97 | cmd += " -hide_banner -loglevel panic"
98 | cmd += f" -y -i {new_path}"
99 | cmd += f" -acodec pcm_s16le -ac 1"
100 | cmd += f' {new_path.replace(".mp4", ".wav")}'
101 | subprocess.call(cmd.split())
102 | return new_path
103 |
104 |
105 | def decode_single_video_prediction(off_logits, grid, item):
106 | label = item["targets"]["offset_label"].item()
107 | print("Ground Truth offset (sec):", f"{label:.2f} ({quantize_offset(grid, label)[-1].item()})")
108 | print()
109 | print("Prediction Results:")
110 | off_probs = torch.softmax(off_logits, dim=-1)
111 | k = min(off_probs.shape[-1], 5)
112 | topk_logits, topk_preds = torch.topk(off_logits, k)
113 | # remove batch dimension
114 | assert len(topk_logits) == 1, "batch is larger than 1"
115 | topk_logits = topk_logits[0]
116 | topk_preds = topk_preds[0]
117 | off_logits = off_logits[0]
118 | off_probs = off_probs[0]
119 | for target_hat in topk_preds:
120 | print(f'p={off_probs[target_hat]:.4f} ({off_logits[target_hat]:.4f}), "{grid[target_hat]:.2f}" ({target_hat})')
121 | return off_probs
122 |
123 |
124 | def main(args):
125 | vfps = 25
126 | afps = 16000
127 | in_size = 256
128 | # making the offset class grid similar to the one used in transforms,
129 | # refer to the used one: https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml
130 | max_off_sec = 2
131 | num_cls = 21
132 |
133 | # checking if the provided video has the correct frame rates
134 | print(f"Using video: {args.vid_path}")
135 | v, _, info = torchvision.io.read_video(args.vid_path, pts_unit="sec")
136 | _, H, W, _ = v.shape
137 | if info["video_fps"] != vfps or info["audio_fps"] != afps or min(H, W) != in_size:
138 | print(f'Reencoding. vfps: {info["video_fps"]} -> {vfps};', end=" ")
139 | print(f'afps: {info["audio_fps"]} -> {afps};', end=" ")
140 | print(f"{(H, W)} -> min(H, W)={in_size}")
141 | args.vid_path = reencode_video(args.vid_path, vfps, afps, in_size)
142 | else:
143 | print(f'Skipping reencoding. vfps: {info["video_fps"]}; afps: {info["audio_fps"]}; min(H, W)={in_size}')
144 |
145 | device = torch.device(args.device)
146 |
147 | # load visual and audio streams
148 | # rgb: (Tv, 3, H, W) in [0, 225], audio: (Ta,) in [-1, 1]
149 | rgb, audio, meta = get_video_and_audio(args.vid_path, get_meta=True)
150 |
151 | # making an item (dict) to apply transformations
152 | # NOTE: here is how it works:
153 | # For instance, if the model is trained on 5sec clips, the provided video is 9sec, and `v_start_i_sec=1.3`
154 | # the transform will crop out a 5sec-clip from 1.3 to 6.3 seconds and shift the start of the audio
155 | # track by `args.offset_sec` seconds. It means that if `offset_sec` > 0, the audio will
156 | # start by `offset_sec` earlier than the rgb track.
157 | # It is a good idea to use something in [-`max_off_sec`, `max_off_sec`] (-2, +2) seconds (see `grid`)
158 | item = dict(
159 | video=rgb,
160 | audio=audio,
161 | meta=meta,
162 | path=args.vid_path,
163 | split="test",
164 | targets={
165 | "v_start_i_sec": args.v_start_i_sec,
166 | "offset_sec": args.offset_sec,
167 | },
168 | )
169 |
170 | grid = make_class_grid(-max_off_sec, max_off_sec, num_cls)
171 | if not (min(grid) <= item["targets"]["offset_sec"] <= max(grid)):
172 | print(f'WARNING: offset_sec={item["targets"]["offset_sec"]} is outside the trained grid: {grid}')
173 |
174 | # applying the test-time transform
175 | item = get_test_transforms()(item)
176 |
177 | # prepare inputs for inference
178 | batch = torch.utils.data.default_collate([item])
179 | aud, vid = prepare_inputs(batch, device)
180 |
181 | # TODO:
182 | # sanity check: we will take the input to the `model` and recontruct make a video from it.
183 | # Use this check to make sure the input makes sense (audio should be ok but shifted as you specified)
184 | # reconstruct_video_from_input(aud, vid, batch['meta'], args.vid_path, args.v_start_i_sec, args.offset_sec,
185 | # vfps, afps)
186 |
187 | # forward pass
188 | with torch.set_grad_enabled(False):
189 | with torch.autocast("cuda", enabled=True):
190 | _, logits = synchformer(vid, aud)
191 |
192 | # simply prints the results of the prediction
193 | decode_single_video_prediction(logits, grid, item)
194 |
195 |
196 | if __name__ == "__main__":
197 | parser = argparse.ArgumentParser()
198 | parser.add_argument("--exp_name", required=True, help="In a format: xx-xx-xxTxx-xx-xx")
199 | parser.add_argument("--vid_path", required=True, help="A path to .mp4 video")
200 | parser.add_argument("--offset_sec", type=float, default=0.0)
201 | parser.add_argument("--v_start_i_sec", type=float, default=0.0)
202 | parser.add_argument("--device", default="cuda:0")
203 | args = parser.parse_args()
204 |
205 | synchformer = Synchformer().cuda().eval()
206 | synchformer.load_state_dict(
207 | torch.load(
208 | os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"),
209 | weights_only=True,
210 | map_location="cpu",
211 | )
212 | )
213 |
214 | main(args)
215 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/nn/quantize.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from einops import rearrange
8 | from torch.nn.utils import weight_norm
9 |
10 | from .layers import WNConv1d
11 |
12 |
13 | class VectorQuantize(nn.Module):
14 | """
15 | Implementation of VQ similar to Karpathy's repo:
16 | https://github.com/karpathy/deep-vector-quantization
17 | Additionally uses following tricks from Improved VQGAN
18 | (https://arxiv.org/pdf/2110.04627.pdf):
19 | 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
20 | for improved codebook usage
21 | 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
22 | improves training stability
23 | """
24 |
25 | def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
26 | super().__init__()
27 | self.codebook_size = codebook_size
28 | self.codebook_dim = codebook_dim
29 |
30 | self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
31 | self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
32 | self.codebook = nn.Embedding(codebook_size, codebook_dim)
33 |
34 | def forward(self, z):
35 | """Quantized the input tensor using a fixed codebook and returns
36 | the corresponding codebook vectors
37 |
38 | Parameters
39 | ----------
40 | z : Tensor[B x D x T]
41 |
42 | Returns
43 | -------
44 | Tensor[B x D x T]
45 | Quantized continuous representation of input
46 | Tensor[1]
47 | Commitment loss to train encoder to predict vectors closer to codebook
48 | entries
49 | Tensor[1]
50 | Codebook loss to update the codebook
51 | Tensor[B x T]
52 | Codebook indices (quantized discrete representation of input)
53 | Tensor[B x D x T]
54 | Projected latents (continuous representation of input before quantization)
55 | """
56 |
57 | # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
58 | z_e = self.in_proj(z) # z_e : (B x D x T)
59 | z_q, indices = self.decode_latents(z_e)
60 |
61 | commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
62 | codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
63 |
64 | z_q = (
65 | z_e + (z_q - z_e).detach()
66 | ) # noop in forward pass, straight-through gradient estimator in backward pass
67 |
68 | z_q = self.out_proj(z_q)
69 |
70 | return z_q, commitment_loss, codebook_loss, indices, z_e
71 |
72 | def embed_code(self, embed_id):
73 | return F.embedding(embed_id, self.codebook.weight)
74 |
75 | def decode_code(self, embed_id):
76 | return self.embed_code(embed_id).transpose(1, 2)
77 |
78 | def decode_latents(self, latents):
79 | encodings = rearrange(latents, "b d t -> (b t) d")
80 | codebook = self.codebook.weight # codebook: (N x D)
81 |
82 | # L2 normalize encodings and codebook (ViT-VQGAN)
83 | encodings = F.normalize(encodings)
84 | codebook = F.normalize(codebook)
85 |
86 | # Compute euclidean distance with codebook
87 | dist = (
88 | encodings.pow(2).sum(1, keepdim=True)
89 | - 2 * encodings @ codebook.t()
90 | + codebook.pow(2).sum(1, keepdim=True).t()
91 | )
92 | indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
93 | z_q = self.decode_code(indices)
94 | return z_q, indices
95 |
96 |
97 | class ResidualVectorQuantize(nn.Module):
98 | """
99 | Introduced in SoundStream: An end2end neural audio codec
100 | https://arxiv.org/abs/2107.03312
101 | """
102 |
103 | def __init__(
104 | self,
105 | input_dim: int = 512,
106 | n_codebooks: int = 9,
107 | codebook_size: int = 1024,
108 | codebook_dim: Union[int, list] = 8,
109 | quantizer_dropout: float = 0.0,
110 | ):
111 | super().__init__()
112 | if isinstance(codebook_dim, int):
113 | codebook_dim = [codebook_dim for _ in range(n_codebooks)]
114 |
115 | self.n_codebooks = n_codebooks
116 | self.codebook_dim = codebook_dim
117 | self.codebook_size = codebook_size
118 |
119 | self.quantizers = nn.ModuleList(
120 | [
121 | VectorQuantize(input_dim, codebook_size, codebook_dim[i])
122 | for i in range(n_codebooks)
123 | ]
124 | )
125 | self.quantizer_dropout = quantizer_dropout
126 |
127 | def forward(self, z, n_quantizers: int = None):
128 | """Quantized the input tensor using a fixed set of `n` codebooks and returns
129 | the corresponding codebook vectors
130 | Parameters
131 | ----------
132 | z : Tensor[B x D x T]
133 | n_quantizers : int, optional
134 | No. of quantizers to use
135 | (n_quantizers < self.n_codebooks ex: for quantizer dropout)
136 | Note: if `self.quantizer_dropout` is True, this argument is ignored
137 | when in training mode, and a random number of quantizers is used.
138 | Returns
139 | -------
140 | dict
141 | A dictionary with the following keys:
142 |
143 | "z" : Tensor[B x D x T]
144 | Quantized continuous representation of input
145 | "codes" : Tensor[B x N x T]
146 | Codebook indices for each codebook
147 | (quantized discrete representation of input)
148 | "latents" : Tensor[B x N*D x T]
149 | Projected latents (continuous representation of input before quantization)
150 | "vq/commitment_loss" : Tensor[1]
151 | Commitment loss to train encoder to predict vectors closer to codebook
152 | entries
153 | "vq/codebook_loss" : Tensor[1]
154 | Codebook loss to update the codebook
155 | """
156 | z_q = 0
157 | residual = z
158 | commitment_loss = 0
159 | codebook_loss = 0
160 |
161 | codebook_indices = []
162 | latents = []
163 |
164 | if n_quantizers is None:
165 | n_quantizers = self.n_codebooks
166 | if self.training:
167 | n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
168 | dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
169 | n_dropout = int(z.shape[0] * self.quantizer_dropout)
170 | n_quantizers[:n_dropout] = dropout[:n_dropout]
171 | n_quantizers = n_quantizers.to(z.device)
172 |
173 | for i, quantizer in enumerate(self.quantizers):
174 | if self.training is False and i >= n_quantizers:
175 | break
176 |
177 | z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
178 | residual
179 | )
180 |
181 | # Create mask to apply quantizer dropout
182 | mask = (
183 | torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
184 | )
185 | z_q = z_q + z_q_i * mask[:, None, None]
186 | residual = residual - z_q_i
187 |
188 | # Sum losses
189 | commitment_loss += (commitment_loss_i * mask).mean()
190 | codebook_loss += (codebook_loss_i * mask).mean()
191 |
192 | codebook_indices.append(indices_i)
193 | latents.append(z_e_i)
194 |
195 | codes = torch.stack(codebook_indices, dim=1)
196 | latents = torch.cat(latents, dim=1)
197 |
198 | return z_q, codes, latents, commitment_loss, codebook_loss
199 |
200 | def from_codes(self, codes: torch.Tensor):
201 | """Given the quantized codes, reconstruct the continuous representation
202 | Parameters
203 | ----------
204 | codes : Tensor[B x N x T]
205 | Quantized discrete representation of input
206 | Returns
207 | -------
208 | Tensor[B x D x T]
209 | Quantized continuous representation of input
210 | """
211 | z_q = 0.0
212 | z_p = []
213 | n_codebooks = codes.shape[1]
214 | for i in range(n_codebooks):
215 | z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
216 | z_p.append(z_p_i)
217 |
218 | z_q_i = self.quantizers[i].out_proj(z_p_i)
219 | z_q = z_q + z_q_i
220 | return z_q, torch.cat(z_p, dim=1), codes
221 |
222 | def from_latents(self, latents: torch.Tensor):
223 | """Given the unquantized latents, reconstruct the
224 | continuous representation after quantization.
225 |
226 | Parameters
227 | ----------
228 | latents : Tensor[B x N x T]
229 | Continuous representation of input after projection
230 |
231 | Returns
232 | -------
233 | Tensor[B x D x T]
234 | Quantized representation of full-projected space
235 | Tensor[B x D x T]
236 | Quantized representation of latent space
237 | """
238 | z_q = 0
239 | z_p = []
240 | codes = []
241 | dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
242 |
243 | n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
244 | 0
245 | ]
246 | for i in range(n_codebooks):
247 | j, k = dims[i], dims[i + 1]
248 | z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
249 | z_p.append(z_p_i)
250 | codes.append(codes_i)
251 |
252 | z_q_i = self.quantizers[i].out_proj(z_p_i)
253 | z_q = z_q + z_q_i
254 |
255 | return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
256 |
257 |
258 | if __name__ == "__main__":
259 | rvq = ResidualVectorQuantize(quantizer_dropout=True)
260 | x = torch.randn(16, 512, 80)
261 | y = rvq(x)
262 | print(y["latents"].shape)
263 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/model/base.py:
--------------------------------------------------------------------------------
1 | import math
2 | from dataclasses import dataclass
3 | from pathlib import Path
4 | from typing import Union
5 |
6 | import numpy as np
7 | import torch
8 | import tqdm
9 | from audiotools import AudioSignal
10 | from torch import nn
11 |
12 | SUPPORTED_VERSIONS = ["1.0.0"]
13 |
14 |
15 | @dataclass
16 | class DACFile:
17 | codes: torch.Tensor
18 |
19 | # Metadata
20 | chunk_length: int
21 | original_length: int
22 | input_db: float
23 | channels: int
24 | sample_rate: int
25 | padding: bool
26 | dac_version: str
27 |
28 | def save(self, path):
29 | artifacts = {
30 | "codes": self.codes.numpy().astype(np.uint16),
31 | "metadata": {
32 | "input_db": self.input_db.numpy().astype(np.float32),
33 | "original_length": self.original_length,
34 | "sample_rate": self.sample_rate,
35 | "chunk_length": self.chunk_length,
36 | "channels": self.channels,
37 | "padding": self.padding,
38 | "dac_version": SUPPORTED_VERSIONS[-1],
39 | },
40 | }
41 | path = Path(path).with_suffix(".dac")
42 | with open(path, "wb") as f:
43 | np.save(f, artifacts)
44 | return path
45 |
46 | @classmethod
47 | def load(cls, path):
48 | artifacts = np.load(path, allow_pickle=True)[()]
49 | codes = torch.from_numpy(artifacts["codes"].astype(int))
50 | if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
51 | raise RuntimeError(
52 | f"Given file {path} can't be loaded with this version of descript-audio-codec."
53 | )
54 | return cls(codes=codes, **artifacts["metadata"])
55 |
56 |
57 | class CodecMixin:
58 | @property
59 | def padding(self):
60 | if not hasattr(self, "_padding"):
61 | self._padding = True
62 | return self._padding
63 |
64 | @padding.setter
65 | def padding(self, value):
66 | assert isinstance(value, bool)
67 |
68 | layers = [
69 | l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
70 | ]
71 |
72 | for layer in layers:
73 | if value:
74 | if hasattr(layer, "original_padding"):
75 | layer.padding = layer.original_padding
76 | else:
77 | layer.original_padding = layer.padding
78 | layer.padding = tuple(0 for _ in range(len(layer.padding)))
79 |
80 | self._padding = value
81 |
82 | def get_delay(self):
83 | # Any number works here, delay is invariant to input length
84 | l_out = self.get_output_length(0)
85 | L = l_out
86 |
87 | layers = []
88 | for layer in self.modules():
89 | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
90 | layers.append(layer)
91 |
92 | for layer in reversed(layers):
93 | d = layer.dilation[0]
94 | k = layer.kernel_size[0]
95 | s = layer.stride[0]
96 |
97 | if isinstance(layer, nn.ConvTranspose1d):
98 | L = ((L - d * (k - 1) - 1) / s) + 1
99 | elif isinstance(layer, nn.Conv1d):
100 | L = (L - 1) * s + d * (k - 1) + 1
101 |
102 | L = math.ceil(L)
103 |
104 | l_in = L
105 |
106 | return (l_in - l_out) // 2
107 |
108 | def get_output_length(self, input_length):
109 | L = input_length
110 | # Calculate output length
111 | for layer in self.modules():
112 | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
113 | d = layer.dilation[0]
114 | k = layer.kernel_size[0]
115 | s = layer.stride[0]
116 |
117 | if isinstance(layer, nn.Conv1d):
118 | L = ((L - d * (k - 1) - 1) / s) + 1
119 | elif isinstance(layer, nn.ConvTranspose1d):
120 | L = (L - 1) * s + d * (k - 1) + 1
121 |
122 | L = math.floor(L)
123 | return L
124 |
125 | @torch.no_grad()
126 | def compress(
127 | self,
128 | audio_path_or_signal: Union[str, Path, AudioSignal],
129 | win_duration: float = 1.0,
130 | verbose: bool = False,
131 | normalize_db: float = -16,
132 | n_quantizers: int = None,
133 | ) -> DACFile:
134 | """Processes an audio signal from a file or AudioSignal object into
135 | discrete codes. This function processes the signal in short windows,
136 | using constant GPU memory.
137 |
138 | Parameters
139 | ----------
140 | audio_path_or_signal : Union[str, Path, AudioSignal]
141 | audio signal to reconstruct
142 | win_duration : float, optional
143 | window duration in seconds, by default 5.0
144 | verbose : bool, optional
145 | by default False
146 | normalize_db : float, optional
147 | normalize db, by default -16
148 |
149 | Returns
150 | -------
151 | DACFile
152 | Object containing compressed codes and metadata
153 | required for decompression
154 | """
155 | audio_signal = audio_path_or_signal
156 | if isinstance(audio_signal, (str, Path)):
157 | audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
158 |
159 | self.eval()
160 | original_padding = self.padding
161 | original_device = audio_signal.device
162 |
163 | audio_signal = audio_signal.clone()
164 | audio_signal = audio_signal.to_mono()
165 | original_sr = audio_signal.sample_rate
166 |
167 | resample_fn = audio_signal.resample
168 | loudness_fn = audio_signal.loudness
169 |
170 | # If audio is > 10 minutes long, use the ffmpeg versions
171 | if audio_signal.signal_duration >= 10 * 60 * 60:
172 | resample_fn = audio_signal.ffmpeg_resample
173 | loudness_fn = audio_signal.ffmpeg_loudness
174 |
175 | original_length = audio_signal.signal_length
176 | resample_fn(self.sample_rate)
177 | input_db = loudness_fn()
178 |
179 | if normalize_db is not None:
180 | audio_signal.normalize(normalize_db)
181 | audio_signal.ensure_max_of_audio()
182 |
183 | nb, nac, nt = audio_signal.audio_data.shape
184 | audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
185 | win_duration = (
186 | audio_signal.signal_duration if win_duration is None else win_duration
187 | )
188 |
189 | if audio_signal.signal_duration <= win_duration:
190 | # Unchunked compression (used if signal length < win duration)
191 | self.padding = True
192 | n_samples = nt
193 | hop = nt
194 | else:
195 | # Chunked inference
196 | self.padding = False
197 | # Zero-pad signal on either side by the delay
198 | audio_signal.zero_pad(self.delay, self.delay)
199 | n_samples = int(win_duration * self.sample_rate)
200 | # Round n_samples to nearest hop length multiple
201 | n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
202 | hop = self.get_output_length(n_samples)
203 |
204 | codes = []
205 | range_fn = range if not verbose else tqdm.trange
206 |
207 | for i in range_fn(0, nt, hop):
208 | x = audio_signal[..., i : i + n_samples]
209 | x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
210 |
211 | audio_data = x.audio_data.to(self.device)
212 | audio_data = self.preprocess(audio_data, self.sample_rate)
213 | _, c, _, _, _ = self.encode(audio_data, n_quantizers)
214 | codes.append(c.to(original_device))
215 | chunk_length = c.shape[-1]
216 |
217 | codes = torch.cat(codes, dim=-1)
218 |
219 | dac_file = DACFile(
220 | codes=codes,
221 | chunk_length=chunk_length,
222 | original_length=original_length,
223 | input_db=input_db,
224 | channels=nac,
225 | sample_rate=original_sr,
226 | padding=self.padding,
227 | dac_version=SUPPORTED_VERSIONS[-1],
228 | )
229 |
230 | if n_quantizers is not None:
231 | codes = codes[:, :n_quantizers, :]
232 |
233 | self.padding = original_padding
234 | return dac_file
235 |
236 | @torch.no_grad()
237 | def decompress(
238 | self,
239 | obj: Union[str, Path, DACFile],
240 | verbose: bool = False,
241 | ) -> AudioSignal:
242 | """Reconstruct audio from a given .dac file
243 |
244 | Parameters
245 | ----------
246 | obj : Union[str, Path, DACFile]
247 | .dac file location or corresponding DACFile object.
248 | verbose : bool, optional
249 | Prints progress if True, by default False
250 |
251 | Returns
252 | -------
253 | AudioSignal
254 | Object with the reconstructed audio
255 | """
256 | self.eval()
257 | if isinstance(obj, (str, Path)):
258 | obj = DACFile.load(obj)
259 |
260 | original_padding = self.padding
261 | self.padding = obj.padding
262 |
263 | range_fn = range if not verbose else tqdm.trange
264 | codes = obj.codes
265 | original_device = codes.device
266 | chunk_length = obj.chunk_length
267 | recons = []
268 |
269 | for i in range_fn(0, codes.shape[-1], chunk_length):
270 | c = codes[..., i : i + chunk_length].to(self.device)
271 | z = self.quantizer.from_codes(c)[0]
272 | r = self.decode(z)
273 | recons.append(r.to(original_device))
274 |
275 | recons = torch.cat(recons, dim=-1)
276 | recons = AudioSignal(recons, self.sample_rate)
277 |
278 | resample_fn = recons.resample
279 | loudness_fn = recons.loudness
280 |
281 | # If audio is > 10 minutes long, use the ffmpeg versions
282 | if recons.signal_duration >= 10 * 60 * 60:
283 | resample_fn = recons.ffmpeg_resample
284 | loudness_fn = recons.ffmpeg_loudness
285 |
286 | if obj.input_db is not None:
287 | recons.normalize(obj.input_db)
288 |
289 | resample_fn(obj.sample_rate)
290 |
291 | if obj.original_length is not None:
292 | recons = recons[..., : obj.original_length]
293 | loudness_fn()
294 | recons.audio_data = recons.audio_data.reshape(
295 | -1, obj.channels, obj.original_length
296 | )
297 | else:
298 | loudness_fn()
299 |
300 | self.padding = original_padding
301 | return recons
302 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/synchformer/video_model_builder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3 | # Copyright 2020 Ross Wightman
4 | # Modified Model definition
5 |
6 | from collections import OrderedDict
7 | from functools import partial
8 |
9 | import torch
10 | import torch.nn as nn
11 | from timm.layers import trunc_normal_
12 |
13 | from .vit_helper import PatchEmbed, PatchEmbed3D, DividedSpaceTimeBlock
14 |
15 |
16 | class VisionTransformer(nn.Module):
17 | """Vision Transformer with support for patch or hybrid CNN input stage"""
18 |
19 | def __init__(self, cfg):
20 | super().__init__()
21 | self.img_size = cfg.DATA.TRAIN_CROP_SIZE
22 | self.patch_size = cfg.VIT.PATCH_SIZE
23 | self.in_chans = cfg.VIT.CHANNELS
24 | if cfg.TRAIN.DATASET == "Epickitchens":
25 | self.num_classes = [97, 300]
26 | else:
27 | self.num_classes = cfg.MODEL.NUM_CLASSES
28 | self.embed_dim = cfg.VIT.EMBED_DIM
29 | self.depth = cfg.VIT.DEPTH
30 | self.num_heads = cfg.VIT.NUM_HEADS
31 | self.mlp_ratio = cfg.VIT.MLP_RATIO
32 | self.qkv_bias = cfg.VIT.QKV_BIAS
33 | self.drop_rate = cfg.VIT.DROP
34 | self.drop_path_rate = cfg.VIT.DROP_PATH
35 | self.head_dropout = cfg.VIT.HEAD_DROPOUT
36 | self.video_input = cfg.VIT.VIDEO_INPUT
37 | self.temporal_resolution = cfg.VIT.TEMPORAL_RESOLUTION
38 | self.use_mlp = cfg.VIT.USE_MLP
39 | self.num_features = self.embed_dim
40 | norm_layer = partial(nn.LayerNorm, eps=1e-6)
41 | self.attn_drop_rate = cfg.VIT.ATTN_DROPOUT
42 | self.head_act = cfg.VIT.HEAD_ACT
43 | self.cfg = cfg
44 |
45 | # Patch Embedding
46 | self.patch_embed = PatchEmbed(
47 | img_size=224, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim
48 | )
49 |
50 | # 3D Patch Embedding
51 | self.patch_embed_3d = PatchEmbed3D(
52 | img_size=self.img_size,
53 | temporal_resolution=self.temporal_resolution,
54 | patch_size=self.patch_size,
55 | in_chans=self.in_chans,
56 | embed_dim=self.embed_dim,
57 | z_block_size=self.cfg.VIT.PATCH_SIZE_TEMP,
58 | )
59 | self.patch_embed_3d.proj.weight.data = torch.zeros_like(self.patch_embed_3d.proj.weight.data)
60 |
61 | # Number of patches
62 | if self.video_input:
63 | num_patches = self.patch_embed.num_patches * self.temporal_resolution
64 | else:
65 | num_patches = self.patch_embed.num_patches
66 | self.num_patches = num_patches
67 |
68 | # CLS token
69 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
70 | trunc_normal_(self.cls_token, std=0.02)
71 |
72 | # Positional embedding
73 | self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim))
74 | self.pos_drop = nn.Dropout(p=cfg.VIT.POS_DROPOUT)
75 | trunc_normal_(self.pos_embed, std=0.02)
76 |
77 | if self.cfg.VIT.POS_EMBED == "joint":
78 | self.st_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
79 | trunc_normal_(self.st_embed, std=0.02)
80 | elif self.cfg.VIT.POS_EMBED == "separate":
81 | self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim))
82 |
83 | # Layer Blocks
84 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
85 | if self.cfg.VIT.ATTN_LAYER == "divided":
86 | self.blocks = nn.ModuleList(
87 | [
88 | DividedSpaceTimeBlock(
89 | attn_type=cfg.VIT.ATTN_LAYER,
90 | dim=self.embed_dim,
91 | num_heads=self.num_heads,
92 | mlp_ratio=self.mlp_ratio,
93 | qkv_bias=self.qkv_bias,
94 | drop=self.drop_rate,
95 | attn_drop=self.attn_drop_rate,
96 | drop_path=dpr[i],
97 | norm_layer=norm_layer,
98 | )
99 | for i in range(self.depth)
100 | ]
101 | )
102 |
103 | self.norm = norm_layer(self.embed_dim)
104 |
105 | # MLP head
106 | if self.use_mlp:
107 | hidden_dim = self.embed_dim
108 | if self.head_act == "tanh":
109 | # logging.info("Using TanH activation in MLP")
110 | act = nn.Tanh()
111 | elif self.head_act == "gelu":
112 | # logging.info("Using GELU activation in MLP")
113 | act = nn.GELU()
114 | else:
115 | # logging.info("Using ReLU activation in MLP")
116 | act = nn.ReLU()
117 | self.pre_logits = nn.Sequential(
118 | OrderedDict(
119 | [
120 | ("fc", nn.Linear(self.embed_dim, hidden_dim)),
121 | ("act", act),
122 | ]
123 | )
124 | )
125 | else:
126 | self.pre_logits = nn.Identity()
127 |
128 | # Classifier Head
129 | self.head_drop = nn.Dropout(p=self.head_dropout)
130 | if isinstance(self.num_classes, (list,)) and len(self.num_classes) > 1:
131 | for a, i in enumerate(range(len(self.num_classes))):
132 | setattr(self, "head%d" % a, nn.Linear(self.embed_dim, self.num_classes[i]))
133 | else:
134 | self.head = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
135 |
136 | # Initialize weights
137 | self.apply(self._init_weights)
138 |
139 | def _init_weights(self, m):
140 | if isinstance(m, nn.Linear):
141 | trunc_normal_(m.weight, std=0.02)
142 | if isinstance(m, nn.Linear) and m.bias is not None:
143 | nn.init.constant_(m.bias, 0)
144 | elif isinstance(m, nn.LayerNorm):
145 | nn.init.constant_(m.bias, 0)
146 | nn.init.constant_(m.weight, 1.0)
147 |
148 | @torch.jit.ignore
149 | def no_weight_decay(self):
150 | if self.cfg.VIT.POS_EMBED == "joint":
151 | return {"pos_embed", "cls_token", "st_embed"}
152 | else:
153 | return {"pos_embed", "cls_token", "temp_embed"}
154 |
155 | def get_classifier(self):
156 | return self.head
157 |
158 | def reset_classifier(self, num_classes, global_pool=""):
159 | self.num_classes = num_classes
160 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
161 |
162 | def forward_features(self, x):
163 | # if self.video_input:
164 | # x = x[0]
165 | B = x.shape[0]
166 |
167 | # Tokenize input
168 | # if self.cfg.VIT.PATCH_SIZE_TEMP > 1:
169 | # for simplicity of mapping between content dimensions (input x) and token dims (after patching)
170 | # we use the same trick as for AST (see modeling_ast.ASTModel.forward for the details):
171 |
172 | # apply patching on input
173 | x = self.patch_embed_3d(x)
174 | tok_mask = None
175 |
176 | # else:
177 | # tok_mask = None
178 | # # 2D tokenization
179 | # if self.video_input:
180 | # x = x.permute(0, 2, 1, 3, 4)
181 | # (B, T, C, H, W) = x.shape
182 | # x = x.reshape(B * T, C, H, W)
183 |
184 | # x = self.patch_embed(x)
185 |
186 | # if self.video_input:
187 | # (B2, T2, D2) = x.shape
188 | # x = x.reshape(B, T * T2, D2)
189 |
190 | # Append CLS token
191 | cls_tokens = self.cls_token.expand(B, -1, -1)
192 | x = torch.cat((cls_tokens, x), dim=1)
193 | # if tok_mask is not None:
194 | # # prepend 1(=keep) to the mask to account for the CLS token as well
195 | # tok_mask = torch.cat((torch.ones_like(tok_mask[:, [0]]), tok_mask), dim=1)
196 |
197 | # Interpolate positinoal embeddings
198 | # if self.cfg.DATA.TRAIN_CROP_SIZE != 224:
199 | # pos_embed = self.pos_embed
200 | # N = pos_embed.shape[1] - 1
201 | # npatch = int((x.size(1) - 1) / self.temporal_resolution)
202 | # class_emb = pos_embed[:, 0]
203 | # pos_embed = pos_embed[:, 1:]
204 | # dim = x.shape[-1]
205 | # pos_embed = torch.nn.functional.interpolate(
206 | # pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
207 | # scale_factor=math.sqrt(npatch / N),
208 | # mode='bicubic',
209 | # )
210 | # pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
211 | # new_pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
212 | # else:
213 | new_pos_embed = self.pos_embed
214 | npatch = self.patch_embed.num_patches
215 |
216 | # Add positional embeddings to input
217 | if self.video_input:
218 | if self.cfg.VIT.POS_EMBED == "separate":
219 | cls_embed = self.pos_embed[:, 0, :].unsqueeze(1)
220 | tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1)
221 | tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1)
222 | total_pos_embed = tile_pos_embed + tile_temporal_embed
223 | total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1)
224 | x = x + total_pos_embed
225 | elif self.cfg.VIT.POS_EMBED == "joint":
226 | x = x + self.st_embed
227 | else:
228 | # image input
229 | x = x + new_pos_embed
230 |
231 | # Apply positional dropout
232 | x = self.pos_drop(x)
233 |
234 | # Encoding using transformer layers
235 | for i, blk in enumerate(self.blocks):
236 | x = blk(
237 | x,
238 | seq_len=npatch,
239 | num_frames=self.temporal_resolution,
240 | approx=self.cfg.VIT.APPROX_ATTN_TYPE,
241 | num_landmarks=self.cfg.VIT.APPROX_ATTN_DIM,
242 | tok_mask=tok_mask,
243 | )
244 |
245 | ### v-iashin: I moved it to the forward pass
246 | # x = self.norm(x)[:, 0]
247 | # x = self.pre_logits(x)
248 | ###
249 | return x, tok_mask
250 |
251 | # def forward(self, x):
252 | # x = self.forward_features(x)
253 | # ### v-iashin: here. This should leave the same forward output as before
254 | # x = self.norm(x)[:, 0]
255 | # x = self.pre_logits(x)
256 | # ###
257 | # x = self.head_drop(x)
258 | # if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1:
259 | # output = []
260 | # for head in range(len(self.num_classes)):
261 | # x_out = getattr(self, "head%d" % head)(x)
262 | # if not self.training:
263 | # x_out = torch.nn.functional.softmax(x_out, dim=-1)
264 | # output.append(x_out)
265 | # return output
266 | # else:
267 | # x = self.head(x)
268 | # if not self.training:
269 | # x = torch.nn.functional.softmax(x, dim=-1)
270 | # return x
271 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from loguru import logger
4 | from torchvision import transforms
5 | from torchvision.transforms import v2
6 | from diffusers.utils.torch_utils import randn_tensor
7 | from transformers import AutoTokenizer, AutoModel, ClapTextModelWithProjection
8 | from ..models.dac_vae.model.dac import DAC
9 | from ..models.synchformer import Synchformer
10 | from ..models.hifi_foley import HunyuanVideoFoley
11 | from .config_utils import load_yaml, AttributeDict
12 | from .schedulers import FlowMatchDiscreteScheduler
13 | from tqdm import tqdm
14 |
15 | def load_state_dict(model, model_path):
16 | logger.info(f"Loading model state dict from: {model_path}")
17 | state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False)
18 |
19 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
20 |
21 | if missing_keys:
22 | logger.warning(f"Missing keys in state dict ({len(missing_keys)} keys):")
23 | for key in missing_keys:
24 | logger.warning(f" - {key}")
25 | else:
26 | logger.info("No missing keys found")
27 |
28 | if unexpected_keys:
29 | logger.warning(f"Unexpected keys in state dict ({len(unexpected_keys)} keys):")
30 | for key in unexpected_keys:
31 | logger.warning(f" - {key}")
32 | else:
33 | logger.info("No unexpected keys found")
34 |
35 | logger.info("Model state dict loaded successfully")
36 | return model
37 |
38 | def load_model(model_path, config_path, device):
39 | logger.info("Starting model loading process...")
40 | logger.info(f"Configuration file: {config_path}")
41 | logger.info(f"Model weights dir: {model_path}")
42 | logger.info(f"Target device: {device}")
43 |
44 | cfg = load_yaml(config_path)
45 | logger.info("Configuration loaded successfully")
46 |
47 | # HunyuanVideoFoley
48 | logger.info("Loading HunyuanVideoFoley main model...")
49 | foley_model = HunyuanVideoFoley(cfg, dtype=torch.bfloat16, device=device).to(device=device, dtype=torch.bfloat16)
50 | foley_model = load_state_dict(foley_model, os.path.join(model_path, "hunyuanvideo_foley.pth"))
51 | foley_model.eval()
52 | logger.info("HunyuanVideoFoley model loaded and set to evaluation mode")
53 |
54 | # DAC-VAE
55 | dac_path = os.path.join(model_path, "vae_128d_48k.pth")
56 | logger.info(f"Loading DAC VAE model from: {dac_path}")
57 | try:
58 | # Try loading with the standard DAC.load method
59 | dac_model = DAC.load(dac_path)
60 | except TypeError as e:
61 | if "map_location" in str(e):
62 | # Handle the map_location conflict by manually loading the state dict
63 | logger.warning(f"DAC.load() failed with map_location conflict: {e}")
64 | logger.info("Attempting manual DAC model loading...")
65 |
66 | # Create DAC model instance with appropriate parameters for vae_128d_48k
67 | # Based on filename, this appears to be 128-dimensional latent space, 48kHz sample rate
68 | dac_model = DAC(
69 | encoder_dim=64,
70 | latent_dim=128, # 128d as indicated by filename
71 | decoder_dim=1536,
72 | sample_rate=48000, # 48k as indicated by filename
73 | continuous=False
74 | )
75 | state_dict = torch.load(dac_path, map_location="cpu", weights_only=False)
76 | dac_model.load_state_dict(state_dict, strict=False)
77 | else:
78 | raise e
79 |
80 | dac_model = dac_model.to(device)
81 | dac_model.requires_grad_(False)
82 | dac_model.eval()
83 | logger.info("DAC VAE model loaded successfully")
84 |
85 | # Siglip2 visual-encoder
86 | logger.info("Loading SigLIP2 visual encoder...")
87 | siglip2_preprocess = transforms.Compose([
88 | transforms.Resize((512, 512)),
89 | transforms.ToTensor(),
90 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
91 | ])
92 |
93 | # Try multiple approaches to load SigLIP2
94 | siglip2_model = None
95 |
96 | # Method 1: Try with standard transformers AutoModel
97 | try:
98 | siglip2_model = AutoModel.from_pretrained("google/siglip2-base-patch16-512", trust_remote_code=True).to(device).eval()
99 | logger.info("SigLIP2 loaded using standard transformers")
100 | except Exception as e1:
101 | logger.warning(f"Standard transformers loading failed: {e1}")
102 |
103 | # Method 2: Try loading from local cache or downloaded weights
104 | try:
105 | from transformers import SiglipVisionModel
106 | siglip2_model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-512").to(device).eval()
107 | logger.info("SigLIP2 loaded using SiglipVisionModel (base variant)")
108 | except Exception as e2:
109 | logger.warning(f"SiglipVisionModel loading failed: {e2}")
110 |
111 | # Method 3: Try using a compatible CLIP model as fallback
112 | try:
113 | from transformers import CLIPVisionModel
114 | logger.warning("Falling back to CLIP vision model as SigLIP2 is not available")
115 | siglip2_model = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14-336").to(device).eval()
116 | logger.info("Using CLIP vision model as fallback")
117 | except Exception as e3:
118 | logger.error(f"All vision model loading attempts failed: {e3}")
119 | raise RuntimeError(
120 | "Could not load SigLIP2 vision encoder. Please ensure you have a compatible "
121 | "transformers version installed. You can try:\n"
122 | "1. pip install transformers>=4.37.0\n"
123 | "2. Or manually download the model weights"
124 | )
125 |
126 | logger.info("SigLIP2 model and preprocessing pipeline loaded successfully")
127 |
128 | # clap text-encoder
129 | logger.info("Loading CLAP text encoder...")
130 | clap_tokenizer = AutoTokenizer.from_pretrained("laion/larger_clap_general")
131 | clap_model = ClapTextModelWithProjection.from_pretrained("laion/larger_clap_general").to(device)
132 | logger.info("CLAP tokenizer and model loaded successfully")
133 |
134 | # syncformer
135 | syncformer_path = os.path.join(model_path, "synchformer_state_dict.pth")
136 | logger.info(f"Loading Synchformer model from: {syncformer_path}")
137 | syncformer_preprocess = v2.Compose(
138 | [
139 | v2.Resize(224, interpolation=v2.InterpolationMode.BICUBIC),
140 | v2.CenterCrop(224),
141 | v2.ToImage(),
142 | v2.ToDtype(torch.float32, scale=True),
143 | v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
144 | ]
145 | )
146 |
147 | syncformer_model = Synchformer()
148 | syncformer_model.load_state_dict(torch.load(syncformer_path, weights_only=False, map_location="cpu"))
149 | syncformer_model = syncformer_model.to(device).eval()
150 | logger.info("Synchformer model and preprocessing pipeline loaded successfully")
151 |
152 |
153 | logger.info("Creating model dictionary with attribute access...")
154 | model_dict = AttributeDict({
155 | 'foley_model': foley_model,
156 | 'dac_model': dac_model,
157 | 'siglip2_preprocess': siglip2_preprocess,
158 | 'siglip2_model': siglip2_model,
159 | 'clap_tokenizer': clap_tokenizer,
160 | 'clap_model': clap_model,
161 | 'syncformer_preprocess': syncformer_preprocess,
162 | 'syncformer_model': syncformer_model,
163 | 'device': device,
164 | })
165 |
166 | logger.info("All models loaded successfully!")
167 | logger.info("Available model components:")
168 | for key in model_dict.keys():
169 | logger.info(f" - {key}")
170 | logger.info("Models can be accessed via attribute notation (e.g., models.foley_model)")
171 |
172 | return model_dict, cfg
173 |
174 | def retrieve_timesteps(
175 | scheduler,
176 | num_inference_steps,
177 | device,
178 | **kwargs,
179 | ):
180 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
181 | timesteps = scheduler.timesteps
182 | return timesteps, num_inference_steps
183 |
184 |
185 | def prepare_latents(scheduler, batch_size, num_channels_latents, length, dtype, device):
186 | shape = (batch_size, num_channels_latents, int(length))
187 | latents = randn_tensor(shape, device=device, dtype=dtype)
188 |
189 | # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
190 | if hasattr(scheduler, "init_noise_sigma"):
191 | # scale the initial noise by the standard deviation required by the scheduler
192 | latents = latents * scheduler.init_noise_sigma
193 |
194 | return latents
195 |
196 |
197 | @torch.no_grad()
198 | def denoise_process(visual_feats, text_feats, audio_len_in_s, model_dict, cfg, guidance_scale=4.5, num_inference_steps=50, batch_size=1):
199 |
200 | target_dtype = model_dict.foley_model.dtype
201 | autocast_enabled = target_dtype != torch.float32
202 | device = model_dict.device
203 |
204 | scheduler = FlowMatchDiscreteScheduler(
205 | shift=cfg.diffusion_config.sample_flow_shift,
206 | reverse=cfg.diffusion_config.flow_reverse,
207 | solver=cfg.diffusion_config.flow_solver,
208 | use_flux_shift=cfg.diffusion_config.sample_use_flux_shift,
209 | flux_base_shift=cfg.diffusion_config.flux_base_shift,
210 | flux_max_shift=cfg.diffusion_config.flux_max_shift,
211 | )
212 |
213 | timesteps, num_inference_steps = retrieve_timesteps(
214 | scheduler,
215 | num_inference_steps,
216 | device,
217 | )
218 |
219 | latents = prepare_latents(
220 | scheduler,
221 | batch_size=batch_size,
222 | num_channels_latents=cfg.model_config.model_kwargs.audio_vae_latent_dim,
223 | length=audio_len_in_s * cfg.model_config.model_kwargs.audio_frame_rate,
224 | dtype=target_dtype,
225 | device=device,
226 | )
227 |
228 | # Denoise loop
229 | for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Denoising steps"):
230 | # noise latents
231 | latent_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
232 | latent_input = scheduler.scale_model_input(latent_input, t)
233 |
234 | t_expand = t.repeat(latent_input.shape[0])
235 |
236 | # siglip2 features
237 | siglip2_feat = visual_feats.siglip2_feat.repeat(batch_size, 1, 1) # Repeat for batch_size
238 | uncond_siglip2_feat = model_dict.foley_model.get_empty_clip_sequence(
239 | bs=batch_size, len=siglip2_feat.shape[1]
240 | ).to(device)
241 |
242 | if guidance_scale is not None and guidance_scale > 1.0:
243 | siglip2_feat_input = torch.cat([uncond_siglip2_feat, siglip2_feat], dim=0)
244 | else:
245 | siglip2_feat_input = siglip2_feat
246 |
247 | # syncformer features
248 | syncformer_feat = visual_feats.syncformer_feat.repeat(batch_size, 1, 1) # Repeat for batch_size
249 | uncond_syncformer_feat = model_dict.foley_model.get_empty_sync_sequence(
250 | bs=batch_size, len=syncformer_feat.shape[1]
251 | ).to(device)
252 | if guidance_scale is not None and guidance_scale > 1.0:
253 | syncformer_feat_input = torch.cat([uncond_syncformer_feat, syncformer_feat], dim=0)
254 | else:
255 | syncformer_feat_input = syncformer_feat
256 |
257 | # text features
258 | text_feat_repeated = text_feats.text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size
259 | uncond_text_feat_repeated = text_feats.uncond_text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size
260 | if guidance_scale is not None and guidance_scale > 1.0:
261 | text_feat_input = torch.cat([uncond_text_feat_repeated, text_feat_repeated], dim=0)
262 | else:
263 | text_feat_input = text_feat_repeated
264 |
265 | with torch.autocast(device_type=device.type, enabled=autocast_enabled, dtype=target_dtype):
266 | # Predict the noise residual
267 | noise_pred = model_dict.foley_model(
268 | x=latent_input,
269 | t=t_expand,
270 | cond=text_feat_input,
271 | clip_feat=siglip2_feat_input,
272 | sync_feat=syncformer_feat_input,
273 | return_dict=True,
274 | )["x"]
275 |
276 | noise_pred = noise_pred.to(dtype=torch.float32)
277 |
278 | if guidance_scale is not None and guidance_scale > 1.0:
279 | # Perform classifier-free guidance
280 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
281 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
282 |
283 | # Compute the previous noisy sample x_t -> x_t-1
284 | latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
285 |
286 | # Post-process the latents to audio
287 |
288 | with torch.no_grad():
289 | audio = model_dict.dac_model.decode(latents)
290 | audio = audio.float().cpu()
291 |
292 | audio = audio[:, :int(audio_len_in_s*model_dict.dac_model.sample_rate)]
293 |
294 | return audio, model_dict.dac_model.sample_rate
295 |
296 |
297 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ComfyUI HunyuanVideo-Foley Custom Node
2 |
3 | This is a ComfyUI custom node wrapper for the HunyuanVideo-Foley model, which generates realistic audio from video and text descriptions.
4 |
5 | ## Features
6 |
7 | - **Text-Video-to-Audio Synthesis**: Generate realistic audio that matches your video content
8 | - **Flexible Text Prompts**: Use optional text descriptions to guide audio generation
9 | - **Multiple Samples**: Generate up to 6 different audio variations per inference
10 | - **Configurable Parameters**: Control guidance scale, inference steps, and sampling
11 | - **Seed Control**: Reproducible results with seed parameter
12 | - **Model Caching**: Efficient model loading and reuse across generations
13 | - **Automatic Model Downloads**: Models are automatically downloaded to `ComfyUI/models/foley/` when needed
14 |
15 |
16 |
17 | ## Features
18 |
19 | - **Text-Video-to-Audio Synthesis**: Generate realistic audio that matches your video content
20 | - **Flexible Text Prompts**: Use optional text descriptions to guide audio generation
21 | - **Multiple Samples**: Generate up to 6 different audio variations per inference
22 | - **Configurable Parameters**: Control guidance scale, inference steps, and sampling
23 | - **Seed Control**: Reproducible results with seed parameter
24 | - **Model Caching**: Efficient model loading and reuse across generations
25 | - **Automatic Model Downloads**: Models are automatically downloaded to `ComfyUI/models/foley/` when needed
26 |
27 | ## Installation
28 |
29 | 1. **Clone this repository** into your ComfyUI custom_nodes directory:
30 | ```bash
31 | cd ComfyUI/custom_nodes
32 | git clone https://github.com/if-ai/ComfyUI_HunyuanVideoFoley.git
33 | ```
34 |
35 | 2. **Install dependencies**:
36 | ```bash
37 | cd ComfyUI_HunyuanVideoFoley
38 | pip install -r requirements.txt
39 | ```
40 |
41 | 3. **Run the installation script** (recommended):
42 | ```bash
43 | python install.py
44 | ```
45 |
46 | 4. **Restart ComfyUI** to load the new nodes.
47 |
48 | ### Model Setup
49 |
50 | The models can be obtained in two ways:
51 |
52 | #### Option 1: Automatic Download (Recommended)
53 | - Models will be automatically downloaded to `ComfyUI/models/foley/` when you first run the node
54 | - No manual setup required
55 | - Progress will be shown in the ComfyUI console
56 |
57 | #### Option 2: Manual Download
58 | - Download models from [HuggingFace](https://huggingface.co/tencent/HunyuanVideo-Foley)
59 | - Place models in `ComfyUI/models/foley/` (recommended) or `./pretrained_models/` directory
60 | - Ensure the config file is at `configs/hunyuanvideo-foley-xxl.yaml`
61 |
62 | ## Operation Guide: How to Use the Nodes
63 |
64 | This custom node package is designed in a modular way for maximum flexibility and efficiency. Here is the recommended workflow and an explanation of what each node does.
65 |
66 | ### Recommended Workflow
67 |
68 | The most powerful and efficient way to use these nodes is to chain them together in the following order:
69 |
70 | `Model Loader` → `Dependencies Loader` → `Torch Compile` → `Generator (Advanced)`
71 |
72 | This setup allows you to load the models only once, apply performance optimizations, and then run the generator multiple times without reloading, saving significant time and VRAM.
73 |
74 | ### Node Details
75 |
76 | #### 1. HunyuanVideo-Foley Model Loader (FP8)
77 | This is the starting point. It loads the main (and very large) audio generation model into memory.
78 |
79 | - **quantization**: This is the most important setting for saving VRAM.
80 | - `none`: Loads the model in its original format (highest VRAM usage).
81 | - `fp8_e5m2` / `fp8_e4m3fn`: These options use **FP8 quantization**, a technique that stores the model's weights in a much smaller format. This can save several gigabytes of VRAM with a minimal impact on audio quality, making it possible to run on GPUs with less memory.
82 | - **cpu_offload**: If `True`, the model will be kept in your regular RAM instead of VRAM. This is not the same as the generator's offload setting; use this if you are loading multiple different models in your workflow and need to conserve VRAM.
83 |
84 | #### 2. HunyuanVideo-Foley Dependencies
85 | This node takes the main model from the loader and then loads all the smaller, auxiliary models required for the process (the VAE, text encoder, and visual feature extractors).
86 |
87 | #### 3. HunyuanVideo-Foley Torch Compile
88 | This is an optional but highly recommended performance-enhancing node. It uses `torch.compile` to optimize the model's code for your specific hardware.
89 | - **Note**: The very first time you run a workflow with this node, it will take a minute or two to perform the compilation. However, every subsequent run will be significantly faster (often 20-30%).
90 |
91 | - **`compile_mode`**: This controls the trade-off between compilation time and the amount of performance gain.
92 | - `default`: The best balance. It provides a good speedup with a reasonable initial compile time.
93 | - `reduce-overhead`: Compiles more slowly but can reduce the overhead of running the model, which might be faster for very small audio generations.
94 | - `max-autotune`: Takes the longest to compile initially, but it tries many different optimizations to find the absolute fastest option for your specific hardware.
95 |
96 | - **`backend`**: This is an advanced setting that changes the underlying compiler used by PyTorch. For most users, the default `inductor` is the best choice.
97 |
98 | #### 4. HunyuanVideo-Foley Generator (Advanced)
99 | This is the main workhorse node where the audio generation happens.
100 |
101 | - **video / images**: Your visual input. You can provide either a video file or a batch of images from another node.
102 | - **compiled_model**: The input for the model prepared by the upstream nodes.
103 | - **text_prompt / negative_prompt**: Your descriptions of the sound you want (and don't want).
104 | - **guidance_scale / num_inference_steps / seed**: Standard diffusion model controls for creativity vs. prompt adherence, quality vs. speed, and reproducibility.
105 | - **enabled**: A simple switch. If `False`, the node does nothing and passes through an empty/silent output. This is useful for disabling parts of a complex workflow without having to disconnect them.
106 | - **silent_audio**: Controls what happens when the node is disabled or fails. If `True`, it outputs a valid, silent audio clip, which prevents downstream nodes (like video combiners) from failing. If `False`, it outputs `None`.
107 |
108 | ### Understanding the Memory Options
109 |
110 | The two memory-related checkboxes on the Generator node are crucial for managing your GPU's resources. Here is exactly what they do:
111 |
112 | - **`cpu_offload`**:
113 | - **What it does:** If this is `True`, the node will always move the models to your regular RAM (CPU) after the generation is complete. This is the best option for freeing up VRAM for other nodes in your workflow while still keeping the models ready for the next run without having to reload them from disk.
114 | - **Use this when:** You want to run other VRAM-intensive nodes after this one and plan to come back to the Foley generator later.
115 |
116 | - **`memory_efficient`**:
117 | - **What it does:** This is a more aggressive option. If `True`, the node will completely unload the models from memory (both VRAM and RAM) after the generation is finished.
118 | - **Important Distinction:** This process is smart. It will **only** unload the model if it was loaded by the generator node itself (the simple workflow). If the model was passed in from the `HunyuanVideoFoleyModelLoader` (the advanced workflow), it will **not** unload it, respecting the fact that you may want to reuse the pre-loaded model for another generation.
119 | - **Use this when:** You are finished with audio generation and want to free up as much memory as possible for completely different tasks.
120 |
121 | ### Performance Tuning & VRAM Usage
122 |
123 | The most memory-intensive part of the process is visual feature extraction. We've implemented batched processing to prevent out-of-memory errors with longer videos or on GPUs with less VRAM. You can control this with two settings on the **Generator (Advanced)** node:
124 |
125 | - **`feature_extraction_batch_size`**: This determines how many video frames are processed by the feature extractor models at once.
126 | - **Lower values** significantly reduce peak VRAM usage at the cost of slightly slower processing.
127 | - **Higher values** speed up processing but require more VRAM.
128 |
129 | - **`enable_profiling`**: If you check this box, the node will print detailed performance timings and peak VRAM usage for the feature extraction step to the console. This is highly recommended for finding the optimal batch size for your specific hardware.
130 |
131 | #### Recommended Batch Sizes
132 |
133 | These are general starting points. The optimal value can vary based on your exact GPU, driver version, and other running processes.
134 |
135 | | VRAM Tier | Video Resolution | Recommended Batch Size | Notes |
136 | | :--- | :--- | :--- | :--- |
137 | | **≤ 8 GB** | 480p | 4 - 8 | Start with 4. If successful, you can try increasing it. |
138 | | | 720p | 2 - 4 | Start with 2. 720p videos are demanding on low VRAM cards. |
139 | | **12-16 GB** | 480p | 16 - 32 | The default of 16 should work well. Can be increased for more speed. |
140 | | | 720p | 8 - 16 | Start with 8 or 16. |
141 | | **≥ 24 GB**| 480p | 32 - 64 | You can safely increase the batch size for maximum performance. |
142 | | | 720p | 16 - 32 | A batch size of 32 should be easily achievable. |
143 |
144 | ## Usage
145 |
146 | ### Node Types
147 |
148 | #### 1. HunyuanVideo-Foley Generator
149 | Main node for generating audio from video and text.
150 |
151 | **Inputs:**
152 | - **video**: Video input (VIDEO type)
153 | - **text_prompt**: Text description of desired audio (STRING)
154 | - **guidance_scale**: CFG scale for generation control (1.0-10.0, default: 4.5)
155 | - **num_inference_steps**: Number of denoising steps (10-100, default: 50)
156 | - **sample_nums**: Number of audio samples to generate (1-6, default: 1)
157 | - **seed**: Random seed for reproducibility (INT)
158 | - **model_path**: Path to pretrained models (optional, leave empty for auto-download)
159 | - **enabled**: Enable or disable the entire node. If disabled, it will pass through a silent or null audio output without processing. (BOOLEAN, default: True)
160 | - **silent_audio**: Controls the output when the node is disabled or fails. If true, it outputs a silent audio clip. If false, it outputs `None`. (BOOLEAN, default: True)
161 |
162 | **Outputs:**
163 | - **video_with_audio**: Video with generated audio merged (VIDEO)
164 | - **audio_only**: Generated audio file (AUDIO)
165 | - **status_message**: Generation status and info (STRING)
166 |
167 | ## ⚠ Important Limitations
168 |
169 | ### **Frame Count & Duration Limits**
170 | - **Maximum Frames**: 450 frames (hard limit)
171 | - **Maximum Duration**: 15 seconds at 30fps
172 | - **Recommended**: Keep videos ≤15 seconds for best results
173 |
174 | ### **FPS Recommendations**
175 | - **30fps**: Max 15 seconds (450 frames)
176 | - **24fps**: Max 18.75 seconds (450 frames)
177 | - **15fps**: Max 30 seconds (450 frames)
178 |
179 | ### **Long Video Solutions**
180 | For videos longer than 15 seconds:
181 | 1. **Reduce FPS**: Lower FPS allows longer duration within frame limit
182 | 2. **Segment Processing**: Split long videos into 15s segments
183 | 3. **Audio Merging**: Combine generated audio segments in post-processing
184 |
185 |
186 | ## Example Workflow
187 |
188 | 1. **Load Video**: Use a "Load Video" node to input your video file
189 | 2. **Add Generator**: Add the "HunyuanVideo-Foley Generator" node
190 | 3. **Connect Video**: Connect the video output to the generator's video input
191 | 4. **Set Prompt**: Enter a text description (e.g., "A person walks on frozen ice")
192 | 5. **Adjust Settings**: Configure guidance scale, steps, and sample count as needed
193 | 6. **Generate**: Run the workflow to generate audio
194 |
195 | ## Model Requirements
196 |
197 | The node expects the following model structure:
198 | ```
199 | ComfyUI\models\foley\hunyuanvideo-foley-xxl
200 | ├── hunyuanvideo_foley.pth # Main Foley model
201 | ├── vae_128d_48k.pth # DAC VAE model
202 | └── synchformer_state_dict.pth # Synchformer model
203 |
204 | configs/
205 | └── hunyuanvideo-foley-xxl.yaml # Configuration file
206 | ```
207 |
208 | ## TODO
209 | - [x] ADD VHS INPUT/OUTPUTS (Thanks to YC)
210 | - [x] NEGATIVE PROMPT (Thanks to YC)
211 | - [x] MODEL OFFLOADING OPS
212 | - [x] TORCH COMPILE
213 | - [ ] QUANTISE MODEL
214 |
215 |
216 | ## Support
217 |
218 | If you find this tool useful, please consider supporting my work by:
219 |
220 | - Starring this repository on GitHub
221 | - Subscribing to my YouTube channel: [Impact Frames](https://youtube.com/@impactframes?si=DrBu3tOAC2-YbEvc)
222 | - Following on X: [@ImpactFrames](https://x.com/ImpactFramesX)
223 |
224 | You can also support by reporting issues or suggesting features. Your contributions help me bring updates and improvements to the project.
225 |
226 |
227 |
228 | ## License
229 |
230 | This custom node is based on the HunyuanVideo-Foley project. Please check the original project's license terms.
231 |
232 | ## Credits
233 |
234 | Based on the HunyuanVideo-Foley project by Tencent. Original paper and code available at:
235 | - Paper: [HunyuanVideo-Foley: Text-Video-to-Audio Synthesis]
236 |
237 | - Code: [https://github.com/tencent/HunyuanVideo-Foley]
238 |
239 |
240 |
241 |
242 |
243 |
244 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/nn/loss.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from typing import List
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from audiotools import AudioSignal
7 | from audiotools import STFTParams
8 | from torch import nn
9 |
10 |
11 | class L1Loss(nn.L1Loss):
12 | """L1 Loss between AudioSignals. Defaults
13 | to comparing ``audio_data``, but any
14 | attribute of an AudioSignal can be used.
15 |
16 | Parameters
17 | ----------
18 | attribute : str, optional
19 | Attribute of signal to compare, defaults to ``audio_data``.
20 | weight : float, optional
21 | Weight of this loss, defaults to 1.0.
22 |
23 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
24 | """
25 |
26 | def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
27 | self.attribute = attribute
28 | self.weight = weight
29 | super().__init__(**kwargs)
30 |
31 | def forward(self, x: AudioSignal, y: AudioSignal):
32 | """
33 | Parameters
34 | ----------
35 | x : AudioSignal
36 | Estimate AudioSignal
37 | y : AudioSignal
38 | Reference AudioSignal
39 |
40 | Returns
41 | -------
42 | torch.Tensor
43 | L1 loss between AudioSignal attributes.
44 | """
45 | if isinstance(x, AudioSignal):
46 | x = getattr(x, self.attribute)
47 | y = getattr(y, self.attribute)
48 | return super().forward(x, y)
49 |
50 |
51 | class SISDRLoss(nn.Module):
52 | """
53 | Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
54 | of estimated and reference audio signals or aligned features.
55 |
56 | Parameters
57 | ----------
58 | scaling : int, optional
59 | Whether to use scale-invariant (True) or
60 | signal-to-noise ratio (False), by default True
61 | reduction : str, optional
62 | How to reduce across the batch (either 'mean',
63 | 'sum', or none).], by default ' mean'
64 | zero_mean : int, optional
65 | Zero mean the references and estimates before
66 | computing the loss, by default True
67 | clip_min : int, optional
68 | The minimum possible loss value. Helps network
69 | to not focus on making already good examples better, by default None
70 | weight : float, optional
71 | Weight of this loss, defaults to 1.0.
72 |
73 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
74 | """
75 |
76 | def __init__(
77 | self,
78 | scaling: int = True,
79 | reduction: str = "mean",
80 | zero_mean: int = True,
81 | clip_min: int = None,
82 | weight: float = 1.0,
83 | ):
84 | self.scaling = scaling
85 | self.reduction = reduction
86 | self.zero_mean = zero_mean
87 | self.clip_min = clip_min
88 | self.weight = weight
89 | super().__init__()
90 |
91 | def forward(self, x: AudioSignal, y: AudioSignal):
92 | eps = 1e-8
93 | # nb, nc, nt
94 | if isinstance(x, AudioSignal):
95 | references = x.audio_data
96 | estimates = y.audio_data
97 | else:
98 | references = x
99 | estimates = y
100 |
101 | nb = references.shape[0]
102 | references = references.reshape(nb, 1, -1).permute(0, 2, 1)
103 | estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
104 |
105 | # samples now on axis 1
106 | if self.zero_mean:
107 | mean_reference = references.mean(dim=1, keepdim=True)
108 | mean_estimate = estimates.mean(dim=1, keepdim=True)
109 | else:
110 | mean_reference = 0
111 | mean_estimate = 0
112 |
113 | _references = references - mean_reference
114 | _estimates = estimates - mean_estimate
115 |
116 | references_projection = (_references**2).sum(dim=-2) + eps
117 | references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
118 |
119 | scale = (
120 | (references_on_estimates / references_projection).unsqueeze(1)
121 | if self.scaling
122 | else 1
123 | )
124 |
125 | e_true = scale * _references
126 | e_res = _estimates - e_true
127 |
128 | signal = (e_true**2).sum(dim=1)
129 | noise = (e_res**2).sum(dim=1)
130 | sdr = -10 * torch.log10(signal / noise + eps)
131 |
132 | if self.clip_min is not None:
133 | sdr = torch.clamp(sdr, min=self.clip_min)
134 |
135 | if self.reduction == "mean":
136 | sdr = sdr.mean()
137 | elif self.reduction == "sum":
138 | sdr = sdr.sum()
139 | return sdr
140 |
141 |
142 | class MultiScaleSTFTLoss(nn.Module):
143 | """Computes the multi-scale STFT loss from [1].
144 |
145 | Parameters
146 | ----------
147 | window_lengths : List[int], optional
148 | Length of each window of each STFT, by default [2048, 512]
149 | loss_fn : typing.Callable, optional
150 | How to compare each loss, by default nn.L1Loss()
151 | clamp_eps : float, optional
152 | Clamp on the log magnitude, below, by default 1e-5
153 | mag_weight : float, optional
154 | Weight of raw magnitude portion of loss, by default 1.0
155 | log_weight : float, optional
156 | Weight of log magnitude portion of loss, by default 1.0
157 | pow : float, optional
158 | Power to raise magnitude to before taking log, by default 2.0
159 | weight : float, optional
160 | Weight of this loss, by default 1.0
161 | match_stride : bool, optional
162 | Whether to match the stride of convolutional layers, by default False
163 |
164 | References
165 | ----------
166 |
167 | 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
168 | "DDSP: Differentiable Digital Signal Processing."
169 | International Conference on Learning Representations. 2019.
170 |
171 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
172 | """
173 |
174 | def __init__(
175 | self,
176 | window_lengths: List[int] = [2048, 512],
177 | loss_fn: typing.Callable = nn.L1Loss(),
178 | clamp_eps: float = 1e-5,
179 | mag_weight: float = 1.0,
180 | log_weight: float = 1.0,
181 | pow: float = 2.0,
182 | weight: float = 1.0,
183 | match_stride: bool = False,
184 | window_type: str = None,
185 | ):
186 | super().__init__()
187 | self.stft_params = [
188 | STFTParams(
189 | window_length=w,
190 | hop_length=w // 4,
191 | match_stride=match_stride,
192 | window_type=window_type,
193 | )
194 | for w in window_lengths
195 | ]
196 | self.loss_fn = loss_fn
197 | self.log_weight = log_weight
198 | self.mag_weight = mag_weight
199 | self.clamp_eps = clamp_eps
200 | self.weight = weight
201 | self.pow = pow
202 |
203 | def forward(self, x: AudioSignal, y: AudioSignal):
204 | """Computes multi-scale STFT between an estimate and a reference
205 | signal.
206 |
207 | Parameters
208 | ----------
209 | x : AudioSignal
210 | Estimate signal
211 | y : AudioSignal
212 | Reference signal
213 |
214 | Returns
215 | -------
216 | torch.Tensor
217 | Multi-scale STFT loss.
218 | """
219 | loss = 0.0
220 | for s in self.stft_params:
221 | x.stft(s.window_length, s.hop_length, s.window_type)
222 | y.stft(s.window_length, s.hop_length, s.window_type)
223 | loss += self.log_weight * self.loss_fn(
224 | x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
225 | y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
226 | )
227 | loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
228 | return loss
229 |
230 |
231 | class MelSpectrogramLoss(nn.Module):
232 | """Compute distance between mel spectrograms. Can be used
233 | in a multi-scale way.
234 |
235 | Parameters
236 | ----------
237 | n_mels : List[int]
238 | Number of mels per STFT, by default [150, 80],
239 | window_lengths : List[int], optional
240 | Length of each window of each STFT, by default [2048, 512]
241 | loss_fn : typing.Callable, optional
242 | How to compare each loss, by default nn.L1Loss()
243 | clamp_eps : float, optional
244 | Clamp on the log magnitude, below, by default 1e-5
245 | mag_weight : float, optional
246 | Weight of raw magnitude portion of loss, by default 1.0
247 | log_weight : float, optional
248 | Weight of log magnitude portion of loss, by default 1.0
249 | pow : float, optional
250 | Power to raise magnitude to before taking log, by default 2.0
251 | weight : float, optional
252 | Weight of this loss, by default 1.0
253 | match_stride : bool, optional
254 | Whether to match the stride of convolutional layers, by default False
255 |
256 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
257 | """
258 |
259 | def __init__(
260 | self,
261 | n_mels: List[int] = [150, 80],
262 | window_lengths: List[int] = [2048, 512],
263 | loss_fn: typing.Callable = nn.L1Loss(),
264 | clamp_eps: float = 1e-5,
265 | mag_weight: float = 1.0,
266 | log_weight: float = 1.0,
267 | pow: float = 2.0,
268 | weight: float = 1.0,
269 | match_stride: bool = False,
270 | mel_fmin: List[float] = [0.0, 0.0],
271 | mel_fmax: List[float] = [None, None],
272 | window_type: str = None,
273 | ):
274 | super().__init__()
275 | self.stft_params = [
276 | STFTParams(
277 | window_length=w,
278 | hop_length=w // 4,
279 | match_stride=match_stride,
280 | window_type=window_type,
281 | )
282 | for w in window_lengths
283 | ]
284 | self.n_mels = n_mels
285 | self.loss_fn = loss_fn
286 | self.clamp_eps = clamp_eps
287 | self.log_weight = log_weight
288 | self.mag_weight = mag_weight
289 | self.weight = weight
290 | self.mel_fmin = mel_fmin
291 | self.mel_fmax = mel_fmax
292 | self.pow = pow
293 |
294 | def forward(self, x: AudioSignal, y: AudioSignal):
295 | """Computes mel loss between an estimate and a reference
296 | signal.
297 |
298 | Parameters
299 | ----------
300 | x : AudioSignal
301 | Estimate signal
302 | y : AudioSignal
303 | Reference signal
304 |
305 | Returns
306 | -------
307 | torch.Tensor
308 | Mel loss.
309 | """
310 | loss = 0.0
311 | for n_mels, fmin, fmax, s in zip(
312 | self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
313 | ):
314 | kwargs = {
315 | "window_length": s.window_length,
316 | "hop_length": s.hop_length,
317 | "window_type": s.window_type,
318 | }
319 | x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
320 | y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
321 |
322 | loss += self.log_weight * self.loss_fn(
323 | x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
324 | y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
325 | )
326 | loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
327 | return loss
328 |
329 |
330 | class GANLoss(nn.Module):
331 | """
332 | Computes a discriminator loss, given a discriminator on
333 | generated waveforms/spectrograms compared to ground truth
334 | waveforms/spectrograms. Computes the loss for both the
335 | discriminator and the generator in separate functions.
336 | """
337 |
338 | def __init__(self, discriminator):
339 | super().__init__()
340 | self.discriminator = discriminator
341 |
342 | def forward(self, fake, real):
343 | d_fake = self.discriminator(fake.audio_data)
344 | d_real = self.discriminator(real.audio_data)
345 | return d_fake, d_real
346 |
347 | def discriminator_loss(self, fake, real):
348 | d_fake, d_real = self.forward(fake.clone().detach(), real)
349 |
350 | loss_d = 0
351 | for x_fake, x_real in zip(d_fake, d_real):
352 | loss_d += torch.mean(x_fake[-1] ** 2)
353 | loss_d += torch.mean((1 - x_real[-1]) ** 2)
354 | return loss_d
355 |
356 | def generator_loss(self, fake, real):
357 | d_fake, d_real = self.forward(fake, real)
358 |
359 | loss_g = 0
360 | for x_fake in d_fake:
361 | loss_g += torch.mean((1 - x_fake[-1]) ** 2)
362 |
363 | loss_feature = 0
364 |
365 | for i in range(len(d_fake)):
366 | for j in range(len(d_fake[i]) - 1):
367 | loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
368 | return loss_g, loss_feature
369 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/synchformer/synchformer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | from typing import Any, Mapping
4 |
5 | import einops
6 | import numpy as np
7 | import torch
8 | import torchaudio
9 | from torch import nn
10 | from torch.nn import functional as F
11 |
12 | from .motionformer import MotionFormer
13 | from .ast_model import AST
14 | from .utils import Config
15 |
16 |
17 | class Synchformer(nn.Module):
18 |
19 | def __init__(self):
20 | super().__init__()
21 |
22 | self.vfeat_extractor = MotionFormer(
23 | extract_features=True,
24 | factorize_space_time=True,
25 | agg_space_module="TransformerEncoderLayer",
26 | agg_time_module="torch.nn.Identity",
27 | add_global_repr=False,
28 | )
29 | self.afeat_extractor = AST(
30 | extract_features=True,
31 | max_spec_t=66,
32 | factorize_freq_time=True,
33 | agg_freq_module="TransformerEncoderLayer",
34 | agg_time_module="torch.nn.Identity",
35 | add_global_repr=False,
36 | )
37 |
38 | # # bridging the s3d latent dim (1024) into what is specified in the config
39 | # # to match e.g. the transformer dim
40 | self.vproj = nn.Linear(in_features=768, out_features=768)
41 | self.aproj = nn.Linear(in_features=768, out_features=768)
42 | self.transformer = GlobalTransformer(
43 | tok_pdrop=0.0, embd_pdrop=0.1, resid_pdrop=0.1, attn_pdrop=0.1, n_layer=3, n_head=8, n_embd=768
44 | )
45 |
46 | def forward(self, vis):
47 | B, S, Tv, C, H, W = vis.shape
48 | vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W)
49 | # feat extractors return a tuple of segment-level and global features (ignored for sync)
50 | # (B, S, tv, D), e.g. (B, 7, 8, 768)
51 | vis = self.vfeat_extractor(vis)
52 | return vis
53 |
54 | def compare_v_a(self, vis: torch.Tensor, aud: torch.Tensor):
55 | vis = self.vproj(vis)
56 | aud = self.aproj(aud)
57 |
58 | B, S, tv, D = vis.shape
59 | B, S, ta, D = aud.shape
60 | vis = vis.view(B, S * tv, D) # (B, S*tv, D)
61 | aud = aud.view(B, S * ta, D) # (B, S*ta, D)
62 | # print(vis.shape, aud.shape)
63 |
64 | # self.transformer will concatenate the vis and aud in one sequence with aux tokens,
65 | # ie `CvvvvMaaaaaa`, and will return the logits for the CLS tokens
66 | logits = self.transformer(vis, aud) # (B, cls); or (B, cls) and (B, 2) if DoubtingTransformer
67 |
68 | return logits
69 |
70 | def extract_vfeats(self, vis):
71 | B, S, Tv, C, H, W = vis.shape
72 | vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W)
73 | # feat extractors return a tuple of segment-level and global features (ignored for sync)
74 | # (B, S, tv, D), e.g. (B, 7, 8, 768)
75 | vis = self.vfeat_extractor(vis)
76 | return vis
77 |
78 | def extract_afeats(self, aud):
79 | B, S, _, Fa, Ta = aud.shape
80 | aud = aud.view(B, S, Fa, Ta).permute(0, 1, 3, 2) # (B, S, Ta, F)
81 | # (B, S, ta, D), e.g. (B, 7, 6, 768)
82 | aud, _ = self.afeat_extractor(aud)
83 | return aud
84 |
85 | def compute_loss(self, logits, targets, loss_fn: str = None):
86 | loss = None
87 | if targets is not None:
88 | if loss_fn is None or loss_fn == "cross_entropy":
89 | # logits: (B, cls) and targets: (B,)
90 | loss = F.cross_entropy(logits, targets)
91 | else:
92 | raise NotImplementedError(f"Loss {loss_fn} not implemented")
93 | return loss
94 |
95 | def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True):
96 | # discard all entries except vfeat_extractor
97 | # sd = {k: v for k, v in sd.items() if k.startswith('vfeat_extractor')}
98 |
99 | return super().load_state_dict(sd, strict)
100 |
101 |
102 | class RandInitPositionalEncoding(nn.Module):
103 | """Random inited trainable pos embedding. It is just applied on the sequence, thus respects no priors."""
104 |
105 | def __init__(self, block_shape: list, n_embd: int):
106 | super().__init__()
107 | self.block_shape = block_shape
108 | self.n_embd = n_embd
109 | self.pos_emb = nn.Parameter(torch.randn(1, *block_shape, n_embd))
110 |
111 | def forward(self, token_embeddings):
112 | return token_embeddings + self.pos_emb
113 |
114 |
115 | class GlobalTransformer(torch.nn.Module):
116 | """Same as in SparseSync but without the selector transformers and the head"""
117 |
118 | def __init__(
119 | self,
120 | tok_pdrop=0.0,
121 | embd_pdrop=0.1,
122 | resid_pdrop=0.1,
123 | attn_pdrop=0.1,
124 | n_layer=3,
125 | n_head=8,
126 | n_embd=768,
127 | pos_emb_block_shape=[
128 | 198,
129 | ],
130 | n_off_head_out=21,
131 | ) -> None:
132 | super().__init__()
133 | self.config = Config(
134 | embd_pdrop=embd_pdrop,
135 | resid_pdrop=resid_pdrop,
136 | attn_pdrop=attn_pdrop,
137 | n_layer=n_layer,
138 | n_head=n_head,
139 | n_embd=n_embd,
140 | )
141 | # input norm
142 | self.vis_in_lnorm = torch.nn.LayerNorm(n_embd)
143 | self.aud_in_lnorm = torch.nn.LayerNorm(n_embd)
144 | # aux tokens
145 | self.OFF_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd))
146 | self.MOD_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd))
147 | # whole token dropout
148 | self.tok_pdrop = tok_pdrop
149 | self.tok_drop_vis = torch.nn.Dropout1d(tok_pdrop)
150 | self.tok_drop_aud = torch.nn.Dropout1d(tok_pdrop)
151 | # maybe add pos emb
152 | self.pos_emb_cfg = RandInitPositionalEncoding(
153 | block_shape=pos_emb_block_shape,
154 | n_embd=n_embd,
155 | )
156 | # the stem
157 | self.drop = torch.nn.Dropout(embd_pdrop)
158 | self.blocks = torch.nn.Sequential(*[Block(self.config) for _ in range(n_layer)])
159 | # pre-output norm
160 | self.ln_f = torch.nn.LayerNorm(n_embd)
161 | # maybe add a head
162 | self.off_head = torch.nn.Linear(in_features=n_embd, out_features=n_off_head_out)
163 |
164 | def forward(self, v: torch.Tensor, a: torch.Tensor, targets=None, attempt_to_apply_heads=True):
165 | B, Sv, D = v.shape
166 | B, Sa, D = a.shape
167 | # broadcasting special tokens to the batch size
168 | off_tok = einops.repeat(self.OFF_tok, "1 1 d -> b 1 d", b=B)
169 | mod_tok = einops.repeat(self.MOD_tok, "1 1 d -> b 1 d", b=B)
170 | # norm
171 | v, a = self.vis_in_lnorm(v), self.aud_in_lnorm(a)
172 | # maybe whole token dropout
173 | if self.tok_pdrop > 0:
174 | v, a = self.tok_drop_vis(v), self.tok_drop_aud(a)
175 | # (B, 1+Sv+1+Sa, D)
176 | x = torch.cat((off_tok, v, mod_tok, a), dim=1)
177 | # maybe add pos emb
178 | if hasattr(self, "pos_emb_cfg"):
179 | x = self.pos_emb_cfg(x)
180 | # dropout -> stem -> norm
181 | x = self.drop(x)
182 | x = self.blocks(x)
183 | x = self.ln_f(x)
184 | # maybe add heads
185 | if attempt_to_apply_heads and hasattr(self, "off_head"):
186 | x = self.off_head(x[:, 0, :])
187 | return x
188 |
189 |
190 | class SelfAttention(nn.Module):
191 | """
192 | A vanilla multi-head masked self-attention layer with a projection at the end.
193 | It is possible to use torch.nn.MultiheadAttention here but I am including an
194 | explicit implementation here to show that there is nothing too scary here.
195 | """
196 |
197 | def __init__(self, config):
198 | super().__init__()
199 | assert config.n_embd % config.n_head == 0
200 | # key, query, value projections for all heads
201 | self.key = nn.Linear(config.n_embd, config.n_embd)
202 | self.query = nn.Linear(config.n_embd, config.n_embd)
203 | self.value = nn.Linear(config.n_embd, config.n_embd)
204 | # regularization
205 | self.attn_drop = nn.Dropout(config.attn_pdrop)
206 | self.resid_drop = nn.Dropout(config.resid_pdrop)
207 | # output projection
208 | self.proj = nn.Linear(config.n_embd, config.n_embd)
209 | # # causal mask to ensure that attention is only applied to the left in the input sequence
210 | # mask = torch.tril(torch.ones(config.block_size,
211 | # config.block_size))
212 | # if hasattr(config, "n_unmasked"):
213 | # mask[:config.n_unmasked, :config.n_unmasked] = 1
214 | # self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
215 | self.n_head = config.n_head
216 |
217 | def forward(self, x):
218 | B, T, C = x.size()
219 |
220 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
221 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
222 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
223 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
224 |
225 | # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
226 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
227 | # att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
228 | att = F.softmax(att, dim=-1)
229 | y = self.attn_drop(att) @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
230 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
231 |
232 | # output projection
233 | y = self.resid_drop(self.proj(y))
234 |
235 | return y
236 |
237 |
238 | class Block(nn.Module):
239 | """an unassuming Transformer block"""
240 |
241 | def __init__(self, config):
242 | super().__init__()
243 | self.ln1 = nn.LayerNorm(config.n_embd)
244 | self.ln2 = nn.LayerNorm(config.n_embd)
245 | self.attn = SelfAttention(config)
246 | self.mlp = nn.Sequential(
247 | nn.Linear(config.n_embd, 4 * config.n_embd),
248 | nn.GELU(), # nice
249 | nn.Linear(4 * config.n_embd, config.n_embd),
250 | nn.Dropout(config.resid_pdrop),
251 | )
252 |
253 | def forward(self, x):
254 | x = x + self.attn(self.ln1(x))
255 | x = x + self.mlp(self.ln2(x))
256 | return x
257 |
258 |
259 | def make_class_grid(
260 | leftmost_val,
261 | rightmost_val,
262 | grid_size,
263 | add_extreme_offset: bool = False,
264 | seg_size_vframes: int = None,
265 | nseg: int = None,
266 | step_size_seg: float = None,
267 | vfps: float = None,
268 | ):
269 | assert grid_size >= 3, f"grid_size: {grid_size} doesnot make sense. If =2 -> (-1,1); =1 -> (-1); =0 -> ()"
270 | grid = torch.from_numpy(np.linspace(leftmost_val, rightmost_val, grid_size)).float()
271 | if add_extreme_offset:
272 | assert all([seg_size_vframes, nseg, step_size_seg]), f"{seg_size_vframes} {nseg} {step_size_seg}"
273 | seg_size_sec = seg_size_vframes / vfps
274 | trim_size_in_seg = nseg - (1 - step_size_seg) * (nseg - 1)
275 | extreme_value = trim_size_in_seg * seg_size_sec
276 | grid = torch.cat([grid, torch.tensor([extreme_value])]) # adding extreme offset to the class grid
277 | return grid
278 |
279 |
280 | # from synchformer
281 | def pad_or_truncate(audio: torch.Tensor, max_spec_t: int, pad_mode: str = "constant", pad_value: float = 0.0):
282 | difference = max_spec_t - audio.shape[-1] # safe for batched input
283 | # pad or truncate, depending on difference
284 | if difference > 0:
285 | # pad the last dim (time) -> (..., n_mels, 0+time+difference) # safe for batched input
286 | pad_dims = (0, difference)
287 | audio = torch.nn.functional.pad(audio, pad_dims, pad_mode, pad_value)
288 | elif difference < 0:
289 | print(f"Truncating spec ({audio.shape}) to max_spec_t ({max_spec_t}).")
290 | audio = audio[..., :max_spec_t] # safe for batched input
291 | return audio
292 |
293 |
294 | def encode_audio_with_sync(
295 | synchformer: Synchformer, x: torch.Tensor, mel: torchaudio.transforms.MelSpectrogram
296 | ) -> torch.Tensor:
297 | b, t = x.shape
298 |
299 | # partition the video
300 | segment_size = 10240
301 | step_size = 10240 // 2
302 | num_segments = (t - segment_size) // step_size + 1
303 | segments = []
304 | for i in range(num_segments):
305 | segments.append(x[:, i * step_size : i * step_size + segment_size])
306 | x = torch.stack(segments, dim=1) # (B, S, T, C, H, W)
307 |
308 | x = mel(x)
309 | x = torch.log(x + 1e-6)
310 | x = pad_or_truncate(x, 66)
311 |
312 | mean = -4.2677393
313 | std = 4.5689974
314 | x = (x - mean) / (2 * std)
315 | # x: B * S * 128 * 66
316 | x = synchformer.extract_afeats(x.unsqueeze(2))
317 | return x
318 |
319 |
320 | def read_audio(filename, expected_length=int(16000 * 4)):
321 | waveform, sr = torchaudio.load(filename)
322 | waveform = waveform.mean(dim=0)
323 |
324 | if sr != 16000:
325 | resampler = torchaudio.transforms.Resample(sr, 16000)
326 | waveform = resampler[sr](waveform)
327 |
328 | waveform = waveform[:expected_length]
329 | if waveform.shape[0] != expected_length:
330 | raise ValueError(f"Audio {filename} is too short")
331 |
332 | waveform = waveform.squeeze()
333 |
334 | return waveform
335 |
336 |
337 | if __name__ == "__main__":
338 | synchformer = Synchformer().cuda().eval()
339 |
340 | # mmaudio provided synchformer ckpt
341 | synchformer.load_state_dict(
342 | torch.load(
343 | os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"),
344 | weights_only=True,
345 | map_location="cpu",
346 | )
347 | )
348 |
349 | sync_mel_spectrogram = torchaudio.transforms.MelSpectrogram(
350 | sample_rate=16000,
351 | win_length=400,
352 | hop_length=160,
353 | n_fft=1024,
354 | n_mels=128,
355 | )
356 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/model/dac.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List
3 | from typing import Union
4 |
5 | import numpy as np
6 | import torch
7 | from audiotools import AudioSignal
8 | from audiotools.ml import BaseModel
9 | from torch import nn
10 |
11 | from .base import CodecMixin
12 | from ..nn.layers import Snake1d
13 | from ..nn.layers import WNConv1d
14 | from ..nn.layers import WNConvTranspose1d
15 | from ..nn.quantize import ResidualVectorQuantize
16 | from ..nn.vae_utils import DiagonalGaussianDistribution
17 |
18 |
19 | def init_weights(m):
20 | if isinstance(m, nn.Conv1d):
21 | nn.init.trunc_normal_(m.weight, std=0.02)
22 | nn.init.constant_(m.bias, 0)
23 |
24 |
25 | class ResidualUnit(nn.Module):
26 | def __init__(self, dim: int = 16, dilation: int = 1):
27 | super().__init__()
28 | pad = ((7 - 1) * dilation) // 2
29 | self.block = nn.Sequential(
30 | Snake1d(dim),
31 | WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
32 | Snake1d(dim),
33 | WNConv1d(dim, dim, kernel_size=1),
34 | )
35 |
36 | def forward(self, x):
37 | y = self.block(x)
38 | pad = (x.shape[-1] - y.shape[-1]) // 2
39 | if pad > 0:
40 | x = x[..., pad:-pad]
41 | return x + y
42 |
43 |
44 | class EncoderBlock(nn.Module):
45 | def __init__(self, dim: int = 16, stride: int = 1):
46 | super().__init__()
47 | self.block = nn.Sequential(
48 | ResidualUnit(dim // 2, dilation=1),
49 | ResidualUnit(dim // 2, dilation=3),
50 | ResidualUnit(dim // 2, dilation=9),
51 | Snake1d(dim // 2),
52 | WNConv1d(
53 | dim // 2,
54 | dim,
55 | kernel_size=2 * stride,
56 | stride=stride,
57 | padding=math.ceil(stride / 2),
58 | ),
59 | )
60 |
61 | def forward(self, x):
62 | return self.block(x)
63 |
64 |
65 | class Encoder(nn.Module):
66 | def __init__(
67 | self,
68 | d_model: int = 64,
69 | strides: list = [2, 4, 8, 8],
70 | d_latent: int = 64,
71 | ):
72 | super().__init__()
73 | # Create first convolution
74 | self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
75 |
76 | # Create EncoderBlocks that double channels as they downsample by `stride`
77 | for stride in strides:
78 | d_model *= 2
79 | self.block += [EncoderBlock(d_model, stride=stride)]
80 |
81 | # Create last convolution
82 | self.block += [
83 | Snake1d(d_model),
84 | WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
85 | ]
86 |
87 | # Wrap black into nn.Sequential
88 | self.block = nn.Sequential(*self.block)
89 | self.enc_dim = d_model
90 |
91 | def forward(self, x):
92 | return self.block(x)
93 |
94 |
95 | class DecoderBlock(nn.Module):
96 | def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
97 | super().__init__()
98 | self.block = nn.Sequential(
99 | Snake1d(input_dim),
100 | WNConvTranspose1d(
101 | input_dim,
102 | output_dim,
103 | kernel_size=2 * stride,
104 | stride=stride,
105 | padding=math.ceil(stride / 2),
106 | output_padding=stride % 2,
107 | ),
108 | ResidualUnit(output_dim, dilation=1),
109 | ResidualUnit(output_dim, dilation=3),
110 | ResidualUnit(output_dim, dilation=9),
111 | )
112 |
113 | def forward(self, x):
114 | return self.block(x)
115 |
116 |
117 | class Decoder(nn.Module):
118 | def __init__(
119 | self,
120 | input_channel,
121 | channels,
122 | rates,
123 | d_out: int = 1,
124 | ):
125 | super().__init__()
126 |
127 | # Add first conv layer
128 | layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
129 |
130 | # Add upsampling + MRF blocks
131 | for i, stride in enumerate(rates):
132 | input_dim = channels // 2**i
133 | output_dim = channels // 2 ** (i + 1)
134 | layers += [DecoderBlock(input_dim, output_dim, stride)]
135 |
136 | # Add final conv layer
137 | layers += [
138 | Snake1d(output_dim),
139 | WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
140 | nn.Tanh(),
141 | ]
142 |
143 | self.model = nn.Sequential(*layers)
144 |
145 | def forward(self, x):
146 | return self.model(x)
147 |
148 |
149 | class DAC(BaseModel, CodecMixin):
150 | def __init__(
151 | self,
152 | encoder_dim: int = 64,
153 | encoder_rates: List[int] = [2, 4, 8, 8],
154 | latent_dim: int = None,
155 | decoder_dim: int = 1536,
156 | decoder_rates: List[int] = [8, 8, 4, 2],
157 | n_codebooks: int = 9,
158 | codebook_size: int = 1024,
159 | codebook_dim: Union[int, list] = 8,
160 | quantizer_dropout: bool = False,
161 | sample_rate: int = 44100,
162 | continuous: bool = False,
163 | ):
164 | super().__init__()
165 |
166 | self.encoder_dim = encoder_dim
167 | self.encoder_rates = encoder_rates
168 | self.decoder_dim = decoder_dim
169 | self.decoder_rates = decoder_rates
170 | self.sample_rate = sample_rate
171 | self.continuous = continuous
172 |
173 | if latent_dim is None:
174 | latent_dim = encoder_dim * (2 ** len(encoder_rates))
175 |
176 | self.latent_dim = latent_dim
177 |
178 | self.hop_length = np.prod(encoder_rates)
179 | self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
180 |
181 | if not continuous:
182 | self.n_codebooks = n_codebooks
183 | self.codebook_size = codebook_size
184 | self.codebook_dim = codebook_dim
185 | self.quantizer = ResidualVectorQuantize(
186 | input_dim=latent_dim,
187 | n_codebooks=n_codebooks,
188 | codebook_size=codebook_size,
189 | codebook_dim=codebook_dim,
190 | quantizer_dropout=quantizer_dropout,
191 | )
192 | else:
193 | self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1)
194 | self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1)
195 |
196 | self.decoder = Decoder(
197 | latent_dim,
198 | decoder_dim,
199 | decoder_rates,
200 | )
201 | self.sample_rate = sample_rate
202 | self.apply(init_weights)
203 |
204 | self.delay = self.get_delay()
205 |
206 | @property
207 | def dtype(self):
208 | """Get the dtype of the model parameters."""
209 | # Return the dtype of the first parameter found
210 | for param in self.parameters():
211 | return param.dtype
212 | return torch.float32 # fallback
213 |
214 | @property
215 | def device(self):
216 | """Get the device of the model parameters."""
217 | # Return the device of the first parameter found
218 | for param in self.parameters():
219 | return param.device
220 | return torch.device('cpu') # fallback
221 |
222 | def preprocess(self, audio_data, sample_rate):
223 | if sample_rate is None:
224 | sample_rate = self.sample_rate
225 | assert sample_rate == self.sample_rate
226 |
227 | length = audio_data.shape[-1]
228 | right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
229 | audio_data = nn.functional.pad(audio_data, (0, right_pad))
230 |
231 | return audio_data
232 |
233 | def encode(
234 | self,
235 | audio_data: torch.Tensor,
236 | n_quantizers: int = None,
237 | ):
238 | """Encode given audio data and return quantized latent codes
239 |
240 | Parameters
241 | ----------
242 | audio_data : Tensor[B x 1 x T]
243 | Audio data to encode
244 | n_quantizers : int, optional
245 | Number of quantizers to use, by default None
246 | If None, all quantizers are used.
247 |
248 | Returns
249 | -------
250 | dict
251 | A dictionary with the following keys:
252 | "z" : Tensor[B x D x T]
253 | Quantized continuous representation of input
254 | "codes" : Tensor[B x N x T]
255 | Codebook indices for each codebook
256 | (quantized discrete representation of input)
257 | "latents" : Tensor[B x N*D x T]
258 | Projected latents (continuous representation of input before quantization)
259 | "vq/commitment_loss" : Tensor[1]
260 | Commitment loss to train encoder to predict vectors closer to codebook
261 | entries
262 | "vq/codebook_loss" : Tensor[1]
263 | Codebook loss to update the codebook
264 | "length" : int
265 | Number of samples in input audio
266 | """
267 | z = self.encoder(audio_data) # [B x D x T]
268 | if not self.continuous:
269 | z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
270 | else:
271 | z = self.quant_conv(z) # [B x 2D x T]
272 | z = DiagonalGaussianDistribution(z)
273 | codes, latents, commitment_loss, codebook_loss = None, None, 0, 0
274 |
275 | return z, codes, latents, commitment_loss, codebook_loss
276 |
277 | def decode(self, z: torch.Tensor):
278 | """Decode given latent codes and return audio data
279 |
280 | Parameters
281 | ----------
282 | z : Tensor[B x D x T]
283 | Quantized continuous representation of input
284 | length : int, optional
285 | Number of samples in output audio, by default None
286 |
287 | Returns
288 | -------
289 | dict
290 | A dictionary with the following keys:
291 | "audio" : Tensor[B x 1 x length]
292 | Decoded audio data.
293 | """
294 | if not self.continuous:
295 | audio = self.decoder(z)
296 | else:
297 | z = self.post_quant_conv(z)
298 | audio = self.decoder(z)
299 |
300 | return audio
301 |
302 | def forward(
303 | self,
304 | audio_data: torch.Tensor,
305 | sample_rate: int = None,
306 | n_quantizers: int = None,
307 | ):
308 | """Model forward pass
309 |
310 | Parameters
311 | ----------
312 | audio_data : Tensor[B x 1 x T]
313 | Audio data to encode
314 | sample_rate : int, optional
315 | Sample rate of audio data in Hz, by default None
316 | If None, defaults to `self.sample_rate`
317 | n_quantizers : int, optional
318 | Number of quantizers to use, by default None.
319 | If None, all quantizers are used.
320 |
321 | Returns
322 | -------
323 | dict
324 | A dictionary with the following keys:
325 | "z" : Tensor[B x D x T]
326 | Quantized continuous representation of input
327 | "codes" : Tensor[B x N x T]
328 | Codebook indices for each codebook
329 | (quantized discrete representation of input)
330 | "latents" : Tensor[B x N*D x T]
331 | Projected latents (continuous representation of input before quantization)
332 | "vq/commitment_loss" : Tensor[1]
333 | Commitment loss to train encoder to predict vectors closer to codebook
334 | entries
335 | "vq/codebook_loss" : Tensor[1]
336 | Codebook loss to update the codebook
337 | "length" : int
338 | Number of samples in input audio
339 | "audio" : Tensor[B x 1 x length]
340 | Decoded audio data.
341 | """
342 | length = audio_data.shape[-1]
343 | audio_data = self.preprocess(audio_data, sample_rate)
344 | if not self.continuous:
345 | z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
346 |
347 | x = self.decode(z)
348 | return {
349 | "audio": x[..., :length],
350 | "z": z,
351 | "codes": codes,
352 | "latents": latents,
353 | "vq/commitment_loss": commitment_loss,
354 | "vq/codebook_loss": codebook_loss,
355 | }
356 | else:
357 | posterior, _, _, _, _ = self.encode(audio_data, n_quantizers)
358 | z = posterior.sample()
359 | x = self.decode(z)
360 |
361 | kl_loss = posterior.kl()
362 | kl_loss = kl_loss.mean()
363 |
364 | return {
365 | "audio": x[..., :length],
366 | "z": z,
367 | "kl_loss": kl_loss,
368 | }
369 |
370 |
371 | if __name__ == "__main__":
372 | import numpy as np
373 | from functools import partial
374 |
375 | model = DAC().to("cpu")
376 |
377 | for n, m in model.named_modules():
378 | o = m.extra_repr()
379 | p = sum([np.prod(p.size()) for p in m.parameters()])
380 | fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
381 | setattr(m, "extra_repr", partial(fn, o=o, p=p))
382 | print(model)
383 | print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
384 |
385 | length = 88200 * 2
386 | x = torch.randn(1, 1, length).to(model.device)
387 | x.requires_grad_(True)
388 | x.retain_grad()
389 |
390 | # Make a forward pass
391 | out = model(x)["audio"]
392 | print("Input shape:", x.shape)
393 | print("Output shape:", out.shape)
394 |
395 | # Create gradient variable
396 | grad = torch.zeros_like(out)
397 | grad[:, :, grad.shape[-1] // 2] = 1
398 |
399 | # Make a backward pass
400 | out.backward(grad)
401 |
402 | # Check non-zero values
403 | gradmap = x.grad.squeeze(0)
404 | gradmap = (gradmap != 0).sum(0) # sum across features
405 | rf = (gradmap != 0).sum()
406 |
407 | print(f"Receptive field: {rf.item()}")
408 |
409 | x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
410 | model.decompress(model.compress(x, verbose=True), verbose=True)
411 |
--------------------------------------------------------------------------------