├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── __pycache__ └── __init__.cpython-311.pyc ├── assets ├── 03dd6465a900e81a6e1812302efc2b4.png ├── 1718816711480.png ├── 1718851026553.png └── 1719392506548.jpg ├── install.bat ├── nodes ├── ChatTTS │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ └── core.cpython-311.pyc │ ├── core.py │ ├── experimental │ │ └── llm.py │ ├── infer │ │ ├── __pycache__ │ │ │ └── api.cpython-311.pyc │ │ └── api.py │ ├── model │ │ ├── __pycache__ │ │ │ ├── dvae.cpython-311.pyc │ │ │ └── gpt.cpython-311.pyc │ │ ├── dvae.py │ │ └── gpt.py │ ├── res │ │ └── homophones_map.json │ └── utils │ │ ├── __pycache__ │ │ ├── gpu_utils.cpython-311.pyc │ │ ├── infer_utils.cpython-311.pyc │ │ └── io_utils.cpython-311.pyc │ │ ├── dl.py │ │ ├── gpu.py │ │ ├── gpu_utils.py │ │ ├── infer.py │ │ ├── infer_utils.py │ │ ├── io.py │ │ ├── io_utils.py │ │ └── log.py ├── __pycache__ │ ├── chat_tts.cpython-311.pyc │ └── chat_tts_run.cpython-311.pyc ├── chat_tts.py ├── chat_tts_run.py ├── openvoice │ ├── __init__.py │ ├── api.py │ ├── attentions.py │ ├── commons.py │ ├── mel_processing.py │ ├── models.py │ ├── modules.py │ ├── openvoice_app.py │ ├── se_extractor.py │ ├── text │ │ ├── __init__.py │ │ ├── cleaners.py │ │ ├── english.py │ │ ├── mandarin.py │ │ └── symbols.py │ ├── transforms.py │ └── utils.py ├── openvoice_run.py └── zh_normalization │ ├── README.md │ ├── __init__.py │ ├── char_convert.py │ ├── chronology.py │ ├── constants.py │ ├── num.py │ ├── phonecode.py │ ├── quantifier.py │ └── text_normlization.py ├── requirements.txt └── web └── loadSpeaker.js /.gitignore: -------------------------------------------------------------------------------- 1 | nodes/__pycache__ 2 | __pycache__ 3 | /nodes/__pycache__ 4 | /nodes/ChatTTS/__pycache__ 5 | *.pyc 6 | /__pycache__ 7 | /nodes/__pycache__ 8 | /nodes/__pycache__ 9 | /nodes/__pycache__ 10 | *.pyc 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 shadow 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Comfyui-ChatTTS 2 | > [寻求帮助 Mixlab nodes discord](https://discord.gg/cXs9vZSqeK) 3 | 4 | > [推荐:mixlab-nodes](https://github.com/shadowcz007/comfyui-mixlab-nodes) 5 | 6 | 7 | 目前可以创建音色,复用音色,支持多人对话模式的生成,寻求帮助可以加入[discord](https://discord.gg/cXs9vZSqeK),注意输入的text不需要加[speed_3][laugh_2]这种手动控制的标签。 8 | 9 | 10 | > 案例 : 多人对话 x 脱口秀 11 | 12 | [![alt text](assets/1718816711480.png)](https://www.youtube.com/embed/s6O9aKrr3pM?si=--mwIX1rR0axEQFn) 13 | 14 | 15 | ![alt text](assets/1718851026553.png) 16 | 17 | 18 | 节点: 19 | 20 | ChatTTS 21 | 22 | Multi Person Podcast 23 | 24 | CreateSpeakers 25 | 26 | SaveSpeaker 、LoadSpeaker : 方便保存和加载音色,支持 [ChatTTS_Speaker/summary](https://modelscope.cn/studios/ttwwwaa/ChatTTS_Speaker/summary) 的音色加载 27 | 28 | 29 | Load Whisper Model、Whisper Transcribe:方便导出音频对应的字幕文件 30 | 31 | 32 | OpenVoiceClone :方便迁移音色,更好地控制角色声音 33 | 34 | ![alt text](assets/03dd6465a900e81a6e1812302efc2b4.png) 35 | 36 | 37 | 38 | 模型: 39 | 40 | 下载后放到 ```models/chat_tts``` 41 | 42 | https://huggingface.co/2Noise/ChatTTS 43 | 44 | 音色pt文件放到```models/chat_tts_speaker``` 45 | 46 | [openvoice 模型](https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/checkpoints_v2_0417.zip)放到```models/open_voice``` 47 | 48 | 49 | [whisper模型](https://github.com/SYSTRAN/faster-whisper/tree/master)放到```models/whisper/large-v3``` 50 | 51 | ![alt text](assets/1719392506548.jpg) 52 | 53 | 54 | > 分支是一个课程的示例代码:以ChatTTS为例,为ComfyUI增加语音合成功能。一个自定义的节点需要完成: 55 | python 运行时(后端)- 后端python怎么写 56 | GUI - 怎么修改节点界面 57 | 58 | 59 | 60 | ### 相关插件推荐 61 | 62 | [comfyui-liveportrait](https://github.com/shadowcz007/comfyui-liveportrait) 63 | 64 | [Comfyui-ChatTTS](https://github.com/shadowcz007/Comfyui-ChatTTS) 65 | 66 | [comfyui-sound-lab](https://github.com/shadowcz007/comfyui-sound-lab) 67 | 68 | [comfyui-Image-reward](https://github.com/shadowcz007/comfyui-Image-reward) 69 | 70 | [comfyui-ultralytics-yolo](https://github.com/shadowcz007/comfyui-ultralytics-yolo) 71 | 72 | [comfyui-moondream](https://github.com/shadowcz007/comfyui-moondream) -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes.chat_tts import ChatTTSNode,multiPersonPodcast,CreateSpeakers,SaveSpeaker,LoadSpeaker,MergeSpeaker,RenameSpeaker,OpenVoiceClone,LoadWhisperModel,WhisperTranscribe,OpenVoiceCloneBySpeaker 2 | 3 | 4 | NODE_CLASS_MAPPINGS = { 5 | "ChatTTS_": ChatTTSNode, 6 | "CreateSpeakers":CreateSpeakers, 7 | "MultiPersonPodcast":multiPersonPodcast, 8 | "OpenVoiceClone":OpenVoiceClone, 9 | "OpenVoiceCloneBySpeaker":OpenVoiceCloneBySpeaker, 10 | "SaveSpeaker":SaveSpeaker, 11 | "LoadSpeaker":LoadSpeaker, 12 | "MergeSpeaker":MergeSpeaker, 13 | "RenameSpeaker":RenameSpeaker, 14 | "LoadWhisperModel":LoadWhisperModel, 15 | "WhisperTranscribe":WhisperTranscribe 16 | } 17 | 18 | # dict = { "key":value } 19 | 20 | NODE_DISPLAY_NAME_MAPPINGS = { 21 | "ChatTTS_": "ChatTTS", 22 | "MultiPersonPodcast":"Multi Person Podcast", 23 | "CreateSpeakers":"Create Speakers", 24 | "OpenVoiceClone":"OpenVoice Clone", 25 | "OpenVoiceCloneBySpeaker":"OpenVoice Clone By Speaker", 26 | "SaveSpeaker":"Save Speaker", 27 | "LoadSpeaker":"Load Speaker", 28 | "MergeSpeaker":"Merge Speaker", 29 | "RenameSpeaker":"Rename Speaker", 30 | "LoadWhisperModel":"Load Whisper Model", 31 | "WhisperTranscribe":"Whisper Transcribe" 32 | } 33 | 34 | # web ui的节点功能 35 | WEB_DIRECTORY = "./web" 36 | -------------------------------------------------------------------------------- /__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadowcz007/Comfyui-ChatTTS/3dbed17f2f858b1eef3d7415b5cb718d52ec3842/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /assets/03dd6465a900e81a6e1812302efc2b4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadowcz007/Comfyui-ChatTTS/3dbed17f2f858b1eef3d7415b5cb718d52ec3842/assets/03dd6465a900e81a6e1812302efc2b4.png -------------------------------------------------------------------------------- /assets/1718816711480.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadowcz007/Comfyui-ChatTTS/3dbed17f2f858b1eef3d7415b5cb718d52ec3842/assets/1718816711480.png -------------------------------------------------------------------------------- /assets/1718851026553.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadowcz007/Comfyui-ChatTTS/3dbed17f2f858b1eef3d7415b5cb718d52ec3842/assets/1718851026553.png -------------------------------------------------------------------------------- /assets/1719392506548.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadowcz007/Comfyui-ChatTTS/3dbed17f2f858b1eef3d7415b5cb718d52ec3842/assets/1719392506548.jpg -------------------------------------------------------------------------------- /install.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | set "requirements_txt=%~dp0\requirements.txt" 4 | set "python_exec=..\..\..\python_embeded\python.exe" 5 | 6 | echo Installing ComfyUI's Comfyui-ChatTTS Nodes.. 7 | 8 | if exist "%python_exec%" ( 9 | echo Installing with ComfyUI Portable 10 | for /f "delims=" %%i in (%requirements_txt%) do ( 11 | %python_exec% -s -m pip install "%%i" -i https://pypi.tuna.tsinghua.edu.cn/simple 12 | ) 13 | ) else ( 14 | echo Installing with system Python 15 | for /f "delims=" %%i in (%requirements_txt%) do ( 16 | pip install "%%i" -i https://pypi.tuna.tsinghua.edu.cn/simple 17 | ) 18 | ) 19 | 20 | pause -------------------------------------------------------------------------------- /nodes/ChatTTS/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import Chat -------------------------------------------------------------------------------- /nodes/ChatTTS/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadowcz007/Comfyui-ChatTTS/3dbed17f2f858b1eef3d7415b5cb718d52ec3842/nodes/ChatTTS/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /nodes/ChatTTS/__pycache__/core.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadowcz007/Comfyui-ChatTTS/3dbed17f2f858b1eef3d7415b5cb718d52ec3842/nodes/ChatTTS/__pycache__/core.cpython-311.pyc -------------------------------------------------------------------------------- /nodes/ChatTTS/core.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | import tempfile 5 | from functools import partial 6 | from typing import Literal, Optional, List, Callable 7 | 8 | import numpy as np 9 | import torch 10 | from omegaconf import OmegaConf 11 | from vocos import Vocos 12 | from huggingface_hub import snapshot_download 13 | 14 | from .model.dvae import DVAE 15 | from .model.gpt import GPT 16 | from .utils.gpu import select_device 17 | from .utils.infer import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map, HomophonesReplacer 18 | from .utils.io import get_latest_modified_file, del_all 19 | from .infer.api import refine_text, infer_code 20 | from .utils.dl import check_all_assets, download_all_assets 21 | from .utils.log import logger as utils_logger 22 | 23 | 24 | class Chat: 25 | def __init__(self, logger=logging.getLogger(__name__)): 26 | self.pretrain_models = {} 27 | self.normalizer = {} 28 | self.homophones_replacer = None 29 | self.logger = logger 30 | utils_logger.set_logger(logger) 31 | 32 | def has_loaded(self, use_decoder = False): 33 | not_finish = False 34 | check_list = ['gpt', 'tokenizer'] 35 | 36 | if use_decoder: 37 | check_list.append('decoder') 38 | else: 39 | check_list.append('dvae') 40 | 41 | for module in check_list: 42 | if module not in self.pretrain_models: 43 | self.logger.warn(f'{module} not initialized.') 44 | not_finish = True 45 | 46 | if not hasattr(self, "_vocos_decode") or not hasattr(self, "vocos"): 47 | self.logger.warn('vocos not initialized.') 48 | not_finish = True 49 | 50 | if not not_finish: 51 | self.logger.info('all models has been initialized.') 52 | 53 | return not not_finish 54 | 55 | def download_models( 56 | self, 57 | source: Literal['huggingface', 'local', 'custom']='local', 58 | force_redownload=False, 59 | custom_path: Optional[torch.serialization.FILE_LIKE]=None, 60 | ) -> Optional[str]: 61 | if source == 'local': 62 | download_path = os.getcwd() 63 | if not check_all_assets(update=True) or force_redownload: 64 | with tempfile.TemporaryDirectory() as tmp: 65 | download_all_assets(tmpdir=tmp) 66 | if not check_all_assets(update=False): 67 | self.logger.error("download to local path %s failed.", download_path) 68 | return None 69 | elif source == 'huggingface': 70 | hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface")) 71 | try: 72 | download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots')) 73 | except: 74 | download_path = None 75 | if download_path is None or force_redownload: 76 | self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS') 77 | try: 78 | download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"]) 79 | except: 80 | download_path = None 81 | else: 82 | self.logger.log(logging.INFO, f'load latest snapshot from cache: {download_path}') 83 | if download_path is None: 84 | self.logger.error("download from huggingface failed.") 85 | return None 86 | elif source == 'custom': 87 | self.logger.log(logging.INFO, f'try to load from local: {custom_path}') 88 | download_path = custom_path 89 | 90 | return download_path 91 | 92 | def load_models( 93 | self, 94 | source: Literal['huggingface', 'local', 'custom']='local', 95 | force_redownload=False, 96 | compile: bool = True, 97 | custom_path: Optional[torch.serialization.FILE_LIKE]=None, 98 | device: Optional[torch.device] = None, 99 | coef: Optional[torch.Tensor] = None, 100 | ) -> bool: 101 | download_path = self.download_models(source, force_redownload, custom_path) 102 | if download_path is None: 103 | return False 104 | return self._load( 105 | device=device, compile=compile, coef=coef, 106 | **{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, 107 | ) 108 | 109 | def _load( 110 | self, 111 | vocos_config_path: str = None, 112 | vocos_ckpt_path: str = None, 113 | dvae_config_path: str = None, 114 | dvae_ckpt_path: str = None, 115 | gpt_config_path: str = None, 116 | gpt_ckpt_path: str = None, 117 | decoder_config_path: str = None, 118 | decoder_ckpt_path: str = None, 119 | tokenizer_path: str = None, 120 | device: Optional[torch.device] = None, 121 | compile: bool = True, 122 | coef: Optional[str] = None 123 | ): 124 | if device is None: 125 | device = select_device(4096) 126 | self.logger.log(logging.INFO, f'use {device}') 127 | self.device = device 128 | 129 | if vocos_config_path: 130 | vocos = Vocos.from_hparams(vocos_config_path).to( 131 | # vocos on mps will crash, use cpu fallback 132 | "cpu" if "mps" in str(device) else device 133 | ).eval() 134 | assert vocos_ckpt_path, 'vocos_ckpt_path should not be None' 135 | vocos.load_state_dict(torch.load(vocos_ckpt_path)) 136 | self.vocos = vocos 137 | if "mps" in str(self.device): 138 | self._vocos_decode: Callable[[torch.Tensor], np.ndarray] = lambda spec: self.vocos.decode( 139 | spec.cpu() 140 | ).cpu().numpy() 141 | else: 142 | self._vocos_decode: Callable[[torch.Tensor], np.ndarray] = lambda spec: self.vocos.decode( 143 | spec 144 | ).cpu().numpy() 145 | self.logger.log(logging.INFO, 'vocos loaded.') 146 | 147 | if dvae_config_path: 148 | cfg = OmegaConf.load(dvae_config_path) 149 | dvae = DVAE(**cfg, coef=coef).to(device).eval() 150 | coef = str(dvae) 151 | assert dvae_ckpt_path, 'dvae_ckpt_path should not be None' 152 | dvae.load_state_dict(torch.load(dvae_ckpt_path)) 153 | self.pretrain_models['dvae'] = dvae 154 | self.logger.log(logging.INFO, 'dvae loaded.') 155 | 156 | if gpt_config_path: 157 | cfg = OmegaConf.load(gpt_config_path) 158 | gpt = GPT(**cfg, device=device, logger=self.logger).eval() 159 | assert gpt_ckpt_path, 'gpt_ckpt_path should not be None' 160 | gpt.load_state_dict(torch.load(gpt_ckpt_path)) 161 | if compile and 'cuda' in str(device): 162 | try: 163 | gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True) 164 | except RuntimeError as e: 165 | self.logger.warning(f'Compile failed,{e}. fallback to normal mode.') 166 | self.pretrain_models['gpt'] = gpt 167 | spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), 'spk_stat.pt') 168 | assert os.path.exists(spk_stat_path), f'Missing spk_stat.pt: {spk_stat_path}' 169 | self.pretrain_models['spk_stat'] = torch.load(spk_stat_path).to(device) 170 | self.logger.log(logging.INFO, 'gpt loaded.') 171 | 172 | if decoder_config_path: 173 | cfg = OmegaConf.load(decoder_config_path) 174 | decoder = DVAE(**cfg, coef=coef).to(device).eval() 175 | coef = str(decoder) 176 | assert decoder_ckpt_path, 'decoder_ckpt_path should not be None' 177 | decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu')) 178 | self.pretrain_models['decoder'] = decoder 179 | self.logger.log(logging.INFO, 'decoder loaded.') 180 | 181 | if tokenizer_path: 182 | tokenizer = torch.load(tokenizer_path, map_location='cpu') 183 | tokenizer.padding_side = 'left' 184 | self.pretrain_models['tokenizer'] = tokenizer 185 | self.logger.log(logging.INFO, 'tokenizer loaded.') 186 | 187 | self.coef = coef 188 | 189 | return self.has_loaded() 190 | 191 | def unload(self): 192 | logger = self.logger 193 | del_all(self) 194 | self.__init__(logger) 195 | 196 | def _infer( 197 | self, 198 | text, 199 | skip_refine_text=False, 200 | refine_text_only=False, 201 | params_refine_text={}, 202 | params_infer_code={'prompt':'[speed_5]'}, 203 | use_decoder=True, 204 | do_text_normalization=True, 205 | lang=None, 206 | stream=False, 207 | do_homophone_replacement=True 208 | ): 209 | 210 | assert self.has_loaded(use_decoder=use_decoder) 211 | 212 | if not isinstance(text, list): 213 | text = [text] 214 | if do_text_normalization: 215 | for i, t in enumerate(text): 216 | _lang = detect_language(t) if lang is None else lang 217 | if self._init_normalizer(_lang): 218 | text[i] = self.normalizer[_lang](t) 219 | if _lang == 'zh': 220 | text[i] = apply_half2full_map(text[i]) 221 | for i, t in enumerate(text): 222 | invalid_characters = count_invalid_characters(t) 223 | if len(invalid_characters): 224 | self.logger.warn(f'Invalid characters found! : {invalid_characters}') 225 | text[i] = apply_character_map(t) 226 | if do_homophone_replacement and self._init_homophones_replacer(): 227 | text[i], replaced_words = self.homophones_replacer.replace(text[i]) 228 | if replaced_words: 229 | repl_res = ', '.join([f'{_[0]}->{_[1]}' for _ in replaced_words]) 230 | self.logger.log(logging.INFO, f'Homophones replace: {repl_res}') 231 | 232 | if not skip_refine_text: 233 | refined = refine_text( 234 | self.pretrain_models, 235 | text, 236 | device=self.device, 237 | **params_refine_text, 238 | ) 239 | text_tokens = refined.ids 240 | text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens] 241 | text = self.pretrain_models['tokenizer'].batch_decode(text_tokens) 242 | refined.destroy() 243 | if refine_text_only: 244 | yield text 245 | return 246 | 247 | text = [params_infer_code.get('prompt', '') + i for i in text] 248 | print('\033[93m' + '#infer text:' + '\033[0m', text) 249 | params_infer_code.pop('prompt', '') 250 | 251 | length = [0 for _ in range(len(text))] 252 | for result in infer_code( 253 | self.pretrain_models, 254 | text, 255 | device=self.device, 256 | **params_infer_code, 257 | return_hidden=use_decoder, 258 | stream=stream, 259 | ): 260 | wav = self.decode_to_wavs(result, length, use_decoder) 261 | yield wav 262 | 263 | def infer( 264 | self, 265 | text, 266 | skip_refine_text=False, 267 | refine_text_only=False, 268 | params_refine_text={}, 269 | params_infer_code={'prompt':'[speed_5]'}, 270 | use_decoder=True, 271 | do_text_normalization=True, 272 | lang=None, 273 | stream=False, 274 | do_homophone_replacement=True, 275 | ): 276 | res_gen = self._infer( 277 | text, 278 | skip_refine_text, 279 | refine_text_only, 280 | params_refine_text, 281 | params_infer_code, 282 | use_decoder, 283 | do_text_normalization, 284 | lang, 285 | stream, 286 | do_homophone_replacement, 287 | ) 288 | if stream: 289 | return res_gen 290 | else: 291 | return next(res_gen) 292 | 293 | def sample_random_speaker(self): 294 | dim = self.pretrain_models['gpt'].gpt.layers[0].mlp.gate_proj.in_features 295 | std, mean = self.pretrain_models['spk_stat'].chunk(2) 296 | return torch.randn(dim, device=std.device) * std + mean 297 | 298 | def decode_to_wavs(self, result: GPT.GenerationOutputs, start_seeks: List[int], use_decoder: bool): 299 | x = result.hiddens if use_decoder else result.ids 300 | wavs: List[np.ndarray] = [] 301 | for i, chunk_data in enumerate(x): 302 | start_seek = start_seeks[i] 303 | length = len(chunk_data) 304 | if length <= start_seek: 305 | wavs.append(None) 306 | continue 307 | start_seeks[i] = length 308 | chunk_data = chunk_data[start_seek:] 309 | if use_decoder: 310 | decoder = self.pretrain_models['decoder'] 311 | else: 312 | decoder = self.pretrain_models['dvae'] 313 | mel_spec = decoder(chunk_data[None].permute(0,2,1).to(self.device)) 314 | del chunk_data 315 | wavs.append(self._vocos_decode(mel_spec)) 316 | del_all(mel_spec) 317 | result.destroy() 318 | del_all(x) 319 | return wavs 320 | 321 | def _init_normalizer(self, lang) -> bool: 322 | 323 | if lang in self.normalizer: 324 | return True 325 | 326 | if lang == 'zh': 327 | try: 328 | from tn.chinese.normalizer import Normalizer 329 | self.normalizer[lang] = Normalizer().normalize 330 | return True 331 | except: 332 | self.logger.log( 333 | logging.WARNING, 334 | 'Package WeTextProcessing not found!', 335 | ) 336 | self.logger.log( 337 | logging.WARNING, 338 | 'Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing', 339 | ) 340 | else: 341 | try: 342 | from nemo_text_processing.text_normalization.normalize import Normalizer 343 | self.normalizer[lang] = partial(Normalizer(input_case='cased', lang=lang).normalize, verbose=False, punct_post_process=True) 344 | return True 345 | except: 346 | self.logger.log( 347 | logging.WARNING, 348 | 'Package nemo_text_processing not found!', 349 | ) 350 | self.logger.log( 351 | logging.WARNING, 352 | 'Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing', 353 | ) 354 | return False 355 | 356 | def _init_homophones_replacer(self): 357 | if self.homophones_replacer: 358 | return True 359 | else: 360 | try: 361 | self.homophones_replacer = HomophonesReplacer(os.path.join(os.path.dirname(__file__), 'res', 'homophones_map.json')) 362 | self.logger.log(logging.INFO, 'successfully loaded HomophonesReplacer.') 363 | return True 364 | except (IOError, json.JSONDecodeError) as e: 365 | self.logger.log(logging.WARNING, f'error loading homophones map: {e}') 366 | except Exception as e: 367 | self.logger.log(logging.WARNING, f'error loading HomophonesReplacer: {e}') 368 | return False 369 | -------------------------------------------------------------------------------- /nodes/ChatTTS/experimental/llm.py: -------------------------------------------------------------------------------- 1 | 2 | from openai import OpenAI 3 | 4 | prompt_dict = { 5 | 'kimi': [ {"role": "system", "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。"}, 6 | {"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"}, 7 | {"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},], 8 | 'deepseek': [ 9 | {"role": "system", "content": "You are a helpful assistant"}, 10 | {"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"}, 11 | {"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},], 12 | 'deepseek_TN': [ 13 | {"role": "system", "content": "You are a helpful assistant"}, 14 | {"role": "user", "content": "你好,现在我们在处理TTS的文本输入,下面将会给你输入一段文本,请你将其中的阿拉伯数字等等转为文字表达,并且输出的文本里仅包含逗号和句号这两个标点符号"}, 15 | {"role": "assistant", "content": "好的,我现在对TTS的文本输入进行处理。这一般叫做text normalization。下面请输入"}, 16 | {"role": "user", "content": "We paid $123 for this desk."}, 17 | {"role": "assistant", "content": "We paid one hundred and twenty three dollars for this desk."}, 18 | {"role": "user", "content": "详询请拨打010-724654"}, 19 | {"role": "assistant", "content": "详询请拨打零幺零,七二四六五四"}, 20 | {"role": "user", "content": "罗森宣布将于7月24日退市,在华门店超6000家!"}, 21 | {"role": "assistant", "content": "罗森宣布将于七月二十四日退市,在华门店超过六千家。"}, 22 | ], 23 | } 24 | 25 | class llm_api: 26 | def __init__(self, api_key, base_url, model): 27 | self.client = OpenAI( 28 | api_key = api_key, 29 | base_url = base_url, 30 | ) 31 | self.model = model 32 | def call(self, user_question, temperature = 0.3, prompt_version='kimi', **kwargs): 33 | 34 | completion = self.client.chat.completions.create( 35 | model = self.model, 36 | messages = prompt_dict[prompt_version]+[{"role": "user", "content": user_question},], 37 | temperature = temperature, 38 | **kwargs 39 | ) 40 | return completion.choices[0].message.content 41 | -------------------------------------------------------------------------------- /nodes/ChatTTS/infer/__pycache__/api.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadowcz007/Comfyui-ChatTTS/3dbed17f2f858b1eef3d7415b5cb718d52ec3842/nodes/ChatTTS/infer/__pycache__/api.cpython-311.pyc -------------------------------------------------------------------------------- /nodes/ChatTTS/infer/api.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | from transformers.generation import TopKLogitsWarper, TopPLogitsWarper 5 | 6 | from ..utils.infer import CustomRepetitionPenaltyLogitsProcessorRepeat 7 | from ..utils.io import del_all 8 | from ..model.gpt import GPT 9 | 10 | def infer_code( 11 | models, 12 | text, 13 | spk_emb = None, 14 | top_P = 0.7, 15 | top_K = 20, 16 | temperature = 0.3, 17 | repetition_penalty = 1.05, 18 | max_new_token = 2048, 19 | stream=False, 20 | device="cpu", 21 | **kwargs 22 | ): 23 | 24 | gpt: GPT = models['gpt'] 25 | 26 | if not isinstance(text, list): 27 | text = [text] 28 | 29 | if not isinstance(temperature, list): 30 | temperature = [temperature] * gpt.num_vq 31 | 32 | if spk_emb is not None: 33 | text = [f'[Stts][spk_emb]{i}[Ptts]' for i in text] 34 | else: 35 | text = [f'[Stts][empty_spk]{i}[Ptts]' for i in text] 36 | 37 | text_token_tmp = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True) 38 | text_token = text_token_tmp.to(device) 39 | del text_token_tmp 40 | input_ids = text_token['input_ids'][...,None].expand(-1, -1, gpt.num_vq).to(gpt.device_gpt) 41 | text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=gpt.device_gpt) 42 | 43 | emb = gpt(input_ids, text_mask) 44 | del text_mask 45 | 46 | if spk_emb is not None: 47 | n = F.normalize(spk_emb.to(emb.dtype)[None].expand(len(text), -1), p=2.0, dim=1, eps=1e-12).to(gpt.device_gpt) 48 | emb[input_ids[..., 0] == models['tokenizer'].convert_tokens_to_ids('[spk_emb]')] = n 49 | del n 50 | 51 | num_code = int(gpt.emb_code[0].num_embeddings - 1) 52 | 53 | LogitsWarpers = [] 54 | if top_P is not None: 55 | LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3)) 56 | if top_K is not None: 57 | LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3)) 58 | 59 | LogitsProcessors = [] 60 | if repetition_penalty is not None and repetition_penalty != 1: 61 | LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(\ 62 | repetition_penalty, num_code, 16)) 63 | 64 | result = gpt.generate( 65 | emb, input_ids, 66 | temperature = torch.tensor(temperature, device=device), 67 | attention_mask = text_token['attention_mask'], 68 | LogitsWarpers = LogitsWarpers, 69 | LogitsProcessors = LogitsProcessors, 70 | eos_token = num_code, 71 | max_new_token = max_new_token, 72 | infer_text = False, 73 | stream = stream, 74 | **kwargs 75 | ) 76 | 77 | del_all(text_token) 78 | del emb, text_token, input_ids 79 | del_all(LogitsWarpers) 80 | del_all(LogitsProcessors) 81 | 82 | return result 83 | 84 | 85 | def refine_text( 86 | models, 87 | text, 88 | top_P = 0.7, 89 | top_K = 20, 90 | temperature = 0.7, 91 | repetition_penalty = 1.0, 92 | max_new_token = 384, 93 | prompt = '', 94 | device="cpu", 95 | **kwargs 96 | ): 97 | 98 | gpt: GPT = models['gpt'] 99 | 100 | if not isinstance(text, list): 101 | text = [text] 102 | 103 | assert len(text), 'text should not be empty' 104 | 105 | text = [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text] 106 | text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device) 107 | text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device) 108 | 109 | input_ids = text_token['input_ids'][...,None].expand(-1, -1, gpt.num_vq) 110 | 111 | LogitsWarpers = [] 112 | if top_P is not None: 113 | LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3)) 114 | if top_K is not None: 115 | LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3)) 116 | 117 | LogitsProcessors = [] 118 | if repetition_penalty is not None and repetition_penalty != 1: 119 | LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, len(models['tokenizer']), 16)) 120 | 121 | emb = gpt(input_ids,text_mask) 122 | del text_mask 123 | 124 | result = gpt.generate( 125 | emb, input_ids, 126 | temperature = torch.tensor([temperature,], device=device), 127 | attention_mask = text_token['attention_mask'], 128 | LogitsWarpers = LogitsWarpers, 129 | LogitsProcessors = LogitsProcessors, 130 | eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None], 131 | max_new_token = max_new_token, 132 | infer_text = True, 133 | stream = False, 134 | **kwargs 135 | ) 136 | 137 | del_all(text_token) 138 | del emb, text_token, input_ids 139 | del_all(LogitsWarpers) 140 | del_all(LogitsProcessors) 141 | 142 | return next(result) 143 | -------------------------------------------------------------------------------- /nodes/ChatTTS/model/__pycache__/dvae.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadowcz007/Comfyui-ChatTTS/3dbed17f2f858b1eef3d7415b5cb718d52ec3842/nodes/ChatTTS/model/__pycache__/dvae.cpython-311.pyc -------------------------------------------------------------------------------- /nodes/ChatTTS/model/__pycache__/gpt.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadowcz007/Comfyui-ChatTTS/3dbed17f2f858b1eef3d7415b5cb718d52ec3842/nodes/ChatTTS/model/__pycache__/gpt.cpython-311.pyc -------------------------------------------------------------------------------- /nodes/ChatTTS/model/dvae.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional 3 | 4 | import numpy as np 5 | import pybase16384 as b14 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from vector_quantize_pytorch import GroupedResidualFSQ 10 | 11 | class ConvNeXtBlock(nn.Module): 12 | def __init__( 13 | self, 14 | dim: int, 15 | intermediate_dim: int, 16 | kernel: int, dilation: int, 17 | layer_scale_init_value: float = 1e-6, 18 | ): 19 | # ConvNeXt Block copied from Vocos. 20 | super().__init__() 21 | self.dwconv = nn.Conv1d(dim, dim, 22 | kernel_size=kernel, padding=dilation*(kernel//2), 23 | dilation=dilation, groups=dim 24 | ) # depthwise conv 25 | 26 | self.norm = nn.LayerNorm(dim, eps=1e-6) 27 | self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers 28 | self.act = nn.GELU() 29 | self.pwconv2 = nn.Linear(intermediate_dim, dim) 30 | self.gamma = ( 31 | nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) 32 | if layer_scale_init_value > 0 33 | else None 34 | ) 35 | 36 | def forward(self, x: torch.Tensor, cond = None) -> torch.Tensor: 37 | residual = x 38 | 39 | y = self.dwconv(x) 40 | y.transpose_(1, 2) # (B, C, T) -> (B, T, C) 41 | x = self.norm(y) 42 | del y 43 | y = self.pwconv1(x) 44 | del x 45 | x = self.act(y) 46 | del y 47 | y = self.pwconv2(x) 48 | del x 49 | if self.gamma is not None: 50 | y *= self.gamma 51 | y.transpose_(1, 2) # (B, T, C) -> (B, C, T) 52 | 53 | x = y + residual 54 | del y 55 | 56 | return x 57 | 58 | 59 | class GFSQ(nn.Module): 60 | 61 | def __init__(self, 62 | dim: int, levels: List[int], G: int, R: int, eps=1e-5, transpose = True 63 | ): 64 | super(GFSQ, self).__init__() 65 | self.quantizer = GroupedResidualFSQ( 66 | dim=dim, 67 | levels=levels, 68 | num_quantizers=R, 69 | groups=G, 70 | ) 71 | self.n_ind = math.prod(levels) 72 | self.eps = eps 73 | self.transpose = transpose 74 | self.G = G 75 | self.R = R 76 | 77 | def _embed(self, x: torch.Tensor): 78 | if self.transpose: 79 | x = x.transpose(1, 2) 80 | """ 81 | x = rearrange( 82 | x, "b t (g r) -> g b t r", g = self.G, r = self.R, 83 | ) 84 | """ 85 | x = x.view(x.size(0), x.size(1), self.G, self.R).permute(2, 0, 1, 3) 86 | feat = self.quantizer.get_output_from_indices(x) 87 | return feat.transpose_(1,2) if self.transpose else feat 88 | 89 | def forward(self, x): 90 | if self.transpose: 91 | x = x.transpose(1, 2) 92 | feat, ind = self.quantizer(x) 93 | """ 94 | ind = rearrange( 95 | ind, "g b t r ->b t (g r)", 96 | ) 97 | """ 98 | ind = ind.permute(1, 2, 0, 3).contiguous() 99 | ind = ind.view(ind.size(0), ind.size(1), -1) 100 | embed_onehot_tmp = F.one_hot(ind.long(), self.n_ind) 101 | embed_onehot = embed_onehot_tmp.to(x.dtype) 102 | del embed_onehot_tmp 103 | e_mean = torch.mean(embed_onehot, dim=[0,1]) 104 | # e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1) 105 | torch.div(e_mean, (e_mean.sum(dim=1) + self.eps).unsqueeze(1), out=e_mean) 106 | perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1)) 107 | 108 | return ( 109 | torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device), 110 | feat.transpose_(1,2) if self.transpose else feat, 111 | perplexity, 112 | None, 113 | ind.transpose_(1,2) if self.transpose else ind, 114 | ) 115 | 116 | class DVAEDecoder(nn.Module): 117 | def __init__(self, idim: int, odim: int, 118 | n_layer = 12, bn_dim = 64, hidden = 256, 119 | kernel = 7, dilation = 2, up = False 120 | ): 121 | super().__init__() 122 | self.up = up 123 | self.conv_in = nn.Sequential( 124 | nn.Conv1d(idim, bn_dim, 3, 1, 1), nn.GELU(), 125 | nn.Conv1d(bn_dim, hidden, 3, 1, 1) 126 | ) 127 | self.decoder_block = nn.ModuleList([ 128 | ConvNeXtBlock(hidden, hidden* 4, kernel, dilation,) 129 | for _ in range(n_layer)]) 130 | self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False) 131 | 132 | def forward(self, input: torch.Tensor, conditioning=None) -> torch.Tensor: 133 | # B, T, C 134 | x = input.transpose_(1, 2) 135 | y = self.conv_in(x) 136 | del x 137 | for f in self.decoder_block: 138 | y = f(y, conditioning) 139 | 140 | x = self.conv_out(y) 141 | del y 142 | return x.transpose_(1, 2) 143 | 144 | 145 | class DVAE(nn.Module): 146 | def __init__( 147 | self, decoder_config, vq_config, dim=512, coef: Optional[str] = None, 148 | ): 149 | super().__init__() 150 | if coef is None: 151 | coef = torch.rand(100) 152 | else: 153 | coef = torch.from_numpy(np.copy(np.frombuffer(b14.decode_from_string(coef), dtype=np.float32))) 154 | self.register_buffer('coef', coef.unsqueeze(0).unsqueeze_(2)) 155 | 156 | self.decoder = DVAEDecoder(**decoder_config) 157 | self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False) 158 | if vq_config is not None: 159 | self.vq_layer = GFSQ(**vq_config) 160 | else: 161 | self.vq_layer = None 162 | 163 | def __repr__(self) -> str: 164 | return b14.encode_to_string(self.coef.cpu().numpy().astype(np.float32).tobytes()) 165 | 166 | def forward(self, inp: torch.Tensor) -> torch.Tensor: 167 | with torch.no_grad(): 168 | 169 | if self.vq_layer is not None: 170 | vq_feats = self.vq_layer._embed(inp) 171 | else: 172 | vq_feats = inp.detach().clone() 173 | 174 | vq_feats = vq_feats.view( 175 | (vq_feats.size(0), 2, vq_feats.size(1)//2, vq_feats.size(2)), 176 | ).permute(0, 2, 3, 1).flatten(2) 177 | 178 | dec_out = self.out_conv( 179 | self.decoder( 180 | input=vq_feats.transpose_(1, 2), 181 | ).transpose_(1, 2), 182 | ) 183 | 184 | return torch.mul(dec_out, self.coef, out=dec_out) 185 | -------------------------------------------------------------------------------- /nodes/ChatTTS/utils/__pycache__/gpu_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadowcz007/Comfyui-ChatTTS/3dbed17f2f858b1eef3d7415b5cb718d52ec3842/nodes/ChatTTS/utils/__pycache__/gpu_utils.cpython-311.pyc -------------------------------------------------------------------------------- /nodes/ChatTTS/utils/__pycache__/infer_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadowcz007/Comfyui-ChatTTS/3dbed17f2f858b1eef3d7415b5cb718d52ec3842/nodes/ChatTTS/utils/__pycache__/infer_utils.cpython-311.pyc -------------------------------------------------------------------------------- /nodes/ChatTTS/utils/__pycache__/io_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadowcz007/Comfyui-ChatTTS/3dbed17f2f858b1eef3d7415b5cb718d52ec3842/nodes/ChatTTS/utils/__pycache__/io_utils.cpython-311.pyc -------------------------------------------------------------------------------- /nodes/ChatTTS/utils/dl.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import hashlib 4 | import requests 5 | from io import BytesIO 6 | from mmap import mmap, ACCESS_READ 7 | 8 | from .log import logger 9 | 10 | def sha256(fileno: int) -> str: 11 | data = mmap(fileno, 0, access=ACCESS_READ) 12 | h = hashlib.sha256(data).hexdigest() 13 | del data 14 | return h 15 | 16 | 17 | def check_model( 18 | dir_name: Path, model_name: str, hash: str, remove_incorrect=False 19 | ) -> bool: 20 | target = dir_name / model_name 21 | relname = target.as_posix() 22 | logger.get_logger().debug(f"checking {relname}...") 23 | if not os.path.exists(target): 24 | logger.get_logger().info(f"{target} not exist.") 25 | return False 26 | with open(target, "rb") as f: 27 | digest = sha256(f.fileno()) 28 | bakfile = f"{target}.bak" 29 | if digest != hash: 30 | logger.get_logger().warn(f"{target} sha256 hash mismatch.") 31 | logger.get_logger().info(f"expected: {hash}") 32 | logger.get_logger().info(f"real val: {digest}") 33 | logger.get_logger().warn("please add parameter --update to download the latest assets.") 34 | if remove_incorrect: 35 | if not os.path.exists(bakfile): 36 | os.rename(str(target), bakfile) 37 | else: 38 | os.remove(str(target)) 39 | return False 40 | if remove_incorrect and os.path.exists(bakfile): 41 | os.remove(bakfile) 42 | return True 43 | 44 | 45 | def check_all_assets(update=False) -> bool: 46 | BASE_DIR = Path(os.getcwd()) 47 | 48 | logger.get_logger().info("checking assets...") 49 | current_dir = BASE_DIR / "asset" 50 | names = [ 51 | "Decoder.pt", 52 | "DVAE.pt", 53 | "GPT.pt", 54 | "spk_stat.pt", 55 | "tokenizer.pt", 56 | "Vocos.pt", 57 | ] 58 | for model in names: 59 | menv = model.replace(".", "_") 60 | if not check_model( 61 | current_dir, model, os.environ[f"sha256_asset_{menv}"], update 62 | ): 63 | return False 64 | 65 | logger.get_logger().info("checking configs...") 66 | current_dir = BASE_DIR / "config" 67 | names = [ 68 | "decoder.yaml", 69 | "dvae.yaml", 70 | "gpt.yaml", 71 | "path.yaml", 72 | "vocos.yaml", 73 | ] 74 | for model in names: 75 | menv = model.replace(".", "_") 76 | if not check_model( 77 | current_dir, model, os.environ[f"sha256_config_{menv}"], update 78 | ): 79 | return False 80 | 81 | logger.get_logger().info("all assets are already latest.") 82 | return True 83 | 84 | 85 | def download_and_extract_tar_gz(url: str, folder: str): 86 | import tarfile 87 | 88 | logger.get_logger().info(f"downloading {url}") 89 | response = requests.get(url, stream=True, timeout=(5, 10)) 90 | with BytesIO() as out_file: 91 | out_file.write(response.content) 92 | out_file.seek(0) 93 | logger.get_logger().info(f"downloaded.") 94 | with tarfile.open(fileobj=out_file, mode="r:gz") as tar: 95 | tar.extractall(folder) 96 | logger.get_logger().info(f"extracted into {folder}") 97 | 98 | 99 | def download_and_extract_zip(url: str, folder: str): 100 | import zipfile 101 | 102 | logger.get_logger().info(f"downloading {url}") 103 | response = requests.get(url, stream=True, timeout=(5, 10)) 104 | with BytesIO() as out_file: 105 | out_file.write(response.content) 106 | out_file.seek(0) 107 | logger.get_logger().info(f"downloaded.") 108 | with zipfile.ZipFile(out_file) as zip_ref: 109 | zip_ref.extractall(folder) 110 | logger.get_logger().info(f"extracted into {folder}") 111 | 112 | 113 | def download_dns_yaml(url: str, folder: str): 114 | logger.get_logger().info(f"downloading {url}") 115 | response = requests.get(url, stream=True, timeout=(5, 10)) 116 | with open(os.path.join(folder, "dns.yaml"), "wb") as out_file: 117 | out_file.write(response.content) 118 | logger.get_logger().info(f"downloaded into {folder}") 119 | 120 | 121 | def download_all_assets(tmpdir: str, version="0.2.5"): 122 | import subprocess 123 | import platform 124 | 125 | archs = { 126 | "aarch64": "arm64", 127 | "armv8l": "arm64", 128 | "arm64": "arm64", 129 | "x86": "386", 130 | "i386": "386", 131 | "i686": "386", 132 | "386": "386", 133 | "x86_64": "amd64", 134 | "x64": "amd64", 135 | "amd64": "amd64", 136 | } 137 | system_type = platform.system().lower() 138 | architecture = platform.machine().lower() 139 | is_win = system_type == "windows" 140 | 141 | architecture = archs.get(architecture, None) 142 | if not architecture: 143 | logger.get_logger().error(f"architecture {architecture} is not supported") 144 | exit(1) 145 | try: 146 | BASE_URL = "https://github.com/fumiama/RVC-Models-Downloader/releases/download/" 147 | suffix = "zip" if is_win else "tar.gz" 148 | RVCMD_URL = BASE_URL + f"v{version}/rvcmd_{system_type}_{architecture}.{suffix}" 149 | cmdfile = os.path.join(tmpdir, "rvcmd") 150 | if is_win: 151 | download_and_extract_zip(RVCMD_URL, tmpdir) 152 | cmdfile += ".exe" 153 | else: 154 | download_and_extract_tar_gz(RVCMD_URL, tmpdir) 155 | os.chmod(cmdfile, 0o755) 156 | subprocess.run([cmdfile, "-notui", "-w", "0", "assets/chtts"]) 157 | except Exception: 158 | BASE_URL = "https://raw.gitcode.com/u011570312/RVC-Models-Downloader/assets/" 159 | suffix = { 160 | "darwin_amd64": "555", 161 | "darwin_arm64": "556", 162 | "linux_386": "557", 163 | "linux_amd64": "558", 164 | "linux_arm64": "559", 165 | "windows_386": "562", 166 | "windows_amd64": "563", 167 | }[f"{system_type}_{architecture}"] 168 | RVCMD_URL = BASE_URL + suffix 169 | download_dns_yaml( 170 | "https://raw.gitcode.com/u011570312/RVC-Models-Downloader/raw/main/dns.yaml", 171 | tmpdir, 172 | ) 173 | if is_win: 174 | download_and_extract_zip(RVCMD_URL, tmpdir) 175 | cmdfile += ".exe" 176 | else: 177 | download_and_extract_tar_gz(RVCMD_URL, tmpdir) 178 | os.chmod(cmdfile, 0o755) 179 | subprocess.run( 180 | [ 181 | cmdfile, 182 | "-notui", 183 | "-w", 184 | "0", 185 | "-dns", 186 | os.path.join(tmpdir, "dns.yaml"), 187 | "assets/chtts", 188 | ] 189 | ) 190 | -------------------------------------------------------------------------------- /nodes/ChatTTS/utils/gpu.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from .log import logger 5 | 6 | def select_device(min_memory=2048): 7 | if torch.cuda.is_available(): 8 | available_gpus = [] 9 | for i in range(torch.cuda.device_count()): 10 | props = torch.cuda.get_device_properties(i) 11 | free_memory = props.total_memory - torch.cuda.memory_reserved(i) 12 | available_gpus.append((i, free_memory)) 13 | selected_gpu, max_free_memory = max(available_gpus, key=lambda x: x[1]) 14 | device = torch.device(f'cuda:{selected_gpu}') 15 | free_memory_mb = max_free_memory / (1024 * 1024) 16 | if free_memory_mb < min_memory: 17 | logger.get_logger().warning(f'GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU.') 18 | device = torch.device('cpu') 19 | elif torch.backends.mps.is_available(): 20 | # For Apple M1/M2 chips with Metal Performance Shaders 21 | logger.get_logger().info('Apple GPU found, using MPS.') 22 | device = torch.device('mps') 23 | else: 24 | logger.get_logger().warning('No GPU found, use CPU instead') 25 | device = torch.device('cpu') 26 | 27 | return device 28 | -------------------------------------------------------------------------------- /nodes/ChatTTS/utils/gpu_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import logging 4 | 5 | def select_device(min_memory = 2048): 6 | logger = logging.getLogger(__name__) 7 | if torch.cuda.is_available(): 8 | available_gpus = [] 9 | for i in range(torch.cuda.device_count()): 10 | props = torch.cuda.get_device_properties(i) 11 | free_memory = props.total_memory - torch.cuda.memory_reserved(i) 12 | available_gpus.append((i, free_memory)) 13 | selected_gpu, max_free_memory = max(available_gpus, key=lambda x: x[1]) 14 | device = torch.device(f'cuda:{selected_gpu}') 15 | free_memory_mb = max_free_memory / (1024 * 1024) 16 | if free_memory_mb < min_memory: 17 | logger.log(logging.WARNING, f'GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left.') 18 | device = torch.device('cpu') 19 | else: 20 | logger.log(logging.WARNING, f'No GPU found, use CPU instead') 21 | device = torch.device('cpu') 22 | 23 | return device 24 | -------------------------------------------------------------------------------- /nodes/ChatTTS/utils/infer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from typing import Dict, Tuple, List 4 | import sys 5 | 6 | from numba import jit 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | class CustomRepetitionPenaltyLogitsProcessorRepeat(): 13 | 14 | def __init__(self, penalty: float, max_input_ids, past_window): 15 | if not isinstance(penalty, float) or not (penalty > 0): 16 | raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") 17 | 18 | self.penalty = penalty 19 | self.max_input_ids = max_input_ids 20 | self.past_window = past_window 21 | 22 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 23 | 24 | input_ids = input_ids[:, -self.past_window:] 25 | freq = F.one_hot(input_ids, scores.size(1)).sum(1) 26 | freq[self.max_input_ids:] = 0 27 | alpha = self.penalty**freq 28 | scores = scores.contiguous() 29 | scores = torch.where(scores < 0, scores*alpha, scores/alpha) 30 | 31 | return scores 32 | 33 | class CustomRepetitionPenaltyLogitsProcessor(): 34 | 35 | def __init__(self, penalty: float, max_input_ids, past_window): 36 | if not isinstance(penalty, float) or not (penalty > 0): 37 | raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") 38 | 39 | self.penalty = penalty 40 | self.max_input_ids = max_input_ids 41 | self.past_window = past_window 42 | 43 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 44 | 45 | input_ids = input_ids[:, -self.past_window:] 46 | score = torch.gather(scores, 1, input_ids) 47 | _score = score.detach().clone() 48 | score = torch.where(score < 0, score * self.penalty, score / self.penalty) 49 | score[input_ids>=self.max_input_ids] = _score[input_ids>=self.max_input_ids] 50 | scores.scatter_(1, input_ids, score) 51 | 52 | return scores 53 | 54 | @jit 55 | def _find_index(table: np.ndarray, val: np.uint16): 56 | for i in range(table.size): 57 | if table[i] == val: 58 | return i 59 | return -1 60 | 61 | @jit 62 | def _fast_replace(table: np.ndarray, text: bytes) -> Tuple[np.ndarray, List[Tuple[str, str]]]: 63 | result = np.frombuffer(text, dtype=np.uint16).copy() 64 | replaced_words = [] 65 | for i in range(result.size): 66 | ch = result[i] 67 | p = _find_index(table[0], ch) 68 | if p >= 0: 69 | repl_char = table[1][p] 70 | result[i] = repl_char 71 | replaced_words.append((chr(ch), chr(repl_char))) 72 | return result, replaced_words 73 | 74 | class HomophonesReplacer: 75 | """ 76 | Homophones Replacer 77 | 78 | Replace the mispronounced characters with correctly pronounced ones. 79 | 80 | Creation process of homophones_map.json: 81 | 82 | 1. Establish a word corpus using the [Tencent AI Lab Embedding Corpora v0.2.0 large] with 12 million entries. After cleaning, approximately 1.8 million entries remain. Use ChatTTS to infer the text. 83 | 2. Record discrepancies between the inferred and input text, identifying about 180,000 misread words. 84 | 3. Create a pinyin to common characters mapping using correctly read characters by ChatTTS. 85 | 4. For each discrepancy, extract the correct pinyin using [python-pinyin] and find homophones with the correct pronunciation from the mapping. 86 | 87 | Thanks to: 88 | [Tencent AI Lab Embedding Corpora for Chinese and English Words and Phrases](https://ai.tencent.com/ailab/nlp/en/embedding.html) 89 | [python-pinyin](https://github.com/mozillazg/python-pinyin) 90 | 91 | """ 92 | def __init__(self, map_file_path: str): 93 | self.homophones_map = self._load_homophones_map(map_file_path) 94 | self.coding = "utf-16-le" if sys.byteorder == "little" else "utf-16-be" 95 | 96 | def _load_homophones_map(self, map_file_path: str) -> np.ndarray: 97 | with open(map_file_path, 'r', encoding='utf-8') as f: 98 | homophones_map: Dict[str, str] = json.load(f) 99 | map = np.empty((2, len(homophones_map)), dtype=np.uint32) 100 | for i, k in enumerate(homophones_map.keys()): 101 | map[:, i] = (ord(k), ord(homophones_map[k])) 102 | del homophones_map 103 | return map 104 | 105 | def replace(self, text: str): 106 | arr, lst = _fast_replace( 107 | self.homophones_map, 108 | text.encode(self.coding), 109 | ) 110 | return arr.tobytes().decode(self.coding), lst 111 | 112 | accept_pattern = re.compile(r'[^\u4e00-\u9fffA-Za-z,。、,\. ]') 113 | sub_pattern = re.compile(r'\[uv_break\]|\[laugh\]|\[lbreak\]') 114 | 115 | def count_invalid_characters(s: str): 116 | global accept_pattern, sub_pattern 117 | s = sub_pattern.sub('', s) 118 | non_alphabetic_chinese_chars = accept_pattern.findall(s) 119 | return set(non_alphabetic_chinese_chars) 120 | 121 | chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]') 122 | english_word_pattern = re.compile(r'\b[A-Za-z]+\b') 123 | 124 | def detect_language(sentence): 125 | global chinese_char_pattern, english_word_pattern 126 | 127 | chinese_chars = chinese_char_pattern.findall(sentence) 128 | english_words = english_word_pattern.findall(sentence) 129 | 130 | if len(chinese_chars) > len(english_words): 131 | return "zh" 132 | else: 133 | return "en" 134 | 135 | 136 | character_simplifier = str.maketrans({ 137 | ':': ',', 138 | ';': ',', 139 | '!': '。', 140 | '(': ',', 141 | ')': ',', 142 | '【': ',', 143 | '】': ',', 144 | '『': ',', 145 | '』': ',', 146 | '「': ',', 147 | '」': ',', 148 | '《': ',', 149 | '》': ',', 150 | '-': ',', 151 | '‘': '', 152 | '“': '', 153 | '’': '', 154 | '”': '', 155 | ':': ',', 156 | ';': ',', 157 | '!': '.', 158 | '(': ',', 159 | ')': ',', 160 | '[': ',', 161 | ']': ',', 162 | '>': ',', 163 | '<': ',', 164 | '-': ',', 165 | }) 166 | 167 | halfwidth_2_fullwidth = str.maketrans({ 168 | '!': '!', 169 | '"': '“', 170 | "'": '‘', 171 | '#': '#', 172 | '$': '$', 173 | '%': '%', 174 | '&': '&', 175 | '(': '(', 176 | ')': ')', 177 | ',': ',', 178 | '-': '-', 179 | '*': '*', 180 | '+': '+', 181 | '.': '。', 182 | '/': '/', 183 | ':': ':', 184 | ';': ';', 185 | '<': '<', 186 | '=': '=', 187 | '>': '>', 188 | '?': '?', 189 | '@': '@', 190 | # '[': '[', 191 | '\\': '\', 192 | # ']': ']', 193 | '^': '^', 194 | # '_': '_', 195 | '`': '`', 196 | '{': '{', 197 | '|': '|', 198 | '}': '}', 199 | '~': '~' 200 | }) 201 | 202 | def apply_half2full_map(text: str) -> str: 203 | return text.translate(halfwidth_2_fullwidth) 204 | 205 | def apply_character_map(text: str) -> str: 206 | return text.translate(character_simplifier) 207 | -------------------------------------------------------------------------------- /nodes/ChatTTS/utils/infer_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class CustomRepetitionPenaltyLogitsProcessorRepeat(): 8 | 9 | def __init__(self, penalty: float, max_input_ids, past_window): 10 | if not isinstance(penalty, float) or not (penalty > 0): 11 | raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") 12 | 13 | self.penalty = penalty 14 | self.max_input_ids = max_input_ids 15 | self.past_window = past_window 16 | 17 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 18 | 19 | input_ids = input_ids[:, -self.past_window:] 20 | freq = F.one_hot(input_ids, scores.size(1)).sum(1) 21 | freq[self.max_input_ids:] = 0 22 | alpha = self.penalty**freq 23 | scores = torch.where(scores < 0, scores*alpha, scores/alpha) 24 | 25 | return scores 26 | 27 | class CustomRepetitionPenaltyLogitsProcessor(): 28 | 29 | def __init__(self, penalty: float, max_input_ids, past_window): 30 | if not isinstance(penalty, float) or not (penalty > 0): 31 | raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") 32 | 33 | self.penalty = penalty 34 | self.max_input_ids = max_input_ids 35 | self.past_window = past_window 36 | 37 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 38 | 39 | input_ids = input_ids[:, -self.past_window:] 40 | score = torch.gather(scores, 1, input_ids) 41 | _score = score.detach().clone() 42 | score = torch.where(score < 0, score * self.penalty, score / self.penalty) 43 | score[input_ids>=self.max_input_ids] = _score[input_ids>=self.max_input_ids] 44 | scores.scatter_(1, input_ids, score) 45 | 46 | return scores 47 | 48 | def count_invalid_characters(s): 49 | 50 | s = re.sub(r'\[uv_break\]|\[laugh\]|\[lbreak\]', '', s) 51 | pattern = re.compile(r'[^\u4e00-\u9fffA-Za-z,。、,\. ]') 52 | non_alphabetic_chinese_chars = pattern.findall(s) 53 | return set(non_alphabetic_chinese_chars) 54 | 55 | def detect_language(sentence): 56 | 57 | chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]') 58 | english_word_pattern = re.compile(r'\b[A-Za-z]+\b') 59 | 60 | chinese_chars = chinese_char_pattern.findall(sentence) 61 | english_words = english_word_pattern.findall(sentence) 62 | 63 | if len(chinese_chars) > len(english_words): 64 | return "zh" 65 | else: 66 | return "en" 67 | 68 | 69 | character_map = { 70 | ':': ',', 71 | ';': ',', 72 | '!': '。', 73 | '(': ',', 74 | ')': ',', 75 | '【': ',', 76 | '】': ',', 77 | '『': ',', 78 | '』': ',', 79 | '「': ',', 80 | '」': ',', 81 | '《': ',', 82 | '》': ',', 83 | '-': ',', 84 | '‘': '', 85 | '“': '', 86 | '’': '', 87 | '”': '', 88 | ':': ',', 89 | ';': ',', 90 | '!': '.', 91 | '(': ',', 92 | ')': ',', 93 | '[': ',', 94 | ']': ',', 95 | '>': ',', 96 | '<': ',', 97 | '-': ',', 98 | '…': '', 99 | '—': ',', 100 | '_': ',', 101 | '?': ',', 102 | } 103 | 104 | halfwidth_2_fullwidth_map = { 105 | '!': '!', 106 | '"': '“', 107 | "'": '‘', 108 | '#': '#', 109 | '$': '$', 110 | '%': '%', 111 | '&': '&', 112 | '(': '(', 113 | ')': ')', 114 | ',': ',', 115 | '-': '-', 116 | '*': '*', 117 | '+': '+', 118 | '.': '。', 119 | '/': '/', 120 | ':': ':', 121 | ';': ';', 122 | '<': '<', 123 | '=': '=', 124 | '>': '>', 125 | '?': '?', 126 | '@': '@', 127 | # '[': '[', 128 | '\\': '\', 129 | # ']': ']', 130 | '^': '^', 131 | # '_': '_', 132 | '`': '`', 133 | '{': '{', 134 | '|': '|', 135 | '}': '}', 136 | '~': '~' 137 | } 138 | 139 | def apply_half2full_map(text): 140 | translation_table = str.maketrans(halfwidth_2_fullwidth_map) 141 | return text.translate(translation_table) 142 | 143 | def apply_character_map(text): 144 | translation_table = str.maketrans(character_map) 145 | return text.translate(translation_table) -------------------------------------------------------------------------------- /nodes/ChatTTS/utils/io.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import logging 4 | from typing import Union 5 | from dataclasses import is_dataclass 6 | 7 | from .log import logger 8 | 9 | def get_latest_modified_file(directory): 10 | 11 | files = [os.path.join(directory, f) for f in os.listdir(directory)] 12 | if not files: 13 | logger.get_logger().log(logging.WARNING, f'no files found in the directory: {directory}') 14 | return None 15 | latest_file = max(files, key=os.path.getmtime) 16 | 17 | return latest_file 18 | 19 | def del_all(d: Union[dict, list]): 20 | if is_dataclass(d): 21 | for k in list(vars(d).keys()): 22 | x = getattr(d, k) 23 | if isinstance(x, dict) or isinstance(x, list) or is_dataclass(x): 24 | del_all(x) 25 | del x 26 | delattr(d, k) 27 | elif isinstance(d, dict): 28 | lst = list(d.keys()) 29 | for k in lst: 30 | x = d.pop(k) 31 | if isinstance(x, dict) or isinstance(x, list) or is_dataclass(x): 32 | del_all(x) 33 | del x 34 | elif isinstance(d, list): 35 | while len(d): 36 | x = d.pop() 37 | if isinstance(x, dict) or isinstance(x, list) or is_dataclass(x): 38 | del_all(x) 39 | del x 40 | else: 41 | del d 42 | 43 | -------------------------------------------------------------------------------- /nodes/ChatTTS/utils/io_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import logging 4 | 5 | def get_latest_modified_file(directory): 6 | logger = logging.getLogger(__name__) 7 | 8 | files = [os.path.join(directory, f) for f in os.listdir(directory)] 9 | if not files: 10 | logger.log(logging.WARNING, f'No files found in the directory: {directory}') 11 | return None 12 | latest_file = max(files, key=os.path.getmtime) 13 | 14 | return latest_file -------------------------------------------------------------------------------- /nodes/ChatTTS/utils/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | class Logger(): 5 | def __init__(self, logger=logging.getLogger(Path(__file__).parent.name)): 6 | self.logger = logger 7 | 8 | def set_logger(self, logger: logging.Logger): 9 | self.logger = logger 10 | 11 | def get_logger(self) -> logging.Logger: 12 | return self.logger 13 | 14 | logger = Logger() 15 | -------------------------------------------------------------------------------- /nodes/__pycache__/chat_tts.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadowcz007/Comfyui-ChatTTS/3dbed17f2f858b1eef3d7415b5cb718d52ec3842/nodes/__pycache__/chat_tts.cpython-311.pyc -------------------------------------------------------------------------------- /nodes/__pycache__/chat_tts_run.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadowcz007/Comfyui-ChatTTS/3dbed17f2f858b1eef3d7415b5cb718d52ec3842/nodes/__pycache__/chat_tts_run.cpython-311.pyc -------------------------------------------------------------------------------- /nodes/chat_tts_run.py: -------------------------------------------------------------------------------- 1 | import ChatTTS 2 | 3 | import torchaudio,torch 4 | 5 | import folder_paths 6 | 7 | import os 8 | 9 | # 修改模型的本地缓存地址 10 | # os.environ['HF_HOME'] = os.path.join(folder_paths.models_dir,'chat_tts') 11 | 12 | def get_model_dir(m): 13 | try: 14 | return folder_paths.get_folder_paths(m)[0] 15 | except: 16 | return os.path.join(folder_paths.models_dir, m) 17 | 18 | model_local_path=get_model_dir('chat_tts') 19 | 20 | # 写一个python文件,用来 判断文件夹内命名为 所有chat_tts开头的文件数量(chat_tts_00001),并输出新的编号 21 | def get_new_counter(full_output_folder, filename_prefix): 22 | # 获取目录中的所有文件 23 | files = os.listdir(full_output_folder) 24 | 25 | # 过滤出以 filename_prefix 开头并且后续部分为数字的文件 26 | filtered_files = [] 27 | for f in files: 28 | if f.startswith(filename_prefix): 29 | # 去掉文件名中的前缀和后缀,只保留中间的数字部分 30 | base_name = f[len(filename_prefix)+1:] 31 | number_part = base_name.split('.')[0] # 假设文件名中只有一个点,即扩展名 32 | if number_part.isdigit(): 33 | filtered_files.append(int(number_part)) 34 | 35 | if not filtered_files: 36 | return 1 37 | 38 | # 获取最大的编号 39 | max_number = max(filtered_files) 40 | 41 | # 新的编号 42 | return max_number + 1 43 | 44 | 45 | def run(audio_file,texts, 46 | rand_spk, 47 | uv_speed=None, 48 | uv_oral=None, 49 | uv_laugh=None, 50 | uv_break=None, 51 | skip_refine_text=False): 52 | # 需要运行chat tts 的代码 53 | 54 | output_dir = folder_paths.get_output_directory() 55 | 56 | counter=get_new_counter(output_dir,audio_file) 57 | # print('#audio_path',folder_paths, ) 58 | # 添加文件名后缀 59 | audio_file = f"{audio_file}_{counter:05}.wav" 60 | 61 | audio_path=os.path.join(output_dir, audio_file) 62 | 63 | # from IPython.display import Audio 64 | # print('#audio_path',audio_path) 65 | chat = ChatTTS.Chat() 66 | chat.load_models(source="custom",custom_path=model_local_path,compile=False) # 设置为True以获得更快速度 67 | 68 | # texts = [text,] 69 | 70 | params_refine_text = { 71 | 'prompt': f'' 72 | } 73 | 74 | if uv_oral: 75 | params_refine_text['prompt']+=f'[oral_{uv_oral}]' 76 | 77 | if uv_laugh: 78 | params_refine_text['prompt']+=f'[laugh_{uv_laugh}]' 79 | 80 | if uv_break: 81 | params_refine_text['prompt']+=f'[break_{uv_break}]' 82 | 83 | if uv_speed: 84 | params_refine_text['prompt']+=f'[speed_{uv_speed}]' 85 | 86 | if rand_spk is None: 87 | rand_spk = chat.sample_random_speaker() 88 | 89 | print('params_refine_text',params_refine_text,texts) 90 | 91 | params_infer_code = { 92 | 'spk_emb': rand_spk, # add sampled speaker 93 | 'temperature': .3, # using custom temperature 94 | 'top_P': 0.7, # top P decode 95 | 'top_K': 20, # top K decode 96 | } 97 | 98 | 99 | # ChatTTS使用pynini对中英文进行处理,目前在window上安装报错,需要编译环境, 100 | # 暂时把do_text_normalization关掉 101 | wavs = chat.infer(texts, 102 | use_decoder=True, 103 | do_text_normalization=False, 104 | params_refine_text=params_refine_text, 105 | params_infer_code=params_infer_code, 106 | # progress_callback=progress_callback, 107 | skip_refine_text=skip_refine_text, 108 | ) 109 | 110 | wavs = [torch.tensor(wav) for wav in wavs] 111 | combined_waveform = torch.cat(wavs, dim=1) 112 | 113 | torchaudio.save(audio_path, combined_waveform, 24000) 114 | 115 | return ({ 116 | "filename": audio_file, 117 | "subfolder": "", 118 | "type": "output", 119 | "prompt":"".join(texts), 120 | "audio_path":audio_path 121 | },rand_spk) -------------------------------------------------------------------------------- /nodes/openvoice/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadowcz007/Comfyui-ChatTTS/3dbed17f2f858b1eef3d7415b5cb718d52ec3842/nodes/openvoice/__init__.py -------------------------------------------------------------------------------- /nodes/openvoice/api.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import re 4 | import soundfile 5 | from openvoice import utils 6 | from openvoice import commons 7 | import os 8 | import librosa 9 | from openvoice.text import text_to_sequence 10 | from openvoice.mel_processing import spectrogram_torch 11 | from openvoice.models import SynthesizerTrn 12 | 13 | 14 | class OpenVoiceBaseClass(object): 15 | def __init__(self, 16 | config_path, 17 | device='cuda:0'): 18 | if 'cuda' in device: 19 | assert torch.cuda.is_available() 20 | 21 | hps = utils.get_hparams_from_file(config_path) 22 | 23 | model = SynthesizerTrn( 24 | len(getattr(hps, 'symbols', [])), 25 | hps.data.filter_length // 2 + 1, 26 | n_speakers=hps.data.n_speakers, 27 | **hps.model, 28 | ).to(device) 29 | 30 | model.eval() 31 | self.model = model 32 | self.hps = hps 33 | self.device = device 34 | 35 | def load_ckpt(self, ckpt_path): 36 | checkpoint_dict = torch.load(ckpt_path, map_location=torch.device(self.device)) 37 | a, b = self.model.load_state_dict(checkpoint_dict['model'], strict=False) 38 | print("Loaded checkpoint '{}'".format(ckpt_path)) 39 | print('missing/unexpected keys:', a, b) 40 | 41 | 42 | class BaseSpeakerTTS(OpenVoiceBaseClass): 43 | language_marks = { 44 | "english": "EN", 45 | "chinese": "ZH", 46 | } 47 | 48 | @staticmethod 49 | def get_text(text, hps, is_symbol): 50 | text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners) 51 | if hps.data.add_blank: 52 | text_norm = commons.intersperse(text_norm, 0) 53 | text_norm = torch.LongTensor(text_norm) 54 | return text_norm 55 | 56 | @staticmethod 57 | def audio_numpy_concat(segment_data_list, sr, speed=1.): 58 | audio_segments = [] 59 | for segment_data in segment_data_list: 60 | audio_segments += segment_data.reshape(-1).tolist() 61 | audio_segments += [0] * int((sr * 0.05)/speed) 62 | audio_segments = np.array(audio_segments).astype(np.float32) 63 | return audio_segments 64 | 65 | @staticmethod 66 | def split_sentences_into_pieces(text, language_str): 67 | texts = utils.split_sentence(text, language_str=language_str) 68 | print(" > Text splitted to sentences.") 69 | print('\n'.join(texts)) 70 | print(" > ===========================") 71 | return texts 72 | 73 | def tts(self, text, output_path, speaker, language='English', speed=1.0): 74 | mark = self.language_marks.get(language.lower(), None) 75 | assert mark is not None, f"language {language} is not supported" 76 | 77 | texts = self.split_sentences_into_pieces(text, mark) 78 | 79 | audio_list = [] 80 | for t in texts: 81 | t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t) 82 | t = f'[{mark}]{t}[{mark}]' 83 | stn_tst = self.get_text(t, self.hps, False) 84 | device = self.device 85 | speaker_id = self.hps.speakers[speaker] 86 | with torch.no_grad(): 87 | x_tst = stn_tst.unsqueeze(0).to(device) 88 | x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device) 89 | sid = torch.LongTensor([speaker_id]).to(device) 90 | audio = self.model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.6, 91 | length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy() 92 | audio_list.append(audio) 93 | audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed) 94 | 95 | if output_path is None: 96 | return audio 97 | else: 98 | soundfile.write(output_path, audio, self.hps.data.sampling_rate) 99 | 100 | 101 | class ToneColorConverter(OpenVoiceBaseClass): 102 | def __init__(self, *args, **kwargs): 103 | super().__init__(*args, **kwargs) 104 | 105 | # if kwargs.get('enable_watermark', True): 106 | # import wavmark 107 | # self.watermark_model = wavmark.load_model().to(self.device) 108 | # else: 109 | self.watermark_model = None 110 | self.version = getattr(self.hps, '_version_', "v1") 111 | 112 | 113 | 114 | def extract_se(self, ref_wav_list, se_save_path=None): 115 | if isinstance(ref_wav_list, str): 116 | ref_wav_list = [ref_wav_list] 117 | 118 | device = self.device 119 | hps = self.hps 120 | gs = [] 121 | 122 | for fname in ref_wav_list: 123 | audio_ref, sr = librosa.load(fname, sr=hps.data.sampling_rate) 124 | y = torch.FloatTensor(audio_ref) 125 | y = y.to(device) 126 | y = y.unsqueeze(0) 127 | y = spectrogram_torch(y, hps.data.filter_length, 128 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 129 | center=False).to(device) 130 | with torch.no_grad(): 131 | g = self.model.ref_enc(y.transpose(1, 2)).unsqueeze(-1) 132 | gs.append(g.detach()) 133 | gs = torch.stack(gs).mean(0) 134 | 135 | if se_save_path is not None: 136 | os.makedirs(os.path.dirname(se_save_path), exist_ok=True) 137 | torch.save(gs.cpu(), se_save_path) 138 | 139 | return gs 140 | 141 | def convert(self, audio_src_path, src_se, tgt_se, output_path=None, tau=0.3, message="default"): 142 | hps = self.hps 143 | # load audio 144 | audio, sample_rate = librosa.load(audio_src_path, sr=hps.data.sampling_rate) 145 | audio = torch.tensor(audio).float() 146 | 147 | with torch.no_grad(): 148 | y = torch.FloatTensor(audio).to(self.device) 149 | y = y.unsqueeze(0) 150 | spec = spectrogram_torch(y, hps.data.filter_length, 151 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 152 | center=False).to(self.device) 153 | spec_lengths = torch.LongTensor([spec.size(-1)]).to(self.device) 154 | audio = self.model.voice_conversion(spec, spec_lengths, sid_src=src_se, sid_tgt=tgt_se, tau=tau)[0][ 155 | 0, 0].data.cpu().float().numpy() 156 | audio = self.add_watermark(audio, message) 157 | if output_path is None: 158 | return audio 159 | else: 160 | soundfile.write(output_path, audio, hps.data.sampling_rate) 161 | 162 | def add_watermark(self, audio, message): 163 | if self.watermark_model is None: 164 | return audio 165 | device = self.device 166 | bits = utils.string_to_bits(message).reshape(-1) 167 | n_repeat = len(bits) // 32 168 | 169 | K = 16000 170 | coeff = 2 171 | for n in range(n_repeat): 172 | trunck = audio[(coeff * n) * K: (coeff * n + 1) * K] 173 | if len(trunck) != K: 174 | print('Audio too short, fail to add watermark') 175 | break 176 | message_npy = bits[n * 32: (n + 1) * 32] 177 | 178 | with torch.no_grad(): 179 | signal = torch.FloatTensor(trunck).to(device)[None] 180 | message_tensor = torch.FloatTensor(message_npy).to(device)[None] 181 | signal_wmd_tensor = self.watermark_model.encode(signal, message_tensor) 182 | signal_wmd_npy = signal_wmd_tensor.detach().cpu().squeeze() 183 | audio[(coeff * n) * K: (coeff * n + 1) * K] = signal_wmd_npy 184 | return audio 185 | 186 | def detect_watermark(self, audio, n_repeat): 187 | bits = [] 188 | K = 16000 189 | coeff = 2 190 | for n in range(n_repeat): 191 | trunck = audio[(coeff * n) * K: (coeff * n + 1) * K] 192 | if len(trunck) != K: 193 | print('Audio too short, fail to detect watermark') 194 | return 'Fail' 195 | with torch.no_grad(): 196 | signal = torch.FloatTensor(trunck).to(self.device).unsqueeze(0) 197 | message_decoded_npy = (self.watermark_model.decode(signal) >= 0.5).int().detach().cpu().numpy().squeeze() 198 | bits.append(message_decoded_npy) 199 | bits = np.stack(bits).reshape(-1, 8) 200 | message = utils.bits_to_string(bits) 201 | return message 202 | 203 | -------------------------------------------------------------------------------- /nodes/openvoice/attentions.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from openvoice import commons 7 | import logging 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class LayerNorm(nn.Module): 13 | def __init__(self, channels, eps=1e-5): 14 | super().__init__() 15 | self.channels = channels 16 | self.eps = eps 17 | 18 | self.gamma = nn.Parameter(torch.ones(channels)) 19 | self.beta = nn.Parameter(torch.zeros(channels)) 20 | 21 | def forward(self, x): 22 | x = x.transpose(1, -1) 23 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 24 | return x.transpose(1, -1) 25 | 26 | 27 | @torch.jit.script 28 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 29 | n_channels_int = n_channels[0] 30 | in_act = input_a + input_b 31 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 32 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 33 | acts = t_act * s_act 34 | return acts 35 | 36 | 37 | class Encoder(nn.Module): 38 | def __init__( 39 | self, 40 | hidden_channels, 41 | filter_channels, 42 | n_heads, 43 | n_layers, 44 | kernel_size=1, 45 | p_dropout=0.0, 46 | window_size=4, 47 | isflow=True, 48 | **kwargs 49 | ): 50 | super().__init__() 51 | self.hidden_channels = hidden_channels 52 | self.filter_channels = filter_channels 53 | self.n_heads = n_heads 54 | self.n_layers = n_layers 55 | self.kernel_size = kernel_size 56 | self.p_dropout = p_dropout 57 | self.window_size = window_size 58 | # if isflow: 59 | # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1) 60 | # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1) 61 | # self.cond_layer = weight_norm(cond_layer, name='weight') 62 | # self.gin_channels = 256 63 | self.cond_layer_idx = self.n_layers 64 | if "gin_channels" in kwargs: 65 | self.gin_channels = kwargs["gin_channels"] 66 | if self.gin_channels != 0: 67 | self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels) 68 | # vits2 says 3rd block, so idx is 2 by default 69 | self.cond_layer_idx = ( 70 | kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2 71 | ) 72 | # logging.debug(self.gin_channels, self.cond_layer_idx) 73 | assert ( 74 | self.cond_layer_idx < self.n_layers 75 | ), "cond_layer_idx should be less than n_layers" 76 | self.drop = nn.Dropout(p_dropout) 77 | self.attn_layers = nn.ModuleList() 78 | self.norm_layers_1 = nn.ModuleList() 79 | self.ffn_layers = nn.ModuleList() 80 | self.norm_layers_2 = nn.ModuleList() 81 | 82 | for i in range(self.n_layers): 83 | self.attn_layers.append( 84 | MultiHeadAttention( 85 | hidden_channels, 86 | hidden_channels, 87 | n_heads, 88 | p_dropout=p_dropout, 89 | window_size=window_size, 90 | ) 91 | ) 92 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 93 | self.ffn_layers.append( 94 | FFN( 95 | hidden_channels, 96 | hidden_channels, 97 | filter_channels, 98 | kernel_size, 99 | p_dropout=p_dropout, 100 | ) 101 | ) 102 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 103 | 104 | def forward(self, x, x_mask, g=None): 105 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 106 | x = x * x_mask 107 | for i in range(self.n_layers): 108 | if i == self.cond_layer_idx and g is not None: 109 | g = self.spk_emb_linear(g.transpose(1, 2)) 110 | g = g.transpose(1, 2) 111 | x = x + g 112 | x = x * x_mask 113 | y = self.attn_layers[i](x, x, attn_mask) 114 | y = self.drop(y) 115 | x = self.norm_layers_1[i](x + y) 116 | 117 | y = self.ffn_layers[i](x, x_mask) 118 | y = self.drop(y) 119 | x = self.norm_layers_2[i](x + y) 120 | x = x * x_mask 121 | return x 122 | 123 | 124 | class Decoder(nn.Module): 125 | def __init__( 126 | self, 127 | hidden_channels, 128 | filter_channels, 129 | n_heads, 130 | n_layers, 131 | kernel_size=1, 132 | p_dropout=0.0, 133 | proximal_bias=False, 134 | proximal_init=True, 135 | **kwargs 136 | ): 137 | super().__init__() 138 | self.hidden_channels = hidden_channels 139 | self.filter_channels = filter_channels 140 | self.n_heads = n_heads 141 | self.n_layers = n_layers 142 | self.kernel_size = kernel_size 143 | self.p_dropout = p_dropout 144 | self.proximal_bias = proximal_bias 145 | self.proximal_init = proximal_init 146 | 147 | self.drop = nn.Dropout(p_dropout) 148 | self.self_attn_layers = nn.ModuleList() 149 | self.norm_layers_0 = nn.ModuleList() 150 | self.encdec_attn_layers = nn.ModuleList() 151 | self.norm_layers_1 = nn.ModuleList() 152 | self.ffn_layers = nn.ModuleList() 153 | self.norm_layers_2 = nn.ModuleList() 154 | for i in range(self.n_layers): 155 | self.self_attn_layers.append( 156 | MultiHeadAttention( 157 | hidden_channels, 158 | hidden_channels, 159 | n_heads, 160 | p_dropout=p_dropout, 161 | proximal_bias=proximal_bias, 162 | proximal_init=proximal_init, 163 | ) 164 | ) 165 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 166 | self.encdec_attn_layers.append( 167 | MultiHeadAttention( 168 | hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout 169 | ) 170 | ) 171 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 172 | self.ffn_layers.append( 173 | FFN( 174 | hidden_channels, 175 | hidden_channels, 176 | filter_channels, 177 | kernel_size, 178 | p_dropout=p_dropout, 179 | causal=True, 180 | ) 181 | ) 182 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 183 | 184 | def forward(self, x, x_mask, h, h_mask): 185 | """ 186 | x: decoder input 187 | h: encoder output 188 | """ 189 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( 190 | device=x.device, dtype=x.dtype 191 | ) 192 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 193 | x = x * x_mask 194 | for i in range(self.n_layers): 195 | y = self.self_attn_layers[i](x, x, self_attn_mask) 196 | y = self.drop(y) 197 | x = self.norm_layers_0[i](x + y) 198 | 199 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) 200 | y = self.drop(y) 201 | x = self.norm_layers_1[i](x + y) 202 | 203 | y = self.ffn_layers[i](x, x_mask) 204 | y = self.drop(y) 205 | x = self.norm_layers_2[i](x + y) 206 | x = x * x_mask 207 | return x 208 | 209 | 210 | class MultiHeadAttention(nn.Module): 211 | def __init__( 212 | self, 213 | channels, 214 | out_channels, 215 | n_heads, 216 | p_dropout=0.0, 217 | window_size=None, 218 | heads_share=True, 219 | block_length=None, 220 | proximal_bias=False, 221 | proximal_init=False, 222 | ): 223 | super().__init__() 224 | assert channels % n_heads == 0 225 | 226 | self.channels = channels 227 | self.out_channels = out_channels 228 | self.n_heads = n_heads 229 | self.p_dropout = p_dropout 230 | self.window_size = window_size 231 | self.heads_share = heads_share 232 | self.block_length = block_length 233 | self.proximal_bias = proximal_bias 234 | self.proximal_init = proximal_init 235 | self.attn = None 236 | 237 | self.k_channels = channels // n_heads 238 | self.conv_q = nn.Conv1d(channels, channels, 1) 239 | self.conv_k = nn.Conv1d(channels, channels, 1) 240 | self.conv_v = nn.Conv1d(channels, channels, 1) 241 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 242 | self.drop = nn.Dropout(p_dropout) 243 | 244 | if window_size is not None: 245 | n_heads_rel = 1 if heads_share else n_heads 246 | rel_stddev = self.k_channels**-0.5 247 | self.emb_rel_k = nn.Parameter( 248 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) 249 | * rel_stddev 250 | ) 251 | self.emb_rel_v = nn.Parameter( 252 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) 253 | * rel_stddev 254 | ) 255 | 256 | nn.init.xavier_uniform_(self.conv_q.weight) 257 | nn.init.xavier_uniform_(self.conv_k.weight) 258 | nn.init.xavier_uniform_(self.conv_v.weight) 259 | if proximal_init: 260 | with torch.no_grad(): 261 | self.conv_k.weight.copy_(self.conv_q.weight) 262 | self.conv_k.bias.copy_(self.conv_q.bias) 263 | 264 | def forward(self, x, c, attn_mask=None): 265 | q = self.conv_q(x) 266 | k = self.conv_k(c) 267 | v = self.conv_v(c) 268 | 269 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 270 | 271 | x = self.conv_o(x) 272 | return x 273 | 274 | def attention(self, query, key, value, mask=None): 275 | # reshape [b, d, t] -> [b, n_h, t, d_k] 276 | b, d, t_s, t_t = (*key.size(), query.size(2)) 277 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 278 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 279 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 280 | 281 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) 282 | if self.window_size is not None: 283 | assert ( 284 | t_s == t_t 285 | ), "Relative attention is only available for self-attention." 286 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 287 | rel_logits = self._matmul_with_relative_keys( 288 | query / math.sqrt(self.k_channels), key_relative_embeddings 289 | ) 290 | scores_local = self._relative_position_to_absolute_position(rel_logits) 291 | scores = scores + scores_local 292 | if self.proximal_bias: 293 | assert t_s == t_t, "Proximal bias is only available for self-attention." 294 | scores = scores + self._attention_bias_proximal(t_s).to( 295 | device=scores.device, dtype=scores.dtype 296 | ) 297 | if mask is not None: 298 | scores = scores.masked_fill(mask == 0, -1e4) 299 | if self.block_length is not None: 300 | assert ( 301 | t_s == t_t 302 | ), "Local attention is only available for self-attention." 303 | block_mask = ( 304 | torch.ones_like(scores) 305 | .triu(-self.block_length) 306 | .tril(self.block_length) 307 | ) 308 | scores = scores.masked_fill(block_mask == 0, -1e4) 309 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 310 | p_attn = self.drop(p_attn) 311 | output = torch.matmul(p_attn, value) 312 | if self.window_size is not None: 313 | relative_weights = self._absolute_position_to_relative_position(p_attn) 314 | value_relative_embeddings = self._get_relative_embeddings( 315 | self.emb_rel_v, t_s 316 | ) 317 | output = output + self._matmul_with_relative_values( 318 | relative_weights, value_relative_embeddings 319 | ) 320 | output = ( 321 | output.transpose(2, 3).contiguous().view(b, d, t_t) 322 | ) # [b, n_h, t_t, d_k] -> [b, d, t_t] 323 | return output, p_attn 324 | 325 | def _matmul_with_relative_values(self, x, y): 326 | """ 327 | x: [b, h, l, m] 328 | y: [h or 1, m, d] 329 | ret: [b, h, l, d] 330 | """ 331 | ret = torch.matmul(x, y.unsqueeze(0)) 332 | return ret 333 | 334 | def _matmul_with_relative_keys(self, x, y): 335 | """ 336 | x: [b, h, l, d] 337 | y: [h or 1, m, d] 338 | ret: [b, h, l, m] 339 | """ 340 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 341 | return ret 342 | 343 | def _get_relative_embeddings(self, relative_embeddings, length): 344 | 2 * self.window_size + 1 345 | # Pad first before slice to avoid using cond ops. 346 | pad_length = max(length - (self.window_size + 1), 0) 347 | slice_start_position = max((self.window_size + 1) - length, 0) 348 | slice_end_position = slice_start_position + 2 * length - 1 349 | if pad_length > 0: 350 | padded_relative_embeddings = F.pad( 351 | relative_embeddings, 352 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), 353 | ) 354 | else: 355 | padded_relative_embeddings = relative_embeddings 356 | used_relative_embeddings = padded_relative_embeddings[ 357 | :, slice_start_position:slice_end_position 358 | ] 359 | return used_relative_embeddings 360 | 361 | def _relative_position_to_absolute_position(self, x): 362 | """ 363 | x: [b, h, l, 2*l-1] 364 | ret: [b, h, l, l] 365 | """ 366 | batch, heads, length, _ = x.size() 367 | # Concat columns of pad to shift from relative to absolute indexing. 368 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) 369 | 370 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 371 | x_flat = x.view([batch, heads, length * 2 * length]) 372 | x_flat = F.pad( 373 | x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) 374 | ) 375 | 376 | # Reshape and slice out the padded elements. 377 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ 378 | :, :, :length, length - 1 : 379 | ] 380 | return x_final 381 | 382 | def _absolute_position_to_relative_position(self, x): 383 | """ 384 | x: [b, h, l, l] 385 | ret: [b, h, l, 2*l-1] 386 | """ 387 | batch, heads, length, _ = x.size() 388 | # pad along column 389 | x = F.pad( 390 | x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) 391 | ) 392 | x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) 393 | # add 0's in the beginning that will skew the elements after reshape 394 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 395 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] 396 | return x_final 397 | 398 | def _attention_bias_proximal(self, length): 399 | """Bias for self-attention to encourage attention to close positions. 400 | Args: 401 | length: an integer scalar. 402 | Returns: 403 | a Tensor with shape [1, 1, length, length] 404 | """ 405 | r = torch.arange(length, dtype=torch.float32) 406 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 407 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 408 | 409 | 410 | class FFN(nn.Module): 411 | def __init__( 412 | self, 413 | in_channels, 414 | out_channels, 415 | filter_channels, 416 | kernel_size, 417 | p_dropout=0.0, 418 | activation=None, 419 | causal=False, 420 | ): 421 | super().__init__() 422 | self.in_channels = in_channels 423 | self.out_channels = out_channels 424 | self.filter_channels = filter_channels 425 | self.kernel_size = kernel_size 426 | self.p_dropout = p_dropout 427 | self.activation = activation 428 | self.causal = causal 429 | 430 | if causal: 431 | self.padding = self._causal_padding 432 | else: 433 | self.padding = self._same_padding 434 | 435 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) 436 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) 437 | self.drop = nn.Dropout(p_dropout) 438 | 439 | def forward(self, x, x_mask): 440 | x = self.conv_1(self.padding(x * x_mask)) 441 | if self.activation == "gelu": 442 | x = x * torch.sigmoid(1.702 * x) 443 | else: 444 | x = torch.relu(x) 445 | x = self.drop(x) 446 | x = self.conv_2(self.padding(x * x_mask)) 447 | return x * x_mask 448 | 449 | def _causal_padding(self, x): 450 | if self.kernel_size == 1: 451 | return x 452 | pad_l = self.kernel_size - 1 453 | pad_r = 0 454 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 455 | x = F.pad(x, commons.convert_pad_shape(padding)) 456 | return x 457 | 458 | def _same_padding(self, x): 459 | if self.kernel_size == 1: 460 | return x 461 | pad_l = (self.kernel_size - 1) // 2 462 | pad_r = self.kernel_size // 2 463 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 464 | x = F.pad(x, commons.convert_pad_shape(padding)) 465 | return x 466 | -------------------------------------------------------------------------------- /nodes/openvoice/commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | def init_weights(m, mean=0.0, std=0.01): 7 | classname = m.__class__.__name__ 8 | if classname.find("Conv") != -1: 9 | m.weight.data.normal_(mean, std) 10 | 11 | 12 | def get_padding(kernel_size, dilation=1): 13 | return int((kernel_size * dilation - dilation) / 2) 14 | 15 | 16 | def convert_pad_shape(pad_shape): 17 | layer = pad_shape[::-1] 18 | pad_shape = [item for sublist in layer for item in sublist] 19 | return pad_shape 20 | 21 | 22 | def intersperse(lst, item): 23 | result = [item] * (len(lst) * 2 + 1) 24 | result[1::2] = lst 25 | return result 26 | 27 | 28 | def kl_divergence(m_p, logs_p, m_q, logs_q): 29 | """KL(P||Q)""" 30 | kl = (logs_q - logs_p) - 0.5 31 | kl += ( 32 | 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) 33 | ) 34 | return kl 35 | 36 | 37 | def rand_gumbel(shape): 38 | """Sample from the Gumbel distribution, protect from overflows.""" 39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 40 | return -torch.log(-torch.log(uniform_samples)) 41 | 42 | 43 | def rand_gumbel_like(x): 44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 45 | return g 46 | 47 | 48 | def slice_segments(x, ids_str, segment_size=4): 49 | ret = torch.zeros_like(x[:, :, :segment_size]) 50 | for i in range(x.size(0)): 51 | idx_str = ids_str[i] 52 | idx_end = idx_str + segment_size 53 | ret[i] = x[i, :, idx_str:idx_end] 54 | return ret 55 | 56 | 57 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 58 | b, d, t = x.size() 59 | if x_lengths is None: 60 | x_lengths = t 61 | ids_str_max = x_lengths - segment_size + 1 62 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 63 | ret = slice_segments(x, ids_str, segment_size) 64 | return ret, ids_str 65 | 66 | 67 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): 68 | position = torch.arange(length, dtype=torch.float) 69 | num_timescales = channels // 2 70 | log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( 71 | num_timescales - 1 72 | ) 73 | inv_timescales = min_timescale * torch.exp( 74 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment 75 | ) 76 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 77 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 78 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 79 | signal = signal.view(1, channels, length) 80 | return signal 81 | 82 | 83 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 84 | b, channels, length = x.size() 85 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 86 | return x + signal.to(dtype=x.dtype, device=x.device) 87 | 88 | 89 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 90 | b, channels, length = x.size() 91 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 92 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 93 | 94 | 95 | def subsequent_mask(length): 96 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 97 | return mask 98 | 99 | 100 | @torch.jit.script 101 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 102 | n_channels_int = n_channels[0] 103 | in_act = input_a + input_b 104 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 105 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 106 | acts = t_act * s_act 107 | return acts 108 | 109 | 110 | def convert_pad_shape(pad_shape): 111 | layer = pad_shape[::-1] 112 | pad_shape = [item for sublist in layer for item in sublist] 113 | return pad_shape 114 | 115 | 116 | def shift_1d(x): 117 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 118 | return x 119 | 120 | 121 | def sequence_mask(length, max_length=None): 122 | if max_length is None: 123 | max_length = length.max() 124 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 125 | return x.unsqueeze(0) < length.unsqueeze(1) 126 | 127 | 128 | def generate_path(duration, mask): 129 | """ 130 | duration: [b, 1, t_x] 131 | mask: [b, 1, t_y, t_x] 132 | """ 133 | 134 | b, _, t_y, t_x = mask.shape 135 | cum_duration = torch.cumsum(duration, -1) 136 | 137 | cum_duration_flat = cum_duration.view(b * t_x) 138 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 139 | path = path.view(b, t_x, t_y) 140 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 141 | path = path.unsqueeze(1).transpose(2, 3) * mask 142 | return path 143 | 144 | 145 | def clip_grad_value_(parameters, clip_value, norm_type=2): 146 | if isinstance(parameters, torch.Tensor): 147 | parameters = [parameters] 148 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 149 | norm_type = float(norm_type) 150 | if clip_value is not None: 151 | clip_value = float(clip_value) 152 | 153 | total_norm = 0 154 | for p in parameters: 155 | param_norm = p.grad.data.norm(norm_type) 156 | total_norm += param_norm.item() ** norm_type 157 | if clip_value is not None: 158 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 159 | total_norm = total_norm ** (1.0 / norm_type) 160 | return total_norm 161 | -------------------------------------------------------------------------------- /nodes/openvoice/mel_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from librosa.filters import mel as librosa_mel_fn 4 | 5 | MAX_WAV_VALUE = 32768.0 6 | 7 | 8 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 9 | """ 10 | PARAMS 11 | ------ 12 | C: compression factor 13 | """ 14 | return torch.log(torch.clamp(x, min=clip_val) * C) 15 | 16 | 17 | def dynamic_range_decompression_torch(x, C=1): 18 | """ 19 | PARAMS 20 | ------ 21 | C: compression factor used to compress 22 | """ 23 | return torch.exp(x) / C 24 | 25 | 26 | def spectral_normalize_torch(magnitudes): 27 | output = dynamic_range_compression_torch(magnitudes) 28 | return output 29 | 30 | 31 | def spectral_de_normalize_torch(magnitudes): 32 | output = dynamic_range_decompression_torch(magnitudes) 33 | return output 34 | 35 | 36 | mel_basis = {} 37 | hann_window = {} 38 | 39 | 40 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 41 | if torch.min(y) < -1.1: 42 | print("min value is ", torch.min(y)) 43 | if torch.max(y) > 1.1: 44 | print("max value is ", torch.max(y)) 45 | 46 | global hann_window 47 | dtype_device = str(y.dtype) + "_" + str(y.device) 48 | wnsize_dtype_device = str(win_size) + "_" + dtype_device 49 | if wnsize_dtype_device not in hann_window: 50 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( 51 | dtype=y.dtype, device=y.device 52 | ) 53 | 54 | y = torch.nn.functional.pad( 55 | y.unsqueeze(1), 56 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 57 | mode="reflect", 58 | ) 59 | y = y.squeeze(1) 60 | 61 | spec = torch.stft( 62 | y, 63 | n_fft, 64 | hop_length=hop_size, 65 | win_length=win_size, 66 | window=hann_window[wnsize_dtype_device], 67 | center=center, 68 | pad_mode="reflect", 69 | normalized=False, 70 | onesided=True, 71 | return_complex=False, 72 | ) 73 | 74 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 75 | return spec 76 | 77 | 78 | def spectrogram_torch_conv(y, n_fft, sampling_rate, hop_size, win_size, center=False): 79 | # if torch.min(y) < -1.: 80 | # print('min value is ', torch.min(y)) 81 | # if torch.max(y) > 1.: 82 | # print('max value is ', torch.max(y)) 83 | 84 | global hann_window 85 | dtype_device = str(y.dtype) + '_' + str(y.device) 86 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 87 | if wnsize_dtype_device not in hann_window: 88 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 89 | 90 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 91 | 92 | # ******************** original ************************# 93 | # y = y.squeeze(1) 94 | # spec1 = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 95 | # center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 96 | 97 | # ******************** ConvSTFT ************************# 98 | freq_cutoff = n_fft // 2 + 1 99 | fourier_basis = torch.view_as_real(torch.fft.fft(torch.eye(n_fft))) 100 | forward_basis = fourier_basis[:freq_cutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis.shape[1]) 101 | forward_basis = forward_basis * torch.as_tensor(librosa.util.pad_center(torch.hann_window(win_size), size=n_fft)).float() 102 | 103 | import torch.nn.functional as F 104 | 105 | # if center: 106 | # signal = F.pad(y[:, None, None, :], (n_fft // 2, n_fft // 2, 0, 0), mode = 'reflect').squeeze(1) 107 | assert center is False 108 | 109 | forward_transform_squared = F.conv1d(y, forward_basis.to(y.device), stride = hop_size) 110 | spec2 = torch.stack([forward_transform_squared[:, :freq_cutoff, :], forward_transform_squared[:, freq_cutoff:, :]], dim = -1) 111 | 112 | 113 | # ******************** Verification ************************# 114 | spec1 = torch.stft(y.squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 115 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 116 | assert torch.allclose(spec1, spec2, atol=1e-4) 117 | 118 | spec = torch.sqrt(spec2.pow(2).sum(-1) + 1e-6) 119 | return spec 120 | 121 | 122 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 123 | global mel_basis 124 | dtype_device = str(spec.dtype) + "_" + str(spec.device) 125 | fmax_dtype_device = str(fmax) + "_" + dtype_device 126 | if fmax_dtype_device not in mel_basis: 127 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 128 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( 129 | dtype=spec.dtype, device=spec.device 130 | ) 131 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 132 | spec = spectral_normalize_torch(spec) 133 | return spec 134 | 135 | 136 | def mel_spectrogram_torch( 137 | y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False 138 | ): 139 | if torch.min(y) < -1.0: 140 | print("min value is ", torch.min(y)) 141 | if torch.max(y) > 1.0: 142 | print("max value is ", torch.max(y)) 143 | 144 | global mel_basis, hann_window 145 | dtype_device = str(y.dtype) + "_" + str(y.device) 146 | fmax_dtype_device = str(fmax) + "_" + dtype_device 147 | wnsize_dtype_device = str(win_size) + "_" + dtype_device 148 | if fmax_dtype_device not in mel_basis: 149 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 150 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( 151 | dtype=y.dtype, device=y.device 152 | ) 153 | if wnsize_dtype_device not in hann_window: 154 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( 155 | dtype=y.dtype, device=y.device 156 | ) 157 | 158 | y = torch.nn.functional.pad( 159 | y.unsqueeze(1), 160 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 161 | mode="reflect", 162 | ) 163 | y = y.squeeze(1) 164 | 165 | spec = torch.stft( 166 | y, 167 | n_fft, 168 | hop_length=hop_size, 169 | win_length=win_size, 170 | window=hann_window[wnsize_dtype_device], 171 | center=center, 172 | pad_mode="reflect", 173 | normalized=False, 174 | onesided=True, 175 | return_complex=False, 176 | ) 177 | 178 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 179 | 180 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 181 | spec = spectral_normalize_torch(spec) 182 | 183 | return spec -------------------------------------------------------------------------------- /nodes/openvoice/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from openvoice import commons 7 | from openvoice import modules 8 | from openvoice import attentions 9 | 10 | from torch.nn import Conv1d, ConvTranspose1d, Conv2d 11 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 12 | 13 | from openvoice.commons import init_weights, get_padding 14 | 15 | 16 | class TextEncoder(nn.Module): 17 | def __init__(self, 18 | n_vocab, 19 | out_channels, 20 | hidden_channels, 21 | filter_channels, 22 | n_heads, 23 | n_layers, 24 | kernel_size, 25 | p_dropout): 26 | super().__init__() 27 | self.n_vocab = n_vocab 28 | self.out_channels = out_channels 29 | self.hidden_channels = hidden_channels 30 | self.filter_channels = filter_channels 31 | self.n_heads = n_heads 32 | self.n_layers = n_layers 33 | self.kernel_size = kernel_size 34 | self.p_dropout = p_dropout 35 | 36 | self.emb = nn.Embedding(n_vocab, hidden_channels) 37 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) 38 | 39 | self.encoder = attentions.Encoder( 40 | hidden_channels, 41 | filter_channels, 42 | n_heads, 43 | n_layers, 44 | kernel_size, 45 | p_dropout) 46 | self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1) 47 | 48 | def forward(self, x, x_lengths): 49 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] 50 | x = torch.transpose(x, 1, -1) # [b, h, t] 51 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 52 | 53 | x = self.encoder(x * x_mask, x_mask) 54 | stats = self.proj(x) * x_mask 55 | 56 | m, logs = torch.split(stats, self.out_channels, dim=1) 57 | return x, m, logs, x_mask 58 | 59 | 60 | class DurationPredictor(nn.Module): 61 | def __init__( 62 | self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0 63 | ): 64 | super().__init__() 65 | 66 | self.in_channels = in_channels 67 | self.filter_channels = filter_channels 68 | self.kernel_size = kernel_size 69 | self.p_dropout = p_dropout 70 | self.gin_channels = gin_channels 71 | 72 | self.drop = nn.Dropout(p_dropout) 73 | self.conv_1 = nn.Conv1d( 74 | in_channels, filter_channels, kernel_size, padding=kernel_size // 2 75 | ) 76 | self.norm_1 = modules.LayerNorm(filter_channels) 77 | self.conv_2 = nn.Conv1d( 78 | filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 79 | ) 80 | self.norm_2 = modules.LayerNorm(filter_channels) 81 | self.proj = nn.Conv1d(filter_channels, 1, 1) 82 | 83 | if gin_channels != 0: 84 | self.cond = nn.Conv1d(gin_channels, in_channels, 1) 85 | 86 | def forward(self, x, x_mask, g=None): 87 | x = torch.detach(x) 88 | if g is not None: 89 | g = torch.detach(g) 90 | x = x + self.cond(g) 91 | x = self.conv_1(x * x_mask) 92 | x = torch.relu(x) 93 | x = self.norm_1(x) 94 | x = self.drop(x) 95 | x = self.conv_2(x * x_mask) 96 | x = torch.relu(x) 97 | x = self.norm_2(x) 98 | x = self.drop(x) 99 | x = self.proj(x * x_mask) 100 | return x * x_mask 101 | 102 | class StochasticDurationPredictor(nn.Module): 103 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): 104 | super().__init__() 105 | filter_channels = in_channels # it needs to be removed from future version. 106 | self.in_channels = in_channels 107 | self.filter_channels = filter_channels 108 | self.kernel_size = kernel_size 109 | self.p_dropout = p_dropout 110 | self.n_flows = n_flows 111 | self.gin_channels = gin_channels 112 | 113 | self.log_flow = modules.Log() 114 | self.flows = nn.ModuleList() 115 | self.flows.append(modules.ElementwiseAffine(2)) 116 | for i in range(n_flows): 117 | self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 118 | self.flows.append(modules.Flip()) 119 | 120 | self.post_pre = nn.Conv1d(1, filter_channels, 1) 121 | self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) 122 | self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 123 | self.post_flows = nn.ModuleList() 124 | self.post_flows.append(modules.ElementwiseAffine(2)) 125 | for i in range(4): 126 | self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 127 | self.post_flows.append(modules.Flip()) 128 | 129 | self.pre = nn.Conv1d(in_channels, filter_channels, 1) 130 | self.proj = nn.Conv1d(filter_channels, filter_channels, 1) 131 | self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 132 | if gin_channels != 0: 133 | self.cond = nn.Conv1d(gin_channels, filter_channels, 1) 134 | 135 | def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): 136 | x = torch.detach(x) 137 | x = self.pre(x) 138 | if g is not None: 139 | g = torch.detach(g) 140 | x = x + self.cond(g) 141 | x = self.convs(x, x_mask) 142 | x = self.proj(x) * x_mask 143 | 144 | if not reverse: 145 | flows = self.flows 146 | assert w is not None 147 | 148 | logdet_tot_q = 0 149 | h_w = self.post_pre(w) 150 | h_w = self.post_convs(h_w, x_mask) 151 | h_w = self.post_proj(h_w) * x_mask 152 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask 153 | z_q = e_q 154 | for flow in self.post_flows: 155 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) 156 | logdet_tot_q += logdet_q 157 | z_u, z1 = torch.split(z_q, [1, 1], 1) 158 | u = torch.sigmoid(z_u) * x_mask 159 | z0 = (w - u) * x_mask 160 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2]) 161 | logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q 162 | 163 | logdet_tot = 0 164 | z0, logdet = self.log_flow(z0, x_mask) 165 | logdet_tot += logdet 166 | z = torch.cat([z0, z1], 1) 167 | for flow in flows: 168 | z, logdet = flow(z, x_mask, g=x, reverse=reverse) 169 | logdet_tot = logdet_tot + logdet 170 | nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot 171 | return nll + logq # [b] 172 | else: 173 | flows = list(reversed(self.flows)) 174 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow 175 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale 176 | for flow in flows: 177 | z = flow(z, x_mask, g=x, reverse=reverse) 178 | z0, z1 = torch.split(z, [1, 1], 1) 179 | logw = z0 180 | return logw 181 | 182 | class PosteriorEncoder(nn.Module): 183 | def __init__( 184 | self, 185 | in_channels, 186 | out_channels, 187 | hidden_channels, 188 | kernel_size, 189 | dilation_rate, 190 | n_layers, 191 | gin_channels=0, 192 | ): 193 | super().__init__() 194 | self.in_channels = in_channels 195 | self.out_channels = out_channels 196 | self.hidden_channels = hidden_channels 197 | self.kernel_size = kernel_size 198 | self.dilation_rate = dilation_rate 199 | self.n_layers = n_layers 200 | self.gin_channels = gin_channels 201 | 202 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 203 | self.enc = modules.WN( 204 | hidden_channels, 205 | kernel_size, 206 | dilation_rate, 207 | n_layers, 208 | gin_channels=gin_channels, 209 | ) 210 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) 211 | 212 | def forward(self, x, x_lengths, g=None, tau=1.0): 213 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( 214 | x.dtype 215 | ) 216 | x = self.pre(x) * x_mask 217 | x = self.enc(x, x_mask, g=g) 218 | stats = self.proj(x) * x_mask 219 | m, logs = torch.split(stats, self.out_channels, dim=1) 220 | z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask 221 | return z, m, logs, x_mask 222 | 223 | 224 | class Generator(torch.nn.Module): 225 | def __init__( 226 | self, 227 | initial_channel, 228 | resblock, 229 | resblock_kernel_sizes, 230 | resblock_dilation_sizes, 231 | upsample_rates, 232 | upsample_initial_channel, 233 | upsample_kernel_sizes, 234 | gin_channels=0, 235 | ): 236 | super(Generator, self).__init__() 237 | self.num_kernels = len(resblock_kernel_sizes) 238 | self.num_upsamples = len(upsample_rates) 239 | self.conv_pre = Conv1d( 240 | initial_channel, upsample_initial_channel, 7, 1, padding=3 241 | ) 242 | resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 243 | 244 | self.ups = nn.ModuleList() 245 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): 246 | self.ups.append( 247 | weight_norm( 248 | ConvTranspose1d( 249 | upsample_initial_channel // (2**i), 250 | upsample_initial_channel // (2 ** (i + 1)), 251 | k, 252 | u, 253 | padding=(k - u) // 2, 254 | ) 255 | ) 256 | ) 257 | 258 | self.resblocks = nn.ModuleList() 259 | for i in range(len(self.ups)): 260 | ch = upsample_initial_channel // (2 ** (i + 1)) 261 | for j, (k, d) in enumerate( 262 | zip(resblock_kernel_sizes, resblock_dilation_sizes) 263 | ): 264 | self.resblocks.append(resblock(ch, k, d)) 265 | 266 | self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) 267 | self.ups.apply(init_weights) 268 | 269 | if gin_channels != 0: 270 | self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) 271 | 272 | def forward(self, x, g=None): 273 | x = self.conv_pre(x) 274 | if g is not None: 275 | x = x + self.cond(g) 276 | 277 | for i in range(self.num_upsamples): 278 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 279 | x = self.ups[i](x) 280 | xs = None 281 | for j in range(self.num_kernels): 282 | if xs is None: 283 | xs = self.resblocks[i * self.num_kernels + j](x) 284 | else: 285 | xs += self.resblocks[i * self.num_kernels + j](x) 286 | x = xs / self.num_kernels 287 | x = F.leaky_relu(x) 288 | x = self.conv_post(x) 289 | x = torch.tanh(x) 290 | 291 | return x 292 | 293 | def remove_weight_norm(self): 294 | print("Removing weight norm...") 295 | for layer in self.ups: 296 | remove_weight_norm(layer) 297 | for layer in self.resblocks: 298 | layer.remove_weight_norm() 299 | 300 | 301 | class ReferenceEncoder(nn.Module): 302 | """ 303 | inputs --- [N, Ty/r, n_mels*r] mels 304 | outputs --- [N, ref_enc_gru_size] 305 | """ 306 | 307 | def __init__(self, spec_channels, gin_channels=0, layernorm=True): 308 | super().__init__() 309 | self.spec_channels = spec_channels 310 | ref_enc_filters = [32, 32, 64, 64, 128, 128] 311 | K = len(ref_enc_filters) 312 | filters = [1] + ref_enc_filters 313 | convs = [ 314 | weight_norm( 315 | nn.Conv2d( 316 | in_channels=filters[i], 317 | out_channels=filters[i + 1], 318 | kernel_size=(3, 3), 319 | stride=(2, 2), 320 | padding=(1, 1), 321 | ) 322 | ) 323 | for i in range(K) 324 | ] 325 | self.convs = nn.ModuleList(convs) 326 | 327 | out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) 328 | self.gru = nn.GRU( 329 | input_size=ref_enc_filters[-1] * out_channels, 330 | hidden_size=256 // 2, 331 | batch_first=True, 332 | ) 333 | self.proj = nn.Linear(128, gin_channels) 334 | if layernorm: 335 | self.layernorm = nn.LayerNorm(self.spec_channels) 336 | else: 337 | self.layernorm = None 338 | 339 | def forward(self, inputs, mask=None): 340 | N = inputs.size(0) 341 | 342 | out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] 343 | if self.layernorm is not None: 344 | out = self.layernorm(out) 345 | 346 | for conv in self.convs: 347 | out = conv(out) 348 | # out = wn(out) 349 | out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] 350 | 351 | out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] 352 | T = out.size(1) 353 | N = out.size(0) 354 | out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] 355 | 356 | self.gru.flatten_parameters() 357 | memory, out = self.gru(out) # out --- [1, N, 128] 358 | 359 | return self.proj(out.squeeze(0)) 360 | 361 | def calculate_channels(self, L, kernel_size, stride, pad, n_convs): 362 | for i in range(n_convs): 363 | L = (L - kernel_size + 2 * pad) // stride + 1 364 | return L 365 | 366 | 367 | class ResidualCouplingBlock(nn.Module): 368 | def __init__(self, 369 | channels, 370 | hidden_channels, 371 | kernel_size, 372 | dilation_rate, 373 | n_layers, 374 | n_flows=4, 375 | gin_channels=0): 376 | super().__init__() 377 | self.channels = channels 378 | self.hidden_channels = hidden_channels 379 | self.kernel_size = kernel_size 380 | self.dilation_rate = dilation_rate 381 | self.n_layers = n_layers 382 | self.n_flows = n_flows 383 | self.gin_channels = gin_channels 384 | 385 | self.flows = nn.ModuleList() 386 | for i in range(n_flows): 387 | self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) 388 | self.flows.append(modules.Flip()) 389 | 390 | def forward(self, x, x_mask, g=None, reverse=False): 391 | if not reverse: 392 | for flow in self.flows: 393 | x, _ = flow(x, x_mask, g=g, reverse=reverse) 394 | else: 395 | for flow in reversed(self.flows): 396 | x = flow(x, x_mask, g=g, reverse=reverse) 397 | return x 398 | 399 | class SynthesizerTrn(nn.Module): 400 | """ 401 | Synthesizer for Training 402 | """ 403 | 404 | def __init__( 405 | self, 406 | n_vocab, 407 | spec_channels, 408 | inter_channels, 409 | hidden_channels, 410 | filter_channels, 411 | n_heads, 412 | n_layers, 413 | kernel_size, 414 | p_dropout, 415 | resblock, 416 | resblock_kernel_sizes, 417 | resblock_dilation_sizes, 418 | upsample_rates, 419 | upsample_initial_channel, 420 | upsample_kernel_sizes, 421 | n_speakers=256, 422 | gin_channels=256, 423 | zero_g=False, 424 | **kwargs 425 | ): 426 | super().__init__() 427 | 428 | self.dec = Generator( 429 | inter_channels, 430 | resblock, 431 | resblock_kernel_sizes, 432 | resblock_dilation_sizes, 433 | upsample_rates, 434 | upsample_initial_channel, 435 | upsample_kernel_sizes, 436 | gin_channels=gin_channels, 437 | ) 438 | self.enc_q = PosteriorEncoder( 439 | spec_channels, 440 | inter_channels, 441 | hidden_channels, 442 | 5, 443 | 1, 444 | 16, 445 | gin_channels=gin_channels, 446 | ) 447 | 448 | self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) 449 | 450 | self.n_speakers = n_speakers 451 | if n_speakers == 0: 452 | self.ref_enc = ReferenceEncoder(spec_channels, gin_channels) 453 | else: 454 | self.enc_p = TextEncoder(n_vocab, 455 | inter_channels, 456 | hidden_channels, 457 | filter_channels, 458 | n_heads, 459 | n_layers, 460 | kernel_size, 461 | p_dropout) 462 | self.sdp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) 463 | self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) 464 | self.emb_g = nn.Embedding(n_speakers, gin_channels) 465 | self.zero_g = zero_g 466 | 467 | def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., sdp_ratio=0.2, max_len=None): 468 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) 469 | if self.n_speakers > 0: 470 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] 471 | else: 472 | g = None 473 | 474 | logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * sdp_ratio \ 475 | + self.dp(x, x_mask, g=g) * (1 - sdp_ratio) 476 | 477 | w = torch.exp(logw) * x_mask * length_scale 478 | w_ceil = torch.ceil(w) 479 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 480 | y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype) 481 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) 482 | attn = commons.generate_path(w_ceil, attn_mask) 483 | 484 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] 485 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] 486 | 487 | z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale 488 | z = self.flow(z_p, y_mask, g=g, reverse=True) 489 | o = self.dec((z * y_mask)[:,:,:max_len], g=g) 490 | return o, attn, y_mask, (z, z_p, m_p, logs_p) 491 | 492 | def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0): 493 | g_src = sid_src 494 | g_tgt = sid_tgt 495 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=tau) 496 | z_p = self.flow(z, y_mask, g=g_src) 497 | z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) 498 | o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt)) 499 | return o_hat, y_mask, (z, z_p, z_hat) 500 | -------------------------------------------------------------------------------- /nodes/openvoice/openvoice_app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import gradio as gr 5 | from zipfile import ZipFile 6 | import langid 7 | from openvoice import se_extractor 8 | from openvoice.api import BaseSpeakerTTS, ToneColorConverter 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--share", action='store_true', default=False, help="make link public") 12 | args = parser.parse_args() 13 | 14 | en_ckpt_base = 'checkpoints/base_speakers/EN' 15 | zh_ckpt_base = 'checkpoints/base_speakers/ZH' 16 | ckpt_converter = 'checkpoints/converter' 17 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 18 | output_dir = 'outputs' 19 | os.makedirs(output_dir, exist_ok=True) 20 | 21 | # load models 22 | en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device) 23 | en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth') 24 | zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device) 25 | zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth') 26 | tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device) 27 | tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth') 28 | 29 | # load speaker embeddings 30 | en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device) 31 | en_source_style_se = torch.load(f'{en_ckpt_base}/en_style_se.pth').to(device) 32 | zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device) 33 | 34 | # This online demo mainly supports English and Chinese 35 | supported_languages = ['zh', 'en'] 36 | 37 | def predict(prompt, style, audio_file_pth, agree): 38 | # initialize a empty info 39 | text_hint = '' 40 | # agree with the terms 41 | if agree == False: 42 | text_hint += '[ERROR] Please accept the Terms & Condition!\n' 43 | gr.Warning("Please accept the Terms & Condition!") 44 | return ( 45 | text_hint, 46 | None, 47 | None, 48 | ) 49 | 50 | # first detect the input language 51 | language_predicted = langid.classify(prompt)[0].strip() 52 | print(f"Detected language:{language_predicted}") 53 | 54 | if language_predicted not in supported_languages: 55 | text_hint += f"[ERROR] The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}\n" 56 | gr.Warning( 57 | f"The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}" 58 | ) 59 | 60 | return ( 61 | text_hint, 62 | None, 63 | None, 64 | ) 65 | 66 | if language_predicted == "zh": 67 | tts_model = zh_base_speaker_tts 68 | source_se = zh_source_se 69 | language = 'Chinese' 70 | if style not in ['default']: 71 | text_hint += f"[ERROR] The style {style} is not supported for Chinese, which should be in ['default']\n" 72 | gr.Warning(f"The style {style} is not supported for Chinese, which should be in ['default']") 73 | return ( 74 | text_hint, 75 | None, 76 | None, 77 | ) 78 | 79 | else: 80 | tts_model = en_base_speaker_tts 81 | if style == 'default': 82 | source_se = en_source_default_se 83 | else: 84 | source_se = en_source_style_se 85 | language = 'English' 86 | if style not in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']: 87 | text_hint += f"[ERROR] The style {style} is not supported for English, which should be in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']\n" 88 | gr.Warning(f"The style {style} is not supported for English, which should be in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']") 89 | return ( 90 | text_hint, 91 | None, 92 | None, 93 | ) 94 | 95 | speaker_wav = audio_file_pth 96 | 97 | if len(prompt) < 2: 98 | text_hint += f"[ERROR] Please give a longer prompt text \n" 99 | gr.Warning("Please give a longer prompt text") 100 | return ( 101 | text_hint, 102 | None, 103 | None, 104 | ) 105 | if len(prompt) > 200: 106 | text_hint += f"[ERROR] Text length limited to 200 characters for this demo, please try shorter text. You can clone our open-source repo and try for your usage \n" 107 | gr.Warning( 108 | "Text length limited to 200 characters for this demo, please try shorter text. You can clone our open-source repo for your usage" 109 | ) 110 | return ( 111 | text_hint, 112 | None, 113 | None, 114 | ) 115 | 116 | # note diffusion_conditioning not used on hifigan (default mode), it will be empty but need to pass it to model.inference 117 | try: 118 | target_se, audio_name = se_extractor.get_se(speaker_wav, tone_color_converter, target_dir='processed', vad=True) 119 | except Exception as e: 120 | text_hint += f"[ERROR] Get target tone color error {str(e)} \n" 121 | gr.Warning( 122 | "[ERROR] Get target tone color error {str(e)} \n" 123 | ) 124 | return ( 125 | text_hint, 126 | None, 127 | None, 128 | ) 129 | 130 | src_path = f'{output_dir}/tmp.wav' 131 | tts_model.tts(prompt, src_path, speaker=style, language=language) 132 | 133 | save_path = f'{output_dir}/output.wav' 134 | # Run the tone color converter 135 | encode_message = "@MyShell" 136 | tone_color_converter.convert( 137 | audio_src_path=src_path, 138 | src_se=source_se, 139 | tgt_se=target_se, 140 | output_path=save_path, 141 | message=encode_message) 142 | 143 | text_hint += f'''Get response successfully \n''' 144 | 145 | return ( 146 | text_hint, 147 | save_path, 148 | speaker_wav, 149 | ) 150 | 151 | 152 | 153 | title = "MyShell OpenVoice" 154 | 155 | description = """ 156 | We introduce OpenVoice, a versatile instant voice cloning approach that requires only a short audio clip from the reference speaker to replicate their voice and generate speech in multiple languages. OpenVoice enables granular control over voice styles, including emotion, accent, rhythm, pauses, and intonation, in addition to replicating the tone color of the reference speaker. OpenVoice also achieves zero-shot cross-lingual voice cloning for languages not included in the massive-speaker training set. 157 | """ 158 | 159 | markdown_table = """ 160 |
161 | 162 | | | | | 163 | | :-----------: | :-----------: | :-----------: | 164 | | **OpenSource Repo** | **Project Page** | **Join the Community** | 165 | |
| [OpenVoice](https://research.myshell.ai/open-voice) | [![Discord](https://img.shields.io/discord/1122227993805336617?color=%239B59B6&label=%20Discord%20)](https://discord.gg/myshell) | 166 | 167 |
168 | """ 169 | 170 | markdown_table_v2 = """ 171 |
172 | 173 | | | | | | 174 | | :-----------: | :-----------: | :-----------: | :-----------: | 175 | | **OpenSource Repo** |
| **Project Page** | [OpenVoice](https://research.myshell.ai/open-voice) | 176 | 177 | | | | 178 | | :-----------: | :-----------: | 179 | **Join the Community** | [![Discord](https://img.shields.io/discord/1122227993805336617?color=%239B59B6&label=%20Discord%20)](https://discord.gg/myshell) | 180 | 181 |
182 | """ 183 | content = """ 184 |
185 | If the generated voice does not sound like the reference voice, please refer to this QnA. For multi-lingual & cross-lingual examples, please refer to this jupyter notebook. 186 | This online demo mainly supports English. The default style also supports Chinese. But OpenVoice can adapt to any other language as long as a base speaker is provided. 187 |
188 | """ 189 | wrapped_markdown_content = f"
{content}
" 190 | 191 | 192 | examples = [ 193 | [ 194 | "今天天气真好,我们一起出去吃饭吧。", 195 | 'default', 196 | "resources/demo_speaker1.mp3", 197 | True, 198 | ],[ 199 | "This audio is generated by open voice with a half-performance model.", 200 | 'whispering', 201 | "resources/demo_speaker2.mp3", 202 | True, 203 | ], 204 | [ 205 | "He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.", 206 | 'sad', 207 | "resources/demo_speaker0.mp3", 208 | True, 209 | ], 210 | ] 211 | 212 | with gr.Blocks(analytics_enabled=False) as demo: 213 | 214 | with gr.Row(): 215 | with gr.Column(): 216 | with gr.Row(): 217 | gr.Markdown( 218 | """ 219 | ## 220 | """ 221 | ) 222 | with gr.Row(): 223 | gr.Markdown(markdown_table_v2) 224 | with gr.Row(): 225 | gr.Markdown(description) 226 | with gr.Column(): 227 | gr.Video('https://github.com/myshell-ai/OpenVoice/assets/40556743/3cba936f-82bf-476c-9e52-09f0f417bb2f', autoplay=True) 228 | 229 | with gr.Row(): 230 | gr.HTML(wrapped_markdown_content) 231 | 232 | with gr.Row(): 233 | with gr.Column(): 234 | input_text_gr = gr.Textbox( 235 | label="Text Prompt", 236 | info="One or two sentences at a time is better. Up to 200 text characters.", 237 | value="He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.", 238 | ) 239 | style_gr = gr.Dropdown( 240 | label="Style", 241 | info="Select a style of output audio for the synthesised speech. (Chinese only support 'default' now)", 242 | choices=['default', 'whispering', 'cheerful', 'terrified', 'angry', 'sad', 'friendly'], 243 | max_choices=1, 244 | value="default", 245 | ) 246 | ref_gr = gr.Audio( 247 | label="Reference Audio", 248 | info="Click on the ✎ button to upload your own target speaker audio", 249 | type="filepath", 250 | value="resources/demo_speaker2.mp3", 251 | ) 252 | tos_gr = gr.Checkbox( 253 | label="Agree", 254 | value=False, 255 | info="I agree to the terms of the cc-by-nc-4.0 license-: https://github.com/myshell-ai/OpenVoice/blob/main/LICENSE", 256 | ) 257 | 258 | tts_button = gr.Button("Send", elem_id="send-btn", visible=True) 259 | 260 | 261 | with gr.Column(): 262 | out_text_gr = gr.Text(label="Info") 263 | audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True) 264 | ref_audio_gr = gr.Audio(label="Reference Audio Used") 265 | 266 | gr.Examples(examples, 267 | label="Examples", 268 | inputs=[input_text_gr, style_gr, ref_gr, tos_gr], 269 | outputs=[out_text_gr, audio_gr, ref_audio_gr], 270 | fn=predict, 271 | cache_examples=False,) 272 | tts_button.click(predict, [input_text_gr, style_gr, ref_gr, tos_gr], outputs=[out_text_gr, audio_gr, ref_audio_gr]) 273 | 274 | demo.queue() 275 | demo.launch(debug=True, show_api=True, share=args.share) 276 | -------------------------------------------------------------------------------- /nodes/openvoice/se_extractor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import hashlib 5 | import librosa 6 | import base64 7 | from glob import glob 8 | import numpy as np 9 | from pydub import AudioSegment 10 | from faster_whisper import WhisperModel 11 | import hashlib 12 | import base64 13 | import librosa 14 | from whisper_timestamped.transcribe import get_audio_tensor, get_vad_segments 15 | import shutil 16 | 17 | # Run on GPU with FP16 18 | # model = None 19 | def split_audio_whisper(audio_path, audio_name, target_dir='processed',model=None): 20 | # global model 21 | if model is None: 22 | model_size = "medium" 23 | model = WhisperModel(model_size, device="cuda", compute_type="float16") 24 | audio = AudioSegment.from_file(audio_path) 25 | max_len = len(audio) 26 | 27 | target_folder = os.path.join(target_dir, audio_name) 28 | 29 | segments, info = model.transcribe(audio_path, beam_size=5, word_timestamps=True) 30 | segments = list(segments) 31 | 32 | # create directory 33 | os.makedirs(target_folder, exist_ok=True) 34 | wavs_folder = os.path.join(target_folder, 'wavs') 35 | os.makedirs(wavs_folder, exist_ok=True) 36 | 37 | # segments 38 | s_ind = 0 39 | start_time = None 40 | 41 | for k, w in enumerate(segments): 42 | # process with the time 43 | if k == 0: 44 | start_time = max(0, w.start) 45 | 46 | end_time = w.end 47 | 48 | # calculate confidence 49 | if len(w.words) > 0: 50 | confidence = sum([s.probability for s in w.words]) / len(w.words) 51 | else: 52 | confidence = 0. 53 | # clean text 54 | text = w.text.replace('...', '') 55 | 56 | # left 0.08s for each audios 57 | audio_seg = audio[int( start_time * 1000) : min(max_len, int(end_time * 1000) + 80)] 58 | 59 | # segment file name 60 | fname = f"{audio_name}_seg{s_ind}.wav" 61 | 62 | # filter out the segment shorter than 1.5s and longer than 20s 63 | save = audio_seg.duration_seconds > 1.5 and \ 64 | audio_seg.duration_seconds < 20. and \ 65 | len(text) >= 2 and len(text) < 200 66 | 67 | if save: 68 | output_file = os.path.join(wavs_folder, fname) 69 | audio_seg.export(output_file, format='wav') 70 | 71 | if k < len(segments) - 1: 72 | start_time = max(0, segments[k+1].start - 0.08) 73 | 74 | s_ind = s_ind + 1 75 | return wavs_folder 76 | 77 | 78 | def split_audio_vad(audio_path, audio_name, target_dir, split_seconds=10.0): 79 | SAMPLE_RATE = 16000 80 | audio_vad = get_audio_tensor(audio_path) 81 | segments = get_vad_segments( 82 | audio_vad, 83 | output_sample=True, 84 | min_speech_duration=0.1, 85 | min_silence_duration=1, 86 | method="silero", 87 | ) 88 | segments = [(seg["start"], seg["end"]) for seg in segments] 89 | segments = [(float(s) / SAMPLE_RATE, float(e) / SAMPLE_RATE) for s,e in segments] 90 | print(segments) 91 | audio_active = AudioSegment.silent(duration=0) 92 | audio = AudioSegment.from_file(audio_path) 93 | 94 | for start_time, end_time in segments: 95 | audio_active += audio[int(start_time * 1000):int(end_time * 1000)] 96 | 97 | audio_dur = audio_active.duration_seconds 98 | print(f'after vad: dur = {audio_dur}') 99 | target_folder = os.path.join(target_dir, audio_name) 100 | wavs_folder = os.path.join(target_folder, 'wavs') 101 | os.makedirs(wavs_folder, exist_ok=True) 102 | start_time = 0. 103 | count = 0 104 | num_splits = int(np.round(audio_dur / split_seconds)) 105 | 106 | if num_splits <= 0: 107 | # Adjust split_seconds to the audio duration to ensure at least one segment 108 | split_seconds = audio_dur 109 | num_splits = 1 110 | 111 | interval = audio_dur / num_splits 112 | 113 | for i in range(num_splits): 114 | end_time = min(start_time + interval, audio_dur) 115 | if i == num_splits - 1: 116 | end_time = audio_dur 117 | output_file = f"{wavs_folder}/{audio_name}_seg{count}.wav" 118 | audio_seg = audio_active[int(start_time * 1000): int(end_time * 1000)] 119 | audio_seg.export(output_file, format='wav') 120 | start_time = end_time 121 | count += 1 122 | return wavs_folder 123 | 124 | def hash_numpy_array(audio_path): 125 | array, _ = librosa.load(audio_path, sr=None, mono=True) 126 | # Convert the array to bytes 127 | array_bytes = array.tobytes() 128 | # Calculate the hash of the array bytes 129 | hash_object = hashlib.sha256(array_bytes) 130 | hash_value = hash_object.digest() 131 | # Convert the hash value to base64 132 | base64_value = base64.b64encode(hash_value) 133 | return base64_value.decode('utf-8')[:16].replace('/', '_^') 134 | 135 | def get_se(audio_path, vc_model, target_dir='processed', whisper_model=None): 136 | device = vc_model.device 137 | version = vc_model.version 138 | print("OpenVoice version:", version) 139 | 140 | audio_name = f"{os.path.basename(audio_path).rsplit('.', 1)[0]}_{version}_{hash_numpy_array(audio_path)}" 141 | se_path = os.path.join(target_dir, audio_name, 'se.pth') 142 | 143 | # if os.path.isfile(se_path): 144 | # se = torch.load(se_path).to(device) 145 | # return se, audio_name 146 | # if os.path.isdir(audio_path): 147 | # wavs_folder = audio_path 148 | 149 | vad= whisper_model==None 150 | 151 | if vad: 152 | wavs_folder = split_audio_vad(audio_path, target_dir=target_dir, audio_name=audio_name) 153 | else: 154 | wavs_folder = split_audio_whisper(audio_path, target_dir=target_dir, audio_name=audio_name,model=whisper_model) 155 | 156 | audio_segs = glob(f'{wavs_folder}/*.wav') 157 | if len(audio_segs) == 0 and os.path.exists(audio_path) and audio_path.lower().endswith('.wav'): 158 | audio_segs = [audio_path] 159 | 160 | if len(audio_segs) == 0: 161 | raise NotImplementedError('No audio segments found!') 162 | 163 | return vc_model.extract_se(audio_segs, se_save_path=se_path), audio_name 164 | 165 | -------------------------------------------------------------------------------- /nodes/openvoice/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from openvoice.text import cleaners 3 | from openvoice.text.symbols import symbols 4 | 5 | 6 | # Mappings from symbol to numeric ID and vice versa: 7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 9 | 10 | 11 | def text_to_sequence(text, symbols, cleaner_names): 12 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 13 | Args: 14 | text: string to convert to a sequence 15 | cleaner_names: names of the cleaner functions to run the text through 16 | Returns: 17 | List of integers corresponding to the symbols in the text 18 | ''' 19 | sequence = [] 20 | symbol_to_id = {s: i for i, s in enumerate(symbols)} 21 | clean_text = _clean_text(text, cleaner_names) 22 | print(clean_text) 23 | print(f" length:{len(clean_text)}") 24 | for symbol in clean_text: 25 | if symbol not in symbol_to_id.keys(): 26 | continue 27 | symbol_id = symbol_to_id[symbol] 28 | sequence += [symbol_id] 29 | print(f" length:{len(sequence)}") 30 | return sequence 31 | 32 | 33 | def cleaned_text_to_sequence(cleaned_text, symbols): 34 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 35 | Args: 36 | text: string to convert to a sequence 37 | Returns: 38 | List of integers corresponding to the symbols in the text 39 | ''' 40 | symbol_to_id = {s: i for i, s in enumerate(symbols)} 41 | sequence = [symbol_to_id[symbol] for symbol in cleaned_text if symbol in symbol_to_id.keys()] 42 | return sequence 43 | 44 | 45 | 46 | from openvoice.text.symbols import language_tone_start_map 47 | def cleaned_text_to_sequence_vits2(cleaned_text, tones, language, symbols, languages): 48 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 49 | Args: 50 | text: string to convert to a sequence 51 | Returns: 52 | List of integers corresponding to the symbols in the text 53 | """ 54 | symbol_to_id = {s: i for i, s in enumerate(symbols)} 55 | language_id_map = {s: i for i, s in enumerate(languages)} 56 | phones = [symbol_to_id[symbol] for symbol in cleaned_text] 57 | tone_start = language_tone_start_map[language] 58 | tones = [i + tone_start for i in tones] 59 | lang_id = language_id_map[language] 60 | lang_ids = [lang_id for i in phones] 61 | return phones, tones, lang_ids 62 | 63 | 64 | def sequence_to_text(sequence): 65 | '''Converts a sequence of IDs back to a string''' 66 | result = '' 67 | for symbol_id in sequence: 68 | s = _id_to_symbol[symbol_id] 69 | result += s 70 | return result 71 | 72 | 73 | def _clean_text(text, cleaner_names): 74 | for name in cleaner_names: 75 | cleaner = getattr(cleaners, name) 76 | if not cleaner: 77 | raise Exception('Unknown cleaner: %s' % name) 78 | text = cleaner(text) 79 | return text 80 | -------------------------------------------------------------------------------- /nodes/openvoice/text/cleaners.py: -------------------------------------------------------------------------------- 1 | import re 2 | from openvoice.text.english import english_to_lazy_ipa, english_to_ipa2, english_to_lazy_ipa2 3 | from openvoice.text.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo, chinese_to_romaji, chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2 4 | 5 | def cjke_cleaners2(text): 6 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', 7 | lambda x: chinese_to_ipa(x.group(1))+' ', text) 8 | text = re.sub(r'\[JA\](.*?)\[JA\]', 9 | lambda x: japanese_to_ipa2(x.group(1))+' ', text) 10 | text = re.sub(r'\[KO\](.*?)\[KO\]', 11 | lambda x: korean_to_ipa(x.group(1))+' ', text) 12 | text = re.sub(r'\[EN\](.*?)\[EN\]', 13 | lambda x: english_to_ipa2(x.group(1))+' ', text) 14 | text = re.sub(r'\s+$', '', text) 15 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 16 | return text -------------------------------------------------------------------------------- /nodes/openvoice/text/english.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | 16 | # Regular expression matching whitespace: 17 | 18 | 19 | import re 20 | import inflect 21 | from unidecode import unidecode 22 | import eng_to_ipa as ipa 23 | _inflect = inflect.engine() 24 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 25 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 26 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 27 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 28 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 29 | _number_re = re.compile(r'[0-9]+') 30 | 31 | # List of (regular expression, replacement) pairs for abbreviations: 32 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 33 | ('mrs', 'misess'), 34 | ('mr', 'mister'), 35 | ('dr', 'doctor'), 36 | ('st', 'saint'), 37 | ('co', 'company'), 38 | ('jr', 'junior'), 39 | ('maj', 'major'), 40 | ('gen', 'general'), 41 | ('drs', 'doctors'), 42 | ('rev', 'reverend'), 43 | ('lt', 'lieutenant'), 44 | ('hon', 'honorable'), 45 | ('sgt', 'sergeant'), 46 | ('capt', 'captain'), 47 | ('esq', 'esquire'), 48 | ('ltd', 'limited'), 49 | ('col', 'colonel'), 50 | ('ft', 'fort'), 51 | ]] 52 | 53 | 54 | # List of (ipa, lazy ipa) pairs: 55 | _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 56 | ('r', 'ɹ'), 57 | ('æ', 'e'), 58 | ('ɑ', 'a'), 59 | ('ɔ', 'o'), 60 | ('ð', 'z'), 61 | ('θ', 's'), 62 | ('ɛ', 'e'), 63 | ('ɪ', 'i'), 64 | ('ʊ', 'u'), 65 | ('ʒ', 'ʥ'), 66 | ('ʤ', 'ʥ'), 67 | ('ˈ', '↓'), 68 | ]] 69 | 70 | # List of (ipa, lazy ipa2) pairs: 71 | _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 72 | ('r', 'ɹ'), 73 | ('ð', 'z'), 74 | ('θ', 's'), 75 | ('ʒ', 'ʑ'), 76 | ('ʤ', 'dʑ'), 77 | ('ˈ', '↓'), 78 | ]] 79 | 80 | # List of (ipa, ipa2) pairs 81 | _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 82 | ('r', 'ɹ'), 83 | ('ʤ', 'dʒ'), 84 | ('ʧ', 'tʃ') 85 | ]] 86 | 87 | 88 | def expand_abbreviations(text): 89 | for regex, replacement in _abbreviations: 90 | text = re.sub(regex, replacement, text) 91 | return text 92 | 93 | 94 | def collapse_whitespace(text): 95 | return re.sub(r'\s+', ' ', text) 96 | 97 | 98 | def _remove_commas(m): 99 | return m.group(1).replace(',', '') 100 | 101 | 102 | def _expand_decimal_point(m): 103 | return m.group(1).replace('.', ' point ') 104 | 105 | 106 | def _expand_dollars(m): 107 | match = m.group(1) 108 | parts = match.split('.') 109 | if len(parts) > 2: 110 | return match + ' dollars' # Unexpected format 111 | dollars = int(parts[0]) if parts[0] else 0 112 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 113 | if dollars and cents: 114 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 115 | cent_unit = 'cent' if cents == 1 else 'cents' 116 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 117 | elif dollars: 118 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 119 | return '%s %s' % (dollars, dollar_unit) 120 | elif cents: 121 | cent_unit = 'cent' if cents == 1 else 'cents' 122 | return '%s %s' % (cents, cent_unit) 123 | else: 124 | return 'zero dollars' 125 | 126 | 127 | def _expand_ordinal(m): 128 | return _inflect.number_to_words(m.group(0)) 129 | 130 | 131 | def _expand_number(m): 132 | num = int(m.group(0)) 133 | if num > 1000 and num < 3000: 134 | if num == 2000: 135 | return 'two thousand' 136 | elif num > 2000 and num < 2010: 137 | return 'two thousand ' + _inflect.number_to_words(num % 100) 138 | elif num % 100 == 0: 139 | return _inflect.number_to_words(num // 100) + ' hundred' 140 | else: 141 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 142 | else: 143 | return _inflect.number_to_words(num, andword='') 144 | 145 | 146 | def normalize_numbers(text): 147 | text = re.sub(_comma_number_re, _remove_commas, text) 148 | text = re.sub(_pounds_re, r'\1 pounds', text) 149 | text = re.sub(_dollars_re, _expand_dollars, text) 150 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 151 | text = re.sub(_ordinal_re, _expand_ordinal, text) 152 | text = re.sub(_number_re, _expand_number, text) 153 | return text 154 | 155 | 156 | def mark_dark_l(text): 157 | return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text) 158 | 159 | 160 | def english_to_ipa(text): 161 | text = unidecode(text).lower() 162 | text = expand_abbreviations(text) 163 | text = normalize_numbers(text) 164 | phonemes = ipa.convert(text) 165 | phonemes = collapse_whitespace(phonemes) 166 | return phonemes 167 | 168 | 169 | def english_to_lazy_ipa(text): 170 | text = english_to_ipa(text) 171 | for regex, replacement in _lazy_ipa: 172 | text = re.sub(regex, replacement, text) 173 | return text 174 | 175 | 176 | def english_to_ipa2(text): 177 | text = english_to_ipa(text) 178 | text = mark_dark_l(text) 179 | for regex, replacement in _ipa_to_ipa2: 180 | text = re.sub(regex, replacement, text) 181 | return text.replace('...', '…') 182 | 183 | 184 | def english_to_lazy_ipa2(text): 185 | text = english_to_ipa(text) 186 | for regex, replacement in _lazy_ipa2: 187 | text = re.sub(regex, replacement, text) 188 | return text 189 | -------------------------------------------------------------------------------- /nodes/openvoice/text/mandarin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | # from pypinyin import lazy_pinyin, BOPOMOFO 5 | import jieba 6 | import cn2an 7 | import logging 8 | 9 | 10 | # List of (Latin alphabet, bopomofo) pairs: 11 | _latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 12 | ('a', 'ㄟˉ'), 13 | ('b', 'ㄅㄧˋ'), 14 | ('c', 'ㄙㄧˉ'), 15 | ('d', 'ㄉㄧˋ'), 16 | ('e', 'ㄧˋ'), 17 | ('f', 'ㄝˊㄈㄨˋ'), 18 | ('g', 'ㄐㄧˋ'), 19 | ('h', 'ㄝˇㄑㄩˋ'), 20 | ('i', 'ㄞˋ'), 21 | ('j', 'ㄐㄟˋ'), 22 | ('k', 'ㄎㄟˋ'), 23 | ('l', 'ㄝˊㄛˋ'), 24 | ('m', 'ㄝˊㄇㄨˋ'), 25 | ('n', 'ㄣˉ'), 26 | ('o', 'ㄡˉ'), 27 | ('p', 'ㄆㄧˉ'), 28 | ('q', 'ㄎㄧㄡˉ'), 29 | ('r', 'ㄚˋ'), 30 | ('s', 'ㄝˊㄙˋ'), 31 | ('t', 'ㄊㄧˋ'), 32 | ('u', 'ㄧㄡˉ'), 33 | ('v', 'ㄨㄧˉ'), 34 | ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'), 35 | ('x', 'ㄝˉㄎㄨˋㄙˋ'), 36 | ('y', 'ㄨㄞˋ'), 37 | ('z', 'ㄗㄟˋ') 38 | ]] 39 | 40 | # List of (bopomofo, romaji) pairs: 41 | _bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [ 42 | ('ㄅㄛ', 'p⁼wo'), 43 | ('ㄆㄛ', 'pʰwo'), 44 | ('ㄇㄛ', 'mwo'), 45 | ('ㄈㄛ', 'fwo'), 46 | ('ㄅ', 'p⁼'), 47 | ('ㄆ', 'pʰ'), 48 | ('ㄇ', 'm'), 49 | ('ㄈ', 'f'), 50 | ('ㄉ', 't⁼'), 51 | ('ㄊ', 'tʰ'), 52 | ('ㄋ', 'n'), 53 | ('ㄌ', 'l'), 54 | ('ㄍ', 'k⁼'), 55 | ('ㄎ', 'kʰ'), 56 | ('ㄏ', 'h'), 57 | ('ㄐ', 'ʧ⁼'), 58 | ('ㄑ', 'ʧʰ'), 59 | ('ㄒ', 'ʃ'), 60 | ('ㄓ', 'ʦ`⁼'), 61 | ('ㄔ', 'ʦ`ʰ'), 62 | ('ㄕ', 's`'), 63 | ('ㄖ', 'ɹ`'), 64 | ('ㄗ', 'ʦ⁼'), 65 | ('ㄘ', 'ʦʰ'), 66 | ('ㄙ', 's'), 67 | ('ㄚ', 'a'), 68 | ('ㄛ', 'o'), 69 | ('ㄜ', 'ə'), 70 | ('ㄝ', 'e'), 71 | ('ㄞ', 'ai'), 72 | ('ㄟ', 'ei'), 73 | ('ㄠ', 'au'), 74 | ('ㄡ', 'ou'), 75 | ('ㄧㄢ', 'yeNN'), 76 | ('ㄢ', 'aNN'), 77 | ('ㄧㄣ', 'iNN'), 78 | ('ㄣ', 'əNN'), 79 | ('ㄤ', 'aNg'), 80 | ('ㄧㄥ', 'iNg'), 81 | ('ㄨㄥ', 'uNg'), 82 | ('ㄩㄥ', 'yuNg'), 83 | ('ㄥ', 'əNg'), 84 | ('ㄦ', 'əɻ'), 85 | ('ㄧ', 'i'), 86 | ('ㄨ', 'u'), 87 | ('ㄩ', 'ɥ'), 88 | ('ˉ', '→'), 89 | ('ˊ', '↑'), 90 | ('ˇ', '↓↑'), 91 | ('ˋ', '↓'), 92 | ('˙', ''), 93 | (',', ','), 94 | ('。', '.'), 95 | ('!', '!'), 96 | ('?', '?'), 97 | ('—', '-') 98 | ]] 99 | 100 | # List of (romaji, ipa) pairs: 101 | _romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 102 | ('ʃy', 'ʃ'), 103 | ('ʧʰy', 'ʧʰ'), 104 | ('ʧ⁼y', 'ʧ⁼'), 105 | ('NN', 'n'), 106 | ('Ng', 'ŋ'), 107 | ('y', 'j'), 108 | ('h', 'x') 109 | ]] 110 | 111 | # List of (bopomofo, ipa) pairs: 112 | _bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 113 | ('ㄅㄛ', 'p⁼wo'), 114 | ('ㄆㄛ', 'pʰwo'), 115 | ('ㄇㄛ', 'mwo'), 116 | ('ㄈㄛ', 'fwo'), 117 | ('ㄅ', 'p⁼'), 118 | ('ㄆ', 'pʰ'), 119 | ('ㄇ', 'm'), 120 | ('ㄈ', 'f'), 121 | ('ㄉ', 't⁼'), 122 | ('ㄊ', 'tʰ'), 123 | ('ㄋ', 'n'), 124 | ('ㄌ', 'l'), 125 | ('ㄍ', 'k⁼'), 126 | ('ㄎ', 'kʰ'), 127 | ('ㄏ', 'x'), 128 | ('ㄐ', 'tʃ⁼'), 129 | ('ㄑ', 'tʃʰ'), 130 | ('ㄒ', 'ʃ'), 131 | ('ㄓ', 'ts`⁼'), 132 | ('ㄔ', 'ts`ʰ'), 133 | ('ㄕ', 's`'), 134 | ('ㄖ', 'ɹ`'), 135 | ('ㄗ', 'ts⁼'), 136 | ('ㄘ', 'tsʰ'), 137 | ('ㄙ', 's'), 138 | ('ㄚ', 'a'), 139 | ('ㄛ', 'o'), 140 | ('ㄜ', 'ə'), 141 | ('ㄝ', 'ɛ'), 142 | ('ㄞ', 'aɪ'), 143 | ('ㄟ', 'eɪ'), 144 | ('ㄠ', 'ɑʊ'), 145 | ('ㄡ', 'oʊ'), 146 | ('ㄧㄢ', 'jɛn'), 147 | ('ㄩㄢ', 'ɥæn'), 148 | ('ㄢ', 'an'), 149 | ('ㄧㄣ', 'in'), 150 | ('ㄩㄣ', 'ɥn'), 151 | ('ㄣ', 'ən'), 152 | ('ㄤ', 'ɑŋ'), 153 | ('ㄧㄥ', 'iŋ'), 154 | ('ㄨㄥ', 'ʊŋ'), 155 | ('ㄩㄥ', 'jʊŋ'), 156 | ('ㄥ', 'əŋ'), 157 | ('ㄦ', 'əɻ'), 158 | ('ㄧ', 'i'), 159 | ('ㄨ', 'u'), 160 | ('ㄩ', 'ɥ'), 161 | ('ˉ', '→'), 162 | ('ˊ', '↑'), 163 | ('ˇ', '↓↑'), 164 | ('ˋ', '↓'), 165 | ('˙', ''), 166 | (',', ','), 167 | ('。', '.'), 168 | ('!', '!'), 169 | ('?', '?'), 170 | ('—', '-') 171 | ]] 172 | 173 | # List of (bopomofo, ipa2) pairs: 174 | _bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 175 | ('ㄅㄛ', 'pwo'), 176 | ('ㄆㄛ', 'pʰwo'), 177 | ('ㄇㄛ', 'mwo'), 178 | ('ㄈㄛ', 'fwo'), 179 | ('ㄅ', 'p'), 180 | ('ㄆ', 'pʰ'), 181 | ('ㄇ', 'm'), 182 | ('ㄈ', 'f'), 183 | ('ㄉ', 't'), 184 | ('ㄊ', 'tʰ'), 185 | ('ㄋ', 'n'), 186 | ('ㄌ', 'l'), 187 | ('ㄍ', 'k'), 188 | ('ㄎ', 'kʰ'), 189 | ('ㄏ', 'h'), 190 | ('ㄐ', 'tɕ'), 191 | ('ㄑ', 'tɕʰ'), 192 | ('ㄒ', 'ɕ'), 193 | ('ㄓ', 'tʂ'), 194 | ('ㄔ', 'tʂʰ'), 195 | ('ㄕ', 'ʂ'), 196 | ('ㄖ', 'ɻ'), 197 | ('ㄗ', 'ts'), 198 | ('ㄘ', 'tsʰ'), 199 | ('ㄙ', 's'), 200 | ('ㄚ', 'a'), 201 | ('ㄛ', 'o'), 202 | ('ㄜ', 'ɤ'), 203 | ('ㄝ', 'ɛ'), 204 | ('ㄞ', 'aɪ'), 205 | ('ㄟ', 'eɪ'), 206 | ('ㄠ', 'ɑʊ'), 207 | ('ㄡ', 'oʊ'), 208 | ('ㄧㄢ', 'jɛn'), 209 | ('ㄩㄢ', 'yæn'), 210 | ('ㄢ', 'an'), 211 | ('ㄧㄣ', 'in'), 212 | ('ㄩㄣ', 'yn'), 213 | ('ㄣ', 'ən'), 214 | ('ㄤ', 'ɑŋ'), 215 | ('ㄧㄥ', 'iŋ'), 216 | ('ㄨㄥ', 'ʊŋ'), 217 | ('ㄩㄥ', 'jʊŋ'), 218 | ('ㄥ', 'ɤŋ'), 219 | ('ㄦ', 'əɻ'), 220 | ('ㄧ', 'i'), 221 | ('ㄨ', 'u'), 222 | ('ㄩ', 'y'), 223 | ('ˉ', '˥'), 224 | ('ˊ', '˧˥'), 225 | ('ˇ', '˨˩˦'), 226 | ('ˋ', '˥˩'), 227 | ('˙', ''), 228 | (',', ','), 229 | ('。', '.'), 230 | ('!', '!'), 231 | ('?', '?'), 232 | ('—', '-') 233 | ]] 234 | 235 | 236 | def number_to_chinese(text): 237 | numbers = re.findall(r'\d+(?:\.?\d+)?', text) 238 | for number in numbers: 239 | text = text.replace(number, cn2an.an2cn(number), 1) 240 | return text 241 | 242 | 243 | def chinese_to_bopomofo(text): 244 | text = text.replace('、', ',').replace(';', ',').replace(':', ',') 245 | # words = jieba.lcut(text, cut_all=False) 246 | # text = '' 247 | # for word in words: 248 | # bopomofos = lazy_pinyin(word, BOPOMOFO) 249 | # if not re.search('[\u4e00-\u9fff]', word): 250 | # text += word 251 | # continue 252 | # for i in range(len(bopomofos)): 253 | # bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i]) 254 | # if text != '': 255 | # text += ' ' 256 | # text += ''.join(bopomofos) 257 | return text 258 | 259 | 260 | def latin_to_bopomofo(text): 261 | for regex, replacement in _latin_to_bopomofo: 262 | text = re.sub(regex, replacement, text) 263 | return text 264 | 265 | 266 | def bopomofo_to_romaji(text): 267 | for regex, replacement in _bopomofo_to_romaji: 268 | text = re.sub(regex, replacement, text) 269 | return text 270 | 271 | 272 | def bopomofo_to_ipa(text): 273 | for regex, replacement in _bopomofo_to_ipa: 274 | text = re.sub(regex, replacement, text) 275 | return text 276 | 277 | 278 | def bopomofo_to_ipa2(text): 279 | for regex, replacement in _bopomofo_to_ipa2: 280 | text = re.sub(regex, replacement, text) 281 | return text 282 | 283 | 284 | def chinese_to_romaji(text): 285 | text = number_to_chinese(text) 286 | text = chinese_to_bopomofo(text) 287 | text = latin_to_bopomofo(text) 288 | text = bopomofo_to_romaji(text) 289 | text = re.sub('i([aoe])', r'y\1', text) 290 | text = re.sub('u([aoəe])', r'w\1', text) 291 | text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', 292 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') 293 | text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) 294 | return text 295 | 296 | 297 | def chinese_to_lazy_ipa(text): 298 | text = chinese_to_romaji(text) 299 | for regex, replacement in _romaji_to_ipa: 300 | text = re.sub(regex, replacement, text) 301 | return text 302 | 303 | 304 | def chinese_to_ipa(text): 305 | text = number_to_chinese(text) 306 | text = chinese_to_bopomofo(text) 307 | text = latin_to_bopomofo(text) 308 | text = bopomofo_to_ipa(text) 309 | text = re.sub('i([aoe])', r'j\1', text) 310 | text = re.sub('u([aoəe])', r'w\1', text) 311 | text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', 312 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') 313 | text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) 314 | return text 315 | 316 | 317 | def chinese_to_ipa2(text): 318 | text = number_to_chinese(text) 319 | text = chinese_to_bopomofo(text) 320 | text = latin_to_bopomofo(text) 321 | text = bopomofo_to_ipa2(text) 322 | text = re.sub(r'i([aoe])', r'j\1', text) 323 | text = re.sub(r'u([aoəe])', r'w\1', text) 324 | text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text) 325 | text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text) 326 | return text 327 | -------------------------------------------------------------------------------- /nodes/openvoice/text/symbols.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Defines the set of symbols used in text input to the model. 3 | ''' 4 | 5 | # japanese_cleaners 6 | # _pad = '_' 7 | # _punctuation = ',.!?-' 8 | # _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ ' 9 | 10 | 11 | '''# japanese_cleaners2 12 | _pad = '_' 13 | _punctuation = ',.!?-~…' 14 | _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ ' 15 | ''' 16 | 17 | 18 | '''# korean_cleaners 19 | _pad = '_' 20 | _punctuation = ',.!?…~' 21 | _letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ ' 22 | ''' 23 | 24 | '''# chinese_cleaners 25 | _pad = '_' 26 | _punctuation = ',。!?—…' 27 | _letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ ' 28 | ''' 29 | 30 | # # zh_ja_mixture_cleaners 31 | # _pad = '_' 32 | # _punctuation = ',.!?-~…' 33 | # _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ ' 34 | 35 | 36 | '''# sanskrit_cleaners 37 | _pad = '_' 38 | _punctuation = '।' 39 | _letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ ' 40 | ''' 41 | 42 | '''# cjks_cleaners 43 | _pad = '_' 44 | _punctuation = ',.!?-~…' 45 | _letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ ' 46 | ''' 47 | 48 | '''# thai_cleaners 49 | _pad = '_' 50 | _punctuation = '.!? ' 51 | _letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์' 52 | ''' 53 | 54 | # # cjke_cleaners2 55 | _pad = '_' 56 | _punctuation = ',.!?-~…' 57 | _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ ' 58 | 59 | 60 | '''# shanghainese_cleaners 61 | _pad = '_' 62 | _punctuation = ',.!?…' 63 | _letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 ' 64 | ''' 65 | 66 | '''# chinese_dialect_cleaners 67 | _pad = '_' 68 | _punctuation = ',.!?~…─' 69 | _letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ ' 70 | ''' 71 | 72 | # Export all symbols: 73 | symbols = [_pad] + list(_punctuation) + list(_letters) 74 | 75 | # Special symbol ids 76 | SPACE_ID = symbols.index(" ") 77 | 78 | num_ja_tones = 1 79 | num_kr_tones = 1 80 | num_zh_tones = 6 81 | num_en_tones = 4 82 | 83 | language_tone_start_map = { 84 | "ZH": 0, 85 | "JP": num_zh_tones, 86 | "EN": num_zh_tones + num_ja_tones, 87 | 'KR': num_zh_tones + num_ja_tones + num_en_tones, 88 | } -------------------------------------------------------------------------------- /nodes/openvoice/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import numpy as np 5 | 6 | 7 | DEFAULT_MIN_BIN_WIDTH = 1e-3 8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 9 | DEFAULT_MIN_DERIVATIVE = 1e-3 10 | 11 | 12 | def piecewise_rational_quadratic_transform( 13 | inputs, 14 | unnormalized_widths, 15 | unnormalized_heights, 16 | unnormalized_derivatives, 17 | inverse=False, 18 | tails=None, 19 | tail_bound=1.0, 20 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 21 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 22 | min_derivative=DEFAULT_MIN_DERIVATIVE, 23 | ): 24 | if tails is None: 25 | spline_fn = rational_quadratic_spline 26 | spline_kwargs = {} 27 | else: 28 | spline_fn = unconstrained_rational_quadratic_spline 29 | spline_kwargs = {"tails": tails, "tail_bound": tail_bound} 30 | 31 | outputs, logabsdet = spline_fn( 32 | inputs=inputs, 33 | unnormalized_widths=unnormalized_widths, 34 | unnormalized_heights=unnormalized_heights, 35 | unnormalized_derivatives=unnormalized_derivatives, 36 | inverse=inverse, 37 | min_bin_width=min_bin_width, 38 | min_bin_height=min_bin_height, 39 | min_derivative=min_derivative, 40 | **spline_kwargs 41 | ) 42 | return outputs, logabsdet 43 | 44 | 45 | def searchsorted(bin_locations, inputs, eps=1e-6): 46 | bin_locations[..., -1] += eps 47 | return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 48 | 49 | 50 | def unconstrained_rational_quadratic_spline( 51 | inputs, 52 | unnormalized_widths, 53 | unnormalized_heights, 54 | unnormalized_derivatives, 55 | inverse=False, 56 | tails="linear", 57 | tail_bound=1.0, 58 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 59 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 60 | min_derivative=DEFAULT_MIN_DERIVATIVE, 61 | ): 62 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 63 | outside_interval_mask = ~inside_interval_mask 64 | 65 | outputs = torch.zeros_like(inputs) 66 | logabsdet = torch.zeros_like(inputs) 67 | 68 | if tails == "linear": 69 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 70 | constant = np.log(np.exp(1 - min_derivative) - 1) 71 | unnormalized_derivatives[..., 0] = constant 72 | unnormalized_derivatives[..., -1] = constant 73 | 74 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 75 | logabsdet[outside_interval_mask] = 0 76 | else: 77 | raise RuntimeError("{} tails are not implemented.".format(tails)) 78 | 79 | ( 80 | outputs[inside_interval_mask], 81 | logabsdet[inside_interval_mask], 82 | ) = rational_quadratic_spline( 83 | inputs=inputs[inside_interval_mask], 84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 87 | inverse=inverse, 88 | left=-tail_bound, 89 | right=tail_bound, 90 | bottom=-tail_bound, 91 | top=tail_bound, 92 | min_bin_width=min_bin_width, 93 | min_bin_height=min_bin_height, 94 | min_derivative=min_derivative, 95 | ) 96 | 97 | return outputs, logabsdet 98 | 99 | 100 | def rational_quadratic_spline( 101 | inputs, 102 | unnormalized_widths, 103 | unnormalized_heights, 104 | unnormalized_derivatives, 105 | inverse=False, 106 | left=0.0, 107 | right=1.0, 108 | bottom=0.0, 109 | top=1.0, 110 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 111 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 112 | min_derivative=DEFAULT_MIN_DERIVATIVE, 113 | ): 114 | if torch.min(inputs) < left or torch.max(inputs) > right: 115 | raise ValueError("Input to a transform is not within its domain") 116 | 117 | num_bins = unnormalized_widths.shape[-1] 118 | 119 | if min_bin_width * num_bins > 1.0: 120 | raise ValueError("Minimal bin width too large for the number of bins") 121 | if min_bin_height * num_bins > 1.0: 122 | raise ValueError("Minimal bin height too large for the number of bins") 123 | 124 | widths = F.softmax(unnormalized_widths, dim=-1) 125 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 126 | cumwidths = torch.cumsum(widths, dim=-1) 127 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) 128 | cumwidths = (right - left) * cumwidths + left 129 | cumwidths[..., 0] = left 130 | cumwidths[..., -1] = right 131 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 132 | 133 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 134 | 135 | heights = F.softmax(unnormalized_heights, dim=-1) 136 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 137 | cumheights = torch.cumsum(heights, dim=-1) 138 | cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) 139 | cumheights = (top - bottom) * cumheights + bottom 140 | cumheights[..., 0] = bottom 141 | cumheights[..., -1] = top 142 | heights = cumheights[..., 1:] - cumheights[..., :-1] 143 | 144 | if inverse: 145 | bin_idx = searchsorted(cumheights, inputs)[..., None] 146 | else: 147 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 148 | 149 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 150 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 151 | 152 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 153 | delta = heights / widths 154 | input_delta = delta.gather(-1, bin_idx)[..., 0] 155 | 156 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 157 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 158 | 159 | input_heights = heights.gather(-1, bin_idx)[..., 0] 160 | 161 | if inverse: 162 | a = (inputs - input_cumheights) * ( 163 | input_derivatives + input_derivatives_plus_one - 2 * input_delta 164 | ) + input_heights * (input_delta - input_derivatives) 165 | b = input_heights * input_derivatives - (inputs - input_cumheights) * ( 166 | input_derivatives + input_derivatives_plus_one - 2 * input_delta 167 | ) 168 | c = -input_delta * (inputs - input_cumheights) 169 | 170 | discriminant = b.pow(2) - 4 * a * c 171 | assert (discriminant >= 0).all() 172 | 173 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 174 | outputs = root * input_bin_widths + input_cumwidths 175 | 176 | theta_one_minus_theta = root * (1 - root) 177 | denominator = input_delta + ( 178 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta) 179 | * theta_one_minus_theta 180 | ) 181 | derivative_numerator = input_delta.pow(2) * ( 182 | input_derivatives_plus_one * root.pow(2) 183 | + 2 * input_delta * theta_one_minus_theta 184 | + input_derivatives * (1 - root).pow(2) 185 | ) 186 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 187 | 188 | return outputs, -logabsdet 189 | else: 190 | theta = (inputs - input_cumwidths) / input_bin_widths 191 | theta_one_minus_theta = theta * (1 - theta) 192 | 193 | numerator = input_heights * ( 194 | input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta 195 | ) 196 | denominator = input_delta + ( 197 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta) 198 | * theta_one_minus_theta 199 | ) 200 | outputs = input_cumheights + numerator / denominator 201 | 202 | derivative_numerator = input_delta.pow(2) * ( 203 | input_derivatives_plus_one * theta.pow(2) 204 | + 2 * input_delta * theta_one_minus_theta 205 | + input_derivatives * (1 - theta).pow(2) 206 | ) 207 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 208 | 209 | return outputs, logabsdet 210 | -------------------------------------------------------------------------------- /nodes/openvoice/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import numpy as np 4 | 5 | 6 | def get_hparams_from_file(config_path): 7 | with open(config_path, "r", encoding="utf-8") as f: 8 | data = f.read() 9 | config = json.loads(data) 10 | 11 | hparams = HParams(**config) 12 | return hparams 13 | 14 | class HParams: 15 | def __init__(self, **kwargs): 16 | for k, v in kwargs.items(): 17 | if type(v) == dict: 18 | v = HParams(**v) 19 | self[k] = v 20 | 21 | def keys(self): 22 | return self.__dict__.keys() 23 | 24 | def items(self): 25 | return self.__dict__.items() 26 | 27 | def values(self): 28 | return self.__dict__.values() 29 | 30 | def __len__(self): 31 | return len(self.__dict__) 32 | 33 | def __getitem__(self, key): 34 | return getattr(self, key) 35 | 36 | def __setitem__(self, key, value): 37 | return setattr(self, key, value) 38 | 39 | def __contains__(self, key): 40 | return key in self.__dict__ 41 | 42 | def __repr__(self): 43 | return self.__dict__.__repr__() 44 | 45 | 46 | def string_to_bits(string, pad_len=8): 47 | # Convert each character to its ASCII value 48 | ascii_values = [ord(char) for char in string] 49 | 50 | # Convert ASCII values to binary representation 51 | binary_values = [bin(value)[2:].zfill(8) for value in ascii_values] 52 | 53 | # Convert binary strings to integer arrays 54 | bit_arrays = [[int(bit) for bit in binary] for binary in binary_values] 55 | 56 | # Convert list of arrays to NumPy array 57 | numpy_array = np.array(bit_arrays) 58 | numpy_array_full = np.zeros((pad_len, 8), dtype=numpy_array.dtype) 59 | numpy_array_full[:, 2] = 1 60 | max_len = min(pad_len, len(numpy_array)) 61 | numpy_array_full[:max_len] = numpy_array[:max_len] 62 | return numpy_array_full 63 | 64 | 65 | def bits_to_string(bits_array): 66 | # Convert each row of the array to a binary string 67 | binary_values = [''.join(str(bit) for bit in row) for row in bits_array] 68 | 69 | # Convert binary strings to ASCII values 70 | ascii_values = [int(binary, 2) for binary in binary_values] 71 | 72 | # Convert ASCII values to characters 73 | output_string = ''.join(chr(value) for value in ascii_values) 74 | 75 | return output_string 76 | 77 | 78 | def split_sentence(text, min_len=10, language_str='[EN]'): 79 | if language_str in ['EN']: 80 | sentences = split_sentences_latin(text, min_len=min_len) 81 | else: 82 | sentences = split_sentences_zh(text, min_len=min_len) 83 | return sentences 84 | 85 | def split_sentences_latin(text, min_len=10): 86 | """Split Long sentences into list of short ones 87 | 88 | Args: 89 | str: Input sentences. 90 | 91 | Returns: 92 | List[str]: list of output sentences. 93 | """ 94 | # deal with dirty sentences 95 | text = re.sub('[。!?;]', '.', text) 96 | text = re.sub('[,]', ',', text) 97 | text = re.sub('[“”]', '"', text) 98 | text = re.sub('[‘’]', "'", text) 99 | text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text) 100 | text = re.sub('[\n\t ]+', ' ', text) 101 | text = re.sub('([,.!?;])', r'\1 $#!', text) 102 | # split 103 | sentences = [s.strip() for s in text.split('$#!')] 104 | if len(sentences[-1]) == 0: del sentences[-1] 105 | 106 | new_sentences = [] 107 | new_sent = [] 108 | count_len = 0 109 | for ind, sent in enumerate(sentences): 110 | # print(sent) 111 | new_sent.append(sent) 112 | count_len += len(sent.split(" ")) 113 | if count_len > min_len or ind == len(sentences) - 1: 114 | count_len = 0 115 | new_sentences.append(' '.join(new_sent)) 116 | new_sent = [] 117 | return merge_short_sentences_latin(new_sentences) 118 | 119 | 120 | def merge_short_sentences_latin(sens): 121 | """Avoid short sentences by merging them with the following sentence. 122 | 123 | Args: 124 | List[str]: list of input sentences. 125 | 126 | Returns: 127 | List[str]: list of output sentences. 128 | """ 129 | sens_out = [] 130 | for s in sens: 131 | # If the previous sentence is too short, merge them with 132 | # the current sentence. 133 | if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2: 134 | sens_out[-1] = sens_out[-1] + " " + s 135 | else: 136 | sens_out.append(s) 137 | try: 138 | if len(sens_out[-1].split(" ")) <= 2: 139 | sens_out[-2] = sens_out[-2] + " " + sens_out[-1] 140 | sens_out.pop(-1) 141 | except: 142 | pass 143 | return sens_out 144 | 145 | def split_sentences_zh(text, min_len=10): 146 | text = re.sub('[。!?;]', '.', text) 147 | text = re.sub('[,]', ',', text) 148 | # 将文本中的换行符、空格和制表符替换为空格 149 | text = re.sub('[\n\t ]+', ' ', text) 150 | # 在标点符号后添加一个空格 151 | text = re.sub('([,.!?;])', r'\1 $#!', text) 152 | # 分隔句子并去除前后空格 153 | # sentences = [s.strip() for s in re.split('(。|!|?|;)', text)] 154 | sentences = [s.strip() for s in text.split('$#!')] 155 | if len(sentences[-1]) == 0: del sentences[-1] 156 | 157 | new_sentences = [] 158 | new_sent = [] 159 | count_len = 0 160 | for ind, sent in enumerate(sentences): 161 | new_sent.append(sent) 162 | count_len += len(sent) 163 | if count_len > min_len or ind == len(sentences) - 1: 164 | count_len = 0 165 | new_sentences.append(' '.join(new_sent)) 166 | new_sent = [] 167 | return merge_short_sentences_zh(new_sentences) 168 | 169 | 170 | def merge_short_sentences_zh(sens): 171 | # return sens 172 | """Avoid short sentences by merging them with the following sentence. 173 | 174 | Args: 175 | List[str]: list of input sentences. 176 | 177 | Returns: 178 | List[str]: list of output sentences. 179 | """ 180 | sens_out = [] 181 | for s in sens: 182 | # If the previous sentense is too short, merge them with 183 | # the current sentence. 184 | if len(sens_out) > 0 and len(sens_out[-1]) <= 2: 185 | sens_out[-1] = sens_out[-1] + " " + s 186 | else: 187 | sens_out.append(s) 188 | try: 189 | if len(sens_out[-1]) <= 2: 190 | sens_out[-2] = sens_out[-2] + " " + sens_out[-1] 191 | sens_out.pop(-1) 192 | except: 193 | pass 194 | return sens_out -------------------------------------------------------------------------------- /nodes/openvoice_run.py: -------------------------------------------------------------------------------- 1 | from openvoice import se_extractor 2 | from openvoice.api import ToneColorConverter 3 | 4 | import comfy.model_management as mm 5 | 6 | import folder_paths 7 | 8 | import os,torch 9 | 10 | # 修改模型的本地缓存地址 11 | # os.environ['HF_HOME'] = os.path.join(folder_paths.models_dir,'chat_tts') 12 | 13 | def get_model_dir(m): 14 | try: 15 | return folder_paths.get_folder_paths(m)[0] 16 | except: 17 | return os.path.join(folder_paths.models_dir, m) 18 | 19 | ckpt_converter=get_model_dir('open_voice') 20 | 21 | 22 | # device = mm.get_torch_device() 23 | device="cuda:0" if torch.cuda.is_available() else "cpu" 24 | 25 | tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device) 26 | tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth') 27 | 28 | def run(reference_speaker="",src_path="",save_path="",whisper=None): 29 | 30 | if reference_speaker != "" and src_path!="": 31 | # Run the base speaker tts 32 | print("Ready for voice cloning!") 33 | 34 | temp=folder_paths.get_temp_directory() 35 | target_dir=os.path.join(temp,'processed') 36 | 37 | source_se, audio_name = se_extractor.get_se(src_path, tone_color_converter, target_dir=target_dir,whisper_model=whisper) 38 | 39 | target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir=target_dir,whisper_model= whisper) 40 | 41 | # Run the tone color converter 42 | # convert from file 43 | tone_color_converter.convert( 44 | audio_src_path=src_path, 45 | src_se=source_se, 46 | tgt_se=target_se, 47 | output_path=save_path) 48 | -------------------------------------------------------------------------------- /nodes/zh_normalization/README.md: -------------------------------------------------------------------------------- 1 | ## Supported NSW (Non-Standard-Word) Normalization 2 | 3 | |NSW type|raw|normalized| 4 | |:--|:-|:-| 5 | |serial number|电影中梁朝伟扮演的陈永仁的编号27149|电影中梁朝伟扮演的陈永仁的编号二七一四九| 6 | |cardinal|这块黄金重达324.75克
我们班的最高总分为583分|这块黄金重达三百二十四点七五克
我们班的最高总分为五百八十三分| 7 | |numeric range |12\~23
-1.5\~2|十二到二十三
负一点五到二| 8 | |date|她出生于86年8月18日,她弟弟出生于1995年3月1日|她出生于八六年八月十八日, 她弟弟出生于一九九五年三月一日| 9 | |time|等会请在12:05请通知我|等会请在十二点零五分请通知我 10 | |temperature|今天的最低气温达到-10°C|今天的最低气温达到零下十度 11 | |fraction|现场有7/12的观众投出了赞成票|现场有十二分之七的观众投出了赞成票| 12 | |percentage|明天有62%的概率降雨|明天有百分之六十二的概率降雨| 13 | |money|随便来几个价格12块5,34.5元,20.1万|随便来几个价格十二块五,三十四点五元,二十点一万| 14 | |telephone|这是固话0421-33441122
这是手机+86 18544139121|这是固话零四二一三三四四一一二二
这是手机八六一八五四四一三九一二一| 15 | ## References 16 | [Pull requests #658 of DeepSpeech](https://github.com/PaddlePaddle/DeepSpeech/pull/658/files) 17 | -------------------------------------------------------------------------------- /nodes/zh_normalization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from .text_normlization import * 15 | -------------------------------------------------------------------------------- /nodes/zh_normalization/chronology.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | 16 | from .num import DIGITS 17 | from .num import num2str 18 | from .num import verbalize_cardinal 19 | from .num import verbalize_digit 20 | 21 | 22 | def _time_num2str(num_string: str) -> str: 23 | """A special case for verbalizing number in time.""" 24 | result = num2str(num_string.lstrip('0')) 25 | if num_string.startswith('0'): 26 | result = DIGITS['0'] + result 27 | return result 28 | 29 | 30 | # 时刻表达式 31 | RE_TIME = re.compile(r'([0-1]?[0-9]|2[0-3])' 32 | r':([0-9][0-9]?)' 33 | r'(:([0-9][0-9]?))?') 34 | 35 | # 时间范围,如8:30-12:30 36 | RE_TIME_RANGE = re.compile(r'([0-1]?[0-9]|2[0-3])' 37 | r':([0-9][0-9]?)' 38 | r'(:([0-9][0-9]?))?' 39 | r'(~|-)' 40 | r'([0-1]?[0-9]|2[0-3])' 41 | r':([0-9][0-9]?)' 42 | r'(:([0-9][0-9]?))?') 43 | 44 | 45 | def replace_time(match) -> str: 46 | """ 47 | Args: 48 | match (re.Match) 49 | Returns: 50 | str 51 | """ 52 | 53 | is_range = len(match.groups()) > 5 54 | 55 | hour = match.group(1) 56 | minute = match.group(2) 57 | second = match.group(4) 58 | 59 | if is_range: 60 | hour_2 = match.group(6) 61 | minute_2 = match.group(7) 62 | second_2 = match.group(9) 63 | 64 | result = f"{num2str(hour)}点" 65 | if minute.lstrip('0'): 66 | if int(minute) == 30: 67 | result += "半" 68 | else: 69 | result += f"{_time_num2str(minute)}分" 70 | if second and second.lstrip('0'): 71 | result += f"{_time_num2str(second)}秒" 72 | 73 | if is_range: 74 | result += "至" 75 | result += f"{num2str(hour_2)}点" 76 | if minute_2.lstrip('0'): 77 | if int(minute) == 30: 78 | result += "半" 79 | else: 80 | result += f"{_time_num2str(minute_2)}分" 81 | if second_2 and second_2.lstrip('0'): 82 | result += f"{_time_num2str(second_2)}秒" 83 | 84 | return result 85 | 86 | 87 | RE_DATE = re.compile(r'(\d{4}|\d{2})年' 88 | r'((0?[1-9]|1[0-2])月)?' 89 | r'(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?') 90 | 91 | 92 | def replace_date(match) -> str: 93 | """ 94 | Args: 95 | match (re.Match) 96 | Returns: 97 | str 98 | """ 99 | year = match.group(1) 100 | month = match.group(3) 101 | day = match.group(5) 102 | result = "" 103 | if year: 104 | result += f"{verbalize_digit(year)}年" 105 | if month: 106 | result += f"{verbalize_cardinal(month)}月" 107 | if day: 108 | result += f"{verbalize_cardinal(day)}{match.group(9)}" 109 | return result 110 | 111 | 112 | # 用 / 或者 - 分隔的 YY/MM/DD 或者 YY-MM-DD 日期 113 | RE_DATE2 = re.compile( 114 | r'(\d{4})([- /.])(0?[1-9]|1[012])\2(0?[1-9]|[12][0-9]|3[01])') 115 | 116 | 117 | def replace_date2(match) -> str: 118 | """ 119 | Args: 120 | match (re.Match) 121 | Returns: 122 | str 123 | """ 124 | year = match.group(1) 125 | month = match.group(3) 126 | day = match.group(4) 127 | result = "" 128 | if year: 129 | result += f"{verbalize_digit(year)}年" 130 | if month: 131 | result += f"{verbalize_cardinal(month)}月" 132 | if day: 133 | result += f"{verbalize_cardinal(day)}日" 134 | return result 135 | -------------------------------------------------------------------------------- /nodes/zh_normalization/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | import string 16 | 17 | #from pypinyin.constants import SUPPORT_UCS4 18 | 19 | # 全角半角转换 20 | # 英文字符全角 -> 半角映射表 (num: 52) 21 | F2H_ASCII_LETTERS = { 22 | ord(char) + 65248: ord(char) 23 | for char in string.ascii_letters 24 | } 25 | 26 | # 英文字符半角 -> 全角映射表 27 | H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()} 28 | 29 | # 数字字符全角 -> 半角映射表 (num: 10) 30 | F2H_DIGITS = {ord(char) + 65248: ord(char) for char in string.digits} 31 | # 数字字符半角 -> 全角映射表 32 | H2F_DIGITS = {value: key for key, value in F2H_DIGITS.items()} 33 | 34 | # 标点符号全角 -> 半角映射表 (num: 32) 35 | F2H_PUNCTUATIONS = {ord(char) + 65248: ord(char) for char in string.punctuation} 36 | # 标点符号半角 -> 全角映射表 37 | H2F_PUNCTUATIONS = {value: key for key, value in F2H_PUNCTUATIONS.items()} 38 | 39 | # 空格 (num: 1) 40 | F2H_SPACE = {'\u3000': ' '} 41 | H2F_SPACE = {' ': '\u3000'} 42 | 43 | # 非"有拼音的汉字"的字符串,可用于NSW提取 44 | ''' 45 | if SUPPORT_UCS4: 46 | RE_NSW = re.compile(r'(?:[^' 47 | r'\u3007' # 〇 48 | r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF] 49 | r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF] 50 | r'\uf900-\ufaff' # CJK兼容:[F900-FAFF] 51 | r'\U00020000-\U0002A6DF' # CJK扩展B:[20000-2A6DF] 52 | r'\U0002A703-\U0002B73F' # CJK扩展C:[2A700-2B73F] 53 | r'\U0002B740-\U0002B81D' # CJK扩展D:[2B740-2B81D] 54 | r'\U0002F80A-\U0002FA1F' # CJK兼容扩展:[2F800-2FA1F] 55 | r'])+') 56 | else: 57 | ''' 58 | RE_NSW = re.compile( # pragma: no cover 59 | r'(?:[^' 60 | r'\u3007' # 〇 61 | r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF] 62 | r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF] 63 | r'\uf900-\ufaff' # CJK兼容:[F900-FAFF] 64 | r'])+') 65 | -------------------------------------------------------------------------------- /nodes/zh_normalization/num.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Rules to verbalize numbers into Chinese characters. 16 | https://zh.wikipedia.org/wiki/中文数字#現代中文 17 | """ 18 | import re 19 | from collections import OrderedDict 20 | from typing import List 21 | 22 | DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')} 23 | UNITS = OrderedDict({ 24 | 1: '十', 25 | 2: '百', 26 | 3: '千', 27 | 4: '万', 28 | 8: '亿', 29 | }) 30 | 31 | COM_QUANTIFIERS = '(封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)' 32 | 33 | # 分数表达式 34 | RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)') 35 | 36 | 37 | def replace_frac(match) -> str: 38 | """ 39 | Args: 40 | match (re.Match) 41 | Returns: 42 | str 43 | """ 44 | sign = match.group(1) 45 | nominator = match.group(2) 46 | denominator = match.group(3) 47 | sign: str = "负" if sign else "" 48 | nominator: str = num2str(nominator) 49 | denominator: str = num2str(denominator) 50 | result = f"{sign}{denominator}分之{nominator}" 51 | return result 52 | 53 | 54 | # 百分数表达式 55 | RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%') 56 | 57 | 58 | def replace_percentage(match) -> str: 59 | """ 60 | Args: 61 | match (re.Match) 62 | Returns: 63 | str 64 | """ 65 | sign = match.group(1) 66 | percent = match.group(2) 67 | sign: str = "负" if sign else "" 68 | percent: str = num2str(percent) 69 | result = f"{sign}百分之{percent}" 70 | return result 71 | 72 | 73 | # 整数表达式 74 | # 带负号的整数 -10 75 | RE_INTEGER = re.compile(r'(-)' r'(\d+)') 76 | 77 | 78 | def replace_negative_num(match) -> str: 79 | """ 80 | Args: 81 | match (re.Match) 82 | Returns: 83 | str 84 | """ 85 | sign = match.group(1) 86 | number = match.group(2) 87 | sign: str = "负" if sign else "" 88 | number: str = num2str(number) 89 | result = f"{sign}{number}" 90 | return result 91 | 92 | 93 | # 编号-无符号整形 94 | # 00078 95 | RE_DEFAULT_NUM = re.compile(r'\d{3}\d*') 96 | 97 | 98 | def replace_default_num(match): 99 | """ 100 | Args: 101 | match (re.Match) 102 | Returns: 103 | str 104 | """ 105 | number = match.group(0) 106 | return verbalize_digit(number, alt_one=False) 107 | 108 | 109 | # 数字表达式 110 | # 纯小数 111 | RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))') 112 | # 正整数 + 量词 113 | RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS) 114 | RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))') 115 | 116 | 117 | def replace_positive_quantifier(match) -> str: 118 | """ 119 | Args: 120 | match (re.Match) 121 | Returns: 122 | str 123 | """ 124 | number = match.group(1) 125 | match_2 = match.group(2) 126 | if match_2 == "+": 127 | match_2 = "多" 128 | match_2: str = match_2 if match_2 else "" 129 | quantifiers: str = match.group(3) 130 | number: str = num2str(number) 131 | result = f"{number}{match_2}{quantifiers}" 132 | return result 133 | 134 | 135 | def replace_number(match) -> str: 136 | """ 137 | Args: 138 | match (re.Match) 139 | Returns: 140 | str 141 | """ 142 | sign = match.group(1) 143 | number = match.group(2) 144 | pure_decimal = match.group(5) 145 | if pure_decimal: 146 | result = num2str(pure_decimal) 147 | else: 148 | sign: str = "负" if sign else "" 149 | number: str = num2str(number) 150 | result = f"{sign}{number}" 151 | return result 152 | 153 | 154 | # 范围表达式 155 | # match.group(1) and match.group(8) are copy from RE_NUMBER 156 | 157 | RE_RANGE = re.compile( 158 | r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))[-~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))') 159 | 160 | 161 | def replace_range(match) -> str: 162 | """ 163 | Args: 164 | match (re.Match) 165 | Returns: 166 | str 167 | """ 168 | first, second = match.group(1), match.group(8) 169 | first = RE_NUMBER.sub(replace_number, first) 170 | second = RE_NUMBER.sub(replace_number, second) 171 | result = f"{first}到{second}" 172 | return result 173 | 174 | 175 | def _get_value(value_string: str, use_zero: bool=True) -> List[str]: 176 | stripped = value_string.lstrip('0') 177 | if len(stripped) == 0: 178 | return [] 179 | elif len(stripped) == 1: 180 | if use_zero and len(stripped) < len(value_string): 181 | return [DIGITS['0'], DIGITS[stripped]] 182 | else: 183 | return [DIGITS[stripped]] 184 | else: 185 | largest_unit = next( 186 | power for power in reversed(UNITS.keys()) if power < len(stripped)) 187 | first_part = value_string[:-largest_unit] 188 | second_part = value_string[-largest_unit:] 189 | return _get_value(first_part) + [UNITS[largest_unit]] + _get_value( 190 | second_part) 191 | 192 | 193 | def verbalize_cardinal(value_string: str) -> str: 194 | if not value_string: 195 | return '' 196 | 197 | # 000 -> '零' , 0 -> '零' 198 | value_string = value_string.lstrip('0') 199 | if len(value_string) == 0: 200 | return DIGITS['0'] 201 | 202 | result_symbols = _get_value(value_string) 203 | # verbalized number starting with '一十*' is abbreviated as `十*` 204 | if len(result_symbols) >= 2 and result_symbols[0] == DIGITS[ 205 | '1'] and result_symbols[1] == UNITS[1]: 206 | result_symbols = result_symbols[1:] 207 | return ''.join(result_symbols) 208 | 209 | 210 | def verbalize_digit(value_string: str, alt_one=False) -> str: 211 | result_symbols = [DIGITS[digit] for digit in value_string] 212 | result = ''.join(result_symbols) 213 | if alt_one: 214 | result = result.replace("一", "幺") 215 | return result 216 | 217 | 218 | def num2str(value_string: str) -> str: 219 | integer_decimal = value_string.split('.') 220 | if len(integer_decimal) == 1: 221 | integer = integer_decimal[0] 222 | decimal = '' 223 | elif len(integer_decimal) == 2: 224 | integer, decimal = integer_decimal 225 | else: 226 | raise ValueError( 227 | f"The value string: '${value_string}' has more than one point in it." 228 | ) 229 | 230 | result = verbalize_cardinal(integer) 231 | 232 | decimal = decimal.rstrip('0') 233 | if decimal: 234 | # '.22' is verbalized as '零点二二' 235 | # '3.20' is verbalized as '三点二 236 | result = result if result else "零" 237 | result += '点' + verbalize_digit(decimal) 238 | return result 239 | -------------------------------------------------------------------------------- /nodes/zh_normalization/phonecode.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | 16 | from .num import verbalize_digit 17 | 18 | # 规范化固话/手机号码 19 | # 手机 20 | # http://www.jihaoba.com/news/show/13680 21 | # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 22 | # 联通:130、131、132、156、155、186、185、176 23 | # 电信:133、153、189、180、181、177 24 | RE_MOBILE_PHONE = re.compile( 25 | r"(? str: 34 | if mobile: 35 | sp_parts = phone_string.strip('+').split() 36 | result = ','.join( 37 | [verbalize_digit(part, alt_one=True) for part in sp_parts]) 38 | return result 39 | else: 40 | sil_parts = phone_string.split('-') 41 | result = ','.join( 42 | [verbalize_digit(part, alt_one=True) for part in sil_parts]) 43 | return result 44 | 45 | 46 | def replace_phone(match) -> str: 47 | """ 48 | Args: 49 | match (re.Match) 50 | Returns: 51 | str 52 | """ 53 | return phone2str(match.group(0), mobile=False) 54 | 55 | 56 | def replace_mobile(match) -> str: 57 | """ 58 | Args: 59 | match (re.Match) 60 | Returns: 61 | str 62 | """ 63 | return phone2str(match.group(0)) 64 | -------------------------------------------------------------------------------- /nodes/zh_normalization/quantifier.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | 16 | from .num import num2str 17 | 18 | # 温度表达式,温度会影响负号的读法 19 | # -3°C 零下三度 20 | RE_TEMPERATURE = re.compile(r'(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)') 21 | measure_dict = { 22 | "cm2": "平方厘米", 23 | "cm²": "平方厘米", 24 | "cm3": "立方厘米", 25 | "cm³": "立方厘米", 26 | "cm": "厘米", 27 | "db": "分贝", 28 | "ds": "毫秒", 29 | "kg": "千克", 30 | "km": "千米", 31 | "m2": "平方米", 32 | "m²": "平方米", 33 | "m³": "立方米", 34 | "m3": "立方米", 35 | "ml": "毫升", 36 | "m": "米", 37 | "mm": "毫米", 38 | "s": "秒" 39 | } 40 | 41 | 42 | def replace_temperature(match) -> str: 43 | """ 44 | Args: 45 | match (re.Match) 46 | Returns: 47 | str 48 | """ 49 | sign = match.group(1) 50 | temperature = match.group(2) 51 | unit = match.group(3) 52 | sign: str = "零下" if sign else "" 53 | temperature: str = num2str(temperature) 54 | unit: str = "摄氏度" if unit == "摄氏度" else "度" 55 | result = f"{sign}{temperature}{unit}" 56 | return result 57 | 58 | 59 | def replace_measure(sentence) -> str: 60 | for q_notation in measure_dict: 61 | if q_notation in sentence and re.search(f'\d{q_notation}',sentence): 62 | 63 | sentence = sentence.replace(q_notation, measure_dict[q_notation]) 64 | return sentence 65 | -------------------------------------------------------------------------------- /nodes/zh_normalization/text_normlization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | from typing import List 16 | 17 | from .char_convert import tranditional_to_simplified 18 | from .chronology import RE_DATE 19 | from .chronology import RE_DATE2 20 | from .chronology import RE_TIME 21 | from .chronology import RE_TIME_RANGE 22 | from .chronology import replace_date 23 | from .chronology import replace_date2 24 | from .chronology import replace_time 25 | from .constants import F2H_ASCII_LETTERS 26 | from .constants import F2H_DIGITS 27 | from .constants import F2H_SPACE 28 | from .num import RE_DECIMAL_NUM 29 | from .num import RE_DEFAULT_NUM 30 | from .num import RE_FRAC 31 | from .num import RE_INTEGER 32 | from .num import RE_NUMBER 33 | from .num import RE_PERCENTAGE 34 | from .num import RE_POSITIVE_QUANTIFIERS 35 | from .num import RE_RANGE 36 | from .num import replace_default_num 37 | from .num import replace_frac 38 | from .num import replace_negative_num 39 | from .num import replace_number 40 | from .num import replace_percentage 41 | from .num import replace_positive_quantifier 42 | from .num import replace_range 43 | from .phonecode import RE_MOBILE_PHONE 44 | from .phonecode import RE_NATIONAL_UNIFORM_NUMBER 45 | from .phonecode import RE_TELEPHONE 46 | from .phonecode import replace_mobile 47 | from .phonecode import replace_phone 48 | from .quantifier import RE_TEMPERATURE 49 | from .quantifier import replace_measure 50 | from .quantifier import replace_temperature 51 | 52 | 53 | class TextNormalizer(): 54 | def __init__(self): 55 | self.SENTENCE_SPLITOR = re.compile(r'([:、,;。?!,;?!][”’]?)') 56 | 57 | def _split(self, text: str, lang="zh") -> List[str]: 58 | """Split long text into sentences with sentence-splitting punctuations. 59 | Args: 60 | text (str): The input text. 61 | Returns: 62 | List[str]: Sentences. 63 | 64 | character_map = { 65 | ":": ",", 66 | ";": ",", 67 | "!": "。", 68 | "(": ",", 69 | ")": ",", 70 | "【": ",", 71 | "】": ",", 72 | "『": ",", 73 | "』": ",", 74 | "「": ",", 75 | "」": ",", 76 | "《": ",", 77 | "》": ",", 78 | "-": ",", 79 | "‘": " ", 80 | "“": " ", 81 | "’": " ", 82 | "”": " ", 83 | '"': " ", 84 | "'": " ", 85 | ":": ",", 86 | ";": ",", 87 | "!": ".", 88 | "(": ",", 89 | ")": ",", 90 | "[": ",", 91 | "]": ",", 92 | ">": ",", 93 | "<": ",", 94 | "-": ",", 95 | } 96 | """ 97 | # Only for pure Chinese here 98 | if lang == "zh": 99 | #text = text.replace(" ", "") 100 | # 过滤掉特殊字符 101 | text = re.sub(r'[——《》【】<>{}()()#&@“”^|…\\]', '', text) 102 | text = self.SENTENCE_SPLITOR.sub(r'\1\n', text) 103 | text = text.strip() 104 | sentences = [sentence.strip() for sentence in re.split(r'\n+', text)] 105 | return sentences 106 | 107 | def _post_replace(self, sentence: str) -> str: 108 | 109 | 110 | #sentence = sentence.replace('/', '每') 111 | sentence = sentence.replace('~', '至') 112 | sentence = sentence.replace('~', '至') 113 | sentence = sentence.replace('①', '一') 114 | sentence = sentence.replace('②', '二') 115 | sentence = sentence.replace('③', '三') 116 | sentence = sentence.replace('④', '四') 117 | sentence = sentence.replace('⑤', '五') 118 | sentence = sentence.replace('⑥', '六') 119 | sentence = sentence.replace('⑦', '七') 120 | sentence = sentence.replace('⑧', '八') 121 | sentence = sentence.replace('⑨', '九') 122 | sentence = sentence.replace('⑩', '十') 123 | sentence = sentence.replace('α', '阿尔法') 124 | sentence = sentence.replace('β', '贝塔') 125 | sentence = sentence.replace('γ', '伽玛').replace('Γ', '伽玛') 126 | sentence = sentence.replace('δ', '德尔塔').replace('Δ', '德尔塔') 127 | sentence = sentence.replace('ε', '艾普西龙') 128 | sentence = sentence.replace('ζ', '捷塔') 129 | sentence = sentence.replace('η', '依塔') 130 | sentence = sentence.replace('θ', '西塔').replace('Θ', '西塔') 131 | sentence = sentence.replace('ι', '艾欧塔') 132 | sentence = sentence.replace('κ', '喀帕') 133 | sentence = sentence.replace('λ', '拉姆达').replace('Λ', '拉姆达') 134 | sentence = sentence.replace('μ', '缪') 135 | sentence = sentence.replace('ν', '拗') 136 | sentence = sentence.replace('ξ', '克西').replace('Ξ', '克西') 137 | sentence = sentence.replace('ο', '欧米克伦') 138 | sentence = sentence.replace('π', '派').replace('Π', '派') 139 | sentence = sentence.replace('ρ', '肉') 140 | sentence = sentence.replace('ς', '西格玛').replace('Σ', '西格玛').replace( 141 | 'σ', '西格玛') 142 | sentence = sentence.replace('τ', '套') 143 | sentence = sentence.replace('υ', '宇普西龙') 144 | sentence = sentence.replace('φ', '服艾').replace('Φ', '服艾') 145 | sentence = sentence.replace('χ', '器') 146 | sentence = sentence.replace('ψ', '普赛').replace('Ψ', '普赛') 147 | sentence = sentence.replace('ω', '欧米伽').replace('Ω', '欧米伽') 148 | sentence = sentence.replace('+', '加') 149 | 150 | 151 | # re filter special characters, have one more character "-" than line 68 152 | sentence = re.sub(r'[-——《》【】<=>{}()()#&@“”^|…\\]', '', sentence) 153 | return sentence 154 | 155 | # 数字转为中文读法 156 | def num_to_chinese(self,num): 157 | num_str = str(num) 158 | chinese_digits = "零一二三四五六七八九" 159 | units = ["", "十", "百", "千"] 160 | big_units = ["", "万", "亿", "兆"] 161 | result = "" 162 | zero_flag = False # 标记是否需要加'零' 163 | part = [] # 存储每4位的数字 164 | 165 | # 将数字按每4位分组 166 | while num_str: 167 | part.append(num_str[-4:]) 168 | num_str = num_str[:-4] 169 | 170 | for i in range(len(part)): 171 | part_str = "" 172 | part_zero_flag = False 173 | for j in range(len(part[i])): 174 | digit = int(part[i][j]) 175 | if digit == 0: 176 | part_zero_flag = True 177 | else: 178 | if part_zero_flag or (zero_flag and i > 0 and not result.startswith(chinese_digits[0])): 179 | part_str += chinese_digits[0] 180 | zero_flag = False 181 | part_zero_flag = False 182 | part_str += chinese_digits[digit] + units[len(part[i]) - j - 1] 183 | if part_str.endswith("零"): 184 | part_str = part_str[:-1] # 去除尾部的'零' 185 | if part_str: 186 | zero_flag = True 187 | 188 | if i > 0 and not set(part[i]) <= {'0'}: # 如果当前部分不全是0,则加上相应的大单位 189 | result = part_str + big_units[i] + result 190 | else: 191 | result = part_str + result 192 | 193 | # 处理输入为0的情况或者去掉开头的零 194 | result = result.lstrip(chinese_digits[0]) 195 | if not result: 196 | return chinese_digits[0] 197 | 198 | return result 199 | 200 | def normalize_sentence(self, sentence: str) -> str: 201 | 202 | # basic character conversions 203 | # add 204 | sentence = re.sub(r'(\d+)\s*[\*xX]\s*(\d+)', r'\1 乘 \2', sentence,re.I) 205 | # 区号 电话 分机 206 | sentence = re.sub(r'(0\d+)\-(\d{3,})\-(\d{3,})', r'\1杠\2杠\3', sentence,re.I) 207 | sentence = re.sub(r'(0\d+)\-(\d{3,})', r'\1杠\2', sentence,re.I) 208 | sentence = sentence.replace('=', '等于') 209 | sentence = sentence.replace('÷','除以') 210 | 211 | #sentence = re.sub(r'(\d+)\s*\-', r'\1 减', sentence) 212 | sentence = re.sub(r'((?:\d+\.)?\d+)\s*/\s*(\d+)', r'\2分之\1', sentence) 213 | 214 | # 取出数字 number_list= [('1000200030004000.123', '1000200030004000', '123'), ('23425', '23425', '')] 215 | number_list=re.findall('((\d+)(?:\.(\d+))?%?)',sentence) 216 | numtext=['零','一','二','三','四','五','六','七','八','九'] 217 | if len(number_list)>0: 218 | #dc= ('1000200030004000.123', '1000200030004000', '123','') 219 | for m,dc in enumerate(number_list): 220 | n_len=len(dc[1]) 221 | #手机号/座机号 超大数 亿内的数 0开头的数,不做处理 222 | if n_len>16 or n_len<9 or (n_len==11 and str(dc[1])[0]=='1') or str(dc[1])[0]=='0': 223 | continue 224 | int_text=self.num_to_chinese(dc[1]) 225 | if len(dc)>2 and dc[2]: 226 | int_text+="点"+"".join([numtext[int(i)] for i in dc[2]]) 227 | if dc[0][-1]=='%': 228 | int_text=f'百分之{int_text}' 229 | sentence=sentence.replace(dc[0],int_text) 230 | 231 | 232 | sentence = tranditional_to_simplified(sentence) 233 | sentence = sentence.translate(F2H_ASCII_LETTERS).translate( 234 | F2H_DIGITS).translate(F2H_SPACE) 235 | 236 | # number related NSW verbalization 237 | sentence = RE_DATE.sub(replace_date, sentence) 238 | sentence = RE_DATE2.sub(replace_date2, sentence) 239 | 240 | # range first 241 | sentence = RE_TIME_RANGE.sub(replace_time, sentence) 242 | sentence = RE_TIME.sub(replace_time, sentence) 243 | 244 | sentence = RE_TEMPERATURE.sub(replace_temperature, sentence) 245 | sentence = replace_measure(sentence) 246 | sentence = RE_FRAC.sub(replace_frac, sentence) 247 | sentence = RE_PERCENTAGE.sub(replace_percentage, sentence) 248 | sentence = RE_MOBILE_PHONE.sub(replace_mobile, sentence) 249 | 250 | sentence = RE_TELEPHONE.sub(replace_phone, sentence) 251 | sentence = RE_NATIONAL_UNIFORM_NUMBER.sub(replace_phone, sentence) 252 | 253 | sentence = RE_RANGE.sub(replace_range, sentence) 254 | sentence = RE_INTEGER.sub(replace_negative_num, sentence) 255 | sentence = RE_DECIMAL_NUM.sub(replace_number, sentence) 256 | sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier, 257 | sentence) 258 | sentence = RE_DEFAULT_NUM.sub(replace_default_num, sentence) 259 | sentence = RE_NUMBER.sub(replace_number, sentence) 260 | sentence = self._post_replace(sentence) 261 | 262 | sentence = sentence.replace('[一break]','[1break]') 263 | 264 | return sentence 265 | 266 | def normalize(self, text: str) -> List[str]: 267 | sentences = self._split(text) 268 | sentences = [self.normalize_sentence(sent) for sent in sentences] 269 | return sentences 270 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | omegaconf~=2.3.0 2 | tqdm 3 | einops 4 | vector_quantize_pytorch 5 | vocos 6 | pydub 7 | faster_whisper 8 | whisper_timestamped 9 | inflect 10 | unidecode 11 | eng_to_ipa 12 | cn2an 13 | jieba 14 | pybase16384 -------------------------------------------------------------------------------- /web/loadSpeaker.js: -------------------------------------------------------------------------------- 1 | import { app } from '../../../scripts/app.js' 2 | 3 | app.registerExtension({ 4 | name: 'Mixlab.Chattts.LoadSpeaker', 5 | 6 | async beforeRegisterNodeDef (nodeType, nodeData, app) { 7 | if ( 8 | nodeType.comfyClass == 'LoadSpeaker' || 9 | nodeType.comfyClass == 'MergeSpeaker' 10 | ) { 11 | const onExecuted = nodeType.prototype.onExecuted 12 | nodeType.prototype.onExecuted = function (message) { 13 | onExecuted?.apply(this, arguments) 14 | if (message.text && message.text[0]) { 15 | this.title = message.text.join(",") 16 | } 17 | } 18 | } 19 | } 20 | }) 21 | --------------------------------------------------------------------------------