├── src └── chatterbox │ ├── models │ ├── __init__.py │ ├── s3gen │ │ ├── transformer │ │ │ ├── __init__.py │ │ │ ├── activation.py │ │ │ ├── positionwise_feed_forward.py │ │ │ ├── convolution.py │ │ │ ├── encoder_layer.py │ │ │ └── embedding.py │ │ ├── const.py │ │ ├── __init__.py │ │ ├── configs.py │ │ ├── f0_predictor.py │ │ ├── utils │ │ │ ├── class_utils.py │ │ │ ├── mel.py │ │ │ └── mask.py │ │ ├── matcha │ │ │ └── flow_matching.py │ │ ├── flow_matching.py │ │ ├── flow.py │ │ ├── decoder.py │ │ └── s3gen.py │ ├── t3 │ │ ├── __init__.py │ │ ├── modules │ │ │ ├── t3_config.py │ │ │ ├── learned_pos_emb.py │ │ │ ├── cond_enc.py │ │ │ └── perceiver.py │ │ ├── llama_configs.py │ │ └── inference │ │ │ ├── t3_hf_backend.py │ │ │ └── alignment_stream_analyzer.py │ ├── tokenizers │ │ ├── __init__.py │ │ └── tokenizer.py │ ├── voice_encoder │ │ ├── __init__.py │ │ ├── config.py │ │ ├── melspec.py │ │ └── voice_encoder.py │ ├── utils.py │ └── s3tokenizer │ │ ├── __init__.py │ │ └── s3tokenizer.py │ ├── __init__.py │ ├── vc.py │ └── tts.py ├── workflow-examples ├── ChatterboxTTS-workflow.png └── ChatterboxTTS-workflow.json ├── requirements.txt ├── pyproject.toml ├── .github └── workflows │ └── publish.yml ├── __init__.py ├── .gitignore ├── modules └── chatterbox_handler.py ├── README.md └── LICENSE /src/chatterbox/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/chatterbox/models/s3gen/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/chatterbox/models/s3gen/const.py: -------------------------------------------------------------------------------- 1 | S3GEN_SR = 24000 2 | -------------------------------------------------------------------------------- /src/chatterbox/models/t3/__init__.py: -------------------------------------------------------------------------------- 1 | from .t3 import T3 2 | -------------------------------------------------------------------------------- /src/chatterbox/models/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .tokenizer import EnTokenizer 2 | -------------------------------------------------------------------------------- /src/chatterbox/models/voice_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .voice_encoder import VoiceEncoder, VoiceEncConfig 2 | -------------------------------------------------------------------------------- /src/chatterbox/models/s3gen/__init__.py: -------------------------------------------------------------------------------- 1 | from .s3gen import S3Token2Wav as S3Gen 2 | from .const import S3GEN_SR 3 | -------------------------------------------------------------------------------- /workflow-examples/ChatterboxTTS-workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wildminder/ComfyUI-Chatterbox/HEAD/workflow-examples/ChatterboxTTS-workflow.png -------------------------------------------------------------------------------- /src/chatterbox/models/utils.py: -------------------------------------------------------------------------------- 1 | class AttrDict(dict): 2 | def __init__(self, *args, **kwargs): 3 | super(AttrDict, self).__init__(*args, **kwargs) 4 | self.__dict__ = self 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchaudio 3 | librosa 4 | numpy 5 | huggingface_hub 6 | einops 7 | scipy 8 | tokenizers 9 | soundfile 10 | s3tokenizer 11 | conformer 12 | safetensors 13 | transformers 14 | diffusers -------------------------------------------------------------------------------- /src/chatterbox/models/s3gen/configs.py: -------------------------------------------------------------------------------- 1 | from ..utils import AttrDict 2 | 3 | CFM_PARAMS = AttrDict({ 4 | "sigma_min": 1e-06, 5 | "solver": "euler", 6 | "t_scheduler": "cosine", 7 | "training_cfg_rate": 0.2, 8 | "inference_cfg_rate": 0.7, 9 | "reg_loss_type": "l1" 10 | }) 11 | -------------------------------------------------------------------------------- /src/chatterbox/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from importlib.metadata import version 3 | except ImportError: 4 | from importlib_metadata import version # For Python <3.8 5 | 6 | # __version__ = version("chatterbox-tts") 7 | 8 | from .tts import ChatterboxTTS 9 | from .vc import ChatterboxVC 10 | -------------------------------------------------------------------------------- /src/chatterbox/models/voice_encoder/config.py: -------------------------------------------------------------------------------- 1 | class VoiceEncConfig: 2 | num_mels = 40 3 | sample_rate = 16000 4 | speaker_embed_size = 256 5 | ve_hidden_size = 256 6 | flatten_lstm_params = False 7 | n_fft = 400 8 | hop_size = 160 9 | win_size = 400 10 | fmax = 8000 11 | fmin = 0 12 | preemphasis = 0. 13 | mel_power = 2.0 14 | mel_type = "amp" 15 | normalized_mels = False 16 | ve_partial_frames = 160 17 | ve_final_relu = True 18 | stft_magnitude_min = 1e-4 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "ComfyUI-ChatterboxTTS" 3 | description = "ComfyUI Chatterbox TTS & Voice Conversion Node" 4 | version = "1.2.1" 5 | license = {file = "LICENSE"} 6 | dependencies = ["torch", "torchaudio", "librosa", "numpy", "huggingface_hub", "einops", "scipy", "tokenizers", "soundfile", "s3tokenizer", "tqdm", "conformer", "safetensors", "transformers", "diffusers"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/wildminder/ComfyUI-Chatterbox" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "wildai" 14 | DisplayName = "ComfyUI-ChatterboxTTS" 15 | Icon = "" 16 | 17 | -------------------------------------------------------------------------------- /src/chatterbox/models/s3tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .s3tokenizer import ( 2 | S3_SR, 3 | S3_HOP, 4 | S3_TOKEN_HOP, 5 | S3_TOKEN_RATE, 6 | SPEECH_VOCAB_SIZE, 7 | S3Tokenizer, 8 | ) 9 | 10 | 11 | SOS = SPEECH_VOCAB_SIZE 12 | EOS = SPEECH_VOCAB_SIZE + 1 13 | 14 | 15 | 16 | def drop_invalid_tokens(x): 17 | """Drop SoS and EoS""" 18 | assert len(x.shape) == 1 or (len(x.shape) == 2 and x.shape[0] == 1), "only batch size of one allowed for now" 19 | if SOS in x: 20 | s = (x == SOS).nonzero(as_tuple=True)[0].squeeze(0) + 1 21 | else: 22 | s = 0 23 | 24 | if EOS in x: 25 | e = (x == EOS).nonzero(as_tuple=True)[0].squeeze(0) 26 | else: 27 | e = None 28 | 29 | x = x[s: e] 30 | return x 31 | -------------------------------------------------------------------------------- /src/chatterbox/models/t3/modules/t3_config.py: -------------------------------------------------------------------------------- 1 | from ..llama_configs import LLAMA_CONFIGS 2 | 3 | 4 | class T3Config: 5 | start_text_token = 255 6 | stop_text_token = 0 7 | text_tokens_dict_size = 704 8 | max_text_tokens = 2048 9 | 10 | start_speech_token = 6561 11 | stop_speech_token = 6562 12 | speech_tokens_dict_size = 8194 13 | max_speech_tokens = 4096 14 | 15 | llama_config_name = "Llama_520M" 16 | input_pos_emb = "learned" 17 | speech_cond_prompt_len = 150 18 | 19 | # For T3CondEnc 20 | encoder_type = "voice_encoder" 21 | speaker_embed_size = 256 22 | use_perceiver_resampler = True 23 | emotion_adv = True 24 | 25 | @property 26 | def n_channels(self): 27 | return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"] 28 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | permissions: 12 | issues: write 13 | 14 | jobs: 15 | publish-node: 16 | name: Publish Custom Node to registry 17 | runs-on: ubuntu-latest 18 | if: ${{ github.repository_owner == 'wildminder' }} 19 | steps: 20 | - name: Check out code 21 | uses: actions/checkout@v4 22 | with: 23 | submodules: true 24 | - name: Publish Custom Node 25 | uses: Comfy-Org/publish-node-action@v1 26 | with: 27 | ## Add your own personal access token to your Github Repository secrets and reference it here. 28 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 29 | -------------------------------------------------------------------------------- /src/chatterbox/models/t3/modules/learned_pos_emb.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | from torch import nn, Tensor 5 | 6 | 7 | class LearnedPositionEmbeddings(nn.Module): 8 | def __init__(self, seq_len, model_dim, init=.02): 9 | super().__init__() 10 | self.emb = nn.Embedding(seq_len, model_dim) 11 | # Initializing this way is standard for GPT-2 12 | self.emb.weight.data.normal_(mean=0.0, std=init) 13 | 14 | def forward(self, x): 15 | """ 16 | Returns positional embeddings for index 0 up to the length of x 17 | """ 18 | sl = x.shape[1] 19 | return self.emb(torch.arange(0, sl, device=x.device)) 20 | 21 | def get_fixed_embedding(self, idx: 'Union[int, Tensor]'): 22 | """ 23 | Args: 24 | idx: scalar int or an integer tensor of shape (T,) or (B, T) 25 | Returns: 26 | positional embeddings for given indices, shape (B, T, dim), ie (1, 1, dim) for int input 27 | """ 28 | device = self.emb.weight.device 29 | idx = idx.to(device) if torch.is_tensor(idx) else torch.tensor(idx, device=device) 30 | idx = torch.atleast_2d(idx) 31 | assert idx.ndim == 2 32 | return self.emb(idx) # (B, T, dim) 33 | -------------------------------------------------------------------------------- /src/chatterbox/models/t3/llama_configs.py: -------------------------------------------------------------------------------- 1 | LLAMA_520M_CONFIG_DICT = dict( 2 | # Arbitrary small number that won't cause problems when loading. 3 | # These param are unused due to custom input layers. 4 | vocab_size=8, 5 | # default params needed for loading most pretrained 1B weights 6 | max_position_embeddings=131072, 7 | hidden_size=1024, 8 | intermediate_size=4096, 9 | num_hidden_layers=30, 10 | num_attention_heads=16, 11 | # This is required because AlignmentStreamAnalyzer needs to inspect attention weights, 12 | # which is not supported by optimized backends like FlashAttention (SDPA). 13 | # This informs transformers that the fallback to the slower path is intentional, silencing the warning. 14 | attn_implementation="eager", 15 | #attn_implementation="sdpa", 16 | head_dim=64, 17 | tie_word_embeddings=False, 18 | hidden_act="silu", 19 | attention_bias=False, 20 | attention_dropout=0.0, 21 | initializer_range=0.02, 22 | mlp_bias=False, 23 | model_type="llama", 24 | num_key_value_heads=16, 25 | pretraining_tp=1, 26 | rms_norm_eps=1e-05, 27 | rope_scaling=dict( 28 | factor=8.0, 29 | high_freq_factor=4.0, 30 | low_freq_factor=1.0, 31 | original_max_position_embeddings=8192, 32 | rope_type="llama3" 33 | ), 34 | rope_theta=500000.0, 35 | torch_dtype="bfloat16", 36 | use_cache=True, 37 | ) 38 | 39 | LLAMA_CONFIGS = { 40 | "Llama_520M": LLAMA_520M_CONFIG_DICT, 41 | } 42 | -------------------------------------------------------------------------------- /src/chatterbox/models/tokenizers/tokenizer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from tokenizers import Tokenizer 5 | 6 | 7 | # Special tokens 8 | SOT = "[START]" 9 | EOT = "[STOP]" 10 | UNK = "[UNK]" 11 | SPACE = "[SPACE]" 12 | SPECIAL_TOKENS = [SOT, EOT, UNK, SPACE, "[PAD]", "[SEP]", "[CLS]", "[MASK]"] 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | class EnTokenizer: 17 | def __init__(self, vocab_file_path): 18 | self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path) 19 | self.check_vocabset_sot_eot() 20 | 21 | def check_vocabset_sot_eot(self): 22 | voc = self.tokenizer.get_vocab() 23 | assert SOT in voc 24 | assert EOT in voc 25 | 26 | def text_to_tokens(self, text: str): 27 | text_tokens = self.encode(text) 28 | text_tokens = torch.IntTensor(text_tokens).unsqueeze(0) 29 | return text_tokens 30 | 31 | def encode( self, txt: str, verbose=False): 32 | """ 33 | clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer 34 | """ 35 | txt = txt.replace(' ', SPACE) 36 | code = self.tokenizer.encode(txt) 37 | ids = code.ids 38 | return ids 39 | 40 | def decode(self, seq): 41 | if isinstance(seq, torch.Tensor): 42 | seq = seq.cpu().numpy() 43 | 44 | txt: str = self.tokenizer.decode(seq, 45 | skip_special_tokens=False) 46 | txt = txt.replace(' ', '') 47 | txt = txt.replace(SPACE, ' ') 48 | txt = txt.replace(EOT, '') 49 | txt = txt.replace(UNK, '') 50 | return txt 51 | -------------------------------------------------------------------------------- /src/chatterbox/models/s3gen/f0_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn.utils.parametrizations import weight_norm 17 | 18 | 19 | class ConvRNNF0Predictor(nn.Module): 20 | def __init__(self, 21 | num_class: int = 1, 22 | in_channels: int = 80, 23 | cond_channels: int = 512 24 | ): 25 | super().__init__() 26 | 27 | self.num_class = num_class 28 | self.condnet = nn.Sequential( 29 | weight_norm( 30 | nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) 31 | ), 32 | nn.ELU(), 33 | weight_norm( 34 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 35 | ), 36 | nn.ELU(), 37 | weight_norm( 38 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 39 | ), 40 | nn.ELU(), 41 | weight_norm( 42 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 43 | ), 44 | nn.ELU(), 45 | weight_norm( 46 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 47 | ), 48 | nn.ELU(), 49 | ) 50 | self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class) 51 | 52 | def forward(self, x: torch.Tensor) -> torch.Tensor: 53 | x = self.condnet(x) 54 | x = x.transpose(1, 2) 55 | return torch.abs(self.classifier(x).squeeze(-1)) 56 | -------------------------------------------------------------------------------- /src/chatterbox/models/voice_encoder/melspec.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | from scipy import signal 4 | import numpy as np 5 | import librosa 6 | 7 | 8 | @lru_cache() 9 | def mel_basis(hp): 10 | assert hp.fmax <= hp.sample_rate // 2 11 | return librosa.filters.mel( 12 | sr=hp.sample_rate, 13 | n_fft=hp.n_fft, 14 | n_mels=hp.num_mels, 15 | fmin=hp.fmin, 16 | fmax=hp.fmax) # -> (nmel, nfreq) 17 | 18 | 19 | def preemphasis(wav, hp): 20 | assert hp.preemphasis != 0 21 | wav = signal.lfilter([1, -hp.preemphasis], [1], wav) 22 | wav = np.clip(wav, -1, 1) 23 | return wav 24 | 25 | 26 | def melspectrogram(wav, hp, pad=True): 27 | # Run through pre-emphasis 28 | if hp.preemphasis > 0: 29 | wav = preemphasis(wav, hp) 30 | assert np.abs(wav).max() - 1 < 1e-07 31 | 32 | # Do the stft 33 | spec_complex = _stft(wav, hp, pad=pad) 34 | 35 | # Get the magnitudes 36 | spec_magnitudes = np.abs(spec_complex) 37 | 38 | if hp.mel_power != 1.0: 39 | spec_magnitudes **= hp.mel_power 40 | 41 | # Get the mel and convert magnitudes->db 42 | mel = np.dot(mel_basis(hp), spec_magnitudes) 43 | if hp.mel_type == "db": 44 | mel = _amp_to_db(mel, hp) 45 | 46 | # Normalise the mel from db to 0,1 47 | if hp.normalized_mels: 48 | mel = _normalize(mel, hp).astype(np.float32) 49 | 50 | assert not pad or mel.shape[1] == 1 + len(wav) // hp.hop_size # Sanity check 51 | return mel # (M, T) 52 | 53 | 54 | def _stft(y, hp, pad=True): 55 | # NOTE: after 0.8, pad mode defaults to constant, setting this to reflect for 56 | # historical consistency and streaming-version consistency 57 | return librosa.stft( 58 | y, 59 | n_fft=hp.n_fft, 60 | hop_length=hp.hop_size, 61 | win_length=hp.win_size, 62 | center=pad, 63 | pad_mode="reflect", 64 | ) 65 | 66 | 67 | def _amp_to_db(x, hp): 68 | return 20 * np.log10(np.maximum(hp.stft_magnitude_min, x)) 69 | 70 | 71 | def _db_to_amp(x): 72 | return np.power(10.0, x * 0.05) 73 | 74 | 75 | def _normalize(s, hp, headroom_db=15): 76 | min_level_db = 20 * np.log10(hp.stft_magnitude_min) 77 | s = (s - min_level_db) / (-min_level_db + headroom_db) 78 | return s 79 | -------------------------------------------------------------------------------- /src/chatterbox/models/s3gen/utils/class_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright [2023-11-28] 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import torch 16 | 17 | from ..transformer.activation import Swish 18 | from ..transformer.subsampling import ( 19 | LinearNoSubsampling, 20 | EmbedinigNoSubsampling, 21 | Conv1dSubsampling2, 22 | Conv2dSubsampling4, 23 | Conv2dSubsampling6, 24 | Conv2dSubsampling8, 25 | ) 26 | from ..transformer.embedding import ( 27 | PositionalEncoding, 28 | RelPositionalEncoding, 29 | WhisperPositionalEncoding, 30 | LearnablePositionalEncoding, 31 | NoPositionalEncoding) 32 | from ..transformer.attention import (MultiHeadedAttention, 33 | RelPositionMultiHeadedAttention) 34 | from ..transformer.embedding import EspnetRelPositionalEncoding 35 | from ..transformer.subsampling import LegacyLinearNoSubsampling 36 | 37 | 38 | COSYVOICE_ACTIVATION_CLASSES = { 39 | "hardtanh": torch.nn.Hardtanh, 40 | "tanh": torch.nn.Tanh, 41 | "relu": torch.nn.ReLU, 42 | "selu": torch.nn.SELU, 43 | "swish": getattr(torch.nn, "SiLU", Swish), 44 | "gelu": torch.nn.GELU, 45 | } 46 | 47 | COSYVOICE_SUBSAMPLE_CLASSES = { 48 | "linear": LinearNoSubsampling, 49 | "linear_legacy": LegacyLinearNoSubsampling, 50 | "embed": EmbedinigNoSubsampling, 51 | "conv1d2": Conv1dSubsampling2, 52 | "conv2d": Conv2dSubsampling4, 53 | "conv2d6": Conv2dSubsampling6, 54 | "conv2d8": Conv2dSubsampling8, 55 | 'paraformer_dummy': torch.nn.Identity 56 | } 57 | 58 | COSYVOICE_EMB_CLASSES = { 59 | "embed": PositionalEncoding, 60 | "abs_pos": PositionalEncoding, 61 | "rel_pos": RelPositionalEncoding, 62 | "rel_pos_espnet": EspnetRelPositionalEncoding, 63 | "no_pos": NoPositionalEncoding, 64 | "abs_pos_whisper": WhisperPositionalEncoding, 65 | "embed_learnable_pe": LearnablePositionalEncoding, 66 | } 67 | 68 | COSYVOICE_ATTENTION_CLASSES = { 69 | "selfattn": MultiHeadedAttention, 70 | "rel_selfattn": RelPositionMultiHeadedAttention, 71 | } 72 | -------------------------------------------------------------------------------- /src/chatterbox/models/s3gen/utils/mel.py: -------------------------------------------------------------------------------- 1 | """mel-spectrogram extraction in Matcha-TTS""" 2 | from librosa.filters import mel as librosa_mel_fn 3 | import torch 4 | import numpy as np 5 | 6 | 7 | # NOTE: they decalred these global vars 8 | mel_basis = {} 9 | hann_window = {} 10 | 11 | 12 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 13 | return torch.log(torch.clamp(x, min=clip_val) * C) 14 | 15 | 16 | def spectral_normalize_torch(magnitudes): 17 | output = dynamic_range_compression_torch(magnitudes) 18 | return output 19 | 20 | """ 21 | feat_extractor: !name:matcha.utils.audio.mel_spectrogram 22 | n_fft: 1920 23 | num_mels: 80 24 | sampling_rate: 24000 25 | hop_size: 480 26 | win_size: 1920 27 | fmin: 0 28 | fmax: 8000 29 | center: False 30 | 31 | """ 32 | 33 | def mel_spectrogram(y, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920, 34 | fmin=0, fmax=8000, center=False): 35 | """Copied from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/audio.py 36 | Set default values according to Cosyvoice's config. 37 | """ 38 | 39 | if isinstance(y, np.ndarray): 40 | y = torch.tensor(y).float() 41 | 42 | if len(y.shape) == 1: 43 | y = y[None, ] 44 | 45 | if torch.min(y) < -1.0: 46 | print("min value is ", torch.min(y)) 47 | if torch.max(y) > 1.0: 48 | print("max value is ", torch.max(y)) 49 | 50 | global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned 51 | if f"{str(fmax)}_{str(y.device)}" not in mel_basis: 52 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 53 | mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) 54 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 55 | 56 | y = torch.nn.functional.pad( 57 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" 58 | ) 59 | y = y.squeeze(1) 60 | 61 | spec = torch.view_as_real( 62 | torch.stft( 63 | y, 64 | n_fft, 65 | hop_length=hop_size, 66 | win_length=win_size, 67 | window=hann_window[str(y.device)], 68 | center=center, 69 | pad_mode="reflect", 70 | normalized=False, 71 | onesided=True, 72 | return_complex=True, 73 | ) 74 | ) 75 | 76 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 77 | 78 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) 79 | spec = spectral_normalize_torch(spec) 80 | 81 | return spec 82 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import types 4 | import logging 5 | import folder_paths 6 | from importlib import metadata as importlib_metadata 7 | 8 | # Configure a logger for the entire custom node package 9 | logger = logging.getLogger(__name__) 10 | logger.setLevel(logging.WARNING) 11 | 12 | # Add a handler if none exist to avoid duplicate logs 13 | if not logger.hasHandlers(): 14 | # Use stdout for info, stderr for errors 15 | handler = logging.StreamHandler(sys.stdout) 16 | formatter = logging.Formatter(f"[%(name)s] %(message)s") 17 | handler.setFormatter(formatter) 18 | logger.addHandler(handler) 19 | 20 | 21 | # Monkey-Patch for 'chatterbox-tts' version 22 | # original_importlib_version_func = importlib_metadata.version 23 | # def patched_version_lookup(package_name): 24 | # if package_name == "chatterbox-tts": 25 | # return "local-vendored" 26 | # return original_importlib_version_func(package_name) 27 | # importlib_metadata.version = patched_version_lookup 28 | 29 | 30 | # easily check if the real 'perth' is being used. 31 | try: 32 | import perth 33 | # A simple check to ensure it's not our mock from a previous run 34 | if not hasattr(perth, '_is_mock'): 35 | logger.info("Found and using 'resemble-perth' library for watermarking.") 36 | except ImportError: 37 | logger.warning("'resemble-perth' not found. Watermarking will be unavailable.") 38 | class DummyPerthImplicitWatermarker: 39 | def apply_watermark(self, wav, sample_rate): 40 | logger.warning("Watermarking skipped: 'resemble-perth' is not installed.") 41 | return wav 42 | perth_mock = types.ModuleType('perth') 43 | perth_mock.PerthImplicitWatermarker = DummyPerthImplicitWatermarker 44 | # Flag to identify our mock module 45 | perth_mock._is_mock = True 46 | sys.modules['perth'] = perth_mock 47 | 48 | 49 | current_dir = os.path.dirname(os.path.abspath(__file__)) 50 | src_dir = os.path.join(current_dir, "src") 51 | if src_dir not in sys.path: 52 | sys.path.insert(0, src_dir) 53 | 54 | 55 | from .nodes import ChatterboxTTSNode, ChatterboxVCNode 56 | from .modules.chatterbox_handler import CHATTERBOX_MODEL_SUBDIR 57 | 58 | NODE_CLASS_MAPPINGS = { 59 | "ChatterboxTTS": ChatterboxTTSNode, 60 | "ChatterboxVC": ChatterboxVCNode, 61 | } 62 | 63 | NODE_DISPLAY_NAME_MAPPINGS = { 64 | "ChatterboxTTS": "Chatterbox TTS 📢", 65 | "ChatterboxVC": "Chatterbox Voice Conversion 🗣️", 66 | } 67 | 68 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 69 | 70 | # Model path setup for ComfyUI 71 | chatterbox_models_full_path = os.path.join(folder_paths.models_dir, CHATTERBOX_MODEL_SUBDIR) 72 | if not os.path.exists(chatterbox_models_full_path): 73 | try: 74 | os.makedirs(chatterbox_models_full_path, exist_ok=True) 75 | except OSError as e: 76 | logger.error(f"Error creating models directory {chatterbox_models_full_path}: {e}") 77 | 78 | # Register the tts/chatterbox path with ComfyUI 79 | tts_chatterbox_path = os.path.join(folder_paths.models_dir, "tts") 80 | if "tts" not in folder_paths.folder_names_and_paths: 81 | supported_exts = folder_paths.supported_pt_extensions.union({".safetensors"}) 82 | folder_paths.folder_names_and_paths["tts"] = ([tts_chatterbox_path], supported_exts) 83 | else: 84 | if tts_chatterbox_path not in folder_paths.folder_names_and_paths["tts"][0]: 85 | folder_paths.folder_names_and_paths["tts"][0].append(tts_chatterbox_path) -------------------------------------------------------------------------------- /src/chatterbox/models/s3gen/transformer/activation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe) 2 | # 2020 Northwestern Polytechnical University (Pengcheng Guo) 3 | # 2020 Mobvoi Inc (Binbin Zhang) 4 | # 2024 Alibaba Inc (Xiang Lyu) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """Swish() activation function for Conformer.""" 18 | 19 | import torch 20 | from torch import nn, sin, pow 21 | from torch.nn import Parameter 22 | 23 | 24 | class Swish(torch.nn.Module): 25 | """Construct an Swish object.""" 26 | 27 | def forward(self, x: torch.Tensor) -> torch.Tensor: 28 | """Return Swish activation function.""" 29 | return x * torch.sigmoid(x) 30 | 31 | 32 | # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. 33 | # LICENSE is in incl_licenses directory. 34 | class Snake(nn.Module): 35 | ''' 36 | Implementation of a sine-based periodic activation function 37 | Shape: 38 | - Input: (B, C, T) 39 | - Output: (B, C, T), same shape as the input 40 | Parameters: 41 | - alpha - trainable parameter 42 | References: 43 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 44 | https://arxiv.org/abs/2006.08195 45 | Examples: 46 | >>> a1 = snake(256) 47 | >>> x = torch.randn(256) 48 | >>> x = a1(x) 49 | ''' 50 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 51 | ''' 52 | Initialization. 53 | INPUT: 54 | - in_features: shape of the input 55 | - alpha: trainable parameter 56 | alpha is initialized to 1 by default, higher values = higher-frequency. 57 | alpha will be trained along with the rest of your model. 58 | ''' 59 | super(Snake, self).__init__() 60 | self.in_features = in_features 61 | 62 | # initialize alpha 63 | self.alpha_logscale = alpha_logscale 64 | if self.alpha_logscale: # log scale alphas initialized to zeros 65 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 66 | else: # linear scale alphas initialized to ones 67 | self.alpha = Parameter(torch.ones(in_features) * alpha) 68 | 69 | self.alpha.requires_grad = alpha_trainable 70 | 71 | self.no_div_by_zero = 0.000000001 72 | 73 | def forward(self, x): 74 | ''' 75 | Forward pass of the function. 76 | Applies the function to the input elementwise. 77 | Snake ∶= x + 1/a * sin^2 (xa) 78 | ''' 79 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 80 | if self.alpha_logscale: 81 | alpha = torch.exp(alpha) 82 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 83 | 84 | return x 85 | -------------------------------------------------------------------------------- /src/chatterbox/models/t3/modules/cond_enc.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | import torch 5 | from torch import nn, Tensor 6 | 7 | from .perceiver import Perceiver 8 | from .t3_config import T3Config 9 | 10 | 11 | @dataclass 12 | class T3Cond: 13 | """ 14 | Dataclass container for most / all conditioning info. 15 | TODO: serialization methods aren't used, keeping them around for convenience 16 | """ 17 | 18 | speaker_emb: Tensor 19 | clap_emb: Optional[Tensor] = None 20 | cond_prompt_speech_tokens: Optional[Tensor] = None 21 | cond_prompt_speech_emb: Optional[Tensor] = None 22 | emotion_adv: Optional[Tensor] = 0.5 23 | 24 | def to(self, *, device=None, dtype=None): 25 | "Cast to a device and dtype. Dtype casting is ignored for long/int tensors." 26 | for k, v in self.__dict__.items(): 27 | if torch.is_tensor(v): 28 | is_fp = type(v.view(-1)[0].item()) is not int 29 | setattr(self, k, v.to(device=device, dtype=dtype if is_fp else None)) 30 | return self 31 | 32 | def save(self, fpath): 33 | torch.save(self.__dict__, fpath) 34 | 35 | @staticmethod 36 | def load(fpath, map_location="cpu"): 37 | kwargs = torch.load(fpath, map_location=map_location, weights_only=True) 38 | return T3Cond(**kwargs) 39 | 40 | 41 | class T3CondEnc(nn.Module): 42 | """ 43 | Handle all non-text conditioning, like speaker embeddings / prompts, CLAP, emotion, etc. 44 | """ 45 | 46 | def __init__(self, hp: T3Config): 47 | super().__init__() 48 | self.hp = hp 49 | if hp.encoder_type == "voice_encoder": 50 | self.spkr_enc = nn.Linear(hp.speaker_embed_size, hp.n_channels) 51 | else: 52 | raise NotImplementedError(str(hp.encoder_type)) 53 | 54 | # emotion adv 55 | self.emotion_adv_fc = None 56 | if hp.emotion_adv: 57 | self.emotion_adv_fc = nn.Linear(1, hp.n_channels, bias=False) 58 | 59 | # perceiver resampler 60 | self.perceiver = None 61 | if hp.use_perceiver_resampler: 62 | self.perceiver = Perceiver() 63 | 64 | def forward(self, cond: T3Cond): 65 | # Validate 66 | assert (cond.cond_prompt_speech_tokens is None) == (cond.cond_prompt_speech_emb is None), \ 67 | "no embeddings for cond_prompt_speech_tokens" 68 | 69 | # Speaker embedding projection 70 | cond_spkr = self.spkr_enc(cond.speaker_emb.view(-1, self.hp.speaker_embed_size))[:, None] # (B, 1, dim) 71 | empty = torch.zeros_like(cond_spkr[:, :0]) # (B, 0, dim) 72 | 73 | # TODO CLAP 74 | assert cond.clap_emb is None, "clap_embed not implemented" 75 | cond_clap = empty # (B, 0, dim) 76 | 77 | # Cond prompt 78 | cond_prompt_speech_emb = cond.cond_prompt_speech_emb 79 | if cond_prompt_speech_emb is None: 80 | cond_prompt_speech_emb = empty # (B, 0, dim) 81 | elif self.hp.use_perceiver_resampler: 82 | cond_prompt_speech_emb = self.perceiver(cond_prompt_speech_emb) 83 | 84 | # Emotion Adv: must provide a value if this model uses emotion conditioning 85 | cond_emotion_adv = empty # (B, 0, dim) 86 | if self.hp.emotion_adv: 87 | assert cond.emotion_adv is not None 88 | cond_emotion_adv = self.emotion_adv_fc(cond.emotion_adv.view(-1, 1, 1)) 89 | 90 | # Concat and return 91 | cond_embeds = torch.cat(( 92 | cond_spkr, 93 | cond_clap, 94 | cond_prompt_speech_emb, 95 | cond_emotion_adv, 96 | ), dim=1) 97 | return cond_embeds 98 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .github 11 | .idea 12 | .Python 13 | __pycache__ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /modules/chatterbox_handler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import torch 4 | import random 5 | import numpy as np 6 | import folder_paths 7 | from huggingface_hub import hf_hub_download 8 | 9 | from chatterbox.tts import ChatterboxTTS 10 | from chatterbox.vc import ChatterboxVC 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | CHATTERBOX_MODEL_SUBDIR = os.path.join("tts", "chatterbox") 15 | CHATTERBOX_REPO_ID = "ResembleAI/chatterbox" 16 | CHATTERBOX_FILES_TO_DOWNLOAD = ["ve.safetensors", "t3_cfg.safetensors", "s3gen.safetensors", "tokenizer.json", "conds.pt"] 17 | DEFAULT_MODEL_PACK_NAME = "resembleai_default_voice" 18 | 19 | def get_chatterbox_model_pack_names(): 20 | chatterbox_models_base_path = os.path.join(folder_paths.models_dir, CHATTERBOX_MODEL_SUBDIR) 21 | if not os.path.isdir(chatterbox_models_base_path): 22 | os.makedirs(chatterbox_models_base_path, exist_ok=True) 23 | return [DEFAULT_MODEL_PACK_NAME] 24 | packs = [d for d in os.listdir(chatterbox_models_base_path) if os.path.isdir(os.path.join(chatterbox_models_base_path, d))] 25 | # Ensure default is first if it exists 26 | if DEFAULT_MODEL_PACK_NAME in packs: 27 | packs.insert(0, packs.pop(packs.index(DEFAULT_MODEL_PACK_NAME))) 28 | 29 | # Return default even if folder doesn't exist yet, to prompt the user to download it. 30 | return packs if packs else [DEFAULT_MODEL_PACK_NAME] 31 | 32 | def get_model_pack_path(model_pack_name): 33 | # Added check for None or empty string to prevent errors 34 | if not model_pack_name: 35 | return None 36 | return os.path.join(folder_paths.models_dir, CHATTERBOX_MODEL_SUBDIR, model_pack_name) 37 | 38 | def _download_file_from_hf(repo_id, filename, local_dir): 39 | destination = os.path.join(local_dir, filename) 40 | if not os.path.exists(destination): 41 | logger.info(f"Downloading '{filename}' from '{repo_id}'...") 42 | try: 43 | hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir, local_dir_use_symlinks=False, resume_download=True) 44 | logger.info(f"Successfully downloaded '{filename}'.") 45 | return True 46 | except Exception as e: 47 | logger.error(f"Failed to download '{filename}': {e}") 48 | if os.path.exists(destination + ".incomplete"): os.remove(destination + ".incomplete") 49 | return False 50 | return True 51 | 52 | def download_chatterbox_model_pack_if_missing(model_pack_name): 53 | ckpt_dir = get_model_pack_path(model_pack_name) 54 | if not ckpt_dir: 55 | logger.warning(f"Invalid model pack name '{model_pack_name}', cannot download.") 56 | return False 57 | os.makedirs(ckpt_dir, exist_ok=True) 58 | all_files_ok = all(_download_file_from_hf(CHATTERBOX_REPO_ID, f, ckpt_dir) for f in CHATTERBOX_FILES_TO_DOWNLOAD) 59 | if not all_files_ok: 60 | logger.error(f"Some files failed to download for model pack '{model_pack_name}'. Check logs.") 61 | return all_files_ok 62 | 63 | def load_chatterbox_models(model_pack_name, device): 64 | """Loads both TTS and VC models for a given pack onto a specified device.""" 65 | ckpt_dir = get_model_pack_path(model_pack_name) 66 | if not ckpt_dir: 67 | raise ValueError(f"Invalid model_pack_name: {model_pack_name}") 68 | 69 | if not download_chatterbox_model_pack_if_missing(model_pack_name): 70 | logger.warning(f"Not all model files could be verified for '{model_pack_name}'. Loading may fail.") 71 | 72 | if not os.path.isdir(ckpt_dir): 73 | raise FileNotFoundError(f"Model pack directory '{model_pack_name}' not found at '{ckpt_dir}'.") 74 | 75 | try: 76 | logger.info(f"Loading Chatterbox TTS model from {ckpt_dir} onto {device}") 77 | tts_model = ChatterboxTTS.from_local(ckpt_dir, device=device) 78 | except Exception as e: 79 | logger.error(f"Error loading ChatterboxTTS from '{ckpt_dir}': {e}", exc_info=True) 80 | raise 81 | 82 | try: 83 | logger.info(f"Loading Chatterbox VC model from {ckpt_dir} onto {device}") 84 | vc_model = ChatterboxVC.from_local(ckpt_dir, device=device) 85 | except Exception as e: 86 | logger.error(f"Error loading ChatterboxVC from '{ckpt_dir}': {e}", exc_info=True) 87 | raise 88 | 89 | return tts_model, vc_model 90 | 91 | def set_chatterbox_seed(seed: int): 92 | MAX_NUMPY_SEED = 2**32 - 1 93 | actual_seed_for_torch_random = random.randint(1, 0xffffffffffffffff) if seed == 0 else seed 94 | actual_seed_for_numpy = random.randint(1, MAX_NUMPY_SEED) if seed == 0 else (seed % MAX_NUMPY_SEED) 95 | torch.manual_seed(actual_seed_for_torch_random) 96 | if torch.cuda.is_available(): torch.cuda.manual_seed_all(actual_seed_for_torch_random) 97 | random.seed(actual_seed_for_torch_random) 98 | np.random.seed(actual_seed_for_numpy) -------------------------------------------------------------------------------- /src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Positionwise feed forward layer definition.""" 16 | 17 | import torch 18 | 19 | 20 | class PositionwiseFeedForward(torch.nn.Module): 21 | """Positionwise feed forward layer. 22 | 23 | FeedForward are appied on each position of the sequence. 24 | The output dim is same with the input dim. 25 | 26 | Args: 27 | idim (int): Input dimenstion. 28 | hidden_units (int): The number of hidden units. 29 | dropout_rate (float): Dropout rate. 30 | activation (torch.nn.Module): Activation function 31 | """ 32 | 33 | def __init__( 34 | self, 35 | idim: int, 36 | hidden_units: int, 37 | dropout_rate: float, 38 | activation: torch.nn.Module = torch.nn.ReLU(), 39 | ): 40 | """Construct a PositionwiseFeedForward object.""" 41 | super(PositionwiseFeedForward, self).__init__() 42 | self.w_1 = torch.nn.Linear(idim, hidden_units) 43 | self.activation = activation 44 | self.dropout = torch.nn.Dropout(dropout_rate) 45 | self.w_2 = torch.nn.Linear(hidden_units, idim) 46 | 47 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 48 | """Forward function. 49 | 50 | Args: 51 | xs: input tensor (B, L, D) 52 | Returns: 53 | output tensor, (B, L, D) 54 | """ 55 | return self.w_2(self.dropout(self.activation(self.w_1(xs)))) 56 | 57 | 58 | class MoEFFNLayer(torch.nn.Module): 59 | """ 60 | Mixture of expert with Positionwise feed forward layer 61 | See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf 62 | The output dim is same with the input dim. 63 | 64 | Modified from https://github.com/Lightning-AI/lit-gpt/pull/823 65 | https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 66 | Args: 67 | n_expert: number of expert. 68 | n_expert_per_token: The actual number of experts used for each frame 69 | idim (int): Input dimenstion. 70 | hidden_units (int): The number of hidden units. 71 | dropout_rate (float): Dropout rate. 72 | activation (torch.nn.Module): Activation function 73 | """ 74 | 75 | def __init__( 76 | self, 77 | n_expert: int, 78 | n_expert_per_token: int, 79 | idim: int, 80 | hidden_units: int, 81 | dropout_rate: float, 82 | activation: torch.nn.Module = torch.nn.ReLU(), 83 | ): 84 | super(MoEFFNLayer, self).__init__() 85 | self.gate = torch.nn.Linear(idim, n_expert, bias=False) 86 | self.experts = torch.nn.ModuleList( 87 | PositionwiseFeedForward(idim, hidden_units, dropout_rate, 88 | activation) for _ in range(n_expert)) 89 | self.n_expert_per_token = n_expert_per_token 90 | 91 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 92 | """Foward function. 93 | Args: 94 | xs: input tensor (B, L, D) 95 | Returns: 96 | output tensor, (B, L, D) 97 | 98 | """ 99 | B, L, D = xs.size( 100 | ) # batch size, sequence length, embedding dimension (idim) 101 | xs = xs.view(-1, D) # (B*L, D) 102 | router = self.gate(xs) # (B*L, n_expert) 103 | logits, indices = torch.topk( 104 | router, self.n_expert_per_token 105 | ) # probs:(B*L, n_expert), indices: (B*L, n_expert) 106 | weights = torch.nn.functional.softmax( 107 | logits, dim=1, 108 | dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token) 109 | output = torch.zeros_like(xs) # (B*L, D) 110 | for i, expert in enumerate(self.experts): 111 | mask = indices == i 112 | batch_idx, ith_expert = torch.where(mask) 113 | output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( 114 | xs[batch_idx]) 115 | return output.view(B, L, D) 116 | -------------------------------------------------------------------------------- /src/chatterbox/models/t3/inference/t3_hf_backend.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn as nn 5 | from transformers import LlamaConfig, LlamaModel, LlamaPreTrainedModel, GenerationMixin 6 | from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions 7 | 8 | 9 | class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin): 10 | """ 11 | Override some HuggingFace interface methods so we can use the standard `generate` method with our 12 | custom embedding / logit layers. 13 | 14 | NOTE: need to extend "*PreTrainedModel" to avoid re-initializing weights! 15 | """ 16 | 17 | def __init__( 18 | self, 19 | config: LlamaConfig, 20 | llama: LlamaModel, 21 | *, 22 | speech_enc, 23 | speech_head, 24 | latents_queue=None, 25 | logits_queue=None, 26 | alignment_stream_analyzer: 'AlignmentStreamAnalyzer'=None, 27 | ): 28 | super().__init__(config) 29 | self.model = llama 30 | self.speech_enc = speech_enc 31 | self.speech_head = speech_head 32 | self._added_cond = False 33 | self.alignment_stream_analyzer = alignment_stream_analyzer 34 | 35 | @torch.inference_mode() 36 | def prepare_inputs_for_generation( 37 | self, input_ids: torch.Tensor, decoder_cond: torch.Tensor, use_cache: bool, past_key_values=None, 38 | # This argument was introduced in some recent version of transformers (>=4.29.1) 39 | cache_position=None 40 | ): 41 | """ 42 | This is a method used by huggingface's generate() method. 43 | Overridden here to apply our custom speech token embedding layer. 44 | 45 | :param input_ids: (B, S) int64 tensors of input tokens. 46 | :param decoder_cond: (B, T, C) float32 tensor of conditioning (prefixed to ) 47 | """ 48 | 49 | # Make use of the kv cache: only the last input ID is new, we trim away all the ones before 50 | if not use_cache: 51 | past_key_values = None 52 | if past_key_values is not None: 53 | input_ids = input_ids[:, -1:] 54 | 55 | # custom speech token embedding layer 56 | inputs_embeds = self.speech_enc(input_ids) 57 | 58 | # prefix decoder conditioning if applicable 59 | if not self._added_cond: 60 | assert past_key_values is not None # should be first step 61 | if decoder_cond.size(0) != inputs_embeds.size(0): 62 | decoder_cond = decoder_cond.expand(inputs_embeds.size(0), -1, -1) 63 | inputs_embeds = torch.cat([decoder_cond, inputs_embeds], dim=1) 64 | self._added_cond = True 65 | 66 | return { 67 | "inputs_embeds": inputs_embeds, 68 | "past_key_values": past_key_values, 69 | "use_cache": use_cache, 70 | } 71 | 72 | @torch.inference_mode() 73 | def forward( 74 | self, 75 | inputs_embeds: torch.Tensor, 76 | past_key_values: Optional[torch.Tensor]=None, 77 | use_cache=True, 78 | output_attentions=False, 79 | output_hidden_states=True, 80 | return_dict=True, 81 | ): 82 | """ 83 | This is a method used by huggingface's generate() method. 84 | Overridden here to apply our custom layer norm and speech logit projection layers. 85 | 86 | :param inputs_embeds: (B, S, C) float32 tensor of conditioning inputs. If past key values are given, 87 | S should be 1. 88 | """ 89 | is_large_input = inputs_embeds.size(1) != 1 90 | has_cache = past_key_values is not None and len(past_key_values) > 0 91 | assert not (is_large_input and has_cache) 92 | assert return_dict 93 | assert output_hidden_states 94 | 95 | tfmr_out = self.model( 96 | inputs_embeds=inputs_embeds, 97 | past_key_values=past_key_values, 98 | use_cache=use_cache, 99 | output_attentions=output_attentions, 100 | output_hidden_states=output_hidden_states, 101 | return_dict=True, 102 | ) 103 | hidden_states = tfmr_out.hidden_states[-1] # (B, seq, dim) 104 | 105 | logits = self.speech_head(hidden_states) 106 | # assert inputs_embeds.size(0) == 1 # (disabled for CFG) 107 | 108 | # NOTE: hallucination handler may modify logits to force emit an EOS token 109 | # logits = self.alignment_stream_analyzer.step(logits) 110 | 111 | return CausalLMOutputWithCrossAttentions( 112 | logits=logits, 113 | past_key_values=tfmr_out.past_key_values, 114 | hidden_states=tfmr_out.hidden_states, 115 | attentions=tfmr_out.attentions, 116 | ) 117 | -------------------------------------------------------------------------------- /src/chatterbox/models/s3gen/matcha/flow_matching.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from .decoder import Decoder 7 | 8 | 9 | class BASECFM(torch.nn.Module, ABC): 10 | def __init__( 11 | self, 12 | n_feats, 13 | cfm_params, 14 | n_spks=1, 15 | spk_emb_dim=128, 16 | ): 17 | super().__init__() 18 | self.n_feats = n_feats 19 | self.n_spks = n_spks 20 | self.spk_emb_dim = spk_emb_dim 21 | self.solver = cfm_params.solver 22 | if hasattr(cfm_params, "sigma_min"): 23 | self.sigma_min = cfm_params.sigma_min 24 | else: 25 | self.sigma_min = 1e-4 26 | 27 | self.estimator = None 28 | 29 | @torch.inference_mode() 30 | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): 31 | """Forward diffusion 32 | 33 | Args: 34 | mu (torch.Tensor): output of encoder 35 | shape: (batch_size, n_feats, mel_timesteps) 36 | mask (torch.Tensor): output_mask 37 | shape: (batch_size, 1, mel_timesteps) 38 | n_timesteps (int): number of diffusion steps 39 | temperature (float, optional): temperature for scaling noise. Defaults to 1.0. 40 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 41 | shape: (batch_size, spk_emb_dim) 42 | cond: Not used but kept for future purposes 43 | 44 | Returns: 45 | sample: generated mel-spectrogram 46 | shape: (batch_size, n_feats, mel_timesteps) 47 | """ 48 | z = torch.randn_like(mu) * temperature 49 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) 50 | return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) 51 | 52 | def solve_euler(self, x, t_span, mu, mask, spks, cond): 53 | """ 54 | Fixed euler solver for ODEs. 55 | Args: 56 | x (torch.Tensor): random noise 57 | t_span (torch.Tensor): n_timesteps interpolated 58 | shape: (n_timesteps + 1,) 59 | mu (torch.Tensor): output of encoder 60 | shape: (batch_size, n_feats, mel_timesteps) 61 | mask (torch.Tensor): output_mask 62 | shape: (batch_size, 1, mel_timesteps) 63 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 64 | shape: (batch_size, spk_emb_dim) 65 | cond: Not used but kept for future purposes 66 | """ 67 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 68 | 69 | # I am storing this because I can later plot it by putting a debugger here and saving it to a file 70 | # Or in future might add like a return_all_steps flag 71 | sol = [] 72 | 73 | for step in range(1, len(t_span)): 74 | dphi_dt = self.estimator(x, mask, mu, t, spks, cond) 75 | 76 | x = x + dt * dphi_dt 77 | t = t + dt 78 | sol.append(x) 79 | if step < len(t_span) - 1: 80 | dt = t_span[step + 1] - t 81 | 82 | return sol[-1] 83 | 84 | def compute_loss(self, x1, mask, mu, spks=None, cond=None): 85 | """Computes diffusion loss 86 | 87 | Args: 88 | x1 (torch.Tensor): Target 89 | shape: (batch_size, n_feats, mel_timesteps) 90 | mask (torch.Tensor): target mask 91 | shape: (batch_size, 1, mel_timesteps) 92 | mu (torch.Tensor): output of encoder 93 | shape: (batch_size, n_feats, mel_timesteps) 94 | spks (torch.Tensor, optional): speaker embedding. Defaults to None. 95 | shape: (batch_size, spk_emb_dim) 96 | 97 | Returns: 98 | loss: conditional flow matching loss 99 | y: conditional flow 100 | shape: (batch_size, n_feats, mel_timesteps) 101 | """ 102 | b, _, t = mu.shape 103 | 104 | # random timestep 105 | t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) 106 | # sample noise p(x_0) 107 | z = torch.randn_like(x1) 108 | 109 | y = (1 - (1 - self.sigma_min) * t) * z + t * x1 110 | u = x1 - (1 - self.sigma_min) * z 111 | 112 | loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / ( 113 | torch.sum(mask) * u.shape[1] 114 | ) 115 | return loss, y 116 | 117 | 118 | class CFM(BASECFM): 119 | def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64): 120 | super().__init__( 121 | n_feats=in_channels, 122 | cfm_params=cfm_params, 123 | n_spks=n_spks, 124 | spk_emb_dim=spk_emb_dim, 125 | ) 126 | 127 | in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) 128 | # Just change the architecture of the estimator here 129 | self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params) 130 | -------------------------------------------------------------------------------- /src/chatterbox/vc.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | import librosa 5 | import torch 6 | import perth 7 | import torch.nn.functional as F 8 | from huggingface_hub import hf_hub_download 9 | from safetensors.torch import load_file 10 | 11 | from .models.s3tokenizer import S3_SR, S3_TOKEN_RATE 12 | from .models.s3gen import S3GEN_SR, S3Gen 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | REPO_ID = "ResembleAI/chatterbox" 17 | 18 | def _pad_wav_to_40ms_multiple(wav: torch.Tensor, sr: int) -> torch.Tensor: 19 | """ 20 | Pads a waveform to be a multiple of 40ms to prevent rounding errors between 21 | the mel spectrogram (20ms hop) and the speech tokenizer (40ms hop). 22 | """ 23 | S3_TOKEN_DURATION_S = 1 / S3_TOKEN_RATE # 0.04 seconds 24 | samples_per_token = int(sr * S3_TOKEN_DURATION_S) 25 | current_samples = wav.shape[-1] 26 | remainder = current_samples % samples_per_token 27 | if remainder != 0: 28 | padding_needed = samples_per_token - remainder 29 | padded_wav = F.pad(wav, (0, padding_needed)) 30 | return padded_wav 31 | return wav 32 | 33 | class ChatterboxVC: 34 | ENC_COND_LEN = 6 * S3_SR 35 | DEC_COND_LEN = 10 * S3GEN_SR 36 | 37 | def __init__( 38 | self, 39 | s3gen: S3Gen, 40 | device: str, 41 | ref_dict: dict=None, 42 | ): 43 | self.sr = S3GEN_SR 44 | self.s3gen = s3gen 45 | self.device = device 46 | self.watermarker = perth.PerthImplicitWatermarker() 47 | if ref_dict is None: 48 | self.ref_dict = None 49 | else: 50 | self.ref_dict = { 51 | k: v.to(device) if torch.is_tensor(v) else v 52 | for k, v in ref_dict.items() 53 | } 54 | 55 | @classmethod 56 | def from_local(cls, ckpt_dir, device) -> 'ChatterboxVC': 57 | ckpt_dir = Path(ckpt_dir) 58 | 59 | # Always load to CPU first for non-CUDA devices to handle CUDA-saved models 60 | if str(device) in ["cpu", "mps"]: 61 | map_location = torch.device('cpu') 62 | else: 63 | map_location = None 64 | 65 | ref_dict = None 66 | if (builtin_voice := ckpt_dir / "conds.pt").exists(): 67 | states = torch.load(builtin_voice, map_location=map_location) 68 | ref_dict = states['gen'] 69 | 70 | s3gen = S3Gen() 71 | s3gen.load_state_dict( 72 | load_file(ckpt_dir / "s3gen.safetensors"), strict=False 73 | ) 74 | s3gen.to(device).eval() 75 | 76 | return cls(s3gen, device, ref_dict=ref_dict) 77 | 78 | @classmethod 79 | def from_pretrained(cls, device) -> 'ChatterboxVC': 80 | # Check if MPS is available on macOS 81 | if device == "mps" and not torch.backends.mps.is_available(): 82 | if not torch.backends.mps.is_built(): 83 | logger.warning("MPS not available because the current PyTorch install was not built with MPS enabled.") 84 | else: 85 | logger.warning("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.") 86 | device = "cpu" 87 | 88 | for fpath in ["s3gen.safetensors", "conds.pt"]: 89 | local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath) 90 | 91 | return cls.from_local(Path(local_path).parent, device) 92 | 93 | def set_target_voice(self, wav_fpath): 94 | ## Load reference wav 95 | s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR) 96 | 97 | # Convert to tensor and pad to a 40ms boundary 98 | s3gen_ref_wav = torch.from_numpy(s3gen_ref_wav).float().unsqueeze(0) 99 | s3gen_ref_wav = _pad_wav_to_40ms_multiple(s3gen_ref_wav, S3GEN_SR) 100 | s3gen_ref_wav_np = s3gen_ref_wav.squeeze(0).numpy() 101 | 102 | s3gen_ref_wav_np = s3gen_ref_wav_np[:self.DEC_COND_LEN] 103 | self.ref_dict = self.s3gen.embed_ref(s3gen_ref_wav_np, S3GEN_SR, device=self.device) 104 | 105 | def generate( 106 | self, 107 | audio, 108 | target_voice_path=None, 109 | n_timesteps=10, 110 | pbar=None, 111 | temperature=1.0, 112 | flow_cfg_scale=0.7 113 | ): 114 | if target_voice_path: 115 | self.set_target_voice(target_voice_path) 116 | else: 117 | assert self.ref_dict is not None, "Please `set_target_voice` first or specify `target_voice_path`" 118 | 119 | with torch.inference_mode(): 120 | audio_16, _ = librosa.load(audio, sr=S3_SR) 121 | audio_16 = torch.from_numpy(audio_16).float().to(self.device)[None, ] 122 | 123 | s3_tokens, _ = self.s3gen.tokenizer(audio_16) 124 | wav, _ = self.s3gen.inference( 125 | speech_tokens=s3_tokens, 126 | ref_dict=self.ref_dict, 127 | n_timesteps=n_timesteps, 128 | pbar=pbar, 129 | temperature=temperature, 130 | flow_cfg_scale=flow_cfg_scale 131 | ) 132 | wav = wav.squeeze(0).detach().cpu().numpy() 133 | watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr) 134 | return torch.from_numpy(watermarked_wav).unsqueeze(0) 135 | -------------------------------------------------------------------------------- /src/chatterbox/models/s3gen/transformer/convolution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) 2 | # 2024 Alibaba Inc (Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Modified from ESPnet(https://github.com/espnet/espnet) 16 | """ConvolutionModule definition.""" 17 | 18 | from typing import Tuple 19 | 20 | import torch 21 | from torch import nn 22 | 23 | 24 | class ConvolutionModule(nn.Module): 25 | """ConvolutionModule in Conformer model.""" 26 | 27 | def __init__(self, 28 | channels: int, 29 | kernel_size: int = 15, 30 | activation: nn.Module = nn.ReLU(), 31 | norm: str = "batch_norm", 32 | causal: bool = False, 33 | bias: bool = True): 34 | """Construct an ConvolutionModule object. 35 | Args: 36 | channels (int): The number of channels of conv layers. 37 | kernel_size (int): Kernel size of conv layers. 38 | causal (int): Whether use causal convolution or not 39 | """ 40 | super().__init__() 41 | 42 | self.pointwise_conv1 = nn.Conv1d( 43 | channels, 44 | 2 * channels, 45 | kernel_size=1, 46 | stride=1, 47 | padding=0, 48 | bias=bias, 49 | ) 50 | # self.lorder is used to distinguish if it's a causal convolution, 51 | # if self.lorder > 0: it's a causal convolution, the input will be 52 | # padded with self.lorder frames on the left in forward. 53 | # else: it's a symmetrical convolution 54 | if causal: 55 | padding = 0 56 | self.lorder = kernel_size - 1 57 | else: 58 | # kernel_size should be an odd number for none causal convolution 59 | assert (kernel_size - 1) % 2 == 0 60 | padding = (kernel_size - 1) // 2 61 | self.lorder = 0 62 | self.depthwise_conv = nn.Conv1d( 63 | channels, 64 | channels, 65 | kernel_size, 66 | stride=1, 67 | padding=padding, 68 | groups=channels, 69 | bias=bias, 70 | ) 71 | 72 | assert norm in ['batch_norm', 'layer_norm'] 73 | if norm == "batch_norm": 74 | self.use_layer_norm = False 75 | self.norm = nn.BatchNorm1d(channels) 76 | else: 77 | self.use_layer_norm = True 78 | self.norm = nn.LayerNorm(channels) 79 | 80 | self.pointwise_conv2 = nn.Conv1d( 81 | channels, 82 | channels, 83 | kernel_size=1, 84 | stride=1, 85 | padding=0, 86 | bias=bias, 87 | ) 88 | self.activation = activation 89 | 90 | def forward( 91 | self, 92 | x: torch.Tensor, 93 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 94 | cache: torch.Tensor = torch.zeros((0, 0, 0)), 95 | ) -> Tuple[torch.Tensor, torch.Tensor]: 96 | """Compute convolution module. 97 | Args: 98 | x (torch.Tensor): Input tensor (#batch, time, channels). 99 | mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), 100 | (0, 0, 0) means fake mask. 101 | cache (torch.Tensor): left context cache, it is only 102 | used in causal convolution (#batch, channels, cache_t), 103 | (0, 0, 0) meas fake cache. 104 | Returns: 105 | torch.Tensor: Output tensor (#batch, time, channels). 106 | """ 107 | # exchange the temporal dimension and the feature dimension 108 | x = x.transpose(1, 2) # (#batch, channels, time) 109 | 110 | # mask batch padding 111 | if mask_pad.size(2) > 0: # time > 0 112 | x.masked_fill_(~mask_pad, 0.0) 113 | 114 | if self.lorder > 0: 115 | if cache.size(2) == 0: # cache_t == 0 116 | x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) 117 | else: 118 | assert cache.size(0) == x.size(0) # equal batch 119 | assert cache.size(1) == x.size(1) # equal channel 120 | x = torch.cat((cache, x), dim=2) 121 | assert (x.size(2) > self.lorder) 122 | new_cache = x[:, :, -self.lorder:] 123 | else: 124 | # It's better we just return None if no cache is required, 125 | # However, for JIT export, here we just fake one tensor instead of 126 | # None. 127 | new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 128 | 129 | # GLU mechanism 130 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim) 131 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 132 | 133 | # 1D Depthwise Conv 134 | x = self.depthwise_conv(x) 135 | if self.use_layer_norm: 136 | x = x.transpose(1, 2) 137 | x = self.activation(self.norm(x)) 138 | if self.use_layer_norm: 139 | x = x.transpose(1, 2) 140 | x = self.pointwise_conv2(x) 141 | # mask batch padding 142 | if mask_pad.size(2) > 0: # time > 0 143 | x.masked_fill_(~mask_pad, 0.0) 144 | 145 | return x.transpose(1, 2), new_cache 146 | -------------------------------------------------------------------------------- /src/chatterbox/models/s3tokenizer/s3tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import numpy as np 4 | import librosa 5 | import torch 6 | import torch.nn.functional as F 7 | from s3tokenizer.utils import padding 8 | from s3tokenizer.model_v2 import ( 9 | S3TokenizerV2, 10 | ModelConfig, 11 | ) 12 | 13 | 14 | # Sampling rate of the inputs to S3TokenizerV2 15 | S3_SR = 16_000 16 | S3_HOP = 160 # 100 frames/sec 17 | S3_TOKEN_HOP = 640 # 25 tokens/sec 18 | S3_TOKEN_RATE = 25 19 | SPEECH_VOCAB_SIZE = 6561 20 | 21 | 22 | class S3Tokenizer(S3TokenizerV2): 23 | """ 24 | s3tokenizer.S3TokenizerV2 with the following changes: 25 | - a more integrated `forward` 26 | - compute `log_mel_spectrogram` using `_mel_filters` and `window` in `register_buffers` 27 | """ 28 | 29 | ignore_state_dict_missing = ("_mel_filters", "window") 30 | 31 | def __init__( 32 | self, 33 | name: str="speech_tokenizer_v2_25hz", 34 | config: ModelConfig = ModelConfig() 35 | ): 36 | super().__init__(name) 37 | 38 | self.n_fft = 400 39 | _mel_filters = librosa.filters.mel( 40 | sr=S3_SR, 41 | n_fft=self.n_fft, 42 | n_mels=config.n_mels 43 | ) 44 | self.register_buffer( 45 | "_mel_filters", 46 | torch.FloatTensor(_mel_filters), 47 | ) 48 | 49 | self.register_buffer( 50 | "window", 51 | torch.hann_window(self.n_fft), 52 | ) 53 | 54 | def pad(self, wavs, sr) -> List[torch.Tensor]: 55 | """ 56 | Given a list of wavs with the same `sample_rate`, pad them so that the length is multiple of 40ms (S3 runs at 25 token/sec). 57 | """ 58 | processed_wavs = [] 59 | for wav in wavs: 60 | if isinstance(wav, np.ndarray): 61 | wav = torch.from_numpy(wav) 62 | if wav.dim() == 1: 63 | wav = wav.unsqueeze(0) 64 | 65 | n_tokens = (wav.shape[1] / sr) * S3_TOKEN_RATE 66 | n_tokens = np.ceil(n_tokens) 67 | intended_wav_len = n_tokens * (sr / S3_TOKEN_RATE) 68 | intended_wav_len = int(intended_wav_len) 69 | wav = torch.nn.functional.pad( 70 | wav, 71 | (0, intended_wav_len - wav.shape[-1]), 72 | mode="constant", 73 | value=0 74 | ) 75 | processed_wavs.append(wav) 76 | return processed_wavs 77 | 78 | def _prepare_audio(self, wavs): 79 | """Prepare a list of audios for s3tokenizer processing.""" 80 | processed_wavs = [] 81 | for wav in wavs: 82 | if isinstance(wav, np.ndarray): 83 | wav = torch.from_numpy(wav) 84 | if wav.dim() == 1: 85 | wav = wav.unsqueeze(0) 86 | 87 | processed_wavs.append(wav) 88 | return processed_wavs 89 | 90 | @torch.no_grad() 91 | def forward( 92 | self, 93 | wavs: torch.Tensor, 94 | accelerator: 'Accelerator'=None, 95 | max_len: int=None, 96 | ) -> Tuple[torch.Tensor, torch.LongTensor]: 97 | """ 98 | NOTE: mel-spec has a hop size of 160 points (100 frame/sec). 99 | FIXME: this class inherits `nn.Module` but doesn't accept `torch.Tensor` and handles a list of wavs one by one, which is unexpected. 100 | 101 | Args 102 | ---- 103 | - `wavs`: 16 kHz speech audio 104 | - `max_len` max length to truncate the output sequence to (25 token/sec). 105 | NOTE: please pad the waveform if longer sequence is needed. 106 | """ 107 | processed_wavs = self._prepare_audio(wavs) 108 | mels, mel_lens = [], [] 109 | for wav in processed_wavs: 110 | wav = wav.to(self.device) 111 | mel = self.log_mel_spectrogram(wav) # [B=1, F, T] 112 | if max_len is not None: 113 | mel = mel[..., :max_len * 4] # num_mel_frames = 4 * num_tokens 114 | mels.append(mel.squeeze(0)) 115 | 116 | mels, mel_lens = padding(mels) 117 | if accelerator is None: 118 | tokenizer = self 119 | else: 120 | tokenizer = accelerator.unwrap_model(self) 121 | 122 | speech_tokens, speech_token_lens = tokenizer.quantize(mels, mel_lens.to(self.device)) 123 | return ( 124 | speech_tokens.long().detach(), 125 | speech_token_lens.long().detach(), 126 | ) 127 | 128 | def log_mel_spectrogram( 129 | self, 130 | audio: torch.Tensor, 131 | padding: int = 0, 132 | ): 133 | """ 134 | Compute the log-Mel spectrogram of 135 | 136 | Parameters 137 | ---------- 138 | audio: torch.Tensor, shape = (*) 139 | The path to audio or either a NumPy array or Tensor containing the 140 | audio waveform in 16 kHz 141 | 142 | padding: int 143 | Number of zero samples to pad to the right 144 | 145 | Returns 146 | ------- 147 | torch.Tensor, shape = (128, n_frames) 148 | A Tensor that contains the Mel spectrogram 149 | """ 150 | if not torch.is_tensor(audio): 151 | audio = torch.from_numpy(audio) 152 | 153 | audio = audio.to(self.device) 154 | if padding > 0: 155 | audio = F.pad(audio, (0, padding)) 156 | stft = torch.stft( 157 | audio, self.n_fft, S3_HOP, 158 | window=self.window.to(self.device), 159 | return_complex=True 160 | ) 161 | magnitudes = stft[..., :-1].abs()**2 162 | 163 | mel_spec = self._mel_filters.to(self.device) @ magnitudes 164 | 165 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() 166 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) 167 | log_spec = (log_spec + 4.0) / 4.0 168 | return log_spec 169 | -------------------------------------------------------------------------------- /src/chatterbox/models/t3/inference/alignment_stream_analyzer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Resemble AI 2 | # Author: John Meade, Jeremy Hsu 3 | # MIT License 4 | import logging 5 | import torch 6 | from dataclasses import dataclass 7 | from types import MethodType 8 | 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | @dataclass 14 | class AlignmentAnalysisResult: 15 | # was this frame detected as being part of a noisy beginning chunk with potential hallucinations? 16 | false_start: bool 17 | # was this frame detected as being part of a long tail with potential hallucinations? 18 | long_tail: bool 19 | # was this frame detected as repeating existing text content? 20 | repetition: bool 21 | # was the alignment position of this frame too far from the previous frame? 22 | discontinuity: bool 23 | # has inference reached the end of the text tokens? eg, this remains false if inference stops early 24 | complete: bool 25 | # approximate position in the text token sequence. Can be used for generating online timestamps. 26 | position: int 27 | 28 | 29 | class AlignmentStreamAnalyzer: 30 | def __init__(self, tfmr, queue, text_tokens_slice, alignment_layer_idx=9, eos_idx=0): 31 | """ 32 | Some transformer TTS models implicitly solve text-speech alignment in one or more of their self-attention 33 | activation maps. This module exploits this to perform online integrity checks which streaming. 34 | A hook is injected into the specified attention layer, and heuristics are used to determine alignment 35 | position, repetition, etc. 36 | 37 | NOTE: currently requires no queues. 38 | """ 39 | # self.queue = queue 40 | self.text_tokens_slice = (i, j) = text_tokens_slice 41 | self.eos_idx = eos_idx 42 | self.alignment = torch.zeros(0, j-i) 43 | # self.alignment_bin = torch.zeros(0, j-i) 44 | self.curr_frame_pos = 0 45 | self.text_position = 0 46 | 47 | self.started = False 48 | self.started_at = None 49 | 50 | self.complete = False 51 | self.completed_at = None 52 | 53 | # Using `output_attentions=True` is incompatible with optimized attention kernels, so 54 | # using it for all layers slows things down too much. We can apply it to just one layer 55 | # by intercepting the kwargs and adding a forward hook (credit: jrm) 56 | self.last_aligned_attn = None 57 | self._add_attention_spy(tfmr, alignment_layer_idx) 58 | 59 | def _add_attention_spy(self, tfmr, alignment_layer_idx): 60 | """ 61 | Adds a forward hook to a specific attention layer to collect outputs. 62 | Using `output_attentions=True` is incompatible with optimized attention kernels, so 63 | using it for all layers slows things down too much. 64 | (credit: jrm) 65 | """ 66 | 67 | def attention_forward_hook(module, input, output): 68 | """ 69 | See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`. 70 | NOTE: 71 | - When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`. 72 | - `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th. 73 | """ 74 | step_attention = output[1].cpu() # (B, 16, N, N) 75 | self.last_aligned_attn = step_attention[0].mean(0) # (N, N) 76 | 77 | target_layer = tfmr.layers[alignment_layer_idx].self_attn 78 | hook_handle = target_layer.register_forward_hook(attention_forward_hook) 79 | 80 | # Backup original forward 81 | original_forward = target_layer.forward 82 | def patched_forward(self, *args, **kwargs): 83 | kwargs['output_attentions'] = True 84 | return original_forward(*args, **kwargs) 85 | 86 | # TODO: how to unpatch it? 87 | target_layer.forward = MethodType(patched_forward, target_layer) 88 | 89 | def step(self, logits): 90 | """ 91 | Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS. 92 | """ 93 | # extract approximate alignment matrix chunk (1 frame at a time after the first chunk) 94 | aligned_attn = self.last_aligned_attn # (N, N) 95 | i, j = self.text_tokens_slice 96 | if self.curr_frame_pos == 0: 97 | # first chunk has conditioning info, text tokens, and BOS token 98 | A_chunk = aligned_attn[j:, i:j].clone().cpu() # (T, S) 99 | else: 100 | # subsequent chunks have 1 frame due to KV-caching 101 | A_chunk = aligned_attn[:, i:j].clone().cpu() # (1, S) 102 | 103 | # TODO: monotonic masking; could have issue b/c spaces are often skipped. 104 | A_chunk[:, self.curr_frame_pos + 1:] = 0 105 | 106 | 107 | self.alignment = torch.cat((self.alignment, A_chunk), dim=0) 108 | 109 | A = self.alignment 110 | T, S = A.shape 111 | 112 | # update position 113 | cur_text_posn = A_chunk[-1].argmax() 114 | discontinuity = not(-4 < cur_text_posn - self.text_position < 7) # NOTE: very lenient! 115 | if not discontinuity: 116 | self.text_position = cur_text_posn 117 | 118 | # Hallucinations at the start of speech show up as activations at the bottom of the attention maps! 119 | # To mitigate this, we just wait until there are no activations far off-diagonal in the last 2 tokens, 120 | # and there are some strong activations in the first few tokens. 121 | false_start = (not self.started) and (A[-2:, -2:].max() > 0.1 or A[:, :4].max() < 0.5) 122 | self.started = not false_start 123 | if self.started and self.started_at is None: 124 | self.started_at = T 125 | 126 | # Is generation likely complete? 127 | self.complete = self.complete or self.text_position >= S - 3 128 | if self.complete and self.completed_at is None: 129 | self.completed_at = T 130 | 131 | # NOTE: EOS rarely assigned activations, and second-last token is often punctuation, so use last 3 tokens. 132 | # NOTE: due to the false-start behaviour, we need to make sure we skip activations for the first few tokens. 133 | last_text_token_duration = A[15:, -3:].sum() 134 | 135 | # Activations for the final token that last too long are likely hallucinations. 136 | long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 10) # 400ms 137 | 138 | # If there are activations in previous tokens after generation has completed, assume this is a repetition error. 139 | repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5) 140 | 141 | # If a bad ending is detected, force emit EOS by modifying logits 142 | # NOTE: this means logits may be inconsistent with latents! 143 | if long_tail or repetition: 144 | logger.warn(f"forcing EOS token, {long_tail=}, {repetition=}") 145 | # (±2**15 is safe for all dtypes >= 16bit) 146 | logits = -(2**15) * torch.ones_like(logits) 147 | logits[..., self.eos_idx] = 2**15 148 | 149 | # Suppress EoS to prevent early termination 150 | if cur_text_posn < S - 3: # FIXME: arbitrary 151 | logits[..., self.eos_idx] = -2**15 152 | 153 | self.curr_frame_pos += 1 154 | return logits 155 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 |
6 |
7 | 8 | Chatterbox Nodes in ComfyUI 9 | 10 | 11 |

