├── modules ├── __init__.py ├── model_info.py ├── patcher.py ├── utils.py └── loader.py ├── vibevoice ├── __init__.py ├── modular │ ├── __init__.py │ ├── sage_attention_patch.py │ ├── modular_vibevoice_text_tokenizer.py │ ├── modular_vibevoice_diffusion_head.py │ ├── streamer.py │ ├── configuration_vibevoice.py │ └── modeling_vibevoice.py ├── processor │ ├── __init__.py │ ├── vibevoice_tokenizer_processor.py │ └── vibevoice_processor.py ├── schedule │ ├── __init__.py │ └── timestep_sampler.py ├── scripts │ ├── __init__.py │ └── convert_nnscaler_checkpoint_to_transformers.py └── configs │ ├── qwen2.5_1.5b_64k.json │ ├── qwen2.5_7b_32k.json │ ├── default_VibeVoice-1.5B_config.json │ └── default_VibeVoice-Large_config.json ├── example_workflows ├── VibeVoice_example.png └── VibeVoice_example.json ├── requirements.txt ├── pyproject.toml ├── .github └── workflows │ └── publish.yml ├── LICENSE ├── .gitignore ├── __init__.py ├── vibevoice_nodes.py └── README.md /modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vibevoice/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vibevoice/modular/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vibevoice/processor/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vibevoice/schedule/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vibevoice/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /example_workflows/VibeVoice_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wildminder/ComfyUI-VibeVoice/HEAD/example_workflows/VibeVoice_example.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | accelerate 3 | torchaudio 4 | librosa 5 | numpy 6 | huggingface_hub 7 | einops 8 | scipy 9 | tokenizers 10 | soundfile 11 | s3tokenizer 12 | conformer 13 | safetensors 14 | transformers>=4.51.3 15 | diffusers 16 | tqdm 17 | bitsandbytes 18 | -------------------------------------------------------------------------------- /modules/model_info.py: -------------------------------------------------------------------------------- 1 | # This dictionary contains the configurations for official, downloadable models. 2 | MODEL_CONFIGS = { 3 | "VibeVoice-1.5B": { 4 | "repo_id": "microsoft/VibeVoice-1.5B", 5 | "size_gb": 3.0, 6 | }, 7 | "VibeVoice-Large": { 8 | "repo_id": "aoi-ot/VibeVoice-Large", 9 | "size_gb": 17.4, 10 | } 11 | } 12 | 13 | AVAILABLE_VIBEVOICE_MODELS = {} -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "ComfyUI-VibeVoice" 3 | description = "VibeVoice TTS. Expressive, long-form, multi-speaker conversational audio" 4 | version = "1.5.1" 5 | license = {file = "LICENSE"} 6 | dependencies = ["torch", "torchaudio", "librosa", "numpy", "huggingface_hub", "einops", "scipy", "tokenizers", "soundfile", "s3tokenizer", "tqdm", "conformer", "safetensors", "transformers", "diffusers", "bitsandbytes"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/wildminder/ComfyUI-VibeVoice" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "wildai" 14 | DisplayName = "ComfyUI-VibeVoice" 15 | Icon = "" 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /vibevoice/schedule/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | class UniformSampler: 6 | def __init__(self, timesteps = 1000): 7 | self.timesteps = timesteps 8 | def sample(self, batch_size, device): 9 | return torch.randint(0, self.timesteps, (batch_size,), device=device) 10 | 11 | class LogitNormalSampler: 12 | def __init__(self, timesteps = 1000, m = 0, s = 1): 13 | self.timesteps = timesteps 14 | timesteps = torch.linspace(0, 1, timesteps) 15 | logit = torch.log(timesteps / (1 - timesteps)) 16 | self.prob = torch.exp(-0.5 * (logit - m) ** 2 / s ** 2) / (s * math.sqrt(2 * math.pi)) 17 | def sample(self, batch_size, device): 18 | return torch.multinomial(self.prob, batch_size, replacement=True).to(device) 19 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | permissions: 12 | issues: write 13 | 14 | jobs: 15 | publish-node: 16 | name: Publish Custom Node to registry 17 | runs-on: ubuntu-latest 18 | if: ${{ github.repository_owner == 'wildminder' }} 19 | steps: 20 | - name: Check out code 21 | uses: actions/checkout@v4 22 | with: 23 | submodules: true 24 | - name: Publish Custom Node 25 | uses: Comfy-Org/publish-node-action@v1 26 | with: 27 | ## Add your own personal access token to your Github Repository secrets and reference it here. 28 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 WildAi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /modules/patcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gc 3 | import logging 4 | import comfy.model_patcher 5 | import comfy.model_management as model_management 6 | 7 | from .loader import LOADED_MODELS, logger 8 | 9 | class VibeVoicePatcher(comfy.model_patcher.ModelPatcher): 10 | """Custom ModelPatcher for managing VibeVoice models in ComfyUI.""" 11 | def __init__(self, model, attention_mode="eager", *args, **kwargs): 12 | super().__init__(model, *args, **kwargs) 13 | self.attention_mode = attention_mode 14 | self.cache_key = model.cache_key 15 | 16 | @property 17 | def is_loaded(self): 18 | """Check if the model is currently loaded in memory.""" 19 | return hasattr(self, 'model') and self.model is not None and hasattr(self.model, 'model') and self.model.model is not None 20 | 21 | def patch_model(self, device_to=None, *args, **kwargs): 22 | target_device = self.load_device 23 | if self.model.model is None: 24 | logger.info(f"Loading VibeVoice models for '{self.model.model_pack_name}' to {target_device}...") 25 | mode_names = { 26 | "eager": "Eager (Most Compatible)", 27 | "sdpa": "SDPA (Balanced Speed/Compatibility)", 28 | "flash_attention_2": "Flash Attention 2 (Fastest)", 29 | "sage": "SageAttention (Quantized High-Performance)", 30 | } 31 | logger.info(f"Attention Mode: {mode_names.get(self.attention_mode, self.attention_mode)}") 32 | self.model.load_model(target_device, self.attention_mode) 33 | self.model.model.to(target_device) 34 | return super().patch_model(device_to=target_device, *args, **kwargs) 35 | 36 | def unpatch_model(self, device_to=None, unpatch_weights=True, *args, **kwargs): 37 | if unpatch_weights: 38 | logger.info(f"Offloading VibeVoice models for '{self.model.model_pack_name}' ({self.attention_mode}) to {device_to}...") 39 | self.model.model = None 40 | self.model.processor = None 41 | 42 | if self.cache_key in LOADED_MODELS: 43 | del LOADED_MODELS[self.cache_key] 44 | logger.info(f"Cleared LOADED_MODELS cache for: {self.cache_key}") 45 | 46 | gc.collect() 47 | model_management.soft_empty_cache() 48 | 49 | return super().unpatch_model(device_to, unpatch_weights, *args, **kwargs) -------------------------------------------------------------------------------- /vibevoice/configs/qwen2.5_1.5b_64k.json: -------------------------------------------------------------------------------- 1 | { 2 | "_attn_implementation_autoset": true, 3 | "acoustic_vae_dim": 64, 4 | "acoustic_tokenizer_config": { 5 | "causal": true, 6 | "channels": 1, 7 | "conv_bias": true, 8 | "conv_norm": "none", 9 | "corpus_normalize": 0.0, 10 | "decoder_depths": null, 11 | "decoder_n_filters": 32, 12 | "decoder_ratios": [ 13 | 8, 14 | 5, 15 | 5, 16 | 4, 17 | 2, 18 | 2 19 | ], 20 | "disable_last_norm": true, 21 | "encoder_depths": "3-3-3-3-3-3-8", 22 | "encoder_n_filters": 32, 23 | "encoder_ratios": [ 24 | 8, 25 | 5, 26 | 5, 27 | 4, 28 | 2, 29 | 2 30 | ], 31 | "fix_std": 0.5, 32 | "layer_scale_init_value": 1e-06, 33 | "layernorm": "RMSNorm", 34 | "layernorm_elementwise_affine": true, 35 | "layernorm_eps": 1e-05, 36 | "mixer_layer": "depthwise_conv", 37 | "model_type": "vibepod_acoustic_tokenizer", 38 | "pad_mode": "constant", 39 | "std_dist_type": "gaussian", 40 | "vae_dim": 64, 41 | "weight_init_value": 0.01 42 | }, 43 | "decoder_config": { 44 | "attention_dropout": 0.0, 45 | "hidden_act": "silu", 46 | "hidden_size": 1536, 47 | "initializer_range": 0.02, 48 | "intermediate_size": 8960, 49 | "max_position_embeddings": 65536, 50 | "max_window_layers": 28, 51 | "model_type": "qwen2", 52 | "num_attention_heads": 12, 53 | "num_hidden_layers": 28, 54 | "num_key_value_heads": 2, 55 | "rms_norm_eps": 1e-06, 56 | "rope_scaling": null, 57 | "rope_theta": 1000000.0, 58 | "sliding_window": 4096, 59 | "tie_word_embeddings": true, 60 | "torch_dtype": "bfloat16", 61 | "use_cache": true, 62 | "use_sliding_window": false, 63 | "vocab_size": 151936 64 | }, 65 | "diffusion_head_config": { 66 | "ddpm_batch_mul": 4, 67 | "ddpm_beta_schedule": "cosine", 68 | "ddpm_num_inference_steps": 20, 69 | "ddpm_num_steps": 1000, 70 | "diffusion_type": "ddpm", 71 | "head_ffn_ratio": 3.0, 72 | "head_layers": 4, 73 | "hidden_size": 1536, 74 | "latent_size": 64, 75 | "model_type": "vibepod_diffusion_head", 76 | "prediction_type": "v_prediction", 77 | "rms_norm_eps": 1e-05, 78 | "speech_vae_dim": 64 79 | }, 80 | "model_type": "vibepod", 81 | "semantic_tokenizer_config": { 82 | "causal": true, 83 | "channels": 1, 84 | "conv_bias": true, 85 | "conv_norm": "none", 86 | "corpus_normalize": 0.0, 87 | "disable_last_norm": true, 88 | "encoder_depths": "3-3-3-3-3-3-8", 89 | "encoder_n_filters": 32, 90 | "encoder_ratios": [ 91 | 8, 92 | 5, 93 | 5, 94 | 4, 95 | 2, 96 | 2 97 | ], 98 | "fix_std": 0, 99 | "layer_scale_init_value": 1e-06, 100 | "layernorm": "RMSNorm", 101 | "layernorm_elementwise_affine": true, 102 | "layernorm_eps": 1e-05, 103 | "mixer_layer": "depthwise_conv", 104 | "model_type": "vibepod_semantic_tokenizer", 105 | "pad_mode": "constant", 106 | "std_dist_type": "none", 107 | "vae_dim": 128, 108 | "weight_init_value": 0.01 109 | }, 110 | "semantic_vae_dim": 128, 111 | "torch_dtype": "bfloat16" 112 | } 113 | -------------------------------------------------------------------------------- /vibevoice/configs/qwen2.5_7b_32k.json: -------------------------------------------------------------------------------- 1 | { 2 | "_attn_implementation_autoset": true, 3 | "acoustic_vae_dim": 64, 4 | "acoustic_tokenizer_config": { 5 | "causal": true, 6 | "channels": 1, 7 | "conv_bias": true, 8 | "conv_norm": "none", 9 | "corpus_normalize": 0.0, 10 | "decoder_depths": null, 11 | "decoder_n_filters": 32, 12 | "decoder_ratios": [ 13 | 8, 14 | 5, 15 | 5, 16 | 4, 17 | 2, 18 | 2 19 | ], 20 | "disable_last_norm": true, 21 | "encoder_depths": "3-3-3-3-3-3-8", 22 | "encoder_n_filters": 32, 23 | "encoder_ratios": [ 24 | 8, 25 | 5, 26 | 5, 27 | 4, 28 | 2, 29 | 2 30 | ], 31 | "fix_std": 0.5, 32 | "layer_scale_init_value": 1e-06, 33 | "layernorm": "RMSNorm", 34 | "layernorm_elementwise_affine": true, 35 | "layernorm_eps": 1e-05, 36 | "mixer_layer": "depthwise_conv", 37 | "model_type": "vibepod_acoustic_tokenizer", 38 | "pad_mode": "constant", 39 | "std_dist_type": "gaussian", 40 | "vae_dim": 64, 41 | "weight_init_value": 0.01 42 | }, 43 | "decoder_config": { 44 | "attention_dropout": 0.0, 45 | "hidden_act": "silu", 46 | "hidden_size": 3584, 47 | "initializer_range": 0.02, 48 | "intermediate_size": 18944, 49 | "max_position_embeddings": 32768, 50 | "max_window_layers": 28, 51 | "model_type": "qwen2", 52 | "num_attention_heads": 28, 53 | "num_hidden_layers": 28, 54 | "num_key_value_heads": 4, 55 | "rms_norm_eps": 1e-06, 56 | "rope_theta": 1000000.0, 57 | "sliding_window": 4096, 58 | "tie_word_embeddings": false, 59 | "torch_dtype": "bfloat16", 60 | "transformers_version": "4.40.1", 61 | "use_cache": true, 62 | "use_mrope": false, 63 | "use_sliding_window": false, 64 | "vocab_size": 152064 65 | }, 66 | "diffusion_head_config": { 67 | "ddpm_batch_mul": 4, 68 | "ddpm_beta_schedule": "cosine", 69 | "ddpm_num_inference_steps": 20, 70 | "ddpm_num_steps": 1000, 71 | "diffusion_type": "ddpm", 72 | "head_ffn_ratio": 3.0, 73 | "head_layers": 4, 74 | "hidden_size": 3584, 75 | "latent_size": 64, 76 | "model_type": "vibepod_diffusion_head", 77 | "prediction_type": "v_prediction", 78 | "rms_norm_eps": 1e-05, 79 | "speech_vae_dim": 64 80 | }, 81 | "model_type": "vibepod", 82 | "semantic_tokenizer_config": { 83 | "causal": true, 84 | "channels": 1, 85 | "conv_bias": true, 86 | "conv_norm": "none", 87 | "corpus_normalize": 0.0, 88 | "disable_last_norm": true, 89 | "encoder_depths": "3-3-3-3-3-3-8", 90 | "encoder_n_filters": 32, 91 | "encoder_ratios": [ 92 | 8, 93 | 5, 94 | 5, 95 | 4, 96 | 2, 97 | 2 98 | ], 99 | "fix_std": 0, 100 | "layer_scale_init_value": 1e-06, 101 | "layernorm": "RMSNorm", 102 | "layernorm_elementwise_affine": true, 103 | "layernorm_eps": 1e-05, 104 | "mixer_layer": "depthwise_conv", 105 | "model_type": "vibepod_semantic_tokenizer", 106 | "pad_mode": "constant", 107 | "std_dist_type": "none", 108 | "vae_dim": 128, 109 | "weight_init_value": 0.01 110 | }, 111 | "semantic_vae_dim": 128, 112 | "torch_dtype": "bfloat16" 113 | } 114 | -------------------------------------------------------------------------------- /vibevoice/configs/default_VibeVoice-1.5B_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "acoustic_vae_dim": 64, 3 | "acoustic_tokenizer_config": { 4 | "causal": true, 5 | "channels": 1, 6 | "conv_bias": true, 7 | "conv_norm": "none", 8 | "corpus_normalize": 0.0, 9 | "decoder_depths": null, 10 | "decoder_n_filters": 32, 11 | "decoder_ratios": [ 12 | 8, 13 | 5, 14 | 5, 15 | 4, 16 | 2, 17 | 2 18 | ], 19 | "disable_last_norm": true, 20 | "encoder_depths": "3-3-3-3-3-3-8", 21 | "encoder_n_filters": 32, 22 | "encoder_ratios": [ 23 | 8, 24 | 5, 25 | 5, 26 | 4, 27 | 2, 28 | 2 29 | ], 30 | "fix_std": 0.5, 31 | "layer_scale_init_value": 1e-06, 32 | "layernorm": "RMSNorm", 33 | "layernorm_elementwise_affine": true, 34 | "layernorm_eps": 1e-05, 35 | "mixer_layer": "depthwise_conv", 36 | "model_type": "vibevoice_acoustic_tokenizer", 37 | "pad_mode": "constant", 38 | "std_dist_type": "gaussian", 39 | "vae_dim": 64, 40 | "weight_init_value": 0.01 41 | }, 42 | "architectures": [ 43 | "VibeVoiceForConditionalGeneration" 44 | ], 45 | "decoder_config": { 46 | "attention_dropout": 0.0, 47 | "hidden_act": "silu", 48 | "hidden_size": 1536, 49 | "initializer_range": 0.02, 50 | "intermediate_size": 8960, 51 | "max_position_embeddings": 65536, 52 | "max_window_layers": 28, 53 | "model_type": "qwen2", 54 | "num_attention_heads": 12, 55 | "num_hidden_layers": 28, 56 | "num_key_value_heads": 2, 57 | "rms_norm_eps": 1e-06, 58 | "rope_scaling": null, 59 | "rope_theta": 1000000.0, 60 | "sliding_window": null, 61 | "tie_word_embeddings": true, 62 | "torch_dtype": "bfloat16", 63 | "use_cache": true, 64 | "use_sliding_window": false, 65 | "vocab_size": 151936 66 | }, 67 | "diffusion_head_config": { 68 | "ddpm_batch_mul": 4, 69 | "ddpm_beta_schedule": "cosine", 70 | "ddpm_num_inference_steps": 20, 71 | "ddpm_num_steps": 1000, 72 | "diffusion_type": "ddpm", 73 | "head_ffn_ratio": 3.0, 74 | "head_layers": 4, 75 | "hidden_size": 1536, 76 | "latent_size": 64, 77 | "model_type": "vibevoice_diffusion_head", 78 | "prediction_type": "v_prediction", 79 | "rms_norm_eps": 1e-05, 80 | "speech_vae_dim": 64 81 | }, 82 | "model_type": "vibevoice", 83 | "semantic_tokenizer_config": { 84 | "causal": true, 85 | "channels": 1, 86 | "conv_bias": true, 87 | "conv_norm": "none", 88 | "corpus_normalize": 0.0, 89 | "disable_last_norm": true, 90 | "encoder_depths": "3-3-3-3-3-3-8", 91 | "encoder_n_filters": 32, 92 | "encoder_ratios": [ 93 | 8, 94 | 5, 95 | 5, 96 | 4, 97 | 2, 98 | 2 99 | ], 100 | "fix_std": 0, 101 | "layer_scale_init_value": 1e-06, 102 | "layernorm": "RMSNorm", 103 | "layernorm_elementwise_affine": true, 104 | "layernorm_eps": 1e-05, 105 | "mixer_layer": "depthwise_conv", 106 | "model_type": "vibevoice_semantic_tokenizer", 107 | "pad_mode": "constant", 108 | "std_dist_type": "none", 109 | "vae_dim": 128, 110 | "weight_init_value": 0.01 111 | }, 112 | "semantic_vae_dim": 128, 113 | "torch_dtype": "bfloat16", 114 | "transformers_version": "4.51.3" 115 | } 116 | -------------------------------------------------------------------------------- /vibevoice/configs/default_VibeVoice-Large_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "acostic_vae_dim": 64, 3 | "acoustic_tokenizer_config": { 4 | "causal": true, 5 | "channels": 1, 6 | "conv_bias": true, 7 | "conv_norm": "none", 8 | "corpus_normalize": 0.0, 9 | "decoder_depths": null, 10 | "decoder_n_filters": 32, 11 | "decoder_ratios": [ 12 | 8, 13 | 5, 14 | 5, 15 | 4, 16 | 2, 17 | 2 18 | ], 19 | "disable_last_norm": true, 20 | "encoder_depths": "3-3-3-3-3-3-8", 21 | "encoder_n_filters": 32, 22 | "encoder_ratios": [ 23 | 8, 24 | 5, 25 | 5, 26 | 4, 27 | 2, 28 | 2 29 | ], 30 | "fix_std": 0.5, 31 | "layer_scale_init_value": 1e-06, 32 | "layernorm": "RMSNorm", 33 | "layernorm_elementwise_affine": true, 34 | "layernorm_eps": 1e-05, 35 | "mixer_layer": "depthwise_conv", 36 | "model_type": "vibevoice_acoustic_tokenizer", 37 | "pad_mode": "constant", 38 | "std_dist_type": "gaussian", 39 | "vae_dim": 64, 40 | "weight_init_value": 0.01 41 | }, 42 | "architectures": [ 43 | "VibeVoiceForConditionalGeneration" 44 | ], 45 | "decoder_config": { 46 | "attention_dropout": 0.0, 47 | "hidden_act": "silu", 48 | "hidden_size": 3584, 49 | "initializer_range": 0.02, 50 | "intermediate_size": 18944, 51 | "max_position_embeddings": 32768, 52 | "max_window_layers": 28, 53 | "model_type": "qwen2", 54 | "num_attention_heads": 28, 55 | "num_hidden_layers": 28, 56 | "num_key_value_heads": 4, 57 | "rms_norm_eps": 1e-06, 58 | "rope_scaling": null, 59 | "rope_theta": 1000000.0, 60 | "sliding_window": null, 61 | "torch_dtype": "bfloat16", 62 | "use_cache": true, 63 | "use_mrope": false, 64 | "use_sliding_window": false, 65 | "vocab_size": 152064 66 | }, 67 | "diffusion_head_config": { 68 | "ddpm_batch_mul": 4, 69 | "ddpm_beta_schedule": "cosine", 70 | "ddpm_num_inference_steps": 20, 71 | "ddpm_num_steps": 1000, 72 | "diffusion_type": "ddpm", 73 | "head_ffn_ratio": 3.0, 74 | "head_layers": 4, 75 | "hidden_size": 3584, 76 | "latent_size": 64, 77 | "model_type": "vibevoice_diffusion_head", 78 | "prediction_type": "v_prediction", 79 | "rms_norm_eps": 1e-05, 80 | "speech_vae_dim": 64 81 | }, 82 | "model_type": "vibevoice", 83 | "semantic_tokenizer_config": { 84 | "causal": true, 85 | "channels": 1, 86 | "conv_bias": true, 87 | "conv_norm": "none", 88 | "corpus_normalize": 0.0, 89 | "disable_last_norm": true, 90 | "encoder_depths": "3-3-3-3-3-3-8", 91 | "encoder_n_filters": 32, 92 | "encoder_ratios": [ 93 | 8, 94 | 5, 95 | 5, 96 | 4, 97 | 2, 98 | 2 99 | ], 100 | "fix_std": 0, 101 | "layer_scale_init_value": 1e-06, 102 | "layernorm": "RMSNorm", 103 | "layernorm_elementwise_affine": true, 104 | "layernorm_eps": 1e-05, 105 | "mixer_layer": "depthwise_conv", 106 | "model_type": "vibevoice_semantic_tokenizer", 107 | "pad_mode": "constant", 108 | "std_dist_type": "none", 109 | "vae_dim": 128, 110 | "weight_init_value": 0.01 111 | }, 112 | "semantic_vae_dim": 128, 113 | "tie_word_embeddings": false, 114 | "torch_dtype": "bfloat16", 115 | "transformers_version": "4.51.3" 116 | } 117 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .github 11 | .idea 12 | .Python 13 | __pycache__ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | import folder_paths 5 | import json 6 | 7 | try: 8 | import sageattention 9 | SAGE_ATTENTION_AVAILABLE = True 10 | except ImportError: 11 | SAGE_ATTENTION_AVAILABLE = False 12 | 13 | current_dir = os.path.dirname(os.path.abspath(__file__)) 14 | if current_dir not in sys.path: 15 | sys.path.append(current_dir) 16 | 17 | from .modules.model_info import AVAILABLE_VIBEVOICE_MODELS, MODEL_CONFIGS 18 | 19 | # Configure a logger 20 | logger = logging.getLogger(__name__) 21 | logger.setLevel(logging.INFO) 22 | logger.propagate = False 23 | if not logger.hasHandlers(): 24 | handler = logging.StreamHandler() 25 | formatter = logging.Formatter(f"[ComfyUI-VibeVoice] %(message)s") 26 | handler.setFormatter(formatter) 27 | logger.addHandler(handler) 28 | 29 | # This is just the *name* of the subdirectory, not the full path. 30 | VIBEVOICE_SUBDIR_NAME = "VibeVoice" 31 | 32 | # This is the *primary* path where official models will be downloaded. 33 | primary_vibevoice_models_path = os.path.join(folder_paths.models_dir, "tts", VIBEVOICE_SUBDIR_NAME) 34 | os.makedirs(primary_vibevoice_models_path, exist_ok=True) 35 | 36 | # Register the tts path type with ComfyUI so get_folder_paths works 37 | tts_path = os.path.join(folder_paths.models_dir, "tts") 38 | if "tts" not in folder_paths.folder_names_and_paths: 39 | supported_exts = folder_paths.supported_pt_extensions.union({".safetensors", ".json"}) 40 | folder_paths.folder_names_and_paths["tts"] = ([tts_path], supported_exts) 41 | else: 42 | # Ensure the default path is in the list if it's not already 43 | if tts_path not in folder_paths.folder_names_and_paths["tts"][0]: 44 | folder_paths.folder_names_and_paths["tts"][0].append(tts_path) 45 | 46 | # The logic for dynamic model discovery 47 | # ToDo: optimize finding 48 | 49 | # official models that can be auto-downloaded 50 | for model_name, config in MODEL_CONFIGS.items(): 51 | AVAILABLE_VIBEVOICE_MODELS[model_name] = { 52 | "type": "official", 53 | "repo_id": config["repo_id"], 54 | "tokenizer_repo": "Qwen/Qwen2.5-7B" if "Large" in model_name else "Qwen/Qwen2.5-1.5B" 55 | } 56 | 57 | # just workaround, default + custom 58 | vibevoice_search_paths = [] 59 | # Use ComfyUI's API to get all registered 'tts' folders 60 | for tts_folder in folder_paths.get_folder_paths("tts"): 61 | potential_path = os.path.join(tts_folder, VIBEVOICE_SUBDIR_NAME) 62 | if os.path.isdir(potential_path) and potential_path not in vibevoice_search_paths: 63 | vibevoice_search_paths.append(potential_path) 64 | 65 | # Add the primary path just in case it wasn't registered for some reason 66 | if primary_vibevoice_models_path not in vibevoice_search_paths: 67 | vibevoice_search_paths.insert(0, primary_vibevoice_models_path) 68 | 69 | # Messy... Discover all local models in the search paths 70 | for search_path in vibevoice_search_paths: 71 | logger.info(f"Scanning for VibeVoice models in: {search_path}") 72 | if not os.path.exists(search_path): continue 73 | for item in os.listdir(search_path): 74 | item_path = os.path.join(search_path, item) 75 | 76 | # Case 1: we have a standard HF directory 77 | if os.path.isdir(item_path): 78 | model_name = item 79 | if model_name in AVAILABLE_VIBEVOICE_MODELS: continue 80 | 81 | config_exists = os.path.exists(os.path.join(item_path, "config.json")) 82 | weights_exist = os.path.exists(os.path.join(item_path, "model.safetensors.index.json")) or any(f.endswith(('.safetensors', '.bin')) for f in os.listdir(item_path)) 83 | 84 | if config_exists and weights_exist: 85 | tokenizer_repo = "Qwen/Qwen2.5-7B" if "large" in model_name.lower() else "Qwen/Qwen2.5-1.5B" 86 | AVAILABLE_VIBEVOICE_MODELS[model_name] = { 87 | "type": "local_dir", 88 | "path": item_path, 89 | "tokenizer_repo": tokenizer_repo 90 | } 91 | 92 | # Case 2: Item is a standalone file 93 | elif os.path.isfile(item_path) and any(item.endswith(ext) for ext in folder_paths.supported_pt_extensions): 94 | model_name = os.path.splitext(item)[0] 95 | if model_name in AVAILABLE_VIBEVOICE_MODELS: continue 96 | 97 | tokenizer_repo = "Qwen/Qwen2.5-7B" if "large" in model_name.lower() else "Qwen/Qwen2.5-1.5B" 98 | AVAILABLE_VIBEVOICE_MODELS[model_name] = { 99 | "type": "standalone", 100 | "path": item_path, 101 | "tokenizer_repo": tokenizer_repo 102 | } 103 | 104 | logger.info(f"Discovered VibeVoice models: {sorted(list(AVAILABLE_VIBEVOICE_MODELS.keys()))}") 105 | 106 | from .vibevoice_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 107 | 108 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] -------------------------------------------------------------------------------- /modules/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import numpy as np 4 | import random 5 | import logging 6 | 7 | from comfy.utils import ProgressBar 8 | from comfy.model_management import throw_exception_if_processing_interrupted 9 | 10 | try: 11 | import librosa 12 | except ImportError: 13 | print("VibeVoice Node: `librosa` is not installed. Resampling of reference audio will not be available.") 14 | librosa = None 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | def set_vibevoice_seed(seed: int): 19 | """Sets the seed for torch, numpy, and random, handling large seeds for numpy.""" 20 | if seed == 0: 21 | seed = random.randint(1, 0xffffffffffffffff) 22 | 23 | MAX_NUMPY_SEED = 2**32 - 1 24 | numpy_seed = seed % MAX_NUMPY_SEED 25 | 26 | torch.manual_seed(seed) 27 | if torch.cuda.is_available(): 28 | torch.cuda.manual_seed_all(seed) 29 | np.random.seed(numpy_seed) 30 | random.seed(seed) 31 | 32 | def parse_script_1_based(script: str) -> tuple[list[tuple[int, str]], list[int]]: 33 | """ 34 | Parses a 1-based speaker script into a list of (speaker_id, text) tuples 35 | and a list of unique speaker IDs in the order of their first appearance. 36 | Internally, it converts speaker IDs to 0-based for the model. 37 | 38 | Supports two formats: 39 | 1. Speaker 1: Some text... 40 | 2. [1] Some text... 41 | 42 | If no speaker markers are found, the entire script is assigned to Speaker 1. 43 | """ 44 | parsed_lines = [] 45 | speaker_ids_in_script = [] # This will store the 1-based IDs from the script 46 | 47 | line_format_regex = re.compile(r'^(?:Speaker\s+(\d+)\s*:|\[(\d+)\])\s*(.*)$', re.IGNORECASE) 48 | 49 | for line in script.strip().split("\n"): 50 | if not (line := line.strip()): continue 51 | 52 | match = line_format_regex.match(line) 53 | if match: 54 | speaker_id_str = match.group(1) or match.group(2) 55 | speaker_id = int(speaker_id_str) 56 | text_content = match.group(3) 57 | 58 | if match.group(1) is None and text_content.lstrip().startswith(':'): 59 | colon_index = text_content.find(':') 60 | text_content = text_content[colon_index + 1:] 61 | 62 | if speaker_id < 1: 63 | logger.warning(f"Speaker ID must be 1 or greater. Skipping line: '{line}'") 64 | continue 65 | 66 | text = text_content.strip() # REMOVED the prepended space ' ' + 67 | internal_speaker_id = speaker_id - 1 68 | parsed_lines.append((internal_speaker_id, text)) 69 | 70 | if speaker_id not in speaker_ids_in_script: 71 | speaker_ids_in_script.append(speaker_id) 72 | else: 73 | logger.warning(f"Could not parse speaker marker, treating as part of previous line if any, or ignoring: '{line}'") 74 | 75 | if not parsed_lines and script.strip(): 76 | logger.info("No speaker markers found. Treating entire text as a single utterance for Speaker 1.") 77 | parsed_lines.append((0, ' ' + script.strip())) 78 | speaker_ids_in_script.append(1) 79 | 80 | return parsed_lines, sorted(list(set(speaker_ids_in_script))) 81 | 82 | 83 | def preprocess_comfy_audio(audio_dict: dict, target_sr: int = 24000) -> np.ndarray: 84 | """ 85 | Converts a ComfyUI AUDIO dict to a mono NumPy array, resampling if necessary. 86 | """ 87 | if not audio_dict: return None 88 | waveform_tensor = audio_dict.get('waveform') 89 | if waveform_tensor is None or waveform_tensor.numel() == 0: return None 90 | 91 | waveform = waveform_tensor[0].cpu().numpy() 92 | original_sr = audio_dict['sample_rate'] 93 | 94 | if waveform.ndim > 1: 95 | waveform = np.mean(waveform, axis=0) 96 | 97 | # Check for invalid values 98 | if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)): 99 | logger.error("Audio contains NaN or Inf values, replacing with zeros") 100 | waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0) 101 | 102 | # Ensure audio is not completely silent or has extreme values 103 | if np.all(waveform == 0): 104 | logger.warning("Audio waveform is completely silent") 105 | 106 | # Normalize extreme values 107 | max_val = np.abs(waveform).max() 108 | if max_val > 10.0: 109 | logger.warning(f"Audio values are very large (max: {max_val}), normalizing") 110 | waveform = waveform / max_val 111 | 112 | if original_sr != target_sr: 113 | if librosa is None: 114 | raise ImportError("`librosa` package is required for audio resampling. Please install it with `pip install librosa`.") 115 | logger.warning(f"Resampling reference audio from {original_sr}Hz to {target_sr}Hz.") 116 | waveform = librosa.resample(y=waveform, orig_sr=original_sr, target_sr=target_sr) 117 | 118 | # Final check after resampling 119 | if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)): 120 | logger.error("Audio contains NaN or Inf after resampling, replacing with zeros") 121 | waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0) 122 | 123 | return waveform.astype(np.float32) 124 | 125 | def check_for_interrupt(): 126 | try: 127 | throw_exception_if_processing_interrupted() 128 | return False 129 | except: 130 | return True -------------------------------------------------------------------------------- /vibevoice/modular/sage_attention_patch.py: -------------------------------------------------------------------------------- 1 | # Author: Wildminder 2 | # Desc: SageAttention and patcher 3 | # License: Apache 2.0 4 | 5 | import torch 6 | from typing import Optional, Tuple 7 | 8 | from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, apply_rotary_pos_emb, repeat_kv 9 | from transformers.cache_utils import Cache 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | try: 15 | from sageattention.core import ( 16 | sageattn_qk_int8_pv_fp16_cuda, 17 | sageattn_qk_int8_pv_fp8_cuda, 18 | sageattn_qk_int8_pv_fp8_cuda_sm90, 19 | ) 20 | SAGE_ATTENTION_AVAILABLE = True 21 | except ImportError: 22 | SAGE_ATTENTION_AVAILABLE = False 23 | 24 | 25 | def get_sage_attention_function_and_params(): 26 | """ 27 | Selects the best available SageAttention CUDA kernel and its parameters 28 | based on the current GPU architecture. 29 | """ 30 | if not SAGE_ATTENTION_AVAILABLE or not torch.cuda.is_available(): 31 | return None, None, None 32 | 33 | major, minor = torch.cuda.get_device_capability() 34 | arch_code = major * 10 + minor 35 | 36 | attn_func = None 37 | pv_accum_dtype = "fp32" 38 | 39 | if arch_code >= 120: # Blackwell 40 | pv_accum_dtype = "fp32+fp32" 41 | attn_func = sageattn_qk_int8_pv_fp8_cuda 42 | logger.info(f"SageAttention: Using SM120 (Blackwell) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.") 43 | elif arch_code >= 90: # Hopper 44 | pv_accum_dtype = "fp32+fp32" 45 | attn_func = sageattn_qk_int8_pv_fp8_cuda_sm90 46 | logger.info(f"SageAttention: Using SM90 (Hopper) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.") 47 | elif arch_code == 89: # Ada Lovelace 48 | pv_accum_dtype = "fp32+fp32" 49 | attn_func = sageattn_qk_int8_pv_fp8_cuda 50 | logger.info(f"SageAttention: Using SM89 (Ada) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.") 51 | elif arch_code >= 80: # Ampere 52 | pv_accum_dtype = "fp32" 53 | attn_func = sageattn_qk_int8_pv_fp16_cuda 54 | logger.info(f"SageAttention: Using SM80+ (Ampere) FP16 kernel with pv_accum_dtype='{pv_accum_dtype}'.") 55 | else: 56 | logger.warning(f"SageAttention not supported on current GPU architecture (SM{arch_code}).") 57 | return None, None, None 58 | 59 | return attn_func, "per_warp", pv_accum_dtype 60 | 61 | SAGE_ATTENTION_FUNCTION, QK_QUANT_GRAN, PV_ACCUM_DTYPE = get_sage_attention_function_and_params() 62 | 63 | 64 | def sage_attention_forward( 65 | self, 66 | hidden_states: torch.Tensor, 67 | position_embeddings: tuple[torch.Tensor, torch.Tensor], 68 | attention_mask: Optional[torch.Tensor] = None, 69 | past_key_values: Optional[Cache] = None, 70 | cache_position: Optional[torch.LongTensor] = None, 71 | **kwargs, 72 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 73 | 74 | if SAGE_ATTENTION_FUNCTION is None: 75 | raise RuntimeError("SageAttention was selected but no compatible kernel was found for this GPU.") 76 | 77 | original_dtype = hidden_states.dtype 78 | 79 | is_4bit = hasattr(self.q_proj, 'quant_state') 80 | if is_4bit: 81 | target_dtype = torch.bfloat16 82 | else: 83 | target_dtype = self.q_proj.weight.dtype 84 | 85 | if hidden_states.dtype != target_dtype: 86 | hidden_states = hidden_states.to(target_dtype) 87 | 88 | bsz, q_len, _ = hidden_states.size() 89 | 90 | query_states = self.q_proj(hidden_states) 91 | key_states = self.k_proj(hidden_states) 92 | value_states = self.v_proj(hidden_states) 93 | 94 | query_states = query_states.view(bsz, q_len, self.config.num_attention_heads, self.head_dim).transpose(1, 2) 95 | key_states = key_states.view(bsz, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2) 96 | value_states = value_states.view(bsz, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2) 97 | 98 | cos, sin = position_embeddings 99 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids=None) 100 | 101 | if past_key_values is not None: 102 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} 103 | key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) 104 | 105 | # !! DO NOT repeat K and V heads here. The SageAttention kernel is optimized 106 | # to handle the broadcasting internally. 107 | 108 | is_causal = attention_mask is None and q_len > 1 109 | 110 | attn_output = SAGE_ATTENTION_FUNCTION( 111 | query_states.to(target_dtype), 112 | key_states.to(target_dtype), 113 | value_states.to(target_dtype), 114 | tensor_layout="HND", 115 | is_causal=is_causal, 116 | qk_quant_gran=QK_QUANT_GRAN, 117 | pv_accum_dtype=PV_ACCUM_DTYPE, 118 | ) 119 | 120 | if isinstance(attn_output, tuple): 121 | attn_output = attn_output[0] 122 | 123 | attn_output = attn_output.transpose(1, 2).contiguous() 124 | attn_output = attn_output.reshape(bsz, q_len, self.config.hidden_size) 125 | 126 | attn_output = self.o_proj(attn_output) 127 | 128 | if attn_output.dtype != original_dtype: 129 | attn_output = attn_output.to(original_dtype) 130 | 131 | attn_weights = None 132 | 133 | return attn_output, attn_weights 134 | 135 | 136 | def set_sage_attention(model): 137 | """ 138 | Recursively iterates through the model's modules and monkey-patches the 139 | forward method of each Qwen2Attention layer. 140 | """ 141 | if not SAGE_ATTENTION_AVAILABLE: 142 | raise ImportError("SageAttention library is not installed or failed to load.") 143 | 144 | if SAGE_ATTENTION_FUNCTION is None: 145 | return 146 | 147 | for module in model.modules(): 148 | if isinstance(module, Qwen2Attention): 149 | module.forward = sage_attention_forward.__get__(module, Qwen2Attention) -------------------------------------------------------------------------------- /vibevoice/scripts/convert_nnscaler_checkpoint_to_transformers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import argparse 5 | import json 6 | import os 7 | from pathlib import Path 8 | import re 9 | import torch 10 | from typing import Dict, List, Tuple 11 | 12 | from vibevoice.modular.configuration_vibevoice import ( 13 | VibeVoiceConfig 14 | ) 15 | from vibevoice.modular.modeling_vibevoice import VibeVoiceForConditionalGeneration 16 | from transformers.utils import logging 17 | 18 | logger = logging.get_logger(__name__) 19 | 20 | def convert_vibevoice_nnscaler_checkpoint_to_hf( 21 | checkpoint_path: str, 22 | pytorch_dump_folder_path: str, 23 | config_path: str = None, 24 | ): 25 | """ 26 | Convert a nnscaler VibeVoice checkpoint to HuggingFace format. 27 | Supports both regular checkpoints and tensor parallel checkpoints. 28 | """ 29 | 30 | # Load regular checkpoint 31 | logger.info(f"Loading regular checkpoint from {checkpoint_path}") 32 | checkpoint = torch.load(checkpoint_path, map_location="cpu") # ['model', 'optimizer', 'lr_scheduler', 'train_status', 'train_args', 'rng_states', 'nnscaler', 'dataloader'] 33 | 34 | # config = checkpoint['train_args'] 35 | init_config_name = checkpoint['train_args']['vars']['model_args']['config_path']['relative_path'] 36 | pretrained_name = checkpoint['train_args']['vars']['data_args']['tokenizer_path'] 37 | 38 | init_config_path = Path(__file__).parent.parent / 'configs' / init_config_name.split('/')[-1] 39 | if init_config_path.exists(): 40 | logger.info(f"Loading initial config from {init_config_path}") 41 | with open(init_config_path, 'r') as f: 42 | init_config = json.load(f) 43 | else: 44 | raise FileNotFoundError(f"Initial config file {init_config_path} not found. Please provide a valid path.") 45 | 46 | tie_word_embeddings = init_config['decoder_config'].get('tie_word_embeddings', True) 47 | logger.info(f"Tie word embeddings: {tie_word_embeddings}") 48 | 49 | init_config['decoder_config']['use_cache'] = True 50 | config = VibeVoiceConfig(**init_config, tie_word_embeddings=tie_word_embeddings) 51 | 52 | # # Extract the model state dict 53 | model_state_dict = {k.replace('model.model.', 'model.'): v for k, v in checkpoint["model"].items() if k.startswith('model.model.')} 54 | if not tie_word_embeddings and 'model.lm_head.weight' in checkpoint["model"].keys(): 55 | # If not tying weights, we need to add the lm_head weight separately 56 | model_state_dict['lm_head.weight'] = checkpoint["model"]['model.lm_head.weight'] 57 | 58 | # Override with provided config if available 59 | if config_path: 60 | logger.info(f"Loading config from {config_path}") 61 | with open(config_path, 'r') as f: 62 | config_dict = json.load(f) 63 | config = VibeVoiceConfig.from_dict(config_dict) 64 | 65 | # Set the default dtype to bfloat16 before creating the model 66 | original_dtype = torch.get_default_dtype() 67 | torch.set_default_dtype(torch.bfloat16) 68 | 69 | # Create the HuggingFace model 70 | logger.info("Creating HuggingFace VibeVoiceForConditionalGeneration model") 71 | model = VibeVoiceForConditionalGeneration(config) 72 | 73 | # Restore original dtype 74 | torch.set_default_dtype(original_dtype) 75 | 76 | # Load the state dict 77 | logger.info("Loading weights into model") 78 | missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False) 79 | 80 | if missing_keys: 81 | logger.warning(f"Missing keys: {missing_keys}") 82 | if unexpected_keys: 83 | logger.warning(f"Unexpected keys: {unexpected_keys}") 84 | 85 | # Create output directory 86 | os.makedirs(pytorch_dump_folder_path, exist_ok=True) 87 | 88 | # Save the model and config 89 | logger.info(f"Saving model to {pytorch_dump_folder_path}") 90 | 91 | # Save config 92 | config.save_pretrained(pytorch_dump_folder_path) 93 | 94 | # Save VibeVoiceProcessor configuration 95 | logger.info("Saving VibeVoiceProcessor configuration") 96 | processor_config = { 97 | "processor_class": "VibeVoiceProcessor", 98 | "speech_tok_compress_ratio": 3200, 99 | "db_normalize": True, 100 | # Audio processor configuration 101 | "audio_processor": { 102 | "feature_extractor_type": "VibeVoiceTokenizerProcessor", 103 | "sampling_rate": 24000, 104 | "normalize_audio": True, 105 | "target_dB_FS": -25, 106 | "eps": 1e-6, 107 | }, 108 | "language_model_pretrained_name": pretrained_name, 109 | } 110 | 111 | processor_config_path = os.path.join(pytorch_dump_folder_path, "preprocessor_config.json") 112 | with open(processor_config_path, 'w') as f: 113 | json.dump(processor_config, f, indent=2) 114 | logger.info(f"Saved processor config to {processor_config_path}") 115 | 116 | # Save model with sharding 117 | # save_pretrained handles tied weights automatically 118 | logger.info("Saving model weights with sharding...") 119 | model.save_pretrained( 120 | pytorch_dump_folder_path, 121 | max_shard_size="2GB", # Set maximum size for each shard 122 | safe_serialization=True # Ensure saving in .safetensors format 123 | ) 124 | logger.info(f"Model weights saved to {pytorch_dump_folder_path}") 125 | 126 | logger.info("Conversion complete!") 127 | 128 | # Verify the saved model can be loaded 129 | logger.info("Verifying saved model...") 130 | loaded_model = VibeVoiceForConditionalGeneration.from_pretrained(pytorch_dump_folder_path) 131 | logger.info("Model successfully loaded from saved checkpoint!") 132 | 133 | def main(): 134 | parser = argparse.ArgumentParser() 135 | parser.add_argument( 136 | "--nnscaler_checkpoint_path", 137 | type=str, 138 | required=True, 139 | help="Path to the fairseq checkpoint (.pt file). For tensor parallel checkpoints, " 140 | "provide any one of the part files (e.g., checkpoint_1_5000-model_part-0.pt), " 141 | "and the script will automatically detect and merge all parts.", 142 | ) 143 | parser.add_argument( 144 | "--pytorch_dump_folder_path", 145 | type=str, 146 | required=True, 147 | help="Path to the output PyTorch model directory", 148 | ) 149 | parser.add_argument( 150 | "--config_path", 151 | type=str, 152 | default=None, 153 | help="Optional path to a config JSON file to override extracted config", 154 | ) 155 | 156 | args = parser.parse_args() 157 | 158 | convert_vibevoice_nnscaler_checkpoint_to_hf( 159 | args.nnscaler_checkpoint_path, 160 | args.pytorch_dump_folder_path, 161 | args.config_path, 162 | ) 163 | 164 | 165 | if __name__ == "__main__": 166 | main() -------------------------------------------------------------------------------- /vibevoice/modular/modular_vibevoice_text_tokenizer.py: -------------------------------------------------------------------------------- 1 | """Tokenization classes for vibevoice.""" 2 | 3 | from typing import List, Optional, Union 4 | 5 | from transformers.utils import logging 6 | from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer 7 | from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast 8 | 9 | logger = logging.get_logger(__name__) 10 | 11 | 12 | class VibeVoiceTextTokenizer(Qwen2Tokenizer): 13 | """ 14 | Construct a VibeVoice tokenizer. Based on the Qwen2 tokenizer with additional special tokens for speech. 15 | 16 | Args: 17 | vocab_file (`str`): 18 | Path to the vocabulary file. 19 | merges_file (`str`): 20 | Path to the merges file. 21 | errors (`str`, *optional*, defaults to `"replace"`): 22 | Paradigm to follow when decoding bytes to UTF-8. 23 | unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): 24 | The unknown token. 25 | bos_token (`str`, *optional*): 26 | The beginning of sequence token. Not used for vibevoice. 27 | eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): 28 | The end of sequence token. 29 | pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`): 30 | The token used for padding. 31 | add_special_tokens (`bool`, *optional*, defaults to `True`): 32 | Whether or not to add special tokens when encoding. 33 | """ 34 | 35 | model_input_names = ["input_ids", "attention_mask"] 36 | 37 | def __init__( 38 | self, 39 | vocab_file, 40 | merges_file, 41 | errors="replace", 42 | unk_token="<|endoftext|>", 43 | bos_token=None, 44 | eos_token="<|endoftext|>", 45 | pad_token="<|endoftext|>", 46 | add_prefix_space=False, 47 | add_special_tokens=True, 48 | **kwargs, 49 | ): 50 | super().__init__( 51 | vocab_file=vocab_file, 52 | merges_file=merges_file, 53 | errors=errors, 54 | unk_token=unk_token, 55 | bos_token=bos_token, 56 | eos_token=eos_token, 57 | pad_token=pad_token, 58 | add_prefix_space=add_prefix_space, 59 | add_special_tokens=add_special_tokens, 60 | **kwargs, 61 | ) 62 | 63 | # Add VibeVoice-specific special tokens 64 | self._add_vibevoice_special_tokens() 65 | 66 | def _add_vibevoice_special_tokens(self): 67 | """Add VibeVoice-specific special tokens.""" 68 | special_tokens = { 69 | "additional_special_tokens": [ 70 | "<|vision_start|>", # Speech start (reusing vision tokens) 71 | "<|vision_end|>", # Speech end 72 | "<|vision_pad|>", # Speech diffusion pad 73 | ] 74 | } 75 | num_added = self.add_special_tokens(special_tokens) 76 | 77 | # Cache special token IDs 78 | self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>") 79 | self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>") 80 | self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>") 81 | 82 | self._eos_id = self.convert_tokens_to_ids('<|endoftext|>') 83 | 84 | return num_added 85 | 86 | @property 87 | def eos_id(self) -> int: 88 | """Id of the end of sequence token.""" 89 | return self._eos_id 90 | 91 | @property 92 | def speech_start_id(self) -> int: 93 | """Id of the speech start token.""" 94 | return self._speech_start_id 95 | 96 | @property 97 | def speech_end_id(self) -> int: 98 | """Id of the speech end token.""" 99 | return self._speech_end_id 100 | 101 | @property 102 | def speech_diffusion_id(self) -> int: 103 | """Id of the speech diffusion token.""" 104 | return self._speech_diffusion_id 105 | 106 | @property 107 | def pad_id(self) -> int: 108 | """Id used for padding (returns -100 for loss masking).""" 109 | return -100 110 | 111 | 112 | class VibeVoiceTextTokenizerFast(Qwen2TokenizerFast): 113 | """ 114 | Construct a "fast" VibeVoice tokenizer (backed by HuggingFace's *tokenizers* library). 115 | Based on the Qwen2 tokenizer with additional special tokens for speech. 116 | 117 | Args: 118 | vocab_file (`str`, *optional*): 119 | Path to the vocabulary file. 120 | merges_file (`str`, *optional*): 121 | Path to the merges file. 122 | tokenizer_file (`str`, *optional*): 123 | Path to [tokenizers](https://github.com/huggingface/tokenizers) file. 124 | unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): 125 | The unknown token. 126 | bos_token (`str`, *optional*): 127 | The beginning of sequence token. Not used for vibevoice. 128 | eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): 129 | The end of sequence token. 130 | pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`): 131 | The token used for padding. 132 | """ 133 | 134 | model_input_names = ["input_ids", "attention_mask"] 135 | 136 | def __init__( 137 | self, 138 | vocab_file=None, 139 | merges_file=None, 140 | tokenizer_file=None, 141 | unk_token="<|endoftext|>", 142 | bos_token=None, 143 | eos_token="<|endoftext|>", 144 | pad_token="<|endoftext|>", 145 | add_prefix_space=False, 146 | **kwargs, 147 | ): 148 | super().__init__( 149 | vocab_file=vocab_file, 150 | merges_file=merges_file, 151 | tokenizer_file=tokenizer_file, 152 | unk_token=unk_token, 153 | bos_token=bos_token, 154 | eos_token=eos_token, 155 | pad_token=pad_token, 156 | add_prefix_space=add_prefix_space, 157 | **kwargs, 158 | ) 159 | 160 | # Add VibeVoice-specific special tokens 161 | self._add_vibevoice_special_tokens() 162 | 163 | def _add_vibevoice_special_tokens(self): 164 | """Add VibeVoice-specific special tokens.""" 165 | special_tokens = { 166 | "additional_special_tokens": [ 167 | "<|vision_start|>", # Speech start (reusing vision tokens) 168 | "<|vision_end|>", # Speech end 169 | "<|vision_pad|>", # Speech diffusion pad 170 | ] 171 | } 172 | num_added = self.add_special_tokens(special_tokens) 173 | 174 | # Cache special token IDs 175 | self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>") 176 | self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>") 177 | self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>") 178 | 179 | # self._eos_id = self.convert_tokens_to_ids('<|endoftext|>') 180 | self._eos_id = self.eos_token_id # qwen2 / qwen3 181 | self._pad_id = self.convert_tokens_to_ids('<|image_pad|>') 182 | 183 | return num_added 184 | 185 | @property 186 | def eos_id(self) -> int: 187 | """Id of the end of sequence token.""" 188 | return self._eos_id 189 | 190 | @property 191 | def speech_start_id(self) -> int: 192 | """Id of the speech start token.""" 193 | return self._speech_start_id 194 | 195 | @property 196 | def speech_end_id(self) -> int: 197 | """Id of the speech end token.""" 198 | return self._speech_end_id 199 | 200 | @property 201 | def speech_diffusion_id(self) -> int: 202 | """Id of the speech diffusion token.""" 203 | return self._speech_diffusion_id 204 | 205 | @property 206 | def pad_id(self) -> int: 207 | """Id used for padding (returns -100 for loss masking).""" 208 | return self._pad_id 209 | 210 | 211 | __all__ = [ 212 | "VibeVoiceTextTokenizer", 213 | "VibeVoiceTextTokenizerFast", 214 | ] -------------------------------------------------------------------------------- /vibevoice/modular/modular_vibevoice_diffusion_head.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from transformers.models.auto import AutoModel 9 | from transformers.modeling_utils import PreTrainedModel 10 | # from transformers.modeling_layers import GradientCheckpointingLayer 11 | from transformers.activations import ACT2FN 12 | from transformers.utils import logging 13 | 14 | from .configuration_vibevoice import VibeVoiceDiffusionHeadConfig 15 | 16 | 17 | logger = logging.get_logger(__name__) 18 | 19 | 20 | class RMSNorm(nn.Module): 21 | def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False): 22 | super().__init__() 23 | self.dim = dim 24 | self.eps = eps 25 | self.elementwise_affine = elementwise_affine 26 | if self.elementwise_affine: 27 | self.weight = nn.Parameter(torch.ones(dim)) 28 | else: 29 | self.register_parameter('weight', None) 30 | 31 | def _norm(self, x): 32 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 33 | 34 | def forward(self, x): 35 | output = self._norm(x.float()).type_as(x) 36 | if self.weight is not None: 37 | output = output * self.weight 38 | return output 39 | 40 | def extra_repr(self) -> str: 41 | return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}' 42 | 43 | def modulate(x, shift, scale): 44 | """Apply modulation to input tensor.""" 45 | return x * (1 + scale) + shift 46 | 47 | 48 | class TimestepEmbedder(nn.Module): 49 | """ 50 | Embeds scalar timesteps into vector representations. 51 | 52 | Args: 53 | hidden_size (`int`): Size of the output embedding 54 | frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding 55 | """ 56 | def __init__(self, hidden_size, frequency_embedding_size=256): 57 | super().__init__() 58 | self.mlp = nn.Sequential( 59 | nn.Linear(frequency_embedding_size, hidden_size, bias=False), 60 | # nn.SiLU(), 61 | ACT2FN['silu'], 62 | nn.Linear(hidden_size, hidden_size, bias=False), 63 | ) 64 | self.frequency_embedding_size = frequency_embedding_size 65 | 66 | @staticmethod 67 | def timestep_embedding(t, dim, max_period=10000): 68 | """ 69 | Create sinusoidal timestep embeddings. 70 | 71 | Args: 72 | t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element. 73 | These may be fractional. 74 | dim (`int`): The dimension of the output. 75 | max_period (`int`, optional): Controls the minimum frequency of the embeddings. 76 | 77 | Returns: 78 | `torch.Tensor`: An [N, D] Tensor of positional embeddings. 79 | """ 80 | half = dim // 2 81 | freqs = torch.exp( 82 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 83 | ).to(t.device) 84 | args = t[:, None].float() * freqs[None] 85 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 86 | if dim % 2: 87 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 88 | return embedding.to(t.dtype) 89 | 90 | def forward(self, t): 91 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 92 | t_emb = self.mlp(t_freq) 93 | return t_emb 94 | 95 | 96 | class FeedForwardNetwork(nn.Module): 97 | """ 98 | Standard feed-forward network with SwiGLU activation. 99 | 100 | Args: 101 | embed_dim (`int`): Input dimension 102 | ffn_dim (`int`): Hidden dimension 103 | """ 104 | def __init__( 105 | self, 106 | embed_dim, 107 | ffn_dim, 108 | ): 109 | super().__init__() 110 | self.embed_dim = embed_dim 111 | self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False) 112 | self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False) 113 | self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False) 114 | self.act_fn = ACT2FN['silu'] # Using SiLU as the activation function 115 | 116 | def forward(self, x): 117 | gate = self.gate_proj(x) 118 | up = self.up_proj(x) 119 | 120 | # SwiGLU activation 121 | # gate = F.silu(gate) 122 | gate = self.act_fn(gate) 123 | return self.down_proj(gate * up) 124 | 125 | 126 | class HeadLayer(nn.Module): 127 | """ 128 | A layer in the diffusion head. 129 | 130 | Args: 131 | embed_dim (`int`): Input dimension 132 | ffn_dim (`int`): Hidden dimension 133 | cond_dim (`int`): Condition embedding dimension 134 | norm_eps (`float`, optional): Epsilon for normalization 135 | """ 136 | def __init__( 137 | self, 138 | embed_dim, 139 | ffn_dim, 140 | cond_dim, 141 | norm_eps=1e-5, 142 | ): 143 | super().__init__() 144 | self.embed_dim = embed_dim 145 | self.cond_dim = cond_dim 146 | self.ffn_dim = ffn_dim 147 | self.ffn = FeedForwardNetwork( 148 | self.embed_dim, 149 | self.ffn_dim, 150 | ) 151 | self.norm = RMSNorm(self.embed_dim, eps=norm_eps) 152 | self.adaLN_modulation = nn.Sequential( 153 | # nn.SiLU(), 154 | ACT2FN['silu'], 155 | nn.Linear(cond_dim, 3 * self.embed_dim, bias=False) 156 | ) 157 | 158 | def forward(self, x, c): 159 | shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1) 160 | x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn)) 161 | return x 162 | 163 | 164 | class FinalLayer(nn.Module): 165 | """ 166 | Final layer in the diffusion head. 167 | 168 | Args: 169 | hidden_size (`int`): Input dimension 170 | output_size (`int`): Output dimension 171 | cond_size (`int`): Condition embedding dimension 172 | norm_eps (`float`, optional): Epsilon for normalization 173 | """ 174 | def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5): 175 | super().__init__() 176 | self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False) 177 | self.linear = nn.Linear(hidden_size, output_size, bias=False) 178 | self.adaLN_modulation = nn.Sequential( 179 | # nn.SiLU(), 180 | ACT2FN['silu'], 181 | nn.Linear(cond_size, 2 * hidden_size, bias=False) 182 | ) 183 | 184 | def forward(self, x, c): 185 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) 186 | x = modulate(self.norm_final(x), shift, scale) 187 | x = self.linear(x) 188 | return x 189 | 190 | 191 | class VibeVoiceDiffusionHead(PreTrainedModel): 192 | """ 193 | Diffusion head model for vibevoice. 194 | 195 | Args: 196 | config (`VibeVoiceDiffusionHeadConfig`): Model configuration 197 | latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`. 198 | """ 199 | config_class = VibeVoiceDiffusionHeadConfig 200 | supports_gradient_checkpointing = True 201 | _supports_flash_attn_2 = True 202 | _supports_sdpa = True 203 | 204 | def __init__( 205 | self, 206 | config, 207 | ): 208 | super().__init__(config) 209 | self.config = config 210 | self.cond_dim = config.hidden_size 211 | latent_size = config.latent_size 212 | 213 | self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False) 214 | self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False) 215 | self.t_embedder = TimestepEmbedder(self.cond_dim) 216 | 217 | ffn_dim = int(config.hidden_size * config.head_ffn_ratio) 218 | 219 | # Create the intermediate layers 220 | self.layers = nn.ModuleList([ 221 | HeadLayer( 222 | embed_dim=config.hidden_size, 223 | ffn_dim=ffn_dim, 224 | cond_dim=self.cond_dim, 225 | norm_eps=config.rms_norm_eps 226 | ) 227 | for _ in range(config.head_layers) 228 | ]) 229 | 230 | # Final layer for output 231 | self.final_layer = FinalLayer( 232 | hidden_size=config.hidden_size, 233 | output_size=latent_size, 234 | cond_size=self.cond_dim, 235 | norm_eps=config.rms_norm_eps 236 | ) 237 | 238 | self.initialize_weights() 239 | 240 | def initialize_weights(self): 241 | """Initialize the weights of the model.""" 242 | # Initialize timestep embedder 243 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 244 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 245 | 246 | # Zero-out adaLN modulation layers 247 | for layer in self.layers: 248 | nn.init.constant_(layer.adaLN_modulation[-1].weight, 0) 249 | 250 | # Zero-out output layers 251 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 252 | nn.init.constant_(self.final_layer.linear.weight, 0) 253 | 254 | def forward( 255 | self, 256 | noisy_images, 257 | timesteps, 258 | condition, 259 | ): 260 | """ 261 | Forward pass of the prediction head. 262 | 263 | Args: 264 | noisy_images (`torch.Tensor`): Noisy images/latents to denoise 265 | timesteps (`torch.Tensor`): Timesteps for diffusion 266 | condition (`torch.Tensor`): Conditioning information 267 | 268 | Returns: 269 | `torch.Tensor`: The predicted noise/velocity 270 | """ 271 | x = self.noisy_images_proj(noisy_images) 272 | t = self.t_embedder(timesteps) 273 | condition = self.cond_proj(condition) 274 | c = condition + t 275 | 276 | for layer in self.layers: 277 | x = layer(x, c) 278 | 279 | x = self.final_layer(x, c) 280 | return x 281 | 282 | 283 | AutoModel.register(VibeVoiceDiffusionHeadConfig, VibeVoiceDiffusionHead) 284 | 285 | __all__ = [ 286 | "VibeVoiceDiffusionHead", 287 | ] -------------------------------------------------------------------------------- /vibevoice/modular/streamer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | 5 | import asyncio 6 | from queue import Queue 7 | from typing import TYPE_CHECKING, Optional 8 | 9 | 10 | from transformers.generation import BaseStreamer 11 | 12 | 13 | class AudioStreamer(BaseStreamer): 14 | """ 15 | Audio streamer that stores audio chunks in queues for each sample in the batch. 16 | This allows streaming audio generation for multiple samples simultaneously. 17 | 18 | Parameters: 19 | batch_size (`int`): 20 | The batch size for generation 21 | stop_signal (`any`, *optional*): 22 | The signal to put in the queue when generation ends. Defaults to None. 23 | timeout (`float`, *optional*): 24 | The timeout for the audio queue. If `None`, the queue will block indefinitely. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | batch_size: int, 30 | stop_signal: Optional[any] = None, 31 | timeout: Optional[float] = None, 32 | ): 33 | self.batch_size = batch_size 34 | self.stop_signal = stop_signal 35 | self.timeout = timeout 36 | 37 | # Create a queue for each sample in the batch 38 | self.audio_queues = [Queue() for _ in range(batch_size)] 39 | self.finished_flags = [False for _ in range(batch_size)] 40 | self.sample_indices_map = {} # Maps from sample index to queue index 41 | 42 | def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor): 43 | """ 44 | Receives audio chunks and puts them in the appropriate queues. 45 | 46 | Args: 47 | audio_chunks: Tensor of shape (num_samples, ...) containing audio chunks 48 | sample_indices: Tensor indicating which samples these chunks belong to 49 | """ 50 | for i, sample_idx in enumerate(sample_indices): 51 | idx = sample_idx.item() 52 | if idx < self.batch_size and not self.finished_flags[idx]: 53 | # Convert to numpy or keep as tensor based on preference 54 | audio_chunk = audio_chunks[i].detach().cpu() 55 | self.audio_queues[idx].put(audio_chunk, timeout=self.timeout) 56 | 57 | def end(self, sample_indices: Optional[torch.Tensor] = None): 58 | """ 59 | Signals the end of generation for specified samples or all samples. 60 | 61 | Args: 62 | sample_indices: Optional tensor of sample indices to end. If None, ends all. 63 | """ 64 | if sample_indices is None: 65 | # End all samples 66 | for idx in range(self.batch_size): 67 | if not self.finished_flags[idx]: 68 | self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout) 69 | self.finished_flags[idx] = True 70 | else: 71 | # End specific samples 72 | for sample_idx in sample_indices: 73 | idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx 74 | if idx < self.batch_size and not self.finished_flags[idx]: 75 | self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout) 76 | self.finished_flags[idx] = True 77 | 78 | def __iter__(self): 79 | """Returns an iterator over the batch of audio streams.""" 80 | return AudioBatchIterator(self) 81 | 82 | def get_stream(self, sample_idx: int): 83 | """Get the audio stream for a specific sample.""" 84 | if sample_idx >= self.batch_size: 85 | raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}") 86 | return AudioSampleIterator(self, sample_idx) 87 | 88 | 89 | class AudioSampleIterator: 90 | """Iterator for a single audio stream from the batch.""" 91 | 92 | def __init__(self, streamer: AudioStreamer, sample_idx: int): 93 | self.streamer = streamer 94 | self.sample_idx = sample_idx 95 | 96 | def __iter__(self): 97 | return self 98 | 99 | def __next__(self): 100 | value = self.streamer.audio_queues[self.sample_idx].get(timeout=self.streamer.timeout) 101 | if value == self.streamer.stop_signal: 102 | raise StopIteration() 103 | return value 104 | 105 | 106 | class AudioBatchIterator: 107 | """Iterator that yields audio chunks for all samples in the batch.""" 108 | 109 | def __init__(self, streamer: AudioStreamer): 110 | self.streamer = streamer 111 | self.active_samples = set(range(streamer.batch_size)) 112 | 113 | def __iter__(self): 114 | return self 115 | 116 | def __next__(self): 117 | if not self.active_samples: 118 | raise StopIteration() 119 | 120 | batch_chunks = {} 121 | samples_to_remove = set() 122 | 123 | # Try to get chunks from all active samples 124 | for idx in self.active_samples: 125 | try: 126 | value = self.streamer.audio_queues[idx].get(block=False) 127 | if value == self.streamer.stop_signal: 128 | samples_to_remove.add(idx) 129 | else: 130 | batch_chunks[idx] = value 131 | except: 132 | # Queue is empty for this sample, skip it this iteration 133 | pass 134 | 135 | # Remove finished samples 136 | self.active_samples -= samples_to_remove 137 | 138 | if batch_chunks: 139 | return batch_chunks 140 | elif self.active_samples: 141 | # If no chunks were ready but we still have active samples, 142 | # wait a bit and try again 143 | import time 144 | time.sleep(0.01) 145 | return self.__next__() 146 | else: 147 | raise StopIteration() 148 | 149 | 150 | class AsyncAudioStreamer(AudioStreamer): 151 | """ 152 | Async version of AudioStreamer for use in async contexts. 153 | """ 154 | 155 | def __init__( 156 | self, 157 | batch_size: int, 158 | stop_signal: Optional[any] = None, 159 | timeout: Optional[float] = None, 160 | ): 161 | super().__init__(batch_size, stop_signal, timeout) 162 | # Replace regular queues with async queues 163 | self.audio_queues = [asyncio.Queue() for _ in range(batch_size)] 164 | self.loop = asyncio.get_running_loop() 165 | 166 | def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor): 167 | """Put audio chunks in the appropriate async queues.""" 168 | for i, sample_idx in enumerate(sample_indices): 169 | idx = sample_idx.item() 170 | if idx < self.batch_size and not self.finished_flags[idx]: 171 | audio_chunk = audio_chunks[i].detach().cpu() 172 | self.loop.call_soon_threadsafe( 173 | self.audio_queues[idx].put_nowait, audio_chunk 174 | ) 175 | 176 | def end(self, sample_indices: Optional[torch.Tensor] = None): 177 | """Signal the end of generation for specified samples.""" 178 | if sample_indices is None: 179 | indices_to_end = range(self.batch_size) 180 | else: 181 | indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices] 182 | 183 | for idx in indices_to_end: 184 | if idx < self.batch_size and not self.finished_flags[idx]: 185 | self.loop.call_soon_threadsafe( 186 | self.audio_queues[idx].put_nowait, self.stop_signal 187 | ) 188 | self.finished_flags[idx] = True 189 | 190 | async def get_stream(self, sample_idx: int): 191 | """Get async iterator for a specific sample's audio stream.""" 192 | if sample_idx >= self.batch_size: 193 | raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}") 194 | 195 | while True: 196 | value = await self.audio_queues[sample_idx].get() 197 | if value == self.stop_signal: 198 | break 199 | yield value 200 | 201 | def __aiter__(self): 202 | """Returns an async iterator over all audio streams.""" 203 | return AsyncAudioBatchIterator(self) 204 | 205 | 206 | class AsyncAudioBatchIterator: 207 | """Async iterator for batch audio streaming.""" 208 | 209 | def __init__(self, streamer: AsyncAudioStreamer): 210 | self.streamer = streamer 211 | self.active_samples = set(range(streamer.batch_size)) 212 | 213 | def __aiter__(self): 214 | return self 215 | 216 | async def __anext__(self): 217 | if not self.active_samples: 218 | raise StopAsyncIteration() 219 | 220 | batch_chunks = {} 221 | samples_to_remove = set() 222 | 223 | # Create tasks for all active samples 224 | tasks = { 225 | idx: asyncio.create_task(self._get_chunk(idx)) 226 | for idx in self.active_samples 227 | } 228 | 229 | # Wait for at least one chunk to be ready 230 | done, pending = await asyncio.wait( 231 | tasks.values(), 232 | return_when=asyncio.FIRST_COMPLETED, 233 | timeout=self.streamer.timeout 234 | ) 235 | 236 | # Cancel pending tasks 237 | for task in pending: 238 | task.cancel() 239 | 240 | # Process completed tasks 241 | for idx, task in tasks.items(): 242 | if task in done: 243 | try: 244 | value = await task 245 | if value == self.streamer.stop_signal: 246 | samples_to_remove.add(idx) 247 | else: 248 | batch_chunks[idx] = value 249 | except asyncio.CancelledError: 250 | pass 251 | 252 | self.active_samples -= samples_to_remove 253 | 254 | if batch_chunks: 255 | return batch_chunks 256 | elif self.active_samples: 257 | # Try again if we still have active samples 258 | return await self.__anext__() 259 | else: 260 | raise StopAsyncIteration() 261 | 262 | async def _get_chunk(self, idx): 263 | """Helper to get a chunk from a specific queue.""" 264 | return await self.streamer.audio_queues[idx].get() -------------------------------------------------------------------------------- /example_workflows/VibeVoice_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "b91265e5-1b03-4b63-8dc3-4abd9a030e08", 3 | "revision": 0, 4 | "last_node_id": 14, 5 | "last_link_id": 44, 6 | "nodes": [ 7 | { 8 | "id": 3, 9 | "type": "SaveAudio", 10 | "pos": [ 11 | -1040, 12 | -1130 13 | ], 14 | "size": [ 15 | 270, 16 | 112 17 | ], 18 | "flags": {}, 19 | "order": 6, 20 | "mode": 0, 21 | "inputs": [ 22 | { 23 | "name": "audio", 24 | "type": "AUDIO", 25 | "link": 27 26 | } 27 | ], 28 | "outputs": [], 29 | "properties": { 30 | "Node name for S&R": "SaveAudio", 31 | "cnr_id": "comfy-core", 32 | "ver": "0.3.52", 33 | "ue_properties": { 34 | "widget_ue_connectable": { 35 | "filename_prefix": true, 36 | "audioUI": true 37 | }, 38 | "version": "7.0.1" 39 | } 40 | }, 41 | "widgets_values": [ 42 | "audio/VibeVoice" 43 | ] 44 | }, 45 | { 46 | "id": 13, 47 | "type": "MarkdownNote", 48 | "pos": [ 49 | -1898.1748046875, 50 | -1409.22314453125 51 | ], 52 | "size": [ 53 | 1035.619873046875, 54 | 211.96694946289062 55 | ], 56 | "flags": {}, 57 | "order": 0, 58 | "mode": 0, 59 | "inputs": [], 60 | "outputs": [], 61 | "title": "Note", 62 | "properties": {}, 63 | "widgets_values": [ 64 | "# ComfyUI-VibeVoice\n\nVibeVoice is a novel framework by Microsoft for generating expressive, long-form, multi-speaker conversational audio. It excels at creating natural-sounding dialogue, podcasts, and more, with consistent voices for up to 4 speakers.\n\n**✨ Key Features:**\n* **Multi-Speaker TTS:** Generate conversations with up to 4 distinct voices in a single audio output.\n* **High-Fidelity Voice Cloning:** Use any audio file (`.wav`, `.mp3`) as a reference for a speaker's voice.\n* **Hybrid Generation Mode:** Mix and match cloned voices with high-quality, zero-shot generated voices in the same script.\n* **Flexible Scripting:** Use simple `[1]` tags or the classic `Speaker 1:` format to write your dialogue.\n* **Advanced Attention Mechanisms:** Choose between `eager`, `sdpa`, `flash_attention_2`, and the high-performance `sage` attention for fine-tuned control over speed and compatibility.\n* **Robust 4-Bit Quantization:** Run the large language model component in 4-bit mode to significantly reduce VRAM usage.\n* **Automatic Model Management:** Models are downloaded automatically and managed efficiently by ComfyUI to save VRAM." 65 | ], 66 | "color": "#233", 67 | "bgcolor": "#355" 68 | }, 69 | { 70 | "id": 4, 71 | "type": "LoadAudio", 72 | "pos": [ 73 | -1900, 74 | -1130 75 | ], 76 | "size": [ 77 | 272.9800720214844, 78 | 136 79 | ], 80 | "flags": {}, 81 | "order": 1, 82 | "mode": 0, 83 | "inputs": [], 84 | "outputs": [ 85 | { 86 | "name": "AUDIO", 87 | "type": "AUDIO", 88 | "links": [] 89 | } 90 | ], 91 | "properties": { 92 | "Node name for S&R": "LoadAudio", 93 | "cnr_id": "comfy-core", 94 | "ver": "0.3.52", 95 | "ue_properties": { 96 | "widget_ue_connectable": { 97 | "audio": true, 98 | "audioUI": true, 99 | "upload": true 100 | }, 101 | "version": "7.0.1" 102 | } 103 | }, 104 | "widgets_values": [ 105 | "male_rickmorty.mp3", 106 | null, 107 | null 108 | ] 109 | }, 110 | { 111 | "id": 8, 112 | "type": "LoadAudio", 113 | "pos": [ 114 | -1901.10009765625, 115 | -948.7998046875 116 | ], 117 | "size": [ 118 | 274.080078125, 119 | 136 120 | ], 121 | "flags": {}, 122 | "order": 2, 123 | "mode": 0, 124 | "inputs": [], 125 | "outputs": [ 126 | { 127 | "name": "AUDIO", 128 | "type": "AUDIO", 129 | "links": [] 130 | } 131 | ], 132 | "properties": { 133 | "Node name for S&R": "LoadAudio", 134 | "cnr_id": "comfy-core", 135 | "ver": "0.3.52", 136 | "ue_properties": { 137 | "widget_ue_connectable": { 138 | "audio": true, 139 | "audioUI": true, 140 | "upload": true 141 | }, 142 | "version": "7.0.1" 143 | } 144 | }, 145 | "widgets_values": [ 146 | "male_stewie.mp3", 147 | null, 148 | null 149 | ] 150 | }, 151 | { 152 | "id": 12, 153 | "type": "MarkdownNote", 154 | "pos": [ 155 | -1915.701904296875, 156 | -762.380126953125 157 | ], 158 | "size": [ 159 | 312.85455322265625, 160 | 292.8734130859375 161 | ], 162 | "flags": {}, 163 | "order": 3, 164 | "mode": 0, 165 | "inputs": [], 166 | "outputs": [], 167 | "title": "Note", 168 | "properties": {}, 169 | "widgets_values": [ 170 | "### Scripting and Voice Modes\n\n#### Speaker Tagging\nYou can assign lines to speakers in two ways. Both are treated identically.\n\n* **Modern Format (Recommended):** `[1] This is the first speaker.`\n* **Classic Format:** `Speaker 1: This is the first speaker.`\n\nYou can also add an optional colon to the modern format (e.g., `[1]: ...`). The node handles all variations consistently.\n\n#### Hybrid Voice Generation\nThis is a powerful feature that lets you mix cloned voices and generated (zero-shot) voices.\n\n* **To Clone a Voice:** Connect a `Load Audio` node to the speaker's input (e.g., `speaker_1_voice`).\n* **To Generate a Voice:** Leave the speaker's input empty. The model will create a unique, high-quality voice for that speaker." 171 | ], 172 | "color": "#233", 173 | "bgcolor": "#355" 174 | }, 175 | { 176 | "id": 14, 177 | "type": "MarkdownNote", 178 | "pos": [ 179 | -1048.3660888671875, 180 | -960.8771362304688 181 | ], 182 | "size": [ 183 | 280.797607421875, 184 | 487.02728271484375 185 | ], 186 | "flags": {}, 187 | "order": 4, 188 | "mode": 0, 189 | "inputs": [], 190 | "outputs": [], 191 | "title": "Note", 192 | "properties": {}, 193 | "widgets_values": [ 194 | "## Models\n\nWill be downloaded on the first run, or download them manually and place them into the directory: /models/tts/VibeVoice\n\n| Model | Context Length | Generation Length | Weight |\n|-------|----------------|----------|----------|\n| VibeVoice-1.5B | 64K | ~90 min | [HF link](https://huggingface.co/microsoft/VibeVoice-1.5B) |\n| VibeVoice-Large| 32K | ~45 min | [HF link](https://huggingface.co/microsoft/VibeVoice-Large) |\n\n## Support \n\n- Don't know how to update PyTorch?\n- Need help with ComfyUI?\n- Need technical support?\n\n### Or do you just have questions? Then join the [@TokenDiffusion Hub](https://t.me/TokenDiff_hub) group\n\n### AI news [TokenDiffusion](https://t.me/TokenDiff)" 195 | ], 196 | "color": "#233", 197 | "bgcolor": "#355" 198 | }, 199 | { 200 | "id": 11, 201 | "type": "VibeVoiceTTS", 202 | "pos": [ 203 | -1570, 204 | -1130 205 | ], 206 | "size": [ 207 | 475.3999938964844, 208 | 662.9000244140625 209 | ], 210 | "flags": {}, 211 | "order": 5, 212 | "mode": 0, 213 | "inputs": [ 214 | { 215 | "name": "speaker_1_voice", 216 | "shape": 7, 217 | "type": "AUDIO", 218 | "link": null 219 | }, 220 | { 221 | "name": "speaker_2_voice", 222 | "shape": 7, 223 | "type": "AUDIO", 224 | "link": null 225 | }, 226 | { 227 | "name": "speaker_3_voice", 228 | "shape": 7, 229 | "type": "AUDIO", 230 | "link": null 231 | }, 232 | { 233 | "name": "speaker_4_voice", 234 | "shape": 7, 235 | "type": "AUDIO", 236 | "link": null 237 | } 238 | ], 239 | "outputs": [ 240 | { 241 | "name": "AUDIO", 242 | "type": "AUDIO", 243 | "links": [ 244 | 27 245 | ] 246 | } 247 | ], 248 | "properties": { 249 | "Node name for S&R": "VibeVoiceTTS", 250 | "cnr_id": "ComfyUI-VibeVoice", 251 | "ver": "37803a884fb8f9b43c38286f6d654c7f97181a73", 252 | "ue_properties": { 253 | "widget_ue_connectable": { 254 | "model_name": true, 255 | "text": true, 256 | "quantize_llm_4bit": true, 257 | "attention_mode": true, 258 | "cfg_scale": true, 259 | "inference_steps": true, 260 | "seed": true, 261 | "do_sample": true, 262 | "temperature": true, 263 | "top_p": true, 264 | "top_k": true 265 | }, 266 | "version": "7.0.1" 267 | } 268 | }, 269 | "widgets_values": [ 270 | "VibeVoice-1.5B", 271 | "[1] I can't believe you did it again. I waited for two hours. Two hours! Not a single call, not a text. Do you have any idea how embarrassing that was, just sitting there alone?\n[2] Look, I know, I'm sorry, alright? Work was a complete nightmare. My boss dropped a critical deadline on me at the last minute. I didn't even have a second to breathe, let alone check my phone.\n", 272 | false, 273 | "flash_attention_2", 274 | 1.3, 275 | 10, 276 | 471935335072093, 277 | "fixed", 278 | true, 279 | 0.95, 280 | 0.95, 281 | 0, 282 | false 283 | ], 284 | "color": "#232", 285 | "bgcolor": "#353" 286 | } 287 | ], 288 | "links": [ 289 | [ 290 | 27, 291 | 11, 292 | 0, 293 | 3, 294 | 0, 295 | "AUDIO" 296 | ] 297 | ], 298 | "groups": [], 299 | "config": {}, 300 | "extra": { 301 | "ds": { 302 | "scale": 0.8264462809917354, 303 | "offset": [ 304 | 2015.701904296875, 305 | 1509.22314453125 306 | ] 307 | }, 308 | "ue_links": [], 309 | "links_added_by_ue": [], 310 | "frontendVersion": "1.26.11", 311 | "VHS_latentpreview": false, 312 | "VHS_latentpreviewrate": 0, 313 | "VHS_MetadataImage": true, 314 | "VHS_KeepIntermediate": true 315 | }, 316 | "version": 0.4 317 | } -------------------------------------------------------------------------------- /vibevoice/modular/configuration_vibevoice.py: -------------------------------------------------------------------------------- 1 | """ VibeVoice_AcousticTokenizer model configuration""" 2 | 3 | from typing import Dict, List, Optional, Tuple 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | from transformers.utils import logging 7 | 8 | from transformers.models.qwen2.configuration_qwen2 import Qwen2Config 9 | 10 | logger = logging.get_logger(__name__) 11 | 12 | 13 | class VibeVoiceAcousticTokenizerConfig(PretrainedConfig): 14 | model_type = "vibevoice_acoustic_tokenizer" 15 | 16 | def __init__( 17 | self, 18 | channels: int = 1, 19 | corpus_normalize: float = 0.0, 20 | causal: bool = True, 21 | vae_dim: int = 64, 22 | fix_std: float = 0.5, 23 | std_dist_type: str = 'gaussian', 24 | # common 25 | mixer_layer: str = 'depthwise_conv', 26 | conv_norm: str = 'none', 27 | pad_mode: str = 'constant', 28 | disable_last_norm: bool = True, 29 | layernorm: str = 'RMSNorm', 30 | layernorm_eps: float = 1e-5, 31 | layernorm_elementwise_affine: bool = True, 32 | conv_bias: bool = True, 33 | layer_scale_init_value: float = 1e-6, 34 | weight_init_value: float = 1e-2, 35 | # encoder specific 36 | encoder_n_filters: int = 32, 37 | encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2], 38 | encoder_depths: str = "3-3-3-3-3-3-8", 39 | # decoder specific 40 | decoder_n_filters: int = 32, 41 | decoder_ratios: Optional[List[int]] = None, # if None, same as encoder 42 | decoder_depths: Optional[str] = None, 43 | **kwargs 44 | ): 45 | super().__init__(**kwargs) 46 | self.channels = channels 47 | self.corpus_normalize = corpus_normalize 48 | self.causal = causal 49 | self.vae_dim = vae_dim 50 | self.fix_std = fix_std 51 | self.std_dist_type = std_dist_type 52 | 53 | # common parameters 54 | self.conv_norm = conv_norm 55 | self.pad_mode = pad_mode 56 | self.layernorm_eps = layernorm_eps 57 | self.disable_last_norm = disable_last_norm 58 | self.layernorm = layernorm 59 | self.layernorm_elementwise_affine = layernorm_elementwise_affine 60 | self.conv_bias = conv_bias 61 | self.layer_scale_init_value = layer_scale_init_value 62 | self.weight_init_value = weight_init_value 63 | self.mixer_layer = mixer_layer 64 | 65 | # encoder specific parameters 66 | self.encoder_n_filters = encoder_n_filters 67 | self.encoder_ratios = encoder_ratios 68 | self.encoder_depths = encoder_depths 69 | 70 | # decoder specific parameters 71 | self.decoder_ratios = decoder_ratios if decoder_ratios is not None else encoder_ratios 72 | self.decoder_n_filters = decoder_n_filters 73 | self.decoder_depths = decoder_depths 74 | 75 | 76 | class VibeVoiceSemanticTokenizerConfig(PretrainedConfig): 77 | model_type = "vibevoice_semantic_tokenizer" 78 | 79 | def __init__( 80 | self, 81 | channels: int = 1, 82 | corpus_normalize: float = 0.0, 83 | causal: bool = True, 84 | vae_dim: int = 64, 85 | fix_std: float = 0, 86 | std_dist_type: str = 'none', 87 | # common 88 | mixer_layer: str = 'depthwise_conv', 89 | conv_norm: str = 'none', 90 | pad_mode: str = 'constant', 91 | disable_last_norm: bool = True, 92 | layernorm: str = 'RMSNorm', 93 | layernorm_eps: float = 1e-5, 94 | layernorm_elementwise_affine: bool = True, 95 | conv_bias: bool = True, 96 | layer_scale_init_value: float = 1e-6, 97 | weight_init_value: float = 1e-2, 98 | # encoder specific 99 | encoder_n_filters: int = 32, 100 | encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2], 101 | encoder_depths: str = "3-3-3-3-3-3-8", 102 | **kwargs 103 | ): 104 | super().__init__(**kwargs) 105 | self.channels = channels 106 | self.corpus_normalize = corpus_normalize 107 | self.causal = causal 108 | self.vae_dim = vae_dim 109 | self.fix_std = fix_std 110 | self.std_dist_type = std_dist_type 111 | 112 | # common parameters 113 | self.conv_norm = conv_norm 114 | self.pad_mode = pad_mode 115 | self.layernorm_eps = layernorm_eps 116 | self.disable_last_norm = disable_last_norm 117 | self.layernorm = layernorm 118 | self.layernorm_elementwise_affine = layernorm_elementwise_affine 119 | self.conv_bias = conv_bias 120 | self.layer_scale_init_value = layer_scale_init_value 121 | self.weight_init_value = weight_init_value 122 | self.mixer_layer = mixer_layer 123 | 124 | # encoder specific parameters 125 | self.encoder_n_filters = encoder_n_filters 126 | self.encoder_ratios = encoder_ratios 127 | self.encoder_depths = encoder_depths 128 | 129 | 130 | class VibeVoiceDiffusionHeadConfig(PretrainedConfig): 131 | model_type = "vibevoice_diffusion_head" 132 | 133 | def __init__( 134 | self, 135 | hidden_size=768, 136 | head_layers=4, 137 | head_ffn_ratio=3.0, 138 | rms_norm_eps=1e-5, 139 | latent_size=64, 140 | speech_vae_dim=None, 141 | prediction_type="v_prediction", 142 | diffusion_type="ddpm", 143 | ddpm_num_steps=1000, 144 | ddpm_num_inference_steps=20, 145 | ddpm_beta_schedule="cosine", 146 | ddpm_batch_mul=4, 147 | **kwargs 148 | ): 149 | self.hidden_size = hidden_size 150 | self.head_layers = head_layers 151 | self.head_ffn_ratio = head_ffn_ratio 152 | self.rms_norm_eps = rms_norm_eps 153 | self.latent_size = latent_size 154 | self.speech_vae_dim = speech_vae_dim 155 | self.prediction_type = prediction_type 156 | self.diffusion_type = diffusion_type 157 | self.ddpm_num_steps = ddpm_num_steps 158 | self.ddpm_num_inference_steps = ddpm_num_inference_steps 159 | self.ddpm_beta_schedule = ddpm_beta_schedule 160 | self.ddpm_batch_mul = ddpm_batch_mul 161 | 162 | super().__init__(**kwargs) 163 | 164 | class VibeVoiceConfig(PretrainedConfig): 165 | model_type = "vibevoice" 166 | is_composition = True 167 | sub_configs = { 168 | "acoustic_tokenizer_config": VibeVoiceAcousticTokenizerConfig, 169 | "semantic_tokenizer_config": VibeVoiceSemanticTokenizerConfig, 170 | "decoder_config": Qwen2Config, 171 | "diffusion_head_config": VibeVoiceDiffusionHeadConfig, 172 | } 173 | # keys_to_ignore_at_inference = ["past_key_values"] 174 | # Default tensor parallel plan for base model `Qwen2` 175 | base_model_tp_plan = { 176 | "layers.*.self_attn.q_proj": "colwise", 177 | "layers.*.self_attn.k_proj": "colwise", 178 | "layers.*.self_attn.v_proj": "colwise", 179 | "layers.*.self_attn.o_proj": "rowwise", 180 | "layers.*.mlp.gate_proj": "colwise", 181 | "layers.*.mlp.up_proj": "colwise", 182 | "layers.*.mlp.down_proj": "rowwise", 183 | } 184 | 185 | def __init__( 186 | self, 187 | acoustic_tokenizer_config=None, 188 | semantic_tokenizer_config=None, 189 | decoder_config=None, 190 | diffusion_head_config=None, 191 | **kwargs 192 | ): 193 | 194 | # kwargs["_attn_implementation"] = "flash_attention_2" 195 | kwargs["_attn_implementation_autoset"] = False 196 | 197 | if acoustic_tokenizer_config is None: 198 | self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]() 199 | elif isinstance(acoustic_tokenizer_config, dict): 200 | acoustic_tokenizer_config["model_type"] = "vibevoice_acoustic_tokenizer" 201 | self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config) 202 | elif isinstance(acoustic_tokenizer_config, VibeVoiceAcousticTokenizerConfig): 203 | # If an instance of the config class is provided 204 | self.acoustic_tokenizer_config = acoustic_tokenizer_config 205 | 206 | if semantic_tokenizer_config is None: 207 | self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"]() 208 | elif isinstance(semantic_tokenizer_config, dict): 209 | semantic_tokenizer_config["model_type"] = "vibevoice_semantic_tokenizer" 210 | self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"](**semantic_tokenizer_config) 211 | elif isinstance(semantic_tokenizer_config, VibeVoiceSemanticTokenizerConfig): 212 | # If an instance of the config class is provided 213 | self.semantic_tokenizer_config = semantic_tokenizer_config 214 | 215 | if decoder_config is None: 216 | self.decoder_config = self.sub_configs["decoder_config"]() 217 | elif isinstance(decoder_config, dict): 218 | # If a dictionary is provided, instantiate the config class with it 219 | # self.decoder_config = self.sub_configs["decoder_config"](**decoder_config) 220 | if decoder_config.get("model_type", '') == "qwen2": 221 | self.decoder_config = Qwen2Config(**decoder_config) 222 | else: 223 | raise ValueError(f"Unsupported decoder model type: {decoder_config.get('model_type', '')}") 224 | elif isinstance(decoder_config, (Qwen2Config,)): 225 | # If an instance of the config class is provided 226 | self.decoder_config = decoder_config 227 | 228 | if diffusion_head_config is None: 229 | self.diffusion_head_config = self.sub_configs["diffusion_head_config"]() 230 | elif isinstance(diffusion_head_config, dict): 231 | diffusion_head_config["model_type"] = "vibevoice_diffusion_head" 232 | self.diffusion_head_config = self.sub_configs["diffusion_head_config"](**diffusion_head_config) 233 | elif isinstance(diffusion_head_config, VibeVoiceDiffusionHeadConfig): 234 | # If an instance of the config class is provided 235 | self.diffusion_head_config = diffusion_head_config 236 | 237 | # other parameters 238 | self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, 'vae_dim', 64) 239 | self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, 'vae_dim', 128) 240 | 241 | self.num_hidden_layers = self.decoder_config.num_hidden_layers 242 | super().__init__(**kwargs) 243 | 244 | __all__ = [ 245 | "VibeVoiceAcousticTokenizerConfig", 246 | "VibeVoiceSemanticTokenizerConfig", 247 | "VibeVoiceDiffusionHeadConfig", 248 | "VibeVoiceConfig" 249 | ] -------------------------------------------------------------------------------- /vibevoice_nodes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gc 3 | import logging 4 | 5 | import comfy.model_management as model_management 6 | from comfy.utils import ProgressBar 7 | 8 | # Import from the dedicated model_info module 9 | from .modules.model_info import AVAILABLE_VIBEVOICE_MODELS 10 | from .modules.loader import VibeVoiceModelHandler, ATTENTION_MODES, VIBEVOICE_PATCHER_CACHE, cleanup_old_models 11 | from .modules.patcher import VibeVoicePatcher 12 | from .modules.utils import parse_script_1_based, preprocess_comfy_audio, set_vibevoice_seed, check_for_interrupt 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | class VibeVoiceTTSNode: 17 | @classmethod 18 | def INPUT_TYPES(cls): 19 | model_names = list(AVAILABLE_VIBEVOICE_MODELS.keys()) 20 | if not model_names: 21 | model_names.append("No models found in models/tts/VibeVoice") 22 | 23 | return { 24 | "required": { 25 | "model_name": (model_names, { 26 | "tooltip": "Select the VibeVoice model to use. Official models will be downloaded automatically." 27 | }), 28 | "text": ("STRING", { 29 | "multiline": True, 30 | "default": "[1] Hello, this is a cloned voice.\n[2] And this is a generated voice, how cool is that?", 31 | "tooltip": "The script for generation. Use '[1]' or 'Speaker 1:' for speakers. If a speaker in the script lacks a reference voice, it will be generated via zero-shot TTS." 32 | }), 33 | "quantize_llm_4bit": ("BOOLEAN", { 34 | "default": False, "label_on": "Q4 (LLM only)", "label_off": "Full precision", 35 | "tooltip": "Quantize the Qwen2.5 LLM to 4-bit NF4 via bitsandbytes. Diffusion head stays BF16/FP32." 36 | }), 37 | "attention_mode": (ATTENTION_MODES, { 38 | "default": "sdpa", 39 | "tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest), Sage (quantized)" 40 | }), 41 | "cfg_scale": ("FLOAT", { 42 | "default": 1.3, "min": 0.1, "max": 50.0, "step": 0.05, 43 | "tooltip": "Classifier-Free Guidance scale. Higher values increase adherence to the voice prompt but may reduce naturalness. Recommended: 1.3" 44 | }), 45 | "inference_steps": ("INT", { 46 | "default": 10, "min": 1, "max": 500, 47 | "tooltip": "Number of diffusion steps for audio generation. More steps can improve quality but take longer. Recommended: 10" 48 | }), 49 | "seed": ("INT", { 50 | "default": 42, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, "control_after_generate": True, 51 | "tooltip": "Seed for reproducibility. Set to 0 for a random seed on each run." 52 | }), 53 | "do_sample": ("BOOLEAN", { 54 | "default": True, "label_on": "Enabled (Sampling)", "label_off": "Disabled (Greedy)", 55 | "tooltip": "Enable to use sampling methods (like temperature and top_p) for more varied output. Disable for deterministic (greedy) decoding." 56 | }), 57 | "temperature": ("FLOAT", { 58 | "default": 0.95, "min": 0.0, "max": 2.0, "step": 0.01, 59 | "tooltip": "Controls randomness. Higher values make the output more random and creative, while lower values make it more focused and deterministic. Active only if 'do_sample' is enabled." 60 | }), 61 | "top_p": ("FLOAT", { 62 | "default": 0.95, "min": 0.0, "max": 1.0, "step": 0.01, 63 | "tooltip": "Nucleus sampling (Top-P). The model samples from the smallest set of tokens whose cumulative probability exceeds this value. Active only if 'do_sample' is enabled." 64 | }), 65 | "top_k": ("INT", { 66 | "default": 0, "min": 0, "max": 500, "step": 1, 67 | "tooltip": "Top-K sampling. Restricts sampling to the K most likely next tokens. Set to 0 to disable. Active only if 'do_sample' is enabled." 68 | }), 69 | "force_offload": ("BOOLEAN", { 70 | "default": False, "label_on": "Force Offload", "label_off": "Keep in VRAM", 71 | "tooltip": "Force model to be offloaded from VRAM after generation. Useful to free up memory between generations but may slow down subsequent runs." 72 | }), 73 | }, 74 | "optional": { 75 | "speaker_1_voice": ("AUDIO", {"tooltip": "Reference audio for 'Speaker 1' or '[1]' in the script."}), 76 | "speaker_2_voice": ("AUDIO", {"tooltip": "Reference audio for 'Speaker 2' or '[2]' in the script."}), 77 | "speaker_3_voice": ("AUDIO", {"tooltip": "Reference audio for 'Speaker 3' or '[3]' in the script."}), 78 | "speaker_4_voice": ("AUDIO", {"tooltip": "Reference audio for 'Speaker 4' or '[4]' in the script."}), 79 | } 80 | } 81 | 82 | RETURN_TYPES = ("AUDIO",) 83 | FUNCTION = "generate_audio" 84 | CATEGORY = "audio/tts" 85 | 86 | def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, quantize_llm_4bit, force_offload, **kwargs): 87 | actual_attention_mode = attention_mode 88 | if quantize_llm_4bit and attention_mode in ["eager", "flash_attention_2"]: 89 | actual_attention_mode = "sdpa" 90 | 91 | cache_key = f"{model_name}_attn_{actual_attention_mode}_q4_{int(quantize_llm_4bit)}" 92 | 93 | if cache_key not in VIBEVOICE_PATCHER_CACHE: 94 | cleanup_old_models(keep_cache_key=cache_key) 95 | 96 | model_handler = VibeVoiceModelHandler(model_name, attention_mode, use_llm_4bit=quantize_llm_4bit) 97 | patcher = VibeVoicePatcher( 98 | model_handler, 99 | attention_mode=attention_mode, 100 | load_device=model_management.get_torch_device(), 101 | offload_device=model_management.unet_offload_device(), 102 | size=model_handler.size 103 | ) 104 | VIBEVOICE_PATCHER_CACHE[cache_key] = patcher 105 | 106 | patcher = VIBEVOICE_PATCHER_CACHE[cache_key] 107 | model_management.load_model_gpu(patcher) 108 | model = patcher.model.model 109 | processor = patcher.model.processor 110 | 111 | if model is None or processor is None: 112 | raise RuntimeError("VibeVoice model and processor could not be loaded. Check logs for errors.") 113 | 114 | parsed_lines_0_based, speaker_ids_1_based = parse_script_1_based(text) 115 | if not parsed_lines_0_based: 116 | raise ValueError("Script is empty or invalid. Please provide text to generate.") 117 | 118 | # full_script = "\n".join([f"Speaker {spk+1}: {txt}" for spk, txt in parsed_lines_0_based]) # <-- REMOVED: This was the cause of the bug. 119 | 120 | speaker_inputs = {i: kwargs.get(f"speaker_{i}_voice") for i in range(1, 5)} 121 | voice_samples_np = [preprocess_comfy_audio(speaker_inputs.get(sid)) for sid in speaker_ids_1_based] 122 | 123 | set_vibevoice_seed(seed) 124 | 125 | try: 126 | inputs = processor( 127 | parsed_scripts=[parsed_lines_0_based], 128 | voice_samples=[voice_samples_np], 129 | speaker_ids_for_prompt=[speaker_ids_1_based], 130 | padding=True, 131 | return_tensors="pt", 132 | return_attention_mask=True 133 | ) 134 | 135 | for key, value in inputs.items(): 136 | if isinstance(value, torch.Tensor): 137 | if torch.any(torch.isnan(value)) or torch.any(torch.isinf(value)): 138 | logger.error(f"Input tensor '{key}' contains NaN or Inf values") 139 | raise ValueError(f"Invalid values in input tensor: {key}") 140 | 141 | inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} 142 | 143 | model.set_ddpm_inference_steps(num_steps=inference_steps) 144 | 145 | generation_config = {'do_sample': do_sample} 146 | if do_sample: 147 | generation_config['temperature'] = temperature 148 | generation_config['top_p'] = top_p 149 | if top_k > 0: 150 | generation_config['top_k'] = top_k 151 | 152 | with torch.no_grad(): 153 | pbar = ProgressBar(inference_steps) 154 | 155 | def progress_callback(step, total_steps): 156 | pbar.update(1) 157 | if model_management.interrupt_current_processing: 158 | raise model_management.InterruptProcessingException() 159 | 160 | try: 161 | outputs = model.generate( 162 | **inputs, max_new_tokens=None, cfg_scale=cfg_scale, 163 | tokenizer=processor.tokenizer, generation_config=generation_config, 164 | verbose=False, stop_check_fn=check_for_interrupt 165 | ) 166 | pbar.update(inference_steps - pbar.current) 167 | 168 | except RuntimeError as e: 169 | error_msg = str(e).lower() 170 | if "assertion" in error_msg or "cuda" in error_msg: 171 | logger.error(f"CUDA assertion failed with {attention_mode} attention: {e}") 172 | logger.error("This might be due to invalid input data, GPU memory issues, or incompatible attention mode.") 173 | logger.error("Try restarting ComfyUI, using different audio files, or switching to 'eager' attention mode.") 174 | raise e 175 | except model_management.InterruptProcessingException: 176 | logger.info("VibeVoice generation interrupted by user") 177 | raise 178 | finally: 179 | pbar.update_absolute(inference_steps) 180 | 181 | except model_management.InterruptProcessingException: 182 | logger.info("VibeVoice TTS generation was cancelled") 183 | return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},) 184 | 185 | except Exception as e: 186 | logger.error(f"Error during VibeVoice generation with {attention_mode} attention: {e}") 187 | if "interrupt" in str(e).lower() or "cancel" in str(e).lower(): 188 | logger.info("Generation was interrupted") 189 | return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},) 190 | raise 191 | 192 | output_waveform = outputs.speech_outputs[0] 193 | if output_waveform.ndim == 1: output_waveform = output_waveform.unsqueeze(0) 194 | if output_waveform.ndim == 2: output_waveform = output_waveform.unsqueeze(0) 195 | 196 | if force_offload: 197 | logger.info(f"Force offloading VibeVoice model '{model_name}' from VRAM...") 198 | if patcher.is_loaded: 199 | patcher.unpatch_model(unpatch_weights=True) 200 | model_management.unload_all_models() 201 | gc.collect() 202 | model_management.soft_empty_cache() 203 | logger.info("Model force offload completed") 204 | 205 | return ({"waveform": output_waveform.detach().cpu(), "sample_rate": 24000},) 206 | 207 | NODE_CLASS_MAPPINGS = {"VibeVoiceTTS": VibeVoiceTTSNode} 208 | NODE_DISPLAY_NAME_MAPPINGS = {"VibeVoiceTTS": "VibeVoice TTS"} -------------------------------------------------------------------------------- /modules/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import gc 4 | import json 5 | import logging 6 | from huggingface_hub import hf_hub_download, snapshot_download 7 | 8 | import comfy.utils 9 | import folder_paths 10 | import comfy.model_management as model_management 11 | 12 | import transformers 13 | from packaging import version 14 | 15 | _transformers_version = version.parse(transformers.__version__) 16 | _DTYPE_ARG_SUPPORTED = _transformers_version >= version.parse("4.56.0") 17 | 18 | from transformers import BitsAndBytesConfig 19 | from ..vibevoice.modular.configuration_vibevoice import VibeVoiceConfig 20 | from ..vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference 21 | from ..vibevoice.processor.vibevoice_processor import VibeVoiceProcessor 22 | from ..vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor 23 | from ..vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizerFast 24 | 25 | from .model_info import AVAILABLE_VIBEVOICE_MODELS, MODEL_CONFIGS 26 | from .. import SAGE_ATTENTION_AVAILABLE 27 | if SAGE_ATTENTION_AVAILABLE: 28 | from ..vibevoice.modular.sage_attention_patch import set_sage_attention 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | LOADED_MODELS = {} 33 | VIBEVOICE_PATCHER_CACHE = {} 34 | 35 | ATTENTION_MODES = ["eager", "sdpa", "flash_attention_2"] 36 | if SAGE_ATTENTION_AVAILABLE: 37 | ATTENTION_MODES.append("sage") 38 | 39 | def cleanup_old_models(keep_cache_key=None): 40 | global LOADED_MODELS, VIBEVOICE_PATCHER_CACHE 41 | keys_to_remove = [] 42 | for key in list(LOADED_MODELS.keys()): 43 | if key != keep_cache_key: 44 | keys_to_remove.append(key) 45 | del LOADED_MODELS[key] 46 | for key in list(VIBEVOICE_PATCHER_CACHE.keys()): 47 | if key != keep_cache_key: 48 | try: 49 | patcher = VIBEVOICE_PATCHER_CACHE[key] 50 | if hasattr(patcher, 'model') and patcher.model: 51 | patcher.model.model = None 52 | patcher.model.processor = None 53 | del VIBEVOICE_PATCHER_CACHE[key] 54 | except Exception as e: 55 | logger.warning(f"Error cleaning up patcher {key}: {e}") 56 | if keys_to_remove: 57 | logger.info(f"Cleaned up cached models: {keys_to_remove}") 58 | gc.collect() 59 | model_management.soft_empty_cache() 60 | 61 | 62 | class VibeVoiceModelHandler(torch.nn.Module): 63 | def __init__(self, model_pack_name, attention_mode="eager", use_llm_4bit=False): 64 | super().__init__() 65 | self.model_pack_name = model_pack_name 66 | self.attention_mode = attention_mode 67 | self.use_llm_4bit = use_llm_4bit 68 | self.cache_key = f"{self.model_pack_name}_attn_{attention_mode}_q4_{int(use_llm_4bit)}" 69 | self.model = None 70 | self.processor = None 71 | info = AVAILABLE_VIBEVOICE_MODELS.get(model_pack_name, {}) 72 | size_gb = MODEL_CONFIGS.get(model_pack_name, {}).get("size_gb", 4.0) 73 | self.size = int(size_gb * (1024**3)) 74 | def load_model(self, device, attention_mode="eager"): 75 | self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device, attention_mode, use_llm_4bit=self.use_llm_4bit) 76 | if self.model.device != device: 77 | self.model.to(device) 78 | 79 | class VibeVoiceLoader: 80 | @staticmethod 81 | def _check_gpu_for_sage_attention(): 82 | if not SAGE_ATTENTION_AVAILABLE: return False 83 | if not torch.cuda.is_available(): return False 84 | major, _ = torch.cuda.get_device_capability() 85 | if major < 8: 86 | logger.warning(f"Your GPU (compute capability {major}.x) does not support SageAttention, which requires CC 8.0+. Sage option will be disabled.") 87 | return False 88 | return True 89 | 90 | @staticmethod 91 | def load_model(model_name: str, device, attention_mode: str = "eager", use_llm_4bit: bool = False): 92 | if model_name not in AVAILABLE_VIBEVOICE_MODELS: 93 | raise ValueError(f"Unknown VibeVoice model: {model_name}. Available models: {list(AVAILABLE_VIBEVOICE_MODELS.keys())}") 94 | 95 | if use_llm_4bit and attention_mode in ["eager", "flash_attention_2"]: 96 | logger.warning(f"Attention mode '{attention_mode}' is not recommended with 4-bit quantization. Falling back to 'sdpa' for stability and performance.") 97 | attention_mode = "sdpa" 98 | if attention_mode not in ATTENTION_MODES: 99 | logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager") 100 | attention_mode = "eager" 101 | 102 | cache_key = f"{model_name}_attn_{attention_mode}_q4_{int(use_llm_4bit)}" 103 | if cache_key in LOADED_MODELS: 104 | logger.info(f"Using cached model with {attention_mode} attention and q4={use_llm_4bit}") 105 | return LOADED_MODELS[cache_key] 106 | 107 | model_info = AVAILABLE_VIBEVOICE_MODELS[model_name] 108 | model_type = model_info["type"] 109 | vibevoice_base_path = os.path.join(folder_paths.get_folder_paths("tts")[0], "VibeVoice") 110 | 111 | model_path_or_none = None 112 | config_path = None 113 | preprocessor_config_path = None 114 | tokenizer_dir = None 115 | 116 | if model_type == "official": 117 | model_path_or_none = os.path.join(vibevoice_base_path, model_name) 118 | if not os.path.exists(os.path.join(model_path_or_none, "model.safetensors.index.json")): 119 | logger.info(f"Downloading official VibeVoice model: {model_name}...") 120 | snapshot_download(repo_id=model_info["repo_id"], local_dir=model_path_or_none, local_dir_use_symlinks=False) 121 | config_path = os.path.join(model_path_or_none, "config.json") 122 | preprocessor_config_path = os.path.join(model_path_or_none, "preprocessor_config.json") 123 | tokenizer_dir = model_path_or_none 124 | elif model_type == "local_dir": 125 | model_path_or_none = model_info["path"] 126 | config_path = os.path.join(model_path_or_none, "config.json") 127 | preprocessor_config_path = os.path.join(model_path_or_none, "preprocessor_config.json") 128 | tokenizer_dir = model_path_or_none 129 | elif model_type == "standalone": 130 | model_path_or_none = None # IMPORTANT: This must be None when loading from state_dict 131 | config_path = os.path.splitext(model_info["path"])[0] + ".config.json" 132 | preprocessor_config_path = os.path.splitext(model_info["path"])[0] + ".preprocessor.json" 133 | tokenizer_dir = os.path.dirname(model_info["path"]) 134 | 135 | if os.path.exists(config_path): 136 | config = VibeVoiceConfig.from_pretrained(config_path) 137 | else: 138 | fallback_name = "default_VibeVoice-Large_config.json" if "large" in model_name.lower() else "default_VibeVoice-1.5B_config.json" 139 | fallback_path = os.path.join(os.path.dirname(__file__), "..", "vibevoice", "configs", fallback_name) 140 | logger.warning(f"Config not found for '{model_name}'. Using fallback: {fallback_name}") 141 | config = VibeVoiceConfig.from_pretrained(fallback_path) 142 | 143 | # Processor & Tokenizer setup 144 | tokenizer_file_path = os.path.join(tokenizer_dir, "tokenizer.json") 145 | 146 | if not os.path.exists(tokenizer_file_path): 147 | logger.info(f"'tokenizer.json' not found in model directory: {tokenizer_dir}") 148 | 149 | packaged_configs_dir = os.path.join(os.path.dirname(__file__), "..", "vibevoice", "configs") 150 | packaged_tokenizer_path = os.path.join(packaged_configs_dir, "tokenizer.json") 151 | 152 | if os.path.exists(packaged_tokenizer_path): 153 | try: 154 | import shutil 155 | logger.info("Found pre-packaged tokenizer. Copying it to model directory...") 156 | shutil.copyfile(packaged_tokenizer_path, tokenizer_file_path) 157 | except Exception as e: 158 | logger.warning(f"Failed to copy pre-packaged tokenizer: {e}. Will attempt to download.") 159 | 160 | if not os.path.exists(tokenizer_file_path): 161 | repos_to_try = ["Qwen/Qwen2.5-1.5B", "Qwen/Qwen2.5-7B"] 162 | download_successful = False 163 | last_error = None 164 | 165 | for repo_id in repos_to_try: 166 | logger.info(f"Attempting to download 'tokenizer.json' from Hugging Face repo '{repo_id}'...") 167 | try: 168 | hf_hub_download( 169 | repo_id=repo_id, 170 | filename="tokenizer.json", 171 | local_dir=tokenizer_dir 172 | ) 173 | download_successful = True 174 | logger.info("Download successful.") 175 | break # Exit the loop on success 176 | except Exception as e: 177 | logger.warning(f"Failed to download from '{repo_id}': {e}") 178 | last_error = e 179 | 180 | # Final Failure 181 | if not download_successful: 182 | error_message = ( 183 | f"FATAL: Could not get 'tokenizer.json'. All download attempts failed.\n" 184 | f"Last error: {last_error}\n\n" 185 | f"ACTION REQUIRED:\n" 186 | f"1. Manually download 'tokenizer.json' from https://huggingface.co/{repos_to_try[0]}/blob/main/tokenizer.json\n" 187 | f"2. Place the downloaded file in the following directory:\n '{tokenizer_dir}'" 188 | ) 189 | raise RuntimeError(error_message) 190 | 191 | vibevoice_tokenizer = VibeVoiceTextTokenizerFast(tokenizer_file=tokenizer_file_path) 192 | 193 | processor_config_data = {} 194 | if os.path.exists(preprocessor_config_path): 195 | with open(preprocessor_config_path, 'r', encoding='utf-8') as f: processor_config_data = json.load(f) 196 | 197 | audio_processor = VibeVoiceTokenizerProcessor() 198 | processor = VibeVoiceProcessor(tokenizer=vibevoice_tokenizer, audio_processor=audio_processor, speech_tok_compress_ratio=processor_config_data.get("speech_tok_compress_ratio", 3200), db_normalize=processor_config_data.get("db_normalize", True)) 199 | 200 | # Model Loading Prep 201 | if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): model_dtype = torch.bfloat16 202 | else: model_dtype = torch.float16 203 | quant_config = None 204 | final_load_dtype = model_dtype 205 | 206 | if use_llm_4bit: 207 | bnb_compute_dtype = model_dtype 208 | if attention_mode == 'sage': bnb_compute_dtype, final_load_dtype = torch.float32, torch.float32 209 | quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=bnb_compute_dtype) 210 | 211 | attn_implementation_for_load = "sdpa" if attention_mode == "sage" else attention_mode 212 | 213 | try: 214 | logger.info(f"Loading model '{model_name}' with dtype: {final_load_dtype} and attention: '{attn_implementation_for_load}'") 215 | 216 | # UNIFIED MODEL LOADING LOGIC 217 | from_pretrained_kwargs = { 218 | "config": config, 219 | "attn_implementation": attn_implementation_for_load, 220 | "device_map": "auto" if quant_config else device, 221 | "quantization_config": quant_config, 222 | } 223 | if _DTYPE_ARG_SUPPORTED: 224 | from_pretrained_kwargs['dtype'] = final_load_dtype 225 | else: 226 | from_pretrained_kwargs['torch_dtype'] = final_load_dtype 227 | 228 | if model_type == "standalone": 229 | logger.info(f"Loading standalone model state_dict directly to device: {device}") 230 | # loading the state dict directly to the target device 231 | state_dict = comfy.utils.load_torch_file(model_info["path"], device=device) 232 | from_pretrained_kwargs["state_dict"] = state_dict 233 | 234 | model = VibeVoiceForConditionalGenerationInference.from_pretrained(model_path_or_none, **from_pretrained_kwargs) 235 | 236 | if attention_mode == "sage": 237 | if VibeVoiceLoader._check_gpu_for_sage_attention(): 238 | set_sage_attention(model) 239 | else: 240 | raise RuntimeError("Incompatible hardware/setup for SageAttention.") 241 | 242 | model.eval() 243 | setattr(model, "_llm_4bit", bool(quant_config)) 244 | LOADED_MODELS[cache_key] = (model, processor) 245 | logger.info(f"Successfully configured model '{model_name}' with {attention_mode} attention") 246 | return model, processor 247 | 248 | except Exception as e: 249 | # It's not ideal to automatically reload the model. Let the user decide what to do in case of an error. 250 | logger.error(f"Failed to load model '{model_name}' with {attention_mode} attention: {e}") 251 | # if attention_mode in ["sage", "flash_attention_2"]: return VibeVoiceLoader.load_model(model_name, device, "sdpa", use_llm_4bit) 252 | # elif attention_mode == "sdpa": return VibeVoiceLoader.load_model(model_name, device, "eager", use_llm_4bit) 253 | # else: 254 | raise RuntimeError(f"Failed to load model even with eager attention: {e}") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 |
5 |

