├── hunyuanvideo_foley
├── __init__.py
├── models
│ ├── __init__.py
│ ├── nn
│ │ ├── __init__.py
│ │ ├── activation_layers.py
│ │ ├── modulate_layers.py
│ │ ├── norm_layers.py
│ │ ├── embed_layers.py
│ │ ├── mlp_layers.py
│ │ └── posemb_layers.py
│ ├── synchformer
│ │ ├── __init__.py
│ │ ├── divided_224_16x4.yaml
│ │ ├── utils.py
│ │ ├── compute_desync_score.py
│ │ ├── video_model_builder.py
│ │ └── synchformer.py
│ └── dac_vae
│ │ ├── nn
│ │ ├── __init__.py
│ │ ├── layers.py
│ │ ├── vae_utils.py
│ │ ├── quantize.py
│ │ └── loss.py
│ │ ├── model
│ │ ├── __init__.py
│ │ ├── discriminator.py
│ │ ├── base.py
│ │ └── dac.py
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ └── utils
│ │ ├── decode.py
│ │ ├── encode.py
│ │ └── __init__.py
├── utils
│ ├── __init__.py
│ ├── schedulers
│ │ └── __init__.py
│ ├── media_utils.py
│ ├── config_utils.py
│ ├── helper.py
│ ├── feature_utils.py
│ └── model_utils.py
└── constants.py
├── __init__.py
├── requirements.txt
├── configs
└── hunyuanvideo-foley-xxl.yaml
├── model_urls.py
├── install.py
├── download_models_manual.py
├── README.md
├── INSTALLATION_GUIDE.md
├── test_node.py
├── utils.py
└── example_workflows
└── HunyuanVideo-Foley.json
/hunyuanvideo_foley/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/nn/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/synchformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .synchformer import Synchformer
2 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from . import layers
2 | from . import loss
3 | from . import quantize
4 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import CodecMixin
2 | from .base import DACFile
3 | from .dac import DAC
4 | from .discriminator import Discriminator
5 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/utils/schedulers/__init__.py:
--------------------------------------------------------------------------------
1 | from diffusers.schedulers import DDPMScheduler, EulerDiscreteScheduler
2 | from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.0.0"
2 |
3 | # preserved here for legacy reasons
4 | __model_version__ = "latest"
5 |
6 | import audiotools
7 |
8 | audiotools.ml.BaseModel.INTERN += ["dac.**"]
9 | audiotools.ml.BaseModel.EXTERN += ["einops"]
10 |
11 |
12 | from . import nn
13 | from . import model
14 | from . import utils
15 | from .model import DAC
16 | from .model import DACFile
17 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | # Add the current directory to Python path to import hunyuanvideo_foley modules
5 | current_dir = os.path.dirname(os.path.abspath(__file__))
6 | if current_dir not in sys.path:
7 | sys.path.insert(0, current_dir)
8 |
9 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
10 |
11 | # Export the mappings
12 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Core ML dependencies
2 | #torch>=2.0.0
3 | #torchvision>=0.15.0
4 | #torchaudio>=2.0.0
5 | numpy>=1.26.4
6 | scipy
7 |
8 | # Deep Learning frameworks
9 | diffusers
10 | timm
11 | accelerate
12 |
13 | # Transformers and NLP
14 | transformers>=4.37.0
15 | sentencepiece
16 | huggingface_hub
17 |
18 | # Audio processing
19 | git+https://github.com/descriptinc/audiotools
20 |
21 | # Video/Image processing
22 | pillow
23 | av
24 | einops
25 |
26 | # Configuration and utilities
27 | pyyaml
28 | omegaconf
29 | easydict
30 | loguru
31 | tqdm
32 | setuptools
33 |
34 | # Data handling
35 | pandas
36 | pyarrow
37 |
38 | # Network
39 | urllib3==2.4.0
40 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/__main__.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import argbind
4 |
5 | from .utils import download
6 | from .utils.decode import decode
7 | from .utils.encode import encode
8 |
9 | STAGES = ["encode", "decode", "download"]
10 |
11 |
12 | def run(stage: str):
13 | """Run stages.
14 |
15 | Parameters
16 | ----------
17 | stage : str
18 | Stage to run
19 | """
20 | if stage not in STAGES:
21 | raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
22 | stage_fn = globals()[stage]
23 |
24 | if stage == "download":
25 | stage_fn()
26 | return
27 |
28 | stage_fn()
29 |
30 |
31 | if __name__ == "__main__":
32 | group = sys.argv.pop(1)
33 | args = argbind.parse_args(group=group)
34 |
35 | with argbind.scope(args):
36 | run(group)
37 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/nn/layers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from einops import rearrange
6 | from torch.nn.utils import weight_norm
7 |
8 |
9 | def WNConv1d(*args, **kwargs):
10 | return weight_norm(nn.Conv1d(*args, **kwargs))
11 |
12 |
13 | def WNConvTranspose1d(*args, **kwargs):
14 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15 |
16 |
17 | # Scripting this brings model speed up 1.4x
18 | @torch.jit.script
19 | def snake(x, alpha):
20 | shape = x.shape
21 | x = x.reshape(shape[0], shape[1], -1)
22 | x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
23 | x = x.reshape(shape)
24 | return x
25 |
26 |
27 | class Snake1d(nn.Module):
28 | def __init__(self, channels):
29 | super().__init__()
30 | self.alpha = nn.Parameter(torch.ones(1, channels, 1))
31 |
32 | def forward(self, x):
33 | return snake(x, self.alpha)
34 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/nn/activation_layers.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | def get_activation_layer(act_type):
5 | if act_type == "gelu":
6 | return lambda: nn.GELU()
7 | elif act_type == "gelu_tanh":
8 | # Approximate `tanh` requires torch >= 1.13
9 | return lambda: nn.GELU(approximate="tanh")
10 | elif act_type == "relu":
11 | return nn.ReLU
12 | elif act_type == "silu":
13 | return nn.SiLU
14 | else:
15 | raise ValueError(f"Unknown activation type: {act_type}")
16 |
17 | class SwiGLU(nn.Module):
18 | def __init__(
19 | self,
20 | dim: int,
21 | hidden_dim: int,
22 | out_dim: int,
23 | ):
24 | """
25 | Initialize the SwiGLU FeedForward module.
26 |
27 | Args:
28 | dim (int): Input dimension.
29 | hidden_dim (int): Hidden dimension of the feedforward layer.
30 |
31 | Attributes:
32 | w1: Linear transformation for the first layer.
33 | w2: Linear transformation for the second layer.
34 | w3: Linear transformation for the third layer.
35 |
36 | """
37 | super().__init__()
38 |
39 | self.w1 = nn.Linear(dim, hidden_dim, bias=False)
40 | self.w2 = nn.Linear(hidden_dim, out_dim, bias=False)
41 | self.w3 = nn.Linear(dim, hidden_dim, bias=False)
42 |
43 | def forward(self, x):
44 | return self.w2(F.silu(self.w1(x)) * self.w3(x))
45 |
--------------------------------------------------------------------------------
/configs/hunyuanvideo-foley-xxl.yaml:
--------------------------------------------------------------------------------
1 | model_config:
2 | model_name: HunyuanVideo-Foley-XXL
3 | model_type: 1d
4 | model_precision: bf16
5 | model_kwargs:
6 | depth_triple_blocks: 18
7 | depth_single_blocks: 36
8 | hidden_size: 1536
9 | num_heads: 12
10 | mlp_ratio: 4
11 | mlp_act_type: "gelu_tanh"
12 | qkv_bias: True
13 | qk_norm: True
14 | qk_norm_type: "rms"
15 | attn_mode: "torch"
16 | embedder_type: "default"
17 | interleaved_audio_visual_rope: True
18 | enable_learnable_empty_visual_feat: True
19 | sync_modulation: False
20 | add_sync_feat_to_audio: True
21 | cross_attention: True
22 | use_attention_mask: False
23 | condition_projection: "linear"
24 | sync_feat_dim: 768 # syncformer 768 dim
25 | condition_dim: 768 # clap 768 text condition dim (clip-text)
26 | clip_dim: 768 # siglip2 visual dim
27 | audio_vae_latent_dim: 128
28 | audio_frame_rate: 50
29 | patch_size: 1
30 | rope_dim_list: null
31 | rope_theta: 10000
32 | text_length: 77
33 | clip_length: 64
34 | sync_length: 192
35 | use_mmaudio_singleblock: True
36 | depth_triple_ssl_encoder: null
37 | depth_single_ssl_encoder: 8
38 | use_repa_with_audiossl: True
39 |
40 | diffusion_config:
41 | denoise_type: "flow"
42 | flow_path_type: "linear"
43 | flow_predict_type: "velocity"
44 | flow_reverse: True
45 | flow_solver: "euler"
46 | sample_flow_shift: 1.0
47 | sample_use_flux_shift: False
48 | flux_base_shift: 0.5
49 | flux_max_shift: 1.15
50 |
--------------------------------------------------------------------------------
/model_urls.py:
--------------------------------------------------------------------------------
1 | # HunyuanVideo-Foley Model URLs Configuration
2 | # Update these URLs with the actual download links for the models
3 |
4 | MODEL_URLS = {
5 | "hunyuanvideo-foley-xxl": {
6 | "models": [
7 | {
8 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/hunyuanvideo_foley.pth",
9 | "filename": "hunyuanvideo_foley.pth",
10 | "description": "Main HunyuanVideo-Foley model"
11 | },
12 | {
13 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/synchformer_state_dict.pth",
14 | "filename": "synchformer_state_dict.pth",
15 | "description": "Synchformer model weights"
16 | },
17 | {
18 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/vae_128d_48k.pth",
19 | "filename": "vae_128d_48k.pth",
20 | "description": "VAE model weights"
21 | }
22 | ],
23 | "extracted_dir": "hunyuanvideo-foley-xxl",
24 | "description": "HunyuanVideo-Foley XXL model for audio generation"
25 | }
26 | }
27 |
28 | # Alternative mirror URLs (if main URLs fail)
29 | MIRROR_URLS = {
30 | # Add mirror download sources here if needed
31 | }
32 |
33 | def get_model_url(model_name: str, use_mirror: bool = False) -> dict:
34 | """Get model URL configuration"""
35 | urls_dict = MIRROR_URLS if use_mirror else MODEL_URLS
36 | return urls_dict.get(model_name, {})
37 |
38 | def list_available_models() -> list:
39 | """List all available model names"""
40 | return list(MODEL_URLS.keys())
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/nn/modulate_layers.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 | import torch
3 | import torch.nn as nn
4 |
5 | class ModulateDiT(nn.Module):
6 | def __init__(self, hidden_size: int, factor: int, act_layer: Callable, dtype=None, device=None):
7 | factory_kwargs = {"dtype": dtype, "device": device}
8 | super().__init__()
9 | self.act = act_layer()
10 | self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
11 | # Zero-initialize the modulation
12 | nn.init.zeros_(self.linear.weight)
13 | nn.init.zeros_(self.linear.bias)
14 |
15 | def forward(self, x: torch.Tensor) -> torch.Tensor:
16 | return self.linear(self.act(x))
17 |
18 |
19 | def modulate(x, shift=None, scale=None):
20 | if x.ndim == 3:
21 | shift = shift.unsqueeze(1) if shift is not None and shift.ndim == 2 else None
22 | scale = scale.unsqueeze(1) if scale is not None and scale.ndim == 2 else None
23 | if scale is None and shift is None:
24 | return x
25 | elif shift is None:
26 | return x * (1 + scale)
27 | elif scale is None:
28 | return x + shift
29 | else:
30 | return x * (1 + scale) + shift
31 |
32 |
33 | def apply_gate(x, gate=None, tanh=False):
34 | if gate is None:
35 | return x
36 | if gate.ndim == 2 and x.ndim == 3:
37 | gate = gate.unsqueeze(1)
38 | if tanh:
39 | return x * gate.tanh()
40 | else:
41 | return x * gate
42 |
43 |
44 | def ckpt_wrapper(module):
45 | def ckpt_forward(*inputs):
46 | outputs = module(*inputs)
47 | return outputs
48 |
49 | return ckpt_forward
50 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/constants.py:
--------------------------------------------------------------------------------
1 | """Constants used throughout the HunyuanVideo-Foley project."""
2 |
3 | from typing import Dict, List
4 |
5 | # Model configuration
6 | DEFAULT_AUDIO_SAMPLE_RATE = 48000
7 | DEFAULT_VIDEO_FPS = 25
8 | DEFAULT_AUDIO_CHANNELS = 2
9 |
10 | # Video processing
11 | MAX_VIDEO_DURATION_SECONDS = 15.0
12 | MIN_VIDEO_DURATION_SECONDS = 1.0
13 |
14 | # Audio processing
15 | AUDIO_VAE_LATENT_DIM = 128
16 | AUDIO_FRAME_RATE = 75 # frames per second in latent space
17 |
18 | # Visual features
19 | FPS_VISUAL: Dict[str, int] = {
20 | "siglip2": 8,
21 | "synchformer": 25
22 | }
23 |
24 | # Model paths (can be overridden by environment variables)
25 | DEFAULT_MODEL_PATH = "./pretrained_models/"
26 | DEFAULT_CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml"
27 |
28 | # Inference parameters
29 | DEFAULT_GUIDANCE_SCALE = 4.5
30 | DEFAULT_NUM_INFERENCE_STEPS = 50
31 | MIN_GUIDANCE_SCALE = 1.0
32 | MAX_GUIDANCE_SCALE = 10.0
33 | MIN_INFERENCE_STEPS = 10
34 | MAX_INFERENCE_STEPS = 100
35 |
36 | # Text processing
37 | MAX_TEXT_LENGTH = 100
38 | DEFAULT_NEGATIVE_PROMPT = "noisy, harsh"
39 |
40 | # File extensions
41 | SUPPORTED_VIDEO_EXTENSIONS: List[str] = [".mp4", ".avi", ".mov", ".mkv", ".webm"]
42 | SUPPORTED_AUDIO_EXTENSIONS: List[str] = [".wav", ".mp3", ".flac", ".aac"]
43 |
44 | # Quality settings
45 | AUDIO_QUALITY_SETTINGS: Dict[str, List[str]] = {
46 | "high": ["-b:a", "192k"],
47 | "medium": ["-b:a", "128k"],
48 | "low": ["-b:a", "96k"]
49 | }
50 |
51 | # Error messages
52 | ERROR_MESSAGES: Dict[str, str] = {
53 | "model_not_loaded": "Model is not loaded. Please load the model first.",
54 | "invalid_video_format": "Unsupported video format. Supported formats: {formats}",
55 | "video_too_long": f"Video duration exceeds maximum of {MAX_VIDEO_DURATION_SECONDS} seconds",
56 | "ffmpeg_not_found": "ffmpeg not found. Please install ffmpeg: https://ffmpeg.org/download.html"
57 | }
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/synchformer/divided_224_16x4.yaml:
--------------------------------------------------------------------------------
1 | TRAIN:
2 | ENABLE: True
3 | DATASET: Ssv2
4 | BATCH_SIZE: 32
5 | EVAL_PERIOD: 5
6 | CHECKPOINT_PERIOD: 5
7 | AUTO_RESUME: True
8 | CHECKPOINT_EPOCH_RESET: True
9 | CHECKPOINT_FILE_PATH: /checkpoint/fmetze/neurips_sota/40944587/checkpoints/checkpoint_epoch_00035.pyth
10 | DATA:
11 | NUM_FRAMES: 16
12 | SAMPLING_RATE: 4
13 | TRAIN_JITTER_SCALES: [256, 320]
14 | TRAIN_CROP_SIZE: 224
15 | TEST_CROP_SIZE: 224
16 | INPUT_CHANNEL_NUM: [3]
17 | MEAN: [0.5, 0.5, 0.5]
18 | STD: [0.5, 0.5, 0.5]
19 | PATH_TO_DATA_DIR: /private/home/mandelapatrick/slowfast/data/ssv2
20 | PATH_PREFIX: /datasets01/SomethingV2/092720/20bn-something-something-v2-frames
21 | INV_UNIFORM_SAMPLE: True
22 | RANDOM_FLIP: False
23 | REVERSE_INPUT_CHANNEL: True
24 | USE_RAND_AUGMENT: True
25 | RE_PROB: 0.0
26 | USE_REPEATED_AUG: False
27 | USE_RANDOM_RESIZE_CROPS: False
28 | COLORJITTER: False
29 | GRAYSCALE: False
30 | GAUSSIAN: False
31 | SOLVER:
32 | BASE_LR: 1e-4
33 | LR_POLICY: steps_with_relative_lrs
34 | LRS: [1, 0.1, 0.01]
35 | STEPS: [0, 20, 30]
36 | MAX_EPOCH: 35
37 | MOMENTUM: 0.9
38 | WEIGHT_DECAY: 5e-2
39 | WARMUP_EPOCHS: 0.0
40 | OPTIMIZING_METHOD: adamw
41 | USE_MIXED_PRECISION: True
42 | SMOOTHING: 0.2
43 | SLOWFAST:
44 | ALPHA: 8
45 | VIT:
46 | PATCH_SIZE: 16
47 | PATCH_SIZE_TEMP: 2
48 | CHANNELS: 3
49 | EMBED_DIM: 768
50 | DEPTH: 12
51 | NUM_HEADS: 12
52 | MLP_RATIO: 4
53 | QKV_BIAS: True
54 | VIDEO_INPUT: True
55 | TEMPORAL_RESOLUTION: 8
56 | USE_MLP: True
57 | DROP: 0.0
58 | POS_DROPOUT: 0.0
59 | DROP_PATH: 0.2
60 | IM_PRETRAINED: True
61 | HEAD_DROPOUT: 0.0
62 | HEAD_ACT: tanh
63 | PRETRAINED_WEIGHTS: vit_1k
64 | ATTN_LAYER: divided
65 | MODEL:
66 | NUM_CLASSES: 174
67 | ARCH: slow
68 | MODEL_NAME: VisionTransformer
69 | LOSS_FUNC: cross_entropy
70 | TEST:
71 | ENABLE: True
72 | DATASET: Ssv2
73 | BATCH_SIZE: 64
74 | NUM_ENSEMBLE_VIEWS: 1
75 | NUM_SPATIAL_CROPS: 3
76 | DATA_LOADER:
77 | NUM_WORKERS: 4
78 | PIN_MEMORY: True
79 | NUM_GPUS: 8
80 | NUM_SHARDS: 4
81 | RNG_SEED: 0
82 | OUTPUT_DIR: .
83 | TENSORBOARD:
84 | ENABLE: True
85 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/nn/norm_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class RMSNorm(nn.Module):
5 | def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6,
6 | device=None, dtype=None):
7 | """
8 | Initialize the RMSNorm normalization layer.
9 |
10 | Args:
11 | dim (int): The dimension of the input tensor.
12 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
13 |
14 | Attributes:
15 | eps (float): A small value added to the denominator for numerical stability.
16 | weight (nn.Parameter): Learnable scaling parameter.
17 |
18 | """
19 | factory_kwargs = {'device': device, 'dtype': dtype}
20 | super().__init__()
21 | self.eps = eps
22 | if elementwise_affine:
23 | self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
24 |
25 | def _norm(self, x):
26 | """
27 | Apply the RMSNorm normalization to the input tensor.
28 |
29 | Args:
30 | x (torch.Tensor): The input tensor.
31 |
32 | Returns:
33 | torch.Tensor: The normalized tensor.
34 |
35 | """
36 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
37 |
38 | def forward(self, x):
39 | """
40 | Forward pass through the RMSNorm layer.
41 |
42 | Args:
43 | x (torch.Tensor): The input tensor.
44 |
45 | Returns:
46 | torch.Tensor: The output tensor after applying RMSNorm.
47 |
48 | """
49 | output = self._norm(x.float()).type_as(x)
50 | if hasattr(self, "weight"):
51 | output = output * self.weight
52 | return output
53 |
54 |
55 | def get_norm_layer(norm_layer):
56 | """
57 | Get the normalization layer.
58 |
59 | Args:
60 | norm_layer (str): The type of normalization layer.
61 |
62 | Returns:
63 | norm_layer (nn.Module): The normalization layer.
64 | """
65 | if norm_layer == "layer":
66 | return nn.LayerNorm
67 | elif norm_layer == "rms":
68 | return RMSNorm
69 | else:
70 | raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
71 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.0])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.mean(
45 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
46 | dim=[1, 2],
47 | )
48 | else:
49 | return 0.5 * torch.mean(
50 | torch.pow(self.mean - other.mean, 2) / other.var
51 | + self.var / other.var
52 | - 1.0
53 | - self.logvar
54 | + other.logvar,
55 | dim=[1, 2],
56 | )
57 |
58 | def nll(self, sample, dims=[1, 2]):
59 | if self.deterministic:
60 | return torch.Tensor([0.0])
61 | logtwopi = np.log(2.0 * np.pi)
62 | return 0.5 * torch.sum(
63 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
64 | dim=dims,
65 | )
66 |
67 | def mode(self):
68 | return self.mean
69 |
70 |
71 | def normal_kl(mean1, logvar1, mean2, logvar2):
72 | """
73 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
74 | Compute the KL divergence between two gaussians.
75 | Shapes are automatically broadcasted, so batches can be compared to
76 | scalars, among other use cases.
77 | """
78 | tensor = None
79 | for obj in (mean1, logvar1, mean2, logvar2):
80 | if isinstance(obj, torch.Tensor):
81 | tensor = obj
82 | break
83 | assert tensor is not None, "at least one argument must be a Tensor"
84 |
85 | # Force variances to be Tensors. Broadcasting helps convert scalars to
86 | # Tensors, but it does not work for torch.exp().
87 | logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
88 |
89 | return 0.5 * (
90 | -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
91 | )
92 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/utils/decode.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from pathlib import Path
3 |
4 | import argbind
5 | import numpy as np
6 | import torch
7 | from audiotools import AudioSignal
8 | from tqdm import tqdm
9 |
10 | from ..model import DACFile
11 | from . import load_model
12 |
13 | warnings.filterwarnings("ignore", category=UserWarning)
14 |
15 |
16 | @argbind.bind(group="decode", positional=True, without_prefix=True)
17 | @torch.inference_mode()
18 | @torch.no_grad()
19 | def decode(
20 | input: str,
21 | output: str = "",
22 | weights_path: str = "",
23 | model_tag: str = "latest",
24 | model_bitrate: str = "8kbps",
25 | device: str = "cuda",
26 | model_type: str = "44khz",
27 | verbose: bool = False,
28 | ):
29 | """Decode audio from codes.
30 |
31 | Parameters
32 | ----------
33 | input : str
34 | Path to input directory or file
35 | output : str, optional
36 | Path to output directory, by default "".
37 | If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
38 | weights_path : str, optional
39 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
40 | model_tag and model_type.
41 | model_tag : str, optional
42 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
43 | model_bitrate: str
44 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
45 | device : str, optional
46 | Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
47 | model_type : str, optional
48 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
49 | """
50 | generator = load_model(
51 | model_type=model_type,
52 | model_bitrate=model_bitrate,
53 | tag=model_tag,
54 | load_path=weights_path,
55 | )
56 | generator.to(device)
57 | generator.eval()
58 |
59 | # Find all .dac files in input directory
60 | _input = Path(input)
61 | input_files = list(_input.glob("**/*.dac"))
62 |
63 | # If input is a .dac file, add it to the list
64 | if _input.suffix == ".dac":
65 | input_files.append(_input)
66 |
67 | # Create output directory
68 | output = Path(output)
69 | output.mkdir(parents=True, exist_ok=True)
70 |
71 | for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
72 | # Load file
73 | artifact = DACFile.load(input_files[i])
74 |
75 | # Reconstruct audio from codes
76 | recons = generator.decompress(artifact, verbose=verbose)
77 |
78 | # Compute output path
79 | relative_path = input_files[i].relative_to(input)
80 | output_dir = output / relative_path.parent
81 | if not relative_path.name:
82 | output_dir = output
83 | relative_path = input_files[i]
84 | output_name = relative_path.with_suffix(".wav").name
85 | output_path = output_dir / output_name
86 | output_path.parent.mkdir(parents=True, exist_ok=True)
87 |
88 | # Write to file
89 | recons.write(output_path)
90 |
91 |
92 | if __name__ == "__main__":
93 | args = argbind.parse_args()
94 | with argbind.scope(args):
95 | decode()
96 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/utils/media_utils.py:
--------------------------------------------------------------------------------
1 | """Media utilities for audio/video processing."""
2 |
3 | import os
4 | import subprocess
5 | from pathlib import Path
6 | from typing import Optional
7 |
8 | from loguru import logger
9 |
10 |
11 | class MediaProcessingError(Exception):
12 | """Exception raised for media processing errors."""
13 | pass
14 |
15 |
16 | def merge_audio_video(
17 | audio_path: str,
18 | video_path: str,
19 | output_path: str,
20 | overwrite: bool = True,
21 | quality: str = "high"
22 | ) -> str:
23 | """
24 | Merge audio and video files using ffmpeg.
25 |
26 | Args:
27 | audio_path: Path to input audio file
28 | video_path: Path to input video file
29 | output_path: Path for output video file
30 | overwrite: Whether to overwrite existing output file
31 | quality: Quality setting ('high', 'medium', 'low')
32 |
33 | Returns:
34 | Path to the output file
35 |
36 | Raises:
37 | MediaProcessingError: If input files don't exist or ffmpeg fails
38 | FileNotFoundError: If ffmpeg is not installed
39 | """
40 | # Validate input files
41 | if not os.path.exists(audio_path):
42 | raise MediaProcessingError(f"Audio file not found: {audio_path}")
43 | if not os.path.exists(video_path):
44 | raise MediaProcessingError(f"Video file not found: {video_path}")
45 |
46 | # Create output directory if needed
47 | output_dir = Path(output_path).parent
48 | output_dir.mkdir(parents=True, exist_ok=True)
49 |
50 | # Quality settings
51 | quality_settings = {
52 | "high": ["-b:a", "192k"],
53 | "medium": ["-b:a", "128k"],
54 | "low": ["-b:a", "96k"]
55 | }
56 |
57 | # Build ffmpeg command
58 | ffmpeg_command = [
59 | "ffmpeg",
60 | "-i", video_path,
61 | "-i", audio_path,
62 | "-c:v", "copy",
63 | "-c:a", "aac",
64 | "-ac", "2",
65 | "-af", "pan=stereo|c0=c0|c1=c0",
66 | "-map", "0:v:0",
67 | "-map", "1:a:0",
68 | *quality_settings.get(quality, quality_settings["high"]),
69 | ]
70 |
71 | if overwrite:
72 | ffmpeg_command.append("-y")
73 |
74 | ffmpeg_command.append(output_path)
75 |
76 | try:
77 | logger.info(f"Merging audio '{audio_path}' with video '{video_path}'")
78 | process = subprocess.Popen(
79 | ffmpeg_command,
80 | stdout=subprocess.PIPE,
81 | stderr=subprocess.PIPE,
82 | text=True
83 | )
84 | stdout, stderr = process.communicate()
85 |
86 | if process.returncode != 0:
87 | error_msg = f"FFmpeg failed with return code {process.returncode}: {stderr}"
88 | logger.error(error_msg)
89 | raise MediaProcessingError(error_msg)
90 | else:
91 | logger.info(f"Successfully merged video saved to: {output_path}")
92 |
93 | except FileNotFoundError:
94 | raise FileNotFoundError(
95 | "ffmpeg not found. Please install ffmpeg: "
96 | "https://ffmpeg.org/download.html"
97 | )
98 | except Exception as e:
99 | raise MediaProcessingError(f"Unexpected error during media processing: {e}")
100 |
101 | return output_path
102 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/utils/encode.py:
--------------------------------------------------------------------------------
1 | import math
2 | import warnings
3 | from pathlib import Path
4 |
5 | import argbind
6 | import numpy as np
7 | import torch
8 | from audiotools import AudioSignal
9 | from audiotools.core import util
10 | from tqdm import tqdm
11 |
12 | from . import load_model
13 |
14 | warnings.filterwarnings("ignore", category=UserWarning)
15 |
16 |
17 | @argbind.bind(group="encode", positional=True, without_prefix=True)
18 | @torch.inference_mode()
19 | @torch.no_grad()
20 | def encode(
21 | input: str,
22 | output: str = "",
23 | weights_path: str = "",
24 | model_tag: str = "latest",
25 | model_bitrate: str = "8kbps",
26 | n_quantizers: int = None,
27 | device: str = "cuda",
28 | model_type: str = "44khz",
29 | win_duration: float = 5.0,
30 | verbose: bool = False,
31 | ):
32 | """Encode audio files in input path to .dac format.
33 |
34 | Parameters
35 | ----------
36 | input : str
37 | Path to input audio file or directory
38 | output : str, optional
39 | Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
40 | weights_path : str, optional
41 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
42 | model_tag and model_type.
43 | model_tag : str, optional
44 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
45 | model_bitrate: str
46 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
47 | n_quantizers : int, optional
48 | Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
49 | device : str, optional
50 | Device to use, by default "cuda"
51 | model_type : str, optional
52 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
53 | """
54 | generator = load_model(
55 | model_type=model_type,
56 | model_bitrate=model_bitrate,
57 | tag=model_tag,
58 | load_path=weights_path,
59 | )
60 | generator.to(device)
61 | generator.eval()
62 | kwargs = {"n_quantizers": n_quantizers}
63 |
64 | # Find all audio files in input path
65 | input = Path(input)
66 | audio_files = util.find_audio(input)
67 |
68 | output = Path(output)
69 | output.mkdir(parents=True, exist_ok=True)
70 |
71 | for i in tqdm(range(len(audio_files)), desc="Encoding files"):
72 | # Load file
73 | signal = AudioSignal(audio_files[i])
74 |
75 | # Encode audio to .dac format
76 | artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs)
77 |
78 | # Compute output path
79 | relative_path = audio_files[i].relative_to(input)
80 | output_dir = output / relative_path.parent
81 | if not relative_path.name:
82 | output_dir = output
83 | relative_path = audio_files[i]
84 | output_name = relative_path.with_suffix(".dac").name
85 | output_path = output_dir / output_name
86 | output_path.parent.mkdir(parents=True, exist_ok=True)
87 |
88 | artifact.save(output_path)
89 |
90 |
91 | if __name__ == "__main__":
92 | args = argbind.parse_args()
93 | with argbind.scope(args):
94 | encode()
95 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import argbind
4 | from audiotools import ml
5 |
6 | from ..model import DAC
7 | Accelerator = ml.Accelerator
8 |
9 | __MODEL_LATEST_TAGS__ = {
10 | ("44khz", "8kbps"): "0.0.1",
11 | ("24khz", "8kbps"): "0.0.4",
12 | ("16khz", "8kbps"): "0.0.5",
13 | ("44khz", "16kbps"): "1.0.0",
14 | }
15 |
16 | __MODEL_URLS__ = {
17 | (
18 | "44khz",
19 | "0.0.1",
20 | "8kbps",
21 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
22 | (
23 | "24khz",
24 | "0.0.4",
25 | "8kbps",
26 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
27 | (
28 | "16khz",
29 | "0.0.5",
30 | "8kbps",
31 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
32 | (
33 | "44khz",
34 | "1.0.0",
35 | "16kbps",
36 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
37 | }
38 |
39 |
40 | @argbind.bind(group="download", positional=True, without_prefix=True)
41 | def download(
42 | model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
43 | ):
44 | """
45 | Function that downloads the weights file from URL if a local cache is not found.
46 |
47 | Parameters
48 | ----------
49 | model_type : str
50 | The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
51 | model_bitrate: str
52 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
53 | Only 44khz model supports 16kbps.
54 | tag : str
55 | The tag of the model to download. Defaults to "latest".
56 |
57 | Returns
58 | -------
59 | Path
60 | Directory path required to load model via audiotools.
61 | """
62 | model_type = model_type.lower()
63 | tag = tag.lower()
64 |
65 | assert model_type in [
66 | "44khz",
67 | "24khz",
68 | "16khz",
69 | ], "model_type must be one of '44khz', '24khz', or '16khz'"
70 |
71 | assert model_bitrate in [
72 | "8kbps",
73 | "16kbps",
74 | ], "model_bitrate must be one of '8kbps', or '16kbps'"
75 |
76 | if tag == "latest":
77 | tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
78 |
79 | download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
80 |
81 | if download_link is None:
82 | raise ValueError(
83 | f"Could not find model with tag {tag} and model type {model_type}"
84 | )
85 |
86 | local_path = (
87 | Path.home()
88 | / ".cache"
89 | / "descript"
90 | / "dac"
91 | / f"weights_{model_type}_{model_bitrate}_{tag}.pth"
92 | )
93 | if not local_path.exists():
94 | local_path.parent.mkdir(parents=True, exist_ok=True)
95 |
96 | # Download the model
97 | import requests
98 |
99 | response = requests.get(download_link)
100 |
101 | if response.status_code != 200:
102 | raise ValueError(
103 | f"Could not download model. Received response code {response.status_code}"
104 | )
105 | local_path.write_bytes(response.content)
106 |
107 | return local_path
108 |
109 |
110 | def load_model(
111 | model_type: str = "44khz",
112 | model_bitrate: str = "8kbps",
113 | tag: str = "latest",
114 | load_path: str = None,
115 | ):
116 | if not load_path:
117 | load_path = download(
118 | model_type=model_type, model_bitrate=model_bitrate, tag=tag
119 | )
120 | generator = DAC.load(load_path)
121 | return generator
122 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/utils/config_utils.py:
--------------------------------------------------------------------------------
1 | """Configuration utilities for the HunyuanVideo-Foley project."""
2 |
3 | import yaml
4 | from pathlib import Path
5 | from typing import Any, Dict, List, Union
6 |
7 | class AttributeDict:
8 |
9 | def __init__(self, data: Union[Dict, List, Any]):
10 | if isinstance(data, dict):
11 | for key, value in data.items():
12 | if isinstance(value, (dict, list)):
13 | value = AttributeDict(value)
14 | setattr(self, self._sanitize_key(key), value)
15 | elif isinstance(data, list):
16 | self._list = [AttributeDict(item) if isinstance(item, (dict, list)) else item
17 | for item in data]
18 | else:
19 | self._value = data
20 |
21 | def _sanitize_key(self, key: str) -> str:
22 | import re
23 | sanitized = re.sub(r'[^a-zA-Z0-9_]', '_', str(key))
24 | if sanitized[0].isdigit():
25 | sanitized = f'_{sanitized}'
26 | return sanitized
27 |
28 | def __getitem__(self, key):
29 | if hasattr(self, '_list'):
30 | return self._list[key]
31 | return getattr(self, self._sanitize_key(key))
32 |
33 | def __setitem__(self, key, value):
34 | if hasattr(self, '_list'):
35 | self._list[key] = value
36 | else:
37 | setattr(self, self._sanitize_key(key), value)
38 |
39 | def __iter__(self):
40 | if hasattr(self, '_list'):
41 | return iter(self._list)
42 | return iter(self.__dict__.keys())
43 |
44 | def __len__(self):
45 | if hasattr(self, '_list'):
46 | return len(self._list)
47 | return len(self.__dict__)
48 |
49 | def get(self, key, default=None):
50 | try:
51 | return self[key]
52 | except (KeyError, AttributeError, IndexError):
53 | return default
54 |
55 | def keys(self):
56 | if hasattr(self, '_list'):
57 | return range(len(self._list))
58 | elif hasattr(self, '_value'):
59 | return []
60 | else:
61 | return [key for key in self.__dict__.keys() if not key.startswith('_')]
62 |
63 | def values(self):
64 | if hasattr(self, '_list'):
65 | return self._list
66 | elif hasattr(self, '_value'):
67 | return [self._value]
68 | else:
69 | return [value for key, value in self.__dict__.items() if not key.startswith('_')]
70 |
71 | def items(self):
72 | if hasattr(self, '_list'):
73 | return enumerate(self._list)
74 | elif hasattr(self, '_value'):
75 | return []
76 | else:
77 | return [(key, value) for key, value in self.__dict__.items() if not key.startswith('_')]
78 |
79 | def __repr__(self):
80 | if hasattr(self, '_list'):
81 | return f"AttributeDict({self._list})"
82 | elif hasattr(self, '_value'):
83 | return f"AttributeDict({self._value})"
84 | return f"AttributeDict({dict(self.__dict__)})"
85 |
86 | def to_dict(self) -> Union[Dict, List, Any]:
87 | if hasattr(self, '_list'):
88 | return [item.to_dict() if isinstance(item, AttributeDict) else item
89 | for item in self._list]
90 | elif hasattr(self, '_value'):
91 | return self._value
92 | else:
93 | result = {}
94 | for key, value in self.__dict__.items():
95 | if isinstance(value, AttributeDict):
96 | result[key] = value.to_dict()
97 | else:
98 | result[key] = value
99 | return result
100 |
101 | def load_yaml(file_path: str, encoding: str = 'utf-8') -> AttributeDict:
102 | try:
103 | with open(file_path, 'r', encoding=encoding) as file:
104 | data = yaml.safe_load(file)
105 | return AttributeDict(data)
106 | except FileNotFoundError:
107 | raise FileNotFoundError(f"YAML file not found: {file_path}")
108 | except yaml.YAMLError as e:
109 | raise yaml.YAMLError(f"YAML format error: {e}")
110 |
--------------------------------------------------------------------------------
/install.py:
--------------------------------------------------------------------------------
1 | """
2 | Installation script for ComfyUI HunyuanVideo-Foley Custom Node
3 | """
4 |
5 | import os
6 | import sys
7 | import subprocess
8 | import pkg_resources
9 | from pathlib import Path
10 |
11 | def check_and_install_requirements():
12 | """Check and install required packages"""
13 | requirements_file = Path(__file__).parent / "requirements.txt"
14 |
15 | if not requirements_file.exists():
16 | print("Requirements file not found!")
17 | return False
18 |
19 | try:
20 | print("Checking and installing requirements...")
21 |
22 | # Read requirements
23 | with open(requirements_file, 'r') as f:
24 | requirements = f.read().splitlines()
25 |
26 | # Filter out comments and empty lines
27 | requirements = [line.strip() for line in requirements
28 | if line.strip() and not line.strip().startswith('#')]
29 |
30 | # Install packages
31 | for requirement in requirements:
32 | try:
33 | # Skip git+ requirements for now (they need special handling)
34 | if requirement.startswith('git+'):
35 | print(f"Installing git requirement: {requirement}")
36 | subprocess.check_call([sys.executable, '-m', 'pip', 'install', requirement])
37 | else:
38 | # Check if package is already installed
39 | try:
40 | pkg_resources.require([requirement])
41 | print(f"✓ {requirement} already installed")
42 | except pkg_resources.DistributionNotFound:
43 | print(f"Installing {requirement}...")
44 | subprocess.check_call([sys.executable, '-m', 'pip', 'install', requirement])
45 | except pkg_resources.VersionConflict:
46 | print(f"Updating {requirement}...")
47 | subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--upgrade', requirement])
48 |
49 | except subprocess.CalledProcessError as e:
50 | print(f"Failed to install {requirement}: {e}")
51 | return False
52 |
53 | print("✅ All requirements installed successfully!")
54 | return True
55 |
56 | except Exception as e:
57 | print(f"Error installing requirements: {e}")
58 | return False
59 |
60 | def setup_model_directories():
61 | """Create necessary model directories"""
62 | base_dir = Path(__file__).parent.parent.parent # Go up to ComfyUI root
63 |
64 | # Create ComfyUI/models/foley directory for automatic downloads
65 | foley_models_dir = base_dir / "models" / "foley"
66 | foley_models_dir.mkdir(parents=True, exist_ok=True)
67 | print(f"✓ Created ComfyUI models directory: {foley_models_dir}")
68 |
69 | # Also create local fallback directories
70 | node_dir = Path(__file__).parent
71 | local_dirs = [
72 | node_dir / "pretrained_models",
73 | node_dir / "configs"
74 | ]
75 |
76 | for dir_path in local_dirs:
77 | dir_path.mkdir(exist_ok=True)
78 | print(f"✓ Created local directory: {dir_path}")
79 |
80 | def main():
81 | """Main installation function"""
82 | print("🚀 Installing ComfyUI HunyuanVideo-Foley Custom Node...")
83 |
84 | # Install requirements
85 | if not check_and_install_requirements():
86 | print("❌ Failed to install requirements")
87 | return False
88 |
89 | # Setup directories
90 | setup_model_directories()
91 |
92 | print("📁 Directory structure created")
93 | print("📋 Installation completed!")
94 | print()
95 | print("📌 Next steps:")
96 | print("1. Restart ComfyUI to load the custom nodes")
97 | print("2. Models will be automatically downloaded when you first use the node")
98 | print("3. Alternatively, manually download models and place them in ComfyUI/models/foley/")
99 | print("4. Model URLs are configured in model_urls.py (can be updated as needed)")
100 | print()
101 |
102 | return True
103 |
104 | if __name__ == "__main__":
105 | main()
--------------------------------------------------------------------------------
/hunyuanvideo_foley/utils/helper.py:
--------------------------------------------------------------------------------
1 | import collections.abc
2 | from itertools import repeat
3 | import importlib
4 | import yaml
5 | import time
6 |
7 | def default(value, default_val):
8 | return default_val if value is None else value
9 |
10 |
11 | def default_dtype(value, default_val):
12 | if value is not None:
13 | assert isinstance(value, type(default_val)), f"Expect {type(default_val)}, got {type(value)}."
14 | return value
15 | return default_val
16 |
17 |
18 | def repeat_interleave(lst, num_repeats):
19 | return [item for item in lst for _ in range(num_repeats)]
20 |
21 |
22 | def _ntuple(n):
23 | def parse(x):
24 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
25 | x = tuple(x)
26 | if len(x) == 1:
27 | x = tuple(repeat(x[0], n))
28 | return x
29 | return tuple(repeat(x, n))
30 |
31 | return parse
32 |
33 |
34 | to_1tuple = _ntuple(1)
35 | to_2tuple = _ntuple(2)
36 | to_3tuple = _ntuple(3)
37 | to_4tuple = _ntuple(4)
38 |
39 |
40 | def as_tuple(x):
41 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
42 | return tuple(x)
43 | if x is None or isinstance(x, (int, float, str)):
44 | return (x,)
45 | else:
46 | raise ValueError(f"Unknown type {type(x)}")
47 |
48 |
49 | def as_list_of_2tuple(x):
50 | x = as_tuple(x)
51 | if len(x) == 1:
52 | x = (x[0], x[0])
53 | assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
54 | lst = []
55 | for i in range(0, len(x), 2):
56 | lst.append((x[i], x[i + 1]))
57 | return lst
58 |
59 |
60 | def find_multiple(n: int, k: int) -> int:
61 | assert k > 0
62 | if n % k == 0:
63 | return n
64 | return n - (n % k) + k
65 |
66 |
67 | def merge_dicts(dict1, dict2):
68 | for key, value in dict2.items():
69 | if key in dict1 and isinstance(dict1[key], dict) and isinstance(value, dict):
70 | merge_dicts(dict1[key], value)
71 | else:
72 | dict1[key] = value
73 | return dict1
74 |
75 |
76 | def merge_yaml_files(file_list):
77 | merged_config = {}
78 |
79 | for file in file_list:
80 | with open(file, "r", encoding="utf-8") as f:
81 | config = yaml.safe_load(f)
82 | if config:
83 | # Remove the first level
84 | for key, value in config.items():
85 | if isinstance(value, dict):
86 | merged_config = merge_dicts(merged_config, value)
87 | else:
88 | merged_config[key] = value
89 |
90 | return merged_config
91 |
92 |
93 | def merge_dict(file_list):
94 | merged_config = {}
95 |
96 | for file in file_list:
97 | with open(file, "r", encoding="utf-8") as f:
98 | config = yaml.safe_load(f)
99 | if config:
100 | merged_config = merge_dicts(merged_config, config)
101 |
102 | return merged_config
103 |
104 |
105 | def get_obj_from_str(string, reload=False):
106 | module, cls = string.rsplit(".", 1)
107 | if reload:
108 | module_imp = importlib.import_module(module)
109 | importlib.reload(module_imp)
110 | return getattr(importlib.import_module(module, package=None), cls)
111 |
112 |
113 | def readable_time(seconds):
114 | """ Convert time seconds to a readable format: DD Days, HH Hours, MM Minutes, SS Seconds """
115 | seconds = int(seconds)
116 | days, seconds = divmod(seconds, 86400)
117 | hours, seconds = divmod(seconds, 3600)
118 | minutes, seconds = divmod(seconds, 60)
119 | if days > 0:
120 | return f"{days} Days, {hours} Hours, {minutes} Minutes, {seconds} Seconds"
121 | if hours > 0:
122 | return f"{hours} Hours, {minutes} Minutes, {seconds} Seconds"
123 | if minutes > 0:
124 | return f"{minutes} Minutes, {seconds} Seconds"
125 | return f"{seconds} Seconds"
126 |
127 |
128 | def get_obj_from_cfg(cfg, reload=False):
129 | if isinstance(cfg, str):
130 | return get_obj_from_str(cfg, reload)
131 | elif isinstance(cfg, (list, tuple,)):
132 | return tuple([get_obj_from_str(c, reload) for c in cfg])
133 | else:
134 | raise NotImplementedError(f"Not implemented for {type(cfg)}.")
135 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/synchformer/utils.py:
--------------------------------------------------------------------------------
1 | from hashlib import md5
2 | from pathlib import Path
3 | import subprocess
4 |
5 | import requests
6 | from tqdm import tqdm
7 |
8 | PARENT_LINK = "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a"
9 | FNAME2LINK = {
10 | # S3: Synchability: AudioSet (run 2)
11 | "24-01-22T20-34-52.pt": f"{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/24-01-22T20-34-52.pt",
12 | "cfg-24-01-22T20-34-52.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/cfg-24-01-22T20-34-52.yaml",
13 | # S2: Synchformer: AudioSet (run 2)
14 | "24-01-04T16-39-21.pt": f"{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/24-01-04T16-39-21.pt",
15 | "cfg-24-01-04T16-39-21.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml",
16 | # S2: Synchformer: AudioSet (run 1)
17 | "23-08-28T11-23-23.pt": f"{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/23-08-28T11-23-23.pt",
18 | "cfg-23-08-28T11-23-23.yaml": f"{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/cfg-23-08-28T11-23-23.yaml",
19 | # S2: Synchformer: LRS3 (run 2)
20 | "23-12-23T18-33-57.pt": f"{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/23-12-23T18-33-57.pt",
21 | "cfg-23-12-23T18-33-57.yaml": f"{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/cfg-23-12-23T18-33-57.yaml",
22 | # S2: Synchformer: VGS (run 2)
23 | "24-01-02T10-00-53.pt": f"{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/24-01-02T10-00-53.pt",
24 | "cfg-24-01-02T10-00-53.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/cfg-24-01-02T10-00-53.yaml",
25 | # SparseSync: ft VGGSound-Full
26 | "22-09-21T21-00-52.pt": f"{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/22-09-21T21-00-52.pt",
27 | "cfg-22-09-21T21-00-52.yaml": f"{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/cfg-22-09-21T21-00-52.yaml",
28 | # SparseSync: ft VGGSound-Sparse
29 | "22-07-28T15-49-45.pt": f"{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/22-07-28T15-49-45.pt",
30 | "cfg-22-07-28T15-49-45.yaml": f"{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/cfg-22-07-28T15-49-45.yaml",
31 | # SparseSync: only pt on LRS3
32 | "22-07-13T22-25-49.pt": f"{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/22-07-13T22-25-49.pt",
33 | "cfg-22-07-13T22-25-49.yaml": f"{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/cfg-22-07-13T22-25-49.yaml",
34 | # SparseSync: feature extractors
35 | "ResNetAudio-22-08-04T09-51-04.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-04T09-51-04.pt", # 2s
36 | "ResNetAudio-22-08-03T23-14-49.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-49.pt", # 3s
37 | "ResNetAudio-22-08-03T23-14-28.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-28.pt", # 4s
38 | "ResNetAudio-22-06-24T08-10-33.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T08-10-33.pt", # 5s
39 | "ResNetAudio-22-06-24T17-31-07.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T17-31-07.pt", # 6s
40 | "ResNetAudio-22-06-24T23-57-11.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T23-57-11.pt", # 7s
41 | "ResNetAudio-22-06-25T04-35-42.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-25T04-35-42.pt", # 8s
42 | }
43 |
44 |
45 | def check_if_file_exists_else_download(path, fname2link=FNAME2LINK, chunk_size=1024):
46 | """Checks if file exists, if not downloads it from the link to the path"""
47 | path = Path(path)
48 | if not path.exists():
49 | path.parent.mkdir(exist_ok=True, parents=True)
50 | link = fname2link.get(path.name, None)
51 | if link is None:
52 | raise ValueError(
53 | f"Cant find the checkpoint file: {path}.", f"Please download it manually and ensure the path exists."
54 | )
55 | with requests.get(fname2link[path.name], stream=True) as r:
56 | total_size = int(r.headers.get("content-length", 0))
57 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
58 | with open(path, "wb") as f:
59 | for data in r.iter_content(chunk_size=chunk_size):
60 | if data:
61 | f.write(data)
62 | pbar.update(chunk_size)
63 |
64 |
65 | def which_ffmpeg() -> str:
66 | """Determines the path to ffmpeg library
67 | Returns:
68 | str -- path to the library
69 | """
70 | result = subprocess.run(["which", "ffmpeg"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
71 | ffmpeg_path = result.stdout.decode("utf-8").replace("\n", "")
72 | return ffmpeg_path
73 |
74 |
75 | def get_md5sum(path):
76 | hash_md5 = md5()
77 | with open(path, "rb") as f:
78 | for chunk in iter(lambda: f.read(4096 * 8), b""):
79 | hash_md5.update(chunk)
80 | md5sum = hash_md5.hexdigest()
81 | return md5sum
82 |
83 |
84 | class Config:
85 | def __init__(self, **kwargs):
86 | for k, v in kwargs.items():
87 | setattr(self, k, v)
88 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/nn/embed_layers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 | from ...utils.helper import to_2tuple, to_1tuple
6 |
7 | class PatchEmbed1D(nn.Module):
8 | """1D Audio to Patch Embedding
9 |
10 | A convolution based approach to patchifying a 1D audio w/ embedding projection.
11 |
12 | Based on the impl in https://github.com/google-research/vision_transformer
13 |
14 | Hacked together by / Copyright 2020 Ross Wightman
15 | """
16 |
17 | def __init__(
18 | self,
19 | patch_size=1,
20 | in_chans=768,
21 | embed_dim=768,
22 | norm_layer=None,
23 | flatten=True,
24 | bias=True,
25 | dtype=None,
26 | device=None,
27 | ):
28 | factory_kwargs = {"dtype": dtype, "device": device}
29 | super().__init__()
30 | patch_size = to_1tuple(patch_size)
31 | self.patch_size = patch_size
32 | self.flatten = flatten
33 |
34 | self.proj = nn.Conv1d(
35 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs
36 | )
37 | nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
38 | if bias:
39 | nn.init.zeros_(self.proj.bias)
40 |
41 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
42 |
43 | def forward(self, x):
44 | assert (
45 | x.shape[2] % self.patch_size[0] == 0
46 | ), f"The patch_size of {self.patch_size[0]} must be divisible by the token number ({x.shape[2]}) of x."
47 |
48 | x = self.proj(x)
49 | if self.flatten:
50 | x = x.transpose(1, 2) # BCN -> BNC
51 | x = self.norm(x)
52 | return x
53 |
54 |
55 | class ConditionProjection(nn.Module):
56 | """
57 | Projects condition embeddings. Also handles dropout for classifier-free guidance.
58 |
59 | Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
60 | """
61 |
62 | def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
63 | factory_kwargs = {'dtype': dtype, 'device': device}
64 | super().__init__()
65 | self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
66 | self.act_1 = act_layer()
67 | self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
68 |
69 | def forward(self, caption):
70 | hidden_states = self.linear_1(caption)
71 | hidden_states = self.act_1(hidden_states)
72 | hidden_states = self.linear_2(hidden_states)
73 | return hidden_states
74 |
75 |
76 | def timestep_embedding(t, dim, max_period=10000):
77 | """
78 | Create sinusoidal timestep embeddings.
79 |
80 | Args:
81 | t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
82 | dim (int): the dimension of the output.
83 | max_period (int): controls the minimum frequency of the embeddings.
84 |
85 | Returns:
86 | embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
87 |
88 | .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
89 | """
90 | half = dim // 2
91 | freqs = torch.exp(
92 | -math.log(max_period)
93 | * torch.arange(start=0, end=half, dtype=torch.float32)
94 | / half
95 | ).to(device=t.device)
96 | args = t[:, None].float() * freqs[None]
97 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
98 | if dim % 2:
99 | embedding = torch.cat(
100 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
101 | )
102 | return embedding
103 |
104 |
105 | class TimestepEmbedder(nn.Module):
106 | """
107 | Embeds scalar timesteps into vector representations.
108 | """
109 | def __init__(self,
110 | hidden_size,
111 | act_layer,
112 | frequency_embedding_size=256,
113 | max_period=10000,
114 | out_size=None,
115 | dtype=None,
116 | device=None
117 | ):
118 | factory_kwargs = {'dtype': dtype, 'device': device}
119 | super().__init__()
120 | self.frequency_embedding_size = frequency_embedding_size
121 | self.max_period = max_period
122 | if out_size is None:
123 | out_size = hidden_size
124 |
125 | self.mlp = nn.Sequential(
126 | nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
127 | act_layer(),
128 | nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
129 | )
130 | nn.init.normal_(self.mlp[0].weight, std=0.02)
131 | nn.init.normal_(self.mlp[2].weight, std=0.02)
132 |
133 | def forward(self, t):
134 | t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
135 | t_emb = self.mlp(t_freq)
136 | return t_emb
137 |
--------------------------------------------------------------------------------
/download_models_manual.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Manual download helper for HunyuanVideo-Foley models
4 | Run this script directly if the automatic download fails in ComfyUI
5 | """
6 |
7 | import os
8 | import sys
9 | from pathlib import Path
10 | import urllib.request
11 | import time
12 |
13 | # Model download URLs
14 | MODELS = [
15 | {
16 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/hunyuanvideo_foley.pth",
17 | "filename": "hunyuanvideo_foley.pth",
18 | "size_gb": 10.3,
19 | "description": "Main HunyuanVideo-Foley model"
20 | },
21 | {
22 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/synchformer_state_dict.pth",
23 | "filename": "synchformer_state_dict.pth",
24 | "size_gb": 0.95,
25 | "description": "Synchformer model weights"
26 | },
27 | {
28 | "url": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/vae_128d_48k.pth",
29 | "filename": "vae_128d_48k.pth",
30 | "size_gb": 1.49,
31 | "description": "VAE model weights"
32 | }
33 | ]
34 |
35 | def download_with_progress(url, dest_path):
36 | """Download file with progress display"""
37 | def progress_hook(block_num, block_size, total_size):
38 | if total_size > 0:
39 | downloaded = block_num * block_size
40 | percent = min(100, (downloaded * 100) // total_size)
41 | size_mb = downloaded / (1024 * 1024)
42 | total_mb = total_size / (1024 * 1024)
43 |
44 | # Print progress
45 | bar_len = 40
46 | filled_len = int(bar_len * percent // 100)
47 | bar = '█' * filled_len + '-' * (bar_len - filled_len)
48 |
49 | sys.stdout.write(f'\r[{bar}] {percent}% ({size_mb:.1f}/{total_mb:.1f} MB)')
50 | sys.stdout.flush()
51 |
52 | try:
53 | urllib.request.urlretrieve(url, dest_path, progress_hook)
54 | print() # New line after progress
55 | return True
56 | except Exception as e:
57 | print(f"\nError: {e}")
58 | return False
59 |
60 | def main():
61 | # Determine ComfyUI models directory
62 | comfyui_root = Path(__file__).parent.parent.parent # Go up to ComfyUI root
63 | models_dir = comfyui_root / "models" / "foley" / "hunyuanvideo-foley-xxl"
64 |
65 | print("=" * 60)
66 | print("HunyuanVideo-Foley Model Downloader")
67 | print("=" * 60)
68 | print(f"\nModels will be downloaded to:")
69 | print(f" {models_dir}")
70 |
71 | # Create directory if it doesn't exist
72 | models_dir.mkdir(parents=True, exist_ok=True)
73 |
74 | # Check disk space
75 | import shutil
76 | stat = shutil.disk_usage(models_dir)
77 | available_gb = stat.free / (1024 ** 3)
78 | required_gb = sum(m["size_gb"] for m in MODELS) + 1 # Add 1GB buffer
79 |
80 | print(f"\nDisk space available: {available_gb:.1f} GB")
81 | print(f"Space required: {required_gb:.1f} GB")
82 |
83 | if available_gb < required_gb:
84 | print(f"\n⚠️ WARNING: Insufficient disk space!")
85 | print(f"Please free up at least {required_gb - available_gb:.1f} GB before continuing.")
86 | response = input("\nContinue anyway? (y/n): ")
87 | if response.lower() != 'y':
88 | return
89 |
90 | print("\n" + "=" * 60)
91 |
92 | # Download each model
93 | for i, model_info in enumerate(MODELS, 1):
94 | model_path = models_dir / model_info["filename"]
95 |
96 | print(f"\n[{i}/{len(MODELS)}] {model_info['description']}")
97 | print(f" File: {model_info['filename']} ({model_info['size_gb']} GB)")
98 |
99 | # Check if already downloaded
100 | if model_path.exists() and model_path.stat().st_size > 100 * 1024 * 1024:
101 | size_gb = model_path.stat().st_size / (1024 ** 3)
102 | print(f" ✓ Already downloaded ({size_gb:.2f} GB)")
103 | continue
104 |
105 | print(f" Downloading from: {model_info['url']}")
106 |
107 | # Try downloading with retries
108 | max_retries = 3
109 | for attempt in range(max_retries):
110 | if attempt > 0:
111 | print(f" Retry {attempt}/{max_retries - 1}...")
112 | time.sleep(5)
113 |
114 | success = download_with_progress(model_info["url"], model_path)
115 | if success:
116 | size_gb = model_path.stat().st_size / (1024 ** 3)
117 | print(f" ✓ Downloaded successfully ({size_gb:.2f} GB)")
118 | break
119 | else:
120 | if attempt == max_retries - 1:
121 | print(f" ✗ Failed to download after {max_retries} attempts")
122 | print(f"\n You can manually download from:")
123 | print(f" {model_info['url']}")
124 | print(f" And place it at:")
125 | print(f" {model_path}")
126 |
127 | print("\n" + "=" * 60)
128 | print("Download process completed!")
129 | print("\nYou can now use the HunyuanVideo-Foley node in ComfyUI.")
130 | print("=" * 60)
131 |
132 | if __name__ == "__main__":
133 | main()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ComfyUI HunyuanVideo-Foley Custom Node
2 |
3 | This is a ComfyUI custom node wrapper for the HunyuanVideo-Foley model, which generates realistic audio from video and text descriptions.
4 |
5 |
6 | ## ✨大佬的原插件已更新,安装原插件就行 (https://github.com/if-ai/ComfyUI_HunyuanVideoFoley)
7 | ## The original plug-in of the boss has been updated, just install the original plug-in
8 |
9 | ## 模型卸载 Model Unloading
10 |
11 |
12 |
13 |
14 |
15 | ## Features
16 |
17 | - **Text-Video-to-Audio Synthesis**: Generate realistic audio that matches your video content
18 | - **Flexible Text Prompts**: Use optional text descriptions to guide audio generation
19 | - **Multiple Samples**: Generate up to 6 different audio variations per inference
20 | - **Configurable Parameters**: Control guidance scale, inference steps, and sampling
21 | - **Seed Control**: Reproducible results with seed parameter
22 | - **Model Caching**: Efficient model loading and reuse across generations
23 | - **Automatic Model Downloads**: Models are automatically downloaded to `ComfyUI/models/foley/` when needed
24 |
25 | ## Installation
26 |
27 | 1. **Clone this repository** into your ComfyUI custom_nodes directory:
28 | ```bash
29 | cd ComfyUI/custom_nodes
30 | git clone https://github.com/yichengup/ComfyUI_ycHunyuanVideoFoley.git
31 | ```
32 |
33 | 2. **Install dependencies**:
34 | ```bash
35 | cd ComfyUI_ycHunyuanVideoFoley
36 | pip install -r requirements.txt
37 | ```
38 |
39 | 3. **Run the installation script** (recommended):
40 | ```bash
41 | python install.py
42 | ```
43 |
44 | 4. **Restart ComfyUI** to load the new nodes.
45 |
46 | ### Model Setup
47 |
48 | The models can be obtained in two ways:
49 |
50 | #### Option 1: Automatic Download (Recommended)
51 | - Models will be automatically downloaded to `ComfyUI/models/foley/` when you first run the node
52 | - No manual setup required
53 | - Progress will be shown in the ComfyUI console
54 |
55 | #### Option 2: Manual Download
56 | - Download models from [HuggingFace](https://huggingface.co/tencent/HunyuanVideo-Foley)
57 | - Place models in `ComfyUI/models/foley/` (recommended) or `./pretrained_models/` directory
58 | - Ensure the config file is at `configs/hunyuanvideo-foley-xxl.yaml`
59 |
60 | ## Usage
61 |
62 | ### Node Types
63 |
64 | #### 1. HunyuanVideo-Foley Generator
65 | Main node for generating audio from video and text.
66 |
67 | **Inputs:**
68 | - **video**: Video input (VIDEO type)
69 | - **text_prompt**: Text description of desired audio (STRING)
70 | - **guidance_scale**: CFG scale for generation control (1.0-10.0, default: 4.5)
71 | - **num_inference_steps**: Number of denoising steps (10-100, default: 50)
72 | - **sample_nums**: Number of audio samples to generate (1-6, default: 1)
73 | - **seed**: Random seed for reproducibility (INT)
74 | - **model_path**: Path to pretrained models (optional, leave empty for auto-download)
75 |
76 | **Outputs:**
77 | - **video_with_audio**: Video with generated audio merged (VIDEO)
78 | - **audio_only**: Generated audio file (AUDIO)
79 | - **status_message**: Generation status and info (STRING)
80 |
81 | ## ⚠ Important Limitations
82 |
83 | ### **Frame Count & Duration Limits**
84 | - **Maximum Frames**: 450 frames (hard limit)
85 | - **Maximum Duration**: 15 seconds at 30fps
86 | - **Recommended**: Keep videos ≤15 seconds for best results
87 |
88 | ### **FPS Recommendations**
89 | - **30fps**: Max 15 seconds (450 frames)
90 | - **24fps**: Max 18.75 seconds (450 frames)
91 | - **15fps**: Max 30 seconds (450 frames)
92 |
93 | ### **Long Video Solutions**
94 | For videos longer than 15 seconds:
95 | 1. **Reduce FPS**: Lower FPS allows longer duration within frame limit
96 | 2. **Segment Processing**: Split long videos into 15s segments
97 | 3. **Audio Merging**: Combine generated audio segments in post-processing
98 |
99 |
100 | ## Example Workflow
101 |
102 | 1. **Load Video**: Use a "Load Video" node to input your video file
103 | 2. **Add Generator**: Add the "HunyuanVideo-Foley Generator" node
104 | 3. **Connect Video**: Connect the video output to the generator's video input
105 | 4. **Set Prompt**: Enter a text description (e.g., "A person walks on frozen ice")
106 | 5. **Adjust Settings**: Configure guidance scale, steps, and sample count as needed
107 | 6. **Generate**: Run the workflow to generate audio
108 |
109 | ## Model Requirements
110 |
111 | The node expects the following model structure:
112 | ```
113 | pretrained_models/
114 | ├── hunyuanvideo_foley.pth # Main Foley model
115 | ├── vae_128d_48k.pth # DAC VAE model
116 | └── synchformer_state_dict.pth # Synchformer model
117 |
118 | configs/
119 | └── hunyuanvideo-foley-xxl.yaml # Configuration file
120 | ```
121 |
122 |
123 | ## License
124 |
125 | This custom node is based on the HunyuanVideo-Foley project. Please check the original project's license terms.
126 |
127 | ## Credits
128 |
129 | Based on the HunyuanVideo-Foley project by Tencent. Original paper and code available at:
130 | - Paper: [HunyuanVideo-Foley: Text-Video-to-Audio Synthesis]
131 |
132 | - Code: [https://github.com/tencent/HunyuanVideo-Foley]
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
--------------------------------------------------------------------------------
/INSTALLATION_GUIDE.md:
--------------------------------------------------------------------------------
1 | # Installation Guide for ComfyUI HunyuanVideo-Foley Custom Node
2 |
3 | ## Overview
4 |
5 | This custom node wraps the HunyuanVideo-Foley model for use in ComfyUI, enabling text-video-to-audio synthesis directly within ComfyUI workflows.
6 |
7 | ## Prerequisites
8 |
9 | - ComfyUI installation
10 | - Python 3.8+
11 | - CUDA-capable GPU (recommended 16GB+ VRAM)
12 | - At least 32GB system RAM
13 |
14 | ## Step-by-Step Installation
15 |
16 | ### 1. Clone the Custom Node
17 |
18 | ```bash
19 | cd /path/to/ComfyUI/custom_nodes
20 | git clone ComfyUI_HunyuanVideoFoley
21 | cd ComfyUI_HunyuanVideoFoley
22 | ```
23 |
24 | ### 2. Install Dependencies
25 |
26 | ```bash
27 | # Install Python dependencies
28 | pip install -r requirements.txt
29 |
30 | # Or run the installation script
31 | python install.py
32 | ```
33 |
34 | ### 3. Download Models
35 |
36 | You need to download the HunyuanVideo-Foley models:
37 |
38 | ```bash
39 | # Option 1: Manual download from HuggingFace
40 | # Visit: https://huggingface.co/tencent/HunyuanVideo-Foley
41 | # Download and place files in the following structure:
42 |
43 | mkdir -p pretrained_models
44 | mkdir -p configs
45 |
46 | # Download these files to pretrained_models/:
47 | # - hunyuanvideo_foley.pth
48 | # - vae_128d_48k.pth
49 | # - synchformer_state_dict.pth
50 |
51 | # Download config file to configs/:
52 | # - hunyuanvideo-foley-xxl.yaml
53 | ```
54 |
55 | Or use the Gradio app's auto-download feature by setting the environment variable:
56 |
57 | ```bash
58 | export HIFI_FOLEY_MODEL_PATH="/path/to/ComfyUI/custom_nodes/ComfyUI_HunyuanVideoFoley/pretrained_models"
59 | ```
60 |
61 | ### 4. Verify Installation
62 |
63 | ```bash
64 | # Run the test script to verify everything is working
65 | python test_node.py
66 | ```
67 |
68 | ### 5. Restart ComfyUI
69 |
70 | After installation, restart ComfyUI to load the new custom nodes.
71 |
72 | ## Expected Directory Structure
73 |
74 | After installation, your directory should look like this:
75 |
76 | ```
77 | ComfyUI_HunyuanVideoFoley/
78 | ├── __init__.py
79 | ├── nodes.py
80 | ├── utils.py
81 | ├── requirements.txt
82 | ├── install.py
83 | ├── test_node.py
84 | ├── example_workflow.json
85 | ├── README.md
86 | ├── INSTALLATION_GUIDE.md
87 | └── pyproject.toml
88 |
89 | pretrained_models/
90 | ├── hunyuanvideo_foley.pth
91 | ├── vae_128d_48k.pth
92 | └── synchformer_state_dict.pth
93 |
94 | configs/
95 | └── hunyuanvideo-foley-xxl.yaml
96 | ```
97 |
98 | ## Usage
99 |
100 | ### Nodes Available
101 |
102 | 1. **HunyuanVideo-Foley Generator**
103 | - Main node for audio generation
104 | - Inputs: video, text prompt, generation parameters
105 | - Outputs: video with audio, audio only, status message
106 |
107 | 2. **HunyuanVideo-Foley Model Loader**
108 | - Separate model loading node
109 | - Useful for sharing models between multiple generator nodes
110 | - Inputs: model path, config path
111 | - Outputs: model handle, status message
112 |
113 | ### Basic Workflow
114 |
115 | 1. Load a video using ComfyUI's video input nodes
116 | 2. Add the "HunyuanVideo-Foley Generator" node
117 | 3. Connect the video to the generator
118 | 4. Set your text prompt (e.g., "A person walks on frozen ice")
119 | 5. Adjust parameters as needed
120 | 6. Run the workflow
121 |
122 | ### Example Workflow
123 |
124 | An example workflow is provided in `example_workflow.json`. Load this file in ComfyUI to see a basic setup.
125 |
126 | ## Performance Tips
127 |
128 | - **VRAM Usage**: The model requires significant GPU memory. If you encounter CUDA out of memory errors:
129 | - Reduce `sample_nums` parameter
130 | - Lower `num_inference_steps`
131 | - Use CPU mode (slower but works with less memory)
132 |
133 | - **Generation Time**: Audio generation can take several minutes depending on:
134 | - Video length
135 | - Number of inference steps
136 | - Number of samples generated
137 | - Hardware specifications
138 |
139 | ## Troubleshooting
140 |
141 | ### Common Issues
142 |
143 | 1. **"Failed to import HunyuanVideo-Foley modules"**
144 | ```bash
145 | # Make sure you're in the correct directory and have all dependencies
146 | pip install -r requirements.txt
147 | ```
148 |
149 | 2. **"Model path does not exist"**
150 | ```bash
151 | # Download models from HuggingFace and verify directory structure
152 | ls pretrained_models/
153 | # Should show: hunyuanvideo_foley.pth, vae_128d_48k.pth, synchformer_state_dict.pth
154 | ```
155 |
156 | 3. **CUDA out of memory**
157 | ```bash
158 | # Reduce memory usage by adjusting parameters:
159 | # - Lower sample_nums to 1
160 | # - Reduce num_inference_steps to 25
161 | # - Use shorter videos for testing
162 | ```
163 |
164 | 4. **Slow generation**
165 | ```bash
166 | # Normal for first run (model loading)
167 | # Subsequent runs should be faster
168 | # Consider using fewer inference steps for faster results
169 | ```
170 |
171 | ### Getting Help
172 |
173 | - Check the test script output: `python test_node.py`
174 | - Review ComfyUI console output for detailed error messages
175 | - Ensure all model files are downloaded correctly
176 | - Verify GPU memory availability
177 |
178 | ## Model Information
179 |
180 | The HunyuanVideo-Foley model consists of several components:
181 |
182 | - **Main Foley Model**: Core text-video-to-audio generation
183 | - **DAC VAE**: Audio encoding/decoding
184 | - **SigLIP2**: Visual feature extraction
185 | - **CLAP**: Text feature extraction
186 | - **Synchformer**: Video-audio synchronization
187 |
188 | All components are automatically loaded when using the custom node.
189 |
190 | ## License & Credits
191 |
192 | This custom node is based on the HunyuanVideo-Foley project by Tencent. Please respect the original project's license terms when using this implementation.
193 |
194 | Original project: https://github.com/tencent/HunyuanVideo-Foley
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/nn/mlp_layers.py:
--------------------------------------------------------------------------------
1 | # Modified from timm library:
2 | # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
3 |
4 | from functools import partial
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | from .modulate_layers import modulate
11 | from ...utils.helper import to_2tuple
12 |
13 | class MLP(nn.Module):
14 | """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
15 |
16 | def __init__(
17 | self,
18 | in_channels,
19 | hidden_channels=None,
20 | out_features=None,
21 | act_layer=nn.GELU,
22 | norm_layer=None,
23 | bias=True,
24 | drop=0.0,
25 | use_conv=False,
26 | device=None,
27 | dtype=None,
28 | ):
29 | factory_kwargs = {"device": device, "dtype": dtype}
30 | super().__init__()
31 | out_features = out_features or in_channels
32 | hidden_channels = hidden_channels or in_channels
33 | bias = to_2tuple(bias)
34 | drop_probs = to_2tuple(drop)
35 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
36 |
37 | self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs)
38 | self.act = act_layer()
39 | self.drop1 = nn.Dropout(drop_probs[0])
40 | self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity()
41 | self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs)
42 | self.drop2 = nn.Dropout(drop_probs[1])
43 |
44 | def forward(self, x):
45 | x = self.fc1(x)
46 | x = self.act(x)
47 | x = self.drop1(x)
48 | x = self.norm(x)
49 | x = self.fc2(x)
50 | x = self.drop2(x)
51 | return x
52 |
53 |
54 | # copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
55 | # only used when use_vanilla is True
56 | class MLPEmbedder(nn.Module):
57 | def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
58 | factory_kwargs = {"device": device, "dtype": dtype}
59 | super().__init__()
60 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
61 | self.silu = nn.SiLU()
62 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
63 |
64 | def forward(self, x: torch.Tensor) -> torch.Tensor:
65 | return self.out_layer(self.silu(self.in_layer(x)))
66 |
67 |
68 | class LinearWarpforSingle(nn.Module):
69 | def __init__(self, in_dim: int, out_dim: int, bias=True, device=None, dtype=None):
70 | factory_kwargs = {"device": device, "dtype": dtype}
71 | super().__init__()
72 | self.fc = nn.Linear(in_dim, out_dim, bias=bias, **factory_kwargs)
73 |
74 | def forward(self, x, y):
75 | z = torch.cat([x, y], dim=2)
76 | return self.fc(z)
77 |
78 | class FinalLayer1D(nn.Module):
79 | def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None):
80 | factory_kwargs = {"device": device, "dtype": dtype}
81 | super().__init__()
82 |
83 | # Just use LayerNorm for the final layer
84 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
85 | self.linear = nn.Linear(hidden_size, patch_size * out_channels, bias=True, **factory_kwargs)
86 | nn.init.zeros_(self.linear.weight)
87 | nn.init.zeros_(self.linear.bias)
88 |
89 | # Here we don't distinguish between the modulate types. Just use the simple one.
90 | self.adaLN_modulation = nn.Sequential(
91 | act_layer(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs)
92 | )
93 | # Zero-initialize the modulation
94 | nn.init.zeros_(self.adaLN_modulation[1].weight)
95 | nn.init.zeros_(self.adaLN_modulation[1].bias)
96 |
97 | def forward(self, x, c):
98 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
99 | x = modulate(self.norm_final(x), shift=shift, scale=scale)
100 | x = self.linear(x)
101 | return x
102 |
103 |
104 | class ChannelLastConv1d(nn.Conv1d):
105 |
106 | def forward(self, x: torch.Tensor) -> torch.Tensor:
107 | x = x.permute(0, 2, 1)
108 | x = super().forward(x)
109 | x = x.permute(0, 2, 1)
110 | return x
111 |
112 |
113 | class ConvMLP(nn.Module):
114 |
115 | def __init__(
116 | self,
117 | dim: int,
118 | hidden_dim: int,
119 | multiple_of: int = 256,
120 | kernel_size: int = 3,
121 | padding: int = 1,
122 | device=None,
123 | dtype=None,
124 | ):
125 | """
126 | Convolutional MLP module.
127 |
128 | Args:
129 | dim (int): Input dimension.
130 | hidden_dim (int): Hidden dimension of the feedforward layer.
131 | multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
132 |
133 | Attributes:
134 | w1: Linear transformation for the first layer.
135 | w2: Linear transformation for the second layer.
136 | w3: Linear transformation for the third layer.
137 |
138 | """
139 | factory_kwargs = {"device": device, "dtype": dtype}
140 | super().__init__()
141 | hidden_dim = int(2 * hidden_dim / 3)
142 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
143 |
144 | self.w1 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
145 | self.w2 = ChannelLastConv1d(hidden_dim, dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
146 | self.w3 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
147 |
148 | def forward(self, x):
149 | return self.w2(F.silu(self.w1(x)) * self.w3(x))
150 |
--------------------------------------------------------------------------------
/test_node.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Test script for ComfyUI HunyuanVideo-Foley custom node
4 | """
5 |
6 | import sys
7 | import os
8 | import tempfile
9 | from pathlib import Path
10 |
11 | # Add the parent directory to path for imports
12 | current_dir = Path(__file__).parent
13 | parent_dir = current_dir.parent
14 | sys.path.insert(0, str(parent_dir))
15 |
16 | def test_imports():
17 | """Test that all required modules can be imported"""
18 | print("Testing imports...")
19 |
20 | try:
21 | from ComfyUI_HunyuanVideoFoley import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
22 | print("✅ Successfully imported node mappings")
23 |
24 | print(f"Available nodes: {list(NODE_CLASS_MAPPINGS.keys())}")
25 | print(f"Display names: {NODE_DISPLAY_NAME_MAPPINGS}")
26 |
27 | return True
28 |
29 | except ImportError as e:
30 | print(f"❌ Import failed: {e}")
31 | return False
32 |
33 | def test_node_structure():
34 | """Test node class structure"""
35 | print("\nTesting node structure...")
36 |
37 | try:
38 | from ComfyUI_HunyuanVideoFoley.nodes import HunyuanVideoFoleyNode, HunyuanVideoFoleyModelLoader
39 |
40 | # Test HunyuanVideoFoleyNode
41 | node = HunyuanVideoFoleyNode()
42 | input_types = node.INPUT_TYPES()
43 |
44 | print("✅ HunyuanVideoFoleyNode structure:")
45 | print(f" - Required inputs: {list(input_types['required'].keys())}")
46 | print(f" - Optional inputs: {list(input_types.get('optional', {}).keys())}")
47 | print(f" - Return types: {node.RETURN_TYPES}")
48 | print(f" - Function: {node.FUNCTION}")
49 | print(f" - Category: {node.CATEGORY}")
50 |
51 | # Test HunyuanVideoFoleyModelLoader
52 | loader = HunyuanVideoFoleyModelLoader()
53 | loader_input_types = loader.INPUT_TYPES()
54 |
55 | print("✅ HunyuanVideoFoleyModelLoader structure:")
56 | print(f" - Required inputs: {list(loader_input_types['required'].keys())}")
57 | print(f" - Return types: {loader.RETURN_TYPES}")
58 | print(f" - Function: {loader.FUNCTION}")
59 |
60 | return True
61 |
62 | except Exception as e:
63 | print(f"❌ Node structure test failed: {e}")
64 | return False
65 |
66 | def test_device_setup():
67 | """Test device setup functionality"""
68 | print("\nTesting device setup...")
69 |
70 | try:
71 | from ComfyUI_HunyuanVideoFoley.nodes import HunyuanVideoFoleyNode
72 |
73 | device = HunyuanVideoFoleyNode.setup_device("auto")
74 | print(f"✅ Device setup successful: {device}")
75 |
76 | return True
77 |
78 | except Exception as e:
79 | print(f"❌ Device setup failed: {e}")
80 | return False
81 |
82 | def test_utils():
83 | """Test utility functions"""
84 | print("\nTesting utility functions...")
85 |
86 | try:
87 | from ComfyUI_HunyuanVideoFoley.utils import (
88 | get_optimal_device,
89 | check_memory_requirements,
90 | format_duration,
91 | validate_model_files
92 | )
93 |
94 | # Test device detection
95 | device = get_optimal_device()
96 | print(f"✅ Optimal device: {device}")
97 |
98 | # Test memory check
99 | has_memory, msg = check_memory_requirements(device)
100 | print(f"✅ Memory check: {msg}")
101 |
102 | # Test duration formatting
103 | duration = format_duration(125.5)
104 | print(f"✅ Duration formatting: 125.5s -> {duration}")
105 |
106 | # Test model validation (will fail without models, but that's expected)
107 | is_valid, msg = validate_model_files("./pretrained_models/")
108 | print(f"✅ Model validation: {msg}")
109 |
110 | return True
111 |
112 | except Exception as e:
113 | print(f"❌ Utils test failed: {e}")
114 | return False
115 |
116 | def test_requirements():
117 | """Test if key requirements are available"""
118 | print("\nTesting requirements...")
119 |
120 | required_packages = [
121 | 'torch',
122 | 'torchaudio',
123 | 'numpy',
124 | 'loguru',
125 | 'diffusers',
126 | 'transformers'
127 | ]
128 |
129 | missing = []
130 | for package in required_packages:
131 | try:
132 | __import__(package)
133 | print(f"✅ {package}")
134 | except ImportError:
135 | print(f"❌ {package} - not installed")
136 | missing.append(package)
137 |
138 | if missing:
139 | print(f"\nMissing packages: {', '.join(missing)}")
140 | print("Run: pip install -r requirements.txt")
141 | return False
142 |
143 | return True
144 |
145 | def main():
146 | """Main test function"""
147 | print("🧪 Testing ComfyUI HunyuanVideo-Foley Custom Node")
148 | print("=" * 50)
149 |
150 | tests = [
151 | ("Requirements", test_requirements),
152 | ("Imports", test_imports),
153 | ("Node Structure", test_node_structure),
154 | ("Device Setup", test_device_setup),
155 | ("Utils", test_utils),
156 | ]
157 |
158 | passed = 0
159 | failed = 0
160 |
161 | for test_name, test_func in tests:
162 | print(f"\n🔍 Running test: {test_name}")
163 | try:
164 | if test_func():
165 | passed += 1
166 | print(f"✅ {test_name} PASSED")
167 | else:
168 | failed += 1
169 | print(f"❌ {test_name} FAILED")
170 | except Exception as e:
171 | failed += 1
172 | print(f"❌ {test_name} FAILED with exception: {e}")
173 |
174 | print("\n" + "=" * 50)
175 | print(f"📊 Test Results: {passed} passed, {failed} failed")
176 |
177 | if failed == 0:
178 | print("🎉 All tests passed! The custom node is ready for use.")
179 | else:
180 | print("⚠️ Some tests failed. Please check the issues above.")
181 |
182 | return failed == 0
183 |
184 | if __name__ == "__main__":
185 | success = main()
186 | sys.exit(0 if success else 1)
--------------------------------------------------------------------------------
/hunyuanvideo_foley/utils/feature_utils.py:
--------------------------------------------------------------------------------
1 | """Feature extraction utilities for video and text processing."""
2 |
3 | import os
4 | import numpy as np
5 | import torch
6 | import av
7 | from PIL import Image
8 | from einops import rearrange
9 | from typing import Any, Dict, List, Union, Tuple
10 | from loguru import logger
11 |
12 | from .config_utils import AttributeDict
13 | from ..constants import FPS_VISUAL, MAX_VIDEO_DURATION_SECONDS
14 |
15 |
16 | class FeatureExtractionError(Exception):
17 | """Exception raised for feature extraction errors."""
18 | pass
19 |
20 | def get_frames_av(
21 | video_path: str,
22 | fps: float,
23 | max_length: float = None,
24 | ) -> Tuple[np.ndarray, float]:
25 | end_sec = max_length if max_length is not None else 15
26 | next_frame_time_for_each_fps = 0.0
27 | time_delta_for_each_fps = 1 / fps
28 |
29 | all_frames = []
30 | output_frames = []
31 |
32 | with av.open(video_path) as container:
33 | stream = container.streams.video[0]
34 | ori_fps = stream.guessed_rate
35 | stream.thread_type = "AUTO"
36 | for packet in container.demux(stream):
37 | for frame in packet.decode():
38 | frame_time = frame.time
39 | if frame_time < 0:
40 | continue
41 | if frame_time > end_sec:
42 | break
43 |
44 | frame_np = None
45 |
46 | this_time = frame_time
47 | while this_time >= next_frame_time_for_each_fps:
48 | if frame_np is None:
49 | frame_np = frame.to_ndarray(format="rgb24")
50 |
51 | output_frames.append(frame_np)
52 | next_frame_time_for_each_fps += time_delta_for_each_fps
53 |
54 | output_frames = np.stack(output_frames)
55 |
56 | vid_len_in_s = len(output_frames) / fps
57 | if max_length is not None and len(output_frames) > int(max_length * fps):
58 | output_frames = output_frames[: int(max_length * fps)]
59 | vid_len_in_s = max_length
60 |
61 | return output_frames, vid_len_in_s
62 |
63 | @torch.inference_mode()
64 | def encode_video_with_siglip2(x: torch.Tensor, model_dict, batch_size: int = -1):
65 | b, t, c, h, w = x.shape
66 | if batch_size < 0:
67 | batch_size = b * t
68 | x = rearrange(x, "b t c h w -> (b t) c h w")
69 | outputs = []
70 | for i in range(0, b * t, batch_size):
71 | outputs.append(model_dict.siglip2_model.get_image_features(pixel_values=x[i : i + batch_size]))
72 | res = torch.cat(outputs, dim=0)
73 | res = rearrange(res, "(b t) d -> b t d", b=b)
74 | return res
75 |
76 | @torch.inference_mode()
77 | def encode_video_with_sync(x: torch.Tensor, model_dict, batch_size: int = -1):
78 | """
79 | The input video of x is best to be in fps of 24 of greater than 24.
80 | Input:
81 | x: tensor in shape of [B, T, C, H, W]
82 | batch_size: the batch_size for synchformer inference
83 | """
84 | b, t, c, h, w = x.shape
85 | assert c == 3 and h == 224 and w == 224
86 |
87 | segment_size = 16
88 | step_size = 8
89 | num_segments = (t - segment_size) // step_size + 1
90 | segments = []
91 | for i in range(num_segments):
92 | segments.append(x[:, i * step_size : i * step_size + segment_size])
93 | x = torch.stack(segments, dim=1).cuda() # (B, num_segments, segment_size, 3, 224, 224)
94 |
95 | outputs = []
96 | if batch_size < 0:
97 | batch_size = b * num_segments
98 | x = rearrange(x, "b s t c h w -> (b s) 1 t c h w")
99 | for i in range(0, b * num_segments, batch_size):
100 | with torch.autocast(device_type="cuda", enabled=True, dtype=torch.half):
101 | outputs.append(model_dict.syncformer_model(x[i : i + batch_size]))
102 | x = torch.cat(outputs, dim=0) # [b * num_segments, 1, 8, 768]
103 | x = rearrange(x, "(b s) 1 t d -> b (s t) d", b=b)
104 | return x
105 |
106 |
107 | @torch.inference_mode()
108 | def encode_video_features(video_path, model_dict):
109 | visual_features = {}
110 | # siglip2 visual features
111 | frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["siglip2"])
112 | images = [Image.fromarray(frame).convert('RGB') for frame in frames]
113 | images = [model_dict.siglip2_preprocess(image) for image in images] # [T, C, H, W]
114 | clip_frames = torch.stack(images).to(model_dict.device).unsqueeze(0)
115 | visual_features['siglip2_feat'] = encode_video_with_siglip2(clip_frames, model_dict).to(model_dict.device)
116 |
117 | # synchformer visual features
118 | frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["synchformer"])
119 | images = torch.from_numpy(frames).permute(0, 3, 1, 2) # [T, C, H, W]
120 | sync_frames = model_dict.syncformer_preprocess(images).unsqueeze(0) # [1, T, 3, 224, 224]
121 | # [1, num_segments * 8, channel_dim], e.g. [1, 240, 768] for 10s video
122 | visual_features['syncformer_feat'] = encode_video_with_sync(sync_frames, model_dict)
123 |
124 | vid_len_in_s = sync_frames.shape[1] / FPS_VISUAL["synchformer"]
125 | visual_features = AttributeDict(visual_features)
126 |
127 | return visual_features, vid_len_in_s
128 |
129 | @torch.inference_mode()
130 | def encode_text_feat(text: List[str], model_dict):
131 | # x: (B, L)
132 | inputs = model_dict.clap_tokenizer(text, padding=True, return_tensors="pt").to(model_dict.device)
133 | outputs = model_dict.clap_model(**inputs, output_hidden_states=True, return_dict=True)
134 | return outputs.last_hidden_state, outputs.attentions
135 |
136 |
137 | def feature_process(video_path, prompt, model_dict, cfg):
138 | visual_feats, audio_len_in_s = encode_video_features(video_path, model_dict)
139 | neg_prompt = "noisy, harsh"
140 | prompts = [neg_prompt, prompt]
141 | text_feat_res, text_feat_mask = encode_text_feat(prompts, model_dict)
142 |
143 | text_feat = text_feat_res[1:]
144 | uncond_text_feat = text_feat_res[:1]
145 |
146 | if cfg.model_config.model_kwargs.text_length < text_feat.shape[1]:
147 | text_seq_length = cfg.model_config.model_kwargs.text_length
148 | text_feat = text_feat[:, :text_seq_length]
149 | uncond_text_feat = uncond_text_feat[:, :text_seq_length]
150 |
151 | text_feats = AttributeDict({
152 | 'text_feat': text_feat,
153 | 'uncond_text_feat': uncond_text_feat,
154 | })
155 |
156 | return visual_feats, text_feats, audio_len_in_s
157 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/nn/posemb_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Union, Tuple
3 |
4 |
5 | def _to_tuple(x, dim=2):
6 | if isinstance(x, int):
7 | return (x,) * dim
8 | elif len(x) == dim:
9 | return x
10 | else:
11 | raise ValueError(f"Expected length {dim} or int, but got {x}")
12 |
13 |
14 | def get_meshgrid_nd(start, *args, dim=2):
15 | """
16 | Get n-D meshgrid with start, stop and num.
17 |
18 | Args:
19 | start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
20 | step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
21 | should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
22 | n-tuples.
23 | *args: See above.
24 | dim (int): Dimension of the meshgrid. Defaults to 2.
25 |
26 | Returns:
27 | grid (np.ndarray): [dim, ...]
28 | """
29 | if len(args) == 0:
30 | # start is grid_size
31 | num = _to_tuple(start, dim=dim)
32 | start = (0,) * dim
33 | stop = num
34 | elif len(args) == 1:
35 | # start is start, args[0] is stop, step is 1
36 | start = _to_tuple(start, dim=dim)
37 | stop = _to_tuple(args[0], dim=dim)
38 | num = [stop[i] - start[i] for i in range(dim)]
39 | elif len(args) == 2:
40 | # start is start, args[0] is stop, args[1] is num
41 | start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
42 | stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
43 | num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
44 | else:
45 | raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
46 |
47 | # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
48 | axis_grid = []
49 | for i in range(dim):
50 | a, b, n = start[i], stop[i], num[i]
51 | g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
52 | axis_grid.append(g)
53 | grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
54 | grid = torch.stack(grid, dim=0) # [dim, W, H, D]
55 |
56 | return grid
57 |
58 |
59 | #################################################################################
60 | # Rotary Positional Embedding Functions #
61 | #################################################################################
62 | # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
63 |
64 |
65 | def get_nd_rotary_pos_embed(
66 | rope_dim_list, start, *args, theta=10000.0, use_real=False, theta_rescale_factor=1.0, freq_scaling=1.0
67 | ):
68 | """
69 | This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
70 |
71 | Args:
72 | rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
73 | sum(rope_dim_list) should equal to head_dim of attention layer.
74 | start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
75 | args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
76 | *args: See above.
77 | theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
78 | use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
79 | Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
80 | part and an imaginary part separately.
81 | theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
82 | freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0.
83 |
84 | Returns:
85 | pos_embed (torch.Tensor): [HW, D/2]
86 | """
87 |
88 | grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
89 |
90 | # use 1/ndim of dimensions to encode grid_axis
91 | embs = []
92 | for i in range(len(rope_dim_list)):
93 | emb = get_1d_rotary_pos_embed(
94 | rope_dim_list[i],
95 | grid[i].reshape(-1),
96 | theta,
97 | use_real=use_real,
98 | theta_rescale_factor=theta_rescale_factor,
99 | freq_scaling=freq_scaling,
100 | ) # 2 x [WHD, rope_dim_list[i]]
101 | embs.append(emb)
102 |
103 | if use_real:
104 | cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
105 | sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
106 | return cos, sin
107 | else:
108 | emb = torch.cat(embs, dim=1) # (WHD, D/2)
109 | return emb
110 |
111 |
112 | def get_1d_rotary_pos_embed(
113 | dim: int,
114 | pos: Union[torch.FloatTensor, int],
115 | theta: float = 10000.0,
116 | use_real: bool = False,
117 | theta_rescale_factor: float = 1.0,
118 | freq_scaling: float = 1.0,
119 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
120 | """
121 | Precompute the frequency tensor for complex exponential (cis) with given dimensions.
122 | (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
123 |
124 | This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
125 | and the end index 'end'. The 'theta' parameter scales the frequencies.
126 | The returned tensor contains complex values in complex64 data type.
127 |
128 | Args:
129 | dim (int): Dimension of the frequency tensor.
130 | pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
131 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
132 | use_real (bool, optional): If True, return real part and imaginary part separately.
133 | Otherwise, return complex numbers.
134 | theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
135 | freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0.
136 |
137 | Returns:
138 | freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
139 | freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
140 | """
141 | if isinstance(pos, int):
142 | pos = torch.arange(pos).float()
143 |
144 | # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
145 | # has some connection to NTK literature
146 | # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
147 | if theta_rescale_factor != 1.0:
148 | theta *= theta_rescale_factor ** (dim / (dim - 1))
149 |
150 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
151 | freqs *= freq_scaling
152 | freqs = torch.outer(pos, freqs) # [S, D/2]
153 | if use_real:
154 | freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
155 | freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
156 | return freqs_cos, freqs_sin
157 | else:
158 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
159 | return freqs_cis
160 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/model/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from audiotools import AudioSignal
5 | from audiotools import ml
6 | from audiotools import STFTParams
7 | from einops import rearrange
8 | from torch.nn.utils import weight_norm
9 |
10 |
11 | def WNConv1d(*args, **kwargs):
12 | act = kwargs.pop("act", True)
13 | conv = weight_norm(nn.Conv1d(*args, **kwargs))
14 | if not act:
15 | return conv
16 | return nn.Sequential(conv, nn.LeakyReLU(0.1))
17 |
18 |
19 | def WNConv2d(*args, **kwargs):
20 | act = kwargs.pop("act", True)
21 | conv = weight_norm(nn.Conv2d(*args, **kwargs))
22 | if not act:
23 | return conv
24 | return nn.Sequential(conv, nn.LeakyReLU(0.1))
25 |
26 |
27 | class MPD(nn.Module):
28 | def __init__(self, period):
29 | super().__init__()
30 | self.period = period
31 | self.convs = nn.ModuleList(
32 | [
33 | WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
34 | WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
35 | WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
36 | WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
37 | WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
38 | ]
39 | )
40 | self.conv_post = WNConv2d(
41 | 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
42 | )
43 |
44 | def pad_to_period(self, x):
45 | t = x.shape[-1]
46 | x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
47 | return x
48 |
49 | def forward(self, x):
50 | fmap = []
51 |
52 | x = self.pad_to_period(x)
53 | x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
54 |
55 | for layer in self.convs:
56 | x = layer(x)
57 | fmap.append(x)
58 |
59 | x = self.conv_post(x)
60 | fmap.append(x)
61 |
62 | return fmap
63 |
64 |
65 | class MSD(nn.Module):
66 | def __init__(self, rate: int = 1, sample_rate: int = 44100):
67 | super().__init__()
68 | self.convs = nn.ModuleList(
69 | [
70 | WNConv1d(1, 16, 15, 1, padding=7),
71 | WNConv1d(16, 64, 41, 4, groups=4, padding=20),
72 | WNConv1d(64, 256, 41, 4, groups=16, padding=20),
73 | WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
74 | WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
75 | WNConv1d(1024, 1024, 5, 1, padding=2),
76 | ]
77 | )
78 | self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
79 | self.sample_rate = sample_rate
80 | self.rate = rate
81 |
82 | def forward(self, x):
83 | x = AudioSignal(x, self.sample_rate)
84 | x.resample(self.sample_rate // self.rate)
85 | x = x.audio_data
86 |
87 | fmap = []
88 |
89 | for l in self.convs:
90 | x = l(x)
91 | fmap.append(x)
92 | x = self.conv_post(x)
93 | fmap.append(x)
94 |
95 | return fmap
96 |
97 |
98 | BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
99 |
100 |
101 | class MRD(nn.Module):
102 | def __init__(
103 | self,
104 | window_length: int,
105 | hop_factor: float = 0.25,
106 | sample_rate: int = 44100,
107 | bands: list = BANDS,
108 | ):
109 | """Complex multi-band spectrogram discriminator.
110 | Parameters
111 | ----------
112 | window_length : int
113 | Window length of STFT.
114 | hop_factor : float, optional
115 | Hop factor of the STFT, defaults to ``0.25 * window_length``.
116 | sample_rate : int, optional
117 | Sampling rate of audio in Hz, by default 44100
118 | bands : list, optional
119 | Bands to run discriminator over.
120 | """
121 | super().__init__()
122 |
123 | self.window_length = window_length
124 | self.hop_factor = hop_factor
125 | self.sample_rate = sample_rate
126 | self.stft_params = STFTParams(
127 | window_length=window_length,
128 | hop_length=int(window_length * hop_factor),
129 | match_stride=True,
130 | )
131 |
132 | n_fft = window_length // 2 + 1
133 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
134 | self.bands = bands
135 |
136 | ch = 32
137 | convs = lambda: nn.ModuleList(
138 | [
139 | WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
140 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
141 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
142 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
143 | WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
144 | ]
145 | )
146 | self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
147 | self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
148 |
149 | def spectrogram(self, x):
150 | x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
151 | x = torch.view_as_real(x.stft())
152 | x = rearrange(x, "b 1 f t c -> (b 1) c t f")
153 | # Split into bands
154 | x_bands = [x[..., b[0] : b[1]] for b in self.bands]
155 | return x_bands
156 |
157 | def forward(self, x):
158 | x_bands = self.spectrogram(x)
159 | fmap = []
160 |
161 | x = []
162 | for band, stack in zip(x_bands, self.band_convs):
163 | for layer in stack:
164 | band = layer(band)
165 | fmap.append(band)
166 | x.append(band)
167 |
168 | x = torch.cat(x, dim=-1)
169 | x = self.conv_post(x)
170 | fmap.append(x)
171 |
172 | return fmap
173 |
174 |
175 | class Discriminator(ml.BaseModel):
176 | def __init__(
177 | self,
178 | rates: list = [],
179 | periods: list = [2, 3, 5, 7, 11],
180 | fft_sizes: list = [2048, 1024, 512],
181 | sample_rate: int = 44100,
182 | bands: list = BANDS,
183 | ):
184 | """Discriminator that combines multiple discriminators.
185 |
186 | Parameters
187 | ----------
188 | rates : list, optional
189 | sampling rates (in Hz) to run MSD at, by default []
190 | If empty, MSD is not used.
191 | periods : list, optional
192 | periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
193 | fft_sizes : list, optional
194 | Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
195 | sample_rate : int, optional
196 | Sampling rate of audio in Hz, by default 44100
197 | bands : list, optional
198 | Bands to run MRD at, by default `BANDS`
199 | """
200 | super().__init__()
201 | discs = []
202 | discs += [MPD(p) for p in periods]
203 | discs += [MSD(r, sample_rate=sample_rate) for r in rates]
204 | discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
205 | self.discriminators = nn.ModuleList(discs)
206 |
207 | def preprocess(self, y):
208 | # Remove DC offset
209 | y = y - y.mean(dim=-1, keepdims=True)
210 | # Peak normalize the volume of input audio
211 | y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
212 | return y
213 |
214 | def forward(self, x):
215 | x = self.preprocess(x)
216 | fmaps = [d(x) for d in self.discriminators]
217 | return fmaps
218 |
219 |
220 | if __name__ == "__main__":
221 | disc = Discriminator()
222 | x = torch.zeros(1, 1, 44100)
223 | results = disc(x)
224 | for i, result in enumerate(results):
225 | print(f"disc{i}")
226 | for i, r in enumerate(result):
227 | print(r.shape, r.mean(), r.min(), r.max())
228 | print()
229 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions for ComfyUI HunyuanVideo-Foley custom node
3 | """
4 |
5 | import os
6 | import tempfile
7 | import torch
8 | import numpy as np
9 | from typing import Union, Optional, Tuple
10 | from loguru import logger
11 |
12 | def tensor_to_video(video_tensor: torch.Tensor, output_path: str, fps: int = 30) -> str:
13 | """
14 | Convert a video tensor to a video file
15 |
16 | Args:
17 | video_tensor: Video tensor with shape (frames, channels, height, width)
18 | output_path: Output video file path
19 | fps: Frame rate for the output video
20 |
21 | Returns:
22 | Path to the saved video file
23 | """
24 | try:
25 | import cv2
26 |
27 | # Convert tensor to numpy and handle different formats
28 | if isinstance(video_tensor, torch.Tensor):
29 | video_np = video_tensor.detach().cpu().numpy()
30 | else:
31 | video_np = np.array(video_tensor)
32 |
33 | # Handle different tensor formats
34 | if video_np.ndim == 4: # (frames, channels, height, width)
35 | if video_np.shape[1] == 3: # RGB
36 | video_np = np.transpose(video_np, (0, 2, 3, 1)) # (frames, height, width, channels)
37 | elif video_np.shape[1] == 1: # Grayscale
38 | video_np = np.transpose(video_np, (0, 2, 3, 1))
39 | video_np = np.repeat(video_np, 3, axis=3) # Convert to RGB
40 | elif video_np.ndim == 5: # (batch, frames, channels, height, width)
41 | video_np = video_np[0] # Take first batch
42 | if video_np.shape[1] == 3:
43 | video_np = np.transpose(video_np, (0, 2, 3, 1))
44 |
45 | # Normalize values to 0-255 range
46 | if video_np.max() <= 1.0:
47 | video_np = (video_np * 255).astype(np.uint8)
48 | else:
49 | video_np = video_np.astype(np.uint8)
50 |
51 | # Get video dimensions
52 | frames, height, width, channels = video_np.shape
53 |
54 | # Create video writer
55 | fourcc = cv2.VideoWriter_fourcc(*'mp4v')
56 | out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
57 |
58 | # Write frames
59 | for i in range(frames):
60 | frame = video_np[i]
61 | if channels == 3:
62 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # Convert RGB to BGR for OpenCV
63 | out.write(frame)
64 |
65 | out.release()
66 | logger.info(f"Video saved to: {output_path}")
67 | return output_path
68 |
69 | except Exception as e:
70 | logger.error(f"Failed to convert tensor to video: {e}")
71 | raise
72 |
73 |
74 | def get_video_info(video_path: str) -> dict:
75 | """
76 | Get information about a video file
77 |
78 | Args:
79 | video_path: Path to video file
80 |
81 | Returns:
82 | Dictionary with video information
83 | """
84 | try:
85 | import cv2
86 |
87 | cap = cv2.VideoCapture(video_path)
88 |
89 | info = {
90 | 'fps': cap.get(cv2.CAP_PROP_FPS),
91 | 'frame_count': int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),
92 | 'width': int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
93 | 'height': int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
94 | 'duration': int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) / cap.get(cv2.CAP_PROP_FPS)
95 | }
96 |
97 | cap.release()
98 | return info
99 |
100 | except Exception as e:
101 | logger.error(f"Failed to get video info: {e}")
102 | return {}
103 |
104 |
105 | def ensure_video_file(video_input: Union[str, torch.Tensor, np.ndarray]) -> str:
106 | """
107 | Ensure the video input is converted to a file path
108 |
109 | Args:
110 | video_input: Video input (path, tensor, or array)
111 |
112 | Returns:
113 | Path to video file
114 | """
115 | if isinstance(video_input, str):
116 | # Already a file path
117 | if os.path.exists(video_input):
118 | return video_input
119 | else:
120 | raise FileNotFoundError(f"Video file not found: {video_input}")
121 |
122 | elif isinstance(video_input, (torch.Tensor, np.ndarray)):
123 | # Convert tensor/array to video file
124 | temp_dir = tempfile.mkdtemp()
125 | output_path = os.path.join(temp_dir, "input_video.mp4")
126 | return tensor_to_video(video_input, output_path)
127 |
128 | else:
129 | raise ValueError(f"Unsupported video input type: {type(video_input)}")
130 |
131 |
132 | def validate_model_files(model_path: str) -> Tuple[bool, str]:
133 | """
134 | Validate that all required model files exist
135 |
136 | Args:
137 | model_path: Path to model directory
138 |
139 | Returns:
140 | Tuple of (is_valid, error_message)
141 | """
142 | required_files = [
143 | "hunyuanvideo_foley.pth",
144 | "vae_128d_48k.pth",
145 | "synchformer_state_dict.pth"
146 | ]
147 |
148 | missing_files = []
149 |
150 | for file_name in required_files:
151 | file_path = os.path.join(model_path, file_name)
152 | if not os.path.exists(file_path):
153 | missing_files.append(file_name)
154 |
155 | if missing_files:
156 | return False, f"Missing model files: {', '.join(missing_files)}"
157 |
158 | return True, "All required model files found"
159 |
160 |
161 | def get_optimal_device() -> torch.device:
162 | """
163 | Get the optimal device for model execution
164 |
165 | Returns:
166 | PyTorch device
167 | """
168 | if torch.cuda.is_available():
169 | # Get the device with most free memory
170 | max_memory = 0
171 | best_device = 0
172 |
173 | for i in range(torch.cuda.device_count()):
174 | memory_free = torch.cuda.get_device_properties(i).total_memory
175 | if memory_free > max_memory:
176 | max_memory = memory_free
177 | best_device = i
178 |
179 | device = torch.device(f"cuda:{best_device}")
180 | logger.info(f"Using CUDA device: {device} with {max_memory / 1e9:.1f}GB memory")
181 | return device
182 |
183 | elif torch.backends.mps.is_available():
184 | device = torch.device("mps")
185 | logger.info("Using MPS device (Apple Silicon)")
186 | return device
187 |
188 | else:
189 | device = torch.device("cpu")
190 | logger.info("Using CPU device")
191 | return device
192 |
193 |
194 | def check_memory_requirements(device: torch.device, required_gb: float = 16.0) -> Tuple[bool, str]:
195 | """
196 | Check if the device has enough memory for model execution
197 |
198 | Args:
199 | device: PyTorch device
200 | required_gb: Required memory in GB
201 |
202 | Returns:
203 | Tuple of (has_enough_memory, message)
204 | """
205 | if device.type == "cuda":
206 | properties = torch.cuda.get_device_properties(device)
207 | total_memory = properties.total_memory / 1e9 # Convert to GB
208 |
209 | if total_memory < required_gb:
210 | return False, f"GPU has {total_memory:.1f}GB memory, but {required_gb}GB is recommended"
211 | else:
212 | return True, f"GPU has {total_memory:.1f}GB memory (sufficient)"
213 |
214 | elif device.type == "mps":
215 | # MPS doesn't have a direct way to check memory, assume it's sufficient
216 | return True, "Using MPS device (memory check not available)"
217 |
218 | else:
219 | # CPU - assume it has enough memory
220 | return True, "Using CPU (no memory limit)"
221 |
222 |
223 | def format_duration(seconds: float) -> str:
224 | """
225 | Format duration in seconds to human readable format
226 |
227 | Args:
228 | seconds: Duration in seconds
229 |
230 | Returns:
231 | Formatted duration string
232 | """
233 | if seconds < 60:
234 | return f"{seconds:.1f}s"
235 | elif seconds < 3600:
236 | minutes = int(seconds // 60)
237 | remaining_seconds = seconds % 60
238 | return f"{minutes}m {remaining_seconds:.1f}s"
239 | else:
240 | hours = int(seconds // 3600)
241 | remaining_minutes = int((seconds % 3600) // 60)
242 | return f"{hours}h {remaining_minutes}m"
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/synchformer/compute_desync_score.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import subprocess
3 | from pathlib import Path
4 |
5 | import torch
6 | import torchaudio
7 | import torchvision
8 | from omegaconf import OmegaConf
9 |
10 | import data_transforms
11 | from .synchformer import Synchformer
12 | from .data_transforms import make_class_grid, quantize_offset
13 | from .utils import check_if_file_exists_else_download, which_ffmpeg
14 |
15 |
16 | def prepare_inputs(batch, device):
17 | aud = batch["audio"].to(device)
18 | vid = batch["video"].to(device)
19 |
20 | return aud, vid
21 |
22 |
23 | def get_test_transforms():
24 | ts = [
25 | data_transforms.EqualifyFromRight(),
26 | data_transforms.RGBSpatialCrop(input_size=224, is_random=False),
27 | data_transforms.TemporalCropAndOffset(
28 | crop_len_sec=5,
29 | max_off_sec=2, # https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml
30 | max_wiggle_sec=0.0,
31 | do_offset=True,
32 | offset_type="grid",
33 | prob_oos="null",
34 | grid_size=21,
35 | segment_size_vframes=16,
36 | n_segments=14,
37 | step_size_seg=0.5,
38 | vfps=25,
39 | ),
40 | data_transforms.GenerateMultipleSegments(
41 | segment_size_vframes=16,
42 | n_segments=14,
43 | is_start_random=False,
44 | step_size_seg=0.5,
45 | ),
46 | data_transforms.RGBToHalfToZeroOne(),
47 | data_transforms.RGBNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # motionformer normalization
48 | data_transforms.AudioMelSpectrogram(
49 | sample_rate=16000,
50 | win_length=400, # 25 ms * 16 kHz
51 | hop_length=160, # 10 ms * 16 kHz
52 | n_fft=1024, # 2^(ceil(log2(window_size * sampling_rate)))
53 | n_mels=128, # as in AST
54 | ),
55 | data_transforms.AudioLog(),
56 | data_transforms.PadOrTruncate(max_spec_t=66),
57 | data_transforms.AudioNormalizeAST(mean=-4.2677393, std=4.5689974), # AST, pre-trained on AudioSet
58 | data_transforms.PermuteStreams(
59 | einops_order_audio="S F T -> S 1 F T", einops_order_rgb="S T C H W -> S T C H W" # same
60 | ),
61 | ]
62 | transforms = torchvision.transforms.Compose(ts)
63 |
64 | return transforms
65 |
66 |
67 | def get_video_and_audio(path, get_meta=False, start_sec=0, end_sec=None):
68 | orig_path = path
69 | # (Tv, 3, H, W) [0, 255, uint8]; (Ca, Ta)
70 | rgb, audio, meta = torchvision.io.read_video(str(path), start_sec, end_sec, "sec", output_format="TCHW")
71 | assert meta["video_fps"], f"No video fps for {orig_path}"
72 | # (Ta) <- (Ca, Ta)
73 | audio = audio.mean(dim=0)
74 | # FIXME: this is legacy format of `meta` as it used to be loaded by VideoReader.
75 | meta = {
76 | "video": {"fps": [meta["video_fps"]]},
77 | "audio": {"framerate": [meta["audio_fps"]]},
78 | }
79 | return rgb, audio, meta
80 |
81 |
82 | def reencode_video(path, vfps=25, afps=16000, in_size=256):
83 | assert which_ffmpeg() != "", "Is ffmpeg installed? Check if the conda environment is activated."
84 | new_path = Path.cwd() / "vis" / f"{Path(path).stem}_{vfps}fps_{in_size}side_{afps}hz.mp4"
85 | new_path.parent.mkdir(exist_ok=True)
86 | new_path = str(new_path)
87 | cmd = f"{which_ffmpeg()}"
88 | # no info/error printing
89 | cmd += " -hide_banner -loglevel panic"
90 | cmd += f" -y -i {path}"
91 | # 1) change fps, 2) resize: min(H,W)=MIN_SIDE (vertical vids are supported), 3) change audio framerate
92 | cmd += f" -vf fps={vfps},scale=iw*{in_size}/'min(iw,ih)':ih*{in_size}/'min(iw,ih)',crop='trunc(iw/2)'*2:'trunc(ih/2)'*2"
93 | cmd += f" -ar {afps}"
94 | cmd += f" {new_path}"
95 | subprocess.call(cmd.split())
96 | cmd = f"{which_ffmpeg()}"
97 | cmd += " -hide_banner -loglevel panic"
98 | cmd += f" -y -i {new_path}"
99 | cmd += f" -acodec pcm_s16le -ac 1"
100 | cmd += f' {new_path.replace(".mp4", ".wav")}'
101 | subprocess.call(cmd.split())
102 | return new_path
103 |
104 |
105 | def decode_single_video_prediction(off_logits, grid, item):
106 | label = item["targets"]["offset_label"].item()
107 | print("Ground Truth offset (sec):", f"{label:.2f} ({quantize_offset(grid, label)[-1].item()})")
108 | print()
109 | print("Prediction Results:")
110 | off_probs = torch.softmax(off_logits, dim=-1)
111 | k = min(off_probs.shape[-1], 5)
112 | topk_logits, topk_preds = torch.topk(off_logits, k)
113 | # remove batch dimension
114 | assert len(topk_logits) == 1, "batch is larger than 1"
115 | topk_logits = topk_logits[0]
116 | topk_preds = topk_preds[0]
117 | off_logits = off_logits[0]
118 | off_probs = off_probs[0]
119 | for target_hat in topk_preds:
120 | print(f'p={off_probs[target_hat]:.4f} ({off_logits[target_hat]:.4f}), "{grid[target_hat]:.2f}" ({target_hat})')
121 | return off_probs
122 |
123 |
124 | def main(args):
125 | vfps = 25
126 | afps = 16000
127 | in_size = 256
128 | # making the offset class grid similar to the one used in transforms,
129 | # refer to the used one: https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml
130 | max_off_sec = 2
131 | num_cls = 21
132 |
133 | # checking if the provided video has the correct frame rates
134 | print(f"Using video: {args.vid_path}")
135 | v, _, info = torchvision.io.read_video(args.vid_path, pts_unit="sec")
136 | _, H, W, _ = v.shape
137 | if info["video_fps"] != vfps or info["audio_fps"] != afps or min(H, W) != in_size:
138 | print(f'Reencoding. vfps: {info["video_fps"]} -> {vfps};', end=" ")
139 | print(f'afps: {info["audio_fps"]} -> {afps};', end=" ")
140 | print(f"{(H, W)} -> min(H, W)={in_size}")
141 | args.vid_path = reencode_video(args.vid_path, vfps, afps, in_size)
142 | else:
143 | print(f'Skipping reencoding. vfps: {info["video_fps"]}; afps: {info["audio_fps"]}; min(H, W)={in_size}')
144 |
145 | device = torch.device(args.device)
146 |
147 | # load visual and audio streams
148 | # rgb: (Tv, 3, H, W) in [0, 225], audio: (Ta,) in [-1, 1]
149 | rgb, audio, meta = get_video_and_audio(args.vid_path, get_meta=True)
150 |
151 | # making an item (dict) to apply transformations
152 | # NOTE: here is how it works:
153 | # For instance, if the model is trained on 5sec clips, the provided video is 9sec, and `v_start_i_sec=1.3`
154 | # the transform will crop out a 5sec-clip from 1.3 to 6.3 seconds and shift the start of the audio
155 | # track by `args.offset_sec` seconds. It means that if `offset_sec` > 0, the audio will
156 | # start by `offset_sec` earlier than the rgb track.
157 | # It is a good idea to use something in [-`max_off_sec`, `max_off_sec`] (-2, +2) seconds (see `grid`)
158 | item = dict(
159 | video=rgb,
160 | audio=audio,
161 | meta=meta,
162 | path=args.vid_path,
163 | split="test",
164 | targets={
165 | "v_start_i_sec": args.v_start_i_sec,
166 | "offset_sec": args.offset_sec,
167 | },
168 | )
169 |
170 | grid = make_class_grid(-max_off_sec, max_off_sec, num_cls)
171 | if not (min(grid) <= item["targets"]["offset_sec"] <= max(grid)):
172 | print(f'WARNING: offset_sec={item["targets"]["offset_sec"]} is outside the trained grid: {grid}')
173 |
174 | # applying the test-time transform
175 | item = get_test_transforms()(item)
176 |
177 | # prepare inputs for inference
178 | batch = torch.utils.data.default_collate([item])
179 | aud, vid = prepare_inputs(batch, device)
180 |
181 | # TODO:
182 | # sanity check: we will take the input to the `model` and recontruct make a video from it.
183 | # Use this check to make sure the input makes sense (audio should be ok but shifted as you specified)
184 | # reconstruct_video_from_input(aud, vid, batch['meta'], args.vid_path, args.v_start_i_sec, args.offset_sec,
185 | # vfps, afps)
186 |
187 | # forward pass
188 | with torch.set_grad_enabled(False):
189 | with torch.autocast("cuda", enabled=True):
190 | _, logits = synchformer(vid, aud)
191 |
192 | # simply prints the results of the prediction
193 | decode_single_video_prediction(logits, grid, item)
194 |
195 |
196 | if __name__ == "__main__":
197 | parser = argparse.ArgumentParser()
198 | parser.add_argument("--exp_name", required=True, help="In a format: xx-xx-xxTxx-xx-xx")
199 | parser.add_argument("--vid_path", required=True, help="A path to .mp4 video")
200 | parser.add_argument("--offset_sec", type=float, default=0.0)
201 | parser.add_argument("--v_start_i_sec", type=float, default=0.0)
202 | parser.add_argument("--device", default="cuda:0")
203 | args = parser.parse_args()
204 |
205 | synchformer = Synchformer().cuda().eval()
206 | synchformer.load_state_dict(
207 | torch.load(
208 | os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"),
209 | weights_only=True,
210 | map_location="cpu",
211 | )
212 | )
213 |
214 | main(args)
215 |
--------------------------------------------------------------------------------
/example_workflows/HunyuanVideo-Foley.json:
--------------------------------------------------------------------------------
1 | {
2 | "id": "b97ba031-d866-4f4e-baf5-3017bd4ef4b5",
3 | "revision": 0,
4 | "last_node_id": 36,
5 | "last_link_id": 82,
6 | "nodes": [
7 | {
8 | "id": 19,
9 | "type": "VHS_VideoCombine",
10 | "pos": [
11 | 768.0491943359375,
12 | -51.64681625366211
13 | ],
14 | "size": [
15 | 214.7587890625,
16 | 477.5454406738281
17 | ],
18 | "flags": {},
19 | "order": 3,
20 | "mode": 0,
21 | "inputs": [
22 | {
23 | "name": "images",
24 | "type": "IMAGE",
25 | "link": 79
26 | },
27 | {
28 | "name": "audio",
29 | "shape": 7,
30 | "type": "AUDIO",
31 | "link": 80
32 | },
33 | {
34 | "name": "meta_batch",
35 | "shape": 7,
36 | "type": "VHS_BatchManager",
37 | "link": null
38 | },
39 | {
40 | "name": "vae",
41 | "shape": 7,
42 | "type": "VAE",
43 | "link": null
44 | },
45 | {
46 | "name": "frame_rate",
47 | "type": "FLOAT",
48 | "widget": {
49 | "name": "frame_rate"
50 | },
51 | "link": 59
52 | }
53 | ],
54 | "outputs": [
55 | {
56 | "name": "Filenames",
57 | "type": "VHS_FILENAMES",
58 | "links": null
59 | }
60 | ],
61 | "properties": {
62 | "Node name for S&R": "VHS_VideoCombine"
63 | },
64 | "widgets_values": {
65 | "frame_rate": 24,
66 | "loop_count": 0,
67 | "filename_prefix": "AnimateDiff",
68 | "format": "video/h264-mp4",
69 | "pix_fmt": "yuv420p",
70 | "crf": 19,
71 | "save_metadata": true,
72 | "trim_to_audio": false,
73 | "pingpong": false,
74 | "save_output": true,
75 | "videopreview": {
76 | "hidden": false,
77 | "paused": false,
78 | "params": {
79 | "filename": "AnimateDiff_00044-audio.mp4",
80 | "subfolder": "",
81 | "type": "output",
82 | "format": "video/h264-mp4",
83 | "frame_rate": 24,
84 | "workflow": "AnimateDiff_00044.png",
85 | "fullpath": "/root/ComfyUI/output/AnimateDiff_00044-audio.mp4"
86 | }
87 | }
88 | }
89 | },
90 | {
91 | "id": 36,
92 | "type": "HunyuanVideoFoley",
93 | "pos": [
94 | 280.91168212890625,
95 | -55.041656494140625
96 | ],
97 | "size": [
98 | 400,
99 | 398
100 | ],
101 | "flags": {},
102 | "order": 2,
103 | "mode": 0,
104 | "inputs": [
105 | {
106 | "name": "images",
107 | "type": "IMAGE",
108 | "link": 81
109 | },
110 | {
111 | "name": "fps",
112 | "shape": 7,
113 | "type": "FLOAT",
114 | "widget": {
115 | "name": "fps"
116 | },
117 | "link": 82
118 | }
119 | ],
120 | "outputs": [
121 | {
122 | "name": "video_frames",
123 | "type": "IMAGE",
124 | "links": [
125 | 79
126 | ]
127 | },
128 | {
129 | "name": "audio",
130 | "type": "AUDIO",
131 | "links": [
132 | 80
133 | ]
134 | },
135 | {
136 | "name": "status_message",
137 | "type": "STRING",
138 | "links": null
139 | }
140 | ],
141 | "properties": {
142 | "Node name for S&R": "HunyuanVideoFoley"
143 | },
144 | "widgets_values": [
145 | "Woman holding a gun and fighting monster",
146 | 4.5,
147 | 50,
148 | 1,
149 | 1183713675,
150 | "randomize",
151 | "Sharp sounds, noise, messy sounds, and weak frame-related sounds",
152 | 24,
153 | "hunyuan_foley",
154 | "foley_",
155 | false
156 | ]
157 | },
158 | {
159 | "id": 16,
160 | "type": "VHS_LoadVideo",
161 | "pos": [
162 | -87.1226577758789,
163 | -49.90601348876953
164 | ],
165 | "size": [
166 | 252.2359619140625,
167 | 479.15765380859375
168 | ],
169 | "flags": {},
170 | "order": 0,
171 | "mode": 0,
172 | "inputs": [
173 | {
174 | "name": "meta_batch",
175 | "shape": 7,
176 | "type": "VHS_BatchManager",
177 | "link": null
178 | },
179 | {
180 | "name": "vae",
181 | "shape": 7,
182 | "type": "VAE",
183 | "link": null
184 | }
185 | ],
186 | "outputs": [
187 | {
188 | "name": "IMAGE",
189 | "type": "IMAGE",
190 | "links": [
191 | 81
192 | ]
193 | },
194 | {
195 | "name": "frame_count",
196 | "type": "INT",
197 | "links": []
198 | },
199 | {
200 | "name": "audio",
201 | "type": "AUDIO",
202 | "links": null
203 | },
204 | {
205 | "name": "video_info",
206 | "type": "VHS_VIDEOINFO",
207 | "links": [
208 | 51
209 | ]
210 | }
211 | ],
212 | "properties": {
213 | "Node name for S&R": "VHS_LoadVideo"
214 | },
215 | "widgets_values": {
216 | "video": "70973466-d29f-4789-a957-99b44aa36f0b.mp4",
217 | "force_rate": 0,
218 | "custom_width": 0,
219 | "custom_height": 0,
220 | "frame_load_cap": 0,
221 | "skip_first_frames": 0,
222 | "select_every_nth": 1,
223 | "format": "AnimateDiff",
224 | "choose video to upload": "image",
225 | "videopreview": {
226 | "hidden": false,
227 | "paused": false,
228 | "params": {
229 | "force_rate": 0,
230 | "custom_width": 0,
231 | "custom_height": 0,
232 | "frame_load_cap": 0,
233 | "skip_first_frames": 0,
234 | "select_every_nth": 1,
235 | "filename": "70973466-d29f-4789-a957-99b44aa36f0b.mp4",
236 | "type": "input",
237 | "format": "video/mp4"
238 | }
239 | }
240 | }
241 | },
242 | {
243 | "id": 26,
244 | "type": "VHS_VideoInfo",
245 | "pos": [
246 | 310.0038146972656,
247 | 443.6332092285156
248 | ],
249 | "size": [
250 | 218.12109375,
251 | 206
252 | ],
253 | "flags": {},
254 | "order": 1,
255 | "mode": 0,
256 | "inputs": [
257 | {
258 | "name": "video_info",
259 | "type": "VHS_VIDEOINFO",
260 | "link": 51
261 | }
262 | ],
263 | "outputs": [
264 | {
265 | "name": "source_fps🟨",
266 | "type": "FLOAT",
267 | "links": [
268 | 59,
269 | 82
270 | ]
271 | },
272 | {
273 | "name": "source_frame_count🟨",
274 | "type": "INT",
275 | "links": null
276 | },
277 | {
278 | "name": "source_duration🟨",
279 | "type": "FLOAT",
280 | "links": null
281 | },
282 | {
283 | "name": "source_width🟨",
284 | "type": "INT",
285 | "links": null
286 | },
287 | {
288 | "name": "source_height🟨",
289 | "type": "INT",
290 | "links": null
291 | },
292 | {
293 | "name": "loaded_fps🟦",
294 | "type": "FLOAT",
295 | "links": null
296 | },
297 | {
298 | "name": "loaded_frame_count🟦",
299 | "type": "INT",
300 | "links": null
301 | },
302 | {
303 | "name": "loaded_duration🟦",
304 | "type": "FLOAT",
305 | "links": null
306 | },
307 | {
308 | "name": "loaded_width🟦",
309 | "type": "INT",
310 | "links": null
311 | },
312 | {
313 | "name": "loaded_height🟦",
314 | "type": "INT",
315 | "links": null
316 | }
317 | ],
318 | "properties": {
319 | "Node name for S&R": "VHS_VideoInfo"
320 | },
321 | "widgets_values": {}
322 | }
323 | ],
324 | "links": [
325 | [
326 | 51,
327 | 16,
328 | 3,
329 | 26,
330 | 0,
331 | "VHS_VIDEOINFO"
332 | ],
333 | [
334 | 59,
335 | 26,
336 | 0,
337 | 19,
338 | 4,
339 | "FLOAT"
340 | ],
341 | [
342 | 79,
343 | 36,
344 | 0,
345 | 19,
346 | 0,
347 | "IMAGE"
348 | ],
349 | [
350 | 80,
351 | 36,
352 | 1,
353 | 19,
354 | 1,
355 | "AUDIO"
356 | ],
357 | [
358 | 81,
359 | 16,
360 | 0,
361 | 36,
362 | 0,
363 | "IMAGE"
364 | ],
365 | [
366 | 82,
367 | 26,
368 | 0,
369 | 36,
370 | 1,
371 | "FLOAT"
372 | ]
373 | ],
374 | "groups": [],
375 | "config": {},
376 | "extra": {
377 | "ds": {
378 | "scale": 0.9507547298682385,
379 | "offset": [
380 | 188.57921583361022,
381 | 197.62529662436822
382 | ]
383 | },
384 | "frontendVersion": "1.23.4",
385 | "VHS_latentpreview": false,
386 | "VHS_latentpreviewrate": 0,
387 | "VHS_MetadataImage": true,
388 | "VHS_KeepIntermediate": true
389 | },
390 | "version": 0.4
391 | }
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/nn/quantize.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from einops import rearrange
8 | from torch.nn.utils import weight_norm
9 |
10 | from .layers import WNConv1d
11 |
12 |
13 | class VectorQuantize(nn.Module):
14 | """
15 | Implementation of VQ similar to Karpathy's repo:
16 | https://github.com/karpathy/deep-vector-quantization
17 | Additionally uses following tricks from Improved VQGAN
18 | (https://arxiv.org/pdf/2110.04627.pdf):
19 | 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
20 | for improved codebook usage
21 | 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
22 | improves training stability
23 | """
24 |
25 | def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
26 | super().__init__()
27 | self.codebook_size = codebook_size
28 | self.codebook_dim = codebook_dim
29 |
30 | self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
31 | self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
32 | self.codebook = nn.Embedding(codebook_size, codebook_dim)
33 |
34 | def forward(self, z):
35 | """Quantized the input tensor using a fixed codebook and returns
36 | the corresponding codebook vectors
37 |
38 | Parameters
39 | ----------
40 | z : Tensor[B x D x T]
41 |
42 | Returns
43 | -------
44 | Tensor[B x D x T]
45 | Quantized continuous representation of input
46 | Tensor[1]
47 | Commitment loss to train encoder to predict vectors closer to codebook
48 | entries
49 | Tensor[1]
50 | Codebook loss to update the codebook
51 | Tensor[B x T]
52 | Codebook indices (quantized discrete representation of input)
53 | Tensor[B x D x T]
54 | Projected latents (continuous representation of input before quantization)
55 | """
56 |
57 | # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
58 | z_e = self.in_proj(z) # z_e : (B x D x T)
59 | z_q, indices = self.decode_latents(z_e)
60 |
61 | commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
62 | codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
63 |
64 | z_q = (
65 | z_e + (z_q - z_e).detach()
66 | ) # noop in forward pass, straight-through gradient estimator in backward pass
67 |
68 | z_q = self.out_proj(z_q)
69 |
70 | return z_q, commitment_loss, codebook_loss, indices, z_e
71 |
72 | def embed_code(self, embed_id):
73 | return F.embedding(embed_id, self.codebook.weight)
74 |
75 | def decode_code(self, embed_id):
76 | return self.embed_code(embed_id).transpose(1, 2)
77 |
78 | def decode_latents(self, latents):
79 | encodings = rearrange(latents, "b d t -> (b t) d")
80 | codebook = self.codebook.weight # codebook: (N x D)
81 |
82 | # L2 normalize encodings and codebook (ViT-VQGAN)
83 | encodings = F.normalize(encodings)
84 | codebook = F.normalize(codebook)
85 |
86 | # Compute euclidean distance with codebook
87 | dist = (
88 | encodings.pow(2).sum(1, keepdim=True)
89 | - 2 * encodings @ codebook.t()
90 | + codebook.pow(2).sum(1, keepdim=True).t()
91 | )
92 | indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
93 | z_q = self.decode_code(indices)
94 | return z_q, indices
95 |
96 |
97 | class ResidualVectorQuantize(nn.Module):
98 | """
99 | Introduced in SoundStream: An end2end neural audio codec
100 | https://arxiv.org/abs/2107.03312
101 | """
102 |
103 | def __init__(
104 | self,
105 | input_dim: int = 512,
106 | n_codebooks: int = 9,
107 | codebook_size: int = 1024,
108 | codebook_dim: Union[int, list] = 8,
109 | quantizer_dropout: float = 0.0,
110 | ):
111 | super().__init__()
112 | if isinstance(codebook_dim, int):
113 | codebook_dim = [codebook_dim for _ in range(n_codebooks)]
114 |
115 | self.n_codebooks = n_codebooks
116 | self.codebook_dim = codebook_dim
117 | self.codebook_size = codebook_size
118 |
119 | self.quantizers = nn.ModuleList(
120 | [
121 | VectorQuantize(input_dim, codebook_size, codebook_dim[i])
122 | for i in range(n_codebooks)
123 | ]
124 | )
125 | self.quantizer_dropout = quantizer_dropout
126 |
127 | def forward(self, z, n_quantizers: int = None):
128 | """Quantized the input tensor using a fixed set of `n` codebooks and returns
129 | the corresponding codebook vectors
130 | Parameters
131 | ----------
132 | z : Tensor[B x D x T]
133 | n_quantizers : int, optional
134 | No. of quantizers to use
135 | (n_quantizers < self.n_codebooks ex: for quantizer dropout)
136 | Note: if `self.quantizer_dropout` is True, this argument is ignored
137 | when in training mode, and a random number of quantizers is used.
138 | Returns
139 | -------
140 | dict
141 | A dictionary with the following keys:
142 |
143 | "z" : Tensor[B x D x T]
144 | Quantized continuous representation of input
145 | "codes" : Tensor[B x N x T]
146 | Codebook indices for each codebook
147 | (quantized discrete representation of input)
148 | "latents" : Tensor[B x N*D x T]
149 | Projected latents (continuous representation of input before quantization)
150 | "vq/commitment_loss" : Tensor[1]
151 | Commitment loss to train encoder to predict vectors closer to codebook
152 | entries
153 | "vq/codebook_loss" : Tensor[1]
154 | Codebook loss to update the codebook
155 | """
156 | z_q = 0
157 | residual = z
158 | commitment_loss = 0
159 | codebook_loss = 0
160 |
161 | codebook_indices = []
162 | latents = []
163 |
164 | if n_quantizers is None:
165 | n_quantizers = self.n_codebooks
166 | if self.training:
167 | n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
168 | dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
169 | n_dropout = int(z.shape[0] * self.quantizer_dropout)
170 | n_quantizers[:n_dropout] = dropout[:n_dropout]
171 | n_quantizers = n_quantizers.to(z.device)
172 |
173 | for i, quantizer in enumerate(self.quantizers):
174 | if self.training is False and i >= n_quantizers:
175 | break
176 |
177 | z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
178 | residual
179 | )
180 |
181 | # Create mask to apply quantizer dropout
182 | mask = (
183 | torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
184 | )
185 | z_q = z_q + z_q_i * mask[:, None, None]
186 | residual = residual - z_q_i
187 |
188 | # Sum losses
189 | commitment_loss += (commitment_loss_i * mask).mean()
190 | codebook_loss += (codebook_loss_i * mask).mean()
191 |
192 | codebook_indices.append(indices_i)
193 | latents.append(z_e_i)
194 |
195 | codes = torch.stack(codebook_indices, dim=1)
196 | latents = torch.cat(latents, dim=1)
197 |
198 | return z_q, codes, latents, commitment_loss, codebook_loss
199 |
200 | def from_codes(self, codes: torch.Tensor):
201 | """Given the quantized codes, reconstruct the continuous representation
202 | Parameters
203 | ----------
204 | codes : Tensor[B x N x T]
205 | Quantized discrete representation of input
206 | Returns
207 | -------
208 | Tensor[B x D x T]
209 | Quantized continuous representation of input
210 | """
211 | z_q = 0.0
212 | z_p = []
213 | n_codebooks = codes.shape[1]
214 | for i in range(n_codebooks):
215 | z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
216 | z_p.append(z_p_i)
217 |
218 | z_q_i = self.quantizers[i].out_proj(z_p_i)
219 | z_q = z_q + z_q_i
220 | return z_q, torch.cat(z_p, dim=1), codes
221 |
222 | def from_latents(self, latents: torch.Tensor):
223 | """Given the unquantized latents, reconstruct the
224 | continuous representation after quantization.
225 |
226 | Parameters
227 | ----------
228 | latents : Tensor[B x N x T]
229 | Continuous representation of input after projection
230 |
231 | Returns
232 | -------
233 | Tensor[B x D x T]
234 | Quantized representation of full-projected space
235 | Tensor[B x D x T]
236 | Quantized representation of latent space
237 | """
238 | z_q = 0
239 | z_p = []
240 | codes = []
241 | dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
242 |
243 | n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
244 | 0
245 | ]
246 | for i in range(n_codebooks):
247 | j, k = dims[i], dims[i + 1]
248 | z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
249 | z_p.append(z_p_i)
250 | codes.append(codes_i)
251 |
252 | z_q_i = self.quantizers[i].out_proj(z_p_i)
253 | z_q = z_q + z_q_i
254 |
255 | return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
256 |
257 |
258 | if __name__ == "__main__":
259 | rvq = ResidualVectorQuantize(quantizer_dropout=True)
260 | x = torch.randn(16, 512, 80)
261 | y = rvq(x)
262 | print(y["latents"].shape)
263 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from loguru import logger
4 | from torchvision import transforms
5 | from torchvision.transforms import v2
6 | from diffusers.utils.torch_utils import randn_tensor
7 | from transformers import AutoTokenizer, AutoModel, ClapTextModelWithProjection
8 | from ..models.dac_vae.model.dac import DAC
9 | from ..models.synchformer import Synchformer
10 | from ..models.hifi_foley import HunyuanVideoFoley
11 | from .config_utils import load_yaml, AttributeDict
12 | from .schedulers import FlowMatchDiscreteScheduler
13 | from tqdm import tqdm
14 |
15 | def load_state_dict(model, model_path):
16 | logger.info(f"Loading model state dict from: {model_path}")
17 | state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False)
18 |
19 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
20 |
21 | if missing_keys:
22 | logger.warning(f"Missing keys in state dict ({len(missing_keys)} keys):")
23 | for key in missing_keys:
24 | logger.warning(f" - {key}")
25 | else:
26 | logger.info("No missing keys found")
27 |
28 | if unexpected_keys:
29 | logger.warning(f"Unexpected keys in state dict ({len(unexpected_keys)} keys):")
30 | for key in unexpected_keys:
31 | logger.warning(f" - {key}")
32 | else:
33 | logger.info("No unexpected keys found")
34 |
35 | logger.info("Model state dict loaded successfully")
36 | return model
37 |
38 | def load_model(model_path, config_path, device):
39 | logger.info("Starting model loading process...")
40 | logger.info(f"Configuration file: {config_path}")
41 | logger.info(f"Model weights dir: {model_path}")
42 | logger.info(f"Target device: {device}")
43 |
44 | cfg = load_yaml(config_path)
45 | logger.info("Configuration loaded successfully")
46 |
47 | # HunyuanVideoFoley
48 | logger.info("Loading HunyuanVideoFoley main model...")
49 | foley_model = HunyuanVideoFoley(cfg, dtype=torch.bfloat16, device=device).to(device=device, dtype=torch.bfloat16)
50 | foley_model = load_state_dict(foley_model, os.path.join(model_path, "hunyuanvideo_foley.pth"))
51 | foley_model.eval()
52 | logger.info("HunyuanVideoFoley model loaded and set to evaluation mode")
53 |
54 | # DAC-VAE
55 | dac_path = os.path.join(model_path, "vae_128d_48k.pth")
56 | logger.info(f"Loading DAC VAE model from: {dac_path}")
57 | dac_model = DAC.load(dac_path)
58 | dac_model = dac_model.to(device)
59 | dac_model.requires_grad_(False)
60 | dac_model.eval()
61 | logger.info("DAC VAE model loaded successfully")
62 |
63 | # Siglip2 visual-encoder
64 | logger.info("Loading SigLIP2 visual encoder...")
65 | siglip2_preprocess = transforms.Compose([
66 | transforms.Resize((512, 512)),
67 | transforms.ToTensor(),
68 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
69 | ])
70 | siglip2_model = AutoModel.from_pretrained("google/siglip2-base-patch16-512").to(device).eval()
71 | logger.info("SigLIP2 model and preprocessing pipeline loaded successfully")
72 |
73 | # clap text-encoder
74 | logger.info("Loading CLAP text encoder...")
75 | clap_tokenizer = AutoTokenizer.from_pretrained("laion/larger_clap_general")
76 | clap_model = ClapTextModelWithProjection.from_pretrained("laion/larger_clap_general").to(device)
77 | logger.info("CLAP tokenizer and model loaded successfully")
78 |
79 | # syncformer
80 | syncformer_path = os.path.join(model_path, "synchformer_state_dict.pth")
81 | logger.info(f"Loading Synchformer model from: {syncformer_path}")
82 | syncformer_preprocess = v2.Compose(
83 | [
84 | v2.Resize(224, interpolation=v2.InterpolationMode.BICUBIC),
85 | v2.CenterCrop(224),
86 | v2.ToImage(),
87 | v2.ToDtype(torch.float32, scale=True),
88 | v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
89 | ]
90 | )
91 |
92 | syncformer_model = Synchformer()
93 | syncformer_model.load_state_dict(torch.load(syncformer_path, weights_only=False, map_location="cpu"))
94 | syncformer_model = syncformer_model.to(device).eval()
95 | logger.info("Synchformer model and preprocessing pipeline loaded successfully")
96 |
97 |
98 | logger.info("Creating model dictionary with attribute access...")
99 | model_dict = AttributeDict({
100 | 'foley_model': foley_model,
101 | 'dac_model': dac_model,
102 | 'siglip2_preprocess': siglip2_preprocess,
103 | 'siglip2_model': siglip2_model,
104 | 'clap_tokenizer': clap_tokenizer,
105 | 'clap_model': clap_model,
106 | 'syncformer_preprocess': syncformer_preprocess,
107 | 'syncformer_model': syncformer_model,
108 | 'device': device,
109 | })
110 |
111 | logger.info("All models loaded successfully!")
112 | logger.info("Available model components:")
113 | for key in model_dict.keys():
114 | logger.info(f" - {key}")
115 | logger.info("Models can be accessed via attribute notation (e.g., models.foley_model)")
116 |
117 | return model_dict, cfg
118 |
119 | def retrieve_timesteps(
120 | scheduler,
121 | num_inference_steps,
122 | device,
123 | **kwargs,
124 | ):
125 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
126 | timesteps = scheduler.timesteps
127 | return timesteps, num_inference_steps
128 |
129 |
130 | def prepare_latents(scheduler, batch_size, num_channels_latents, length, dtype, device):
131 | shape = (batch_size, num_channels_latents, int(length))
132 | latents = randn_tensor(shape, device=device, dtype=dtype)
133 |
134 | # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
135 | if hasattr(scheduler, "init_noise_sigma"):
136 | # scale the initial noise by the standard deviation required by the scheduler
137 | latents = latents * scheduler.init_noise_sigma
138 |
139 | return latents
140 |
141 |
142 | @torch.no_grad()
143 | def denoise_process(visual_feats, text_feats, audio_len_in_s, model_dict, cfg, guidance_scale=4.5, num_inference_steps=50, batch_size=1):
144 |
145 | target_dtype = model_dict.foley_model.dtype
146 | autocast_enabled = target_dtype != torch.float32
147 | device = model_dict.device
148 |
149 | scheduler = FlowMatchDiscreteScheduler(
150 | shift=cfg.diffusion_config.sample_flow_shift,
151 | reverse=cfg.diffusion_config.flow_reverse,
152 | solver=cfg.diffusion_config.flow_solver,
153 | use_flux_shift=cfg.diffusion_config.sample_use_flux_shift,
154 | flux_base_shift=cfg.diffusion_config.flux_base_shift,
155 | flux_max_shift=cfg.diffusion_config.flux_max_shift,
156 | )
157 |
158 | timesteps, num_inference_steps = retrieve_timesteps(
159 | scheduler,
160 | num_inference_steps,
161 | device,
162 | )
163 |
164 | latents = prepare_latents(
165 | scheduler,
166 | batch_size=batch_size,
167 | num_channels_latents=cfg.model_config.model_kwargs.audio_vae_latent_dim,
168 | length=audio_len_in_s * cfg.model_config.model_kwargs.audio_frame_rate,
169 | dtype=target_dtype,
170 | device=device,
171 | )
172 |
173 | # Denoise loop
174 | for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Denoising steps"):
175 | # noise latents
176 | latent_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
177 | latent_input = scheduler.scale_model_input(latent_input, t)
178 |
179 | t_expand = t.repeat(latent_input.shape[0])
180 |
181 | # siglip2 features
182 | siglip2_feat = visual_feats.siglip2_feat.repeat(batch_size, 1, 1) # Repeat for batch_size
183 | uncond_siglip2_feat = model_dict.foley_model.get_empty_clip_sequence(
184 | bs=batch_size, len=siglip2_feat.shape[1]
185 | ).to(device)
186 |
187 | if guidance_scale is not None and guidance_scale > 1.0:
188 | siglip2_feat_input = torch.cat([uncond_siglip2_feat, siglip2_feat], dim=0)
189 | else:
190 | siglip2_feat_input = siglip2_feat
191 |
192 | # syncformer features
193 | syncformer_feat = visual_feats.syncformer_feat.repeat(batch_size, 1, 1) # Repeat for batch_size
194 | uncond_syncformer_feat = model_dict.foley_model.get_empty_sync_sequence(
195 | bs=batch_size, len=syncformer_feat.shape[1]
196 | ).to(device)
197 | if guidance_scale is not None and guidance_scale > 1.0:
198 | syncformer_feat_input = torch.cat([uncond_syncformer_feat, syncformer_feat], dim=0)
199 | else:
200 | syncformer_feat_input = syncformer_feat
201 |
202 | # text features
203 | text_feat_repeated = text_feats.text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size
204 | uncond_text_feat_repeated = text_feats.uncond_text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size
205 | if guidance_scale is not None and guidance_scale > 1.0:
206 | text_feat_input = torch.cat([uncond_text_feat_repeated, text_feat_repeated], dim=0)
207 | else:
208 | text_feat_input = text_feat_repeated
209 |
210 | with torch.autocast(device_type=device.type, enabled=autocast_enabled, dtype=target_dtype):
211 | # Predict the noise residual
212 | noise_pred = model_dict.foley_model(
213 | x=latent_input,
214 | t=t_expand,
215 | cond=text_feat_input,
216 | clip_feat=siglip2_feat_input,
217 | sync_feat=syncformer_feat_input,
218 | return_dict=True,
219 | )["x"]
220 |
221 | noise_pred = noise_pred.to(dtype=torch.float32)
222 |
223 | if guidance_scale is not None and guidance_scale > 1.0:
224 | # Perform classifier-free guidance
225 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
226 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
227 |
228 | # Compute the previous noisy sample x_t -> x_t-1
229 | latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
230 |
231 | # Post-process the latents to audio
232 |
233 | with torch.no_grad():
234 | audio = model_dict.dac_model.decode(latents)
235 | audio = audio.float().cpu()
236 |
237 | audio = audio[:, :int(audio_len_in_s*model_dict.dac_model.sample_rate)]
238 |
239 | return audio, model_dict.dac_model.sample_rate
240 |
241 |
242 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/model/base.py:
--------------------------------------------------------------------------------
1 | import math
2 | from dataclasses import dataclass
3 | from pathlib import Path
4 | from typing import Union
5 |
6 | import numpy as np
7 | import torch
8 | import tqdm
9 | from audiotools import AudioSignal
10 | from torch import nn
11 |
12 | SUPPORTED_VERSIONS = ["1.0.0"]
13 |
14 |
15 | @dataclass
16 | class DACFile:
17 | codes: torch.Tensor
18 |
19 | # Metadata
20 | chunk_length: int
21 | original_length: int
22 | input_db: float
23 | channels: int
24 | sample_rate: int
25 | padding: bool
26 | dac_version: str
27 |
28 | def save(self, path):
29 | artifacts = {
30 | "codes": self.codes.numpy().astype(np.uint16),
31 | "metadata": {
32 | "input_db": self.input_db.numpy().astype(np.float32),
33 | "original_length": self.original_length,
34 | "sample_rate": self.sample_rate,
35 | "chunk_length": self.chunk_length,
36 | "channels": self.channels,
37 | "padding": self.padding,
38 | "dac_version": SUPPORTED_VERSIONS[-1],
39 | },
40 | }
41 | path = Path(path).with_suffix(".dac")
42 | with open(path, "wb") as f:
43 | np.save(f, artifacts)
44 | return path
45 |
46 | @classmethod
47 | def load(cls, path):
48 | artifacts = np.load(path, allow_pickle=True)[()]
49 | codes = torch.from_numpy(artifacts["codes"].astype(int))
50 | if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
51 | raise RuntimeError(
52 | f"Given file {path} can't be loaded with this version of descript-audio-codec."
53 | )
54 | return cls(codes=codes, **artifacts["metadata"])
55 |
56 |
57 | class CodecMixin:
58 | @property
59 | def padding(self):
60 | if not hasattr(self, "_padding"):
61 | self._padding = True
62 | return self._padding
63 |
64 | @padding.setter
65 | def padding(self, value):
66 | assert isinstance(value, bool)
67 |
68 | layers = [
69 | l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
70 | ]
71 |
72 | for layer in layers:
73 | if value:
74 | if hasattr(layer, "original_padding"):
75 | layer.padding = layer.original_padding
76 | else:
77 | layer.original_padding = layer.padding
78 | layer.padding = tuple(0 for _ in range(len(layer.padding)))
79 |
80 | self._padding = value
81 |
82 | def get_delay(self):
83 | # Any number works here, delay is invariant to input length
84 | l_out = self.get_output_length(0)
85 | L = l_out
86 |
87 | layers = []
88 | for layer in self.modules():
89 | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
90 | layers.append(layer)
91 |
92 | for layer in reversed(layers):
93 | d = layer.dilation[0]
94 | k = layer.kernel_size[0]
95 | s = layer.stride[0]
96 |
97 | if isinstance(layer, nn.ConvTranspose1d):
98 | L = ((L - d * (k - 1) - 1) / s) + 1
99 | elif isinstance(layer, nn.Conv1d):
100 | L = (L - 1) * s + d * (k - 1) + 1
101 |
102 | L = math.ceil(L)
103 |
104 | l_in = L
105 |
106 | return (l_in - l_out) // 2
107 |
108 | def get_output_length(self, input_length):
109 | L = input_length
110 | # Calculate output length
111 | for layer in self.modules():
112 | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
113 | d = layer.dilation[0]
114 | k = layer.kernel_size[0]
115 | s = layer.stride[0]
116 |
117 | if isinstance(layer, nn.Conv1d):
118 | L = ((L - d * (k - 1) - 1) / s) + 1
119 | elif isinstance(layer, nn.ConvTranspose1d):
120 | L = (L - 1) * s + d * (k - 1) + 1
121 |
122 | L = math.floor(L)
123 | return L
124 |
125 | @torch.no_grad()
126 | def compress(
127 | self,
128 | audio_path_or_signal: Union[str, Path, AudioSignal],
129 | win_duration: float = 1.0,
130 | verbose: bool = False,
131 | normalize_db: float = -16,
132 | n_quantizers: int = None,
133 | ) -> DACFile:
134 | """Processes an audio signal from a file or AudioSignal object into
135 | discrete codes. This function processes the signal in short windows,
136 | using constant GPU memory.
137 |
138 | Parameters
139 | ----------
140 | audio_path_or_signal : Union[str, Path, AudioSignal]
141 | audio signal to reconstruct
142 | win_duration : float, optional
143 | window duration in seconds, by default 5.0
144 | verbose : bool, optional
145 | by default False
146 | normalize_db : float, optional
147 | normalize db, by default -16
148 |
149 | Returns
150 | -------
151 | DACFile
152 | Object containing compressed codes and metadata
153 | required for decompression
154 | """
155 | audio_signal = audio_path_or_signal
156 | if isinstance(audio_signal, (str, Path)):
157 | audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
158 |
159 | self.eval()
160 | original_padding = self.padding
161 | original_device = audio_signal.device
162 |
163 | audio_signal = audio_signal.clone()
164 | audio_signal = audio_signal.to_mono()
165 | original_sr = audio_signal.sample_rate
166 |
167 | resample_fn = audio_signal.resample
168 | loudness_fn = audio_signal.loudness
169 |
170 | # If audio is > 10 minutes long, use the ffmpeg versions
171 | if audio_signal.signal_duration >= 10 * 60 * 60:
172 | resample_fn = audio_signal.ffmpeg_resample
173 | loudness_fn = audio_signal.ffmpeg_loudness
174 |
175 | original_length = audio_signal.signal_length
176 | resample_fn(self.sample_rate)
177 | input_db = loudness_fn()
178 |
179 | if normalize_db is not None:
180 | audio_signal.normalize(normalize_db)
181 | audio_signal.ensure_max_of_audio()
182 |
183 | nb, nac, nt = audio_signal.audio_data.shape
184 | audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
185 | win_duration = (
186 | audio_signal.signal_duration if win_duration is None else win_duration
187 | )
188 |
189 | if audio_signal.signal_duration <= win_duration:
190 | # Unchunked compression (used if signal length < win duration)
191 | self.padding = True
192 | n_samples = nt
193 | hop = nt
194 | else:
195 | # Chunked inference
196 | self.padding = False
197 | # Zero-pad signal on either side by the delay
198 | audio_signal.zero_pad(self.delay, self.delay)
199 | n_samples = int(win_duration * self.sample_rate)
200 | # Round n_samples to nearest hop length multiple
201 | n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
202 | hop = self.get_output_length(n_samples)
203 |
204 | codes = []
205 | range_fn = range if not verbose else tqdm.trange
206 |
207 | for i in range_fn(0, nt, hop):
208 | x = audio_signal[..., i : i + n_samples]
209 | x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
210 |
211 | audio_data = x.audio_data.to(self.device)
212 | audio_data = self.preprocess(audio_data, self.sample_rate)
213 | _, c, _, _, _ = self.encode(audio_data, n_quantizers)
214 | codes.append(c.to(original_device))
215 | chunk_length = c.shape[-1]
216 |
217 | codes = torch.cat(codes, dim=-1)
218 |
219 | dac_file = DACFile(
220 | codes=codes,
221 | chunk_length=chunk_length,
222 | original_length=original_length,
223 | input_db=input_db,
224 | channels=nac,
225 | sample_rate=original_sr,
226 | padding=self.padding,
227 | dac_version=SUPPORTED_VERSIONS[-1],
228 | )
229 |
230 | if n_quantizers is not None:
231 | codes = codes[:, :n_quantizers, :]
232 |
233 | self.padding = original_padding
234 | return dac_file
235 |
236 | @torch.no_grad()
237 | def decompress(
238 | self,
239 | obj: Union[str, Path, DACFile],
240 | verbose: bool = False,
241 | ) -> AudioSignal:
242 | """Reconstruct audio from a given .dac file
243 |
244 | Parameters
245 | ----------
246 | obj : Union[str, Path, DACFile]
247 | .dac file location or corresponding DACFile object.
248 | verbose : bool, optional
249 | Prints progress if True, by default False
250 |
251 | Returns
252 | -------
253 | AudioSignal
254 | Object with the reconstructed audio
255 | """
256 | self.eval()
257 | if isinstance(obj, (str, Path)):
258 | obj = DACFile.load(obj)
259 |
260 | original_padding = self.padding
261 | self.padding = obj.padding
262 |
263 | range_fn = range if not verbose else tqdm.trange
264 | codes = obj.codes
265 | original_device = codes.device
266 | chunk_length = obj.chunk_length
267 | recons = []
268 |
269 | for i in range_fn(0, codes.shape[-1], chunk_length):
270 | c = codes[..., i : i + chunk_length].to(self.device)
271 | z = self.quantizer.from_codes(c)[0]
272 | r = self.decode(z)
273 | recons.append(r.to(original_device))
274 |
275 | recons = torch.cat(recons, dim=-1)
276 | recons = AudioSignal(recons, self.sample_rate)
277 |
278 | resample_fn = recons.resample
279 | loudness_fn = recons.loudness
280 |
281 | # If audio is > 10 minutes long, use the ffmpeg versions
282 | if recons.signal_duration >= 10 * 60 * 60:
283 | resample_fn = recons.ffmpeg_resample
284 | loudness_fn = recons.ffmpeg_loudness
285 |
286 | if obj.input_db is not None:
287 | recons.normalize(obj.input_db)
288 |
289 | resample_fn(obj.sample_rate)
290 |
291 | if obj.original_length is not None:
292 | recons = recons[..., : obj.original_length]
293 | loudness_fn()
294 | recons.audio_data = recons.audio_data.reshape(
295 | -1, obj.channels, obj.original_length
296 | )
297 | else:
298 | loudness_fn()
299 |
300 | self.padding = original_padding
301 | return recons
302 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/synchformer/video_model_builder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3 | # Copyright 2020 Ross Wightman
4 | # Modified Model definition
5 |
6 | from collections import OrderedDict
7 | from functools import partial
8 |
9 | import torch
10 | import torch.nn as nn
11 | from timm.layers import trunc_normal_
12 |
13 | from .vit_helper import PatchEmbed, PatchEmbed3D, DividedSpaceTimeBlock
14 |
15 |
16 | class VisionTransformer(nn.Module):
17 | """Vision Transformer with support for patch or hybrid CNN input stage"""
18 |
19 | def __init__(self, cfg):
20 | super().__init__()
21 | self.img_size = cfg.DATA.TRAIN_CROP_SIZE
22 | self.patch_size = cfg.VIT.PATCH_SIZE
23 | self.in_chans = cfg.VIT.CHANNELS
24 | if cfg.TRAIN.DATASET == "Epickitchens":
25 | self.num_classes = [97, 300]
26 | else:
27 | self.num_classes = cfg.MODEL.NUM_CLASSES
28 | self.embed_dim = cfg.VIT.EMBED_DIM
29 | self.depth = cfg.VIT.DEPTH
30 | self.num_heads = cfg.VIT.NUM_HEADS
31 | self.mlp_ratio = cfg.VIT.MLP_RATIO
32 | self.qkv_bias = cfg.VIT.QKV_BIAS
33 | self.drop_rate = cfg.VIT.DROP
34 | self.drop_path_rate = cfg.VIT.DROP_PATH
35 | self.head_dropout = cfg.VIT.HEAD_DROPOUT
36 | self.video_input = cfg.VIT.VIDEO_INPUT
37 | self.temporal_resolution = cfg.VIT.TEMPORAL_RESOLUTION
38 | self.use_mlp = cfg.VIT.USE_MLP
39 | self.num_features = self.embed_dim
40 | norm_layer = partial(nn.LayerNorm, eps=1e-6)
41 | self.attn_drop_rate = cfg.VIT.ATTN_DROPOUT
42 | self.head_act = cfg.VIT.HEAD_ACT
43 | self.cfg = cfg
44 |
45 | # Patch Embedding
46 | self.patch_embed = PatchEmbed(
47 | img_size=224, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim
48 | )
49 |
50 | # 3D Patch Embedding
51 | self.patch_embed_3d = PatchEmbed3D(
52 | img_size=self.img_size,
53 | temporal_resolution=self.temporal_resolution,
54 | patch_size=self.patch_size,
55 | in_chans=self.in_chans,
56 | embed_dim=self.embed_dim,
57 | z_block_size=self.cfg.VIT.PATCH_SIZE_TEMP,
58 | )
59 | self.patch_embed_3d.proj.weight.data = torch.zeros_like(self.patch_embed_3d.proj.weight.data)
60 |
61 | # Number of patches
62 | if self.video_input:
63 | num_patches = self.patch_embed.num_patches * self.temporal_resolution
64 | else:
65 | num_patches = self.patch_embed.num_patches
66 | self.num_patches = num_patches
67 |
68 | # CLS token
69 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
70 | trunc_normal_(self.cls_token, std=0.02)
71 |
72 | # Positional embedding
73 | self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim))
74 | self.pos_drop = nn.Dropout(p=cfg.VIT.POS_DROPOUT)
75 | trunc_normal_(self.pos_embed, std=0.02)
76 |
77 | if self.cfg.VIT.POS_EMBED == "joint":
78 | self.st_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
79 | trunc_normal_(self.st_embed, std=0.02)
80 | elif self.cfg.VIT.POS_EMBED == "separate":
81 | self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim))
82 |
83 | # Layer Blocks
84 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
85 | if self.cfg.VIT.ATTN_LAYER == "divided":
86 | self.blocks = nn.ModuleList(
87 | [
88 | DividedSpaceTimeBlock(
89 | attn_type=cfg.VIT.ATTN_LAYER,
90 | dim=self.embed_dim,
91 | num_heads=self.num_heads,
92 | mlp_ratio=self.mlp_ratio,
93 | qkv_bias=self.qkv_bias,
94 | drop=self.drop_rate,
95 | attn_drop=self.attn_drop_rate,
96 | drop_path=dpr[i],
97 | norm_layer=norm_layer,
98 | )
99 | for i in range(self.depth)
100 | ]
101 | )
102 |
103 | self.norm = norm_layer(self.embed_dim)
104 |
105 | # MLP head
106 | if self.use_mlp:
107 | hidden_dim = self.embed_dim
108 | if self.head_act == "tanh":
109 | # logging.info("Using TanH activation in MLP")
110 | act = nn.Tanh()
111 | elif self.head_act == "gelu":
112 | # logging.info("Using GELU activation in MLP")
113 | act = nn.GELU()
114 | else:
115 | # logging.info("Using ReLU activation in MLP")
116 | act = nn.ReLU()
117 | self.pre_logits = nn.Sequential(
118 | OrderedDict(
119 | [
120 | ("fc", nn.Linear(self.embed_dim, hidden_dim)),
121 | ("act", act),
122 | ]
123 | )
124 | )
125 | else:
126 | self.pre_logits = nn.Identity()
127 |
128 | # Classifier Head
129 | self.head_drop = nn.Dropout(p=self.head_dropout)
130 | if isinstance(self.num_classes, (list,)) and len(self.num_classes) > 1:
131 | for a, i in enumerate(range(len(self.num_classes))):
132 | setattr(self, "head%d" % a, nn.Linear(self.embed_dim, self.num_classes[i]))
133 | else:
134 | self.head = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
135 |
136 | # Initialize weights
137 | self.apply(self._init_weights)
138 |
139 | def _init_weights(self, m):
140 | if isinstance(m, nn.Linear):
141 | trunc_normal_(m.weight, std=0.02)
142 | if isinstance(m, nn.Linear) and m.bias is not None:
143 | nn.init.constant_(m.bias, 0)
144 | elif isinstance(m, nn.LayerNorm):
145 | nn.init.constant_(m.bias, 0)
146 | nn.init.constant_(m.weight, 1.0)
147 |
148 | @torch.jit.ignore
149 | def no_weight_decay(self):
150 | if self.cfg.VIT.POS_EMBED == "joint":
151 | return {"pos_embed", "cls_token", "st_embed"}
152 | else:
153 | return {"pos_embed", "cls_token", "temp_embed"}
154 |
155 | def get_classifier(self):
156 | return self.head
157 |
158 | def reset_classifier(self, num_classes, global_pool=""):
159 | self.num_classes = num_classes
160 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
161 |
162 | def forward_features(self, x):
163 | # if self.video_input:
164 | # x = x[0]
165 | B = x.shape[0]
166 |
167 | # Tokenize input
168 | # if self.cfg.VIT.PATCH_SIZE_TEMP > 1:
169 | # for simplicity of mapping between content dimensions (input x) and token dims (after patching)
170 | # we use the same trick as for AST (see modeling_ast.ASTModel.forward for the details):
171 |
172 | # apply patching on input
173 | x = self.patch_embed_3d(x)
174 | tok_mask = None
175 |
176 | # else:
177 | # tok_mask = None
178 | # # 2D tokenization
179 | # if self.video_input:
180 | # x = x.permute(0, 2, 1, 3, 4)
181 | # (B, T, C, H, W) = x.shape
182 | # x = x.reshape(B * T, C, H, W)
183 |
184 | # x = self.patch_embed(x)
185 |
186 | # if self.video_input:
187 | # (B2, T2, D2) = x.shape
188 | # x = x.reshape(B, T * T2, D2)
189 |
190 | # Append CLS token
191 | cls_tokens = self.cls_token.expand(B, -1, -1)
192 | x = torch.cat((cls_tokens, x), dim=1)
193 | # if tok_mask is not None:
194 | # # prepend 1(=keep) to the mask to account for the CLS token as well
195 | # tok_mask = torch.cat((torch.ones_like(tok_mask[:, [0]]), tok_mask), dim=1)
196 |
197 | # Interpolate positinoal embeddings
198 | # if self.cfg.DATA.TRAIN_CROP_SIZE != 224:
199 | # pos_embed = self.pos_embed
200 | # N = pos_embed.shape[1] - 1
201 | # npatch = int((x.size(1) - 1) / self.temporal_resolution)
202 | # class_emb = pos_embed[:, 0]
203 | # pos_embed = pos_embed[:, 1:]
204 | # dim = x.shape[-1]
205 | # pos_embed = torch.nn.functional.interpolate(
206 | # pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
207 | # scale_factor=math.sqrt(npatch / N),
208 | # mode='bicubic',
209 | # )
210 | # pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
211 | # new_pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
212 | # else:
213 | new_pos_embed = self.pos_embed
214 | npatch = self.patch_embed.num_patches
215 |
216 | # Add positional embeddings to input
217 | if self.video_input:
218 | if self.cfg.VIT.POS_EMBED == "separate":
219 | cls_embed = self.pos_embed[:, 0, :].unsqueeze(1)
220 | tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1)
221 | tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1)
222 | total_pos_embed = tile_pos_embed + tile_temporal_embed
223 | total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1)
224 | x = x + total_pos_embed
225 | elif self.cfg.VIT.POS_EMBED == "joint":
226 | x = x + self.st_embed
227 | else:
228 | # image input
229 | x = x + new_pos_embed
230 |
231 | # Apply positional dropout
232 | x = self.pos_drop(x)
233 |
234 | # Encoding using transformer layers
235 | for i, blk in enumerate(self.blocks):
236 | x = blk(
237 | x,
238 | seq_len=npatch,
239 | num_frames=self.temporal_resolution,
240 | approx=self.cfg.VIT.APPROX_ATTN_TYPE,
241 | num_landmarks=self.cfg.VIT.APPROX_ATTN_DIM,
242 | tok_mask=tok_mask,
243 | )
244 |
245 | ### v-iashin: I moved it to the forward pass
246 | # x = self.norm(x)[:, 0]
247 | # x = self.pre_logits(x)
248 | ###
249 | return x, tok_mask
250 |
251 | # def forward(self, x):
252 | # x = self.forward_features(x)
253 | # ### v-iashin: here. This should leave the same forward output as before
254 | # x = self.norm(x)[:, 0]
255 | # x = self.pre_logits(x)
256 | # ###
257 | # x = self.head_drop(x)
258 | # if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1:
259 | # output = []
260 | # for head in range(len(self.num_classes)):
261 | # x_out = getattr(self, "head%d" % head)(x)
262 | # if not self.training:
263 | # x_out = torch.nn.functional.softmax(x_out, dim=-1)
264 | # output.append(x_out)
265 | # return output
266 | # else:
267 | # x = self.head(x)
268 | # if not self.training:
269 | # x = torch.nn.functional.softmax(x, dim=-1)
270 | # return x
271 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/nn/loss.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from typing import List
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from audiotools import AudioSignal
7 | from audiotools import STFTParams
8 | from torch import nn
9 |
10 |
11 | class L1Loss(nn.L1Loss):
12 | """L1 Loss between AudioSignals. Defaults
13 | to comparing ``audio_data``, but any
14 | attribute of an AudioSignal can be used.
15 |
16 | Parameters
17 | ----------
18 | attribute : str, optional
19 | Attribute of signal to compare, defaults to ``audio_data``.
20 | weight : float, optional
21 | Weight of this loss, defaults to 1.0.
22 |
23 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
24 | """
25 |
26 | def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
27 | self.attribute = attribute
28 | self.weight = weight
29 | super().__init__(**kwargs)
30 |
31 | def forward(self, x: AudioSignal, y: AudioSignal):
32 | """
33 | Parameters
34 | ----------
35 | x : AudioSignal
36 | Estimate AudioSignal
37 | y : AudioSignal
38 | Reference AudioSignal
39 |
40 | Returns
41 | -------
42 | torch.Tensor
43 | L1 loss between AudioSignal attributes.
44 | """
45 | if isinstance(x, AudioSignal):
46 | x = getattr(x, self.attribute)
47 | y = getattr(y, self.attribute)
48 | return super().forward(x, y)
49 |
50 |
51 | class SISDRLoss(nn.Module):
52 | """
53 | Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
54 | of estimated and reference audio signals or aligned features.
55 |
56 | Parameters
57 | ----------
58 | scaling : int, optional
59 | Whether to use scale-invariant (True) or
60 | signal-to-noise ratio (False), by default True
61 | reduction : str, optional
62 | How to reduce across the batch (either 'mean',
63 | 'sum', or none).], by default ' mean'
64 | zero_mean : int, optional
65 | Zero mean the references and estimates before
66 | computing the loss, by default True
67 | clip_min : int, optional
68 | The minimum possible loss value. Helps network
69 | to not focus on making already good examples better, by default None
70 | weight : float, optional
71 | Weight of this loss, defaults to 1.0.
72 |
73 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
74 | """
75 |
76 | def __init__(
77 | self,
78 | scaling: int = True,
79 | reduction: str = "mean",
80 | zero_mean: int = True,
81 | clip_min: int = None,
82 | weight: float = 1.0,
83 | ):
84 | self.scaling = scaling
85 | self.reduction = reduction
86 | self.zero_mean = zero_mean
87 | self.clip_min = clip_min
88 | self.weight = weight
89 | super().__init__()
90 |
91 | def forward(self, x: AudioSignal, y: AudioSignal):
92 | eps = 1e-8
93 | # nb, nc, nt
94 | if isinstance(x, AudioSignal):
95 | references = x.audio_data
96 | estimates = y.audio_data
97 | else:
98 | references = x
99 | estimates = y
100 |
101 | nb = references.shape[0]
102 | references = references.reshape(nb, 1, -1).permute(0, 2, 1)
103 | estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
104 |
105 | # samples now on axis 1
106 | if self.zero_mean:
107 | mean_reference = references.mean(dim=1, keepdim=True)
108 | mean_estimate = estimates.mean(dim=1, keepdim=True)
109 | else:
110 | mean_reference = 0
111 | mean_estimate = 0
112 |
113 | _references = references - mean_reference
114 | _estimates = estimates - mean_estimate
115 |
116 | references_projection = (_references**2).sum(dim=-2) + eps
117 | references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
118 |
119 | scale = (
120 | (references_on_estimates / references_projection).unsqueeze(1)
121 | if self.scaling
122 | else 1
123 | )
124 |
125 | e_true = scale * _references
126 | e_res = _estimates - e_true
127 |
128 | signal = (e_true**2).sum(dim=1)
129 | noise = (e_res**2).sum(dim=1)
130 | sdr = -10 * torch.log10(signal / noise + eps)
131 |
132 | if self.clip_min is not None:
133 | sdr = torch.clamp(sdr, min=self.clip_min)
134 |
135 | if self.reduction == "mean":
136 | sdr = sdr.mean()
137 | elif self.reduction == "sum":
138 | sdr = sdr.sum()
139 | return sdr
140 |
141 |
142 | class MultiScaleSTFTLoss(nn.Module):
143 | """Computes the multi-scale STFT loss from [1].
144 |
145 | Parameters
146 | ----------
147 | window_lengths : List[int], optional
148 | Length of each window of each STFT, by default [2048, 512]
149 | loss_fn : typing.Callable, optional
150 | How to compare each loss, by default nn.L1Loss()
151 | clamp_eps : float, optional
152 | Clamp on the log magnitude, below, by default 1e-5
153 | mag_weight : float, optional
154 | Weight of raw magnitude portion of loss, by default 1.0
155 | log_weight : float, optional
156 | Weight of log magnitude portion of loss, by default 1.0
157 | pow : float, optional
158 | Power to raise magnitude to before taking log, by default 2.0
159 | weight : float, optional
160 | Weight of this loss, by default 1.0
161 | match_stride : bool, optional
162 | Whether to match the stride of convolutional layers, by default False
163 |
164 | References
165 | ----------
166 |
167 | 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
168 | "DDSP: Differentiable Digital Signal Processing."
169 | International Conference on Learning Representations. 2019.
170 |
171 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
172 | """
173 |
174 | def __init__(
175 | self,
176 | window_lengths: List[int] = [2048, 512],
177 | loss_fn: typing.Callable = nn.L1Loss(),
178 | clamp_eps: float = 1e-5,
179 | mag_weight: float = 1.0,
180 | log_weight: float = 1.0,
181 | pow: float = 2.0,
182 | weight: float = 1.0,
183 | match_stride: bool = False,
184 | window_type: str = None,
185 | ):
186 | super().__init__()
187 | self.stft_params = [
188 | STFTParams(
189 | window_length=w,
190 | hop_length=w // 4,
191 | match_stride=match_stride,
192 | window_type=window_type,
193 | )
194 | for w in window_lengths
195 | ]
196 | self.loss_fn = loss_fn
197 | self.log_weight = log_weight
198 | self.mag_weight = mag_weight
199 | self.clamp_eps = clamp_eps
200 | self.weight = weight
201 | self.pow = pow
202 |
203 | def forward(self, x: AudioSignal, y: AudioSignal):
204 | """Computes multi-scale STFT between an estimate and a reference
205 | signal.
206 |
207 | Parameters
208 | ----------
209 | x : AudioSignal
210 | Estimate signal
211 | y : AudioSignal
212 | Reference signal
213 |
214 | Returns
215 | -------
216 | torch.Tensor
217 | Multi-scale STFT loss.
218 | """
219 | loss = 0.0
220 | for s in self.stft_params:
221 | x.stft(s.window_length, s.hop_length, s.window_type)
222 | y.stft(s.window_length, s.hop_length, s.window_type)
223 | loss += self.log_weight * self.loss_fn(
224 | x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
225 | y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
226 | )
227 | loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
228 | return loss
229 |
230 |
231 | class MelSpectrogramLoss(nn.Module):
232 | """Compute distance between mel spectrograms. Can be used
233 | in a multi-scale way.
234 |
235 | Parameters
236 | ----------
237 | n_mels : List[int]
238 | Number of mels per STFT, by default [150, 80],
239 | window_lengths : List[int], optional
240 | Length of each window of each STFT, by default [2048, 512]
241 | loss_fn : typing.Callable, optional
242 | How to compare each loss, by default nn.L1Loss()
243 | clamp_eps : float, optional
244 | Clamp on the log magnitude, below, by default 1e-5
245 | mag_weight : float, optional
246 | Weight of raw magnitude portion of loss, by default 1.0
247 | log_weight : float, optional
248 | Weight of log magnitude portion of loss, by default 1.0
249 | pow : float, optional
250 | Power to raise magnitude to before taking log, by default 2.0
251 | weight : float, optional
252 | Weight of this loss, by default 1.0
253 | match_stride : bool, optional
254 | Whether to match the stride of convolutional layers, by default False
255 |
256 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
257 | """
258 |
259 | def __init__(
260 | self,
261 | n_mels: List[int] = [150, 80],
262 | window_lengths: List[int] = [2048, 512],
263 | loss_fn: typing.Callable = nn.L1Loss(),
264 | clamp_eps: float = 1e-5,
265 | mag_weight: float = 1.0,
266 | log_weight: float = 1.0,
267 | pow: float = 2.0,
268 | weight: float = 1.0,
269 | match_stride: bool = False,
270 | mel_fmin: List[float] = [0.0, 0.0],
271 | mel_fmax: List[float] = [None, None],
272 | window_type: str = None,
273 | ):
274 | super().__init__()
275 | self.stft_params = [
276 | STFTParams(
277 | window_length=w,
278 | hop_length=w // 4,
279 | match_stride=match_stride,
280 | window_type=window_type,
281 | )
282 | for w in window_lengths
283 | ]
284 | self.n_mels = n_mels
285 | self.loss_fn = loss_fn
286 | self.clamp_eps = clamp_eps
287 | self.log_weight = log_weight
288 | self.mag_weight = mag_weight
289 | self.weight = weight
290 | self.mel_fmin = mel_fmin
291 | self.mel_fmax = mel_fmax
292 | self.pow = pow
293 |
294 | def forward(self, x: AudioSignal, y: AudioSignal):
295 | """Computes mel loss between an estimate and a reference
296 | signal.
297 |
298 | Parameters
299 | ----------
300 | x : AudioSignal
301 | Estimate signal
302 | y : AudioSignal
303 | Reference signal
304 |
305 | Returns
306 | -------
307 | torch.Tensor
308 | Mel loss.
309 | """
310 | loss = 0.0
311 | for n_mels, fmin, fmax, s in zip(
312 | self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
313 | ):
314 | kwargs = {
315 | "window_length": s.window_length,
316 | "hop_length": s.hop_length,
317 | "window_type": s.window_type,
318 | }
319 | x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
320 | y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
321 |
322 | loss += self.log_weight * self.loss_fn(
323 | x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
324 | y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
325 | )
326 | loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
327 | return loss
328 |
329 |
330 | class GANLoss(nn.Module):
331 | """
332 | Computes a discriminator loss, given a discriminator on
333 | generated waveforms/spectrograms compared to ground truth
334 | waveforms/spectrograms. Computes the loss for both the
335 | discriminator and the generator in separate functions.
336 | """
337 |
338 | def __init__(self, discriminator):
339 | super().__init__()
340 | self.discriminator = discriminator
341 |
342 | def forward(self, fake, real):
343 | d_fake = self.discriminator(fake.audio_data)
344 | d_real = self.discriminator(real.audio_data)
345 | return d_fake, d_real
346 |
347 | def discriminator_loss(self, fake, real):
348 | d_fake, d_real = self.forward(fake.clone().detach(), real)
349 |
350 | loss_d = 0
351 | for x_fake, x_real in zip(d_fake, d_real):
352 | loss_d += torch.mean(x_fake[-1] ** 2)
353 | loss_d += torch.mean((1 - x_real[-1]) ** 2)
354 | return loss_d
355 |
356 | def generator_loss(self, fake, real):
357 | d_fake, d_real = self.forward(fake, real)
358 |
359 | loss_g = 0
360 | for x_fake in d_fake:
361 | loss_g += torch.mean((1 - x_fake[-1]) ** 2)
362 |
363 | loss_feature = 0
364 |
365 | for i in range(len(d_fake)):
366 | for j in range(len(d_fake[i]) - 1):
367 | loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
368 | return loss_g, loss_feature
369 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/synchformer/synchformer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | from typing import Any, Mapping
4 |
5 | import einops
6 | import numpy as np
7 | import torch
8 | import torchaudio
9 | from torch import nn
10 | from torch.nn import functional as F
11 |
12 | from .motionformer import MotionFormer
13 | from .ast_model import AST
14 | from .utils import Config
15 |
16 |
17 | class Synchformer(nn.Module):
18 |
19 | def __init__(self):
20 | super().__init__()
21 |
22 | self.vfeat_extractor = MotionFormer(
23 | extract_features=True,
24 | factorize_space_time=True,
25 | agg_space_module="TransformerEncoderLayer",
26 | agg_time_module="torch.nn.Identity",
27 | add_global_repr=False,
28 | )
29 | self.afeat_extractor = AST(
30 | extract_features=True,
31 | max_spec_t=66,
32 | factorize_freq_time=True,
33 | agg_freq_module="TransformerEncoderLayer",
34 | agg_time_module="torch.nn.Identity",
35 | add_global_repr=False,
36 | )
37 |
38 | # # bridging the s3d latent dim (1024) into what is specified in the config
39 | # # to match e.g. the transformer dim
40 | self.vproj = nn.Linear(in_features=768, out_features=768)
41 | self.aproj = nn.Linear(in_features=768, out_features=768)
42 | self.transformer = GlobalTransformer(
43 | tok_pdrop=0.0, embd_pdrop=0.1, resid_pdrop=0.1, attn_pdrop=0.1, n_layer=3, n_head=8, n_embd=768
44 | )
45 |
46 | def forward(self, vis):
47 | B, S, Tv, C, H, W = vis.shape
48 | vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W)
49 | # feat extractors return a tuple of segment-level and global features (ignored for sync)
50 | # (B, S, tv, D), e.g. (B, 7, 8, 768)
51 | vis = self.vfeat_extractor(vis)
52 | return vis
53 |
54 | def compare_v_a(self, vis: torch.Tensor, aud: torch.Tensor):
55 | vis = self.vproj(vis)
56 | aud = self.aproj(aud)
57 |
58 | B, S, tv, D = vis.shape
59 | B, S, ta, D = aud.shape
60 | vis = vis.view(B, S * tv, D) # (B, S*tv, D)
61 | aud = aud.view(B, S * ta, D) # (B, S*ta, D)
62 | # print(vis.shape, aud.shape)
63 |
64 | # self.transformer will concatenate the vis and aud in one sequence with aux tokens,
65 | # ie `CvvvvMaaaaaa`, and will return the logits for the CLS tokens
66 | logits = self.transformer(vis, aud) # (B, cls); or (B, cls) and (B, 2) if DoubtingTransformer
67 |
68 | return logits
69 |
70 | def extract_vfeats(self, vis):
71 | B, S, Tv, C, H, W = vis.shape
72 | vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W)
73 | # feat extractors return a tuple of segment-level and global features (ignored for sync)
74 | # (B, S, tv, D), e.g. (B, 7, 8, 768)
75 | vis = self.vfeat_extractor(vis)
76 | return vis
77 |
78 | def extract_afeats(self, aud):
79 | B, S, _, Fa, Ta = aud.shape
80 | aud = aud.view(B, S, Fa, Ta).permute(0, 1, 3, 2) # (B, S, Ta, F)
81 | # (B, S, ta, D), e.g. (B, 7, 6, 768)
82 | aud, _ = self.afeat_extractor(aud)
83 | return aud
84 |
85 | def compute_loss(self, logits, targets, loss_fn: str = None):
86 | loss = None
87 | if targets is not None:
88 | if loss_fn is None or loss_fn == "cross_entropy":
89 | # logits: (B, cls) and targets: (B,)
90 | loss = F.cross_entropy(logits, targets)
91 | else:
92 | raise NotImplementedError(f"Loss {loss_fn} not implemented")
93 | return loss
94 |
95 | def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True):
96 | # discard all entries except vfeat_extractor
97 | # sd = {k: v for k, v in sd.items() if k.startswith('vfeat_extractor')}
98 |
99 | return super().load_state_dict(sd, strict)
100 |
101 |
102 | class RandInitPositionalEncoding(nn.Module):
103 | """Random inited trainable pos embedding. It is just applied on the sequence, thus respects no priors."""
104 |
105 | def __init__(self, block_shape: list, n_embd: int):
106 | super().__init__()
107 | self.block_shape = block_shape
108 | self.n_embd = n_embd
109 | self.pos_emb = nn.Parameter(torch.randn(1, *block_shape, n_embd))
110 |
111 | def forward(self, token_embeddings):
112 | return token_embeddings + self.pos_emb
113 |
114 |
115 | class GlobalTransformer(torch.nn.Module):
116 | """Same as in SparseSync but without the selector transformers and the head"""
117 |
118 | def __init__(
119 | self,
120 | tok_pdrop=0.0,
121 | embd_pdrop=0.1,
122 | resid_pdrop=0.1,
123 | attn_pdrop=0.1,
124 | n_layer=3,
125 | n_head=8,
126 | n_embd=768,
127 | pos_emb_block_shape=[
128 | 198,
129 | ],
130 | n_off_head_out=21,
131 | ) -> None:
132 | super().__init__()
133 | self.config = Config(
134 | embd_pdrop=embd_pdrop,
135 | resid_pdrop=resid_pdrop,
136 | attn_pdrop=attn_pdrop,
137 | n_layer=n_layer,
138 | n_head=n_head,
139 | n_embd=n_embd,
140 | )
141 | # input norm
142 | self.vis_in_lnorm = torch.nn.LayerNorm(n_embd)
143 | self.aud_in_lnorm = torch.nn.LayerNorm(n_embd)
144 | # aux tokens
145 | self.OFF_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd))
146 | self.MOD_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd))
147 | # whole token dropout
148 | self.tok_pdrop = tok_pdrop
149 | self.tok_drop_vis = torch.nn.Dropout1d(tok_pdrop)
150 | self.tok_drop_aud = torch.nn.Dropout1d(tok_pdrop)
151 | # maybe add pos emb
152 | self.pos_emb_cfg = RandInitPositionalEncoding(
153 | block_shape=pos_emb_block_shape,
154 | n_embd=n_embd,
155 | )
156 | # the stem
157 | self.drop = torch.nn.Dropout(embd_pdrop)
158 | self.blocks = torch.nn.Sequential(*[Block(self.config) for _ in range(n_layer)])
159 | # pre-output norm
160 | self.ln_f = torch.nn.LayerNorm(n_embd)
161 | # maybe add a head
162 | self.off_head = torch.nn.Linear(in_features=n_embd, out_features=n_off_head_out)
163 |
164 | def forward(self, v: torch.Tensor, a: torch.Tensor, targets=None, attempt_to_apply_heads=True):
165 | B, Sv, D = v.shape
166 | B, Sa, D = a.shape
167 | # broadcasting special tokens to the batch size
168 | off_tok = einops.repeat(self.OFF_tok, "1 1 d -> b 1 d", b=B)
169 | mod_tok = einops.repeat(self.MOD_tok, "1 1 d -> b 1 d", b=B)
170 | # norm
171 | v, a = self.vis_in_lnorm(v), self.aud_in_lnorm(a)
172 | # maybe whole token dropout
173 | if self.tok_pdrop > 0:
174 | v, a = self.tok_drop_vis(v), self.tok_drop_aud(a)
175 | # (B, 1+Sv+1+Sa, D)
176 | x = torch.cat((off_tok, v, mod_tok, a), dim=1)
177 | # maybe add pos emb
178 | if hasattr(self, "pos_emb_cfg"):
179 | x = self.pos_emb_cfg(x)
180 | # dropout -> stem -> norm
181 | x = self.drop(x)
182 | x = self.blocks(x)
183 | x = self.ln_f(x)
184 | # maybe add heads
185 | if attempt_to_apply_heads and hasattr(self, "off_head"):
186 | x = self.off_head(x[:, 0, :])
187 | return x
188 |
189 |
190 | class SelfAttention(nn.Module):
191 | """
192 | A vanilla multi-head masked self-attention layer with a projection at the end.
193 | It is possible to use torch.nn.MultiheadAttention here but I am including an
194 | explicit implementation here to show that there is nothing too scary here.
195 | """
196 |
197 | def __init__(self, config):
198 | super().__init__()
199 | assert config.n_embd % config.n_head == 0
200 | # key, query, value projections for all heads
201 | self.key = nn.Linear(config.n_embd, config.n_embd)
202 | self.query = nn.Linear(config.n_embd, config.n_embd)
203 | self.value = nn.Linear(config.n_embd, config.n_embd)
204 | # regularization
205 | self.attn_drop = nn.Dropout(config.attn_pdrop)
206 | self.resid_drop = nn.Dropout(config.resid_pdrop)
207 | # output projection
208 | self.proj = nn.Linear(config.n_embd, config.n_embd)
209 | # # causal mask to ensure that attention is only applied to the left in the input sequence
210 | # mask = torch.tril(torch.ones(config.block_size,
211 | # config.block_size))
212 | # if hasattr(config, "n_unmasked"):
213 | # mask[:config.n_unmasked, :config.n_unmasked] = 1
214 | # self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
215 | self.n_head = config.n_head
216 |
217 | def forward(self, x):
218 | B, T, C = x.size()
219 |
220 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
221 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
222 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
223 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
224 |
225 | # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
226 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
227 | # att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
228 | att = F.softmax(att, dim=-1)
229 | y = self.attn_drop(att) @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
230 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
231 |
232 | # output projection
233 | y = self.resid_drop(self.proj(y))
234 |
235 | return y
236 |
237 |
238 | class Block(nn.Module):
239 | """an unassuming Transformer block"""
240 |
241 | def __init__(self, config):
242 | super().__init__()
243 | self.ln1 = nn.LayerNorm(config.n_embd)
244 | self.ln2 = nn.LayerNorm(config.n_embd)
245 | self.attn = SelfAttention(config)
246 | self.mlp = nn.Sequential(
247 | nn.Linear(config.n_embd, 4 * config.n_embd),
248 | nn.GELU(), # nice
249 | nn.Linear(4 * config.n_embd, config.n_embd),
250 | nn.Dropout(config.resid_pdrop),
251 | )
252 |
253 | def forward(self, x):
254 | x = x + self.attn(self.ln1(x))
255 | x = x + self.mlp(self.ln2(x))
256 | return x
257 |
258 |
259 | def make_class_grid(
260 | leftmost_val,
261 | rightmost_val,
262 | grid_size,
263 | add_extreme_offset: bool = False,
264 | seg_size_vframes: int = None,
265 | nseg: int = None,
266 | step_size_seg: float = None,
267 | vfps: float = None,
268 | ):
269 | assert grid_size >= 3, f"grid_size: {grid_size} doesnot make sense. If =2 -> (-1,1); =1 -> (-1); =0 -> ()"
270 | grid = torch.from_numpy(np.linspace(leftmost_val, rightmost_val, grid_size)).float()
271 | if add_extreme_offset:
272 | assert all([seg_size_vframes, nseg, step_size_seg]), f"{seg_size_vframes} {nseg} {step_size_seg}"
273 | seg_size_sec = seg_size_vframes / vfps
274 | trim_size_in_seg = nseg - (1 - step_size_seg) * (nseg - 1)
275 | extreme_value = trim_size_in_seg * seg_size_sec
276 | grid = torch.cat([grid, torch.tensor([extreme_value])]) # adding extreme offset to the class grid
277 | return grid
278 |
279 |
280 | # from synchformer
281 | def pad_or_truncate(audio: torch.Tensor, max_spec_t: int, pad_mode: str = "constant", pad_value: float = 0.0):
282 | difference = max_spec_t - audio.shape[-1] # safe for batched input
283 | # pad or truncate, depending on difference
284 | if difference > 0:
285 | # pad the last dim (time) -> (..., n_mels, 0+time+difference) # safe for batched input
286 | pad_dims = (0, difference)
287 | audio = torch.nn.functional.pad(audio, pad_dims, pad_mode, pad_value)
288 | elif difference < 0:
289 | print(f"Truncating spec ({audio.shape}) to max_spec_t ({max_spec_t}).")
290 | audio = audio[..., :max_spec_t] # safe for batched input
291 | return audio
292 |
293 |
294 | def encode_audio_with_sync(
295 | synchformer: Synchformer, x: torch.Tensor, mel: torchaudio.transforms.MelSpectrogram
296 | ) -> torch.Tensor:
297 | b, t = x.shape
298 |
299 | # partition the video
300 | segment_size = 10240
301 | step_size = 10240 // 2
302 | num_segments = (t - segment_size) // step_size + 1
303 | segments = []
304 | for i in range(num_segments):
305 | segments.append(x[:, i * step_size : i * step_size + segment_size])
306 | x = torch.stack(segments, dim=1) # (B, S, T, C, H, W)
307 |
308 | x = mel(x)
309 | x = torch.log(x + 1e-6)
310 | x = pad_or_truncate(x, 66)
311 |
312 | mean = -4.2677393
313 | std = 4.5689974
314 | x = (x - mean) / (2 * std)
315 | # x: B * S * 128 * 66
316 | x = synchformer.extract_afeats(x.unsqueeze(2))
317 | return x
318 |
319 |
320 | def read_audio(filename, expected_length=int(16000 * 4)):
321 | waveform, sr = torchaudio.load(filename)
322 | waveform = waveform.mean(dim=0)
323 |
324 | if sr != 16000:
325 | resampler = torchaudio.transforms.Resample(sr, 16000)
326 | waveform = resampler[sr](waveform)
327 |
328 | waveform = waveform[:expected_length]
329 | if waveform.shape[0] != expected_length:
330 | raise ValueError(f"Audio {filename} is too short")
331 |
332 | waveform = waveform.squeeze()
333 |
334 | return waveform
335 |
336 |
337 | if __name__ == "__main__":
338 | synchformer = Synchformer().cuda().eval()
339 |
340 | # mmaudio provided synchformer ckpt
341 | synchformer.load_state_dict(
342 | torch.load(
343 | os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"),
344 | weights_only=True,
345 | map_location="cpu",
346 | )
347 | )
348 |
349 | sync_mel_spectrogram = torchaudio.transforms.MelSpectrogram(
350 | sample_rate=16000,
351 | win_length=400,
352 | hop_length=160,
353 | n_fft=1024,
354 | n_mels=128,
355 | )
356 |
--------------------------------------------------------------------------------
/hunyuanvideo_foley/models/dac_vae/model/dac.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List
3 | from typing import Union
4 |
5 | import numpy as np
6 | import torch
7 | from audiotools import AudioSignal
8 | from audiotools.ml import BaseModel
9 | from torch import nn
10 |
11 | from .base import CodecMixin
12 | from ..nn.layers import Snake1d
13 | from ..nn.layers import WNConv1d
14 | from ..nn.layers import WNConvTranspose1d
15 | from ..nn.quantize import ResidualVectorQuantize
16 | from ..nn.vae_utils import DiagonalGaussianDistribution
17 |
18 |
19 | def init_weights(m):
20 | if isinstance(m, nn.Conv1d):
21 | nn.init.trunc_normal_(m.weight, std=0.02)
22 | nn.init.constant_(m.bias, 0)
23 |
24 |
25 | class ResidualUnit(nn.Module):
26 | def __init__(self, dim: int = 16, dilation: int = 1):
27 | super().__init__()
28 | pad = ((7 - 1) * dilation) // 2
29 | self.block = nn.Sequential(
30 | Snake1d(dim),
31 | WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
32 | Snake1d(dim),
33 | WNConv1d(dim, dim, kernel_size=1),
34 | )
35 |
36 | def forward(self, x):
37 | y = self.block(x)
38 | pad = (x.shape[-1] - y.shape[-1]) // 2
39 | if pad > 0:
40 | x = x[..., pad:-pad]
41 | return x + y
42 |
43 |
44 | class EncoderBlock(nn.Module):
45 | def __init__(self, dim: int = 16, stride: int = 1):
46 | super().__init__()
47 | self.block = nn.Sequential(
48 | ResidualUnit(dim // 2, dilation=1),
49 | ResidualUnit(dim // 2, dilation=3),
50 | ResidualUnit(dim // 2, dilation=9),
51 | Snake1d(dim // 2),
52 | WNConv1d(
53 | dim // 2,
54 | dim,
55 | kernel_size=2 * stride,
56 | stride=stride,
57 | padding=math.ceil(stride / 2),
58 | ),
59 | )
60 |
61 | def forward(self, x):
62 | return self.block(x)
63 |
64 |
65 | class Encoder(nn.Module):
66 | def __init__(
67 | self,
68 | d_model: int = 64,
69 | strides: list = [2, 4, 8, 8],
70 | d_latent: int = 64,
71 | ):
72 | super().__init__()
73 | # Create first convolution
74 | self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
75 |
76 | # Create EncoderBlocks that double channels as they downsample by `stride`
77 | for stride in strides:
78 | d_model *= 2
79 | self.block += [EncoderBlock(d_model, stride=stride)]
80 |
81 | # Create last convolution
82 | self.block += [
83 | Snake1d(d_model),
84 | WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
85 | ]
86 |
87 | # Wrap black into nn.Sequential
88 | self.block = nn.Sequential(*self.block)
89 | self.enc_dim = d_model
90 |
91 | def forward(self, x):
92 | return self.block(x)
93 |
94 |
95 | class DecoderBlock(nn.Module):
96 | def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
97 | super().__init__()
98 | self.block = nn.Sequential(
99 | Snake1d(input_dim),
100 | WNConvTranspose1d(
101 | input_dim,
102 | output_dim,
103 | kernel_size=2 * stride,
104 | stride=stride,
105 | padding=math.ceil(stride / 2),
106 | output_padding=stride % 2,
107 | ),
108 | ResidualUnit(output_dim, dilation=1),
109 | ResidualUnit(output_dim, dilation=3),
110 | ResidualUnit(output_dim, dilation=9),
111 | )
112 |
113 | def forward(self, x):
114 | return self.block(x)
115 |
116 |
117 | class Decoder(nn.Module):
118 | def __init__(
119 | self,
120 | input_channel,
121 | channels,
122 | rates,
123 | d_out: int = 1,
124 | ):
125 | super().__init__()
126 |
127 | # Add first conv layer
128 | layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
129 |
130 | # Add upsampling + MRF blocks
131 | for i, stride in enumerate(rates):
132 | input_dim = channels // 2**i
133 | output_dim = channels // 2 ** (i + 1)
134 | layers += [DecoderBlock(input_dim, output_dim, stride)]
135 |
136 | # Add final conv layer
137 | layers += [
138 | Snake1d(output_dim),
139 | WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
140 | nn.Tanh(),
141 | ]
142 |
143 | self.model = nn.Sequential(*layers)
144 |
145 | def forward(self, x):
146 | return self.model(x)
147 |
148 |
149 | class DAC(BaseModel, CodecMixin):
150 | def __init__(
151 | self,
152 | encoder_dim: int = 64,
153 | encoder_rates: List[int] = [2, 4, 8, 8],
154 | latent_dim: int = None,
155 | decoder_dim: int = 1536,
156 | decoder_rates: List[int] = [8, 8, 4, 2],
157 | n_codebooks: int = 9,
158 | codebook_size: int = 1024,
159 | codebook_dim: Union[int, list] = 8,
160 | quantizer_dropout: bool = False,
161 | sample_rate: int = 44100,
162 | continuous: bool = False,
163 | ):
164 | super().__init__()
165 |
166 | self.encoder_dim = encoder_dim
167 | self.encoder_rates = encoder_rates
168 | self.decoder_dim = decoder_dim
169 | self.decoder_rates = decoder_rates
170 | self.sample_rate = sample_rate
171 | self.continuous = continuous
172 |
173 | if latent_dim is None:
174 | latent_dim = encoder_dim * (2 ** len(encoder_rates))
175 |
176 | self.latent_dim = latent_dim
177 |
178 | self.hop_length = np.prod(encoder_rates)
179 | self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
180 |
181 | if not continuous:
182 | self.n_codebooks = n_codebooks
183 | self.codebook_size = codebook_size
184 | self.codebook_dim = codebook_dim
185 | self.quantizer = ResidualVectorQuantize(
186 | input_dim=latent_dim,
187 | n_codebooks=n_codebooks,
188 | codebook_size=codebook_size,
189 | codebook_dim=codebook_dim,
190 | quantizer_dropout=quantizer_dropout,
191 | )
192 | else:
193 | self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1)
194 | self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1)
195 |
196 | self.decoder = Decoder(
197 | latent_dim,
198 | decoder_dim,
199 | decoder_rates,
200 | )
201 | self.sample_rate = sample_rate
202 | self.apply(init_weights)
203 |
204 | self.delay = self.get_delay()
205 |
206 | @property
207 | def dtype(self):
208 | """Get the dtype of the model parameters."""
209 | # Return the dtype of the first parameter found
210 | for param in self.parameters():
211 | return param.dtype
212 | return torch.float32 # fallback
213 |
214 | @property
215 | def device(self):
216 | """Get the device of the model parameters."""
217 | # Return the device of the first parameter found
218 | for param in self.parameters():
219 | return param.device
220 | return torch.device('cpu') # fallback
221 |
222 | def preprocess(self, audio_data, sample_rate):
223 | if sample_rate is None:
224 | sample_rate = self.sample_rate
225 | assert sample_rate == self.sample_rate
226 |
227 | length = audio_data.shape[-1]
228 | right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
229 | audio_data = nn.functional.pad(audio_data, (0, right_pad))
230 |
231 | return audio_data
232 |
233 | def encode(
234 | self,
235 | audio_data: torch.Tensor,
236 | n_quantizers: int = None,
237 | ):
238 | """Encode given audio data and return quantized latent codes
239 |
240 | Parameters
241 | ----------
242 | audio_data : Tensor[B x 1 x T]
243 | Audio data to encode
244 | n_quantizers : int, optional
245 | Number of quantizers to use, by default None
246 | If None, all quantizers are used.
247 |
248 | Returns
249 | -------
250 | dict
251 | A dictionary with the following keys:
252 | "z" : Tensor[B x D x T]
253 | Quantized continuous representation of input
254 | "codes" : Tensor[B x N x T]
255 | Codebook indices for each codebook
256 | (quantized discrete representation of input)
257 | "latents" : Tensor[B x N*D x T]
258 | Projected latents (continuous representation of input before quantization)
259 | "vq/commitment_loss" : Tensor[1]
260 | Commitment loss to train encoder to predict vectors closer to codebook
261 | entries
262 | "vq/codebook_loss" : Tensor[1]
263 | Codebook loss to update the codebook
264 | "length" : int
265 | Number of samples in input audio
266 | """
267 | z = self.encoder(audio_data) # [B x D x T]
268 | if not self.continuous:
269 | z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
270 | else:
271 | z = self.quant_conv(z) # [B x 2D x T]
272 | z = DiagonalGaussianDistribution(z)
273 | codes, latents, commitment_loss, codebook_loss = None, None, 0, 0
274 |
275 | return z, codes, latents, commitment_loss, codebook_loss
276 |
277 | def decode(self, z: torch.Tensor):
278 | """Decode given latent codes and return audio data
279 |
280 | Parameters
281 | ----------
282 | z : Tensor[B x D x T]
283 | Quantized continuous representation of input
284 | length : int, optional
285 | Number of samples in output audio, by default None
286 |
287 | Returns
288 | -------
289 | dict
290 | A dictionary with the following keys:
291 | "audio" : Tensor[B x 1 x length]
292 | Decoded audio data.
293 | """
294 | if not self.continuous:
295 | audio = self.decoder(z)
296 | else:
297 | z = self.post_quant_conv(z)
298 | audio = self.decoder(z)
299 |
300 | return audio
301 |
302 | def forward(
303 | self,
304 | audio_data: torch.Tensor,
305 | sample_rate: int = None,
306 | n_quantizers: int = None,
307 | ):
308 | """Model forward pass
309 |
310 | Parameters
311 | ----------
312 | audio_data : Tensor[B x 1 x T]
313 | Audio data to encode
314 | sample_rate : int, optional
315 | Sample rate of audio data in Hz, by default None
316 | If None, defaults to `self.sample_rate`
317 | n_quantizers : int, optional
318 | Number of quantizers to use, by default None.
319 | If None, all quantizers are used.
320 |
321 | Returns
322 | -------
323 | dict
324 | A dictionary with the following keys:
325 | "z" : Tensor[B x D x T]
326 | Quantized continuous representation of input
327 | "codes" : Tensor[B x N x T]
328 | Codebook indices for each codebook
329 | (quantized discrete representation of input)
330 | "latents" : Tensor[B x N*D x T]
331 | Projected latents (continuous representation of input before quantization)
332 | "vq/commitment_loss" : Tensor[1]
333 | Commitment loss to train encoder to predict vectors closer to codebook
334 | entries
335 | "vq/codebook_loss" : Tensor[1]
336 | Codebook loss to update the codebook
337 | "length" : int
338 | Number of samples in input audio
339 | "audio" : Tensor[B x 1 x length]
340 | Decoded audio data.
341 | """
342 | length = audio_data.shape[-1]
343 | audio_data = self.preprocess(audio_data, sample_rate)
344 | if not self.continuous:
345 | z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
346 |
347 | x = self.decode(z)
348 | return {
349 | "audio": x[..., :length],
350 | "z": z,
351 | "codes": codes,
352 | "latents": latents,
353 | "vq/commitment_loss": commitment_loss,
354 | "vq/codebook_loss": codebook_loss,
355 | }
356 | else:
357 | posterior, _, _, _, _ = self.encode(audio_data, n_quantizers)
358 | z = posterior.sample()
359 | x = self.decode(z)
360 |
361 | kl_loss = posterior.kl()
362 | kl_loss = kl_loss.mean()
363 |
364 | return {
365 | "audio": x[..., :length],
366 | "z": z,
367 | "kl_loss": kl_loss,
368 | }
369 |
370 |
371 | if __name__ == "__main__":
372 | import numpy as np
373 | from functools import partial
374 |
375 | model = DAC().to("cpu")
376 |
377 | for n, m in model.named_modules():
378 | o = m.extra_repr()
379 | p = sum([np.prod(p.size()) for p in m.parameters()])
380 | fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
381 | setattr(m, "extra_repr", partial(fn, o=o, p=p))
382 | print(model)
383 | print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
384 |
385 | length = 88200 * 2
386 | x = torch.randn(1, 1, length).to(model.device)
387 | x.requires_grad_(True)
388 | x.retain_grad()
389 |
390 | # Make a forward pass
391 | out = model(x)["audio"]
392 | print("Input shape:", x.shape)
393 | print("Output shape:", out.shape)
394 |
395 | # Create gradient variable
396 | grad = torch.zeros_like(out)
397 | grad[:, :, grad.shape[-1] // 2] = 1
398 |
399 | # Make a backward pass
400 | out.backward(grad)
401 |
402 | # Check non-zero values
403 | gradmap = x.grad.squeeze(0)
404 | gradmap = (gradmap != 0).sum(0) # sum across features
405 | rf = (gradmap != 0).sum()
406 |
407 | print(f"Receptive field: {rf.item()}")
408 |
409 | x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
410 | model.decompress(model.compress(x, verbose=True), verbose=True)
411 |
--------------------------------------------------------------------------------