├── ChatTTS ├── __init__.py ├── utils │ ├── io_utils.py │ ├── gpu_utils.py │ └── infer_utils.py ├── experimental │ └── llm.py ├── infer │ └── api.py ├── model │ ├── dvae.py │ └── gpt.py └── core.py ├── OpenVoice ├── checkpoints │ ├── converter │ │ ├── config_json │ │ └── checkpiont_pth │ └── openvoice_v1_in_here ├── text │ ├── cleaners.py │ ├── symbols.py │ ├── __init__.py │ ├── english.py │ └── mandarin.py ├── utils │ ├── se_extractor.py │ ├── commons.py │ ├── utils.py │ ├── mel_processing.py │ ├── transforms.py │ ├── attentions.py │ ├── models.py │ └── modules.py └── api.py ├── requirements.txt ├── README.md ├── gitattributes ├── app.py └── LICENSE /ChatTTS/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import Chat -------------------------------------------------------------------------------- /OpenVoice/checkpoints/converter/config_json: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /OpenVoice/checkpoints/openvoice_v1_in_here: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /OpenVoice/checkpoints/converter/checkpiont_pth: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /ChatTTS/utils/io_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import logging 4 | 5 | def get_latest_modified_file(directory): 6 | logger = logging.getLogger(__name__) 7 | 8 | files = [os.path.join(directory, f) for f in os.listdir(directory)] 9 | if not files: 10 | logger.log(logging.WARNING, f'No files found in the directory: {directory}') 11 | return None 12 | latest_file = max(files, key=os.path.getmtime) 13 | 14 | return latest_file -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # PyTorch and related libraries 2 | torch 3 | torchvision 4 | torchaudio 5 | 6 | # Hugging Face transformers library 7 | transformers 8 | 9 | # Configuration management with OmegaConf 10 | omegaconf 11 | 12 | # Interactive widgets for Jupyter Notebooks 13 | ipywidgets 14 | 15 | # Gradio for creating web UIs 16 | gradio 17 | 18 | # Vector quantization for PyTorch 19 | vector_quantize_pytorch 20 | 21 | # Hugging Face Hub client 22 | huggingface_hub 23 | 24 | vocos 25 | 26 | # OpenVoice 27 | librosa 28 | faster-whisper 29 | pydub 30 | wavmark 31 | numpy 32 | eng_to_ipa 33 | inflect 34 | unidecode 35 | whisper-timestamped 36 | openai 37 | python-dotenv 38 | pypinyin 39 | jieba 40 | cn2an 41 | edge_tts -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | __ChatTTS x OpenVoice__ 2 | 3 | Enhance the authenticity of speech by utilizing ChatTTS for more natural voice generation, complemented with the voice timber simulation module from Openvoice for seamless tone transplantation. 4 | 5 | Have a try on huggingface! 6 | https://huggingface.co/spaces/Hilley/ChatTTS-OpenVoice 7 | 8 | image 9 | 10 | 11 | 12 | --- 13 | __Notice:__ 14 | 15 | We need to download the OpenVoice Checkpoint and save it into the __./OpenVoice/checkpoint__ folder. 16 | 17 | __OpenVoice Checkpoint:__ https://huggingface.co/myshell-ai/OpenVoice/tree/main/checkpoints 18 | image 19 | -------------------------------------------------------------------------------- /OpenVoice/text/cleaners.py: -------------------------------------------------------------------------------- 1 | import re 2 | from .english import english_to_lazy_ipa, english_to_ipa2, english_to_lazy_ipa2 3 | from .mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo, chinese_to_romaji, chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2 4 | 5 | def cjke_cleaners2(text): 6 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', 7 | lambda x: chinese_to_ipa(x.group(1))+' ', text) 8 | text = re.sub(r'\[JA\](.*?)\[JA\]', 9 | lambda x: japanese_to_ipa2(x.group(1))+' ', text) 10 | text = re.sub(r'\[KO\](.*?)\[KO\]', 11 | lambda x: korean_to_ipa(x.group(1))+' ', text) 12 | text = re.sub(r'\[EN\](.*?)\[EN\]', 13 | lambda x: english_to_ipa2(x.group(1))+' ', text) 14 | text = re.sub(r'\s+$', '', text) 15 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 16 | return text -------------------------------------------------------------------------------- /ChatTTS/utils/gpu_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import logging 4 | 5 | def select_device(min_memory = 2048): 6 | logger = logging.getLogger(__name__) 7 | if torch.cuda.is_available(): 8 | available_gpus = [] 9 | for i in range(torch.cuda.device_count()): 10 | props = torch.cuda.get_device_properties(i) 11 | free_memory = props.total_memory - torch.cuda.memory_reserved(i) 12 | available_gpus.append((i, free_memory)) 13 | selected_gpu, max_free_memory = max(available_gpus, key=lambda x: x[1]) 14 | device = torch.device(f'cuda:{selected_gpu}') 15 | free_memory_mb = max_free_memory / (1024 * 1024) 16 | if free_memory_mb < min_memory: 17 | logger.log(logging.WARNING, f'GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left.') 18 | device = torch.device('cpu') 19 | else: 20 | logger.log(logging.WARNING, f'No GPU found, use CPU instead') 21 | device = torch.device('cpu') 22 | 23 | return device 24 | -------------------------------------------------------------------------------- /gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tar filter=lfs diff=lfs merge=lfs -text 29 | *.tflite filter=lfs diff=lfs merge=lfs -text 30 | *.tgz filter=lfs diff=lfs merge=lfs -text 31 | *.wasm filter=lfs diff=lfs merge=lfs -text 32 | *.xz filter=lfs diff=lfs merge=lfs -text 33 | *.zip filter=lfs diff=lfs merge=lfs -text 34 | *.zst filter=lfs diff=lfs merge=lfs -text 35 | *tfevents* filter=lfs diff=lfs merge=lfs -text 36 | -------------------------------------------------------------------------------- /ChatTTS/utils/infer_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class CustomRepetitionPenaltyLogitsProcessorRepeat(): 7 | 8 | def __init__(self, penalty: float, max_input_ids, past_window): 9 | if not isinstance(penalty, float) or not (penalty > 0): 10 | raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") 11 | 12 | self.penalty = penalty 13 | self.max_input_ids = max_input_ids 14 | self.past_window = past_window 15 | 16 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 17 | 18 | input_ids = input_ids[:, -self.past_window:] 19 | freq = F.one_hot(input_ids, scores.size(1)).sum(1) 20 | freq[self.max_input_ids:] = 0 21 | alpha = self.penalty**freq 22 | scores = torch.where(scores < 0, scores*alpha, scores/alpha) 23 | 24 | return scores 25 | 26 | class CustomRepetitionPenaltyLogitsProcessor(): 27 | 28 | def __init__(self, penalty: float, max_input_ids, past_window): 29 | if not isinstance(penalty, float) or not (penalty > 0): 30 | raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") 31 | 32 | self.penalty = penalty 33 | self.max_input_ids = max_input_ids 34 | self.past_window = past_window 35 | 36 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 37 | 38 | input_ids = input_ids[:, -self.past_window:] 39 | score = torch.gather(scores, 1, input_ids) 40 | _score = score.detach().clone() 41 | score = torch.where(score < 0, score * self.penalty, score / self.penalty) 42 | score[input_ids>=self.max_input_ids] = _score[input_ids>=self.max_input_ids] 43 | scores.scatter_(1, input_ids, score) 44 | 45 | return scores -------------------------------------------------------------------------------- /ChatTTS/experimental/llm.py: -------------------------------------------------------------------------------- 1 | 2 | from openai import OpenAI 3 | 4 | prompt_dict = { 5 | 'kimi': [ {"role": "system", "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。"}, 6 | {"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"}, 7 | {"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},], 8 | 'deepseek': [ 9 | {"role": "system", "content": "You are a helpful assistant"}, 10 | {"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"}, 11 | {"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},], 12 | 'deepseek_TN': [ 13 | {"role": "system", "content": "You are a helpful assistant"}, 14 | {"role": "user", "content": "你好,现在我们在处理TTS的文本输入,下面将会给你输入一段文本,请你将其中的阿拉伯数字等等转为文字表达,并且输出的文本里仅包含逗号和句号这两个标点符号"}, 15 | {"role": "assistant", "content": "好的,我现在对TTS的文本输入进行处理。这一般叫做text normalization。下面请输入"}, 16 | {"role": "user", "content": "We paid $123 for this desk."}, 17 | {"role": "assistant", "content": "We paid one hundred and twenty three dollars for this desk."}, 18 | {"role": "user", "content": "详询请拨打010-724654"}, 19 | {"role": "assistant", "content": "详询请拨打零幺零,七二四六五四"}, 20 | {"role": "user", "content": "罗森宣布将于7月24日退市,在华门店超6000家!"}, 21 | {"role": "assistant", "content": "罗森宣布将于七月二十四日退市,在华门店超过六千家。"}, 22 | ], 23 | } 24 | 25 | class llm_api: 26 | def __init__(self, api_key, base_url, model): 27 | self.client = OpenAI( 28 | api_key = api_key, 29 | base_url = base_url, 30 | ) 31 | self.model = model 32 | def call(self, user_question, temperature = 0.3, prompt_version='kimi', **kwargs): 33 | 34 | completion = self.client.chat.completions.create( 35 | model = self.model, 36 | messages = prompt_dict[prompt_version]+[{"role": "user", "content": user_question},], 37 | temperature = temperature, 38 | **kwargs 39 | ) 40 | return completion.choices[0].message.content 41 | -------------------------------------------------------------------------------- /OpenVoice/text/symbols.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Defines the set of symbols used in text input to the model. 3 | ''' 4 | 5 | # japanese_cleaners 6 | # _pad = '_' 7 | # _punctuation = ',.!?-' 8 | # _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ ' 9 | 10 | 11 | '''# japanese_cleaners2 12 | _pad = '_' 13 | _punctuation = ',.!?-~…' 14 | _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ ' 15 | ''' 16 | 17 | 18 | '''# korean_cleaners 19 | _pad = '_' 20 | _punctuation = ',.!?…~' 21 | _letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ ' 22 | ''' 23 | 24 | '''# chinese_cleaners 25 | _pad = '_' 26 | _punctuation = ',。!?—…' 27 | _letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ ' 28 | ''' 29 | 30 | # # zh_ja_mixture_cleaners 31 | # _pad = '_' 32 | # _punctuation = ',.!?-~…' 33 | # _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ ' 34 | 35 | 36 | '''# sanskrit_cleaners 37 | _pad = '_' 38 | _punctuation = '।' 39 | _letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ ' 40 | ''' 41 | 42 | '''# cjks_cleaners 43 | _pad = '_' 44 | _punctuation = ',.!?-~…' 45 | _letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ ' 46 | ''' 47 | 48 | '''# thai_cleaners 49 | _pad = '_' 50 | _punctuation = '.!? ' 51 | _letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์' 52 | ''' 53 | 54 | # # cjke_cleaners2 55 | _pad = '_' 56 | _punctuation = ',.!?-~…' 57 | _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ ' 58 | 59 | 60 | '''# shanghainese_cleaners 61 | _pad = '_' 62 | _punctuation = ',.!?…' 63 | _letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 ' 64 | ''' 65 | 66 | '''# chinese_dialect_cleaners 67 | _pad = '_' 68 | _punctuation = ',.!?~…─' 69 | _letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ ' 70 | ''' 71 | 72 | # Export all symbols: 73 | symbols = [_pad] + list(_punctuation) + list(_letters) 74 | 75 | # Special symbol ids 76 | SPACE_ID = symbols.index(" ") 77 | 78 | num_ja_tones = 1 79 | num_kr_tones = 1 80 | num_zh_tones = 6 81 | num_en_tones = 4 82 | 83 | language_tone_start_map = { 84 | "ZH": 0, 85 | "JP": num_zh_tones, 86 | "EN": num_zh_tones + num_ja_tones, 87 | 'KR': num_zh_tones + num_ja_tones + num_en_tones, 88 | } -------------------------------------------------------------------------------- /OpenVoice/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from . import cleaners 3 | from .symbols import symbols, language_tone_start_map 4 | 5 | # Mappings from symbol to numeric ID and vice versa: 6 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 7 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 8 | 9 | 10 | def text_to_sequence(text, symbols, cleaner_names): 11 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 12 | Args: 13 | text: string to convert to a sequence 14 | cleaner_names: names of the cleaner functions to run the text through 15 | Returns: 16 | List of integers corresponding to the symbols in the text 17 | ''' 18 | sequence = [] 19 | symbol_to_id = {s: i for i, s in enumerate(symbols)} 20 | clean_text = _clean_text(text, cleaner_names) 21 | print(clean_text) 22 | print(f" length:{len(clean_text)}") 23 | for symbol in clean_text: 24 | if symbol not in symbol_to_id.keys(): 25 | continue 26 | symbol_id = symbol_to_id[symbol] 27 | sequence += [symbol_id] 28 | print(f" length:{len(sequence)}") 29 | return sequence 30 | 31 | 32 | def cleaned_text_to_sequence(cleaned_text, symbols): 33 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 34 | Args: 35 | text: string to convert to a sequence 36 | Returns: 37 | List of integers corresponding to the symbols in the text 38 | ''' 39 | symbol_to_id = {s: i for i, s in enumerate(symbols)} 40 | sequence = [symbol_to_id[symbol] for symbol in cleaned_text if symbol in symbol_to_id.keys()] 41 | return sequence 42 | 43 | def cleaned_text_to_sequence_vits2(cleaned_text, tones, language, symbols, languages): 44 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 45 | Args: 46 | text: string to convert to a sequence 47 | Returns: 48 | List of integers corresponding to the symbols in the text 49 | """ 50 | symbol_to_id = {s: i for i, s in enumerate(symbols)} 51 | language_id_map = {s: i for i, s in enumerate(languages)} 52 | phones = [symbol_to_id[symbol] for symbol in cleaned_text] 53 | tone_start = language_tone_start_map[language] 54 | tones = [i + tone_start for i in tones] 55 | lang_id = language_id_map[language] 56 | lang_ids = [lang_id for i in phones] 57 | return phones, tones, lang_ids 58 | 59 | 60 | def sequence_to_text(sequence): 61 | '''Converts a sequence of IDs back to a string''' 62 | result = '' 63 | for symbol_id in sequence: 64 | s = _id_to_symbol[symbol_id] 65 | result += s 66 | return result 67 | 68 | 69 | def _clean_text(text, cleaner_names): 70 | for name in cleaner_names: 71 | cleaner = getattr(cleaners, name) 72 | if not cleaner: 73 | raise Exception('Unknown cleaner: %s' % name) 74 | text = cleaner(text) 75 | return text 76 | -------------------------------------------------------------------------------- /ChatTTS/infer/api.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | from transformers.generation import TopKLogitsWarper, TopPLogitsWarper 5 | from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat 6 | 7 | def infer_code( 8 | models, 9 | text, 10 | spk_emb = None, 11 | top_P = 0.7, 12 | top_K = 20, 13 | temperature = 0.3, 14 | repetition_penalty = 1.05, 15 | max_new_token = 2048, 16 | **kwargs 17 | ): 18 | 19 | device = next(models['gpt'].parameters()).device 20 | 21 | if not isinstance(text, list): 22 | text = [text] 23 | 24 | if not isinstance(temperature, list): 25 | temperature = [temperature] * models['gpt'].num_vq 26 | 27 | if spk_emb is not None: 28 | text = [f'[Stts][spk_emb]{i}[uv_break][Ptts]' for i in text] 29 | else: 30 | text = [f'[Stts][empty_spk]{i}[uv_break][Ptts]' for i in text] 31 | 32 | text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device) 33 | input_ids = text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq) 34 | text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device) 35 | 36 | inputs = { 37 | 'input_ids': input_ids, 38 | 'text_mask': text_mask, 39 | 'attention_mask': text_token['attention_mask'], 40 | } 41 | 42 | emb = models['gpt'].get_emb(**inputs) 43 | if spk_emb is not None: 44 | emb[inputs['input_ids'][..., 0] == models['tokenizer'].convert_tokens_to_ids('[spk_emb]')] = \ 45 | F.normalize(spk_emb.to(device).to(emb.dtype)[None].expand(len(text), -1), p=2.0, dim=1, eps=1e-12) 46 | 47 | num_code = models['gpt'].emb_code[0].num_embeddings - 1 48 | 49 | LogitsWarpers = [] 50 | if top_P is not None: 51 | LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3)) 52 | if top_K is not None: 53 | LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3)) 54 | 55 | LogitsProcessors = [] 56 | if repetition_penalty is not None and repetition_penalty != 1: 57 | LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(\ 58 | repetition_penalty, num_code, 16)) 59 | 60 | result = models['gpt'].generate( 61 | emb, inputs['input_ids'], 62 | temperature = torch.tensor(temperature, device=device), 63 | attention_mask = inputs['attention_mask'], 64 | LogitsWarpers = LogitsWarpers, 65 | LogitsProcessors = LogitsProcessors, 66 | eos_token = num_code, 67 | max_new_token = max_new_token, 68 | infer_text = False, 69 | **kwargs 70 | ) 71 | 72 | return result 73 | 74 | 75 | def refine_text( 76 | models, 77 | text, 78 | top_P = 0.7, 79 | top_K = 20, 80 | temperature = 0.7, 81 | repetition_penalty = 1.0, 82 | max_new_token = 384, 83 | prompt = '', 84 | **kwargs 85 | ): 86 | 87 | device = next(models['gpt'].parameters()).device 88 | 89 | if not isinstance(text, list): 90 | text = [text] 91 | 92 | assert len(text), 'text should not be empty' 93 | 94 | text = [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text] 95 | text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device) 96 | text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device) 97 | 98 | inputs = { 99 | 'input_ids': text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq), 100 | 'text_mask': text_mask, 101 | 'attention_mask': text_token['attention_mask'], 102 | } 103 | 104 | LogitsWarpers = [] 105 | if top_P is not None: 106 | LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3)) 107 | if top_K is not None: 108 | LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3)) 109 | 110 | LogitsProcessors = [] 111 | if repetition_penalty is not None and repetition_penalty != 1: 112 | LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, len(models['tokenizer']), 16)) 113 | 114 | result = models['gpt'].generate( 115 | models['gpt'].get_emb(**inputs), inputs['input_ids'], 116 | temperature = torch.tensor([temperature,], device=device), 117 | attention_mask = inputs['attention_mask'], 118 | LogitsWarpers = LogitsWarpers, 119 | LogitsProcessors = LogitsProcessors, 120 | eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None], 121 | max_new_token = max_new_token, 122 | infer_text = True, 123 | **kwargs 124 | ) 125 | return result -------------------------------------------------------------------------------- /OpenVoice/utils/se_extractor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | from glob import glob 5 | import numpy as np 6 | from pydub import AudioSegment 7 | from faster_whisper import WhisperModel 8 | from whisper_timestamped.transcribe import get_audio_tensor, get_vad_segments 9 | 10 | model_size = "medium" 11 | # Run on GPU with FP16 12 | model = None 13 | def split_audio_whisper(audio_path, target_dir='processed'): 14 | global model 15 | if model is None: 16 | model = WhisperModel(model_size, device="cuda", compute_type="float16") 17 | audio = AudioSegment.from_file(audio_path) 18 | max_len = len(audio) 19 | 20 | audio_name = os.path.basename(audio_path).rsplit('.', 1)[0] 21 | target_folder = os.path.join(target_dir, audio_name) 22 | 23 | segments, info = model.transcribe(audio_path, beam_size=5, word_timestamps=True) 24 | segments = list(segments) 25 | 26 | # create directory 27 | os.makedirs(target_folder, exist_ok=True) 28 | wavs_folder = os.path.join(target_folder, 'wavs') 29 | os.makedirs(wavs_folder, exist_ok=True) 30 | 31 | # segments 32 | s_ind = 0 33 | start_time = None 34 | 35 | for k, w in enumerate(segments): 36 | # process with the time 37 | if k == 0: 38 | start_time = max(0, w.start) 39 | 40 | end_time = w.end 41 | 42 | # calculate confidence 43 | if len(w.words) > 0: 44 | confidence = sum([s.probability for s in w.words]) / len(w.words) 45 | else: 46 | confidence = 0. 47 | # clean text 48 | text = w.text.replace('...', '') 49 | 50 | # left 0.08s for each audios 51 | audio_seg = audio[int( start_time * 1000) : min(max_len, int(end_time * 1000) + 80)] 52 | 53 | # segment file name 54 | fname = f"{audio_name}_seg{s_ind}.wav" 55 | 56 | # filter out the segment shorter than 1.5s and longer than 20s 57 | save = audio_seg.duration_seconds > 1.5 and \ 58 | audio_seg.duration_seconds < 20. and \ 59 | len(text) >= 2 and len(text) < 200 60 | 61 | if save: 62 | output_file = os.path.join(wavs_folder, fname) 63 | audio_seg.export(output_file, format='wav') 64 | 65 | if k < len(segments) - 1: 66 | start_time = max(0, segments[k+1].start - 0.08) 67 | 68 | s_ind = s_ind + 1 69 | return wavs_folder 70 | 71 | 72 | def split_audio_vad(audio_path, target_dir, split_seconds=10.0): 73 | SAMPLE_RATE = 16000 74 | audio_vad = get_audio_tensor(audio_path) 75 | segments = get_vad_segments( 76 | audio_vad, 77 | output_sample=True, 78 | min_speech_duration=0.1, 79 | min_silence_duration=1, 80 | method="silero", 81 | ) 82 | segments = [(seg["start"], seg["end"]) for seg in segments] 83 | segments = [(float(s) / SAMPLE_RATE, float(e) / SAMPLE_RATE) for s,e in segments] 84 | print(segments) 85 | audio_active = AudioSegment.silent(duration=0) 86 | audio = AudioSegment.from_file(audio_path) 87 | 88 | for start_time, end_time in segments: 89 | audio_active += audio[int( start_time * 1000) : int(end_time * 1000)] 90 | 91 | audio_dur = audio_active.duration_seconds 92 | print(f'after vad: dur = {audio_dur}') 93 | audio_name = os.path.basename(audio_path).rsplit('.', 1)[0] 94 | target_folder = os.path.join(target_dir, audio_name) 95 | wavs_folder = os.path.join(target_folder, 'wavs') 96 | os.makedirs(wavs_folder, exist_ok=True) 97 | start_time = 0. 98 | count = 0 99 | num_splits = int(np.round(audio_dur / split_seconds)) 100 | assert num_splits > 0, 'input audio is too short' 101 | interval = audio_dur / num_splits 102 | 103 | for i in range(num_splits): 104 | end_time = min(start_time + interval, audio_dur) 105 | if i == num_splits - 1: 106 | end_time = audio_dur 107 | output_file = f"{wavs_folder}/{audio_name}_seg{count}.wav" 108 | audio_seg = audio_active[int(start_time * 1000): int(end_time * 1000)] 109 | audio_seg.export(output_file, format='wav') 110 | start_time = end_time 111 | count += 1 112 | return wavs_folder 113 | 114 | 115 | 116 | 117 | 118 | def get_se(audio_path, vc_model, target_dir='processed', vad=True): 119 | device = vc_model.device 120 | 121 | audio_name = os.path.basename(audio_path).rsplit('.', 1)[0] 122 | se_path = os.path.join(target_dir, audio_name, 'se.pth') 123 | 124 | if os.path.isfile(se_path): 125 | se = torch.load(se_path).to(device) 126 | return se, audio_name 127 | if os.path.isdir(audio_path): 128 | wavs_folder = audio_path 129 | elif vad: 130 | wavs_folder = split_audio_vad(audio_path, target_dir) 131 | else: 132 | wavs_folder = split_audio_whisper(audio_path, target_dir) 133 | 134 | audio_segs = glob(f'{wavs_folder}/*.wav') 135 | if len(audio_segs) == 0: 136 | raise NotImplementedError('No audio segments found!') 137 | 138 | return vc_model.extract_se(audio_segs, se_save_path=se_path), audio_name 139 | 140 | -------------------------------------------------------------------------------- /ChatTTS/model/dvae.py: -------------------------------------------------------------------------------- 1 | import math 2 | from einops import rearrange 3 | from vector_quantize_pytorch import GroupedResidualFSQ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | class ConvNeXtBlock(nn.Module): 10 | def __init__( 11 | self, 12 | dim: int, 13 | intermediate_dim: int, 14 | kernel, dilation, 15 | layer_scale_init_value: float = 1e-6, 16 | ): 17 | # ConvNeXt Block copied from Vocos. 18 | super().__init__() 19 | self.dwconv = nn.Conv1d(dim, dim, 20 | kernel_size=kernel, padding=dilation*(kernel//2), 21 | dilation=dilation, groups=dim 22 | ) # depthwise conv 23 | 24 | self.norm = nn.LayerNorm(dim, eps=1e-6) 25 | self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers 26 | self.act = nn.GELU() 27 | self.pwconv2 = nn.Linear(intermediate_dim, dim) 28 | self.gamma = ( 29 | nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) 30 | if layer_scale_init_value > 0 31 | else None 32 | ) 33 | 34 | def forward(self, x: torch.Tensor, cond = None) -> torch.Tensor: 35 | residual = x 36 | x = self.dwconv(x) 37 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) 38 | x = self.norm(x) 39 | x = self.pwconv1(x) 40 | x = self.act(x) 41 | x = self.pwconv2(x) 42 | if self.gamma is not None: 43 | x = self.gamma * x 44 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) 45 | 46 | x = residual + x 47 | return x 48 | 49 | 50 | 51 | class GFSQ(nn.Module): 52 | 53 | def __init__(self, 54 | dim, levels, G, R, eps=1e-5, transpose = True 55 | ): 56 | super(GFSQ, self).__init__() 57 | self.quantizer = GroupedResidualFSQ( 58 | dim=dim, 59 | levels=levels, 60 | num_quantizers=R, 61 | groups=G, 62 | ) 63 | self.n_ind = math.prod(levels) 64 | self.eps = eps 65 | self.transpose = transpose 66 | self.G = G 67 | self.R = R 68 | 69 | def _embed(self, x): 70 | if self.transpose: 71 | x = x.transpose(1,2) 72 | x = rearrange( 73 | x, "b t (g r) -> g b t r", g = self.G, r = self.R, 74 | ) 75 | feat = self.quantizer.get_output_from_indices(x) 76 | return feat.transpose(1,2) if self.transpose else feat 77 | 78 | def forward(self, x,): 79 | if self.transpose: 80 | x = x.transpose(1,2) 81 | feat, ind = self.quantizer(x) 82 | ind = rearrange( 83 | ind, "g b t r ->b t (g r)", 84 | ) 85 | embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype) 86 | e_mean = torch.mean(embed_onehot, dim=[0,1]) 87 | e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1) 88 | perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1)) 89 | 90 | return ( 91 | torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device), 92 | feat.transpose(1,2) if self.transpose else feat, 93 | perplexity, 94 | None, 95 | ind.transpose(1,2) if self.transpose else ind, 96 | ) 97 | 98 | class DVAEDecoder(nn.Module): 99 | def __init__(self, idim, odim, 100 | n_layer = 12, bn_dim = 64, hidden = 256, 101 | kernel = 7, dilation = 2, up = False 102 | ): 103 | super().__init__() 104 | self.up = up 105 | self.conv_in = nn.Sequential( 106 | nn.Conv1d(idim, bn_dim, 3, 1, 1), nn.GELU(), 107 | nn.Conv1d(bn_dim, hidden, 3, 1, 1) 108 | ) 109 | self.decoder_block = nn.ModuleList([ 110 | ConvNeXtBlock(hidden, hidden* 4, kernel, dilation,) 111 | for _ in range(n_layer)]) 112 | self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False) 113 | 114 | def forward(self, input, conditioning=None): 115 | # B, T, C 116 | x = input.transpose(1, 2) 117 | x = self.conv_in(x) 118 | for f in self.decoder_block: 119 | x = f(x, conditioning) 120 | 121 | x = self.conv_out(x) 122 | return x.transpose(1, 2) 123 | 124 | 125 | class DVAE(nn.Module): 126 | def __init__( 127 | self, decoder_config, vq_config, dim=512 128 | ): 129 | super().__init__() 130 | self.register_buffer('coef', torch.randn(1, 100, 1)) 131 | 132 | self.decoder = DVAEDecoder(**decoder_config) 133 | self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False) 134 | if vq_config is not None: 135 | self.vq_layer = GFSQ(**vq_config) 136 | else: 137 | self.vq_layer = None 138 | 139 | def forward(self, inp): 140 | 141 | if self.vq_layer is not None: 142 | vq_feats = self.vq_layer._embed(inp) 143 | else: 144 | vq_feats = inp.detach().clone() 145 | 146 | temp = torch.chunk(vq_feats, 2, dim=1) # flatten trick :) 147 | temp = torch.stack(temp, -1) 148 | vq_feats = temp.reshape(*temp.shape[:2], -1) 149 | 150 | vq_feats = vq_feats.transpose(1, 2) 151 | dec_out = self.decoder(input=vq_feats) 152 | dec_out = self.out_conv(dec_out.transpose(1, 2)) 153 | mel = dec_out * self.coef 154 | 155 | return mel 156 | -------------------------------------------------------------------------------- /OpenVoice/utils/commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | def init_weights(m, mean=0.0, std=0.01): 7 | classname = m.__class__.__name__ 8 | if classname.find("Conv") != -1: 9 | m.weight.data.normal_(mean, std) 10 | 11 | 12 | def get_padding(kernel_size, dilation=1): 13 | return int((kernel_size * dilation - dilation) / 2) 14 | 15 | 16 | def convert_pad_shape(pad_shape): 17 | layer = pad_shape[::-1] 18 | pad_shape = [item for sublist in layer for item in sublist] 19 | return pad_shape 20 | 21 | 22 | def intersperse(lst, item): 23 | result = [item] * (len(lst) * 2 + 1) 24 | result[1::2] = lst 25 | return result 26 | 27 | 28 | def kl_divergence(m_p, logs_p, m_q, logs_q): 29 | """KL(P||Q)""" 30 | kl = (logs_q - logs_p) - 0.5 31 | kl += ( 32 | 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) 33 | ) 34 | return kl 35 | 36 | 37 | def rand_gumbel(shape): 38 | """Sample from the Gumbel distribution, protect from overflows.""" 39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 40 | return -torch.log(-torch.log(uniform_samples)) 41 | 42 | 43 | def rand_gumbel_like(x): 44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 45 | return g 46 | 47 | 48 | def slice_segments(x, ids_str, segment_size=4): 49 | ret = torch.zeros_like(x[:, :, :segment_size]) 50 | for i in range(x.size(0)): 51 | idx_str = ids_str[i] 52 | idx_end = idx_str + segment_size 53 | ret[i] = x[i, :, idx_str:idx_end] 54 | return ret 55 | 56 | 57 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 58 | b, d, t = x.size() 59 | if x_lengths is None: 60 | x_lengths = t 61 | ids_str_max = x_lengths - segment_size + 1 62 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 63 | ret = slice_segments(x, ids_str, segment_size) 64 | return ret, ids_str 65 | 66 | 67 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): 68 | position = torch.arange(length, dtype=torch.float) 69 | num_timescales = channels // 2 70 | log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( 71 | num_timescales - 1 72 | ) 73 | inv_timescales = min_timescale * torch.exp( 74 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment 75 | ) 76 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 77 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 78 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 79 | signal = signal.view(1, channels, length) 80 | return signal 81 | 82 | 83 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 84 | b, channels, length = x.size() 85 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 86 | return x + signal.to(dtype=x.dtype, device=x.device) 87 | 88 | 89 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 90 | b, channels, length = x.size() 91 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 92 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 93 | 94 | 95 | def subsequent_mask(length): 96 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 97 | return mask 98 | 99 | 100 | @torch.jit.script 101 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 102 | n_channels_int = n_channels[0] 103 | in_act = input_a + input_b 104 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 105 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 106 | acts = t_act * s_act 107 | return acts 108 | 109 | 110 | def convert_pad_shape(pad_shape): 111 | layer = pad_shape[::-1] 112 | pad_shape = [item for sublist in layer for item in sublist] 113 | return pad_shape 114 | 115 | 116 | def shift_1d(x): 117 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 118 | return x 119 | 120 | 121 | def sequence_mask(length, max_length=None): 122 | if max_length is None: 123 | max_length = length.max() 124 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 125 | return x.unsqueeze(0) < length.unsqueeze(1) 126 | 127 | 128 | def generate_path(duration, mask): 129 | """ 130 | duration: [b, 1, t_x] 131 | mask: [b, 1, t_y, t_x] 132 | """ 133 | 134 | b, _, t_y, t_x = mask.shape 135 | cum_duration = torch.cumsum(duration, -1) 136 | 137 | cum_duration_flat = cum_duration.view(b * t_x) 138 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 139 | path = path.view(b, t_x, t_y) 140 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 141 | path = path.unsqueeze(1).transpose(2, 3) * mask 142 | return path 143 | 144 | 145 | def clip_grad_value_(parameters, clip_value, norm_type=2): 146 | if isinstance(parameters, torch.Tensor): 147 | parameters = [parameters] 148 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 149 | norm_type = float(norm_type) 150 | if clip_value is not None: 151 | clip_value = float(clip_value) 152 | 153 | total_norm = 0 154 | for p in parameters: 155 | param_norm = p.grad.data.norm(norm_type) 156 | total_norm += param_norm.item() ** norm_type 157 | if clip_value is not None: 158 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 159 | total_norm = total_norm ** (1.0 / norm_type) 160 | return total_norm 161 | -------------------------------------------------------------------------------- /OpenVoice/text/english.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | 16 | # Regular expression matching whitespace: 17 | 18 | 19 | import re 20 | import inflect 21 | from unidecode import unidecode 22 | import eng_to_ipa as ipa 23 | _inflect = inflect.engine() 24 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 25 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 26 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 27 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 28 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 29 | _number_re = re.compile(r'[0-9]+') 30 | 31 | # List of (regular expression, replacement) pairs for abbreviations: 32 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 33 | ('mrs', 'misess'), 34 | ('mr', 'mister'), 35 | ('dr', 'doctor'), 36 | ('st', 'saint'), 37 | ('co', 'company'), 38 | ('jr', 'junior'), 39 | ('maj', 'major'), 40 | ('gen', 'general'), 41 | ('drs', 'doctors'), 42 | ('rev', 'reverend'), 43 | ('lt', 'lieutenant'), 44 | ('hon', 'honorable'), 45 | ('sgt', 'sergeant'), 46 | ('capt', 'captain'), 47 | ('esq', 'esquire'), 48 | ('ltd', 'limited'), 49 | ('col', 'colonel'), 50 | ('ft', 'fort'), 51 | ]] 52 | 53 | 54 | # List of (ipa, lazy ipa) pairs: 55 | _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 56 | ('r', 'ɹ'), 57 | ('æ', 'e'), 58 | ('ɑ', 'a'), 59 | ('ɔ', 'o'), 60 | ('ð', 'z'), 61 | ('θ', 's'), 62 | ('ɛ', 'e'), 63 | ('ɪ', 'i'), 64 | ('ʊ', 'u'), 65 | ('ʒ', 'ʥ'), 66 | ('ʤ', 'ʥ'), 67 | ('ˈ', '↓'), 68 | ]] 69 | 70 | # List of (ipa, lazy ipa2) pairs: 71 | _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 72 | ('r', 'ɹ'), 73 | ('ð', 'z'), 74 | ('θ', 's'), 75 | ('ʒ', 'ʑ'), 76 | ('ʤ', 'dʑ'), 77 | ('ˈ', '↓'), 78 | ]] 79 | 80 | # List of (ipa, ipa2) pairs 81 | _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 82 | ('r', 'ɹ'), 83 | ('ʤ', 'dʒ'), 84 | ('ʧ', 'tʃ') 85 | ]] 86 | 87 | 88 | def expand_abbreviations(text): 89 | for regex, replacement in _abbreviations: 90 | text = re.sub(regex, replacement, text) 91 | return text 92 | 93 | 94 | def collapse_whitespace(text): 95 | return re.sub(r'\s+', ' ', text) 96 | 97 | 98 | def _remove_commas(m): 99 | return m.group(1).replace(',', '') 100 | 101 | 102 | def _expand_decimal_point(m): 103 | return m.group(1).replace('.', ' point ') 104 | 105 | 106 | def _expand_dollars(m): 107 | match = m.group(1) 108 | parts = match.split('.') 109 | if len(parts) > 2: 110 | return match + ' dollars' # Unexpected format 111 | dollars = int(parts[0]) if parts[0] else 0 112 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 113 | if dollars and cents: 114 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 115 | cent_unit = 'cent' if cents == 1 else 'cents' 116 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 117 | elif dollars: 118 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 119 | return '%s %s' % (dollars, dollar_unit) 120 | elif cents: 121 | cent_unit = 'cent' if cents == 1 else 'cents' 122 | return '%s %s' % (cents, cent_unit) 123 | else: 124 | return 'zero dollars' 125 | 126 | 127 | def _expand_ordinal(m): 128 | return _inflect.number_to_words(m.group(0)) 129 | 130 | 131 | def _expand_number(m): 132 | num = int(m.group(0)) 133 | if num > 1000 and num < 3000: 134 | if num == 2000: 135 | return 'two thousand' 136 | elif num > 2000 and num < 2010: 137 | return 'two thousand ' + _inflect.number_to_words(num % 100) 138 | elif num % 100 == 0: 139 | return _inflect.number_to_words(num // 100) + ' hundred' 140 | else: 141 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 142 | else: 143 | return _inflect.number_to_words(num, andword='') 144 | 145 | 146 | def normalize_numbers(text): 147 | text = re.sub(_comma_number_re, _remove_commas, text) 148 | text = re.sub(_pounds_re, r'\1 pounds', text) 149 | text = re.sub(_dollars_re, _expand_dollars, text) 150 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 151 | text = re.sub(_ordinal_re, _expand_ordinal, text) 152 | text = re.sub(_number_re, _expand_number, text) 153 | return text 154 | 155 | 156 | def mark_dark_l(text): 157 | return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text) 158 | 159 | 160 | def english_to_ipa(text): 161 | text = unidecode(text).lower() 162 | text = expand_abbreviations(text) 163 | text = normalize_numbers(text) 164 | phonemes = ipa.convert(text) 165 | phonemes = collapse_whitespace(phonemes) 166 | return phonemes 167 | 168 | 169 | def english_to_lazy_ipa(text): 170 | text = english_to_ipa(text) 171 | for regex, replacement in _lazy_ipa: 172 | text = re.sub(regex, replacement, text) 173 | return text 174 | 175 | 176 | def english_to_ipa2(text): 177 | text = english_to_ipa(text) 178 | text = mark_dark_l(text) 179 | for regex, replacement in _ipa_to_ipa2: 180 | text = re.sub(regex, replacement, text) 181 | return text.replace('...', '…') 182 | 183 | 184 | def english_to_lazy_ipa2(text): 185 | text = english_to_ipa(text) 186 | for regex, replacement in _lazy_ipa2: 187 | text = re.sub(regex, replacement, text) 188 | return text 189 | -------------------------------------------------------------------------------- /OpenVoice/utils/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import numpy as np 4 | 5 | 6 | def get_hparams_from_file(config_path): 7 | with open(config_path, "r", encoding="utf-8") as f: 8 | data = f.read() 9 | config = json.loads(data) 10 | 11 | hparams = HParams(**config) 12 | return hparams 13 | 14 | class HParams: 15 | def __init__(self, **kwargs): 16 | for k, v in kwargs.items(): 17 | if type(v) == dict: 18 | v = HParams(**v) 19 | self[k] = v 20 | 21 | def keys(self): 22 | return self.__dict__.keys() 23 | 24 | def items(self): 25 | return self.__dict__.items() 26 | 27 | def values(self): 28 | return self.__dict__.values() 29 | 30 | def __len__(self): 31 | return len(self.__dict__) 32 | 33 | def __getitem__(self, key): 34 | return getattr(self, key) 35 | 36 | def __setitem__(self, key, value): 37 | return setattr(self, key, value) 38 | 39 | def __contains__(self, key): 40 | return key in self.__dict__ 41 | 42 | def __repr__(self): 43 | return self.__dict__.__repr__() 44 | 45 | 46 | def string_to_bits(string, pad_len=8): 47 | # Convert each character to its ASCII value 48 | ascii_values = [ord(char) for char in string] 49 | 50 | # Convert ASCII values to binary representation 51 | binary_values = [bin(value)[2:].zfill(8) for value in ascii_values] 52 | 53 | # Convert binary strings to integer arrays 54 | bit_arrays = [[int(bit) for bit in binary] for binary in binary_values] 55 | 56 | # Convert list of arrays to NumPy array 57 | numpy_array = np.array(bit_arrays) 58 | numpy_array_full = np.zeros((pad_len, 8), dtype=numpy_array.dtype) 59 | numpy_array_full[:, 2] = 1 60 | max_len = min(pad_len, len(numpy_array)) 61 | numpy_array_full[:max_len] = numpy_array[:max_len] 62 | return numpy_array_full 63 | 64 | 65 | def bits_to_string(bits_array): 66 | # Convert each row of the array to a binary string 67 | binary_values = [''.join(str(bit) for bit in row) for row in bits_array] 68 | 69 | # Convert binary strings to ASCII values 70 | ascii_values = [int(binary, 2) for binary in binary_values] 71 | 72 | # Convert ASCII values to characters 73 | output_string = ''.join(chr(value) for value in ascii_values) 74 | 75 | return output_string 76 | 77 | 78 | def split_sentence(text, min_len=10, language_str='[EN]'): 79 | if language_str in ['EN']: 80 | sentences = split_sentences_latin(text, min_len=min_len) 81 | else: 82 | sentences = split_sentences_zh(text, min_len=min_len) 83 | return sentences 84 | 85 | def split_sentences_latin(text, min_len=10): 86 | """Split Long sentences into list of short ones 87 | 88 | Args: 89 | str: Input sentences. 90 | 91 | Returns: 92 | List[str]: list of output sentences. 93 | """ 94 | # deal with dirty sentences 95 | text = re.sub('[。!?;]', '.', text) 96 | text = re.sub('[,]', ',', text) 97 | text = re.sub('[“”]', '"', text) 98 | text = re.sub('[‘’]', "'", text) 99 | text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text) 100 | text = re.sub('[\n\t ]+', ' ', text) 101 | text = re.sub('([,.!?;])', r'\1 $#!', text) 102 | # split 103 | sentences = [s.strip() for s in text.split('$#!')] 104 | if len(sentences[-1]) == 0: del sentences[-1] 105 | 106 | new_sentences = [] 107 | new_sent = [] 108 | count_len = 0 109 | for ind, sent in enumerate(sentences): 110 | # print(sent) 111 | new_sent.append(sent) 112 | count_len += len(sent.split(" ")) 113 | if count_len > min_len or ind == len(sentences) - 1: 114 | count_len = 0 115 | new_sentences.append(' '.join(new_sent)) 116 | new_sent = [] 117 | return merge_short_sentences_latin(new_sentences) 118 | 119 | 120 | def merge_short_sentences_latin(sens): 121 | """Avoid short sentences by merging them with the following sentence. 122 | 123 | Args: 124 | List[str]: list of input sentences. 125 | 126 | Returns: 127 | List[str]: list of output sentences. 128 | """ 129 | sens_out = [] 130 | for s in sens: 131 | # If the previous sentense is too short, merge them with 132 | # the current sentence. 133 | if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2: 134 | sens_out[-1] = sens_out[-1] + " " + s 135 | else: 136 | sens_out.append(s) 137 | try: 138 | if len(sens_out[-1].split(" ")) <= 2: 139 | sens_out[-2] = sens_out[-2] + " " + sens_out[-1] 140 | sens_out.pop(-1) 141 | except: 142 | pass 143 | return sens_out 144 | 145 | def split_sentences_zh(text, min_len=10): 146 | text = re.sub('[。!?;]', '.', text) 147 | text = re.sub('[,]', ',', text) 148 | # 将文本中的换行符、空格和制表符替换为空格 149 | text = re.sub('[\n\t ]+', ' ', text) 150 | # 在标点符号后添加一个空格 151 | text = re.sub('([,.!?;])', r'\1 $#!', text) 152 | # 分隔句子并去除前后空格 153 | # sentences = [s.strip() for s in re.split('(。|!|?|;)', text)] 154 | sentences = [s.strip() for s in text.split('$#!')] 155 | if len(sentences[-1]) == 0: del sentences[-1] 156 | 157 | new_sentences = [] 158 | new_sent = [] 159 | count_len = 0 160 | for ind, sent in enumerate(sentences): 161 | new_sent.append(sent) 162 | count_len += len(sent) 163 | if count_len > min_len or ind == len(sentences) - 1: 164 | count_len = 0 165 | new_sentences.append(' '.join(new_sent)) 166 | new_sent = [] 167 | return merge_short_sentences_zh(new_sentences) 168 | 169 | 170 | def merge_short_sentences_zh(sens): 171 | # return sens 172 | """Avoid short sentences by merging them with the following sentence. 173 | 174 | Args: 175 | List[str]: list of input sentences. 176 | 177 | Returns: 178 | List[str]: list of output sentences. 179 | """ 180 | sens_out = [] 181 | for s in sens: 182 | # If the previous sentense is too short, merge them with 183 | # the current sentence. 184 | if len(sens_out) > 0 and len(sens_out[-1]) <= 2: 185 | sens_out[-1] = sens_out[-1] + " " + s 186 | else: 187 | sens_out.append(s) 188 | try: 189 | if len(sens_out[-1]) <= 2: 190 | sens_out[-2] = sens_out[-2] + " " + sens_out[-1] 191 | sens_out.pop(-1) 192 | except: 193 | pass 194 | return sens_out -------------------------------------------------------------------------------- /ChatTTS/core.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import logging 4 | from omegaconf import OmegaConf 5 | 6 | import torch 7 | from vocos import Vocos 8 | from .model.dvae import DVAE 9 | from .model.gpt import GPT_warpper 10 | from .utils.gpu_utils import select_device 11 | from .utils.io_utils import get_latest_modified_file 12 | from .infer.api import refine_text, infer_code 13 | 14 | from huggingface_hub import snapshot_download 15 | 16 | logging.basicConfig(level = logging.INFO) 17 | 18 | 19 | class Chat: 20 | def __init__(self, ): 21 | self.pretrain_models = {} 22 | self.logger = logging.getLogger(__name__) 23 | 24 | def check_model(self, level = logging.INFO, use_decoder = False): 25 | not_finish = False 26 | check_list = ['vocos', 'gpt', 'tokenizer'] 27 | 28 | if use_decoder: 29 | check_list.append('decoder') 30 | else: 31 | check_list.append('dvae') 32 | 33 | for module in check_list: 34 | if module not in self.pretrain_models: 35 | self.logger.log(logging.WARNING, f'{module} not initialized.') 36 | not_finish = True 37 | 38 | if not not_finish: 39 | self.logger.log(level, f'All initialized.') 40 | 41 | return not not_finish 42 | 43 | def load_models(self, source='huggingface', force_redownload=False, local_path=''): 44 | if source == 'huggingface': 45 | hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface")) 46 | try: 47 | download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots')) 48 | except: 49 | download_path = None 50 | if download_path is None or force_redownload: 51 | self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS') 52 | download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"]) 53 | else: 54 | self.logger.log(logging.INFO, f'Load from cache: {download_path}') 55 | self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}) 56 | elif source == 'local': 57 | self.logger.log(logging.INFO, f'Load from local: {local_path}') 58 | self._load(**{k: os.path.join(local_path, v) for k, v in OmegaConf.load(os.path.join(local_path, 'config', 'path.yaml')).items()}) 59 | 60 | def _load( 61 | self, 62 | vocos_config_path: str = None, 63 | vocos_ckpt_path: str = None, 64 | dvae_config_path: str = None, 65 | dvae_ckpt_path: str = None, 66 | gpt_config_path: str = None, 67 | gpt_ckpt_path: str = None, 68 | decoder_config_path: str = None, 69 | decoder_ckpt_path: str = None, 70 | tokenizer_path: str = None, 71 | device: str = None 72 | ): 73 | if not device: 74 | device = select_device(4096) 75 | self.logger.log(logging.INFO, f'use {device}') 76 | 77 | if vocos_config_path: 78 | vocos = Vocos.from_hparams(vocos_config_path).to(device).eval() 79 | assert vocos_ckpt_path, 'vocos_ckpt_path should not be None' 80 | vocos.load_state_dict(torch.load(vocos_ckpt_path)) 81 | self.pretrain_models['vocos'] = vocos 82 | self.logger.log(logging.INFO, 'vocos loaded.') 83 | 84 | if dvae_config_path: 85 | cfg = OmegaConf.load(dvae_config_path) 86 | dvae = DVAE(**cfg).to(device).eval() 87 | assert dvae_ckpt_path, 'dvae_ckpt_path should not be None' 88 | dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu')) 89 | self.pretrain_models['dvae'] = dvae 90 | self.logger.log(logging.INFO, 'dvae loaded.') 91 | 92 | if gpt_config_path: 93 | cfg = OmegaConf.load(gpt_config_path) 94 | gpt = GPT_warpper(**cfg).to(device).eval() 95 | assert gpt_ckpt_path, 'gpt_ckpt_path should not be None' 96 | gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu')) 97 | self.pretrain_models['gpt'] = gpt 98 | self.logger.log(logging.INFO, 'gpt loaded.') 99 | 100 | if decoder_config_path: 101 | cfg = OmegaConf.load(decoder_config_path) 102 | decoder = DVAE(**cfg).to(device).eval() 103 | assert decoder_ckpt_path, 'decoder_ckpt_path should not be None' 104 | decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu')) 105 | self.pretrain_models['decoder'] = decoder 106 | self.logger.log(logging.INFO, 'decoder loaded.') 107 | 108 | if tokenizer_path: 109 | tokenizer = torch.load(tokenizer_path, map_location='cpu') 110 | tokenizer.padding_side = 'left' 111 | self.pretrain_models['tokenizer'] = tokenizer 112 | self.logger.log(logging.INFO, 'tokenizer loaded.') 113 | 114 | self.check_model() 115 | 116 | def infer( 117 | self, 118 | text, 119 | skip_refine_text=False, 120 | refine_text_only=False, 121 | params_refine_text={}, 122 | params_infer_code={}, 123 | use_decoder=False 124 | ): 125 | 126 | assert self.check_model(use_decoder=use_decoder) 127 | 128 | if not skip_refine_text: 129 | text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids'] 130 | text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens] 131 | text = self.pretrain_models['tokenizer'].batch_decode(text_tokens) 132 | if refine_text_only: 133 | return text 134 | 135 | text = [params_infer_code.get('prompt', '') + i for i in text] 136 | params_infer_code.pop('prompt', '') 137 | result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder) 138 | 139 | if use_decoder: 140 | mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']] 141 | else: 142 | mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']] 143 | 144 | wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec] 145 | 146 | return wav 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # import spaces 2 | import os 3 | import random 4 | import argparse 5 | 6 | import torch 7 | import gradio as gr 8 | import numpy as np 9 | 10 | import ChatTTS 11 | 12 | from OpenVoice.utils import se_extractor 13 | from OpenVoice.api import ToneColorConverter 14 | import soundfile 15 | 16 | print("loading ChatTTS model...") 17 | chat = ChatTTS.Chat() 18 | chat.load_models() 19 | 20 | 21 | def generate_seed(): 22 | new_seed = random.randint(1, 100000000) 23 | return { 24 | "__type__": "update", 25 | "value": new_seed 26 | } 27 | 28 | # @spaces.GPU 29 | def chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, output_path=None): 30 | 31 | torch.manual_seed(audio_seed_input) 32 | rand_spk = torch.randn(768) 33 | params_infer_code = { 34 | 'spk_emb': rand_spk, 35 | 'temperature': temperature, 36 | 'top_P': top_P, 37 | 'top_K': top_K, 38 | } 39 | params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'} 40 | 41 | torch.manual_seed(text_seed_input) 42 | 43 | if refine_text_flag: 44 | if refine_text_input: 45 | params_refine_text['prompt'] = refine_text_input 46 | text = chat.infer(text, 47 | skip_refine_text=False, 48 | refine_text_only=True, 49 | params_refine_text=params_refine_text, 50 | params_infer_code=params_infer_code 51 | ) 52 | print("Text has been refined!") 53 | 54 | wav = chat.infer(text, 55 | skip_refine_text=True, 56 | params_refine_text=params_refine_text, 57 | params_infer_code=params_infer_code 58 | ) 59 | 60 | audio_data = np.array(wav[0]).flatten() 61 | sample_rate = 24000 62 | text_data = text[0] if isinstance(text, list) else text 63 | 64 | if output_path is None: 65 | return [(sample_rate, audio_data), text_data] 66 | else: 67 | soundfile.write(output_path, audio_data, sample_rate) 68 | return text_data 69 | 70 | # OpenVoice Clone 71 | ckpt_converter = 'OpenVoice/checkpoints/converter' 72 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 73 | 74 | tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device) 75 | tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth') 76 | 77 | def generate_audio(text, audio_ref, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input): 78 | save_path = "output.wav" 79 | 80 | if audio_ref != "" : 81 | # Run the base speaker tts 82 | src_path = "tmp.wav" 83 | text_data = chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, src_path) 84 | print("Ready for voice cloning!") 85 | 86 | source_se, audio_name = se_extractor.get_se(src_path, tone_color_converter, target_dir='processed', vad=True) 87 | reference_speaker = audio_ref 88 | target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True) 89 | 90 | print("Get voices segment!") 91 | 92 | # Run the tone color converter 93 | # convert from file 94 | tone_color_converter.convert( 95 | audio_src_path=src_path, 96 | src_se=source_se, 97 | tgt_se=target_se, 98 | output_path=save_path) 99 | else: 100 | chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, save_path) 101 | 102 | print("Finished!") 103 | 104 | return [save_path, text_data] 105 | 106 | 107 | with gr.Blocks() as demo: 108 | gr.Markdown("#
🥳 ChatTTS x OpenVoice 🥳
") 109 | gr.Markdown("##
🌟 Make it sound super natural and switch it up to any voice you want, nailing the mood and tone also!🌟
") 110 | 111 | default_text = "Today a man knocked on my door and asked for a small donation toward the local swimming pool. I gave him a glass of water." 112 | text_input = gr.Textbox(label="Input Text", lines=4, placeholder="Please Input Text...", value=default_text) 113 | 114 | 115 | default_refine_text = "[oral_2][laugh_0][break_6]" 116 | refine_text_checkbox = gr.Checkbox(label="Refine text", info="'oral' means add filler words, 'laugh' means add laughter, and 'break' means add a pause. (0-10) ", value=True) 117 | refine_text_input = gr.Textbox(label="Refine Prompt", lines=1, placeholder="Please Refine Prompt...", value=default_refine_text) 118 | with gr.Column(): 119 | voice_ref = gr.Audio(label="Reference Audio", type="filepath", value="Examples/speaker.mp3") 120 | 121 | with gr.Row(): 122 | temperature_slider = gr.Slider(minimum=0.00001, maximum=1.0, step=0.00001, value=0.3, label="Audio temperature") 123 | top_p_slider = gr.Slider(minimum=0.1, maximum=0.9, step=0.05, value=0.7, label="top_P") 124 | top_k_slider = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_K") 125 | 126 | with gr.Row(): 127 | audio_seed_input = gr.Number(value=42, label="Speaker Seed") 128 | generate_audio_seed = gr.Button("\U0001F3B2") 129 | text_seed_input = gr.Number(value=42, label="Text Seed") 130 | generate_text_seed = gr.Button("\U0001F3B2") 131 | 132 | generate_button = gr.Button("Generate") 133 | 134 | text_output = gr.Textbox(label="Refined Text", interactive=False) 135 | audio_output = gr.Audio(label="Output Audio") 136 | 137 | generate_audio_seed.click(generate_seed, 138 | inputs=[], 139 | outputs=audio_seed_input) 140 | 141 | generate_text_seed.click(generate_seed, 142 | inputs=[], 143 | outputs=text_seed_input) 144 | 145 | generate_button.click(generate_audio, 146 | inputs=[text_input, voice_ref, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, refine_text_checkbox, refine_text_input], 147 | outputs=[audio_output,text_output]) 148 | 149 | parser = argparse.ArgumentParser(description='ChatTTS-OpenVoice Launch') 150 | parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name') 151 | parser.add_argument('--server_port', type=int, default=8080, help='Server port') 152 | args = parser.parse_args() 153 | 154 | # demo.launch(server_name=args.server_name, server_port=args.server_port, inbrowser=True) 155 | 156 | if __name__ == '__main__': 157 | demo.launch() 158 | -------------------------------------------------------------------------------- /OpenVoice/utils/mel_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from librosa.filters import mel as librosa_mel_fn 4 | 5 | MAX_WAV_VALUE = 32768.0 6 | 7 | 8 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 9 | """ 10 | PARAMS 11 | ------ 12 | C: compression factor 13 | """ 14 | return torch.log(torch.clamp(x, min=clip_val) * C) 15 | 16 | 17 | def dynamic_range_decompression_torch(x, C=1): 18 | """ 19 | PARAMS 20 | ------ 21 | C: compression factor used to compress 22 | """ 23 | return torch.exp(x) / C 24 | 25 | 26 | def spectral_normalize_torch(magnitudes): 27 | output = dynamic_range_compression_torch(magnitudes) 28 | return output 29 | 30 | 31 | def spectral_de_normalize_torch(magnitudes): 32 | output = dynamic_range_decompression_torch(magnitudes) 33 | return output 34 | 35 | 36 | mel_basis = {} 37 | hann_window = {} 38 | 39 | 40 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 41 | if torch.min(y) < -1.1: 42 | print("min value is ", torch.min(y)) 43 | if torch.max(y) > 1.1: 44 | print("max value is ", torch.max(y)) 45 | 46 | global hann_window 47 | dtype_device = str(y.dtype) + "_" + str(y.device) 48 | wnsize_dtype_device = str(win_size) + "_" + dtype_device 49 | if wnsize_dtype_device not in hann_window: 50 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( 51 | dtype=y.dtype, device=y.device 52 | ) 53 | 54 | y = torch.nn.functional.pad( 55 | y.unsqueeze(1), 56 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 57 | mode="reflect", 58 | ) 59 | y = y.squeeze(1) 60 | 61 | spec = torch.stft( 62 | y, 63 | n_fft, 64 | hop_length=hop_size, 65 | win_length=win_size, 66 | window=hann_window[wnsize_dtype_device], 67 | center=center, 68 | pad_mode="reflect", 69 | normalized=False, 70 | onesided=True, 71 | return_complex=False, 72 | ) 73 | 74 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 75 | return spec 76 | 77 | 78 | def spectrogram_torch_conv(y, n_fft, sampling_rate, hop_size, win_size, center=False): 79 | # if torch.min(y) < -1.: 80 | # print('min value is ', torch.min(y)) 81 | # if torch.max(y) > 1.: 82 | # print('max value is ', torch.max(y)) 83 | 84 | global hann_window 85 | dtype_device = str(y.dtype) + '_' + str(y.device) 86 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 87 | if wnsize_dtype_device not in hann_window: 88 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 89 | 90 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 91 | 92 | # ******************** original ************************# 93 | # y = y.squeeze(1) 94 | # spec1 = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 95 | # center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 96 | 97 | # ******************** ConvSTFT ************************# 98 | freq_cutoff = n_fft // 2 + 1 99 | fourier_basis = torch.view_as_real(torch.fft.fft(torch.eye(n_fft))) 100 | forward_basis = fourier_basis[:freq_cutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis.shape[1]) 101 | forward_basis = forward_basis * torch.as_tensor(librosa.util.pad_center(torch.hann_window(win_size), size=n_fft)).float() 102 | 103 | import torch.nn.functional as F 104 | 105 | # if center: 106 | # signal = F.pad(y[:, None, None, :], (n_fft // 2, n_fft // 2, 0, 0), mode = 'reflect').squeeze(1) 107 | assert center is False 108 | 109 | forward_transform_squared = F.conv1d(y, forward_basis.to(y.device), stride = hop_size) 110 | spec2 = torch.stack([forward_transform_squared[:, :freq_cutoff, :], forward_transform_squared[:, freq_cutoff:, :]], dim = -1) 111 | 112 | 113 | # ******************** Verification ************************# 114 | spec1 = torch.stft(y.squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 115 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 116 | assert torch.allclose(spec1, spec2, atol=1e-4) 117 | 118 | spec = torch.sqrt(spec2.pow(2).sum(-1) + 1e-6) 119 | return spec 120 | 121 | 122 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 123 | global mel_basis 124 | dtype_device = str(spec.dtype) + "_" + str(spec.device) 125 | fmax_dtype_device = str(fmax) + "_" + dtype_device 126 | if fmax_dtype_device not in mel_basis: 127 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 128 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( 129 | dtype=spec.dtype, device=spec.device 130 | ) 131 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 132 | spec = spectral_normalize_torch(spec) 133 | return spec 134 | 135 | 136 | def mel_spectrogram_torch( 137 | y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False 138 | ): 139 | if torch.min(y) < -1.0: 140 | print("min value is ", torch.min(y)) 141 | if torch.max(y) > 1.0: 142 | print("max value is ", torch.max(y)) 143 | 144 | global mel_basis, hann_window 145 | dtype_device = str(y.dtype) + "_" + str(y.device) 146 | fmax_dtype_device = str(fmax) + "_" + dtype_device 147 | wnsize_dtype_device = str(win_size) + "_" + dtype_device 148 | if fmax_dtype_device not in mel_basis: 149 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 150 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( 151 | dtype=y.dtype, device=y.device 152 | ) 153 | if wnsize_dtype_device not in hann_window: 154 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( 155 | dtype=y.dtype, device=y.device 156 | ) 157 | 158 | y = torch.nn.functional.pad( 159 | y.unsqueeze(1), 160 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 161 | mode="reflect", 162 | ) 163 | y = y.squeeze(1) 164 | 165 | spec = torch.stft( 166 | y, 167 | n_fft, 168 | hop_length=hop_size, 169 | win_length=win_size, 170 | window=hann_window[wnsize_dtype_device], 171 | center=center, 172 | pad_mode="reflect", 173 | normalized=False, 174 | onesided=True, 175 | return_complex=False, 176 | ) 177 | 178 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 179 | 180 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 181 | spec = spectral_normalize_torch(spec) 182 | 183 | return spec -------------------------------------------------------------------------------- /OpenVoice/utils/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import numpy as np 5 | 6 | 7 | DEFAULT_MIN_BIN_WIDTH = 1e-3 8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 9 | DEFAULT_MIN_DERIVATIVE = 1e-3 10 | 11 | 12 | def piecewise_rational_quadratic_transform( 13 | inputs, 14 | unnormalized_widths, 15 | unnormalized_heights, 16 | unnormalized_derivatives, 17 | inverse=False, 18 | tails=None, 19 | tail_bound=1.0, 20 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 21 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 22 | min_derivative=DEFAULT_MIN_DERIVATIVE, 23 | ): 24 | if tails is None: 25 | spline_fn = rational_quadratic_spline 26 | spline_kwargs = {} 27 | else: 28 | spline_fn = unconstrained_rational_quadratic_spline 29 | spline_kwargs = {"tails": tails, "tail_bound": tail_bound} 30 | 31 | outputs, logabsdet = spline_fn( 32 | inputs=inputs, 33 | unnormalized_widths=unnormalized_widths, 34 | unnormalized_heights=unnormalized_heights, 35 | unnormalized_derivatives=unnormalized_derivatives, 36 | inverse=inverse, 37 | min_bin_width=min_bin_width, 38 | min_bin_height=min_bin_height, 39 | min_derivative=min_derivative, 40 | **spline_kwargs 41 | ) 42 | return outputs, logabsdet 43 | 44 | 45 | def searchsorted(bin_locations, inputs, eps=1e-6): 46 | bin_locations[..., -1] += eps 47 | return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 48 | 49 | 50 | def unconstrained_rational_quadratic_spline( 51 | inputs, 52 | unnormalized_widths, 53 | unnormalized_heights, 54 | unnormalized_derivatives, 55 | inverse=False, 56 | tails="linear", 57 | tail_bound=1.0, 58 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 59 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 60 | min_derivative=DEFAULT_MIN_DERIVATIVE, 61 | ): 62 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 63 | outside_interval_mask = ~inside_interval_mask 64 | 65 | outputs = torch.zeros_like(inputs) 66 | logabsdet = torch.zeros_like(inputs) 67 | 68 | if tails == "linear": 69 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 70 | constant = np.log(np.exp(1 - min_derivative) - 1) 71 | unnormalized_derivatives[..., 0] = constant 72 | unnormalized_derivatives[..., -1] = constant 73 | 74 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 75 | logabsdet[outside_interval_mask] = 0 76 | else: 77 | raise RuntimeError("{} tails are not implemented.".format(tails)) 78 | 79 | ( 80 | outputs[inside_interval_mask], 81 | logabsdet[inside_interval_mask], 82 | ) = rational_quadratic_spline( 83 | inputs=inputs[inside_interval_mask], 84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 87 | inverse=inverse, 88 | left=-tail_bound, 89 | right=tail_bound, 90 | bottom=-tail_bound, 91 | top=tail_bound, 92 | min_bin_width=min_bin_width, 93 | min_bin_height=min_bin_height, 94 | min_derivative=min_derivative, 95 | ) 96 | 97 | return outputs, logabsdet 98 | 99 | 100 | def rational_quadratic_spline( 101 | inputs, 102 | unnormalized_widths, 103 | unnormalized_heights, 104 | unnormalized_derivatives, 105 | inverse=False, 106 | left=0.0, 107 | right=1.0, 108 | bottom=0.0, 109 | top=1.0, 110 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 111 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 112 | min_derivative=DEFAULT_MIN_DERIVATIVE, 113 | ): 114 | if torch.min(inputs) < left or torch.max(inputs) > right: 115 | raise ValueError("Input to a transform is not within its domain") 116 | 117 | num_bins = unnormalized_widths.shape[-1] 118 | 119 | if min_bin_width * num_bins > 1.0: 120 | raise ValueError("Minimal bin width too large for the number of bins") 121 | if min_bin_height * num_bins > 1.0: 122 | raise ValueError("Minimal bin height too large for the number of bins") 123 | 124 | widths = F.softmax(unnormalized_widths, dim=-1) 125 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 126 | cumwidths = torch.cumsum(widths, dim=-1) 127 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) 128 | cumwidths = (right - left) * cumwidths + left 129 | cumwidths[..., 0] = left 130 | cumwidths[..., -1] = right 131 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 132 | 133 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 134 | 135 | heights = F.softmax(unnormalized_heights, dim=-1) 136 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 137 | cumheights = torch.cumsum(heights, dim=-1) 138 | cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) 139 | cumheights = (top - bottom) * cumheights + bottom 140 | cumheights[..., 0] = bottom 141 | cumheights[..., -1] = top 142 | heights = cumheights[..., 1:] - cumheights[..., :-1] 143 | 144 | if inverse: 145 | bin_idx = searchsorted(cumheights, inputs)[..., None] 146 | else: 147 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 148 | 149 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 150 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 151 | 152 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 153 | delta = heights / widths 154 | input_delta = delta.gather(-1, bin_idx)[..., 0] 155 | 156 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 157 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 158 | 159 | input_heights = heights.gather(-1, bin_idx)[..., 0] 160 | 161 | if inverse: 162 | a = (inputs - input_cumheights) * ( 163 | input_derivatives + input_derivatives_plus_one - 2 * input_delta 164 | ) + input_heights * (input_delta - input_derivatives) 165 | b = input_heights * input_derivatives - (inputs - input_cumheights) * ( 166 | input_derivatives + input_derivatives_plus_one - 2 * input_delta 167 | ) 168 | c = -input_delta * (inputs - input_cumheights) 169 | 170 | discriminant = b.pow(2) - 4 * a * c 171 | assert (discriminant >= 0).all() 172 | 173 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 174 | outputs = root * input_bin_widths + input_cumwidths 175 | 176 | theta_one_minus_theta = root * (1 - root) 177 | denominator = input_delta + ( 178 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta) 179 | * theta_one_minus_theta 180 | ) 181 | derivative_numerator = input_delta.pow(2) * ( 182 | input_derivatives_plus_one * root.pow(2) 183 | + 2 * input_delta * theta_one_minus_theta 184 | + input_derivatives * (1 - root).pow(2) 185 | ) 186 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 187 | 188 | return outputs, -logabsdet 189 | else: 190 | theta = (inputs - input_cumwidths) / input_bin_widths 191 | theta_one_minus_theta = theta * (1 - theta) 192 | 193 | numerator = input_heights * ( 194 | input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta 195 | ) 196 | denominator = input_delta + ( 197 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta) 198 | * theta_one_minus_theta 199 | ) 200 | outputs = input_cumheights + numerator / denominator 201 | 202 | derivative_numerator = input_delta.pow(2) * ( 203 | input_derivatives_plus_one * theta.pow(2) 204 | + 2 * input_delta * theta_one_minus_theta 205 | + input_derivatives * (1 - theta).pow(2) 206 | ) 207 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 208 | 209 | return outputs, logabsdet 210 | -------------------------------------------------------------------------------- /OpenVoice/text/mandarin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | from pypinyin import lazy_pinyin, BOPOMOFO 5 | import jieba 6 | import cn2an 7 | import logging 8 | 9 | 10 | # List of (Latin alphabet, bopomofo) pairs: 11 | _latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 12 | ('a', 'ㄟˉ'), 13 | ('b', 'ㄅㄧˋ'), 14 | ('c', 'ㄙㄧˉ'), 15 | ('d', 'ㄉㄧˋ'), 16 | ('e', 'ㄧˋ'), 17 | ('f', 'ㄝˊㄈㄨˋ'), 18 | ('g', 'ㄐㄧˋ'), 19 | ('h', 'ㄝˇㄑㄩˋ'), 20 | ('i', 'ㄞˋ'), 21 | ('j', 'ㄐㄟˋ'), 22 | ('k', 'ㄎㄟˋ'), 23 | ('l', 'ㄝˊㄛˋ'), 24 | ('m', 'ㄝˊㄇㄨˋ'), 25 | ('n', 'ㄣˉ'), 26 | ('o', 'ㄡˉ'), 27 | ('p', 'ㄆㄧˉ'), 28 | ('q', 'ㄎㄧㄡˉ'), 29 | ('r', 'ㄚˋ'), 30 | ('s', 'ㄝˊㄙˋ'), 31 | ('t', 'ㄊㄧˋ'), 32 | ('u', 'ㄧㄡˉ'), 33 | ('v', 'ㄨㄧˉ'), 34 | ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'), 35 | ('x', 'ㄝˉㄎㄨˋㄙˋ'), 36 | ('y', 'ㄨㄞˋ'), 37 | ('z', 'ㄗㄟˋ') 38 | ]] 39 | 40 | # List of (bopomofo, romaji) pairs: 41 | _bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [ 42 | ('ㄅㄛ', 'p⁼wo'), 43 | ('ㄆㄛ', 'pʰwo'), 44 | ('ㄇㄛ', 'mwo'), 45 | ('ㄈㄛ', 'fwo'), 46 | ('ㄅ', 'p⁼'), 47 | ('ㄆ', 'pʰ'), 48 | ('ㄇ', 'm'), 49 | ('ㄈ', 'f'), 50 | ('ㄉ', 't⁼'), 51 | ('ㄊ', 'tʰ'), 52 | ('ㄋ', 'n'), 53 | ('ㄌ', 'l'), 54 | ('ㄍ', 'k⁼'), 55 | ('ㄎ', 'kʰ'), 56 | ('ㄏ', 'h'), 57 | ('ㄐ', 'ʧ⁼'), 58 | ('ㄑ', 'ʧʰ'), 59 | ('ㄒ', 'ʃ'), 60 | ('ㄓ', 'ʦ`⁼'), 61 | ('ㄔ', 'ʦ`ʰ'), 62 | ('ㄕ', 's`'), 63 | ('ㄖ', 'ɹ`'), 64 | ('ㄗ', 'ʦ⁼'), 65 | ('ㄘ', 'ʦʰ'), 66 | ('ㄙ', 's'), 67 | ('ㄚ', 'a'), 68 | ('ㄛ', 'o'), 69 | ('ㄜ', 'ə'), 70 | ('ㄝ', 'e'), 71 | ('ㄞ', 'ai'), 72 | ('ㄟ', 'ei'), 73 | ('ㄠ', 'au'), 74 | ('ㄡ', 'ou'), 75 | ('ㄧㄢ', 'yeNN'), 76 | ('ㄢ', 'aNN'), 77 | ('ㄧㄣ', 'iNN'), 78 | ('ㄣ', 'əNN'), 79 | ('ㄤ', 'aNg'), 80 | ('ㄧㄥ', 'iNg'), 81 | ('ㄨㄥ', 'uNg'), 82 | ('ㄩㄥ', 'yuNg'), 83 | ('ㄥ', 'əNg'), 84 | ('ㄦ', 'əɻ'), 85 | ('ㄧ', 'i'), 86 | ('ㄨ', 'u'), 87 | ('ㄩ', 'ɥ'), 88 | ('ˉ', '→'), 89 | ('ˊ', '↑'), 90 | ('ˇ', '↓↑'), 91 | ('ˋ', '↓'), 92 | ('˙', ''), 93 | (',', ','), 94 | ('。', '.'), 95 | ('!', '!'), 96 | ('?', '?'), 97 | ('—', '-') 98 | ]] 99 | 100 | # List of (romaji, ipa) pairs: 101 | _romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 102 | ('ʃy', 'ʃ'), 103 | ('ʧʰy', 'ʧʰ'), 104 | ('ʧ⁼y', 'ʧ⁼'), 105 | ('NN', 'n'), 106 | ('Ng', 'ŋ'), 107 | ('y', 'j'), 108 | ('h', 'x') 109 | ]] 110 | 111 | # List of (bopomofo, ipa) pairs: 112 | _bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 113 | ('ㄅㄛ', 'p⁼wo'), 114 | ('ㄆㄛ', 'pʰwo'), 115 | ('ㄇㄛ', 'mwo'), 116 | ('ㄈㄛ', 'fwo'), 117 | ('ㄅ', 'p⁼'), 118 | ('ㄆ', 'pʰ'), 119 | ('ㄇ', 'm'), 120 | ('ㄈ', 'f'), 121 | ('ㄉ', 't⁼'), 122 | ('ㄊ', 'tʰ'), 123 | ('ㄋ', 'n'), 124 | ('ㄌ', 'l'), 125 | ('ㄍ', 'k⁼'), 126 | ('ㄎ', 'kʰ'), 127 | ('ㄏ', 'x'), 128 | ('ㄐ', 'tʃ⁼'), 129 | ('ㄑ', 'tʃʰ'), 130 | ('ㄒ', 'ʃ'), 131 | ('ㄓ', 'ts`⁼'), 132 | ('ㄔ', 'ts`ʰ'), 133 | ('ㄕ', 's`'), 134 | ('ㄖ', 'ɹ`'), 135 | ('ㄗ', 'ts⁼'), 136 | ('ㄘ', 'tsʰ'), 137 | ('ㄙ', 's'), 138 | ('ㄚ', 'a'), 139 | ('ㄛ', 'o'), 140 | ('ㄜ', 'ə'), 141 | ('ㄝ', 'ɛ'), 142 | ('ㄞ', 'aɪ'), 143 | ('ㄟ', 'eɪ'), 144 | ('ㄠ', 'ɑʊ'), 145 | ('ㄡ', 'oʊ'), 146 | ('ㄧㄢ', 'jɛn'), 147 | ('ㄩㄢ', 'ɥæn'), 148 | ('ㄢ', 'an'), 149 | ('ㄧㄣ', 'in'), 150 | ('ㄩㄣ', 'ɥn'), 151 | ('ㄣ', 'ən'), 152 | ('ㄤ', 'ɑŋ'), 153 | ('ㄧㄥ', 'iŋ'), 154 | ('ㄨㄥ', 'ʊŋ'), 155 | ('ㄩㄥ', 'jʊŋ'), 156 | ('ㄥ', 'əŋ'), 157 | ('ㄦ', 'əɻ'), 158 | ('ㄧ', 'i'), 159 | ('ㄨ', 'u'), 160 | ('ㄩ', 'ɥ'), 161 | ('ˉ', '→'), 162 | ('ˊ', '↑'), 163 | ('ˇ', '↓↑'), 164 | ('ˋ', '↓'), 165 | ('˙', ''), 166 | (',', ','), 167 | ('。', '.'), 168 | ('!', '!'), 169 | ('?', '?'), 170 | ('—', '-') 171 | ]] 172 | 173 | # List of (bopomofo, ipa2) pairs: 174 | _bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 175 | ('ㄅㄛ', 'pwo'), 176 | ('ㄆㄛ', 'pʰwo'), 177 | ('ㄇㄛ', 'mwo'), 178 | ('ㄈㄛ', 'fwo'), 179 | ('ㄅ', 'p'), 180 | ('ㄆ', 'pʰ'), 181 | ('ㄇ', 'm'), 182 | ('ㄈ', 'f'), 183 | ('ㄉ', 't'), 184 | ('ㄊ', 'tʰ'), 185 | ('ㄋ', 'n'), 186 | ('ㄌ', 'l'), 187 | ('ㄍ', 'k'), 188 | ('ㄎ', 'kʰ'), 189 | ('ㄏ', 'h'), 190 | ('ㄐ', 'tɕ'), 191 | ('ㄑ', 'tɕʰ'), 192 | ('ㄒ', 'ɕ'), 193 | ('ㄓ', 'tʂ'), 194 | ('ㄔ', 'tʂʰ'), 195 | ('ㄕ', 'ʂ'), 196 | ('ㄖ', 'ɻ'), 197 | ('ㄗ', 'ts'), 198 | ('ㄘ', 'tsʰ'), 199 | ('ㄙ', 's'), 200 | ('ㄚ', 'a'), 201 | ('ㄛ', 'o'), 202 | ('ㄜ', 'ɤ'), 203 | ('ㄝ', 'ɛ'), 204 | ('ㄞ', 'aɪ'), 205 | ('ㄟ', 'eɪ'), 206 | ('ㄠ', 'ɑʊ'), 207 | ('ㄡ', 'oʊ'), 208 | ('ㄧㄢ', 'jɛn'), 209 | ('ㄩㄢ', 'yæn'), 210 | ('ㄢ', 'an'), 211 | ('ㄧㄣ', 'in'), 212 | ('ㄩㄣ', 'yn'), 213 | ('ㄣ', 'ən'), 214 | ('ㄤ', 'ɑŋ'), 215 | ('ㄧㄥ', 'iŋ'), 216 | ('ㄨㄥ', 'ʊŋ'), 217 | ('ㄩㄥ', 'jʊŋ'), 218 | ('ㄥ', 'ɤŋ'), 219 | ('ㄦ', 'əɻ'), 220 | ('ㄧ', 'i'), 221 | ('ㄨ', 'u'), 222 | ('ㄩ', 'y'), 223 | ('ˉ', '˥'), 224 | ('ˊ', '˧˥'), 225 | ('ˇ', '˨˩˦'), 226 | ('ˋ', '˥˩'), 227 | ('˙', ''), 228 | (',', ','), 229 | ('。', '.'), 230 | ('!', '!'), 231 | ('?', '?'), 232 | ('—', '-') 233 | ]] 234 | 235 | 236 | def number_to_chinese(text): 237 | numbers = re.findall(r'\d+(?:\.?\d+)?', text) 238 | for number in numbers: 239 | text = text.replace(number, cn2an.an2cn(number), 1) 240 | return text 241 | 242 | 243 | def chinese_to_bopomofo(text): 244 | text = text.replace('、', ',').replace(';', ',').replace(':', ',') 245 | words = jieba.lcut(text, cut_all=False) 246 | text = '' 247 | for word in words: 248 | bopomofos = lazy_pinyin(word, BOPOMOFO) 249 | if not re.search('[\u4e00-\u9fff]', word): 250 | text += word 251 | continue 252 | for i in range(len(bopomofos)): 253 | bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i]) 254 | if text != '': 255 | text += ' ' 256 | text += ''.join(bopomofos) 257 | return text 258 | 259 | 260 | def latin_to_bopomofo(text): 261 | for regex, replacement in _latin_to_bopomofo: 262 | text = re.sub(regex, replacement, text) 263 | return text 264 | 265 | 266 | def bopomofo_to_romaji(text): 267 | for regex, replacement in _bopomofo_to_romaji: 268 | text = re.sub(regex, replacement, text) 269 | return text 270 | 271 | 272 | def bopomofo_to_ipa(text): 273 | for regex, replacement in _bopomofo_to_ipa: 274 | text = re.sub(regex, replacement, text) 275 | return text 276 | 277 | 278 | def bopomofo_to_ipa2(text): 279 | for regex, replacement in _bopomofo_to_ipa2: 280 | text = re.sub(regex, replacement, text) 281 | return text 282 | 283 | 284 | def chinese_to_romaji(text): 285 | text = number_to_chinese(text) 286 | text = chinese_to_bopomofo(text) 287 | text = latin_to_bopomofo(text) 288 | text = bopomofo_to_romaji(text) 289 | text = re.sub('i([aoe])', r'y\1', text) 290 | text = re.sub('u([aoəe])', r'w\1', text) 291 | text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', 292 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') 293 | text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) 294 | return text 295 | 296 | 297 | def chinese_to_lazy_ipa(text): 298 | text = chinese_to_romaji(text) 299 | for regex, replacement in _romaji_to_ipa: 300 | text = re.sub(regex, replacement, text) 301 | return text 302 | 303 | 304 | def chinese_to_ipa(text): 305 | text = number_to_chinese(text) 306 | text = chinese_to_bopomofo(text) 307 | text = latin_to_bopomofo(text) 308 | text = bopomofo_to_ipa(text) 309 | text = re.sub('i([aoe])', r'j\1', text) 310 | text = re.sub('u([aoəe])', r'w\1', text) 311 | text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', 312 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') 313 | text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) 314 | return text 315 | 316 | 317 | def chinese_to_ipa2(text): 318 | text = number_to_chinese(text) 319 | text = chinese_to_bopomofo(text) 320 | text = latin_to_bopomofo(text) 321 | text = bopomofo_to_ipa2(text) 322 | text = re.sub(r'i([aoe])', r'j\1', text) 323 | text = re.sub(r'u([aoəe])', r'w\1', text) 324 | text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text) 325 | text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text) 326 | return text 327 | -------------------------------------------------------------------------------- /OpenVoice/api.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import re 4 | import soundfile 5 | from .utils import utils 6 | from .utils import commons 7 | import os 8 | import librosa 9 | from .text import text_to_sequence 10 | from .utils.mel_processing import spectrogram_torch 11 | from .utils.models import SynthesizerTrn 12 | 13 | 14 | class OpenVoiceBaseClass(object): 15 | def __init__(self, 16 | config_path, 17 | #device='cuda:0'): 18 | device="cpu"): 19 | #if 'cuda' in device: 20 | # assert torch.cuda.is_available() 21 | 22 | hps = utils.get_hparams_from_file(config_path) 23 | 24 | model = SynthesizerTrn( 25 | len(getattr(hps, 'symbols', [])), 26 | hps.data.filter_length // 2 + 1, 27 | n_speakers=hps.data.n_speakers, 28 | **hps.model, 29 | ).to(device) 30 | 31 | model.eval() 32 | self.model = model 33 | self.hps = hps 34 | self.device = device 35 | 36 | def load_ckpt(self, ckpt_path): 37 | checkpoint_dict = torch.load(ckpt_path, map_location=torch.device('cpu')) 38 | a, b = self.model.load_state_dict(checkpoint_dict['model'], strict=False) 39 | print("Loaded checkpoint '{}'".format(ckpt_path)) 40 | print('missing/unexpected keys:', a, b) 41 | 42 | 43 | class BaseSpeakerTTS(OpenVoiceBaseClass): 44 | language_marks = { 45 | "english": "EN", 46 | "chinese": "ZH", 47 | } 48 | 49 | @staticmethod 50 | def get_text(text, hps, is_symbol): 51 | text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners) 52 | if hps.data.add_blank: 53 | text_norm = commons.intersperse(text_norm, 0) 54 | text_norm = torch.LongTensor(text_norm) 55 | return text_norm 56 | 57 | @staticmethod 58 | def audio_numpy_concat(segment_data_list, sr, speed=1.): 59 | audio_segments = [] 60 | for segment_data in segment_data_list: 61 | audio_segments += segment_data.reshape(-1).tolist() 62 | audio_segments += [0] * int((sr * 0.05)/speed) 63 | audio_segments = np.array(audio_segments).astype(np.float32) 64 | return audio_segments 65 | 66 | @staticmethod 67 | def split_sentences_into_pieces(text, language_str): 68 | texts = utils.split_sentence(text, language_str=language_str) 69 | print(" > Text splitted to sentences.") 70 | print('\n'.join(texts)) 71 | print(" > ===========================") 72 | return texts 73 | 74 | def tts(self, text, output_path, speaker, language='English', speed=1.0): 75 | mark = self.language_marks.get(language.lower(), None) 76 | assert mark is not None, f"language {language} is not supported" 77 | 78 | texts = self.split_sentences_into_pieces(text, mark) 79 | 80 | audio_list = [] 81 | for t in texts: 82 | t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t) 83 | t = f'[{mark}]{t}[{mark}]' 84 | stn_tst = self.get_text(t, self.hps, False) 85 | device = self.device 86 | speaker_id = self.hps.speakers[speaker] 87 | with torch.no_grad(): 88 | x_tst = stn_tst.unsqueeze(0).to(device) 89 | x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device) 90 | sid = torch.LongTensor([speaker_id]).to(device) 91 | audio = self.model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.6, 92 | length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy() 93 | audio_list.append(audio) 94 | audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed) 95 | 96 | if output_path is None: 97 | return audio 98 | else: 99 | soundfile.write(output_path, audio, self.hps.data.sampling_rate) 100 | 101 | 102 | class ToneColorConverter(OpenVoiceBaseClass): 103 | def __init__(self, *args, **kwargs): 104 | super().__init__(*args, **kwargs) 105 | 106 | if kwargs.get('enable_watermark', True): 107 | import wavmark 108 | self.watermark_model = wavmark.load_model().to(self.device) 109 | else: 110 | self.watermark_model = None 111 | 112 | 113 | 114 | def extract_se(self, ref_wav_list, se_save_path=None): 115 | if isinstance(ref_wav_list, str): 116 | ref_wav_list = [ref_wav_list] 117 | 118 | device = self.device 119 | hps = self.hps 120 | gs = [] 121 | 122 | for fname in ref_wav_list: 123 | audio_ref, sr = librosa.load(fname, sr=hps.data.sampling_rate) 124 | y = torch.FloatTensor(audio_ref) 125 | y = y.to(device) 126 | y = y.unsqueeze(0) 127 | y = spectrogram_torch(y, hps.data.filter_length, 128 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 129 | center=False).to(device) 130 | with torch.no_grad(): 131 | g = self.model.ref_enc(y.transpose(1, 2)).unsqueeze(-1) 132 | gs.append(g.detach()) 133 | gs = torch.stack(gs).mean(0) 134 | 135 | if se_save_path is not None: 136 | os.makedirs(os.path.dirname(se_save_path), exist_ok=True) 137 | torch.save(gs.cpu(), se_save_path) 138 | 139 | return gs 140 | 141 | def convert(self, audio_src_path, src_se, tgt_se, output_path=None, tau=0.3, message="@Hilley-MyShell"): 142 | hps = self.hps 143 | # load audio 144 | audio, sample_rate = librosa.load(audio_src_path, sr=hps.data.sampling_rate) 145 | audio = torch.tensor(audio).float() 146 | 147 | with torch.no_grad(): 148 | y = torch.FloatTensor(audio).to(self.device) 149 | y = y.unsqueeze(0) 150 | spec = spectrogram_torch(y, hps.data.filter_length, 151 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 152 | center=False).to(self.device) 153 | spec_lengths = torch.LongTensor([spec.size(-1)]).to(self.device) 154 | audio = self.model.voice_conversion(spec, spec_lengths, sid_src=src_se, sid_tgt=tgt_se, tau=tau)[0][ 155 | 0, 0].data.cpu().float().numpy() 156 | audio = self.add_watermark(audio, message) 157 | if output_path is None: 158 | return audio 159 | else: 160 | soundfile.write(output_path, audio, hps.data.sampling_rate) 161 | 162 | def add_watermark(self, audio, message): 163 | if self.watermark_model is None: 164 | return audio 165 | device = self.device 166 | bits = utils.string_to_bits(message).reshape(-1) 167 | n_repeat = len(bits) // 32 168 | 169 | K = 16000 170 | coeff = 2 171 | for n in range(n_repeat): 172 | trunck = audio[(coeff * n) * K: (coeff * n + 1) * K] 173 | if len(trunck) != K: 174 | print('Audio too short, fail to add watermark') 175 | break 176 | message_npy = bits[n * 32: (n + 1) * 32] 177 | 178 | with torch.no_grad(): 179 | signal = torch.FloatTensor(trunck).to(device)[None] 180 | message_tensor = torch.FloatTensor(message_npy).to(device)[None] 181 | signal_wmd_tensor = self.watermark_model.encode(signal, message_tensor) 182 | signal_wmd_npy = signal_wmd_tensor.detach().cpu().squeeze() 183 | audio[(coeff * n) * K: (coeff * n + 1) * K] = signal_wmd_npy 184 | return audio 185 | 186 | def detect_watermark(self, audio, n_repeat): 187 | bits = [] 188 | K = 16000 189 | coeff = 2 190 | for n in range(n_repeat): 191 | trunck = audio[(coeff * n) * K: (coeff * n + 1) * K] 192 | if len(trunck) != K: 193 | print('Audio too short, fail to detect watermark') 194 | return 'Fail' 195 | with torch.no_grad(): 196 | signal = torch.FloatTensor(trunck).to(self.device).unsqueeze(0) 197 | message_decoded_npy = (self.watermark_model.decode(signal) >= 0.5).int().detach().cpu().numpy().squeeze() 198 | bits.append(message_decoded_npy) 199 | bits = np.stack(bits).reshape(-1, 8) 200 | message = utils.bits_to_string(bits) 201 | return message 202 | 203 | -------------------------------------------------------------------------------- /ChatTTS/model/gpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 3 | 4 | import logging 5 | from tqdm import tqdm 6 | from einops import rearrange 7 | from transformers.cache_utils import Cache 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.nn.utils.parametrize as P 13 | from torch.nn.utils.parametrizations import weight_norm 14 | from transformers import LlamaModel, LlamaConfig 15 | 16 | 17 | class LlamaMLP(nn.Module): 18 | def __init__(self, hidden_size, intermediate_size): 19 | super().__init__() 20 | self.hidden_size = hidden_size 21 | self.intermediate_size = intermediate_size 22 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 23 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 24 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 25 | self.act_fn = F.silu 26 | 27 | def forward(self, x): 28 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 29 | return down_proj 30 | 31 | 32 | class GPT_warpper(nn.Module): 33 | def __init__( 34 | self, 35 | gpt_config, 36 | num_audio_tokens, 37 | num_text_tokens, 38 | num_vq=4, 39 | **kwargs, 40 | ): 41 | super().__init__() 42 | 43 | self.logger = logging.getLogger(__name__) 44 | self.gpt = self.build_model(gpt_config) 45 | self.model_dim = self.gpt.config.hidden_size 46 | 47 | self.num_vq = num_vq 48 | self.emb_code = nn.ModuleList([nn.Embedding(num_audio_tokens, self.model_dim) for i in range(self.num_vq)]) 49 | self.emb_text = nn.Embedding(num_text_tokens, self.model_dim) 50 | self.head_text = weight_norm(nn.Linear(self.model_dim, num_text_tokens, bias=False), name='weight') 51 | self.head_code = nn.ModuleList([weight_norm(nn.Linear(self.model_dim, num_audio_tokens, bias=False), name='weight') for i in range(self.num_vq)]) 52 | 53 | def build_model(self, config): 54 | 55 | configuration = LlamaConfig(**config) 56 | model = LlamaModel(configuration) 57 | del model.embed_tokens 58 | 59 | return model 60 | 61 | def get_emb(self, input_ids, text_mask, **kwargs): 62 | 63 | emb_text = self.emb_text(input_ids[text_mask][:, 0]) 64 | 65 | emb_code = [self.emb_code[i](input_ids[~text_mask][:, i]) for i in range(self.num_vq)] 66 | emb_code = torch.stack(emb_code, 2).sum(2) 67 | 68 | emb = torch.zeros((input_ids.shape[:-1])+(emb_text.shape[-1],), device=emb_text.device, dtype=emb_text.dtype) 69 | emb[text_mask] = emb_text 70 | emb[~text_mask] = emb_code.to(emb.dtype) 71 | 72 | return emb 73 | 74 | def prepare_inputs_for_generation( 75 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs 76 | ): 77 | # With static cache, the `past_key_values` is None 78 | # TODO joao: standardize interface for the different Cache classes and remove of this if 79 | has_static_cache = False 80 | if past_key_values is None: 81 | past_key_values = getattr(self.gpt.layers[0].self_attn, "past_key_value", None) 82 | has_static_cache = past_key_values is not None 83 | 84 | past_length = 0 85 | if past_key_values is not None: 86 | if isinstance(past_key_values, Cache): 87 | past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() 88 | max_cache_length = ( 89 | torch.tensor(past_key_values.get_max_length(), device=input_ids.device) 90 | if past_key_values.get_max_length() is not None 91 | else None 92 | ) 93 | cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) 94 | # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects 95 | else: 96 | cache_length = past_length = past_key_values[0][0].shape[2] 97 | max_cache_length = None 98 | 99 | # Keep only the unprocessed tokens: 100 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where 101 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as 102 | # input) 103 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: 104 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] 105 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard 106 | # input_ids based on the past_length. 107 | elif past_length < input_ids.shape[1]: 108 | input_ids = input_ids[:, past_length:] 109 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. 110 | 111 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. 112 | if ( 113 | max_cache_length is not None 114 | and attention_mask is not None 115 | and cache_length + input_ids.shape[1] > max_cache_length 116 | ): 117 | attention_mask = attention_mask[:, -max_cache_length:] 118 | 119 | position_ids = kwargs.get("position_ids", None) 120 | if attention_mask is not None and position_ids is None: 121 | # create position_ids on the fly for batch generation 122 | position_ids = attention_mask.long().cumsum(-1) - 1 123 | position_ids.masked_fill_(attention_mask == 0, 1) 124 | if past_key_values: 125 | position_ids = position_ids[:, -input_ids.shape[1] :] 126 | 127 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 128 | if inputs_embeds is not None and past_key_values is None: 129 | model_inputs = {"inputs_embeds": inputs_embeds} 130 | else: 131 | # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise 132 | # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 133 | # TODO: use `next_tokens` directly instead. 134 | model_inputs = {"input_ids": input_ids.contiguous()} 135 | 136 | input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] 137 | if cache_position is None: 138 | cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) 139 | else: 140 | cache_position = cache_position[-input_length:] 141 | 142 | if has_static_cache: 143 | past_key_values = None 144 | 145 | model_inputs.update( 146 | { 147 | "position_ids": position_ids, 148 | "cache_position": cache_position, 149 | "past_key_values": past_key_values, 150 | "use_cache": kwargs.get("use_cache"), 151 | "attention_mask": attention_mask, 152 | } 153 | ) 154 | return model_inputs 155 | 156 | def generate( 157 | self, 158 | emb, 159 | inputs_ids, 160 | temperature, 161 | eos_token, 162 | attention_mask = None, 163 | max_new_token = 2048, 164 | min_new_token = 0, 165 | LogitsWarpers = [], 166 | LogitsProcessors = [], 167 | infer_text=False, 168 | return_attn=False, 169 | return_hidden=False, 170 | ): 171 | 172 | with torch.no_grad(): 173 | 174 | attentions = [] 175 | hiddens = [] 176 | 177 | start_idx, end_idx = inputs_ids.shape[1], torch.zeros(inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long) 178 | finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool() 179 | 180 | temperature = temperature[None].expand(inputs_ids.shape[0], -1) 181 | temperature = rearrange(temperature, "b n -> (b n) 1") 182 | 183 | attention_mask_cache = torch.ones((inputs_ids.shape[0], inputs_ids.shape[1]+max_new_token,), dtype=torch.bool, device=inputs_ids.device) 184 | if attention_mask is not None: 185 | attention_mask_cache[:, :attention_mask.shape[1]] = attention_mask 186 | 187 | for i in tqdm(range(max_new_token)): 188 | 189 | model_input = self.prepare_inputs_for_generation(inputs_ids, 190 | outputs.past_key_values if i!=0 else None, 191 | attention_mask_cache[:, :inputs_ids.shape[1]], use_cache=True) 192 | 193 | if i == 0: 194 | model_input['inputs_embeds'] = emb 195 | else: 196 | if infer_text: 197 | model_input['inputs_embeds'] = self.emb_text(model_input['input_ids'][:,:,0]) 198 | else: 199 | code_emb = [self.emb_code[i](model_input['input_ids'][:,:,i]) for i in range(self.num_vq)] 200 | model_input['inputs_embeds'] = torch.stack(code_emb, 3).sum(3) 201 | 202 | model_input['input_ids'] = None 203 | outputs = self.gpt.forward(**model_input, output_attentions=return_attn) 204 | attentions.append(outputs.attentions) 205 | hidden_states = outputs[0] # 🐻 206 | if return_hidden: 207 | hiddens.append(hidden_states[:, -1]) 208 | 209 | with P.cached(): 210 | if infer_text: 211 | logits = self.head_text(hidden_states) 212 | else: 213 | logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3) 214 | 215 | logits = logits[:, -1].float() 216 | 217 | if not infer_text: 218 | logits = rearrange(logits, "b c n -> (b n) c") 219 | logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c") 220 | else: 221 | logits_token = inputs_ids[:, start_idx:, 0] 222 | 223 | logits = logits / temperature 224 | 225 | for logitsProcessors in LogitsProcessors: 226 | logits = logitsProcessors(logits_token, logits) 227 | 228 | for logitsWarpers in LogitsWarpers: 229 | logits = logitsWarpers(logits_token, logits) 230 | 231 | if i < min_new_token: 232 | logits[:, eos_token] = -torch.inf 233 | 234 | scores = F.softmax(logits, dim=-1) 235 | 236 | idx_next = torch.multinomial(scores, num_samples=1) 237 | 238 | if not infer_text: 239 | idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq) 240 | finish = finish | (idx_next == eos_token).any(1) 241 | inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(1)], 1) 242 | else: 243 | finish = finish | (idx_next == eos_token).any(1) 244 | inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(-1).expand(-1, -1, self.num_vq)], 1) 245 | 246 | end_idx = end_idx + (~finish).int() 247 | 248 | if finish.all(): 249 | break 250 | 251 | inputs_ids = [inputs_ids[idx, start_idx: start_idx+i] for idx, i in enumerate(end_idx.int())] 252 | inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids 253 | 254 | if return_hidden: 255 | hiddens = torch.stack(hiddens, 1) 256 | hiddens = [hiddens[idx, :i] for idx, i in enumerate(end_idx.int())] 257 | 258 | if not finish.all(): 259 | self.logger.warn(f'Incomplete result. hit max_new_token: {max_new_token}') 260 | 261 | return { 262 | 'ids': inputs_ids, 263 | 'attentions': attentions, 264 | 'hiddens':hiddens, 265 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Attribution-NonCommercial 4.0 International Public 2 | License 3 | 4 | By exercising the Licensed Rights (defined below), You accept and agree 5 | to be bound by the terms and conditions of this Creative Commons 6 | Attribution-NonCommercial 4.0 International Public License ("Public 7 | License"). To the extent this Public License may be interpreted as a 8 | contract, You are granted the Licensed Rights in consideration of Your 9 | acceptance of these terms and conditions, and the Licensor grants You 10 | such rights in consideration of benefits the Licensor receives from 11 | making the Licensed Material available under these terms and 12 | conditions. 13 | 14 | 15 | Section 1 -- Definitions. 16 | 17 | a. Adapted Material means material subject to Copyright and Similar 18 | Rights that is derived from or based upon the Licensed Material 19 | and in which the Licensed Material is translated, altered, 20 | arranged, transformed, or otherwise modified in a manner requiring 21 | permission under the Copyright and Similar Rights held by the 22 | Licensor. For purposes of this Public License, where the Licensed 23 | Material is a musical work, performance, or sound recording, 24 | Adapted Material is always produced where the Licensed Material is 25 | synched in timed relation with a moving image. 26 | 27 | b. Adapter's License means the license You apply to Your Copyright 28 | and Similar Rights in Your contributions to Adapted Material in 29 | accordance with the terms and conditions of this Public License. 30 | 31 | c. Copyright and Similar Rights means copyright and/or similar rights 32 | closely related to copyright including, without limitation, 33 | performance, broadcast, sound recording, and Sui Generis Database 34 | Rights, without regard to how the rights are labeled or 35 | categorized. For purposes of this Public License, the rights 36 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 37 | Rights. 38 | d. Effective Technological Measures means those measures that, in the 39 | absence of proper authority, may not be circumvented under laws 40 | fulfilling obligations under Article 11 of the WIPO Copyright 41 | Treaty adopted on December 20, 1996, and/or similar international 42 | agreements. 43 | 44 | e. Exceptions and Limitations means fair use, fair dealing, and/or 45 | any other exception or limitation to Copyright and Similar Rights 46 | that applies to Your use of the Licensed Material. 47 | 48 | f. Licensed Material means the artistic or literary work, database, 49 | or other material to which the Licensor applied this Public 50 | License. 51 | 52 | g. Licensed Rights means the rights granted to You subject to the 53 | terms and conditions of this Public License, which are limited to 54 | all Copyright and Similar Rights that apply to Your use of the 55 | Licensed Material and that the Licensor has authority to license. 56 | 57 | h. Licensor means the individual(s) or entity(ies) granting rights 58 | under this Public License. 59 | 60 | i. NonCommercial means not primarily intended for or directed towards 61 | commercial advantage or monetary compensation. For purposes of 62 | this Public License, the exchange of the Licensed Material for 63 | other material subject to Copyright and Similar Rights by digital 64 | file-sharing or similar means is NonCommercial provided there is 65 | no payment of monetary compensation in connection with the 66 | exchange. 67 | 68 | j. Share means to provide material to the public by any means or 69 | process that requires permission under the Licensed Rights, such 70 | as reproduction, public display, public performance, distribution, 71 | dissemination, communication, or importation, and to make material 72 | available to the public including in ways that members of the 73 | public may access the material from a place and at a time 74 | individually chosen by them. 75 | 76 | k. Sui Generis Database Rights means rights other than copyright 77 | resulting from Directive 96/9/EC of the European Parliament and of 78 | the Council of 11 March 1996 on the legal protection of databases, 79 | as amended and/or succeeded, as well as other essentially 80 | equivalent rights anywhere in the world. 81 | 82 | l. You means the individual or entity exercising the Licensed Rights 83 | under this Public License. Your has a corresponding meaning. 84 | 85 | 86 | Section 2 -- Scope. 87 | 88 | a. License grant. 89 | 90 | 1. Subject to the terms and conditions of this Public License, 91 | the Licensor hereby grants You a worldwide, royalty-free, 92 | non-sublicensable, non-exclusive, irrevocable license to 93 | exercise the Licensed Rights in the Licensed Material to: 94 | 95 | a. reproduce and Share the Licensed Material, in whole or 96 | in part, for NonCommercial purposes only; and 97 | 98 | b. produce, reproduce, and Share Adapted Material for 99 | NonCommercial purposes only. 100 | 101 | 2. Exceptions and Limitations. For the avoidance of doubt, where 102 | Exceptions and Limitations apply to Your use, this Public 103 | License does not apply, and You do not need to comply with 104 | its terms and conditions. 105 | 106 | 3. Term. The term of this Public License is specified in Section 107 | 6(a). 108 | 109 | 4. Media and formats; technical modifications allowed. The 110 | Licensor authorizes You to exercise the Licensed Rights in 111 | all media and formats whether now known or hereafter created, 112 | and to make technical modifications necessary to do so. The 113 | Licensor waives and/or agrees not to assert any right or 114 | authority to forbid You from making technical modifications 115 | necessary to exercise the Licensed Rights, including 116 | technical modifications necessary to circumvent Effective 117 | Technological Measures. For purposes of this Public License, 118 | simply making modifications authorized by this Section 2(a) 119 | (4) never produces Adapted Material. 120 | 121 | 5. Downstream recipients. 122 | 123 | a. Offer from the Licensor -- Licensed Material. Every 124 | recipient of the Licensed Material automatically 125 | receives an offer from the Licensor to exercise the 126 | Licensed Rights under the terms and conditions of this 127 | Public License. 128 | 129 | b. No downstream restrictions. You may not offer or impose 130 | any additional or different terms or conditions on, or 131 | apply any Effective Technological Measures to, the 132 | Licensed Material if doing so restricts exercise of the 133 | Licensed Rights by any recipient of the Licensed 134 | Material. 135 | 136 | 6. No endorsement. Nothing in this Public License constitutes or 137 | may be construed as permission to assert or imply that You 138 | are, or that Your use of the Licensed Material is, connected 139 | with, or sponsored, endorsed, or granted official status by, 140 | the Licensor or others designated to receive attribution as 141 | provided in Section 3(a)(1)(A)(i). 142 | 143 | b. Other rights. 144 | 145 | 1. Moral rights, such as the right of integrity, are not 146 | licensed under this Public License, nor are publicity, 147 | privacy, and/or other similar personality rights; however, to 148 | the extent possible, the Licensor waives and/or agrees not to 149 | assert any such rights held by the Licensor to the limited 150 | extent necessary to allow You to exercise the Licensed 151 | Rights, but not otherwise. 152 | 153 | 2. Patent and trademark rights are not licensed under this 154 | Public License. 155 | 156 | 3. To the extent possible, the Licensor waives any right to 157 | collect royalties from You for the exercise of the Licensed 158 | Rights, whether directly or through a collecting society 159 | under any voluntary or waivable statutory or compulsory 160 | licensing scheme. In all other cases the Licensor expressly 161 | reserves any right to collect such royalties, including when 162 | the Licensed Material is used other than for NonCommercial 163 | purposes. 164 | 165 | 166 | Section 3 -- License Conditions. 167 | 168 | Your exercise of the Licensed Rights is expressly made subject to the 169 | following conditions. 170 | 171 | a. Attribution. 172 | 173 | 1. If You Share the Licensed Material (including in modified 174 | form), You must: 175 | 176 | a. retain the following if it is supplied by the Licensor 177 | with the Licensed Material: 178 | 179 | i. identification of the creator(s) of the Licensed 180 | Material and any others designated to receive 181 | attribution, in any reasonable manner requested by 182 | the Licensor (including by pseudonym if 183 | designated); 184 | 185 | ii. a copyright notice; 186 | 187 | iii. a notice that refers to this Public License; 188 | 189 | iv. a notice that refers to the disclaimer of 190 | warranties; 191 | 192 | v. a URI or hyperlink to the Licensed Material to the 193 | extent reasonably practicable; 194 | 195 | b. indicate if You modified the Licensed Material and 196 | retain an indication of any previous modifications; and 197 | 198 | c. indicate the Licensed Material is licensed under this 199 | Public License, and include the text of, or the URI or 200 | hyperlink to, this Public License. 201 | 202 | 2. You may satisfy the conditions in Section 3(a)(1) in any 203 | reasonable manner based on the medium, means, and context in 204 | which You Share the Licensed Material. For example, it may be 205 | reasonable to satisfy the conditions by providing a URI or 206 | hyperlink to a resource that includes the required 207 | information. 208 | 209 | 3. If requested by the Licensor, You must remove any of the 210 | information required by Section 3(a)(1)(A) to the extent 211 | reasonably practicable. 212 | 213 | 4. If You Share Adapted Material You produce, the Adapter's 214 | License You apply must not prevent recipients of the Adapted 215 | Material from complying with this Public License. 216 | 217 | 218 | Section 4 -- Sui Generis Database Rights. 219 | 220 | Where the Licensed Rights include Sui Generis Database Rights that 221 | apply to Your use of the Licensed Material: 222 | 223 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 224 | to extract, reuse, reproduce, and Share all or a substantial 225 | portion of the contents of the database for NonCommercial purposes 226 | only; 227 | 228 | b. if You include all or a substantial portion of the database 229 | contents in a database in which You have Sui Generis Database 230 | Rights, then the database in which You have Sui Generis Database 231 | Rights (but not its individual contents) is Adapted Material; and 232 | 233 | c. You must comply with the conditions in Section 3(a) if You Share 234 | all or a substantial portion of the contents of the database. 235 | 236 | For the avoidance of doubt, this Section 4 supplements and does not 237 | replace Your obligations under this Public License where the Licensed 238 | Rights include other Copyright and Similar Rights. 239 | 240 | 241 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 242 | 243 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 244 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 245 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 246 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 247 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 248 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 249 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 250 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 251 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 252 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 253 | 254 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 255 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 256 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 257 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 258 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 259 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 260 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 261 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 262 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 263 | 264 | c. The disclaimer of warranties and limitation of liability provided 265 | above shall be interpreted in a manner that, to the extent 266 | possible, most closely approximates an absolute disclaimer and 267 | waiver of all liability. 268 | 269 | 270 | Section 6 -- Term and Termination. 271 | 272 | a. This Public License applies for the term of the Copyright and 273 | Similar Rights licensed here. However, if You fail to comply with 274 | this Public License, then Your rights under this Public License 275 | terminate automatically. 276 | 277 | b. Where Your right to use the Licensed Material has terminated under 278 | Section 6(a), it reinstates: 279 | 280 | 1. automatically as of the date the violation is cured, provided 281 | it is cured within 30 days of Your discovery of the 282 | violation; or 283 | 284 | 2. upon express reinstatement by the Licensor. 285 | 286 | For the avoidance of doubt, this Section 6(b) does not affect any 287 | right the Licensor may have to seek remedies for Your violations 288 | of this Public License. 289 | 290 | c. For the avoidance of doubt, the Licensor may also offer the 291 | Licensed Material under separate terms or conditions or stop 292 | distributing the Licensed Material at any time; however, doing so 293 | will not terminate this Public License. 294 | 295 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 296 | License. 297 | 298 | 299 | Section 7 -- Other Terms and Conditions. 300 | 301 | a. The Licensor shall not be bound by any additional or different 302 | terms or conditions communicated by You unless expressly agreed. 303 | 304 | b. Any arrangements, understandings, or agreements regarding the 305 | Licensed Material not stated herein are separate from and 306 | independent of the terms and conditions of this Public License. 307 | 308 | 309 | Section 8 -- Interpretation. 310 | 311 | a. For the avoidance of doubt, this Public License does not, and 312 | shall not be interpreted to, reduce, limit, restrict, or impose 313 | conditions on any use of the Licensed Material that could lawfully 314 | be made without permission under this Public License. 315 | 316 | b. To the extent possible, if any provision of this Public License is 317 | deemed unenforceable, it shall be automatically reformed to the 318 | minimum extent necessary to make it enforceable. If the provision 319 | cannot be reformed, it shall be severed from this Public License 320 | without affecting the enforceability of the remaining terms and 321 | conditions. 322 | 323 | c. No term or condition of this Public License will be waived and no 324 | failure to comply consented to unless expressly agreed to by the 325 | Licensor. 326 | 327 | d. Nothing in this Public License constitutes or may be interpreted 328 | as a limitation upon, or waiver of, any privileges and immunities 329 | that apply to the Licensor or You, including from the legal 330 | processes of any jurisdiction or authority. 331 | 332 | ======================================================================= 333 | 334 | -------------------------------------------------------------------------------- /OpenVoice/utils/attentions.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from . import commons 7 | import logging 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class LayerNorm(nn.Module): 13 | def __init__(self, channels, eps=1e-5): 14 | super().__init__() 15 | self.channels = channels 16 | self.eps = eps 17 | 18 | self.gamma = nn.Parameter(torch.ones(channels)) 19 | self.beta = nn.Parameter(torch.zeros(channels)) 20 | 21 | def forward(self, x): 22 | x = x.transpose(1, -1) 23 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 24 | return x.transpose(1, -1) 25 | 26 | 27 | @torch.jit.script 28 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 29 | n_channels_int = n_channels[0] 30 | in_act = input_a + input_b 31 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 32 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 33 | acts = t_act * s_act 34 | return acts 35 | 36 | 37 | class Encoder(nn.Module): 38 | def __init__( 39 | self, 40 | hidden_channels, 41 | filter_channels, 42 | n_heads, 43 | n_layers, 44 | kernel_size=1, 45 | p_dropout=0.0, 46 | window_size=4, 47 | isflow=True, 48 | **kwargs 49 | ): 50 | super().__init__() 51 | self.hidden_channels = hidden_channels 52 | self.filter_channels = filter_channels 53 | self.n_heads = n_heads 54 | self.n_layers = n_layers 55 | self.kernel_size = kernel_size 56 | self.p_dropout = p_dropout 57 | self.window_size = window_size 58 | # if isflow: 59 | # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1) 60 | # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1) 61 | # self.cond_layer = weight_norm(cond_layer, name='weight') 62 | # self.gin_channels = 256 63 | self.cond_layer_idx = self.n_layers 64 | if "gin_channels" in kwargs: 65 | self.gin_channels = kwargs["gin_channels"] 66 | if self.gin_channels != 0: 67 | self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels) 68 | # vits2 says 3rd block, so idx is 2 by default 69 | self.cond_layer_idx = ( 70 | kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2 71 | ) 72 | # logging.debug(self.gin_channels, self.cond_layer_idx) 73 | assert ( 74 | self.cond_layer_idx < self.n_layers 75 | ), "cond_layer_idx should be less than n_layers" 76 | self.drop = nn.Dropout(p_dropout) 77 | self.attn_layers = nn.ModuleList() 78 | self.norm_layers_1 = nn.ModuleList() 79 | self.ffn_layers = nn.ModuleList() 80 | self.norm_layers_2 = nn.ModuleList() 81 | 82 | for i in range(self.n_layers): 83 | self.attn_layers.append( 84 | MultiHeadAttention( 85 | hidden_channels, 86 | hidden_channels, 87 | n_heads, 88 | p_dropout=p_dropout, 89 | window_size=window_size, 90 | ) 91 | ) 92 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 93 | self.ffn_layers.append( 94 | FFN( 95 | hidden_channels, 96 | hidden_channels, 97 | filter_channels, 98 | kernel_size, 99 | p_dropout=p_dropout, 100 | ) 101 | ) 102 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 103 | 104 | def forward(self, x, x_mask, g=None): 105 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 106 | x = x * x_mask 107 | for i in range(self.n_layers): 108 | if i == self.cond_layer_idx and g is not None: 109 | g = self.spk_emb_linear(g.transpose(1, 2)) 110 | g = g.transpose(1, 2) 111 | x = x + g 112 | x = x * x_mask 113 | y = self.attn_layers[i](x, x, attn_mask) 114 | y = self.drop(y) 115 | x = self.norm_layers_1[i](x + y) 116 | 117 | y = self.ffn_layers[i](x, x_mask) 118 | y = self.drop(y) 119 | x = self.norm_layers_2[i](x + y) 120 | x = x * x_mask 121 | return x 122 | 123 | 124 | class Decoder(nn.Module): 125 | def __init__( 126 | self, 127 | hidden_channels, 128 | filter_channels, 129 | n_heads, 130 | n_layers, 131 | kernel_size=1, 132 | p_dropout=0.0, 133 | proximal_bias=False, 134 | proximal_init=True, 135 | **kwargs 136 | ): 137 | super().__init__() 138 | self.hidden_channels = hidden_channels 139 | self.filter_channels = filter_channels 140 | self.n_heads = n_heads 141 | self.n_layers = n_layers 142 | self.kernel_size = kernel_size 143 | self.p_dropout = p_dropout 144 | self.proximal_bias = proximal_bias 145 | self.proximal_init = proximal_init 146 | 147 | self.drop = nn.Dropout(p_dropout) 148 | self.self_attn_layers = nn.ModuleList() 149 | self.norm_layers_0 = nn.ModuleList() 150 | self.encdec_attn_layers = nn.ModuleList() 151 | self.norm_layers_1 = nn.ModuleList() 152 | self.ffn_layers = nn.ModuleList() 153 | self.norm_layers_2 = nn.ModuleList() 154 | for i in range(self.n_layers): 155 | self.self_attn_layers.append( 156 | MultiHeadAttention( 157 | hidden_channels, 158 | hidden_channels, 159 | n_heads, 160 | p_dropout=p_dropout, 161 | proximal_bias=proximal_bias, 162 | proximal_init=proximal_init, 163 | ) 164 | ) 165 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 166 | self.encdec_attn_layers.append( 167 | MultiHeadAttention( 168 | hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout 169 | ) 170 | ) 171 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 172 | self.ffn_layers.append( 173 | FFN( 174 | hidden_channels, 175 | hidden_channels, 176 | filter_channels, 177 | kernel_size, 178 | p_dropout=p_dropout, 179 | causal=True, 180 | ) 181 | ) 182 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 183 | 184 | def forward(self, x, x_mask, h, h_mask): 185 | """ 186 | x: decoder input 187 | h: encoder output 188 | """ 189 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( 190 | device=x.device, dtype=x.dtype 191 | ) 192 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 193 | x = x * x_mask 194 | for i in range(self.n_layers): 195 | y = self.self_attn_layers[i](x, x, self_attn_mask) 196 | y = self.drop(y) 197 | x = self.norm_layers_0[i](x + y) 198 | 199 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) 200 | y = self.drop(y) 201 | x = self.norm_layers_1[i](x + y) 202 | 203 | y = self.ffn_layers[i](x, x_mask) 204 | y = self.drop(y) 205 | x = self.norm_layers_2[i](x + y) 206 | x = x * x_mask 207 | return x 208 | 209 | 210 | class MultiHeadAttention(nn.Module): 211 | def __init__( 212 | self, 213 | channels, 214 | out_channels, 215 | n_heads, 216 | p_dropout=0.0, 217 | window_size=None, 218 | heads_share=True, 219 | block_length=None, 220 | proximal_bias=False, 221 | proximal_init=False, 222 | ): 223 | super().__init__() 224 | assert channels % n_heads == 0 225 | 226 | self.channels = channels 227 | self.out_channels = out_channels 228 | self.n_heads = n_heads 229 | self.p_dropout = p_dropout 230 | self.window_size = window_size 231 | self.heads_share = heads_share 232 | self.block_length = block_length 233 | self.proximal_bias = proximal_bias 234 | self.proximal_init = proximal_init 235 | self.attn = None 236 | 237 | self.k_channels = channels // n_heads 238 | self.conv_q = nn.Conv1d(channels, channels, 1) 239 | self.conv_k = nn.Conv1d(channels, channels, 1) 240 | self.conv_v = nn.Conv1d(channels, channels, 1) 241 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 242 | self.drop = nn.Dropout(p_dropout) 243 | 244 | if window_size is not None: 245 | n_heads_rel = 1 if heads_share else n_heads 246 | rel_stddev = self.k_channels**-0.5 247 | self.emb_rel_k = nn.Parameter( 248 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) 249 | * rel_stddev 250 | ) 251 | self.emb_rel_v = nn.Parameter( 252 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) 253 | * rel_stddev 254 | ) 255 | 256 | nn.init.xavier_uniform_(self.conv_q.weight) 257 | nn.init.xavier_uniform_(self.conv_k.weight) 258 | nn.init.xavier_uniform_(self.conv_v.weight) 259 | if proximal_init: 260 | with torch.no_grad(): 261 | self.conv_k.weight.copy_(self.conv_q.weight) 262 | self.conv_k.bias.copy_(self.conv_q.bias) 263 | 264 | def forward(self, x, c, attn_mask=None): 265 | q = self.conv_q(x) 266 | k = self.conv_k(c) 267 | v = self.conv_v(c) 268 | 269 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 270 | 271 | x = self.conv_o(x) 272 | return x 273 | 274 | def attention(self, query, key, value, mask=None): 275 | # reshape [b, d, t] -> [b, n_h, t, d_k] 276 | b, d, t_s, t_t = (*key.size(), query.size(2)) 277 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 278 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 279 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 280 | 281 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) 282 | if self.window_size is not None: 283 | assert ( 284 | t_s == t_t 285 | ), "Relative attention is only available for self-attention." 286 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 287 | rel_logits = self._matmul_with_relative_keys( 288 | query / math.sqrt(self.k_channels), key_relative_embeddings 289 | ) 290 | scores_local = self._relative_position_to_absolute_position(rel_logits) 291 | scores = scores + scores_local 292 | if self.proximal_bias: 293 | assert t_s == t_t, "Proximal bias is only available for self-attention." 294 | scores = scores + self._attention_bias_proximal(t_s).to( 295 | device=scores.device, dtype=scores.dtype 296 | ) 297 | if mask is not None: 298 | scores = scores.masked_fill(mask == 0, -1e4) 299 | if self.block_length is not None: 300 | assert ( 301 | t_s == t_t 302 | ), "Local attention is only available for self-attention." 303 | block_mask = ( 304 | torch.ones_like(scores) 305 | .triu(-self.block_length) 306 | .tril(self.block_length) 307 | ) 308 | scores = scores.masked_fill(block_mask == 0, -1e4) 309 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 310 | p_attn = self.drop(p_attn) 311 | output = torch.matmul(p_attn, value) 312 | if self.window_size is not None: 313 | relative_weights = self._absolute_position_to_relative_position(p_attn) 314 | value_relative_embeddings = self._get_relative_embeddings( 315 | self.emb_rel_v, t_s 316 | ) 317 | output = output + self._matmul_with_relative_values( 318 | relative_weights, value_relative_embeddings 319 | ) 320 | output = ( 321 | output.transpose(2, 3).contiguous().view(b, d, t_t) 322 | ) # [b, n_h, t_t, d_k] -> [b, d, t_t] 323 | return output, p_attn 324 | 325 | def _matmul_with_relative_values(self, x, y): 326 | """ 327 | x: [b, h, l, m] 328 | y: [h or 1, m, d] 329 | ret: [b, h, l, d] 330 | """ 331 | ret = torch.matmul(x, y.unsqueeze(0)) 332 | return ret 333 | 334 | def _matmul_with_relative_keys(self, x, y): 335 | """ 336 | x: [b, h, l, d] 337 | y: [h or 1, m, d] 338 | ret: [b, h, l, m] 339 | """ 340 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 341 | return ret 342 | 343 | def _get_relative_embeddings(self, relative_embeddings, length): 344 | 2 * self.window_size + 1 345 | # Pad first before slice to avoid using cond ops. 346 | pad_length = max(length - (self.window_size + 1), 0) 347 | slice_start_position = max((self.window_size + 1) - length, 0) 348 | slice_end_position = slice_start_position + 2 * length - 1 349 | if pad_length > 0: 350 | padded_relative_embeddings = F.pad( 351 | relative_embeddings, 352 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), 353 | ) 354 | else: 355 | padded_relative_embeddings = relative_embeddings 356 | used_relative_embeddings = padded_relative_embeddings[ 357 | :, slice_start_position:slice_end_position 358 | ] 359 | return used_relative_embeddings 360 | 361 | def _relative_position_to_absolute_position(self, x): 362 | """ 363 | x: [b, h, l, 2*l-1] 364 | ret: [b, h, l, l] 365 | """ 366 | batch, heads, length, _ = x.size() 367 | # Concat columns of pad to shift from relative to absolute indexing. 368 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) 369 | 370 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 371 | x_flat = x.view([batch, heads, length * 2 * length]) 372 | x_flat = F.pad( 373 | x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) 374 | ) 375 | 376 | # Reshape and slice out the padded elements. 377 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ 378 | :, :, :length, length - 1 : 379 | ] 380 | return x_final 381 | 382 | def _absolute_position_to_relative_position(self, x): 383 | """ 384 | x: [b, h, l, l] 385 | ret: [b, h, l, 2*l-1] 386 | """ 387 | batch, heads, length, _ = x.size() 388 | # pad along column 389 | x = F.pad( 390 | x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) 391 | ) 392 | x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) 393 | # add 0's in the beginning that will skew the elements after reshape 394 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 395 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] 396 | return x_final 397 | 398 | def _attention_bias_proximal(self, length): 399 | """Bias for self-attention to encourage attention to close positions. 400 | Args: 401 | length: an integer scalar. 402 | Returns: 403 | a Tensor with shape [1, 1, length, length] 404 | """ 405 | r = torch.arange(length, dtype=torch.float32) 406 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 407 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 408 | 409 | 410 | class FFN(nn.Module): 411 | def __init__( 412 | self, 413 | in_channels, 414 | out_channels, 415 | filter_channels, 416 | kernel_size, 417 | p_dropout=0.0, 418 | activation=None, 419 | causal=False, 420 | ): 421 | super().__init__() 422 | self.in_channels = in_channels 423 | self.out_channels = out_channels 424 | self.filter_channels = filter_channels 425 | self.kernel_size = kernel_size 426 | self.p_dropout = p_dropout 427 | self.activation = activation 428 | self.causal = causal 429 | 430 | if causal: 431 | self.padding = self._causal_padding 432 | else: 433 | self.padding = self._same_padding 434 | 435 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) 436 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) 437 | self.drop = nn.Dropout(p_dropout) 438 | 439 | def forward(self, x, x_mask): 440 | x = self.conv_1(self.padding(x * x_mask)) 441 | if self.activation == "gelu": 442 | x = x * torch.sigmoid(1.702 * x) 443 | else: 444 | x = torch.relu(x) 445 | x = self.drop(x) 446 | x = self.conv_2(self.padding(x * x_mask)) 447 | return x * x_mask 448 | 449 | def _causal_padding(self, x): 450 | if self.kernel_size == 1: 451 | return x 452 | pad_l = self.kernel_size - 1 453 | pad_r = 0 454 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 455 | x = F.pad(x, commons.convert_pad_shape(padding)) 456 | return x 457 | 458 | def _same_padding(self, x): 459 | if self.kernel_size == 1: 460 | return x 461 | pad_l = (self.kernel_size - 1) // 2 462 | pad_r = self.kernel_size // 2 463 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 464 | x = F.pad(x, commons.convert_pad_shape(padding)) 465 | return x 466 | -------------------------------------------------------------------------------- /OpenVoice/utils/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from . import commons 7 | from . import modules 8 | from . import attentions 9 | 10 | from torch.nn import Conv1d, ConvTranspose1d, Conv2d 11 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 12 | 13 | from .utils.commons import init_weights, get_padding 14 | 15 | 16 | class TextEncoder(nn.Module): 17 | def __init__(self, 18 | n_vocab, 19 | out_channels, 20 | hidden_channels, 21 | filter_channels, 22 | n_heads, 23 | n_layers, 24 | kernel_size, 25 | p_dropout): 26 | super().__init__() 27 | self.n_vocab = n_vocab 28 | self.out_channels = out_channels 29 | self.hidden_channels = hidden_channels 30 | self.filter_channels = filter_channels 31 | self.n_heads = n_heads 32 | self.n_layers = n_layers 33 | self.kernel_size = kernel_size 34 | self.p_dropout = p_dropout 35 | 36 | self.emb = nn.Embedding(n_vocab, hidden_channels) 37 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) 38 | 39 | self.encoder = attentions.Encoder( 40 | hidden_channels, 41 | filter_channels, 42 | n_heads, 43 | n_layers, 44 | kernel_size, 45 | p_dropout) 46 | self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1) 47 | 48 | def forward(self, x, x_lengths): 49 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] 50 | x = torch.transpose(x, 1, -1) # [b, h, t] 51 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 52 | 53 | x = self.encoder(x * x_mask, x_mask) 54 | stats = self.proj(x) * x_mask 55 | 56 | m, logs = torch.split(stats, self.out_channels, dim=1) 57 | return x, m, logs, x_mask 58 | 59 | 60 | class DurationPredictor(nn.Module): 61 | def __init__( 62 | self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0 63 | ): 64 | super().__init__() 65 | 66 | self.in_channels = in_channels 67 | self.filter_channels = filter_channels 68 | self.kernel_size = kernel_size 69 | self.p_dropout = p_dropout 70 | self.gin_channels = gin_channels 71 | 72 | self.drop = nn.Dropout(p_dropout) 73 | self.conv_1 = nn.Conv1d( 74 | in_channels, filter_channels, kernel_size, padding=kernel_size // 2 75 | ) 76 | self.norm_1 = modules.LayerNorm(filter_channels) 77 | self.conv_2 = nn.Conv1d( 78 | filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 79 | ) 80 | self.norm_2 = modules.LayerNorm(filter_channels) 81 | self.proj = nn.Conv1d(filter_channels, 1, 1) 82 | 83 | if gin_channels != 0: 84 | self.cond = nn.Conv1d(gin_channels, in_channels, 1) 85 | 86 | def forward(self, x, x_mask, g=None): 87 | x = torch.detach(x) 88 | if g is not None: 89 | g = torch.detach(g) 90 | x = x + self.cond(g) 91 | x = self.conv_1(x * x_mask) 92 | x = torch.relu(x) 93 | x = self.norm_1(x) 94 | x = self.drop(x) 95 | x = self.conv_2(x * x_mask) 96 | x = torch.relu(x) 97 | x = self.norm_2(x) 98 | x = self.drop(x) 99 | x = self.proj(x * x_mask) 100 | return x * x_mask 101 | 102 | class StochasticDurationPredictor(nn.Module): 103 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): 104 | super().__init__() 105 | filter_channels = in_channels # it needs to be removed from future version. 106 | self.in_channels = in_channels 107 | self.filter_channels = filter_channels 108 | self.kernel_size = kernel_size 109 | self.p_dropout = p_dropout 110 | self.n_flows = n_flows 111 | self.gin_channels = gin_channels 112 | 113 | self.log_flow = modules.Log() 114 | self.flows = nn.ModuleList() 115 | self.flows.append(modules.ElementwiseAffine(2)) 116 | for i in range(n_flows): 117 | self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 118 | self.flows.append(modules.Flip()) 119 | 120 | self.post_pre = nn.Conv1d(1, filter_channels, 1) 121 | self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) 122 | self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 123 | self.post_flows = nn.ModuleList() 124 | self.post_flows.append(modules.ElementwiseAffine(2)) 125 | for i in range(4): 126 | self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 127 | self.post_flows.append(modules.Flip()) 128 | 129 | self.pre = nn.Conv1d(in_channels, filter_channels, 1) 130 | self.proj = nn.Conv1d(filter_channels, filter_channels, 1) 131 | self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 132 | if gin_channels != 0: 133 | self.cond = nn.Conv1d(gin_channels, filter_channels, 1) 134 | 135 | def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): 136 | x = torch.detach(x) 137 | x = self.pre(x) 138 | if g is not None: 139 | g = torch.detach(g) 140 | x = x + self.cond(g) 141 | x = self.convs(x, x_mask) 142 | x = self.proj(x) * x_mask 143 | 144 | if not reverse: 145 | flows = self.flows 146 | assert w is not None 147 | 148 | logdet_tot_q = 0 149 | h_w = self.post_pre(w) 150 | h_w = self.post_convs(h_w, x_mask) 151 | h_w = self.post_proj(h_w) * x_mask 152 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask 153 | z_q = e_q 154 | for flow in self.post_flows: 155 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) 156 | logdet_tot_q += logdet_q 157 | z_u, z1 = torch.split(z_q, [1, 1], 1) 158 | u = torch.sigmoid(z_u) * x_mask 159 | z0 = (w - u) * x_mask 160 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2]) 161 | logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q 162 | 163 | logdet_tot = 0 164 | z0, logdet = self.log_flow(z0, x_mask) 165 | logdet_tot += logdet 166 | z = torch.cat([z0, z1], 1) 167 | for flow in flows: 168 | z, logdet = flow(z, x_mask, g=x, reverse=reverse) 169 | logdet_tot = logdet_tot + logdet 170 | nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot 171 | return nll + logq # [b] 172 | else: 173 | flows = list(reversed(self.flows)) 174 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow 175 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale 176 | for flow in flows: 177 | z = flow(z, x_mask, g=x, reverse=reverse) 178 | z0, z1 = torch.split(z, [1, 1], 1) 179 | logw = z0 180 | return logw 181 | 182 | class PosteriorEncoder(nn.Module): 183 | def __init__( 184 | self, 185 | in_channels, 186 | out_channels, 187 | hidden_channels, 188 | kernel_size, 189 | dilation_rate, 190 | n_layers, 191 | gin_channels=0, 192 | ): 193 | super().__init__() 194 | self.in_channels = in_channels 195 | self.out_channels = out_channels 196 | self.hidden_channels = hidden_channels 197 | self.kernel_size = kernel_size 198 | self.dilation_rate = dilation_rate 199 | self.n_layers = n_layers 200 | self.gin_channels = gin_channels 201 | 202 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 203 | self.enc = modules.WN( 204 | hidden_channels, 205 | kernel_size, 206 | dilation_rate, 207 | n_layers, 208 | gin_channels=gin_channels, 209 | ) 210 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) 211 | 212 | def forward(self, x, x_lengths, g=None, tau=1.0): 213 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( 214 | x.dtype 215 | ) 216 | x = self.pre(x) * x_mask 217 | x = self.enc(x, x_mask, g=g) 218 | stats = self.proj(x) * x_mask 219 | m, logs = torch.split(stats, self.out_channels, dim=1) 220 | z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask 221 | return z, m, logs, x_mask 222 | 223 | 224 | class Generator(torch.nn.Module): 225 | def __init__( 226 | self, 227 | initial_channel, 228 | resblock, 229 | resblock_kernel_sizes, 230 | resblock_dilation_sizes, 231 | upsample_rates, 232 | upsample_initial_channel, 233 | upsample_kernel_sizes, 234 | gin_channels=0, 235 | ): 236 | super(Generator, self).__init__() 237 | self.num_kernels = len(resblock_kernel_sizes) 238 | self.num_upsamples = len(upsample_rates) 239 | self.conv_pre = Conv1d( 240 | initial_channel, upsample_initial_channel, 7, 1, padding=3 241 | ) 242 | resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 243 | 244 | self.ups = nn.ModuleList() 245 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): 246 | self.ups.append( 247 | weight_norm( 248 | ConvTranspose1d( 249 | upsample_initial_channel // (2**i), 250 | upsample_initial_channel // (2 ** (i + 1)), 251 | k, 252 | u, 253 | padding=(k - u) // 2, 254 | ) 255 | ) 256 | ) 257 | 258 | self.resblocks = nn.ModuleList() 259 | for i in range(len(self.ups)): 260 | ch = upsample_initial_channel // (2 ** (i + 1)) 261 | for j, (k, d) in enumerate( 262 | zip(resblock_kernel_sizes, resblock_dilation_sizes) 263 | ): 264 | self.resblocks.append(resblock(ch, k, d)) 265 | 266 | self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) 267 | self.ups.apply(init_weights) 268 | 269 | if gin_channels != 0: 270 | self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) 271 | 272 | def forward(self, x, g=None): 273 | x = self.conv_pre(x) 274 | if g is not None: 275 | x = x + self.cond(g) 276 | 277 | for i in range(self.num_upsamples): 278 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 279 | x = self.ups[i](x) 280 | xs = None 281 | for j in range(self.num_kernels): 282 | if xs is None: 283 | xs = self.resblocks[i * self.num_kernels + j](x) 284 | else: 285 | xs += self.resblocks[i * self.num_kernels + j](x) 286 | x = xs / self.num_kernels 287 | x = F.leaky_relu(x) 288 | x = self.conv_post(x) 289 | x = torch.tanh(x) 290 | 291 | return x 292 | 293 | def remove_weight_norm(self): 294 | print("Removing weight norm...") 295 | for layer in self.ups: 296 | remove_weight_norm(layer) 297 | for layer in self.resblocks: 298 | layer.remove_weight_norm() 299 | 300 | 301 | class ReferenceEncoder(nn.Module): 302 | """ 303 | inputs --- [N, Ty/r, n_mels*r] mels 304 | outputs --- [N, ref_enc_gru_size] 305 | """ 306 | 307 | def __init__(self, spec_channels, gin_channels=0, layernorm=True): 308 | super().__init__() 309 | self.spec_channels = spec_channels 310 | ref_enc_filters = [32, 32, 64, 64, 128, 128] 311 | K = len(ref_enc_filters) 312 | filters = [1] + ref_enc_filters 313 | convs = [ 314 | weight_norm( 315 | nn.Conv2d( 316 | in_channels=filters[i], 317 | out_channels=filters[i + 1], 318 | kernel_size=(3, 3), 319 | stride=(2, 2), 320 | padding=(1, 1), 321 | ) 322 | ) 323 | for i in range(K) 324 | ] 325 | self.convs = nn.ModuleList(convs) 326 | 327 | out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) 328 | self.gru = nn.GRU( 329 | input_size=ref_enc_filters[-1] * out_channels, 330 | hidden_size=256 // 2, 331 | batch_first=True, 332 | ) 333 | self.proj = nn.Linear(128, gin_channels) 334 | if layernorm: 335 | self.layernorm = nn.LayerNorm(self.spec_channels) 336 | else: 337 | self.layernorm = None 338 | 339 | def forward(self, inputs, mask=None): 340 | N = inputs.size(0) 341 | 342 | out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] 343 | if self.layernorm is not None: 344 | out = self.layernorm(out) 345 | 346 | for conv in self.convs: 347 | out = conv(out) 348 | # out = wn(out) 349 | out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] 350 | 351 | out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] 352 | T = out.size(1) 353 | N = out.size(0) 354 | out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] 355 | 356 | self.gru.flatten_parameters() 357 | memory, out = self.gru(out) # out --- [1, N, 128] 358 | 359 | return self.proj(out.squeeze(0)) 360 | 361 | def calculate_channels(self, L, kernel_size, stride, pad, n_convs): 362 | for i in range(n_convs): 363 | L = (L - kernel_size + 2 * pad) // stride + 1 364 | return L 365 | 366 | 367 | class ResidualCouplingBlock(nn.Module): 368 | def __init__(self, 369 | channels, 370 | hidden_channels, 371 | kernel_size, 372 | dilation_rate, 373 | n_layers, 374 | n_flows=4, 375 | gin_channels=0): 376 | super().__init__() 377 | self.channels = channels 378 | self.hidden_channels = hidden_channels 379 | self.kernel_size = kernel_size 380 | self.dilation_rate = dilation_rate 381 | self.n_layers = n_layers 382 | self.n_flows = n_flows 383 | self.gin_channels = gin_channels 384 | 385 | self.flows = nn.ModuleList() 386 | for i in range(n_flows): 387 | self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) 388 | self.flows.append(modules.Flip()) 389 | 390 | def forward(self, x, x_mask, g=None, reverse=False): 391 | if not reverse: 392 | for flow in self.flows: 393 | x, _ = flow(x, x_mask, g=g, reverse=reverse) 394 | else: 395 | for flow in reversed(self.flows): 396 | x = flow(x, x_mask, g=g, reverse=reverse) 397 | return x 398 | 399 | class SynthesizerTrn(nn.Module): 400 | """ 401 | Synthesizer for Training 402 | """ 403 | 404 | def __init__( 405 | self, 406 | n_vocab, 407 | spec_channels, 408 | inter_channels, 409 | hidden_channels, 410 | filter_channels, 411 | n_heads, 412 | n_layers, 413 | kernel_size, 414 | p_dropout, 415 | resblock, 416 | resblock_kernel_sizes, 417 | resblock_dilation_sizes, 418 | upsample_rates, 419 | upsample_initial_channel, 420 | upsample_kernel_sizes, 421 | n_speakers=256, 422 | gin_channels=256, 423 | **kwargs 424 | ): 425 | super().__init__() 426 | 427 | self.dec = Generator( 428 | inter_channels, 429 | resblock, 430 | resblock_kernel_sizes, 431 | resblock_dilation_sizes, 432 | upsample_rates, 433 | upsample_initial_channel, 434 | upsample_kernel_sizes, 435 | gin_channels=gin_channels, 436 | ) 437 | self.enc_q = PosteriorEncoder( 438 | spec_channels, 439 | inter_channels, 440 | hidden_channels, 441 | 5, 442 | 1, 443 | 16, 444 | gin_channels=gin_channels, 445 | ) 446 | 447 | self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) 448 | 449 | self.n_speakers = n_speakers 450 | if n_speakers == 0: 451 | self.ref_enc = ReferenceEncoder(spec_channels, gin_channels) 452 | else: 453 | self.enc_p = TextEncoder(n_vocab, 454 | inter_channels, 455 | hidden_channels, 456 | filter_channels, 457 | n_heads, 458 | n_layers, 459 | kernel_size, 460 | p_dropout) 461 | self.sdp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) 462 | self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) 463 | self.emb_g = nn.Embedding(n_speakers, gin_channels) 464 | 465 | def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., sdp_ratio=0.2, max_len=None): 466 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) 467 | if self.n_speakers > 0: 468 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] 469 | else: 470 | g = None 471 | 472 | logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * sdp_ratio \ 473 | + self.dp(x, x_mask, g=g) * (1 - sdp_ratio) 474 | 475 | w = torch.exp(logw) * x_mask * length_scale 476 | w_ceil = torch.ceil(w) 477 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 478 | y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype) 479 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) 480 | attn = commons.generate_path(w_ceil, attn_mask) 481 | 482 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] 483 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] 484 | 485 | z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale 486 | z = self.flow(z_p, y_mask, g=g, reverse=True) 487 | o = self.dec((z * y_mask)[:,:,:max_len], g=g) 488 | return o, attn, y_mask, (z, z_p, m_p, logs_p) 489 | 490 | def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0): 491 | g_src = sid_src 492 | g_tgt = sid_tgt 493 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src, tau=tau) 494 | z_p = self.flow(z, y_mask, g=g_src) 495 | z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) 496 | o_hat = self.dec(z_hat * y_mask, g=g_tgt) 497 | return o_hat, y_mask, (z, z_p, z_hat) 498 | -------------------------------------------------------------------------------- /OpenVoice/utils/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from torch.nn import Conv1d 7 | from torch.nn.utils import weight_norm, remove_weight_norm 8 | 9 | from . import commons 10 | from .commons import init_weights, get_padding 11 | from .transforms import piecewise_rational_quadratic_transform 12 | from .attentions import Encoder 13 | 14 | LRELU_SLOPE = 0.1 15 | 16 | 17 | class LayerNorm(nn.Module): 18 | def __init__(self, channels, eps=1e-5): 19 | super().__init__() 20 | self.channels = channels 21 | self.eps = eps 22 | 23 | self.gamma = nn.Parameter(torch.ones(channels)) 24 | self.beta = nn.Parameter(torch.zeros(channels)) 25 | 26 | def forward(self, x): 27 | x = x.transpose(1, -1) 28 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 29 | return x.transpose(1, -1) 30 | 31 | 32 | class ConvReluNorm(nn.Module): 33 | def __init__( 34 | self, 35 | in_channels, 36 | hidden_channels, 37 | out_channels, 38 | kernel_size, 39 | n_layers, 40 | p_dropout, 41 | ): 42 | super().__init__() 43 | self.in_channels = in_channels 44 | self.hidden_channels = hidden_channels 45 | self.out_channels = out_channels 46 | self.kernel_size = kernel_size 47 | self.n_layers = n_layers 48 | self.p_dropout = p_dropout 49 | assert n_layers > 1, "Number of layers should be larger than 0." 50 | 51 | self.conv_layers = nn.ModuleList() 52 | self.norm_layers = nn.ModuleList() 53 | self.conv_layers.append( 54 | nn.Conv1d( 55 | in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 56 | ) 57 | ) 58 | self.norm_layers.append(LayerNorm(hidden_channels)) 59 | self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) 60 | for _ in range(n_layers - 1): 61 | self.conv_layers.append( 62 | nn.Conv1d( 63 | hidden_channels, 64 | hidden_channels, 65 | kernel_size, 66 | padding=kernel_size // 2, 67 | ) 68 | ) 69 | self.norm_layers.append(LayerNorm(hidden_channels)) 70 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 71 | self.proj.weight.data.zero_() 72 | self.proj.bias.data.zero_() 73 | 74 | def forward(self, x, x_mask): 75 | x_org = x 76 | for i in range(self.n_layers): 77 | x = self.conv_layers[i](x * x_mask) 78 | x = self.norm_layers[i](x) 79 | x = self.relu_drop(x) 80 | x = x_org + self.proj(x) 81 | return x * x_mask 82 | 83 | 84 | class DDSConv(nn.Module): 85 | """ 86 | Dilated and Depth-Separable Convolution 87 | """ 88 | 89 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): 90 | super().__init__() 91 | self.channels = channels 92 | self.kernel_size = kernel_size 93 | self.n_layers = n_layers 94 | self.p_dropout = p_dropout 95 | 96 | self.drop = nn.Dropout(p_dropout) 97 | self.convs_sep = nn.ModuleList() 98 | self.convs_1x1 = nn.ModuleList() 99 | self.norms_1 = nn.ModuleList() 100 | self.norms_2 = nn.ModuleList() 101 | for i in range(n_layers): 102 | dilation = kernel_size**i 103 | padding = (kernel_size * dilation - dilation) // 2 104 | self.convs_sep.append( 105 | nn.Conv1d( 106 | channels, 107 | channels, 108 | kernel_size, 109 | groups=channels, 110 | dilation=dilation, 111 | padding=padding, 112 | ) 113 | ) 114 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 115 | self.norms_1.append(LayerNorm(channels)) 116 | self.norms_2.append(LayerNorm(channels)) 117 | 118 | def forward(self, x, x_mask, g=None): 119 | if g is not None: 120 | x = x + g 121 | for i in range(self.n_layers): 122 | y = self.convs_sep[i](x * x_mask) 123 | y = self.norms_1[i](y) 124 | y = F.gelu(y) 125 | y = self.convs_1x1[i](y) 126 | y = self.norms_2[i](y) 127 | y = F.gelu(y) 128 | y = self.drop(y) 129 | x = x + y 130 | return x * x_mask 131 | 132 | 133 | class WN(torch.nn.Module): 134 | def __init__( 135 | self, 136 | hidden_channels, 137 | kernel_size, 138 | dilation_rate, 139 | n_layers, 140 | gin_channels=0, 141 | p_dropout=0, 142 | ): 143 | super(WN, self).__init__() 144 | assert kernel_size % 2 == 1 145 | self.hidden_channels = hidden_channels 146 | self.kernel_size = (kernel_size,) 147 | self.dilation_rate = dilation_rate 148 | self.n_layers = n_layers 149 | self.gin_channels = gin_channels 150 | self.p_dropout = p_dropout 151 | 152 | self.in_layers = torch.nn.ModuleList() 153 | self.res_skip_layers = torch.nn.ModuleList() 154 | self.drop = nn.Dropout(p_dropout) 155 | 156 | if gin_channels != 0: 157 | cond_layer = torch.nn.Conv1d( 158 | gin_channels, 2 * hidden_channels * n_layers, 1 159 | ) 160 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") 161 | 162 | for i in range(n_layers): 163 | dilation = dilation_rate**i 164 | padding = int((kernel_size * dilation - dilation) / 2) 165 | in_layer = torch.nn.Conv1d( 166 | hidden_channels, 167 | 2 * hidden_channels, 168 | kernel_size, 169 | dilation=dilation, 170 | padding=padding, 171 | ) 172 | in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") 173 | self.in_layers.append(in_layer) 174 | 175 | # last one is not necessary 176 | if i < n_layers - 1: 177 | res_skip_channels = 2 * hidden_channels 178 | else: 179 | res_skip_channels = hidden_channels 180 | 181 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 182 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") 183 | self.res_skip_layers.append(res_skip_layer) 184 | 185 | def forward(self, x, x_mask, g=None, **kwargs): 186 | output = torch.zeros_like(x) 187 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 188 | 189 | if g is not None: 190 | g = self.cond_layer(g) 191 | 192 | for i in range(self.n_layers): 193 | x_in = self.in_layers[i](x) 194 | if g is not None: 195 | cond_offset = i * 2 * self.hidden_channels 196 | g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] 197 | else: 198 | g_l = torch.zeros_like(x_in) 199 | 200 | acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) 201 | acts = self.drop(acts) 202 | 203 | res_skip_acts = self.res_skip_layers[i](acts) 204 | if i < self.n_layers - 1: 205 | res_acts = res_skip_acts[:, : self.hidden_channels, :] 206 | x = (x + res_acts) * x_mask 207 | output = output + res_skip_acts[:, self.hidden_channels :, :] 208 | else: 209 | output = output + res_skip_acts 210 | return output * x_mask 211 | 212 | def remove_weight_norm(self): 213 | if self.gin_channels != 0: 214 | torch.nn.utils.remove_weight_norm(self.cond_layer) 215 | for l in self.in_layers: 216 | torch.nn.utils.remove_weight_norm(l) 217 | for l in self.res_skip_layers: 218 | torch.nn.utils.remove_weight_norm(l) 219 | 220 | 221 | class ResBlock1(torch.nn.Module): 222 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 223 | super(ResBlock1, self).__init__() 224 | self.convs1 = nn.ModuleList( 225 | [ 226 | weight_norm( 227 | Conv1d( 228 | channels, 229 | channels, 230 | kernel_size, 231 | 1, 232 | dilation=dilation[0], 233 | padding=get_padding(kernel_size, dilation[0]), 234 | ) 235 | ), 236 | weight_norm( 237 | Conv1d( 238 | channels, 239 | channels, 240 | kernel_size, 241 | 1, 242 | dilation=dilation[1], 243 | padding=get_padding(kernel_size, dilation[1]), 244 | ) 245 | ), 246 | weight_norm( 247 | Conv1d( 248 | channels, 249 | channels, 250 | kernel_size, 251 | 1, 252 | dilation=dilation[2], 253 | padding=get_padding(kernel_size, dilation[2]), 254 | ) 255 | ), 256 | ] 257 | ) 258 | self.convs1.apply(init_weights) 259 | 260 | self.convs2 = nn.ModuleList( 261 | [ 262 | weight_norm( 263 | Conv1d( 264 | channels, 265 | channels, 266 | kernel_size, 267 | 1, 268 | dilation=1, 269 | padding=get_padding(kernel_size, 1), 270 | ) 271 | ), 272 | weight_norm( 273 | Conv1d( 274 | channels, 275 | channels, 276 | kernel_size, 277 | 1, 278 | dilation=1, 279 | padding=get_padding(kernel_size, 1), 280 | ) 281 | ), 282 | weight_norm( 283 | Conv1d( 284 | channels, 285 | channels, 286 | kernel_size, 287 | 1, 288 | dilation=1, 289 | padding=get_padding(kernel_size, 1), 290 | ) 291 | ), 292 | ] 293 | ) 294 | self.convs2.apply(init_weights) 295 | 296 | def forward(self, x, x_mask=None): 297 | for c1, c2 in zip(self.convs1, self.convs2): 298 | xt = F.leaky_relu(x, LRELU_SLOPE) 299 | if x_mask is not None: 300 | xt = xt * x_mask 301 | xt = c1(xt) 302 | xt = F.leaky_relu(xt, LRELU_SLOPE) 303 | if x_mask is not None: 304 | xt = xt * x_mask 305 | xt = c2(xt) 306 | x = xt + x 307 | if x_mask is not None: 308 | x = x * x_mask 309 | return x 310 | 311 | def remove_weight_norm(self): 312 | for l in self.convs1: 313 | remove_weight_norm(l) 314 | for l in self.convs2: 315 | remove_weight_norm(l) 316 | 317 | 318 | class ResBlock2(torch.nn.Module): 319 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)): 320 | super(ResBlock2, self).__init__() 321 | self.convs = nn.ModuleList( 322 | [ 323 | weight_norm( 324 | Conv1d( 325 | channels, 326 | channels, 327 | kernel_size, 328 | 1, 329 | dilation=dilation[0], 330 | padding=get_padding(kernel_size, dilation[0]), 331 | ) 332 | ), 333 | weight_norm( 334 | Conv1d( 335 | channels, 336 | channels, 337 | kernel_size, 338 | 1, 339 | dilation=dilation[1], 340 | padding=get_padding(kernel_size, dilation[1]), 341 | ) 342 | ), 343 | ] 344 | ) 345 | self.convs.apply(init_weights) 346 | 347 | def forward(self, x, x_mask=None): 348 | for c in self.convs: 349 | xt = F.leaky_relu(x, LRELU_SLOPE) 350 | if x_mask is not None: 351 | xt = xt * x_mask 352 | xt = c(xt) 353 | x = xt + x 354 | if x_mask is not None: 355 | x = x * x_mask 356 | return x 357 | 358 | def remove_weight_norm(self): 359 | for l in self.convs: 360 | remove_weight_norm(l) 361 | 362 | 363 | class Log(nn.Module): 364 | def forward(self, x, x_mask, reverse=False, **kwargs): 365 | if not reverse: 366 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask 367 | logdet = torch.sum(-y, [1, 2]) 368 | return y, logdet 369 | else: 370 | x = torch.exp(x) * x_mask 371 | return x 372 | 373 | 374 | class Flip(nn.Module): 375 | def forward(self, x, *args, reverse=False, **kwargs): 376 | x = torch.flip(x, [1]) 377 | if not reverse: 378 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 379 | return x, logdet 380 | else: 381 | return x 382 | 383 | 384 | class ElementwiseAffine(nn.Module): 385 | def __init__(self, channels): 386 | super().__init__() 387 | self.channels = channels 388 | self.m = nn.Parameter(torch.zeros(channels, 1)) 389 | self.logs = nn.Parameter(torch.zeros(channels, 1)) 390 | 391 | def forward(self, x, x_mask, reverse=False, **kwargs): 392 | if not reverse: 393 | y = self.m + torch.exp(self.logs) * x 394 | y = y * x_mask 395 | logdet = torch.sum(self.logs * x_mask, [1, 2]) 396 | return y, logdet 397 | else: 398 | x = (x - self.m) * torch.exp(-self.logs) * x_mask 399 | return x 400 | 401 | 402 | class ResidualCouplingLayer(nn.Module): 403 | def __init__( 404 | self, 405 | channels, 406 | hidden_channels, 407 | kernel_size, 408 | dilation_rate, 409 | n_layers, 410 | p_dropout=0, 411 | gin_channels=0, 412 | mean_only=False, 413 | ): 414 | assert channels % 2 == 0, "channels should be divisible by 2" 415 | super().__init__() 416 | self.channels = channels 417 | self.hidden_channels = hidden_channels 418 | self.kernel_size = kernel_size 419 | self.dilation_rate = dilation_rate 420 | self.n_layers = n_layers 421 | self.half_channels = channels // 2 422 | self.mean_only = mean_only 423 | 424 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 425 | self.enc = WN( 426 | hidden_channels, 427 | kernel_size, 428 | dilation_rate, 429 | n_layers, 430 | p_dropout=p_dropout, 431 | gin_channels=gin_channels, 432 | ) 433 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) 434 | self.post.weight.data.zero_() 435 | self.post.bias.data.zero_() 436 | 437 | def forward(self, x, x_mask, g=None, reverse=False): 438 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 439 | h = self.pre(x0) * x_mask 440 | h = self.enc(h, x_mask, g=g) 441 | stats = self.post(h) * x_mask 442 | if not self.mean_only: 443 | m, logs = torch.split(stats, [self.half_channels] * 2, 1) 444 | else: 445 | m = stats 446 | logs = torch.zeros_like(m) 447 | 448 | if not reverse: 449 | x1 = m + x1 * torch.exp(logs) * x_mask 450 | x = torch.cat([x0, x1], 1) 451 | logdet = torch.sum(logs, [1, 2]) 452 | return x, logdet 453 | else: 454 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 455 | x = torch.cat([x0, x1], 1) 456 | return x 457 | 458 | 459 | class ConvFlow(nn.Module): 460 | def __init__( 461 | self, 462 | in_channels, 463 | filter_channels, 464 | kernel_size, 465 | n_layers, 466 | num_bins=10, 467 | tail_bound=5.0, 468 | ): 469 | super().__init__() 470 | self.in_channels = in_channels 471 | self.filter_channels = filter_channels 472 | self.kernel_size = kernel_size 473 | self.n_layers = n_layers 474 | self.num_bins = num_bins 475 | self.tail_bound = tail_bound 476 | self.half_channels = in_channels // 2 477 | 478 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) 479 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) 480 | self.proj = nn.Conv1d( 481 | filter_channels, self.half_channels * (num_bins * 3 - 1), 1 482 | ) 483 | self.proj.weight.data.zero_() 484 | self.proj.bias.data.zero_() 485 | 486 | def forward(self, x, x_mask, g=None, reverse=False): 487 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 488 | h = self.pre(x0) 489 | h = self.convs(h, x_mask, g=g) 490 | h = self.proj(h) * x_mask 491 | 492 | b, c, t = x0.shape 493 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] 494 | 495 | unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) 496 | unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt( 497 | self.filter_channels 498 | ) 499 | unnormalized_derivatives = h[..., 2 * self.num_bins :] 500 | 501 | x1, logabsdet = piecewise_rational_quadratic_transform( 502 | x1, 503 | unnormalized_widths, 504 | unnormalized_heights, 505 | unnormalized_derivatives, 506 | inverse=reverse, 507 | tails="linear", 508 | tail_bound=self.tail_bound, 509 | ) 510 | 511 | x = torch.cat([x0, x1], 1) * x_mask 512 | logdet = torch.sum(logabsdet * x_mask, [1, 2]) 513 | if not reverse: 514 | return x, logdet 515 | else: 516 | return x 517 | 518 | 519 | class TransformerCouplingLayer(nn.Module): 520 | def __init__( 521 | self, 522 | channels, 523 | hidden_channels, 524 | kernel_size, 525 | n_layers, 526 | n_heads, 527 | p_dropout=0, 528 | filter_channels=0, 529 | mean_only=False, 530 | wn_sharing_parameter=None, 531 | gin_channels=0, 532 | ): 533 | assert n_layers == 3, n_layers 534 | assert channels % 2 == 0, "channels should be divisible by 2" 535 | super().__init__() 536 | self.channels = channels 537 | self.hidden_channels = hidden_channels 538 | self.kernel_size = kernel_size 539 | self.n_layers = n_layers 540 | self.half_channels = channels // 2 541 | self.mean_only = mean_only 542 | 543 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 544 | self.enc = ( 545 | Encoder( 546 | hidden_channels, 547 | filter_channels, 548 | n_heads, 549 | n_layers, 550 | kernel_size, 551 | p_dropout, 552 | isflow=True, 553 | gin_channels=gin_channels, 554 | ) 555 | if wn_sharing_parameter is None 556 | else wn_sharing_parameter 557 | ) 558 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) 559 | self.post.weight.data.zero_() 560 | self.post.bias.data.zero_() 561 | 562 | def forward(self, x, x_mask, g=None, reverse=False): 563 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 564 | h = self.pre(x0) * x_mask 565 | h = self.enc(h, x_mask, g=g) 566 | stats = self.post(h) * x_mask 567 | if not self.mean_only: 568 | m, logs = torch.split(stats, [self.half_channels] * 2, 1) 569 | else: 570 | m = stats 571 | logs = torch.zeros_like(m) 572 | 573 | if not reverse: 574 | x1 = m + x1 * torch.exp(logs) * x_mask 575 | x = torch.cat([x0, x1], 1) 576 | logdet = torch.sum(logs, [1, 2]) 577 | return x, logdet 578 | else: 579 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 580 | x = torch.cat([x0, x1], 1) 581 | return x 582 | 583 | x1, logabsdet = piecewise_rational_quadratic_transform( 584 | x1, 585 | unnormalized_widths, 586 | unnormalized_heights, 587 | unnormalized_derivatives, 588 | inverse=reverse, 589 | tails="linear", 590 | tail_bound=self.tail_bound, 591 | ) 592 | 593 | x = torch.cat([x0, x1], 1) * x_mask 594 | logdet = torch.sum(logabsdet * x_mask, [1, 2]) 595 | if not reverse: 596 | return x, logdet 597 | else: 598 | return x 599 | --------------------------------------------------------------------------------