├── modules ├── __init__.py ├── model_info.py ├── patcher.py ├── utils.py └── loader.py ├── vibevoice ├── __init__.py ├── modular │ ├── __init__.py │ ├── sage_attention_patch.py │ ├── modular_vibevoice_text_tokenizer.py │ ├── modular_vibevoice_diffusion_head.py │ ├── streamer.py │ ├── configuration_vibevoice.py │ └── modeling_vibevoice.py ├── processor │ ├── __init__.py │ ├── vibevoice_tokenizer_processor.py │ └── vibevoice_processor.py ├── schedule │ ├── __init__.py │ └── timestep_sampler.py ├── scripts │ ├── __init__.py │ └── convert_nnscaler_checkpoint_to_transformers.py └── configs │ ├── qwen2.5_1.5b_64k.json │ ├── qwen2.5_7b_32k.json │ ├── default_VibeVoice-1.5B_config.json │ └── default_VibeVoice-Large_config.json ├── example_workflows ├── VibeVoice_example.png └── VibeVoice_example.json ├── requirements.txt ├── pyproject.toml ├── .github └── workflows │ └── publish.yml ├── LICENSE ├── .gitignore ├── __init__.py ├── vibevoice_nodes.py └── README.md /modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vibevoice/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vibevoice/modular/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vibevoice/processor/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vibevoice/schedule/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vibevoice/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /example_workflows/VibeVoice_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wildminder/ComfyUI-VibeVoice/HEAD/example_workflows/VibeVoice_example.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | accelerate 3 | torchaudio 4 | librosa 5 | numpy 6 | huggingface_hub 7 | einops 8 | scipy 9 | tokenizers 10 | soundfile 11 | s3tokenizer 12 | conformer 13 | safetensors 14 | transformers>=4.51.3 15 | diffusers 16 | tqdm 17 | bitsandbytes 18 | -------------------------------------------------------------------------------- /modules/model_info.py: -------------------------------------------------------------------------------- 1 | # This dictionary contains the configurations for official, downloadable models. 2 | MODEL_CONFIGS = { 3 | "VibeVoice-1.5B": { 4 | "repo_id": "microsoft/VibeVoice-1.5B", 5 | "size_gb": 3.0, 6 | }, 7 | "VibeVoice-Large": { 8 | "repo_id": "aoi-ot/VibeVoice-Large", 9 | "size_gb": 17.4, 10 | } 11 | } 12 | 13 | AVAILABLE_VIBEVOICE_MODELS = {} -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "ComfyUI-VibeVoice" 3 | description = "VibeVoice TTS. Expressive, long-form, multi-speaker conversational audio" 4 | version = "1.5.1" 5 | license = {file = "LICENSE"} 6 | dependencies = ["torch", "torchaudio", "librosa", "numpy", "huggingface_hub", "einops", "scipy", "tokenizers", "soundfile", "s3tokenizer", "tqdm", "conformer", "safetensors", "transformers", "diffusers", "bitsandbytes"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/wildminder/ComfyUI-VibeVoice" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "wildai" 14 | DisplayName = "ComfyUI-VibeVoice" 15 | Icon = "" 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /vibevoice/schedule/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | class UniformSampler: 6 | def __init__(self, timesteps = 1000): 7 | self.timesteps = timesteps 8 | def sample(self, batch_size, device): 9 | return torch.randint(0, self.timesteps, (batch_size,), device=device) 10 | 11 | class LogitNormalSampler: 12 | def __init__(self, timesteps = 1000, m = 0, s = 1): 13 | self.timesteps = timesteps 14 | timesteps = torch.linspace(0, 1, timesteps) 15 | logit = torch.log(timesteps / (1 - timesteps)) 16 | self.prob = torch.exp(-0.5 * (logit - m) ** 2 / s ** 2) / (s * math.sqrt(2 * math.pi)) 17 | def sample(self, batch_size, device): 18 | return torch.multinomial(self.prob, batch_size, replacement=True).to(device) 19 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | permissions: 12 | issues: write 13 | 14 | jobs: 15 | publish-node: 16 | name: Publish Custom Node to registry 17 | runs-on: ubuntu-latest 18 | if: ${{ github.repository_owner == 'wildminder' }} 19 | steps: 20 | - name: Check out code 21 | uses: actions/checkout@v4 22 | with: 23 | submodules: true 24 | - name: Publish Custom Node 25 | uses: Comfy-Org/publish-node-action@v1 26 | with: 27 | ## Add your own personal access token to your Github Repository secrets and reference it here. 28 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 WildAi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /modules/patcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gc 3 | import logging 4 | import comfy.model_patcher 5 | import comfy.model_management as model_management 6 | 7 | from .loader import LOADED_MODELS, logger 8 | 9 | class VibeVoicePatcher(comfy.model_patcher.ModelPatcher): 10 | """Custom ModelPatcher for managing VibeVoice models in ComfyUI.""" 11 | def __init__(self, model, attention_mode="eager", *args, **kwargs): 12 | super().__init__(model, *args, **kwargs) 13 | self.attention_mode = attention_mode 14 | self.cache_key = model.cache_key 15 | 16 | @property 17 | def is_loaded(self): 18 | """Check if the model is currently loaded in memory.""" 19 | return hasattr(self, 'model') and self.model is not None and hasattr(self.model, 'model') and self.model.model is not None 20 | 21 | def patch_model(self, device_to=None, *args, **kwargs): 22 | target_device = self.load_device 23 | if self.model.model is None: 24 | logger.info(f"Loading VibeVoice models for '{self.model.model_pack_name}' to {target_device}...") 25 | mode_names = { 26 | "eager": "Eager (Most Compatible)", 27 | "sdpa": "SDPA (Balanced Speed/Compatibility)", 28 | "flash_attention_2": "Flash Attention 2 (Fastest)", 29 | "sage": "SageAttention (Quantized High-Performance)", 30 | } 31 | logger.info(f"Attention Mode: {mode_names.get(self.attention_mode, self.attention_mode)}") 32 | self.model.load_model(target_device, self.attention_mode) 33 | self.model.model.to(target_device) 34 | return super().patch_model(device_to=target_device, *args, **kwargs) 35 | 36 | def unpatch_model(self, device_to=None, unpatch_weights=True, *args, **kwargs): 37 | if unpatch_weights: 38 | logger.info(f"Offloading VibeVoice models for '{self.model.model_pack_name}' ({self.attention_mode}) to {device_to}...") 39 | self.model.model = None 40 | self.model.processor = None 41 | 42 | if self.cache_key in LOADED_MODELS: 43 | del LOADED_MODELS[self.cache_key] 44 | logger.info(f"Cleared LOADED_MODELS cache for: {self.cache_key}") 45 | 46 | gc.collect() 47 | model_management.soft_empty_cache() 48 | 49 | return super().unpatch_model(device_to, unpatch_weights, *args, **kwargs) -------------------------------------------------------------------------------- /vibevoice/configs/qwen2.5_1.5b_64k.json: -------------------------------------------------------------------------------- 1 | { 2 | "_attn_implementation_autoset": true, 3 | "acoustic_vae_dim": 64, 4 | "acoustic_tokenizer_config": { 5 | "causal": true, 6 | "channels": 1, 7 | "conv_bias": true, 8 | "conv_norm": "none", 9 | "corpus_normalize": 0.0, 10 | "decoder_depths": null, 11 | "decoder_n_filters": 32, 12 | "decoder_ratios": [ 13 | 8, 14 | 5, 15 | 5, 16 | 4, 17 | 2, 18 | 2 19 | ], 20 | "disable_last_norm": true, 21 | "encoder_depths": "3-3-3-3-3-3-8", 22 | "encoder_n_filters": 32, 23 | "encoder_ratios": [ 24 | 8, 25 | 5, 26 | 5, 27 | 4, 28 | 2, 29 | 2 30 | ], 31 | "fix_std": 0.5, 32 | "layer_scale_init_value": 1e-06, 33 | "layernorm": "RMSNorm", 34 | "layernorm_elementwise_affine": true, 35 | "layernorm_eps": 1e-05, 36 | "mixer_layer": "depthwise_conv", 37 | "model_type": "vibepod_acoustic_tokenizer", 38 | "pad_mode": "constant", 39 | "std_dist_type": "gaussian", 40 | "vae_dim": 64, 41 | "weight_init_value": 0.01 42 | }, 43 | "decoder_config": { 44 | "attention_dropout": 0.0, 45 | "hidden_act": "silu", 46 | "hidden_size": 1536, 47 | "initializer_range": 0.02, 48 | "intermediate_size": 8960, 49 | "max_position_embeddings": 65536, 50 | "max_window_layers": 28, 51 | "model_type": "qwen2", 52 | "num_attention_heads": 12, 53 | "num_hidden_layers": 28, 54 | "num_key_value_heads": 2, 55 | "rms_norm_eps": 1e-06, 56 | "rope_scaling": null, 57 | "rope_theta": 1000000.0, 58 | "sliding_window": 4096, 59 | "tie_word_embeddings": true, 60 | "torch_dtype": "bfloat16", 61 | "use_cache": true, 62 | "use_sliding_window": false, 63 | "vocab_size": 151936 64 | }, 65 | "diffusion_head_config": { 66 | "ddpm_batch_mul": 4, 67 | "ddpm_beta_schedule": "cosine", 68 | "ddpm_num_inference_steps": 20, 69 | "ddpm_num_steps": 1000, 70 | "diffusion_type": "ddpm", 71 | "head_ffn_ratio": 3.0, 72 | "head_layers": 4, 73 | "hidden_size": 1536, 74 | "latent_size": 64, 75 | "model_type": "vibepod_diffusion_head", 76 | "prediction_type": "v_prediction", 77 | "rms_norm_eps": 1e-05, 78 | "speech_vae_dim": 64 79 | }, 80 | "model_type": "vibepod", 81 | "semantic_tokenizer_config": { 82 | "causal": true, 83 | "channels": 1, 84 | "conv_bias": true, 85 | "conv_norm": "none", 86 | "corpus_normalize": 0.0, 87 | "disable_last_norm": true, 88 | "encoder_depths": "3-3-3-3-3-3-8", 89 | "encoder_n_filters": 32, 90 | "encoder_ratios": [ 91 | 8, 92 | 5, 93 | 5, 94 | 4, 95 | 2, 96 | 2 97 | ], 98 | "fix_std": 0, 99 | "layer_scale_init_value": 1e-06, 100 | "layernorm": "RMSNorm", 101 | "layernorm_elementwise_affine": true, 102 | "layernorm_eps": 1e-05, 103 | "mixer_layer": "depthwise_conv", 104 | "model_type": "vibepod_semantic_tokenizer", 105 | "pad_mode": "constant", 106 | "std_dist_type": "none", 107 | "vae_dim": 128, 108 | "weight_init_value": 0.01 109 | }, 110 | "semantic_vae_dim": 128, 111 | "torch_dtype": "bfloat16" 112 | } 113 | -------------------------------------------------------------------------------- /vibevoice/configs/qwen2.5_7b_32k.json: -------------------------------------------------------------------------------- 1 | { 2 | "_attn_implementation_autoset": true, 3 | "acoustic_vae_dim": 64, 4 | "acoustic_tokenizer_config": { 5 | "causal": true, 6 | "channels": 1, 7 | "conv_bias": true, 8 | "conv_norm": "none", 9 | "corpus_normalize": 0.0, 10 | "decoder_depths": null, 11 | "decoder_n_filters": 32, 12 | "decoder_ratios": [ 13 | 8, 14 | 5, 15 | 5, 16 | 4, 17 | 2, 18 | 2 19 | ], 20 | "disable_last_norm": true, 21 | "encoder_depths": "3-3-3-3-3-3-8", 22 | "encoder_n_filters": 32, 23 | "encoder_ratios": [ 24 | 8, 25 | 5, 26 | 5, 27 | 4, 28 | 2, 29 | 2 30 | ], 31 | "fix_std": 0.5, 32 | "layer_scale_init_value": 1e-06, 33 | "layernorm": "RMSNorm", 34 | "layernorm_elementwise_affine": true, 35 | "layernorm_eps": 1e-05, 36 | "mixer_layer": "depthwise_conv", 37 | "model_type": "vibepod_acoustic_tokenizer", 38 | "pad_mode": "constant", 39 | "std_dist_type": "gaussian", 40 | "vae_dim": 64, 41 | "weight_init_value": 0.01 42 | }, 43 | "decoder_config": { 44 | "attention_dropout": 0.0, 45 | "hidden_act": "silu", 46 | "hidden_size": 3584, 47 | "initializer_range": 0.02, 48 | "intermediate_size": 18944, 49 | "max_position_embeddings": 32768, 50 | "max_window_layers": 28, 51 | "model_type": "qwen2", 52 | "num_attention_heads": 28, 53 | "num_hidden_layers": 28, 54 | "num_key_value_heads": 4, 55 | "rms_norm_eps": 1e-06, 56 | "rope_theta": 1000000.0, 57 | "sliding_window": 4096, 58 | "tie_word_embeddings": false, 59 | "torch_dtype": "bfloat16", 60 | "transformers_version": "4.40.1", 61 | "use_cache": true, 62 | "use_mrope": false, 63 | "use_sliding_window": false, 64 | "vocab_size": 152064 65 | }, 66 | "diffusion_head_config": { 67 | "ddpm_batch_mul": 4, 68 | "ddpm_beta_schedule": "cosine", 69 | "ddpm_num_inference_steps": 20, 70 | "ddpm_num_steps": 1000, 71 | "diffusion_type": "ddpm", 72 | "head_ffn_ratio": 3.0, 73 | "head_layers": 4, 74 | "hidden_size": 3584, 75 | "latent_size": 64, 76 | "model_type": "vibepod_diffusion_head", 77 | "prediction_type": "v_prediction", 78 | "rms_norm_eps": 1e-05, 79 | "speech_vae_dim": 64 80 | }, 81 | "model_type": "vibepod", 82 | "semantic_tokenizer_config": { 83 | "causal": true, 84 | "channels": 1, 85 | "conv_bias": true, 86 | "conv_norm": "none", 87 | "corpus_normalize": 0.0, 88 | "disable_last_norm": true, 89 | "encoder_depths": "3-3-3-3-3-3-8", 90 | "encoder_n_filters": 32, 91 | "encoder_ratios": [ 92 | 8, 93 | 5, 94 | 5, 95 | 4, 96 | 2, 97 | 2 98 | ], 99 | "fix_std": 0, 100 | "layer_scale_init_value": 1e-06, 101 | "layernorm": "RMSNorm", 102 | "layernorm_elementwise_affine": true, 103 | "layernorm_eps": 1e-05, 104 | "mixer_layer": "depthwise_conv", 105 | "model_type": "vibepod_semantic_tokenizer", 106 | "pad_mode": "constant", 107 | "std_dist_type": "none", 108 | "vae_dim": 128, 109 | "weight_init_value": 0.01 110 | }, 111 | "semantic_vae_dim": 128, 112 | "torch_dtype": "bfloat16" 113 | } 114 | -------------------------------------------------------------------------------- /vibevoice/configs/default_VibeVoice-1.5B_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "acoustic_vae_dim": 64, 3 | "acoustic_tokenizer_config": { 4 | "causal": true, 5 | "channels": 1, 6 | "conv_bias": true, 7 | "conv_norm": "none", 8 | "corpus_normalize": 0.0, 9 | "decoder_depths": null, 10 | "decoder_n_filters": 32, 11 | "decoder_ratios": [ 12 | 8, 13 | 5, 14 | 5, 15 | 4, 16 | 2, 17 | 2 18 | ], 19 | "disable_last_norm": true, 20 | "encoder_depths": "3-3-3-3-3-3-8", 21 | "encoder_n_filters": 32, 22 | "encoder_ratios": [ 23 | 8, 24 | 5, 25 | 5, 26 | 4, 27 | 2, 28 | 2 29 | ], 30 | "fix_std": 0.5, 31 | "layer_scale_init_value": 1e-06, 32 | "layernorm": "RMSNorm", 33 | "layernorm_elementwise_affine": true, 34 | "layernorm_eps": 1e-05, 35 | "mixer_layer": "depthwise_conv", 36 | "model_type": "vibevoice_acoustic_tokenizer", 37 | "pad_mode": "constant", 38 | "std_dist_type": "gaussian", 39 | "vae_dim": 64, 40 | "weight_init_value": 0.01 41 | }, 42 | "architectures": [ 43 | "VibeVoiceForConditionalGeneration" 44 | ], 45 | "decoder_config": { 46 | "attention_dropout": 0.0, 47 | "hidden_act": "silu", 48 | "hidden_size": 1536, 49 | "initializer_range": 0.02, 50 | "intermediate_size": 8960, 51 | "max_position_embeddings": 65536, 52 | "max_window_layers": 28, 53 | "model_type": "qwen2", 54 | "num_attention_heads": 12, 55 | "num_hidden_layers": 28, 56 | "num_key_value_heads": 2, 57 | "rms_norm_eps": 1e-06, 58 | "rope_scaling": null, 59 | "rope_theta": 1000000.0, 60 | "sliding_window": null, 61 | "tie_word_embeddings": true, 62 | "torch_dtype": "bfloat16", 63 | "use_cache": true, 64 | "use_sliding_window": false, 65 | "vocab_size": 151936 66 | }, 67 | "diffusion_head_config": { 68 | "ddpm_batch_mul": 4, 69 | "ddpm_beta_schedule": "cosine", 70 | "ddpm_num_inference_steps": 20, 71 | "ddpm_num_steps": 1000, 72 | "diffusion_type": "ddpm", 73 | "head_ffn_ratio": 3.0, 74 | "head_layers": 4, 75 | "hidden_size": 1536, 76 | "latent_size": 64, 77 | "model_type": "vibevoice_diffusion_head", 78 | "prediction_type": "v_prediction", 79 | "rms_norm_eps": 1e-05, 80 | "speech_vae_dim": 64 81 | }, 82 | "model_type": "vibevoice", 83 | "semantic_tokenizer_config": { 84 | "causal": true, 85 | "channels": 1, 86 | "conv_bias": true, 87 | "conv_norm": "none", 88 | "corpus_normalize": 0.0, 89 | "disable_last_norm": true, 90 | "encoder_depths": "3-3-3-3-3-3-8", 91 | "encoder_n_filters": 32, 92 | "encoder_ratios": [ 93 | 8, 94 | 5, 95 | 5, 96 | 4, 97 | 2, 98 | 2 99 | ], 100 | "fix_std": 0, 101 | "layer_scale_init_value": 1e-06, 102 | "layernorm": "RMSNorm", 103 | "layernorm_elementwise_affine": true, 104 | "layernorm_eps": 1e-05, 105 | "mixer_layer": "depthwise_conv", 106 | "model_type": "vibevoice_semantic_tokenizer", 107 | "pad_mode": "constant", 108 | "std_dist_type": "none", 109 | "vae_dim": 128, 110 | "weight_init_value": 0.01 111 | }, 112 | "semantic_vae_dim": 128, 113 | "torch_dtype": "bfloat16", 114 | "transformers_version": "4.51.3" 115 | } 116 | -------------------------------------------------------------------------------- /vibevoice/configs/default_VibeVoice-Large_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "acostic_vae_dim": 64, 3 | "acoustic_tokenizer_config": { 4 | "causal": true, 5 | "channels": 1, 6 | "conv_bias": true, 7 | "conv_norm": "none", 8 | "corpus_normalize": 0.0, 9 | "decoder_depths": null, 10 | "decoder_n_filters": 32, 11 | "decoder_ratios": [ 12 | 8, 13 | 5, 14 | 5, 15 | 4, 16 | 2, 17 | 2 18 | ], 19 | "disable_last_norm": true, 20 | "encoder_depths": "3-3-3-3-3-3-8", 21 | "encoder_n_filters": 32, 22 | "encoder_ratios": [ 23 | 8, 24 | 5, 25 | 5, 26 | 4, 27 | 2, 28 | 2 29 | ], 30 | "fix_std": 0.5, 31 | "layer_scale_init_value": 1e-06, 32 | "layernorm": "RMSNorm", 33 | "layernorm_elementwise_affine": true, 34 | "layernorm_eps": 1e-05, 35 | "mixer_layer": "depthwise_conv", 36 | "model_type": "vibevoice_acoustic_tokenizer", 37 | "pad_mode": "constant", 38 | "std_dist_type": "gaussian", 39 | "vae_dim": 64, 40 | "weight_init_value": 0.01 41 | }, 42 | "architectures": [ 43 | "VibeVoiceForConditionalGeneration" 44 | ], 45 | "decoder_config": { 46 | "attention_dropout": 0.0, 47 | "hidden_act": "silu", 48 | "hidden_size": 3584, 49 | "initializer_range": 0.02, 50 | "intermediate_size": 18944, 51 | "max_position_embeddings": 32768, 52 | "max_window_layers": 28, 53 | "model_type": "qwen2", 54 | "num_attention_heads": 28, 55 | "num_hidden_layers": 28, 56 | "num_key_value_heads": 4, 57 | "rms_norm_eps": 1e-06, 58 | "rope_scaling": null, 59 | "rope_theta": 1000000.0, 60 | "sliding_window": null, 61 | "torch_dtype": "bfloat16", 62 | "use_cache": true, 63 | "use_mrope": false, 64 | "use_sliding_window": false, 65 | "vocab_size": 152064 66 | }, 67 | "diffusion_head_config": { 68 | "ddpm_batch_mul": 4, 69 | "ddpm_beta_schedule": "cosine", 70 | "ddpm_num_inference_steps": 20, 71 | "ddpm_num_steps": 1000, 72 | "diffusion_type": "ddpm", 73 | "head_ffn_ratio": 3.0, 74 | "head_layers": 4, 75 | "hidden_size": 3584, 76 | "latent_size": 64, 77 | "model_type": "vibevoice_diffusion_head", 78 | "prediction_type": "v_prediction", 79 | "rms_norm_eps": 1e-05, 80 | "speech_vae_dim": 64 81 | }, 82 | "model_type": "vibevoice", 83 | "semantic_tokenizer_config": { 84 | "causal": true, 85 | "channels": 1, 86 | "conv_bias": true, 87 | "conv_norm": "none", 88 | "corpus_normalize": 0.0, 89 | "disable_last_norm": true, 90 | "encoder_depths": "3-3-3-3-3-3-8", 91 | "encoder_n_filters": 32, 92 | "encoder_ratios": [ 93 | 8, 94 | 5, 95 | 5, 96 | 4, 97 | 2, 98 | 2 99 | ], 100 | "fix_std": 0, 101 | "layer_scale_init_value": 1e-06, 102 | "layernorm": "RMSNorm", 103 | "layernorm_elementwise_affine": true, 104 | "layernorm_eps": 1e-05, 105 | "mixer_layer": "depthwise_conv", 106 | "model_type": "vibevoice_semantic_tokenizer", 107 | "pad_mode": "constant", 108 | "std_dist_type": "none", 109 | "vae_dim": 128, 110 | "weight_init_value": 0.01 111 | }, 112 | "semantic_vae_dim": 128, 113 | "tie_word_embeddings": false, 114 | "torch_dtype": "bfloat16", 115 | "transformers_version": "4.51.3" 116 | } 117 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .github 11 | .idea 12 | .Python 13 | __pycache__ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | import folder_paths 5 | import json 6 | 7 | try: 8 | import sageattention 9 | SAGE_ATTENTION_AVAILABLE = True 10 | except ImportError: 11 | SAGE_ATTENTION_AVAILABLE = False 12 | 13 | current_dir = os.path.dirname(os.path.abspath(__file__)) 14 | if current_dir not in sys.path: 15 | sys.path.append(current_dir) 16 | 17 | from .modules.model_info import AVAILABLE_VIBEVOICE_MODELS, MODEL_CONFIGS 18 | 19 | # Configure a logger 20 | logger = logging.getLogger(__name__) 21 | logger.setLevel(logging.INFO) 22 | logger.propagate = False 23 | if not logger.hasHandlers(): 24 | handler = logging.StreamHandler() 25 | formatter = logging.Formatter(f"[ComfyUI-VibeVoice] %(message)s") 26 | handler.setFormatter(formatter) 27 | logger.addHandler(handler) 28 | 29 | # This is just the *name* of the subdirectory, not the full path. 30 | VIBEVOICE_SUBDIR_NAME = "VibeVoice" 31 | 32 | # This is the *primary* path where official models will be downloaded. 33 | primary_vibevoice_models_path = os.path.join(folder_paths.models_dir, "tts", VIBEVOICE_SUBDIR_NAME) 34 | os.makedirs(primary_vibevoice_models_path, exist_ok=True) 35 | 36 | # Register the tts path type with ComfyUI so get_folder_paths works 37 | tts_path = os.path.join(folder_paths.models_dir, "tts") 38 | if "tts" not in folder_paths.folder_names_and_paths: 39 | supported_exts = folder_paths.supported_pt_extensions.union({".safetensors", ".json"}) 40 | folder_paths.folder_names_and_paths["tts"] = ([tts_path], supported_exts) 41 | else: 42 | # Ensure the default path is in the list if it's not already 43 | if tts_path not in folder_paths.folder_names_and_paths["tts"][0]: 44 | folder_paths.folder_names_and_paths["tts"][0].append(tts_path) 45 | 46 | # The logic for dynamic model discovery 47 | # ToDo: optimize finding 48 | 49 | # official models that can be auto-downloaded 50 | for model_name, config in MODEL_CONFIGS.items(): 51 | AVAILABLE_VIBEVOICE_MODELS[model_name] = { 52 | "type": "official", 53 | "repo_id": config["repo_id"], 54 | "tokenizer_repo": "Qwen/Qwen2.5-7B" if "Large" in model_name else "Qwen/Qwen2.5-1.5B" 55 | } 56 | 57 | # just workaround, default + custom 58 | vibevoice_search_paths = [] 59 | # Use ComfyUI's API to get all registered 'tts' folders 60 | for tts_folder in folder_paths.get_folder_paths("tts"): 61 | potential_path = os.path.join(tts_folder, VIBEVOICE_SUBDIR_NAME) 62 | if os.path.isdir(potential_path) and potential_path not in vibevoice_search_paths: 63 | vibevoice_search_paths.append(potential_path) 64 | 65 | # Add the primary path just in case it wasn't registered for some reason 66 | if primary_vibevoice_models_path not in vibevoice_search_paths: 67 | vibevoice_search_paths.insert(0, primary_vibevoice_models_path) 68 | 69 | # Messy... Discover all local models in the search paths 70 | for search_path in vibevoice_search_paths: 71 | logger.info(f"Scanning for VibeVoice models in: {search_path}") 72 | if not os.path.exists(search_path): continue 73 | for item in os.listdir(search_path): 74 | item_path = os.path.join(search_path, item) 75 | 76 | # Case 1: we have a standard HF directory 77 | if os.path.isdir(item_path): 78 | model_name = item 79 | if model_name in AVAILABLE_VIBEVOICE_MODELS: continue 80 | 81 | config_exists = os.path.exists(os.path.join(item_path, "config.json")) 82 | weights_exist = os.path.exists(os.path.join(item_path, "model.safetensors.index.json")) or any(f.endswith(('.safetensors', '.bin')) for f in os.listdir(item_path)) 83 | 84 | if config_exists and weights_exist: 85 | tokenizer_repo = "Qwen/Qwen2.5-7B" if "large" in model_name.lower() else "Qwen/Qwen2.5-1.5B" 86 | AVAILABLE_VIBEVOICE_MODELS[model_name] = { 87 | "type": "local_dir", 88 | "path": item_path, 89 | "tokenizer_repo": tokenizer_repo 90 | } 91 | 92 | # Case 2: Item is a standalone file 93 | elif os.path.isfile(item_path) and any(item.endswith(ext) for ext in folder_paths.supported_pt_extensions): 94 | model_name = os.path.splitext(item)[0] 95 | if model_name in AVAILABLE_VIBEVOICE_MODELS: continue 96 | 97 | tokenizer_repo = "Qwen/Qwen2.5-7B" if "large" in model_name.lower() else "Qwen/Qwen2.5-1.5B" 98 | AVAILABLE_VIBEVOICE_MODELS[model_name] = { 99 | "type": "standalone", 100 | "path": item_path, 101 | "tokenizer_repo": tokenizer_repo 102 | } 103 | 104 | logger.info(f"Discovered VibeVoice models: {sorted(list(AVAILABLE_VIBEVOICE_MODELS.keys()))}") 105 | 106 | from .vibevoice_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 107 | 108 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] -------------------------------------------------------------------------------- /modules/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import numpy as np 4 | import random 5 | import logging 6 | 7 | from comfy.utils import ProgressBar 8 | from comfy.model_management import throw_exception_if_processing_interrupted 9 | 10 | try: 11 | import librosa 12 | except ImportError: 13 | print("VibeVoice Node: `librosa` is not installed. Resampling of reference audio will not be available.") 14 | librosa = None 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | def set_vibevoice_seed(seed: int): 19 | """Sets the seed for torch, numpy, and random, handling large seeds for numpy.""" 20 | if seed == 0: 21 | seed = random.randint(1, 0xffffffffffffffff) 22 | 23 | MAX_NUMPY_SEED = 2**32 - 1 24 | numpy_seed = seed % MAX_NUMPY_SEED 25 | 26 | torch.manual_seed(seed) 27 | if torch.cuda.is_available(): 28 | torch.cuda.manual_seed_all(seed) 29 | np.random.seed(numpy_seed) 30 | random.seed(seed) 31 | 32 | def parse_script_1_based(script: str) -> tuple[list[tuple[int, str]], list[int]]: 33 | """ 34 | Parses a 1-based speaker script into a list of (speaker_id, text) tuples 35 | and a list of unique speaker IDs in the order of their first appearance. 36 | Internally, it converts speaker IDs to 0-based for the model. 37 | 38 | Supports two formats: 39 | 1. Speaker 1: Some text... 40 | 2. [1] Some text... 41 | 42 | If no speaker markers are found, the entire script is assigned to Speaker 1. 43 | """ 44 | parsed_lines = [] 45 | speaker_ids_in_script = [] # This will store the 1-based IDs from the script 46 | 47 | line_format_regex = re.compile(r'^(?:Speaker\s+(\d+)\s*:|\[(\d+)\])\s*(.*)$', re.IGNORECASE) 48 | 49 | for line in script.strip().split("\n"): 50 | if not (line := line.strip()): continue 51 | 52 | match = line_format_regex.match(line) 53 | if match: 54 | speaker_id_str = match.group(1) or match.group(2) 55 | speaker_id = int(speaker_id_str) 56 | text_content = match.group(3) 57 | 58 | if match.group(1) is None and text_content.lstrip().startswith(':'): 59 | colon_index = text_content.find(':') 60 | text_content = text_content[colon_index + 1:] 61 | 62 | if speaker_id < 1: 63 | logger.warning(f"Speaker ID must be 1 or greater. Skipping line: '{line}'") 64 | continue 65 | 66 | text = text_content.strip() # REMOVED the prepended space ' ' + 67 | internal_speaker_id = speaker_id - 1 68 | parsed_lines.append((internal_speaker_id, text)) 69 | 70 | if speaker_id not in speaker_ids_in_script: 71 | speaker_ids_in_script.append(speaker_id) 72 | else: 73 | logger.warning(f"Could not parse speaker marker, treating as part of previous line if any, or ignoring: '{line}'") 74 | 75 | if not parsed_lines and script.strip(): 76 | logger.info("No speaker markers found. Treating entire text as a single utterance for Speaker 1.") 77 | parsed_lines.append((0, ' ' + script.strip())) 78 | speaker_ids_in_script.append(1) 79 | 80 | return parsed_lines, sorted(list(set(speaker_ids_in_script))) 81 | 82 | 83 | def preprocess_comfy_audio(audio_dict: dict, target_sr: int = 24000) -> np.ndarray: 84 | """ 85 | Converts a ComfyUI AUDIO dict to a mono NumPy array, resampling if necessary. 86 | """ 87 | if not audio_dict: return None 88 | waveform_tensor = audio_dict.get('waveform') 89 | if waveform_tensor is None or waveform_tensor.numel() == 0: return None 90 | 91 | waveform = waveform_tensor[0].cpu().numpy() 92 | original_sr = audio_dict['sample_rate'] 93 | 94 | if waveform.ndim > 1: 95 | waveform = np.mean(waveform, axis=0) 96 | 97 | # Check for invalid values 98 | if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)): 99 | logger.error("Audio contains NaN or Inf values, replacing with zeros") 100 | waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0) 101 | 102 | # Ensure audio is not completely silent or has extreme values 103 | if np.all(waveform == 0): 104 | logger.warning("Audio waveform is completely silent") 105 | 106 | # Normalize extreme values 107 | max_val = np.abs(waveform).max() 108 | if max_val > 10.0: 109 | logger.warning(f"Audio values are very large (max: {max_val}), normalizing") 110 | waveform = waveform / max_val 111 | 112 | if original_sr != target_sr: 113 | if librosa is None: 114 | raise ImportError("`librosa` package is required for audio resampling. Please install it with `pip install librosa`.") 115 | logger.warning(f"Resampling reference audio from {original_sr}Hz to {target_sr}Hz.") 116 | waveform = librosa.resample(y=waveform, orig_sr=original_sr, target_sr=target_sr) 117 | 118 | # Final check after resampling 119 | if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)): 120 | logger.error("Audio contains NaN or Inf after resampling, replacing with zeros") 121 | waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0) 122 | 123 | return waveform.astype(np.float32) 124 | 125 | def check_for_interrupt(): 126 | try: 127 | throw_exception_if_processing_interrupted() 128 | return False 129 | except: 130 | return True -------------------------------------------------------------------------------- /vibevoice/modular/sage_attention_patch.py: -------------------------------------------------------------------------------- 1 | # Author: Wildminder 2 | # Desc: SageAttention and patcher 3 | # License: Apache 2.0 4 | 5 | import torch 6 | from typing import Optional, Tuple 7 | 8 | from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, apply_rotary_pos_emb, repeat_kv 9 | from transformers.cache_utils import Cache 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | try: 15 | from sageattention.core import ( 16 | sageattn_qk_int8_pv_fp16_cuda, 17 | sageattn_qk_int8_pv_fp8_cuda, 18 | sageattn_qk_int8_pv_fp8_cuda_sm90, 19 | ) 20 | SAGE_ATTENTION_AVAILABLE = True 21 | except ImportError: 22 | SAGE_ATTENTION_AVAILABLE = False 23 | 24 | 25 | def get_sage_attention_function_and_params(): 26 | """ 27 | Selects the best available SageAttention CUDA kernel and its parameters 28 | based on the current GPU architecture. 29 | """ 30 | if not SAGE_ATTENTION_AVAILABLE or not torch.cuda.is_available(): 31 | return None, None, None 32 | 33 | major, minor = torch.cuda.get_device_capability() 34 | arch_code = major * 10 + minor 35 | 36 | attn_func = None 37 | pv_accum_dtype = "fp32" 38 | 39 | if arch_code >= 120: # Blackwell 40 | pv_accum_dtype = "fp32+fp32" 41 | attn_func = sageattn_qk_int8_pv_fp8_cuda 42 | logger.info(f"SageAttention: Using SM120 (Blackwell) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.") 43 | elif arch_code >= 90: # Hopper 44 | pv_accum_dtype = "fp32+fp32" 45 | attn_func = sageattn_qk_int8_pv_fp8_cuda_sm90 46 | logger.info(f"SageAttention: Using SM90 (Hopper) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.") 47 | elif arch_code == 89: # Ada Lovelace 48 | pv_accum_dtype = "fp32+fp32" 49 | attn_func = sageattn_qk_int8_pv_fp8_cuda 50 | logger.info(f"SageAttention: Using SM89 (Ada) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.") 51 | elif arch_code >= 80: # Ampere 52 | pv_accum_dtype = "fp32" 53 | attn_func = sageattn_qk_int8_pv_fp16_cuda 54 | logger.info(f"SageAttention: Using SM80+ (Ampere) FP16 kernel with pv_accum_dtype='{pv_accum_dtype}'.") 55 | else: 56 | logger.warning(f"SageAttention not supported on current GPU architecture (SM{arch_code}).") 57 | return None, None, None 58 | 59 | return attn_func, "per_warp", pv_accum_dtype 60 | 61 | SAGE_ATTENTION_FUNCTION, QK_QUANT_GRAN, PV_ACCUM_DTYPE = get_sage_attention_function_and_params() 62 | 63 | 64 | def sage_attention_forward( 65 | self, 66 | hidden_states: torch.Tensor, 67 | position_embeddings: tuple[torch.Tensor, torch.Tensor], 68 | attention_mask: Optional[torch.Tensor] = None, 69 | past_key_values: Optional[Cache] = None, 70 | cache_position: Optional[torch.LongTensor] = None, 71 | **kwargs, 72 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 73 | 74 | if SAGE_ATTENTION_FUNCTION is None: 75 | raise RuntimeError("SageAttention was selected but no compatible kernel was found for this GPU.") 76 | 77 | original_dtype = hidden_states.dtype 78 | 79 | is_4bit = hasattr(self.q_proj, 'quant_state') 80 | if is_4bit: 81 | target_dtype = torch.bfloat16 82 | else: 83 | target_dtype = self.q_proj.weight.dtype 84 | 85 | if hidden_states.dtype != target_dtype: 86 | hidden_states = hidden_states.to(target_dtype) 87 | 88 | bsz, q_len, _ = hidden_states.size() 89 | 90 | query_states = self.q_proj(hidden_states) 91 | key_states = self.k_proj(hidden_states) 92 | value_states = self.v_proj(hidden_states) 93 | 94 | query_states = query_states.view(bsz, q_len, self.config.num_attention_heads, self.head_dim).transpose(1, 2) 95 | key_states = key_states.view(bsz, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2) 96 | value_states = value_states.view(bsz, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2) 97 | 98 | cos, sin = position_embeddings 99 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids=None) 100 | 101 | if past_key_values is not None: 102 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} 103 | key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) 104 | 105 | # !! DO NOT repeat K and V heads here. The SageAttention kernel is optimized 106 | # to handle the broadcasting internally. 107 | 108 | is_causal = attention_mask is None and q_len > 1 109 | 110 | attn_output = SAGE_ATTENTION_FUNCTION( 111 | query_states.to(target_dtype), 112 | key_states.to(target_dtype), 113 | value_states.to(target_dtype), 114 | tensor_layout="HND", 115 | is_causal=is_causal, 116 | qk_quant_gran=QK_QUANT_GRAN, 117 | pv_accum_dtype=PV_ACCUM_DTYPE, 118 | ) 119 | 120 | if isinstance(attn_output, tuple): 121 | attn_output = attn_output[0] 122 | 123 | attn_output = attn_output.transpose(1, 2).contiguous() 124 | attn_output = attn_output.reshape(bsz, q_len, self.config.hidden_size) 125 | 126 | attn_output = self.o_proj(attn_output) 127 | 128 | if attn_output.dtype != original_dtype: 129 | attn_output = attn_output.to(original_dtype) 130 | 131 | attn_weights = None 132 | 133 | return attn_output, attn_weights 134 | 135 | 136 | def set_sage_attention(model): 137 | """ 138 | Recursively iterates through the model's modules and monkey-patches the 139 | forward method of each Qwen2Attention layer. 140 | """ 141 | if not SAGE_ATTENTION_AVAILABLE: 142 | raise ImportError("SageAttention library is not installed or failed to load.") 143 | 144 | if SAGE_ATTENTION_FUNCTION is None: 145 | return 146 | 147 | for module in model.modules(): 148 | if isinstance(module, Qwen2Attention): 149 | module.forward = sage_attention_forward.__get__(module, Qwen2Attention) -------------------------------------------------------------------------------- /vibevoice/scripts/convert_nnscaler_checkpoint_to_transformers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import argparse 5 | import json 6 | import os 7 | from pathlib import Path 8 | import re 9 | import torch 10 | from typing import Dict, List, Tuple 11 | 12 | from vibevoice.modular.configuration_vibevoice import ( 13 | VibeVoiceConfig 14 | ) 15 | from vibevoice.modular.modeling_vibevoice import VibeVoiceForConditionalGeneration 16 | from transformers.utils import logging 17 | 18 | logger = logging.get_logger(__name__) 19 | 20 | def convert_vibevoice_nnscaler_checkpoint_to_hf( 21 | checkpoint_path: str, 22 | pytorch_dump_folder_path: str, 23 | config_path: str = None, 24 | ): 25 | """ 26 | Convert a nnscaler VibeVoice checkpoint to HuggingFace format. 27 | Supports both regular checkpoints and tensor parallel checkpoints. 28 | """ 29 | 30 | # Load regular checkpoint 31 | logger.info(f"Loading regular checkpoint from {checkpoint_path}") 32 | checkpoint = torch.load(checkpoint_path, map_location="cpu") # ['model', 'optimizer', 'lr_scheduler', 'train_status', 'train_args', 'rng_states', 'nnscaler', 'dataloader'] 33 | 34 | # config = checkpoint['train_args'] 35 | init_config_name = checkpoint['train_args']['vars']['model_args']['config_path']['relative_path'] 36 | pretrained_name = checkpoint['train_args']['vars']['data_args']['tokenizer_path'] 37 | 38 | init_config_path = Path(__file__).parent.parent / 'configs' / init_config_name.split('/')[-1] 39 | if init_config_path.exists(): 40 | logger.info(f"Loading initial config from {init_config_path}") 41 | with open(init_config_path, 'r') as f: 42 | init_config = json.load(f) 43 | else: 44 | raise FileNotFoundError(f"Initial config file {init_config_path} not found. Please provide a valid path.") 45 | 46 | tie_word_embeddings = init_config['decoder_config'].get('tie_word_embeddings', True) 47 | logger.info(f"Tie word embeddings: {tie_word_embeddings}") 48 | 49 | init_config['decoder_config']['use_cache'] = True 50 | config = VibeVoiceConfig(**init_config, tie_word_embeddings=tie_word_embeddings) 51 | 52 | # # Extract the model state dict 53 | model_state_dict = {k.replace('model.model.', 'model.'): v for k, v in checkpoint["model"].items() if k.startswith('model.model.')} 54 | if not tie_word_embeddings and 'model.lm_head.weight' in checkpoint["model"].keys(): 55 | # If not tying weights, we need to add the lm_head weight separately 56 | model_state_dict['lm_head.weight'] = checkpoint["model"]['model.lm_head.weight'] 57 | 58 | # Override with provided config if available 59 | if config_path: 60 | logger.info(f"Loading config from {config_path}") 61 | with open(config_path, 'r') as f: 62 | config_dict = json.load(f) 63 | config = VibeVoiceConfig.from_dict(config_dict) 64 | 65 | # Set the default dtype to bfloat16 before creating the model 66 | original_dtype = torch.get_default_dtype() 67 | torch.set_default_dtype(torch.bfloat16) 68 | 69 | # Create the HuggingFace model 70 | logger.info("Creating HuggingFace VibeVoiceForConditionalGeneration model") 71 | model = VibeVoiceForConditionalGeneration(config) 72 | 73 | # Restore original dtype 74 | torch.set_default_dtype(original_dtype) 75 | 76 | # Load the state dict 77 | logger.info("Loading weights into model") 78 | missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False) 79 | 80 | if missing_keys: 81 | logger.warning(f"Missing keys: {missing_keys}") 82 | if unexpected_keys: 83 | logger.warning(f"Unexpected keys: {unexpected_keys}") 84 | 85 | # Create output directory 86 | os.makedirs(pytorch_dump_folder_path, exist_ok=True) 87 | 88 | # Save the model and config 89 | logger.info(f"Saving model to {pytorch_dump_folder_path}") 90 | 91 | # Save config 92 | config.save_pretrained(pytorch_dump_folder_path) 93 | 94 | # Save VibeVoiceProcessor configuration 95 | logger.info("Saving VibeVoiceProcessor configuration") 96 | processor_config = { 97 | "processor_class": "VibeVoiceProcessor", 98 | "speech_tok_compress_ratio": 3200, 99 | "db_normalize": True, 100 | # Audio processor configuration 101 | "audio_processor": { 102 | "feature_extractor_type": "VibeVoiceTokenizerProcessor", 103 | "sampling_rate": 24000, 104 | "normalize_audio": True, 105 | "target_dB_FS": -25, 106 | "eps": 1e-6, 107 | }, 108 | "language_model_pretrained_name": pretrained_name, 109 | } 110 | 111 | processor_config_path = os.path.join(pytorch_dump_folder_path, "preprocessor_config.json") 112 | with open(processor_config_path, 'w') as f: 113 | json.dump(processor_config, f, indent=2) 114 | logger.info(f"Saved processor config to {processor_config_path}") 115 | 116 | # Save model with sharding 117 | # save_pretrained handles tied weights automatically 118 | logger.info("Saving model weights with sharding...") 119 | model.save_pretrained( 120 | pytorch_dump_folder_path, 121 | max_shard_size="2GB", # Set maximum size for each shard 122 | safe_serialization=True # Ensure saving in .safetensors format 123 | ) 124 | logger.info(f"Model weights saved to {pytorch_dump_folder_path}") 125 | 126 | logger.info("Conversion complete!") 127 | 128 | # Verify the saved model can be loaded 129 | logger.info("Verifying saved model...") 130 | loaded_model = VibeVoiceForConditionalGeneration.from_pretrained(pytorch_dump_folder_path) 131 | logger.info("Model successfully loaded from saved checkpoint!") 132 | 133 | def main(): 134 | parser = argparse.ArgumentParser() 135 | parser.add_argument( 136 | "--nnscaler_checkpoint_path", 137 | type=str, 138 | required=True, 139 | help="Path to the fairseq checkpoint (.pt file). For tensor parallel checkpoints, " 140 | "provide any one of the part files (e.g., checkpoint_1_5000-model_part-0.pt), " 141 | "and the script will automatically detect and merge all parts.", 142 | ) 143 | parser.add_argument( 144 | "--pytorch_dump_folder_path", 145 | type=str, 146 | required=True, 147 | help="Path to the output PyTorch model directory", 148 | ) 149 | parser.add_argument( 150 | "--config_path", 151 | type=str, 152 | default=None, 153 | help="Optional path to a config JSON file to override extracted config", 154 | ) 155 | 156 | args = parser.parse_args() 157 | 158 | convert_vibevoice_nnscaler_checkpoint_to_hf( 159 | args.nnscaler_checkpoint_path, 160 | args.pytorch_dump_folder_path, 161 | args.config_path, 162 | ) 163 | 164 | 165 | if __name__ == "__main__": 166 | main() -------------------------------------------------------------------------------- /vibevoice/modular/modular_vibevoice_text_tokenizer.py: -------------------------------------------------------------------------------- 1 | """Tokenization classes for vibevoice.""" 2 | 3 | from typing import List, Optional, Union 4 | 5 | from transformers.utils import logging 6 | from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer 7 | from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast 8 | 9 | logger = logging.get_logger(__name__) 10 | 11 | 12 | class VibeVoiceTextTokenizer(Qwen2Tokenizer): 13 | """ 14 | Construct a VibeVoice tokenizer. Based on the Qwen2 tokenizer with additional special tokens for speech. 15 | 16 | Args: 17 | vocab_file (`str`): 18 | Path to the vocabulary file. 19 | merges_file (`str`): 20 | Path to the merges file. 21 | errors (`str`, *optional*, defaults to `"replace"`): 22 | Paradigm to follow when decoding bytes to UTF-8. 23 | unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): 24 | The unknown token. 25 | bos_token (`str`, *optional*): 26 | The beginning of sequence token. Not used for vibevoice. 27 | eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): 28 | The end of sequence token. 29 | pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`): 30 | The token used for padding. 31 | add_special_tokens (`bool`, *optional*, defaults to `True`): 32 | Whether or not to add special tokens when encoding. 33 | """ 34 | 35 | model_input_names = ["input_ids", "attention_mask"] 36 | 37 | def __init__( 38 | self, 39 | vocab_file, 40 | merges_file, 41 | errors="replace", 42 | unk_token="<|endoftext|>", 43 | bos_token=None, 44 | eos_token="<|endoftext|>", 45 | pad_token="<|endoftext|>", 46 | add_prefix_space=False, 47 | add_special_tokens=True, 48 | **kwargs, 49 | ): 50 | super().__init__( 51 | vocab_file=vocab_file, 52 | merges_file=merges_file, 53 | errors=errors, 54 | unk_token=unk_token, 55 | bos_token=bos_token, 56 | eos_token=eos_token, 57 | pad_token=pad_token, 58 | add_prefix_space=add_prefix_space, 59 | add_special_tokens=add_special_tokens, 60 | **kwargs, 61 | ) 62 | 63 | # Add VibeVoice-specific special tokens 64 | self._add_vibevoice_special_tokens() 65 | 66 | def _add_vibevoice_special_tokens(self): 67 | """Add VibeVoice-specific special tokens.""" 68 | special_tokens = { 69 | "additional_special_tokens": [ 70 | "<|vision_start|>", # Speech start (reusing vision tokens) 71 | "<|vision_end|>", # Speech end 72 | "<|vision_pad|>", # Speech diffusion pad 73 | ] 74 | } 75 | num_added = self.add_special_tokens(special_tokens) 76 | 77 | # Cache special token IDs 78 | self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>") 79 | self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>") 80 | self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>") 81 | 82 | self._eos_id = self.convert_tokens_to_ids('<|endoftext|>') 83 | 84 | return num_added 85 | 86 | @property 87 | def eos_id(self) -> int: 88 | """Id of the end of sequence token.""" 89 | return self._eos_id 90 | 91 | @property 92 | def speech_start_id(self) -> int: 93 | """Id of the speech start token.""" 94 | return self._speech_start_id 95 | 96 | @property 97 | def speech_end_id(self) -> int: 98 | """Id of the speech end token.""" 99 | return self._speech_end_id 100 | 101 | @property 102 | def speech_diffusion_id(self) -> int: 103 | """Id of the speech diffusion token.""" 104 | return self._speech_diffusion_id 105 | 106 | @property 107 | def pad_id(self) -> int: 108 | """Id used for padding (returns -100 for loss masking).""" 109 | return -100 110 | 111 | 112 | class VibeVoiceTextTokenizerFast(Qwen2TokenizerFast): 113 | """ 114 | Construct a "fast" VibeVoice tokenizer (backed by HuggingFace's *tokenizers* library). 115 | Based on the Qwen2 tokenizer with additional special tokens for speech. 116 | 117 | Args: 118 | vocab_file (`str`, *optional*): 119 | Path to the vocabulary file. 120 | merges_file (`str`, *optional*): 121 | Path to the merges file. 122 | tokenizer_file (`str`, *optional*): 123 | Path to [tokenizers](https://github.com/huggingface/tokenizers) file. 124 | unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): 125 | The unknown token. 126 | bos_token (`str`, *optional*): 127 | The beginning of sequence token. Not used for vibevoice. 128 | eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): 129 | The end of sequence token. 130 | pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`): 131 | The token used for padding. 132 | """ 133 | 134 | model_input_names = ["input_ids", "attention_mask"] 135 | 136 | def __init__( 137 | self, 138 | vocab_file=None, 139 | merges_file=None, 140 | tokenizer_file=None, 141 | unk_token="<|endoftext|>", 142 | bos_token=None, 143 | eos_token="<|endoftext|>", 144 | pad_token="<|endoftext|>", 145 | add_prefix_space=False, 146 | **kwargs, 147 | ): 148 | super().__init__( 149 | vocab_file=vocab_file, 150 | merges_file=merges_file, 151 | tokenizer_file=tokenizer_file, 152 | unk_token=unk_token, 153 | bos_token=bos_token, 154 | eos_token=eos_token, 155 | pad_token=pad_token, 156 | add_prefix_space=add_prefix_space, 157 | **kwargs, 158 | ) 159 | 160 | # Add VibeVoice-specific special tokens 161 | self._add_vibevoice_special_tokens() 162 | 163 | def _add_vibevoice_special_tokens(self): 164 | """Add VibeVoice-specific special tokens.""" 165 | special_tokens = { 166 | "additional_special_tokens": [ 167 | "<|vision_start|>", # Speech start (reusing vision tokens) 168 | "<|vision_end|>", # Speech end 169 | "<|vision_pad|>", # Speech diffusion pad 170 | ] 171 | } 172 | num_added = self.add_special_tokens(special_tokens) 173 | 174 | # Cache special token IDs 175 | self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>") 176 | self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>") 177 | self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>") 178 | 179 | # self._eos_id = self.convert_tokens_to_ids('<|endoftext|>') 180 | self._eos_id = self.eos_token_id # qwen2 / qwen3 181 | self._pad_id = self.convert_tokens_to_ids('<|image_pad|>') 182 | 183 | return num_added 184 | 185 | @property 186 | def eos_id(self) -> int: 187 | """Id of the end of sequence token.""" 188 | return self._eos_id 189 | 190 | @property 191 | def speech_start_id(self) -> int: 192 | """Id of the speech start token.""" 193 | return self._speech_start_id 194 | 195 | @property 196 | def speech_end_id(self) -> int: 197 | """Id of the speech end token.""" 198 | return self._speech_end_id 199 | 200 | @property 201 | def speech_diffusion_id(self) -> int: 202 | """Id of the speech diffusion token.""" 203 | return self._speech_diffusion_id 204 | 205 | @property 206 | def pad_id(self) -> int: 207 | """Id used for padding (returns -100 for loss masking).""" 208 | return self._pad_id 209 | 210 | 211 | __all__ = [ 212 | "VibeVoiceTextTokenizer", 213 | "VibeVoiceTextTokenizerFast", 214 | ] -------------------------------------------------------------------------------- /vibevoice/modular/modular_vibevoice_diffusion_head.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from transformers.models.auto import AutoModel 9 | from transformers.modeling_utils import PreTrainedModel 10 | # from transformers.modeling_layers import GradientCheckpointingLayer 11 | from transformers.activations import ACT2FN 12 | from transformers.utils import logging 13 | 14 | from .configuration_vibevoice import VibeVoiceDiffusionHeadConfig 15 | 16 | 17 | logger = logging.get_logger(__name__) 18 | 19 | 20 | class RMSNorm(nn.Module): 21 | def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False): 22 | super().__init__() 23 | self.dim = dim 24 | self.eps = eps 25 | self.elementwise_affine = elementwise_affine 26 | if self.elementwise_affine: 27 | self.weight = nn.Parameter(torch.ones(dim)) 28 | else: 29 | self.register_parameter('weight', None) 30 | 31 | def _norm(self, x): 32 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 33 | 34 | def forward(self, x): 35 | output = self._norm(x.float()).type_as(x) 36 | if self.weight is not None: 37 | output = output * self.weight 38 | return output 39 | 40 | def extra_repr(self) -> str: 41 | return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}' 42 | 43 | def modulate(x, shift, scale): 44 | """Apply modulation to input tensor.""" 45 | return x * (1 + scale) + shift 46 | 47 | 48 | class TimestepEmbedder(nn.Module): 49 | """ 50 | Embeds scalar timesteps into vector representations. 51 | 52 | Args: 53 | hidden_size (`int`): Size of the output embedding 54 | frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding 55 | """ 56 | def __init__(self, hidden_size, frequency_embedding_size=256): 57 | super().__init__() 58 | self.mlp = nn.Sequential( 59 | nn.Linear(frequency_embedding_size, hidden_size, bias=False), 60 | # nn.SiLU(), 61 | ACT2FN['silu'], 62 | nn.Linear(hidden_size, hidden_size, bias=False), 63 | ) 64 | self.frequency_embedding_size = frequency_embedding_size 65 | 66 | @staticmethod 67 | def timestep_embedding(t, dim, max_period=10000): 68 | """ 69 | Create sinusoidal timestep embeddings. 70 | 71 | Args: 72 | t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element. 73 | These may be fractional. 74 | dim (`int`): The dimension of the output. 75 | max_period (`int`, optional): Controls the minimum frequency of the embeddings. 76 | 77 | Returns: 78 | `torch.Tensor`: An [N, D] Tensor of positional embeddings. 79 | """ 80 | half = dim // 2 81 | freqs = torch.exp( 82 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 83 | ).to(t.device) 84 | args = t[:, None].float() * freqs[None] 85 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 86 | if dim % 2: 87 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 88 | return embedding.to(t.dtype) 89 | 90 | def forward(self, t): 91 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 92 | t_emb = self.mlp(t_freq) 93 | return t_emb 94 | 95 | 96 | class FeedForwardNetwork(nn.Module): 97 | """ 98 | Standard feed-forward network with SwiGLU activation. 99 | 100 | Args: 101 | embed_dim (`int`): Input dimension 102 | ffn_dim (`int`): Hidden dimension 103 | """ 104 | def __init__( 105 | self, 106 | embed_dim, 107 | ffn_dim, 108 | ): 109 | super().__init__() 110 | self.embed_dim = embed_dim 111 | self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False) 112 | self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False) 113 | self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False) 114 | self.act_fn = ACT2FN['silu'] # Using SiLU as the activation function 115 | 116 | def forward(self, x): 117 | gate = self.gate_proj(x) 118 | up = self.up_proj(x) 119 | 120 | # SwiGLU activation 121 | # gate = F.silu(gate) 122 | gate = self.act_fn(gate) 123 | return self.down_proj(gate * up) 124 | 125 | 126 | class HeadLayer(nn.Module): 127 | """ 128 | A layer in the diffusion head. 129 | 130 | Args: 131 | embed_dim (`int`): Input dimension 132 | ffn_dim (`int`): Hidden dimension 133 | cond_dim (`int`): Condition embedding dimension 134 | norm_eps (`float`, optional): Epsilon for normalization 135 | """ 136 | def __init__( 137 | self, 138 | embed_dim, 139 | ffn_dim, 140 | cond_dim, 141 | norm_eps=1e-5, 142 | ): 143 | super().__init__() 144 | self.embed_dim = embed_dim 145 | self.cond_dim = cond_dim 146 | self.ffn_dim = ffn_dim 147 | self.ffn = FeedForwardNetwork( 148 | self.embed_dim, 149 | self.ffn_dim, 150 | ) 151 | self.norm = RMSNorm(self.embed_dim, eps=norm_eps) 152 | self.adaLN_modulation = nn.Sequential( 153 | # nn.SiLU(), 154 | ACT2FN['silu'], 155 | nn.Linear(cond_dim, 3 * self.embed_dim, bias=False) 156 | ) 157 | 158 | def forward(self, x, c): 159 | shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1) 160 | x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn)) 161 | return x 162 | 163 | 164 | class FinalLayer(nn.Module): 165 | """ 166 | Final layer in the diffusion head. 167 | 168 | Args: 169 | hidden_size (`int`): Input dimension 170 | output_size (`int`): Output dimension 171 | cond_size (`int`): Condition embedding dimension 172 | norm_eps (`float`, optional): Epsilon for normalization 173 | """ 174 | def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5): 175 | super().__init__() 176 | self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False) 177 | self.linear = nn.Linear(hidden_size, output_size, bias=False) 178 | self.adaLN_modulation = nn.Sequential( 179 | # nn.SiLU(), 180 | ACT2FN['silu'], 181 | nn.Linear(cond_size, 2 * hidden_size, bias=False) 182 | ) 183 | 184 | def forward(self, x, c): 185 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) 186 | x = modulate(self.norm_final(x), shift, scale) 187 | x = self.linear(x) 188 | return x 189 | 190 | 191 | class VibeVoiceDiffusionHead(PreTrainedModel): 192 | """ 193 | Diffusion head model for vibevoice. 194 | 195 | Args: 196 | config (`VibeVoiceDiffusionHeadConfig`): Model configuration 197 | latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`. 198 | """ 199 | config_class = VibeVoiceDiffusionHeadConfig 200 | supports_gradient_checkpointing = True 201 | _supports_flash_attn_2 = True 202 | _supports_sdpa = True 203 | 204 | def __init__( 205 | self, 206 | config, 207 | ): 208 | super().__init__(config) 209 | self.config = config 210 | self.cond_dim = config.hidden_size 211 | latent_size = config.latent_size 212 | 213 | self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False) 214 | self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False) 215 | self.t_embedder = TimestepEmbedder(self.cond_dim) 216 | 217 | ffn_dim = int(config.hidden_size * config.head_ffn_ratio) 218 | 219 | # Create the intermediate layers 220 | self.layers = nn.ModuleList([ 221 | HeadLayer( 222 | embed_dim=config.hidden_size, 223 | ffn_dim=ffn_dim, 224 | cond_dim=self.cond_dim, 225 | norm_eps=config.rms_norm_eps 226 | ) 227 | for _ in range(config.head_layers) 228 | ]) 229 | 230 | # Final layer for output 231 | self.final_layer = FinalLayer( 232 | hidden_size=config.hidden_size, 233 | output_size=latent_size, 234 | cond_size=self.cond_dim, 235 | norm_eps=config.rms_norm_eps 236 | ) 237 | 238 | self.initialize_weights() 239 | 240 | def initialize_weights(self): 241 | """Initialize the weights of the model.""" 242 | # Initialize timestep embedder 243 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 244 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 245 | 246 | # Zero-out adaLN modulation layers 247 | for layer in self.layers: 248 | nn.init.constant_(layer.adaLN_modulation[-1].weight, 0) 249 | 250 | # Zero-out output layers 251 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 252 | nn.init.constant_(self.final_layer.linear.weight, 0) 253 | 254 | def forward( 255 | self, 256 | noisy_images, 257 | timesteps, 258 | condition, 259 | ): 260 | """ 261 | Forward pass of the prediction head. 262 | 263 | Args: 264 | noisy_images (`torch.Tensor`): Noisy images/latents to denoise 265 | timesteps (`torch.Tensor`): Timesteps for diffusion 266 | condition (`torch.Tensor`): Conditioning information 267 | 268 | Returns: 269 | `torch.Tensor`: The predicted noise/velocity 270 | """ 271 | x = self.noisy_images_proj(noisy_images) 272 | t = self.t_embedder(timesteps) 273 | condition = self.cond_proj(condition) 274 | c = condition + t 275 | 276 | for layer in self.layers: 277 | x = layer(x, c) 278 | 279 | x = self.final_layer(x, c) 280 | return x 281 | 282 | 283 | AutoModel.register(VibeVoiceDiffusionHeadConfig, VibeVoiceDiffusionHead) 284 | 285 | __all__ = [ 286 | "VibeVoiceDiffusionHead", 287 | ] -------------------------------------------------------------------------------- /vibevoice/modular/streamer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | 5 | import asyncio 6 | from queue import Queue 7 | from typing import TYPE_CHECKING, Optional 8 | 9 | 10 | from transformers.generation import BaseStreamer 11 | 12 | 13 | class AudioStreamer(BaseStreamer): 14 | """ 15 | Audio streamer that stores audio chunks in queues for each sample in the batch. 16 | This allows streaming audio generation for multiple samples simultaneously. 17 | 18 | Parameters: 19 | batch_size (`int`): 20 | The batch size for generation 21 | stop_signal (`any`, *optional*): 22 | The signal to put in the queue when generation ends. Defaults to None. 23 | timeout (`float`, *optional*): 24 | The timeout for the audio queue. If `None`, the queue will block indefinitely. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | batch_size: int, 30 | stop_signal: Optional[any] = None, 31 | timeout: Optional[float] = None, 32 | ): 33 | self.batch_size = batch_size 34 | self.stop_signal = stop_signal 35 | self.timeout = timeout 36 | 37 | # Create a queue for each sample in the batch 38 | self.audio_queues = [Queue() for _ in range(batch_size)] 39 | self.finished_flags = [False for _ in range(batch_size)] 40 | self.sample_indices_map = {} # Maps from sample index to queue index 41 | 42 | def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor): 43 | """ 44 | Receives audio chunks and puts them in the appropriate queues. 45 | 46 | Args: 47 | audio_chunks: Tensor of shape (num_samples, ...) containing audio chunks 48 | sample_indices: Tensor indicating which samples these chunks belong to 49 | """ 50 | for i, sample_idx in enumerate(sample_indices): 51 | idx = sample_idx.item() 52 | if idx < self.batch_size and not self.finished_flags[idx]: 53 | # Convert to numpy or keep as tensor based on preference 54 | audio_chunk = audio_chunks[i].detach().cpu() 55 | self.audio_queues[idx].put(audio_chunk, timeout=self.timeout) 56 | 57 | def end(self, sample_indices: Optional[torch.Tensor] = None): 58 | """ 59 | Signals the end of generation for specified samples or all samples. 60 | 61 | Args: 62 | sample_indices: Optional tensor of sample indices to end. If None, ends all. 63 | """ 64 | if sample_indices is None: 65 | # End all samples 66 | for idx in range(self.batch_size): 67 | if not self.finished_flags[idx]: 68 | self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout) 69 | self.finished_flags[idx] = True 70 | else: 71 | # End specific samples 72 | for sample_idx in sample_indices: 73 | idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx 74 | if idx < self.batch_size and not self.finished_flags[idx]: 75 | self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout) 76 | self.finished_flags[idx] = True 77 | 78 | def __iter__(self): 79 | """Returns an iterator over the batch of audio streams.""" 80 | return AudioBatchIterator(self) 81 | 82 | def get_stream(self, sample_idx: int): 83 | """Get the audio stream for a specific sample.""" 84 | if sample_idx >= self.batch_size: 85 | raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}") 86 | return AudioSampleIterator(self, sample_idx) 87 | 88 | 89 | class AudioSampleIterator: 90 | """Iterator for a single audio stream from the batch.""" 91 | 92 | def __init__(self, streamer: AudioStreamer, sample_idx: int): 93 | self.streamer = streamer 94 | self.sample_idx = sample_idx 95 | 96 | def __iter__(self): 97 | return self 98 | 99 | def __next__(self): 100 | value = self.streamer.audio_queues[self.sample_idx].get(timeout=self.streamer.timeout) 101 | if value == self.streamer.stop_signal: 102 | raise StopIteration() 103 | return value 104 | 105 | 106 | class AudioBatchIterator: 107 | """Iterator that yields audio chunks for all samples in the batch.""" 108 | 109 | def __init__(self, streamer: AudioStreamer): 110 | self.streamer = streamer 111 | self.active_samples = set(range(streamer.batch_size)) 112 | 113 | def __iter__(self): 114 | return self 115 | 116 | def __next__(self): 117 | if not self.active_samples: 118 | raise StopIteration() 119 | 120 | batch_chunks = {} 121 | samples_to_remove = set() 122 | 123 | # Try to get chunks from all active samples 124 | for idx in self.active_samples: 125 | try: 126 | value = self.streamer.audio_queues[idx].get(block=False) 127 | if value == self.streamer.stop_signal: 128 | samples_to_remove.add(idx) 129 | else: 130 | batch_chunks[idx] = value 131 | except: 132 | # Queue is empty for this sample, skip it this iteration 133 | pass 134 | 135 | # Remove finished samples 136 | self.active_samples -= samples_to_remove 137 | 138 | if batch_chunks: 139 | return batch_chunks 140 | elif self.active_samples: 141 | # If no chunks were ready but we still have active samples, 142 | # wait a bit and try again 143 | import time 144 | time.sleep(0.01) 145 | return self.__next__() 146 | else: 147 | raise StopIteration() 148 | 149 | 150 | class AsyncAudioStreamer(AudioStreamer): 151 | """ 152 | Async version of AudioStreamer for use in async contexts. 153 | """ 154 | 155 | def __init__( 156 | self, 157 | batch_size: int, 158 | stop_signal: Optional[any] = None, 159 | timeout: Optional[float] = None, 160 | ): 161 | super().__init__(batch_size, stop_signal, timeout) 162 | # Replace regular queues with async queues 163 | self.audio_queues = [asyncio.Queue() for _ in range(batch_size)] 164 | self.loop = asyncio.get_running_loop() 165 | 166 | def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor): 167 | """Put audio chunks in the appropriate async queues.""" 168 | for i, sample_idx in enumerate(sample_indices): 169 | idx = sample_idx.item() 170 | if idx < self.batch_size and not self.finished_flags[idx]: 171 | audio_chunk = audio_chunks[i].detach().cpu() 172 | self.loop.call_soon_threadsafe( 173 | self.audio_queues[idx].put_nowait, audio_chunk 174 | ) 175 | 176 | def end(self, sample_indices: Optional[torch.Tensor] = None): 177 | """Signal the end of generation for specified samples.""" 178 | if sample_indices is None: 179 | indices_to_end = range(self.batch_size) 180 | else: 181 | indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices] 182 | 183 | for idx in indices_to_end: 184 | if idx < self.batch_size and not self.finished_flags[idx]: 185 | self.loop.call_soon_threadsafe( 186 | self.audio_queues[idx].put_nowait, self.stop_signal 187 | ) 188 | self.finished_flags[idx] = True 189 | 190 | async def get_stream(self, sample_idx: int): 191 | """Get async iterator for a specific sample's audio stream.""" 192 | if sample_idx >= self.batch_size: 193 | raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}") 194 | 195 | while True: 196 | value = await self.audio_queues[sample_idx].get() 197 | if value == self.stop_signal: 198 | break 199 | yield value 200 | 201 | def __aiter__(self): 202 | """Returns an async iterator over all audio streams.""" 203 | return AsyncAudioBatchIterator(self) 204 | 205 | 206 | class AsyncAudioBatchIterator: 207 | """Async iterator for batch audio streaming.""" 208 | 209 | def __init__(self, streamer: AsyncAudioStreamer): 210 | self.streamer = streamer 211 | self.active_samples = set(range(streamer.batch_size)) 212 | 213 | def __aiter__(self): 214 | return self 215 | 216 | async def __anext__(self): 217 | if not self.active_samples: 218 | raise StopAsyncIteration() 219 | 220 | batch_chunks = {} 221 | samples_to_remove = set() 222 | 223 | # Create tasks for all active samples 224 | tasks = { 225 | idx: asyncio.create_task(self._get_chunk(idx)) 226 | for idx in self.active_samples 227 | } 228 | 229 | # Wait for at least one chunk to be ready 230 | done, pending = await asyncio.wait( 231 | tasks.values(), 232 | return_when=asyncio.FIRST_COMPLETED, 233 | timeout=self.streamer.timeout 234 | ) 235 | 236 | # Cancel pending tasks 237 | for task in pending: 238 | task.cancel() 239 | 240 | # Process completed tasks 241 | for idx, task in tasks.items(): 242 | if task in done: 243 | try: 244 | value = await task 245 | if value == self.streamer.stop_signal: 246 | samples_to_remove.add(idx) 247 | else: 248 | batch_chunks[idx] = value 249 | except asyncio.CancelledError: 250 | pass 251 | 252 | self.active_samples -= samples_to_remove 253 | 254 | if batch_chunks: 255 | return batch_chunks 256 | elif self.active_samples: 257 | # Try again if we still have active samples 258 | return await self.__anext__() 259 | else: 260 | raise StopAsyncIteration() 261 | 262 | async def _get_chunk(self, idx): 263 | """Helper to get a chunk from a specific queue.""" 264 | return await self.streamer.audio_queues[idx].get() -------------------------------------------------------------------------------- /example_workflows/VibeVoice_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "b91265e5-1b03-4b63-8dc3-4abd9a030e08", 3 | "revision": 0, 4 | "last_node_id": 14, 5 | "last_link_id": 44, 6 | "nodes": [ 7 | { 8 | "id": 3, 9 | "type": "SaveAudio", 10 | "pos": [ 11 | -1040, 12 | -1130 13 | ], 14 | "size": [ 15 | 270, 16 | 112 17 | ], 18 | "flags": {}, 19 | "order": 6, 20 | "mode": 0, 21 | "inputs": [ 22 | { 23 | "name": "audio", 24 | "type": "AUDIO", 25 | "link": 27 26 | } 27 | ], 28 | "outputs": [], 29 | "properties": { 30 | "Node name for S&R": "SaveAudio", 31 | "cnr_id": "comfy-core", 32 | "ver": "0.3.52", 33 | "ue_properties": { 34 | "widget_ue_connectable": { 35 | "filename_prefix": true, 36 | "audioUI": true 37 | }, 38 | "version": "7.0.1" 39 | } 40 | }, 41 | "widgets_values": [ 42 | "audio/VibeVoice" 43 | ] 44 | }, 45 | { 46 | "id": 13, 47 | "type": "MarkdownNote", 48 | "pos": [ 49 | -1898.1748046875, 50 | -1409.22314453125 51 | ], 52 | "size": [ 53 | 1035.619873046875, 54 | 211.96694946289062 55 | ], 56 | "flags": {}, 57 | "order": 0, 58 | "mode": 0, 59 | "inputs": [], 60 | "outputs": [], 61 | "title": "Note", 62 | "properties": {}, 63 | "widgets_values": [ 64 | "# ComfyUI-VibeVoice\n\nVibeVoice is a novel framework by Microsoft for generating expressive, long-form, multi-speaker conversational audio. It excels at creating natural-sounding dialogue, podcasts, and more, with consistent voices for up to 4 speakers.\n\n**✨ Key Features:**\n* **Multi-Speaker TTS:** Generate conversations with up to 4 distinct voices in a single audio output.\n* **High-Fidelity Voice Cloning:** Use any audio file (`.wav`, `.mp3`) as a reference for a speaker's voice.\n* **Hybrid Generation Mode:** Mix and match cloned voices with high-quality, zero-shot generated voices in the same script.\n* **Flexible Scripting:** Use simple `[1]` tags or the classic `Speaker 1:` format to write your dialogue.\n* **Advanced Attention Mechanisms:** Choose between `eager`, `sdpa`, `flash_attention_2`, and the high-performance `sage` attention for fine-tuned control over speed and compatibility.\n* **Robust 4-Bit Quantization:** Run the large language model component in 4-bit mode to significantly reduce VRAM usage.\n* **Automatic Model Management:** Models are downloaded automatically and managed efficiently by ComfyUI to save VRAM." 65 | ], 66 | "color": "#233", 67 | "bgcolor": "#355" 68 | }, 69 | { 70 | "id": 4, 71 | "type": "LoadAudio", 72 | "pos": [ 73 | -1900, 74 | -1130 75 | ], 76 | "size": [ 77 | 272.9800720214844, 78 | 136 79 | ], 80 | "flags": {}, 81 | "order": 1, 82 | "mode": 0, 83 | "inputs": [], 84 | "outputs": [ 85 | { 86 | "name": "AUDIO", 87 | "type": "AUDIO", 88 | "links": [] 89 | } 90 | ], 91 | "properties": { 92 | "Node name for S&R": "LoadAudio", 93 | "cnr_id": "comfy-core", 94 | "ver": "0.3.52", 95 | "ue_properties": { 96 | "widget_ue_connectable": { 97 | "audio": true, 98 | "audioUI": true, 99 | "upload": true 100 | }, 101 | "version": "7.0.1" 102 | } 103 | }, 104 | "widgets_values": [ 105 | "male_rickmorty.mp3", 106 | null, 107 | null 108 | ] 109 | }, 110 | { 111 | "id": 8, 112 | "type": "LoadAudio", 113 | "pos": [ 114 | -1901.10009765625, 115 | -948.7998046875 116 | ], 117 | "size": [ 118 | 274.080078125, 119 | 136 120 | ], 121 | "flags": {}, 122 | "order": 2, 123 | "mode": 0, 124 | "inputs": [], 125 | "outputs": [ 126 | { 127 | "name": "AUDIO", 128 | "type": "AUDIO", 129 | "links": [] 130 | } 131 | ], 132 | "properties": { 133 | "Node name for S&R": "LoadAudio", 134 | "cnr_id": "comfy-core", 135 | "ver": "0.3.52", 136 | "ue_properties": { 137 | "widget_ue_connectable": { 138 | "audio": true, 139 | "audioUI": true, 140 | "upload": true 141 | }, 142 | "version": "7.0.1" 143 | } 144 | }, 145 | "widgets_values": [ 146 | "male_stewie.mp3", 147 | null, 148 | null 149 | ] 150 | }, 151 | { 152 | "id": 12, 153 | "type": "MarkdownNote", 154 | "pos": [ 155 | -1915.701904296875, 156 | -762.380126953125 157 | ], 158 | "size": [ 159 | 312.85455322265625, 160 | 292.8734130859375 161 | ], 162 | "flags": {}, 163 | "order": 3, 164 | "mode": 0, 165 | "inputs": [], 166 | "outputs": [], 167 | "title": "Note", 168 | "properties": {}, 169 | "widgets_values": [ 170 | "### Scripting and Voice Modes\n\n#### Speaker Tagging\nYou can assign lines to speakers in two ways. Both are treated identically.\n\n* **Modern Format (Recommended):** `[1] This is the first speaker.`\n* **Classic Format:** `Speaker 1: This is the first speaker.`\n\nYou can also add an optional colon to the modern format (e.g., `[1]: ...`). The node handles all variations consistently.\n\n#### Hybrid Voice Generation\nThis is a powerful feature that lets you mix cloned voices and generated (zero-shot) voices.\n\n* **To Clone a Voice:** Connect a `Load Audio` node to the speaker's input (e.g., `speaker_1_voice`).\n* **To Generate a Voice:** Leave the speaker's input empty. The model will create a unique, high-quality voice for that speaker." 171 | ], 172 | "color": "#233", 173 | "bgcolor": "#355" 174 | }, 175 | { 176 | "id": 14, 177 | "type": "MarkdownNote", 178 | "pos": [ 179 | -1048.3660888671875, 180 | -960.8771362304688 181 | ], 182 | "size": [ 183 | 280.797607421875, 184 | 487.02728271484375 185 | ], 186 | "flags": {}, 187 | "order": 4, 188 | "mode": 0, 189 | "inputs": [], 190 | "outputs": [], 191 | "title": "Note", 192 | "properties": {}, 193 | "widgets_values": [ 194 | "## Models\n\nWill be downloaded on the first run, or download them manually and place them into the directory: /models/tts/VibeVoice\n\n| Model | Context Length | Generation Length | Weight |\n|-------|----------------|----------|----------|\n| VibeVoice-1.5B | 64K | ~90 min | [HF link](https://huggingface.co/microsoft/VibeVoice-1.5B) |\n| VibeVoice-Large| 32K | ~45 min | [HF link](https://huggingface.co/microsoft/VibeVoice-Large) |\n\n## Support \n\n- Don't know how to update PyTorch?\n- Need help with ComfyUI?\n- Need technical support?\n\n### Or do you just have questions? Then join the [@TokenDiffusion Hub](https://t.me/TokenDiff_hub) group\n\n### AI news [TokenDiffusion](https://t.me/TokenDiff)" 195 | ], 196 | "color": "#233", 197 | "bgcolor": "#355" 198 | }, 199 | { 200 | "id": 11, 201 | "type": "VibeVoiceTTS", 202 | "pos": [ 203 | -1570, 204 | -1130 205 | ], 206 | "size": [ 207 | 475.3999938964844, 208 | 662.9000244140625 209 | ], 210 | "flags": {}, 211 | "order": 5, 212 | "mode": 0, 213 | "inputs": [ 214 | { 215 | "name": "speaker_1_voice", 216 | "shape": 7, 217 | "type": "AUDIO", 218 | "link": null 219 | }, 220 | { 221 | "name": "speaker_2_voice", 222 | "shape": 7, 223 | "type": "AUDIO", 224 | "link": null 225 | }, 226 | { 227 | "name": "speaker_3_voice", 228 | "shape": 7, 229 | "type": "AUDIO", 230 | "link": null 231 | }, 232 | { 233 | "name": "speaker_4_voice", 234 | "shape": 7, 235 | "type": "AUDIO", 236 | "link": null 237 | } 238 | ], 239 | "outputs": [ 240 | { 241 | "name": "AUDIO", 242 | "type": "AUDIO", 243 | "links": [ 244 | 27 245 | ] 246 | } 247 | ], 248 | "properties": { 249 | "Node name for S&R": "VibeVoiceTTS", 250 | "cnr_id": "ComfyUI-VibeVoice", 251 | "ver": "37803a884fb8f9b43c38286f6d654c7f97181a73", 252 | "ue_properties": { 253 | "widget_ue_connectable": { 254 | "model_name": true, 255 | "text": true, 256 | "quantize_llm_4bit": true, 257 | "attention_mode": true, 258 | "cfg_scale": true, 259 | "inference_steps": true, 260 | "seed": true, 261 | "do_sample": true, 262 | "temperature": true, 263 | "top_p": true, 264 | "top_k": true 265 | }, 266 | "version": "7.0.1" 267 | } 268 | }, 269 | "widgets_values": [ 270 | "VibeVoice-1.5B", 271 | "[1] I can't believe you did it again. I waited for two hours. Two hours! Not a single call, not a text. Do you have any idea how embarrassing that was, just sitting there alone?\n[2] Look, I know, I'm sorry, alright? Work was a complete nightmare. My boss dropped a critical deadline on me at the last minute. I didn't even have a second to breathe, let alone check my phone.\n", 272 | false, 273 | "flash_attention_2", 274 | 1.3, 275 | 10, 276 | 471935335072093, 277 | "fixed", 278 | true, 279 | 0.95, 280 | 0.95, 281 | 0, 282 | false 283 | ], 284 | "color": "#232", 285 | "bgcolor": "#353" 286 | } 287 | ], 288 | "links": [ 289 | [ 290 | 27, 291 | 11, 292 | 0, 293 | 3, 294 | 0, 295 | "AUDIO" 296 | ] 297 | ], 298 | "groups": [], 299 | "config": {}, 300 | "extra": { 301 | "ds": { 302 | "scale": 0.8264462809917354, 303 | "offset": [ 304 | 2015.701904296875, 305 | 1509.22314453125 306 | ] 307 | }, 308 | "ue_links": [], 309 | "links_added_by_ue": [], 310 | "frontendVersion": "1.26.11", 311 | "VHS_latentpreview": false, 312 | "VHS_latentpreviewrate": 0, 313 | "VHS_MetadataImage": true, 314 | "VHS_KeepIntermediate": true 315 | }, 316 | "version": 0.4 317 | } -------------------------------------------------------------------------------- /vibevoice/modular/configuration_vibevoice.py: -------------------------------------------------------------------------------- 1 | """ VibeVoice_AcousticTokenizer model configuration""" 2 | 3 | from typing import Dict, List, Optional, Tuple 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | from transformers.utils import logging 7 | 8 | from transformers.models.qwen2.configuration_qwen2 import Qwen2Config 9 | 10 | logger = logging.get_logger(__name__) 11 | 12 | 13 | class VibeVoiceAcousticTokenizerConfig(PretrainedConfig): 14 | model_type = "vibevoice_acoustic_tokenizer" 15 | 16 | def __init__( 17 | self, 18 | channels: int = 1, 19 | corpus_normalize: float = 0.0, 20 | causal: bool = True, 21 | vae_dim: int = 64, 22 | fix_std: float = 0.5, 23 | std_dist_type: str = 'gaussian', 24 | # common 25 | mixer_layer: str = 'depthwise_conv', 26 | conv_norm: str = 'none', 27 | pad_mode: str = 'constant', 28 | disable_last_norm: bool = True, 29 | layernorm: str = 'RMSNorm', 30 | layernorm_eps: float = 1e-5, 31 | layernorm_elementwise_affine: bool = True, 32 | conv_bias: bool = True, 33 | layer_scale_init_value: float = 1e-6, 34 | weight_init_value: float = 1e-2, 35 | # encoder specific 36 | encoder_n_filters: int = 32, 37 | encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2], 38 | encoder_depths: str = "3-3-3-3-3-3-8", 39 | # decoder specific 40 | decoder_n_filters: int = 32, 41 | decoder_ratios: Optional[List[int]] = None, # if None, same as encoder 42 | decoder_depths: Optional[str] = None, 43 | **kwargs 44 | ): 45 | super().__init__(**kwargs) 46 | self.channels = channels 47 | self.corpus_normalize = corpus_normalize 48 | self.causal = causal 49 | self.vae_dim = vae_dim 50 | self.fix_std = fix_std 51 | self.std_dist_type = std_dist_type 52 | 53 | # common parameters 54 | self.conv_norm = conv_norm 55 | self.pad_mode = pad_mode 56 | self.layernorm_eps = layernorm_eps 57 | self.disable_last_norm = disable_last_norm 58 | self.layernorm = layernorm 59 | self.layernorm_elementwise_affine = layernorm_elementwise_affine 60 | self.conv_bias = conv_bias 61 | self.layer_scale_init_value = layer_scale_init_value 62 | self.weight_init_value = weight_init_value 63 | self.mixer_layer = mixer_layer 64 | 65 | # encoder specific parameters 66 | self.encoder_n_filters = encoder_n_filters 67 | self.encoder_ratios = encoder_ratios 68 | self.encoder_depths = encoder_depths 69 | 70 | # decoder specific parameters 71 | self.decoder_ratios = decoder_ratios if decoder_ratios is not None else encoder_ratios 72 | self.decoder_n_filters = decoder_n_filters 73 | self.decoder_depths = decoder_depths 74 | 75 | 76 | class VibeVoiceSemanticTokenizerConfig(PretrainedConfig): 77 | model_type = "vibevoice_semantic_tokenizer" 78 | 79 | def __init__( 80 | self, 81 | channels: int = 1, 82 | corpus_normalize: float = 0.0, 83 | causal: bool = True, 84 | vae_dim: int = 64, 85 | fix_std: float = 0, 86 | std_dist_type: str = 'none', 87 | # common 88 | mixer_layer: str = 'depthwise_conv', 89 | conv_norm: str = 'none', 90 | pad_mode: str = 'constant', 91 | disable_last_norm: bool = True, 92 | layernorm: str = 'RMSNorm', 93 | layernorm_eps: float = 1e-5, 94 | layernorm_elementwise_affine: bool = True, 95 | conv_bias: bool = True, 96 | layer_scale_init_value: float = 1e-6, 97 | weight_init_value: float = 1e-2, 98 | # encoder specific 99 | encoder_n_filters: int = 32, 100 | encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2], 101 | encoder_depths: str = "3-3-3-3-3-3-8", 102 | **kwargs 103 | ): 104 | super().__init__(**kwargs) 105 | self.channels = channels 106 | self.corpus_normalize = corpus_normalize 107 | self.causal = causal 108 | self.vae_dim = vae_dim 109 | self.fix_std = fix_std 110 | self.std_dist_type = std_dist_type 111 | 112 | # common parameters 113 | self.conv_norm = conv_norm 114 | self.pad_mode = pad_mode 115 | self.layernorm_eps = layernorm_eps 116 | self.disable_last_norm = disable_last_norm 117 | self.layernorm = layernorm 118 | self.layernorm_elementwise_affine = layernorm_elementwise_affine 119 | self.conv_bias = conv_bias 120 | self.layer_scale_init_value = layer_scale_init_value 121 | self.weight_init_value = weight_init_value 122 | self.mixer_layer = mixer_layer 123 | 124 | # encoder specific parameters 125 | self.encoder_n_filters = encoder_n_filters 126 | self.encoder_ratios = encoder_ratios 127 | self.encoder_depths = encoder_depths 128 | 129 | 130 | class VibeVoiceDiffusionHeadConfig(PretrainedConfig): 131 | model_type = "vibevoice_diffusion_head" 132 | 133 | def __init__( 134 | self, 135 | hidden_size=768, 136 | head_layers=4, 137 | head_ffn_ratio=3.0, 138 | rms_norm_eps=1e-5, 139 | latent_size=64, 140 | speech_vae_dim=None, 141 | prediction_type="v_prediction", 142 | diffusion_type="ddpm", 143 | ddpm_num_steps=1000, 144 | ddpm_num_inference_steps=20, 145 | ddpm_beta_schedule="cosine", 146 | ddpm_batch_mul=4, 147 | **kwargs 148 | ): 149 | self.hidden_size = hidden_size 150 | self.head_layers = head_layers 151 | self.head_ffn_ratio = head_ffn_ratio 152 | self.rms_norm_eps = rms_norm_eps 153 | self.latent_size = latent_size 154 | self.speech_vae_dim = speech_vae_dim 155 | self.prediction_type = prediction_type 156 | self.diffusion_type = diffusion_type 157 | self.ddpm_num_steps = ddpm_num_steps 158 | self.ddpm_num_inference_steps = ddpm_num_inference_steps 159 | self.ddpm_beta_schedule = ddpm_beta_schedule 160 | self.ddpm_batch_mul = ddpm_batch_mul 161 | 162 | super().__init__(**kwargs) 163 | 164 | class VibeVoiceConfig(PretrainedConfig): 165 | model_type = "vibevoice" 166 | is_composition = True 167 | sub_configs = { 168 | "acoustic_tokenizer_config": VibeVoiceAcousticTokenizerConfig, 169 | "semantic_tokenizer_config": VibeVoiceSemanticTokenizerConfig, 170 | "decoder_config": Qwen2Config, 171 | "diffusion_head_config": VibeVoiceDiffusionHeadConfig, 172 | } 173 | # keys_to_ignore_at_inference = ["past_key_values"] 174 | # Default tensor parallel plan for base model `Qwen2` 175 | base_model_tp_plan = { 176 | "layers.*.self_attn.q_proj": "colwise", 177 | "layers.*.self_attn.k_proj": "colwise", 178 | "layers.*.self_attn.v_proj": "colwise", 179 | "layers.*.self_attn.o_proj": "rowwise", 180 | "layers.*.mlp.gate_proj": "colwise", 181 | "layers.*.mlp.up_proj": "colwise", 182 | "layers.*.mlp.down_proj": "rowwise", 183 | } 184 | 185 | def __init__( 186 | self, 187 | acoustic_tokenizer_config=None, 188 | semantic_tokenizer_config=None, 189 | decoder_config=None, 190 | diffusion_head_config=None, 191 | **kwargs 192 | ): 193 | 194 | # kwargs["_attn_implementation"] = "flash_attention_2" 195 | kwargs["_attn_implementation_autoset"] = False 196 | 197 | if acoustic_tokenizer_config is None: 198 | self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]() 199 | elif isinstance(acoustic_tokenizer_config, dict): 200 | acoustic_tokenizer_config["model_type"] = "vibevoice_acoustic_tokenizer" 201 | self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config) 202 | elif isinstance(acoustic_tokenizer_config, VibeVoiceAcousticTokenizerConfig): 203 | # If an instance of the config class is provided 204 | self.acoustic_tokenizer_config = acoustic_tokenizer_config 205 | 206 | if semantic_tokenizer_config is None: 207 | self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"]() 208 | elif isinstance(semantic_tokenizer_config, dict): 209 | semantic_tokenizer_config["model_type"] = "vibevoice_semantic_tokenizer" 210 | self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"](**semantic_tokenizer_config) 211 | elif isinstance(semantic_tokenizer_config, VibeVoiceSemanticTokenizerConfig): 212 | # If an instance of the config class is provided 213 | self.semantic_tokenizer_config = semantic_tokenizer_config 214 | 215 | if decoder_config is None: 216 | self.decoder_config = self.sub_configs["decoder_config"]() 217 | elif isinstance(decoder_config, dict): 218 | # If a dictionary is provided, instantiate the config class with it 219 | # self.decoder_config = self.sub_configs["decoder_config"](**decoder_config) 220 | if decoder_config.get("model_type", '') == "qwen2": 221 | self.decoder_config = Qwen2Config(**decoder_config) 222 | else: 223 | raise ValueError(f"Unsupported decoder model type: {decoder_config.get('model_type', '')}") 224 | elif isinstance(decoder_config, (Qwen2Config,)): 225 | # If an instance of the config class is provided 226 | self.decoder_config = decoder_config 227 | 228 | if diffusion_head_config is None: 229 | self.diffusion_head_config = self.sub_configs["diffusion_head_config"]() 230 | elif isinstance(diffusion_head_config, dict): 231 | diffusion_head_config["model_type"] = "vibevoice_diffusion_head" 232 | self.diffusion_head_config = self.sub_configs["diffusion_head_config"](**diffusion_head_config) 233 | elif isinstance(diffusion_head_config, VibeVoiceDiffusionHeadConfig): 234 | # If an instance of the config class is provided 235 | self.diffusion_head_config = diffusion_head_config 236 | 237 | # other parameters 238 | self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, 'vae_dim', 64) 239 | self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, 'vae_dim', 128) 240 | 241 | self.num_hidden_layers = self.decoder_config.num_hidden_layers 242 | super().__init__(**kwargs) 243 | 244 | __all__ = [ 245 | "VibeVoiceAcousticTokenizerConfig", 246 | "VibeVoiceSemanticTokenizerConfig", 247 | "VibeVoiceDiffusionHeadConfig", 248 | "VibeVoiceConfig" 249 | ] -------------------------------------------------------------------------------- /vibevoice_nodes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gc 3 | import logging 4 | 5 | import comfy.model_management as model_management 6 | from comfy.utils import ProgressBar 7 | 8 | # Import from the dedicated model_info module 9 | from .modules.model_info import AVAILABLE_VIBEVOICE_MODELS 10 | from .modules.loader import VibeVoiceModelHandler, ATTENTION_MODES, VIBEVOICE_PATCHER_CACHE, cleanup_old_models 11 | from .modules.patcher import VibeVoicePatcher 12 | from .modules.utils import parse_script_1_based, preprocess_comfy_audio, set_vibevoice_seed, check_for_interrupt 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | class VibeVoiceTTSNode: 17 | @classmethod 18 | def INPUT_TYPES(cls): 19 | model_names = list(AVAILABLE_VIBEVOICE_MODELS.keys()) 20 | if not model_names: 21 | model_names.append("No models found in models/tts/VibeVoice") 22 | 23 | return { 24 | "required": { 25 | "model_name": (model_names, { 26 | "tooltip": "Select the VibeVoice model to use. Official models will be downloaded automatically." 27 | }), 28 | "text": ("STRING", { 29 | "multiline": True, 30 | "default": "[1] Hello, this is a cloned voice.\n[2] And this is a generated voice, how cool is that?", 31 | "tooltip": "The script for generation. Use '[1]' or 'Speaker 1:' for speakers. If a speaker in the script lacks a reference voice, it will be generated via zero-shot TTS." 32 | }), 33 | "quantize_llm_4bit": ("BOOLEAN", { 34 | "default": False, "label_on": "Q4 (LLM only)", "label_off": "Full precision", 35 | "tooltip": "Quantize the Qwen2.5 LLM to 4-bit NF4 via bitsandbytes. Diffusion head stays BF16/FP32." 36 | }), 37 | "attention_mode": (ATTENTION_MODES, { 38 | "default": "sdpa", 39 | "tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest), Sage (quantized)" 40 | }), 41 | "cfg_scale": ("FLOAT", { 42 | "default": 1.3, "min": 0.1, "max": 50.0, "step": 0.05, 43 | "tooltip": "Classifier-Free Guidance scale. Higher values increase adherence to the voice prompt but may reduce naturalness. Recommended: 1.3" 44 | }), 45 | "inference_steps": ("INT", { 46 | "default": 10, "min": 1, "max": 500, 47 | "tooltip": "Number of diffusion steps for audio generation. More steps can improve quality but take longer. Recommended: 10" 48 | }), 49 | "seed": ("INT", { 50 | "default": 42, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, "control_after_generate": True, 51 | "tooltip": "Seed for reproducibility. Set to 0 for a random seed on each run." 52 | }), 53 | "do_sample": ("BOOLEAN", { 54 | "default": True, "label_on": "Enabled (Sampling)", "label_off": "Disabled (Greedy)", 55 | "tooltip": "Enable to use sampling methods (like temperature and top_p) for more varied output. Disable for deterministic (greedy) decoding." 56 | }), 57 | "temperature": ("FLOAT", { 58 | "default": 0.95, "min": 0.0, "max": 2.0, "step": 0.01, 59 | "tooltip": "Controls randomness. Higher values make the output more random and creative, while lower values make it more focused and deterministic. Active only if 'do_sample' is enabled." 60 | }), 61 | "top_p": ("FLOAT", { 62 | "default": 0.95, "min": 0.0, "max": 1.0, "step": 0.01, 63 | "tooltip": "Nucleus sampling (Top-P). The model samples from the smallest set of tokens whose cumulative probability exceeds this value. Active only if 'do_sample' is enabled." 64 | }), 65 | "top_k": ("INT", { 66 | "default": 0, "min": 0, "max": 500, "step": 1, 67 | "tooltip": "Top-K sampling. Restricts sampling to the K most likely next tokens. Set to 0 to disable. Active only if 'do_sample' is enabled." 68 | }), 69 | "force_offload": ("BOOLEAN", { 70 | "default": False, "label_on": "Force Offload", "label_off": "Keep in VRAM", 71 | "tooltip": "Force model to be offloaded from VRAM after generation. Useful to free up memory between generations but may slow down subsequent runs." 72 | }), 73 | }, 74 | "optional": { 75 | "speaker_1_voice": ("AUDIO", {"tooltip": "Reference audio for 'Speaker 1' or '[1]' in the script."}), 76 | "speaker_2_voice": ("AUDIO", {"tooltip": "Reference audio for 'Speaker 2' or '[2]' in the script."}), 77 | "speaker_3_voice": ("AUDIO", {"tooltip": "Reference audio for 'Speaker 3' or '[3]' in the script."}), 78 | "speaker_4_voice": ("AUDIO", {"tooltip": "Reference audio for 'Speaker 4' or '[4]' in the script."}), 79 | } 80 | } 81 | 82 | RETURN_TYPES = ("AUDIO",) 83 | FUNCTION = "generate_audio" 84 | CATEGORY = "audio/tts" 85 | 86 | def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, quantize_llm_4bit, force_offload, **kwargs): 87 | actual_attention_mode = attention_mode 88 | if quantize_llm_4bit and attention_mode in ["eager", "flash_attention_2"]: 89 | actual_attention_mode = "sdpa" 90 | 91 | cache_key = f"{model_name}_attn_{actual_attention_mode}_q4_{int(quantize_llm_4bit)}" 92 | 93 | if cache_key not in VIBEVOICE_PATCHER_CACHE: 94 | cleanup_old_models(keep_cache_key=cache_key) 95 | 96 | model_handler = VibeVoiceModelHandler(model_name, attention_mode, use_llm_4bit=quantize_llm_4bit) 97 | patcher = VibeVoicePatcher( 98 | model_handler, 99 | attention_mode=attention_mode, 100 | load_device=model_management.get_torch_device(), 101 | offload_device=model_management.unet_offload_device(), 102 | size=model_handler.size 103 | ) 104 | VIBEVOICE_PATCHER_CACHE[cache_key] = patcher 105 | 106 | patcher = VIBEVOICE_PATCHER_CACHE[cache_key] 107 | model_management.load_model_gpu(patcher) 108 | model = patcher.model.model 109 | processor = patcher.model.processor 110 | 111 | if model is None or processor is None: 112 | raise RuntimeError("VibeVoice model and processor could not be loaded. Check logs for errors.") 113 | 114 | parsed_lines_0_based, speaker_ids_1_based = parse_script_1_based(text) 115 | if not parsed_lines_0_based: 116 | raise ValueError("Script is empty or invalid. Please provide text to generate.") 117 | 118 | # full_script = "\n".join([f"Speaker {spk+1}: {txt}" for spk, txt in parsed_lines_0_based]) # <-- REMOVED: This was the cause of the bug. 119 | 120 | speaker_inputs = {i: kwargs.get(f"speaker_{i}_voice") for i in range(1, 5)} 121 | voice_samples_np = [preprocess_comfy_audio(speaker_inputs.get(sid)) for sid in speaker_ids_1_based] 122 | 123 | set_vibevoice_seed(seed) 124 | 125 | try: 126 | inputs = processor( 127 | parsed_scripts=[parsed_lines_0_based], 128 | voice_samples=[voice_samples_np], 129 | speaker_ids_for_prompt=[speaker_ids_1_based], 130 | padding=True, 131 | return_tensors="pt", 132 | return_attention_mask=True 133 | ) 134 | 135 | for key, value in inputs.items(): 136 | if isinstance(value, torch.Tensor): 137 | if torch.any(torch.isnan(value)) or torch.any(torch.isinf(value)): 138 | logger.error(f"Input tensor '{key}' contains NaN or Inf values") 139 | raise ValueError(f"Invalid values in input tensor: {key}") 140 | 141 | inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} 142 | 143 | model.set_ddpm_inference_steps(num_steps=inference_steps) 144 | 145 | generation_config = {'do_sample': do_sample} 146 | if do_sample: 147 | generation_config['temperature'] = temperature 148 | generation_config['top_p'] = top_p 149 | if top_k > 0: 150 | generation_config['top_k'] = top_k 151 | 152 | with torch.no_grad(): 153 | pbar = ProgressBar(inference_steps) 154 | 155 | def progress_callback(step, total_steps): 156 | pbar.update(1) 157 | if model_management.interrupt_current_processing: 158 | raise model_management.InterruptProcessingException() 159 | 160 | try: 161 | outputs = model.generate( 162 | **inputs, max_new_tokens=None, cfg_scale=cfg_scale, 163 | tokenizer=processor.tokenizer, generation_config=generation_config, 164 | verbose=False, stop_check_fn=check_for_interrupt 165 | ) 166 | pbar.update(inference_steps - pbar.current) 167 | 168 | except RuntimeError as e: 169 | error_msg = str(e).lower() 170 | if "assertion" in error_msg or "cuda" in error_msg: 171 | logger.error(f"CUDA assertion failed with {attention_mode} attention: {e}") 172 | logger.error("This might be due to invalid input data, GPU memory issues, or incompatible attention mode.") 173 | logger.error("Try restarting ComfyUI, using different audio files, or switching to 'eager' attention mode.") 174 | raise e 175 | except model_management.InterruptProcessingException: 176 | logger.info("VibeVoice generation interrupted by user") 177 | raise 178 | finally: 179 | pbar.update_absolute(inference_steps) 180 | 181 | except model_management.InterruptProcessingException: 182 | logger.info("VibeVoice TTS generation was cancelled") 183 | return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},) 184 | 185 | except Exception as e: 186 | logger.error(f"Error during VibeVoice generation with {attention_mode} attention: {e}") 187 | if "interrupt" in str(e).lower() or "cancel" in str(e).lower(): 188 | logger.info("Generation was interrupted") 189 | return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},) 190 | raise 191 | 192 | output_waveform = outputs.speech_outputs[0] 193 | if output_waveform.ndim == 1: output_waveform = output_waveform.unsqueeze(0) 194 | if output_waveform.ndim == 2: output_waveform = output_waveform.unsqueeze(0) 195 | 196 | if force_offload: 197 | logger.info(f"Force offloading VibeVoice model '{model_name}' from VRAM...") 198 | if patcher.is_loaded: 199 | patcher.unpatch_model(unpatch_weights=True) 200 | model_management.unload_all_models() 201 | gc.collect() 202 | model_management.soft_empty_cache() 203 | logger.info("Model force offload completed") 204 | 205 | return ({"waveform": output_waveform.detach().cpu(), "sample_rate": 24000},) 206 | 207 | NODE_CLASS_MAPPINGS = {"VibeVoiceTTS": VibeVoiceTTSNode} 208 | NODE_DISPLAY_NAME_MAPPINGS = {"VibeVoiceTTS": "VibeVoice TTS"} -------------------------------------------------------------------------------- /modules/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import gc 4 | import json 5 | import logging 6 | from huggingface_hub import hf_hub_download, snapshot_download 7 | 8 | import comfy.utils 9 | import folder_paths 10 | import comfy.model_management as model_management 11 | 12 | import transformers 13 | from packaging import version 14 | 15 | _transformers_version = version.parse(transformers.__version__) 16 | _DTYPE_ARG_SUPPORTED = _transformers_version >= version.parse("4.56.0") 17 | 18 | from transformers import BitsAndBytesConfig 19 | from ..vibevoice.modular.configuration_vibevoice import VibeVoiceConfig 20 | from ..vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference 21 | from ..vibevoice.processor.vibevoice_processor import VibeVoiceProcessor 22 | from ..vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor 23 | from ..vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizerFast 24 | 25 | from .model_info import AVAILABLE_VIBEVOICE_MODELS, MODEL_CONFIGS 26 | from .. import SAGE_ATTENTION_AVAILABLE 27 | if SAGE_ATTENTION_AVAILABLE: 28 | from ..vibevoice.modular.sage_attention_patch import set_sage_attention 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | LOADED_MODELS = {} 33 | VIBEVOICE_PATCHER_CACHE = {} 34 | 35 | ATTENTION_MODES = ["eager", "sdpa", "flash_attention_2"] 36 | if SAGE_ATTENTION_AVAILABLE: 37 | ATTENTION_MODES.append("sage") 38 | 39 | def cleanup_old_models(keep_cache_key=None): 40 | global LOADED_MODELS, VIBEVOICE_PATCHER_CACHE 41 | keys_to_remove = [] 42 | for key in list(LOADED_MODELS.keys()): 43 | if key != keep_cache_key: 44 | keys_to_remove.append(key) 45 | del LOADED_MODELS[key] 46 | for key in list(VIBEVOICE_PATCHER_CACHE.keys()): 47 | if key != keep_cache_key: 48 | try: 49 | patcher = VIBEVOICE_PATCHER_CACHE[key] 50 | if hasattr(patcher, 'model') and patcher.model: 51 | patcher.model.model = None 52 | patcher.model.processor = None 53 | del VIBEVOICE_PATCHER_CACHE[key] 54 | except Exception as e: 55 | logger.warning(f"Error cleaning up patcher {key}: {e}") 56 | if keys_to_remove: 57 | logger.info(f"Cleaned up cached models: {keys_to_remove}") 58 | gc.collect() 59 | model_management.soft_empty_cache() 60 | 61 | 62 | class VibeVoiceModelHandler(torch.nn.Module): 63 | def __init__(self, model_pack_name, attention_mode="eager", use_llm_4bit=False): 64 | super().__init__() 65 | self.model_pack_name = model_pack_name 66 | self.attention_mode = attention_mode 67 | self.use_llm_4bit = use_llm_4bit 68 | self.cache_key = f"{self.model_pack_name}_attn_{attention_mode}_q4_{int(use_llm_4bit)}" 69 | self.model = None 70 | self.processor = None 71 | info = AVAILABLE_VIBEVOICE_MODELS.get(model_pack_name, {}) 72 | size_gb = MODEL_CONFIGS.get(model_pack_name, {}).get("size_gb", 4.0) 73 | self.size = int(size_gb * (1024**3)) 74 | def load_model(self, device, attention_mode="eager"): 75 | self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device, attention_mode, use_llm_4bit=self.use_llm_4bit) 76 | if self.model.device != device: 77 | self.model.to(device) 78 | 79 | class VibeVoiceLoader: 80 | @staticmethod 81 | def _check_gpu_for_sage_attention(): 82 | if not SAGE_ATTENTION_AVAILABLE: return False 83 | if not torch.cuda.is_available(): return False 84 | major, _ = torch.cuda.get_device_capability() 85 | if major < 8: 86 | logger.warning(f"Your GPU (compute capability {major}.x) does not support SageAttention, which requires CC 8.0+. Sage option will be disabled.") 87 | return False 88 | return True 89 | 90 | @staticmethod 91 | def load_model(model_name: str, device, attention_mode: str = "eager", use_llm_4bit: bool = False): 92 | if model_name not in AVAILABLE_VIBEVOICE_MODELS: 93 | raise ValueError(f"Unknown VibeVoice model: {model_name}. Available models: {list(AVAILABLE_VIBEVOICE_MODELS.keys())}") 94 | 95 | if use_llm_4bit and attention_mode in ["eager", "flash_attention_2"]: 96 | logger.warning(f"Attention mode '{attention_mode}' is not recommended with 4-bit quantization. Falling back to 'sdpa' for stability and performance.") 97 | attention_mode = "sdpa" 98 | if attention_mode not in ATTENTION_MODES: 99 | logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager") 100 | attention_mode = "eager" 101 | 102 | cache_key = f"{model_name}_attn_{attention_mode}_q4_{int(use_llm_4bit)}" 103 | if cache_key in LOADED_MODELS: 104 | logger.info(f"Using cached model with {attention_mode} attention and q4={use_llm_4bit}") 105 | return LOADED_MODELS[cache_key] 106 | 107 | model_info = AVAILABLE_VIBEVOICE_MODELS[model_name] 108 | model_type = model_info["type"] 109 | vibevoice_base_path = os.path.join(folder_paths.get_folder_paths("tts")[0], "VibeVoice") 110 | 111 | model_path_or_none = None 112 | config_path = None 113 | preprocessor_config_path = None 114 | tokenizer_dir = None 115 | 116 | if model_type == "official": 117 | model_path_or_none = os.path.join(vibevoice_base_path, model_name) 118 | if not os.path.exists(os.path.join(model_path_or_none, "model.safetensors.index.json")): 119 | logger.info(f"Downloading official VibeVoice model: {model_name}...") 120 | snapshot_download(repo_id=model_info["repo_id"], local_dir=model_path_or_none, local_dir_use_symlinks=False) 121 | config_path = os.path.join(model_path_or_none, "config.json") 122 | preprocessor_config_path = os.path.join(model_path_or_none, "preprocessor_config.json") 123 | tokenizer_dir = model_path_or_none 124 | elif model_type == "local_dir": 125 | model_path_or_none = model_info["path"] 126 | config_path = os.path.join(model_path_or_none, "config.json") 127 | preprocessor_config_path = os.path.join(model_path_or_none, "preprocessor_config.json") 128 | tokenizer_dir = model_path_or_none 129 | elif model_type == "standalone": 130 | model_path_or_none = None # IMPORTANT: This must be None when loading from state_dict 131 | config_path = os.path.splitext(model_info["path"])[0] + ".config.json" 132 | preprocessor_config_path = os.path.splitext(model_info["path"])[0] + ".preprocessor.json" 133 | tokenizer_dir = os.path.dirname(model_info["path"]) 134 | 135 | if os.path.exists(config_path): 136 | config = VibeVoiceConfig.from_pretrained(config_path) 137 | else: 138 | fallback_name = "default_VibeVoice-Large_config.json" if "large" in model_name.lower() else "default_VibeVoice-1.5B_config.json" 139 | fallback_path = os.path.join(os.path.dirname(__file__), "..", "vibevoice", "configs", fallback_name) 140 | logger.warning(f"Config not found for '{model_name}'. Using fallback: {fallback_name}") 141 | config = VibeVoiceConfig.from_pretrained(fallback_path) 142 | 143 | # Processor & Tokenizer setup 144 | tokenizer_file_path = os.path.join(tokenizer_dir, "tokenizer.json") 145 | 146 | if not os.path.exists(tokenizer_file_path): 147 | logger.info(f"'tokenizer.json' not found in model directory: {tokenizer_dir}") 148 | 149 | packaged_configs_dir = os.path.join(os.path.dirname(__file__), "..", "vibevoice", "configs") 150 | packaged_tokenizer_path = os.path.join(packaged_configs_dir, "tokenizer.json") 151 | 152 | if os.path.exists(packaged_tokenizer_path): 153 | try: 154 | import shutil 155 | logger.info("Found pre-packaged tokenizer. Copying it to model directory...") 156 | shutil.copyfile(packaged_tokenizer_path, tokenizer_file_path) 157 | except Exception as e: 158 | logger.warning(f"Failed to copy pre-packaged tokenizer: {e}. Will attempt to download.") 159 | 160 | if not os.path.exists(tokenizer_file_path): 161 | repos_to_try = ["Qwen/Qwen2.5-1.5B", "Qwen/Qwen2.5-7B"] 162 | download_successful = False 163 | last_error = None 164 | 165 | for repo_id in repos_to_try: 166 | logger.info(f"Attempting to download 'tokenizer.json' from Hugging Face repo '{repo_id}'...") 167 | try: 168 | hf_hub_download( 169 | repo_id=repo_id, 170 | filename="tokenizer.json", 171 | local_dir=tokenizer_dir 172 | ) 173 | download_successful = True 174 | logger.info("Download successful.") 175 | break # Exit the loop on success 176 | except Exception as e: 177 | logger.warning(f"Failed to download from '{repo_id}': {e}") 178 | last_error = e 179 | 180 | # Final Failure 181 | if not download_successful: 182 | error_message = ( 183 | f"FATAL: Could not get 'tokenizer.json'. All download attempts failed.\n" 184 | f"Last error: {last_error}\n\n" 185 | f"ACTION REQUIRED:\n" 186 | f"1. Manually download 'tokenizer.json' from https://huggingface.co/{repos_to_try[0]}/blob/main/tokenizer.json\n" 187 | f"2. Place the downloaded file in the following directory:\n '{tokenizer_dir}'" 188 | ) 189 | raise RuntimeError(error_message) 190 | 191 | vibevoice_tokenizer = VibeVoiceTextTokenizerFast(tokenizer_file=tokenizer_file_path) 192 | 193 | processor_config_data = {} 194 | if os.path.exists(preprocessor_config_path): 195 | with open(preprocessor_config_path, 'r', encoding='utf-8') as f: processor_config_data = json.load(f) 196 | 197 | audio_processor = VibeVoiceTokenizerProcessor() 198 | processor = VibeVoiceProcessor(tokenizer=vibevoice_tokenizer, audio_processor=audio_processor, speech_tok_compress_ratio=processor_config_data.get("speech_tok_compress_ratio", 3200), db_normalize=processor_config_data.get("db_normalize", True)) 199 | 200 | # Model Loading Prep 201 | if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): model_dtype = torch.bfloat16 202 | else: model_dtype = torch.float16 203 | quant_config = None 204 | final_load_dtype = model_dtype 205 | 206 | if use_llm_4bit: 207 | bnb_compute_dtype = model_dtype 208 | if attention_mode == 'sage': bnb_compute_dtype, final_load_dtype = torch.float32, torch.float32 209 | quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=bnb_compute_dtype) 210 | 211 | attn_implementation_for_load = "sdpa" if attention_mode == "sage" else attention_mode 212 | 213 | try: 214 | logger.info(f"Loading model '{model_name}' with dtype: {final_load_dtype} and attention: '{attn_implementation_for_load}'") 215 | 216 | # UNIFIED MODEL LOADING LOGIC 217 | from_pretrained_kwargs = { 218 | "config": config, 219 | "attn_implementation": attn_implementation_for_load, 220 | "device_map": "auto" if quant_config else device, 221 | "quantization_config": quant_config, 222 | } 223 | if _DTYPE_ARG_SUPPORTED: 224 | from_pretrained_kwargs['dtype'] = final_load_dtype 225 | else: 226 | from_pretrained_kwargs['torch_dtype'] = final_load_dtype 227 | 228 | if model_type == "standalone": 229 | logger.info(f"Loading standalone model state_dict directly to device: {device}") 230 | # loading the state dict directly to the target device 231 | state_dict = comfy.utils.load_torch_file(model_info["path"], device=device) 232 | from_pretrained_kwargs["state_dict"] = state_dict 233 | 234 | model = VibeVoiceForConditionalGenerationInference.from_pretrained(model_path_or_none, **from_pretrained_kwargs) 235 | 236 | if attention_mode == "sage": 237 | if VibeVoiceLoader._check_gpu_for_sage_attention(): 238 | set_sage_attention(model) 239 | else: 240 | raise RuntimeError("Incompatible hardware/setup for SageAttention.") 241 | 242 | model.eval() 243 | setattr(model, "_llm_4bit", bool(quant_config)) 244 | LOADED_MODELS[cache_key] = (model, processor) 245 | logger.info(f"Successfully configured model '{model_name}' with {attention_mode} attention") 246 | return model, processor 247 | 248 | except Exception as e: 249 | # It's not ideal to automatically reload the model. Let the user decide what to do in case of an error. 250 | logger.error(f"Failed to load model '{model_name}' with {attention_mode} attention: {e}") 251 | # if attention_mode in ["sage", "flash_attention_2"]: return VibeVoiceLoader.load_model(model_name, device, "sdpa", use_llm_4bit) 252 | # elif attention_mode == "sdpa": return VibeVoiceLoader.load_model(model_name, device, "eager", use_llm_4bit) 253 | # else: 254 | raise RuntimeError(f"Failed to load model even with eager attention: {e}") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 |
10 | A custom node for ComfyUI that integrates Microsoft's VibeVoice, a frontier model for generating expressive, long-form, multi-speaker conversational audio.
11 |
12 |
13 | Report Bug
14 | ·
15 | Request Feature
16 |
17 |
18 | [![Stargazers][stars-shield]][stars-url]
19 | [![Issues][issues-shield]][issues-url]
20 | [![Contributors][contributors-shield]][contributors-url]
21 | [![Forks][forks-shield]][forks-url]
22 |
33 |