├── audiox ├── stable_audio_tools │ ├── data │ │ ├── __init__.py │ │ └── utils.py │ ├── inference │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── sampling.py │ │ └── generation.py │ ├── interface │ │ └── __init__.py │ ├── training │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── losses.cpython-312.pyc │ │ │ │ └── __init__.cpython-312.pyc │ │ │ └── losses.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── factory.cpython-312.pyc │ │ │ └── __init__.cpython-312.pyc │ │ ├── utils.py │ │ ├── lm.py │ │ └── factory.py │ ├── models │ │ ├── __init__.py │ │ ├── pretrained.py │ │ ├── wavelets.py │ │ ├── utils.py │ │ ├── factory.py │ │ ├── temptransformer.py │ │ ├── local_attention.py │ │ ├── pretransforms.py │ │ ├── bottleneck.py │ │ └── blocks.py │ └── __init__.py ├── setup.py └── README.md ├── web ├── appearance.js └── audiox.js ├── LICENSE ├── requirements.txt ├── debug_comfyui_loading.py ├── install_dependencies.py ├── ENHANCED_VIDEO_TO_AUDIO.md ├── examples └── audiox_txt2music+audio.json ├── __init__.py └── README.md /audiox/stable_audio_tools/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audiox/stable_audio_tools/inference/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audiox/stable_audio_tools/interface/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audiox/stable_audio_tools/training/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import * -------------------------------------------------------------------------------- /audiox/stable_audio_tools/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import create_model_from_config, create_model_from_config_path -------------------------------------------------------------------------------- /audiox/stable_audio_tools/training/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import create_training_wrapper_from_config, create_demo_callback_from_config 2 | -------------------------------------------------------------------------------- /audiox/stable_audio_tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .models.factory import create_model_from_config, create_model_from_config_path 2 | from .models.pretrained import get_pretrained_model -------------------------------------------------------------------------------- /audiox/stable_audio_tools/training/__pycache__/factory.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lum3on/ComfyUI-StableAudioX/HEAD/audiox/stable_audio_tools/training/__pycache__/factory.cpython-312.pyc -------------------------------------------------------------------------------- /audiox/stable_audio_tools/training/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lum3on/ComfyUI-StableAudioX/HEAD/audiox/stable_audio_tools/training/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /audiox/stable_audio_tools/training/losses/__pycache__/losses.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lum3on/ComfyUI-StableAudioX/HEAD/audiox/stable_audio_tools/training/losses/__pycache__/losses.cpython-312.pyc -------------------------------------------------------------------------------- /audiox/stable_audio_tools/training/losses/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lum3on/ComfyUI-StableAudioX/HEAD/audiox/stable_audio_tools/training/losses/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /audiox/stable_audio_tools/models/pretrained.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from .factory import create_model_from_config 4 | from .utils import load_ckpt_state_dict 5 | 6 | from huggingface_hub import hf_hub_download 7 | 8 | def get_pretrained_model(name: str): 9 | 10 | model_config_path = hf_hub_download(name, filename="config.json", repo_type='model') 11 | 12 | with open(model_config_path) as f: 13 | model_config = json.load(f) 14 | 15 | model = create_model_from_config(model_config) 16 | 17 | # Try to download the model.safetensors file first, if it doesn't exist, download the model.ckpt file 18 | try: 19 | model_ckpt_path = hf_hub_download(name, filename="model.safetensors", repo_type='model') 20 | except Exception as e: 21 | model_ckpt_path = hf_hub_download(name, filename="model.ckpt", repo_type='model') 22 | 23 | model.load_state_dict(load_ckpt_state_dict(model_ckpt_path)) 24 | 25 | return model, model_config -------------------------------------------------------------------------------- /audiox/stable_audio_tools/inference/utils.py: -------------------------------------------------------------------------------- 1 | from ..data.utils import PadCrop 2 | 3 | from torchaudio import transforms as T 4 | 5 | def set_audio_channels(audio, target_channels): 6 | if target_channels == 1: 7 | # Convert to mono 8 | audio = audio.mean(1, keepdim=True) 9 | elif target_channels == 2: 10 | # Convert to stereo 11 | if audio.shape[1] == 1: 12 | audio = audio.repeat(1, 2, 1) 13 | elif audio.shape[1] > 2: 14 | audio = audio[:, :2, :] 15 | return audio 16 | 17 | def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): 18 | 19 | audio = audio.to(device) 20 | 21 | if in_sr != target_sr: 22 | resample_tf = T.Resample(in_sr, target_sr).to(device) 23 | audio = resample_tf(audio) 24 | 25 | audio = PadCrop(target_length, randomize=False)(audio) 26 | 27 | # Add batch dimension 28 | if audio.dim() == 1: 29 | audio = audio.unsqueeze(0).unsqueeze(0) 30 | elif audio.dim() == 2: 31 | audio = audio.unsqueeze(0) 32 | 33 | audio = set_audio_channels(audio, target_channels) 34 | 35 | return audio -------------------------------------------------------------------------------- /web/appearance.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../scripts/app.js"; 2 | 3 | app.registerExtension({ 4 | name: "ComfyUI-AudioX.appearance", 5 | async nodeCreated(node) { 6 | // AudioX nodes styling - Apply styling 7 | if (node.comfyClass === "AudioXModelLoader" || 8 | node.comfyClass === "AudioXTextToAudio" || 9 | node.comfyClass === "AudioXEnhancedTextToAudio" || 10 | node.comfyClass === "AudioXTextToMusic" || 11 | node.comfyClass === "AudioXEnhancedTextToMusic" || 12 | node.comfyClass === "AudioXVideoToAudio" || 13 | node.comfyClass === "AudioXEnhancedVideoToAudio" || 14 | node.comfyClass === "AudioXVideoToMusic" || 15 | node.comfyClass === "AudioXMultiModalGeneration" || 16 | node.comfyClass === "AudioXAudioProcessor" || 17 | node.comfyClass === "AudioXVolumeControl" || 18 | node.comfyClass === "AudioXAdvancedVolumeControl" || 19 | node.comfyClass === "AudioXVideoMuter" || 20 | node.comfyClass === "AudioXVideoAudioCombiner" || 21 | node.comfyClass === "AudioXPromptHelper") { 22 | node.color = "#ddaeff"; 23 | node.bgcolor = "#a1cfa9"; 24 | } 25 | } 26 | }); -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 ComfyUI-AudioX Contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | --- 24 | 25 | Note: This license applies to the ComfyUI integration code. The underlying AudioX models 26 | and related components may have different licensing terms. Please refer to the original 27 | AudioX repository and model documentation for their specific license requirements. 28 | -------------------------------------------------------------------------------- /audiox/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='AudioX', 5 | version='0.1.0', 6 | url='https://github.com/ZeyueT/AudioX.git', 7 | author='AudioX, HKUST', 8 | description='Training and inference tools for generative audio models from AudioX', 9 | packages=find_packages(), 10 | install_requires=[ 11 | 'aeiou', 12 | 'alias-free-torch==0.0.6', 13 | 'auraloss==0.4.0', 14 | 'descript-audio-codec==1.0.0', 15 | 'decord==0.6.0', 16 | 'einops', 17 | 'einops_exts', 18 | 'ema-pytorch==0.2.3', 19 | 'encodec==0.1.1', 20 | 'gradio==4.44.1', 21 | 'gradio_client==1.3.0', 22 | 'huggingface_hub', 23 | 'importlib-resources==5.12.0', 24 | 'k-diffusion==0.1.1', 25 | 'laion-clap==1.1.6', 26 | 'local-attention==1.8.6', 27 | 'pandas==2.0.2', 28 | 'pedalboard==0.9.14', 29 | 'prefigure==0.0.9', 30 | 'pytorch_lightning==2.4.0', 31 | 'PyWavelets==1.4.1', 32 | 'safetensors', 33 | 'sentencepiece==0.1.99', 34 | 'torch>=2.0.1', 35 | 'torchaudio>=2.0.2', 36 | 'torchmetrics==0.11.4', 37 | 'tqdm', 38 | 'transformers', 39 | 'v-diffusion-pytorch==0.0.2', 40 | 'vector-quantize-pytorch==1.9.14', 41 | 'wandb', 42 | 'webdataset==0.2.48', 43 | 'x-transformers<1.27.0', 44 | ], 45 | 46 | ) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Essential dependencies for AudioX 2 | torch>=2.0.0 3 | torchaudio>=2.0.0 4 | transformers>=4.30.0 5 | huggingface_hub>=0.16.0 6 | einops>=0.6.0 7 | einops-exts>=0.0.4 8 | numpy>=1.21.0 9 | safetensors>=0.3.0 10 | 11 | # Audio processing and utilities 12 | librosa>=0.9.0 13 | soundfile>=0.10.0 14 | descript-audio-codec>=1.0.0 15 | pyloudnorm>=0.1.0 16 | 17 | # ComfyUI compatibility 18 | Pillow>=8.0.0 19 | 20 | # EMA utilities for model training 21 | ema_pytorch>=0.3.0 22 | 23 | # Transformer and attention mechanisms 24 | x-transformers>=1.0.0 25 | local-attention>=1.8.0 26 | 27 | # Audio-specific neural network components 28 | alias-free-torch>=0.0.6 29 | vector-quantize-pytorch>=1.0.0 30 | 31 | # Diffusion and generation 32 | k-diffusion>=0.1.0 33 | v-diffusion-pytorch>=0.0.2 34 | 35 | # Audio codecs and processing 36 | encodec>=0.1.1 37 | auraloss>=0.4.0 38 | 39 | # CLAP for audio-text understanding 40 | laion-clap>=1.1.0 41 | 42 | # Utilities 43 | aeiou 44 | prefigure>=0.0.9 45 | 46 | 47 | 48 | # Optional dependencies (install manually if needed): 49 | # Video processing (for advanced video workflows) 50 | # decord>=0.6.0 51 | # torchvision>=0.15.0 52 | 53 | # Advanced audio processing (may be slow to install) 54 | # pedalboard>=0.7.0 55 | 56 | # Phoneme conditioning (for advanced text processing) 57 | # g2p_en>=2.1.0 58 | 59 | # Dataset handling (for training workflows) 60 | # webdataset>=0.2.0 61 | 62 | # Flash attention (optional - often fails on Windows, may improve performance) 63 | # flash-attn>=2.0.0 64 | 65 | # Note: All core dependencies above are required for basic AudioX functionality. 66 | # The enhanced conditioning features, volume controls, and professional audio 67 | # processing capabilities are included in the core dependencies. 68 | -------------------------------------------------------------------------------- /audiox/stable_audio_tools/training/losses/losses.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | from torch.nn import functional as F 4 | from torch import nn 5 | import torch 6 | class LossModule(nn.Module): 7 | def __init__(self, name: str, weight: float = 1.0): 8 | super().__init__() 9 | 10 | self.name = name 11 | self.weight = weight 12 | 13 | def forward(self, info, *args, **kwargs): 14 | raise NotImplementedError 15 | 16 | class ValueLoss(LossModule): 17 | def __init__(self, key: str, name, weight: float = 1.0): 18 | super().__init__(name=name, weight=weight) 19 | 20 | self.key = key 21 | 22 | def forward(self, info): 23 | return self.weight * info[self.key] 24 | 25 | class L1Loss(LossModule): 26 | def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'l1_loss'): 27 | super().__init__(name=name, weight=weight) 28 | 29 | self.key_a = key_a 30 | self.key_b = key_b 31 | 32 | self.mask_key = mask_key 33 | 34 | def forward(self, info): 35 | mse_loss = F.l1_loss(info[self.key_a], info[self.key_b], reduction='none') 36 | 37 | if self.mask_key is not None and self.mask_key in info: 38 | mse_loss = mse_loss[info[self.mask_key]] 39 | 40 | mse_loss = mse_loss.mean() 41 | 42 | return self.weight * mse_loss 43 | 44 | class MSELoss(LossModule): 45 | def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'mse_loss'): 46 | super().__init__(name=name, weight=weight) 47 | 48 | self.key_a = key_a 49 | self.key_b = key_b 50 | 51 | self.mask_key = mask_key 52 | 53 | def forward(self, info): 54 | mse_loss = F.mse_loss(info[self.key_a], info[self.key_b], reduction='none') 55 | 56 | if self.mask_key is not None and self.mask_key in info and info[self.mask_key] is not None: 57 | mask = info[self.mask_key] 58 | 59 | if mask.ndim == 2 and mse_loss.ndim == 3: 60 | mask = mask.unsqueeze(1) 61 | 62 | if mask.shape[1] != mse_loss.shape[1]: 63 | mask = mask.repeat(1, mse_loss.shape[1], 1) 64 | 65 | mse_loss = mse_loss[mask] 66 | 67 | mse_loss = mse_loss.mean() 68 | 69 | return self.weight * mse_loss 70 | 71 | class AuralossLoss(LossModule): 72 | def __init__(self, auraloss_module, input_key: str, target_key: str, name: str, weight: float = 1): 73 | super().__init__(name, weight) 74 | 75 | self.auraloss_module = auraloss_module 76 | 77 | self.input_key = input_key 78 | self.target_key = target_key 79 | 80 | def forward(self, info): 81 | loss = self.auraloss_module(info[self.input_key], info[self.target_key]) 82 | 83 | return self.weight * loss 84 | 85 | class MultiLoss(nn.Module): 86 | def __init__(self, losses: tp.List[LossModule]): 87 | super().__init__() 88 | 89 | self.losses = nn.ModuleList(losses) 90 | 91 | def forward(self, info): 92 | total_loss = 0 93 | 94 | losses = {} 95 | 96 | for loss_module in self.losses: 97 | module_loss = loss_module(info) 98 | total_loss += module_loss 99 | losses[loss_module.name] = module_loss 100 | 101 | return total_loss, losses -------------------------------------------------------------------------------- /audiox/stable_audio_tools/models/wavelets.py: -------------------------------------------------------------------------------- 1 | """The 1D discrete wavelet transform for PyTorch.""" 2 | 3 | from einops import rearrange 4 | import pywt 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from typing import Literal 9 | 10 | 11 | def get_filter_bank(wavelet): 12 | filt = torch.tensor(pywt.Wavelet(wavelet).filter_bank) 13 | if wavelet.startswith("bior") and torch.all(filt[:, 0] == 0): 14 | filt = filt[:, 1:] 15 | return filt 16 | 17 | class WaveletEncode1d(nn.Module): 18 | def __init__(self, 19 | channels, 20 | levels, 21 | wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"): 22 | super().__init__() 23 | self.wavelet = wavelet 24 | self.channels = channels 25 | self.levels = levels 26 | filt = get_filter_bank(wavelet) 27 | assert filt.shape[-1] % 2 == 1 28 | kernel = filt[:2, None] 29 | kernel = torch.flip(kernel, dims=(-1,)) 30 | index_i = torch.repeat_interleave(torch.arange(2), channels) 31 | index_j = torch.tile(torch.arange(channels), (2,)) 32 | kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1]) 33 | kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0] 34 | self.register_buffer("kernel", kernel_final) 35 | 36 | def forward(self, x): 37 | for i in range(self.levels): 38 | low, rest = x[:, : self.channels], x[:, self.channels :] 39 | pad = self.kernel.shape[-1] // 2 40 | low = F.pad(low, (pad, pad), "reflect") 41 | low = F.conv1d(low, self.kernel, stride=2) 42 | rest = rearrange( 43 | rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels 44 | ) 45 | x = torch.cat([low, rest], dim=1) 46 | return x 47 | 48 | 49 | class WaveletDecode1d(nn.Module): 50 | def __init__(self, 51 | channels, 52 | levels, 53 | wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"): 54 | super().__init__() 55 | self.wavelet = wavelet 56 | self.channels = channels 57 | self.levels = levels 58 | filt = get_filter_bank(wavelet) 59 | assert filt.shape[-1] % 2 == 1 60 | kernel = filt[2:, None] 61 | index_i = torch.repeat_interleave(torch.arange(2), channels) 62 | index_j = torch.tile(torch.arange(channels), (2,)) 63 | kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1]) 64 | kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0] 65 | self.register_buffer("kernel", kernel_final) 66 | 67 | def forward(self, x): 68 | for i in range(self.levels): 69 | low, rest = x[:, : self.channels * 2], x[:, self.channels * 2 :] 70 | pad = self.kernel.shape[-1] // 2 + 2 71 | low = rearrange(low, "n (l2 c) l -> n c (l l2)", l2=2) 72 | low = F.pad(low, (pad, pad), "reflect") 73 | low = rearrange(low, "n c (l l2) -> n (l2 c) l", l2=2) 74 | low = F.conv_transpose1d( 75 | low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2 76 | ) 77 | low = low[..., pad - 1 : -pad] 78 | rest = rearrange( 79 | rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels 80 | ) 81 | x = torch.cat([low, rest], dim=1) 82 | return x -------------------------------------------------------------------------------- /audiox/stable_audio_tools/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from safetensors.torch import load_file 3 | 4 | from torch.nn.utils import remove_weight_norm 5 | import warnings 6 | warnings.simplefilter(action='ignore', category=FutureWarning) 7 | 8 | 9 | def load_ckpt_state_dict(ckpt_path): 10 | if ckpt_path.endswith(".safetensors"): 11 | state_dict = load_file(ckpt_path) 12 | else: 13 | state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] 14 | 15 | return state_dict 16 | 17 | def remove_weight_norm_from_model(model): 18 | for module in model.modules(): 19 | if hasattr(module, "weight"): 20 | print(f"Removing weight norm from {module}") 21 | remove_weight_norm(module) 22 | 23 | return model 24 | 25 | # Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license 26 | # License can be found in LICENSES/LICENSE_META.txt 27 | 28 | def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): 29 | """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. 30 | 31 | Args: 32 | input (torch.Tensor): The input tensor containing probabilities. 33 | num_samples (int): Number of samples to draw. 34 | replacement (bool): Whether to draw with replacement or not. 35 | Keywords args: 36 | generator (torch.Generator): A pseudorandom number generator for sampling. 37 | Returns: 38 | torch.Tensor: Last dimension contains num_samples indices 39 | sampled from the multinomial probability distribution 40 | located in the last dimension of tensor input. 41 | """ 42 | 43 | if num_samples == 1: 44 | q = torch.empty_like(input).exponential_(1, generator=generator) 45 | return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64) 46 | 47 | input_ = input.reshape(-1, input.shape[-1]) 48 | output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator) 49 | output = output_.reshape(*list(input.shape[:-1]), -1) 50 | return output 51 | 52 | 53 | def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: 54 | """Sample next token from top K values along the last dimension of the input probs tensor. 55 | 56 | Args: 57 | probs (torch.Tensor): Input probabilities with token candidates on the last dimension. 58 | k (int): The k in “top-k”. 59 | Returns: 60 | torch.Tensor: Sampled tokens. 61 | """ 62 | top_k_value, _ = torch.topk(probs, k, dim=-1) 63 | min_value_top_k = top_k_value[..., [-1]] 64 | probs *= (probs >= min_value_top_k).float() 65 | probs.div_(probs.sum(dim=-1, keepdim=True)) 66 | next_token = multinomial(probs, num_samples=1) 67 | return next_token 68 | 69 | 70 | def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: 71 | """Sample next token from top P probabilities along the last dimension of the input probs tensor. 72 | 73 | Args: 74 | probs (torch.Tensor): Input probabilities with token candidates on the last dimension. 75 | p (int): The p in “top-p”. 76 | Returns: 77 | torch.Tensor: Sampled tokens. 78 | """ 79 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 80 | probs_sum = torch.cumsum(probs_sort, dim=-1) 81 | mask = probs_sum - probs_sort > p 82 | probs_sort *= (~mask).float() 83 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 84 | next_token = multinomial(probs_sort, num_samples=1) 85 | next_token = torch.gather(probs_idx, -1, next_token) 86 | return next_token 87 | 88 | def next_power_of_two(n): 89 | return 2 ** (n - 1).bit_length() 90 | 91 | def next_multiple_of_64(n): 92 | return ((n + 63) // 64) * 64 -------------------------------------------------------------------------------- /audiox/stable_audio_tools/training/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | def get_rank(): 5 | """Get rank of current process.""" 6 | 7 | print(os.environ.keys()) 8 | 9 | if "SLURM_PROCID" in os.environ: 10 | return int(os.environ["SLURM_PROCID"]) 11 | 12 | if not torch.distributed.is_available() or not torch.distributed.is_initialized(): 13 | return 0 14 | 15 | return torch.distributed.get_rank() 16 | 17 | class InverseLR(torch.optim.lr_scheduler._LRScheduler): 18 | """Implements an inverse decay learning rate schedule with an optional exponential 19 | warmup. When last_epoch=-1, sets initial lr as lr. 20 | inv_gamma is the number of steps/epochs required for the learning rate to decay to 21 | (1 / 2)**power of its original value. 22 | Args: 23 | optimizer (Optimizer): Wrapped optimizer. 24 | inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. 25 | power (float): Exponential factor of learning rate decay. Default: 1. 26 | warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) 27 | Default: 0. 28 | final_lr (float): The final learning rate. Default: 0. 29 | last_epoch (int): The index of last epoch. Default: -1. 30 | verbose (bool): If ``True``, prints a message to stdout for 31 | each update. Default: ``False``. 32 | """ 33 | 34 | def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0., 35 | last_epoch=-1, verbose=False): 36 | self.inv_gamma = inv_gamma 37 | self.power = power 38 | if not 0. <= warmup < 1: 39 | raise ValueError('Invalid value for warmup') 40 | self.warmup = warmup 41 | self.final_lr = final_lr 42 | super().__init__(optimizer, last_epoch, verbose) 43 | 44 | def get_lr(self): 45 | if not self._get_lr_called_within_step: 46 | import warnings 47 | warnings.warn("To get the last learning rate computed by the scheduler, " 48 | "please use `get_last_lr()`.") 49 | 50 | return self._get_closed_form_lr() 51 | 52 | def _get_closed_form_lr(self): 53 | warmup = 1 - self.warmup ** (self.last_epoch + 1) 54 | lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power 55 | return [warmup * max(self.final_lr, base_lr * lr_mult) 56 | for base_lr in self.base_lrs] 57 | 58 | def copy_state_dict(model, state_dict): 59 | """Load state_dict to model, but only for keys that match exactly. 60 | 61 | Args: 62 | model (nn.Module): model to load state_dict. 63 | state_dict (OrderedDict): state_dict to load. 64 | """ 65 | model_state_dict = model.state_dict() 66 | for key in state_dict: 67 | if key in model_state_dict and state_dict[key].shape == model_state_dict[key].shape: 68 | if isinstance(state_dict[key], torch.nn.Parameter): 69 | # backwards compatibility for serialized parameters 70 | state_dict[key] = state_dict[key].data 71 | model_state_dict[key] = state_dict[key] 72 | 73 | model.load_state_dict(model_state_dict, strict=False) 74 | 75 | def create_optimizer_from_config(optimizer_config, parameters): 76 | """Create optimizer from config. 77 | 78 | Args: 79 | parameters (iterable): parameters to optimize. 80 | optimizer_config (dict): optimizer config. 81 | 82 | Returns: 83 | torch.optim.Optimizer: optimizer. 84 | """ 85 | 86 | optimizer_type = optimizer_config["type"] 87 | 88 | if optimizer_type == "FusedAdam": 89 | from deepspeed.ops.adam import FusedAdam 90 | optimizer = FusedAdam(parameters, **optimizer_config["config"]) 91 | else: 92 | optimizer_fn = getattr(torch.optim, optimizer_type) 93 | optimizer = optimizer_fn(parameters, **optimizer_config["config"]) 94 | return optimizer 95 | 96 | def create_scheduler_from_config(scheduler_config, optimizer): 97 | """Create scheduler from config. 98 | 99 | Args: 100 | scheduler_config (dict): scheduler config. 101 | optimizer (torch.optim.Optimizer): optimizer. 102 | 103 | Returns: 104 | torch.optim.lr_scheduler._LRScheduler: scheduler. 105 | """ 106 | if scheduler_config["type"] == "InverseLR": 107 | scheduler_fn = InverseLR 108 | else: 109 | scheduler_fn = getattr(torch.optim.lr_scheduler, scheduler_config["type"]) 110 | scheduler = scheduler_fn(optimizer, **scheduler_config["config"]) 111 | return scheduler -------------------------------------------------------------------------------- /debug_comfyui_loading.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Debug script to check ComfyUI node loading issues. 4 | This simulates how ComfyUI loads the nodes. 5 | """ 6 | 7 | import sys 8 | import os 9 | 10 | # Add the current directory to the path 11 | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 12 | 13 | def simulate_comfyui_loading(): 14 | """Simulate how ComfyUI loads the AudioX nodes.""" 15 | print("Simulating ComfyUI node loading process...") 16 | 17 | try: 18 | print("1. Testing __init__.py import...") 19 | 20 | # This simulates how ComfyUI imports the extension 21 | import __init__ as audiox_init 22 | 23 | print(" ✅ __init__.py imported successfully") 24 | 25 | # Check if NODE_CLASS_MAPPINGS exists 26 | if hasattr(audiox_init, 'NODE_CLASS_MAPPINGS'): 27 | mappings = audiox_init.NODE_CLASS_MAPPINGS 28 | print(f" ✅ NODE_CLASS_MAPPINGS found with {len(mappings)} nodes") 29 | 30 | # Check for our volume control nodes 31 | basic_node = mappings.get("AudioXVolumeControl") 32 | advanced_node = mappings.get("AudioXAdvancedVolumeControl") 33 | 34 | if basic_node: 35 | print(" ✅ AudioXVolumeControl found in mappings") 36 | try: 37 | # Test instantiation 38 | instance = basic_node() 39 | print(" ✅ AudioXVolumeControl can be instantiated") 40 | 41 | # Test INPUT_TYPES 42 | input_types = instance.INPUT_TYPES() 43 | print(f" ✅ AudioXVolumeControl INPUT_TYPES: {list(input_types.keys())}") 44 | except Exception as e: 45 | print(f" ❌ AudioXVolumeControl instantiation failed: {e}") 46 | else: 47 | print(" ❌ AudioXVolumeControl NOT found in mappings") 48 | 49 | if advanced_node: 50 | print(" ✅ AudioXAdvancedVolumeControl found in mappings") 51 | try: 52 | # Test instantiation 53 | instance = advanced_node() 54 | print(" ✅ AudioXAdvancedVolumeControl can be instantiated") 55 | 56 | # Test INPUT_TYPES 57 | input_types = instance.INPUT_TYPES() 58 | required = input_types.get("required", {}) 59 | optional = input_types.get("optional", {}) 60 | print(f" ✅ AudioXAdvancedVolumeControl required: {list(required.keys())}") 61 | print(f" ✅ AudioXAdvancedVolumeControl optional: {list(optional.keys())}") 62 | except Exception as e: 63 | print(f" ❌ AudioXAdvancedVolumeControl instantiation failed: {e}") 64 | import traceback 65 | traceback.print_exc() 66 | else: 67 | print(" ❌ AudioXAdvancedVolumeControl NOT found in mappings") 68 | 69 | # List all available nodes 70 | print(f"\n 📋 All available nodes:") 71 | for node_name in sorted(mappings.keys()): 72 | print(f" - {node_name}") 73 | 74 | else: 75 | print(" ❌ NODE_CLASS_MAPPINGS not found in __init__.py") 76 | 77 | except Exception as e: 78 | print(f" ❌ Failed to import __init__.py: {e}") 79 | import traceback 80 | traceback.print_exc() 81 | return False 82 | 83 | print("\n✅ ComfyUI loading simulation completed!") 84 | return True 85 | 86 | def check_node_class_directly(): 87 | """Try to import the node class directly from the module.""" 88 | print("\n" + "="*60) 89 | print("Testing direct node class import...") 90 | 91 | try: 92 | # Mock the relative imports to make them work 93 | import audiox_utils 94 | sys.modules['ComfyUI-AudioX.audiox_utils'] = audiox_utils 95 | 96 | # Now try to import the nodes module 97 | import nodes as audiox_nodes 98 | 99 | print(" ✅ nodes.py imported successfully") 100 | 101 | # Check if the classes exist 102 | if hasattr(audiox_nodes, 'AudioXVolumeControl'): 103 | print(" ✅ AudioXVolumeControl class found") 104 | else: 105 | print(" ❌ AudioXVolumeControl class NOT found") 106 | 107 | if hasattr(audiox_nodes, 'AudioXAdvancedVolumeControl'): 108 | print(" ✅ AudioXAdvancedVolumeControl class found") 109 | else: 110 | print(" ❌ AudioXAdvancedVolumeControl class NOT found") 111 | 112 | return True 113 | 114 | except Exception as e: 115 | print(f" ❌ Direct import failed: {e}") 116 | import traceback 117 | traceback.print_exc() 118 | return False 119 | 120 | if __name__ == "__main__": 121 | print("🔍 AudioX Node Loading Debugger") 122 | print("="*60) 123 | 124 | success1 = simulate_comfyui_loading() 125 | success2 = check_node_class_directly() 126 | 127 | if success1 and success2: 128 | print("\n🎉 All tests passed! The nodes should be loading correctly in ComfyUI.") 129 | else: 130 | print("\n❌ Some tests failed. Check the output above for details.") 131 | print("\nPossible solutions:") 132 | print("1. Restart ComfyUI completely") 133 | print("2. Check ComfyUI console for error messages") 134 | print("3. Verify all dependencies are installed") 135 | 136 | sys.exit(0 if (success1 and success2) else 1) 137 | -------------------------------------------------------------------------------- /install_dependencies.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Installation script for AudioX dependencies. 4 | Run this script to ensure all required dependencies are installed. 5 | """ 6 | 7 | import sys 8 | import subprocess 9 | import importlib.util 10 | 11 | def check_package(package_name, import_name=None): 12 | """Check if a package is installed and can be imported.""" 13 | if import_name is None: 14 | import_name = package_name.replace('-', '_') 15 | 16 | try: 17 | spec = importlib.util.find_spec(import_name) 18 | if spec is not None: 19 | return True 20 | except ImportError: 21 | pass 22 | return False 23 | 24 | def install_package(package_name): 25 | """Install a package using pip.""" 26 | try: 27 | print(f"Installing {package_name}...") 28 | result = subprocess.run([ 29 | sys.executable, "-m", "pip", "install", package_name 30 | ], capture_output=True, text=True, timeout=300) 31 | 32 | if result.returncode == 0: 33 | print(f"✓ {package_name} installed successfully") 34 | return True 35 | else: 36 | print(f"✗ Failed to install {package_name}: {result.stderr}") 37 | return False 38 | except Exception as e: 39 | print(f"✗ Error installing {package_name}: {e}") 40 | return False 41 | 42 | def main(): 43 | """Main installation function.""" 44 | print("AudioX Dependency Installer") 45 | print("=" * 40) 46 | print(f"Python executable: {sys.executable}") 47 | print(f"Python version: {sys.version}") 48 | print() 49 | 50 | # List of required packages 51 | required_packages = [ 52 | ("torch", "torch"), 53 | ("torchaudio", "torchaudio"), 54 | ("transformers", "transformers"), 55 | ("huggingface_hub", "huggingface_hub"), 56 | ("einops", "einops"), 57 | ("einops-exts", "einops_exts"), 58 | ("numpy", "numpy"), 59 | ("safetensors", "safetensors"), 60 | ("librosa", "librosa"), 61 | ("soundfile", "soundfile"), 62 | ("descript-audio-codec", "dac"), 63 | ("Pillow", "PIL"), 64 | ("ema_pytorch", "ema_pytorch"), 65 | ("x-transformers", "x_transformers"), 66 | ("alias-free-torch", "alias_free_torch"), 67 | ("vector-quantize-pytorch", "vector_quantize_pytorch"), 68 | ("local-attention", "local_attention"), 69 | ("k-diffusion", "k_diffusion"), 70 | ("aeiou", "aeiou"), 71 | ("auraloss", "auraloss"), 72 | ("encodec", "encodec"), 73 | ("laion-clap", "laion_clap"), 74 | ("prefigure", "prefigure"), 75 | ("v-diffusion-pytorch", "diffusion"), 76 | ] 77 | 78 | missing_packages = [] 79 | 80 | # Check which packages are missing 81 | print("Checking installed packages...") 82 | for package_name, import_name in required_packages: 83 | if check_package(package_name, import_name): 84 | print(f"✓ {package_name}") 85 | else: 86 | print(f"✗ {package_name} (missing)") 87 | missing_packages.append(package_name) 88 | 89 | print() 90 | 91 | if not missing_packages: 92 | print("✓ All required packages are already installed!") 93 | return True 94 | 95 | # Install missing packages 96 | print(f"Installing {len(missing_packages)} missing packages...") 97 | print() 98 | 99 | success_count = 0 100 | for package_name in missing_packages: 101 | if install_package(package_name): 102 | success_count += 1 103 | 104 | print() 105 | print("=" * 40) 106 | print(f"Installation complete: {success_count}/{len(missing_packages)} packages installed successfully") 107 | 108 | if success_count == len(missing_packages): 109 | print("✓ All dependencies installed successfully!") 110 | 111 | # Test critical imports 112 | print("\nTesting critical imports...") 113 | try: 114 | from x_transformers import ContinuousTransformerWrapper, Encoder 115 | print("✓ x_transformers import successful") 116 | except ImportError as e: 117 | print(f"✗ x_transformers import failed: {e}") 118 | return False 119 | 120 | try: 121 | import einops_exts 122 | print("✓ einops_exts import successful") 123 | except ImportError as e: 124 | print(f"✗ einops_exts import failed: {e}") 125 | return False 126 | 127 | try: 128 | import dac 129 | print("✓ dac import successful") 130 | except ImportError as e: 131 | print(f"✗ dac import failed: {e}") 132 | return False 133 | 134 | try: 135 | import alias_free_torch 136 | print("✓ alias_free_torch import successful") 137 | except ImportError as e: 138 | print(f"✗ alias_free_torch import failed: {e}") 139 | return False 140 | 141 | try: 142 | import vector_quantize_pytorch 143 | print("✓ vector_quantize_pytorch import successful") 144 | except ImportError as e: 145 | print(f"✗ vector_quantize_pytorch import failed: {e}") 146 | return False 147 | 148 | print("\n✓ All critical imports working!") 149 | return True 150 | else: 151 | print("✗ Some packages failed to install. Please check the error messages above.") 152 | return False 153 | 154 | if __name__ == "__main__": 155 | success = main() 156 | sys.exit(0 if success else 1) 157 | -------------------------------------------------------------------------------- /audiox/stable_audio_tools/models/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def create_model_from_config(model_config): 4 | model_type = model_config.get('model_type', None) 5 | 6 | assert model_type is not None, 'model_type must be specified in model config' 7 | 8 | if model_type == 'autoencoder': 9 | from .autoencoders import create_autoencoder_from_config 10 | return create_autoencoder_from_config(model_config) 11 | elif model_type == 'diffusion_uncond': 12 | from .diffusion import create_diffusion_uncond_from_config 13 | return create_diffusion_uncond_from_config(model_config) 14 | elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior": 15 | from .diffusion import create_diffusion_cond_from_config 16 | return create_diffusion_cond_from_config(model_config) 17 | elif model_type == 'diffusion_autoencoder': 18 | from .autoencoders import create_diffAE_from_config 19 | return create_diffAE_from_config(model_config) 20 | elif model_type == 'lm': 21 | from .lm import create_audio_lm_from_config 22 | return create_audio_lm_from_config(model_config) 23 | else: 24 | raise NotImplementedError(f'Unknown model type: {model_type}') 25 | 26 | def create_model_from_config_path(model_config_path): 27 | with open(model_config_path) as f: 28 | model_config = json.load(f) 29 | 30 | return create_model_from_config(model_config) 31 | 32 | def create_pretransform_from_config(pretransform_config, sample_rate): 33 | pretransform_type = pretransform_config.get('type', None) 34 | 35 | assert pretransform_type is not None, 'type must be specified in pretransform config' 36 | 37 | if pretransform_type == 'autoencoder': 38 | from .autoencoders import create_autoencoder_from_config 39 | from .pretransforms import AutoencoderPretransform 40 | 41 | # Create fake top-level config to pass sample rate to autoencoder constructor 42 | # This is a bit of a hack but it keeps us from re-defining the sample rate in the config 43 | autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]} 44 | autoencoder = create_autoencoder_from_config(autoencoder_config) 45 | 46 | scale = pretransform_config.get("scale", 1.0) 47 | model_half = pretransform_config.get("model_half", False) 48 | iterate_batch = pretransform_config.get("iterate_batch", False) 49 | chunked = pretransform_config.get("chunked", False) 50 | 51 | pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked) 52 | elif pretransform_type == 'wavelet': 53 | from .pretransforms import WaveletPretransform 54 | 55 | wavelet_config = pretransform_config["config"] 56 | channels = wavelet_config["channels"] 57 | levels = wavelet_config["levels"] 58 | wavelet = wavelet_config["wavelet"] 59 | 60 | pretransform = WaveletPretransform(channels, levels, wavelet) 61 | elif pretransform_type == 'pqmf': 62 | from .pretransforms import PQMFPretransform 63 | pqmf_config = pretransform_config["config"] 64 | pretransform = PQMFPretransform(**pqmf_config) 65 | elif pretransform_type == 'dac_pretrained': 66 | from .pretransforms import PretrainedDACPretransform 67 | pretrained_dac_config = pretransform_config["config"] 68 | pretransform = PretrainedDACPretransform(**pretrained_dac_config) 69 | elif pretransform_type == "audiocraft_pretrained": 70 | from .pretransforms import AudiocraftCompressionPretransform 71 | 72 | audiocraft_config = pretransform_config["config"] 73 | pretransform = AudiocraftCompressionPretransform(**audiocraft_config) 74 | else: 75 | raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}') 76 | 77 | enable_grad = pretransform_config.get('enable_grad', False) 78 | pretransform.enable_grad = enable_grad 79 | 80 | pretransform.eval().requires_grad_(pretransform.enable_grad) 81 | 82 | return pretransform 83 | 84 | def create_bottleneck_from_config(bottleneck_config): 85 | bottleneck_type = bottleneck_config.get('type', None) 86 | 87 | assert bottleneck_type is not None, 'type must be specified in bottleneck config' 88 | 89 | if bottleneck_type == 'tanh': 90 | from .bottleneck import TanhBottleneck 91 | bottleneck = TanhBottleneck() 92 | elif bottleneck_type == 'vae': 93 | from .bottleneck import VAEBottleneck 94 | bottleneck = VAEBottleneck() 95 | elif bottleneck_type == 'rvq': 96 | from .bottleneck import RVQBottleneck 97 | 98 | quantizer_params = { 99 | "dim": 128, 100 | "codebook_size": 1024, 101 | "num_quantizers": 8, 102 | "decay": 0.99, 103 | "kmeans_init": True, 104 | "kmeans_iters": 50, 105 | "threshold_ema_dead_code": 2, 106 | } 107 | 108 | quantizer_params.update(bottleneck_config["config"]) 109 | 110 | bottleneck = RVQBottleneck(**quantizer_params) 111 | elif bottleneck_type == "dac_rvq": 112 | from .bottleneck import DACRVQBottleneck 113 | 114 | bottleneck = DACRVQBottleneck(**bottleneck_config["config"]) 115 | 116 | elif bottleneck_type == 'rvq_vae': 117 | from .bottleneck import RVQVAEBottleneck 118 | 119 | quantizer_params = { 120 | "dim": 128, 121 | "codebook_size": 1024, 122 | "num_quantizers": 8, 123 | "decay": 0.99, 124 | "kmeans_init": True, 125 | "kmeans_iters": 50, 126 | "threshold_ema_dead_code": 2, 127 | } 128 | 129 | quantizer_params.update(bottleneck_config["config"]) 130 | 131 | bottleneck = RVQVAEBottleneck(**quantizer_params) 132 | 133 | elif bottleneck_type == 'dac_rvq_vae': 134 | from .bottleneck import DACRVQVAEBottleneck 135 | bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"]) 136 | elif bottleneck_type == 'l2_norm': 137 | from .bottleneck import L2Bottleneck 138 | bottleneck = L2Bottleneck() 139 | elif bottleneck_type == "wasserstein": 140 | from .bottleneck import WassersteinBottleneck 141 | bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {})) 142 | elif bottleneck_type == "fsq": 143 | from .bottleneck import FSQBottleneck 144 | bottleneck = FSQBottleneck(**bottleneck_config["config"]) 145 | else: 146 | raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}') 147 | 148 | requires_grad = bottleneck_config.get('requires_grad', True) 149 | if not requires_grad: 150 | for param in bottleneck.parameters(): 151 | param.requires_grad = False 152 | 153 | return bottleneck 154 | -------------------------------------------------------------------------------- /audiox/stable_audio_tools/models/temptransformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | class Residual(nn.Module): 9 | def __init__(self, fn): 10 | super().__init__() 11 | self.fn = fn 12 | def forward(self, x, **kwargs): 13 | return self.fn(x, **kwargs) + x 14 | 15 | class SA_PreNorm(nn.Module): 16 | def __init__(self, dim, fn): 17 | super().__init__() 18 | self.norm = nn.LayerNorm(dim) 19 | self.fn = fn 20 | def forward(self, x, **kwargs): 21 | return self.fn(self.norm(x), **kwargs) 22 | 23 | class SA_FeedForward(nn.Module): 24 | def __init__(self, dim, hidden_dim, dropout = 0.): 25 | super().__init__() 26 | self.net = nn.Sequential( 27 | nn.Linear(dim, hidden_dim), 28 | nn.GELU(), 29 | nn.Dropout(dropout), 30 | nn.Linear(hidden_dim, dim), 31 | nn.Dropout(dropout) 32 | ) 33 | def forward(self, x): 34 | return self.net(x) 35 | 36 | class SA_Attention(nn.Module): 37 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 38 | super().__init__() 39 | inner_dim = dim_head * heads 40 | project_out = not (heads == 1 and dim_head == dim) 41 | 42 | self.heads = heads 43 | self.scale = dim_head ** -0.5 44 | 45 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 46 | 47 | self.to_out = nn.Sequential( 48 | nn.Linear(inner_dim, dim), 49 | nn.Dropout(dropout) 50 | ) if project_out else nn.Identity() 51 | 52 | def forward(self, x): 53 | b, n, _, h = *x.shape, self.heads 54 | qkv = self.to_qkv(x).chunk(3, dim = -1) 55 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 56 | 57 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 58 | 59 | attn = dots.softmax(dim=-1) 60 | 61 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 62 | out = rearrange(out, 'b h n d -> b n (h d)') 63 | out = self.to_out(out) 64 | return out 65 | 66 | 67 | class ReAttention(nn.Module): 68 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 69 | super().__init__() 70 | inner_dim = dim_head * heads 71 | self.heads = heads 72 | self.scale = dim_head ** -0.5 73 | 74 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 75 | 76 | self.reattn_weights = nn.Parameter(torch.randn(heads, heads)) 77 | 78 | self.reattn_norm = nn.Sequential( 79 | Rearrange('b h i j -> b i j h'), 80 | nn.LayerNorm(heads), 81 | Rearrange('b i j h -> b h i j') 82 | ) 83 | 84 | self.to_out = nn.Sequential( 85 | nn.Linear(inner_dim, dim), 86 | nn.Dropout(dropout) 87 | ) 88 | 89 | def forward(self, x): 90 | b, n, _, h = *x.shape, self.heads 91 | qkv = self.to_qkv(x).chunk(3, dim = -1) 92 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 93 | 94 | # attention 95 | 96 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 97 | attn = dots.softmax(dim=-1) 98 | 99 | # re-attention 100 | 101 | attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights) 102 | attn = self.reattn_norm(attn) 103 | 104 | # aggregate and out 105 | 106 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 107 | out = rearrange(out, 'b h n d -> b n (h d)') 108 | out = self.to_out(out) 109 | return out 110 | 111 | class LeFF(nn.Module): 112 | 113 | def __init__(self, dim = 192, scale = 4, depth_kernel = 3): 114 | super().__init__() 115 | 116 | scale_dim = dim*scale 117 | self.up_proj = nn.Sequential(nn.Linear(dim, scale_dim), 118 | Rearrange('b n c -> b c n'), 119 | nn.BatchNorm1d(scale_dim), 120 | nn.GELU(), 121 | Rearrange('b c (h w) -> b c h w', h=14, w=14) 122 | ) 123 | 124 | self.depth_conv = nn.Sequential(nn.Conv2d(scale_dim, scale_dim, kernel_size=depth_kernel, padding=1, groups=scale_dim, bias=False), 125 | nn.BatchNorm2d(scale_dim), 126 | nn.GELU(), 127 | Rearrange('b c h w -> b (h w) c', h=14, w=14) 128 | ) 129 | 130 | self.down_proj = nn.Sequential(nn.Linear(scale_dim, dim), 131 | Rearrange('b n c -> b c n'), 132 | nn.BatchNorm1d(dim), 133 | nn.GELU(), 134 | Rearrange('b c n -> b n c') 135 | ) 136 | 137 | def forward(self, x): 138 | x = self.up_proj(x) 139 | x = self.depth_conv(x) 140 | x = self.down_proj(x) 141 | return x 142 | 143 | 144 | class LCAttention(nn.Module): 145 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 146 | super().__init__() 147 | inner_dim = dim_head * heads 148 | project_out = not (heads == 1 and dim_head == dim) 149 | 150 | self.heads = heads 151 | self.scale = dim_head ** -0.5 152 | 153 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 154 | 155 | self.to_out = nn.Sequential( 156 | nn.Linear(inner_dim, dim), 157 | nn.Dropout(dropout) 158 | ) if project_out else nn.Identity() 159 | 160 | def forward(self, x): 161 | b, n, _, h = *x.shape, self.heads 162 | qkv = self.to_qkv(x).chunk(3, dim = -1) 163 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 164 | q = q[:, :, -1, :].unsqueeze(2) # Only Lth element use as query 165 | 166 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 167 | 168 | attn = dots.softmax(dim=-1) 169 | 170 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 171 | out = rearrange(out, 'b h n d -> b n (h d)') 172 | out = self.to_out(out) 173 | return out 174 | 175 | class SA_Transformer(nn.Module): 176 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 177 | super().__init__() 178 | self.layers = nn.ModuleList([]) 179 | self.norm = nn.LayerNorm(dim) 180 | for _ in range(depth): 181 | self.layers.append(nn.ModuleList([ 182 | SA_PreNorm(dim, SA_Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 183 | SA_PreNorm(dim, SA_FeedForward(dim, mlp_dim, dropout = dropout)) 184 | ])) 185 | 186 | def forward(self, x): 187 | for attn, ff in self.layers: 188 | x = attn(x) + x 189 | x = ff(x) + x 190 | return self.norm(x) -------------------------------------------------------------------------------- /web/audiox.js: -------------------------------------------------------------------------------- 1 | // AudioX ComfyUI Web Components 2 | // This file contains web-side functionality for AudioX nodes 3 | 4 | import { app } from "../../scripts/app.js"; 5 | 6 | // Extension registration with post-execution patching 7 | app.registerExtension({ 8 | name: "AudioX.WebComponents", 9 | 10 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 11 | // Add beforeQueued method if it doesn't exist 12 | if (!nodeType.prototype.beforeQueued) { 13 | nodeType.prototype.beforeQueued = function() { 14 | console.log("AudioX: beforeQueued called on", nodeData.name); 15 | return true; 16 | }; 17 | } 18 | }, 19 | 20 | async nodeCreated(node) { 21 | // Ensure node has beforeQueued method 22 | if (!node.beforeQueued) { 23 | node.beforeQueued = function() { 24 | console.log("AudioX: beforeQueued called on node", node.type); 25 | return true; 26 | }; 27 | } 28 | }, 29 | 30 | // CRITICAL: Re-patch after workflow execution 31 | async beforeQueued(details) { 32 | console.log("AudioX: Extension beforeQueued called"); 33 | 34 | // Re-patch all nodes after execution to prevent second-run errors 35 | setTimeout(() => { 36 | this.repatchAllNodes(); 37 | }, 100); 38 | 39 | return true; 40 | }, 41 | 42 | // Re-patch nodes after execution 43 | repatchAllNodes() { 44 | try { 45 | console.log("AudioX: Re-patching all nodes after execution..."); 46 | 47 | if (window.app?.graph?._nodes) { 48 | window.app.graph._nodes.forEach((node, index) => { 49 | if (node && (!node.beforeQueued || typeof node.beforeQueued !== 'function')) { 50 | node.beforeQueued = function() { 51 | console.log(`AudioX: Re-patched beforeQueued called on node ${index} (${node.type})`); 52 | return true; 53 | }; 54 | console.log(`AudioX: Re-patched node ${index}: ${node.type}`); 55 | } 56 | }); 57 | } 58 | 59 | // Also re-patch prototypes in case they got reset 60 | if (window.LiteGraph?.LGraphNode?.prototype && !window.LiteGraph.LGraphNode.prototype.beforeQueued) { 61 | window.LiteGraph.LGraphNode.prototype.beforeQueued = function() { 62 | console.log("AudioX: Re-patched LiteGraph beforeQueued called"); 63 | return true; 64 | }; 65 | console.log("AudioX: Re-patched LiteGraph prototype"); 66 | } 67 | 68 | } catch (error) { 69 | console.warn("AudioX: Error re-patching nodes:", error); 70 | } 71 | } 72 | }); 73 | 74 | // Monitor for workflow execution completion and re-patch 75 | function setupPostExecutionMonitoring() { 76 | try { 77 | // Monitor for execution completion via app events 78 | if (window.app) { 79 | // Hook into the app's execution system 80 | const originalQueuePrompt = window.app.queuePrompt; 81 | if (originalQueuePrompt) { 82 | window.app.queuePrompt = async function(...args) { 83 | console.log("AudioX: Workflow execution starting..."); 84 | 85 | try { 86 | const result = await originalQueuePrompt.apply(this, args); 87 | 88 | // Re-patch after execution completes 89 | setTimeout(() => { 90 | console.log("AudioX: Post-execution re-patching..."); 91 | repatchAfterExecution(); 92 | }, 500); 93 | 94 | return result; 95 | } catch (error) { 96 | console.warn("AudioX: Error in queuePrompt:", error); 97 | // Still re-patch even if execution failed 98 | setTimeout(() => { 99 | repatchAfterExecution(); 100 | }, 500); 101 | throw error; 102 | } 103 | }; 104 | console.log("AudioX: Hooked into queuePrompt for post-execution patching"); 105 | } 106 | } 107 | 108 | // Also monitor for WebSocket messages that indicate execution completion 109 | if (window.WebSocket) { 110 | const originalWebSocket = window.WebSocket; 111 | window.WebSocket = function(...args) { 112 | const ws = new originalWebSocket(...args); 113 | 114 | const originalOnMessage = ws.onmessage; 115 | ws.onmessage = function(event) { 116 | try { 117 | const data = JSON.parse(event.data); 118 | 119 | // Check for execution completion messages 120 | if (data.type === 'executed' || data.type === 'execution_cached') { 121 | console.log("AudioX: Detected execution completion via WebSocket"); 122 | setTimeout(() => { 123 | repatchAfterExecution(); 124 | }, 200); 125 | } 126 | } catch (e) { 127 | // Ignore JSON parse errors 128 | } 129 | 130 | if (originalOnMessage) { 131 | return originalOnMessage.apply(this, arguments); 132 | } 133 | }; 134 | 135 | return ws; 136 | }; 137 | console.log("AudioX: Set up WebSocket monitoring for execution completion"); 138 | } 139 | 140 | } catch (error) { 141 | console.warn("AudioX: Error setting up post-execution monitoring:", error); 142 | } 143 | } 144 | 145 | // Function to re-patch everything after execution 146 | function repatchAfterExecution() { 147 | try { 148 | console.log("AudioX: Re-patching after execution..."); 149 | 150 | // Re-patch all nodes 151 | if (window.app?.graph?._nodes) { 152 | window.app.graph._nodes.forEach((node, index) => { 153 | if (node && (!node.beforeQueued || typeof node.beforeQueued !== 'function')) { 154 | node.beforeQueued = function() { 155 | console.log(`AudioX: Post-exec beforeQueued on node ${index} (${node.type})`); 156 | return true; 157 | }; 158 | } 159 | }); 160 | } 161 | 162 | // Re-patch prototypes 163 | if (window.LiteGraph?.LGraphNode?.prototype && !window.LiteGraph.LGraphNode.prototype.beforeQueued) { 164 | window.LiteGraph.LGraphNode.prototype.beforeQueued = function() { 165 | console.log("AudioX: Post-exec LiteGraph beforeQueued"); 166 | return true; 167 | }; 168 | } 169 | 170 | // Re-patch app object 171 | if (window.app && !window.app.beforeQueued) { 172 | window.app.beforeQueued = function() { 173 | console.log("AudioX: Post-exec app beforeQueued"); 174 | return true; 175 | }; 176 | } 177 | 178 | console.log("AudioX: Post-execution re-patching completed"); 179 | 180 | } catch (error) { 181 | console.warn("AudioX: Error in post-execution re-patching:", error); 182 | } 183 | } 184 | 185 | // Set up monitoring after a short delay 186 | setTimeout(setupPostExecutionMonitoring, 1000); 187 | 188 | console.log("AudioX: Extension with post-execution monitoring loaded"); 189 | -------------------------------------------------------------------------------- /audiox/stable_audio_tools/data/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch 4 | 5 | from torch import nn 6 | from typing import Tuple 7 | import os 8 | import subprocess as sp 9 | from PIL import Image 10 | from torchvision import transforms 11 | from decord import VideoReader, cpu 12 | 13 | class PadCrop(nn.Module): 14 | def __init__(self, n_samples, randomize=True): 15 | super().__init__() 16 | self.n_samples = n_samples 17 | self.randomize = randomize 18 | 19 | def __call__(self, signal): 20 | n, s = signal.shape 21 | start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item() 22 | end = start + self.n_samples 23 | output = signal.new_zeros([n, self.n_samples]) 24 | output[:, :min(s, self.n_samples)] = signal[:, start:end] 25 | return output 26 | 27 | 28 | class PadCrop_Normalized_T(nn.Module): 29 | 30 | def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): 31 | super().__init__() 32 | self.n_samples = n_samples 33 | self.sample_rate = sample_rate 34 | self.randomize = randomize 35 | 36 | def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int, torch.Tensor]: 37 | n_channels, n_samples = source.shape 38 | 39 | # Calculate the duration of the audio in seconds 40 | total_duration = n_samples // self.sample_rate 41 | 42 | # If the audio is shorter than the desired length, pad it 43 | upper_bound = max(0, n_samples - self.n_samples) 44 | 45 | # If randomize is False, always start at the beginning of the audio 46 | offset = 0 47 | 48 | if self.randomize and n_samples > self.n_samples: 49 | valid_offsets = [ 50 | i * self.sample_rate for i in range(0, total_duration, 10) 51 | if i * self.sample_rate + self.n_samples <= n_samples and 52 | (total_duration <= 20 or total_duration - i >= 15) 53 | ] 54 | if valid_offsets: 55 | offset = random.choice(valid_offsets) 56 | 57 | # Calculate the start and end times of the chunk 58 | t_start = offset / (upper_bound + self.n_samples) 59 | t_end = (offset + self.n_samples) / (upper_bound + self.n_samples) 60 | 61 | # Create the chunk 62 | chunk = source.new_zeros([n_channels, self.n_samples]) 63 | 64 | # Copy the audio into the chunk 65 | chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples] 66 | 67 | # Calculate the start and end times of the chunk in seconds 68 | seconds_start = math.floor(offset / self.sample_rate) 69 | seconds_total = math.ceil(n_samples / self.sample_rate) 70 | 71 | # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't 72 | padding_mask = torch.zeros([self.n_samples]) 73 | padding_mask[:min(n_samples, self.n_samples)] = 1 74 | 75 | return ( 76 | chunk, 77 | t_start, 78 | t_end, 79 | seconds_start, 80 | seconds_total, 81 | padding_mask 82 | ) 83 | 84 | 85 | class PhaseFlipper(nn.Module): 86 | "Randomly invert the phase of a signal" 87 | def __init__(self, p=0.5): 88 | super().__init__() 89 | self.p = p 90 | def __call__(self, signal): 91 | return -signal if (random.random() < self.p) else signal 92 | 93 | class Mono(nn.Module): 94 | def __call__(self, signal): 95 | return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal 96 | 97 | class Stereo(nn.Module): 98 | def __call__(self, signal): 99 | signal_shape = signal.shape 100 | # Check if it's mono 101 | if len(signal_shape) == 1: # s -> 2, s 102 | signal = signal.unsqueeze(0).repeat(2, 1) 103 | elif len(signal_shape) == 2: 104 | if signal_shape[0] == 1: #1, s -> 2, s 105 | signal = signal.repeat(2, 1) 106 | elif signal_shape[0] > 2: #?, s -> 2,s 107 | signal = signal[:2, :] 108 | 109 | return signal 110 | 111 | 112 | def adjust_video_duration(video_tensor, duration, target_fps): 113 | current_duration = video_tensor.shape[0] 114 | target_duration = duration * target_fps 115 | if current_duration > target_duration: 116 | video_tensor = video_tensor[:target_duration] 117 | elif current_duration < target_duration: 118 | last_frame = video_tensor[-1:] 119 | repeat_times = target_duration - current_duration 120 | video_tensor = torch.cat((video_tensor, last_frame.repeat(repeat_times, 1, 1, 1)), dim=0) 121 | return video_tensor 122 | 123 | def read_video(filepath, seek_time=0., duration=-1, target_fps=2): 124 | if filepath is None: 125 | return torch.zeros((int(duration * target_fps), 3, 224, 224)) 126 | 127 | ext = os.path.splitext(filepath)[1].lower() 128 | if ext in ['.jpg', '.jpeg', '.png']: 129 | resize_transform = transforms.Resize((224, 224)) 130 | image = Image.open(filepath).convert("RGB") 131 | frame = transforms.ToTensor()(image).unsqueeze(0) 132 | frame = resize_transform(frame) 133 | target_frames = int(duration * target_fps) 134 | frame = frame.repeat(int(math.ceil(target_frames / frame.shape[0])), 1, 1, 1)[:target_frames] 135 | assert frame.shape[0] == target_frames, f"The shape of frame is {frame.shape}" 136 | return frame 137 | 138 | vr = VideoReader(filepath, ctx=cpu(0)) 139 | fps = vr.get_avg_fps() 140 | total_frames = len(vr) 141 | 142 | seek_frame = int(seek_time * fps) 143 | if duration > 0: 144 | total_frames_to_read = int(target_fps * duration) 145 | frame_interval = int(math.ceil(fps / target_fps)) 146 | end_frame = min(seek_frame + total_frames_to_read * frame_interval, total_frames) 147 | frame_ids = list(range(seek_frame, end_frame, frame_interval)) 148 | else: 149 | frame_interval = int(math.ceil(fps / target_fps)) 150 | frame_ids = list(range(0, total_frames, frame_interval)) 151 | 152 | frames = vr.get_batch(frame_ids).asnumpy() 153 | frames = torch.from_numpy(frames).permute(0, 3, 1, 2) 154 | 155 | if frames.shape[2] != 224 or frames.shape[3] != 224: 156 | resize_transform = transforms.Resize((224, 224)) 157 | frames = resize_transform(frames) 158 | 159 | video_tensor = adjust_video_duration(frames, duration, target_fps) 160 | assert video_tensor.shape[0] == duration * target_fps, f"The shape of video_tensor is {video_tensor.shape}" 161 | return video_tensor 162 | 163 | def merge_video_audio(video_path, audio_path, output_path, start_time, duration): 164 | command = [ 165 | 'ffmpeg', 166 | '-y', 167 | '-ss', str(start_time), 168 | '-t', str(duration), 169 | '-i', video_path, 170 | '-i', audio_path, 171 | '-c:v', 'copy', 172 | '-c:a', 'aac', 173 | '-map', '0:v:0', 174 | '-map', '1:a:0', 175 | '-shortest', 176 | '-strict', 'experimental', 177 | output_path 178 | ] 179 | 180 | try: 181 | sp.run(command, check=True) 182 | print(f"Successfully merged audio and video into {output_path}") 183 | return output_path 184 | except sp.CalledProcessError as e: 185 | print(f"Error merging audio and video: {e}") 186 | return None 187 | 188 | def load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total): 189 | if audio_path is None: 190 | return torch.zeros((2, int(sample_rate * seconds_total))) 191 | audio_tensor, sr = torchaudio.load(audio_path) 192 | start_index = int(sample_rate * seconds_start) 193 | target_length = int(sample_rate * seconds_total) 194 | end_index = start_index + target_length 195 | audio_tensor = audio_tensor[:, start_index:end_index] 196 | if audio_tensor.shape[1] < target_length: 197 | pad_length = target_length - audio_tensor.shape[1] 198 | audio_tensor = F.pad(audio_tensor, (pad_length, 0)) 199 | return audio_tensor -------------------------------------------------------------------------------- /audiox/README.md: -------------------------------------------------------------------------------- 1 | # 🎧 AudioX: Diffusion Transformer for Anything-to-Audio Generation 2 | 3 | [](https://arxiv.org/abs/2503.10522) 4 | [](https://zeyuet.github.io/AudioX/) 5 | [](https://huggingface.co/HKUSTAudio/AudioX) 6 | [](https://huggingface.co/spaces/Zeyue7/AudioX) 7 | 8 | --- 9 | 10 | **This is the official repository for "[AudioX: Diffusion Transformer for Anything-to-Audio Generation](https://arxiv.org/pdf/2503.10522)".** 11 | 12 | 13 | ## 📺 Demo Video 14 | 15 | https://github.com/user-attachments/assets/0d8dd927-ff0f-4b35-ab1f-b3c3915017be 16 | 17 | --- 18 | 19 | 20 | ## ✨ Abstract 21 | 22 | Audio and music generation have emerged as crucial tasks in many applications, yet existing approaches face significant limitations: they operate in isolation without unified capabilities across modalities, suffer from scarce high-quality, multi-modal training data, and struggle to effectively integrate diverse inputs. In this work, we propose AudioX, a unified Diffusion Transformer model for Anything-to-Audio and Music Generation. Unlike previous domain-specific models, AudioX can generate both general audio and music with high quality, while offering flexible natural language control and seamless processing of various modalities including text, video, image, music, and audio. Its key innovation is a multi-modal masked training strategy that masks inputs across modalities and forces the model to learn from masked inputs, yielding robust and unified cross-modal representations. To address data scarcity, we curate two comprehensive datasets: vggsound-caps with 190K audio captions based on the VGGSound dataset, and V2M-caps with 6 million music captions derived from the V2M dataset. Extensive experiments demonstrate that AudioX not only matches or outperforms state-of-the-art specialized models, but also offers remarkable versatility in handling diverse input modalities and generation tasks within a unified architecture. 23 | 24 | 25 | ## ✨ Teaser 26 | 27 |
28 |
29 |
(a) Overview of AudioX, illustrating its capabilities across various tasks. (b) Radar chart comparing the performance of different methods across multiple benchmarks. AudioX demonstrates superior Inception Scores (IS) across a diverse set of datasets in audio and music generation tasks.
31 | 32 | 33 | ## ✨ Method 34 | 35 |
36 |
37 |
Overview of the AudioX Framework.
39 | 40 | 41 | 42 | ## Code 43 | 44 | ### 🆕 **Enhanced ComfyUI-AudioX Features** 45 | 46 | **Latest Improvements (2025-06-16):** 47 | - ✅ **Fixed torchdata dependency issues** - Now works without optional dependencies 48 | - ✅ **Enhanced audio generation** - Improved conditioning and quality controls 49 | - ✅ **Better error handling** - Graceful degradation and clear error messages 50 | 51 | ### 🛠️ Environment Setup 52 | 53 | ```bash 54 | git clone https://github.com/ZeyueT/AudioX.git 55 | cd AudioX 56 | conda create -n AudioX python=3.8.20 57 | conda activate AudioX 58 | pip install git+https://github.com/ZeyueT/AudioX.git 59 | conda install -c conda-forge ffmpeg libsndfile 60 | 61 | # Optional: Install torchdata for enhanced performance (not required) 62 | pip install torchdata 63 | ``` 64 | 65 | ### 🔧 **ComfyUI-AudioX Training Features** 66 | 67 | #### **Enhanced Audio Generation** 68 | - **🎯 Advanced Conditioning** - Improved text-to-audio and video-to-audio generation 69 | - **📊 Quality Controls**: Enhanced CFG scales, conditioning weights, negative prompting 70 | - **⏱️ Flexible Duration** - Generate audio of various lengths 71 | - **🔧 Smart Configuration** - Optimized settings for different audio types 72 | 73 | 74 | 75 | ## 🪄 Pretrained Checkpoints 76 | 77 | Download the pretrained model from 🤗 [AudioX on Hugging Face](https://huggingface.co/HKUSTAudio/AudioX): 78 | 79 | ```bash 80 | mkdir -p model 81 | wget https://huggingface.co/HKUSTAudio/AudioX/resolve/main/model.ckpt -O model/model.ckpt 82 | wget https://huggingface.co/HKUSTAudio/AudioX/resolve/main/config.json -O model/config.json 83 | ``` 84 | 85 | ### 🤗 Gradio Demo 86 | 87 | To launch the Gradio demo locally, run: 88 | 89 | ```bash 90 | python3 run_gradio.py \ 91 | --model-config model/config.json \ 92 | --share 93 | ``` 94 | 95 | 96 | ### 🎯 Prompt Configuration Examples 97 | 98 | | Task | `video_path` | `text_prompt` | `audio_path` | 99 | |:---------------------|:-------------------|:----------------------------------------------|:-------------| 100 | | Text-to-Audio (T2A) | `None` | `"Typing on a keyboard"` | `None` | 101 | | Text-to-Music (T2M) | `None` | `"A music with piano and violin"` | `None` | 102 | | Video-to-Audio (V2A) | `"video_path.mp4"` | `"Generate general audio for the video"` | `None` | 103 | | Video-to-Music (V2M) | `"video_path.mp4"` | `"Generate music for the video"` | `None` | 104 | | TV-to-Audio (TV2A) | `"video_path.mp4"` | `"Ocean waves crashing with people laughing"` | `None` | 105 | | TV-to-Music (TV2M) | `"video_path.mp4"` | `"Generate music with piano instrument"` | `None` | 106 | 107 | ### 🖥️ Script Inference 108 | 109 | ```python 110 | import torch 111 | import torchaudio 112 | from einops import rearrange 113 | from stable_audio_tools import get_pretrained_model 114 | from stable_audio_tools.inference.generation import generate_diffusion_cond 115 | from stable_audio_tools.data.utils import read_video, merge_video_audio 116 | from stable_audio_tools.data.utils import load_and_process_audio 117 | import os 118 | 119 | device = "cuda" if torch.cuda.is_available() else "cpu" 120 | 121 | # Download model 122 | model, model_config = get_pretrained_model("HKUSTAudio/AudioX") 123 | sample_rate = model_config["sample_rate"] 124 | sample_size = model_config["sample_size"] 125 | target_fps = model_config["video_fps"] 126 | seconds_start = 0 127 | seconds_total = 10 128 | 129 | model = model.to(device) 130 | 131 | # for video-to-music generation 132 | video_path = "example/V2M_sample-1.mp4" 133 | text_prompt = "Generate music for the video" 134 | audio_path = None 135 | 136 | video_tensor = read_video(video_path, seek_time=0, duration=seconds_total, target_fps=target_fps) 137 | audio_tensor = load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total) 138 | 139 | conditioning = [{ 140 | "video_prompt": [video_tensor.unsqueeze(0)], 141 | "text_prompt": text_prompt, 142 | "audio_prompt": audio_tensor.unsqueeze(0), 143 | "seconds_start": seconds_start, 144 | "seconds_total": seconds_total 145 | }] 146 | 147 | # Generate stereo audio 148 | output = generate_diffusion_cond( 149 | model, 150 | steps=250, 151 | cfg_scale=7, 152 | conditioning=conditioning, 153 | sample_size=sample_size, 154 | sigma_min=0.3, 155 | sigma_max=500, 156 | sampler_type="dpmpp-3m-sde", 157 | device=device 158 | ) 159 | 160 | # Rearrange audio batch to a single sequence 161 | output = rearrange(output, "b d n -> d (b n)") 162 | 163 | # Peak normalize, clip, convert to int16, and save to file 164 | output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() 165 | torchaudio.save("output.wav", output, sample_rate) 166 | 167 | if video_path is not None and os.path.exists(video_path): 168 | merge_video_audio(video_path, "output.wav", "output.mp4", 0, seconds_total) 169 | 170 | ``` 171 | 172 | 173 | ## 🚀 Citation 174 | 175 | If you find our work useful, please consider citing: 176 | 177 | ``` 178 | @article{tian2025audiox, 179 | title={AudioX: Diffusion Transformer for Anything-to-Audio Generation}, 180 | author={Tian, Zeyue and Jin, Yizhu and Liu, Zhaoyang and Yuan, Ruibin and Tan, Xu and Chen, Qifeng and Xue, Wei and Guo, Yike}, 181 | journal={arXiv preprint arXiv:2503.10522}, 182 | year={2025} 183 | } 184 | ``` 185 | 186 | ## 📭 Contact 187 | 188 | If you have any comments or questions, feel free to contact Zeyue Tian(ztianad@connect.ust.hk). 189 | 190 | ## License 191 | 192 | Please follow [CC-BY-NC](./LICENSE). 193 | -------------------------------------------------------------------------------- /audiox/stable_audio_tools/models/local_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from einops import rearrange 4 | from torch import nn 5 | 6 | from .blocks import AdaRMSNorm 7 | from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm 8 | 9 | def checkpoint(function, *args, **kwargs): 10 | kwargs.setdefault("use_reentrant", False) 11 | return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) 12 | 13 | # Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py 14 | class ContinuousLocalTransformer(nn.Module): 15 | def __init__( 16 | self, 17 | *, 18 | dim, 19 | depth, 20 | dim_in = None, 21 | dim_out = None, 22 | causal = False, 23 | local_attn_window_size = 64, 24 | heads = 8, 25 | ff_mult = 2, 26 | cond_dim = 0, 27 | cross_attn_cond_dim = 0, 28 | **kwargs 29 | ): 30 | super().__init__() 31 | 32 | dim_head = dim//heads 33 | 34 | self.layers = nn.ModuleList([]) 35 | 36 | self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity() 37 | 38 | self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity() 39 | 40 | self.local_attn_window_size = local_attn_window_size 41 | 42 | self.cond_dim = cond_dim 43 | 44 | self.cross_attn_cond_dim = cross_attn_cond_dim 45 | 46 | self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32)) 47 | 48 | for _ in range(depth): 49 | 50 | self.layers.append(nn.ModuleList([ 51 | AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), 52 | Attention( 53 | dim=dim, 54 | dim_heads=dim_head, 55 | causal=causal, 56 | zero_init_output=True, 57 | natten_kernel_size=local_attn_window_size, 58 | ), 59 | Attention( 60 | dim=dim, 61 | dim_heads=dim_head, 62 | dim_context = cross_attn_cond_dim, 63 | zero_init_output=True 64 | ) if self.cross_attn_cond_dim > 0 else nn.Identity(), 65 | AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), 66 | FeedForward(dim = dim, mult = ff_mult, no_bias=True) 67 | ])) 68 | 69 | def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None): 70 | 71 | x = checkpoint(self.project_in, x) 72 | 73 | if prepend_cond is not None: 74 | x = torch.cat([prepend_cond, x], dim=1) 75 | 76 | pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) 77 | 78 | for attn_norm, attn, xattn, ff_norm, ff in self.layers: 79 | 80 | residual = x 81 | if cond is not None: 82 | x = checkpoint(attn_norm, x, cond) 83 | else: 84 | x = checkpoint(attn_norm, x) 85 | 86 | x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual 87 | 88 | if cross_attn_cond is not None: 89 | x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x 90 | 91 | residual = x 92 | 93 | if cond is not None: 94 | x = checkpoint(ff_norm, x, cond) 95 | else: 96 | x = checkpoint(ff_norm, x) 97 | 98 | x = checkpoint(ff, x) + residual 99 | 100 | return checkpoint(self.project_out, x) 101 | 102 | class TransformerDownsampleBlock1D(nn.Module): 103 | def __init__( 104 | self, 105 | in_channels, 106 | embed_dim = 768, 107 | depth = 3, 108 | heads = 12, 109 | downsample_ratio = 2, 110 | local_attn_window_size = 64, 111 | **kwargs 112 | ): 113 | super().__init__() 114 | 115 | self.downsample_ratio = downsample_ratio 116 | 117 | self.transformer = ContinuousLocalTransformer( 118 | dim=embed_dim, 119 | depth=depth, 120 | heads=heads, 121 | local_attn_window_size=local_attn_window_size, 122 | **kwargs 123 | ) 124 | 125 | self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() 126 | 127 | self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False) 128 | 129 | 130 | def forward(self, x): 131 | 132 | x = checkpoint(self.project_in, x) 133 | 134 | # Compute 135 | x = self.transformer(x) 136 | 137 | # Trade sequence length for channels 138 | x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio) 139 | 140 | # Project back to embed dim 141 | x = checkpoint(self.project_down, x) 142 | 143 | return x 144 | 145 | class TransformerUpsampleBlock1D(nn.Module): 146 | def __init__( 147 | self, 148 | in_channels, 149 | embed_dim, 150 | depth = 3, 151 | heads = 12, 152 | upsample_ratio = 2, 153 | local_attn_window_size = 64, 154 | **kwargs 155 | ): 156 | super().__init__() 157 | 158 | self.upsample_ratio = upsample_ratio 159 | 160 | self.transformer = ContinuousLocalTransformer( 161 | dim=embed_dim, 162 | depth=depth, 163 | heads=heads, 164 | local_attn_window_size = local_attn_window_size, 165 | **kwargs 166 | ) 167 | 168 | self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() 169 | 170 | self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False) 171 | 172 | def forward(self, x): 173 | 174 | # Project to embed dim 175 | x = checkpoint(self.project_in, x) 176 | 177 | # Project to increase channel dim 178 | x = checkpoint(self.project_up, x) 179 | 180 | # Trade channels for sequence length 181 | x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio) 182 | 183 | # Compute 184 | x = self.transformer(x) 185 | 186 | return x 187 | 188 | 189 | class TransformerEncoder1D(nn.Module): 190 | def __init__( 191 | self, 192 | in_channels, 193 | out_channels, 194 | embed_dims = [96, 192, 384, 768], 195 | heads = [12, 12, 12, 12], 196 | depths = [3, 3, 3, 3], 197 | ratios = [2, 2, 2, 2], 198 | local_attn_window_size = 64, 199 | **kwargs 200 | ): 201 | super().__init__() 202 | 203 | layers = [] 204 | 205 | for layer in range(len(depths)): 206 | prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] 207 | 208 | layers.append( 209 | TransformerDownsampleBlock1D( 210 | in_channels = prev_dim, 211 | embed_dim = embed_dims[layer], 212 | heads = heads[layer], 213 | depth = depths[layer], 214 | downsample_ratio = ratios[layer], 215 | local_attn_window_size = local_attn_window_size, 216 | **kwargs 217 | ) 218 | ) 219 | 220 | self.layers = nn.Sequential(*layers) 221 | 222 | self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) 223 | self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False) 224 | 225 | def forward(self, x): 226 | x = rearrange(x, "b c n -> b n c") 227 | x = checkpoint(self.project_in, x) 228 | x = self.layers(x) 229 | x = checkpoint(self.project_out, x) 230 | x = rearrange(x, "b n c -> b c n") 231 | 232 | return x 233 | 234 | 235 | class TransformerDecoder1D(nn.Module): 236 | def __init__( 237 | self, 238 | in_channels, 239 | out_channels, 240 | embed_dims = [768, 384, 192, 96], 241 | heads = [12, 12, 12, 12], 242 | depths = [3, 3, 3, 3], 243 | ratios = [2, 2, 2, 2], 244 | local_attn_window_size = 64, 245 | **kwargs 246 | ): 247 | 248 | super().__init__() 249 | 250 | layers = [] 251 | 252 | for layer in range(len(depths)): 253 | prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] 254 | 255 | layers.append( 256 | TransformerUpsampleBlock1D( 257 | in_channels = prev_dim, 258 | embed_dim = embed_dims[layer], 259 | heads = heads[layer], 260 | depth = depths[layer], 261 | upsample_ratio = ratios[layer], 262 | local_attn_window_size = local_attn_window_size, 263 | **kwargs 264 | ) 265 | ) 266 | 267 | self.layers = nn.Sequential(*layers) 268 | 269 | self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) 270 | self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False) 271 | 272 | def forward(self, x): 273 | x = rearrange(x, "b c n -> b n c") 274 | x = checkpoint(self.project_in, x) 275 | x = self.layers(x) 276 | x = checkpoint(self.project_out, x) 277 | x = rearrange(x, "b n c -> b c n") 278 | return x -------------------------------------------------------------------------------- /audiox/stable_audio_tools/models/pretransforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import nn 4 | 5 | class Pretransform(nn.Module): 6 | def __init__(self, enable_grad, io_channels, is_discrete): 7 | super().__init__() 8 | 9 | self.is_discrete = is_discrete 10 | self.io_channels = io_channels 11 | self.encoded_channels = None 12 | self.downsampling_ratio = None 13 | 14 | self.enable_grad = enable_grad 15 | 16 | def encode(self, x): 17 | raise NotImplementedError 18 | 19 | def decode(self, z): 20 | raise NotImplementedError 21 | 22 | def tokenize(self, x): 23 | raise NotImplementedError 24 | 25 | def decode_tokens(self, tokens): 26 | raise NotImplementedError 27 | 28 | class AutoencoderPretransform(Pretransform): 29 | def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False): 30 | super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete) 31 | self.model = model 32 | self.model.requires_grad_(False).eval() 33 | self.scale=scale 34 | self.downsampling_ratio = model.downsampling_ratio 35 | self.io_channels = model.io_channels 36 | self.sample_rate = model.sample_rate 37 | 38 | self.model_half = model_half 39 | self.iterate_batch = iterate_batch 40 | 41 | self.encoded_channels = model.latent_dim 42 | 43 | self.chunked = chunked 44 | self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None 45 | self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None 46 | 47 | if self.model_half: 48 | self.model.half() 49 | 50 | def encode(self, x, **kwargs): 51 | 52 | if self.model_half: 53 | x = x.half() 54 | self.model.to(torch.float16) 55 | 56 | encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) 57 | 58 | if self.model_half: 59 | encoded = encoded.float() 60 | 61 | return encoded / self.scale 62 | 63 | def decode(self, z, **kwargs): 64 | z = z * self.scale 65 | 66 | if self.model_half: 67 | z = z.half() 68 | self.model.to(torch.float16) 69 | 70 | decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) 71 | 72 | if self.model_half: 73 | decoded = decoded.float() 74 | 75 | return decoded 76 | 77 | def tokenize(self, x, **kwargs): 78 | assert self.model.is_discrete, "Cannot tokenize with a continuous model" 79 | 80 | _, info = self.model.encode(x, return_info = True, **kwargs) 81 | 82 | return info[self.model.bottleneck.tokens_id] 83 | 84 | def decode_tokens(self, tokens, **kwargs): 85 | assert self.model.is_discrete, "Cannot decode tokens with a continuous model" 86 | 87 | return self.model.decode_tokens(tokens, **kwargs) 88 | 89 | def load_state_dict(self, state_dict, strict=True): 90 | self.model.load_state_dict(state_dict, strict=strict) 91 | 92 | class WaveletPretransform(Pretransform): 93 | def __init__(self, channels, levels, wavelet): 94 | super().__init__(enable_grad=False, io_channels=channels, is_discrete=False) 95 | 96 | from .wavelets import WaveletEncode1d, WaveletDecode1d 97 | 98 | self.encoder = WaveletEncode1d(channels, levels, wavelet) 99 | self.decoder = WaveletDecode1d(channels, levels, wavelet) 100 | 101 | self.downsampling_ratio = 2 ** levels 102 | self.io_channels = channels 103 | self.encoded_channels = channels * self.downsampling_ratio 104 | 105 | def encode(self, x): 106 | return self.encoder(x) 107 | 108 | def decode(self, z): 109 | return self.decoder(z) 110 | 111 | class PQMFPretransform(Pretransform): 112 | def __init__(self, attenuation=100, num_bands=16): 113 | # TODO: Fix PQMF to take in in-channels 114 | super().__init__(enable_grad=False, io_channels=1, is_discrete=False) 115 | from .pqmf import PQMF 116 | self.pqmf = PQMF(attenuation, num_bands) 117 | 118 | 119 | def encode(self, x): 120 | # x is (Batch x Channels x Time) 121 | x = self.pqmf.forward(x) 122 | # pqmf.forward returns (Batch x Channels x Bands x Time) 123 | # but Pretransform needs Batch x Channels x Time 124 | # so concatenate channels and bands into one axis 125 | return rearrange(x, "b c n t -> b (c n) t") 126 | 127 | def decode(self, x): 128 | # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time) 129 | x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands) 130 | # returns (Batch x Channels x Time) 131 | return self.pqmf.inverse(x) 132 | 133 | class PretrainedDACPretransform(Pretransform): 134 | def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True): 135 | super().__init__(enable_grad=False, io_channels=1, is_discrete=True) 136 | 137 | import dac 138 | 139 | model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate) 140 | 141 | self.model = dac.DAC.load(model_path) 142 | 143 | self.quantize_on_decode = quantize_on_decode 144 | 145 | if model_type == "44khz": 146 | self.downsampling_ratio = 512 147 | else: 148 | self.downsampling_ratio = 320 149 | 150 | self.io_channels = 1 151 | 152 | self.scale = scale 153 | 154 | self.chunked = chunked 155 | 156 | self.encoded_channels = self.model.latent_dim 157 | 158 | self.num_quantizers = self.model.n_codebooks 159 | 160 | self.codebook_size = self.model.codebook_size 161 | 162 | def encode(self, x): 163 | 164 | latents = self.model.encoder(x) 165 | 166 | if self.quantize_on_decode: 167 | output = latents 168 | else: 169 | z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) 170 | output = z 171 | 172 | if self.scale != 1.0: 173 | output = output / self.scale 174 | 175 | return output 176 | 177 | def decode(self, z): 178 | 179 | if self.scale != 1.0: 180 | z = z * self.scale 181 | 182 | if self.quantize_on_decode: 183 | z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) 184 | 185 | return self.model.decode(z) 186 | 187 | def tokenize(self, x): 188 | return self.model.encode(x)[1] 189 | 190 | def decode_tokens(self, tokens): 191 | latents = self.model.quantizer.from_codes(tokens) 192 | return self.model.decode(latents) 193 | 194 | class AudiocraftCompressionPretransform(Pretransform): 195 | def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True): 196 | super().__init__(enable_grad=False, io_channels=1, is_discrete=True) 197 | 198 | try: 199 | from audiocraft.models import CompressionModel 200 | except ImportError: 201 | raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.") 202 | 203 | self.model = CompressionModel.get_pretrained(model_type) 204 | 205 | self.quantize_on_decode = quantize_on_decode 206 | 207 | self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate) 208 | 209 | self.sample_rate = self.model.sample_rate 210 | 211 | self.io_channels = self.model.channels 212 | 213 | self.scale = scale 214 | 215 | #self.encoded_channels = self.model.latent_dim 216 | 217 | self.num_quantizers = self.model.num_codebooks 218 | 219 | self.codebook_size = self.model.cardinality 220 | 221 | self.model.to(torch.float16).eval().requires_grad_(False) 222 | 223 | def encode(self, x): 224 | 225 | assert False, "Audiocraft compression models do not support continuous encoding" 226 | 227 | # latents = self.model.encoder(x) 228 | 229 | # if self.quantize_on_decode: 230 | # output = latents 231 | # else: 232 | # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) 233 | # output = z 234 | 235 | # if self.scale != 1.0: 236 | # output = output / self.scale 237 | 238 | # return output 239 | 240 | def decode(self, z): 241 | 242 | assert False, "Audiocraft compression models do not support continuous decoding" 243 | 244 | # if self.scale != 1.0: 245 | # z = z * self.scale 246 | 247 | # if self.quantize_on_decode: 248 | # z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) 249 | 250 | # return self.model.decode(z) 251 | 252 | def tokenize(self, x): 253 | with torch.cuda.amp.autocast(enabled=False): 254 | return self.model.encode(x.to(torch.float16))[0] 255 | 256 | def decode_tokens(self, tokens): 257 | with torch.cuda.amp.autocast(enabled=False): 258 | return self.model.decode(tokens) 259 | -------------------------------------------------------------------------------- /ENHANCED_VIDEO_TO_AUDIO.md: -------------------------------------------------------------------------------- 1 | # Enhanced AudioX Generation System 2 | 3 | This document describes the comprehensive improvements made to the ComfyUI-AudioX generation system to enhance prompt adherence and provide better control over text-to-audio, text-to-music, and video-to-audio generation. 4 | 5 | ## Issues Identified 6 | 7 | The original video-to-audio implementation had several limitations: 8 | 9 | 1. **Weak Text-Video Balance**: Text and video conditioning were processed separately without proper balancing 10 | 2. **Fixed CFG Scale**: Single CFG scale didn't allow fine-tuning text vs video influence 11 | 3. **No Negative Prompting**: Missing negative conditioning capabilities 12 | 4. **Basic Text Processing**: No audio-domain specific text preprocessing 13 | 5. **Limited Conditioning Control**: No user control over conditioning weights 14 | 15 | ## New Features 16 | 17 | ### 1. AudioXEnhancedTextToAudio Node 18 | 19 | An advanced text-to-audio generation node with enhanced prompt processing: 20 | 21 | #### Advanced Prompting Features 22 | - **Negative Prompting**: Specify audio characteristics to avoid (e.g., "muffled, distorted, low quality") 23 | - **Prompt Templates**: Pre-defined templates for common audio scenarios 24 | - **Style Modifiers**: cinematic, realistic, ambient, dramatic, peaceful, energetic 25 | - **Auto Enhancement**: Automatically adds audio-specific keywords 26 | 27 | #### Enhanced Controls 28 | - **Higher CFG Scale Range**: Up to 20.0 for stronger prompt adherence 29 | - **Template Integration**: Seamlessly combines base prompts with templates 30 | - **Quality Enhancement**: Automatic addition of quality terms 31 | 32 | #### Advanced Conditioning Features 33 | - **Conditioning Modes**: 34 | - `standard`: Basic conditioning 35 | - `enhanced`: Audio-specific keyword enhancement 36 | - `super_enhanced`: All enhancement techniques combined 37 | - `multi_aspect`: Multiple conditioning vectors for robust generation 38 | - **Adaptive CFG**: Automatically adjusts CFG scale based on prompt specificity 39 | - **Keyword Expansion**: Adds related audio terms and synonyms 40 | - **Term Emphasis**: Strategic repetition and emphasis of key terms 41 | - **Context-Aware Processing**: Understands prompt intent and enhances accordingly 42 | 43 | ### 2. AudioXEnhancedTextToMusic Node 44 | 45 | A specialized music generation node with musical attributes: 46 | 47 | #### Musical Style Controls 48 | - **Music Styles**: classical, jazz, electronic, ambient, rock, folk, cinematic 49 | - **Tempo Control**: slow, moderate, fast, very_fast 50 | - **Mood Settings**: happy, sad, peaceful, energetic, mysterious, dramatic 51 | - **Negative Prompting**: Avoid discordant, harsh, or atonal characteristics 52 | 53 | #### Smart Music Enhancement 54 | - **Automatic Music Context**: Ensures prompts are interpreted as musical 55 | - **Style Integration**: Combines multiple musical attributes intelligently 56 | - **Enhanced Descriptions**: Adds appropriate musical terminology 57 | 58 | ### 3. AudioXEnhancedVideoToAudio Node 59 | 60 | An advanced video-to-audio generation node with the following improvements: 61 | 62 | #### Separate CFG Controls 63 | - **Text CFG Scale**: Independent control over text conditioning strength (default: 7.0) 64 | - **Video CFG Scale**: Independent control over video conditioning strength (default: 7.0) 65 | - **Effective CFG**: Automatically calculated weighted average based on conditioning weights 66 | 67 | #### Conditioning Weight Controls 68 | - **Text Weight**: Control the influence of text conditioning (0.0-2.0, default: 1.0) 69 | - **Video Weight**: Control the influence of video conditioning (0.0-2.0, default: 1.0) 70 | - **Balanced Generation**: Allows fine-tuning the balance between following text prompts vs video content 71 | 72 | #### Advanced Prompting 73 | - **Negative Prompting**: Specify what audio characteristics to avoid 74 | - **Prompt Templates**: Pre-defined templates for common audio scenarios 75 | - **Auto Enhancement**: Automatically add audio-specific keywords to improve generation 76 | 77 | ### 4. AudioXPromptHelper Node 78 | 79 | A utility node for creating better audio prompts: 80 | 81 | #### Template Categories 82 | - **Music**: ambient, upbeat, dramatic, peaceful 83 | - **Nature**: forest, ocean, rain, wind 84 | - **Urban**: traffic, crowd, construction, cafe 85 | - **Action**: footsteps, running, impact, mechanical 86 | 87 | #### Enhancement Features 88 | - **Auto Enhancement**: Adds audio-specific context to prompts 89 | - **Quality Terms**: Adds quality enhancement terms like "high quality, clear" 90 | - **Style Modifiers**: cinematic, realistic, ambient, dramatic, peaceful, energetic 91 | - **Negative Prompts**: Suggests appropriate negative terms 92 | 93 | ### 5. Enhanced Conditioning Pipeline 94 | 95 | #### Audio-Specific Text Processing 96 | ```python 97 | def enhance_audio_prompt(text_prompt: str) -> str: 98 | """Enhance text prompt for better audio generation""" 99 | # Adds audio context if missing 100 | # Emphasizes audio-specific keywords 101 | # Ensures proper audio terminology 102 | ``` 103 | 104 | #### Better Text-Video Fusion 105 | ```python 106 | def create_enhanced_video_conditioning(video_tensor, text_prompt, 107 | text_weight=1.0, video_weight=1.0, 108 | negative_prompt=""): 109 | """Create enhanced conditioning with better text-video balance""" 110 | # Processes text with audio-specific enhancements 111 | # Applies conditioning weights 112 | # Includes negative prompting support 113 | ``` 114 | 115 | ## Usage Examples 116 | 117 | ### Enhanced Text-to-Audio Generation 118 | ```json 119 | { 120 | "text_prompt": "footsteps on wooden floor", 121 | "cfg_scale": 8.0, 122 | "negative_prompt": "muffled, distorted, low quality", 123 | "prompt_template": "action_footsteps", 124 | "style_modifier": "realistic", 125 | "enhance_prompt": true 126 | } 127 | ``` 128 | 129 | ### Enhanced Text-to-Music Generation 130 | ```json 131 | { 132 | "text_prompt": "peaceful piano melody", 133 | "cfg_scale": 7.0, 134 | "negative_prompt": "discordant, harsh, atonal", 135 | "music_style": "classical", 136 | "tempo": "slow", 137 | "mood": "peaceful", 138 | "enhance_prompt": true 139 | } 140 | ``` 141 | 142 | ### Enhanced Video-to-Audio Generation 143 | ```json 144 | { 145 | "text_prompt": "footsteps on wooden floor", 146 | "text_cfg_scale": 8.0, 147 | "video_cfg_scale": 6.0, 148 | "text_weight": 1.2, 149 | "video_weight": 0.8, 150 | "negative_prompt": "muffled, distorted, low quality" 151 | } 152 | ``` 153 | 154 | ### Using Prompt Templates 155 | ```json 156 | { 157 | "base_prompt": "person walking", 158 | "template": "action_footsteps", 159 | "style_modifier": "realistic", 160 | "enhance_automatically": true 161 | } 162 | ``` 163 | 164 | ## Best Practices 165 | 166 | ### For Better Prompt Adherence 167 | 168 | 1. **Use Specific Audio Descriptions** 169 | - Instead of: "person walking" 170 | - Use: "clear footsteps on wooden floor, steady rhythm" 171 | 172 | 2. **Balance Text and Video Weights** 173 | - High text weight (1.5-2.0) for specific audio requirements 174 | - High video weight (1.5-2.0) for video-synchronized audio 175 | - Balanced weights (1.0 each) for general video-to-audio 176 | 177 | 3. **Leverage Negative Prompting** 178 | - Always include: "muffled, distorted, low quality" 179 | - Add specific negatives: "echo, reverb" for dry sounds 180 | - Use "silence, quiet" to avoid empty audio 181 | 182 | 4. **Adjust CFG Scales** 183 | - Higher text CFG (8.0-12.0) for strong prompt adherence 184 | - Lower video CFG (4.0-6.0) when text is very specific 185 | - Balanced CFG (7.0 each) for general use 186 | 187 | ### Prompt Templates Guide 188 | 189 | #### Music Generation 190 | - **ambient**: "ambient atmospheric music, soft melodic tones" 191 | - **upbeat**: "upbeat energetic music, rhythmic and lively" 192 | - **dramatic**: "dramatic cinematic music, intense and emotional" 193 | 194 | #### Nature Sounds 195 | - **forest**: "natural forest sounds, birds chirping, leaves rustling" 196 | - **ocean**: "ocean waves, water sounds, peaceful seaside ambience" 197 | - **rain**: "gentle rain sounds, water droplets, calming precipitation" 198 | 199 | #### Action Sounds 200 | - **footsteps**: "footsteps walking, movement sounds, human activity" 201 | - **impact**: "impact sounds, hitting, collision effects" 202 | - **mechanical**: "mechanical sounds, machine operation, industrial audio" 203 | 204 | ## Technical Implementation 205 | 206 | ### Conditioning Flow 207 | 1. **Text Enhancement**: Audio-specific keyword processing 208 | 2. **Template Application**: Pre-defined prompt templates 209 | 3. **Weight Application**: Balanced text-video conditioning 210 | 4. **CFG Calculation**: Weighted average of separate CFG scales 211 | 5. **Generation**: Enhanced conditioning pipeline 212 | 213 | ### Key Functions 214 | - `enhance_audio_prompt()`: Audio-specific text processing 215 | - `create_enhanced_video_conditioning()`: Advanced conditioning creation 216 | - `get_audio_prompt_templates()`: Template management 217 | 218 | ## Workflow Examples 219 | 220 | See the included workflow files: 221 | - `examples/enhanced_video_to_audio_workflow.json`: Complete enhanced workflow 222 | - `examples/simple_video_to_audio_workflow.json`: Basic workflow for comparison 223 | 224 | ## Performance Notes 225 | 226 | - Enhanced conditioning adds minimal computational overhead 227 | - Text processing is lightweight and fast 228 | - CFG calculation is optimized for real-time adjustment 229 | - Template system provides instant prompt improvements 230 | 231 | ## Future Improvements 232 | 233 | Potential areas for further enhancement: 234 | 1. **Cross-Modal Attention**: Direct attention between text and video features 235 | 2. **Audio-Aware Text Encoding**: CLAP-based text encoding for better audio alignment 236 | 3. **Hierarchical Conditioning**: Separate global and local audio descriptions 237 | 4. **Adaptive CFG**: Dynamic CFG adjustment based on prompt specificity 238 | 5. **Real-time Preview**: Quick audio previews during parameter adjustment 239 | -------------------------------------------------------------------------------- /audiox/stable_audio_tools/training/lm.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import sys, gc 3 | import random 4 | import torch 5 | import torchaudio 6 | import typing as tp 7 | import wandb 8 | 9 | from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image 10 | from ema_pytorch import EMA 11 | from einops import rearrange 12 | from safetensors.torch import save_file 13 | from torch import optim 14 | from torch.nn import functional as F 15 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 16 | 17 | from ..models.lm import AudioLanguageModelWrapper 18 | from .utils import create_optimizer_from_config, create_scheduler_from_config 19 | 20 | class AudioLanguageModelTrainingWrapper(pl.LightningModule): 21 | def __init__( 22 | self, 23 | model: AudioLanguageModelWrapper, 24 | lr = 1e-4, 25 | use_ema=False, 26 | ema_copy=None, 27 | optimizer_configs: dict = None, 28 | pre_encoded=False 29 | ): 30 | super().__init__() 31 | 32 | self.model = model 33 | 34 | self.model.pretransform.requires_grad_(False) 35 | 36 | self.model_ema = None 37 | if use_ema: 38 | self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10) 39 | 40 | assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" 41 | 42 | if optimizer_configs is None: 43 | optimizer_configs = { 44 | "lm": { 45 | "optimizer": { 46 | "type": "AdamW", 47 | "config": { 48 | "lr": lr, 49 | "betas": (0.9, 0.95), 50 | "weight_decay": 0.1 51 | } 52 | } 53 | } 54 | } 55 | else: 56 | if lr is not None: 57 | print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") 58 | 59 | self.optimizer_configs = optimizer_configs 60 | 61 | self.pre_encoded = pre_encoded 62 | 63 | def configure_optimizers(self): 64 | lm_opt_config = self.optimizer_configs['lm'] 65 | opt_lm = create_optimizer_from_config(lm_opt_config['optimizer'], self.model.parameters()) 66 | 67 | if "scheduler" in lm_opt_config: 68 | sched_lm = create_scheduler_from_config(lm_opt_config['scheduler'], opt_lm) 69 | sched_lm_config = { 70 | "scheduler": sched_lm, 71 | "interval": "step" 72 | } 73 | return [opt_lm], [sched_lm_config] 74 | 75 | return [opt_lm] 76 | 77 | # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/solvers/musicgen.py under MIT license 78 | # License can be found in LICENSES/LICENSE_META.txt 79 | 80 | def _compute_cross_entropy( 81 | self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor 82 | ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: 83 | """Compute cross entropy between multi-codebook targets and model's logits. 84 | The cross entropy is computed per codebook to provide codebook-level cross entropy. 85 | Valid timesteps for each of the codebook are pulled from the mask, where invalid 86 | timesteps are set to 0. 87 | 88 | Args: 89 | logits (torch.Tensor): Model's logits of shape [B, K, T, card]. 90 | targets (torch.Tensor): Target codes, of shape [B, K, T]. 91 | mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. 92 | Returns: 93 | ce (torch.Tensor): Cross entropy averaged over the codebooks 94 | ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). 95 | """ 96 | B, K, T = targets.shape 97 | assert logits.shape[:-1] == targets.shape 98 | assert mask.shape == targets.shape 99 | ce = torch.zeros([], device=targets.device) 100 | ce_per_codebook: tp.List[torch.Tensor] = [] 101 | for k in range(K): 102 | logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card] 103 | targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T] 104 | mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T] 105 | ce_targets = targets_k[mask_k] 106 | ce_logits = logits_k[mask_k] 107 | q_ce = F.cross_entropy(ce_logits, ce_targets) 108 | ce += q_ce 109 | ce_per_codebook.append(q_ce.detach()) 110 | # average cross entropy across codebooks 111 | ce = ce / K 112 | return ce, ce_per_codebook 113 | 114 | def training_step(self, batch, batch_idx): 115 | reals, metadata = batch 116 | 117 | if reals.ndim == 4 and reals.shape[0] == 1: 118 | reals = reals[0] 119 | 120 | if not self.pre_encoded: 121 | codes = self.model.pretransform.tokenize(reals) 122 | else: 123 | codes = reals 124 | 125 | padding_masks = [] 126 | for md in metadata: 127 | if md["padding_mask"].ndim == 1: 128 | padding_masks.append(md["padding_mask"]) 129 | else: 130 | padding_masks.append(md["padding_mask"][0]) 131 | 132 | padding_masks = torch.stack(padding_masks, dim=0).to(self.device) # Shape (batch_size, sequence_length) 133 | 134 | # Interpolate padding masks to the same length as the codes 135 | padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=codes.shape[2], mode='nearest').bool() 136 | 137 | condition_tensors = None 138 | 139 | # If the model is conditioned, get the conditioning tensors 140 | if self.model.conditioner is not None: 141 | condition_tensors = self.model.conditioner(metadata, self.device) 142 | 143 | lm_output = self.model.compute_logits(codes, condition_tensors=condition_tensors, cfg_dropout_prob=0.1) 144 | 145 | logits = lm_output.logits # [b, k, t, c] 146 | logits_mask = lm_output.mask # [b, k, t] 147 | 148 | logits_mask = logits_mask & padding_masks 149 | 150 | cross_entropy, cross_entropy_per_codebook = self._compute_cross_entropy(logits, codes, logits_mask) 151 | 152 | loss = cross_entropy 153 | 154 | log_dict = { 155 | 'train/loss': loss.detach(), 156 | 'train/cross_entropy': cross_entropy.detach(), 157 | 'train/perplexity': torch.exp(cross_entropy).detach(), 158 | 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] 159 | } 160 | 161 | for k, ce_q in enumerate(cross_entropy_per_codebook): 162 | log_dict[f'cross_entropy_q{k + 1}'] = ce_q 163 | log_dict[f'perplexity_q{k + 1}'] = torch.exp(ce_q) 164 | 165 | self.log_dict(log_dict, prog_bar=True, on_step=True) 166 | return loss 167 | 168 | def on_before_zero_grad(self, *args, **kwargs): 169 | if self.model_ema is not None: 170 | self.model_ema.update() 171 | 172 | def export_model(self, path, use_safetensors=False): 173 | 174 | model = self.model_ema.ema_model if self.model_ema is not None else self.model 175 | 176 | if use_safetensors: 177 | save_file(model.state_dict(), path) 178 | else: 179 | torch.save({"state_dict": model.state_dict()}, path) 180 | 181 | 182 | class AudioLanguageModelDemoCallback(pl.Callback): 183 | def __init__(self, 184 | demo_every=2000, 185 | num_demos=8, 186 | sample_size=65536, 187 | sample_rate=48000, 188 | demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, 189 | demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], 190 | **kwargs 191 | ): 192 | super().__init__() 193 | 194 | self.demo_every = demo_every 195 | self.num_demos = num_demos 196 | self.demo_samples = sample_size 197 | self.sample_rate = sample_rate 198 | self.last_demo_step = -1 199 | self.demo_conditioning = demo_conditioning 200 | self.demo_cfg_scales = demo_cfg_scales 201 | 202 | @rank_zero_only 203 | @torch.no_grad() 204 | def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, outputs, batch, batch_idx): 205 | 206 | if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: 207 | return 208 | 209 | module.eval() 210 | 211 | print(f"Generating demo") 212 | self.last_demo_step = trainer.global_step 213 | 214 | demo_length_tokens = self.demo_samples // module.model.pretransform.downsampling_ratio 215 | 216 | #demo_reals = batch[0][:self.num_demos] 217 | 218 | # if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: 219 | # demo_reals = demo_reals[0] 220 | 221 | #demo_reals_tokens = module.model.pretransform.tokenize(demo_reals) 222 | 223 | ##Limit to first 50 tokens 224 | #demo_reals_tokens = demo_reals_tokens[:, :, :50] 225 | 226 | try: 227 | print("Getting conditioning") 228 | 229 | for cfg_scale in self.demo_cfg_scales: 230 | 231 | model = module.model # module.model_ema.ema_model if module.model_ema is not None else module.model 232 | 233 | print(f"Generating demo for cfg scale {cfg_scale}") 234 | fakes = model.generate_audio( 235 | batch_size=self.num_demos, 236 | max_gen_len=demo_length_tokens, 237 | conditioning=self.demo_conditioning, 238 | #init_data = demo_reals_tokens, 239 | cfg_scale=cfg_scale, 240 | temp=1.0, 241 | top_p=0.95 242 | ) 243 | 244 | # Put the demos together 245 | fakes = rearrange(fakes, 'b d n -> d (b n)') 246 | 247 | log_dict = {} 248 | 249 | filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' 250 | fakes = fakes / fakes.abs().max() 251 | fakes = fakes.type(torch.float32).clamp(-1, 1).mul(32767).type(torch.int16).cpu() 252 | torchaudio.save(filename, fakes, self.sample_rate) 253 | 254 | log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, 255 | sample_rate=self.sample_rate, 256 | caption=f'Reconstructed') 257 | 258 | log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) 259 | 260 | trainer.logger.experiment.log(log_dict) 261 | 262 | except Exception as e: 263 | raise e 264 | finally: 265 | gc.collect() 266 | torch.cuda.empty_cache() 267 | module.train() -------------------------------------------------------------------------------- /audiox/stable_audio_tools/inference/sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from tqdm import trange, tqdm 4 | 5 | # Import k_diffusion with error handling 6 | try: 7 | import k_diffusion as K 8 | # Verify that external module is available 9 | if not hasattr(K, 'external'): 10 | raise ImportError("k_diffusion.external not found") 11 | if not hasattr(K.external, 'VDenoiser'): 12 | raise ImportError("k_diffusion.external.VDenoiser not found") 13 | except ImportError as e: 14 | print(f"k_diffusion import error: {e}") 15 | # Try alternative import 16 | try: 17 | import importlib 18 | K = importlib.import_module('k_diffusion') 19 | if not hasattr(K, 'external') or not hasattr(K.external, 'VDenoiser'): 20 | raise ImportError("k_diffusion.external.VDenoiser not available") 21 | except Exception as e2: 22 | raise ImportError(f"Could not import k_diffusion: {e}, {e2}") 23 | 24 | # Define the noise schedule and sampling loop 25 | def get_alphas_sigmas(t): 26 | """Returns the scaling factors for the clean image (alpha) and for the 27 | noise (sigma), given a timestep.""" 28 | return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) 29 | 30 | def alpha_sigma_to_t(alpha, sigma): 31 | """Returns a timestep, given the scaling factors for the clean image and for 32 | the noise.""" 33 | return torch.atan2(sigma, alpha) / math.pi * 2 34 | 35 | def t_to_alpha_sigma(t): 36 | """Returns the scaling factors for the clean image and for the noise, given 37 | a timestep.""" 38 | return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) 39 | 40 | 41 | @torch.no_grad() 42 | def sample_discrete_euler(model, x, steps, sigma_max=1, **extra_args): 43 | """Draws samples from a model given starting noise. Euler method""" 44 | 45 | # Make tensor of ones to broadcast the single t values 46 | ts = x.new_ones([x.shape[0]]) 47 | 48 | # Create the noise schedule 49 | t = torch.linspace(sigma_max, 0, steps + 1) 50 | 51 | #alphas, sigmas = 1-t, t 52 | 53 | for t_curr, t_prev in tqdm(zip(t[:-1], t[1:])): 54 | # Broadcast the current timestep to the correct shape 55 | t_curr_tensor = t_curr * torch.ones( 56 | (x.shape[0],), dtype=x.dtype, device=x.device 57 | ) 58 | dt = t_prev - t_curr # we solve backwards in our formulation 59 | x = x + dt * model(x, t_curr_tensor, **extra_args) #.denoise(x, denoiser, t_curr_tensor, cond, uc) 60 | 61 | # If we are on the last timestep, output the denoised image 62 | return x 63 | 64 | @torch.no_grad() 65 | def sample(model, x, steps, eta, **extra_args): 66 | """Draws samples from a model given starting noise. v-diffusion""" 67 | ts = x.new_ones([x.shape[0]]) 68 | 69 | # Create the noise schedule 70 | t = torch.linspace(1, 0, steps + 1)[:-1] 71 | 72 | alphas, sigmas = get_alphas_sigmas(t) 73 | 74 | # The sampling loop 75 | for i in trange(steps): 76 | 77 | # Get the model output (v, the predicted velocity) 78 | with torch.cuda.amp.autocast(): 79 | v = model(x, ts * t[i], **extra_args).float() 80 | 81 | # Predict the noise and the denoised image 82 | pred = x * alphas[i] - v * sigmas[i] 83 | eps = x * sigmas[i] + v * alphas[i] 84 | 85 | # If we are not on the last timestep, compute the noisy image for the 86 | # next timestep. 87 | if i < steps - 1: 88 | # If eta > 0, adjust the scaling factor for the predicted noise 89 | # downward according to the amount of additional noise to add 90 | ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \ 91 | (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt() 92 | adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt() 93 | 94 | # Recombine the predicted noise and predicted denoised image in the 95 | # correct proportions for the next step 96 | x = pred * alphas[i + 1] + eps * adjusted_sigma 97 | 98 | # Add the correct amount of fresh noise 99 | if eta: 100 | x += torch.randn_like(x) * ddim_sigma 101 | 102 | # If we are on the last timestep, output the denoised image 103 | return pred 104 | 105 | # Soft mask inpainting is just shrinking hard (binary) mask inpainting 106 | # Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step 107 | def get_bmask(i, steps, mask): 108 | strength = (i+1)/(steps) 109 | # convert to binary mask 110 | bmask = torch.where(mask<=strength,1,0) 111 | return bmask 112 | 113 | def make_cond_model_fn(model, cond_fn): 114 | def cond_model_fn(x, sigma, **kwargs): 115 | with torch.enable_grad(): 116 | x = x.detach().requires_grad_() 117 | denoised = model(x, sigma, **kwargs) 118 | cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach() 119 | cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim) 120 | return cond_denoised 121 | return cond_model_fn 122 | 123 | # Uses k-diffusion from https://github.com/crowsonkb/k-diffusion 124 | # init_data is init_audio as latents (if this is latent diffusion) 125 | # For sampling, set both init_data and mask to None 126 | # For variations, set init_data 127 | # For inpainting, set both init_data & mask 128 | def sample_k( 129 | model_fn, 130 | noise, 131 | init_data=None, 132 | mask=None, 133 | steps=100, 134 | sampler_type="dpmpp-2m-sde", 135 | sigma_min=0.5, 136 | sigma_max=50, 137 | rho=1.0, device="cuda", 138 | callback=None, 139 | cond_fn=None, 140 | **extra_args 141 | ): 142 | 143 | # Create VDenoiser 144 | try: 145 | denoiser = K.external.VDenoiser(model_fn) 146 | except Exception as e: 147 | raise RuntimeError(f"Failed to create K.external.VDenoiser: {e}. This often indicates an incompatible k-diffusion version. ComfyUI-AudioX expects a version like 0.0.14. Please check your installed k-diffusion version.") 148 | 149 | if cond_fn is not None: 150 | denoiser = make_cond_model_fn(denoiser, cond_fn) 151 | 152 | # Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has 153 | sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device) 154 | # Scale the initial noise by sigma 155 | noise = noise * sigmas[0] 156 | 157 | wrapped_callback = callback 158 | 159 | 160 | if mask is None and init_data is not None: 161 | # VARIATION (no inpainting) 162 | # set the initial latent to the init_data, and noise it with initial sigma 163 | 164 | x = init_data + noise 165 | 166 | elif mask is not None and init_data is not None: 167 | # INPAINTING 168 | bmask = get_bmask(0, steps, mask) 169 | # initial noising 170 | input_noised = init_data + noise 171 | # set the initial latent to a mix of init_data and noise, based on step 0's binary mask 172 | x = input_noised * bmask + noise * (1-bmask) 173 | # define the inpainting callback function (Note: side effects, it mutates x) 174 | # See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105 175 | # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 176 | # This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)` 177 | def inpainting_callback(args): 178 | i = args["i"] 179 | x = args["x"] 180 | sigma = args["sigma"] 181 | #denoised = args["denoised"] 182 | # noise the init_data input with this step's appropriate amount of noise 183 | input_noised = init_data + torch.randn_like(init_data) * sigma 184 | # shrinking hard mask 185 | bmask = get_bmask(i, steps, mask) 186 | # mix input_noise with x, using binary mask 187 | new_x = input_noised * bmask + x * (1-bmask) 188 | # mutate x 189 | x[:,:,:] = new_x[:,:,:] 190 | # wrap together the inpainting callback and the user-submitted callback. 191 | if callback is None: 192 | wrapped_callback = inpainting_callback 193 | else: 194 | wrapped_callback = lambda args: (inpainting_callback(args), callback(args)) 195 | else: 196 | # SAMPLING 197 | # set the initial latent to noise 198 | x = noise 199 | # x = noise 200 | 201 | with torch.cuda.amp.autocast(): 202 | if sampler_type == "k-heun": 203 | return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) 204 | elif sampler_type == "k-lms": 205 | return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) 206 | elif sampler_type == "k-dpmpp-2s-ancestral": 207 | return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) 208 | elif sampler_type == "k-dpm-2": 209 | return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) 210 | elif sampler_type == "k-dpm-fast": 211 | return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args) 212 | elif sampler_type == "k-dpm-adaptive": 213 | return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args) 214 | elif sampler_type == "dpmpp-2m-sde": 215 | return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) 216 | elif sampler_type == "dpmpp-3m-sde": 217 | return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) 218 | 219 | # Uses discrete Euler sampling for rectified flow models 220 | # init_data is init_audio as latents (if this is latent diffusion) 221 | # For sampling, set both init_data and mask to None 222 | # For variations, set init_data 223 | # For inpainting, set both init_data & mask 224 | def sample_rf( 225 | model_fn, 226 | noise, 227 | init_data=None, 228 | steps=100, 229 | sigma_max=1, 230 | device="cuda", 231 | callback=None, 232 | cond_fn=None, 233 | **extra_args 234 | ): 235 | 236 | if sigma_max > 1: 237 | sigma_max = 1 238 | 239 | if cond_fn is not None: 240 | denoiser = make_cond_model_fn(denoiser, cond_fn) 241 | 242 | wrapped_callback = callback 243 | 244 | if init_data is not None: 245 | # VARIATION (no inpainting) 246 | # Interpolate the init data and the noise for init audio 247 | x = init_data * (1 - sigma_max) + noise * sigma_max 248 | else: 249 | # SAMPLING 250 | # set the initial latent to noise 251 | x = noise 252 | 253 | with torch.cuda.amp.autocast(): 254 | # TODO: Add callback support 255 | #return sample_discrete_euler(model_fn, x, steps, sigma_max, callback=wrapped_callback, **extra_args) 256 | return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args) -------------------------------------------------------------------------------- /audiox/stable_audio_tools/training/factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from ..models.factory import create_model_from_config 4 | 5 | def create_training_wrapper_from_config(model_config, model): 6 | model_type = model_config.get('model_type', None) 7 | assert model_type is not None, 'model_type must be specified in model config' 8 | 9 | training_config = model_config.get('training', None) 10 | assert training_config is not None, 'training config must be specified in model config' 11 | 12 | if model_type == 'autoencoder': 13 | from .autoencoders import AutoencoderTrainingWrapper 14 | 15 | ema_copy = None 16 | 17 | if training_config.get("use_ema", False): 18 | ema_copy = create_model_from_config(model_config) 19 | ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once 20 | # Copy each weight to the ema copy 21 | for name, param in model.state_dict().items(): 22 | if isinstance(param, Parameter): 23 | # backwards compatibility for serialized parameters 24 | param = param.data 25 | ema_copy.state_dict()[name].copy_(param) 26 | 27 | use_ema = training_config.get("use_ema", False) 28 | 29 | latent_mask_ratio = training_config.get("latent_mask_ratio", 0.0) 30 | 31 | teacher_model = training_config.get("teacher_model", None) 32 | if teacher_model is not None: 33 | teacher_model = create_model_from_config(teacher_model) 34 | teacher_model = teacher_model.eval().requires_grad_(False) 35 | 36 | teacher_model_ckpt = training_config.get("teacher_model_ckpt", None) 37 | if teacher_model_ckpt is not None: 38 | teacher_model.load_state_dict(torch.load(teacher_model_ckpt)["state_dict"]) 39 | else: 40 | raise ValueError("teacher_model_ckpt must be specified if teacher_model is specified") 41 | 42 | return AutoencoderTrainingWrapper( 43 | model, 44 | lr=training_config["learning_rate"], 45 | warmup_steps=training_config.get("warmup_steps", 0), 46 | encoder_freeze_on_warmup=training_config.get("encoder_freeze_on_warmup", False), 47 | sample_rate=model_config["sample_rate"], 48 | loss_config=training_config.get("loss_configs", None), 49 | optimizer_configs=training_config.get("optimizer_configs", None), 50 | use_ema=use_ema, 51 | ema_copy=ema_copy if use_ema else None, 52 | force_input_mono=training_config.get("force_input_mono", False), 53 | latent_mask_ratio=latent_mask_ratio, 54 | teacher_model=teacher_model 55 | ) 56 | elif model_type == 'diffusion_uncond': 57 | from .diffusion import DiffusionUncondTrainingWrapper 58 | return DiffusionUncondTrainingWrapper( 59 | model, 60 | lr=training_config["learning_rate"], 61 | pre_encoded=training_config.get("pre_encoded", False), 62 | ) 63 | elif model_type == 'diffusion_cond': 64 | from .diffusion import DiffusionCondTrainingWrapper 65 | return DiffusionCondTrainingWrapper( 66 | model, 67 | lr=training_config.get("learning_rate", None), 68 | mask_padding=training_config.get("mask_padding", False), 69 | mask_padding_dropout=training_config.get("mask_padding_dropout", 0.0), 70 | use_ema = training_config.get("use_ema", True), 71 | log_loss_info=training_config.get("log_loss_info", False), 72 | optimizer_configs=training_config.get("optimizer_configs", None), 73 | pre_encoded=training_config.get("pre_encoded", False), 74 | cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1), 75 | timestep_sampler = training_config.get("timestep_sampler", "uniform") 76 | ) 77 | elif model_type == 'diffusion_prior': 78 | from .diffusion import DiffusionPriorTrainingWrapper 79 | from ..models.diffusion_prior import PriorType 80 | 81 | ema_copy = create_model_from_config(model_config) 82 | 83 | # Copy each weight to the ema copy 84 | for name, param in model.state_dict().items(): 85 | if isinstance(param, Parameter): 86 | # backwards compatibility for serialized parameters 87 | param = param.data 88 | ema_copy.state_dict()[name].copy_(param) 89 | 90 | prior_type = training_config.get("prior_type", "mono_stereo") 91 | 92 | if prior_type == "mono_stereo": 93 | prior_type_enum = PriorType.MonoToStereo 94 | else: 95 | raise ValueError(f"Unknown prior type: {prior_type}") 96 | 97 | return DiffusionPriorTrainingWrapper( 98 | model, 99 | lr=training_config["learning_rate"], 100 | ema_copy=ema_copy, 101 | prior_type=prior_type_enum, 102 | log_loss_info=training_config.get("log_loss_info", False), 103 | use_reconstruction_loss=training_config.get("use_reconstruction_loss", False), 104 | ) 105 | elif model_type == 'diffusion_cond_inpaint': 106 | from .diffusion import DiffusionCondInpaintTrainingWrapper 107 | return DiffusionCondInpaintTrainingWrapper( 108 | model, 109 | lr=training_config.get("learning_rate", None), 110 | max_mask_segments = training_config.get("max_mask_segments", 10), 111 | log_loss_info=training_config.get("log_loss_info", False), 112 | optimizer_configs=training_config.get("optimizer_configs", None), 113 | use_ema=training_config.get("use_ema", True), 114 | pre_encoded=training_config.get("pre_encoded", False), 115 | cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1), 116 | timestep_sampler = training_config.get("timestep_sampler", "uniform") 117 | ) 118 | elif model_type == 'diffusion_autoencoder': 119 | from .diffusion import DiffusionAutoencoderTrainingWrapper 120 | 121 | ema_copy = create_model_from_config(model_config) 122 | 123 | # Copy each weight to the ema copy 124 | for name, param in model.state_dict().items(): 125 | if isinstance(param, Parameter): 126 | # backwards compatibility for serialized parameters 127 | param = param.data 128 | ema_copy.state_dict()[name].copy_(param) 129 | 130 | return DiffusionAutoencoderTrainingWrapper( 131 | model, 132 | ema_copy=ema_copy, 133 | lr=training_config["learning_rate"], 134 | use_reconstruction_loss=training_config.get("use_reconstruction_loss", False) 135 | ) 136 | elif model_type == 'lm': 137 | from .lm import AudioLanguageModelTrainingWrapper 138 | 139 | ema_copy = create_model_from_config(model_config) 140 | 141 | for name, param in model.state_dict().items(): 142 | if isinstance(param, Parameter): 143 | # backwards compatibility for serialized parameters 144 | param = param.data 145 | ema_copy.state_dict()[name].copy_(param) 146 | 147 | return AudioLanguageModelTrainingWrapper( 148 | model, 149 | ema_copy=ema_copy, 150 | lr=training_config.get("learning_rate", None), 151 | use_ema=training_config.get("use_ema", False), 152 | optimizer_configs=training_config.get("optimizer_configs", None), 153 | pre_encoded=training_config.get("pre_encoded", False), 154 | ) 155 | 156 | else: 157 | raise NotImplementedError(f'Unknown model type: {model_type}') 158 | 159 | def create_demo_callback_from_config(model_config, **kwargs): 160 | model_type = model_config.get('model_type', None) 161 | assert model_type is not None, 'model_type must be specified in model config' 162 | 163 | training_config = model_config.get('training', None) 164 | assert training_config is not None, 'training config must be specified in model config' 165 | 166 | demo_config = training_config.get("demo", {}) 167 | 168 | if model_type == 'autoencoder': 169 | from .autoencoders import AutoencoderDemoCallback 170 | return AutoencoderDemoCallback( 171 | demo_every=demo_config.get("demo_every", 2000), 172 | sample_size=model_config["sample_size"], 173 | sample_rate=model_config["sample_rate"], 174 | **kwargs 175 | ) 176 | elif model_type == 'diffusion_uncond': 177 | from .diffusion import DiffusionUncondDemoCallback 178 | return DiffusionUncondDemoCallback( 179 | demo_every=demo_config.get("demo_every", 2000), 180 | demo_steps=demo_config.get("demo_steps", 250), 181 | sample_rate=model_config["sample_rate"] 182 | ) 183 | elif model_type == "diffusion_autoencoder": 184 | from .diffusion import DiffusionAutoencoderDemoCallback 185 | return DiffusionAutoencoderDemoCallback( 186 | demo_every=demo_config.get("demo_every", 2000), 187 | demo_steps=demo_config.get("demo_steps", 250), 188 | sample_size=model_config["sample_size"], 189 | sample_rate=model_config["sample_rate"], 190 | **kwargs 191 | ) 192 | elif model_type == "diffusion_prior": 193 | from .diffusion import DiffusionPriorDemoCallback 194 | return DiffusionPriorDemoCallback( 195 | demo_every=demo_config.get("demo_every", 2000), 196 | demo_steps=demo_config.get("demo_steps", 250), 197 | sample_size=model_config["sample_size"], 198 | sample_rate=model_config["sample_rate"], 199 | **kwargs 200 | ) 201 | elif model_type == "diffusion_cond": 202 | from .diffusion import DiffusionCondDemoCallback 203 | 204 | return DiffusionCondDemoCallback( 205 | demo_every=demo_config.get("demo_every", 2000), 206 | sample_size=model_config["sample_size"], 207 | sample_rate=model_config["sample_rate"], 208 | demo_steps=demo_config.get("demo_steps", 250), 209 | num_demos=demo_config["num_demos"], 210 | demo_cfg_scales=demo_config["demo_cfg_scales"], 211 | demo_conditioning=demo_config.get("demo_cond", {}), 212 | demo_cond_from_batch=demo_config.get("demo_cond_from_batch", False), 213 | display_audio_cond=demo_config.get("display_audio_cond", False), 214 | ) 215 | elif model_type == "diffusion_cond_inpaint": 216 | from .diffusion import DiffusionCondInpaintDemoCallback 217 | 218 | return DiffusionCondInpaintDemoCallback( 219 | demo_every=demo_config.get("demo_every", 2000), 220 | sample_size=model_config["sample_size"], 221 | sample_rate=model_config["sample_rate"], 222 | demo_steps=demo_config.get("demo_steps", 250), 223 | demo_cfg_scales=demo_config["demo_cfg_scales"], 224 | **kwargs 225 | ) 226 | 227 | elif model_type == "lm": 228 | from .lm import AudioLanguageModelDemoCallback 229 | 230 | return AudioLanguageModelDemoCallback( 231 | demo_every=demo_config.get("demo_every", 2000), 232 | sample_size=model_config["sample_size"], 233 | sample_rate=model_config["sample_rate"], 234 | demo_cfg_scales=demo_config.get("demo_cfg_scales", [1]), 235 | demo_conditioning=demo_config.get("demo_cond", None), 236 | num_demos=demo_config.get("num_demos", 8), 237 | **kwargs 238 | ) 239 | else: 240 | raise NotImplementedError(f'Unknown model type: {model_type}') -------------------------------------------------------------------------------- /audiox/stable_audio_tools/models/bottleneck.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from einops import rearrange 7 | from vector_quantize_pytorch import ResidualVQ, FSQ 8 | from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ 9 | 10 | class Bottleneck(nn.Module): 11 | def __init__(self, is_discrete: bool = False): 12 | super().__init__() 13 | 14 | self.is_discrete = is_discrete 15 | 16 | def encode(self, x, return_info=False, **kwargs): 17 | raise NotImplementedError 18 | 19 | def decode(self, x): 20 | raise NotImplementedError 21 | 22 | class DiscreteBottleneck(Bottleneck): 23 | def __init__(self, num_quantizers, codebook_size, tokens_id): 24 | super().__init__(is_discrete=True) 25 | 26 | self.num_quantizers = num_quantizers 27 | self.codebook_size = codebook_size 28 | self.tokens_id = tokens_id 29 | 30 | def decode_tokens(self, codes, **kwargs): 31 | raise NotImplementedError 32 | 33 | class TanhBottleneck(Bottleneck): 34 | def __init__(self): 35 | super().__init__(is_discrete=False) 36 | self.tanh = nn.Tanh() 37 | 38 | def encode(self, x, return_info=False): 39 | info = {} 40 | 41 | x = torch.tanh(x) 42 | 43 | if return_info: 44 | return x, info 45 | else: 46 | return x 47 | 48 | def decode(self, x): 49 | return x 50 | 51 | def vae_sample(mean, scale): 52 | stdev = nn.functional.softplus(scale) + 1e-4 53 | var = stdev * stdev 54 | logvar = torch.log(var) 55 | latents = torch.randn_like(mean) * stdev + mean 56 | 57 | kl = (mean * mean + var - logvar - 1).sum(1).mean() 58 | 59 | return latents, kl 60 | 61 | class VAEBottleneck(Bottleneck): 62 | def __init__(self): 63 | super().__init__(is_discrete=False) 64 | 65 | def encode(self, x, return_info=False, **kwargs): 66 | info = {} 67 | 68 | mean, scale = x.chunk(2, dim=1) 69 | 70 | x, kl = vae_sample(mean, scale) 71 | 72 | info["kl"] = kl 73 | 74 | if return_info: 75 | return x, info 76 | else: 77 | return x 78 | 79 | def decode(self, x): 80 | return x 81 | 82 | def compute_mean_kernel(x, y): 83 | kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1] 84 | return torch.exp(-kernel_input).mean() 85 | 86 | def compute_mmd(latents): 87 | latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1]) 88 | noise = torch.randn_like(latents_reshaped) 89 | 90 | latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped) 91 | noise_kernel = compute_mean_kernel(noise, noise) 92 | latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise) 93 | 94 | mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel 95 | return mmd.mean() 96 | 97 | class WassersteinBottleneck(Bottleneck): 98 | def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False): 99 | super().__init__(is_discrete=False) 100 | 101 | self.noise_augment_dim = noise_augment_dim 102 | self.bypass_mmd = bypass_mmd 103 | 104 | def encode(self, x, return_info=False): 105 | info = {} 106 | 107 | if self.training and return_info: 108 | if self.bypass_mmd: 109 | mmd = torch.tensor(0.0) 110 | else: 111 | mmd = compute_mmd(x) 112 | 113 | info["mmd"] = mmd 114 | 115 | if return_info: 116 | return x, info 117 | 118 | return x 119 | 120 | def decode(self, x): 121 | 122 | if self.noise_augment_dim > 0: 123 | noise = torch.randn(x.shape[0], self.noise_augment_dim, 124 | x.shape[-1]).type_as(x) 125 | x = torch.cat([x, noise], dim=1) 126 | 127 | return x 128 | 129 | class L2Bottleneck(Bottleneck): 130 | def __init__(self): 131 | super().__init__(is_discrete=False) 132 | 133 | def encode(self, x, return_info=False): 134 | info = {} 135 | 136 | x = F.normalize(x, dim=1) 137 | 138 | if return_info: 139 | return x, info 140 | else: 141 | return x 142 | 143 | def decode(self, x): 144 | return F.normalize(x, dim=1) 145 | 146 | class RVQBottleneck(DiscreteBottleneck): 147 | def __init__(self, **quantizer_kwargs): 148 | super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") 149 | self.quantizer = ResidualVQ(**quantizer_kwargs) 150 | self.num_quantizers = quantizer_kwargs["num_quantizers"] 151 | 152 | def encode(self, x, return_info=False, **kwargs): 153 | info = {} 154 | 155 | x = rearrange(x, "b c n -> b n c") 156 | x, indices, loss = self.quantizer(x) 157 | x = rearrange(x, "b n c -> b c n") 158 | 159 | info["quantizer_indices"] = indices 160 | info["quantizer_loss"] = loss.mean() 161 | 162 | if return_info: 163 | return x, info 164 | else: 165 | return x 166 | 167 | def decode(self, x): 168 | return x 169 | 170 | def decode_tokens(self, codes, **kwargs): 171 | latents = self.quantizer.get_outputs_from_indices(codes) 172 | 173 | return self.decode(latents, **kwargs) 174 | 175 | class RVQVAEBottleneck(DiscreteBottleneck): 176 | def __init__(self, **quantizer_kwargs): 177 | super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") 178 | self.quantizer = ResidualVQ(**quantizer_kwargs) 179 | self.num_quantizers = quantizer_kwargs["num_quantizers"] 180 | 181 | def encode(self, x, return_info=False): 182 | info = {} 183 | 184 | x, kl = vae_sample(*x.chunk(2, dim=1)) 185 | 186 | info["kl"] = kl 187 | 188 | x = rearrange(x, "b c n -> b n c") 189 | x, indices, loss = self.quantizer(x) 190 | x = rearrange(x, "b n c -> b c n") 191 | 192 | info["quantizer_indices"] = indices 193 | info["quantizer_loss"] = loss.mean() 194 | 195 | if return_info: 196 | return x, info 197 | else: 198 | return x 199 | 200 | def decode(self, x): 201 | return x 202 | 203 | def decode_tokens(self, codes, **kwargs): 204 | latents = self.quantizer.get_outputs_from_indices(codes) 205 | 206 | return self.decode(latents, **kwargs) 207 | 208 | class DACRVQBottleneck(DiscreteBottleneck): 209 | def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs): 210 | super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") 211 | self.quantizer = DACResidualVQ(**quantizer_kwargs) 212 | self.num_quantizers = quantizer_kwargs["n_codebooks"] 213 | self.quantize_on_decode = quantize_on_decode 214 | self.noise_augment_dim = noise_augment_dim 215 | 216 | def encode(self, x, return_info=False, **kwargs): 217 | info = {} 218 | 219 | info["pre_quantizer"] = x 220 | 221 | if self.quantize_on_decode: 222 | return x, info if return_info else x 223 | 224 | z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs) 225 | 226 | output = { 227 | "z": z, 228 | "codes": codes, 229 | "latents": latents, 230 | "vq/commitment_loss": commitment_loss, 231 | "vq/codebook_loss": codebook_loss, 232 | } 233 | 234 | output["vq/commitment_loss"] /= self.num_quantizers 235 | output["vq/codebook_loss"] /= self.num_quantizers 236 | 237 | info.update(output) 238 | 239 | if return_info: 240 | return output["z"], info 241 | 242 | return output["z"] 243 | 244 | def decode(self, x): 245 | 246 | if self.quantize_on_decode: 247 | x = self.quantizer(x)[0] 248 | 249 | if self.noise_augment_dim > 0: 250 | noise = torch.randn(x.shape[0], self.noise_augment_dim, 251 | x.shape[-1]).type_as(x) 252 | x = torch.cat([x, noise], dim=1) 253 | 254 | return x 255 | 256 | def decode_tokens(self, codes, **kwargs): 257 | latents, _, _ = self.quantizer.from_codes(codes) 258 | 259 | return self.decode(latents, **kwargs) 260 | 261 | class DACRVQVAEBottleneck(DiscreteBottleneck): 262 | def __init__(self, quantize_on_decode=False, **quantizer_kwargs): 263 | super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") 264 | self.quantizer = DACResidualVQ(**quantizer_kwargs) 265 | self.num_quantizers = quantizer_kwargs["n_codebooks"] 266 | self.quantize_on_decode = quantize_on_decode 267 | 268 | def encode(self, x, return_info=False, n_quantizers: int = None): 269 | info = {} 270 | 271 | mean, scale = x.chunk(2, dim=1) 272 | 273 | x, kl = vae_sample(mean, scale) 274 | 275 | info["pre_quantizer"] = x 276 | info["kl"] = kl 277 | 278 | if self.quantize_on_decode: 279 | return x, info if return_info else x 280 | 281 | z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers) 282 | 283 | output = { 284 | "z": z, 285 | "codes": codes, 286 | "latents": latents, 287 | "vq/commitment_loss": commitment_loss, 288 | "vq/codebook_loss": codebook_loss, 289 | } 290 | 291 | output["vq/commitment_loss"] /= self.num_quantizers 292 | output["vq/codebook_loss"] /= self.num_quantizers 293 | 294 | info.update(output) 295 | 296 | if return_info: 297 | return output["z"], info 298 | 299 | return output["z"] 300 | 301 | def decode(self, x): 302 | 303 | if self.quantize_on_decode: 304 | x = self.quantizer(x)[0] 305 | 306 | return x 307 | 308 | def decode_tokens(self, codes, **kwargs): 309 | latents, _, _ = self.quantizer.from_codes(codes) 310 | 311 | return self.decode(latents, **kwargs) 312 | 313 | class FSQBottleneck(DiscreteBottleneck): 314 | def __init__(self, noise_augment_dim=0, **kwargs): 315 | super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices") 316 | 317 | self.noise_augment_dim = noise_augment_dim 318 | 319 | self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64]) 320 | 321 | def encode(self, x, return_info=False): 322 | info = {} 323 | 324 | orig_dtype = x.dtype 325 | x = x.float() 326 | 327 | x = rearrange(x, "b c n -> b n c") 328 | x, indices = self.quantizer(x) 329 | x = rearrange(x, "b n c -> b c n") 330 | 331 | x = x.to(orig_dtype) 332 | 333 | # Reorder indices to match the expected format 334 | indices = rearrange(indices, "b n q -> b q n") 335 | 336 | info["quantizer_indices"] = indices 337 | 338 | if return_info: 339 | return x, info 340 | else: 341 | return x 342 | 343 | def decode(self, x): 344 | 345 | if self.noise_augment_dim > 0: 346 | noise = torch.randn(x.shape[0], self.noise_augment_dim, 347 | x.shape[-1]).type_as(x) 348 | x = torch.cat([x, noise], dim=1) 349 | 350 | return x 351 | 352 | def decode_tokens(self, tokens, **kwargs): 353 | latents = self.quantizer.indices_to_codes(tokens) 354 | 355 | return self.decode(latents, **kwargs) -------------------------------------------------------------------------------- /audiox/stable_audio_tools/models/blocks.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | from torch.backends.cuda import sdp_kernel 9 | from packaging import version 10 | 11 | from dac.nn.layers import Snake1d 12 | 13 | class ResidualBlock(nn.Module): 14 | def __init__(self, main, skip=None): 15 | super().__init__() 16 | self.main = nn.Sequential(*main) 17 | self.skip = skip if skip else nn.Identity() 18 | 19 | def forward(self, input): 20 | return self.main(input) + self.skip(input) 21 | 22 | class ResConvBlock(ResidualBlock): 23 | def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False): 24 | skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False) 25 | super().__init__([ 26 | nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias), 27 | nn.GroupNorm(1, c_mid), 28 | Snake1d(c_mid) if use_snake else nn.GELU(), 29 | nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias), 30 | nn.GroupNorm(1, c_out) if not is_last else nn.Identity(), 31 | (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(), 32 | ], skip) 33 | 34 | class SelfAttention1d(nn.Module): 35 | def __init__(self, c_in, n_head=1, dropout_rate=0.): 36 | super().__init__() 37 | assert c_in % n_head == 0 38 | self.norm = nn.GroupNorm(1, c_in) 39 | self.n_head = n_head 40 | self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1) 41 | self.out_proj = nn.Conv1d(c_in, c_in, 1) 42 | self.dropout = nn.Dropout(dropout_rate, inplace=True) 43 | 44 | self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') 45 | 46 | if not self.use_flash: 47 | return 48 | 49 | device_properties = torch.cuda.get_device_properties(torch.device('cuda')) 50 | 51 | if device_properties.major == 8 and device_properties.minor == 0: 52 | # Use flash attention for A100 GPUs 53 | self.sdp_kernel_config = (True, False, False) 54 | else: 55 | # Don't use flash attention for other GPUs 56 | self.sdp_kernel_config = (False, True, True) 57 | 58 | def forward(self, input): 59 | n, c, s = input.shape 60 | qkv = self.qkv_proj(self.norm(input)) 61 | qkv = qkv.view( 62 | [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3) 63 | q, k, v = qkv.chunk(3, dim=1) 64 | scale = k.shape[3]**-0.25 65 | 66 | if self.use_flash: 67 | with sdp_kernel(*self.sdp_kernel_config): 68 | y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s]) 69 | else: 70 | att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) 71 | y = (att @ v).transpose(2, 3).contiguous().view([n, c, s]) 72 | 73 | 74 | return input + self.dropout(self.out_proj(y)) 75 | 76 | class SkipBlock(nn.Module): 77 | def __init__(self, *main): 78 | super().__init__() 79 | self.main = nn.Sequential(*main) 80 | 81 | def forward(self, input): 82 | return torch.cat([self.main(input), input], dim=1) 83 | 84 | class FourierFeatures(nn.Module): 85 | def __init__(self, in_features, out_features, std=1.): 86 | super().__init__() 87 | assert out_features % 2 == 0 88 | self.weight = nn.Parameter(torch.randn( 89 | [out_features // 2, in_features]) * std) 90 | 91 | def forward(self, input): 92 | f = 2 * math.pi * input @ self.weight.T 93 | return torch.cat([f.cos(), f.sin()], dim=-1) 94 | 95 | def expand_to_planes(input, shape): 96 | return input[..., None].repeat([1, 1, shape[2]]) 97 | 98 | _kernels = { 99 | 'linear': 100 | [1 / 8, 3 / 8, 3 / 8, 1 / 8], 101 | 'cubic': 102 | [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 103 | 0.43359375, 0.11328125, -0.03515625, -0.01171875], 104 | 'lanczos3': 105 | [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, 106 | -0.066637322306633, 0.13550527393817902, 0.44638532400131226, 107 | 0.44638532400131226, 0.13550527393817902, -0.066637322306633, 108 | -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] 109 | } 110 | 111 | class Downsample1d(nn.Module): 112 | def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): 113 | super().__init__() 114 | self.pad_mode = pad_mode 115 | kernel_1d = torch.tensor(_kernels[kernel]) 116 | self.pad = kernel_1d.shape[0] // 2 - 1 117 | self.register_buffer('kernel', kernel_1d) 118 | self.channels_last = channels_last 119 | 120 | def forward(self, x): 121 | if self.channels_last: 122 | x = x.permute(0, 2, 1) 123 | x = F.pad(x, (self.pad,) * 2, self.pad_mode) 124 | weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) 125 | indices = torch.arange(x.shape[1], device=x.device) 126 | weight[indices, indices] = self.kernel.to(weight) 127 | x = F.conv1d(x, weight, stride=2) 128 | if self.channels_last: 129 | x = x.permute(0, 2, 1) 130 | return x 131 | 132 | 133 | class Upsample1d(nn.Module): 134 | def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): 135 | super().__init__() 136 | self.pad_mode = pad_mode 137 | kernel_1d = torch.tensor(_kernels[kernel]) * 2 138 | self.pad = kernel_1d.shape[0] // 2 - 1 139 | self.register_buffer('kernel', kernel_1d) 140 | self.channels_last = channels_last 141 | 142 | def forward(self, x): 143 | if self.channels_last: 144 | x = x.permute(0, 2, 1) 145 | x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode) 146 | weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) 147 | indices = torch.arange(x.shape[1], device=x.device) 148 | weight[indices, indices] = self.kernel.to(weight) 149 | x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1) 150 | if self.channels_last: 151 | x = x.permute(0, 2, 1) 152 | return x 153 | 154 | def Downsample1d_2( 155 | in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 156 | ) -> nn.Module: 157 | assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" 158 | 159 | return nn.Conv1d( 160 | in_channels=in_channels, 161 | out_channels=out_channels, 162 | kernel_size=factor * kernel_multiplier + 1, 163 | stride=factor, 164 | padding=factor * (kernel_multiplier // 2), 165 | ) 166 | 167 | 168 | def Upsample1d_2( 169 | in_channels: int, out_channels: int, factor: int, use_nearest: bool = False 170 | ) -> nn.Module: 171 | 172 | if factor == 1: 173 | return nn.Conv1d( 174 | in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 175 | ) 176 | 177 | if use_nearest: 178 | return nn.Sequential( 179 | nn.Upsample(scale_factor=factor, mode="nearest"), 180 | nn.Conv1d( 181 | in_channels=in_channels, 182 | out_channels=out_channels, 183 | kernel_size=3, 184 | padding=1, 185 | ), 186 | ) 187 | else: 188 | return nn.ConvTranspose1d( 189 | in_channels=in_channels, 190 | out_channels=out_channels, 191 | kernel_size=factor * 2, 192 | stride=factor, 193 | padding=factor // 2 + factor % 2, 194 | output_padding=factor % 2, 195 | ) 196 | 197 | def zero_init(layer): 198 | nn.init.zeros_(layer.weight) 199 | if layer.bias is not None: 200 | nn.init.zeros_(layer.bias) 201 | return layer 202 | 203 | def rms_norm(x, scale, eps): 204 | dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) 205 | mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) 206 | scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) 207 | return x * scale.to(x.dtype) 208 | 209 | #rms_norm = torch.compile(rms_norm) 210 | 211 | class AdaRMSNorm(nn.Module): 212 | def __init__(self, features, cond_features, eps=1e-6): 213 | super().__init__() 214 | self.eps = eps 215 | self.linear = zero_init(nn.Linear(cond_features, features, bias=False)) 216 | 217 | def extra_repr(self): 218 | return f"eps={self.eps}," 219 | 220 | def forward(self, x, cond): 221 | return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps) 222 | 223 | def normalize(x, eps=1e-4): 224 | dim = list(range(1, x.ndim)) 225 | n = torch.linalg.vector_norm(x, dim=dim, keepdim=True) 226 | alpha = np.sqrt(n.numel() / x.numel()) 227 | return x / torch.add(eps, n, alpha=alpha) 228 | 229 | class ForcedWNConv1d(nn.Module): 230 | def __init__(self, in_channels, out_channels, kernel_size=1): 231 | super().__init__() 232 | self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size])) 233 | 234 | def forward(self, x): 235 | if self.training: 236 | with torch.no_grad(): 237 | self.weight.copy_(normalize(self.weight)) 238 | 239 | fan_in = self.weight[0].numel() 240 | 241 | w = normalize(self.weight) / math.sqrt(fan_in) 242 | 243 | return F.conv1d(x, w, padding='same') 244 | 245 | # Kernels 246 | 247 | use_compile = True 248 | 249 | def compile(function, *args, **kwargs): 250 | if not use_compile: 251 | return function 252 | try: 253 | return torch.compile(function, *args, **kwargs) 254 | except RuntimeError: 255 | return function 256 | 257 | 258 | @compile 259 | def linear_geglu(x, weight, bias=None): 260 | x = x @ weight.mT 261 | if bias is not None: 262 | x = x + bias 263 | x, gate = x.chunk(2, dim=-1) 264 | return x * F.gelu(gate) 265 | 266 | 267 | @compile 268 | def rms_norm(x, scale, eps): 269 | dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) 270 | mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) 271 | scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) 272 | return x * scale.to(x.dtype) 273 | 274 | # Layers 275 | 276 | class LinearGEGLU(nn.Linear): 277 | def __init__(self, in_features, out_features, bias=True): 278 | super().__init__(in_features, out_features * 2, bias=bias) 279 | self.out_features = out_features 280 | 281 | def forward(self, x): 282 | return linear_geglu(x, self.weight, self.bias) 283 | 284 | 285 | class RMSNorm(nn.Module): 286 | def __init__(self, shape, fix_scale = False, eps=1e-6): 287 | super().__init__() 288 | self.eps = eps 289 | 290 | if fix_scale: 291 | self.register_buffer("scale", torch.ones(shape)) 292 | else: 293 | self.scale = nn.Parameter(torch.ones(shape)) 294 | 295 | def extra_repr(self): 296 | return f"shape={tuple(self.scale.shape)}, eps={self.eps}" 297 | 298 | def forward(self, x): 299 | return rms_norm(x, self.scale, self.eps) 300 | 301 | def snake_beta(x, alpha, beta): 302 | return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) 303 | 304 | # try: 305 | # snake_beta = torch.compile(snake_beta) 306 | # except RuntimeError: 307 | # pass 308 | 309 | # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license 310 | # License available in LICENSES/LICENSE_NVIDIA.txt 311 | class SnakeBeta(nn.Module): 312 | 313 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): 314 | super(SnakeBeta, self).__init__() 315 | self.in_features = in_features 316 | 317 | # initialize alpha 318 | self.alpha_logscale = alpha_logscale 319 | if self.alpha_logscale: # log scale alphas initialized to zeros 320 | self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) 321 | self.beta = nn.Parameter(torch.zeros(in_features) * alpha) 322 | else: # linear scale alphas initialized to ones 323 | self.alpha = nn.Parameter(torch.ones(in_features) * alpha) 324 | self.beta = nn.Parameter(torch.ones(in_features) * alpha) 325 | 326 | self.alpha.requires_grad = alpha_trainable 327 | self.beta.requires_grad = alpha_trainable 328 | 329 | self.no_div_by_zero = 0.000000001 330 | 331 | def forward(self, x): 332 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 333 | beta = self.beta.unsqueeze(0).unsqueeze(-1) 334 | if self.alpha_logscale: 335 | alpha = torch.exp(alpha) 336 | beta = torch.exp(beta) 337 | x = snake_beta(x, alpha, beta) 338 | 339 | return x -------------------------------------------------------------------------------- /audiox/stable_audio_tools/inference/generation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import typing as tp 4 | import math 5 | from torchaudio import transforms as T 6 | 7 | from .utils import prepare_audio 8 | from .sampling import sample, sample_k, sample_rf 9 | from ..data.utils import PadCrop 10 | 11 | def generate_diffusion_uncond( 12 | model, 13 | steps: int = 250, 14 | batch_size: int = 1, 15 | sample_size: int = 2097152, 16 | seed: int = -1, 17 | device: str = "cuda", 18 | init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None, 19 | init_noise_level: float = 1.0, 20 | return_latents = False, 21 | **sampler_kwargs 22 | ) -> torch.Tensor: 23 | 24 | # The length of the output in audio samples 25 | audio_sample_size = sample_size 26 | 27 | # If this is latent diffusion, change sample_size instead to the downsampled latent size 28 | if model.pretransform is not None: 29 | sample_size = sample_size // model.pretransform.downsampling_ratio 30 | 31 | # Seed 32 | # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed. 33 | seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32) 34 | # seed = 777 35 | print(seed) 36 | torch.manual_seed(seed) 37 | # Define the initial noise immediately after setting the seed 38 | noise = torch.randn([batch_size, model.io_channels, sample_size], device=device) 39 | 40 | if init_audio is not None: 41 | # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio. 42 | in_sr, init_audio = init_audio 43 | 44 | io_channels = model.io_channels 45 | 46 | # For latent models, set the io_channels to the autoencoder's io_channels 47 | if model.pretransform is not None: 48 | io_channels = model.pretransform.io_channels 49 | 50 | # Prepare the initial audio for use by the model 51 | init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device) 52 | 53 | # For latent models, encode the initial audio into latents 54 | if model.pretransform is not None: 55 | init_audio = model.pretransform.encode(init_audio) 56 | 57 | init_audio = init_audio.repeat(batch_size, 1, 1) 58 | else: 59 | # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch. 60 | init_audio = None 61 | init_noise_level = None 62 | 63 | # Inpainting mask 64 | 65 | if init_audio is not None: 66 | # variations 67 | sampler_kwargs["sigma_max"] = init_noise_level 68 | mask = None 69 | else: 70 | mask = None 71 | 72 | # Now the generative AI part: 73 | 74 | diff_objective = model.diffusion_objective 75 | 76 | if diff_objective == "v": 77 | # k-diffusion denoising process go! 78 | sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, device=device) 79 | elif diff_objective == "rectified_flow": 80 | sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, device=device) 81 | 82 | # Denoising process done. 83 | # If this is latent diffusion, decode latents back into audio 84 | if model.pretransform is not None and not return_latents: 85 | sampled = model.pretransform.decode(sampled) 86 | 87 | # Return audio 88 | return sampled 89 | 90 | 91 | def generate_diffusion_cond( 92 | model, 93 | steps: int = 250, 94 | cfg_scale=6, 95 | conditioning: dict = None, 96 | conditioning_tensors: tp.Optional[dict] = None, 97 | negative_conditioning: dict = None, 98 | negative_conditioning_tensors: tp.Optional[dict] = None, 99 | batch_size: int = 1, 100 | sample_size: int = 2097152, 101 | sample_rate: int = 48000, 102 | seed: int = -1, 103 | device: str = "cuda", 104 | init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None, 105 | init_noise_level: float = 1.0, 106 | mask_args: dict = None, 107 | return_latents = False, 108 | **sampler_kwargs 109 | ) -> torch.Tensor: 110 | """ 111 | Generate audio from a prompt using a diffusion model. 112 | 113 | Args: 114 | model: The diffusion model to use for generation. 115 | steps: The number of diffusion steps to use. 116 | cfg_scale: Classifier-free guidance scale 117 | conditioning: A dictionary of conditioning parameters to use for generation. 118 | conditioning_tensors: A dictionary of precomputed conditioning tensors to use for generation. 119 | batch_size: The batch size to use for generation. 120 | sample_size: The length of the audio to generate, in samples. 121 | sample_rate: The sample rate of the audio to generate (Deprecated, now pulled from the model directly) 122 | seed: The random seed to use for generation, or -1 to use a random seed. 123 | device: The device to use for generation. 124 | init_audio: A tuple of (sample_rate, audio) to use as the initial audio for generation. 125 | init_noise_level: The noise level to use when generating from an initial audio sample. 126 | return_latents: Whether to return the latents used for generation instead of the decoded audio. 127 | **sampler_kwargs: Additional keyword arguments to pass to the sampler. 128 | """ 129 | 130 | # The length of the output in audio samples 131 | audio_sample_size = sample_size 132 | 133 | # If this is latent diffusion, change sample_size instead to the downsampled latent size 134 | if model.pretransform is not None: 135 | sample_size = sample_size // model.pretransform.downsampling_ratio 136 | 137 | # Seed 138 | # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed. 139 | seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32) 140 | # seed = 777 141 | # print(seed) 142 | torch.manual_seed(seed) 143 | # Define the initial noise immediately after setting the seed 144 | noise = torch.randn([batch_size, model.io_channels, sample_size], device=device) 145 | 146 | torch.backends.cuda.matmul.allow_tf32 = False 147 | torch.backends.cudnn.allow_tf32 = False 148 | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 149 | torch.backends.cudnn.benchmark = False 150 | 151 | # Conditioning 152 | assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors" 153 | if conditioning_tensors is None: 154 | conditioning_tensors = model.conditioner(conditioning, device) 155 | conditioning_inputs = model.get_conditioning_inputs(conditioning_tensors) 156 | 157 | if negative_conditioning is not None or negative_conditioning_tensors is not None: 158 | 159 | if negative_conditioning_tensors is None: 160 | negative_conditioning_tensors = model.conditioner(negative_conditioning, device) 161 | 162 | negative_conditioning_tensors = model.get_conditioning_inputs(negative_conditioning_tensors, negative=True) 163 | else: 164 | negative_conditioning_tensors = {} 165 | 166 | if init_audio is not None: 167 | # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio. 168 | in_sr, init_audio = init_audio 169 | 170 | io_channels = model.io_channels 171 | 172 | # For latent models, set the io_channels to the autoencoder's io_channels 173 | if model.pretransform is not None: 174 | io_channels = model.pretransform.io_channels 175 | 176 | # Prepare the initial audio for use by the model 177 | init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device) 178 | 179 | # For latent models, encode the initial audio into latents 180 | if model.pretransform is not None: 181 | init_audio = model.pretransform.encode(init_audio) 182 | 183 | init_audio = init_audio.repeat(batch_size, 1, 1) 184 | else: 185 | # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch. 186 | init_audio = None 187 | init_noise_level = None 188 | mask_args = None 189 | 190 | # Inpainting mask 191 | if init_audio is not None and mask_args is not None: 192 | # Cut and paste init_audio according to cropfrom, pastefrom, pasteto 193 | # This is helpful for forward and reverse outpainting 194 | cropfrom = math.floor(mask_args["cropfrom"]/100.0 * sample_size) 195 | pastefrom = math.floor(mask_args["pastefrom"]/100.0 * sample_size) 196 | pasteto = math.ceil(mask_args["pasteto"]/100.0 * sample_size) 197 | assert pastefrom < pasteto, "Paste From should be less than Paste To" 198 | croplen = pasteto - pastefrom 199 | if cropfrom + croplen > sample_size: 200 | croplen = sample_size - cropfrom 201 | cropto = cropfrom + croplen 202 | pasteto = pastefrom + croplen 203 | cutpaste = init_audio.new_zeros(init_audio.shape) 204 | cutpaste[:, :, pastefrom:pasteto] = init_audio[:,:,cropfrom:cropto] 205 | #print(cropfrom, cropto, pastefrom, pasteto) 206 | init_audio = cutpaste 207 | # Build a soft mask (list of floats 0 to 1, the size of the latent) from the given args 208 | mask = build_mask(sample_size, mask_args) 209 | mask = mask.to(device) 210 | elif init_audio is not None and mask_args is None: 211 | # variations 212 | sampler_kwargs["sigma_max"] = init_noise_level 213 | mask = None 214 | else: 215 | mask = None 216 | 217 | model_dtype = next(model.model.parameters()).dtype 218 | noise = noise.type(model_dtype) 219 | conditioning_inputs = {k: v.type(model_dtype) if v is not None else v for k, v in conditioning_inputs.items()} 220 | # Now the generative AI part: 221 | # k-diffusion denoising process go! 222 | 223 | diff_objective = model.diffusion_objective 224 | 225 | if diff_objective == "v": 226 | # k-diffusion denoising process go! 227 | sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device) 228 | 229 | elif diff_objective == "rectified_flow": 230 | 231 | if "sigma_min" in sampler_kwargs: 232 | del sampler_kwargs["sigma_min"] 233 | 234 | if "sampler_type" in sampler_kwargs: 235 | del sampler_kwargs["sampler_type"] 236 | 237 | sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device) 238 | 239 | # v-diffusion: 240 | del noise 241 | del conditioning_tensors 242 | del conditioning_inputs 243 | torch.cuda.empty_cache() 244 | # Denoising process done. 245 | # If this is latent diffusion, decode latents back into audio 246 | 247 | if model.pretransform is not None and not return_latents: 248 | #cast sampled latents to pretransform dtype 249 | sampled = sampled.to(next(model.pretransform.parameters()).dtype) 250 | sampled = model.pretransform.decode(sampled) 251 | 252 | return sampled 253 | 254 | # builds a softmask given the parameters 255 | # returns array of values 0 to 1, size sample_size, where 0 means noise / fresh generation, 1 means keep the input audio, 256 | # and anything between is a mixture of old/new 257 | # ideally 0.5 is half/half mixture but i haven't figured this out yet 258 | def build_mask(sample_size, mask_args): 259 | maskstart = math.floor(mask_args["maskstart"]/100.0 * sample_size) 260 | maskend = math.ceil(mask_args["maskend"]/100.0 * sample_size) 261 | softnessL = round(mask_args["softnessL"]/100.0 * sample_size) 262 | softnessR = round(mask_args["softnessR"]/100.0 * sample_size) 263 | marination = mask_args["marination"] 264 | # use hann windows for softening the transition (i don't know if this is correct) 265 | hannL = torch.hann_window(softnessL*2, periodic=False)[:softnessL] 266 | hannR = torch.hann_window(softnessR*2, periodic=False)[softnessR:] 267 | # build the mask. 268 | mask = torch.zeros((sample_size)) 269 | mask[maskstart:maskend] = 1 270 | mask[maskstart:maskstart+softnessL] = hannL 271 | mask[maskend-softnessR:maskend] = hannR 272 | # marination finishes the inpainting early in the denoising schedule, and lets audio get changed in the final rounds 273 | if marination > 0: 274 | mask = mask * (1-marination) 275 | return mask 276 | -------------------------------------------------------------------------------- /examples/audiox_txt2music+audio.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "3e7aad6e-5c73-40a9-9022-f75e15e2d354", 3 | "revision": 0, 4 | "last_node_id": 9, 5 | "last_link_id": 8, 6 | "nodes": [ 7 | { 8 | "id": 1, 9 | "type": "AudioXEnhancedTextToAudio", 10 | "pos": [ 11 | -617.2890014648438, 12 | -264.3356018066406 13 | ], 14 | "size": [ 15 | 400, 16 | 382 17 | ], 18 | "flags": {}, 19 | "order": 4, 20 | "mode": 4, 21 | "inputs": [ 22 | { 23 | "name": "model", 24 | "type": "AUDIOX_MODEL", 25 | "link": 4 26 | }, 27 | { 28 | "name": "text_prompt", 29 | "type": "STRING", 30 | "widget": { 31 | "name": "text_prompt" 32 | }, 33 | "link": 1 34 | }, 35 | { 36 | "name": "negative_prompt", 37 | "shape": 7, 38 | "type": "STRING", 39 | "widget": { 40 | "name": "negative_prompt" 41 | }, 42 | "link": 2 43 | } 44 | ], 45 | "outputs": [ 46 | { 47 | "name": "audio", 48 | "type": "AUDIO", 49 | "links": [ 50 | 3 51 | ] 52 | } 53 | ], 54 | "properties": { 55 | "aux_id": "lum3on/ComfyUI-AudioX", 56 | "ver": "68529ed407aad26b38565b51cfbd50d495ace0ec", 57 | "Node name for S&R": "AudioXEnhancedTextToAudio", 58 | "enableTabs": false, 59 | "tabWidth": 65, 60 | "tabXOffset": 10, 61 | "hasSecondTab": false, 62 | "secondTabText": "Send Back", 63 | "secondTabOffset": 80, 64 | "secondTabWidth": 65, 65 | "widget_ue_connectable": {} 66 | }, 67 | "widgets_values": [ 68 | "Typing on a keyboard", 69 | 250, 70 | 9, 71 | 1584129233, 72 | "randomize", 73 | 10, 74 | "muffled, distorted, low quality, noise, silence", 75 | "none", 76 | true, 77 | "realistic", 78 | "multi_aspect", 79 | true 80 | ] 81 | }, 82 | { 83 | "id": 6, 84 | "type": "AudioXEnhancedTextToMusic", 85 | "pos": [ 86 | -610.9913330078125, 87 | 301.61346435546875 88 | ], 89 | "size": [ 90 | 400, 91 | 406 92 | ], 93 | "flags": {}, 94 | "order": 5, 95 | "mode": 0, 96 | "inputs": [ 97 | { 98 | "name": "model", 99 | "type": "AUDIOX_MODEL", 100 | "link": 5 101 | }, 102 | { 103 | "name": "text_prompt", 104 | "type": "STRING", 105 | "widget": { 106 | "name": "text_prompt" 107 | }, 108 | "link": 6 109 | }, 110 | { 111 | "name": "negative_prompt", 112 | "shape": 7, 113 | "type": "STRING", 114 | "widget": { 115 | "name": "negative_prompt" 116 | }, 117 | "link": 7 118 | } 119 | ], 120 | "outputs": [ 121 | { 122 | "name": "audio", 123 | "type": "AUDIO", 124 | "links": [ 125 | 8 126 | ] 127 | } 128 | ], 129 | "properties": { 130 | "aux_id": "lum3on/ComfyUI-AudioX", 131 | "ver": "68529ed407aad26b38565b51cfbd50d495ace0ec", 132 | "Node name for S&R": "AudioXEnhancedTextToMusic", 133 | "enableTabs": false, 134 | "tabWidth": 65, 135 | "tabXOffset": 10, 136 | "hasSecondTab": false, 137 | "secondTabText": "Send Back", 138 | "secondTabOffset": 80, 139 | "secondTabWidth": 65, 140 | "widget_ue_connectable": {} 141 | }, 142 | "widgets_values": [ 143 | "A peaceful piano melody", 144 | 250, 145 | 7, 146 | 3213582614, 147 | "randomize", 148 | 10, 149 | "discordant, harsh, atonal, noise, distorted", 150 | "none", 151 | "none", 152 | "none", 153 | true, 154 | "multi_aspect", 155 | true 156 | ] 157 | }, 158 | { 159 | "id": 8, 160 | "type": "SaveAudio", 161 | "pos": [ 162 | -186.57098388671875, 163 | 303.60675048828125 164 | ], 165 | "size": [ 166 | 270, 167 | 112 168 | ], 169 | "flags": {}, 170 | "order": 7, 171 | "mode": 0, 172 | "inputs": [ 173 | { 174 | "name": "audio", 175 | "type": "AUDIO", 176 | "link": 8 177 | } 178 | ], 179 | "outputs": [], 180 | "properties": { 181 | "cnr_id": "comfy-core", 182 | "ver": "0.3.40", 183 | "Node name for S&R": "SaveAudio", 184 | "enableTabs": false, 185 | "tabWidth": 65, 186 | "tabXOffset": 10, 187 | "hasSecondTab": false, 188 | "secondTabText": "Send Back", 189 | "secondTabOffset": 80, 190 | "secondTabWidth": 65, 191 | "widget_ue_connectable": {} 192 | }, 193 | "widgets_values": [ 194 | "audio/ComfyUI" 195 | ] 196 | }, 197 | { 198 | "id": 7, 199 | "type": "AudioXPromptHelper", 200 | "pos": [ 201 | -1031.8333740234375, 202 | 306.123046875 203 | ], 204 | "size": [ 205 | 400, 206 | 278 207 | ], 208 | "flags": {}, 209 | "order": 0, 210 | "mode": 0, 211 | "inputs": [], 212 | "outputs": [ 213 | { 214 | "name": "enhanced_prompt", 215 | "type": "STRING", 216 | "links": [ 217 | 6 218 | ] 219 | }, 220 | { 221 | "name": "negative_prompt", 222 | "type": "STRING", 223 | "links": [ 224 | 7 225 | ] 226 | }, 227 | { 228 | "name": "prompt_info", 229 | "type": "STRING", 230 | "links": null 231 | } 232 | ], 233 | "properties": { 234 | "aux_id": "lum3on/ComfyUI-AudioX", 235 | "ver": "68529ed407aad26b38565b51cfbd50d495ace0ec", 236 | "Node name for S&R": "AudioXPromptHelper", 237 | "enableTabs": false, 238 | "tabWidth": 65, 239 | "tabXOffset": 10, 240 | "hasSecondTab": false, 241 | "secondTabText": "Send Back", 242 | "secondTabOffset": 80, 243 | "secondTabWidth": 65, 244 | "widget_ue_connectable": {} 245 | }, 246 | "widgets_values": [ 247 | "cinematic orchestra, beautiful strings", 248 | "none", 249 | true, 250 | true, 251 | "noise, muffled, distorted, low quality, noise", 252 | "realistic" 253 | ] 254 | }, 255 | { 256 | "id": 2, 257 | "type": "AudioXPromptHelper", 258 | "pos": [ 259 | -1039.6156005859375, 260 | -249.17440795898438 261 | ], 262 | "size": [ 263 | 400, 264 | 278 265 | ], 266 | "flags": {}, 267 | "order": 1, 268 | "mode": 4, 269 | "inputs": [], 270 | "outputs": [ 271 | { 272 | "name": "enhanced_prompt", 273 | "type": "STRING", 274 | "links": [ 275 | 1 276 | ] 277 | }, 278 | { 279 | "name": "negative_prompt", 280 | "type": "STRING", 281 | "links": [ 282 | 2 283 | ] 284 | }, 285 | { 286 | "name": "prompt_info", 287 | "type": "STRING", 288 | "links": null 289 | } 290 | ], 291 | "properties": { 292 | "aux_id": "lum3on/ComfyUI-AudioX", 293 | "ver": "68529ed407aad26b38565b51cfbd50d495ace0ec", 294 | "Node name for S&R": "AudioXPromptHelper", 295 | "enableTabs": false, 296 | "tabWidth": 65, 297 | "tabXOffset": 10, 298 | "hasSecondTab": false, 299 | "secondTabText": "Send Back", 300 | "secondTabOffset": 80, 301 | "secondTabWidth": 65, 302 | "widget_ue_connectable": {} 303 | }, 304 | "widgets_values": [ 305 | "boat horn, rainy water droplets, bird song", 306 | "none", 307 | true, 308 | true, 309 | "muffled, distorted, low quality, noise", 310 | "realistic" 311 | ] 312 | }, 313 | { 314 | "id": 3, 315 | "type": "SaveAudio", 316 | "pos": [ 317 | -189.4632568359375, 318 | -269.09698486328125 319 | ], 320 | "size": [ 321 | 270, 322 | 112 323 | ], 324 | "flags": {}, 325 | "order": 6, 326 | "mode": 4, 327 | "inputs": [ 328 | { 329 | "name": "audio", 330 | "type": "AUDIO", 331 | "link": 3 332 | } 333 | ], 334 | "outputs": [], 335 | "properties": { 336 | "cnr_id": "comfy-core", 337 | "ver": "0.3.40", 338 | "Node name for S&R": "SaveAudio", 339 | "enableTabs": false, 340 | "tabWidth": 65, 341 | "tabXOffset": 10, 342 | "hasSecondTab": false, 343 | "secondTabText": "Send Back", 344 | "secondTabOffset": 80, 345 | "secondTabWidth": 65, 346 | "widget_ue_connectable": {} 347 | }, 348 | "widgets_values": [ 349 | "audio/ComfyUI" 350 | ] 351 | }, 352 | { 353 | "id": 4, 354 | "type": "AudioXModelLoader", 355 | "pos": [ 356 | -1480, 357 | 160 358 | ], 359 | "size": [ 360 | 398.2261962890625, 361 | 110.46002197265625 362 | ], 363 | "flags": {}, 364 | "order": 2, 365 | "mode": 0, 366 | "inputs": [], 367 | "outputs": [ 368 | { 369 | "name": "model", 370 | "type": "AUDIOX_MODEL", 371 | "links": [ 372 | 4, 373 | 5 374 | ] 375 | } 376 | ], 377 | "properties": { 378 | "aux_id": "lum3on/ComfyUI-AudioX", 379 | "ver": "68529ed407aad26b38565b51cfbd50d495ace0ec", 380 | "Node name for S&R": "AudioXModelLoader", 381 | "enableTabs": false, 382 | "tabWidth": 65, 383 | "tabXOffset": 10, 384 | "hasSecondTab": false, 385 | "secondTabText": "Send Back", 386 | "secondTabOffset": 80, 387 | "secondTabWidth": 65, 388 | "widget_ue_connectable": {} 389 | }, 390 | "widgets_values": [ 391 | "AudioX.ckpt", 392 | "auto", 393 | "auto" 394 | ] 395 | }, 396 | { 397 | "id": 9, 398 | "type": "Fast Groups Bypasser (rgthree)", 399 | "pos": [ 400 | -1480, 401 | 320 402 | ], 403 | "size": [ 404 | 390.63751220703125, 405 | 130 406 | ], 407 | "flags": {}, 408 | "order": 3, 409 | "mode": 0, 410 | "inputs": [], 411 | "outputs": [ 412 | { 413 | "name": "OPT_CONNECTION", 414 | "type": "*", 415 | "links": null 416 | } 417 | ], 418 | "properties": { 419 | "matchColors": "", 420 | "matchTitle": "", 421 | "showNav": true, 422 | "sort": "position", 423 | "customSortAlphabet": "", 424 | "toggleRestriction": "default", 425 | "widget_ue_connectable": {} 426 | } 427 | } 428 | ], 429 | "links": [ 430 | [ 431 | 1, 432 | 2, 433 | 0, 434 | 1, 435 | 1, 436 | "STRING" 437 | ], 438 | [ 439 | 2, 440 | 2, 441 | 1, 442 | 1, 443 | 2, 444 | "STRING" 445 | ], 446 | [ 447 | 3, 448 | 1, 449 | 0, 450 | 3, 451 | 0, 452 | "AUDIO" 453 | ], 454 | [ 455 | 4, 456 | 4, 457 | 0, 458 | 1, 459 | 0, 460 | "AUDIOX_MODEL" 461 | ], 462 | [ 463 | 5, 464 | 4, 465 | 0, 466 | 6, 467 | 0, 468 | "AUDIOX_MODEL" 469 | ], 470 | [ 471 | 6, 472 | 7, 473 | 0, 474 | 6, 475 | 1, 476 | "STRING" 477 | ], 478 | [ 479 | 7, 480 | 7, 481 | 1, 482 | 6, 483 | 2, 484 | "STRING" 485 | ], 486 | [ 487 | 8, 488 | 6, 489 | 0, 490 | 8, 491 | 0, 492 | "AUDIO" 493 | ] 494 | ], 495 | "groups": [ 496 | { 497 | "id": 1, 498 | "title": "text2audio", 499 | "bounding": [ 500 | -1052.966552734375, 501 | -342.6968994140625, 502 | 1143.5032958984375, 503 | 470.36138916015625 504 | ], 505 | "color": "#3f789e", 506 | "font_size": 24, 507 | "flags": {} 508 | }, 509 | { 510 | "id": 3, 511 | "title": "text2music", 512 | "bounding": [ 513 | -1041.8333740234375, 514 | 228.0134735107422, 515 | 1135.2623291015625, 516 | 489.6000061035156 517 | ], 518 | "color": "#3f789e", 519 | "font_size": 24, 520 | "flags": {} 521 | } 522 | ], 523 | "config": {}, 524 | "extra": { 525 | "ue_links": [], 526 | "links_added_by_ue": [], 527 | "ds": { 528 | "scale": 0.8769226950000016, 529 | "offset": [ 530 | 1256.3334285420383, 531 | 195.14090640797156 532 | ] 533 | }, 534 | "frontendVersion": "1.21.7", 535 | "VHS_latentpreview": false, 536 | "VHS_latentpreviewrate": 0, 537 | "VHS_MetadataImage": true, 538 | "VHS_KeepIntermediate": true 539 | }, 540 | "version": 0.4 541 | } 542 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # Safe import of ComfyUI folder_paths 5 | try: 6 | import folder_paths as comfy_paths 7 | COMFYUI_AVAILABLE = True 8 | except ImportError: 9 | print("AudioX: ComfyUI folder_paths not available (running outside ComfyUI)") 10 | COMFYUI_AVAILABLE = False 11 | # Create a dummy folder_paths for testing 12 | class DummyFolderPaths: 13 | @staticmethod 14 | def get_filename_list(folder_type): 15 | return [] 16 | @staticmethod 17 | def get_full_path(folder_type, filename): 18 | return None 19 | comfy_paths = DummyFolderPaths() 20 | 21 | # Version identifier to force reload 22 | __version__ = "1.0.9" 23 | 24 | # Startup mode flag to prevent heavy operations during ComfyUI initialization 25 | STARTUP_MODE = True 26 | 27 | def set_runtime_mode(): 28 | """Switch from startup mode to runtime mode.""" 29 | global STARTUP_MODE 30 | STARTUP_MODE = False 31 | print("AudioX: Switched to runtime mode") 32 | 33 | # Add the audiox directory to Python path 34 | AUDIOX_ROOT = os.path.join(os.path.dirname(__file__), "audiox") 35 | sys.path.insert(0, AUDIOX_ROOT) 36 | 37 | # Debug environment information 38 | print(f"AudioX: Python executable: {sys.executable}") 39 | print(f"AudioX: Python version: {sys.version.split()[0]}") 40 | print(f"AudioX: AudioX root: {AUDIOX_ROOT}") 41 | print(f"AudioX: Current working directory: {os.getcwd()}") 42 | 43 | # Ensure critical paths are in sys.path 44 | current_dir = os.path.dirname(os.path.abspath(__file__)) 45 | if current_dir not in sys.path: 46 | sys.path.insert(0, current_dir) 47 | print(f"AudioX: Added to path: {current_dir}") 48 | 49 | # Ensure all required dependencies are available 50 | def ensure_dependencies(): 51 | """Ensure all AudioX dependencies are available.""" 52 | # Complete list of required dependencies 53 | required_deps = [ 54 | "descript-audio-codec", 55 | "einops-exts", 56 | "x-transformers", 57 | "alias-free-torch", 58 | "vector-quantize-pytorch", 59 | "local-attention", 60 | "k-diffusion", 61 | "aeiou", 62 | "auraloss", 63 | "encodec", 64 | "laion-clap", 65 | "prefigure", 66 | "v-diffusion-pytorch" 67 | ] 68 | 69 | missing_deps = [] 70 | 71 | # Check all dependencies 72 | dependency_checks = [ 73 | ("dac", "descript-audio-codec"), 74 | ("einops_exts", "einops-exts"), 75 | ("x_transformers", "x-transformers"), 76 | ("alias_free_torch", "alias-free-torch"), 77 | ("vector_quantize_pytorch", "vector-quantize-pytorch"), 78 | ("local_attention", "local-attention"), 79 | ("k_diffusion", "k-diffusion"), 80 | ("aeiou", "aeiou"), 81 | ("auraloss", "auraloss"), 82 | ("encodec", "encodec"), 83 | ("laion_clap", "laion-clap"), 84 | ("prefigure", "prefigure"), 85 | ("diffusion", "v-diffusion-pytorch") 86 | ] 87 | 88 | for import_name, package_name in dependency_checks: 89 | try: 90 | __import__(import_name) 91 | print(f"AudioX: {package_name} available") 92 | except ImportError as e: 93 | print(f"AudioX: {package_name} not available: {e}") 94 | missing_deps.append(package_name) 95 | 96 | # Install missing dependencies immediately 97 | if missing_deps: 98 | print(f"AudioX: Installing {len(missing_deps)} missing dependencies...") 99 | print(f"AudioX: Using Python: {sys.executable}") 100 | 101 | try: 102 | import subprocess 103 | 104 | # Install all missing dependencies in one command for efficiency 105 | install_cmd = [sys.executable, "-m", "pip", "install"] + missing_deps 106 | print(f"AudioX: Running: {' '.join(install_cmd)}") 107 | 108 | result = subprocess.run(install_cmd, capture_output=True, text=True, timeout=300) 109 | 110 | if result.returncode == 0: 111 | print("AudioX: All dependencies installed successfully!") 112 | print("AudioX: Reloading modules...") 113 | 114 | # Clear import cache to force reload 115 | import importlib 116 | modules_to_reload = [ 117 | 'dac', 'einops_exts', 'x_transformers', 'alias_free_torch', 118 | 'vector_quantize_pytorch', 'local_attention', 'k_diffusion', 119 | 'aeiou', 'auraloss', 'encodec', 'laion_clap', 'prefigure', 120 | 'diffusion' 121 | ] 122 | 123 | for module_name in modules_to_reload: 124 | if module_name in sys.modules: 125 | importlib.reload(sys.modules[module_name]) 126 | 127 | # Re-check critical dependencies 128 | missing_deps.clear() 129 | for import_name, package_name in dependency_checks: 130 | try: 131 | __import__(import_name) 132 | print(f"AudioX: ✓ {package_name} now available") 133 | except ImportError: 134 | missing_deps.append(package_name) 135 | print(f"AudioX: ✗ {package_name} still missing") 136 | 137 | else: 138 | print(f"AudioX: Installation failed: {result.stderr}") 139 | print("AudioX: Trying individual installations...") 140 | 141 | # Try installing each dependency individually 142 | for dep in missing_deps[:]: 143 | try: 144 | result = subprocess.run([ 145 | sys.executable, "-m", "pip", "install", dep 146 | ], capture_output=True, text=True, timeout=120) 147 | 148 | if result.returncode == 0: 149 | print(f"AudioX: ✓ {dep} installed") 150 | missing_deps.remove(dep) 151 | else: 152 | print(f"AudioX: ✗ Failed to install {dep}") 153 | except Exception as e: 154 | print(f"AudioX: Error installing {dep}: {e}") 155 | 156 | except Exception as install_error: 157 | print(f"AudioX: Installation error: {install_error}") 158 | 159 | if missing_deps: 160 | print(f"AudioX: ⚠️ Still missing: {missing_deps}") 161 | print("AudioX: Some features may not work correctly") 162 | else: 163 | print("AudioX: ✅ All dependencies available!") 164 | 165 | return len(missing_deps) == 0 166 | 167 | # EMERGENCY: Force install critical dependencies immediately 168 | print("AudioX: EMERGENCY DEPENDENCY CHECK...") 169 | try: 170 | import vector_quantize_pytorch 171 | print("AudioX: vector_quantize_pytorch already available") 172 | except ImportError: 173 | print("AudioX: INSTALLING vector_quantize_pytorch NOW...") 174 | import subprocess 175 | try: 176 | result = subprocess.run([ 177 | sys.executable, "-m", "pip", "install", "vector-quantize-pytorch" 178 | ], capture_output=True, text=True, timeout=120) 179 | if result.returncode == 0: 180 | print("AudioX: ✅ vector_quantize_pytorch installed successfully!") 181 | # Force reload the module 182 | try: 183 | import importlib 184 | import vector_quantize_pytorch 185 | importlib.reload(vector_quantize_pytorch) 186 | except: 187 | pass 188 | else: 189 | print(f"AudioX: ❌ Failed to install vector_quantize_pytorch: {result.stderr}") 190 | except Exception as e: 191 | print(f"AudioX: ❌ Installation error: {e}") 192 | 193 | # Quick dependency check without heavy operations during startup 194 | print("AudioX: EMERGENCY DEPENDENCY CHECK...") 195 | try: 196 | # Only check critical dependencies that are fast to import 197 | import dac 198 | import einops_exts 199 | print("AudioX: Critical dependencies available") 200 | deps_ok = True 201 | except ImportError as e: 202 | print(f"AudioX: Critical dependency missing: {e}") 203 | print("AudioX: Will attempt full dependency check later...") 204 | deps_ok = ensure_dependencies() 205 | 206 | # Import our nodes with error handling and lazy loading 207 | try: 208 | print("AudioX: Importing core nodes...") 209 | 210 | # Import nodes with timeout protection 211 | import concurrent.futures 212 | import sys 213 | 214 | def import_core_nodes(): 215 | """Import core nodes in a separate thread.""" 216 | from .nodes import ( 217 | AudioXModelLoader, 218 | AudioXTextToAudio, 219 | AudioXEnhancedTextToAudio, 220 | AudioXTextToMusic, 221 | AudioXEnhancedTextToMusic, 222 | AudioXVideoToAudio, 223 | AudioXEnhancedVideoToAudio, 224 | AudioXVideoToMusic, 225 | AudioXMultiModalGeneration, 226 | AudioXAudioProcessor, 227 | AudioXVolumeControl, 228 | AudioXAdvancedVolumeControl, 229 | AudioXVideoMuter, 230 | AudioXVideoAudioCombiner, 231 | AudioXPromptHelper 232 | ) 233 | return ( 234 | AudioXModelLoader, AudioXTextToAudio, AudioXEnhancedTextToAudio, 235 | AudioXTextToMusic, AudioXEnhancedTextToMusic, AudioXVideoToAudio, 236 | AudioXEnhancedVideoToAudio, AudioXVideoToMusic, AudioXMultiModalGeneration, 237 | AudioXAudioProcessor, AudioXVolumeControl, AudioXAdvancedVolumeControl, 238 | AudioXVideoMuter, AudioXVideoAudioCombiner, AudioXPromptHelper 239 | ) 240 | 241 | # Use ThreadPoolExecutor with timeout for import 242 | with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: 243 | future = executor.submit(import_core_nodes) 244 | try: 245 | # 30 second timeout for core node import 246 | core_nodes = future.result(timeout=30) 247 | (AudioXModelLoader, AudioXTextToAudio, AudioXEnhancedTextToAudio, 248 | AudioXTextToMusic, AudioXEnhancedTextToMusic, AudioXVideoToAudio, 249 | AudioXEnhancedVideoToAudio, AudioXVideoToMusic, AudioXMultiModalGeneration, 250 | AudioXAudioProcessor, AudioXVolumeControl, AudioXAdvancedVolumeControl, 251 | AudioXVideoMuter, AudioXVideoAudioCombiner, AudioXPromptHelper) = core_nodes 252 | print("AudioX: ✅ Core nodes imported successfully!") 253 | except concurrent.futures.TimeoutError: 254 | print("AudioX: ⚠️ Core node import timed out, using placeholder nodes") 255 | raise ImportError("Core node import timeout") 256 | 257 | 258 | 259 | except Exception as e: 260 | error_message = str(e) 261 | print(f"AudioX: ❌ Failed to import core nodes: {error_message}") 262 | print("AudioX: This might be due to missing dependencies or network issues.") 263 | print("AudioX: Creating placeholder nodes...") 264 | 265 | # Create placeholder nodes that will show error messages 266 | class PlaceholderNode: 267 | @classmethod 268 | def INPUT_TYPES(cls): 269 | return {"required": {"error_info": ("STRING", {"default": f"AudioX import failed: {error_message}"})}} 270 | 271 | RETURN_TYPES = ("STRING",) 272 | FUNCTION = "show_error" 273 | CATEGORY = "AudioX/Error" 274 | 275 | def show_error(self, error_info): 276 | raise RuntimeError(f"AudioX nodes are not available: {error_info}") 277 | 278 | # Use placeholder for all nodes 279 | AudioXModelLoader = PlaceholderNode 280 | AudioXTextToAudio = PlaceholderNode 281 | AudioXEnhancedTextToAudio = PlaceholderNode 282 | AudioXTextToMusic = PlaceholderNode 283 | AudioXEnhancedTextToMusic = PlaceholderNode 284 | AudioXVideoToAudio = PlaceholderNode 285 | AudioXEnhancedVideoToAudio = PlaceholderNode 286 | AudioXVideoToMusic = PlaceholderNode 287 | AudioXMultiModalGeneration = PlaceholderNode 288 | AudioXAudioProcessor = PlaceholderNode 289 | AudioXVolumeControl = PlaceholderNode 290 | AudioXAdvancedVolumeControl = PlaceholderNode 291 | AudioXVideoMuter = PlaceholderNode 292 | AudioXVideoAudioCombiner = PlaceholderNode 293 | AudioXPromptHelper = PlaceholderNode 294 | 295 | 296 | # Node mappings for ComfyUI 297 | NODE_CLASS_MAPPINGS = { 298 | "AudioXModelLoader": AudioXModelLoader, 299 | "AudioXTextToAudio": AudioXTextToAudio, 300 | "AudioXEnhancedTextToAudio": AudioXEnhancedTextToAudio, 301 | "AudioXTextToMusic": AudioXTextToMusic, 302 | "AudioXEnhancedTextToMusic": AudioXEnhancedTextToMusic, 303 | "AudioXVideoToAudio": AudioXVideoToAudio, 304 | "AudioXEnhancedVideoToAudio": AudioXEnhancedVideoToAudio, 305 | "AudioXVideoToMusic": AudioXVideoToMusic, 306 | "AudioXMultiModalGeneration": AudioXMultiModalGeneration, 307 | "AudioXAudioProcessor": AudioXAudioProcessor, 308 | "AudioXVolumeControl": AudioXVolumeControl, 309 | "AudioXAdvancedVolumeControl": AudioXAdvancedVolumeControl, 310 | "AudioXVideoMuter": AudioXVideoMuter, 311 | "AudioXVideoAudioCombiner": AudioXVideoAudioCombiner, 312 | "AudioXPromptHelper": AudioXPromptHelper, 313 | 314 | } 315 | 316 | # Display names for the nodes 317 | NODE_DISPLAY_NAME_MAPPINGS = { 318 | "AudioXModelLoader": "AudioX Model Loader", 319 | "AudioXTextToAudio": "AudioX Text to Audio", 320 | "AudioXEnhancedTextToAudio": "AudioX Enhanced Text to Audio", 321 | "AudioXTextToMusic": "AudioX Text to Music", 322 | "AudioXEnhancedTextToMusic": "AudioX Enhanced Text to Music", 323 | "AudioXVideoToAudio": "AudioX Video to Audio", 324 | "AudioXEnhancedVideoToAudio": "AudioX Enhanced Video to Audio", 325 | "AudioXVideoToMusic": "AudioX Video to Music", 326 | "AudioXMultiModalGeneration": "AudioX Multi-Modal Generation", 327 | "AudioXAudioProcessor": "AudioX Audio Processor", 328 | "AudioXVolumeControl": "AudioX Volume Control", 329 | "AudioXAdvancedVolumeControl": "AudioX Advanced Volume Control", 330 | "AudioXVideoMuter": "AudioX Video Muter", 331 | "AudioXVideoAudioCombiner": "AudioX Video Audio Combiner", 332 | "AudioXPromptHelper": "AudioX Prompt Helper", 333 | } 334 | 335 | 336 | 337 | # Web directory for any web components 338 | WEB_DIRECTORY = "./web" 339 | 340 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS', 'WEB_DIRECTORY'] 341 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-AudioX 2 | 3 | A powerful audio generation extension for ComfyUI that integrates AudioX models a finetuned version of stable audio tools for high-quality audio synthesis from text and video inputs. Currently only working on min. 16gb Vram - tested on a single 4090. 4 | 5 | ## 🎵 Features 6 | 7 | - **Text to Audio**: Generate high-quality audio from text descriptions with enhanced conditioning 8 | - **Text to Music**: Create musical compositions with style, tempo, and mood controls 9 | - **Video to Audio**: Extract and generate audio from video content with advanced conditioning 10 | - **Enhanced Conditioning**: Separate CFG scales, conditioning weights, negative prompting, and prompt enhancement 11 | - **Professional Audio Processing**: Volume control with LUFS normalization, limiting, and precise gain staging 12 | - **Video Processing**: Mute videos and combine with generated audio 13 | 14 | ## 🚀 Installation 15 | 16 | ### 1. System Dependencies (Required) 17 | **Install these system dependencies first:** 18 | 19 | **Windows:** 20 | ```bash 21 | # Install ffmpeg (required for video processing) 22 | # Download from: https://ffmpeg.org/download.html 23 | # Or use chocolatey: choco install ffmpeg 24 | 25 | # Install Microsoft Visual C++ Build Tools (if not already installed) 26 | # Download from: https://visualstudio.microsoft.com/visual-cpp-build-tools/ 27 | ``` 28 | 29 | **Linux/Ubuntu:** 30 | ```bash 31 | sudo apt update 32 | sudo apt install ffmpeg libsndfile1-dev build-essential 33 | ``` 34 | 35 | **macOS:** 36 | ```bash 37 | brew install ffmpeg libsndfile 38 | ``` 39 | 40 | ### 2. Clone Repository and Install Python Dependencies 41 | ```bash 42 | cd ComfyUI/custom_nodes 43 | git clone https://github.com/lum3on/ComfyUI-StableAudioX.git 44 | cd ComfyUI-StableAudioX 45 | 46 | # Install Python dependencies 47 | pip install -r requirements.txt 48 | 49 | # Optional: Run dependency checker to verify installation 50 | python install_dependencies.py 51 | ``` 52 | 53 | ### Model Setup – AudioX 54 | 55 | 1. **Model File**: Download from [Hugging Face - model.ckpt](https://huggingface.co/HKUSTAudio/AudioX/resolve/main/model.ckpt) 56 | 2. **Config File**: Download from [Hugging Face - config.json](https://huggingface.co/HKUSTAudio/AudioX/resolve/main/config.json) 57 | 3. **Place both files** in: 58 | `ComfyUI/models/diffusion_models/` 59 | 60 | rename model.ckpt file to AudioX.ckpt 61 | 62 | #### Alternative Download via Hugging Face CLI 63 | 64 | ```bash 65 | # Install huggingface-hub if not already installed 66 | pip install huggingface-hub 67 | 68 | # Download AudioX model files 69 | huggingface-cli download HKUSTAudio/AudioX model.ckpt --local-dir ComfyUI/models/diffusion_models/ 70 | huggingface-cli download HKUSTAudio/AudioX config.json --local-dir ComfyUI/models/diffusion_models/ 71 | ``` 72 | 73 | **Model Directory Structure:** 74 | ``` 75 | ComfyUI/models/diffusion_models/ 76 | ├── model.safetensors # AudioX model 77 | └── model_config.json # Model configuration file 78 | ``` 79 | ### System Requirements 80 | - **VRAM**: 6GB+ recommended for optimal performance 81 | - **RAM**: 16GB+ recommended 82 | - **Storage**: ~5GB for model files 83 | - **GPU**: CUDA-compatible GPU recommended (CPU supported but slower) 84 | 85 | ## 📋 Available Nodes 86 | 87 | ### Core Generation Nodes 88 | - **AudioX Model Loader**: Load AudioX models with device configuration and auto-detect config files 89 | - **AudioX Text to Audio**: Basic text-to-audio generation with automatic prompt enhancement 90 | - **AudioX Text to Music**: Basic text-to-music generation with automatic prompt enhancement 91 | - **AudioX Video to Audio**: Basic video-to-audio generation with automatic prompt enhancement 92 | - **AudioX Video to Music**: Generate musical soundtracks for videos 93 | 94 | ### Enhanced Generation Nodes ⭐ 95 | - **AudioX Enhanced Text to Audio**: Advanced text-to-audio with negative prompting, templates, style modifiers, and conditioning modes 96 | - **AudioX Enhanced Text to Music**: Advanced music generation with style, tempo, mood controls, and musical enhancement 97 | - **AudioX Enhanced Video to Audio**: Advanced video-to-audio with separate CFG scales, conditioning weights, and enhanced prompting 98 | 99 | ### Processing & Utility Nodes 100 | - **AudioX Audio Processor**: Process and enhance audio 101 | - **AudioX Volume Control**: Basic volume control with precise dB control and configurable step size 102 | - **AudioX Advanced Volume Control**: Professional volume control with LUFS normalization, soft limiting, and fade controls 103 | - **AudioX Video Muter**: Remove audio from video files 104 | - **AudioX Video Audio Combiner**: Combine video with generated audio 105 | - **AudioX Multi-Modal Generation**: Advanced multi-modal audio generation 106 | - **AudioX Prompt Helper**: Utility for creating better audio prompts with templates 107 | 108 | ## 🎯 Quick Start 109 | 110 | ### Basic Text to Audio 111 | 1. Add **AudioX Model Loader** node and select your model from `diffusion_models/` 112 | 2. Add **AudioX Text to Audio** node 113 | 3. Connect model output to audio generation node 114 | 4. Enter your text prompt (automatic enhancement applied) 115 | 5. Execute workflow 116 | 117 | ### Enhanced Text to Audio with Advanced Controls ⭐ 118 | 1. Add **AudioX Model Loader** node 119 | 2. Add **AudioX Enhanced Text to Audio** node 120 | 3. Configure advanced options: 121 | - **Negative Prompt**: Specify what to avoid (e.g., "muffled, distorted") 122 | - **Prompt Template**: Choose from predefined templates (action, nature, music, etc.) 123 | - **Style Modifier**: cinematic, realistic, ambient, dramatic, peaceful, energetic 124 | - **Conditioning Mode**: standard, enhanced, super_enhanced, multi_aspect 125 | - **Adaptive CFG**: Automatically adjusts CFG based on prompt specificity 126 | 4. Execute for enhanced audio generation 127 | 128 | ### Enhanced Video to Audio with Separate Controls ⭐ 129 | 1. Add **AudioX Model Loader** node 130 | 2. Add **AudioX Enhanced Video to Audio** node 131 | 3. Configure separate conditioning: 132 | - **Text CFG Scale**: Control text conditioning strength (0.1-20.0) 133 | - **Video CFG Scale**: Control video conditioning strength (0.1-20.0) 134 | - **Text Weight**: Influence of text conditioning (0.0-2.0) 135 | - **Video Weight**: Influence of video conditioning (0.0-2.0) 136 | - **Negative Prompt**: Avoid unwanted audio characteristics 137 | 4. Fine-tune balance between text prompts and video content 138 | 139 | ### Professional Audio Workflow with Volume Control 140 | 1. Generate audio using any AudioX generation node 141 | 2. Add **AudioX Advanced Volume Control** for professional features: 142 | - **LUFS Normalization**: Auto-normalize to broadcast standards (-23 LUFS) 143 | - **Soft Limiting**: Prevent clipping with configurable threshold 144 | - **Fade In/Out**: Add smooth fades to audio 145 | - **Precise Step Control**: Ultra-fine volume adjustments (0.001 dB steps) 146 | 3. Enable `auto_normalize_lufs` for automatic loudness normalization 147 | 4. Set `limiter_threshold_db` to prevent clipping (default: -1.0 dB) 148 | 5. Add fade_in_ms/fade_out_ms for smooth transitions 149 | 150 | ### Enhanced Music Generation ⭐ 151 | 1. Add **AudioX Enhanced Text to Music** node 152 | 2. Configure musical attributes: 153 | - **Music Style**: classical, jazz, electronic, ambient, rock, folk, cinematic 154 | - **Tempo**: slow, moderate, fast, very_fast 155 | - **Mood**: happy, sad, peaceful, energetic, mysterious, dramatic 156 | - **Negative Prompt**: Avoid discordant, harsh, or atonal characteristics 157 | 3. Use automatic music context enhancement for better results 158 | 159 | ## 📁 Example Workflows 160 | 161 | The repository includes example workflows: 162 | - `example_workflow.json` - Basic text to audio 163 | - `audiox_video_to_audio_workflow.json` - Video processing 164 | - `simple_video_to_audio_workflow.json` - Simplified video to audio 165 | 166 | ## ⚙️ Requirements 167 | 168 | - ComfyUI (latest version recommended) 169 | - Python 3.8+ 170 | - CUDA-compatible GPU (recommended) or CPU 171 | - Sufficient disk space for model downloads (models can be several GB) 172 | - AudioX model files and config.json (must be downloaded separately) 173 | 174 | ## 🔧 Configuration 175 | 176 | ### Model Storage 177 | **Important**: Models must be manually placed in the correct directory: 178 | - **Required Location**: `ComfyUI/models/diffusion_models/` 179 | - **Required Files**: 180 | - AudioX model file (`.safetensors` or `.ckpt`) 181 | - `config.json` configuration file 182 | - **Auto-Detection**: The AudioX Model Loader automatically detects config files 183 | 184 | ### Device Selection 185 | - Automatic device detection (CUDA/MPS/CPU) 186 | - Manual device specification available in Model Loader 187 | - Memory-efficient processing options 188 | 189 | ### Node Appearance 190 | - AudioX nodes feature a distinctive light purple color (#ddaeff) for easy identification 191 | - All nodes are categorized under "AudioX/" in the node browser 192 | 193 | ## ✨ Enhanced Features 194 | 195 | ### Advanced Conditioning Controls 196 | - **Separate CFG Scales**: Independent control over text and video conditioning strength 197 | - **Conditioning Weights**: Fine-tune the balance between text prompts and video content 198 | - **Negative Prompting**: Specify audio characteristics to avoid for better results 199 | - **Prompt Enhancement**: Automatic addition of audio-specific keywords and context 200 | 201 | ### Professional Audio Processing 202 | - **Volume Control with Step Size**: Configurable precision from coarse (1.0 dB) to ultra-fine (0.001 dB) 203 | - **LUFS Normalization**: Automatic loudness normalization to broadcast standards 204 | - **Soft Limiting**: Intelligent limiting to prevent clipping while preserving dynamics 205 | - **Fade Controls**: Smooth fade-in and fade-out with millisecond precision 206 | 207 | ### Intelligent Prompt Processing 208 | - **Template System**: Pre-defined templates for common audio scenarios (action, nature, music, urban) 209 | - **Style Modifiers**: Cinematic, realistic, ambient, dramatic, peaceful, energetic 210 | - **Conditioning Modes**: Standard, enhanced, super_enhanced, and multi_aspect processing 211 | - **Adaptive CFG**: Automatically adjusts CFG scale based on prompt specificity 212 | 213 | ## 🐛 Troubleshooting 214 | 215 | ### Common Issues 216 | 217 | **Installation Problems**: 218 | - **Missing ffmpeg**: Install ffmpeg system dependency (see installation steps above) 219 | - **Build errors on Windows**: Install Microsoft Visual C++ Build Tools 220 | - **Package conflicts**: Use a fresh virtual environment: `python -m venv audiox_env && audiox_env\Scripts\activate` 221 | - **Dependency failures**: Run `python install_dependencies.py` to check and install missing packages 222 | 223 | **Model Not Found**: If AudioX Model Loader shows no models: 224 | - Ensure model files are in `ComfyUI/models/diffusion_models/` 225 | - Verify both model file and `model_config.json` are present 226 | - Check file permissions and naming 227 | - Accept the license agreement on Hugging Face before downloading 228 | 229 | **Frontend Errors**: If you encounter "beforeQueued" errors: 230 | - Refresh browser (Ctrl+R) 231 | - Clear browser cache 232 | - Restart ComfyUI 233 | - Check ComfyUI console for dependency errors 234 | 235 | **Memory Issues**: For VRAM/RAM problems: 236 | - Reduce batch sizes and duration_seconds 237 | - Use CPU mode for large models 238 | - Close other applications 239 | - Try lower CFG scales (3.0-5.0) 240 | - Ensure you have at least 6GB VRAM for optimal performance 241 | 242 | **Audio Processing Errors**: 243 | - Verify ffmpeg is properly installed and in PATH 244 | - Check that libsndfile is installed (Linux/macOS) 245 | - For LUFS normalization issues, ensure `pyloudnorm` is installed 246 | 247 | ## 🤝 Contributing 248 | 249 | Contributions welcome! Please: 250 | 1. Fork the repository 251 | 2. Create a feature branch 252 | 3. Submit a pull request 253 | 254 | ## 📄 License 255 | 256 | MIT License - see LICENSE file for details. 257 | 258 | ## 🙏 Acknowledgments 259 | 260 | - AudioX team for original models and research 261 | - ComfyUI community for the excellent framework 262 | - All contributors and testers 263 | 264 | ## 📈 Version History 265 | 266 | **Current Version**: v1.1.0 267 | - ✅ **Enhanced Conditioning**: Added separate CFG scales, conditioning weights, and negative prompting 268 | - ✅ **Advanced Volume Control**: LUFS normalization, soft limiting, and configurable step precision 269 | - ✅ **Enhanced Generation Nodes**: Advanced text-to-audio, text-to-music, and video-to-audio nodes 270 | - ✅ **Intelligent Prompting**: Template system, style modifiers, and adaptive CFG 271 | - ✅ **Professional Audio Processing**: Fade controls, precise gain staging, and broadcast-standard normalization 272 | - ✅ **Improved UI**: Distinctive node appearance with light purple color scheme 273 | - ✅ **Better Model Management**: Auto-detection of config files and improved error handling 274 | 275 | **Previous Version**: v1.0.9 276 | - ✅ Fixed beforeQueued frontend errors 277 | - ✅ Improved workflow execution stability 278 | - ✅ Enhanced video processing capabilities 279 | - ✅ Better error handling and user experience 280 | 281 | ## 🎵 Audio Quality Features 282 | 283 | ### Enhanced Conditioning 284 | - **Better Prompt Adherence**: Enhanced conditioning modes ensure generated audio closely matches your descriptions 285 | - **Negative Prompting**: Avoid unwanted audio characteristics like "muffled", "distorted", or "low quality" 286 | - **Balanced Generation**: Fine-tune the balance between text prompts and video content for optimal results 287 | 288 | ### Professional Audio Standards 289 | - **LUFS Normalization**: Automatic loudness normalization to -23 LUFS (broadcast standard) 290 | - **Dynamic Range Preservation**: Soft limiting maintains audio dynamics while preventing clipping 291 | - **Precise Control**: Volume adjustments from coarse (1.0 dB) to ultra-fine (0.001 dB) steps 292 | 293 | ## 🚀 Roadmap 294 | 295 | ### Upcoming Features 296 | - **🎨 Audio Inpainting**: Fill gaps or replace sections in existing audio with AI-generated content 297 | - **🔧 LoRA Training**: Lightweight fine-tuning for custom audio styles and characteristics 298 | - **🎓 Full Fine-tune Training**: Complete model training pipeline for custom datasets and specialized audio domains 299 | - **� Extended Model Support**: Integration with additional AudioX model variants and architectures 300 | 301 | ### Development Timeline 302 | - **Phase 1** (Current): Enhanced conditioning and professional audio processing ✅ 303 | - **Phase 2** (Next): Audio inpainting capabilities and LoRA training infrastructure 304 | - **Phase 3** (Future): Full fine-tuning pipeline and extended model support 305 | 306 | We welcome community feedback and contributions to help prioritize these features! 307 | 308 | --- 309 | 310 | For support and updates, visit the [GitHub repository](https://github.com/lum3on/ComfyUI-StableAudioX). 311 | --------------------------------------------------------------------------------