ComfyUI Chatterbox

12 | 13 |

14 | High-quality Text-to-Speech (TTS) and Voice Conversion (VC) nodes for ComfyUI, powered by Resemble AI's Chatterbox model. 15 |
16 |
17 | Report Bug 18 | · 19 | Request Feature 20 | 21 | 22 | [![Stargazers][stars-shield]][stars-url] 23 | [![Issues][issues-shield]][issues-url] 24 | [![Contributors][contributors-shield]][contributors-url] 25 | [![Forks][forks-shield]][forks-url] 26 |

27 | 28 |
29 | 30 | 31 |
32 | Table of Contents 33 |
    34 |
  1. 35 | About The Project 36 | 40 |
  2. 41 |
  3. 42 | Getting Started 43 | 46 |
  4. 47 |
  5. Usage 48 | 51 |
  6. 52 |
  7. Roadmap
  8. 53 |
  9. Contributing
  10. 54 |
  11. Acknowledgments
  12. 55 |
56 |
57 | 58 | 59 | ## About The Project 60 | 61 | ComfyUI custom nodes for the powerful [Resemble AI Chatterbox](https://github.com/resemble-ai/chatterbox) library. It enables seamless in-workflow Text-to-Speech and Voice Conversion, complete with deep integration into ComfyUI's model management system for efficient VRAM usage. 62 | 63 | 64 |