ComfyUI-VibeVoice

6 | 7 | ComfyUI-VibeVoice Nodes 8 | 9 |

10 | A custom node for ComfyUI that integrates Microsoft's VibeVoice, a frontier model for generating expressive, long-form, multi-speaker conversational audio. 11 |
12 |
13 | Report Bug 14 | · 15 | Request Feature 16 | 17 | 18 | [![Stargazers][stars-shield]][stars-url] 19 | [![Issues][issues-shield]][issues-url] 20 | [![Contributors][contributors-shield]][contributors-url] 21 | [![Forks][forks-shield]][forks-url] 22 |

23 |
24 | 25 | 26 | 27 | ## About The Project 28 | 29 | VibeVoice is a novel framework by Microsoft for generating expressive, long-form, multi-speaker conversational audio. It excels at creating natural-sounding dialogue, podcasts, and more, with consistent voices for up to 4 speakers. 30 | 31 |
32 | ComfyUI-VibeVoice example workflow 33 |
34 | 35 | The custom node handles everything from model downloading and memory management to audio processing, allowing you to generate high-quality speech directly from a text script and reference audio files. 36 | 37 | **✨ Key Features:** 38 | * **Multi-Speaker TTS:** Generate conversations with up to 4 distinct voices in a single audio output. 39 | * **High-Fidelity Voice Cloning:** Use any audio file (`.wav`, `.mp3`) as a reference for a speaker's voice. 40 | * **Hybrid Generation Mode:** Mix and match cloned voices with high-quality, zero-shot generated voices in the same script. 41 | * **Flexible Scripting:** Use simple `[1]` tags or the classic `Speaker 1:` format to write your dialogue. 42 | * **Advanced Attention Mechanisms:** Choose between `eager`, `sdpa`, `flash_attention_2`, and the high-performance `sage` attention for fine-tuned control over speed and compatibility. 43 | * **Robust 4-Bit Quantization:** Run the large language model component in 4-bit mode to significantly reduce VRAM usage. 44 | * **Automatic Model Management:** Models are downloaded automatically and managed efficiently by ComfyUI to save VRAM. 45 | 46 |

