├── hunyuanvideo_foley ├── __init__.py ├── models │ ├── __init__.py │ ├── nn │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-313.pyc │ │ │ ├── attn_layers.cpython-313.pyc │ │ │ ├── mlp_layers.cpython-313.pyc │ │ │ ├── norm_layers.cpython-313.pyc │ │ │ ├── embed_layers.cpython-313.pyc │ │ │ ├── posemb_layers.cpython-313.pyc │ │ │ ├── modulate_layers.cpython-313.pyc │ │ │ └── activation_layers.cpython-313.pyc │ │ ├── activation_layers.py │ │ ├── modulate_layers.py │ │ ├── norm_layers.py │ │ ├── embed_layers.py │ │ ├── mlp_layers.py │ │ └── posemb_layers.py │ ├── synchformer │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── utils.cpython-313.pyc │ │ │ ├── __init__.cpython-313.pyc │ │ │ ├── ast_model.cpython-313.pyc │ │ │ ├── modeling_ast.cpython-313.pyc │ │ │ ├── motionformer.cpython-313.pyc │ │ │ ├── synchformer.cpython-313.pyc │ │ │ ├── vit_helper.cpython-313.pyc │ │ │ └── video_model_builder.cpython-313.pyc │ │ ├── divided_224_16x4.yaml │ │ ├── utils.py │ │ ├── compute_desync_score.py │ │ ├── video_model_builder.py │ │ ├── synchformer.py │ │ └── vit_helper.py │ ├── dac_vae │ │ ├── nn │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── loss.cpython-313.pyc │ │ │ │ ├── __init__.cpython-313.pyc │ │ │ │ ├── layers.cpython-313.pyc │ │ │ │ ├── quantize.cpython-313.pyc │ │ │ │ └── vae_utils.cpython-313.pyc │ │ │ ├── layers.py │ │ │ ├── vae_utils.py │ │ │ ├── quantize.py │ │ │ └── loss.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── base.cpython-313.pyc │ │ │ │ ├── dac.cpython-313.pyc │ │ │ │ ├── __init__.cpython-313.pyc │ │ │ │ └── discriminator.cpython-313.pyc │ │ │ ├── discriminator.py │ │ │ ├── base.py │ │ │ └── dac.py │ │ ├── __pycache__ │ │ │ └── __init__.cpython-313.pyc │ │ ├── utils │ │ │ ├── __pycache__ │ │ │ │ └── __init__.cpython-313.pyc │ │ │ ├── decode.py │ │ │ ├── encode.py │ │ │ └── __init__.py │ │ ├── __init__.py │ │ └── __main__.py │ └── __pycache__ │ │ ├── __init__.cpython-313.pyc │ │ └── hifi_foley.cpython-313.pyc ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── helper.cpython-313.pyc │ │ ├── __init__.cpython-313.pyc │ │ ├── config_utils.cpython-313.pyc │ │ ├── model_utils.cpython-313.pyc │ │ └── feature_utils.cpython-313.pyc │ ├── schedulers │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-313.pyc │ │ │ └── scheduling_flow_match_discrete.cpython-313.pyc │ │ └── scheduling_flow_match_discrete.py │ ├── media_utils.py │ ├── config_utils.py │ ├── helper.py │ ├── feature_utils.py │ └── model_utils.py ├── __pycache__ │ ├── __init__.cpython-313.pyc │ └── constants.cpython-313.pyc └── constants.py ├── utils.py ├── __init__.py ├── node_list.json ├── pyproject.toml ├── requirements.txt ├── .github └── workflows │ └── publish_action.yml ├── README.md ├── LICENSE ├── configs └── hunyuanvideo-foley-xxl.yaml └── download_models_manual.py /hunyuanvideo_foley/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .synchformer import Synchformer 2 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from . import layers 2 | from . import loss 3 | from . import quantize 4 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/__pycache__/constants.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/__pycache__/constants.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import CodecMixin 2 | from .base import DACFile 3 | from .dac import DAC 4 | from .discriminator import Discriminator 5 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/__pycache__/helper.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/utils/__pycache__/helper.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/utils/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from diffusers.schedulers import DDPMScheduler, EulerDiscreteScheduler 2 | from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/__pycache__/hifi_foley.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/__pycache__/hifi_foley.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/nn/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/__pycache__/config_utils.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/utils/__pycache__/config_utils.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/__pycache__/model_utils.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/utils/__pycache__/model_utils.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/__pycache__/attn_layers.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/nn/__pycache__/attn_layers.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/__pycache__/mlp_layers.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/nn/__pycache__/mlp_layers.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/__pycache__/norm_layers.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/nn/__pycache__/norm_layers.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/__pycache__/feature_utils.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/utils/__pycache__/feature_utils.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/dac_vae/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/__pycache__/loss.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/loss.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/__pycache__/embed_layers.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/nn/__pycache__/embed_layers.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/__pycache__/posemb_layers.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/nn/__pycache__/posemb_layers.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/model/__pycache__/base.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/dac_vae/model/__pycache__/base.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/model/__pycache__/dac.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/dac_vae/model/__pycache__/dac.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/__pycache__/layers.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/layers.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/__pycache__/quantize.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/quantize.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/__pycache__/modulate_layers.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/nn/__pycache__/modulate_layers.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/__pycache__/utils.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/synchformer/__pycache__/utils.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/schedulers/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/utils/schedulers/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/__pycache__/vae_utils.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/vae_utils.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/__pycache__/activation_layers.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/nn/__pycache__/activation_layers.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/synchformer/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/__pycache__/ast_model.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/synchformer/__pycache__/ast_model.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/model/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/dac_vae/model/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/utils/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/dac_vae/utils/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/__pycache__/modeling_ast.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/synchformer/__pycache__/modeling_ast.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/__pycache__/motionformer.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/synchformer/__pycache__/motionformer.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/__pycache__/synchformer.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/synchformer/__pycache__/synchformer.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/__pycache__/vit_helper.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/synchformer/__pycache__/vit_helper.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/model/__pycache__/discriminator.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/dac_vae/model/__pycache__/discriminator.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/__pycache__/video_model_builder.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/models/synchformer/__pycache__/video_model_builder.cpython-313.pyc -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aistudynow/Comfyui-HunyuanFoley/HEAD/hunyuanvideo_foley/utils/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-313.pyc -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | def ensure_in_path(): 6 | """Add this package's directory to Python path.""" 7 | current_dir = os.path.dirname(os.path.abspath(__file__)) 8 | if current_dir not in sys.path: 9 | sys.path.insert(0, current_dir) -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # Ensure package directory is on the Python path for relative imports 5 | current_dir = os.path.dirname(os.path.abspath(__file__)) 6 | if current_dir not in sys.path: 7 | sys.path.insert(0, current_dir) 8 | 9 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 10 | 11 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | 3 | # preserved here for legacy reasons 4 | __model_version__ = "latest" 5 | 6 | import audiotools 7 | 8 | audiotools.ml.BaseModel.INTERN += ["dac.**"] 9 | audiotools.ml.BaseModel.EXTERN += ["einops"] 10 | 11 | 12 | from . import nn 13 | from . import model 14 | from . import utils 15 | from .model import DAC 16 | from .model import DACFile 17 | -------------------------------------------------------------------------------- /node_list.json: -------------------------------------------------------------------------------- 1 | { 2 | "HunyuanModelLoader": "Load HunyuanVideo-Foley main diffusion model with precision/FP8 options", 3 | "HunyuanDependenciesLoader": "Load DAC-VAE, SigLIP2, Synchformer, and CLAP dependencies", 4 | "HunyuanFoleySampler": "Generate Foley audio (ping-pong memory, fast CFG, optional FP8)", 5 | "HunyuanFoleyTorchCompile": "Optional torch.compile accelerator (~30% faster after first compile)", 6 | "SelectAudioFromBatch": "Pick one waveform from an AUDIO batch" 7 | } 8 | 9 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "Comfyui-HunyuanFoley" 3 | description = "Generate Audio from any video and or text" 4 | version = "2.1.0" 5 | license = {file = "LICENSE"} 6 | 7 | [project.urls] 8 | Repository = "https://github.com/aistudynow/Comfyui-HunyuanFoley" 9 | # Used by Comfy Registry https://comfyregistry.org 10 | 11 | [tool.comfy] 12 | PublisherId = "aistudynow" 13 | DisplayName = "Comfyui-HunyuanFoley" 14 | Icon = "" 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Core ML dependencies (numpy and scipy are often pulled in by torch) 2 | numpy 3 | scipy 4 | 5 | # Deep Learning frameworks 6 | diffusers 7 | timm 8 | accelerate 9 | 10 | # Transformers for loading SigLIP2 and CLAP models from Hugging Face 11 | transformers 12 | sentencepiece 13 | 14 | # Audio VAE model (requires direct git install) 15 | git+https://github.com/descriptinc/audiotools 16 | 17 | # Image & Tensor manipulation 18 | pillow 19 | einops 20 | 21 | # Configuration and Utilities 22 | pyyaml 23 | omegaconf 24 | loguru 25 | tqdm 26 | -------------------------------------------------------------------------------- /.github/workflows/publish_action.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | jobs: 11 | publish-node: 12 | name: Publish Custom Node to registry 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Check out code 16 | uses: actions/checkout@v4 17 | - name: Publish Custom Node 18 | uses: Comfy-Org/publish-node-action@main 19 | with: 20 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} ## Add your own personal access token to your Github Repository secrets and reference it here. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI_HunyuanFoley 2 | 3 | Update: Now work on lowvram, also new workflow 4 | New Workflow: created by **[https://aistudynow.com/hunyuanvideo-foley-comfyui-workflow-turn-quiet-video-into-sound/](https://aistudynow.com/hunyuanvideo-foley-comfyui-workflow-turn-quiet-video-into-sound/)** 5 | 6 | ComfyUI wrapper for **Tencent HunyuanVideo-Foley**. 7 | 8 | 9 | --- 10 | 11 | ## Models 12 | 13 | Get the files from the official release: **[HunyuanVideo-Foley on Hugging Face](https://huggingface.co/tencent/HunyuanVideo-Foley/tree/main)** 14 | 15 | Place all HunyuanVideo-Foley weights and `config.yaml` in one folder named `hunyuanfoley`: 16 | 17 | ```text 18 | ComfyUI/models/hunyuanfoley/ 19 | ├─ hunyuanvideo_foley.pth 20 | ├─ vae_128d_48k.pth 21 | ├─ synchformer_state_dict.pth 22 | └─ config.yaml 23 | 24 | # ComfyUI_HunyuanFoley 25 | 26 | 27 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/__main__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import argbind 4 | 5 | from .utils import download 6 | from .utils.decode import decode 7 | from .utils.encode import encode 8 | 9 | STAGES = ["encode", "decode", "download"] 10 | 11 | 12 | def run(stage: str): 13 | """Run stages. 14 | 15 | Parameters 16 | ---------- 17 | stage : str 18 | Stage to run 19 | """ 20 | if stage not in STAGES: 21 | raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}") 22 | stage_fn = globals()[stage] 23 | 24 | if stage == "download": 25 | stage_fn() 26 | return 27 | 28 | stage_fn() 29 | 30 | 31 | if __name__ == "__main__": 32 | group = sys.argv.pop(1) 33 | args = argbind.parse_args(group=group) 34 | 35 | with argbind.scope(args): 36 | run(group) 37 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from torch.nn.utils import weight_norm 7 | 8 | 9 | def WNConv1d(*args, **kwargs): 10 | return weight_norm(nn.Conv1d(*args, **kwargs)) 11 | 12 | 13 | def WNConvTranspose1d(*args, **kwargs): 14 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) 15 | 16 | 17 | # Scripting this brings model speed up 1.4x 18 | @torch.jit.script 19 | def snake(x, alpha): 20 | shape = x.shape 21 | x = x.reshape(shape[0], shape[1], -1) 22 | x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) 23 | x = x.reshape(shape) 24 | return x 25 | 26 | 27 | class Snake1d(nn.Module): 28 | def __init__(self, channels): 29 | super().__init__() 30 | self.alpha = nn.Parameter(torch.ones(1, channels, 1)) 31 | 32 | def forward(self, x): 33 | return snake(x, self.alpha) 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 HunyuanFoley contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/activation_layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | def get_activation_layer(act_type): 5 | if act_type == "gelu": 6 | return lambda: nn.GELU() 7 | elif act_type == "gelu_tanh": 8 | # Approximate `tanh` requires torch >= 1.13 9 | return lambda: nn.GELU(approximate="tanh") 10 | elif act_type == "relu": 11 | return nn.ReLU 12 | elif act_type == "silu": 13 | return nn.SiLU 14 | else: 15 | raise ValueError(f"Unknown activation type: {act_type}") 16 | 17 | class SwiGLU(nn.Module): 18 | def __init__( 19 | self, 20 | dim: int, 21 | hidden_dim: int, 22 | out_dim: int, 23 | ): 24 | """ 25 | Initialize the SwiGLU FeedForward module. 26 | 27 | Args: 28 | dim (int): Input dimension. 29 | hidden_dim (int): Hidden dimension of the feedforward layer. 30 | 31 | Attributes: 32 | w1: Linear transformation for the first layer. 33 | w2: Linear transformation for the second layer. 34 | w3: Linear transformation for the third layer. 35 | 36 | """ 37 | super().__init__() 38 | 39 | self.w1 = nn.Linear(dim, hidden_dim, bias=False) 40 | self.w2 = nn.Linear(hidden_dim, out_dim, bias=False) 41 | self.w3 = nn.Linear(dim, hidden_dim, bias=False) 42 | 43 | def forward(self, x): 44 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 45 | -------------------------------------------------------------------------------- /configs/hunyuanvideo-foley-xxl.yaml: -------------------------------------------------------------------------------- 1 | model_config: 2 | model_name: HunyuanVideo-Foley-XXL 3 | model_type: 1d 4 | model_precision: bf16 5 | model_kwargs: 6 | depth_triple_blocks: 18 7 | depth_single_blocks: 36 8 | hidden_size: 1536 9 | num_heads: 12 10 | mlp_ratio: 4 11 | mlp_act_type: "gelu_tanh" 12 | qkv_bias: True 13 | qk_norm: True 14 | qk_norm_type: "rms" 15 | attn_mode: "torch" 16 | embedder_type: "default" 17 | interleaved_audio_visual_rope: True 18 | enable_learnable_empty_visual_feat: True 19 | sync_modulation: False 20 | add_sync_feat_to_audio: True 21 | cross_attention: True 22 | use_attention_mask: False 23 | condition_projection: "linear" 24 | sync_feat_dim: 768 # syncformer 768 dim 25 | condition_dim: 768 # clap 768 text condition dim (clip-text) 26 | clip_dim: 768 # siglip2 visual dim 27 | audio_vae_latent_dim: 128 28 | audio_frame_rate: 50 29 | patch_size: 1 30 | rope_dim_list: null 31 | rope_theta: 10000 32 | text_length: 77 33 | clip_length: 64 34 | sync_length: 192 35 | use_mmaudio_singleblock: True 36 | depth_triple_ssl_encoder: null 37 | depth_single_ssl_encoder: 8 38 | use_repa_with_audiossl: True 39 | 40 | diffusion_config: 41 | denoise_type: "flow" 42 | flow_path_type: "linear" 43 | flow_predict_type: "velocity" 44 | flow_reverse: True 45 | flow_solver: "euler" 46 | sample_flow_shift: 1.0 47 | sample_use_flux_shift: False 48 | flux_base_shift: 0.5 49 | flux_max_shift: 1.15 50 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/modulate_layers.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import torch 3 | import torch.nn as nn 4 | 5 | class ModulateDiT(nn.Module): 6 | def __init__(self, hidden_size: int, factor: int, act_layer: Callable, dtype=None, device=None): 7 | factory_kwargs = {"dtype": dtype, "device": device} 8 | super().__init__() 9 | self.act = act_layer() 10 | self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) 11 | # Zero-initialize the modulation 12 | nn.init.zeros_(self.linear.weight) 13 | nn.init.zeros_(self.linear.bias) 14 | 15 | def forward(self, x: torch.Tensor) -> torch.Tensor: 16 | return self.linear(self.act(x)) 17 | 18 | 19 | def modulate(x, shift=None, scale=None): 20 | if x.ndim == 3: 21 | shift = shift.unsqueeze(1) if shift is not None and shift.ndim == 2 else None 22 | scale = scale.unsqueeze(1) if scale is not None and scale.ndim == 2 else None 23 | if scale is None and shift is None: 24 | return x 25 | elif shift is None: 26 | return x * (1 + scale) 27 | elif scale is None: 28 | return x + shift 29 | else: 30 | return x * (1 + scale) + shift 31 | 32 | 33 | def apply_gate(x, gate=None, tanh=False): 34 | if gate is None: 35 | return x 36 | if gate.ndim == 2 and x.ndim == 3: 37 | gate = gate.unsqueeze(1) 38 | if tanh: 39 | return x * gate.tanh() 40 | else: 41 | return x * gate 42 | 43 | 44 | def ckpt_wrapper(module): 45 | def ckpt_forward(*inputs): 46 | outputs = module(*inputs) 47 | return outputs 48 | 49 | return ckpt_forward 50 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/constants.py: -------------------------------------------------------------------------------- 1 | """Constants used throughout the HunyuanVideo-Foley project.""" 2 | 3 | from typing import Dict, List 4 | 5 | # Model configuration 6 | DEFAULT_AUDIO_SAMPLE_RATE = 48000 7 | DEFAULT_VIDEO_FPS = 25 8 | DEFAULT_AUDIO_CHANNELS = 2 9 | 10 | # Video processing 11 | MAX_VIDEO_DURATION_SECONDS = 15.0 12 | MIN_VIDEO_DURATION_SECONDS = 1.0 13 | 14 | # Audio processing 15 | AUDIO_VAE_LATENT_DIM = 128 16 | AUDIO_FRAME_RATE = 75 # frames per second in latent space 17 | 18 | # Visual features 19 | FPS_VISUAL: Dict[str, int] = { 20 | "siglip2": 8, 21 | "synchformer": 25 22 | } 23 | 24 | # Model paths (can be overridden by environment variables) 25 | DEFAULT_MODEL_PATH = "./pretrained_models/" 26 | DEFAULT_CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml" 27 | 28 | # Inference parameters 29 | DEFAULT_GUIDANCE_SCALE = 4.5 30 | DEFAULT_NUM_INFERENCE_STEPS = 50 31 | MIN_GUIDANCE_SCALE = 1.0 32 | MAX_GUIDANCE_SCALE = 10.0 33 | MIN_INFERENCE_STEPS = 10 34 | MAX_INFERENCE_STEPS = 100 35 | 36 | # Text processing 37 | MAX_TEXT_LENGTH = 100 38 | DEFAULT_NEGATIVE_PROMPT = "noisy, harsh" 39 | 40 | # File extensions 41 | SUPPORTED_VIDEO_EXTENSIONS: List[str] = [".mp4", ".avi", ".mov", ".mkv", ".webm"] 42 | SUPPORTED_AUDIO_EXTENSIONS: List[str] = [".wav", ".mp3", ".flac", ".aac"] 43 | 44 | # Quality settings 45 | AUDIO_QUALITY_SETTINGS: Dict[str, List[str]] = { 46 | "high": ["-b:a", "192k"], 47 | "medium": ["-b:a", "128k"], 48 | "low": ["-b:a", "96k"] 49 | } 50 | 51 | # Error messages 52 | ERROR_MESSAGES: Dict[str, str] = { 53 | "model_not_loaded": "Model is not loaded. Please load the model first.", 54 | "invalid_video_format": "Unsupported video format. Supported formats: {formats}", 55 | "video_too_long": f"Video duration exceeds maximum of {MAX_VIDEO_DURATION_SECONDS} seconds", 56 | "ffmpeg_not_found": "ffmpeg not found. Please install ffmpeg: https://ffmpeg.org/download.html" 57 | } -------------------------------------------------------------------------------- /download_models_manual.py: -------------------------------------------------------------------------------- 1 | """Utility to manually download HunyuanVideo-Foley model files. 2 | 3 | Run this script directly to place the required model weights in the 4 | `ComfyUI/models/hunyuanfoley` directory. Files are fetched from the official 5 | HuggingFace repository using simple HTTP requests. 6 | """ 7 | import os 8 | import pathlib 9 | import requests 10 | from typing import Dict, List 11 | 12 | # Mapping of model filenames to their download URLs 13 | MODEL_URLS: Dict[str, str] = { 14 | "hunyuanvideo_foley.pth": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/hunyuanvideo_foley.pth", 15 | "synchformer_state_dict.pth": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/synchformer_state_dict.pth", 16 | "vae_128d_48k.pth": "https://huggingface.co/tencent/HunyuanVideo-Foley/resolve/main/vae_128d_48k.pth", 17 | } 18 | 19 | 20 | def download_file(url: str, dest: pathlib.Path) -> None: 21 | """Download a URL to a local path with streaming.""" 22 | dest.parent.mkdir(parents=True, exist_ok=True) 23 | with requests.get(url, stream=True) as r: 24 | r.raise_for_status() 25 | with open(dest, "wb") as f: 26 | for chunk in r.iter_content(chunk_size=8192): 27 | if chunk: 28 | f.write(chunk) 29 | 30 | 31 | def download_all(model_dir: str) -> List[pathlib.Path]: 32 | model_paths = [] 33 | for name, url in MODEL_URLS.items(): 34 | target = pathlib.Path(model_dir) / name 35 | if target.exists(): 36 | model_paths.append(target) 37 | continue 38 | download_file(url, target) 39 | model_paths.append(target) 40 | return model_paths 41 | 42 | 43 | if __name__ == "__main__": 44 | # Determine default model directory relative to ComfyUI 45 | from folder_paths import models_dir 46 | dest_dir = os.path.join(models_dir, "hunyuanfoley") 47 | paths = download_all(dest_dir) 48 | for p in paths: 49 | print(f"Downloaded: {p}") -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/divided_224_16x4.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: Ssv2 4 | BATCH_SIZE: 32 5 | EVAL_PERIOD: 5 6 | CHECKPOINT_PERIOD: 5 7 | AUTO_RESUME: True 8 | CHECKPOINT_EPOCH_RESET: True 9 | CHECKPOINT_FILE_PATH: /checkpoint/fmetze/neurips_sota/40944587/checkpoints/checkpoint_epoch_00035.pyth 10 | DATA: 11 | NUM_FRAMES: 16 12 | SAMPLING_RATE: 4 13 | TRAIN_JITTER_SCALES: [256, 320] 14 | TRAIN_CROP_SIZE: 224 15 | TEST_CROP_SIZE: 224 16 | INPUT_CHANNEL_NUM: [3] 17 | MEAN: [0.5, 0.5, 0.5] 18 | STD: [0.5, 0.5, 0.5] 19 | PATH_TO_DATA_DIR: /private/home/mandelapatrick/slowfast/data/ssv2 20 | PATH_PREFIX: /datasets01/SomethingV2/092720/20bn-something-something-v2-frames 21 | INV_UNIFORM_SAMPLE: True 22 | RANDOM_FLIP: False 23 | REVERSE_INPUT_CHANNEL: True 24 | USE_RAND_AUGMENT: True 25 | RE_PROB: 0.0 26 | USE_REPEATED_AUG: False 27 | USE_RANDOM_RESIZE_CROPS: False 28 | COLORJITTER: False 29 | GRAYSCALE: False 30 | GAUSSIAN: False 31 | SOLVER: 32 | BASE_LR: 1e-4 33 | LR_POLICY: steps_with_relative_lrs 34 | LRS: [1, 0.1, 0.01] 35 | STEPS: [0, 20, 30] 36 | MAX_EPOCH: 35 37 | MOMENTUM: 0.9 38 | WEIGHT_DECAY: 5e-2 39 | WARMUP_EPOCHS: 0.0 40 | OPTIMIZING_METHOD: adamw 41 | USE_MIXED_PRECISION: True 42 | SMOOTHING: 0.2 43 | SLOWFAST: 44 | ALPHA: 8 45 | VIT: 46 | PATCH_SIZE: 16 47 | PATCH_SIZE_TEMP: 2 48 | CHANNELS: 3 49 | EMBED_DIM: 768 50 | DEPTH: 12 51 | NUM_HEADS: 12 52 | MLP_RATIO: 4 53 | QKV_BIAS: True 54 | VIDEO_INPUT: True 55 | TEMPORAL_RESOLUTION: 8 56 | USE_MLP: True 57 | DROP: 0.0 58 | POS_DROPOUT: 0.0 59 | DROP_PATH: 0.2 60 | IM_PRETRAINED: True 61 | HEAD_DROPOUT: 0.0 62 | HEAD_ACT: tanh 63 | PRETRAINED_WEIGHTS: vit_1k 64 | ATTN_LAYER: divided 65 | MODEL: 66 | NUM_CLASSES: 174 67 | ARCH: slow 68 | MODEL_NAME: VisionTransformer 69 | LOSS_FUNC: cross_entropy 70 | TEST: 71 | ENABLE: True 72 | DATASET: Ssv2 73 | BATCH_SIZE: 64 74 | NUM_ENSEMBLE_VIEWS: 1 75 | NUM_SPATIAL_CROPS: 3 76 | DATA_LOADER: 77 | NUM_WORKERS: 4 78 | PIN_MEMORY: True 79 | NUM_GPUS: 8 80 | NUM_SHARDS: 4 81 | RNG_SEED: 0 82 | OUTPUT_DIR: . 83 | TENSORBOARD: 84 | ENABLE: True 85 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/norm_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class RMSNorm(nn.Module): 5 | def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6, 6 | device=None, dtype=None): 7 | """ 8 | Initialize the RMSNorm normalization layer. 9 | 10 | Args: 11 | dim (int): The dimension of the input tensor. 12 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 13 | 14 | Attributes: 15 | eps (float): A small value added to the denominator for numerical stability. 16 | weight (nn.Parameter): Learnable scaling parameter. 17 | 18 | """ 19 | factory_kwargs = {'device': device, 'dtype': dtype} 20 | super().__init__() 21 | self.eps = eps 22 | if elementwise_affine: 23 | self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) 24 | 25 | def _norm(self, x): 26 | """ 27 | Apply the RMSNorm normalization to the input tensor. 28 | 29 | Args: 30 | x (torch.Tensor): The input tensor. 31 | 32 | Returns: 33 | torch.Tensor: The normalized tensor. 34 | 35 | """ 36 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 37 | 38 | def forward(self, x): 39 | """ 40 | Forward pass through the RMSNorm layer. 41 | 42 | Args: 43 | x (torch.Tensor): The input tensor. 44 | 45 | Returns: 46 | torch.Tensor: The output tensor after applying RMSNorm. 47 | 48 | """ 49 | output = self._norm(x.float()).type_as(x) 50 | if hasattr(self, "weight"): 51 | output = output * self.weight 52 | return output 53 | 54 | 55 | def get_norm_layer(norm_layer): 56 | """ 57 | Get the normalization layer. 58 | 59 | Args: 60 | norm_layer (str): The type of normalization layer. 61 | 62 | Returns: 63 | norm_layer (nn.Module): The normalization layer. 64 | """ 65 | if norm_layer == "layer": 66 | return nn.LayerNorm 67 | elif norm_layer == "rms": 68 | return RMSNorm 69 | else: 70 | raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") 71 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.0]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.mean( 45 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 46 | dim=[1, 2], 47 | ) 48 | else: 49 | return 0.5 * torch.mean( 50 | torch.pow(self.mean - other.mean, 2) / other.var 51 | + self.var / other.var 52 | - 1.0 53 | - self.logvar 54 | + other.logvar, 55 | dim=[1, 2], 56 | ) 57 | 58 | def nll(self, sample, dims=[1, 2]): 59 | if self.deterministic: 60 | return torch.Tensor([0.0]) 61 | logtwopi = np.log(2.0 * np.pi) 62 | return 0.5 * torch.sum( 63 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 64 | dim=dims, 65 | ) 66 | 67 | def mode(self): 68 | return self.mean 69 | 70 | 71 | def normal_kl(mean1, logvar1, mean2, logvar2): 72 | """ 73 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 74 | Compute the KL divergence between two gaussians. 75 | Shapes are automatically broadcasted, so batches can be compared to 76 | scalars, among other use cases. 77 | """ 78 | tensor = None 79 | for obj in (mean1, logvar1, mean2, logvar2): 80 | if isinstance(obj, torch.Tensor): 81 | tensor = obj 82 | break 83 | assert tensor is not None, "at least one argument must be a Tensor" 84 | 85 | # Force variances to be Tensors. Broadcasting helps convert scalars to 86 | # Tensors, but it does not work for torch.exp(). 87 | logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] 88 | 89 | return 0.5 * ( 90 | -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 91 | ) 92 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/utils/decode.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | 4 | import argbind 5 | import numpy as np 6 | import torch 7 | from audiotools import AudioSignal 8 | from tqdm import tqdm 9 | 10 | from ..model import DACFile 11 | from . import load_model 12 | 13 | warnings.filterwarnings("ignore", category=UserWarning) 14 | 15 | 16 | @argbind.bind(group="decode", positional=True, without_prefix=True) 17 | @torch.inference_mode() 18 | @torch.no_grad() 19 | def decode( 20 | input: str, 21 | output: str = "", 22 | weights_path: str = "", 23 | model_tag: str = "latest", 24 | model_bitrate: str = "8kbps", 25 | device: str = "cuda", 26 | model_type: str = "44khz", 27 | verbose: bool = False, 28 | ): 29 | """Decode audio from codes. 30 | 31 | Parameters 32 | ---------- 33 | input : str 34 | Path to input directory or file 35 | output : str, optional 36 | Path to output directory, by default "". 37 | If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. 38 | weights_path : str, optional 39 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the 40 | model_tag and model_type. 41 | model_tag : str, optional 42 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. 43 | model_bitrate: str 44 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 45 | device : str, optional 46 | Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU. 47 | model_type : str, optional 48 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. 49 | """ 50 | generator = load_model( 51 | model_type=model_type, 52 | model_bitrate=model_bitrate, 53 | tag=model_tag, 54 | load_path=weights_path, 55 | ) 56 | generator.to(device) 57 | generator.eval() 58 | 59 | # Find all .dac files in input directory 60 | _input = Path(input) 61 | input_files = list(_input.glob("**/*.dac")) 62 | 63 | # If input is a .dac file, add it to the list 64 | if _input.suffix == ".dac": 65 | input_files.append(_input) 66 | 67 | # Create output directory 68 | output = Path(output) 69 | output.mkdir(parents=True, exist_ok=True) 70 | 71 | for i in tqdm(range(len(input_files)), desc=f"Decoding files"): 72 | # Load file 73 | artifact = DACFile.load(input_files[i]) 74 | 75 | # Reconstruct audio from codes 76 | recons = generator.decompress(artifact, verbose=verbose) 77 | 78 | # Compute output path 79 | relative_path = input_files[i].relative_to(input) 80 | output_dir = output / relative_path.parent 81 | if not relative_path.name: 82 | output_dir = output 83 | relative_path = input_files[i] 84 | output_name = relative_path.with_suffix(".wav").name 85 | output_path = output_dir / output_name 86 | output_path.parent.mkdir(parents=True, exist_ok=True) 87 | 88 | # Write to file 89 | recons.write(output_path) 90 | 91 | 92 | if __name__ == "__main__": 93 | args = argbind.parse_args() 94 | with argbind.scope(args): 95 | decode() 96 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/media_utils.py: -------------------------------------------------------------------------------- 1 | """Media utilities for audio/video processing.""" 2 | 3 | import os 4 | import subprocess 5 | from pathlib import Path 6 | from typing import Optional 7 | 8 | from loguru import logger 9 | 10 | 11 | class MediaProcessingError(Exception): 12 | """Exception raised for media processing errors.""" 13 | pass 14 | 15 | 16 | def merge_audio_video( 17 | audio_path: str, 18 | video_path: str, 19 | output_path: str, 20 | overwrite: bool = True, 21 | quality: str = "high" 22 | ) -> str: 23 | """ 24 | Merge audio and video files using ffmpeg. 25 | 26 | Args: 27 | audio_path: Path to input audio file 28 | video_path: Path to input video file 29 | output_path: Path for output video file 30 | overwrite: Whether to overwrite existing output file 31 | quality: Quality setting ('high', 'medium', 'low') 32 | 33 | Returns: 34 | Path to the output file 35 | 36 | Raises: 37 | MediaProcessingError: If input files don't exist or ffmpeg fails 38 | FileNotFoundError: If ffmpeg is not installed 39 | """ 40 | # Validate input files 41 | if not os.path.exists(audio_path): 42 | raise MediaProcessingError(f"Audio file not found: {audio_path}") 43 | if not os.path.exists(video_path): 44 | raise MediaProcessingError(f"Video file not found: {video_path}") 45 | 46 | # Create output directory if needed 47 | output_dir = Path(output_path).parent 48 | output_dir.mkdir(parents=True, exist_ok=True) 49 | 50 | # Quality settings 51 | quality_settings = { 52 | "high": ["-b:a", "192k"], 53 | "medium": ["-b:a", "128k"], 54 | "low": ["-b:a", "96k"] 55 | } 56 | 57 | # Build ffmpeg command 58 | ffmpeg_command = [ 59 | "ffmpeg", 60 | "-i", video_path, 61 | "-i", audio_path, 62 | "-c:v", "copy", 63 | "-c:a", "aac", 64 | "-ac", "2", 65 | "-af", "pan=stereo|c0=c0|c1=c0", 66 | "-map", "0:v:0", 67 | "-map", "1:a:0", 68 | *quality_settings.get(quality, quality_settings["high"]), 69 | ] 70 | 71 | if overwrite: 72 | ffmpeg_command.append("-y") 73 | 74 | ffmpeg_command.append(output_path) 75 | 76 | try: 77 | logger.info(f"Merging audio '{audio_path}' with video '{video_path}'") 78 | process = subprocess.Popen( 79 | ffmpeg_command, 80 | stdout=subprocess.PIPE, 81 | stderr=subprocess.PIPE, 82 | text=True 83 | ) 84 | stdout, stderr = process.communicate() 85 | 86 | if process.returncode != 0: 87 | error_msg = f"FFmpeg failed with return code {process.returncode}: {stderr}" 88 | logger.error(error_msg) 89 | raise MediaProcessingError(error_msg) 90 | else: 91 | logger.info(f"Successfully merged video saved to: {output_path}") 92 | 93 | except FileNotFoundError: 94 | raise FileNotFoundError( 95 | "ffmpeg not found. Please install ffmpeg: " 96 | "https://ffmpeg.org/download.html" 97 | ) 98 | except Exception as e: 99 | raise MediaProcessingError(f"Unexpected error during media processing: {e}") 100 | 101 | return output_path 102 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/utils/encode.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from pathlib import Path 4 | 5 | import argbind 6 | import numpy as np 7 | import torch 8 | from audiotools import AudioSignal 9 | from audiotools.core import util 10 | from tqdm import tqdm 11 | 12 | from . import load_model 13 | 14 | warnings.filterwarnings("ignore", category=UserWarning) 15 | 16 | 17 | @argbind.bind(group="encode", positional=True, without_prefix=True) 18 | @torch.inference_mode() 19 | @torch.no_grad() 20 | def encode( 21 | input: str, 22 | output: str = "", 23 | weights_path: str = "", 24 | model_tag: str = "latest", 25 | model_bitrate: str = "8kbps", 26 | n_quantizers: int = None, 27 | device: str = "cuda", 28 | model_type: str = "44khz", 29 | win_duration: float = 5.0, 30 | verbose: bool = False, 31 | ): 32 | """Encode audio files in input path to .dac format. 33 | 34 | Parameters 35 | ---------- 36 | input : str 37 | Path to input audio file or directory 38 | output : str, optional 39 | Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. 40 | weights_path : str, optional 41 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the 42 | model_tag and model_type. 43 | model_tag : str, optional 44 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. 45 | model_bitrate: str 46 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 47 | n_quantizers : int, optional 48 | Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate. 49 | device : str, optional 50 | Device to use, by default "cuda" 51 | model_type : str, optional 52 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. 53 | """ 54 | generator = load_model( 55 | model_type=model_type, 56 | model_bitrate=model_bitrate, 57 | tag=model_tag, 58 | load_path=weights_path, 59 | ) 60 | generator.to(device) 61 | generator.eval() 62 | kwargs = {"n_quantizers": n_quantizers} 63 | 64 | # Find all audio files in input path 65 | input = Path(input) 66 | audio_files = util.find_audio(input) 67 | 68 | output = Path(output) 69 | output.mkdir(parents=True, exist_ok=True) 70 | 71 | for i in tqdm(range(len(audio_files)), desc="Encoding files"): 72 | # Load file 73 | signal = AudioSignal(audio_files[i]) 74 | 75 | # Encode audio to .dac format 76 | artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs) 77 | 78 | # Compute output path 79 | relative_path = audio_files[i].relative_to(input) 80 | output_dir = output / relative_path.parent 81 | if not relative_path.name: 82 | output_dir = output 83 | relative_path = audio_files[i] 84 | output_name = relative_path.with_suffix(".dac").name 85 | output_path = output_dir / output_name 86 | output_path.parent.mkdir(parents=True, exist_ok=True) 87 | 88 | artifact.save(output_path) 89 | 90 | 91 | if __name__ == "__main__": 92 | args = argbind.parse_args() 93 | with argbind.scope(args): 94 | encode() 95 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import argbind 4 | from audiotools import ml 5 | 6 | from ..model import DAC 7 | Accelerator = ml.Accelerator 8 | 9 | __MODEL_LATEST_TAGS__ = { 10 | ("44khz", "8kbps"): "0.0.1", 11 | ("24khz", "8kbps"): "0.0.4", 12 | ("16khz", "8kbps"): "0.0.5", 13 | ("44khz", "16kbps"): "1.0.0", 14 | } 15 | 16 | __MODEL_URLS__ = { 17 | ( 18 | "44khz", 19 | "0.0.1", 20 | "8kbps", 21 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth", 22 | ( 23 | "24khz", 24 | "0.0.4", 25 | "8kbps", 26 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth", 27 | ( 28 | "16khz", 29 | "0.0.5", 30 | "8kbps", 31 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth", 32 | ( 33 | "44khz", 34 | "1.0.0", 35 | "16kbps", 36 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth", 37 | } 38 | 39 | 40 | @argbind.bind(group="download", positional=True, without_prefix=True) 41 | def download( 42 | model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest" 43 | ): 44 | """ 45 | Function that downloads the weights file from URL if a local cache is not found. 46 | 47 | Parameters 48 | ---------- 49 | model_type : str 50 | The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". 51 | model_bitrate: str 52 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 53 | Only 44khz model supports 16kbps. 54 | tag : str 55 | The tag of the model to download. Defaults to "latest". 56 | 57 | Returns 58 | ------- 59 | Path 60 | Directory path required to load model via audiotools. 61 | """ 62 | model_type = model_type.lower() 63 | tag = tag.lower() 64 | 65 | assert model_type in [ 66 | "44khz", 67 | "24khz", 68 | "16khz", 69 | ], "model_type must be one of '44khz', '24khz', or '16khz'" 70 | 71 | assert model_bitrate in [ 72 | "8kbps", 73 | "16kbps", 74 | ], "model_bitrate must be one of '8kbps', or '16kbps'" 75 | 76 | if tag == "latest": 77 | tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)] 78 | 79 | download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None) 80 | 81 | if download_link is None: 82 | raise ValueError( 83 | f"Could not find model with tag {tag} and model type {model_type}" 84 | ) 85 | 86 | local_path = ( 87 | Path.home() 88 | / ".cache" 89 | / "descript" 90 | / "dac" 91 | / f"weights_{model_type}_{model_bitrate}_{tag}.pth" 92 | ) 93 | if not local_path.exists(): 94 | local_path.parent.mkdir(parents=True, exist_ok=True) 95 | 96 | # Download the model 97 | import requests 98 | 99 | response = requests.get(download_link) 100 | 101 | if response.status_code != 200: 102 | raise ValueError( 103 | f"Could not download model. Received response code {response.status_code}" 104 | ) 105 | local_path.write_bytes(response.content) 106 | 107 | return local_path 108 | 109 | 110 | def load_model( 111 | model_type: str = "44khz", 112 | model_bitrate: str = "8kbps", 113 | tag: str = "latest", 114 | load_path: str = None, 115 | ): 116 | if not load_path: 117 | load_path = download( 118 | model_type=model_type, model_bitrate=model_bitrate, tag=tag 119 | ) 120 | generator = DAC.load(load_path) 121 | return generator 122 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | """Configuration utilities for the HunyuanVideo-Foley project.""" 2 | 3 | import yaml 4 | from pathlib import Path 5 | from typing import Any, Dict, List, Union 6 | 7 | class AttributeDict: 8 | 9 | def __init__(self, data: Union[Dict, List, Any]): 10 | if isinstance(data, dict): 11 | for key, value in data.items(): 12 | if isinstance(value, (dict, list)): 13 | value = AttributeDict(value) 14 | setattr(self, self._sanitize_key(key), value) 15 | elif isinstance(data, list): 16 | self._list = [AttributeDict(item) if isinstance(item, (dict, list)) else item 17 | for item in data] 18 | else: 19 | self._value = data 20 | 21 | def _sanitize_key(self, key: str) -> str: 22 | import re 23 | sanitized = re.sub(r'[^a-zA-Z0-9_]', '_', str(key)) 24 | if sanitized[0].isdigit(): 25 | sanitized = f'_{sanitized}' 26 | return sanitized 27 | 28 | def __getitem__(self, key): 29 | if hasattr(self, '_list'): 30 | return self._list[key] 31 | return getattr(self, self._sanitize_key(key)) 32 | 33 | def __setitem__(self, key, value): 34 | if hasattr(self, '_list'): 35 | self._list[key] = value 36 | else: 37 | setattr(self, self._sanitize_key(key), value) 38 | 39 | def __iter__(self): 40 | if hasattr(self, '_list'): 41 | return iter(self._list) 42 | return iter(self.__dict__.keys()) 43 | 44 | def __len__(self): 45 | if hasattr(self, '_list'): 46 | return len(self._list) 47 | return len(self.__dict__) 48 | 49 | def get(self, key, default=None): 50 | try: 51 | return self[key] 52 | except (KeyError, AttributeError, IndexError): 53 | return default 54 | 55 | def keys(self): 56 | if hasattr(self, '_list'): 57 | return range(len(self._list)) 58 | elif hasattr(self, '_value'): 59 | return [] 60 | else: 61 | return [key for key in self.__dict__.keys() if not key.startswith('_')] 62 | 63 | def values(self): 64 | if hasattr(self, '_list'): 65 | return self._list 66 | elif hasattr(self, '_value'): 67 | return [self._value] 68 | else: 69 | return [value for key, value in self.__dict__.items() if not key.startswith('_')] 70 | 71 | def items(self): 72 | if hasattr(self, '_list'): 73 | return enumerate(self._list) 74 | elif hasattr(self, '_value'): 75 | return [] 76 | else: 77 | return [(key, value) for key, value in self.__dict__.items() if not key.startswith('_')] 78 | 79 | def __repr__(self): 80 | if hasattr(self, '_list'): 81 | return f"AttributeDict({self._list})" 82 | elif hasattr(self, '_value'): 83 | return f"AttributeDict({self._value})" 84 | return f"AttributeDict({dict(self.__dict__)})" 85 | 86 | def to_dict(self) -> Union[Dict, List, Any]: 87 | if hasattr(self, '_list'): 88 | return [item.to_dict() if isinstance(item, AttributeDict) else item 89 | for item in self._list] 90 | elif hasattr(self, '_value'): 91 | return self._value 92 | else: 93 | result = {} 94 | for key, value in self.__dict__.items(): 95 | if isinstance(value, AttributeDict): 96 | result[key] = value.to_dict() 97 | else: 98 | result[key] = value 99 | return result 100 | 101 | def load_yaml(file_path: str, encoding: str = 'utf-8') -> AttributeDict: 102 | try: 103 | with open(file_path, 'r', encoding=encoding) as file: 104 | data = yaml.safe_load(file) 105 | return AttributeDict(data) 106 | except FileNotFoundError: 107 | raise FileNotFoundError(f"YAML file not found: {file_path}") 108 | except yaml.YAMLError as e: 109 | raise yaml.YAMLError(f"YAML format error: {e}") 110 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/helper.py: -------------------------------------------------------------------------------- 1 | import collections.abc 2 | from itertools import repeat 3 | import importlib 4 | import yaml 5 | import time 6 | 7 | def default(value, default_val): 8 | return default_val if value is None else value 9 | 10 | 11 | def default_dtype(value, default_val): 12 | if value is not None: 13 | assert isinstance(value, type(default_val)), f"Expect {type(default_val)}, got {type(value)}." 14 | return value 15 | return default_val 16 | 17 | 18 | def repeat_interleave(lst, num_repeats): 19 | return [item for item in lst for _ in range(num_repeats)] 20 | 21 | 22 | def _ntuple(n): 23 | def parse(x): 24 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 25 | x = tuple(x) 26 | if len(x) == 1: 27 | x = tuple(repeat(x[0], n)) 28 | return x 29 | return tuple(repeat(x, n)) 30 | 31 | return parse 32 | 33 | 34 | to_1tuple = _ntuple(1) 35 | to_2tuple = _ntuple(2) 36 | to_3tuple = _ntuple(3) 37 | to_4tuple = _ntuple(4) 38 | 39 | 40 | def as_tuple(x): 41 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 42 | return tuple(x) 43 | if x is None or isinstance(x, (int, float, str)): 44 | return (x,) 45 | else: 46 | raise ValueError(f"Unknown type {type(x)}") 47 | 48 | 49 | def as_list_of_2tuple(x): 50 | x = as_tuple(x) 51 | if len(x) == 1: 52 | x = (x[0], x[0]) 53 | assert len(x) % 2 == 0, f"Expect even length, got {len(x)}." 54 | lst = [] 55 | for i in range(0, len(x), 2): 56 | lst.append((x[i], x[i + 1])) 57 | return lst 58 | 59 | 60 | def find_multiple(n: int, k: int) -> int: 61 | assert k > 0 62 | if n % k == 0: 63 | return n 64 | return n - (n % k) + k 65 | 66 | 67 | def merge_dicts(dict1, dict2): 68 | for key, value in dict2.items(): 69 | if key in dict1 and isinstance(dict1[key], dict) and isinstance(value, dict): 70 | merge_dicts(dict1[key], value) 71 | else: 72 | dict1[key] = value 73 | return dict1 74 | 75 | 76 | def merge_yaml_files(file_list): 77 | merged_config = {} 78 | 79 | for file in file_list: 80 | with open(file, "r", encoding="utf-8") as f: 81 | config = yaml.safe_load(f) 82 | if config: 83 | # Remove the first level 84 | for key, value in config.items(): 85 | if isinstance(value, dict): 86 | merged_config = merge_dicts(merged_config, value) 87 | else: 88 | merged_config[key] = value 89 | 90 | return merged_config 91 | 92 | 93 | def merge_dict(file_list): 94 | merged_config = {} 95 | 96 | for file in file_list: 97 | with open(file, "r", encoding="utf-8") as f: 98 | config = yaml.safe_load(f) 99 | if config: 100 | merged_config = merge_dicts(merged_config, config) 101 | 102 | return merged_config 103 | 104 | 105 | def get_obj_from_str(string, reload=False): 106 | module, cls = string.rsplit(".", 1) 107 | if reload: 108 | module_imp = importlib.import_module(module) 109 | importlib.reload(module_imp) 110 | return getattr(importlib.import_module(module, package=None), cls) 111 | 112 | 113 | def readable_time(seconds): 114 | """ Convert time seconds to a readable format: DD Days, HH Hours, MM Minutes, SS Seconds """ 115 | seconds = int(seconds) 116 | days, seconds = divmod(seconds, 86400) 117 | hours, seconds = divmod(seconds, 3600) 118 | minutes, seconds = divmod(seconds, 60) 119 | if days > 0: 120 | return f"{days} Days, {hours} Hours, {minutes} Minutes, {seconds} Seconds" 121 | if hours > 0: 122 | return f"{hours} Hours, {minutes} Minutes, {seconds} Seconds" 123 | if minutes > 0: 124 | return f"{minutes} Minutes, {seconds} Seconds" 125 | return f"{seconds} Seconds" 126 | 127 | 128 | def get_obj_from_cfg(cfg, reload=False): 129 | if isinstance(cfg, str): 130 | return get_obj_from_str(cfg, reload) 131 | elif isinstance(cfg, (list, tuple,)): 132 | return tuple([get_obj_from_str(c, reload) for c in cfg]) 133 | else: 134 | raise NotImplementedError(f"Not implemented for {type(cfg)}.") 135 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/utils.py: -------------------------------------------------------------------------------- 1 | from hashlib import md5 2 | from pathlib import Path 3 | import subprocess 4 | 5 | import requests 6 | from tqdm import tqdm 7 | 8 | PARENT_LINK = "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a" 9 | FNAME2LINK = { 10 | # S3: Synchability: AudioSet (run 2) 11 | "24-01-22T20-34-52.pt": f"{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/24-01-22T20-34-52.pt", 12 | "cfg-24-01-22T20-34-52.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/cfg-24-01-22T20-34-52.yaml", 13 | # S2: Synchformer: AudioSet (run 2) 14 | "24-01-04T16-39-21.pt": f"{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/24-01-04T16-39-21.pt", 15 | "cfg-24-01-04T16-39-21.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml", 16 | # S2: Synchformer: AudioSet (run 1) 17 | "23-08-28T11-23-23.pt": f"{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/23-08-28T11-23-23.pt", 18 | "cfg-23-08-28T11-23-23.yaml": f"{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/cfg-23-08-28T11-23-23.yaml", 19 | # S2: Synchformer: LRS3 (run 2) 20 | "23-12-23T18-33-57.pt": f"{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/23-12-23T18-33-57.pt", 21 | "cfg-23-12-23T18-33-57.yaml": f"{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/cfg-23-12-23T18-33-57.yaml", 22 | # S2: Synchformer: VGS (run 2) 23 | "24-01-02T10-00-53.pt": f"{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/24-01-02T10-00-53.pt", 24 | "cfg-24-01-02T10-00-53.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/cfg-24-01-02T10-00-53.yaml", 25 | # SparseSync: ft VGGSound-Full 26 | "22-09-21T21-00-52.pt": f"{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/22-09-21T21-00-52.pt", 27 | "cfg-22-09-21T21-00-52.yaml": f"{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/cfg-22-09-21T21-00-52.yaml", 28 | # SparseSync: ft VGGSound-Sparse 29 | "22-07-28T15-49-45.pt": f"{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/22-07-28T15-49-45.pt", 30 | "cfg-22-07-28T15-49-45.yaml": f"{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/cfg-22-07-28T15-49-45.yaml", 31 | # SparseSync: only pt on LRS3 32 | "22-07-13T22-25-49.pt": f"{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/22-07-13T22-25-49.pt", 33 | "cfg-22-07-13T22-25-49.yaml": f"{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/cfg-22-07-13T22-25-49.yaml", 34 | # SparseSync: feature extractors 35 | "ResNetAudio-22-08-04T09-51-04.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-04T09-51-04.pt", # 2s 36 | "ResNetAudio-22-08-03T23-14-49.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-49.pt", # 3s 37 | "ResNetAudio-22-08-03T23-14-28.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-28.pt", # 4s 38 | "ResNetAudio-22-06-24T08-10-33.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T08-10-33.pt", # 5s 39 | "ResNetAudio-22-06-24T17-31-07.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T17-31-07.pt", # 6s 40 | "ResNetAudio-22-06-24T23-57-11.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T23-57-11.pt", # 7s 41 | "ResNetAudio-22-06-25T04-35-42.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-25T04-35-42.pt", # 8s 42 | } 43 | 44 | 45 | def check_if_file_exists_else_download(path, fname2link=FNAME2LINK, chunk_size=1024): 46 | """Checks if file exists, if not downloads it from the link to the path""" 47 | path = Path(path) 48 | if not path.exists(): 49 | path.parent.mkdir(exist_ok=True, parents=True) 50 | link = fname2link.get(path.name, None) 51 | if link is None: 52 | raise ValueError( 53 | f"Cant find the checkpoint file: {path}.", f"Please download it manually and ensure the path exists." 54 | ) 55 | with requests.get(fname2link[path.name], stream=True) as r: 56 | total_size = int(r.headers.get("content-length", 0)) 57 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 58 | with open(path, "wb") as f: 59 | for data in r.iter_content(chunk_size=chunk_size): 60 | if data: 61 | f.write(data) 62 | pbar.update(chunk_size) 63 | 64 | 65 | def which_ffmpeg() -> str: 66 | """Determines the path to ffmpeg library 67 | Returns: 68 | str -- path to the library 69 | """ 70 | result = subprocess.run(["which", "ffmpeg"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 71 | ffmpeg_path = result.stdout.decode("utf-8").replace("\n", "") 72 | return ffmpeg_path 73 | 74 | 75 | def get_md5sum(path): 76 | hash_md5 = md5() 77 | with open(path, "rb") as f: 78 | for chunk in iter(lambda: f.read(4096 * 8), b""): 79 | hash_md5.update(chunk) 80 | md5sum = hash_md5.hexdigest() 81 | return md5sum 82 | 83 | 84 | class Config: 85 | def __init__(self, **kwargs): 86 | for k, v in kwargs.items(): 87 | setattr(self, k, v) 88 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/embed_layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from ...utils.helper import to_2tuple, to_1tuple 6 | 7 | class PatchEmbed1D(nn.Module): 8 | """1D Audio to Patch Embedding 9 | 10 | A convolution based approach to patchifying a 1D audio w/ embedding projection. 11 | 12 | Based on the impl in https://github.com/google-research/vision_transformer 13 | 14 | Hacked together by / Copyright 2020 Ross Wightman 15 | """ 16 | 17 | def __init__( 18 | self, 19 | patch_size=1, 20 | in_chans=768, 21 | embed_dim=768, 22 | norm_layer=None, 23 | flatten=True, 24 | bias=True, 25 | dtype=None, 26 | device=None, 27 | ): 28 | factory_kwargs = {"dtype": dtype, "device": device} 29 | super().__init__() 30 | patch_size = to_1tuple(patch_size) 31 | self.patch_size = patch_size 32 | self.flatten = flatten 33 | 34 | self.proj = nn.Conv1d( 35 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs 36 | ) 37 | nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1)) 38 | if bias: 39 | nn.init.zeros_(self.proj.bias) 40 | 41 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 42 | 43 | def forward(self, x): 44 | assert ( 45 | x.shape[2] % self.patch_size[0] == 0 46 | ), f"The patch_size of {self.patch_size[0]} must be divisible by the token number ({x.shape[2]}) of x." 47 | 48 | x = self.proj(x) 49 | if self.flatten: 50 | x = x.transpose(1, 2) # BCN -> BNC 51 | x = self.norm(x) 52 | return x 53 | 54 | 55 | class ConditionProjection(nn.Module): 56 | """ 57 | Projects condition embeddings. Also handles dropout for classifier-free guidance. 58 | 59 | Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py 60 | """ 61 | 62 | def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): 63 | factory_kwargs = {'dtype': dtype, 'device': device} 64 | super().__init__() 65 | self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs) 66 | self.act_1 = act_layer() 67 | self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs) 68 | 69 | def forward(self, caption): 70 | hidden_states = self.linear_1(caption) 71 | hidden_states = self.act_1(hidden_states) 72 | hidden_states = self.linear_2(hidden_states) 73 | return hidden_states 74 | 75 | 76 | def timestep_embedding(t, dim, max_period=10000): 77 | """ 78 | Create sinusoidal timestep embeddings. 79 | 80 | Args: 81 | t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. 82 | dim (int): the dimension of the output. 83 | max_period (int): controls the minimum frequency of the embeddings. 84 | 85 | Returns: 86 | embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. 87 | 88 | .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 89 | """ 90 | half = dim // 2 91 | freqs = torch.exp( 92 | -math.log(max_period) 93 | * torch.arange(start=0, end=half, dtype=torch.float32) 94 | / half 95 | ).to(device=t.device) 96 | args = t[:, None].float() * freqs[None] 97 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 98 | if dim % 2: 99 | embedding = torch.cat( 100 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 101 | ) 102 | return embedding 103 | 104 | 105 | class TimestepEmbedder(nn.Module): 106 | """ 107 | Embeds scalar timesteps into vector representations. 108 | """ 109 | def __init__(self, 110 | hidden_size, 111 | act_layer, 112 | frequency_embedding_size=256, 113 | max_period=10000, 114 | out_size=None, 115 | dtype=None, 116 | device=None 117 | ): 118 | factory_kwargs = {'dtype': dtype, 'device': device} 119 | super().__init__() 120 | self.frequency_embedding_size = frequency_embedding_size 121 | self.max_period = max_period 122 | if out_size is None: 123 | out_size = hidden_size 124 | 125 | self.mlp = nn.Sequential( 126 | nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs), 127 | act_layer(), 128 | nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), 129 | ) 130 | nn.init.normal_(self.mlp[0].weight, std=0.02) 131 | nn.init.normal_(self.mlp[2].weight, std=0.02) 132 | 133 | def forward(self, t): 134 | t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) 135 | t_emb = self.mlp(t_freq) 136 | return t_emb 137 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/mlp_layers.py: -------------------------------------------------------------------------------- 1 | # Modified from timm library: 2 | # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13 3 | 4 | from functools import partial 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .modulate_layers import modulate 11 | from ...utils.helper import to_2tuple 12 | 13 | class MLP(nn.Module): 14 | """MLP as used in Vision Transformer, MLP-Mixer and related networks""" 15 | 16 | def __init__( 17 | self, 18 | in_channels, 19 | hidden_channels=None, 20 | out_features=None, 21 | act_layer=nn.GELU, 22 | norm_layer=None, 23 | bias=True, 24 | drop=0.0, 25 | use_conv=False, 26 | device=None, 27 | dtype=None, 28 | ): 29 | factory_kwargs = {"device": device, "dtype": dtype} 30 | super().__init__() 31 | out_features = out_features or in_channels 32 | hidden_channels = hidden_channels or in_channels 33 | bias = to_2tuple(bias) 34 | drop_probs = to_2tuple(drop) 35 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear 36 | 37 | self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs) 38 | self.act = act_layer() 39 | self.drop1 = nn.Dropout(drop_probs[0]) 40 | self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity() 41 | self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs) 42 | self.drop2 = nn.Dropout(drop_probs[1]) 43 | 44 | def forward(self, x): 45 | x = self.fc1(x) 46 | x = self.act(x) 47 | x = self.drop1(x) 48 | x = self.norm(x) 49 | x = self.fc2(x) 50 | x = self.drop2(x) 51 | return x 52 | 53 | 54 | # copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py 55 | # only used when use_vanilla is True 56 | class MLPEmbedder(nn.Module): 57 | def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None): 58 | factory_kwargs = {"device": device, "dtype": dtype} 59 | super().__init__() 60 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs) 61 | self.silu = nn.SiLU() 62 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs) 63 | 64 | def forward(self, x: torch.Tensor) -> torch.Tensor: 65 | return self.out_layer(self.silu(self.in_layer(x))) 66 | 67 | 68 | class LinearWarpforSingle(nn.Module): 69 | def __init__(self, in_dim: int, out_dim: int, bias=True, device=None, dtype=None): 70 | factory_kwargs = {"device": device, "dtype": dtype} 71 | super().__init__() 72 | self.fc = nn.Linear(in_dim, out_dim, bias=bias, **factory_kwargs) 73 | 74 | def forward(self, x, y): 75 | z = torch.cat([x, y], dim=2) 76 | return self.fc(z) 77 | 78 | class FinalLayer1D(nn.Module): 79 | def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None): 80 | factory_kwargs = {"device": device, "dtype": dtype} 81 | super().__init__() 82 | 83 | # Just use LayerNorm for the final layer 84 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) 85 | self.linear = nn.Linear(hidden_size, patch_size * out_channels, bias=True, **factory_kwargs) 86 | nn.init.zeros_(self.linear.weight) 87 | nn.init.zeros_(self.linear.bias) 88 | 89 | # Here we don't distinguish between the modulate types. Just use the simple one. 90 | self.adaLN_modulation = nn.Sequential( 91 | act_layer(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs) 92 | ) 93 | # Zero-initialize the modulation 94 | nn.init.zeros_(self.adaLN_modulation[1].weight) 95 | nn.init.zeros_(self.adaLN_modulation[1].bias) 96 | 97 | def forward(self, x, c): 98 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) 99 | x = modulate(self.norm_final(x), shift=shift, scale=scale) 100 | x = self.linear(x) 101 | return x 102 | 103 | 104 | class ChannelLastConv1d(nn.Conv1d): 105 | 106 | def forward(self, x: torch.Tensor) -> torch.Tensor: 107 | x = x.permute(0, 2, 1) 108 | x = super().forward(x) 109 | x = x.permute(0, 2, 1) 110 | return x 111 | 112 | 113 | class ConvMLP(nn.Module): 114 | 115 | def __init__( 116 | self, 117 | dim: int, 118 | hidden_dim: int, 119 | multiple_of: int = 256, 120 | kernel_size: int = 3, 121 | padding: int = 1, 122 | device=None, 123 | dtype=None, 124 | ): 125 | """ 126 | Convolutional MLP module. 127 | 128 | Args: 129 | dim (int): Input dimension. 130 | hidden_dim (int): Hidden dimension of the feedforward layer. 131 | multiple_of (int): Value to ensure hidden dimension is a multiple of this value. 132 | 133 | Attributes: 134 | w1: Linear transformation for the first layer. 135 | w2: Linear transformation for the second layer. 136 | w3: Linear transformation for the third layer. 137 | 138 | """ 139 | factory_kwargs = {"device": device, "dtype": dtype} 140 | super().__init__() 141 | hidden_dim = int(2 * hidden_dim / 3) 142 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 143 | 144 | self.w1 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs) 145 | self.w2 = ChannelLastConv1d(hidden_dim, dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs) 146 | self.w3 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs) 147 | 148 | def forward(self, x): 149 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 150 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/feature_utils.py: -------------------------------------------------------------------------------- 1 | """Feature extraction utilities for video and text processing.""" 2 | 3 | import os 4 | import numpy as np 5 | import torch 6 | import av 7 | from PIL import Image 8 | from einops import rearrange 9 | from typing import Any, Dict, List, Union, Tuple 10 | from loguru import logger 11 | 12 | from .config_utils import AttributeDict 13 | from ..constants import FPS_VISUAL, MAX_VIDEO_DURATION_SECONDS 14 | 15 | 16 | class FeatureExtractionError(Exception): 17 | """Exception raised for feature extraction errors.""" 18 | pass 19 | 20 | def get_frames_av( 21 | video_path: str, 22 | fps: float, 23 | max_length: float = None, 24 | ) -> Tuple[np.ndarray, float]: 25 | end_sec = max_length if max_length is not None else 15 26 | next_frame_time_for_each_fps = 0.0 27 | time_delta_for_each_fps = 1 / fps 28 | 29 | all_frames = [] 30 | output_frames = [] 31 | 32 | with av.open(video_path) as container: 33 | stream = container.streams.video[0] 34 | ori_fps = stream.guessed_rate 35 | stream.thread_type = "AUTO" 36 | for packet in container.demux(stream): 37 | for frame in packet.decode(): 38 | frame_time = frame.time 39 | if frame_time < 0: 40 | continue 41 | if frame_time > end_sec: 42 | break 43 | 44 | frame_np = None 45 | 46 | this_time = frame_time 47 | while this_time >= next_frame_time_for_each_fps: 48 | if frame_np is None: 49 | frame_np = frame.to_ndarray(format="rgb24") 50 | 51 | output_frames.append(frame_np) 52 | next_frame_time_for_each_fps += time_delta_for_each_fps 53 | 54 | output_frames = np.stack(output_frames) 55 | 56 | vid_len_in_s = len(output_frames) / fps 57 | if max_length is not None and len(output_frames) > int(max_length * fps): 58 | output_frames = output_frames[: int(max_length * fps)] 59 | vid_len_in_s = max_length 60 | 61 | return output_frames, vid_len_in_s 62 | 63 | @torch.inference_mode() 64 | def encode_video_with_siglip2(x: torch.Tensor, model_dict, batch_size: int = -1): 65 | b, t, c, h, w = x.shape 66 | if batch_size < 0: 67 | batch_size = b * t 68 | x = rearrange(x, "b t c h w -> (b t) c h w") 69 | outputs = [] 70 | for i in range(0, b * t, batch_size): 71 | outputs.append(model_dict.siglip2_model.get_image_features(pixel_values=x[i : i + batch_size])) 72 | res = torch.cat(outputs, dim=0) 73 | res = rearrange(res, "(b t) d -> b t d", b=b) 74 | return res 75 | 76 | @torch.inference_mode() 77 | def encode_video_with_sync(x: torch.Tensor, model_dict, batch_size: int = -1): 78 | """ 79 | The input video of x is best to be in fps of 24 of greater than 24. 80 | Input: 81 | x: tensor in shape of [B, T, C, H, W] 82 | batch_size: the batch_size for synchformer inference 83 | """ 84 | b, t, c, h, w = x.shape 85 | assert c == 3 and h == 224 and w == 224 86 | 87 | segment_size = 16 88 | step_size = 8 89 | num_segments = (t - segment_size) // step_size + 1 90 | segments = [] 91 | for i in range(num_segments): 92 | segments.append(x[:, i * step_size : i * step_size + segment_size]) 93 | x = torch.stack(segments, dim=1).cuda() # (B, num_segments, segment_size, 3, 224, 224) 94 | 95 | outputs = [] 96 | if batch_size < 0: 97 | batch_size = b * num_segments 98 | x = rearrange(x, "b s t c h w -> (b s) 1 t c h w") 99 | for i in range(0, b * num_segments, batch_size): 100 | with torch.autocast(device_type="cuda", enabled=True, dtype=torch.half): 101 | outputs.append(model_dict.syncformer_model(x[i : i + batch_size])) 102 | x = torch.cat(outputs, dim=0) # [b * num_segments, 1, 8, 768] 103 | x = rearrange(x, "(b s) 1 t d -> b (s t) d", b=b) 104 | return x 105 | 106 | 107 | @torch.inference_mode() 108 | def encode_video_features(video_path, model_dict): 109 | visual_features = {} 110 | # siglip2 visual features 111 | frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["siglip2"]) 112 | images = [Image.fromarray(frame).convert('RGB') for frame in frames] 113 | images = [model_dict.siglip2_preprocess(image) for image in images] # [T, C, H, W] 114 | clip_frames = torch.stack(images).to(model_dict.device).unsqueeze(0) 115 | visual_features['siglip2_feat'] = encode_video_with_siglip2(clip_frames, model_dict).to(model_dict.device) 116 | 117 | # synchformer visual features 118 | frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["synchformer"]) 119 | images = torch.from_numpy(frames).permute(0, 3, 1, 2) # [T, C, H, W] 120 | sync_frames = model_dict.syncformer_preprocess(images).unsqueeze(0) # [1, T, 3, 224, 224] 121 | # [1, num_segments * 8, channel_dim], e.g. [1, 240, 768] for 10s video 122 | visual_features['syncformer_feat'] = encode_video_with_sync(sync_frames, model_dict) 123 | 124 | vid_len_in_s = sync_frames.shape[1] / FPS_VISUAL["synchformer"] 125 | visual_features = AttributeDict(visual_features) 126 | 127 | return visual_features, vid_len_in_s 128 | 129 | @torch.inference_mode() 130 | def encode_text_feat(text: List[str], model_dict): 131 | # x: (B, L) 132 | inputs = model_dict.clap_tokenizer(text, padding=True, return_tensors="pt").to(model_dict.device) 133 | outputs = model_dict.clap_model(**inputs, output_hidden_states=True, return_dict=True) 134 | return outputs.last_hidden_state, outputs.attentions 135 | 136 | 137 | def feature_process(video_path, prompt, model_dict, cfg): 138 | visual_feats, audio_len_in_s = encode_video_features(video_path, model_dict) 139 | neg_prompt = "noisy, harsh" 140 | prompts = [neg_prompt, prompt] 141 | text_feat_res, text_feat_mask = encode_text_feat(prompts, model_dict) 142 | 143 | text_feat = text_feat_res[1:] 144 | uncond_text_feat = text_feat_res[:1] 145 | 146 | if cfg.model_config.model_kwargs.text_length < text_feat.shape[1]: 147 | text_seq_length = cfg.model_config.model_kwargs.text_length 148 | text_feat = text_feat[:, :text_seq_length] 149 | uncond_text_feat = uncond_text_feat[:, :text_seq_length] 150 | 151 | text_feats = AttributeDict({ 152 | 'text_feat': text_feat, 153 | 'uncond_text_feat': uncond_text_feat, 154 | }) 155 | 156 | return visual_feats, text_feats, audio_len_in_s 157 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/nn/posemb_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, Tuple 3 | 4 | 5 | def _to_tuple(x, dim=2): 6 | if isinstance(x, int): 7 | return (x,) * dim 8 | elif len(x) == dim: 9 | return x 10 | else: 11 | raise ValueError(f"Expected length {dim} or int, but got {x}") 12 | 13 | 14 | def get_meshgrid_nd(start, *args, dim=2): 15 | """ 16 | Get n-D meshgrid with start, stop and num. 17 | 18 | Args: 19 | start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, 20 | step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num 21 | should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in 22 | n-tuples. 23 | *args: See above. 24 | dim (int): Dimension of the meshgrid. Defaults to 2. 25 | 26 | Returns: 27 | grid (np.ndarray): [dim, ...] 28 | """ 29 | if len(args) == 0: 30 | # start is grid_size 31 | num = _to_tuple(start, dim=dim) 32 | start = (0,) * dim 33 | stop = num 34 | elif len(args) == 1: 35 | # start is start, args[0] is stop, step is 1 36 | start = _to_tuple(start, dim=dim) 37 | stop = _to_tuple(args[0], dim=dim) 38 | num = [stop[i] - start[i] for i in range(dim)] 39 | elif len(args) == 2: 40 | # start is start, args[0] is stop, args[1] is num 41 | start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 42 | stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 43 | num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 44 | else: 45 | raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") 46 | 47 | # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) 48 | axis_grid = [] 49 | for i in range(dim): 50 | a, b, n = start[i], stop[i], num[i] 51 | g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] 52 | axis_grid.append(g) 53 | grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] 54 | grid = torch.stack(grid, dim=0) # [dim, W, H, D] 55 | 56 | return grid 57 | 58 | 59 | ################################################################################# 60 | # Rotary Positional Embedding Functions # 61 | ################################################################################# 62 | # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80 63 | 64 | 65 | def get_nd_rotary_pos_embed( 66 | rope_dim_list, start, *args, theta=10000.0, use_real=False, theta_rescale_factor=1.0, freq_scaling=1.0 67 | ): 68 | """ 69 | This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. 70 | 71 | Args: 72 | rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. 73 | sum(rope_dim_list) should equal to head_dim of attention layer. 74 | start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, 75 | args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. 76 | *args: See above. 77 | theta (float): Scaling factor for frequency computation. Defaults to 10000.0. 78 | use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. 79 | Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real 80 | part and an imaginary part separately. 81 | theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. 82 | freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0. 83 | 84 | Returns: 85 | pos_embed (torch.Tensor): [HW, D/2] 86 | """ 87 | 88 | grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] 89 | 90 | # use 1/ndim of dimensions to encode grid_axis 91 | embs = [] 92 | for i in range(len(rope_dim_list)): 93 | emb = get_1d_rotary_pos_embed( 94 | rope_dim_list[i], 95 | grid[i].reshape(-1), 96 | theta, 97 | use_real=use_real, 98 | theta_rescale_factor=theta_rescale_factor, 99 | freq_scaling=freq_scaling, 100 | ) # 2 x [WHD, rope_dim_list[i]] 101 | embs.append(emb) 102 | 103 | if use_real: 104 | cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) 105 | sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) 106 | return cos, sin 107 | else: 108 | emb = torch.cat(embs, dim=1) # (WHD, D/2) 109 | return emb 110 | 111 | 112 | def get_1d_rotary_pos_embed( 113 | dim: int, 114 | pos: Union[torch.FloatTensor, int], 115 | theta: float = 10000.0, 116 | use_real: bool = False, 117 | theta_rescale_factor: float = 1.0, 118 | freq_scaling: float = 1.0, 119 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 120 | """ 121 | Precompute the frequency tensor for complex exponential (cis) with given dimensions. 122 | (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) 123 | 124 | This function calculates a frequency tensor with complex exponential using the given dimension 'dim' 125 | and the end index 'end'. The 'theta' parameter scales the frequencies. 126 | The returned tensor contains complex values in complex64 data type. 127 | 128 | Args: 129 | dim (int): Dimension of the frequency tensor. 130 | pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar 131 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. 132 | use_real (bool, optional): If True, return real part and imaginary part separately. 133 | Otherwise, return complex numbers. 134 | theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. 135 | freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0. 136 | 137 | Returns: 138 | freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] 139 | freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] 140 | """ 141 | if isinstance(pos, int): 142 | pos = torch.arange(pos).float() 143 | 144 | # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning 145 | # has some connection to NTK literature 146 | # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ 147 | if theta_rescale_factor != 1.0: 148 | theta *= theta_rescale_factor ** (dim / (dim - 1)) 149 | 150 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] 151 | freqs *= freq_scaling 152 | freqs = torch.outer(pos, freqs) # [S, D/2] 153 | if use_real: 154 | freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] 155 | freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] 156 | return freqs_cos, freqs_sin 157 | else: 158 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] 159 | return freqs_cis 160 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/model/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from audiotools import AudioSignal 5 | from audiotools import ml 6 | from audiotools import STFTParams 7 | from einops import rearrange 8 | from torch.nn.utils import weight_norm 9 | 10 | 11 | def WNConv1d(*args, **kwargs): 12 | act = kwargs.pop("act", True) 13 | conv = weight_norm(nn.Conv1d(*args, **kwargs)) 14 | if not act: 15 | return conv 16 | return nn.Sequential(conv, nn.LeakyReLU(0.1)) 17 | 18 | 19 | def WNConv2d(*args, **kwargs): 20 | act = kwargs.pop("act", True) 21 | conv = weight_norm(nn.Conv2d(*args, **kwargs)) 22 | if not act: 23 | return conv 24 | return nn.Sequential(conv, nn.LeakyReLU(0.1)) 25 | 26 | 27 | class MPD(nn.Module): 28 | def __init__(self, period): 29 | super().__init__() 30 | self.period = period 31 | self.convs = nn.ModuleList( 32 | [ 33 | WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), 34 | WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), 35 | WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), 36 | WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), 37 | WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), 38 | ] 39 | ) 40 | self.conv_post = WNConv2d( 41 | 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False 42 | ) 43 | 44 | def pad_to_period(self, x): 45 | t = x.shape[-1] 46 | x = F.pad(x, (0, self.period - t % self.period), mode="reflect") 47 | return x 48 | 49 | def forward(self, x): 50 | fmap = [] 51 | 52 | x = self.pad_to_period(x) 53 | x = rearrange(x, "b c (l p) -> b c l p", p=self.period) 54 | 55 | for layer in self.convs: 56 | x = layer(x) 57 | fmap.append(x) 58 | 59 | x = self.conv_post(x) 60 | fmap.append(x) 61 | 62 | return fmap 63 | 64 | 65 | class MSD(nn.Module): 66 | def __init__(self, rate: int = 1, sample_rate: int = 44100): 67 | super().__init__() 68 | self.convs = nn.ModuleList( 69 | [ 70 | WNConv1d(1, 16, 15, 1, padding=7), 71 | WNConv1d(16, 64, 41, 4, groups=4, padding=20), 72 | WNConv1d(64, 256, 41, 4, groups=16, padding=20), 73 | WNConv1d(256, 1024, 41, 4, groups=64, padding=20), 74 | WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), 75 | WNConv1d(1024, 1024, 5, 1, padding=2), 76 | ] 77 | ) 78 | self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) 79 | self.sample_rate = sample_rate 80 | self.rate = rate 81 | 82 | def forward(self, x): 83 | x = AudioSignal(x, self.sample_rate) 84 | x.resample(self.sample_rate // self.rate) 85 | x = x.audio_data 86 | 87 | fmap = [] 88 | 89 | for l in self.convs: 90 | x = l(x) 91 | fmap.append(x) 92 | x = self.conv_post(x) 93 | fmap.append(x) 94 | 95 | return fmap 96 | 97 | 98 | BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] 99 | 100 | 101 | class MRD(nn.Module): 102 | def __init__( 103 | self, 104 | window_length: int, 105 | hop_factor: float = 0.25, 106 | sample_rate: int = 44100, 107 | bands: list = BANDS, 108 | ): 109 | """Complex multi-band spectrogram discriminator. 110 | Parameters 111 | ---------- 112 | window_length : int 113 | Window length of STFT. 114 | hop_factor : float, optional 115 | Hop factor of the STFT, defaults to ``0.25 * window_length``. 116 | sample_rate : int, optional 117 | Sampling rate of audio in Hz, by default 44100 118 | bands : list, optional 119 | Bands to run discriminator over. 120 | """ 121 | super().__init__() 122 | 123 | self.window_length = window_length 124 | self.hop_factor = hop_factor 125 | self.sample_rate = sample_rate 126 | self.stft_params = STFTParams( 127 | window_length=window_length, 128 | hop_length=int(window_length * hop_factor), 129 | match_stride=True, 130 | ) 131 | 132 | n_fft = window_length // 2 + 1 133 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] 134 | self.bands = bands 135 | 136 | ch = 32 137 | convs = lambda: nn.ModuleList( 138 | [ 139 | WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), 140 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 141 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 142 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 143 | WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), 144 | ] 145 | ) 146 | self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) 147 | self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) 148 | 149 | def spectrogram(self, x): 150 | x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) 151 | x = torch.view_as_real(x.stft()) 152 | x = rearrange(x, "b 1 f t c -> (b 1) c t f") 153 | # Split into bands 154 | x_bands = [x[..., b[0] : b[1]] for b in self.bands] 155 | return x_bands 156 | 157 | def forward(self, x): 158 | x_bands = self.spectrogram(x) 159 | fmap = [] 160 | 161 | x = [] 162 | for band, stack in zip(x_bands, self.band_convs): 163 | for layer in stack: 164 | band = layer(band) 165 | fmap.append(band) 166 | x.append(band) 167 | 168 | x = torch.cat(x, dim=-1) 169 | x = self.conv_post(x) 170 | fmap.append(x) 171 | 172 | return fmap 173 | 174 | 175 | class Discriminator(ml.BaseModel): 176 | def __init__( 177 | self, 178 | rates: list = [], 179 | periods: list = [2, 3, 5, 7, 11], 180 | fft_sizes: list = [2048, 1024, 512], 181 | sample_rate: int = 44100, 182 | bands: list = BANDS, 183 | ): 184 | """Discriminator that combines multiple discriminators. 185 | 186 | Parameters 187 | ---------- 188 | rates : list, optional 189 | sampling rates (in Hz) to run MSD at, by default [] 190 | If empty, MSD is not used. 191 | periods : list, optional 192 | periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] 193 | fft_sizes : list, optional 194 | Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] 195 | sample_rate : int, optional 196 | Sampling rate of audio in Hz, by default 44100 197 | bands : list, optional 198 | Bands to run MRD at, by default `BANDS` 199 | """ 200 | super().__init__() 201 | discs = [] 202 | discs += [MPD(p) for p in periods] 203 | discs += [MSD(r, sample_rate=sample_rate) for r in rates] 204 | discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] 205 | self.discriminators = nn.ModuleList(discs) 206 | 207 | def preprocess(self, y): 208 | # Remove DC offset 209 | y = y - y.mean(dim=-1, keepdims=True) 210 | # Peak normalize the volume of input audio 211 | y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) 212 | return y 213 | 214 | def forward(self, x): 215 | x = self.preprocess(x) 216 | fmaps = [d(x) for d in self.discriminators] 217 | return fmaps 218 | 219 | 220 | if __name__ == "__main__": 221 | disc = Discriminator() 222 | x = torch.zeros(1, 1, 44100) 223 | results = disc(x) 224 | for i, result in enumerate(results): 225 | print(f"disc{i}") 226 | for i, r in enumerate(result): 227 | print(r.shape, r.mean(), r.min(), r.max()) 228 | print() 229 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/compute_desync_score.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | from pathlib import Path 4 | 5 | import torch 6 | import torchaudio 7 | import torchvision 8 | from omegaconf import OmegaConf 9 | 10 | import data_transforms 11 | from .synchformer import Synchformer 12 | from .data_transforms import make_class_grid, quantize_offset 13 | from .utils import check_if_file_exists_else_download, which_ffmpeg 14 | 15 | 16 | def prepare_inputs(batch, device): 17 | aud = batch["audio"].to(device) 18 | vid = batch["video"].to(device) 19 | 20 | return aud, vid 21 | 22 | 23 | def get_test_transforms(): 24 | ts = [ 25 | data_transforms.EqualifyFromRight(), 26 | data_transforms.RGBSpatialCrop(input_size=224, is_random=False), 27 | data_transforms.TemporalCropAndOffset( 28 | crop_len_sec=5, 29 | max_off_sec=2, # https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml 30 | max_wiggle_sec=0.0, 31 | do_offset=True, 32 | offset_type="grid", 33 | prob_oos="null", 34 | grid_size=21, 35 | segment_size_vframes=16, 36 | n_segments=14, 37 | step_size_seg=0.5, 38 | vfps=25, 39 | ), 40 | data_transforms.GenerateMultipleSegments( 41 | segment_size_vframes=16, 42 | n_segments=14, 43 | is_start_random=False, 44 | step_size_seg=0.5, 45 | ), 46 | data_transforms.RGBToHalfToZeroOne(), 47 | data_transforms.RGBNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # motionformer normalization 48 | data_transforms.AudioMelSpectrogram( 49 | sample_rate=16000, 50 | win_length=400, # 25 ms * 16 kHz 51 | hop_length=160, # 10 ms * 16 kHz 52 | n_fft=1024, # 2^(ceil(log2(window_size * sampling_rate))) 53 | n_mels=128, # as in AST 54 | ), 55 | data_transforms.AudioLog(), 56 | data_transforms.PadOrTruncate(max_spec_t=66), 57 | data_transforms.AudioNormalizeAST(mean=-4.2677393, std=4.5689974), # AST, pre-trained on AudioSet 58 | data_transforms.PermuteStreams( 59 | einops_order_audio="S F T -> S 1 F T", einops_order_rgb="S T C H W -> S T C H W" # same 60 | ), 61 | ] 62 | transforms = torchvision.transforms.Compose(ts) 63 | 64 | return transforms 65 | 66 | 67 | def get_video_and_audio(path, get_meta=False, start_sec=0, end_sec=None): 68 | orig_path = path 69 | # (Tv, 3, H, W) [0, 255, uint8]; (Ca, Ta) 70 | rgb, audio, meta = torchvision.io.read_video(str(path), start_sec, end_sec, "sec", output_format="TCHW") 71 | assert meta["video_fps"], f"No video fps for {orig_path}" 72 | # (Ta) <- (Ca, Ta) 73 | audio = audio.mean(dim=0) 74 | # FIXME: this is legacy format of `meta` as it used to be loaded by VideoReader. 75 | meta = { 76 | "video": {"fps": [meta["video_fps"]]}, 77 | "audio": {"framerate": [meta["audio_fps"]]}, 78 | } 79 | return rgb, audio, meta 80 | 81 | 82 | def reencode_video(path, vfps=25, afps=16000, in_size=256): 83 | assert which_ffmpeg() != "", "Is ffmpeg installed? Check if the conda environment is activated." 84 | new_path = Path.cwd() / "vis" / f"{Path(path).stem}_{vfps}fps_{in_size}side_{afps}hz.mp4" 85 | new_path.parent.mkdir(exist_ok=True) 86 | new_path = str(new_path) 87 | cmd = f"{which_ffmpeg()}" 88 | # no info/error printing 89 | cmd += " -hide_banner -loglevel panic" 90 | cmd += f" -y -i {path}" 91 | # 1) change fps, 2) resize: min(H,W)=MIN_SIDE (vertical vids are supported), 3) change audio framerate 92 | cmd += f" -vf fps={vfps},scale=iw*{in_size}/'min(iw,ih)':ih*{in_size}/'min(iw,ih)',crop='trunc(iw/2)'*2:'trunc(ih/2)'*2" 93 | cmd += f" -ar {afps}" 94 | cmd += f" {new_path}" 95 | subprocess.call(cmd.split()) 96 | cmd = f"{which_ffmpeg()}" 97 | cmd += " -hide_banner -loglevel panic" 98 | cmd += f" -y -i {new_path}" 99 | cmd += f" -acodec pcm_s16le -ac 1" 100 | cmd += f' {new_path.replace(".mp4", ".wav")}' 101 | subprocess.call(cmd.split()) 102 | return new_path 103 | 104 | 105 | def decode_single_video_prediction(off_logits, grid, item): 106 | label = item["targets"]["offset_label"].item() 107 | print("Ground Truth offset (sec):", f"{label:.2f} ({quantize_offset(grid, label)[-1].item()})") 108 | print() 109 | print("Prediction Results:") 110 | off_probs = torch.softmax(off_logits, dim=-1) 111 | k = min(off_probs.shape[-1], 5) 112 | topk_logits, topk_preds = torch.topk(off_logits, k) 113 | # remove batch dimension 114 | assert len(topk_logits) == 1, "batch is larger than 1" 115 | topk_logits = topk_logits[0] 116 | topk_preds = topk_preds[0] 117 | off_logits = off_logits[0] 118 | off_probs = off_probs[0] 119 | for target_hat in topk_preds: 120 | print(f'p={off_probs[target_hat]:.4f} ({off_logits[target_hat]:.4f}), "{grid[target_hat]:.2f}" ({target_hat})') 121 | return off_probs 122 | 123 | 124 | def main(args): 125 | vfps = 25 126 | afps = 16000 127 | in_size = 256 128 | # making the offset class grid similar to the one used in transforms, 129 | # refer to the used one: https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml 130 | max_off_sec = 2 131 | num_cls = 21 132 | 133 | # checking if the provided video has the correct frame rates 134 | print(f"Using video: {args.vid_path}") 135 | v, _, info = torchvision.io.read_video(args.vid_path, pts_unit="sec") 136 | _, H, W, _ = v.shape 137 | if info["video_fps"] != vfps or info["audio_fps"] != afps or min(H, W) != in_size: 138 | print(f'Reencoding. vfps: {info["video_fps"]} -> {vfps};', end=" ") 139 | print(f'afps: {info["audio_fps"]} -> {afps};', end=" ") 140 | print(f"{(H, W)} -> min(H, W)={in_size}") 141 | args.vid_path = reencode_video(args.vid_path, vfps, afps, in_size) 142 | else: 143 | print(f'Skipping reencoding. vfps: {info["video_fps"]}; afps: {info["audio_fps"]}; min(H, W)={in_size}') 144 | 145 | device = torch.device(args.device) 146 | 147 | # load visual and audio streams 148 | # rgb: (Tv, 3, H, W) in [0, 225], audio: (Ta,) in [-1, 1] 149 | rgb, audio, meta = get_video_and_audio(args.vid_path, get_meta=True) 150 | 151 | # making an item (dict) to apply transformations 152 | # NOTE: here is how it works: 153 | # For instance, if the model is trained on 5sec clips, the provided video is 9sec, and `v_start_i_sec=1.3` 154 | # the transform will crop out a 5sec-clip from 1.3 to 6.3 seconds and shift the start of the audio 155 | # track by `args.offset_sec` seconds. It means that if `offset_sec` > 0, the audio will 156 | # start by `offset_sec` earlier than the rgb track. 157 | # It is a good idea to use something in [-`max_off_sec`, `max_off_sec`] (-2, +2) seconds (see `grid`) 158 | item = dict( 159 | video=rgb, 160 | audio=audio, 161 | meta=meta, 162 | path=args.vid_path, 163 | split="test", 164 | targets={ 165 | "v_start_i_sec": args.v_start_i_sec, 166 | "offset_sec": args.offset_sec, 167 | }, 168 | ) 169 | 170 | grid = make_class_grid(-max_off_sec, max_off_sec, num_cls) 171 | if not (min(grid) <= item["targets"]["offset_sec"] <= max(grid)): 172 | print(f'WARNING: offset_sec={item["targets"]["offset_sec"]} is outside the trained grid: {grid}') 173 | 174 | # applying the test-time transform 175 | item = get_test_transforms()(item) 176 | 177 | # prepare inputs for inference 178 | batch = torch.utils.data.default_collate([item]) 179 | aud, vid = prepare_inputs(batch, device) 180 | 181 | # TODO: 182 | # sanity check: we will take the input to the `model` and recontruct make a video from it. 183 | # Use this check to make sure the input makes sense (audio should be ok but shifted as you specified) 184 | # reconstruct_video_from_input(aud, vid, batch['meta'], args.vid_path, args.v_start_i_sec, args.offset_sec, 185 | # vfps, afps) 186 | 187 | # forward pass 188 | with torch.set_grad_enabled(False): 189 | with torch.autocast("cuda", enabled=True): 190 | _, logits = synchformer(vid, aud) 191 | 192 | # simply prints the results of the prediction 193 | decode_single_video_prediction(logits, grid, item) 194 | 195 | 196 | if __name__ == "__main__": 197 | parser = argparse.ArgumentParser() 198 | parser.add_argument("--exp_name", required=True, help="In a format: xx-xx-xxTxx-xx-xx") 199 | parser.add_argument("--vid_path", required=True, help="A path to .mp4 video") 200 | parser.add_argument("--offset_sec", type=float, default=0.0) 201 | parser.add_argument("--v_start_i_sec", type=float, default=0.0) 202 | parser.add_argument("--device", default="cuda:0") 203 | args = parser.parse_args() 204 | 205 | synchformer = Synchformer().cuda().eval() 206 | synchformer.load_state_dict( 207 | torch.load( 208 | os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"), 209 | weights_only=True, 210 | map_location="cpu", 211 | ) 212 | ) 213 | 214 | main(args) 215 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/quantize.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | from torch.nn.utils import weight_norm 9 | 10 | from .layers import WNConv1d 11 | 12 | 13 | class VectorQuantize(nn.Module): 14 | """ 15 | Implementation of VQ similar to Karpathy's repo: 16 | https://github.com/karpathy/deep-vector-quantization 17 | Additionally uses following tricks from Improved VQGAN 18 | (https://arxiv.org/pdf/2110.04627.pdf): 19 | 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space 20 | for improved codebook usage 21 | 2. l2-normalized codes: Converts euclidean distance to cosine similarity which 22 | improves training stability 23 | """ 24 | 25 | def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): 26 | super().__init__() 27 | self.codebook_size = codebook_size 28 | self.codebook_dim = codebook_dim 29 | 30 | self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) 31 | self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) 32 | self.codebook = nn.Embedding(codebook_size, codebook_dim) 33 | 34 | def forward(self, z): 35 | """Quantized the input tensor using a fixed codebook and returns 36 | the corresponding codebook vectors 37 | 38 | Parameters 39 | ---------- 40 | z : Tensor[B x D x T] 41 | 42 | Returns 43 | ------- 44 | Tensor[B x D x T] 45 | Quantized continuous representation of input 46 | Tensor[1] 47 | Commitment loss to train encoder to predict vectors closer to codebook 48 | entries 49 | Tensor[1] 50 | Codebook loss to update the codebook 51 | Tensor[B x T] 52 | Codebook indices (quantized discrete representation of input) 53 | Tensor[B x D x T] 54 | Projected latents (continuous representation of input before quantization) 55 | """ 56 | 57 | # Factorized codes (ViT-VQGAN) Project input into low-dimensional space 58 | z_e = self.in_proj(z) # z_e : (B x D x T) 59 | z_q, indices = self.decode_latents(z_e) 60 | 61 | commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) 62 | codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) 63 | 64 | z_q = ( 65 | z_e + (z_q - z_e).detach() 66 | ) # noop in forward pass, straight-through gradient estimator in backward pass 67 | 68 | z_q = self.out_proj(z_q) 69 | 70 | return z_q, commitment_loss, codebook_loss, indices, z_e 71 | 72 | def embed_code(self, embed_id): 73 | return F.embedding(embed_id, self.codebook.weight) 74 | 75 | def decode_code(self, embed_id): 76 | return self.embed_code(embed_id).transpose(1, 2) 77 | 78 | def decode_latents(self, latents): 79 | encodings = rearrange(latents, "b d t -> (b t) d") 80 | codebook = self.codebook.weight # codebook: (N x D) 81 | 82 | # L2 normalize encodings and codebook (ViT-VQGAN) 83 | encodings = F.normalize(encodings) 84 | codebook = F.normalize(codebook) 85 | 86 | # Compute euclidean distance with codebook 87 | dist = ( 88 | encodings.pow(2).sum(1, keepdim=True) 89 | - 2 * encodings @ codebook.t() 90 | + codebook.pow(2).sum(1, keepdim=True).t() 91 | ) 92 | indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) 93 | z_q = self.decode_code(indices) 94 | return z_q, indices 95 | 96 | 97 | class ResidualVectorQuantize(nn.Module): 98 | """ 99 | Introduced in SoundStream: An end2end neural audio codec 100 | https://arxiv.org/abs/2107.03312 101 | """ 102 | 103 | def __init__( 104 | self, 105 | input_dim: int = 512, 106 | n_codebooks: int = 9, 107 | codebook_size: int = 1024, 108 | codebook_dim: Union[int, list] = 8, 109 | quantizer_dropout: float = 0.0, 110 | ): 111 | super().__init__() 112 | if isinstance(codebook_dim, int): 113 | codebook_dim = [codebook_dim for _ in range(n_codebooks)] 114 | 115 | self.n_codebooks = n_codebooks 116 | self.codebook_dim = codebook_dim 117 | self.codebook_size = codebook_size 118 | 119 | self.quantizers = nn.ModuleList( 120 | [ 121 | VectorQuantize(input_dim, codebook_size, codebook_dim[i]) 122 | for i in range(n_codebooks) 123 | ] 124 | ) 125 | self.quantizer_dropout = quantizer_dropout 126 | 127 | def forward(self, z, n_quantizers: int = None): 128 | """Quantized the input tensor using a fixed set of `n` codebooks and returns 129 | the corresponding codebook vectors 130 | Parameters 131 | ---------- 132 | z : Tensor[B x D x T] 133 | n_quantizers : int, optional 134 | No. of quantizers to use 135 | (n_quantizers < self.n_codebooks ex: for quantizer dropout) 136 | Note: if `self.quantizer_dropout` is True, this argument is ignored 137 | when in training mode, and a random number of quantizers is used. 138 | Returns 139 | ------- 140 | dict 141 | A dictionary with the following keys: 142 | 143 | "z" : Tensor[B x D x T] 144 | Quantized continuous representation of input 145 | "codes" : Tensor[B x N x T] 146 | Codebook indices for each codebook 147 | (quantized discrete representation of input) 148 | "latents" : Tensor[B x N*D x T] 149 | Projected latents (continuous representation of input before quantization) 150 | "vq/commitment_loss" : Tensor[1] 151 | Commitment loss to train encoder to predict vectors closer to codebook 152 | entries 153 | "vq/codebook_loss" : Tensor[1] 154 | Codebook loss to update the codebook 155 | """ 156 | z_q = 0 157 | residual = z 158 | commitment_loss = 0 159 | codebook_loss = 0 160 | 161 | codebook_indices = [] 162 | latents = [] 163 | 164 | if n_quantizers is None: 165 | n_quantizers = self.n_codebooks 166 | if self.training: 167 | n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 168 | dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) 169 | n_dropout = int(z.shape[0] * self.quantizer_dropout) 170 | n_quantizers[:n_dropout] = dropout[:n_dropout] 171 | n_quantizers = n_quantizers.to(z.device) 172 | 173 | for i, quantizer in enumerate(self.quantizers): 174 | if self.training is False and i >= n_quantizers: 175 | break 176 | 177 | z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( 178 | residual 179 | ) 180 | 181 | # Create mask to apply quantizer dropout 182 | mask = ( 183 | torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers 184 | ) 185 | z_q = z_q + z_q_i * mask[:, None, None] 186 | residual = residual - z_q_i 187 | 188 | # Sum losses 189 | commitment_loss += (commitment_loss_i * mask).mean() 190 | codebook_loss += (codebook_loss_i * mask).mean() 191 | 192 | codebook_indices.append(indices_i) 193 | latents.append(z_e_i) 194 | 195 | codes = torch.stack(codebook_indices, dim=1) 196 | latents = torch.cat(latents, dim=1) 197 | 198 | return z_q, codes, latents, commitment_loss, codebook_loss 199 | 200 | def from_codes(self, codes: torch.Tensor): 201 | """Given the quantized codes, reconstruct the continuous representation 202 | Parameters 203 | ---------- 204 | codes : Tensor[B x N x T] 205 | Quantized discrete representation of input 206 | Returns 207 | ------- 208 | Tensor[B x D x T] 209 | Quantized continuous representation of input 210 | """ 211 | z_q = 0.0 212 | z_p = [] 213 | n_codebooks = codes.shape[1] 214 | for i in range(n_codebooks): 215 | z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) 216 | z_p.append(z_p_i) 217 | 218 | z_q_i = self.quantizers[i].out_proj(z_p_i) 219 | z_q = z_q + z_q_i 220 | return z_q, torch.cat(z_p, dim=1), codes 221 | 222 | def from_latents(self, latents: torch.Tensor): 223 | """Given the unquantized latents, reconstruct the 224 | continuous representation after quantization. 225 | 226 | Parameters 227 | ---------- 228 | latents : Tensor[B x N x T] 229 | Continuous representation of input after projection 230 | 231 | Returns 232 | ------- 233 | Tensor[B x D x T] 234 | Quantized representation of full-projected space 235 | Tensor[B x D x T] 236 | Quantized representation of latent space 237 | """ 238 | z_q = 0 239 | z_p = [] 240 | codes = [] 241 | dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) 242 | 243 | n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ 244 | 0 245 | ] 246 | for i in range(n_codebooks): 247 | j, k = dims[i], dims[i + 1] 248 | z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) 249 | z_p.append(z_p_i) 250 | codes.append(codes_i) 251 | 252 | z_q_i = self.quantizers[i].out_proj(z_p_i) 253 | z_q = z_q + z_q_i 254 | 255 | return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) 256 | 257 | 258 | if __name__ == "__main__": 259 | rvq = ResidualVectorQuantize(quantizer_dropout=True) 260 | x = torch.randn(16, 512, 80) 261 | y = rvq(x) 262 | print(y["latents"].shape) 263 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from loguru import logger 4 | from torchvision import transforms 5 | from torchvision.transforms import v2 6 | from diffusers.utils.torch_utils import randn_tensor 7 | from transformers import AutoTokenizer, AutoModel, ClapTextModelWithProjection 8 | from ..models.dac_vae.model.dac import DAC 9 | from ..models.synchformer import Synchformer 10 | from ..models.hifi_foley import HunyuanVideoFoley 11 | from .config_utils import load_yaml, AttributeDict 12 | from .schedulers import FlowMatchDiscreteScheduler 13 | from tqdm import tqdm 14 | 15 | def load_state_dict(model, model_path): 16 | logger.info(f"Loading model state dict from: {model_path}") 17 | state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False) 18 | 19 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 20 | 21 | if missing_keys: 22 | logger.warning(f"Missing keys in state dict ({len(missing_keys)} keys):") 23 | for key in missing_keys: 24 | logger.warning(f" - {key}") 25 | else: 26 | logger.info("No missing keys found") 27 | 28 | if unexpected_keys: 29 | logger.warning(f"Unexpected keys in state dict ({len(unexpected_keys)} keys):") 30 | for key in unexpected_keys: 31 | logger.warning(f" - {key}") 32 | else: 33 | logger.info("No unexpected keys found") 34 | 35 | logger.info("Model state dict loaded successfully") 36 | return model 37 | 38 | def load_model(model_path, config_path, device): 39 | logger.info("Starting model loading process...") 40 | logger.info(f"Configuration file: {config_path}") 41 | logger.info(f"Model weights dir: {model_path}") 42 | logger.info(f"Target device: {device}") 43 | 44 | cfg = load_yaml(config_path) 45 | logger.info("Configuration loaded successfully") 46 | 47 | # HunyuanVideoFoley 48 | logger.info("Loading HunyuanVideoFoley main model...") 49 | foley_model = HunyuanVideoFoley(cfg, dtype=torch.bfloat16, device=device).to(device=device, dtype=torch.bfloat16) 50 | foley_model = load_state_dict(foley_model, os.path.join(model_path, "hunyuanvideo_foley.pth")) 51 | foley_model.eval() 52 | logger.info("HunyuanVideoFoley model loaded and set to evaluation mode") 53 | 54 | # DAC-VAE 55 | dac_path = os.path.join(model_path, "vae_128d_48k.pth") 56 | logger.info(f"Loading DAC VAE model from: {dac_path}") 57 | dac_model = DAC.load(dac_path) 58 | dac_model = dac_model.to(device) 59 | dac_model.requires_grad_(False) 60 | dac_model.eval() 61 | logger.info("DAC VAE model loaded successfully") 62 | 63 | # Siglip2 visual-encoder 64 | logger.info("Loading SigLIP2 visual encoder...") 65 | siglip2_preprocess = transforms.Compose([ 66 | transforms.Resize((512, 512)), 67 | transforms.ToTensor(), 68 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 69 | ]) 70 | siglip2_model = AutoModel.from_pretrained("google/siglip2-base-patch16-512").to(device).eval() 71 | logger.info("SigLIP2 model and preprocessing pipeline loaded successfully") 72 | 73 | # clap text-encoder 74 | logger.info("Loading CLAP text encoder...") 75 | clap_tokenizer = AutoTokenizer.from_pretrained("laion/larger_clap_general") 76 | clap_model = ClapTextModelWithProjection.from_pretrained("laion/larger_clap_general").to(device) 77 | logger.info("CLAP tokenizer and model loaded successfully") 78 | 79 | # syncformer 80 | syncformer_path = os.path.join(model_path, "synchformer_state_dict.pth") 81 | logger.info(f"Loading Synchformer model from: {syncformer_path}") 82 | syncformer_preprocess = v2.Compose( 83 | [ 84 | v2.Resize(224, interpolation=v2.InterpolationMode.BICUBIC), 85 | v2.CenterCrop(224), 86 | v2.ToImage(), 87 | v2.ToDtype(torch.float32, scale=True), 88 | v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 89 | ] 90 | ) 91 | 92 | syncformer_model = Synchformer() 93 | syncformer_model.load_state_dict(torch.load(syncformer_path, weights_only=False, map_location="cpu")) 94 | syncformer_model = syncformer_model.to(device).eval() 95 | logger.info("Synchformer model and preprocessing pipeline loaded successfully") 96 | 97 | 98 | logger.info("Creating model dictionary with attribute access...") 99 | model_dict = AttributeDict({ 100 | 'foley_model': foley_model, 101 | 'dac_model': dac_model, 102 | 'siglip2_preprocess': siglip2_preprocess, 103 | 'siglip2_model': siglip2_model, 104 | 'clap_tokenizer': clap_tokenizer, 105 | 'clap_model': clap_model, 106 | 'syncformer_preprocess': syncformer_preprocess, 107 | 'syncformer_model': syncformer_model, 108 | 'device': device, 109 | }) 110 | 111 | logger.info("All models loaded successfully!") 112 | logger.info("Available model components:") 113 | for key in model_dict.keys(): 114 | logger.info(f" - {key}") 115 | logger.info("Models can be accessed via attribute notation (e.g., models.foley_model)") 116 | 117 | return model_dict, cfg 118 | 119 | def retrieve_timesteps( 120 | scheduler, 121 | num_inference_steps, 122 | device, 123 | **kwargs, 124 | ): 125 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 126 | timesteps = scheduler.timesteps 127 | return timesteps, num_inference_steps 128 | 129 | 130 | def prepare_latents(scheduler, batch_size, num_channels_latents, length, dtype, device): 131 | shape = (batch_size, num_channels_latents, int(length)) 132 | latents = randn_tensor(shape, device=device, dtype=dtype) 133 | 134 | # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler 135 | if hasattr(scheduler, "init_noise_sigma"): 136 | # scale the initial noise by the standard deviation required by the scheduler 137 | latents = latents * scheduler.init_noise_sigma 138 | 139 | return latents 140 | 141 | 142 | @torch.no_grad() 143 | def denoise_process(visual_feats, text_feats, audio_len_in_s, model_dict, cfg, guidance_scale=4.5, num_inference_steps=50, batch_size=1): 144 | 145 | target_dtype = model_dict.foley_model.dtype 146 | autocast_enabled = target_dtype != torch.float32 147 | device = model_dict.device 148 | 149 | scheduler = FlowMatchDiscreteScheduler( 150 | shift=cfg.diffusion_config.sample_flow_shift, 151 | reverse=cfg.diffusion_config.flow_reverse, 152 | solver=cfg.diffusion_config.flow_solver, 153 | use_flux_shift=cfg.diffusion_config.sample_use_flux_shift, 154 | flux_base_shift=cfg.diffusion_config.flux_base_shift, 155 | flux_max_shift=cfg.diffusion_config.flux_max_shift, 156 | ) 157 | 158 | timesteps, num_inference_steps = retrieve_timesteps( 159 | scheduler, 160 | num_inference_steps, 161 | device, 162 | ) 163 | 164 | latents = prepare_latents( 165 | scheduler, 166 | batch_size=batch_size, 167 | num_channels_latents=cfg.model_config.model_kwargs.audio_vae_latent_dim, 168 | length=audio_len_in_s * cfg.model_config.model_kwargs.audio_frame_rate, 169 | dtype=target_dtype, 170 | device=device, 171 | ) 172 | 173 | # Denoise loop 174 | for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Denoising steps"): 175 | # noise latents 176 | latent_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents 177 | latent_input = scheduler.scale_model_input(latent_input, t) 178 | 179 | t_expand = t.repeat(latent_input.shape[0]) 180 | 181 | # siglip2 features 182 | siglip2_feat = visual_feats.siglip2_feat.repeat(batch_size, 1, 1) # Repeat for batch_size 183 | uncond_siglip2_feat = model_dict.foley_model.get_empty_clip_sequence( 184 | bs=batch_size, len=siglip2_feat.shape[1] 185 | ).to(device) 186 | 187 | if guidance_scale is not None and guidance_scale > 1.0: 188 | siglip2_feat_input = torch.cat([uncond_siglip2_feat, siglip2_feat], dim=0) 189 | else: 190 | siglip2_feat_input = siglip2_feat 191 | 192 | # syncformer features 193 | syncformer_feat = visual_feats.syncformer_feat.repeat(batch_size, 1, 1) # Repeat for batch_size 194 | uncond_syncformer_feat = model_dict.foley_model.get_empty_sync_sequence( 195 | bs=batch_size, len=syncformer_feat.shape[1] 196 | ).to(device) 197 | if guidance_scale is not None and guidance_scale > 1.0: 198 | syncformer_feat_input = torch.cat([uncond_syncformer_feat, syncformer_feat], dim=0) 199 | else: 200 | syncformer_feat_input = syncformer_feat 201 | 202 | # text features 203 | text_feat_repeated = text_feats.text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size 204 | uncond_text_feat_repeated = text_feats.uncond_text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size 205 | if guidance_scale is not None and guidance_scale > 1.0: 206 | text_feat_input = torch.cat([uncond_text_feat_repeated, text_feat_repeated], dim=0) 207 | else: 208 | text_feat_input = text_feat_repeated 209 | 210 | with torch.autocast(device_type=device.type, enabled=autocast_enabled, dtype=target_dtype): 211 | # Predict the noise residual 212 | noise_pred = model_dict.foley_model( 213 | x=latent_input, 214 | t=t_expand, 215 | cond=text_feat_input, 216 | clip_feat=siglip2_feat_input, 217 | sync_feat=syncformer_feat_input, 218 | return_dict=True, 219 | )["x"] 220 | 221 | noise_pred = noise_pred.to(dtype=torch.float32) 222 | 223 | if guidance_scale is not None and guidance_scale > 1.0: 224 | # Perform classifier-free guidance 225 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 226 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 227 | 228 | # Compute the previous noisy sample x_t -> x_t-1 229 | latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] 230 | 231 | # Post-process the latents to audio 232 | 233 | with torch.no_grad(): 234 | audio = model_dict.dac_model.decode(latents) 235 | audio = audio.float().cpu() 236 | 237 | audio = audio[:, :int(audio_len_in_s*model_dict.dac_model.sample_rate)] 238 | 239 | return audio, model_dict.dac_model.sample_rate 240 | 241 | 242 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/model/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Union 5 | 6 | import numpy as np 7 | import torch 8 | import tqdm 9 | from audiotools import AudioSignal 10 | from torch import nn 11 | 12 | SUPPORTED_VERSIONS = ["1.0.0"] 13 | 14 | 15 | @dataclass 16 | class DACFile: 17 | codes: torch.Tensor 18 | 19 | # Metadata 20 | chunk_length: int 21 | original_length: int 22 | input_db: float 23 | channels: int 24 | sample_rate: int 25 | padding: bool 26 | dac_version: str 27 | 28 | def save(self, path): 29 | artifacts = { 30 | "codes": self.codes.numpy().astype(np.uint16), 31 | "metadata": { 32 | "input_db": self.input_db.numpy().astype(np.float32), 33 | "original_length": self.original_length, 34 | "sample_rate": self.sample_rate, 35 | "chunk_length": self.chunk_length, 36 | "channels": self.channels, 37 | "padding": self.padding, 38 | "dac_version": SUPPORTED_VERSIONS[-1], 39 | }, 40 | } 41 | path = Path(path).with_suffix(".dac") 42 | with open(path, "wb") as f: 43 | np.save(f, artifacts) 44 | return path 45 | 46 | @classmethod 47 | def load(cls, path): 48 | artifacts = np.load(path, allow_pickle=True)[()] 49 | codes = torch.from_numpy(artifacts["codes"].astype(int)) 50 | if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: 51 | raise RuntimeError( 52 | f"Given file {path} can't be loaded with this version of descript-audio-codec." 53 | ) 54 | return cls(codes=codes, **artifacts["metadata"]) 55 | 56 | 57 | class CodecMixin: 58 | @property 59 | def padding(self): 60 | if not hasattr(self, "_padding"): 61 | self._padding = True 62 | return self._padding 63 | 64 | @padding.setter 65 | def padding(self, value): 66 | assert isinstance(value, bool) 67 | 68 | layers = [ 69 | l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d)) 70 | ] 71 | 72 | for layer in layers: 73 | if value: 74 | if hasattr(layer, "original_padding"): 75 | layer.padding = layer.original_padding 76 | else: 77 | layer.original_padding = layer.padding 78 | layer.padding = tuple(0 for _ in range(len(layer.padding))) 79 | 80 | self._padding = value 81 | 82 | def get_delay(self): 83 | # Any number works here, delay is invariant to input length 84 | l_out = self.get_output_length(0) 85 | L = l_out 86 | 87 | layers = [] 88 | for layer in self.modules(): 89 | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): 90 | layers.append(layer) 91 | 92 | for layer in reversed(layers): 93 | d = layer.dilation[0] 94 | k = layer.kernel_size[0] 95 | s = layer.stride[0] 96 | 97 | if isinstance(layer, nn.ConvTranspose1d): 98 | L = ((L - d * (k - 1) - 1) / s) + 1 99 | elif isinstance(layer, nn.Conv1d): 100 | L = (L - 1) * s + d * (k - 1) + 1 101 | 102 | L = math.ceil(L) 103 | 104 | l_in = L 105 | 106 | return (l_in - l_out) // 2 107 | 108 | def get_output_length(self, input_length): 109 | L = input_length 110 | # Calculate output length 111 | for layer in self.modules(): 112 | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): 113 | d = layer.dilation[0] 114 | k = layer.kernel_size[0] 115 | s = layer.stride[0] 116 | 117 | if isinstance(layer, nn.Conv1d): 118 | L = ((L - d * (k - 1) - 1) / s) + 1 119 | elif isinstance(layer, nn.ConvTranspose1d): 120 | L = (L - 1) * s + d * (k - 1) + 1 121 | 122 | L = math.floor(L) 123 | return L 124 | 125 | @torch.no_grad() 126 | def compress( 127 | self, 128 | audio_path_or_signal: Union[str, Path, AudioSignal], 129 | win_duration: float = 1.0, 130 | verbose: bool = False, 131 | normalize_db: float = -16, 132 | n_quantizers: int = None, 133 | ) -> DACFile: 134 | """Processes an audio signal from a file or AudioSignal object into 135 | discrete codes. This function processes the signal in short windows, 136 | using constant GPU memory. 137 | 138 | Parameters 139 | ---------- 140 | audio_path_or_signal : Union[str, Path, AudioSignal] 141 | audio signal to reconstruct 142 | win_duration : float, optional 143 | window duration in seconds, by default 5.0 144 | verbose : bool, optional 145 | by default False 146 | normalize_db : float, optional 147 | normalize db, by default -16 148 | 149 | Returns 150 | ------- 151 | DACFile 152 | Object containing compressed codes and metadata 153 | required for decompression 154 | """ 155 | audio_signal = audio_path_or_signal 156 | if isinstance(audio_signal, (str, Path)): 157 | audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) 158 | 159 | self.eval() 160 | original_padding = self.padding 161 | original_device = audio_signal.device 162 | 163 | audio_signal = audio_signal.clone() 164 | audio_signal = audio_signal.to_mono() 165 | original_sr = audio_signal.sample_rate 166 | 167 | resample_fn = audio_signal.resample 168 | loudness_fn = audio_signal.loudness 169 | 170 | # If audio is > 10 minutes long, use the ffmpeg versions 171 | if audio_signal.signal_duration >= 10 * 60 * 60: 172 | resample_fn = audio_signal.ffmpeg_resample 173 | loudness_fn = audio_signal.ffmpeg_loudness 174 | 175 | original_length = audio_signal.signal_length 176 | resample_fn(self.sample_rate) 177 | input_db = loudness_fn() 178 | 179 | if normalize_db is not None: 180 | audio_signal.normalize(normalize_db) 181 | audio_signal.ensure_max_of_audio() 182 | 183 | nb, nac, nt = audio_signal.audio_data.shape 184 | audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) 185 | win_duration = ( 186 | audio_signal.signal_duration if win_duration is None else win_duration 187 | ) 188 | 189 | if audio_signal.signal_duration <= win_duration: 190 | # Unchunked compression (used if signal length < win duration) 191 | self.padding = True 192 | n_samples = nt 193 | hop = nt 194 | else: 195 | # Chunked inference 196 | self.padding = False 197 | # Zero-pad signal on either side by the delay 198 | audio_signal.zero_pad(self.delay, self.delay) 199 | n_samples = int(win_duration * self.sample_rate) 200 | # Round n_samples to nearest hop length multiple 201 | n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) 202 | hop = self.get_output_length(n_samples) 203 | 204 | codes = [] 205 | range_fn = range if not verbose else tqdm.trange 206 | 207 | for i in range_fn(0, nt, hop): 208 | x = audio_signal[..., i : i + n_samples] 209 | x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) 210 | 211 | audio_data = x.audio_data.to(self.device) 212 | audio_data = self.preprocess(audio_data, self.sample_rate) 213 | _, c, _, _, _ = self.encode(audio_data, n_quantizers) 214 | codes.append(c.to(original_device)) 215 | chunk_length = c.shape[-1] 216 | 217 | codes = torch.cat(codes, dim=-1) 218 | 219 | dac_file = DACFile( 220 | codes=codes, 221 | chunk_length=chunk_length, 222 | original_length=original_length, 223 | input_db=input_db, 224 | channels=nac, 225 | sample_rate=original_sr, 226 | padding=self.padding, 227 | dac_version=SUPPORTED_VERSIONS[-1], 228 | ) 229 | 230 | if n_quantizers is not None: 231 | codes = codes[:, :n_quantizers, :] 232 | 233 | self.padding = original_padding 234 | return dac_file 235 | 236 | @torch.no_grad() 237 | def decompress( 238 | self, 239 | obj: Union[str, Path, DACFile], 240 | verbose: bool = False, 241 | ) -> AudioSignal: 242 | """Reconstruct audio from a given .dac file 243 | 244 | Parameters 245 | ---------- 246 | obj : Union[str, Path, DACFile] 247 | .dac file location or corresponding DACFile object. 248 | verbose : bool, optional 249 | Prints progress if True, by default False 250 | 251 | Returns 252 | ------- 253 | AudioSignal 254 | Object with the reconstructed audio 255 | """ 256 | self.eval() 257 | if isinstance(obj, (str, Path)): 258 | obj = DACFile.load(obj) 259 | 260 | original_padding = self.padding 261 | self.padding = obj.padding 262 | 263 | range_fn = range if not verbose else tqdm.trange 264 | codes = obj.codes 265 | original_device = codes.device 266 | chunk_length = obj.chunk_length 267 | recons = [] 268 | 269 | for i in range_fn(0, codes.shape[-1], chunk_length): 270 | c = codes[..., i : i + chunk_length].to(self.device) 271 | z = self.quantizer.from_codes(c)[0] 272 | r = self.decode(z) 273 | recons.append(r.to(original_device)) 274 | 275 | recons = torch.cat(recons, dim=-1) 276 | recons = AudioSignal(recons, self.sample_rate) 277 | 278 | resample_fn = recons.resample 279 | loudness_fn = recons.loudness 280 | 281 | # If audio is > 10 minutes long, use the ffmpeg versions 282 | if recons.signal_duration >= 10 * 60 * 60: 283 | resample_fn = recons.ffmpeg_resample 284 | loudness_fn = recons.ffmpeg_loudness 285 | 286 | if obj.input_db is not None: 287 | recons.normalize(obj.input_db) 288 | 289 | resample_fn(obj.sample_rate) 290 | 291 | if obj.original_length is not None: 292 | recons = recons[..., : obj.original_length] 293 | loudness_fn() 294 | recons.audio_data = recons.audio_data.reshape( 295 | -1, obj.channels, obj.original_length 296 | ) 297 | else: 298 | loudness_fn() 299 | 300 | self.padding = original_padding 301 | return recons 302 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/video_model_builder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | # Copyright 2020 Ross Wightman 4 | # Modified Model definition 5 | 6 | from collections import OrderedDict 7 | from functools import partial 8 | 9 | import torch 10 | import torch.nn as nn 11 | from timm.layers import trunc_normal_ 12 | 13 | from .vit_helper import PatchEmbed, PatchEmbed3D, DividedSpaceTimeBlock 14 | 15 | 16 | class VisionTransformer(nn.Module): 17 | """Vision Transformer with support for patch or hybrid CNN input stage""" 18 | 19 | def __init__(self, cfg): 20 | super().__init__() 21 | self.img_size = cfg.DATA.TRAIN_CROP_SIZE 22 | self.patch_size = cfg.VIT.PATCH_SIZE 23 | self.in_chans = cfg.VIT.CHANNELS 24 | if cfg.TRAIN.DATASET == "Epickitchens": 25 | self.num_classes = [97, 300] 26 | else: 27 | self.num_classes = cfg.MODEL.NUM_CLASSES 28 | self.embed_dim = cfg.VIT.EMBED_DIM 29 | self.depth = cfg.VIT.DEPTH 30 | self.num_heads = cfg.VIT.NUM_HEADS 31 | self.mlp_ratio = cfg.VIT.MLP_RATIO 32 | self.qkv_bias = cfg.VIT.QKV_BIAS 33 | self.drop_rate = cfg.VIT.DROP 34 | self.drop_path_rate = cfg.VIT.DROP_PATH 35 | self.head_dropout = cfg.VIT.HEAD_DROPOUT 36 | self.video_input = cfg.VIT.VIDEO_INPUT 37 | self.temporal_resolution = cfg.VIT.TEMPORAL_RESOLUTION 38 | self.use_mlp = cfg.VIT.USE_MLP 39 | self.num_features = self.embed_dim 40 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 41 | self.attn_drop_rate = cfg.VIT.ATTN_DROPOUT 42 | self.head_act = cfg.VIT.HEAD_ACT 43 | self.cfg = cfg 44 | 45 | # Patch Embedding 46 | self.patch_embed = PatchEmbed( 47 | img_size=224, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim 48 | ) 49 | 50 | # 3D Patch Embedding 51 | self.patch_embed_3d = PatchEmbed3D( 52 | img_size=self.img_size, 53 | temporal_resolution=self.temporal_resolution, 54 | patch_size=self.patch_size, 55 | in_chans=self.in_chans, 56 | embed_dim=self.embed_dim, 57 | z_block_size=self.cfg.VIT.PATCH_SIZE_TEMP, 58 | ) 59 | self.patch_embed_3d.proj.weight.data = torch.zeros_like(self.patch_embed_3d.proj.weight.data) 60 | 61 | # Number of patches 62 | if self.video_input: 63 | num_patches = self.patch_embed.num_patches * self.temporal_resolution 64 | else: 65 | num_patches = self.patch_embed.num_patches 66 | self.num_patches = num_patches 67 | 68 | # CLS token 69 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 70 | trunc_normal_(self.cls_token, std=0.02) 71 | 72 | # Positional embedding 73 | self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim)) 74 | self.pos_drop = nn.Dropout(p=cfg.VIT.POS_DROPOUT) 75 | trunc_normal_(self.pos_embed, std=0.02) 76 | 77 | if self.cfg.VIT.POS_EMBED == "joint": 78 | self.st_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim)) 79 | trunc_normal_(self.st_embed, std=0.02) 80 | elif self.cfg.VIT.POS_EMBED == "separate": 81 | self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim)) 82 | 83 | # Layer Blocks 84 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] 85 | if self.cfg.VIT.ATTN_LAYER == "divided": 86 | self.blocks = nn.ModuleList( 87 | [ 88 | DividedSpaceTimeBlock( 89 | attn_type=cfg.VIT.ATTN_LAYER, 90 | dim=self.embed_dim, 91 | num_heads=self.num_heads, 92 | mlp_ratio=self.mlp_ratio, 93 | qkv_bias=self.qkv_bias, 94 | drop=self.drop_rate, 95 | attn_drop=self.attn_drop_rate, 96 | drop_path=dpr[i], 97 | norm_layer=norm_layer, 98 | ) 99 | for i in range(self.depth) 100 | ] 101 | ) 102 | 103 | self.norm = norm_layer(self.embed_dim) 104 | 105 | # MLP head 106 | if self.use_mlp: 107 | hidden_dim = self.embed_dim 108 | if self.head_act == "tanh": 109 | # logging.info("Using TanH activation in MLP") 110 | act = nn.Tanh() 111 | elif self.head_act == "gelu": 112 | # logging.info("Using GELU activation in MLP") 113 | act = nn.GELU() 114 | else: 115 | # logging.info("Using ReLU activation in MLP") 116 | act = nn.ReLU() 117 | self.pre_logits = nn.Sequential( 118 | OrderedDict( 119 | [ 120 | ("fc", nn.Linear(self.embed_dim, hidden_dim)), 121 | ("act", act), 122 | ] 123 | ) 124 | ) 125 | else: 126 | self.pre_logits = nn.Identity() 127 | 128 | # Classifier Head 129 | self.head_drop = nn.Dropout(p=self.head_dropout) 130 | if isinstance(self.num_classes, (list,)) and len(self.num_classes) > 1: 131 | for a, i in enumerate(range(len(self.num_classes))): 132 | setattr(self, "head%d" % a, nn.Linear(self.embed_dim, self.num_classes[i])) 133 | else: 134 | self.head = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() 135 | 136 | # Initialize weights 137 | self.apply(self._init_weights) 138 | 139 | def _init_weights(self, m): 140 | if isinstance(m, nn.Linear): 141 | trunc_normal_(m.weight, std=0.02) 142 | if isinstance(m, nn.Linear) and m.bias is not None: 143 | nn.init.constant_(m.bias, 0) 144 | elif isinstance(m, nn.LayerNorm): 145 | nn.init.constant_(m.bias, 0) 146 | nn.init.constant_(m.weight, 1.0) 147 | 148 | @torch.jit.ignore 149 | def no_weight_decay(self): 150 | if self.cfg.VIT.POS_EMBED == "joint": 151 | return {"pos_embed", "cls_token", "st_embed"} 152 | else: 153 | return {"pos_embed", "cls_token", "temp_embed"} 154 | 155 | def get_classifier(self): 156 | return self.head 157 | 158 | def reset_classifier(self, num_classes, global_pool=""): 159 | self.num_classes = num_classes 160 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 161 | 162 | def forward_features(self, x): 163 | # if self.video_input: 164 | # x = x[0] 165 | B = x.shape[0] 166 | 167 | # Tokenize input 168 | # if self.cfg.VIT.PATCH_SIZE_TEMP > 1: 169 | # for simplicity of mapping between content dimensions (input x) and token dims (after patching) 170 | # we use the same trick as for AST (see modeling_ast.ASTModel.forward for the details): 171 | 172 | # apply patching on input 173 | x = self.patch_embed_3d(x) 174 | tok_mask = None 175 | 176 | # else: 177 | # tok_mask = None 178 | # # 2D tokenization 179 | # if self.video_input: 180 | # x = x.permute(0, 2, 1, 3, 4) 181 | # (B, T, C, H, W) = x.shape 182 | # x = x.reshape(B * T, C, H, W) 183 | 184 | # x = self.patch_embed(x) 185 | 186 | # if self.video_input: 187 | # (B2, T2, D2) = x.shape 188 | # x = x.reshape(B, T * T2, D2) 189 | 190 | # Append CLS token 191 | cls_tokens = self.cls_token.expand(B, -1, -1) 192 | x = torch.cat((cls_tokens, x), dim=1) 193 | # if tok_mask is not None: 194 | # # prepend 1(=keep) to the mask to account for the CLS token as well 195 | # tok_mask = torch.cat((torch.ones_like(tok_mask[:, [0]]), tok_mask), dim=1) 196 | 197 | # Interpolate positinoal embeddings 198 | # if self.cfg.DATA.TRAIN_CROP_SIZE != 224: 199 | # pos_embed = self.pos_embed 200 | # N = pos_embed.shape[1] - 1 201 | # npatch = int((x.size(1) - 1) / self.temporal_resolution) 202 | # class_emb = pos_embed[:, 0] 203 | # pos_embed = pos_embed[:, 1:] 204 | # dim = x.shape[-1] 205 | # pos_embed = torch.nn.functional.interpolate( 206 | # pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 207 | # scale_factor=math.sqrt(npatch / N), 208 | # mode='bicubic', 209 | # ) 210 | # pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 211 | # new_pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1) 212 | # else: 213 | new_pos_embed = self.pos_embed 214 | npatch = self.patch_embed.num_patches 215 | 216 | # Add positional embeddings to input 217 | if self.video_input: 218 | if self.cfg.VIT.POS_EMBED == "separate": 219 | cls_embed = self.pos_embed[:, 0, :].unsqueeze(1) 220 | tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1) 221 | tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1) 222 | total_pos_embed = tile_pos_embed + tile_temporal_embed 223 | total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1) 224 | x = x + total_pos_embed 225 | elif self.cfg.VIT.POS_EMBED == "joint": 226 | x = x + self.st_embed 227 | else: 228 | # image input 229 | x = x + new_pos_embed 230 | 231 | # Apply positional dropout 232 | x = self.pos_drop(x) 233 | 234 | # Encoding using transformer layers 235 | for i, blk in enumerate(self.blocks): 236 | x = blk( 237 | x, 238 | seq_len=npatch, 239 | num_frames=self.temporal_resolution, 240 | approx=self.cfg.VIT.APPROX_ATTN_TYPE, 241 | num_landmarks=self.cfg.VIT.APPROX_ATTN_DIM, 242 | tok_mask=tok_mask, 243 | ) 244 | 245 | ### v-iashin: I moved it to the forward pass 246 | # x = self.norm(x)[:, 0] 247 | # x = self.pre_logits(x) 248 | ### 249 | return x, tok_mask 250 | 251 | # def forward(self, x): 252 | # x = self.forward_features(x) 253 | # ### v-iashin: here. This should leave the same forward output as before 254 | # x = self.norm(x)[:, 0] 255 | # x = self.pre_logits(x) 256 | # ### 257 | # x = self.head_drop(x) 258 | # if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1: 259 | # output = [] 260 | # for head in range(len(self.num_classes)): 261 | # x_out = getattr(self, "head%d" % head)(x) 262 | # if not self.training: 263 | # x_out = torch.nn.functional.softmax(x_out, dim=-1) 264 | # output.append(x_out) 265 | # return output 266 | # else: 267 | # x = self.head(x) 268 | # if not self.training: 269 | # x = torch.nn.functional.softmax(x, dim=-1) 270 | # return x 271 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/nn/loss.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from typing import List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from audiotools import AudioSignal 7 | from audiotools import STFTParams 8 | from torch import nn 9 | 10 | 11 | class L1Loss(nn.L1Loss): 12 | """L1 Loss between AudioSignals. Defaults 13 | to comparing ``audio_data``, but any 14 | attribute of an AudioSignal can be used. 15 | 16 | Parameters 17 | ---------- 18 | attribute : str, optional 19 | Attribute of signal to compare, defaults to ``audio_data``. 20 | weight : float, optional 21 | Weight of this loss, defaults to 1.0. 22 | 23 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py 24 | """ 25 | 26 | def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): 27 | self.attribute = attribute 28 | self.weight = weight 29 | super().__init__(**kwargs) 30 | 31 | def forward(self, x: AudioSignal, y: AudioSignal): 32 | """ 33 | Parameters 34 | ---------- 35 | x : AudioSignal 36 | Estimate AudioSignal 37 | y : AudioSignal 38 | Reference AudioSignal 39 | 40 | Returns 41 | ------- 42 | torch.Tensor 43 | L1 loss between AudioSignal attributes. 44 | """ 45 | if isinstance(x, AudioSignal): 46 | x = getattr(x, self.attribute) 47 | y = getattr(y, self.attribute) 48 | return super().forward(x, y) 49 | 50 | 51 | class SISDRLoss(nn.Module): 52 | """ 53 | Computes the Scale-Invariant Source-to-Distortion Ratio between a batch 54 | of estimated and reference audio signals or aligned features. 55 | 56 | Parameters 57 | ---------- 58 | scaling : int, optional 59 | Whether to use scale-invariant (True) or 60 | signal-to-noise ratio (False), by default True 61 | reduction : str, optional 62 | How to reduce across the batch (either 'mean', 63 | 'sum', or none).], by default ' mean' 64 | zero_mean : int, optional 65 | Zero mean the references and estimates before 66 | computing the loss, by default True 67 | clip_min : int, optional 68 | The minimum possible loss value. Helps network 69 | to not focus on making already good examples better, by default None 70 | weight : float, optional 71 | Weight of this loss, defaults to 1.0. 72 | 73 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py 74 | """ 75 | 76 | def __init__( 77 | self, 78 | scaling: int = True, 79 | reduction: str = "mean", 80 | zero_mean: int = True, 81 | clip_min: int = None, 82 | weight: float = 1.0, 83 | ): 84 | self.scaling = scaling 85 | self.reduction = reduction 86 | self.zero_mean = zero_mean 87 | self.clip_min = clip_min 88 | self.weight = weight 89 | super().__init__() 90 | 91 | def forward(self, x: AudioSignal, y: AudioSignal): 92 | eps = 1e-8 93 | # nb, nc, nt 94 | if isinstance(x, AudioSignal): 95 | references = x.audio_data 96 | estimates = y.audio_data 97 | else: 98 | references = x 99 | estimates = y 100 | 101 | nb = references.shape[0] 102 | references = references.reshape(nb, 1, -1).permute(0, 2, 1) 103 | estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) 104 | 105 | # samples now on axis 1 106 | if self.zero_mean: 107 | mean_reference = references.mean(dim=1, keepdim=True) 108 | mean_estimate = estimates.mean(dim=1, keepdim=True) 109 | else: 110 | mean_reference = 0 111 | mean_estimate = 0 112 | 113 | _references = references - mean_reference 114 | _estimates = estimates - mean_estimate 115 | 116 | references_projection = (_references**2).sum(dim=-2) + eps 117 | references_on_estimates = (_estimates * _references).sum(dim=-2) + eps 118 | 119 | scale = ( 120 | (references_on_estimates / references_projection).unsqueeze(1) 121 | if self.scaling 122 | else 1 123 | ) 124 | 125 | e_true = scale * _references 126 | e_res = _estimates - e_true 127 | 128 | signal = (e_true**2).sum(dim=1) 129 | noise = (e_res**2).sum(dim=1) 130 | sdr = -10 * torch.log10(signal / noise + eps) 131 | 132 | if self.clip_min is not None: 133 | sdr = torch.clamp(sdr, min=self.clip_min) 134 | 135 | if self.reduction == "mean": 136 | sdr = sdr.mean() 137 | elif self.reduction == "sum": 138 | sdr = sdr.sum() 139 | return sdr 140 | 141 | 142 | class MultiScaleSTFTLoss(nn.Module): 143 | """Computes the multi-scale STFT loss from [1]. 144 | 145 | Parameters 146 | ---------- 147 | window_lengths : List[int], optional 148 | Length of each window of each STFT, by default [2048, 512] 149 | loss_fn : typing.Callable, optional 150 | How to compare each loss, by default nn.L1Loss() 151 | clamp_eps : float, optional 152 | Clamp on the log magnitude, below, by default 1e-5 153 | mag_weight : float, optional 154 | Weight of raw magnitude portion of loss, by default 1.0 155 | log_weight : float, optional 156 | Weight of log magnitude portion of loss, by default 1.0 157 | pow : float, optional 158 | Power to raise magnitude to before taking log, by default 2.0 159 | weight : float, optional 160 | Weight of this loss, by default 1.0 161 | match_stride : bool, optional 162 | Whether to match the stride of convolutional layers, by default False 163 | 164 | References 165 | ---------- 166 | 167 | 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. 168 | "DDSP: Differentiable Digital Signal Processing." 169 | International Conference on Learning Representations. 2019. 170 | 171 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py 172 | """ 173 | 174 | def __init__( 175 | self, 176 | window_lengths: List[int] = [2048, 512], 177 | loss_fn: typing.Callable = nn.L1Loss(), 178 | clamp_eps: float = 1e-5, 179 | mag_weight: float = 1.0, 180 | log_weight: float = 1.0, 181 | pow: float = 2.0, 182 | weight: float = 1.0, 183 | match_stride: bool = False, 184 | window_type: str = None, 185 | ): 186 | super().__init__() 187 | self.stft_params = [ 188 | STFTParams( 189 | window_length=w, 190 | hop_length=w // 4, 191 | match_stride=match_stride, 192 | window_type=window_type, 193 | ) 194 | for w in window_lengths 195 | ] 196 | self.loss_fn = loss_fn 197 | self.log_weight = log_weight 198 | self.mag_weight = mag_weight 199 | self.clamp_eps = clamp_eps 200 | self.weight = weight 201 | self.pow = pow 202 | 203 | def forward(self, x: AudioSignal, y: AudioSignal): 204 | """Computes multi-scale STFT between an estimate and a reference 205 | signal. 206 | 207 | Parameters 208 | ---------- 209 | x : AudioSignal 210 | Estimate signal 211 | y : AudioSignal 212 | Reference signal 213 | 214 | Returns 215 | ------- 216 | torch.Tensor 217 | Multi-scale STFT loss. 218 | """ 219 | loss = 0.0 220 | for s in self.stft_params: 221 | x.stft(s.window_length, s.hop_length, s.window_type) 222 | y.stft(s.window_length, s.hop_length, s.window_type) 223 | loss += self.log_weight * self.loss_fn( 224 | x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), 225 | y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), 226 | ) 227 | loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) 228 | return loss 229 | 230 | 231 | class MelSpectrogramLoss(nn.Module): 232 | """Compute distance between mel spectrograms. Can be used 233 | in a multi-scale way. 234 | 235 | Parameters 236 | ---------- 237 | n_mels : List[int] 238 | Number of mels per STFT, by default [150, 80], 239 | window_lengths : List[int], optional 240 | Length of each window of each STFT, by default [2048, 512] 241 | loss_fn : typing.Callable, optional 242 | How to compare each loss, by default nn.L1Loss() 243 | clamp_eps : float, optional 244 | Clamp on the log magnitude, below, by default 1e-5 245 | mag_weight : float, optional 246 | Weight of raw magnitude portion of loss, by default 1.0 247 | log_weight : float, optional 248 | Weight of log magnitude portion of loss, by default 1.0 249 | pow : float, optional 250 | Power to raise magnitude to before taking log, by default 2.0 251 | weight : float, optional 252 | Weight of this loss, by default 1.0 253 | match_stride : bool, optional 254 | Whether to match the stride of convolutional layers, by default False 255 | 256 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py 257 | """ 258 | 259 | def __init__( 260 | self, 261 | n_mels: List[int] = [150, 80], 262 | window_lengths: List[int] = [2048, 512], 263 | loss_fn: typing.Callable = nn.L1Loss(), 264 | clamp_eps: float = 1e-5, 265 | mag_weight: float = 1.0, 266 | log_weight: float = 1.0, 267 | pow: float = 2.0, 268 | weight: float = 1.0, 269 | match_stride: bool = False, 270 | mel_fmin: List[float] = [0.0, 0.0], 271 | mel_fmax: List[float] = [None, None], 272 | window_type: str = None, 273 | ): 274 | super().__init__() 275 | self.stft_params = [ 276 | STFTParams( 277 | window_length=w, 278 | hop_length=w // 4, 279 | match_stride=match_stride, 280 | window_type=window_type, 281 | ) 282 | for w in window_lengths 283 | ] 284 | self.n_mels = n_mels 285 | self.loss_fn = loss_fn 286 | self.clamp_eps = clamp_eps 287 | self.log_weight = log_weight 288 | self.mag_weight = mag_weight 289 | self.weight = weight 290 | self.mel_fmin = mel_fmin 291 | self.mel_fmax = mel_fmax 292 | self.pow = pow 293 | 294 | def forward(self, x: AudioSignal, y: AudioSignal): 295 | """Computes mel loss between an estimate and a reference 296 | signal. 297 | 298 | Parameters 299 | ---------- 300 | x : AudioSignal 301 | Estimate signal 302 | y : AudioSignal 303 | Reference signal 304 | 305 | Returns 306 | ------- 307 | torch.Tensor 308 | Mel loss. 309 | """ 310 | loss = 0.0 311 | for n_mels, fmin, fmax, s in zip( 312 | self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params 313 | ): 314 | kwargs = { 315 | "window_length": s.window_length, 316 | "hop_length": s.hop_length, 317 | "window_type": s.window_type, 318 | } 319 | x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) 320 | y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) 321 | 322 | loss += self.log_weight * self.loss_fn( 323 | x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), 324 | y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), 325 | ) 326 | loss += self.mag_weight * self.loss_fn(x_mels, y_mels) 327 | return loss 328 | 329 | 330 | class GANLoss(nn.Module): 331 | """ 332 | Computes a discriminator loss, given a discriminator on 333 | generated waveforms/spectrograms compared to ground truth 334 | waveforms/spectrograms. Computes the loss for both the 335 | discriminator and the generator in separate functions. 336 | """ 337 | 338 | def __init__(self, discriminator): 339 | super().__init__() 340 | self.discriminator = discriminator 341 | 342 | def forward(self, fake, real): 343 | d_fake = self.discriminator(fake.audio_data) 344 | d_real = self.discriminator(real.audio_data) 345 | return d_fake, d_real 346 | 347 | def discriminator_loss(self, fake, real): 348 | d_fake, d_real = self.forward(fake.clone().detach(), real) 349 | 350 | loss_d = 0 351 | for x_fake, x_real in zip(d_fake, d_real): 352 | loss_d += torch.mean(x_fake[-1] ** 2) 353 | loss_d += torch.mean((1 - x_real[-1]) ** 2) 354 | return loss_d 355 | 356 | def generator_loss(self, fake, real): 357 | d_fake, d_real = self.forward(fake, real) 358 | 359 | loss_g = 0 360 | for x_fake in d_fake: 361 | loss_g += torch.mean((1 - x_fake[-1]) ** 2) 362 | 363 | loss_feature = 0 364 | 365 | for i in range(len(d_fake)): 366 | for j in range(len(d_fake[i]) - 1): 367 | loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) 368 | return loss_g, loss_feature 369 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/synchformer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from typing import Any, Mapping 4 | 5 | import einops 6 | import numpy as np 7 | import torch 8 | import torchaudio 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | from .motionformer import MotionFormer 13 | from .ast_model import AST 14 | from .utils import Config 15 | 16 | 17 | class Synchformer(nn.Module): 18 | 19 | def __init__(self): 20 | super().__init__() 21 | 22 | self.vfeat_extractor = MotionFormer( 23 | extract_features=True, 24 | factorize_space_time=True, 25 | agg_space_module="TransformerEncoderLayer", 26 | agg_time_module="torch.nn.Identity", 27 | add_global_repr=False, 28 | ) 29 | self.afeat_extractor = AST( 30 | extract_features=True, 31 | max_spec_t=66, 32 | factorize_freq_time=True, 33 | agg_freq_module="TransformerEncoderLayer", 34 | agg_time_module="torch.nn.Identity", 35 | add_global_repr=False, 36 | ) 37 | 38 | # # bridging the s3d latent dim (1024) into what is specified in the config 39 | # # to match e.g. the transformer dim 40 | self.vproj = nn.Linear(in_features=768, out_features=768) 41 | self.aproj = nn.Linear(in_features=768, out_features=768) 42 | self.transformer = GlobalTransformer( 43 | tok_pdrop=0.0, embd_pdrop=0.1, resid_pdrop=0.1, attn_pdrop=0.1, n_layer=3, n_head=8, n_embd=768 44 | ) 45 | 46 | def forward(self, vis): 47 | B, S, Tv, C, H, W = vis.shape 48 | vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W) 49 | # feat extractors return a tuple of segment-level and global features (ignored for sync) 50 | # (B, S, tv, D), e.g. (B, 7, 8, 768) 51 | vis = self.vfeat_extractor(vis) 52 | return vis 53 | 54 | def compare_v_a(self, vis: torch.Tensor, aud: torch.Tensor): 55 | vis = self.vproj(vis) 56 | aud = self.aproj(aud) 57 | 58 | B, S, tv, D = vis.shape 59 | B, S, ta, D = aud.shape 60 | vis = vis.view(B, S * tv, D) # (B, S*tv, D) 61 | aud = aud.view(B, S * ta, D) # (B, S*ta, D) 62 | # print(vis.shape, aud.shape) 63 | 64 | # self.transformer will concatenate the vis and aud in one sequence with aux tokens, 65 | # ie `CvvvvMaaaaaa`, and will return the logits for the CLS tokens 66 | logits = self.transformer(vis, aud) # (B, cls); or (B, cls) and (B, 2) if DoubtingTransformer 67 | 68 | return logits 69 | 70 | def extract_vfeats(self, vis): 71 | B, S, Tv, C, H, W = vis.shape 72 | vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W) 73 | # feat extractors return a tuple of segment-level and global features (ignored for sync) 74 | # (B, S, tv, D), e.g. (B, 7, 8, 768) 75 | vis = self.vfeat_extractor(vis) 76 | return vis 77 | 78 | def extract_afeats(self, aud): 79 | B, S, _, Fa, Ta = aud.shape 80 | aud = aud.view(B, S, Fa, Ta).permute(0, 1, 3, 2) # (B, S, Ta, F) 81 | # (B, S, ta, D), e.g. (B, 7, 6, 768) 82 | aud, _ = self.afeat_extractor(aud) 83 | return aud 84 | 85 | def compute_loss(self, logits, targets, loss_fn: str = None): 86 | loss = None 87 | if targets is not None: 88 | if loss_fn is None or loss_fn == "cross_entropy": 89 | # logits: (B, cls) and targets: (B,) 90 | loss = F.cross_entropy(logits, targets) 91 | else: 92 | raise NotImplementedError(f"Loss {loss_fn} not implemented") 93 | return loss 94 | 95 | def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True): 96 | # discard all entries except vfeat_extractor 97 | # sd = {k: v for k, v in sd.items() if k.startswith('vfeat_extractor')} 98 | 99 | return super().load_state_dict(sd, strict) 100 | 101 | 102 | class RandInitPositionalEncoding(nn.Module): 103 | """Random inited trainable pos embedding. It is just applied on the sequence, thus respects no priors.""" 104 | 105 | def __init__(self, block_shape: list, n_embd: int): 106 | super().__init__() 107 | self.block_shape = block_shape 108 | self.n_embd = n_embd 109 | self.pos_emb = nn.Parameter(torch.randn(1, *block_shape, n_embd)) 110 | 111 | def forward(self, token_embeddings): 112 | return token_embeddings + self.pos_emb 113 | 114 | 115 | class GlobalTransformer(torch.nn.Module): 116 | """Same as in SparseSync but without the selector transformers and the head""" 117 | 118 | def __init__( 119 | self, 120 | tok_pdrop=0.0, 121 | embd_pdrop=0.1, 122 | resid_pdrop=0.1, 123 | attn_pdrop=0.1, 124 | n_layer=3, 125 | n_head=8, 126 | n_embd=768, 127 | pos_emb_block_shape=[ 128 | 198, 129 | ], 130 | n_off_head_out=21, 131 | ) -> None: 132 | super().__init__() 133 | self.config = Config( 134 | embd_pdrop=embd_pdrop, 135 | resid_pdrop=resid_pdrop, 136 | attn_pdrop=attn_pdrop, 137 | n_layer=n_layer, 138 | n_head=n_head, 139 | n_embd=n_embd, 140 | ) 141 | # input norm 142 | self.vis_in_lnorm = torch.nn.LayerNorm(n_embd) 143 | self.aud_in_lnorm = torch.nn.LayerNorm(n_embd) 144 | # aux tokens 145 | self.OFF_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd)) 146 | self.MOD_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd)) 147 | # whole token dropout 148 | self.tok_pdrop = tok_pdrop 149 | self.tok_drop_vis = torch.nn.Dropout1d(tok_pdrop) 150 | self.tok_drop_aud = torch.nn.Dropout1d(tok_pdrop) 151 | # maybe add pos emb 152 | self.pos_emb_cfg = RandInitPositionalEncoding( 153 | block_shape=pos_emb_block_shape, 154 | n_embd=n_embd, 155 | ) 156 | # the stem 157 | self.drop = torch.nn.Dropout(embd_pdrop) 158 | self.blocks = torch.nn.Sequential(*[Block(self.config) for _ in range(n_layer)]) 159 | # pre-output norm 160 | self.ln_f = torch.nn.LayerNorm(n_embd) 161 | # maybe add a head 162 | self.off_head = torch.nn.Linear(in_features=n_embd, out_features=n_off_head_out) 163 | 164 | def forward(self, v: torch.Tensor, a: torch.Tensor, targets=None, attempt_to_apply_heads=True): 165 | B, Sv, D = v.shape 166 | B, Sa, D = a.shape 167 | # broadcasting special tokens to the batch size 168 | off_tok = einops.repeat(self.OFF_tok, "1 1 d -> b 1 d", b=B) 169 | mod_tok = einops.repeat(self.MOD_tok, "1 1 d -> b 1 d", b=B) 170 | # norm 171 | v, a = self.vis_in_lnorm(v), self.aud_in_lnorm(a) 172 | # maybe whole token dropout 173 | if self.tok_pdrop > 0: 174 | v, a = self.tok_drop_vis(v), self.tok_drop_aud(a) 175 | # (B, 1+Sv+1+Sa, D) 176 | x = torch.cat((off_tok, v, mod_tok, a), dim=1) 177 | # maybe add pos emb 178 | if hasattr(self, "pos_emb_cfg"): 179 | x = self.pos_emb_cfg(x) 180 | # dropout -> stem -> norm 181 | x = self.drop(x) 182 | x = self.blocks(x) 183 | x = self.ln_f(x) 184 | # maybe add heads 185 | if attempt_to_apply_heads and hasattr(self, "off_head"): 186 | x = self.off_head(x[:, 0, :]) 187 | return x 188 | 189 | 190 | class SelfAttention(nn.Module): 191 | """ 192 | A vanilla multi-head masked self-attention layer with a projection at the end. 193 | It is possible to use torch.nn.MultiheadAttention here but I am including an 194 | explicit implementation here to show that there is nothing too scary here. 195 | """ 196 | 197 | def __init__(self, config): 198 | super().__init__() 199 | assert config.n_embd % config.n_head == 0 200 | # key, query, value projections for all heads 201 | self.key = nn.Linear(config.n_embd, config.n_embd) 202 | self.query = nn.Linear(config.n_embd, config.n_embd) 203 | self.value = nn.Linear(config.n_embd, config.n_embd) 204 | # regularization 205 | self.attn_drop = nn.Dropout(config.attn_pdrop) 206 | self.resid_drop = nn.Dropout(config.resid_pdrop) 207 | # output projection 208 | self.proj = nn.Linear(config.n_embd, config.n_embd) 209 | # # causal mask to ensure that attention is only applied to the left in the input sequence 210 | # mask = torch.tril(torch.ones(config.block_size, 211 | # config.block_size)) 212 | # if hasattr(config, "n_unmasked"): 213 | # mask[:config.n_unmasked, :config.n_unmasked] = 1 214 | # self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size)) 215 | self.n_head = config.n_head 216 | 217 | def forward(self, x): 218 | B, T, C = x.size() 219 | 220 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 221 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 222 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 223 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 224 | 225 | # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 226 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 227 | # att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf')) 228 | att = F.softmax(att, dim=-1) 229 | y = self.attn_drop(att) @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 230 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 231 | 232 | # output projection 233 | y = self.resid_drop(self.proj(y)) 234 | 235 | return y 236 | 237 | 238 | class Block(nn.Module): 239 | """an unassuming Transformer block""" 240 | 241 | def __init__(self, config): 242 | super().__init__() 243 | self.ln1 = nn.LayerNorm(config.n_embd) 244 | self.ln2 = nn.LayerNorm(config.n_embd) 245 | self.attn = SelfAttention(config) 246 | self.mlp = nn.Sequential( 247 | nn.Linear(config.n_embd, 4 * config.n_embd), 248 | nn.GELU(), # nice 249 | nn.Linear(4 * config.n_embd, config.n_embd), 250 | nn.Dropout(config.resid_pdrop), 251 | ) 252 | 253 | def forward(self, x): 254 | x = x + self.attn(self.ln1(x)) 255 | x = x + self.mlp(self.ln2(x)) 256 | return x 257 | 258 | 259 | def make_class_grid( 260 | leftmost_val, 261 | rightmost_val, 262 | grid_size, 263 | add_extreme_offset: bool = False, 264 | seg_size_vframes: int = None, 265 | nseg: int = None, 266 | step_size_seg: float = None, 267 | vfps: float = None, 268 | ): 269 | assert grid_size >= 3, f"grid_size: {grid_size} doesnot make sense. If =2 -> (-1,1); =1 -> (-1); =0 -> ()" 270 | grid = torch.from_numpy(np.linspace(leftmost_val, rightmost_val, grid_size)).float() 271 | if add_extreme_offset: 272 | assert all([seg_size_vframes, nseg, step_size_seg]), f"{seg_size_vframes} {nseg} {step_size_seg}" 273 | seg_size_sec = seg_size_vframes / vfps 274 | trim_size_in_seg = nseg - (1 - step_size_seg) * (nseg - 1) 275 | extreme_value = trim_size_in_seg * seg_size_sec 276 | grid = torch.cat([grid, torch.tensor([extreme_value])]) # adding extreme offset to the class grid 277 | return grid 278 | 279 | 280 | # from synchformer 281 | def pad_or_truncate(audio: torch.Tensor, max_spec_t: int, pad_mode: str = "constant", pad_value: float = 0.0): 282 | difference = max_spec_t - audio.shape[-1] # safe for batched input 283 | # pad or truncate, depending on difference 284 | if difference > 0: 285 | # pad the last dim (time) -> (..., n_mels, 0+time+difference) # safe for batched input 286 | pad_dims = (0, difference) 287 | audio = torch.nn.functional.pad(audio, pad_dims, pad_mode, pad_value) 288 | elif difference < 0: 289 | print(f"Truncating spec ({audio.shape}) to max_spec_t ({max_spec_t}).") 290 | audio = audio[..., :max_spec_t] # safe for batched input 291 | return audio 292 | 293 | 294 | def encode_audio_with_sync( 295 | synchformer: Synchformer, x: torch.Tensor, mel: torchaudio.transforms.MelSpectrogram 296 | ) -> torch.Tensor: 297 | b, t = x.shape 298 | 299 | # partition the video 300 | segment_size = 10240 301 | step_size = 10240 // 2 302 | num_segments = (t - segment_size) // step_size + 1 303 | segments = [] 304 | for i in range(num_segments): 305 | segments.append(x[:, i * step_size : i * step_size + segment_size]) 306 | x = torch.stack(segments, dim=1) # (B, S, T, C, H, W) 307 | 308 | x = mel(x) 309 | x = torch.log(x + 1e-6) 310 | x = pad_or_truncate(x, 66) 311 | 312 | mean = -4.2677393 313 | std = 4.5689974 314 | x = (x - mean) / (2 * std) 315 | # x: B * S * 128 * 66 316 | x = synchformer.extract_afeats(x.unsqueeze(2)) 317 | return x 318 | 319 | 320 | def read_audio(filename, expected_length=int(16000 * 4)): 321 | waveform, sr = torchaudio.load(filename) 322 | waveform = waveform.mean(dim=0) 323 | 324 | if sr != 16000: 325 | resampler = torchaudio.transforms.Resample(sr, 16000) 326 | waveform = resampler[sr](waveform) 327 | 328 | waveform = waveform[:expected_length] 329 | if waveform.shape[0] != expected_length: 330 | raise ValueError(f"Audio {filename} is too short") 331 | 332 | waveform = waveform.squeeze() 333 | 334 | return waveform 335 | 336 | 337 | if __name__ == "__main__": 338 | synchformer = Synchformer().cuda().eval() 339 | 340 | # mmaudio provided synchformer ckpt 341 | synchformer.load_state_dict( 342 | torch.load( 343 | os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"), 344 | weights_only=True, 345 | map_location="cpu", 346 | ) 347 | ) 348 | 349 | sync_mel_spectrogram = torchaudio.transforms.MelSpectrogram( 350 | sample_rate=16000, 351 | win_length=400, 352 | hop_length=160, 353 | n_fft=1024, 354 | n_mels=128, 355 | ) 356 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/dac_vae/model/dac.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List 3 | from typing import Union 4 | 5 | import numpy as np 6 | import torch 7 | from audiotools import AudioSignal 8 | from audiotools.ml import BaseModel 9 | from torch import nn 10 | 11 | from .base import CodecMixin 12 | from ..nn.layers import Snake1d 13 | from ..nn.layers import WNConv1d 14 | from ..nn.layers import WNConvTranspose1d 15 | from ..nn.quantize import ResidualVectorQuantize 16 | from ..nn.vae_utils import DiagonalGaussianDistribution 17 | 18 | 19 | def init_weights(m): 20 | if isinstance(m, nn.Conv1d): 21 | nn.init.trunc_normal_(m.weight, std=0.02) 22 | nn.init.constant_(m.bias, 0) 23 | 24 | 25 | class ResidualUnit(nn.Module): 26 | def __init__(self, dim: int = 16, dilation: int = 1): 27 | super().__init__() 28 | pad = ((7 - 1) * dilation) // 2 29 | self.block = nn.Sequential( 30 | Snake1d(dim), 31 | WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), 32 | Snake1d(dim), 33 | WNConv1d(dim, dim, kernel_size=1), 34 | ) 35 | 36 | def forward(self, x): 37 | y = self.block(x) 38 | pad = (x.shape[-1] - y.shape[-1]) // 2 39 | if pad > 0: 40 | x = x[..., pad:-pad] 41 | return x + y 42 | 43 | 44 | class EncoderBlock(nn.Module): 45 | def __init__(self, dim: int = 16, stride: int = 1): 46 | super().__init__() 47 | self.block = nn.Sequential( 48 | ResidualUnit(dim // 2, dilation=1), 49 | ResidualUnit(dim // 2, dilation=3), 50 | ResidualUnit(dim // 2, dilation=9), 51 | Snake1d(dim // 2), 52 | WNConv1d( 53 | dim // 2, 54 | dim, 55 | kernel_size=2 * stride, 56 | stride=stride, 57 | padding=math.ceil(stride / 2), 58 | ), 59 | ) 60 | 61 | def forward(self, x): 62 | return self.block(x) 63 | 64 | 65 | class Encoder(nn.Module): 66 | def __init__( 67 | self, 68 | d_model: int = 64, 69 | strides: list = [2, 4, 8, 8], 70 | d_latent: int = 64, 71 | ): 72 | super().__init__() 73 | # Create first convolution 74 | self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] 75 | 76 | # Create EncoderBlocks that double channels as they downsample by `stride` 77 | for stride in strides: 78 | d_model *= 2 79 | self.block += [EncoderBlock(d_model, stride=stride)] 80 | 81 | # Create last convolution 82 | self.block += [ 83 | Snake1d(d_model), 84 | WNConv1d(d_model, d_latent, kernel_size=3, padding=1), 85 | ] 86 | 87 | # Wrap black into nn.Sequential 88 | self.block = nn.Sequential(*self.block) 89 | self.enc_dim = d_model 90 | 91 | def forward(self, x): 92 | return self.block(x) 93 | 94 | 95 | class DecoderBlock(nn.Module): 96 | def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): 97 | super().__init__() 98 | self.block = nn.Sequential( 99 | Snake1d(input_dim), 100 | WNConvTranspose1d( 101 | input_dim, 102 | output_dim, 103 | kernel_size=2 * stride, 104 | stride=stride, 105 | padding=math.ceil(stride / 2), 106 | output_padding=stride % 2, 107 | ), 108 | ResidualUnit(output_dim, dilation=1), 109 | ResidualUnit(output_dim, dilation=3), 110 | ResidualUnit(output_dim, dilation=9), 111 | ) 112 | 113 | def forward(self, x): 114 | return self.block(x) 115 | 116 | 117 | class Decoder(nn.Module): 118 | def __init__( 119 | self, 120 | input_channel, 121 | channels, 122 | rates, 123 | d_out: int = 1, 124 | ): 125 | super().__init__() 126 | 127 | # Add first conv layer 128 | layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] 129 | 130 | # Add upsampling + MRF blocks 131 | for i, stride in enumerate(rates): 132 | input_dim = channels // 2**i 133 | output_dim = channels // 2 ** (i + 1) 134 | layers += [DecoderBlock(input_dim, output_dim, stride)] 135 | 136 | # Add final conv layer 137 | layers += [ 138 | Snake1d(output_dim), 139 | WNConv1d(output_dim, d_out, kernel_size=7, padding=3), 140 | nn.Tanh(), 141 | ] 142 | 143 | self.model = nn.Sequential(*layers) 144 | 145 | def forward(self, x): 146 | return self.model(x) 147 | 148 | 149 | class DAC(BaseModel, CodecMixin): 150 | def __init__( 151 | self, 152 | encoder_dim: int = 64, 153 | encoder_rates: List[int] = [2, 4, 8, 8], 154 | latent_dim: int = None, 155 | decoder_dim: int = 1536, 156 | decoder_rates: List[int] = [8, 8, 4, 2], 157 | n_codebooks: int = 9, 158 | codebook_size: int = 1024, 159 | codebook_dim: Union[int, list] = 8, 160 | quantizer_dropout: bool = False, 161 | sample_rate: int = 44100, 162 | continuous: bool = False, 163 | ): 164 | super().__init__() 165 | 166 | self.encoder_dim = encoder_dim 167 | self.encoder_rates = encoder_rates 168 | self.decoder_dim = decoder_dim 169 | self.decoder_rates = decoder_rates 170 | self.sample_rate = sample_rate 171 | self.continuous = continuous 172 | 173 | if latent_dim is None: 174 | latent_dim = encoder_dim * (2 ** len(encoder_rates)) 175 | 176 | self.latent_dim = latent_dim 177 | 178 | self.hop_length = np.prod(encoder_rates) 179 | self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim) 180 | 181 | if not continuous: 182 | self.n_codebooks = n_codebooks 183 | self.codebook_size = codebook_size 184 | self.codebook_dim = codebook_dim 185 | self.quantizer = ResidualVectorQuantize( 186 | input_dim=latent_dim, 187 | n_codebooks=n_codebooks, 188 | codebook_size=codebook_size, 189 | codebook_dim=codebook_dim, 190 | quantizer_dropout=quantizer_dropout, 191 | ) 192 | else: 193 | self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1) 194 | self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1) 195 | 196 | self.decoder = Decoder( 197 | latent_dim, 198 | decoder_dim, 199 | decoder_rates, 200 | ) 201 | self.sample_rate = sample_rate 202 | self.apply(init_weights) 203 | 204 | self.delay = self.get_delay() 205 | 206 | @property 207 | def dtype(self): 208 | """Get the dtype of the model parameters.""" 209 | # Return the dtype of the first parameter found 210 | for param in self.parameters(): 211 | return param.dtype 212 | return torch.float32 # fallback 213 | 214 | @property 215 | def device(self): 216 | """Get the device of the model parameters.""" 217 | # Return the device of the first parameter found 218 | for param in self.parameters(): 219 | return param.device 220 | return torch.device('cpu') # fallback 221 | 222 | def preprocess(self, audio_data, sample_rate): 223 | if sample_rate is None: 224 | sample_rate = self.sample_rate 225 | assert sample_rate == self.sample_rate 226 | 227 | length = audio_data.shape[-1] 228 | right_pad = math.ceil(length / self.hop_length) * self.hop_length - length 229 | audio_data = nn.functional.pad(audio_data, (0, right_pad)) 230 | 231 | return audio_data 232 | 233 | def encode( 234 | self, 235 | audio_data: torch.Tensor, 236 | n_quantizers: int = None, 237 | ): 238 | """Encode given audio data and return quantized latent codes 239 | 240 | Parameters 241 | ---------- 242 | audio_data : Tensor[B x 1 x T] 243 | Audio data to encode 244 | n_quantizers : int, optional 245 | Number of quantizers to use, by default None 246 | If None, all quantizers are used. 247 | 248 | Returns 249 | ------- 250 | dict 251 | A dictionary with the following keys: 252 | "z" : Tensor[B x D x T] 253 | Quantized continuous representation of input 254 | "codes" : Tensor[B x N x T] 255 | Codebook indices for each codebook 256 | (quantized discrete representation of input) 257 | "latents" : Tensor[B x N*D x T] 258 | Projected latents (continuous representation of input before quantization) 259 | "vq/commitment_loss" : Tensor[1] 260 | Commitment loss to train encoder to predict vectors closer to codebook 261 | entries 262 | "vq/codebook_loss" : Tensor[1] 263 | Codebook loss to update the codebook 264 | "length" : int 265 | Number of samples in input audio 266 | """ 267 | z = self.encoder(audio_data) # [B x D x T] 268 | if not self.continuous: 269 | z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers) 270 | else: 271 | z = self.quant_conv(z) # [B x 2D x T] 272 | z = DiagonalGaussianDistribution(z) 273 | codes, latents, commitment_loss, codebook_loss = None, None, 0, 0 274 | 275 | return z, codes, latents, commitment_loss, codebook_loss 276 | 277 | def decode(self, z: torch.Tensor): 278 | """Decode given latent codes and return audio data 279 | 280 | Parameters 281 | ---------- 282 | z : Tensor[B x D x T] 283 | Quantized continuous representation of input 284 | length : int, optional 285 | Number of samples in output audio, by default None 286 | 287 | Returns 288 | ------- 289 | dict 290 | A dictionary with the following keys: 291 | "audio" : Tensor[B x 1 x length] 292 | Decoded audio data. 293 | """ 294 | if not self.continuous: 295 | audio = self.decoder(z) 296 | else: 297 | z = self.post_quant_conv(z) 298 | audio = self.decoder(z) 299 | 300 | return audio 301 | 302 | def forward( 303 | self, 304 | audio_data: torch.Tensor, 305 | sample_rate: int = None, 306 | n_quantizers: int = None, 307 | ): 308 | """Model forward pass 309 | 310 | Parameters 311 | ---------- 312 | audio_data : Tensor[B x 1 x T] 313 | Audio data to encode 314 | sample_rate : int, optional 315 | Sample rate of audio data in Hz, by default None 316 | If None, defaults to `self.sample_rate` 317 | n_quantizers : int, optional 318 | Number of quantizers to use, by default None. 319 | If None, all quantizers are used. 320 | 321 | Returns 322 | ------- 323 | dict 324 | A dictionary with the following keys: 325 | "z" : Tensor[B x D x T] 326 | Quantized continuous representation of input 327 | "codes" : Tensor[B x N x T] 328 | Codebook indices for each codebook 329 | (quantized discrete representation of input) 330 | "latents" : Tensor[B x N*D x T] 331 | Projected latents (continuous representation of input before quantization) 332 | "vq/commitment_loss" : Tensor[1] 333 | Commitment loss to train encoder to predict vectors closer to codebook 334 | entries 335 | "vq/codebook_loss" : Tensor[1] 336 | Codebook loss to update the codebook 337 | "length" : int 338 | Number of samples in input audio 339 | "audio" : Tensor[B x 1 x length] 340 | Decoded audio data. 341 | """ 342 | length = audio_data.shape[-1] 343 | audio_data = self.preprocess(audio_data, sample_rate) 344 | if not self.continuous: 345 | z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers) 346 | 347 | x = self.decode(z) 348 | return { 349 | "audio": x[..., :length], 350 | "z": z, 351 | "codes": codes, 352 | "latents": latents, 353 | "vq/commitment_loss": commitment_loss, 354 | "vq/codebook_loss": codebook_loss, 355 | } 356 | else: 357 | posterior, _, _, _, _ = self.encode(audio_data, n_quantizers) 358 | z = posterior.sample() 359 | x = self.decode(z) 360 | 361 | kl_loss = posterior.kl() 362 | kl_loss = kl_loss.mean() 363 | 364 | return { 365 | "audio": x[..., :length], 366 | "z": z, 367 | "kl_loss": kl_loss, 368 | } 369 | 370 | 371 | if __name__ == "__main__": 372 | import numpy as np 373 | from functools import partial 374 | 375 | model = DAC().to("cpu") 376 | 377 | for n, m in model.named_modules(): 378 | o = m.extra_repr() 379 | p = sum([np.prod(p.size()) for p in m.parameters()]) 380 | fn = lambda o, p: o + f" {p/1e6:<.3f}M params." 381 | setattr(m, "extra_repr", partial(fn, o=o, p=p)) 382 | print(model) 383 | print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) 384 | 385 | length = 88200 * 2 386 | x = torch.randn(1, 1, length).to(model.device) 387 | x.requires_grad_(True) 388 | x.retain_grad() 389 | 390 | # Make a forward pass 391 | out = model(x)["audio"] 392 | print("Input shape:", x.shape) 393 | print("Output shape:", out.shape) 394 | 395 | # Create gradient variable 396 | grad = torch.zeros_like(out) 397 | grad[:, :, grad.shape[-1] // 2] = 1 398 | 399 | # Make a backward pass 400 | out.backward(grad) 401 | 402 | # Check non-zero values 403 | gradmap = x.grad.squeeze(0) 404 | gradmap = (gradmap != 0).sum(0) # sum across features 405 | rf = (gradmap != 0).sum() 406 | 407 | print(f"Receptive field: {rf.item()}") 408 | 409 | x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100) 410 | model.decompress(model.compress(x, verbose=True), verbose=True) 411 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/models/synchformer/vit_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | # Copyright 2020 Ross Wightman 4 | # Modified Model definition 5 | """Video models.""" 6 | 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | from einops import rearrange, repeat 12 | from timm.layers import to_2tuple 13 | from torch import einsum 14 | from torch.nn import functional as F 15 | 16 | default_cfgs = { 17 | "vit_1k": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth", 18 | "vit_1k_large": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth", 19 | } 20 | 21 | 22 | def qkv_attn(q, k, v, tok_mask: torch.Tensor = None): 23 | sim = einsum("b i d, b j d -> b i j", q, k) 24 | # apply masking if provided, tok_mask is (B*S*H, N): 1s - keep; sim is (B*S*H, H, N, N) 25 | if tok_mask is not None: 26 | BSH, N = tok_mask.shape 27 | sim = sim.masked_fill(tok_mask.view(BSH, 1, N) == 0, float("-inf")) # 1 - broadcasts across N 28 | attn = sim.softmax(dim=-1) 29 | out = einsum("b i j, b j d -> b i d", attn, v) 30 | return out 31 | 32 | 33 | class DividedAttention(nn.Module): 34 | 35 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0): 36 | super().__init__() 37 | self.num_heads = num_heads 38 | head_dim = dim // num_heads 39 | self.scale = head_dim**-0.5 40 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 41 | self.proj = nn.Linear(dim, dim) 42 | 43 | # init to zeros 44 | self.qkv.weight.data.fill_(0) 45 | self.qkv.bias.data.fill_(0) 46 | self.proj.weight.data.fill_(1) 47 | self.proj.bias.data.fill_(0) 48 | 49 | self.attn_drop = nn.Dropout(attn_drop) 50 | self.proj_drop = nn.Dropout(proj_drop) 51 | 52 | def forward(self, x, einops_from, einops_to, tok_mask: torch.Tensor = None, **einops_dims): 53 | # num of heads variable 54 | h = self.num_heads 55 | 56 | # project x to q, k, v vaalues 57 | q, k, v = self.qkv(x).chunk(3, dim=-1) 58 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 59 | if tok_mask is not None: 60 | # replicate token mask across heads (b, n) -> (b, h, n) -> (b*h, n) -- same as qkv but w/o d 61 | assert len(tok_mask.shape) == 2 62 | tok_mask = tok_mask.unsqueeze(1).expand(-1, h, -1).reshape(-1, tok_mask.shape[1]) 63 | 64 | # Scale q 65 | q *= self.scale 66 | 67 | # Take out cls_q, cls_k, cls_v 68 | (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v)) 69 | # the same for masking 70 | if tok_mask is not None: 71 | cls_mask, mask_ = tok_mask[:, 0:1], tok_mask[:, 1:] 72 | else: 73 | cls_mask, mask_ = None, None 74 | 75 | # let CLS token attend to key / values of all patches across time and space 76 | cls_out = qkv_attn(cls_q, k, v, tok_mask=tok_mask) 77 | 78 | # rearrange across time or space 79 | q_, k_, v_ = map(lambda t: rearrange(t, f"{einops_from} -> {einops_to}", **einops_dims), (q_, k_, v_)) 80 | 81 | # expand CLS token keys and values across time or space and concat 82 | r = q_.shape[0] // cls_k.shape[0] 83 | cls_k, cls_v = map(lambda t: repeat(t, "b () d -> (b r) () d", r=r), (cls_k, cls_v)) 84 | 85 | k_ = torch.cat((cls_k, k_), dim=1) 86 | v_ = torch.cat((cls_v, v_), dim=1) 87 | 88 | # the same for masking (if provided) 89 | if tok_mask is not None: 90 | # since mask does not have the latent dim (d), we need to remove it from einops dims 91 | mask_ = rearrange(mask_, f"{einops_from} -> {einops_to}".replace(" d", ""), **einops_dims) 92 | cls_mask = repeat(cls_mask, "b () -> (b r) ()", r=r) # expand cls_mask across time or space 93 | mask_ = torch.cat((cls_mask, mask_), dim=1) 94 | 95 | # attention 96 | out = qkv_attn(q_, k_, v_, tok_mask=mask_) 97 | 98 | # merge back time or space 99 | out = rearrange(out, f"{einops_to} -> {einops_from}", **einops_dims) 100 | 101 | # concat back the cls token 102 | out = torch.cat((cls_out, out), dim=1) 103 | 104 | # merge back the heads 105 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 106 | 107 | ## to out 108 | x = self.proj(out) 109 | x = self.proj_drop(x) 110 | return x 111 | 112 | 113 | class DividedSpaceTimeBlock(nn.Module): 114 | 115 | def __init__( 116 | self, 117 | dim=768, 118 | num_heads=12, 119 | attn_type="divided", 120 | mlp_ratio=4.0, 121 | qkv_bias=False, 122 | drop=0.0, 123 | attn_drop=0.0, 124 | drop_path=0.0, 125 | act_layer=nn.GELU, 126 | norm_layer=nn.LayerNorm, 127 | ): 128 | super().__init__() 129 | 130 | self.einops_from_space = "b (f n) d" 131 | self.einops_to_space = "(b f) n d" 132 | self.einops_from_time = "b (f n) d" 133 | self.einops_to_time = "(b n) f d" 134 | 135 | self.norm1 = norm_layer(dim) 136 | 137 | self.attn = DividedAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 138 | 139 | self.timeattn = DividedAttention( 140 | dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop 141 | ) 142 | 143 | # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 144 | self.drop_path = nn.Identity() 145 | self.norm2 = norm_layer(dim) 146 | mlp_hidden_dim = int(dim * mlp_ratio) 147 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 148 | self.norm3 = norm_layer(dim) 149 | 150 | def forward(self, x, seq_len=196, num_frames=8, approx="none", num_landmarks=128, tok_mask: torch.Tensor = None): 151 | time_output = self.timeattn( 152 | self.norm3(x), self.einops_from_time, self.einops_to_time, n=seq_len, tok_mask=tok_mask 153 | ) 154 | time_residual = x + time_output 155 | 156 | space_output = self.attn( 157 | self.norm1(time_residual), self.einops_from_space, self.einops_to_space, f=num_frames, tok_mask=tok_mask 158 | ) 159 | space_residual = time_residual + self.drop_path(space_output) 160 | 161 | x = space_residual 162 | x = x + self.drop_path(self.mlp(self.norm2(x))) 163 | return x 164 | 165 | 166 | class Mlp(nn.Module): 167 | 168 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): 169 | super().__init__() 170 | out_features = out_features or in_features 171 | hidden_features = hidden_features or in_features 172 | self.fc1 = nn.Linear(in_features, hidden_features) 173 | self.act = act_layer() 174 | self.fc2 = nn.Linear(hidden_features, out_features) 175 | self.drop = nn.Dropout(drop) 176 | 177 | def forward(self, x): 178 | x = self.fc1(x) 179 | x = self.act(x) 180 | x = self.drop(x) 181 | x = self.fc2(x) 182 | x = self.drop(x) 183 | return x 184 | 185 | 186 | class PatchEmbed(nn.Module): 187 | """Image to Patch Embedding""" 188 | 189 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 190 | super().__init__() 191 | img_size = img_size if type(img_size) is tuple else to_2tuple(img_size) 192 | patch_size = img_size if type(patch_size) is tuple else to_2tuple(patch_size) 193 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 194 | self.img_size = img_size 195 | self.patch_size = patch_size 196 | self.num_patches = num_patches 197 | 198 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 199 | 200 | def forward(self, x): 201 | B, C, H, W = x.shape 202 | x = self.proj(x).flatten(2).transpose(1, 2) 203 | return x 204 | 205 | 206 | class PatchEmbed3D(nn.Module): 207 | """Image to Patch Embedding""" 208 | 209 | def __init__( 210 | self, 211 | img_size=224, 212 | temporal_resolution=4, 213 | in_chans=3, 214 | patch_size=16, 215 | z_block_size=2, 216 | embed_dim=768, 217 | flatten=True, 218 | ): 219 | super().__init__() 220 | self.height = img_size // patch_size 221 | self.width = img_size // patch_size 222 | ### v-iashin: these two are incorrect 223 | # self.frames = (temporal_resolution // z_block_size) 224 | # self.num_patches = self.height * self.width * self.frames 225 | self.z_block_size = z_block_size 226 | ### 227 | self.proj = nn.Conv3d( 228 | in_chans, 229 | embed_dim, 230 | kernel_size=(z_block_size, patch_size, patch_size), 231 | stride=(z_block_size, patch_size, patch_size), 232 | ) 233 | self.flatten = flatten 234 | 235 | def forward(self, x): 236 | B, C, T, H, W = x.shape 237 | x = self.proj(x) 238 | if self.flatten: 239 | x = x.flatten(2).transpose(1, 2) 240 | return x 241 | 242 | 243 | class HeadMLP(nn.Module): 244 | 245 | def __init__(self, n_input, n_classes, n_hidden=512, p=0.1): 246 | super(HeadMLP, self).__init__() 247 | self.n_input = n_input 248 | self.n_classes = n_classes 249 | self.n_hidden = n_hidden 250 | if n_hidden is None: 251 | # use linear classifier 252 | self.block_forward = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_input, n_classes, bias=True)) 253 | else: 254 | # use simple MLP classifier 255 | self.block_forward = nn.Sequential( 256 | nn.Dropout(p=p), 257 | nn.Linear(n_input, n_hidden, bias=True), 258 | nn.BatchNorm1d(n_hidden), 259 | nn.ReLU(inplace=True), 260 | nn.Dropout(p=p), 261 | nn.Linear(n_hidden, n_classes, bias=True), 262 | ) 263 | print(f"Dropout-NLP: {p}") 264 | 265 | def forward(self, x): 266 | return self.block_forward(x) 267 | 268 | 269 | def _conv_filter(state_dict, patch_size=16): 270 | """convert patch embedding weight from manual patchify + linear proj to conv""" 271 | out_dict = {} 272 | for k, v in state_dict.items(): 273 | if "patch_embed.proj.weight" in k: 274 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 275 | out_dict[k] = v 276 | return out_dict 277 | 278 | 279 | def adapt_input_conv(in_chans, conv_weight, agg="sum"): 280 | conv_type = conv_weight.dtype 281 | conv_weight = conv_weight.float() 282 | O, I, J, K = conv_weight.shape 283 | if in_chans == 1: 284 | if I > 3: 285 | assert conv_weight.shape[1] % 3 == 0 286 | # For models with space2depth stems 287 | conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) 288 | conv_weight = conv_weight.sum(dim=2, keepdim=False) 289 | else: 290 | if agg == "sum": 291 | print("Summing conv1 weights") 292 | conv_weight = conv_weight.sum(dim=1, keepdim=True) 293 | else: 294 | print("Averaging conv1 weights") 295 | conv_weight = conv_weight.mean(dim=1, keepdim=True) 296 | elif in_chans != 3: 297 | if I != 3: 298 | raise NotImplementedError("Weight format not supported by conversion.") 299 | else: 300 | if agg == "sum": 301 | print("Summing conv1 weights") 302 | repeat = int(math.ceil(in_chans / 3)) 303 | conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] 304 | conv_weight *= 3 / float(in_chans) 305 | else: 306 | print("Averaging conv1 weights") 307 | conv_weight = conv_weight.mean(dim=1, keepdim=True) 308 | conv_weight = conv_weight.repeat(1, in_chans, 1, 1) 309 | conv_weight = conv_weight.to(conv_type) 310 | return conv_weight 311 | 312 | 313 | def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): 314 | # Load state dict 315 | assert f"{cfg.VIT.PRETRAINED_WEIGHTS} not in [vit_1k, vit_1k_large]" 316 | state_dict = torch.hub.load_state_dict_from_url(url=default_cfgs[cfg.VIT.PRETRAINED_WEIGHTS]) 317 | 318 | if filter_fn is not None: 319 | state_dict = filter_fn(state_dict) 320 | 321 | input_convs = "patch_embed.proj" 322 | if input_convs is not None and in_chans != 3: 323 | if isinstance(input_convs, str): 324 | input_convs = (input_convs,) 325 | for input_conv_name in input_convs: 326 | weight_name = input_conv_name + ".weight" 327 | try: 328 | state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name], agg="avg") 329 | print(f"Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)") 330 | except NotImplementedError as e: 331 | del state_dict[weight_name] 332 | strict = False 333 | print(f"Unable to convert pretrained {input_conv_name} weights, using random init for this layer.") 334 | 335 | classifier_name = "head" 336 | label_offset = cfg.get("label_offset", 0) 337 | pretrain_classes = 1000 338 | if num_classes != pretrain_classes: 339 | # completely discard fully connected if model num_classes doesn't match pretrained weights 340 | del state_dict[classifier_name + ".weight"] 341 | del state_dict[classifier_name + ".bias"] 342 | strict = False 343 | elif label_offset > 0: 344 | # special case for pretrained weights with an extra background class in pretrained weights 345 | classifier_weight = state_dict[classifier_name + ".weight"] 346 | state_dict[classifier_name + ".weight"] = classifier_weight[label_offset:] 347 | classifier_bias = state_dict[classifier_name + ".bias"] 348 | state_dict[classifier_name + ".bias"] = classifier_bias[label_offset:] 349 | 350 | loaded_state = state_dict 351 | self_state = model.state_dict() 352 | all_names = set(self_state.keys()) 353 | saved_names = set([]) 354 | for name, param in loaded_state.items(): 355 | param = param 356 | if "module." in name: 357 | name = name.replace("module.", "") 358 | if name in self_state.keys() and param.shape == self_state[name].shape: 359 | saved_names.add(name) 360 | self_state[name].copy_(param) 361 | else: 362 | print(f"didnt load: {name} of shape: {param.shape}") 363 | print("Missing Keys:") 364 | print(all_names - saved_names) 365 | -------------------------------------------------------------------------------- /hunyuanvideo_foley/utils/schedulers/scheduling_flow_match_discrete.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from typing import Optional, Tuple, Union 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from diffusers.configuration_utils import ConfigMixin, register_to_config 9 | from diffusers.utils import BaseOutput, logging 10 | from diffusers.schedulers.scheduling_utils import SchedulerMixin 11 | 12 | 13 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 14 | 15 | 16 | @dataclass 17 | class FlowMatchDiscreteSchedulerOutput(BaseOutput): 18 | """ 19 | Output class for the scheduler's `step` function output. 20 | 21 | Args: 22 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 23 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 24 | denoising loop. 25 | """ 26 | 27 | prev_sample: torch.FloatTensor 28 | 29 | 30 | class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin): 31 | """ 32 | Euler scheduler. 33 | 34 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic 35 | methods the library implements for all schedulers such as loading and saving. 36 | 37 | Args: 38 | num_train_timesteps (`int`, defaults to 1000): 39 | The number of diffusion steps to train the model. 40 | timestep_spacing (`str`, defaults to `"linspace"`): 41 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 42 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 43 | shift (`float`, defaults to 1.0): 44 | The shift value for the timestep schedule. 45 | reverse (`bool`, defaults to `True`): 46 | Whether to reverse the timestep schedule. 47 | """ 48 | 49 | _compatibles = [] 50 | order = 1 51 | 52 | @register_to_config 53 | def __init__( 54 | self, 55 | num_train_timesteps: int = 1000, 56 | shift: float = 1.0, 57 | reverse: bool = True, 58 | solver: str = "euler", 59 | use_flux_shift: bool = False, 60 | flux_base_shift: float = 0.5, 61 | flux_max_shift: float = 1.15, 62 | n_tokens: Optional[int] = None, 63 | ): 64 | sigmas = torch.linspace(1, 0, num_train_timesteps + 1) 65 | 66 | if not reverse: 67 | sigmas = sigmas.flip(0) 68 | 69 | self.sigmas = sigmas 70 | # the value fed to model 71 | self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32) 72 | self.timesteps_full = (sigmas * num_train_timesteps).to(dtype=torch.float32) 73 | 74 | self._step_index = None 75 | self._begin_index = None 76 | 77 | self.supported_solver = [ 78 | "euler", 79 | "heun-2", "midpoint-2", 80 | "kutta-4", 81 | ] 82 | if solver not in self.supported_solver: 83 | raise ValueError(f"Solver {solver} not supported. Supported solvers: {self.supported_solver}") 84 | 85 | # empty dt and derivative (for heun) 86 | self.derivative_1 = None 87 | self.derivative_2 = None 88 | self.derivative_3 = None 89 | self.dt = None 90 | 91 | @property 92 | def step_index(self): 93 | """ 94 | The index counter for current timestep. It will increase 1 after each scheduler step. 95 | """ 96 | return self._step_index 97 | 98 | @property 99 | def begin_index(self): 100 | """ 101 | The index for the first timestep. It should be set from pipeline with `set_begin_index` method. 102 | """ 103 | return self._begin_index 104 | 105 | # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index 106 | def set_begin_index(self, begin_index: int = 0): 107 | """ 108 | Sets the begin index for the scheduler. This function should be run from pipeline before the inference. 109 | 110 | Args: 111 | begin_index (`int`): 112 | The begin index for the scheduler. 113 | """ 114 | self._begin_index = begin_index 115 | 116 | def _sigma_to_t(self, sigma): 117 | return sigma * self.config.num_train_timesteps 118 | 119 | @property 120 | def state_in_first_order(self): 121 | return self.derivative_1 is None 122 | 123 | @property 124 | def state_in_second_order(self): 125 | return self.derivative_2 is None 126 | 127 | @property 128 | def state_in_third_order(self): 129 | return self.derivative_3 is None 130 | 131 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, 132 | n_tokens: int = None): 133 | """ 134 | Sets the discrete timesteps used for the diffusion chain (to be run before inference). 135 | 136 | Args: 137 | num_inference_steps (`int`): 138 | The number of diffusion steps used when generating samples with a pre-trained model. 139 | device (`str` or `torch.device`, *optional*): 140 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 141 | n_tokens (`int`, *optional*): 142 | Number of tokens in the input sequence. 143 | """ 144 | self.num_inference_steps = num_inference_steps 145 | 146 | sigmas = torch.linspace(1, 0, num_inference_steps + 1) 147 | 148 | # Apply timestep shift 149 | if self.config.use_flux_shift: 150 | assert isinstance(n_tokens, int), "n_tokens should be provided for flux shift" 151 | mu = self.get_lin_function(y1=self.config.flux_base_shift, y2=self.config.flux_max_shift)(n_tokens) 152 | sigmas = self.flux_time_shift(mu, 1.0, sigmas) 153 | elif self.config.shift != 1.: 154 | sigmas = self.sd3_time_shift(sigmas) 155 | 156 | if not self.config.reverse: 157 | sigmas = 1 - sigmas 158 | 159 | self.sigmas = sigmas 160 | self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(dtype=torch.float32, device=device) 161 | self.timesteps_full = (sigmas * self.config.num_train_timesteps).to(dtype=torch.float32, device=device) 162 | 163 | # empty dt and derivative (for kutta) 164 | self.derivative_1 = None 165 | self.derivative_2 = None 166 | self.derivative_3 = None 167 | self.dt = None 168 | 169 | # Reset step index 170 | self._step_index = None 171 | 172 | def index_for_timestep(self, timestep, schedule_timesteps=None): 173 | if schedule_timesteps is None: 174 | schedule_timesteps = self.timesteps 175 | 176 | indices = (schedule_timesteps == timestep).nonzero() 177 | 178 | # The sigma index that is taken for the **very** first `step` 179 | # is always the second index (or the last index if there is only 1) 180 | # This way we can ensure we don't accidentally skip a sigma in 181 | # case we start in the middle of the denoising schedule (e.g. for image-to-image) 182 | pos = 1 if len(indices) > 1 else 0 183 | 184 | return indices[pos].item() 185 | 186 | def _init_step_index(self, timestep): 187 | if self.begin_index is None: 188 | if isinstance(timestep, torch.Tensor): 189 | timestep = timestep.to(self.timesteps.device) 190 | self._step_index = self.index_for_timestep(timestep) 191 | else: 192 | self._step_index = self._begin_index 193 | 194 | def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: 195 | return sample 196 | 197 | @staticmethod 198 | def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15): 199 | m = (y2 - y1) / (x2 - x1) 200 | b = y1 - m * x1 201 | return lambda x: m * x + b 202 | 203 | @staticmethod 204 | def flux_time_shift(mu: float, sigma: float, t: torch.Tensor): 205 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) 206 | 207 | def sd3_time_shift(self, t: torch.Tensor): 208 | return (self.config.shift * t) / (1 + (self.config.shift - 1) * t) 209 | 210 | def step( 211 | self, 212 | model_output: torch.FloatTensor, 213 | timestep: Union[float, torch.FloatTensor], 214 | sample: torch.FloatTensor, 215 | pred_uncond: torch.FloatTensor = None, 216 | generator: Optional[torch.Generator] = None, 217 | n_tokens: Optional[int] = None, 218 | return_dict: bool = True, 219 | ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]: 220 | """ 221 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 222 | process from the learned model outputs (most often the predicted noise). 223 | 224 | Args: 225 | model_output (`torch.FloatTensor`): 226 | The direct output from learned diffusion model. 227 | timestep (`float`): 228 | The current discrete timestep in the diffusion chain. 229 | sample (`torch.FloatTensor`): 230 | A current instance of a sample created by the diffusion process. 231 | generator (`torch.Generator`, *optional*): 232 | A random number generator. 233 | n_tokens (`int`, *optional*): 234 | Number of tokens in the input sequence. 235 | return_dict (`bool`): 236 | Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or 237 | tuple. 238 | 239 | Returns: 240 | [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: 241 | If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is 242 | returned, otherwise a tuple is returned where the first element is the sample tensor. 243 | """ 244 | 245 | if ( 246 | isinstance(timestep, int) 247 | or isinstance(timestep, torch.IntTensor) 248 | or isinstance(timestep, torch.LongTensor) 249 | ): 250 | raise ValueError( 251 | ( 252 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 253 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 254 | " one of the `scheduler.timesteps` as a timestep." 255 | ), 256 | ) 257 | 258 | if self.step_index is None: 259 | self._init_step_index(timestep) 260 | 261 | # Upcast to avoid precision issues when computing prev_sample 262 | sample = sample.to(torch.float32) 263 | model_output = model_output.to(torch.float32) 264 | pred_uncond = pred_uncond.to(torch.float32) if pred_uncond is not None else None 265 | 266 | # dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index] 267 | sigma = self.sigmas[self.step_index] 268 | sigma_next = self.sigmas[self.step_index + 1] 269 | 270 | last_inner_step = True 271 | if self.config.solver == "euler": 272 | derivative, dt, sample, last_inner_step = self.first_order_method(model_output, sigma, sigma_next, sample) 273 | elif self.config.solver in ["heun-2", "midpoint-2"]: 274 | derivative, dt, sample, last_inner_step = self.second_order_method(model_output, sigma, sigma_next, sample) 275 | elif self.config.solver == "kutta-4": 276 | derivative, dt, sample, last_inner_step = self.fourth_order_method(model_output, sigma, sigma_next, sample) 277 | else: 278 | raise ValueError(f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}") 279 | 280 | prev_sample = sample + derivative * dt 281 | 282 | # Cast sample back to model compatible dtype 283 | # prev_sample = prev_sample.to(model_output.dtype) 284 | 285 | # upon completion increase step index by one 286 | if last_inner_step: 287 | self._step_index += 1 288 | 289 | if not return_dict: 290 | return (prev_sample,) 291 | 292 | return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample) 293 | 294 | def first_order_method(self, model_output, sigma, sigma_next, sample): 295 | derivative = model_output.float() 296 | dt = sigma_next - sigma 297 | return derivative, dt, sample, True 298 | 299 | def second_order_method(self, model_output, sigma, sigma_next, sample): 300 | if self.state_in_first_order: 301 | # store for 2nd order step 302 | self.derivative_1 = model_output 303 | self.dt = sigma_next - sigma 304 | self.sample = sample 305 | 306 | derivative = model_output 307 | if self.config.solver == 'heun-2': 308 | dt = self.dt 309 | elif self.config.solver == 'midpoint-2': 310 | dt = self.dt / 2 311 | else: 312 | raise NotImplementedError(f"Solver {self.config.solver} not supported.") 313 | last_inner_step = False 314 | 315 | else: 316 | if self.config.solver == 'heun-2': 317 | derivative = 0.5 * (self.derivative_1 + model_output) 318 | elif self.config.solver == 'midpoint-2': 319 | derivative = model_output 320 | else: 321 | raise NotImplementedError(f"Solver {self.config.solver} not supported.") 322 | 323 | # 3. take prev timestep & sample 324 | dt = self.dt 325 | sample = self.sample 326 | last_inner_step = True 327 | 328 | # free dt and derivative 329 | # Note, this puts the scheduler in "first order mode" 330 | self.derivative_1 = None 331 | self.dt = None 332 | self.sample = None 333 | 334 | return derivative, dt, sample, last_inner_step 335 | 336 | def fourth_order_method(self, model_output, sigma, sigma_next, sample): 337 | if self.state_in_first_order: 338 | self.derivative_1 = model_output 339 | self.dt = sigma_next - sigma 340 | self.sample = sample 341 | derivative = model_output 342 | dt = self.dt / 2 343 | last_inner_step = False 344 | 345 | elif self.state_in_second_order: 346 | self.derivative_2 = model_output 347 | derivative = model_output 348 | dt = self.dt / 2 349 | last_inner_step = False 350 | 351 | elif self.state_in_third_order: 352 | self.derivative_3 = model_output 353 | derivative = model_output 354 | dt = self.dt 355 | last_inner_step = False 356 | 357 | else: 358 | derivative = 1/6 * self.derivative_1 + 1/3 * self.derivative_2 + 1/3 * self.derivative_3 + 1/6 * model_output 359 | 360 | # 3. take prev timestep & sample 361 | dt = self.dt 362 | sample = self.sample 363 | last_inner_step = True 364 | 365 | # free dt and derivative 366 | # Note, this puts the scheduler in "first order mode" 367 | self.derivative_1 = None 368 | self.derivative_2 = None 369 | self.derivative_3 = None 370 | self.dt = None 371 | self.sample = None 372 | 373 | return derivative, dt, sample, last_inner_step 374 | 375 | def __len__(self): 376 | return self.config.num_train_timesteps 377 | --------------------------------------------------------------------------------