(back to top)

65 | 66 | ### Major Update Notice 67 | 68 | > [!NOTE] 69 | > * **1.2.0**: This version has been deeply refactored for better performance, stability, and alignment with the ComfyUI codebase. All parameters have been unlocked. 70 | 71 | 72 |

(back to top)

73 | 74 | ### Features 75 | 76 | * **Long generation:** No longer limited to 40 seconds. 77 | * **Chatterbox TTS Node:** Synthesize speech from text with optional voice cloning from an audio prompt. 78 | * **Chatterbox Voice Conversion Node:** Convert the voice in a source audio file to a target voice. 79 | * **Automatic Model Downloading:** Models are automatically downloaded from Hugging Face on first use. 80 | * **Efficient VRAM Management:** Full integration with ComfyUI's model patcher system to load models to GPU only when needed and offload them afterward. 81 | * **Detailed Generation Control:** Fine-tune your audio output with parameters for speed, expressiveness, creativity, and quality. 82 | * **Accurate Progress Bars:** Both console and UI progress bars reflect the true step-by-step generation process. 83 | 84 |

(back to top)

85 | 86 | 87 | ## Getting Started 88 | 89 | ### Installation 90 | 91 | 1. **Install via ComfyUI Manager (Recommended):** 92 | * Search for `ComfyUI-Chatterbox` in the ComfyUI Manager and install it. 93 | 94 | 2. **Manual Installation:** 95 | * Clone this repository into your `ComfyUI/custom_nodes/` directory: 96 | ```bash 97 | git clone https://github.com/wildminder/ComfyUI-Chatterbox.git ComfyUI/custom_nodes/ComfyUI-Chatterbox 98 | ``` 99 | 100 | 3. **Install Dependencies:** 101 | * Navigate to the new directory and install the required packages: 102 | ```bash 103 | cd ComfyUI/custom_nodes/ComfyUI-Chatterbox 104 | pip install -r requirements.txt 105 | ``` 106 | 107 | 4. **Model Management:** 108 | > [!IMPORTANT] 109 | > **For users of previous versions:** This update changes the model directory. You **must manually delete** your old model folder to avoid conflicts: 110 | > 111 | > **Delete this folder:** `ComfyUI/models/chatterbox_tts/` 112 | > 113 | > The new version will automatically download models to the correct ComfyUI-standard directory: `ComfyUI/models/tts/chatterbox/`. 114 | 115 | 5. **Restart ComfyUI.** 116 | 117 |