(back to top)

47 | 48 | 49 | ## 🚀 Getting Started 50 | 51 | The easiest way to install is through the **ComfyUI Manager:** 52 | 1. Go to `Manager` -> `Install Custom Nodes`. 53 | 2. Search for `ComfyUI-VibeVoice` and click "Install". 54 | 3. Restart ComfyUI. 55 | 56 | Alternatively, to install manually: 57 | 58 | 1. **Clone the Repository:** 59 | Navigate to your `ComfyUI/custom_nodes/` directory and clone this repository: 60 | ```sh 61 | git clone https://github.com/wildminder/ComfyUI-VibeVoice.git 62 | ``` 63 | 64 | 2. **Install Dependencies:** 65 | Open a terminal or command prompt, navigate into the cloned directory, and install the required Python packages. **For quantization support, you must install `bitsandbytes`**. 66 | ```sh 67 | cd ComfyUI-VibeVoice 68 | pip install -r requirements.txt 69 | ``` 70 | 71 | 3. **Optional: Install SageAttention** 72 | To enable the `sage` attention mode, you must install the `sageattention` library. For Windows users, a pre-compiled wheel is available at [AI-windows-whl](https://github.com/wildminder/AI-windows-whl). 73 | > **Note:** This is only required if you intend to use the `sage` attention mode. 74 | 75 | 4. **Start/Restart ComfyUI:** 76 | Launch ComfyUI. The "VibeVoice TTS" node will appear under the `audio/tts` category. The first time you use the node, it will automatically download the selected model to your `ComfyUI/models/tts/VibeVoice/` folder. 77 | 78 | ## Models 79 | | Model | Context Length | Generation Length | Weight | 80 | |-------|----------------|----------|----------| 81 | | VibeVoice-1.5B | 64K | ~90 min | [HF link](https://huggingface.co/microsoft/VibeVoice-1.5B) | 82 | | VibeVoice-Large| 32K | ~45 min | [HF link](https://huggingface.co/aoi-ot/VibeVoice-Large) | 83 | 84 |

(back to top)

85 | 86 | 87 | ## 🛠️ Usage 88 | 89 | The node is designed for maximum flexibility within your ComfyUI workflow. 90 | 91 | 1. **Add Nodes:** Add the `VibeVoice TTS` node to your graph. Use ComfyUI's built-in `Load Audio` node to load your reference voice files. 92 | 2. **Connect Voices (Optional):** Connect the `AUDIO` output from each `Load Audio` node to the corresponding `speaker_*_voice` input. 93 | 3. **Write Your Script:** In the `text` input, write your dialogue using one of the supported formats. 94 | 4. **Generate:** Queue the prompt. The node will process the script and generate a single audio file containing the full conversation. 95 | 96 | > **Tip:** For a complete workflow, you can drag the example image from the `example_workflows` folder onto your ComfyUI canvas. 97 | 98 | ### Scripting and Voice Modes 99 | 100 | #### Speaker Tagging 101 | You can assign lines to speakers in two ways. Both are treated identically. 102 | 103 | * **Modern Format (Recommended):** `[1] This is the first speaker.` 104 | * **Classic Format:** `Speaker 1: This is the first speaker.` 105 | 106 | You can also add an optional colon to the modern format (e.g., `[1]: ...`). The node handles all variations consistently. 107 | 108 | #### Hybrid Voice Generation 109 | This is a powerful feature that lets you mix cloned voices and generated (zero-shot) voices. 110 | 111 | * **To Clone a Voice:** Connect a `Load Audio` node to the speaker's input (e.g., `speaker_1_voice`). 112 | * **To Generate a Voice:** Leave the speaker's input empty. The model will create a unique, high-quality voice for that speaker. 113 | 114 | **Example Hybrid Script:** 115 | ``` 116 | [1] This line will use the audio from speaker_1_voice. 117 | [2] This line will have a new, unique voice generated for it. 118 | [1] I'm back with my cloned voice. 119 | ``` 120 | In this example, you would only connect an audio source to `speaker_1_voice`. 121 | 122 | ### Node Inputs 123 | 124 | * **`model_name`**: Select the VibeVoice model to use (`1.5B` or `Large`). 125 | * **`text`**: The conversational script. See "Scripting and Voice Modes" above for formatting. 126 | * **`quantize_llm_4bit`**: Enable to run the LLM component in 4-bit (NF4) mode, dramatically reducing VRAM usage. 127 | * **`attention_mode`**: Select the attention implementation: `eager` (safest), `sdpa` (balanced), `flash_attention_2` (fastest), or `sage` (quantized high-performance). 128 | * **`cfg_scale`**: Controls how strongly the model adheres to the reference voice's timbre. Higher values are stricter. Recommended: `1.3`. 129 | * **`inference_steps`**: Number of diffusion steps for audio generation. Recommended: `10`. 130 | * **`seed`**: A seed for reproducibility. Set to 0 for a random seed on each run. 131 | * **`do_sample`, `temperature`, `top_p`, `top_k`**: Standard sampling parameters for controlling the creativity and determinism of the speech generation. 132 | * **`force_offload`**: Forces the model to be completely offloaded from VRAM after generation. 133 | 134 | 135 | ## ⚙️ Performance & Advanced Features 136 | 137 | This node features a sophisticated system for managing performance, memory, and stability. 138 | 139 | ### Feature Compatibility & VRAM Matrix 140 | 141 | | Quantize LLM | Attention Mode | Behavior / Notes | Relative VRAM | 142 | | :----------- | :------------------ | :---------------------------------------------------------------------------------------------------------------------------------------------- | :------------ | 143 | | **OFF** | `eager` | Full Precision. Most compatible baseline. | High | 144 | | **OFF** | `sdpa` | Full Precision. Recommended for balanced performance. | High | 145 | | **OFF** | `flash_attention_2` | Full Precision. High performance on compatible GPUs. | High | 146 | | **OFF** | `sage` | Full Precision. Uses high-performance mixed-precision kernels. | High | 147 | | **ON** | `eager` | **Falls back to `sdpa`** with `bfloat16` compute. Warns user. | **Low** | 148 | | **ON** | `sdpa` | **Recommended for memory savings.** Uses `bfloat16` compute. | **Low** | 149 | | **ON** | `flash_attention_2` | **Falls back to `sdpa`** with `bfloat16` compute. Warns user. | **Low** | 150 | | **ON** | `sage` | **Recommended for stability.** Uses `fp32` compute to ensure numerical stability with quantization, resulting in slightly higher VRAM usage. | **Medium** | 151 | 152 | 153 | 154 | ## Changelog 155 | 156 |
157 | v1.5.0 - Stability and Prompting 158 | 159 | ### ✨ New Features & Improvements 160 | * **Total Generation Stability:** Fixed the bug where a speaker's voice could unintentionally change or blend with another reference voice mid-sentence. 161 | * **Improved Voice Cloning Fidelity** 162 | * **Consistent Speaker Tagging:** The node now intelligently handles multiple script formats (`[1]`, `[1]:`, and `Speaker 1:`) to produce identical, high-quality results, removing all previous inconsistencies. 163 | * **Hybrid Voice Generation:** Mix and match cloned voices with high-quality, zero-shot generated voices in the same script. If a speaker's voice input is empty, a unique voice will be generated for them automatically. 164 |
165 | 166 |
167 | v1.3.0 - SageAttention & Quantization Overhaul 168 | 169 | * **SageAttention Support:** Full integration with the `sageattention` library for a high-performance, mixed-precision attention option. 170 | * **Robust 4-Bit LLM Quantization:** The "Quantize LLM (4-bit)" option is now highly stable and delivers significant VRAM savings. 171 | * **Smart Configuration & Fallbacks:** The node now automatically handles incompatible settings (e.g., 4-bit with `flash_attention_2`) by gracefully falling back to a stable alternative (`sdpa`) and notifying the user. 172 |
173 | 174 |
175 | v1.2.0 - Compatibility Update 176 | 177 | * **Transformers Library:** Includes automatic detection and compatibility for both older and newer versions of the Transformers library (pre- and post-4.56). 178 | * **Bug Fixes:** Resolved issues with `Force Offload` and multi-speaker generation on newer Transformers versions. 179 |
180 | 181 |