(back to top)

118 | 119 | 120 | ## Usage 121 | 122 | After installation, you will find two new nodes: 123 | * **Chatterbox TTS 📢** under the `audio/generation` category. 124 | * **Chatterbox Voice Conversion 🗣️** under the `audio/generation` category. 125 | 126 | Load an example workflow from the `workflow-examples/` directory in this repository to get started. 127 | 128 | ### Node Parameters Explained 129 | 130 | #### Chatterbox TTS 📢 Parameters 131 | 132 | * **`max_new_tokens`**: Maximum number of audio tokens to generate. Acts as a failsafe against run-on generations. 25 tokens is approximately 1 second of audio. The model's hard limit is 4096 tokens (≈ 163 seconds). 133 | * **`flow_cfg_scale`**: CFG scale for the mel spectrogram decoder. Higher values increase adherence to the text content and speaker timbre but may reduce naturalness. 134 | * **`exaggeration`**: Controls the expressiveness and emotional intensity. Higher values lead to more exaggerated prosody. 135 | * **`temperature`**: Controls the randomness of the *token sampling* process. Higher values produce more diverse and creative speech, while lower values are more deterministic. 136 | * **`cfg_weight`**: Classifier-Free Guidance (CFG) weight for the *token sampling* process. 137 | * **`repetition_penalty`**: Penalizes repeated tokens to discourage monotonous or repetitive speech. `1.0` means no penalty. 138 | * **`min_p` / `top_p`**: Parameters for nucleus sampling, controlling the pool of tokens the model can choose from at each step. 139 | 140 | 141 | #### Chatterbox Voice Conversion 🗣️ Parameters 142 | 143 | * **`n_timesteps`**: Number of diffusion steps for the flow matching process. Higher values can improve quality but will take longer to generate. 144 | * **`temperature`**: Controls the randomness of the initial noise for the diffusion process. `1.0` is standard. Lower values are more deterministic; higher values are more random. 145 | * **`flow_cfg_scale`**: CFG scale for the mel spectrogram decoder. Higher values increase adherence to the target voice's timbre but may reduce the naturalness of the speech prosody. 146 | * **`target_voice_audio`**: The audio file containing the target voice timbre. If not provided, the default voice from the selected model pack will be used. 147 | 148 | 149 |

(back to top)

150 | 151 | 152 | 153 | ## Acknowledgments 154 | 155 | * This node would not be possible without the incredible [Chatterbox](https://github.com/resemble-ai/chatterbox) library by **Resemble AI**. 156 | * README template adapted from the [Best-README-Template](https://github.com/othneildrew/Best-README-Template). 157 | 158 |

(back to top)