(back to top)

182 | 183 | ### Tips from the Original Authors 184 | 185 | * **Punctuation:** For Chinese text, using English punctuation (commas and periods) can improve stability. 186 | * **Model Choice:** The 7B model variant (`VibeVoice-Large`) is generally more stable. 187 | * **Spontaneous Sounds/Music:** The model may spontaneously generate background music, especially if the reference audio contains it or if the text includes introductory phrases like "Welcome to...". This is an emergent capability and cannot be directly controlled. 188 | * **Singing:** The model was not trained on singing data, but it may attempt to sing as an emergent behavior. Results may vary. 189 | 190 |

(back to top)

191 | 192 | 193 | ## License 194 | 195 | This project is distributed under the MIT License. See `LICENSE.txt` for more information. The VibeVoice model and its components are subject to the licenses provided by Microsoft. Please use responsibly. 196 | 197 |

(back to top)

198 | 199 | 200 | ## Acknowledgments 201 | 202 | * **Microsoft** for creating and open-sourcing the [VibeVoice](https://github.com/microsoft/VibeVoice) project. 203 | * **The ComfyUI team** for their incredible and extensible platform. 204 | 205 |

(back to top)

206 | 207 | ## Star History 208 | 209 | [![Star History Chart](https://api.star-history.com/svg?repos=wildminder/ComfyUI-VibeVoice&type=Timeline)](https://www.star-history.com/#wildminder/ComfyUI-VibeVoice&Timeline) 210 | 211 | 212 | 213 | [contributors-shield]: https://img.shields.io/github/contributors/wildminder/ComfyUI-VibeVoice.svg?style=for-the-badge 214 | [contributors-url]: https://github.com/wildminder/ComfyUI-VibeVoice/graphs/contributors 215 | [forks-shield]: https://img.shields.io/github/forks/wildminder/ComfyUI-VibeVoice.svg?style=for-the-badge 216 | [forks-url]: https://github.com/wildminder/ComfyUI-VibeVoice/network/members 217 | [stars-shield]: https://img.shields.io/github/stars/wildminder/ComfyUI-VibeVoice.svg?style=for-the-badge 218 | [stars-url]: https://github.com/wildminder/ComfyUI-VibeVoice/stargazers 219 | [issues-shield]: https://img.shields.io/github/issues/wildminder/ComfyUI-VibeVoice.svg?style=for-the-badge 220 | [issues-url]: https://github.com/wildminder/ComfyUI-VibeVoice/issues 221 | -------------------------------------------------------------------------------- /vibevoice/processor/vibevoice_tokenizer_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Processor class for VibeVoice models. 3 | """ 4 | 5 | import os 6 | import json 7 | import warnings 8 | from typing import List, Optional, Union, Dict, Any 9 | 10 | import numpy as np 11 | import torch 12 | 13 | from transformers.feature_extraction_utils import FeatureExtractionMixin 14 | from transformers.utils import logging 15 | 16 | logger = logging.get_logger(__name__) 17 | 18 | 19 | class AudioNormalizer: 20 | """ 21 | Audio normalization class for VibeVoice tokenizer. 22 | 23 | This class provides audio normalization to ensure consistent input levels 24 | for the VibeVoice tokenizer while maintaining audio quality. 25 | """ 26 | 27 | def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6): 28 | """ 29 | Initialize the audio normalizer. 30 | 31 | Args: 32 | target_dB_FS (float): Target dB FS level for the audio. Default: -25 33 | eps (float): Small value to avoid division by zero. Default: 1e-6 34 | """ 35 | self.target_dB_FS = target_dB_FS 36 | self.eps = eps 37 | 38 | def tailor_dB_FS(self, audio: np.ndarray) -> tuple: 39 | """ 40 | Adjust the audio to the target dB FS level. 41 | 42 | Args: 43 | audio (np.ndarray): Input audio signal 44 | 45 | Returns: 46 | tuple: (normalized_audio, rms, scalar) 47 | """ 48 | rms = np.sqrt(np.mean(audio**2)) 49 | scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps) 50 | normalized_audio = audio * scalar 51 | return normalized_audio, rms, scalar 52 | 53 | def avoid_clipping(self, audio: np.ndarray, scalar: Optional[float] = None) -> tuple: 54 | """ 55 | Avoid clipping by scaling down if necessary. 56 | 57 | Args: 58 | audio (np.ndarray): Input audio signal 59 | scalar (float, optional): Explicit scaling factor 60 | 61 | Returns: 62 | tuple: (normalized_audio, scalar) 63 | """ 64 | if scalar is None: 65 | max_val = np.max(np.abs(audio)) 66 | if max_val > 1.0: 67 | scalar = max_val + self.eps 68 | else: 69 | scalar = 1.0 70 | 71 | return audio / scalar, scalar 72 | 73 | def __call__(self, audio: np.ndarray) -> np.ndarray: 74 | """ 75 | Normalize the audio by adjusting to target dB FS and avoiding clipping. 76 | 77 | Args: 78 | audio (np.ndarray): Input audio signal 79 | 80 | Returns: 81 | np.ndarray: Normalized audio signal 82 | """ 83 | # First adjust to target dB FS 84 | audio, _, _ = self.tailor_dB_FS(audio) 85 | # Then avoid clipping 86 | audio, _ = self.avoid_clipping(audio) 87 | return audio 88 | 89 | 90 | # Change from ProcessorMixin to FeatureExtractionMixin which is designed for single components 91 | class VibeVoiceTokenizerProcessor(FeatureExtractionMixin): 92 | """ 93 | Processor for VibeVoice acoustic tokenizer models. 94 | 95 | This processor handles audio preprocessing for VibeVoice models, including: 96 | - Audio format conversion (stereo to mono) 97 | - Optional audio normalization 98 | - Streaming support for infinite-length audio 99 | 100 | Args: 101 | sampling_rate (int, optional): Expected sampling rate. Defaults to 24000. 102 | normalize_audio (bool, optional): Whether to normalize audio. Defaults to True. 103 | target_dB_FS (float, optional): Target dB FS for normalization. Defaults to -25. 104 | eps (float, optional): Small value for numerical stability. Defaults to 1e-6. 105 | """ 106 | model_input_names = ["input_features"] 107 | 108 | def __init__( 109 | self, 110 | sampling_rate: int = 24000, 111 | normalize_audio: bool = True, 112 | target_dB_FS: float = -25, 113 | eps: float = 1e-6, 114 | **kwargs, 115 | ): 116 | super().__init__(**kwargs) 117 | 118 | self.sampling_rate = sampling_rate 119 | self.normalize_audio = normalize_audio 120 | 121 | # Initialize audio normalizer if needed 122 | if self.normalize_audio: 123 | self.normalizer = AudioNormalizer(target_dB_FS=target_dB_FS, eps=eps) 124 | else: 125 | self.normalizer = None 126 | 127 | # Save config 128 | self.feature_extractor_dict = { 129 | "sampling_rate": sampling_rate, 130 | "normalize_audio": normalize_audio, 131 | "target_dB_FS": target_dB_FS, 132 | "eps": eps, 133 | } 134 | 135 | def _ensure_mono(self, audio: np.ndarray) -> np.ndarray: 136 | """ 137 | Convert stereo audio to mono if needed. 138 | 139 | Args: 140 | audio (np.ndarray): Input audio array 141 | 142 | Returns: 143 | np.ndarray: Mono audio array 144 | """ 145 | if len(audio.shape) == 1: 146 | return audio 147 | elif len(audio.shape) == 2: 148 | if audio.shape[0] == 2: # (2, time) 149 | return np.mean(audio, axis=0) 150 | elif audio.shape[1] == 2: # (time, 2) 151 | return np.mean(audio, axis=1) 152 | else: 153 | # If one dimension is 1, squeeze it 154 | if audio.shape[0] == 1: 155 | return audio.squeeze(0) 156 | elif audio.shape[1] == 1: 157 | return audio.squeeze(1) 158 | else: 159 | raise ValueError(f"Unexpected audio shape: {audio.shape}") 160 | else: 161 | raise ValueError(f"Audio should be 1D or 2D, got shape: {audio.shape}") 162 | 163 | def _process_single_audio(self, audio: Union[np.ndarray, List[float]]) -> np.ndarray: 164 | """ 165 | Process a single audio array. 166 | 167 | Args: 168 | audio: Single audio input 169 | 170 | Returns: 171 | np.ndarray: Processed audio 172 | """ 173 | # Convert to numpy array 174 | if not isinstance(audio, np.ndarray): 175 | audio = np.array(audio, dtype=np.float32) 176 | else: 177 | audio = audio.astype(np.float32) 178 | 179 | # Ensure mono 180 | audio = self._ensure_mono(audio) 181 | 182 | # Normalize if requested 183 | if self.normalize_audio and self.normalizer is not None: 184 | audio = self.normalizer(audio) 185 | 186 | return audio 187 | 188 | def __call__( 189 | self, 190 | audio: Union[str, np.ndarray, List[float], List[np.ndarray], List[List[float]], List[str]] = None, 191 | sampling_rate: Optional[int] = None, 192 | return_tensors: Optional[str] = None, 193 | **kwargs, 194 | ): 195 | """ 196 | Process audio for VibeVoice models. 197 | 198 | Args: 199 | audio: Audio input(s) to process. Can be: 200 | - str: Path to audio file 201 | - np.ndarray: Audio array 202 | - List[float]: Audio as list of floats 203 | - List[np.ndarray]: Batch of audio arrays 204 | - List[str]: Batch of audio file paths 205 | sampling_rate (int, optional): Sampling rate of the input audio 206 | return_tensors (str, optional): Return format ('pt' for PyTorch, 'np' for NumPy) 207 | 208 | Returns: 209 | dict: Processed audio inputs with keys: 210 | - input_features: Audio tensor(s) ready for the model 211 | """ 212 | if audio is None: 213 | raise ValueError("Audio input is required") 214 | 215 | # Validate sampling rate 216 | if sampling_rate is not None and sampling_rate != self.sampling_rate: 217 | logger.warning( 218 | f"Input sampling rate ({sampling_rate}) differs from expected " 219 | f"sampling rate ({self.sampling_rate}). Please resample your audio." 220 | ) 221 | 222 | # Handle different input types 223 | if isinstance(audio, str): 224 | # Single audio file path 225 | audio = self._load_audio_from_path(audio) 226 | is_batched = False 227 | elif isinstance(audio, list): 228 | if len(audio) == 0: 229 | raise ValueError("Empty audio list provided") 230 | 231 | # Check if it's a list of file paths 232 | if all(isinstance(item, str) for item in audio): 233 | # Batch of audio file paths 234 | audio = [self._load_audio_from_path(path) for path in audio] 235 | is_batched = True 236 | else: 237 | # Check if it's batched audio arrays 238 | is_batched = isinstance(audio[0], (np.ndarray, list)) 239 | else: 240 | # Single audio array or list 241 | is_batched = False 242 | 243 | # Process audio 244 | if is_batched: 245 | processed_audio = [self._process_single_audio(a) for a in audio] 246 | else: 247 | processed_audio = [self._process_single_audio(audio)] 248 | 249 | # Convert to tensors if requested 250 | if return_tensors == "pt": 251 | if len(processed_audio) == 1: 252 | # Create a proper batch dimension (B, T) 253 | input_features = torch.from_numpy(processed_audio[0]).unsqueeze(0).unsqueeze(1) 254 | else: 255 | # For batched input with different lengths, create a batch properly 256 | input_features = torch.stack([torch.from_numpy(a) for a in processed_audio]).unsqueeze(1) 257 | elif return_tensors == "np": 258 | if len(processed_audio) == 1: 259 | input_features = processed_audio[0][np.newaxis, np.newaxis, :] 260 | else: 261 | input_features = np.stack(processed_audio)[:, np.newaxis, :] 262 | else: 263 | input_features = processed_audio[0] if len(processed_audio) == 1 else processed_audio 264 | 265 | outputs = { 266 | "audio": input_features, # Use "audio" instead of "input_features" 267 | } 268 | 269 | return outputs 270 | 271 | def _load_audio_from_path(self, audio_path: str) -> np.ndarray: 272 | """ 273 | Load audio from file path. 274 | 275 | Args: 276 | audio_path (str): Path to audio file 277 | 278 | Returns: 279 | np.ndarray: Loaded audio array 280 | """ 281 | # Get file extension to determine loading method 282 | file_ext = os.path.splitext(audio_path)[1].lower() 283 | 284 | if file_ext in ['.wav', '.mp3', '.flac', '.m4a', '.ogg']: 285 | # Audio file - use librosa 286 | import librosa 287 | audio_array, sr = librosa.load( 288 | audio_path, 289 | sr=self.sampling_rate, 290 | mono=True 291 | ) 292 | return audio_array 293 | elif file_ext == '.pt': 294 | # PyTorch tensor file 295 | audio_tensor = torch.load(audio_path, map_location='cpu').squeeze() 296 | if isinstance(audio_tensor, torch.Tensor): 297 | audio_array = audio_tensor.numpy() 298 | else: 299 | audio_array = np.array(audio_tensor) 300 | return audio_array.astype(np.float32) 301 | elif file_ext == '.npy': 302 | # NumPy file 303 | audio_array = np.load(audio_path) 304 | return audio_array.astype(np.float32) 305 | else: 306 | raise ValueError( 307 | f"Unsupported file format: {file_ext}. " 308 | f"Supported formats: .wav, .mp3, .flac, .m4a, .ogg, .pt, .npy, .npz" 309 | ) 310 | 311 | def preprocess_audio( 312 | self, 313 | audio_path_or_array: Union[str, np.ndarray], 314 | normalize: Optional[bool] = None, 315 | ) -> np.ndarray: 316 | """ 317 | Convenience method to preprocess audio from file path or array. 318 | This method is kept for backward compatibility but __call__ is recommended. 319 | 320 | Args: 321 | audio_path_or_array: Path to audio file or numpy array 322 | normalize: Whether to normalize (overrides default setting) 323 | 324 | Returns: 325 | np.ndarray: Preprocessed audio array 326 | """ 327 | if isinstance(audio_path_or_array, str): 328 | audio_array = self._load_audio_from_path(audio_path_or_array) 329 | else: 330 | audio_array = np.array(audio_path_or_array, dtype=np.float32) 331 | 332 | # Override normalization setting if specified 333 | original_normalize = self.normalize_audio 334 | if normalize is not None: 335 | self.normalize_audio = normalize 336 | 337 | try: 338 | processed = self._process_single_audio(audio_array) 339 | finally: 340 | # Restore original setting 341 | self.normalize_audio = original_normalize 342 | 343 | return processed 344 | 345 | # Override to_dict method for configuration saving 346 | def to_dict(self) -> Dict[str, Any]: 347 | """ 348 | Convert the object to a dict containing all attributes needed for serialization. 349 | """ 350 | return self.feature_extractor_dict 351 | 352 | def save_audio( 353 | self, 354 | audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]], 355 | output_path: str = "output.wav", 356 | sampling_rate: Optional[int] = None, 357 | normalize: bool = False, 358 | batch_prefix: str = "audio_", 359 | ): 360 | """ 361 | Save audio data to WAV file(s). 362 | 363 | Args: 364 | audio: Audio data to save. Can be: 365 | - torch.Tensor: PyTorch tensor with shape (B, C, T) or (B, T) or (T) 366 | - np.ndarray: NumPy array with shape (B, C, T) or (B, T) or (T) 367 | - List of tensors or arrays 368 | output_path: Path where to save the audio. If saving multiple files, 369 | this is treated as a directory and individual files will be saved inside. 370 | sampling_rate: Sampling rate for the saved audio. Defaults to the processor's rate. 371 | normalize: Whether to normalize audio before saving. 372 | batch_prefix: Prefix for batch files when saving multiple audios. 373 | 374 | Returns: 375 | List[str]: Paths to the saved audio files. 376 | """ 377 | if sampling_rate is None: 378 | sampling_rate = self.sampling_rate 379 | 380 | try: 381 | import soundfile as sf 382 | except ImportError: 383 | raise ImportError( 384 | "soundfile is required to save audio files. " 385 | "Install it with: pip install soundfile" 386 | ) 387 | 388 | # Ensure audio is in the right format 389 | if isinstance(audio, torch.Tensor): 390 | # Convert PyTorch tensor to numpy 391 | audio_np = audio.float().detach().cpu().numpy() 392 | elif isinstance(audio, np.ndarray): 393 | audio_np = audio 394 | elif isinstance(audio, list): 395 | # Handle list of tensors or arrays 396 | if all(isinstance(a, torch.Tensor) for a in audio): 397 | audio_np = [a.float().detach().cpu().numpy() for a in audio] 398 | else: 399 | audio_np = audio 400 | else: 401 | raise ValueError(f"Unsupported audio type: {type(audio)}") 402 | 403 | saved_paths = [] 404 | 405 | # Handle based on shape or type 406 | if isinstance(audio_np, list): 407 | # Multiple separate audios to save 408 | output_dir = output_path 409 | 410 | # Ensure output directory exists 411 | os.makedirs(output_dir, exist_ok=True) 412 | 413 | # Save each audio 414 | for i, audio_item in enumerate(audio_np): 415 | audio_item = self._prepare_audio_for_save(audio_item, normalize) 416 | file_path = os.path.join(output_dir, f"{batch_prefix}{i}.wav") 417 | sf.write(file_path, audio_item, sampling_rate) 418 | saved_paths.append(file_path) 419 | 420 | else: 421 | # Handle different dimensions 422 | if len(audio_np.shape) >= 3: # (B, C, T) or similar 423 | # Get batch size 424 | batch_size = audio_np.shape[0] 425 | 426 | if batch_size > 1: 427 | # Multiple audios in a batch 428 | output_dir = output_path 429 | 430 | # Ensure output directory exists 431 | os.makedirs(output_dir, exist_ok=True) 432 | 433 | # Save each audio in the batch 434 | for i in range(batch_size): 435 | # Extract single audio and remove channel dim if present 436 | single_audio = audio_np[i] 437 | if len(single_audio.shape) > 1: 438 | if single_audio.shape[0] == 1: # (1, T) 439 | single_audio = single_audio.squeeze(0) 440 | 441 | single_audio = self._prepare_audio_for_save(single_audio, normalize) 442 | file_path = os.path.join(output_dir, f"{batch_prefix}{i}.wav") 443 | sf.write(file_path, single_audio, sampling_rate) 444 | saved_paths.append(file_path) 445 | else: 446 | # Single audio with batch and channel dims 447 | audio_item = audio_np.squeeze() # Remove batch and channel dimensions 448 | audio_item = self._prepare_audio_for_save(audio_item, normalize) 449 | sf.write(output_path, audio_item, sampling_rate) 450 | saved_paths.append(output_path) 451 | else: 452 | # Single audio without batch dimension 453 | audio_item = self._prepare_audio_for_save(audio_np, normalize) 454 | sf.write(output_path, audio_item, sampling_rate) 455 | saved_paths.append(output_path) 456 | 457 | return saved_paths 458 | 459 | def _prepare_audio_for_save(self, audio: np.ndarray, normalize: bool) -> np.ndarray: 460 | """ 461 | Prepare audio for saving by ensuring it's the right shape and optionally normalizing. 462 | 463 | Args: 464 | audio: Audio data as numpy array 465 | normalize: Whether to normalize audio 466 | 467 | Returns: 468 | np.ndarray: Processed audio ready for saving 469 | """ 470 | # Ensure right dimensionality 471 | if len(audio.shape) > 1 and audio.shape[0] == 1: # (1, T) 472 | audio = audio.squeeze(0) 473 | 474 | # Normalize if requested 475 | if normalize: 476 | max_val = np.abs(audio).max() 477 | if max_val > 0: 478 | audio = audio / max_val 479 | 480 | return audio 481 | 482 | 483 | __all__ = ["VibeVoiceTokenizerProcessor", "AudioNormalizer"] -------------------------------------------------------------------------------- /vibevoice/modular/modeling_vibevoice.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, List, Optional, Tuple, Union, Callable 3 | from tqdm import tqdm 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.distributed as dist 8 | 9 | from transformers.models.auto import AutoModel, AutoModelForCausalLM 10 | 11 | from transformers.activations import ACT2FN 12 | from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput 13 | from transformers.models.llama.modeling_llama import LlamaRMSNorm 14 | from transformers import modeling_utils 15 | from transformers.modeling_utils import PreTrainedModel 16 | from transformers.modeling_flash_attention_utils import FlashAttentionKwargs 17 | from transformers.utils import logging 18 | 19 | 20 | from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel 21 | from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead 22 | from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler 23 | 24 | from .configuration_vibevoice import VibeVoiceConfig 25 | 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: 30 | modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"] 31 | 32 | @dataclass 33 | class VibeVoiceCausalLMOutputWithPast(ModelOutput): 34 | loss: Optional[torch.FloatTensor] = None 35 | diffusion_loss: Optional[torch.FloatTensor] = None 36 | speech_token_num: Optional[int] = None 37 | logits: torch.FloatTensor = None 38 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 39 | hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None 40 | attentions: Optional[Tuple[torch.FloatTensor, ...]] = None 41 | 42 | 43 | @dataclass 44 | class VibeVoiceGenerationOutput(ModelOutput): 45 | """ 46 | Output type for VibeVoice generation. 47 | 48 | Args: 49 | sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 50 | The generated sequences. 51 | speech_outputs (`List[torch.FloatTensor]`, *optional*): 52 | List of generated speech waveforms or latents for each speech segment. 53 | """ 54 | sequences: torch.LongTensor = None 55 | speech_outputs: Optional[List[torch.FloatTensor]] = None 56 | 57 | 58 | class SpeechConnector(nn.Module): 59 | def __init__(self, input_dim, output_dim): 60 | super().__init__() 61 | self.fc1 = nn.Linear(input_dim, output_dim) 62 | self.norm = LlamaRMSNorm(output_dim, eps=1e-6) 63 | self.fc2 = nn.Linear(output_dim, output_dim) 64 | 65 | def forward(self, features, **kwargs): 66 | x = self.fc1(features) 67 | x = self.norm(x) 68 | x = self.fc2(x) 69 | return x 70 | 71 | 72 | # @auto_docstring 73 | class VibeVoicePreTrainedModel(PreTrainedModel): 74 | config_class = VibeVoiceConfig 75 | base_model_prefix = "model" 76 | supports_gradient_checkpointing = True 77 | _skip_keys_device_placement = "past_key_values" 78 | _supports_cache_class = True 79 | _supports_flash_attn_2 = True 80 | _supports_sdpa = True 81 | _supports_quantized_cache = True 82 | _supports_static_cache = True 83 | _supports_attention_backend = True 84 | 85 | def _init_weights(self, module): 86 | if isinstance(module, VibeVoiceDiffusionHead): 87 | module.initialize_weights() 88 | return 89 | 90 | # Use the language model's initializer_range if available 91 | if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'): 92 | std = self.config.language_model_config.initializer_range 93 | elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'): 94 | std = self.config.decoder_config.initializer_range 95 | else: 96 | std = 0.02 # Default value 97 | 98 | if isinstance(module, nn.Linear): 99 | module.weight.data.normal_(mean=0.0, std=std) 100 | if module.bias is not None: 101 | module.bias.data.zero_() 102 | elif isinstance(module, nn.LayerNorm): 103 | module.weight.data.fill_(1.0) 104 | module.bias.data.zero_() 105 | 106 | # @auto_docstring 107 | class VibeVoiceModel(VibeVoicePreTrainedModel): 108 | def __init__(self, config): 109 | super().__init__(config) 110 | 111 | if hasattr(config, 'torch_dtype') and config.torch_dtype is not None: 112 | if isinstance(config.torch_dtype, str): 113 | dtype = getattr(torch, config.torch_dtype) 114 | else: 115 | dtype = config.torch_dtype 116 | else: 117 | dtype = torch.float32 118 | 119 | # Initialize Qwen2 model for language modeling 120 | lm_config = config.decoder_config 121 | self.language_model = AutoModel.from_config(lm_config) 122 | 123 | # Initialize speech components if needed 124 | self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype) 125 | self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype) 126 | 127 | self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype) 128 | self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype) 129 | 130 | # Register scaling factors as buffers - use 1D tensors for FSDP compatibility 131 | self.register_buffer('speech_scaling_factor', torch.tensor(float('nan'))) 132 | self.register_buffer('speech_bias_factor', torch.tensor(float('nan'))) 133 | 134 | # Initialize prediction head for speech generation 135 | self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype) 136 | 137 | # Initialize noise scheduler 138 | self.noise_scheduler = DPMSolverMultistepScheduler( 139 | num_train_timesteps=config.diffusion_head_config.ddpm_num_steps, 140 | beta_schedule=config.diffusion_head_config.ddpm_beta_schedule, 141 | prediction_type=config.diffusion_head_config.prediction_type 142 | ) 143 | 144 | def get_input_embeddings(self): 145 | if hasattr(self.language_model, 'embed_tokens'): 146 | # If the language model has an embed_tokens attribute, return it 147 | return self.language_model.embed_tokens 148 | 149 | for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed 150 | if attr.orig_name == 'embed_tokens.weight': 151 | return getattr(self.language_model, name) 152 | assert False, 'should not arrive here' 153 | 154 | def set_input_embeddings(self, value): 155 | self.language_model.embed_tokens = value 156 | 157 | def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None): 158 | """Set the speech tokenizers used for encoding and decoding speech.""" 159 | self.acoustic_tokenizer = acoustic_tokenizer 160 | self.semantic_tokenizer = semantic_tokenizer 161 | 162 | # Reset the encoder to evaluation mode 163 | if self.acoustic_tokenizer is not None: 164 | self.acoustic_tokenizer.eval() 165 | 166 | if self.semantic_tokenizer is not None: 167 | self.semantic_tokenizer.eval() 168 | 169 | def forward( 170 | self, 171 | input_ids: torch.LongTensor = None, 172 | attention_mask: Optional[torch.Tensor] = None, 173 | position_ids: Optional[torch.LongTensor] = None, 174 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, 175 | inputs_embeds: Optional[torch.FloatTensor] = None, 176 | use_cache: Optional[bool] = None, 177 | output_attentions: Optional[bool] = None, 178 | output_hidden_states: Optional[bool] = None, 179 | return_dict: Optional[bool] = None, 180 | cache_position: Optional[torch.LongTensor] = None, 181 | **kwargs, 182 | ) -> Union[Tuple, BaseModelOutputWithPast]: 183 | 184 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 185 | 186 | # Forward through language model 187 | outputs = self.language_model( 188 | input_ids=input_ids, 189 | attention_mask=attention_mask, 190 | position_ids=position_ids, 191 | past_key_values=past_key_values, 192 | inputs_embeds=inputs_embeds, 193 | use_cache=use_cache, 194 | output_attentions=output_attentions, 195 | output_hidden_states=output_hidden_states, 196 | return_dict=return_dict, 197 | cache_position=cache_position, 198 | **kwargs, 199 | ) 200 | 201 | if not return_dict: 202 | return outputs 203 | 204 | return BaseModelOutputWithPast( 205 | last_hidden_state=outputs.last_hidden_state, 206 | past_key_values=outputs.past_key_values, 207 | hidden_states=outputs.hidden_states, 208 | attentions=outputs.attentions, 209 | ) 210 | 211 | 212 | class VibeVoiceForConditionalGeneration(VibeVoicePreTrainedModel): 213 | _tied_weights_keys = ["lm_head.weight"] 214 | _tp_plan = {"lm_head": "colwise_rep"} 215 | 216 | def __init__(self, config): 217 | super().__init__(config) 218 | self.model = VibeVoiceModel(config) 219 | self.vocab_size = config.decoder_config.vocab_size 220 | self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False) 221 | 222 | self.post_init() 223 | 224 | def get_input_embeddings(self): 225 | return self.model.get_input_embeddings() 226 | 227 | def set_input_embeddings(self, value): 228 | self.model.set_input_embeddings(value) 229 | 230 | def get_output_embeddings(self): 231 | return self.lm_head 232 | 233 | def set_decoder(self, decoder): 234 | self.model.language_model = decoder 235 | 236 | def get_decoder(self): 237 | return self.model.language_model 238 | 239 | def tie_weights(self): 240 | """ 241 | Tie the weights between the input embeddings and the output embeddings. 242 | """ 243 | if getattr(self.config.decoder_config, 'tie_word_embeddings', False): 244 | # The standard PreTrainedModel method will handle the tying. 245 | # It typically does a simple parameter object assignment, which is 246 | # CORRECT to do BEFORE FSDP wraps the model. 247 | output_embeddings = self.get_output_embeddings() 248 | input_embeddings = self.get_input_embeddings() 249 | if hasattr(input_embeddings, 'weight'): 250 | output_embeddings.weight = input_embeddings.weight 251 | else: 252 | # maybe returned input_embeddings a tensor directly 253 | output_embeddings.weight = input_embeddings 254 | 255 | if getattr(output_embeddings, "bias", None) is not None: 256 | output_embeddings.bias.data = nn.functional.pad( 257 | output_embeddings.bias.data, 258 | (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]), 259 | "constant", 260 | 0, 261 | ) 262 | print("✅ Tied input and output embeddings using standard assignment.") 263 | else: 264 | print("ℹ️ tie_word_embeddings is False, not tying weights.") 265 | 266 | # Also, ensure set_output_embeddings is safe, though your implementation looks okay. 267 | # The key is to avoid calling it after accelerator.prepare(). 268 | def set_output_embeddings(self, new_embeddings): 269 | # Your current implementation using data.copy_ is good practice, 270 | # but the best way is to not call this after prepare(). 271 | self.lm_head = new_embeddings 272 | 273 | def forward_speech_features( 274 | self, 275 | speech_tensors=None, 276 | speech_masks=None, 277 | speech_type="audio", 278 | return_unmask=False 279 | ): 280 | if speech_tensors is None: 281 | # Use config to get vae_dim instead of non-existent self.args 282 | vae_dim = self.config.acoustic_tokenizer_config.vae_dim 283 | audio_features = torch.zeros(1, 1, vae_dim).to(self.get_input_embeddings().weight) 284 | connect_features = self.model.acoustic_connector(audio_features) 285 | return audio_features, connect_features 286 | else: 287 | with torch.no_grad(): 288 | if speech_type == "audio": 289 | with torch.no_grad(): 290 | frames = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))[0][0] 291 | audio_tokens = frames.sample(self.model.acoustic_tokenizer.std_dist_type)[0] 292 | 293 | elif speech_type == "vae": 294 | # Use config to get vae_dim instead of non-existent self.args 295 | vae_dim = self.config.acoustic_tokenizer_config.vae_dim 296 | speech_mode = speech_tensors.reshape(speech_tensors.size(0), -1, vae_dim) 297 | 298 | # gaussian sample from the speech_mode 299 | batch_size = speech_mode.size(0) 300 | value = self.model.acoustic_tokenizer.fix_std / 0.8 301 | std = torch.randn(batch_size, dtype=speech_mode.dtype, device=speech_mode.device) * value 302 | std = std.view(-1, *[1] * (speech_mode.dim() - 1)) 303 | audio_tokens = speech_mode + std * torch.randn(speech_mode.shape).to(speech_mode) 304 | else: 305 | raise NotImplementedError(f"Speech type {speech_type} not implemented") 306 | 307 | if torch.isnan(self.model.speech_scaling_factor) or torch.isnan(self.model.speech_bias_factor): 308 | scaling_factor = 1. / audio_tokens[speech_masks].flatten().std() 309 | bias_factor = -audio_tokens[speech_masks].flatten().mean() 310 | 311 | # Only use distributed operations if the process group is initialized 312 | if dist.is_available() and dist.is_initialized(): 313 | dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM) 314 | dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM) 315 | world_size = dist.get_world_size() 316 | self.model.speech_scaling_factor.copy_(scaling_factor / world_size) 317 | self.model.speech_bias_factor.copy_(bias_factor / world_size) 318 | print(f"Speech scaling factor (distributed): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True) 319 | else: 320 | # Single process case 321 | self.model.speech_scaling_factor.copy_(scaling_factor) 322 | self.model.speech_bias_factor.copy_(bias_factor) 323 | print(f"Speech scaling factor (single process): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True) 324 | 325 | audio_features = (audio_tokens + self.model.speech_bias_factor) * self.model.speech_scaling_factor 326 | 327 | connect_features = self.model.acoustic_connector(audio_features) 328 | if return_unmask: 329 | return audio_features, connect_features 330 | return audio_features[speech_masks], connect_features[speech_masks] 331 | 332 | def forward( 333 | self, 334 | input_ids: torch.LongTensor = None, 335 | attention_mask: Optional[torch.Tensor] = None, 336 | position_ids: Optional[torch.LongTensor] = None, 337 | past_key_values: Optional[List[torch.FloatTensor]] = None, 338 | inputs_embeds: Optional[torch.FloatTensor] = None, 339 | labels: Optional[torch.LongTensor] = None, 340 | use_cache: Optional[bool] = False, 341 | output_attentions: Optional[bool] = None, 342 | output_hidden_states: Optional[bool] = None, 343 | return_dict: Optional[bool] = None, 344 | cache_position: Optional[torch.LongTensor] = None, 345 | # New arguments for speech processing and loss calculation 346 | speech_tensors: Optional[torch.FloatTensor] = None, 347 | speech_masks: Optional[torch.BoolTensor] = None, 348 | speeches_loss_input: Optional[torch.FloatTensor] = None, 349 | speech_semantic_tensors: Optional[torch.FloatTensor] = None, 350 | acoustic_input_mask: Optional[torch.BoolTensor] = None, 351 | acoustic_loss_mask: Optional[torch.BoolTensor] = None, 352 | ddpm_batch_mul: int = 1, 353 | **kwargs: Optional[Dict[str, Union[torch.Tensor, str]]], 354 | ) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]: 355 | 356 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 357 | 358 | x = self.get_input_embeddings()(input_ids) 359 | 360 | semantic_speech_all_connect_features = self.model.semantic_connector(speech_semantic_tensors) 361 | if speeches_loss_input is not None: 362 | # only part audio need diffuse 363 | speech_all_features, speech_all_connect_features = self.forward_speech_features( 364 | speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None, 365 | speech_masks=speech_masks, 366 | speech_type=kwargs.get("speech_type", "audio"), 367 | return_unmask=True 368 | ) 369 | if speech_tensors is not None: 370 | if semantic_speech_all_connect_features is not None: 371 | x[acoustic_input_mask] = speech_all_connect_features[speech_masks] + semantic_speech_all_connect_features[speech_masks] 372 | else: 373 | x[acoustic_input_mask] = speech_all_connect_features[speech_masks] 374 | speech_features = speech_all_features[speeches_loss_input.unsqueeze(-1) & speech_masks] # only part audio need diffuse 375 | speech_connect_features = speech_all_connect_features[speeches_loss_input.unsqueeze(-1) & speech_masks] 376 | else: 377 | speech_features, speech_connect_features = self.forward_speech_features( 378 | speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None, 379 | speech_masks=speech_masks, 380 | speech_type=kwargs.get("speech_type", "audio"), 381 | ) 382 | if speech_tensors is not None: 383 | x[acoustic_input_mask] = speech_connect_features 384 | 385 | outputs = self.model( 386 | input_ids=None, 387 | attention_mask=attention_mask, 388 | position_ids=position_ids, 389 | past_key_values=past_key_values, 390 | inputs_embeds=x, 391 | use_cache=use_cache, 392 | output_attentions=output_attentions, 393 | output_hidden_states=False, 394 | return_dict=return_dict, 395 | cache_position=cache_position, 396 | ) 397 | 398 | hidden_states = outputs.last_hidden_state 399 | logits = self.lm_head(hidden_states) 400 | # logits = logits.float() 401 | 402 | loss = None 403 | if labels is not None: 404 | # The custom CE loss with masking is calculated in the training script. 405 | # We leave the standard loss calculation here as None. 406 | pass 407 | 408 | # --- Diffusion Loss Calculation --- 409 | diffusion_loss = None 410 | # This block is executed only if we are in a context that involves speech. 411 | if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0: 412 | condition_features = hidden_states[acoustic_loss_mask] 413 | 414 | speech_len, latent_size = speech_features.shape 415 | 416 | noise = torch.randn( 417 | (speech_len * ddpm_batch_mul, latent_size), 418 | device=hidden_states.device, 419 | dtype=hidden_states.dtype 420 | ) 421 | 422 | timesteps = torch.multinomial( 423 | torch.ones(self.config.diffusion_head_config.ddpm_num_steps), 424 | speech_len * ddpm_batch_mul, 425 | replacement=True, 426 | ).to(hidden_states.device) 427 | 428 | speech_features_repeated = speech_features.repeat_interleave(ddpm_batch_mul, dim=0) 429 | condition_features_repeated = condition_features.repeat_interleave(ddpm_batch_mul, dim=0) 430 | 431 | noisy_speech_features = self.model.noise_scheduler.add_noise( 432 | speech_features_repeated, noise, timesteps 433 | ) 434 | 435 | model_output = self.model.prediction_head( 436 | noisy_speech_features, 437 | timesteps.type_as(x), 438 | condition_features_repeated 439 | ) 440 | 441 | prediction_type = self.config.diffusion_head_config.prediction_type 442 | if prediction_type == "epsilon": 443 | target_for_loss = noise 444 | elif prediction_type == "v_prediction": 445 | target_for_loss = self.model.noise_scheduler.get_velocity( 446 | speech_features_repeated, noise, timesteps 447 | ) 448 | else: 449 | raise NotImplementedError(f"Prediction type {prediction_type} not implemented") 450 | 451 | diffusion_loss = F.mse_loss(model_output.float(), target_for_loss.float(), reduction='sum') 452 | if latent_size > 0 and ddpm_batch_mul > 0: 453 | diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul 454 | else: 455 | diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device) 456 | 457 | else: 458 | # Dummy loss for DDP to work when there are no speech samples in a batch, 459 | # but we are in a speech context. 460 | diffusion_loss = sum(p.sum() for p in self.model.prediction_head.parameters()) * 0.0 461 | diffusion_loss += sum(p.sum() for p in self.model.acoustic_connector.parameters()) * 0.0 462 | diffusion_loss += sum(p.sum() for p in self.model.semantic_connector.parameters()) * 0.0 463 | # --- End Diffusion Loss Calculation --- 464 | 465 | if not return_dict: 466 | output = (logits, speech_len) + outputs.to_tuple()[1:] 467 | return (loss, diffusion_loss) + output 468 | 469 | return VibeVoiceCausalLMOutputWithPast( 470 | loss=loss, 471 | diffusion_loss=diffusion_loss, 472 | speech_token_num=speech_len if speech_tensors is not None else 0, 473 | logits=logits, 474 | past_key_values=outputs.past_key_values, 475 | hidden_states=outputs.hidden_states, 476 | attentions=outputs.attentions, 477 | ) 478 | 479 | AutoModel.register(VibeVoiceConfig, VibeVoiceModel) 480 | AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGeneration) 481 | 482 | __all__ = [ 483 | "VibeVoiceModel", 484 | "VibeVoicePreTrainedModel", 485 | "VibeVoiceForConditionalGeneration", 486 | "VibeVoiceCausalLMOutputWithPast", 487 | "VibeVoiceGenerationOutput", 488 | ] -------------------------------------------------------------------------------- /vibevoice/processor/vibevoice_processor.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from typing import List, Optional, Union, Dict, Any, Tuple 4 | import os 5 | import re 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy 11 | from transformers.utils import TensorType, logging 12 | from .vibevoice_tokenizer_processor import AudioNormalizer 13 | 14 | logger = logging.get_logger(__name__) 15 | 16 | 17 | class VibeVoiceProcessor: 18 | r""" 19 | Constructs a VibeVoice processor which wraps a VibeVoice tokenizer and audio processor into a single processor. 20 | 21 | [`VibeVoiceProcessor`] offers all the functionalities of [`VibeVoiceTokenizer`] and [`VibeVoiceTokenizerProcessor`]. 22 | See the [`~VibeVoiceProcessor.__call__`] and [`~VibeVoiceProcessor.decode`] for more information. 23 | 24 | Args: 25 | tokenizer (`VibeVoiceTextTokenizer` or `VibeVoiceTextTokenizerFast`): 26 | The tokenizer for text processing. 27 | audio_processor (`VibeVoiceTokenizerProcessor`): 28 | The audio processor for speech processing. 29 | speech_tok_compress_ratio (`int`, *optional*, defaults to 3200): 30 | The compression ratio for speech tokenization. 31 | db_normalize (`bool`, *optional*, defaults to True): 32 | Whether to apply decibel normalization to audio inputs. 33 | """ 34 | 35 | def __init__(self, tokenizer=None, audio_processor=None, speech_tok_compress_ratio=3200, db_normalize=True, **kwargs): 36 | self.tokenizer = tokenizer 37 | self.audio_processor = audio_processor 38 | self.speech_tok_compress_ratio = speech_tok_compress_ratio 39 | self.db_normalize = db_normalize 40 | self.audio_normalizer = AudioNormalizer() if db_normalize else None 41 | self.system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n" 42 | 43 | @classmethod 44 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): 45 | """ 46 | Instantiate a VibeVoiceProcessor from a pretrained VibeVoice processor. 47 | 48 | Args: 49 | pretrained_model_name_or_path (`str` or `os.PathLike`): 50 | This can be either: 51 | - a string, the *model id* of a pretrained model 52 | - a path to a *directory* containing processor config 53 | 54 | Returns: 55 | [`VibeVoiceProcessor`]: The processor object instantiated from pretrained model. 56 | """ 57 | import os 58 | import json 59 | from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor 60 | from vibevoice.modular.modular_vibevoice_text_tokenizer import ( 61 | VibeVoiceTextTokenizer, 62 | VibeVoiceTextTokenizerFast 63 | ) 64 | 65 | # Load processor configuration 66 | config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json") 67 | if os.path.exists(config_path): 68 | with open(config_path, 'r') as f: 69 | config = json.load(f) 70 | else: 71 | logger.warning(f"No preprocessor_config.json found at {pretrained_model_name_or_path}, using defaults") 72 | config = { 73 | "speech_tok_compress_ratio": 3200, 74 | "db_normalize": True, 75 | } 76 | 77 | # Extract main processor parameters 78 | speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200) 79 | db_normalize = config.get("db_normalize", True) 80 | 81 | # Load tokenizer - try from model path first, then fallback to Qwen 82 | language_model_pretrained_name = config.get("language_model_pretrained_name", None) or kwargs.pop("language_model_pretrained_name", "Qwen/Qwen2.5-1.5B") 83 | logger.info(f"Loading tokenizer from {language_model_pretrained_name}") 84 | if 'qwen' in language_model_pretrained_name.lower(): 85 | tokenizer = VibeVoiceTextTokenizerFast.from_pretrained( 86 | language_model_pretrained_name, 87 | **kwargs 88 | ) 89 | else: 90 | raise ValueError(f"Unsupported tokenizer type for {language_model_pretrained_name}. Supported types: Qwen, Llama, Gemma.") 91 | 92 | # Load audio processor 93 | if "audio_processor" in config: 94 | # Create audio processor from config 95 | audio_config = config["audio_processor"] 96 | audio_processor = VibeVoiceTokenizerProcessor( 97 | sampling_rate=audio_config.get("sampling_rate", 24000), 98 | normalize_audio=audio_config.get("normalize_audio", True), 99 | target_dB_FS=audio_config.get("target_dB_FS", -25), 100 | eps=audio_config.get("eps", 1e-6), 101 | ) 102 | else: 103 | # Create default audio processor 104 | audio_processor = VibeVoiceTokenizerProcessor() 105 | 106 | # Create and return the processor 107 | return cls( 108 | tokenizer=tokenizer, 109 | audio_processor=audio_processor, 110 | speech_tok_compress_ratio=speech_tok_compress_ratio, 111 | db_normalize=db_normalize, 112 | ) 113 | 114 | def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): 115 | """ 116 | Save a processor to a directory, so that it can be re-loaded using the 117 | [`~VibeVoiceProcessor.from_pretrained`] class method. 118 | 119 | Args: 120 | save_directory (`str` or `os.PathLike`): 121 | Directory where the processor will be saved. 122 | """ 123 | import os 124 | import json 125 | 126 | os.makedirs(save_directory, exist_ok=True) 127 | 128 | # Save processor configuration 129 | processor_config = { 130 | "processor_class": "VibeVoiceProcessor", 131 | "speech_tok_compress_ratio": self.speech_tok_compress_ratio, 132 | "db_normalize": self.db_normalize, 133 | "audio_processor": { 134 | "feature_extractor_type": "VibeVoiceTokenizerProcessor", 135 | "sampling_rate": getattr(self.audio_processor, 'sampling_rate', 24000), 136 | "normalize_audio": getattr(self.audio_processor, 'normalize_audio', True), 137 | "target_dB_FS": getattr(self.audio_processor, 'target_dB_FS', -25), 138 | "eps": getattr(self.audio_processor, 'eps', 1e-6), 139 | } 140 | } 141 | 142 | config_path = os.path.join(save_directory, "preprocessor_config.json") 143 | with open(config_path, 'w') as f: 144 | json.dump(processor_config, f, indent=2) 145 | 146 | logger.info(f"Processor configuration saved in {config_path}") 147 | 148 | def __call__( 149 | self, 150 | text: Optional[List[str]] = None, 151 | parsed_scripts: Optional[List[List[Tuple[int, str]]]] = None, # <-- ADDED 152 | voice_samples: Optional[List[List[Optional[Union[str, np.ndarray]]]]] = None, 153 | speaker_ids_for_prompt: Optional[List[List[int]]] = None, 154 | padding: Union[bool, str, PaddingStrategy] = True, 155 | truncation: Union[bool, str, TruncationStrategy] = False, 156 | max_length: Optional[int] = None, 157 | return_tensors: Optional[Union[str, TensorType]] = None, 158 | return_attention_mask: bool = True, 159 | **kwargs, 160 | ) -> BatchEncoding: 161 | """ 162 | Main method to process one or more podcast scripts with optional voice samples. 163 | 164 | Args: 165 | text (`str`, `List[str]`): 166 | The input text(s) to process. Can be: 167 | - A single script string 168 | - A list of script strings for batch processing 169 | - A path to a .json or .txt file 170 | - A list of paths 171 | voice_samples (`List[Union[str, np.ndarray]]`, `List[List[Union[str, np.ndarray]]]`, *optional*): 172 | Voice samples for each script. Can be: 173 | - A list of samples for a single script 174 | - A list of lists for batch processing 175 | padding (`bool`, `str` or `PaddingStrategy`, defaults to `True`): 176 | Whether to pad sequences to the same length 177 | truncation (`bool`, `str` or `TruncationStrategy`, defaults to `False`): 178 | Whether to truncate sequences 179 | max_length (`int`, *optional*): 180 | Maximum length of the returned sequences 181 | return_tensors (`str` or `TensorType`, *optional*): 182 | If set, will return tensors of a particular framework 183 | return_attention_mask (`bool`, defaults to `True`): 184 | Whether to return the attention mask 185 | 186 | Returns: 187 | `BatchEncoding`: A BatchEncoding with the following fields: 188 | - **input_ids** -- List of token id sequences or tensor 189 | - **attention_mask** -- List of attention masks or tensor 190 | - **speech_tensors** -- Padded speech inputs (if voice_samples provided) 191 | - **speech_masks** -- Speech masks (if voice_samples provided) 192 | - **speech_input_mask** -- Boolean masks indicating speech token positions 193 | """ 194 | 195 | if parsed_scripts is None: 196 | if text is None: 197 | raise ValueError("Either 'text' or 'parsed_scripts' must be provided.") 198 | # Fallback for raw text input (though the node won't use this path) 199 | from ..modules.utils import parse_script_1_based 200 | parsed_scripts = [parse_script_1_based(t)[0] for t in text] 201 | 202 | num_scripts = len(parsed_scripts) 203 | voice_samples_list = voice_samples if voice_samples is not None else [[] for _ in range(num_scripts)] 204 | speaker_ids_list = speaker_ids_for_prompt if speaker_ids_for_prompt is not None else [[] for _ in range(num_scripts)] 205 | 206 | all_encodings = [] 207 | for i in range(num_scripts): 208 | # Pass all three corresponding items to _process_single 209 | encoding = self._process_single( 210 | parsed_scripts[i], 211 | voice_samples_list[i], 212 | speaker_ids_list[i] 213 | ) 214 | all_encodings.append(encoding) 215 | 216 | # Combine batch 217 | batch_encoding = self._batch_encode( 218 | all_encodings, 219 | padding=padding, 220 | truncation=truncation, 221 | max_length=max_length, 222 | return_tensors=return_tensors, 223 | return_attention_mask=return_attention_mask, 224 | ) 225 | 226 | return batch_encoding 227 | 228 | def _process_single( 229 | self, 230 | parsed_script: List[Tuple[int, str]], 231 | voice_samples: List[Optional[Union[str, np.ndarray]]], 232 | speaker_ids: List[int], 233 | ) -> Dict[str, Any]: 234 | 235 | system_tokens = self.tokenizer.encode(self.system_prompt) 236 | 237 | voice_tokens, voice_speech_inputs, voice_speech_masks = self._create_voice_prompt( 238 | voice_samples, speaker_ids 239 | ) 240 | 241 | full_tokens = system_tokens + voice_tokens 242 | speech_input_mask = [False] * len(system_tokens) + voice_speech_masks 243 | 244 | dialogue_lines = [] 245 | for speaker_id_0_based, text_chunk in parsed_script: 246 | speaker_id_1_based = speaker_id_0_based + 1 247 | dialogue_lines.append(f"Speaker {speaker_id_1_based}: : {text_chunk}") 248 | 249 | full_dialogue_script = "\n".join(dialogue_lines) 250 | 251 | final_prompt_text = f" Text input:\n{full_dialogue_script}\n Speech output:\n" 252 | 253 | prompt_tokens = self.tokenizer.encode(final_prompt_text, add_special_tokens=False) 254 | 255 | full_tokens += prompt_tokens + [self.tokenizer.speech_start_id] 256 | speech_input_mask += [False] * (len(prompt_tokens) + 1) 257 | 258 | return { 259 | "input_ids": full_tokens, 260 | "speech_inputs": voice_speech_inputs if voice_speech_inputs else None, 261 | "speech_input_mask": speech_input_mask, 262 | } 263 | 264 | def _batch_encode( 265 | self, 266 | encodings: List[Dict[str, Any]], 267 | padding: Union[bool, str, PaddingStrategy] = True, 268 | truncation: Union[bool, str, TruncationStrategy] = False, 269 | max_length: Optional[int] = None, 270 | return_tensors: Optional[Union[str, TensorType]] = None, 271 | return_attention_mask: bool = True, 272 | ) -> BatchEncoding: 273 | """Combine multiple encodings into a batch with padding.""" 274 | input_ids_list = [enc["input_ids"] for enc in encodings] 275 | speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings] 276 | 277 | if isinstance(padding, bool): 278 | padding_strategy = PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD 279 | elif isinstance(padding, str): 280 | padding_strategy = PaddingStrategy(padding) 281 | else: 282 | padding_strategy = padding 283 | 284 | # Apply padding to input_ids 285 | if padding_strategy != PaddingStrategy.DO_NOT_PAD: 286 | if padding_strategy == PaddingStrategy.LONGEST: 287 | max_len = max(len(ids) for ids in input_ids_list) 288 | elif padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None: 289 | max_len = max_length 290 | else: 291 | max_len = max(len(ids) for ids in input_ids_list) 292 | 293 | # Pad sequences 294 | padded_input_ids = [] 295 | attention_masks = [] 296 | padded_speech_input_masks = [] 297 | 298 | for input_ids, speech_mask in zip(input_ids_list, speech_input_masks_list): 299 | # Truncate if needed 300 | if truncation and len(input_ids) > max_len: 301 | input_ids = input_ids[:max_len] 302 | speech_mask = speech_mask[:max_len] 303 | 304 | # Pad 305 | padding_length = max_len - len(input_ids) 306 | # padded_ids = [self.tokenizer.pad_token_id] * padding_length + input_ids 307 | padded_ids = [self.tokenizer.pad_id] * padding_length + input_ids 308 | attention_mask = [0] * padding_length + [1] * len(input_ids) 309 | padded_speech_mask = [False] * padding_length + speech_mask 310 | 311 | padded_input_ids.append(padded_ids) 312 | attention_masks.append(attention_mask) 313 | padded_speech_input_masks.append(padded_speech_mask) 314 | 315 | input_ids_list = padded_input_ids 316 | speech_input_masks_list = padded_speech_input_masks 317 | else: 318 | # No padding, just create attention masks 319 | attention_masks = [[1] * len(ids) for ids in input_ids_list] if return_attention_mask else None 320 | 321 | all_speech_inputs = [] 322 | for enc in encodings: 323 | if enc.get("speech_inputs"): 324 | all_speech_inputs.extend(enc["speech_inputs"]) 325 | 326 | batch_encoding = BatchEncoding() 327 | 328 | # Handle tensor conversion 329 | if return_tensors is not None: 330 | batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long) 331 | if return_attention_mask and attention_masks is not None: 332 | batch_encoding["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long) 333 | batch_encoding["speech_input_mask"] = torch.tensor(speech_input_masks_list, dtype=torch.bool) 334 | else: 335 | batch_encoding["input_ids"] = input_ids_list 336 | if return_attention_mask and attention_masks is not None: 337 | batch_encoding["attention_mask"] = attention_masks 338 | batch_encoding["speech_input_mask"] = speech_input_masks_list 339 | 340 | if all_speech_inputs: 341 | speech_dict = self.prepare_speech_inputs(all_speech_inputs, return_tensors=return_tensors) 342 | batch_encoding["speech_tensors"] = speech_dict["padded_speeches"] 343 | batch_encoding["speech_masks"] = speech_dict["speech_masks"] 344 | else: 345 | batch_encoding["speech_tensors"] = None 346 | batch_encoding["speech_masks"] = None 347 | 348 | return batch_encoding 349 | 350 | def _create_voice_prompt( 351 | self, 352 | speaker_samples: List[Optional[Union[str, np.ndarray]]], 353 | speaker_ids: List[int] 354 | ) -> Tuple[List[int], List[np.ndarray], List[bool]]: 355 | """ 356 | Create voice prompt tokens and process audio samples. 357 | This function now handles `None` in the speaker_samples list for zero-shot speakers. 358 | 359 | Returns: 360 | tuple: (voice_tokens, voice_speech_inputs, voice_speech_masks) 361 | """ 362 | if not any(s is not None for s in speaker_samples): 363 | return [], [], [] 364 | 365 | vae_token_id = self.tokenizer.speech_diffusion_id 366 | 367 | voice_full_tokens = self.tokenizer.encode(' Voice input:\n', add_special_tokens=False) 368 | voice_speech_inputs = [] 369 | voice_speech_masks = [False] * len(voice_full_tokens) 370 | 371 | for speaker_id, speaker_audio in zip(speaker_ids, speaker_samples): 372 | 373 | if speaker_audio is not None: 374 | logger.info(f"Creating voice prompt for Speaker {speaker_id} from reference audio.") 375 | prefix_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:", add_special_tokens=False) 376 | newline_tokens = self.tokenizer.encode('\n', add_special_tokens=False) 377 | 378 | if isinstance(speaker_audio, str): 379 | wav = self.audio_processor._load_audio_from_path(speaker_audio) 380 | else: 381 | wav = np.array(speaker_audio, dtype=np.float32) 382 | 383 | if self.db_normalize and self.audio_normalizer: 384 | wav = self.audio_normalizer(wav) 385 | 386 | 387 | vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio) 388 | speaker_tokens = ( 389 | prefix_tokens + 390 | [self.tokenizer.speech_start_id] + 391 | [vae_token_id] * vae_tok_len + 392 | [self.tokenizer.speech_end_id] + 393 | newline_tokens 394 | ) 395 | 396 | vae_input_mask = ( 397 | [False] * len(prefix_tokens) + 398 | [False] + # for speech_start_id 399 | [True] * vae_tok_len + 400 | [False] + # for speech_end_id 401 | [False] * len(newline_tokens) 402 | ) 403 | voice_speech_inputs.append(wav) 404 | voice_full_tokens.extend(speaker_tokens) 405 | voice_speech_masks.extend(vae_input_mask) 406 | else: 407 | logger.info(f"Skipping voice prompt for Speaker {speaker_id} (zero-shot).") 408 | 409 | 410 | return voice_full_tokens, voice_speech_inputs, voice_speech_masks 411 | 412 | 413 | def prepare_speech_inputs( 414 | self, 415 | speech_inputs: List[np.ndarray], 416 | return_tensors: Optional[Union[str, TensorType]] = None, 417 | device: Optional[Union[str, torch.device]] = None, 418 | dtype: Optional[torch.dtype] = None, 419 | ) -> Dict[str, Any]: 420 | """ 421 | Prepare speech inputs for model consumption. 422 | 423 | Args: 424 | speech_inputs: List of speech arrays 425 | return_tensors: Output tensor type 426 | device: Device to place tensors on 427 | dtype: Data type for tensors 428 | 429 | Returns: 430 | Dictionary with padded_speeches and speech_masks 431 | """ 432 | if not speech_inputs: 433 | return {"padded_speeches": None, "speech_masks": None} 434 | 435 | # Calculate sequence lengths 436 | vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) for s in speech_inputs] 437 | # vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) if s.ndim == 1 else s.shape[0] for s in speech_inputs] 438 | max_speech_length = max(s.shape[0] for s in speech_inputs) 439 | 440 | # Pad speeches 441 | if speech_inputs[0].ndim == 1: 442 | padded_speeches = np.full((len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32) 443 | else: 444 | padded_speeches = np.full((len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), fill_value=0, dtype=np.float32) 445 | speech_masks = np.zeros((len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_) 446 | 447 | for i, (speech, vae_tok_length) in enumerate(zip(speech_inputs, vae_tok_seqlens)): 448 | padded_speeches[i, :len(speech)] = speech 449 | speech_masks[i, :vae_tok_length] = True 450 | 451 | result = {"padded_speeches": padded_speeches, "speech_masks": speech_masks} 452 | 453 | # Convert to tensors if requested 454 | if return_tensors == "pt": 455 | result["padded_speeches"] = torch.tensor(padded_speeches, device=device, dtype=dtype or torch.float32) 456 | result["speech_masks"] = torch.tensor(speech_masks, device=device, dtype=torch.bool) 457 | 458 | return result 459 | 460 | def _convert_json_to_script(self, json_file: str) -> str: 461 | """ 462 | Convert JSON format to script format. 463 | Expected JSON format: 464 | [ 465 | {"speaker": "1", "text": "Hello everyone..."}, 466 | {"speaker": "2", "text": "Great to be here..."} 467 | ] 468 | """ 469 | import json 470 | 471 | with open(json_file, 'r', encoding='utf-8') as f: 472 | data = json.load(f) 473 | 474 | if not isinstance(data, list): 475 | raise ValueError("JSON file must contain a list of speaker entries") 476 | 477 | script_lines = [] 478 | for item in data: 479 | if not isinstance(item, dict): 480 | logger.warning(f"Skipping non-dict entry: {item}") 481 | continue 482 | 483 | speaker = item.get('speaker') 484 | text = item.get('text') 485 | 486 | if speaker is None or text is None: 487 | logger.warning(f"Skipping entry missing speaker or text: {item}") 488 | continue 489 | 490 | # Ensure speaker ID is valid 491 | try: 492 | speaker_id = int(speaker) 493 | except (ValueError, TypeError): 494 | logger.warning(f"Invalid speaker ID: {speaker}, skipping entry") 495 | continue 496 | 497 | # Clean up text 498 | text = text.strip() 499 | if text: 500 | script_lines.append(f"Speaker {speaker_id}: {text}") 501 | 502 | if not script_lines: 503 | raise ValueError("No valid entries found in JSON file") 504 | 505 | return "\n".join(script_lines) 506 | 507 | def _convert_text_to_script(self, text_file: str) -> str: 508 | """ 509 | Convert text file to script format. 510 | Handles multiple formats: 511 | 1. Already formatted as "Speaker X: text" 512 | 2. Plain text (assigns to Speaker 1) 513 | 514 | Handles edge cases like multiple colons in a line. 515 | """ 516 | with open(text_file, 'r', encoding='utf-8') as f: 517 | lines = f.readlines() 518 | 519 | script_lines = [] 520 | current_speaker = 1 521 | 522 | for line in lines: 523 | line = line.strip() 524 | if not line: 525 | continue 526 | 527 | # Try to parse as "Speaker X: text" format 528 | # Use regex to be more robust 529 | speaker_match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line, re.IGNORECASE) 530 | 531 | if speaker_match: 532 | speaker_id = int(speaker_match.group(1)) 533 | text = speaker_match.group(2).strip() 534 | if text: 535 | script_lines.append(f"Speaker {speaker_id}: {text}") 536 | else: 537 | # Treat as plain text - assign to current speaker 538 | script_lines.append(f"Speaker {current_speaker}: {line}") 539 | 540 | if not script_lines: 541 | raise ValueError("No valid content found in text file") 542 | 543 | return "\n".join(script_lines) 544 | 545 | def _parse_script(self, script: str) -> List[Tuple[int, str]]: 546 | """Parse script into list of (speaker_id, text) tuples.""" 547 | lines = script.strip().split("\n") 548 | parsed_lines = [] 549 | speaker_ids = [] 550 | 551 | for line in lines: 552 | if not line.strip(): 553 | continue 554 | 555 | match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line.strip(), re.IGNORECASE) 556 | 557 | if match: 558 | speaker_id = int(match.group(1)) 559 | text = ' ' + match.group(2).strip() 560 | parsed_lines.append((speaker_id, text)) 561 | speaker_ids.append(speaker_id) 562 | else: 563 | logger.warning(f"Could not parse line: '{line}'") 564 | 565 | if not parsed_lines: 566 | raise ValueError("No valid speaker lines found in script") 567 | 568 | # Check if we need to normalize speaker IDs (only if all are > 0) 569 | min_speaker_id = min(speaker_ids) 570 | if min_speaker_id > 0: 571 | # Normalize to start from 0 572 | normalized_lines = [] 573 | for speaker_id, text in parsed_lines: 574 | normalized_lines.append((speaker_id - 1, text)) 575 | return normalized_lines 576 | else: 577 | # Keep original IDs 578 | return parsed_lines 579 | 580 | def _merge_inputs(self, text_inputs: BatchEncoding, audio_inputs: Dict) -> BatchEncoding: 581 | """Merge text and audio inputs into a single BatchEncoding.""" 582 | # Start with text inputs 583 | merged = BatchEncoding(text_inputs) 584 | 585 | # Add audio-specific fields 586 | if "audio" in audio_inputs: 587 | merged["speech_inputs"] = audio_inputs["audio"] 588 | if "streaming" in audio_inputs: 589 | merged["streaming"] = audio_inputs["streaming"] 590 | 591 | return merged 592 | 593 | def batch_decode(self, *args, **kwargs): 594 | """ 595 | This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.batch_decode`]. 596 | Please refer to the docstring of this method for more information. 597 | """ 598 | return self.tokenizer.batch_decode(*args, **kwargs) 599 | 600 | def decode(self, *args, **kwargs): 601 | """ 602 | This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.decode`]. 603 | Please refer to the docstring of this method for more information. 604 | """ 605 | return self.tokenizer.decode(*args, **kwargs) 606 | 607 | @property 608 | def model_input_names(self): 609 | """ 610 | Return the list of inputs accepted by the model. 611 | """ 612 | tokenizer_input_names = self.tokenizer.model_input_names 613 | audio_processor_input_names = self.audio_processor.model_input_names 614 | return list(dict.fromkeys(tokenizer_input_names + audio_processor_input_names + ["speech_inputs", "speech_input_mask"])) 615 | 616 | def save_audio(self, 617 | audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]], 618 | output_path: str = "output.wav", 619 | sampling_rate: Optional[int] = None, 620 | normalize: bool = False, 621 | batch_prefix: str = "audio_", 622 | ) -> str: 623 | """ 624 | Save audio data to a file. 625 | Args: 626 | audio (Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]): 627 | The audio data to save. Can be a single tensor/array or a list of them. 628 | output_path (str, optional): Path to save the audio file. Defaults to "output.wav". 629 | sampling_rate (int, optional): Sampling rate for the audio. If None, uses the processor's default. 630 | normalize (bool, optional): Whether to normalize the audio before saving. Defaults to False. 631 | batch_prefix (str, optional): Prefix for batch audio files. Defaults to "audio_". 632 | Returns: 633 | str: The path to the saved audio file. 634 | """ 635 | return self.audio_processor.save_audio(audio, output_path=output_path, sampling_rate=sampling_rate, normalize=normalize, batch_prefix=batch_prefix) 636 | 637 | __all__ = [ 638 | "VibeVoiceProcessor", 639 | ] --------------------------------------------------------------------------------