159 | 160 | 161 | [contributors-shield]: https://img.shields.io/github/contributors/wildminder/ComfyUI-Chatterbox.svg?style=for-the-badge 162 | [contributors-url]: https://github.com/wildminder/ComfyUI-Chatterbox/graphs/contributors 163 | [forks-shield]: https://img.shields.io/github/forks/wildminder/ComfyUI-Chatterbox.svg?style=for-the-badge 164 | [forks-url]: https://github.com/wildminder/ComfyUI-Chatterbox/network/members 165 | [stars-shield]: https://img.shields.io/github/stars/wildminder/ComfyUI-Chatterbox.svg?style=for-the-badge 166 | [stars-url]: https://github.com/wildminder/ComfyUI-Chatterbox/stargazers 167 | [issues-shield]: https://img.shields.io/github/issues/wildminder/ComfyUI-Chatterbox.svg?style=for-the-badge 168 | [issues-url]: https://github.com/wildminder/ComfyUI-Chatterbox/issues 169 | -------------------------------------------------------------------------------- /workflow-examples/ChatterboxTTS-workflow.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "dba002a2-22d7-4050-a716-4d7196a0d14b", 3 | "revision": 0, 4 | "last_node_id": 28, 5 | "last_link_id": 48, 6 | "nodes": [ 7 | { 8 | "id": 6, 9 | "type": "LoadAudio", 10 | "pos": [ 11 | 293.8790588378906, 12 | 219.2306365966797 13 | ], 14 | "size": [ 15 | 214.080078125, 16 | 136 17 | ], 18 | "flags": {}, 19 | "order": 0, 20 | "mode": 0, 21 | "inputs": [], 22 | "outputs": [ 23 | { 24 | "name": "AUDIO", 25 | "type": "AUDIO", 26 | "links": [ 27 | 44 28 | ] 29 | } 30 | ], 31 | "title": "Source Voice", 32 | "properties": { 33 | "cnr_id": "comfy-core", 34 | "ver": "0.3.38", 35 | "Node name for S&R": "LoadAudio" 36 | }, 37 | "widgets_values": [ 38 | "male_petergriffin.wav", 39 | null, 40 | null 41 | ], 42 | "color": "#223", 43 | "bgcolor": "#335" 44 | }, 45 | { 46 | "id": 4, 47 | "type": "LoadAudio", 48 | "pos": [ 49 | 293.8790588378906, 50 | 503.71502685546875 51 | ], 52 | "size": [ 53 | 214.080078125, 54 | 136 55 | ], 56 | "flags": {}, 57 | "order": 1, 58 | "mode": 0, 59 | "inputs": [], 60 | "outputs": [ 61 | { 62 | "name": "AUDIO", 63 | "type": "AUDIO", 64 | "links": [ 65 | 45, 66 | 47 67 | ] 68 | } 69 | ], 70 | "title": "Target Voice", 71 | "properties": { 72 | "cnr_id": "comfy-core", 73 | "ver": "0.3.38", 74 | "Node name for S&R": "LoadAudio" 75 | }, 76 | "widgets_values": [ 77 | "male_rickmorty.mp3", 78 | null, 79 | null 80 | ], 81 | "color": "#223", 82 | "bgcolor": "#335" 83 | }, 84 | { 85 | "id": 9, 86 | "type": "Note", 87 | "pos": [ 88 | 295.35626220703125, 89 | 703.06201171875 90 | ], 91 | "size": [ 92 | 216.72750854492188, 93 | 209.47329711914062 94 | ], 95 | "flags": {}, 96 | "order": 2, 97 | "mode": 0, 98 | "inputs": [], 99 | "outputs": [], 100 | "properties": {}, 101 | "widgets_values": [ 102 | "- Download models from:\nhttps://huggingface.co/ResembleAI/chatterbox\n\n- and place them in\nComfyUI\\models\\tts\\chatterbox\\resembleai_default_voice\n\n" 103 | ], 104 | "color": "#432", 105 | "bgcolor": "#653" 106 | }, 107 | { 108 | "id": 27, 109 | "type": "ChatterboxVC", 110 | "pos": [ 111 | 569.2343139648438, 112 | 218.49063110351562 113 | ], 114 | "size": [ 115 | 388.5516662597656, 116 | 174 117 | ], 118 | "flags": {}, 119 | "order": 3, 120 | "mode": 0, 121 | "inputs": [ 122 | { 123 | "name": "source_audio", 124 | "type": "AUDIO", 125 | "link": 44 126 | }, 127 | { 128 | "name": "target_voice_audio", 129 | "shape": 7, 130 | "type": "AUDIO", 131 | "link": 45 132 | } 133 | ], 134 | "outputs": [ 135 | { 136 | "name": "converted_audio", 137 | "type": "AUDIO", 138 | "links": [ 139 | 46 140 | ] 141 | } 142 | ], 143 | "properties": { 144 | "cnr_id": "ComfyUI-ChatterboxTTS", 145 | "ver": "8fc60925cbab607e0f08e0d418c4504e0fcd4842", 146 | "Node name for S&R": "ChatterboxVC" 147 | }, 148 | "widgets_values": [ 149 | "resembleai_default_voice", 150 | 10, 151 | 1, 152 | 0.7, 153 | false 154 | ], 155 | "color": "#232", 156 | "bgcolor": "#353" 157 | }, 158 | { 159 | "id": 28, 160 | "type": "ChatterboxTTS", 161 | "pos": [ 162 | 569.2343139648438, 163 | 502.9750061035156 164 | ], 165 | "size": [ 166 | 400, 167 | 409.4891052246094 168 | ], 169 | "flags": {}, 170 | "order": 4, 171 | "mode": 0, 172 | "inputs": [ 173 | { 174 | "name": "audio_prompt", 175 | "shape": 7, 176 | "type": "AUDIO", 177 | "link": 47 178 | } 179 | ], 180 | "outputs": [ 181 | { 182 | "name": "audio", 183 | "type": "AUDIO", 184 | "links": [ 185 | 48 186 | ] 187 | } 188 | ], 189 | "properties": { 190 | "cnr_id": "ComfyUI-ChatterboxTTS", 191 | "ver": "8fc60925cbab607e0f08e0d418c4504e0fcd4842", 192 | "Node name for S&R": "ChatterboxTTS" 193 | }, 194 | "widgets_values": [ 195 | "resembleai_default_voice", 196 | "Hello, this is a test of Chatterbox TTS in ComfyUI.", 197 | 1000, 198 | 0.7, 199 | 0.5, 200 | 0.8, 201 | 0.5, 202 | 1.2, 203 | 0.05, 204 | 1, 205 | 839706344665720, 206 | "randomize", 207 | false 208 | ], 209 | "color": "#232", 210 | "bgcolor": "#353" 211 | }, 212 | { 213 | "id": 7, 214 | "type": "SaveAudio", 215 | "pos": [ 216 | 1010.6351318359375, 217 | 503.71502685546875 218 | ], 219 | "size": [ 220 | 210, 221 | 112 222 | ], 223 | "flags": {}, 224 | "order": 6, 225 | "mode": 0, 226 | "inputs": [ 227 | { 228 | "name": "audio", 229 | "type": "AUDIO", 230 | "link": 48 231 | } 232 | ], 233 | "outputs": [], 234 | "title": "Output Text2Voice", 235 | "properties": { 236 | "cnr_id": "comfy-core", 237 | "ver": "0.3.38", 238 | "Node name for S&R": "SaveAudio" 239 | }, 240 | "widgets_values": [ 241 | "audio/ComfyUI" 242 | ] 243 | }, 244 | { 245 | "id": 3, 246 | "type": "SaveAudio", 247 | "pos": [ 248 | 1010.6351318359375, 249 | 219.23065185546875 250 | ], 251 | "size": [ 252 | 210, 253 | 112 254 | ], 255 | "flags": {}, 256 | "order": 5, 257 | "mode": 0, 258 | "inputs": [ 259 | { 260 | "name": "audio", 261 | "type": "AUDIO", 262 | "link": 46 263 | } 264 | ], 265 | "outputs": [], 266 | "title": "Output Conversion", 267 | "properties": { 268 | "cnr_id": "comfy-core", 269 | "ver": "0.3.38", 270 | "Node name for S&R": "SaveAudio" 271 | }, 272 | "widgets_values": [ 273 | "audio/ComfyUI" 274 | ] 275 | } 276 | ], 277 | "links": [ 278 | [ 279 | 44, 280 | 6, 281 | 0, 282 | 27, 283 | 0, 284 | "AUDIO" 285 | ], 286 | [ 287 | 45, 288 | 4, 289 | 0, 290 | 27, 291 | 1, 292 | "AUDIO" 293 | ], 294 | [ 295 | 46, 296 | 27, 297 | 0, 298 | 3, 299 | 0, 300 | "AUDIO" 301 | ], 302 | [ 303 | 47, 304 | 4, 305 | 0, 306 | 28, 307 | 0, 308 | "AUDIO" 309 | ], 310 | [ 311 | 48, 312 | 28, 313 | 0, 314 | 7, 315 | 0, 316 | "AUDIO" 317 | ] 318 | ], 319 | "groups": [ 320 | { 321 | "id": 1, 322 | "title": "Voice conversion", 323 | "bounding": [ 324 | 283.8790588378906, 325 | 144.890625, 326 | 947.7410888671875, 327 | 265.6488037109375 328 | ], 329 | "color": "#3f789e", 330 | "font_size": 24, 331 | "flags": {} 332 | }, 333 | { 334 | "id": 2, 335 | "title": "Text 2 Voice", 336 | "bounding": [ 337 | 283.8790588378906, 338 | 429.375, 339 | 946.756103515625, 340 | 503.6624450683594 341 | ], 342 | "color": "#3f789e", 343 | "font_size": 24, 344 | "flags": {} 345 | } 346 | ], 347 | "config": {}, 348 | "extra": { 349 | "ds": { 350 | "scale": 1.0152559799477234, 351 | "offset": [ 352 | -194.86403210547138, 353 | -118.49063110351562 354 | ] 355 | }, 356 | "frontendVersion": "1.23.4", 357 | "VHS_latentpreview": false, 358 | "VHS_latentpreviewrate": 0, 359 | "VHS_MetadataImage": true, 360 | "VHS_KeepIntermediate": true 361 | }, 362 | "version": 0.4 363 | } -------------------------------------------------------------------------------- /src/chatterbox/models/s3gen/utils/mask.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 2024 Alibaba Inc (authors: Xiang Lyu) 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | 19 | ''' 20 | def subsequent_mask( 21 | size: int, 22 | device: torch.device = torch.device("cpu"), 23 | ) -> torch.Tensor: 24 | """Create mask for subsequent steps (size, size). 25 | 26 | This mask is used only in decoder which works in an auto-regressive mode. 27 | This means the current step could only do attention with its left steps. 28 | 29 | In encoder, fully attention is used when streaming is not necessary and 30 | the sequence is not long. In this case, no attention mask is needed. 31 | 32 | When streaming is need, chunk-based attention is used in encoder. See 33 | subsequent_chunk_mask for the chunk-based attention mask. 34 | 35 | Args: 36 | size (int): size of mask 37 | str device (str): "cpu" or "cuda" or torch.Tensor.device 38 | dtype (torch.device): result dtype 39 | 40 | Returns: 41 | torch.Tensor: mask 42 | 43 | Examples: 44 | >>> subsequent_mask(3) 45 | [[1, 0, 0], 46 | [1, 1, 0], 47 | [1, 1, 1]] 48 | """ 49 | ret = torch.ones(size, size, device=device, dtype=torch.bool) 50 | return torch.tril(ret) 51 | ''' 52 | 53 | 54 | def subsequent_chunk_mask( 55 | size: int, 56 | chunk_size: int, 57 | num_left_chunks: int = -1, 58 | device: torch.device = torch.device("cpu"), 59 | ) -> torch.Tensor: 60 | """Create mask for subsequent steps (size, size) with chunk size, 61 | this is for streaming encoder 62 | 63 | Args: 64 | size (int): size of mask 65 | chunk_size (int): size of chunk 66 | num_left_chunks (int): number of left chunks 67 | <0: use full chunk 68 | >=0: use num_left_chunks 69 | device (torch.device): "cpu" or "cuda" or torch.Tensor.device 70 | 71 | Returns: 72 | torch.Tensor: mask 73 | 74 | Examples: 75 | >>> subsequent_chunk_mask(4, 2) 76 | [[1, 1, 0, 0], 77 | [1, 1, 0, 0], 78 | [1, 1, 1, 1], 79 | [1, 1, 1, 1]] 80 | """ 81 | # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks 82 | # actually this is not needed after we have inference cache implemented, will remove it later 83 | pos_idx = torch.arange(size, device=device) 84 | block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size 85 | ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1) 86 | return ret 87 | 88 | 89 | def add_optional_chunk_mask(xs: torch.Tensor, 90 | masks: torch.Tensor, 91 | use_dynamic_chunk: bool, 92 | use_dynamic_left_chunk: bool, 93 | decoding_chunk_size: int, 94 | static_chunk_size: int, 95 | num_decoding_left_chunks: int, 96 | enable_full_context: bool = True): 97 | """ Apply optional mask for encoder. 98 | 99 | Args: 100 | xs (torch.Tensor): padded input, (B, L, D), L for max length 101 | mask (torch.Tensor): mask for xs, (B, 1, L) 102 | use_dynamic_chunk (bool): whether to use dynamic chunk or not 103 | use_dynamic_left_chunk (bool): whether to use dynamic left chunk for 104 | training. 105 | decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's 106 | 0: default for training, use random dynamic chunk. 107 | <0: for decoding, use full chunk. 108 | >0: for decoding, use fixed chunk size as set. 109 | static_chunk_size (int): chunk size for static chunk training/decoding 110 | if it's greater than 0, if use_dynamic_chunk is true, 111 | this parameter will be ignored 112 | num_decoding_left_chunks: number of left chunks, this is for decoding, 113 | the chunk size is decoding_chunk_size. 114 | >=0: use num_decoding_left_chunks 115 | <0: use all left chunks 116 | enable_full_context (bool): 117 | True: chunk size is either [1, 25] or full context(max_len) 118 | False: chunk size ~ U[1, 25] 119 | 120 | Returns: 121 | torch.Tensor: chunk mask of the input xs. 122 | """ 123 | # Whether to use chunk mask or not 124 | if use_dynamic_chunk: 125 | max_len = xs.size(1) 126 | if decoding_chunk_size < 0: 127 | chunk_size = max_len 128 | num_left_chunks = -1 129 | elif decoding_chunk_size > 0: 130 | chunk_size = decoding_chunk_size 131 | num_left_chunks = num_decoding_left_chunks 132 | else: 133 | # chunk size is either [1, 25] or full context(max_len). 134 | # Since we use 4 times subsampling and allow up to 1s(100 frames) 135 | # delay, the maximum frame is 100 / 4 = 25. 136 | chunk_size = torch.randint(1, max_len, (1, )).item() 137 | num_left_chunks = -1 138 | if chunk_size > max_len // 2 and enable_full_context: 139 | chunk_size = max_len 140 | else: 141 | chunk_size = chunk_size % 25 + 1 142 | if use_dynamic_left_chunk: 143 | max_left_chunks = (max_len - 1) // chunk_size 144 | num_left_chunks = torch.randint(0, max_left_chunks, 145 | (1, )).item() 146 | chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, 147 | num_left_chunks, 148 | xs.device) # (L, L) 149 | chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) 150 | chunk_masks = masks & chunk_masks # (B, L, L) 151 | elif static_chunk_size > 0: 152 | num_left_chunks = num_decoding_left_chunks 153 | chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, 154 | num_left_chunks, 155 | xs.device) # (L, L) 156 | chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) 157 | chunk_masks = masks & chunk_masks # (B, L, L) 158 | else: 159 | chunk_masks = masks 160 | assert chunk_masks.dtype == torch.bool 161 | if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0: 162 | logging.warning('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!') 163 | chunk_masks[chunk_masks.sum(dim=-1)==0] = True 164 | return chunk_masks 165 | 166 | 167 | def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: 168 | """Make mask tensor containing indices of padded part. 169 | 170 | See description of make_non_pad_mask. 171 | 172 | Args: 173 | lengths (torch.Tensor): Batch of lengths (B,). 174 | Returns: 175 | torch.Tensor: Mask tensor containing indices of padded part. 176 | 177 | Examples: 178 | >>> lengths = [5, 3, 2] 179 | >>> make_pad_mask(lengths) 180 | masks = [[0, 0, 0, 0 ,0], 181 | [0, 0, 0, 1, 1], 182 | [0, 0, 1, 1, 1]] 183 | """ 184 | batch_size = lengths.size(0) 185 | max_len = max_len if max_len > 0 else lengths.max().item() 186 | seq_range = torch.arange(0, 187 | max_len, 188 | dtype=torch.int64, 189 | device=lengths.device) 190 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) 191 | seq_length_expand = lengths.unsqueeze(-1) 192 | mask = seq_range_expand >= seq_length_expand 193 | return mask 194 | -------------------------------------------------------------------------------- /src/chatterbox/models/t3/modules/perceiver.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Resemble AI 2 | # Author: Manmay Nakhashi 3 | # MIT License 4 | import math 5 | 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | from einops import rearrange 10 | 11 | try: 12 | from torch.nn.attention import SDPBackend, sdpa_kernel 13 | # Flag to indicate that the modern SDPA API is available. 14 | SUPPORTS_SDPA_KERNEL = True 15 | except (ImportError, AttributeError): 16 | # Fallback for older PyTorch versions. 17 | SUPPORTS_SDPA_KERNEL = False 18 | 19 | 20 | class RelativePositionBias(nn.Module): 21 | def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8): 22 | super().__init__() 23 | self.scale = scale 24 | self.causal = causal 25 | self.num_buckets = num_buckets 26 | self.max_distance = max_distance 27 | self.relative_attention_bias = nn.Embedding(num_buckets, heads) 28 | 29 | @staticmethod 30 | def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128): 31 | ret = 0 32 | n = -relative_position 33 | if not causal: 34 | num_buckets //= 2 35 | ret += (n < 0).long() * num_buckets 36 | n = torch.abs(n) 37 | else: 38 | n = torch.max(n, torch.zeros_like(n)) 39 | 40 | max_exact = num_buckets // 2 41 | is_small = n < max_exact 42 | 43 | val_if_large = max_exact + ( 44 | torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) 45 | ).long() 46 | val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) 47 | 48 | ret += torch.where(is_small, n, val_if_large) 49 | return ret 50 | 51 | def forward(self, qk_dots): 52 | i, j, device = *qk_dots.shape[-2:], qk_dots.device 53 | q_pos = torch.arange(i, dtype=torch.long, device=device) 54 | k_pos = torch.arange(j, dtype=torch.long, device=device) 55 | rel_pos = k_pos[None, :] - q_pos[:, None] 56 | rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets, 57 | max_distance=self.max_distance) 58 | values = self.relative_attention_bias(rp_bucket) 59 | bias = rearrange(values, 'i j h -> () h i j') 60 | return qk_dots + (bias * self.scale) 61 | 62 | 63 | class AttentionQKV(nn.Module): 64 | def __init__(self, n_heads, head_dim, dropout_rate=0.1, scale=None, flash=False): 65 | super().__init__() 66 | self.n_heads = n_heads 67 | self.head_dim = head_dim 68 | self.scale = scale if scale is not None else head_dim ** -0.5 69 | self.flash = flash 70 | self.dropout_rate = dropout_rate 71 | self.dropout = nn.Dropout(dropout_rate) 72 | self.flash_config = self.setup_flash_config() if flash else None 73 | 74 | def setup_flash_config(self): 75 | # Setup flash attention configuration 76 | flash_config = { 77 | 'enable_flash': True, 78 | 'enable_math': True, 79 | 'enable_mem_efficient': True 80 | } 81 | return flash_config 82 | 83 | def forward(self, q, k, v, mask=None): 84 | q, k, v = [self.split_heads(tensor) for tensor in [q, k, v]] 85 | if self.flash: 86 | out = self.flash_attention(q, k, v, mask=mask) 87 | else: 88 | out = self.scaled_dot_product_attention(q, k, v, mask=mask) 89 | 90 | return self.combine_heads(out) 91 | 92 | def scaled_dot_product_attention(self, q, k, v, mask=None): 93 | sim = torch.einsum("bhlt,bhls->bhts", q, k) * self.scale 94 | if mask is not None: 95 | sim = sim.masked_fill(mask == 0, float('-inf')) 96 | attn = torch.softmax(sim, dim=-1) 97 | attn = self.dropout(attn) 98 | return torch.einsum("bhts,bhls->bhlt", attn, v) 99 | 100 | def flash_attention(self, q, k, v, mask=None): 101 | if SUPPORTS_SDPA_KERNEL: 102 | # Modern PyTorch (>= 2.2) API for enabling specific backends. 103 | # This replaces the deprecated `torch.backends.cuda.sdp_kernel`. 104 | backends = [] 105 | if self.flash_config.get('enable_flash', True): 106 | backends.append(SDPBackend.FLASH_ATTENTION) 107 | if self.flash_config.get('enable_mem_efficient', True): 108 | backends.append(SDPBackend.EFFICIENT_ATTENTION) 109 | if self.flash_config.get('enable_math', True): 110 | backends.append(SDPBackend.MATH) 111 | 112 | with sdpa_kernel(backends=backends): 113 | out = F.scaled_dot_product_attention( 114 | q, k, v, 115 | attn_mask=mask, 116 | dropout_p=self.dropout_rate if self.training else 0. 117 | ) 118 | else: 119 | # Fallback for older PyTorch versions. 120 | config = self.flash_config if self.flash_config else {} 121 | with torch.backends.cuda.sdp_kernel(**config): 122 | out = F.scaled_dot_product_attention( 123 | q, k, v, 124 | attn_mask=mask, 125 | dropout_p=self.dropout_rate if self.training else 0. 126 | ) 127 | return out 128 | 129 | def split_heads(self, x): 130 | bs, length, _ = x.shape 131 | x = x.view(bs, length, self.n_heads, self.head_dim) 132 | return x.permute(0, 2, 1, 3) 133 | 134 | def combine_heads(self, x): 135 | bs, _, length, _ = x.shape 136 | x = x.permute(0, 2, 1, 3).contiguous() 137 | return x.view(bs, length, -1) 138 | 139 | 140 | class AttentionBlock2(nn.Module): 141 | """ 142 | An attention block that allows spatial positions to attend to each other, 143 | using AttentionQKV and separate linear transformations for Q, K, and V. 144 | """ 145 | 146 | def __init__( 147 | self, 148 | channels, 149 | num_heads=1, 150 | num_head_channels=-1, 151 | relative_pos_embeddings=False, 152 | flash_attention=True, 153 | dropout_rate=0.2, 154 | scale=None 155 | ): 156 | super().__init__() 157 | self.channels = channels 158 | 159 | if num_head_channels == -1: 160 | self.num_heads = num_heads 161 | else: 162 | assert ( 163 | channels % num_head_channels == 0 164 | ), f"channels {channels} is not divisible by num_head_channels {num_head_channels}" 165 | self.num_heads = channels // num_head_channels 166 | 167 | self.norm = nn.LayerNorm(channels) 168 | 169 | # Separate linear layers for Q, K, and V 170 | self.to_q = nn.Linear(channels, channels) 171 | self.to_k = nn.Linear(channels, channels) 172 | self.to_v = nn.Linear(channels, channels) 173 | 174 | self.attention = AttentionQKV(self.num_heads, channels // self.num_heads, dropout_rate=dropout_rate, flash=flash_attention, scale=scale) 175 | 176 | self.proj_out = nn.Linear(channels, channels) 177 | 178 | if relative_pos_embeddings: 179 | self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64) 180 | else: 181 | self.relative_pos_embeddings = None 182 | 183 | def forward(self, x1, x2, mask=None): 184 | b1, c1, *spatial1 = x1.shape 185 | b2, c2, *spatial2 = x2.shape 186 | 187 | x1_norm = self.norm(x1) 188 | x2_norm = self.norm(x2) 189 | 190 | q = self.to_q(x1_norm) 191 | k = self.to_k(x2_norm) 192 | v = self.to_v(x2_norm) 193 | 194 | h = self.attention(q, k, v, mask=mask) 195 | h = self.proj_out(h) 196 | 197 | return (x1 + h).reshape(b1, c1, *spatial1) 198 | 199 | 200 | class Perceiver(nn.Module): 201 | """Inspired by https://arxiv.org/abs/2103.03206""" 202 | def __init__(self, pre_attention_query_token=32, pre_attention_query_size=1024, embedding_dim=1024, num_attn_heads=4): 203 | """ 204 | Initialize the perceiver module. 205 | 206 | :param pre_attention_query_token: Number of query tokens for pre-attention 207 | :param pre_attention_query_size: Size of each query token 208 | :param embedding_dim: Dimension of the embedding space 209 | :param num_attn_heads: Number of attention heads 210 | """ 211 | super().__init__() 212 | 213 | # Initialize the pre-attention query parameter 214 | self.pre_attention_query = torch.nn.Parameter( 215 | torch.empty(1, pre_attention_query_token, pre_attention_query_size) 216 | ) 217 | 218 | # Calculate the variance for uniform initialization 219 | query_variance = math.sqrt(3.0) * math.sqrt(2.0 / (pre_attention_query_token + pre_attention_query_token)) 220 | 221 | # Initialize the pre-attention query with uniform distribution 222 | self.pre_attention_query.data.uniform_(-query_variance, query_variance) 223 | 224 | # Initialize the attention block 225 | self.attn = AttentionBlock2(embedding_dim, num_attn_heads) 226 | 227 | def forward(self, h): 228 | """ 229 | Forward pass of the perceiver module. 230 | :param h: Input tensor 231 | :return: Output after applying attention mechanisms 232 | """ 233 | # Expand the pre-attention query to match the batch size of the input 234 | query_ = self.pre_attention_query.expand(h.shape[0], -1, -1) 235 | # Apply the first attention mechanism (cross-attention) 236 | pre_att = self.attn(query_, h) 237 | # Apply the second attention mechanism (self-attention) 238 | attn = self.attn(pre_att, pre_att) 239 | return attn 240 | -------------------------------------------------------------------------------- /src/chatterbox/models/s3gen/transformer/encoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) 2 | # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Modified from ESPnet(https://github.com/espnet/espnet) 16 | """Encoder self-attention layer definition.""" 17 | 18 | from typing import Optional, Tuple 19 | 20 | import torch 21 | from torch import nn 22 | 23 | 24 | class TransformerEncoderLayer(nn.Module): 25 | """Encoder layer module. 26 | 27 | Args: 28 | size (int): Input dimension. 29 | self_attn (torch.nn.Module): Self-attention module instance. 30 | `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` 31 | instance can be used as the argument. 32 | feed_forward (torch.nn.Module): Feed-forward module instance. 33 | `PositionwiseFeedForward`, instance can be used as the argument. 34 | dropout_rate (float): Dropout rate. 35 | normalize_before (bool): 36 | True: use layer_norm before each sub-block. 37 | False: to use layer_norm after each sub-block. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | size: int, 43 | self_attn: torch.nn.Module, 44 | feed_forward: torch.nn.Module, 45 | dropout_rate: float, 46 | normalize_before: bool = True, 47 | ): 48 | """Construct an EncoderLayer object.""" 49 | super().__init__() 50 | self.self_attn = self_attn 51 | self.feed_forward = feed_forward 52 | self.norm1 = nn.LayerNorm(size, eps=1e-12) 53 | self.norm2 = nn.LayerNorm(size, eps=1e-12) 54 | self.dropout = nn.Dropout(dropout_rate) 55 | self.size = size 56 | self.normalize_before = normalize_before 57 | 58 | def forward( 59 | self, 60 | x: torch.Tensor, 61 | mask: torch.Tensor, 62 | pos_emb: torch.Tensor, 63 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 64 | att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 65 | cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 66 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 67 | """Compute encoded features. 68 | 69 | Args: 70 | x (torch.Tensor): (#batch, time, size) 71 | mask (torch.Tensor): Mask tensor for the input (#batch, time,time), 72 | (0, 0, 0) means fake mask. 73 | pos_emb (torch.Tensor): just for interface compatibility 74 | to ConformerEncoderLayer 75 | mask_pad (torch.Tensor): does not used in transformer layer, 76 | just for unified api with conformer. 77 | att_cache (torch.Tensor): Cache tensor of the KEY & VALUE 78 | (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. 79 | cnn_cache (torch.Tensor): Convolution cache in conformer layer 80 | (#batch=1, size, cache_t2), not used here, it's for interface 81 | compatibility to ConformerEncoderLayer. 82 | Returns: 83 | torch.Tensor: Output tensor (#batch, time, size). 84 | torch.Tensor: Mask tensor (#batch, time, time). 85 | torch.Tensor: att_cache tensor, 86 | (#batch=1, head, cache_t1 + time, d_k * 2). 87 | torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2). 88 | 89 | """ 90 | residual = x 91 | if self.normalize_before: 92 | x = self.norm1(x) 93 | x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache) 94 | x = residual + self.dropout(x_att) 95 | if not self.normalize_before: 96 | x = self.norm1(x) 97 | 98 | residual = x 99 | if self.normalize_before: 100 | x = self.norm2(x) 101 | x = residual + self.dropout(self.feed_forward(x)) 102 | if not self.normalize_before: 103 | x = self.norm2(x) 104 | 105 | fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 106 | return x, mask, new_att_cache, fake_cnn_cache 107 | 108 | 109 | class ConformerEncoderLayer(nn.Module): 110 | """Encoder layer module. 111 | Args: 112 | size (int): Input dimension. 113 | self_attn (torch.nn.Module): Self-attention module instance. 114 | `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` 115 | instance can be used as the argument. 116 | feed_forward (torch.nn.Module): Feed-forward module instance. 117 | `PositionwiseFeedForward` instance can be used as the argument. 118 | feed_forward_macaron (torch.nn.Module): Additional feed-forward module 119 | instance. 120 | `PositionwiseFeedForward` instance can be used as the argument. 121 | conv_module (torch.nn.Module): Convolution module instance. 122 | `ConvlutionModule` instance can be used as the argument. 123 | dropout_rate (float): Dropout rate. 124 | normalize_before (bool): 125 | True: use layer_norm before each sub-block. 126 | False: use layer_norm after each sub-block. 127 | """ 128 | 129 | def __init__( 130 | self, 131 | size: int, 132 | self_attn: torch.nn.Module, 133 | feed_forward: Optional[nn.Module] = None, 134 | feed_forward_macaron: Optional[nn.Module] = None, 135 | conv_module: Optional[nn.Module] = None, 136 | dropout_rate: float = 0.1, 137 | normalize_before: bool = True, 138 | ): 139 | """Construct an EncoderLayer object.""" 140 | super().__init__() 141 | self.self_attn = self_attn 142 | self.feed_forward = feed_forward 143 | self.feed_forward_macaron = feed_forward_macaron 144 | self.conv_module = conv_module 145 | self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module 146 | self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module 147 | if feed_forward_macaron is not None: 148 | self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12) 149 | self.ff_scale = 0.5 150 | else: 151 | self.ff_scale = 1.0 152 | if self.conv_module is not None: 153 | self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module 154 | self.norm_final = nn.LayerNorm( 155 | size, eps=1e-12) # for the final output of the block 156 | self.dropout = nn.Dropout(dropout_rate) 157 | self.size = size 158 | self.normalize_before = normalize_before 159 | 160 | def forward( 161 | self, 162 | x: torch.Tensor, 163 | mask: torch.Tensor, 164 | pos_emb: torch.Tensor, 165 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 166 | att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 167 | cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 168 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 169 | """Compute encoded features. 170 | 171 | Args: 172 | x (torch.Tensor): (#batch, time, size) 173 | mask (torch.Tensor): Mask tensor for the input (#batch, time,time), 174 | (0, 0, 0) means fake mask. 175 | pos_emb (torch.Tensor): positional encoding, must not be None 176 | for ConformerEncoderLayer. 177 | mask_pad (torch.Tensor): batch padding mask used for conv module. 178 | (#batch, 1,time), (0, 0, 0) means fake mask. 179 | att_cache (torch.Tensor): Cache tensor of the KEY & VALUE 180 | (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. 181 | cnn_cache (torch.Tensor): Convolution cache in conformer layer 182 | (#batch=1, size, cache_t2) 183 | Returns: 184 | torch.Tensor: Output tensor (#batch, time, size). 185 | torch.Tensor: Mask tensor (#batch, time, time). 186 | torch.Tensor: att_cache tensor, 187 | (#batch=1, head, cache_t1 + time, d_k * 2). 188 | torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). 189 | """ 190 | 191 | # whether to use macaron style 192 | if self.feed_forward_macaron is not None: 193 | residual = x 194 | if self.normalize_before: 195 | x = self.norm_ff_macaron(x) 196 | x = residual + self.ff_scale * self.dropout( 197 | self.feed_forward_macaron(x)) 198 | if not self.normalize_before: 199 | x = self.norm_ff_macaron(x) 200 | 201 | # multi-headed self-attention module 202 | residual = x 203 | if self.normalize_before: 204 | x = self.norm_mha(x) 205 | x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, 206 | att_cache) 207 | x = residual + self.dropout(x_att) 208 | if not self.normalize_before: 209 | x = self.norm_mha(x) 210 | 211 | # convolution module 212 | # Fake new cnn cache here, and then change it in conv_module 213 | new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 214 | if self.conv_module is not None: 215 | residual = x 216 | if self.normalize_before: 217 | x = self.norm_conv(x) 218 | x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) 219 | x = residual + self.dropout(x) 220 | 221 | if not self.normalize_before: 222 | x = self.norm_conv(x) 223 | 224 | # feed forward module 225 | residual = x 226 | if self.normalize_before: 227 | x = self.norm_ff(x) 228 | 229 | x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) 230 | if not self.normalize_before: 231 | x = self.norm_ff(x) 232 | 233 | if self.conv_module is not None: 234 | x = self.norm_final(x) 235 | 236 | return x, mask, new_att_cache, new_cnn_cache 237 | -------------------------------------------------------------------------------- /src/chatterbox/models/s3gen/flow_matching.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import threading 15 | import torch 16 | import torch.nn.functional as F 17 | from tqdm import tqdm 18 | from .matcha.flow_matching import BASECFM 19 | from .configs import CFM_PARAMS 20 | 21 | 22 | class ConditionalCFM(BASECFM): 23 | def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None): 24 | super().__init__( 25 | n_feats=in_channels, 26 | cfm_params=cfm_params, 27 | n_spks=n_spks, 28 | spk_emb_dim=spk_emb_dim, 29 | ) 30 | self.t_scheduler = cfm_params.t_scheduler 31 | self.training_cfg_rate = cfm_params.training_cfg_rate 32 | self.inference_cfg_rate = cfm_params.inference_cfg_rate 33 | in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0) 34 | # Just change the architecture of the estimator here 35 | self.estimator = estimator 36 | self.lock = threading.Lock() 37 | 38 | @torch.inference_mode() 39 | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2), pbar=None, flow_cfg_scale=None): 40 | """Forward diffusion 41 | 42 | Args: 43 | mu (torch.Tensor): output of encoder 44 | shape: (batch_size, n_feats, mel_timesteps) 45 | mask (torch.Tensor): output_mask 46 | shape: (batch_size, 1, mel_timesteps) 47 | n_timesteps (int): number of diffusion steps 48 | temperature (float, optional): temperature for scaling noise. Defaults to 1.0. 49 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 50 | shape: (batch_size, spk_emb_dim) 51 | cond: Not used but kept for future purposes 52 | pbar: ComfyUI ProgressBar instance 53 | 54 | Returns: 55 | sample: generated mel-spectrogram 56 | shape: (batch_size, n_feats, mel_timesteps) 57 | """ 58 | 59 | if flow_cfg_scale is not None: 60 | self.inference_cfg_rate = flow_cfg_scale 61 | 62 | z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature 63 | cache_size = flow_cache.shape[2] 64 | # fix prompt and overlap part mu and z 65 | if cache_size != 0: 66 | z[:, :, :cache_size] = flow_cache[:, :, :, 0] 67 | mu[:, :, :cache_size] = flow_cache[:, :, :, 1] 68 | z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2) 69 | mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2) 70 | flow_cache = torch.stack([z_cache, mu_cache], dim=-1) 71 | 72 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) 73 | if self.t_scheduler == 'cosine': 74 | t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) 75 | 76 | return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, pbar=pbar), flow_cache 77 | 78 | def solve_euler(self, x, t_span, mu, mask, spks, cond, pbar=None): 79 | """ 80 | Fixed euler solver for ODEs. 81 | Args: 82 | x (torch.Tensor): random noise 83 | t_span (torch.Tensor): n_timesteps interpolated 84 | shape: (n_timesteps + 1,) 85 | mu (torch.Tensor): output of encoder 86 | shape: (batch_size, n_feats, mel_timesteps) 87 | mask (torch.Tensor): output_mask 88 | shape: (batch_size, 1, mel_timesteps) 89 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 90 | shape: (batch_size, spk_emb_dim) 91 | cond: Not used but kept for future purposes 92 | pbar: ComfyUI ProgressBar instance 93 | """ 94 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 95 | t = t.unsqueeze(dim=0) 96 | 97 | # I am storing this because I can later plot it by putting a debugger here and saving it to a file 98 | # Or in future might add like a return_all_steps flag 99 | sol = [] 100 | iterator = tqdm(range(1, len(t_span)), desc="Voice Converting", dynamic_ncols=True) 101 | 102 | # Do not use concat, it may cause memory format changed and trt infer with wrong results! 103 | x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) 104 | mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype) 105 | mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) 106 | t_in = torch.zeros([2], device=x.device, dtype=x.dtype) 107 | spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype) 108 | cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) 109 | for step in iterator: 110 | # Classifier-Free Guidance inference introduced in VoiceBox 111 | x_in[:] = x 112 | mask_in[:] = mask 113 | mu_in[0] = mu 114 | t_in[:] = t.unsqueeze(0) 115 | spks_in[0] = spks 116 | cond_in[0] = cond 117 | dphi_dt = self.forward_estimator( 118 | x_in, mask_in, 119 | mu_in, t_in, 120 | spks_in, 121 | cond_in 122 | ) 123 | dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) 124 | dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) 125 | x = x + dt * dphi_dt 126 | t = t + dt 127 | sol.append(x) 128 | if step < len(t_span) - 1: 129 | dt = t_span[step + 1] - t 130 | 131 | if pbar: 132 | pbar.update(1) 133 | 134 | return sol[-1].float() 135 | 136 | def forward_estimator(self, x, mask, mu, t, spks, cond): 137 | if isinstance(self.estimator, torch.nn.Module): 138 | return self.estimator.forward(x, mask, mu, t, spks, cond) 139 | else: 140 | with self.lock: 141 | self.estimator.set_input_shape('x', (2, 80, x.size(2))) 142 | self.estimator.set_input_shape('mask', (2, 1, x.size(2))) 143 | self.estimator.set_input_shape('mu', (2, 80, x.size(2))) 144 | self.estimator.set_input_shape('t', (2,)) 145 | self.estimator.set_input_shape('spks', (2, 80)) 146 | self.estimator.set_input_shape('cond', (2, 80, x.size(2))) 147 | # run trt engine 148 | self.estimator.execute_v2([x.contiguous().data_ptr(), 149 | mask.contiguous().data_ptr(), 150 | mu.contiguous().data_ptr(), 151 | t.contiguous().data_ptr(), 152 | spks.contiguous().data_ptr(), 153 | cond.contiguous().data_ptr(), 154 | x.data_ptr()]) 155 | return x 156 | 157 | def compute_loss(self, x1, mask, mu, spks=None, cond=None): 158 | """Computes diffusion loss 159 | 160 | Args: 161 | x1 (torch.Tensor): Target 162 | shape: (batch_size, n_feats, mel_timesteps) 163 | mask (torch.Tensor): target mask 164 | shape: (batch_size, 1, mel_timesteps) 165 | mu (torch.Tensor): output of encoder 166 | shape: (batch_size, n_feats, mel_timesteps) 167 | spks (torch.Tensor, optional): speaker embedding. Defaults to None. 168 | shape: (batch_size, spk_emb_dim) 169 | 170 | Returns: 171 | loss: conditional flow matching loss 172 | y: conditional flow 173 | shape: (batch_size, n_feats, mel_timesteps) 174 | """ 175 | b, _, t = mu.shape 176 | 177 | # random timestep 178 | t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) 179 | if self.t_scheduler == 'cosine': 180 | t = 1 - torch.cos(t * 0.5 * torch.pi) 181 | # sample noise p(x_0) 182 | z = torch.randn_like(x1) 183 | 184 | y = (1 - (1 - self.sigma_min) * t) * z + t * x1 185 | u = x1 - (1 - self.sigma_min) * z 186 | 187 | # during training, we randomly drop condition to trade off mode coverage and sample fidelity 188 | if self.training_cfg_rate > 0: 189 | cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate 190 | mu = mu * cfg_mask.view(-1, 1, 1) 191 | spks = spks * cfg_mask.view(-1, 1) 192 | cond = cond * cfg_mask.view(-1, 1, 1) 193 | 194 | pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond) 195 | loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) 196 | return loss, y 197 | 198 | 199 | class CausalConditionalCFM(ConditionalCFM): 200 | def __init__(self, in_channels=240, cfm_params=CFM_PARAMS, n_spks=1, spk_emb_dim=80, estimator=None): 201 | super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator) 202 | self.rand_noise = torch.randn([1, 80, 50 * 300]) 203 | 204 | @torch.inference_mode() 205 | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, pbar=None, flow_cfg_scale=None, **kwargs): 206 | """Forward diffusion 207 | 208 | Args: 209 | mu (torch.Tensor): output of encoder 210 | shape: (batch_size, n_feats, mel_timesteps) 211 | mask (torch.Tensor): output_mask 212 | shape: (batch_size, 1, mel_timesteps) 213 | n_timesteps (int): number of diffusion steps 214 | temperature (float, optional): temperature for scaling noise. Defaults to 1.0. 215 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 216 | shape: (batch_size, spk_emb_dim) 217 | cond: Not used but kept for future purposes 218 | 219 | Returns: 220 | sample: generated mel-spectrogram 221 | shape: (batch_size, n_feats, mel_timesteps) 222 | """ 223 | 224 | if flow_cfg_scale is not None: 225 | self.inference_cfg_rate = flow_cfg_scale 226 | 227 | z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature 228 | # fix prompt and overlap part mu and z 229 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) 230 | if self.t_scheduler == 'cosine': 231 | t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) 232 | return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, pbar=pbar), None 233 | -------------------------------------------------------------------------------- /src/chatterbox/tts.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | 5 | import librosa 6 | import torch 7 | import perth 8 | import torch.nn.functional as F 9 | from huggingface_hub import hf_hub_download 10 | from safetensors.torch import load_file 11 | 12 | from .models.t3 import T3 13 | from .models.s3tokenizer import S3_SR, S3_TOKEN_RATE, drop_invalid_tokens 14 | from .models.s3gen import S3GEN_SR, S3Gen 15 | from .models.tokenizers import EnTokenizer 16 | from .models.voice_encoder import VoiceEncoder 17 | from .models.t3.modules.cond_enc import T3Cond 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | REPO_ID = "ResembleAI/chatterbox" 22 | 23 | 24 | def punc_norm(text: str) -> str: 25 | """ 26 | Quick cleanup func for punctuation from LLMs or 27 | containing chars not seen often in the dataset 28 | """ 29 | if len(text) == 0: 30 | return "You need to add some text for me to talk." 31 | 32 | # Capitalise first letter 33 | if text[0].islower(): 34 | text = text[0].upper() + text[1:] 35 | 36 | # Remove multiple space chars 37 | text = " ".join(text.split()) 38 | 39 | # Replace uncommon/llm punc 40 | punc_to_replace = [ 41 | ("...", ", "), 42 | ("…", ", "), 43 | (":", ","), 44 | (" - ", ", "), 45 | (";", ", "), 46 | ("—", "-"), 47 | ("–", "-"), 48 | (" ,", ","), 49 | ("“", "\""), 50 | ("”", "\""), 51 | ("‘", "'"), 52 | ("’", "'"), 53 | ] 54 | for old_char_sequence, new_char in punc_to_replace: 55 | text = text.replace(old_char_sequence, new_char) 56 | 57 | # Add full stop if no ending punc 58 | text = text.rstrip(" ") 59 | sentence_enders = {".", "!", "?", "-", ","} 60 | if not any(text.endswith(p) for p in sentence_enders): 61 | text += "." 62 | 63 | return text 64 | 65 | 66 | @dataclass 67 | class Conditionals: 68 | """ 69 | Conditionals for T3 and S3Gen 70 | - T3 conditionals: 71 | - speaker_emb 72 | - clap_emb 73 | - cond_prompt_speech_tokens 74 | - cond_prompt_speech_emb 75 | - emotion_adv 76 | - S3Gen conditionals: 77 | - prompt_token 78 | - prompt_token_len 79 | - prompt_feat 80 | - prompt_feat_len 81 | - embedding 82 | """ 83 | t3: T3Cond 84 | gen: dict 85 | 86 | def to(self, device): 87 | self.t3 = self.t3.to(device=device) 88 | for k, v in self.gen.items(): 89 | if torch.is_tensor(v): 90 | self.gen[k] = v.to(device=device) 91 | return self 92 | 93 | def save(self, fpath: Path): 94 | arg_dict = dict( 95 | t3=self.t3.__dict__, 96 | gen=self.gen 97 | ) 98 | torch.save(arg_dict, fpath) 99 | 100 | @classmethod 101 | def load(cls, fpath, map_location="cpu"): 102 | if isinstance(map_location, str): 103 | map_location = torch.device(map_location) 104 | kwargs = torch.load(fpath, map_location=map_location, weights_only=True) 105 | return cls(T3Cond(**kwargs['t3']), kwargs['gen']) 106 | 107 | # helper function for padding 108 | def _pad_wav_to_40ms_multiple(wav: torch.Tensor, sr: int) -> torch.Tensor: 109 | """ 110 | Pads a waveform to be a multiple of 40ms to prevent rounding errors between 111 | the mel spectrogram (20ms hop) and the speech tokenizer (40ms hop). 112 | """ 113 | S3_TOKEN_DURATION_S = 1 / S3_TOKEN_RATE # 0.04 seconds 114 | samples_per_token = int(sr * S3_TOKEN_DURATION_S) 115 | current_samples = wav.shape[-1] 116 | remainder = current_samples % samples_per_token 117 | if remainder != 0: 118 | padding_needed = samples_per_token - remainder 119 | padded_wav = F.pad(wav, (0, padding_needed)) 120 | return padded_wav 121 | return wav 122 | 123 | 124 | class ChatterboxTTS: 125 | ENC_COND_LEN = 6 * S3_SR 126 | DEC_COND_LEN = 10 * S3GEN_SR 127 | 128 | def __init__( 129 | self, 130 | t3: T3, 131 | s3gen: S3Gen, 132 | ve: VoiceEncoder, 133 | tokenizer: EnTokenizer, 134 | device: str, 135 | conds: Conditionals = None, 136 | ): 137 | self.sr = S3GEN_SR # sample rate of synthesized audio 138 | self.t3 = t3 139 | self.s3gen = s3gen 140 | self.ve = ve 141 | self.tokenizer = tokenizer 142 | self.device = device 143 | self.conds = conds 144 | self.watermarker = perth.PerthImplicitWatermarker() 145 | 146 | @classmethod 147 | def from_local(cls, ckpt_dir, device) -> 'ChatterboxTTS': 148 | ckpt_dir = Path(ckpt_dir) 149 | 150 | # Always load to CPU first for non-CUDA devices to handle CUDA-saved models 151 | if str(device) in ["cpu", "mps"]: 152 | map_location = torch.device('cpu') 153 | else: 154 | map_location = None 155 | 156 | ve = VoiceEncoder() 157 | ve.load_state_dict( 158 | load_file(ckpt_dir / "ve.safetensors") 159 | ) 160 | ve.to(device).eval() 161 | 162 | t3 = T3() 163 | t3_state = load_file(ckpt_dir / "t3_cfg.safetensors") 164 | if "model" in t3_state.keys(): 165 | t3_state = t3_state["model"][0] 166 | t3.load_state_dict(t3_state) 167 | t3.to(device).eval() 168 | 169 | s3gen = S3Gen() 170 | s3gen.load_state_dict( 171 | load_file(ckpt_dir / "s3gen.safetensors"), strict=False 172 | ) 173 | s3gen.to(device).eval() 174 | 175 | tokenizer = EnTokenizer( 176 | str(ckpt_dir / "tokenizer.json") 177 | ) 178 | 179 | conds = None 180 | if (builtin_voice := ckpt_dir / "conds.pt").exists(): 181 | conds = Conditionals.load(builtin_voice, map_location=map_location).to(device) 182 | 183 | return cls(t3, s3gen, ve, tokenizer, device, conds=conds) 184 | 185 | @classmethod 186 | def from_pretrained(cls, device) -> 'ChatterboxTTS': 187 | # Check if MPS is available on macOS 188 | if device == "mps" and not torch.backends.mps.is_available(): 189 | if not torch.backends.mps.is_built(): 190 | logger.warning("MPS not available because the current PyTorch install was not built with MPS enabled.") 191 | else: 192 | logger.warning("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.") 193 | device = "cpu" 194 | 195 | for fpath in ["ve.safetensors", "t3_cfg.safetensors", "s3gen.safetensors", "tokenizer.json", "conds.pt"]: 196 | local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath) 197 | 198 | return cls.from_local(Path(local_path).parent, device) 199 | 200 | def prepare_conditionals(self, wav_fpath, exaggeration=0.5): 201 | ## Load reference wav 202 | s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR) 203 | 204 | # Convert to tensor and pad to a 40ms boundary 205 | s3gen_ref_wav = torch.from_numpy(s3gen_ref_wav).float().unsqueeze(0) 206 | s3gen_ref_wav = _pad_wav_to_40ms_multiple(s3gen_ref_wav, S3GEN_SR) 207 | 208 | # Now convert back to numpy for librosa, or use torch audio for resampling 209 | s3gen_ref_wav_np = s3gen_ref_wav.squeeze(0).numpy() 210 | 211 | ref_16k_wav = librosa.resample(s3gen_ref_wav_np, orig_sr=S3GEN_SR, target_sr=S3_SR) 212 | 213 | s3gen_ref_wav_np = s3gen_ref_wav_np[:self.DEC_COND_LEN] 214 | s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav_np, S3GEN_SR, device=self.device) 215 | 216 | # Speech cond prompt tokens 217 | if plen := self.t3.hp.speech_cond_prompt_len: 218 | s3_tokzr = self.s3gen.tokenizer 219 | t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen) 220 | t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device) 221 | 222 | # Voice-encoder speaker embedding 223 | ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR)) 224 | ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device) 225 | 226 | t3_cond = T3Cond( 227 | speaker_emb=ve_embed, 228 | cond_prompt_speech_tokens=t3_cond_prompt_tokens, 229 | emotion_adv=exaggeration * torch.ones(1, 1, 1), 230 | ).to(device=self.device) 231 | self.conds = Conditionals(t3_cond, s3gen_ref_dict) 232 | 233 | def generate( 234 | self, 235 | text, 236 | repetition_penalty=1.2, 237 | min_p=0.05, 238 | top_p=1.0, 239 | audio_prompt_path=None, 240 | exaggeration=0.5, 241 | cfg_weight=0.5, 242 | temperature=0.8, 243 | pbar=None, 244 | max_new_tokens=1000, 245 | flow_cfg_scale=0.7 246 | ): 247 | if audio_prompt_path: 248 | self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration) 249 | else: 250 | assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`" 251 | 252 | # Update exaggeration if needed 253 | if exaggeration != self.conds.t3.emotion_adv[0, 0, 0]: 254 | _cond: T3Cond = self.conds.t3 255 | self.conds.t3 = T3Cond( 256 | speaker_emb=_cond.speaker_emb, 257 | cond_prompt_speech_tokens=_cond.cond_prompt_speech_tokens, 258 | emotion_adv=exaggeration * torch.ones(1, 1, 1), 259 | ).to(device=self.device) 260 | 261 | # Norm and tokenize text 262 | text = punc_norm(text) 263 | text_tokens = self.tokenizer.text_to_tokens(text).to(self.device) 264 | 265 | if cfg_weight > 0.0: 266 | text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG 267 | 268 | sot = self.t3.hp.start_text_token 269 | eot = self.t3.hp.stop_text_token 270 | text_tokens = F.pad(text_tokens, (1, 0), value=sot) 271 | text_tokens = F.pad(text_tokens, (0, 1), value=eot) 272 | 273 | with torch.inference_mode(): 274 | speech_tokens = self.t3.inference( 275 | t3_cond=self.conds.t3, 276 | text_tokens=text_tokens, 277 | max_new_tokens=max_new_tokens, 278 | temperature=temperature, 279 | cfg_weight=cfg_weight, 280 | repetition_penalty=repetition_penalty, 281 | min_p=min_p, 282 | top_p=top_p, 283 | pbar=pbar 284 | ) 285 | # Extract only the conditional batch. 286 | speech_tokens = speech_tokens[0] 287 | 288 | # TODO: output becomes 1D 289 | speech_tokens = drop_invalid_tokens(speech_tokens) 290 | 291 | speech_tokens = speech_tokens[speech_tokens < 6561] 292 | 293 | speech_tokens = speech_tokens.to(self.device) 294 | 295 | wav, _ = self.s3gen.inference( 296 | speech_tokens=speech_tokens, 297 | ref_dict=self.conds.gen, 298 | flow_cfg_scale=flow_cfg_scale 299 | ) 300 | wav = wav.squeeze(0).detach().cpu().numpy() 301 | watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr) 302 | return torch.from_numpy(watermarked_wav).unsqueeze(0) 303 | -------------------------------------------------------------------------------- /src/chatterbox/models/voice_encoder/voice_encoder.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/CorentinJ/Real-Time-Voice-Cloning 2 | # MIT License 3 | from typing import List, Union, Optional 4 | 5 | import numpy as np 6 | from numpy.lib.stride_tricks import as_strided 7 | import librosa 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import nn, Tensor 11 | 12 | from .config import VoiceEncConfig 13 | from .melspec import melspectrogram 14 | 15 | 16 | def pack(arrays, seq_len: int=None, pad_value=0): 17 | """ 18 | Given a list of length B of array-like objects of shapes (Ti, ...), packs them in a single tensor of 19 | shape (B, T, ...) by padding each individual array on the right. 20 | 21 | :param arrays: a list of array-like objects of matching shapes except for the first axis. 22 | :param seq_len: the value of T. It must be the maximum of the lengths Ti of the arrays at 23 | minimum. Will default to that value if None. 24 | :param pad_value: the value to pad the arrays with. 25 | :return: a (B, T, ...) tensor 26 | """ 27 | if seq_len is None: 28 | seq_len = max(len(array) for array in arrays) 29 | else: 30 | assert seq_len >= max(len(array) for array in arrays) 31 | 32 | # Convert lists to np.array 33 | if isinstance(arrays[0], list): 34 | arrays = [np.array(array) for array in arrays] 35 | 36 | # Convert to tensor and handle device 37 | device = None 38 | if isinstance(arrays[0], torch.Tensor): 39 | tensors = arrays 40 | device = tensors[0].device 41 | else: 42 | tensors = [torch.as_tensor(array) for array in arrays] 43 | 44 | # Fill the packed tensor with the array data 45 | packed_shape = (len(tensors), seq_len, *tensors[0].shape[1:]) 46 | packed_tensor = torch.full(packed_shape, pad_value, dtype=tensors[0].dtype, device=device) 47 | 48 | for i, tensor in enumerate(tensors): 49 | packed_tensor[i, :tensor.size(0)] = tensor 50 | 51 | return packed_tensor 52 | 53 | 54 | def get_num_wins( 55 | n_frames: int, 56 | step: int, 57 | min_coverage: float, 58 | hp: VoiceEncConfig, 59 | ): 60 | assert n_frames > 0 61 | win_size = hp.ve_partial_frames 62 | n_wins, remainder = divmod(max(n_frames - win_size + step, 0), step) 63 | if n_wins == 0 or (remainder + (win_size - step)) / win_size >= min_coverage: 64 | n_wins += 1 65 | target_n = win_size + step * (n_wins - 1) 66 | return n_wins, target_n 67 | 68 | 69 | def get_frame_step( 70 | overlap: float, 71 | rate: float, 72 | hp: VoiceEncConfig, 73 | ): 74 | # Compute how many frames separate two partial utterances 75 | assert 0 <= overlap < 1 76 | if rate is None: 77 | frame_step = int(np.round(hp.ve_partial_frames * (1 - overlap))) 78 | else: 79 | frame_step = int(np.round((hp.sample_rate / rate) / hp.ve_partial_frames)) 80 | assert 0 < frame_step <= hp.ve_partial_frames 81 | return frame_step 82 | 83 | 84 | def stride_as_partials( 85 | mel: np.ndarray, 86 | hp: VoiceEncConfig, 87 | overlap=0.5, 88 | rate: float=None, 89 | min_coverage=0.8, 90 | ): 91 | """ 92 | Takes unscaled mels in (T, M) format 93 | TODO: doc 94 | """ 95 | assert 0 < min_coverage <= 1 96 | frame_step = get_frame_step(overlap, rate, hp) 97 | 98 | # Compute how many partials can fit in the mel 99 | n_partials, target_len = get_num_wins(len(mel), frame_step, min_coverage, hp) 100 | 101 | # Trim or pad the mel spectrogram to match the number of partials 102 | if target_len > len(mel): 103 | mel = np.concatenate((mel, np.full((target_len - len(mel), hp.num_mels), 0))) 104 | elif target_len < len(mel): 105 | mel = mel[:target_len] 106 | 107 | # Ensure the numpy array data is float32 and contiguous in memory 108 | mel = mel.astype(np.float32, order="C") 109 | 110 | # Re-arrange the array in memory to be of shape (N, P, M) with partials overlapping eachother, 111 | # where N is the number of partials, P is the number of frames of each partial and M the 112 | # number of channels of the mel spectrograms. 113 | shape = (n_partials, hp.ve_partial_frames, hp.num_mels) 114 | strides = (mel.strides[0] * frame_step, mel.strides[0], mel.strides[1]) 115 | partials = as_strided(mel, shape, strides) 116 | return partials 117 | 118 | 119 | class VoiceEncoder(nn.Module): 120 | def __init__(self, hp=VoiceEncConfig()): 121 | super().__init__() 122 | 123 | self.hp = hp 124 | 125 | # Network definition 126 | self.lstm = nn.LSTM(self.hp.num_mels, self.hp.ve_hidden_size, num_layers=3, batch_first=True) 127 | if hp.flatten_lstm_params: 128 | self.lstm.flatten_parameters() 129 | self.proj = nn.Linear(self.hp.ve_hidden_size, self.hp.speaker_embed_size) 130 | 131 | # Cosine similarity scaling (fixed initial parameter values) 132 | self.similarity_weight = nn.Parameter(torch.tensor([10.]), requires_grad=True) 133 | self.similarity_bias = nn.Parameter(torch.tensor([-5.]), requires_grad=True) 134 | 135 | @property 136 | def device(self): 137 | return next(self.parameters()).device 138 | 139 | def forward(self, mels: torch.FloatTensor): 140 | """ 141 | Computes the embeddings of a batch of partial utterances. 142 | 143 | :param mels: a batch of unscaled mel spectrograms of same duration as a float32 tensor 144 | of shape (B, T, M) where T is hp.ve_partial_frames 145 | :return: the embeddings as a float32 tensor of shape (B, E) where E is 146 | hp.speaker_embed_size. Embeddings are L2-normed and thus lay in the range [-1, 1]. 147 | """ 148 | if self.hp.normalized_mels and (mels.min() < 0 or mels.max() > 1): 149 | raise Exception(f"Mels outside [0, 1]. Min={mels.min()}, Max={mels.max()}") 150 | 151 | # Pass the input through the LSTM layers 152 | _, (hidden, _) = self.lstm(mels) 153 | 154 | # Project the final hidden state 155 | raw_embeds = self.proj(hidden[-1]) 156 | if self.hp.ve_final_relu: 157 | raw_embeds = F.relu(raw_embeds) 158 | 159 | # L2 normalize the embeddings. 160 | return raw_embeds / torch.linalg.norm(raw_embeds, dim=1, keepdim=True) 161 | 162 | def inference(self, mels: torch.Tensor, mel_lens, overlap=0.5, rate: float=None, min_coverage=0.8, batch_size=None): 163 | """ 164 | Computes the embeddings of a batch of full utterances with gradients. 165 | 166 | :param mels: (B, T, M) unscaled mels 167 | :return: (B, E) embeddings on CPU 168 | """ 169 | mel_lens = mel_lens.tolist() if torch.is_tensor(mel_lens) else mel_lens 170 | 171 | # Compute where to split the utterances into partials 172 | frame_step = get_frame_step(overlap, rate, self.hp) 173 | n_partials, target_lens = zip(*(get_num_wins(l, frame_step, min_coverage, self.hp) for l in mel_lens)) 174 | 175 | # Possibly pad the mels to reach the target lengths 176 | len_diff = max(target_lens) - mels.size(1) 177 | if len_diff > 0: 178 | pad = torch.full((mels.size(0), len_diff, self.hp.num_mels), 0, dtype=torch.float32) 179 | mels = torch.cat((mels, pad.to(mels.device)), dim=1) 180 | 181 | # Group all partials together so that we can batch them easily 182 | partials = [ 183 | mel[i * frame_step: i * frame_step + self.hp.ve_partial_frames] 184 | for mel, n_partial in zip(mels, n_partials) for i in range(n_partial) 185 | ] 186 | assert all(partials[0].shape == partial.shape for partial in partials) 187 | partials = torch.stack(partials) 188 | 189 | # Forward the partials 190 | n_chunks = int(np.ceil(len(partials) / (batch_size or len(partials)))) 191 | partial_embeds = torch.cat([self(batch) for batch in partials.chunk(n_chunks)], dim=0).cpu() 192 | 193 | # Reduce the partial embeds into full embeds and L2-normalize them 194 | slices = np.concatenate(([0], np.cumsum(n_partials))) 195 | raw_embeds = [torch.mean(partial_embeds[start:end], dim=0) for start, end in zip(slices[:-1], slices[1:])] 196 | raw_embeds = torch.stack(raw_embeds) 197 | embeds = raw_embeds / torch.linalg.norm(raw_embeds, dim=1, keepdim=True) 198 | 199 | return embeds 200 | 201 | @staticmethod 202 | def utt_to_spk_embed(utt_embeds: np.ndarray): 203 | """ 204 | Takes an array of L2-normalized utterance embeddings, computes the mean embedding and L2-normalize it to get a 205 | speaker embedding. 206 | """ 207 | assert utt_embeds.ndim == 2 208 | utt_embeds = np.mean(utt_embeds, axis=0) 209 | return utt_embeds / np.linalg.norm(utt_embeds, 2) 210 | 211 | @staticmethod 212 | def voice_similarity(embeds_x: np.ndarray, embeds_y: np.ndarray): 213 | """ 214 | Cosine similarity for L2-normalized utterance embeddings or speaker embeddings 215 | """ 216 | embeds_x = embeds_x if embeds_x.ndim == 1 else VoiceEncoder.utt_to_spk_embed(embeds_x) 217 | embeds_y = embeds_y if embeds_y.ndim == 1 else VoiceEncoder.utt_to_spk_embed(embeds_y) 218 | return embeds_x @ embeds_y 219 | 220 | def embeds_from_mels( 221 | self, mels: Union[Tensor, List[np.ndarray]], mel_lens=None, as_spk=False, batch_size=32, **kwargs 222 | ): 223 | """ 224 | Convenience function for deriving utterance or speaker embeddings from mel spectrograms. 225 | 226 | :param mels: unscaled mels strictly within [0, 1] as either a (B, T, M) tensor or a list of (Ti, M) arrays. 227 | :param mel_lens: if passing mels as a tensor, individual mel lengths 228 | :param as_spk: whether to return utterance embeddings or a single speaker embedding 229 | :param kwargs: args for inference() 230 | 231 | :returns: embeds as a (B, E) float32 numpy array if is False, else as a (E,) array 232 | """ 233 | # Load mels in memory and pack them 234 | if isinstance(mels, List): 235 | mels = [np.asarray(mel) for mel in mels] 236 | assert all(m.shape[1] == mels[0].shape[1] for m in mels), "Mels aren't in (B, T, M) format" 237 | mel_lens = [mel.shape[0] for mel in mels] 238 | mels = pack(mels) 239 | 240 | # Embed them 241 | with torch.inference_mode(): 242 | utt_embeds = self.inference(mels.to(self.device), mel_lens, batch_size=batch_size, **kwargs).numpy() 243 | 244 | return self.utt_to_spk_embed(utt_embeds) if as_spk else utt_embeds 245 | 246 | def embeds_from_wavs( 247 | self, 248 | wavs: List[np.ndarray], 249 | sample_rate, 250 | as_spk=False, 251 | batch_size=32, 252 | trim_top_db: Optional[float]=20, 253 | **kwargs 254 | ): 255 | """ 256 | Wrapper around embeds_from_mels 257 | 258 | :param trim_top_db: this argument was only added for the sake of compatibility with metavoice's implementation 259 | """ 260 | if sample_rate != self.hp.sample_rate: 261 | wavs = [ 262 | librosa.resample(wav, orig_sr=sample_rate, target_sr=self.hp.sample_rate, res_type="kaiser_best") 263 | for wav in wavs 264 | ] 265 | 266 | if trim_top_db: 267 | wavs = [librosa.effects.trim(wav, top_db=trim_top_db)[0] for wav in wavs] 268 | 269 | if "rate" not in kwargs: 270 | kwargs["rate"] = 1.3 # Resemble's default value. 271 | 272 | mels = [melspectrogram(w, self.hp).T for w in wavs] 273 | 274 | return self.embeds_from_mels(mels, as_spk=as_spk, batch_size=batch_size, **kwargs) 275 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022, Haofei Xu 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/chatterbox/models/s3gen/flow.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import logging 15 | import random 16 | from typing import Dict, Optional 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import functional as F 21 | from .utils.mask import make_pad_mask 22 | from .configs import CFM_PARAMS 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | class MaskedDiffWithXvec(torch.nn.Module): 27 | def __init__( 28 | self, 29 | input_size: int = 512, 30 | output_size: int = 80, 31 | spk_embed_dim: int = 192, 32 | output_type: str = "mel", 33 | vocab_size: int = 4096, 34 | input_frame_rate: int = 50, 35 | only_mask_loss: bool = True, 36 | encoder: torch.nn.Module = None, 37 | length_regulator: torch.nn.Module = None, 38 | decoder: torch.nn.Module = None, 39 | decoder_conf: Dict = { 40 | 'in_channels': 240, 41 | 'out_channel': 80, 42 | 'spk_emb_dim': 80, 43 | 'n_spks': 1, 44 | 'cfm_params': CFM_PARAMS, 45 | 'decoder_params': { 46 | 'channels': [256, 256], 47 | 'dropout': 0.0, 48 | 'attention_head_dim': 64, 49 | 'n_blocks': 4, 50 | 'num_mid_blocks': 12, 51 | 'num_heads': 8, 52 | 'act_fn': 'gelu', 53 | } 54 | }, 55 | mel_feat_conf: Dict = { 56 | 'n_fft': 1024, 57 | 'num_mels': 80, 58 | 'sampling_rate': 22050, 59 | 'hop_size': 256, 60 | 'win_size': 1024, 61 | 'fmin': 0, 62 | 'fmax': 8000 63 | } 64 | ): 65 | super().__init__() 66 | self.input_size = input_size 67 | self.output_size = output_size 68 | self.decoder_conf = decoder_conf 69 | self.mel_feat_conf = mel_feat_conf 70 | self.vocab_size = vocab_size 71 | self.output_type = output_type 72 | self.input_frame_rate = input_frame_rate 73 | # logger.info(f"input frame rate={self.input_frame_rate}") 74 | self.input_embedding = nn.Embedding(vocab_size, input_size) 75 | self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size) 76 | self.encoder = encoder 77 | self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size) 78 | self.decoder = decoder 79 | self.length_regulator = length_regulator 80 | self.only_mask_loss = only_mask_loss 81 | 82 | def forward( 83 | self, 84 | batch: dict, 85 | device: torch.device, 86 | ) -> Dict[str, Optional[torch.Tensor]]: 87 | token = batch['speech_token'].to(device) 88 | token_len = batch['speech_token_len'].to(device) 89 | feat = batch['speech_feat'].to(device) 90 | feat_len = batch['speech_feat_len'].to(device) 91 | embedding = batch['embedding'].to(device) 92 | 93 | # xvec projection 94 | embedding = F.normalize(embedding, dim=1) 95 | embedding = self.spk_embed_affine_layer(embedding) 96 | 97 | # concat text and prompt_text 98 | mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device) 99 | token = self.input_embedding(torch.clamp(token, min=0)) * mask 100 | 101 | # text encode 102 | h, h_lengths = self.encoder(token, token_len) 103 | h = self.encoder_proj(h) 104 | h, h_lengths = self.length_regulator(h, feat_len) 105 | 106 | # get conditions 107 | conds = torch.zeros(feat.shape, device=token.device) 108 | for i, j in enumerate(feat_len): 109 | if random.random() < 0.5: 110 | continue 111 | index = random.randint(0, int(0.3 * j)) 112 | conds[i, :index] = feat[i, :index] 113 | conds = conds.transpose(1, 2) 114 | 115 | mask = (~make_pad_mask(feat_len)).to(h) 116 | feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1) 117 | loss, _ = self.decoder.compute_loss( 118 | feat.transpose(1, 2).contiguous(), 119 | mask.unsqueeze(1), 120 | h.transpose(1, 2).contiguous(), 121 | embedding, 122 | cond=conds 123 | ) 124 | return {'loss': loss} 125 | 126 | @torch.inference_mode() 127 | def inference(self, 128 | token, 129 | token_len, 130 | prompt_token, 131 | prompt_token_len, 132 | prompt_feat, 133 | prompt_feat_len, 134 | embedding, 135 | flow_cache, 136 | n_timesteps=10, 137 | pbar=None, 138 | temperature=1.0, 139 | flow_cfg_scale=0.7 140 | ): 141 | if hasattr(self, 'fp16') and self.fp16 is True: 142 | prompt_feat = prompt_feat.half() 143 | embedding = embedding.half() 144 | 145 | assert token.shape[0] == 1 146 | # xvec projection 147 | embedding = F.normalize(embedding, dim=1) 148 | embedding = self.spk_embed_affine_layer(embedding) 149 | 150 | # concat text and prompt_text 151 | token_len1, token_len2 = prompt_token.shape[1], token.shape[1] 152 | token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len 153 | mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) 154 | token = self.input_embedding(torch.clamp(token, min=0)) * mask 155 | 156 | # text encode 157 | h, h_lengths = self.encoder(token, token_len) 158 | h = self.encoder_proj(h) 159 | mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256) 160 | h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate) 161 | 162 | # get conditions 163 | conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype) 164 | conds[:, :mel_len1] = prompt_feat 165 | conds = conds.transpose(1, 2) 166 | 167 | mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) 168 | feat, flow_cache = self.decoder( 169 | mu=h.transpose(1, 2).contiguous(), 170 | mask=mask.unsqueeze(1), 171 | spks=embedding, 172 | cond=conds, 173 | n_timesteps=n_timesteps, 174 | prompt_len=mel_len1, 175 | flow_cache=flow_cache, 176 | pbar=pbar, 177 | temperature=temperature, 178 | flow_cfg_scale=flow_cfg_scale 179 | ) 180 | feat = feat[:, :, mel_len1:] 181 | assert feat.shape[2] == mel_len2 182 | return feat.float(), flow_cache 183 | 184 | 185 | class CausalMaskedDiffWithXvec(torch.nn.Module): 186 | def __init__( 187 | self, 188 | input_size: int = 512, 189 | output_size: int = 80, 190 | spk_embed_dim: int = 192, 191 | output_type: str = "mel", 192 | vocab_size: int = 6561, 193 | input_frame_rate: int = 25, 194 | only_mask_loss: bool = True, 195 | token_mel_ratio: int = 2, 196 | pre_lookahead_len: int = 3, 197 | encoder: torch.nn.Module = None, 198 | decoder: torch.nn.Module = None, 199 | decoder_conf: Dict = { 200 | 'in_channels': 240, 201 | 'out_channel': 80, 202 | 'spk_emb_dim': 80, 203 | 'n_spks': 1, 204 | 'cfm_params': CFM_PARAMS, 205 | 'decoder_params': { 206 | 'channels': [256, 256], 207 | 'dropout': 0.0, 208 | 'attention_head_dim': 64, 209 | 'n_blocks': 4, 210 | 'num_mid_blocks': 12, 211 | 'num_heads': 8, 212 | 'act_fn': 'gelu', 213 | } 214 | }, 215 | mel_feat_conf: Dict = { 216 | 'n_fft': 1024, 217 | 'num_mels': 80, 218 | 'sampling_rate': 22050, 219 | 'hop_size': 256, 220 | 'win_size': 1024, 221 | 'fmin': 0, 222 | 'fmax': 8000 223 | } 224 | ): 225 | super().__init__() 226 | self.input_size = input_size 227 | self.output_size = output_size 228 | self.decoder_conf = decoder_conf 229 | self.mel_feat_conf = mel_feat_conf 230 | self.vocab_size = vocab_size 231 | self.output_type = output_type 232 | self.input_frame_rate = input_frame_rate 233 | # logger.info(f"input frame rate={self.input_frame_rate}") 234 | self.input_embedding = nn.Embedding(vocab_size, input_size) 235 | self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size) 236 | self.encoder = encoder 237 | self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size) 238 | self.decoder = decoder 239 | self.only_mask_loss = only_mask_loss 240 | self.token_mel_ratio = token_mel_ratio 241 | self.pre_lookahead_len = pre_lookahead_len 242 | 243 | # FIXME: this was missing - just putting it in as false 244 | self.fp16 = False 245 | 246 | @torch.inference_mode() 247 | def inference(self, 248 | token, 249 | token_len, 250 | prompt_token, 251 | prompt_token_len, 252 | prompt_feat, 253 | prompt_feat_len, 254 | embedding, 255 | finalize, 256 | n_timesteps=10, 257 | pbar=None, 258 | temperature=1.0, 259 | flow_cfg_scale=0.7 260 | ): 261 | if hasattr(self, 'fp16') and self.fp16 is True: 262 | prompt_feat = prompt_feat.half() 263 | embedding = embedding.half() 264 | 265 | assert token.shape[0] == 1 266 | # xvec projection 267 | embedding = F.normalize(embedding, dim=1) 268 | embedding = self.spk_embed_affine_layer(embedding) 269 | 270 | # concat text and prompt_text 271 | token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len 272 | mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) 273 | token = self.input_embedding(torch.clamp(token, min=0)) * mask 274 | 275 | # text encode 276 | h, h_lengths = self.encoder(token, token_len) 277 | if finalize is False: 278 | h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio] 279 | mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1] 280 | h = self.encoder_proj(h) 281 | 282 | # get conditions 283 | conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype) 284 | conds[:, :mel_len1] = prompt_feat 285 | conds = conds.transpose(1, 2) 286 | 287 | mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) 288 | feat, _ = self.decoder( 289 | mu=h.transpose(1, 2).contiguous(), 290 | mask=mask.unsqueeze(1), 291 | spks=embedding, 292 | cond=conds, 293 | n_timesteps=n_timesteps, 294 | pbar=pbar, 295 | temperature=temperature, 296 | flow_cfg_scale=flow_cfg_scale, 297 | ) 298 | feat = feat[:, :, mel_len1:] 299 | assert feat.shape[2] == mel_len2 300 | return feat.float(), None # NOTE jrm: why are they returning None here? 301 | -------------------------------------------------------------------------------- /src/chatterbox/models/s3gen/transformer/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) 2 | # 2024 Alibaba Inc (Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Modified from ESPnet(https://github.com/espnet/espnet) 16 | """Positonal Encoding Module.""" 17 | 18 | import math 19 | from typing import Tuple, Union 20 | 21 | import torch 22 | import torch.nn.functional as F 23 | import numpy as np 24 | 25 | 26 | class PositionalEncoding(torch.nn.Module): 27 | """Positional encoding. 28 | 29 | :param int d_model: embedding dim 30 | :param float dropout_rate: dropout rate 31 | :param int max_len: maximum input length 32 | 33 | PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) 34 | PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) 35 | """ 36 | 37 | def __init__(self, 38 | d_model: int, 39 | dropout_rate: float, 40 | max_len: int = 5000, 41 | reverse: bool = False): 42 | """Construct an PositionalEncoding object.""" 43 | super().__init__() 44 | self.d_model = d_model 45 | self.xscale = math.sqrt(self.d_model) 46 | self.dropout = torch.nn.Dropout(p=dropout_rate) 47 | self.max_len = max_len 48 | 49 | self.pe = torch.zeros(self.max_len, self.d_model) 50 | position = torch.arange(0, self.max_len, 51 | dtype=torch.float32).unsqueeze(1) 52 | div_term = torch.exp( 53 | torch.arange(0, self.d_model, 2, dtype=torch.float32) * 54 | -(math.log(10000.0) / self.d_model)) 55 | self.pe[:, 0::2] = torch.sin(position * div_term) 56 | self.pe[:, 1::2] = torch.cos(position * div_term) 57 | self.pe = self.pe.unsqueeze(0) 58 | 59 | def forward(self, 60 | x: torch.Tensor, 61 | offset: Union[int, torch.Tensor] = 0) \ 62 | -> Tuple[torch.Tensor, torch.Tensor]: 63 | """Add positional encoding. 64 | 65 | Args: 66 | x (torch.Tensor): Input. Its shape is (batch, time, ...) 67 | offset (int, torch.tensor): position offset 68 | 69 | Returns: 70 | torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) 71 | torch.Tensor: for compatibility to RelPositionalEncoding 72 | """ 73 | 74 | self.pe = self.pe.to(x.device) 75 | pos_emb = self.position_encoding(offset, x.size(1), False) 76 | x = x * self.xscale + pos_emb 77 | return self.dropout(x), self.dropout(pos_emb) 78 | 79 | def position_encoding(self, 80 | offset: Union[int, torch.Tensor], 81 | size: int, 82 | apply_dropout: bool = True) -> torch.Tensor: 83 | """ For getting encoding in a streaming fashion 84 | 85 | Attention!!!!! 86 | we apply dropout only once at the whole utterance level in a none 87 | streaming way, but will call this function several times with 88 | increasing input size in a streaming scenario, so the dropout will 89 | be applied several times. 90 | 91 | Args: 92 | offset (int or torch.tensor): start offset 93 | size (int): required size of position encoding 94 | 95 | Returns: 96 | torch.Tensor: Corresponding encoding 97 | """ 98 | # How to subscript a Union type: 99 | # https://github.com/pytorch/pytorch/issues/69434 100 | if isinstance(offset, int): 101 | assert offset + size <= self.max_len 102 | pos_emb = self.pe[:, offset:offset + size] 103 | elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar 104 | assert offset + size <= self.max_len 105 | pos_emb = self.pe[:, offset:offset + size] 106 | else: # for batched streaming decoding on GPU 107 | assert torch.max(offset) + size <= self.max_len 108 | index = offset.unsqueeze(1) + \ 109 | torch.arange(0, size).to(offset.device) # B X T 110 | flag = index > 0 111 | # remove negative offset 112 | index = index * flag 113 | pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model 114 | 115 | if apply_dropout: 116 | pos_emb = self.dropout(pos_emb) 117 | return pos_emb 118 | 119 | 120 | class RelPositionalEncoding(PositionalEncoding): 121 | """Relative positional encoding module. 122 | See : Appendix B in https://arxiv.org/abs/1901.02860 123 | Args: 124 | d_model (int): Embedding dimension. 125 | dropout_rate (float): Dropout rate. 126 | max_len (int): Maximum input length. 127 | """ 128 | 129 | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): 130 | """Initialize class.""" 131 | super().__init__(d_model, dropout_rate, max_len, reverse=True) 132 | 133 | def forward(self, 134 | x: torch.Tensor, 135 | offset: Union[int, torch.Tensor] = 0) \ 136 | -> Tuple[torch.Tensor, torch.Tensor]: 137 | """Compute positional encoding. 138 | Args: 139 | x (torch.Tensor): Input tensor (batch, time, `*`). 140 | Returns: 141 | torch.Tensor: Encoded tensor (batch, time, `*`). 142 | torch.Tensor: Positional embedding tensor (1, time, `*`). 143 | """ 144 | self.pe = self.pe.to(x.device) 145 | x = x * self.xscale 146 | pos_emb = self.position_encoding(offset, x.size(1), False) 147 | return self.dropout(x), self.dropout(pos_emb) 148 | 149 | 150 | class WhisperPositionalEncoding(PositionalEncoding): 151 | """ Sinusoids position encoding used in openai-whisper.encoder 152 | """ 153 | 154 | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500): 155 | super().__init__(d_model, dropout_rate, max_len) 156 | self.xscale = 1.0 157 | log_timescale_increment = np.log(10000) / (d_model // 2 - 1) 158 | inv_timescales = torch.exp(-log_timescale_increment * 159 | torch.arange(d_model // 2)) 160 | scaled_time = torch.arange(max_len)[:, np.newaxis] * \ 161 | inv_timescales[np.newaxis, :] 162 | pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) 163 | delattr(self, "pe") 164 | self.register_buffer("pe", pe.unsqueeze(0)) 165 | 166 | 167 | class LearnablePositionalEncoding(PositionalEncoding): 168 | """ Learnable position encoding used in openai-whisper.decoder 169 | """ 170 | 171 | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448): 172 | super().__init__(d_model, dropout_rate, max_len) 173 | # NOTE(xcsong): overwrite self.pe & self.xscale 174 | self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model)) 175 | self.xscale = 1.0 176 | 177 | 178 | class NoPositionalEncoding(torch.nn.Module): 179 | """ No position encoding 180 | """ 181 | 182 | def __init__(self, d_model: int, dropout_rate: float): 183 | super().__init__() 184 | self.d_model = d_model 185 | self.dropout = torch.nn.Dropout(p=dropout_rate) 186 | 187 | def forward(self, 188 | x: torch.Tensor, 189 | offset: Union[int, torch.Tensor] = 0) \ 190 | -> Tuple[torch.Tensor, torch.Tensor]: 191 | """ Just return zero vector for interface compatibility 192 | """ 193 | pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device) 194 | return self.dropout(x), pos_emb 195 | 196 | def position_encoding(self, offset: Union[int, torch.Tensor], 197 | size: int) -> torch.Tensor: 198 | return torch.zeros(1, size, self.d_model) 199 | 200 | 201 | class EspnetRelPositionalEncoding(torch.nn.Module): 202 | """Relative positional encoding module (new implementation). 203 | 204 | Details can be found in https://github.com/espnet/espnet/pull/2816. 205 | 206 | See : Appendix B in https://arxiv.org/abs/1901.02860 207 | 208 | Args: 209 | d_model (int): Embedding dimension. 210 | dropout_rate (float): Dropout rate. 211 | max_len (int): Maximum input length. 212 | 213 | """ 214 | 215 | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): 216 | """Construct an PositionalEncoding object.""" 217 | super(EspnetRelPositionalEncoding, self).__init__() 218 | self.d_model = d_model 219 | self.xscale = math.sqrt(self.d_model) 220 | self.dropout = torch.nn.Dropout(p=dropout_rate) 221 | self.pe = None 222 | self.extend_pe(torch.tensor(0.0).expand(1, max_len)) 223 | 224 | def extend_pe(self, x: torch.Tensor): 225 | """Reset the positional encodings.""" 226 | if self.pe is not None: 227 | # self.pe contains both positive and negative parts 228 | # the length of self.pe is 2 * input_len - 1 229 | if self.pe.size(1) >= x.size(1) * 2 - 1: 230 | if self.pe.dtype != x.dtype or self.pe.device != x.device: 231 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 232 | return 233 | # Suppose `i` means to the position of query vecotr and `j` means the 234 | # position of key vector. We use position relative positions when keys 235 | # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]: 258 | """Add positional encoding. 259 | 260 | Args: 261 | x (torch.Tensor): Input tensor (batch, time, `*`). 262 | 263 | Returns: 264 | torch.Tensor: Encoded tensor (batch, time, `*`). 265 | 266 | """ 267 | self.extend_pe(x) 268 | x = x * self.xscale 269 | pos_emb = self.position_encoding(size=x.size(1), offset=offset) 270 | return self.dropout(x), self.dropout(pos_emb) 271 | 272 | def position_encoding(self, 273 | offset: Union[int, torch.Tensor], 274 | size: int) -> torch.Tensor: 275 | """ For getting encoding in a streaming fashion 276 | 277 | Attention!!!!! 278 | we apply dropout only once at the whole utterance level in a none 279 | streaming way, but will call this function several times with 280 | increasing input size in a streaming scenario, so the dropout will 281 | be applied several times. 282 | 283 | Args: 284 | offset (int or torch.tensor): start offset 285 | size (int): required size of position encoding 286 | 287 | Returns: 288 | torch.Tensor: Corresponding encoding 289 | """ 290 | pos_emb = self.pe[ 291 | :, 292 | self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size, 293 | ] 294 | return pos_emb 295 | -------------------------------------------------------------------------------- /src/chatterbox/models/s3gen/decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from einops import pack, rearrange, repeat 18 | 19 | from .utils.mask import add_optional_chunk_mask 20 | from .matcha.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, \ 21 | TimestepEmbedding, Upsample1D 22 | from .matcha.transformer import BasicTransformerBlock 23 | 24 | 25 | def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: 26 | assert mask.dtype == torch.bool 27 | assert dtype in [torch.float32, torch.bfloat16, torch.float16] 28 | mask = mask.to(dtype) 29 | # attention mask bias 30 | # NOTE(Mddct): torch.finfo jit issues 31 | # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min 32 | mask = (1.0 - mask) * -1.0e+10 33 | return mask 34 | 35 | 36 | 37 | class Transpose(torch.nn.Module): 38 | def __init__(self, dim0: int, dim1: int): 39 | super().__init__() 40 | self.dim0 = dim0 41 | self.dim1 = dim1 42 | 43 | def forward(self, x: torch.Tensor): 44 | x = torch.transpose(x, self.dim0, self.dim1) 45 | return x 46 | 47 | 48 | class CausalBlock1D(Block1D): 49 | def __init__(self, dim: int, dim_out: int): 50 | super(CausalBlock1D, self).__init__(dim, dim_out) 51 | self.block = torch.nn.Sequential( 52 | CausalConv1d(dim, dim_out, 3), 53 | Transpose(1, 2), 54 | nn.LayerNorm(dim_out), 55 | Transpose(1, 2), 56 | nn.Mish(), 57 | ) 58 | 59 | def forward(self, x: torch.Tensor, mask: torch.Tensor): 60 | output = self.block(x * mask) 61 | return output * mask 62 | 63 | 64 | class CausalResnetBlock1D(ResnetBlock1D): 65 | def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8): 66 | super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups) 67 | self.block1 = CausalBlock1D(dim, dim_out) 68 | self.block2 = CausalBlock1D(dim_out, dim_out) 69 | 70 | 71 | class CausalConv1d(torch.nn.Conv1d): 72 | def __init__( 73 | self, 74 | in_channels: int, 75 | out_channels: int, 76 | kernel_size: int, 77 | stride: int = 1, 78 | dilation: int = 1, 79 | groups: int = 1, 80 | bias: bool = True, 81 | padding_mode: str = 'zeros', 82 | device=None, 83 | dtype=None 84 | ) -> None: 85 | super(CausalConv1d, self).__init__(in_channels, out_channels, 86 | kernel_size, stride, 87 | padding=0, dilation=dilation, 88 | groups=groups, bias=bias, 89 | padding_mode=padding_mode, 90 | device=device, dtype=dtype) 91 | assert stride == 1 92 | self.causal_padding = (kernel_size - 1, 0) 93 | 94 | def forward(self, x: torch.Tensor): 95 | x = F.pad(x, self.causal_padding) 96 | x = super(CausalConv1d, self).forward(x) 97 | return x 98 | 99 | 100 | class ConditionalDecoder(nn.Module): 101 | def __init__( 102 | self, 103 | in_channels=320, 104 | out_channels=80, 105 | causal=True, 106 | channels=[256], 107 | dropout=0.0, 108 | attention_head_dim=64, 109 | n_blocks=4, 110 | num_mid_blocks=12, 111 | num_heads=8, 112 | act_fn="gelu", 113 | ): 114 | """ 115 | This decoder requires an input with the same shape of the target. So, if your text content 116 | is shorter or longer than the outputs, please re-sampling it before feeding to the decoder. 117 | """ 118 | super().__init__() 119 | channels = tuple(channels) 120 | self.in_channels = in_channels 121 | self.out_channels = out_channels 122 | self.causal = causal 123 | self.time_embeddings = SinusoidalPosEmb(in_channels) 124 | time_embed_dim = channels[0] * 4 125 | self.time_mlp = TimestepEmbedding( 126 | in_channels=in_channels, 127 | time_embed_dim=time_embed_dim, 128 | act_fn="silu", 129 | ) 130 | self.down_blocks = nn.ModuleList([]) 131 | self.mid_blocks = nn.ModuleList([]) 132 | self.up_blocks = nn.ModuleList([]) 133 | 134 | # NOTE jrm: `static_chunk_size` is missing? 135 | self.static_chunk_size = 0 136 | 137 | output_channel = in_channels 138 | for i in range(len(channels)): # pylint: disable=consider-using-enumerate 139 | input_channel = output_channel 140 | output_channel = channels[i] 141 | is_last = i == len(channels) - 1 142 | resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \ 143 | ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) 144 | transformer_blocks = nn.ModuleList( 145 | [ 146 | BasicTransformerBlock( 147 | dim=output_channel, 148 | num_attention_heads=num_heads, 149 | attention_head_dim=attention_head_dim, 150 | dropout=dropout, 151 | activation_fn=act_fn, 152 | ) 153 | for _ in range(n_blocks) 154 | ] 155 | ) 156 | downsample = ( 157 | Downsample1D(output_channel) if not is_last else 158 | CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) 159 | ) 160 | self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) 161 | 162 | for _ in range(num_mid_blocks): 163 | input_channel = channels[-1] 164 | out_channels = channels[-1] 165 | resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \ 166 | ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) 167 | 168 | transformer_blocks = nn.ModuleList( 169 | [ 170 | BasicTransformerBlock( 171 | dim=output_channel, 172 | num_attention_heads=num_heads, 173 | attention_head_dim=attention_head_dim, 174 | dropout=dropout, 175 | activation_fn=act_fn, 176 | ) 177 | for _ in range(n_blocks) 178 | ] 179 | ) 180 | 181 | self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) 182 | 183 | channels = channels[::-1] + (channels[0],) 184 | for i in range(len(channels) - 1): 185 | input_channel = channels[i] * 2 186 | output_channel = channels[i + 1] 187 | is_last = i == len(channels) - 2 188 | resnet = CausalResnetBlock1D( 189 | dim=input_channel, 190 | dim_out=output_channel, 191 | time_emb_dim=time_embed_dim, 192 | ) if self.causal else ResnetBlock1D( 193 | dim=input_channel, 194 | dim_out=output_channel, 195 | time_emb_dim=time_embed_dim, 196 | ) 197 | transformer_blocks = nn.ModuleList( 198 | [ 199 | BasicTransformerBlock( 200 | dim=output_channel, 201 | num_attention_heads=num_heads, 202 | attention_head_dim=attention_head_dim, 203 | dropout=dropout, 204 | activation_fn=act_fn, 205 | ) 206 | for _ in range(n_blocks) 207 | ] 208 | ) 209 | upsample = ( 210 | Upsample1D(output_channel, use_conv_transpose=True) 211 | if not is_last 212 | else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) 213 | ) 214 | self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) 215 | self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1]) 216 | self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) 217 | self.initialize_weights() 218 | 219 | def initialize_weights(self): 220 | for m in self.modules(): 221 | if isinstance(m, nn.Conv1d): 222 | nn.init.kaiming_normal_(m.weight, nonlinearity="relu") 223 | if m.bias is not None: 224 | nn.init.constant_(m.bias, 0) 225 | elif isinstance(m, nn.GroupNorm): 226 | nn.init.constant_(m.weight, 1) 227 | nn.init.constant_(m.bias, 0) 228 | elif isinstance(m, nn.Linear): 229 | nn.init.kaiming_normal_(m.weight, nonlinearity="relu") 230 | if m.bias is not None: 231 | nn.init.constant_(m.bias, 0) 232 | 233 | def forward(self, x, mask, mu, t, spks=None, cond=None): 234 | """Forward pass of the UNet1DConditional model. 235 | 236 | Args: 237 | x (torch.Tensor): shape (batch_size, in_channels, time) 238 | mask (_type_): shape (batch_size, 1, time) 239 | t (_type_): shape (batch_size) 240 | spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. 241 | cond (_type_, optional): placeholder for future use. Defaults to None. 242 | 243 | Raises: 244 | ValueError: _description_ 245 | ValueError: _description_ 246 | 247 | Returns: 248 | _type_: _description_ 249 | """ 250 | 251 | t = self.time_embeddings(t).to(t.dtype) 252 | t = self.time_mlp(t) 253 | 254 | x = pack([x, mu], "b * t")[0] 255 | 256 | if spks is not None: 257 | spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) 258 | x = pack([x, spks], "b * t")[0] 259 | if cond is not None: 260 | x = pack([x, cond], "b * t")[0] 261 | 262 | hiddens = [] 263 | masks = [mask] 264 | for resnet, transformer_blocks, downsample in self.down_blocks: 265 | mask_down = masks[-1] 266 | x = resnet(x, mask_down, t) 267 | x = rearrange(x, "b c t -> b t c").contiguous() 268 | # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) 269 | attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1) 270 | attn_mask = mask_to_bias(attn_mask == 1, x.dtype) 271 | for transformer_block in transformer_blocks: 272 | x = transformer_block( 273 | hidden_states=x, 274 | attention_mask=attn_mask, 275 | timestep=t, 276 | ) 277 | x = rearrange(x, "b t c -> b c t").contiguous() 278 | hiddens.append(x) # Save hidden states for skip connections 279 | x = downsample(x * mask_down) 280 | masks.append(mask_down[:, :, ::2]) 281 | masks = masks[:-1] 282 | mask_mid = masks[-1] 283 | 284 | for resnet, transformer_blocks in self.mid_blocks: 285 | x = resnet(x, mask_mid, t) 286 | x = rearrange(x, "b c t -> b t c").contiguous() 287 | # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) 288 | attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1) 289 | attn_mask = mask_to_bias(attn_mask == 1, x.dtype) 290 | for transformer_block in transformer_blocks: 291 | x = transformer_block( 292 | hidden_states=x, 293 | attention_mask=attn_mask, 294 | timestep=t, 295 | ) 296 | x = rearrange(x, "b t c -> b c t").contiguous() 297 | 298 | for resnet, transformer_blocks, upsample in self.up_blocks: 299 | mask_up = masks.pop() 300 | skip = hiddens.pop() 301 | x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] 302 | x = resnet(x, mask_up, t) 303 | x = rearrange(x, "b c t -> b t c").contiguous() 304 | # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) 305 | attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1) 306 | attn_mask = mask_to_bias(attn_mask == 1, x.dtype) 307 | for transformer_block in transformer_blocks: 308 | x = transformer_block( 309 | hidden_states=x, 310 | attention_mask=attn_mask, 311 | timestep=t, 312 | ) 313 | x = rearrange(x, "b t c -> b c t").contiguous() 314 | x = upsample(x * mask_up) 315 | x = self.final_block(x, mask_up) 316 | output = self.final_proj(x * mask_up) 317 | return output * mask 318 | -------------------------------------------------------------------------------- /src/chatterbox/models/s3gen/s3gen.py: -------------------------------------------------------------------------------- 1 | # Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | 17 | import numpy as np 18 | import torch 19 | import torchaudio as ta 20 | from functools import lru_cache 21 | from typing import Optional 22 | 23 | from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer 24 | from .const import S3GEN_SR 25 | from .flow import CausalMaskedDiffWithXvec 26 | from .xvector import CAMPPlus 27 | from .utils.mel import mel_spectrogram 28 | from .f0_predictor import ConvRNNF0Predictor 29 | from .hifigan import HiFTGenerator 30 | from .transformer.upsample_encoder import UpsampleConformerEncoder 31 | from .flow_matching import CausalConditionalCFM 32 | from .decoder import ConditionalDecoder 33 | from .configs import CFM_PARAMS 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | def drop_invalid_tokens(x): 38 | assert len(x.shape) <= 2 and x.shape[0] == 1, "only batch size of one allowed for now" 39 | return x[x < SPEECH_VOCAB_SIZE] 40 | 41 | 42 | # TODO: global resampler cache 43 | @lru_cache(100) 44 | def get_resampler(src_sr, dst_sr, device): 45 | return ta.transforms.Resample(src_sr, dst_sr).to(device) 46 | 47 | 48 | class S3Token2Mel(torch.nn.Module): 49 | """ 50 | CosyVoice2's CFM decoder maps S3 speech tokens to mel-spectrograms. 51 | 52 | TODO: make these modules configurable? 53 | """ 54 | def __init__(self): 55 | super().__init__() 56 | self.tokenizer = S3Tokenizer("speech_tokenizer_v2_25hz") 57 | self.mel_extractor = mel_spectrogram # TODO: make it a torch module? 58 | self.speaker_encoder = CAMPPlus() # use default args 59 | 60 | encoder = UpsampleConformerEncoder( 61 | output_size=512, 62 | attention_heads=8, 63 | linear_units=2048, 64 | num_blocks=6, 65 | dropout_rate=0.1, 66 | positional_dropout_rate=0.1, 67 | attention_dropout_rate=0.1, 68 | normalize_before=True, 69 | input_layer='linear', 70 | pos_enc_layer_type='rel_pos_espnet', 71 | selfattention_layer_type='rel_selfattn', 72 | input_size=512, 73 | use_cnn_module=False, 74 | macaron_style=False, 75 | ) 76 | 77 | estimator = ConditionalDecoder( 78 | in_channels=320, 79 | out_channels=80, 80 | causal=True, 81 | channels=[256], 82 | dropout=0.0, 83 | attention_head_dim=64, 84 | n_blocks=4, 85 | num_mid_blocks=12, 86 | num_heads=8, 87 | act_fn='gelu', 88 | ) 89 | cfm_params = CFM_PARAMS 90 | decoder = CausalConditionalCFM( 91 | spk_emb_dim=80, 92 | cfm_params=cfm_params, 93 | estimator=estimator, 94 | ) 95 | 96 | self.flow = CausalMaskedDiffWithXvec( 97 | encoder=encoder, 98 | decoder=decoder 99 | ) 100 | 101 | self.resamplers = {} 102 | 103 | @property 104 | def device(self): 105 | params = self.tokenizer.parameters() 106 | return next(params).device 107 | 108 | def embed_ref( 109 | self, 110 | ref_wav: torch.Tensor, 111 | ref_sr: int, 112 | device="auto", 113 | ref_fade_out=True, 114 | ): 115 | device = self.device if device == "auto" else device 116 | if isinstance(ref_wav, np.ndarray): 117 | ref_wav = torch.from_numpy(ref_wav).float() 118 | 119 | if ref_wav.device != device: 120 | ref_wav = ref_wav.to(device) 121 | 122 | if len(ref_wav.shape) == 1: 123 | ref_wav = ref_wav.unsqueeze(0) # (B, L) 124 | 125 | if ref_wav.size(1) > 10 * ref_sr: 126 | logger.warning("WARNING: cosydec received ref longer than 10s") 127 | 128 | ref_wav_24 = ref_wav 129 | if ref_sr != S3GEN_SR: 130 | ref_wav_24 = get_resampler(ref_sr, S3GEN_SR, device)(ref_wav) 131 | 132 | ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(device) 133 | ref_mels_24_len = None 134 | 135 | # Resample to 16kHz 136 | ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav).to(device) 137 | 138 | # Speaker embedding 139 | ref_x_vector = self.speaker_encoder.inference(ref_wav_16) 140 | 141 | # Tokenize 16khz reference 142 | ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16) 143 | 144 | # Make sure mel_len = 2 * stoken_len (happens when the input is not padded to multiple of 40ms) 145 | if ref_mels_24.shape[1] != 2 * ref_speech_tokens.shape[1]: 146 | logger.warning( 147 | "Reference mel length is not equal to 2 * reference token length.\n" 148 | ) 149 | ref_speech_tokens = ref_speech_tokens[:, :ref_mels_24.shape[1] // 2] 150 | ref_speech_token_lens[0] = ref_speech_tokens.shape[1] 151 | 152 | return dict( 153 | prompt_token=ref_speech_tokens.to(device), 154 | prompt_token_len=ref_speech_token_lens, 155 | prompt_feat=ref_mels_24, 156 | prompt_feat_len=ref_mels_24_len, 157 | embedding=ref_x_vector, 158 | ) 159 | 160 | def forward( 161 | self, 162 | speech_tokens: torch.LongTensor, 163 | # locally-computed ref embedding (mutex with ref_dict) 164 | ref_wav: Optional[torch.Tensor], 165 | ref_sr: Optional[int], 166 | # pre-computed ref embedding (prod API) 167 | ref_dict: Optional[dict] = None, 168 | finalize: bool = False, 169 | n_timesteps: int = 10, 170 | pbar=None 171 | ): 172 | """ 173 | Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from. 174 | 175 | NOTE: 176 | - The speaker encoder accepts 16 kHz waveform. 177 | - S3TokenizerV2 accepts 16 kHz waveform. 178 | - The mel-spectrogram for the reference assumes 24 kHz input signal. 179 | - This function is designed for batch_size=1 only. 180 | 181 | Args 182 | ---- 183 | - `speech_tokens`: S3 speech tokens [B=1, T] 184 | - `ref_wav`: reference waveform (`torch.Tensor` with shape=[B=1, T]) 185 | - `ref_sr`: reference sample rate 186 | - `finalize`: whether streaming is finished or not. Note that if False, the last 3 tokens will be ignored. 187 | """ 188 | assert (ref_wav is None) ^ (ref_dict is None), f"Must provide exactly one of ref_wav or ref_dict (got {ref_wav} and {ref_dict})" 189 | 190 | if ref_dict is None: 191 | ref_dict = self.embed_ref(ref_wav, ref_sr) 192 | else: 193 | # type/device casting (all values will be numpy if it's from a prod API call) 194 | for rk in list(ref_dict): 195 | if isinstance(ref_dict[rk], np.ndarray): 196 | ref_dict[rk] = torch.from_numpy(ref_dict[rk]) 197 | if torch.is_tensor(ref_dict[rk]): 198 | ref_dict[rk] = ref_dict[rk].to(self.device) 199 | 200 | if len(speech_tokens.shape) == 1: 201 | speech_tokens = speech_tokens.unsqueeze(0) 202 | 203 | # assert speech_tokens.shape[0] == 1, "only batch size of one allowed for now" 204 | speech_token_lens = torch.LongTensor([speech_tokens.size(1)]).to(self.device) 205 | 206 | output_mels, _ = self.flow.inference( 207 | token=speech_tokens, 208 | token_len=speech_token_lens, 209 | finalize=finalize, 210 | n_timesteps=n_timesteps, 211 | pbar=pbar, 212 | **ref_dict, 213 | ) 214 | return output_mels 215 | 216 | 217 | 218 | 219 | class S3Token2Wav(S3Token2Mel): 220 | """ 221 | The decoder of CosyVoice2 is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules. 222 | 223 | TODO: make these modules configurable? 224 | """ 225 | 226 | def __init__(self): 227 | super().__init__() 228 | 229 | f0_predictor = ConvRNNF0Predictor() 230 | self.mel2wav = HiFTGenerator( 231 | sampling_rate=S3GEN_SR, 232 | upsample_rates=[8, 5, 3], 233 | upsample_kernel_sizes=[16, 11, 7], 234 | source_resblock_kernel_sizes=[7, 7, 11], 235 | source_resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], 236 | f0_predictor=f0_predictor, 237 | ) 238 | 239 | # silence out a few ms and fade audio in to reduce artifacts 240 | n_trim = S3GEN_SR // 50 # 20ms = half of a frame 241 | trim_fade = torch.zeros(2 * n_trim) 242 | trim_fade[n_trim:] = (torch.cos(torch.linspace(torch.pi, 0, n_trim)) + 1) / 2 243 | self.register_buffer("trim_fade", trim_fade, persistent=False) # (buffers get automatic device casting) 244 | 245 | def forward( 246 | self, 247 | speech_tokens, 248 | # locally-computed ref embedding (mutex with ref_dict) 249 | ref_wav: Optional[torch.Tensor], 250 | ref_sr: Optional[int], 251 | # pre-computed ref embedding (prod API) 252 | ref_dict: Optional[dict] = None, 253 | finalize: bool = False, 254 | n_timesteps: int = 10, 255 | pbar=None 256 | ): 257 | output_mels = super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize, n_timesteps=n_timesteps, pbar=pbar) 258 | 259 | # TODO jrm: ignoring the speed control (mel interpolation) and the HiFTGAN caching mechanisms for now. 260 | hift_cache_source = torch.zeros(1, 1, 0).to(self.device) 261 | 262 | output_wavs, *_ = self.mel2wav.inference(speech_feat=output_mels, cache_source=hift_cache_source) 263 | 264 | if not self.training: 265 | # NOTE: ad-hoc method to reduce "spillover" from the reference clip. 266 | output_wavs[:, :len(self.trim_fade)] *= self.trim_fade 267 | 268 | return output_wavs 269 | 270 | @torch.inference_mode() 271 | def flow_inference( 272 | self, 273 | speech_tokens, 274 | # locally-computed ref embedding (mutex with ref_dict) 275 | ref_wav: Optional[torch.Tensor] = None, 276 | ref_sr: Optional[int] = None, 277 | # pre-computed ref embedding (prod API) 278 | ref_dict: Optional[dict] = None, 279 | finalize: bool = False, 280 | n_timesteps: int = 10, 281 | pbar=None, 282 | temperature: float = 1.0, 283 | flow_cfg_scale: float = 0.7 284 | ): 285 | # This method in the base class now needs to accept and pass the new params 286 | # (The base class `S3Token2Mel` needs this change) 287 | 288 | assert (ref_wav is None) ^ (ref_dict is None), f"Must provide exactly one of ref_wav or ref_dict (got {ref_wav} and {ref_dict})" 289 | 290 | if ref_dict is None: 291 | ref_dict = self.embed_ref(ref_wav, ref_sr) 292 | else: 293 | for rk in list(ref_dict): 294 | if isinstance(ref_dict[rk], np.ndarray): 295 | ref_dict[rk] = torch.from_numpy(ref_dict[rk]) 296 | if torch.is_tensor(ref_dict[rk]): 297 | ref_dict[rk] = ref_dict[rk].to(self.device) 298 | 299 | if len(speech_tokens.shape) == 1: 300 | speech_tokens = speech_tokens.unsqueeze(0) 301 | 302 | speech_token_lens = torch.LongTensor([speech_tokens.size(1)]).to(self.device) 303 | 304 | output_mels, _ = self.flow.inference( 305 | token=speech_tokens, 306 | token_len=speech_token_lens, 307 | finalize=finalize, 308 | n_timesteps=n_timesteps, 309 | pbar=pbar, 310 | temperature=temperature, 311 | flow_cfg_scale=flow_cfg_scale, 312 | **ref_dict, 313 | ) 314 | return output_mels 315 | 316 | @torch.inference_mode() 317 | def hift_inference(self, speech_feat, cache_source: torch.Tensor = None): 318 | if cache_source is None: 319 | cache_source = torch.zeros(1, 1, 0).to(self.device) 320 | return self.mel2wav.inference(speech_feat=speech_feat, cache_source=cache_source) 321 | 322 | @torch.inference_mode() 323 | def inference( 324 | self, 325 | speech_tokens, 326 | # locally-computed ref embedding (mutex with ref_dict) 327 | ref_wav: Optional[torch.Tensor] = None, 328 | ref_sr: Optional[int] = None, 329 | # pre-computed ref embedding (prod API) 330 | ref_dict: Optional[dict] = None, 331 | cache_source: torch.Tensor = None, # NOTE: this arg is for streaming, it can probably be removed here 332 | finalize: bool = True, 333 | n_timesteps: int = 10, 334 | pbar=None, 335 | temperature: float = 1.0, 336 | flow_cfg_scale: float = 0.7 337 | ): 338 | output_mels = self.flow_inference( 339 | speech_tokens, 340 | ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, 341 | finalize=finalize, n_timesteps=n_timesteps, pbar=pbar, 342 | temperature=temperature, flow_cfg_scale=flow_cfg_scale 343 | ) 344 | output_wavs, output_sources = self.hift_inference(output_mels, cache_source) 345 | 346 | # NOTE: ad-hoc method to reduce "spillover" from the reference clip. 347 | output_wavs[:, :len(self.trim_fade)] *= self.trim_fade 348 | 349 | return output_wavs, output_sources 350 | --------------------------------------------------------------------------------