├── .env ├── .github └── FUNDING.yml ├── .gitignore ├── ChatTTS ├── __init__.py ├── config │ ├── __init__.py │ └── config.py ├── core.py ├── experimental │ └── llm.py ├── infer │ └── api.py ├── model │ ├── __init__.py │ ├── cuda │ │ ├── __init__.py │ │ ├── patch.py │ │ └── te_llama.py │ ├── dvae.py │ ├── gpt.py │ ├── processors.py │ └── tokenizer.py ├── norm.py ├── res │ ├── __init__.py │ ├── homophones_map.json │ └── sha256_map.json └── utils │ ├── __init__.py │ ├── dl.py │ ├── download.py │ ├── gpu.py │ ├── gpu_utils.py │ ├── infer_utils.py │ ├── io.py │ ├── io_utils.py │ └── log.py ├── Dockerfile.cpu ├── Dockerfile.gpu ├── LICENSE ├── README.md ├── README_EN.md ├── app.py ├── asset └── 模型下载说明.txt ├── cover-pt.py ├── docker-compose.cpu.yaml ├── docker-compose.gpu.yaml ├── faq.md ├── ffmpeg └── ffmpeg下载.txt ├── listen-speaker ├── 083806_use14.39s-audio0s-seed1983.pt-te0.1-tp0.701-tk20-textlen5-39593-merge.wav ├── 083900_use3.43s-audio0s-seed13.pt-te0.1-tp0.701-tk20-textlen5-09614-merge.wav ├── 083910_use3.22s-audio0s-seed7869.pt-te0.1-tp0.701-tk20-textlen5-19801-merge.wav ├── 083919_use3.42s-audio0s-seed6653.pt-te0.1-tp0.701-tk20-textlen5-10851-merge.wav ├── 083928_use3.3s-audio0s-seed4751.pt-te0.1-tp0.701-tk20-textlen5-69400-merge.wav ├── 083937_use3.11s-audio0s-seed1579.pt-te0.1-tp0.701-tk20-textlen5-27436-merge.wav ├── 083945_use3.13s-audio0s-seed14.pt-te0.1-tp0.701-tk20-textlen5-57598-merge.wav ├── 083955_use2.84s-audio0s-seed3333.pt-te0.1-tp0.701-tk20-textlen5-93133-merge.wav ├── 084004_use3.08s-audio0s-seed1111.pt-te0.1-tp0.701-tk20-textlen5-39727-merge.wav ├── 084014_use3.37s-audio0s-seed11.pt-te0.1-tp0.701-tk20-textlen5-27662-merge.wav ├── 084024_use3.3s-audio0s-seed1031.pt-te0.1-tp0.701-tk20-textlen5-19879-merge.wav ├── 084032_use3.36s-audio0s-seed2222.pt-te0.1-tp0.701-tk20-textlen5-48884-merge.wav ├── 084040_use3.12s-audio0s-seed12.pt-te0.1-tp0.701-tk20-textlen5-28377-merge.wav ├── 084048_use3.16s-audio0s-seed5555.pt-te0.1-tp0.701-tk20-textlen5-42929-merge.wav ├── 084056_use3.02s-audio0s-seed5099.pt-te0.1-tp0.701-tk20-textlen5-35891-merge.wav ├── 084454_use3.47s-audio0s-seed2345.pt-te0.1-tp0.701-tk20-textlen5-86669-merge.wav ├── 084503_use3.22s-audio0s-seed4785.pt-te0.1-tp0.701-tk20-textlen5-95898-merge.wav ├── 084511_use3.56s-audio0s-seed491.pt-te0.1-tp0.701-tk20-textlen5-66150-merge.wav ├── 084518_use3.15s-audio0s-seed4444.pt-te0.1-tp0.701-tk20-textlen5-77649-merge.wav ├── 084526_use3.38s-audio0s-seed1455.pt-te0.1-tp0.701-tk20-textlen5-54547-merge.wav ├── 084755_use3.18s-audio0s-seed2328.pt-te0.1-tp0.701-tk20-textlen5-85733-merge.wav ├── 084813_use3.47s-audio0s-seed8888.pt-te0.1-tp0.701-tk20-textlen5-96180-merge.wav ├── 084823_use3.33s-audio0s-seed16.pt-te0.1-tp0.701-tk20-textlen5-51038-merge.wav ├── 084832_use2.95s-audio0s-seed1234.pt-te0.1-tp0.701-tk20-textlen5-80959-merge.wav ├── 084842_use3.34s-audio0s-seed1518.pt-te0.1-tp0.701-tk20-textlen5-37066-merge.wav ├── 084851_use3.08s-audio0s-seed7777.pt-te0.1-tp0.701-tk20-textlen5-99477-merge.wav ├── 084901_use2.81s-audio0s-seed4099.pt-te0.1-tp0.701-tk20-textlen5-16898-merge.wav ├── 084910_use3.29s-audio0s-seed5600.pt-te0.1-tp0.701-tk20-textlen5-42899-merge.wav ├── 084919_use3.12s-audio0s-seed5400.pt-te0.1-tp0.701-tk20-textlen5-57496-merge.wav ├── 084929_use3.43s-audio0s-seed9999.pt-te0.1-tp0.701-tk20-textlen5-32652-merge.wav ├── 084945_use3.13s-audio0s-seed125.pt-te0.1-tp0.701-tk20-textlen5-93149-merge.wav ├── 084954_use3.25s-audio0s-seed2279.pt-te0.1-tp0.701-tk20-textlen5-62556-merge.wav ├── 085002_use3.2s-audio0s-seed6666.pt-te0.1-tp0.701-tk20-textlen5-07948-merge.wav ├── 085011_use3.31s-audio0s-seed492.pt-te0.1-tp0.701-tk20-textlen5-17771-merge.wav └── 085020_use3.04s-audio0s-seed5.pt-te0.1-tp0.701-tk20-textlen5-82025-merge.wav ├── pyproject.toml ├── requirements.txt ├── run.bat ├── run.py ├── runtest.bat ├── speaker ├── 1031.pt ├── 11.pt ├── 1111.pt ├── 12.pt ├── 1234.pt ├── 125.pt ├── 13.pt ├── 14.pt ├── 1455.pt ├── 1518.pt ├── 1579.pt ├── 16.pt ├── 1983.pt ├── 2222.pt ├── 2279.pt ├── 2328.pt ├── 2345.pt ├── 3333.pt ├── 4099.pt ├── 4444.pt ├── 4751.pt ├── 4785.pt ├── 491.pt ├── 492.pt ├── 5.pt ├── 5099.pt ├── 5400.pt ├── 5555.pt ├── 5600.pt ├── 6653.pt ├── 6666.pt ├── 7777.pt ├── 7869.pt ├── 8888.pt └── 9999.pt ├── static └── js │ ├── bootstrap.bundle.min.js │ ├── bootstrap.min.css │ ├── jquery.min.js │ └── layer │ ├── layer.js │ ├── mobile │ ├── layer.js │ └── need │ │ └── layer.css │ └── theme │ └── default │ ├── icon-ext.png │ ├── icon.png │ ├── layer.css │ ├── loading-0.gif │ ├── loading-1.gif │ └── loading-2.gif ├── templates ├── index.html └── indexen.html ├── test.py ├── tools ├── __init__.py ├── audio │ ├── __init__.py │ └── np.py ├── checksum │ ├── main.go │ └── tmpl.go ├── llm │ ├── __init__.py │ └── llm.py ├── logger │ ├── __init__.py │ └── log.py ├── normalizer │ ├── __init__.py │ ├── en.py │ └── zh.py └── seeder │ ├── __init__.py │ └── ctx.py └── uilib ├── __init__.py ├── cfg.py ├── utils.py └── zh_normalization ├── README.md ├── __init__.py ├── char_convert.py ├── chronology.py ├── constants.py ├── num.py ├── phonecode.py ├── quantifier.py └── text_normlization.py /.env: -------------------------------------------------------------------------------- 1 | WEB_ADDRESS=127.0.0.1:9966 2 | compile=false 3 | device=default 4 | merge_size=6 -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | custom: https://ko-fi.com/jianchang512 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.log 2 | *.srt 3 | .idea 4 | 5 | models/* 6 | dev 7 | venv 8 | dist 9 | source 10 | build 11 | __pycache__ 12 | *.spec 13 | *.ui 14 | *.bak 15 | *.aac 16 | *.pt 17 | # *.wav 18 | pack.bat 19 | gitcmd.bat 20 | logs 21 | poetry.lock 22 | docs 23 | static/wavs/* 24 | examples 25 | ffmpeg/ffmpeg.exe 26 | asset/*.pt -------------------------------------------------------------------------------- /ChatTTS/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import Chat 2 | -------------------------------------------------------------------------------- /ChatTTS/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import Config 2 | -------------------------------------------------------------------------------- /ChatTTS/config/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass(repr=False, eq=False) 5 | class Path: 6 | vocos_ckpt_path: str = "asset/Vocos.pt" 7 | dvae_ckpt_path: str = "asset/DVAE_full.pt" 8 | gpt_ckpt_path: str = "asset/GPT.pt" 9 | decoder_ckpt_path: str = "asset/Decoder.pt" 10 | tokenizer_path: str = "asset/tokenizer.pt" 11 | 12 | 13 | @dataclass(repr=False, eq=False) 14 | class Decoder: 15 | idim: int = 384 16 | odim: int = 384 17 | hidden: int = 512 18 | n_layer: int = 12 19 | bn_dim: int = 128 20 | 21 | 22 | @dataclass(repr=False, eq=False) 23 | class VQ: 24 | dim: int = 1024 25 | levels: tuple = (5, 5, 5, 5) 26 | G: int = 2 27 | R: int = 2 28 | 29 | 30 | @dataclass(repr=False, eq=False) 31 | class DVAE: 32 | encoder: Decoder = Decoder( 33 | idim=512, 34 | odim=1024, 35 | hidden=256, 36 | n_layer=12, 37 | bn_dim=128, 38 | ) 39 | decoder: Decoder = Decoder( 40 | idim=512, 41 | odim=512, 42 | hidden=256, 43 | n_layer=12, 44 | bn_dim=128, 45 | ) 46 | vq: VQ = VQ() 47 | 48 | 49 | @dataclass(repr=False, eq=False) 50 | class GPT: 51 | hidden_size: int = 768 52 | intermediate_size: int = 3072 53 | num_attention_heads: int = 12 54 | num_hidden_layers: int = 20 55 | use_cache: bool = False 56 | max_position_embeddings: int = 4096 57 | 58 | spk_emb_dim: int = 192 59 | spk_KL: bool = False 60 | num_audio_tokens: int = 626 61 | num_vq: int = 4 62 | 63 | 64 | @dataclass(repr=False, eq=False) 65 | class FeatureExtractorInitArgs: 66 | sample_rate: int = 24000 67 | n_fft: int = 1024 68 | hop_length: int = 256 69 | n_mels: int = 100 70 | padding: str = "center" 71 | 72 | 73 | @dataclass(repr=False, eq=False) 74 | class FeatureExtractor: 75 | class_path: str = "vocos.feature_extractors.MelSpectrogramFeatures" 76 | init_args: FeatureExtractorInitArgs = FeatureExtractorInitArgs() 77 | 78 | 79 | @dataclass(repr=False, eq=False) 80 | class BackboneInitArgs: 81 | input_channels: int = 100 82 | dim: int = 512 83 | intermediate_dim: int = 1536 84 | num_layers: int = 8 85 | 86 | 87 | @dataclass(repr=False, eq=False) 88 | class Backbone: 89 | class_path: str = "vocos.models.VocosBackbone" 90 | init_args: BackboneInitArgs = BackboneInitArgs() 91 | 92 | 93 | @dataclass(repr=False, eq=False) 94 | class FourierHeadInitArgs: 95 | dim: int = 512 96 | n_fft: int = 1024 97 | hop_length: int = 256 98 | padding: str = "center" 99 | 100 | 101 | @dataclass(repr=False, eq=False) 102 | class FourierHead: 103 | class_path: str = "vocos.heads.ISTFTHead" 104 | init_args: FourierHeadInitArgs = FourierHeadInitArgs() 105 | 106 | 107 | @dataclass(repr=False, eq=False) 108 | class Vocos: 109 | feature_extractor: FeatureExtractor = FeatureExtractor() 110 | backbone: Backbone = Backbone() 111 | head: FourierHead = FourierHead() 112 | 113 | 114 | @dataclass(repr=False, eq=False) 115 | class Config: 116 | path: Path = Path() 117 | decoder: Decoder = Decoder() 118 | dvae: DVAE = DVAE() 119 | gpt: GPT = GPT() 120 | vocos: Vocos = Vocos() 121 | -------------------------------------------------------------------------------- /ChatTTS/experimental/llm.py: -------------------------------------------------------------------------------- 1 | 2 | from openai import OpenAI 3 | 4 | prompt_dict = { 5 | 'kimi': [ {"role": "system", "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。"}, 6 | {"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"}, 7 | {"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},], 8 | 'deepseek': [ 9 | {"role": "system", "content": "You are a helpful assistant"}, 10 | {"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"}, 11 | {"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},], 12 | 'deepseek_TN': [ 13 | {"role": "system", "content": "You are a helpful assistant"}, 14 | {"role": "user", "content": "你好,现在我们在处理TTS的文本输入,下面将会给你输入一段文本,请你将其中的阿拉伯数字等等转为文字表达,并且输出的文本里仅包含逗号和句号这两个标点符号"}, 15 | {"role": "assistant", "content": "好的,我现在对TTS的文本输入进行处理。这一般叫做text normalization。下面请输入"}, 16 | {"role": "user", "content": "We paid $123 for this desk."}, 17 | {"role": "assistant", "content": "We paid one hundred and twenty three dollars for this desk."}, 18 | {"role": "user", "content": "详询请拨打010-724654"}, 19 | {"role": "assistant", "content": "详询请拨打零幺零,七二四六五四"}, 20 | {"role": "user", "content": "罗森宣布将于7月24日退市,在华门店超6000家!"}, 21 | {"role": "assistant", "content": "罗森宣布将于七月二十四日退市,在华门店超过六千家。"}, 22 | ], 23 | } 24 | 25 | class llm_api: 26 | def __init__(self, api_key, base_url, model): 27 | self.client = OpenAI( 28 | api_key = api_key, 29 | base_url = base_url, 30 | ) 31 | self.model = model 32 | def call(self, user_question, temperature = 0.3, prompt_version='kimi', **kwargs): 33 | 34 | completion = self.client.chat.completions.create( 35 | model = self.model, 36 | messages = prompt_dict[prompt_version]+[{"role": "user", "content": user_question},], 37 | temperature = temperature, 38 | **kwargs 39 | ) 40 | return completion.choices[0].message.content 41 | -------------------------------------------------------------------------------- /ChatTTS/infer/api.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | from transformers.generation import TopKLogitsWarper, TopPLogitsWarper 5 | from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat 6 | 7 | def infer_code( 8 | models, 9 | text, 10 | spk_emb = None, 11 | top_P = 0.7, 12 | top_K = 20, 13 | temperature = 0.3, 14 | repetition_penalty = 1.05, 15 | max_new_token = 2048, 16 | stream=False, 17 | **kwargs 18 | ): 19 | 20 | device = next(models['gpt'].parameters()).device 21 | 22 | if not isinstance(text, list): 23 | text = [text] 24 | 25 | if not isinstance(temperature, list): 26 | temperature = [temperature] * models['gpt'].num_vq 27 | 28 | if spk_emb is not None: 29 | text = [f'[Stts][spk_emb]{i}[Ptts]' for i in text] 30 | else: 31 | text = [f'[Stts][empty_spk]{i}[Ptts]' for i in text] 32 | 33 | text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device) 34 | input_ids = text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq) 35 | text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device) 36 | 37 | inputs = { 38 | 'input_ids': input_ids, 39 | 'text_mask': text_mask, 40 | 'attention_mask': text_token['attention_mask'], 41 | } 42 | 43 | emb = models['gpt'].get_emb(**inputs) 44 | if spk_emb is not None: 45 | emb[inputs['input_ids'][..., 0] == models['tokenizer'].convert_tokens_to_ids('[spk_emb]')] = \ 46 | F.normalize(spk_emb.to(device).to(emb.dtype)[None].expand(len(text), -1), p=2.0, dim=1, eps=1e-12) 47 | 48 | num_code = models['gpt'].emb_code[0].num_embeddings - 1 49 | 50 | LogitsWarpers = [] 51 | if top_P is not None: 52 | LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3)) 53 | if top_K is not None: 54 | LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3)) 55 | 56 | LogitsProcessors = [] 57 | if repetition_penalty is not None and repetition_penalty != 1: 58 | LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(\ 59 | repetition_penalty, num_code, 16)) 60 | 61 | result = models['gpt'].generate( 62 | emb, inputs['input_ids'], 63 | temperature = torch.tensor(temperature, device=device), 64 | attention_mask = inputs['attention_mask'], 65 | LogitsWarpers = LogitsWarpers, 66 | LogitsProcessors = LogitsProcessors, 67 | eos_token = num_code, 68 | max_new_token = max_new_token, 69 | infer_text = False, 70 | stream = stream, 71 | **kwargs 72 | ) 73 | 74 | return result 75 | 76 | 77 | def refine_text( 78 | models, 79 | text, 80 | top_P = 0.7, 81 | top_K = 20, 82 | temperature = 0.7, 83 | repetition_penalty = 1.0, 84 | max_new_token = 384, 85 | prompt = '', 86 | **kwargs 87 | ): 88 | 89 | device = next(models['gpt'].parameters()).device 90 | 91 | if not isinstance(text, list): 92 | text = [text] 93 | 94 | assert len(text), 'text should not be empty' 95 | 96 | text = [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text] 97 | text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device) 98 | text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device) 99 | 100 | inputs = { 101 | 'input_ids': text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq), 102 | 'text_mask': text_mask, 103 | 'attention_mask': text_token['attention_mask'], 104 | } 105 | 106 | LogitsWarpers = [] 107 | if top_P is not None: 108 | LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3)) 109 | if top_K is not None: 110 | LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3)) 111 | 112 | LogitsProcessors = [] 113 | if repetition_penalty is not None and repetition_penalty != 1: 114 | LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, len(models['tokenizer']), 16)) 115 | 116 | result = models['gpt'].generate( 117 | models['gpt'].get_emb(**inputs), inputs['input_ids'], 118 | temperature = torch.tensor([temperature,], device=device), 119 | attention_mask = inputs['attention_mask'], 120 | LogitsWarpers = LogitsWarpers, 121 | LogitsProcessors = LogitsProcessors, 122 | eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None], 123 | max_new_token = max_new_token, 124 | infer_text = True, 125 | stream = False, 126 | **kwargs 127 | ) 128 | return next(result) 129 | -------------------------------------------------------------------------------- /ChatTTS/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .dvae import DVAE 2 | from .gpt import GPT 3 | from .processors import gen_logits 4 | from .tokenizer import Tokenizer 5 | -------------------------------------------------------------------------------- /ChatTTS/model/cuda/__init__.py: -------------------------------------------------------------------------------- 1 | from .te_llama import TELlamaModel 2 | -------------------------------------------------------------------------------- /ChatTTS/model/cuda/patch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LlamaRMSNorm(torch.nn.Module): 5 | def __init__(self, hidden_size, eps=1e-6): 6 | """ 7 | LlamaRMSNorm is equivalent to T5LayerNorm 8 | """ 9 | super().__init__() 10 | self.weight = torch.nn.Parameter(torch.ones(hidden_size)) 11 | self.variance_epsilon = eps 12 | 13 | def forward(self, hidden_states: torch.Tensor): 14 | input_dtype = hidden_states.dtype 15 | hidden_states = hidden_states.to(torch.float32) 16 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 17 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 18 | return self.weight.to(hidden_states.device) * hidden_states.to(input_dtype) 19 | -------------------------------------------------------------------------------- /ChatTTS/model/cuda/te_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | # 5 | # From https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_llama/te_llama.py 6 | # 7 | # Edited by fumiama. 8 | 9 | import re 10 | from contextlib import contextmanager 11 | from typing import Dict 12 | 13 | import transformer_engine as te 14 | from transformer_engine.pytorch.attention import RotaryPositionEmbedding 15 | 16 | import torch 17 | 18 | import transformers 19 | from transformers.models.llama.modeling_llama import ( 20 | LlamaModel, 21 | LlamaConfig, 22 | ) 23 | from transformers.modeling_utils import _load_state_dict_into_model 24 | 25 | from .patch import LlamaRMSNorm 26 | 27 | 28 | @contextmanager 29 | def replace_decoder(te_decoder_cls, llama_rms_norm_cls): 30 | """ 31 | Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`. 32 | """ 33 | original_llama_decoder_cls = ( 34 | transformers.models.llama.modeling_llama.LlamaDecoderLayer 35 | ) 36 | transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls 37 | original_llama_rms_norm_cls = transformers.models.llama.modeling_llama.LlamaRMSNorm 38 | transformers.models.llama.modeling_llama.LlamaRMSNorm = llama_rms_norm_cls 39 | try: 40 | yield 41 | finally: 42 | transformers.models.llama.modeling_llama.LlamaDecoderLayer = ( 43 | original_llama_decoder_cls 44 | ) 45 | transformers.models.llama.modeling_llama.LlamaRMSNorm = ( 46 | original_llama_rms_norm_cls 47 | ) 48 | 49 | 50 | class TELlamaDecoderLayer(te.pytorch.TransformerLayer): 51 | """ 52 | Wrapper class over TE's `TransformerLayer`. This makes the wrapper very 53 | similar to HF's `LlamaDecoderLayer` and easier to replace it in the code. 54 | 55 | Args: 56 | config: LlamaConfig 57 | args: positional args (for compatibility with `LlamaDecoderLayer`) 58 | kwargs: keyword args (for compatibility with `LlamaDecoderLayer`) 59 | """ 60 | 61 | def __init__(self, config, *args, **kwargs): 62 | super().__init__( 63 | hidden_size=config.hidden_size, 64 | ffn_hidden_size=config.intermediate_size, 65 | num_attention_heads=config.num_attention_heads, 66 | bias=False, 67 | layernorm_epsilon=config.rms_norm_eps, 68 | hidden_dropout=0, 69 | attention_dropout=0, 70 | fuse_qkv_params=False, 71 | normalization="RMSNorm", 72 | activation="swiglu", 73 | attn_input_format="bshd", 74 | num_gqa_groups=config.num_key_value_heads, 75 | ) 76 | te_rope = RotaryPositionEmbedding( 77 | config.hidden_size // config.num_attention_heads 78 | ) 79 | self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda() 80 | 81 | def forward(self, hidden_states, *args, attention_mask, **kwargs): 82 | """ 83 | Custom forward to make sure we only pass relevant arguments to the 84 | forward pass of the `TransformerLayer`. Also, make sure the output 85 | format matches the output of the HF's `LlamaDecoderLayer`. 86 | """ 87 | return ( 88 | super().forward( 89 | hidden_states, 90 | attention_mask=attention_mask, 91 | rotary_pos_emb=self.te_rope_emb, 92 | ), 93 | ) 94 | 95 | 96 | class TELlamaModel: 97 | """ 98 | LM created with `LlamaModel`. The underlying `LlamaDecoderLayer` 99 | class is monkey-patched with `TELlamaDecoderLayer` class before 100 | initializing the causal LM with `LlamaModel`. 101 | 102 | Args: 103 | config: LlamaConfig 104 | """ 105 | 106 | def __new__(cls, config: LlamaConfig): 107 | with replace_decoder( 108 | te_decoder_cls=TELlamaDecoderLayer, llama_rms_norm_cls=LlamaRMSNorm 109 | ): 110 | model = LlamaModel(config) 111 | return model 112 | 113 | @classmethod 114 | def from_state_dict( 115 | cls, 116 | state_dict: Dict[str, torch.Tensor], 117 | config: LlamaConfig, 118 | ): 119 | """ 120 | Custom method adapted from `from_pretrained` method in HuggingFace 121 | Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 122 | """ 123 | 124 | vanilla_model = cls(config) 125 | 126 | # replace_params copies parameters relevant only to TransformerEngine 127 | _replace_params(state_dict, vanilla_model.state_dict(), config) 128 | # _load_state_dict_into_model copies parameters other than those in TransformerEngine 129 | _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="") 130 | 131 | return vanilla_model 132 | 133 | 134 | def _replace_params(hf_state_dict, te_state_dict, config): 135 | # collect all layer prefixes to update 136 | all_layer_prefixes = set() 137 | for param_key in hf_state_dict.keys(): 138 | layer_prefix_pat = "model.layers.\d+." 139 | m = re.match(layer_prefix_pat, param_key) 140 | if m is not None: 141 | all_layer_prefixes.add(m.group()) 142 | 143 | for layer_prefix in all_layer_prefixes: 144 | # When loading weights into models with less number of layers, skip the 145 | # copy if the corresponding layer doesn't exist in HF model 146 | if layer_prefix + "input_layernorm.weight" in hf_state_dict: 147 | te_state_dict[ 148 | layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight" 149 | ].data[:] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:] 150 | 151 | if layer_prefix + "self_attn.q_proj.weight" in hf_state_dict: 152 | te_state_dict[ 153 | layer_prefix + "self_attention.layernorm_qkv.query_weight" 154 | ].data[:] = hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:] 155 | 156 | if layer_prefix + "self_attn.k_proj.weight" in hf_state_dict: 157 | te_state_dict[ 158 | layer_prefix + "self_attention.layernorm_qkv.key_weight" 159 | ].data[:] = hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:] 160 | 161 | if layer_prefix + "self_attn.v_proj.weight" in hf_state_dict: 162 | te_state_dict[ 163 | layer_prefix + "self_attention.layernorm_qkv.value_weight" 164 | ].data[:] = hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:] 165 | 166 | if layer_prefix + "self_attn.o_proj.weight" in hf_state_dict: 167 | te_state_dict[layer_prefix + "self_attention.proj.weight"].data[:] = ( 168 | hf_state_dict[layer_prefix + "self_attn.o_proj.weight"].data[:] 169 | ) 170 | 171 | if layer_prefix + "post_attention_layernorm.weight" in hf_state_dict: 172 | te_state_dict[layer_prefix + "layernorm_mlp.layer_norm_weight"].data[:] = ( 173 | hf_state_dict[layer_prefix + "post_attention_layernorm.weight"].data[:] 174 | ) 175 | 176 | # It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to 177 | # load them separately. 178 | if layer_prefix + "mlp.gate_proj.weight" in hf_state_dict: 179 | te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[ 180 | : config.intermediate_size 181 | ] = hf_state_dict[layer_prefix + "mlp.gate_proj.weight"].data 182 | 183 | if layer_prefix + "mlp.up_proj.weight" in hf_state_dict: 184 | te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[ 185 | config.intermediate_size : 186 | ] = hf_state_dict[layer_prefix + "mlp.up_proj.weight"].data 187 | 188 | if layer_prefix + "mlp.down_proj.weight" in hf_state_dict: 189 | te_state_dict[layer_prefix + "layernorm_mlp.fc2_weight"].data[:] = ( 190 | hf_state_dict[layer_prefix + "mlp.down_proj.weight"].data[:] 191 | ) 192 | return all_layer_prefixes 193 | -------------------------------------------------------------------------------- /ChatTTS/model/dvae.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional, Literal, Tuple 3 | 4 | import numpy as np 5 | import pybase16384 as b14 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchaudio 10 | from vector_quantize_pytorch import GroupedResidualFSQ 11 | 12 | 13 | class ConvNeXtBlock(nn.Module): 14 | def __init__( 15 | self, 16 | dim: int, 17 | intermediate_dim: int, 18 | kernel: int, 19 | dilation: int, 20 | layer_scale_init_value: float = 1e-6, 21 | ): 22 | # ConvNeXt Block copied from Vocos. 23 | super().__init__() 24 | self.dwconv = nn.Conv1d( 25 | dim, 26 | dim, 27 | kernel_size=kernel, 28 | padding=dilation * (kernel // 2), 29 | dilation=dilation, 30 | groups=dim, 31 | ) # depthwise conv 32 | 33 | self.norm = nn.LayerNorm(dim, eps=1e-6) 34 | self.pwconv1 = nn.Linear( 35 | dim, intermediate_dim 36 | ) # pointwise/1x1 convs, implemented with linear layers 37 | self.act = nn.GELU() 38 | self.pwconv2 = nn.Linear(intermediate_dim, dim) 39 | self.gamma = ( 40 | nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) 41 | if layer_scale_init_value > 0 42 | else None 43 | ) 44 | 45 | def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor: 46 | residual = x 47 | 48 | y = self.dwconv(x) 49 | y.transpose_(1, 2) # (B, C, T) -> (B, T, C) 50 | x = self.norm(y) 51 | del y 52 | y = self.pwconv1(x) 53 | del x 54 | x = self.act(y) 55 | del y 56 | y = self.pwconv2(x) 57 | del x 58 | if self.gamma is not None: 59 | y *= self.gamma 60 | y.transpose_(1, 2) # (B, T, C) -> (B, C, T) 61 | 62 | x = y + residual 63 | del y 64 | 65 | return x 66 | 67 | 68 | class GFSQ(nn.Module): 69 | 70 | def __init__( 71 | self, dim: int, levels: List[int], G: int, R: int, eps=1e-5, transpose=True 72 | ): 73 | super(GFSQ, self).__init__() 74 | self.quantizer = GroupedResidualFSQ( 75 | dim=dim, 76 | levels=list(levels), 77 | num_quantizers=R, 78 | groups=G, 79 | ) 80 | self.n_ind = math.prod(levels) 81 | self.eps = eps 82 | self.transpose = transpose 83 | self.G = G 84 | self.R = R 85 | 86 | def _embed(self, x: torch.Tensor): 87 | if self.transpose: 88 | x = x.transpose(1, 2) 89 | """ 90 | x = rearrange( 91 | x, "b t (g r) -> g b t r", g = self.G, r = self.R, 92 | ) 93 | """ 94 | x = x.view(x.size(0), x.size(1), self.G, self.R).permute(2, 0, 1, 3) 95 | feat = self.quantizer.get_output_from_indices(x) 96 | return feat.transpose_(1, 2) if self.transpose else feat 97 | 98 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 99 | return super().__call__(x) 100 | 101 | def forward(self, x: torch.Tensor) -> torch.Tensor: 102 | if self.transpose: 103 | x.transpose_(1, 2) 104 | # feat, ind = self.quantizer(x) 105 | _, ind = self.quantizer(x) 106 | """ 107 | ind = rearrange( 108 | ind, "g b t r ->b t (g r)", 109 | ) 110 | """ 111 | ind = ind.permute(1, 2, 0, 3).contiguous() 112 | ind = ind.view(ind.size(0), ind.size(1), -1) 113 | """ 114 | embed_onehot_tmp = F.one_hot(ind.long(), self.n_ind) 115 | embed_onehot = embed_onehot_tmp.to(x.dtype) 116 | del embed_onehot_tmp 117 | e_mean = torch.mean(embed_onehot, dim=[0, 1]) 118 | # e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1) 119 | torch.div(e_mean, (e_mean.sum(dim=1) + self.eps).unsqueeze(1), out=e_mean) 120 | perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1)) 121 | 122 | return 123 | torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device), 124 | feat.transpose_(1, 2) if self.transpose else feat, 125 | perplexity, 126 | """ 127 | return ind.transpose_(1, 2) if self.transpose else ind 128 | 129 | 130 | class DVAEDecoder(nn.Module): 131 | def __init__( 132 | self, 133 | idim: int, 134 | odim: int, 135 | n_layer=12, 136 | bn_dim=64, 137 | hidden=256, 138 | kernel=7, 139 | dilation=2, 140 | up=False, 141 | ): 142 | super().__init__() 143 | self.up = up 144 | self.conv_in = nn.Sequential( 145 | nn.Conv1d(idim, bn_dim, 3, 1, 1), 146 | nn.GELU(), 147 | nn.Conv1d(bn_dim, hidden, 3, 1, 1), 148 | ) 149 | self.decoder_block = nn.ModuleList( 150 | [ 151 | ConvNeXtBlock( 152 | hidden, 153 | hidden * 4, 154 | kernel, 155 | dilation, 156 | ) 157 | for _ in range(n_layer) 158 | ] 159 | ) 160 | self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False) 161 | 162 | def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor: 163 | # B, C, T 164 | y = self.conv_in(x) 165 | del x 166 | for f in self.decoder_block: 167 | y = f(y, conditioning) 168 | 169 | x = self.conv_out(y) 170 | del y 171 | return x 172 | 173 | 174 | class MelSpectrogramFeatures(torch.nn.Module): 175 | def __init__( 176 | self, 177 | sample_rate=24000, 178 | n_fft=1024, 179 | hop_length=256, 180 | n_mels=100, 181 | padding: Literal["center", "same"] = "center", 182 | ): 183 | super().__init__() 184 | if padding not in ["center", "same"]: 185 | raise ValueError("Padding must be 'center' or 'same'.") 186 | self.padding = padding 187 | self.mel_spec = torchaudio.transforms.MelSpectrogram( 188 | sample_rate=sample_rate, 189 | n_fft=n_fft, 190 | hop_length=hop_length, 191 | n_mels=n_mels, 192 | center=padding == "center", 193 | power=1, 194 | ) 195 | 196 | def __call__(self, audio: torch.Tensor) -> torch.Tensor: 197 | return super().__call__(audio) 198 | 199 | def forward(self, audio: torch.Tensor) -> torch.Tensor: 200 | mel: torch.Tensor = self.mel_spec(audio) 201 | features = torch.log(torch.clip(mel, min=1e-5)) 202 | return features 203 | 204 | 205 | class DVAE(nn.Module): 206 | def __init__( 207 | self, 208 | decoder_config: dict, 209 | encoder_config: Optional[dict] = None, 210 | vq_config: Optional[dict] = None, 211 | dim=512, 212 | coef: Optional[str] = None, 213 | ): 214 | super().__init__() 215 | if coef is None: 216 | coef = torch.rand(100) 217 | else: 218 | coef = torch.from_numpy( 219 | np.copy(np.frombuffer(b14.decode_from_string(coef), dtype=np.float32)) 220 | ) 221 | self.register_buffer("coef", coef.unsqueeze(0).unsqueeze_(2)) 222 | 223 | if encoder_config is not None: 224 | self.downsample_conv = nn.Sequential( 225 | nn.Conv1d(100, dim, 3, 1, 1), 226 | nn.GELU(), 227 | nn.Conv1d(dim, dim, 4, 2, 1), 228 | nn.GELU(), 229 | ) 230 | self.preprocessor_mel = MelSpectrogramFeatures() 231 | self.encoder: Optional[DVAEDecoder] = DVAEDecoder(**encoder_config) 232 | 233 | self.decoder = DVAEDecoder(**decoder_config) 234 | self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False) 235 | if vq_config is not None: 236 | self.vq_layer = GFSQ(**vq_config) 237 | else: 238 | self.vq_layer = None 239 | 240 | def __repr__(self) -> str: 241 | return b14.encode_to_string( 242 | self.coef.cpu().numpy().astype(np.float32).tobytes() 243 | ) 244 | 245 | def __call__( 246 | self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode" 247 | ) -> torch.Tensor: 248 | return super().__call__(inp, mode) 249 | 250 | @torch.inference_mode() 251 | def forward( 252 | self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode" 253 | ) -> torch.Tensor: 254 | if mode == "encode" and hasattr(self, "encoder") and self.vq_layer is not None: 255 | mel = self.preprocessor_mel(inp) 256 | x: torch.Tensor = self.downsample_conv( 257 | torch.div(mel, self.coef.view(100, 1).expand(mel.shape), out=mel), 258 | ).unsqueeze_(0) 259 | del mel 260 | x = self.encoder(x) 261 | ind = self.vq_layer(x) 262 | del x 263 | return ind 264 | 265 | if self.vq_layer is not None: 266 | vq_feats = self.vq_layer._embed(inp) 267 | else: 268 | vq_feats = inp 269 | 270 | vq_feats = ( 271 | vq_feats.view( 272 | (vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)), 273 | ) 274 | .permute(0, 2, 3, 1) 275 | .flatten(2) 276 | ) 277 | 278 | dec_out = self.out_conv( 279 | self.decoder( 280 | x=vq_feats, 281 | ), 282 | ) 283 | 284 | del vq_feats 285 | 286 | return torch.mul(dec_out, self.coef, out=dec_out) 287 | -------------------------------------------------------------------------------- /ChatTTS/model/processors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from transformers.generation import TopKLogitsWarper, TopPLogitsWarper 4 | 5 | 6 | class CustomRepetitionPenaltyLogitsProcessorRepeat: 7 | 8 | def __init__(self, penalty: float, max_input_ids: int, past_window: int): 9 | if not isinstance(penalty, float) or not (penalty > 0): 10 | raise ValueError( 11 | f"`penalty` has to be a strictly positive float, but is {penalty}" 12 | ) 13 | 14 | self.penalty = penalty 15 | self.max_input_ids = max_input_ids 16 | self.past_window = past_window 17 | 18 | def __call__( 19 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor 20 | ) -> torch.FloatTensor: 21 | if input_ids.size(1) > self.past_window: 22 | input_ids = input_ids.narrow(1, -self.past_window, self.past_window) 23 | freq = F.one_hot(input_ids, scores.size(1)).sum(1) 24 | if freq.size(0) > self.max_input_ids: 25 | freq.narrow( 26 | 0, self.max_input_ids, freq.size(0) - self.max_input_ids 27 | ).zero_() 28 | alpha = torch.pow(self.penalty, freq) 29 | scores = scores.contiguous() 30 | inp = scores.multiply(alpha) 31 | oth = scores.divide(alpha) 32 | con = scores < 0 33 | out = torch.where(con, inp, oth) 34 | del inp, oth, scores, con, alpha 35 | return out 36 | 37 | 38 | def gen_logits( 39 | num_code: int, 40 | top_P=0.7, 41 | top_K=20, 42 | repetition_penalty=1.0, 43 | ): 44 | logits_warpers = [] 45 | if top_P is not None: 46 | logits_warpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3)) 47 | if top_K is not None: 48 | logits_warpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3)) 49 | 50 | logits_processors = [] 51 | if repetition_penalty is not None and repetition_penalty != 1: 52 | logits_processors.append( 53 | CustomRepetitionPenaltyLogitsProcessorRepeat( 54 | repetition_penalty, num_code, 16 55 | ) 56 | ) 57 | 58 | return logits_warpers, logits_processors 59 | -------------------------------------------------------------------------------- /ChatTTS/model/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 4 | """ 5 | https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning 6 | """ 7 | 8 | from typing import List, Tuple, Optional 9 | import lzma 10 | 11 | import numpy as np 12 | import pybase16384 as b14 13 | import torch 14 | import torch.nn.functional as F 15 | from transformers import BertTokenizerFast 16 | 17 | from ..utils import del_all 18 | 19 | 20 | class Tokenizer: 21 | def __init__( 22 | self, tokenizer_path: torch.serialization.FILE_LIKE, device: torch.device 23 | ): 24 | tokenizer: BertTokenizerFast = torch.load( 25 | tokenizer_path, map_location=device, mmap=True 26 | ) 27 | self._tokenizer = tokenizer 28 | 29 | self.len = len(tokenizer) 30 | self.spk_emb_ids = tokenizer.convert_tokens_to_ids("[spk_emb]") 31 | self.break_0_ids = tokenizer.convert_tokens_to_ids("[break_0]") 32 | self.eos_token = tokenizer.convert_tokens_to_ids("[Ebreak]") 33 | 34 | self.decode = self._tokenizer.batch_decode 35 | 36 | @torch.inference_mode() 37 | def encode( 38 | self, 39 | text: List[str], 40 | num_vq: int, 41 | prompt_str: Optional[str] = None, 42 | device="cpu", 43 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 44 | 45 | input_ids_lst = [] 46 | attention_mask_lst = [] 47 | max_input_ids_len = -1 48 | max_attention_mask_len = -1 49 | prompt_size = 0 50 | 51 | prompt = self._decode_prompt(prompt_str) if prompt_str is not None else None 52 | 53 | if prompt is not None: 54 | assert prompt.size(0) == num_vq, "prompt dim 0 must equal to num_vq" 55 | prompt_size = prompt.size(1) 56 | 57 | # avoid random speaker embedding of tokenizer in the other dims 58 | for t in text: 59 | x = self._tokenizer.encode_plus( 60 | t, return_tensors="pt", add_special_tokens=False, padding=True 61 | ) 62 | input_ids_lst.append(x["input_ids"].squeeze_(0)) 63 | attention_mask_lst.append(x["attention_mask"].squeeze_(0)) 64 | del_all(x) 65 | ids_sz = input_ids_lst[-1].size(0) 66 | if ids_sz > max_input_ids_len: 67 | max_input_ids_len = ids_sz 68 | attn_sz = attention_mask_lst[-1].size(0) 69 | if attn_sz > max_attention_mask_len: 70 | max_attention_mask_len = attn_sz 71 | 72 | if prompt is not None: 73 | max_input_ids_len += prompt_size 74 | max_attention_mask_len += prompt_size 75 | 76 | input_ids = torch.zeros( 77 | len(input_ids_lst), 78 | max_input_ids_len, 79 | device=device, 80 | dtype=input_ids_lst[0].dtype, 81 | ) 82 | for i in range(len(input_ids_lst)): 83 | input_ids.narrow(0, i, 1).narrow( 84 | 1, 85 | max_input_ids_len - prompt_size - input_ids_lst[i].size(0), 86 | input_ids_lst[i].size(0), 87 | ).copy_( 88 | input_ids_lst[i] 89 | ) # left padding 90 | del_all(input_ids_lst) 91 | 92 | attention_mask = torch.zeros( 93 | len(attention_mask_lst), 94 | max_attention_mask_len, 95 | device=device, 96 | dtype=attention_mask_lst[0].dtype, 97 | ) 98 | for i in range(len(attention_mask_lst)): 99 | attn = attention_mask.narrow(0, i, 1) 100 | attn.narrow( 101 | 1, 102 | max_attention_mask_len - prompt_size - attention_mask_lst[i].size(0), 103 | attention_mask_lst[i].size(0), 104 | ).copy_( 105 | attention_mask_lst[i] 106 | ) # left padding 107 | if prompt_size > 0: 108 | attn.narrow( 109 | 1, 110 | max_attention_mask_len - prompt_size, 111 | prompt_size, 112 | ).fill_(1) 113 | del_all(attention_mask_lst) 114 | 115 | text_mask = attention_mask.bool() 116 | new_input_ids = input_ids.unsqueeze_(-1).expand(-1, -1, num_vq).clone() 117 | del input_ids 118 | 119 | if prompt_size > 0: 120 | text_mask.narrow(1, max_input_ids_len - prompt_size, prompt_size).fill_(0) 121 | prompt_t = prompt.t().unsqueeze_(0).expand(new_input_ids.size(0), -1, -1) 122 | new_input_ids.narrow( 123 | 1, 124 | max_input_ids_len - prompt_size, 125 | prompt_size, 126 | ).copy_(prompt_t) 127 | del prompt_t 128 | 129 | return new_input_ids, attention_mask, text_mask 130 | 131 | @staticmethod 132 | def _decode_spk_emb(spk_emb: str) -> np.ndarray: 133 | return np.frombuffer( 134 | lzma.decompress( 135 | b14.decode_from_string(spk_emb), 136 | format=lzma.FORMAT_RAW, 137 | filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], 138 | ), 139 | dtype=np.float16, 140 | ).copy() 141 | 142 | @torch.no_grad() 143 | def apply_spk_emb( 144 | self, 145 | emb: torch.Tensor, 146 | spk_emb: str, 147 | input_ids: torch.Tensor, 148 | device: torch.device, 149 | ): 150 | n = ( 151 | F.normalize( 152 | torch.from_numpy( 153 | self._decode_spk_emb(spk_emb), 154 | ), 155 | p=2.0, 156 | dim=0, 157 | eps=1e-12, 158 | ) 159 | .to(device) 160 | .unsqueeze_(0) 161 | .expand(emb.size(0), -1) 162 | .unsqueeze_(1) 163 | .expand(emb.shape) 164 | ) 165 | cond = input_ids.narrow(-1, 0, 1).eq(self.spk_emb_ids).expand(emb.shape) 166 | torch.where(cond, n, emb, out=emb) 167 | del cond, n 168 | 169 | @staticmethod 170 | @torch.no_grad() 171 | def _decode_prompt(prompt: str) -> torch.Tensor: 172 | dec = b14.decode_from_string(prompt) 173 | shp = np.frombuffer(dec[:4], dtype=" str: 188 | arr: np.ndarray = prompt.to(dtype=torch.uint16, device="cpu").numpy() 189 | shp = arr.shape 190 | assert len(shp) == 2, "prompt must be a 2D tensor" 191 | s = b14.encode_to_string( 192 | np.array(shp, dtype=" str: 205 | arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy() 206 | s = b14.encode_to_string( 207 | lzma.compress( 208 | arr.tobytes(), 209 | format=lzma.FORMAT_RAW, 210 | filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], 211 | ), 212 | ) 213 | del arr 214 | return s 215 | -------------------------------------------------------------------------------- /ChatTTS/norm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import re 4 | from typing import Dict, Tuple, List, Literal, Callable, Optional 5 | import sys 6 | 7 | from numba import jit 8 | import numpy as np 9 | 10 | from .utils import del_all 11 | 12 | 13 | @jit 14 | def _find_index(table: np.ndarray, val: np.uint16): 15 | for i in range(table.size): 16 | if table[i] == val: 17 | return i 18 | return -1 19 | 20 | 21 | @jit 22 | def _fast_replace( 23 | table: np.ndarray, text: bytes 24 | ) -> Tuple[np.ndarray, List[Tuple[str, str]]]: 25 | result = np.frombuffer(text, dtype=np.uint16).copy() 26 | replaced_words = [] 27 | for i in range(result.size): 28 | ch = result[i] 29 | p = _find_index(table[0], ch) 30 | if p >= 0: 31 | repl_char = table[1][p] 32 | result[i] = repl_char 33 | replaced_words.append((chr(ch), chr(repl_char))) 34 | return result, replaced_words 35 | 36 | 37 | class Normalizer: 38 | def __init__(self, map_file_path: str, logger=logging.getLogger(__name__)): 39 | self.logger = logger 40 | self.normalizers: Dict[str, Callable[[str], str]] = {} 41 | self.homophones_map = self._load_homophones_map(map_file_path) 42 | """ 43 | homophones_map 44 | 45 | Replace the mispronounced characters with correctly pronounced ones. 46 | 47 | Creation process of homophones_map.json: 48 | 49 | 1. Establish a word corpus using the [Tencent AI Lab Embedding Corpora v0.2.0 large] with 12 million entries. After cleaning, approximately 1.8 million entries remain. Use ChatTTS to infer the text. 50 | 2. Record discrepancies between the inferred and input text, identifying about 180,000 misread words. 51 | 3. Create a pinyin to common characters mapping using correctly read characters by ChatTTS. 52 | 4. For each discrepancy, extract the correct pinyin using [python-pinyin] and find homophones with the correct pronunciation from the mapping. 53 | 54 | Thanks to: 55 | [Tencent AI Lab Embedding Corpora for Chinese and English Words and Phrases](https://ai.tencent.com/ailab/nlp/en/embedding.html) 56 | [python-pinyin](https://github.com/mozillazg/python-pinyin) 57 | 58 | """ 59 | self.coding = "utf-16-le" if sys.byteorder == "little" else "utf-16-be" 60 | self.reject_pattern = re.compile(r"[^\u4e00-\u9fffA-Za-z,。、,\. ]") 61 | self.sub_pattern = re.compile(r"\[uv_break\]|\[laugh\]|\[lbreak\]") 62 | self.chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]") 63 | self.english_word_pattern = re.compile(r"\b[A-Za-z]+\b") 64 | self.character_simplifier = str.maketrans( 65 | { 66 | ":": ",", 67 | ";": ",", 68 | "!": "。", 69 | "(": ",", 70 | ")": ",", 71 | "【": ",", 72 | "】": ",", 73 | "『": ",", 74 | "』": ",", 75 | "「": ",", 76 | "」": ",", 77 | "《": ",", 78 | "》": ",", 79 | "-": ",", 80 | ":": ",", 81 | ";": ",", 82 | "!": ".", 83 | "(": ",", 84 | ")": ",", 85 | #"[": ",", 86 | #"]": ",", 87 | ">": ",", 88 | "<": ",", 89 | "-": ",", 90 | } 91 | ) 92 | self.halfwidth_2_fullwidth = str.maketrans( 93 | { 94 | "!": "!", 95 | '"': "“", 96 | "'": "‘", 97 | "#": "#", 98 | "$": "$", 99 | "%": "%", 100 | "&": "&", 101 | "(": "(", 102 | ")": ")", 103 | ",": ",", 104 | "-": "-", 105 | "*": "*", 106 | "+": "+", 107 | ".": "。", 108 | "/": "/", 109 | ":": ":", 110 | ";": ";", 111 | "<": "<", 112 | "=": "=", 113 | ">": ">", 114 | "?": "?", 115 | "@": "@", 116 | # '[': '[', 117 | "\\": "\", 118 | # ']': ']', 119 | "^": "^", 120 | # '_': '_', 121 | "`": "`", 122 | "{": "{", 123 | "|": "|", 124 | "}": "}", 125 | "~": "~", 126 | } 127 | ) 128 | 129 | def __call__( 130 | self, 131 | text: str, 132 | do_text_normalization=True, 133 | do_homophone_replacement=True, 134 | lang: Optional[Literal["zh", "en"]] = None, 135 | ) -> str: 136 | if do_text_normalization: 137 | _lang = self._detect_language(text) if lang is None else lang 138 | if _lang in self.normalizers: 139 | text = self.normalizers[_lang](text) 140 | if _lang == "zh": 141 | text = self._apply_half2full_map(text) 142 | invalid_characters = self._count_invalid_characters(text) 143 | if len(invalid_characters): 144 | self.logger.warning(f"found invalid characters: {invalid_characters}") 145 | text = self._apply_character_map(text) 146 | if do_homophone_replacement: 147 | arr, replaced_words = _fast_replace( 148 | self.homophones_map, 149 | text.encode(self.coding), 150 | ) 151 | if replaced_words: 152 | text = arr.tobytes().decode(self.coding) 153 | repl_res = ", ".join([f"{_[0]}->{_[1]}" for _ in replaced_words]) 154 | self.logger.info(f"replace homophones: {repl_res}") 155 | if len(invalid_characters): 156 | text = self.reject_pattern.sub("", text) 157 | return text 158 | 159 | def register(self, name: str, normalizer: Callable[[str], str]) -> bool: 160 | if name in self.normalizers: 161 | self.logger.warning(f"name {name} has been registered") 162 | return False 163 | try: 164 | val = normalizer("test string 测试字符串") 165 | if not isinstance(val, str): 166 | self.logger.warning("normalizer must have caller type (str) -> str") 167 | return False 168 | except Exception as e: 169 | self.logger.warning(e) 170 | return False 171 | self.normalizers[name] = normalizer 172 | return True 173 | 174 | def unregister(self, name: str): 175 | if name in self.normalizers: 176 | del self.normalizers[name] 177 | 178 | def destroy(self): 179 | del_all(self.normalizers) 180 | del self.homophones_map 181 | 182 | def _load_homophones_map(self, map_file_path: str) -> np.ndarray: 183 | with open(map_file_path, "r", encoding="utf-8") as f: 184 | homophones_map: Dict[str, str] = json.load(f) 185 | map = np.empty((2, len(homophones_map)), dtype=np.uint32) 186 | for i, k in enumerate(homophones_map.keys()): 187 | map[:, i] = (ord(k), ord(homophones_map[k])) 188 | del homophones_map 189 | return map 190 | 191 | def _count_invalid_characters(self, s: str): 192 | s = self.sub_pattern.sub("", s) 193 | non_alphabetic_chinese_chars = self.reject_pattern.findall(s) 194 | return set(non_alphabetic_chinese_chars) 195 | 196 | def _apply_half2full_map(self, text: str) -> str: 197 | return text.translate(self.halfwidth_2_fullwidth) 198 | 199 | def _apply_character_map(self, text: str) -> str: 200 | return text.translate(self.character_simplifier) 201 | 202 | def _detect_language(self, sentence: str) -> Literal["zh", "en"]: 203 | chinese_chars = self.chinese_char_pattern.findall(sentence) 204 | english_words = self.english_word_pattern.findall(sentence) 205 | 206 | if len(chinese_chars) > len(english_words): 207 | return "zh" 208 | else: 209 | return "en" 210 | -------------------------------------------------------------------------------- /ChatTTS/res/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/ChatTTS/res/__init__.py -------------------------------------------------------------------------------- /ChatTTS/res/sha256_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "sha256_asset_Decoder_pt" : "9964e36e840f0e3a748c5f716fe6de6490d2135a5f5155f4a642d51860e2ec38", 3 | "sha256_asset_DVAE_full_pt" : "553eb75763511e23f3e5f86303e2163c5ca775489d637fb635d979c8ae58bbe5", 4 | "sha256_asset_GPT_pt" : "d7d4ee6461ea097a2be23eb40d73fb94ad3b3d39cb64fbb50cb3357fd466cadb", 5 | "sha256_asset_spk_stat_pt" : "3228d8a4cbbf349d107a1b76d2f47820865bd3c9928c4bdfe1cefd5c7071105f", 6 | "sha256_asset_tokenizer_pt" : "e911ae7c6a7c27953433f35c44227a67838fe229a1f428503bdb6cd3d1bcc69c", 7 | "sha256_asset_Vocos_pt" : "09a670eda1c08b740013679c7a90ebb7f1a97646ea7673069a6838e6b51d6c58" 8 | } 9 | -------------------------------------------------------------------------------- /ChatTTS/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dl import check_all_assets, download_all_assets 2 | from .gpu import select_device 3 | from .io import get_latest_modified_file, del_all 4 | from .log import logger 5 | -------------------------------------------------------------------------------- /ChatTTS/utils/dl.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import hashlib 4 | import requests 5 | from io import BytesIO 6 | from typing import Dict 7 | from mmap import mmap, ACCESS_READ 8 | 9 | from .log import logger 10 | 11 | 12 | def sha256(fileno: int) -> str: 13 | data = mmap(fileno, 0, access=ACCESS_READ) 14 | h = hashlib.sha256(data).hexdigest() 15 | del data 16 | return h 17 | 18 | 19 | def check_model( 20 | dir_name: Path, model_name: str, hash: str, remove_incorrect=False 21 | ) -> bool: 22 | target = dir_name / model_name 23 | relname = target.as_posix() 24 | logger.get_logger().debug(f"checking {relname}...") 25 | if not os.path.exists(target): 26 | logger.get_logger().info(f"{target} not exist.") 27 | return False 28 | with open(target, "rb") as f: 29 | digest = sha256(f.fileno()) 30 | bakfile = f"{target}.bak" 31 | if digest != hash: 32 | logger.get_logger().warning(f"{target} sha256 hash mismatch.") 33 | logger.get_logger().info(f"expected: {hash}") 34 | logger.get_logger().info(f"real val: {digest}") 35 | if remove_incorrect: 36 | if not os.path.exists(bakfile): 37 | os.rename(str(target), bakfile) 38 | else: 39 | os.remove(str(target)) 40 | return False 41 | if remove_incorrect and os.path.exists(bakfile): 42 | os.remove(bakfile) 43 | return True 44 | 45 | 46 | def check_all_assets(base_dir: Path, sha256_map: Dict[str, str], update=False) -> bool: 47 | logger.get_logger().info("checking assets...") 48 | current_dir = base_dir / "asset" 49 | names = [ 50 | "Decoder.pt", 51 | "DVAE_full.pt", 52 | "GPT.pt", 53 | "spk_stat.pt", 54 | "tokenizer.pt", 55 | "Vocos.pt", 56 | ] 57 | for model in names: 58 | menv = model.replace(".", "_") 59 | if not check_model( 60 | current_dir, model, sha256_map[f"sha256_asset_{menv}"], update 61 | ): 62 | return False 63 | 64 | logger.get_logger().info("all assets are already latest.") 65 | return True 66 | 67 | 68 | def download_and_extract_tar_gz(url: str, folder: str): 69 | import tarfile 70 | 71 | logger.get_logger().info(f"downloading {url}") 72 | response = requests.get(url, stream=True, timeout=(5, 10)) 73 | with BytesIO() as out_file: 74 | out_file.write(response.content) 75 | out_file.seek(0) 76 | logger.get_logger().info(f"downloaded.") 77 | with tarfile.open(fileobj=out_file, mode="r:gz") as tar: 78 | tar.extractall(folder) 79 | logger.get_logger().info(f"extracted into {folder}") 80 | 81 | 82 | def download_and_extract_zip(url: str, folder: str): 83 | import zipfile 84 | 85 | logger.get_logger().info(f"downloading {url}") 86 | response = requests.get(url, stream=True, timeout=(5, 10)) 87 | with BytesIO() as out_file: 88 | out_file.write(response.content) 89 | out_file.seek(0) 90 | logger.get_logger().info(f"downloaded.") 91 | with zipfile.ZipFile(out_file) as zip_ref: 92 | zip_ref.extractall(folder) 93 | logger.get_logger().info(f"extracted into {folder}") 94 | 95 | 96 | def download_dns_yaml(url: str, folder: str): 97 | logger.get_logger().info(f"downloading {url}") 98 | response = requests.get(url, stream=True, timeout=(5, 10)) 99 | with open(os.path.join(folder, "dns.yaml"), "wb") as out_file: 100 | out_file.write(response.content) 101 | logger.get_logger().info(f"downloaded into {folder}") 102 | 103 | 104 | def download_all_assets(tmpdir: str, version="0.2.6"): 105 | import subprocess 106 | import platform 107 | 108 | archs = { 109 | "aarch64": "arm64", 110 | "armv8l": "arm64", 111 | "arm64": "arm64", 112 | "x86": "386", 113 | "i386": "386", 114 | "i686": "386", 115 | "386": "386", 116 | "x86_64": "amd64", 117 | "x64": "amd64", 118 | "amd64": "amd64", 119 | } 120 | system_type = platform.system().lower() 121 | architecture = platform.machine().lower() 122 | is_win = system_type == "windows" 123 | 124 | architecture = archs.get(architecture, None) 125 | if not architecture: 126 | logger.get_logger().error(f"architecture {architecture} is not supported") 127 | exit(1) 128 | try: 129 | BASE_URL = "https://github.com/fumiama/RVC-Models-Downloader/releases/download/" 130 | suffix = "zip" if is_win else "tar.gz" 131 | RVCMD_URL = BASE_URL + f"v{version}/rvcmd_{system_type}_{architecture}.{suffix}" 132 | cmdfile = os.path.join(tmpdir, "rvcmd") 133 | print(f'{RVCMD_URL=},{tmpdir=}') 134 | if is_win: 135 | download_and_extract_zip(RVCMD_URL, tmpdir) 136 | cmdfile += ".exe" 137 | else: 138 | download_and_extract_tar_gz(RVCMD_URL, tmpdir) 139 | os.chmod(cmdfile, 0o755) 140 | print(f'{cmdfile=}') 141 | subprocess.run([cmdfile, "-notui", "-w", "0", "assets/chtts"]) 142 | except Exception: 143 | BASE_URL = "https://raw.gitcode.com/u011570312/RVC-Models-Downloader/assets/" 144 | suffix = { 145 | "darwin_amd64": "987", 146 | "darwin_arm64": "988", 147 | "linux_386": "989", 148 | "linux_amd64": "990", 149 | "linux_arm64": "991", 150 | "windows_386": "992", 151 | "windows_amd64": "993", 152 | }[f"{system_type}_{architecture}"] 153 | RVCMD_URL = BASE_URL + suffix 154 | download_dns_yaml( 155 | "https://raw.gitcode.com/u011570312/RVC-Models-Downloader/raw/main/dns.yaml", 156 | tmpdir, 157 | ) 158 | if is_win: 159 | download_and_extract_zip(RVCMD_URL, tmpdir) 160 | cmdfile += ".exe" 161 | else: 162 | download_and_extract_tar_gz(RVCMD_URL, tmpdir) 163 | os.chmod(cmdfile, 0o755) 164 | subprocess.run( 165 | [ 166 | cmdfile, 167 | "-notui", 168 | "-w", 169 | "0", 170 | "-dns", 171 | os.path.join(tmpdir, "dns.yaml"), 172 | "assets/chtts", 173 | ] 174 | ) 175 | -------------------------------------------------------------------------------- /ChatTTS/utils/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import hashlib 4 | import requests 5 | from io import BytesIO 6 | import logging 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def sha256(f) -> str: 12 | sha256_hash = hashlib.sha256() 13 | # Read and update hash in chunks of 4M 14 | for byte_block in iter(lambda: f.read(4 * 1024 * 1024), b""): 15 | sha256_hash.update(byte_block) 16 | return sha256_hash.hexdigest() 17 | 18 | 19 | def check_model( 20 | dir_name: Path, model_name: str, hash: str, remove_incorrect=False 21 | ) -> bool: 22 | target = dir_name / model_name 23 | relname = target.as_posix() 24 | logger.debug(f"checking {relname}...") 25 | if not os.path.exists(target): 26 | logger.info(f"{target} not exist.") 27 | return False 28 | with open(target, "rb") as f: 29 | digest = sha256(f) 30 | bakfile = f"{target}.bak" 31 | if digest != hash: 32 | logger.warn(f"{target} sha256 hash mismatch.") 33 | logger.info(f"expected: {hash}") 34 | logger.info(f"real val: {digest}") 35 | logger.warn("please add parameter --update to download the latest assets.") 36 | if remove_incorrect: 37 | if not os.path.exists(bakfile): 38 | os.rename(str(target), bakfile) 39 | else: 40 | os.remove(str(target)) 41 | return False 42 | if remove_incorrect and os.path.exists(bakfile): 43 | os.remove(bakfile) 44 | return True 45 | 46 | 47 | def check_all_assets(update=False) -> bool: 48 | BASE_DIR = Path(__file__).resolve().parent.parent.parent 49 | 50 | logger.info("checking assets...") 51 | current_dir = BASE_DIR / "asset" 52 | names = [ 53 | "Decoder.pt", 54 | "DVAE.pt", 55 | "GPT.pt", 56 | "spk_stat.pt", 57 | "tokenizer.pt", 58 | "Vocos.pt", 59 | ] 60 | for model in names: 61 | menv = model.replace(".", "_") 62 | if not check_model( 63 | current_dir, model, os.environ[f"sha256_asset_{menv}"], update 64 | ): 65 | return False 66 | 67 | logger.info("checking configs...") 68 | current_dir = BASE_DIR / "config" 69 | names = [ 70 | "decoder.yaml", 71 | "dvae.yaml", 72 | "gpt.yaml", 73 | "path.yaml", 74 | "vocos.yaml", 75 | ] 76 | for model in names: 77 | menv = model.replace(".", "_") 78 | if not check_model( 79 | current_dir, model, os.environ[f"sha256_config_{menv}"], update 80 | ): 81 | return False 82 | 83 | logger.info("all assets are already latest.") 84 | return True 85 | 86 | 87 | def download_and_extract_tar_gz(url: str, folder: str): 88 | import tarfile 89 | 90 | logger.info(f"downloading {url}") 91 | response = requests.get(url, stream=True, timeout=(5, 10)) 92 | with BytesIO() as out_file: 93 | out_file.write(response.content) 94 | out_file.seek(0) 95 | logger.info(f"downloaded.") 96 | with tarfile.open(fileobj=out_file, mode="r:gz") as tar: 97 | tar.extractall(folder) 98 | logger.info(f"extracted into {folder}") 99 | 100 | 101 | def download_and_extract_zip(url: str, folder: str): 102 | import zipfile 103 | 104 | logger.info(f"downloading {url}") 105 | response = requests.get(url, stream=True, timeout=(5, 10)) 106 | with BytesIO() as out_file: 107 | out_file.write(response.content) 108 | out_file.seek(0) 109 | logger.info(f"downloaded.") 110 | with zipfile.ZipFile(out_file) as zip_ref: 111 | zip_ref.extractall(folder) 112 | logger.info(f"extracted into {folder}") 113 | 114 | 115 | def download_dns_yaml(url: str, folder: str): 116 | logger.info(f"downloading {url}") 117 | response = requests.get(url, stream=True, timeout=(5, 10)) 118 | with open(os.path.join(folder, "dns.yaml"), "wb") as out_file: 119 | out_file.write(response.content) 120 | logger.info(f"downloaded into {folder}") 121 | 122 | 123 | def download_all_assets(tmpdir: str, version="0.2.5"): 124 | import subprocess 125 | import platform 126 | 127 | archs = { 128 | "aarch64": "arm64", 129 | "armv8l": "arm64", 130 | "arm64": "arm64", 131 | "x86": "386", 132 | "i386": "386", 133 | "i686": "386", 134 | "386": "386", 135 | "x86_64": "amd64", 136 | "x64": "amd64", 137 | "amd64": "amd64", 138 | } 139 | system_type = platform.system().lower() 140 | architecture = platform.machine().lower() 141 | is_win = system_type == "windows" 142 | 143 | architecture = archs.get(architecture, None) 144 | if not architecture: 145 | logger.error(f"architecture {architecture} is not supported") 146 | exit(1) 147 | try: 148 | BASE_URL = "https://github.com/fumiama/RVC-Models-Downloader/releases/download/" 149 | suffix = "zip" if is_win else "tar.gz" 150 | RVCMD_URL = BASE_URL + f"v{version}/rvcmd_{system_type}_{architecture}.{suffix}" 151 | cmdfile = os.path.join(tmpdir, "rvcmd") 152 | if is_win: 153 | download_and_extract_zip(RVCMD_URL, tmpdir) 154 | cmdfile += ".exe" 155 | else: 156 | download_and_extract_tar_gz(RVCMD_URL, tmpdir) 157 | os.chmod(cmdfile, 0o755) 158 | subprocess.run([cmdfile, "-notui", "-w", "0", "assets/chtts"]) 159 | except Exception: 160 | BASE_URL = "https://raw.gitcode.com/u011570312/RVC-Models-Downloader/assets/" 161 | suffix = { 162 | "darwin_amd64": "555", 163 | "darwin_arm64": "556", 164 | "linux_386": "557", 165 | "linux_amd64": "558", 166 | "linux_arm64": "559", 167 | "windows_386": "562", 168 | "windows_amd64": "563", 169 | }[f"{system_type}_{architecture}"] 170 | RVCMD_URL = BASE_URL + suffix 171 | download_dns_yaml( 172 | "https://raw.gitcode.com/u011570312/RVC-Models-Downloader/raw/main/dns.yaml", 173 | tmpdir, 174 | ) 175 | if is_win: 176 | download_and_extract_zip(RVCMD_URL, tmpdir) 177 | cmdfile += ".exe" 178 | else: 179 | download_and_extract_tar_gz(RVCMD_URL, tmpdir) 180 | os.chmod(cmdfile, 0o755) 181 | subprocess.run( 182 | [ 183 | cmdfile, 184 | "-notui", 185 | "-w", 186 | "0", 187 | "-dns", 188 | os.path.join(tmpdir, "dns.yaml"), 189 | "assets/chtts", 190 | ] 191 | ) 192 | -------------------------------------------------------------------------------- /ChatTTS/utils/gpu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .log import logger 4 | 5 | 6 | def select_device(min_memory=2047, experimental=False): 7 | if torch.cuda.is_available(): 8 | selected_gpu = 0 9 | max_free_memory = -1 10 | for i in range(torch.cuda.device_count()): 11 | props = torch.cuda.get_device_properties(i) 12 | free_memory = props.total_memory - torch.cuda.memory_reserved(i) 13 | if max_free_memory < free_memory: 14 | selected_gpu = i 15 | max_free_memory = free_memory 16 | free_memory_mb = max_free_memory / (1024 * 1024) 17 | if free_memory_mb < min_memory: 18 | logger.get_logger().warning( 19 | f"GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU." 20 | ) 21 | device = torch.device("cpu") 22 | else: 23 | device = torch.device(f"cuda:{selected_gpu}") 24 | elif torch.backends.mps.is_available(): 25 | """ 26 | Currently MPS is slower than CPU while needs more memory and core utility, 27 | so only enable this for experimental use. 28 | """ 29 | if experimental: 30 | # For Apple M1/M2 chips with Metal Performance Shaders 31 | logger.get_logger().warning("experimantal: found apple GPU, using MPS.") 32 | device = torch.device("mps") 33 | else: 34 | logger.get_logger().info("found Apple GPU, but use CPU.") 35 | device = torch.device("cpu") 36 | else: 37 | logger.get_logger().warning("no GPU found, use CPU instead") 38 | device = torch.device("cpu") 39 | 40 | return device 41 | -------------------------------------------------------------------------------- /ChatTTS/utils/gpu_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import logging 4 | 5 | def select_device(min_memory=2048): 6 | logger = logging.getLogger(__name__) 7 | if torch.cuda.is_available(): 8 | available_gpus = [] 9 | for i in range(torch.cuda.device_count()): 10 | props = torch.cuda.get_device_properties(i) 11 | free_memory = props.total_memory - torch.cuda.memory_reserved(i) 12 | available_gpus.append((i, free_memory)) 13 | selected_gpu, max_free_memory = max(available_gpus, key=lambda x: x[1]) 14 | device = torch.device(f'cuda:{selected_gpu}') 15 | free_memory_mb = max_free_memory / (1024 * 1024) 16 | if free_memory_mb < min_memory: 17 | logger.warning(f'GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU.') 18 | device = torch.device('cpu') 19 | elif torch.backends.mps.is_available(): 20 | # For Apple M1/M2 chips with Metal Performance Shaders 21 | logger.info('Apple GPU found, using MPS.') 22 | device = torch.device('mps') 23 | else: 24 | logger.warning('No GPU found, use CPU instead') 25 | device = torch.device('cpu') 26 | 27 | return device 28 | -------------------------------------------------------------------------------- /ChatTTS/utils/infer_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | import torch 4 | import torch.nn.functional as F 5 | import os 6 | import json 7 | 8 | 9 | class CustomRepetitionPenaltyLogitsProcessorRepeat(): 10 | 11 | def __init__(self, penalty: float, max_input_ids, past_window): 12 | if not isinstance(penalty, float) or not (penalty > 0): 13 | raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") 14 | 15 | self.penalty = penalty 16 | self.max_input_ids = max_input_ids 17 | self.past_window = past_window 18 | 19 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 20 | 21 | input_ids = input_ids[:, -self.past_window:] 22 | freq = F.one_hot(input_ids, scores.size(1)).sum(1) 23 | freq[self.max_input_ids:] = 0 24 | alpha = self.penalty**freq 25 | scores = scores.contiguous() 26 | scores = torch.where(scores < 0, scores*alpha, scores/alpha) 27 | 28 | return scores 29 | 30 | class CustomRepetitionPenaltyLogitsProcessor(): 31 | 32 | def __init__(self, penalty: float, max_input_ids, past_window): 33 | if not isinstance(penalty, float) or not (penalty > 0): 34 | raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") 35 | 36 | self.penalty = penalty 37 | self.max_input_ids = max_input_ids 38 | self.past_window = past_window 39 | 40 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 41 | 42 | input_ids = input_ids[:, -self.past_window:] 43 | score = torch.gather(scores, 1, input_ids) 44 | _score = score.detach().clone() 45 | score = torch.where(score < 0, score * self.penalty, score / self.penalty) 46 | score[input_ids>=self.max_input_ids] = _score[input_ids>=self.max_input_ids] 47 | scores.scatter_(1, input_ids, score) 48 | 49 | return scores 50 | 51 | class HomophonesReplacer: 52 | """ 53 | Homophones Replacer 54 | 55 | Replace the mispronounced characters with correctly pronounced ones. 56 | 57 | Creation process of homophones_map.json: 58 | 59 | 1. Establish a word corpus using the [Tencent AI Lab Embedding Corpora v0.2.0 large] with 12 million entries. After cleaning, approximately 1.8 million entries remain. Use ChatTTS to infer the text. 60 | 2. Record discrepancies between the inferred and input text, identifying about 180,000 misread words. 61 | 3. Create a pinyin to common characters mapping using correctly read characters by ChatTTS. 62 | 4. For each discrepancy, extract the correct pinyin using [python-pinyin] and find homophones with the correct pronunciation from the mapping. 63 | 64 | Thanks to: 65 | [Tencent AI Lab Embedding Corpora for Chinese and English Words and Phrases](https://ai.tencent.com/ailab/nlp/en/embedding.html) 66 | [python-pinyin](https://github.com/mozillazg/python-pinyin) 67 | 68 | """ 69 | def __init__(self, map_file_path): 70 | self.homophones_map = self.load_homophones_map(map_file_path) 71 | 72 | def load_homophones_map(self, map_file_path): 73 | with open(map_file_path, 'r', encoding='utf-8') as f: 74 | homophones_map = json.load(f) 75 | return homophones_map 76 | 77 | def replace(self, text): 78 | result = [] 79 | for char in text: 80 | if char in self.homophones_map: 81 | result.append(self.homophones_map[char]) 82 | else: 83 | result.append(char) 84 | return ''.join(result) 85 | 86 | def count_invalid_characters(s): 87 | 88 | s = re.sub(r'\[uv_break\]|\[laugh\]|\[lbreak\]', '', s) 89 | pattern = re.compile(r'[^\u4e00-\u9fffA-Za-z,。、,\. ]') 90 | non_alphabetic_chinese_chars = pattern.findall(s) 91 | return set(non_alphabetic_chinese_chars) 92 | 93 | def detect_language(sentence): 94 | 95 | chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]') 96 | english_word_pattern = re.compile(r'\b[A-Za-z]+\b') 97 | 98 | chinese_chars = chinese_char_pattern.findall(sentence) 99 | english_words = english_word_pattern.findall(sentence) 100 | 101 | if len(chinese_chars) > len(english_words): 102 | return "zh" 103 | else: 104 | return "en" 105 | 106 | 107 | character_map = { 108 | ':': ',', 109 | ';': ',', 110 | '!': '。', 111 | '(': ',', 112 | ')': ',', 113 | '【': ',', 114 | '】': ',', 115 | '『': ',', 116 | '』': ',', 117 | '「': ',', 118 | '」': ',', 119 | '《': ',', 120 | '》': ',', 121 | '-': ',', 122 | '‘': '', 123 | '“': '', 124 | '’': '', 125 | '”': '', 126 | ':': ',', 127 | ';': ',', 128 | '!': '.', 129 | '(': ',', 130 | ')': ',', 131 | #'[': ',', 132 | #']': ',', 133 | '>': ',', 134 | '<': ',', 135 | '-': ',', 136 | } 137 | 138 | halfwidth_2_fullwidth_map = { 139 | '!': '!', 140 | '"': '“', 141 | "'": '‘', 142 | '#': '#', 143 | '$': '$', 144 | '%': '%', 145 | '&': '&', 146 | '(': '(', 147 | ')': ')', 148 | ',': ',', 149 | '-': '-', 150 | '*': '*', 151 | '+': '+', 152 | '.': '。', 153 | '/': '/', 154 | ':': ':', 155 | ';': ';', 156 | '<': '<', 157 | '=': '=', 158 | '>': '>', 159 | '?': '?', 160 | '@': '@', 161 | # '[': '[', 162 | '\\': '\', 163 | # ']': ']', 164 | '^': '^', 165 | # '_': '_', 166 | '`': '`', 167 | '{': '{', 168 | '|': '|', 169 | '}': '}', 170 | '~': '~' 171 | } 172 | 173 | def apply_half2full_map(text): 174 | translation_table = str.maketrans(halfwidth_2_fullwidth_map) 175 | return text.translate(translation_table) 176 | 177 | def apply_character_map(text): 178 | translation_table = str.maketrans(character_map) 179 | return text.translate(translation_table) 180 | -------------------------------------------------------------------------------- /ChatTTS/utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from typing import Union 4 | from dataclasses import is_dataclass 5 | 6 | from .log import logger 7 | 8 | 9 | def get_latest_modified_file(directory): 10 | 11 | files = [os.path.join(directory, f) for f in os.listdir(directory)] 12 | if not files: 13 | logger.get_logger().log( 14 | logging.WARNING, f"no files found in the directory: {directory}" 15 | ) 16 | return None 17 | latest_file = max(files, key=os.path.getmtime) 18 | 19 | return latest_file 20 | 21 | 22 | def del_all(d: Union[dict, list]): 23 | if is_dataclass(d): 24 | for k in list(vars(d).keys()): 25 | x = getattr(d, k) 26 | if isinstance(x, dict) or isinstance(x, list) or is_dataclass(x): 27 | del_all(x) 28 | del x 29 | delattr(d, k) 30 | elif isinstance(d, dict): 31 | lst = list(d.keys()) 32 | for k in lst: 33 | x = d.pop(k) 34 | if isinstance(x, dict) or isinstance(x, list) or is_dataclass(x): 35 | del_all(x) 36 | del x 37 | elif isinstance(d, list): 38 | while len(d): 39 | x = d.pop() 40 | if isinstance(x, dict) or isinstance(x, list) or is_dataclass(x): 41 | del_all(x) 42 | del x 43 | else: 44 | del d 45 | -------------------------------------------------------------------------------- /ChatTTS/utils/io_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import logging 4 | 5 | def get_latest_modified_file(directory): 6 | logger = logging.getLogger(__name__) 7 | 8 | files = [os.path.join(directory, f) for f in os.listdir(directory)] 9 | if not files: 10 | logger.log(logging.WARNING, f'No files found in the directory: {directory}') 11 | return None 12 | latest_file = max(files, key=os.path.getmtime) 13 | 14 | return latest_file -------------------------------------------------------------------------------- /ChatTTS/utils/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | 5 | class Logger: 6 | def __init__(self, logger=logging.getLogger(Path(__file__).parent.name)): 7 | self.logger = logger 8 | 9 | def set_logger(self, logger: logging.Logger): 10 | self.logger = logger 11 | 12 | def get_logger(self) -> logging.Logger: 13 | return self.logger 14 | 15 | 16 | logger = Logger() 17 | -------------------------------------------------------------------------------- /Dockerfile.cpu: -------------------------------------------------------------------------------- 1 | FROM pytorch/torchserve:0.11.0-cpu as builder 2 | 3 | USER root 4 | 5 | RUN apt-get update && apt-get install -y ffmpeg 6 | 7 | WORKDIR /app 8 | 9 | COPY . ./ 10 | 11 | RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple 12 | -------------------------------------------------------------------------------- /Dockerfile.gpu: -------------------------------------------------------------------------------- 1 | FROM pytorch/torchserve:0.11.0-gpu as builder 2 | 3 | USER root 4 | 5 | RUN apt-get update && apt-get install -y ffmpeg 6 | 7 | WORKDIR /app 8 | 9 | COPY . ./ 10 | 11 | RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | [English README](README_EN.md) | [打赏项目](https://github.com/jianchang512/ChatTTS-ui/issues/122) | [Discord Discussion Group](https://discord.gg/y9gUweVCCJ) 3 | 4 | 5 | # ChatTTS webUI & API 6 | 7 | 一个简单的本地网页界面,在网页使用 ChatTTS 将文字合成为语音,支持中英文、数字混杂,并提供API接口. 8 | 9 | 10 | 原 [ChatTTS](https://github.com/2noise/chattts) 项目. 0.96版起,源码部署必须先安装ffmpeg ,之前的音色文件csv和pt已不可用,请填写音色值重新生成.[获取音色](?tab=readme-ov-file#音色获取) 11 | 12 | 13 | > **[赞助商]** 14 | > 15 | > [![](https://github.com/user-attachments/assets/5348c86e-2d5f-44c7-bc1b-3cc5f077e710)](https://gpt302.saaslink.net/teRK8Y) 16 | > [302.AI](https://gpt302.saaslink.net/teRK8Y)是一个按需付费的一站式AI应用平台,开放平台,开源生态, [302.AI开源地址](https://github.com/302ai) 17 | > 18 | > 集合了最新最全的AI模型和品牌/按需付费零月费/管理和使用分离/所有AI能力均提供API/每周推出2-3个新应用 19 | 20 | **界面预览** 21 | 22 | ![image](https://github.com/jianchang512/ChatTTS-ui/assets/3378335/669876cf-5061-4d7d-86c5-3333d0882ee8) 23 | 24 | 25 | 26 | 27 | 28 | 29 | 文字数字符号 控制符混杂效果 30 | 31 | https://github.com/jianchang512/ChatTTS-ui/assets/3378335/e2a08ea0-32af-4a30-8880-3a91f6cbea55 32 | 33 | 34 | ## Windows预打包版 35 | 36 | 1. 从 [Releases](https://github.com/jianchang512/chatTTS-ui/releases)中下载压缩包,解压后双击 app.exe 即可使用 37 | 2. 某些安全软件可能报毒,请退出或使用源码部署 38 | 3. 英伟达显卡大于4G显存,并安装了CUDA11.8+后,将启用GPU加速 39 | 40 | ## 手动下载模型 41 | 42 | 第一次将从huggingface.co或github下载模型到asset目录下,如果网络不稳,可能下载失败,若是失败,请单独下载 43 | 44 | 下载后解压后,会看到asset文件夹,该文件夹内有多个pt文件,将所有pt文件复制到asset目录下,然后重启软件 45 | 46 | GitHub下载地址: https://github.com/jianchang512/ChatTTS-ui/releases/download/v1.0/all-models.7z 47 | 48 | 百度网盘下载地址: https://pan.baidu.com/s/1yGDZM9YNN7kW9e7SFo8lLw?pwd=ct5x 49 | 50 | 51 | 52 | ## Linux 下容器部署 53 | 54 | ### 安装 55 | 56 | 1. 拉取项目仓库 57 | 58 | 在任意路径下克隆项目,例如: 59 | 60 | ```bash 61 | git clone https://github.com/jianchang512/ChatTTS-ui.git chat-tts-ui 62 | ``` 63 | 64 | 2. 启动 Runner 65 | 66 | 进入到项目目录: 67 | 68 | ```bash 69 | cd chat-tts-ui 70 | ``` 71 | 72 | 启动容器并查看初始化日志: 73 | 74 | ```bash 75 | gpu版本 76 | docker compose -f docker-compose.gpu.yaml up -d 77 | 78 | cpu版本 79 | docker compose -f docker-compose.cpu.yaml up -d 80 | 81 | docker compose logs -f --no-log-prefix 82 | 83 | 3. 访问 ChatTTS WebUI 84 | 85 | `启动:['0.0.0.0', '9966']`,也即,访问部署设备的 `IP:9966` 即可,例如: 86 | 87 | - 本机:`http://127.0.0.1:9966` 88 | - 服务器: `http://192.168.1.100:9966` 89 | 90 | ### 更新 91 | 92 | 1. Get the latest code from the main branch: 93 | 94 | ```bash 95 | git checkout main 96 | git pull origin main 97 | ``` 98 | 99 | 2. Go to the next step and update to the latest image: 100 | 101 | ```bash 102 | docker compose down 103 | 104 | gpu版本 105 | docker compose -f docker-compose.gpu.yaml up -d --build 106 | 107 | cpu版本 108 | docker compose -f docker-compose.cpu.yaml up -d --build 109 | 110 | docker compose logs -f --no-log-prefix 111 | ``` 112 | 113 | ## Linux 下源码部署 114 | 115 | 1. 配置好 python3.9-3.11环境,安装 ffmpeg。 `yum install ffmpeg` 或 `apt-get install ffmpeg`等 116 | 2. 创建空目录 `/data/chattts` 执行命令 `cd /data/chattts && git clone https://github.com/jianchang512/chatTTS-ui .` 117 | 3. 创建虚拟环境 `python3 -m venv venv` 118 | 4. 激活虚拟环境 `source ./venv/bin/activate` 119 | 5. 安装依赖 `pip3 install -r requirements.txt` 120 | 6. 如果不需要CUDA加速,执行 121 | 122 | `pip3 install torch==2.2.0 torchaudio==2.2.0` 123 | 124 | 如果需要CUDA加速,执行 125 | 126 | ``` 127 | pip install torch==2.2.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118 128 | 129 | pip install nvidia-cublas-cu11 nvidia-cudnn-cu11 130 | 131 | ``` 132 | 133 | 另需安装 CUDA11.8+ ToolKit,请自行搜索安装方法 或参考 https://juejin.cn/post/7318704408727519270 134 | 135 | 除CUDA外,也可以使用AMD GPU进行加速,这需要安装ROCm和PyTorch_ROCm版本。AMG GPU借助ROCm,在PyTorch开箱即用,无需额外修改代码。 136 | 1. 请参考https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/quick-start.html 来安装AMD GPU Driver及ROCm. 137 | 1. 再通过https://pytorch.org/ 安装PyTorch_ROCm版本。 138 | 139 | 140 | `pip3 install torch==2.2.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/rocm6.0` 141 | 142 | 安装完成后,可以通过rocm-smi命令来查看系统中的AMD GPU。也可以用以下Torch代码(query_gpu.py)来查询当前AMD GPU Device. 143 | 144 | ``` 145 | import torch 146 | 147 | print(torch.__version__) 148 | 149 | if torch.cuda.is_available(): 150 | device = torch.device("cuda") # a CUDA device object 151 | print('Using GPU:', torch.cuda.get_device_name(0)) 152 | else: 153 | device = torch.device("cpu") 154 | print('Using CPU') 155 | 156 | torch.cuda.get_device_properties(0) 157 | 158 | ``` 159 | 160 | 使用以上代码,以AMD Radeon Pro W7900为例,查询设备如下。 161 | 162 | ``` 163 | 164 | $ python ~/query_gpu.py 165 | 166 | 2.4.0.dev20240401+rocm6.0 167 | 168 | Using GPU: AMD Radeon PRO W7900 169 | 170 | ``` 171 | 172 | 173 | 174 | 7. 执行 `python3 app.py` 启动,将自动打开浏览器窗口,默认地址 `http://127.0.0.1:9966` (注意:默认从 modelscope 魔塔下载模型,不可使用代理下载,请关闭代理) 175 | 176 | 177 | ## MacOS 下源码部署 178 | 179 | 1. 配置好 python3.9-3.11 环境,安装git ,执行命令 `brew install libsndfile git python@3.10` 180 | 继续执行 181 | 182 | ``` 183 | brew install ffmpeg 184 | 185 | export PATH="/usr/local/opt/python@3.10/bin:$PATH" 186 | 187 | source ~/.bash_profile 188 | 189 | source ~/.zshrc 190 | 191 | ``` 192 | 193 | 2. 创建空目录 `/data/chattts` 执行命令 `cd /data/chattts && git clone https://github.com/jianchang512/chatTTS-ui .` 194 | 3. 创建虚拟环境 `python3 -m venv venv` 195 | 4. 激活虚拟环境 `source ./venv/bin/activate` 196 | 5. 安装依赖 `pip3 install -r requirements.txt` 197 | 6. 安装torch `pip3 install torch==2.2.0 torchaudio==2.2.0` 198 | 7. 执行 `python3 app.py` 启动,将自动打开浏览器窗口,默认地址 `http://127.0.0.1:9966` (注意:默认从 modelscope 魔塔下载模型,不可使用代理下载,请关闭代理) 199 | 200 | 201 | ## Windows源码部署 202 | 203 | 1. 下载python3.9-3.11,安装时注意选中`Add Python to environment variables` 204 | 2. 下载 ffmpeg.exe 放在 软件目录下的ffmpeg文件夹内 205 | 3. 下载并安装git,https://github.com/git-for-windows/git/releases/download/v2.45.1.windows.1/Git-2.45.1-64-bit.exe 206 | 4. 创建空文件夹 `D:/chattts` 并进入,地址栏输入 `cmd`回车,在弹出的cmd窗口中执行命令 `git clone https://github.com/jianchang512/chatTTS-ui .` 207 | 5. 创建虚拟环境,执行命令 `python -m venv venv` 208 | 6. 激活虚拟环境,执行 `.\venv\scripts\activate` 209 | 7. 安装依赖,执行 `pip install -r requirements.txt` 210 | 8. 如果不需要CUDA加速, 211 | 212 | 执行 `pip install torch==2.2.0 torchaudio==2.2.0` 213 | 214 | 如果需要CUDA加速,执行 215 | 216 | `pip install torch==2.2.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118` 217 | 218 | 另需安装 CUDA11.8+ ToolKit,请自行搜索安装方法或参考 https://juejin.cn/post/7318704408727519270 219 | 220 | 9. 执行 `python app.py` 启动,将自动打开浏览器窗口,默认地址 `http://127.0.0.1:9966` (注意:默认从 modelscope 魔塔下载模型,不可使用代理下载,请关闭代理) 221 | 222 | 223 | ## 源码部署注意 0.96版本起,必须安装ffmpeg 224 | 225 | 1. 如果GPU显存低于4G,将强制使用CPU。 226 | 227 | 2. Windows或Linux下如果显存大于4G并且是英伟达显卡,但源码部署后仍使用CPU,可尝试先卸载torch再重装,卸载`pip uninstall -y torch torchaudio` , 重新安装cuda版torch。`pip install torch==2.2.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118` 。必须已安装CUDA11.8+ 228 | 229 | 3. 默认检测 modelscope 是否可连接,如果可以,则从modelscope下载模型,否则从 huggingface.co下载模型 230 | 231 | 232 | 233 | ## 音色获取 234 | 235 | 0.96版本后,因ChatTTS内核升级,已无法直接使用从该站点下载的pt文件(https://modelscope.cn/studios/ttwwwaa/ChatTTS_Speaker) 236 | 237 | 因此增加转换脚本 cover-pt.py [Win整合包可以直接下载 cover-pt.exe 文件,和 app.exe 放在同一目录下双击执行](https://github.com/jianchang512/ChatTTS-ui/releases) 238 | 239 | 执行 `python cover-pt.py` 后将把 `speaker` 目录下的,以 `seed_` 开头,以 `_emb.pt` 结尾的文件,即下载后的默认文件名pt, 240 | 转换为可用的编码格式,转换后的pt将改名为以 `_emb-covert.pt` 结尾。 241 | 242 | 例: 243 | 244 | 假如 `speaker/seed_2155_restored_emb.pt` 存在这个文件,将被转换为 `speaker/seed_2155_restored_emb-cover.pt`, 然后删掉原pt文件,仅保留该转换后的文件即可 245 | 246 | 247 | 248 | 249 | 250 | ## [常见问题与报错解决方法](faq.md) 251 | 252 | 253 | 254 | 255 | ## 修改http地址 256 | 257 | 默认地址是 `http://127.0.0.1:9966`,如果想修改,可打开目录下的 `.env`文件,将 `WEB_ADDRESS=127.0.0.1:9966`改为合适的ip和端口,比如修改为`WEB_ADDRESS=192.168.0.10:9966`以便局域网可访问 258 | 259 | ## 使用API请求 v0.5+ 260 | 261 | **请求方法:** POST 262 | 263 | **请求地址:** http://127.0.0.1:9966/tts 264 | 265 | **请求参数:** 266 | 267 | text: str| 必须, 要合成语音的文字 268 | 269 | voice: 可选,默认 2222, 决定音色的数字, 2222 | 7869 | 6653 | 4099 | 5099,可选其一,或者任意传入将随机使用音色 270 | 271 | prompt: str| 可选,默认 空, 设定 笑声、停顿,例如 [oral_2][laugh_0][break_6] 272 | 273 | temperature: float| 可选, 默认 0.3 274 | 275 | top_p: float| 可选, 默认 0.7 276 | 277 | top_k: int| 可选, 默认 20 278 | 279 | skip_refine: int| 可选, 默认0, 1=跳过 refine text,0=不跳过 280 | 281 | custom_voice: int| 可选, 默认0,自定义获取音色值时的种子值,需要大于0的整数,如果设置了则以此为准,将忽略 `voice` 282 | 283 | 284 | **返回:json数据** 285 | 286 | 成功返回: 287 | {code:0,msg:ok,audio_files:[dict1,dict2]} 288 | 289 | 其中 audio_files 是字典数组,每个元素dict为 {filename:wav文件绝对路径,url:可下载的wav网址} 290 | 291 | 失败返回: 292 | 293 | {code:1,msg:错误原因} 294 | 295 | ``` 296 | 297 | # API调用代码 298 | 299 | import requests 300 | 301 | res = requests.post('http://127.0.0.1:9966/tts', data={ 302 | "text": "若不懂无需填写", 303 | "prompt": "", 304 | "voice": "3333", 305 | "temperature": 0.3, 306 | "top_p": 0.7, 307 | "top_k": 20, 308 | "skip_refine": 0, 309 | "custom_voice": 0 310 | }) 311 | print(res.json()) 312 | 313 | #ok 314 | {code:0, msg:'ok', audio_files:[{filename: E:/python/chattts/static/wavs/20240601-22_12_12-c7456293f7b5e4dfd3ff83bbd884a23e.wav, url: http://127.0.0.1:9966/static/wavs/20240601-22_12_12-c7456293f7b5e4dfd3ff83bbd884a23e.wav}]} 315 | 316 | #error 317 | {code:1, msg:"error"} 318 | 319 | 320 | ``` 321 | 322 | 323 | ## 在pyVideoTrans软件中使用 324 | 325 | > 升级 pyVideoTrans 到 1.82+ https://github.com/jianchang512/pyvideotrans 326 | 327 | 1. 点击菜单-设置-ChatTTS,填写请求地址,默认应该填写 http://127.0.0.1:9966 328 | 2. 测试无问题后,在主界面中选择`ChatTTS` 329 | 330 | ![image](https://github.com/jianchang512/ChatTTS-ui/assets/3378335/7118325f-2b9a-46ce-a584-1d5c6dc8e2da) 331 | 332 | -------------------------------------------------------------------------------- /README_EN.md: -------------------------------------------------------------------------------- 1 | 2 | [简体中文](README.md) | [Discord Discussion Group](https://discord.gg/y9gUweVCCJ) | [Support the Project](https://github.com/jianchang512/ChatTTS-ui/issues/122) 3 | 4 | # ChatTTS webUI & API 5 | 6 | A simple local web interface to use ChatTTS for text-to-speech synthesis on the web, supporting mixed Chinese and English text and numbers, and providing an API interface. 7 | 8 | > The original [ChatTTS](https://github.com/2noise/chattts) project 9 | 10 | **Interface Preview** 11 | 12 | ![image](https://github.com/jianchang512/ChatTTS-ui/assets/3378335/8d9b36d4-29b9-4cd7-ae70-3e3bd3225108) 13 | 14 | 15 | Sample synthesized voice effects 16 | 17 | https://github.com/jianchang512/ChatTTS-ui/assets/3378335/bd6aaef9-a49a-4a81-803a-91e3320bf808 18 | 19 | Text and control symbols mixed effect 20 | 21 | https://github.com/jianchang512/ChatTTS-ui/assets/3378335/e2a08ea0-32af-4a30-8880-3a91f6cbea55 22 | 23 | 24 | ## Windows Pre-packaged Version 25 | 26 | 1. Download the compressed package from [Releases](https://github.com/jianchang512/chatTTS-ui/releases), unzip it, and double-click app.exe to use. 27 | 2. Some security software may flag it as a virus, please disable or deploy from source. 28 | 3. If you have an Nvidia graphics card with more than 4GB of memory and have installed CUDA11.8+, GPU acceleration will be enabled. 29 | 30 | ## Linux Container Deployment 31 | 32 | ### Installation 33 | 34 | 1. Clone the project repository 35 | 36 | Clone the project to any directory, for example: 37 | 38 | ```bash 39 | git clone https://github.com/jianchang512/ChatTTS-ui.git chat-tts-ui 40 | ``` 41 | 42 | 2. Start Runner 43 | 44 | Enter the project directory: 45 | 46 | ```bash 47 | cd chat-tts-ui 48 | ``` 49 | 50 | Start the container and view the initialization logs: 51 | 52 | ```bash 53 | For GPU version 54 | docker compose -f docker-compose.gpu.yaml up -d 55 | 56 | For CPU version 57 | docker compose -f docker-compose.cpu.yaml up -d 58 | 59 | docker compose logs -f --no-log-prefix 60 | ``` 61 | 62 | 3. Access ChatTTS WebUI 63 | 64 | `Started at:['0.0.0.0', '9966']`, meaning you can access it via `IP:9966` of the deployment device, for example: 65 | 66 | - Localhost: `http://127.0.0.1:9966` 67 | - Server: `http://192.168.1.100:9966` 68 | 69 | ### Update 70 | 71 | 1. Get the latest code from the main branch: 72 | 73 | ```bash 74 | git checkout main 75 | git pull origin main 76 | ``` 77 | 78 | 2. Go to the next step and update to the latest image: 79 | 80 | ```bash 81 | docker compose down 82 | 83 | For GPU version 84 | docker compose -f docker-compose.gpu.yaml up -d --build 85 | 86 | For CPU version 87 | docker compose -f docker-compose.cpu.yaml up -d --build 88 | 89 | docker compose logs -f --no-log-prefix 90 | ``` 91 | 92 | ## Linux Source Code Deployment 93 | 94 | 1. Prepare python3.9-3.11 environment. Install FFmpeg 95 | 2. Create an empty directory `/data/chattts` and execute `cd /data/chattts && git clone https://github.com/jianchang512/chatTTS-ui .`. 96 | 3. Create a virtual environment `python3 -m venv venv`. 97 | 4. Activate the virtual environment `source ./venv/bin/activate`. 98 | 5. Install dependencies `pip3 install -r requirements.txt`. 99 | 6. If CUDA acceleration is not needed, execute 100 | 101 | `pip3 install torch==2.2.0 torchaudio==2.2.0` 102 | 103 | If CUDA acceleration is needed, execute 104 | 105 | ``` 106 | pip install torch==2.2.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118 107 | 108 | pip install nvidia-cublas-cu11 nvidia-cudnn-cu11 109 | ``` 110 | 111 | Additionally, install CUDA11.8+ ToolKit, search for installation methods or refer to https://juejin.cn/post/7318704408727519270 112 | 113 | Besides CUDA, AMD GPU acceleration can also be used by installing ROCm and PyTorch_ROCm version. For AMD GPU, with the help of ROCm, PyTorch works out of the box without further modifications. 114 | 1. Refer to https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/quick-start.html to install AMD GPU Driver and ROCm. 115 | 2. Then install PyTorch_ROCm version from https://pytorch.org/. 116 | 117 | `pip3 install torch==2.2.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/rocm6.0` 118 | 119 | After installation, you can use the command `rocm-smi` to view the AMD GPUs in the system. The following Torch code(query_gpu.py) can also be used to query the current AMD GPU Device. 120 | 121 | ``` 122 | import torch 123 | 124 | print(torch.__version__) 125 | 126 | if torch.cuda.is_available(): 127 | device = torch.device("cuda") # a CUDA device object 128 | print('Using GPU:', torch.cuda.get_device_name(0)) 129 | else: 130 | device = torch.device("cpu") 131 | print('Using CPU') 132 | 133 | torch.cuda.get_device_properties(0) 134 | 135 | ``` 136 | 137 | Using the code above, for instance, with AMD Radeon Pro W7900, the device query is as follows. 138 | 139 | ``` 140 | 141 | $ python ~/query_gpu.py 142 | 143 | 2.4.0.dev20240401+rocm6.0 144 | 145 | Using GPU: AMD Radeon PRO W7900 146 | 147 | ``` 148 | 149 | 150 | 151 | 7. Execute `python3 app.py` to start. It will automatically open a browser window at `http://127.0.0.1:9966`. Note: Models are downloaded from modelscope by default without using a proxy, please disable the proxy. 152 | 153 | 154 | ## MacOS Source Code Deployment 155 | 156 | 1. Prepare the python3.9-3.11 environment and install git. Execute command `brew install libsndfile git python@3.10`. Then continue with 157 | 158 | ``` 159 | brew install ffmpeg 160 | 161 | export PATH="/usr/local/opt/python@3.10/bin:$PATH" 162 | 163 | source ~/.bash_profile 164 | 165 | source ~/.zshrc 166 | 167 | ``` 168 | 169 | 2. Create an empty directory `/data/chattts` and execute command `cd /data/chattts && git clone https://github.com/jianchang512/chatTTS-ui .`. 170 | 3. Create a virtual environment `python3 -m venv venv`. 171 | 4. Activate the virtual environment `source ./venv/bin/activate`. 172 | 5. Install dependencies `pip3 install -r requirements.txt`. 173 | 6. Install torch `pip3 install torch==2.2.0 torchaudio==2.2.0`. 174 | 7. Execute `python3 app.py` to start. It will automatically open a browser window at `http://127.0.0.1:9966`. Note: Models are downloaded from modelscope by default without using a proxy, please disable the proxy. 175 | 176 | 177 | ## Windows Source Code Deployment 178 | 179 | 1. Download python3.9-3.11, make sure to check `Add Python to environment variables` during installation. install ffmpeg.exe 180 | 2. Download and install git from https://github.com/git-for-windows/git/releases/download/v2.45.1.windows.1/Git-2.45.1-64-bit.exe. 181 | 3. Create an empty folder `D:/chattts` and enter it, type `cmd` in the address bar and press Enter. In the cmd window that pops up, execute command `git clone https://github.com/jianchang512/chatTTS-ui .`. 182 | 4. Create a virtual environment by executing command `python -m venv venv`. 183 | 5. Activate the virtual environment by executing `.\venv\scripts\activate`. 184 | 6. Install dependencies by executing `pip install -r requirements.txt`. 185 | 7. If CUDA acceleration is not needed, 186 | 187 | execute `pip install torch==2.2.0 torchaudio==2.2.0`. 188 | 189 | If CUDA acceleration is needed, execute 190 | 191 | `pip install torch==2.2.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118`. 192 | 193 | Additionally, install CUDA11.8+ ToolKit, search for installation methods or refer to https://juejin.cn/post/7318704408727519270. 194 | 195 | 8. Execute `python app.py` to start. It will automatically open a browser window at `http://127.0.0.1:9966`. Note: Models are downloaded from modelscope by default without using a proxy, please disable the proxy. 196 | 197 | 198 | ## Deployment Notes 199 | 200 | 0. install ffmpeg since 0.96 201 | 202 | 1. If the GPU memory is below 4GB, it will forcefully use the CPU. 203 | 204 | 2. Under Windows or Linux, if the memory is more than 4GB and it is an Nvidia graphics card, but the source code deployment still uses CPU, you may try uninstalling torch first and then reinstalling it. Uninstall with `pip uninstall -y torch torchaudio`, then reinstall the CUDA version of torch `pip install torch==2.2.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118`. CUDA11.8+ must be installed. 205 | 206 | 3. By default, it checks whether modelscope can be connected. If so, models are downloaded from modelscope; otherwise, models are downloaded from huggingface.co. 207 | 208 | 209 | ## [FAQs and Troubleshooting](faq.md) 210 | 211 | 212 | 213 | 214 | ## Modify HTTP Address 215 | 216 | The default address is `http://127.0.0.1:9966`. If you want to modify it, open the `.env` file in the directory and change `WEB_ADDRESS=127.0.0.1:9966` to the appropriate IP and port, such as changing to `WEB_ADDRESS=192.168.0.10:9966` for LAN access. 217 | 218 | ## Using API Requests v0.5+ 219 | 220 | **Method:** POST 221 | 222 | **URL:** http://127.0.0.1:9966/tts 223 | 224 | **Parameters:** 225 | 226 | text: str| Required, text to synthesize. 227 | 228 | voice: int| Optional, default 2222. Determines the voice digit, choose from 2222 | 7869 | 6653 | 4099 | 5099, or any input will randomly use a voice. 229 | 230 | prompt: str| Optional, default empty. Sets laughter, pause, etc., like [oral_2][laugh_0][break_6]. 231 | 232 | temperature: float| Optional, default 0.3. 233 | 234 | top_p: float| Optional, default 0.7. 235 | 236 | top_k: int| Optional, default 20. 237 | 238 | skip_refine: int| Optional, default 0. 1=skip refine text, 0=do not skip. 239 | 240 | custom_voice: int| Optional, default 0. Sets a custom seed value for obtaining the voice, must be a positive integer. If set, it will take precedence over `voice`. 241 | 242 | 243 | **Response: JSON** 244 | 245 | Success: 246 | {code:0,msg:ok,audio_files:[dict1,dict2]} 247 | 248 | where audio_files is an array of dictionaries, each element dict is {filename:absolute path to wav file, url:downloadable wav URL} 249 | 250 | Failure: 251 | 252 | {code:1,msg:error reason} 253 | 254 | 255 | ``` 256 | 257 | # API Call Code 258 | 259 | import requests 260 | 261 | res = requests.post('http://127.0.0.1:9966/tts', data={ 262 | "text": "No need to fill if unsure", 263 | "prompt": "", 264 | "voice": "3333", 265 | "temperature": 0.3, 266 | "top_p": 0.7, 267 | "top_k": 20, 268 | "skip_refine": 0, 269 | "custom_voice": 0, 270 | }) 271 | print(res.json()) 272 | 273 | #ok 274 | {code:0, msg:'ok', audio_files:[{filename: E:/python/chattts/static/wavs/20240601-22_12_12-c7456293f7b5e4dfd3ff83bbd884a23e.wav, url: http://127.0.0.1:9966/static/wavs/20240601-22_12_12-c7456293f7b5e4dfd3ff83bbd884a23e.wav}]} 275 | 276 | #error 277 | {code:1, msg:"error"} 278 | 279 | 280 | ``` 281 | 282 | 283 | ## Using in pyVideoTrans software 284 | 285 | > Upgrade pyVideoTrans to 1.82+ https://github.com/jianchang512/pyvideotrans 286 | 287 | 1. Click Menu-Settings-ChatTTS and fill in the request address, which should by default be http://127.0.0.1:9966. 288 | 2. After ensuring there are no issues, select `ChatTTS` on the main interface. 289 | 290 | ![image](https://github.com/jianchang512/ChatTTS-ui/assets/3378335/7118325f-2b9a-46ce-a584-1d5c6dc8e2da) 291 | 292 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | if sys.platform == "darwin": 5 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 6 | import io 7 | import json 8 | import torchaudio 9 | import wave 10 | from pathlib import Path 11 | print('Starting...') 12 | import shutil 13 | import time 14 | 15 | import torch 16 | import torch._dynamo 17 | torch._dynamo.config.suppress_errors = True 18 | torch._dynamo.config.cache_size_limit = 64 19 | torch._dynamo.config.suppress_errors = True 20 | torch.set_float32_matmul_precision('high') 21 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 22 | import subprocess 23 | import soundfile as sf 24 | import ChatTTS 25 | import datetime 26 | from dotenv import load_dotenv 27 | load_dotenv() 28 | from flask import Flask, request, render_template, jsonify, send_from_directory,send_file,Response, stream_with_context 29 | import logging 30 | from logging.handlers import RotatingFileHandler 31 | from waitress import serve 32 | from random import random 33 | from modelscope import snapshot_download 34 | import numpy as np 35 | import threading 36 | from uilib.cfg import WEB_ADDRESS, SPEAKER_DIR, LOGS_DIR, WAVS_DIR, MODEL_DIR, ROOT_DIR 37 | from uilib import utils,VERSION 38 | from ChatTTS.utils import select_device 39 | from uilib.utils import is_chinese_os,modelscope_status 40 | merge_size=int(os.getenv('merge_size',10)) 41 | env_lang=os.getenv('lang','') 42 | if env_lang=='zh': 43 | is_cn= True 44 | elif env_lang=='en': 45 | is_cn=False 46 | else: 47 | is_cn=is_chinese_os() 48 | 49 | if not shutil.which("ffmpeg"): 50 | print('请先安装ffmpeg') 51 | time.sleep(60) 52 | exit() 53 | 54 | 55 | chat = ChatTTS.Chat() 56 | device_str=os.getenv('device','default') 57 | 58 | if device_str in ['default','mps']: 59 | device=select_device(min_memory=2047,experimental=True if device_str=='mps' else False) 60 | elif device_str =='cuda': 61 | device=select_device(min_memory=2047) 62 | elif device_str == 'cpu': 63 | device = torch.device("cpu") 64 | 65 | 66 | chat.load(source="local" if not os.path.exists(MODEL_DIR+"/DVAE_full.pt") else 'custom',custom_path=ROOT_DIR, device=device,compile=True if os.getenv('compile','true').lower()!='false' else False) 67 | 68 | 69 | # 配置日志 70 | # 禁用 Werkzeug 默认的日志处理器 71 | log = logging.getLogger('werkzeug') 72 | log.handlers[:] = [] 73 | log.setLevel(logging.WARNING) 74 | 75 | app = Flask(__name__, 76 | static_folder=ROOT_DIR+'/static', 77 | static_url_path='/static', 78 | template_folder=ROOT_DIR+'/templates') 79 | 80 | root_log = logging.getLogger() # Flask的根日志记录器 81 | root_log.handlers = [] 82 | root_log.setLevel(logging.WARNING) 83 | app.logger.setLevel(logging.WARNING) 84 | # 创建 RotatingFileHandler 对象,设置写入的文件路径和大小限制 85 | file_handler = RotatingFileHandler(LOGS_DIR+f'/{datetime.datetime.now().strftime("%Y%m%d")}.log', maxBytes=1024 * 1024, backupCount=5) 86 | # 创建日志的格式 87 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 88 | # 设置文件处理器的级别和格式 89 | file_handler.setLevel(logging.WARNING) 90 | file_handler.setFormatter(formatter) 91 | # 将文件处理器添加到日志记录器中 92 | app.logger.addHandler(file_handler) 93 | app.jinja_env.globals.update(enumerate=enumerate) 94 | 95 | @app.route('/static/') 96 | def static_files(filename): 97 | return send_from_directory(app.config['STATIC_FOLDER'], filename) 98 | 99 | 100 | @app.route('/') 101 | def index(): 102 | speakers=utils.get_speakers() 103 | return render_template( 104 | f"index{'' if is_cn else 'en'}.html", 105 | weburl=WEB_ADDRESS, 106 | speakers=speakers, 107 | version=VERSION 108 | ) 109 | 110 | 111 | # 根据文本返回tts结果,返回 filename=文件名 url=可下载地址 112 | # 请求端根据需要自行选择使用哪个 113 | # params: 114 | # 115 | # text:待合成文字 116 | # prompt: 117 | # voice:音色 118 | # custom_voice:自定义音色值 119 | # skip_refine: 1=跳过refine_text阶段,0=不跳过 120 | # temperature 121 | # top_p 122 | # top_k 123 | # speed 124 | # text_seed 125 | # refine_max_new_token 126 | # infer_max_new_token 127 | # wav 128 | 129 | audio_queue=[] 130 | 131 | @app.route('/tts', methods=['GET', 'POST']) 132 | def tts(): 133 | global audio_queue 134 | # 原始字符串 135 | text = request.args.get("text","").strip() or request.form.get("text","").strip() 136 | prompt = request.args.get("prompt","").strip() or request.form.get("prompt",'') 137 | 138 | # 默认值 139 | defaults = { 140 | "custom_voice": 0, 141 | "voice": "2222", 142 | "temperature": 0.3, 143 | "top_p": 0.7, 144 | "top_k": 20, 145 | "skip_refine": 0, 146 | "speed":5, 147 | "text_seed":42, 148 | "refine_max_new_token": 384, 149 | "infer_max_new_token": 2048, 150 | "wav": 0, 151 | "is_stream":0 152 | } 153 | 154 | # 获取 155 | custom_voice = utils.get_parameter(request, "custom_voice", defaults["custom_voice"], int) 156 | voice = str(custom_voice) if custom_voice > 0 else utils.get_parameter(request, "voice", defaults["voice"], str) 157 | temperature = utils.get_parameter(request, "temperature", defaults["temperature"], float) 158 | top_p = utils.get_parameter(request, "top_p", defaults["top_p"], float) 159 | top_k = utils.get_parameter(request, "top_k", defaults["top_k"], int) 160 | skip_refine = utils.get_parameter(request, "skip_refine", defaults["skip_refine"], int) 161 | is_stream = utils.get_parameter(request, "is_stream", defaults["is_stream"], int) 162 | speed = utils.get_parameter(request, "speed", defaults["speed"], int) 163 | text_seed = utils.get_parameter(request, "text_seed", defaults["text_seed"], int) 164 | refine_max_new_token = utils.get_parameter(request, "refine_max_new_token", defaults["refine_max_new_token"], int) 165 | infer_max_new_token = utils.get_parameter(request, "infer_max_new_token", defaults["infer_max_new_token"], int) 166 | wav = utils.get_parameter(request, "wav", defaults["wav"], int) 167 | 168 | 169 | 170 | app.logger.info(f"[tts]{text=}\n{voice=},{skip_refine=}\n") 171 | if not text: 172 | return jsonify({"code": 1, "msg": "text params lost"}) 173 | # 固定音色 174 | rand_spk=None 175 | # voice可能是 {voice}.csv or {voice}.pt or number 176 | voice=voice.replace('.csv','.pt') 177 | seed_path=f'{SPEAKER_DIR}/{voice}' 178 | print(f'{voice=}') 179 | #if voice.endswith('.csv') and os.path.exists(seed_path): 180 | # rand_spk=utils.load_speaker(voice) 181 | # print(f'当前使用音色 {seed_path=}') 182 | #el 183 | 184 | if voice.endswith('.pt') and os.path.exists(seed_path): 185 | #如果.env中未指定设备,则使用 ChatTTS相同算法找设备,否则使用指定设备 186 | rand_spk=torch.load(seed_path, map_location=device) 187 | print(f'当前使用音色 {seed_path=}') 188 | # 否则 判断是否存在 {voice}.csv 189 | #elif os.path.exists(f'{SPEAKER_DIR}/{voice}.csv'): 190 | # rand_spk=utils.load_speaker(voice) 191 | # print(f'当前使用音色 {SPEAKER_DIR}/{voice}.csv') 192 | 193 | if rand_spk is None: 194 | print(f'当前使用音色:根据seed={voice}获取随机音色') 195 | voice_int=re.findall(r'^(\d+)',voice) 196 | if len(voice_int)>0: 197 | voice=int(voice_int[0]) 198 | else: 199 | voice=2222 200 | torch.manual_seed(voice) 201 | #std, mean = chat.sample_random_speaker 202 | rand_spk = chat.sample_random_speaker() 203 | #rand_spk = torch.randn(768) * std + mean 204 | # 保存音色 205 | torch.save(rand_spk,f"{SPEAKER_DIR}/{voice}.pt") 206 | #utils.save_speaker(voice,rand_spk) 207 | 208 | 209 | audio_files = [] 210 | 211 | 212 | start_time = time.time() 213 | 214 | # 中英按语言分行 215 | text_list=[t.strip() for t in text.split("\n") if t.strip()] 216 | new_text=utils.split_text(text_list) 217 | if text_seed>0: 218 | torch.manual_seed(text_seed) 219 | 220 | 221 | params_infer_code = ChatTTS.Chat.InferCodeParams( 222 | spk_emb=rand_spk, 223 | prompt=f"[speed_{speed}]", 224 | top_P=top_p, 225 | top_K=top_k, 226 | temperature=temperature, 227 | max_new_token=infer_max_new_token 228 | ) 229 | params_refine_text = ChatTTS.Chat.RefineTextParams( 230 | prompt=prompt, 231 | top_P=top_p, 232 | top_K=top_k, 233 | temperature=temperature, 234 | max_new_token=refine_max_new_token 235 | ) 236 | print(f'{prompt=}') 237 | # 将少于30个字符的行同其他行拼接 238 | retext=[] 239 | short_text="" 240 | for it in new_text: 241 | if len(it)<30: 242 | short_text+=f"{it} [uv_break] " 243 | if len(short_text)>30: 244 | retext.append(short_text) 245 | short_text="" 246 | else: 247 | retext.append(short_text+it) 248 | short_text="" 249 | if len(short_text)>30 or len(retext)<1: 250 | retext.append(short_text) 251 | elif short_text: 252 | retext[-1]+=f" [uv_break] {short_text}" 253 | 254 | new_text=retext 255 | 256 | new_text_list=[new_text[i:i+merge_size] for i in range(0,len(new_text),merge_size)] 257 | filename_list=[] 258 | 259 | audio_time=0 260 | inter_time=0 261 | 262 | for i,te in enumerate(new_text_list): 263 | print(f'{te=}') 264 | wavs = chat.infer( 265 | te, 266 | #use_decoder=False, 267 | stream=True if is_stream==1 else False, 268 | skip_refine_text=skip_refine, 269 | do_text_normalization=False, 270 | do_homophone_replacement=True, 271 | params_refine_text=params_refine_text, 272 | params_infer_code=params_infer_code 273 | 274 | ) 275 | 276 | 277 | end_time = time.time() 278 | inference_time = end_time - start_time 279 | inference_time_rounded = round(inference_time, 2) 280 | inter_time+=inference_time_rounded 281 | print(f"推理时长: {inference_time_rounded} 秒") 282 | 283 | 284 | 285 | for j,w in enumerate(wavs): 286 | filename = datetime.datetime.now().strftime('%H%M%S_')+f"use{inference_time_rounded}s-seed{voice}-te{temperature}-tp{top_p}-tk{top_k}-textlen{len(text)}-{str(random())[2:7]}" + f"-{i}-{j}.wav" 287 | filename_list.append(filename) 288 | torchaudio.save(WAVS_DIR+'/'+filename, torch.from_numpy(w).unsqueeze(0), 24000) 289 | 290 | txt_tmp="\n".join([f"file '{WAVS_DIR}/{it}'" for it in filename_list]) 291 | txt_name=f'{time.time()}.txt' 292 | with open(f'{WAVS_DIR}/{txt_name}','w',encoding='utf-8') as f: 293 | f.write(txt_tmp) 294 | outname=datetime.datetime.now().strftime('%H%M%S_')+f"use{inter_time}s-audio{audio_time}s-seed{voice}-te{temperature}-tp{top_p}-tk{top_k}-textlen{len(text)}-{str(random())[2:7]}" + "-merge.wav" 295 | try: 296 | subprocess.run(["ffmpeg","-hide_banner", "-ignore_unknown","-y","-f","concat","-safe","0","-i",f'{WAVS_DIR}/{txt_name}',"-c:a","copy",WAVS_DIR + '/' + outname], 297 | stdout=subprocess.PIPE, 298 | stderr=subprocess.PIPE, 299 | encoding="utf-8", 300 | check=True, 301 | text=True, 302 | creationflags=0 if sys.platform != 'win32' else subprocess.CREATE_NO_WINDOW) 303 | except Exception as e: 304 | return jsonify({"code":1,"msg":str(e)}) 305 | 306 | 307 | 308 | audio_files.append({ 309 | "filename": WAVS_DIR + '/' + outname, 310 | "url": f"http://{request.host}/static/wavs/{outname}", 311 | "inference_time": round(inter_time,2), 312 | "audio_duration": -1 313 | }) 314 | result_dict={"code": 0, "msg": "ok", "audio_files": audio_files} 315 | try: 316 | if torch.cuda.is_available(): 317 | torch.cuda.empty_cache() 318 | except Exception: 319 | pass 320 | # 兼容pyVideoTrans接口调用 321 | if len(audio_files)==1: 322 | result_dict["filename"]=audio_files[0]['filename'] 323 | result_dict["url"]=audio_files[0]['url'] 324 | 325 | if wav>0: 326 | return send_file(audio_files[0]['filename'], mimetype='audio/x-wav') 327 | else: 328 | return jsonify(result_dict) 329 | 330 | 331 | 332 | @app.route('/clear_wavs', methods=['POST']) 333 | def clear_wavs(): 334 | dir_path = 'static/wavs' # wav音频文件存储目录 335 | success, message = utils.ClearWav(dir_path) 336 | if success: 337 | return jsonify({"code": 0, "msg": message}) 338 | else: 339 | return jsonify({"code": 1, "msg": message}) 340 | 341 | try: 342 | host = WEB_ADDRESS.split(':') 343 | print(f'Start:{WEB_ADDRESS}') 344 | threading.Thread(target=utils.openweb,args=(f'http://{WEB_ADDRESS}',)).start() 345 | serve(app,host=host[0], port=int(host[1])) 346 | except Exception as e: 347 | print(e) 348 | 349 | -------------------------------------------------------------------------------- /asset/模型下载说明.txt: -------------------------------------------------------------------------------- 1 | 如果无法下载模型,请去下载 https://github.com/jianchang512/ChatTTS-ui/releases/download/v1.0/all-models.7z 2 | 3 | 4 | 下载后解压后,会看到asset文件夹,该文件夹内有多个pt文件,将所有pt文件复制到本目录下,然后重启软件 -------------------------------------------------------------------------------- /cover-pt.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 0.96版本后,因ChatTTS内核升级,已无法直接使用从该站点下载的pt文件。 3 | 4 | https://modelscope.cn/studios/ttwwwaa/ChatTTS_Speaker 5 | 6 | 因此增加该转换脚本。 7 | 8 | 执行 python cover-pt.py 后将把 `speaker` 目录下的,以 seed_ 开头, 9 | 以 _emb.pt 结尾的文件,即下载后的默认文件名, 10 | 转换为可用的编码格式,转换后的pt将改名为以 `_emb-covert.pt` 结尾。 11 | 12 | 例: 13 | 14 | 假如 speaker/seed_2155_restored_emb.pt 存在这个文件 15 | 16 | 将被转换为 speaker/seed_2155_restored_emb-cover.pt, 17 | 18 | 然后删掉原pt文件,仅保留该转换后的文件即可 19 | 20 | 21 | 22 | ''' 23 | 24 | import os 25 | import re 26 | import sys 27 | if sys.platform == "darwin": 28 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 29 | import io 30 | import json 31 | import torchaudio 32 | import wave 33 | from pathlib import Path 34 | print('Starting...') 35 | import shutil 36 | import time 37 | 38 | 39 | import torch 40 | import torch._dynamo 41 | torch._dynamo.config.suppress_errors = True 42 | torch._dynamo.config.cache_size_limit = 64 43 | torch._dynamo.config.suppress_errors = True 44 | torch.set_float32_matmul_precision('high') 45 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 46 | import subprocess 47 | import soundfile as sf 48 | import ChatTTS 49 | import datetime 50 | from dotenv import load_dotenv 51 | load_dotenv() 52 | 53 | import logging 54 | from logging.handlers import RotatingFileHandler 55 | 56 | from random import random 57 | from modelscope import snapshot_download 58 | import numpy as np 59 | import threading 60 | from uilib.cfg import WEB_ADDRESS, SPEAKER_DIR, LOGS_DIR, WAVS_DIR, MODEL_DIR, ROOT_DIR 61 | from uilib import utils,VERSION 62 | from ChatTTS.utils import select_device 63 | from uilib.utils import is_chinese_os,modelscope_status 64 | merge_size=int(os.getenv('merge_size',10)) 65 | env_lang=os.getenv('lang','') 66 | if env_lang=='zh': 67 | is_cn= True 68 | elif env_lang=='en': 69 | is_cn=False 70 | else: 71 | is_cn=is_chinese_os() 72 | 73 | 74 | 75 | chat = ChatTTS.Chat() 76 | device_str=os.getenv('device','default') 77 | 78 | if device_str in ['default','mps']: 79 | device=select_device(min_memory=2047,experimental=True if device_str=='mps' else False) 80 | elif device_str =='cuda': 81 | device=select_device(min_memory=2047) 82 | elif device_str == 'cpu': 83 | device = torch.device("cpu") 84 | 85 | 86 | chat.load(source="custom",custom_path=ROOT_DIR, device=device,compile=True if os.getenv('compile','true').lower()!='false' else False) 87 | n=0 88 | for it in os.listdir('./speaker'): 89 | if it.startswith('seed_') and not it.endswith('_emb-covert.pt'): 90 | print(f'开始转换 {it}') 91 | n+=1 92 | rand_spk=torch.load(f'./speaker/{it}', map_location=device) 93 | 94 | torch.save( chat._encode_spk_emb(rand_spk) ,f"{SPEAKER_DIR}/{it.replace('.pt','-covert.pt')}") 95 | if n==0: 96 | print('没有可转换的pt文件,仅转换以 seed_ 开头,并以 _emb.pt 结尾的文件') 97 | 98 | else: 99 | print(f'转换完成{n}个,可以删掉以 _emb.pt 结尾的文件了,注意保留 -covert.pt 结尾的文件') 100 | 101 | print(f'\n\n30s后本窗口自动关闭') 102 | time.sleep(30) -------------------------------------------------------------------------------- /docker-compose.cpu.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | chat-tts-ui: 3 | build: 4 | context: . 5 | dockerfile: Dockerfile.cpu 6 | container_name: chat-tts-ui 7 | restart: always 8 | volumes: 9 | - "./:/app" 10 | ports: 11 | - 9966:9966 12 | user: "${UID}:${GID}" 13 | environment: 14 | LOG_LEVEL: DEBUG 15 | WEB_ADDRESS: 0.0.0.0:9966 16 | command: python3 app.py 17 | -------------------------------------------------------------------------------- /docker-compose.gpu.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | chat-tts-ui: 3 | build: 4 | context: . 5 | dockerfile: Dockerfile.gpu 6 | container_name: chat-tts-ui 7 | restart: always 8 | volumes: 9 | - "./:/app" 10 | ports: 11 | - 9966:9966 12 | user: "${UID}:${GID}" 13 | environment: 14 | LOG_LEVEL: DEBUG 15 | WEB_ADDRESS: 0.0.0.0:9966 16 | NVIDIA_VISIBLE_DEVICES: all 17 | command: python3 app.py 18 | deploy: 19 | resources: 20 | reservations: 21 | devices: 22 | - driver: nvidia 23 | capabilities: [gpu] 24 | -------------------------------------------------------------------------------- /faq.md: -------------------------------------------------------------------------------- 1 | # 常见问题与报错 2 | 3 | 4 | **注意:不同机器使用相同种子生成的音频音色可能不同,同一机器使用相同种子多次生成的音频音色也可能变化。** 5 | 6 | 7 | **升级到0.96版后报错** 8 | 9 | 答: 0.96版起,源码部署必须先安装ffmpeg 10 | 11 | 0.96版起,之前的音色文件csv和pt已不可用,请填写音色值重新生成,或到以下站点下载 12 | 13 | https://modelscope.cn/studios/ttwwwaa/ChatTTS_Speaker 14 | 15 | 16 | **0.** 执行app.py报错 FileNotFoundError: [Errno 2] No such file or directory: '../ChatTTS-ui/models/pzc163/chatTTS/config/path.yaml 17 | 18 | 答:模型不完整,重新下载模型或者 打开 https://www.modelscope.cn/models/pzc163/chatTTS/files 下载 path.yaml 、复制到报错里显示的文件夹内 ChatTTS-ui/models/pzc163/chatTTS/config/ 19 | 20 | 21 | 22 | **1.** MacOS 报错 `Initializing libomp.dylib, but found libiomp5.dylib already initialized` 23 | 24 | > 答:在app.py的 `import os` 的下一行,添加代码 25 | > 26 | > `os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'` 27 | 28 | 29 | **2.** MacOS 无报错但进度条一直百分之0 卡住不动 30 | 31 | > 答:app.py 中 32 | > 33 | > `chat.load_models(source="local",local_path=CHATTTS_DIR)` 34 | > 35 | > 改为 36 | > 37 | > `chat.load_models(source="local",local_path=CHATTTS_DIR,compile=False)` 38 | 39 | **3.** MacOS 报 `libomp` 相关错误 40 | 41 | > 答:执行 `brew install libomp` 42 | 43 | **4.** 报https相关错误 `ProxyError: HTTPSConnectionPool(host='www.modelscope.cn', port=443)` 44 | 45 | > 答:从 modelscope 魔塔下载模型时不可使用代理,请关闭代理 46 | 47 | 48 | **5.** 报错丢失文件 `Missing spk_stat.pt` 49 | 50 | > 答:本项目(ChatTTS-ui)默认从 modelscope 即魔塔社区下载模型,但该库里的模型缺少 spk_stat.pt文件 51 | > 52 | > 请科学上网后从 53 | > 54 | > https://huggingface.co/2Noise/ChatTTS/blob/main/asset/spk_stat.pt 55 | > 56 | > 下载 spk_stat.pt, 然后复制 spk_stat.pt 到报错提示的目录下,以本项目为例,需要复制到 `models/pzc163/chatTTS/asset` 文件夹内 57 | 58 | 59 | **6.** 报错 `Dynamo is not supported on Python 3.12` 60 | 61 | > 答:不支持python3.12+版本,降级到 python3.10 62 | 63 | 64 | **7.** MacOS报错 `NotOpenSSLWarning: urllib3 v2 only supports OpenSSL 1.1.1+` 65 | 66 | > 答:执行 `brew install openssl@1.1` 67 | > 68 | > 执行 `pip install urllib3==1.26.15 69 | 70 | 71 | 72 | **8.** Windows上报错:`Windows not yet supported for torch.compile` 73 | 74 | > 答:`chat.load_models(compile=False)` 改为 `chat.load_models(compile=False,device="cpu")` 75 | 76 | 77 | **9.** Windows上可以运行有GPU,但很慢 78 | 79 | > 答:如果是英伟达显卡,请将cuda升级到11.8+ 80 | 81 | 82 | **10**. 下载模型时出现 proxy 类型错误 83 | 84 | 答:默认会从 modelscope 下载模型,但 modelscope 仅允许中国大陆ip下载,如果遇到 proxy 类错误,请关闭代理。如果你希望从 huggingface.co 下载模型,请打开 `app.py` 查看大约第50行-60行的代码注释。 85 | 86 | 87 | **11.** 中英分词是怎么回事 88 | 89 | 答:如果选中中英分词,那么将会把文字中的中文和英文分离出来单独合成,同时将对应的数字 转为相应语言的文字,比如 中文下123转为一二三,英文下123转为 one two three 90 | 91 | 92 | **12.** Runtime Error:cannot find a working triton installation 93 | 94 | 打开 .env 将 compile=true 改为 compile=false 95 | 96 | **13.** MacOS下无法安装 soundfile 97 | 98 | 答:打开终端,执行 `brew install libsndfile` 然后再安装 soundfile 99 | 100 | 101 | **14.** 如何离线使用 102 | 103 | 答: 104 | 105 | 1. 使用源码部署 106 | 2. 先运行一次,确保模型下载完毕 107 | 3. 打开 app.py 大约35行, `CHATTTS_DIR = snapshot_download('pzc163/chatTTS',cache_dir=MODEL_DIR)` 改为 `CHATTTS_DIR = MODEL_DIR+"/pzc163/chatTTS"` 108 | 109 | **15.** ChatTTS原始项目新版本有兼容问题,可能会报错 “报错 Normalizer pynini WeTextProcessing nemo_text_processing ” 110 | 111 | 解决方法: 112 | 新版使用了 nemo_text_processing 和 pynini 来处理中文,但遗憾的是,pynini压根无法在windows平台安装和使用,要使用,也只能安装在WSL子系统上。 113 | 114 | 不管给出的什么安装方式, 比如 115 | 116 | ``` 117 | pip install pynini==2.1.5 Cython WeTextProcessing 118 | 119 | ``` 120 | 121 | 都是无法在Windows上正确安装的 122 | 123 | ![image](https://github.com/2noise/ChatTTS/assets/3378335/e32c50d1-492c-4b72-958b-78af0575e662) 124 | 125 | 126 | ---- 127 | 128 | 解决方法: 129 | 打开 ChatTTS/core.py, 大约143行,注释掉接下来的7行, 130 | 131 | ![image](https://github.com/2noise/ChatTTS/assets/3378335/5bdd3dc8-0c7c-485f-b5dc-613f14917319) 132 | 133 | 134 | 问题解决 135 | 136 | 或者 chat.infer() 添加参数 do_text_normalization=False, chat.infer(do_text_normalization=False) 137 | -------------------------------------------------------------------------------- /ffmpeg/ffmpeg下载.txt: -------------------------------------------------------------------------------- 1 | Windows源码部署需下载 ffmpeg.exe 放在这里 2 | https://github.com/BtbN/FFmpeg-Builds/releases/download/autobuild-2023-11-30-12-55/ffmpeg-n6.0.1-win64-gpl-6.0.zip 3 | 4 | Mac和Linux需命令行下载 5 | 6 | Mac: brew install ffmpeg 7 | 8 | Debian: apt-get install ffmpeg 9 | 10 | Centos: yum install ffmpeg -------------------------------------------------------------------------------- /listen-speaker/083806_use14.39s-audio0s-seed1983.pt-te0.1-tp0.701-tk20-textlen5-39593-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/083806_use14.39s-audio0s-seed1983.pt-te0.1-tp0.701-tk20-textlen5-39593-merge.wav -------------------------------------------------------------------------------- /listen-speaker/083900_use3.43s-audio0s-seed13.pt-te0.1-tp0.701-tk20-textlen5-09614-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/083900_use3.43s-audio0s-seed13.pt-te0.1-tp0.701-tk20-textlen5-09614-merge.wav -------------------------------------------------------------------------------- /listen-speaker/083910_use3.22s-audio0s-seed7869.pt-te0.1-tp0.701-tk20-textlen5-19801-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/083910_use3.22s-audio0s-seed7869.pt-te0.1-tp0.701-tk20-textlen5-19801-merge.wav -------------------------------------------------------------------------------- /listen-speaker/083919_use3.42s-audio0s-seed6653.pt-te0.1-tp0.701-tk20-textlen5-10851-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/083919_use3.42s-audio0s-seed6653.pt-te0.1-tp0.701-tk20-textlen5-10851-merge.wav -------------------------------------------------------------------------------- /listen-speaker/083928_use3.3s-audio0s-seed4751.pt-te0.1-tp0.701-tk20-textlen5-69400-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/083928_use3.3s-audio0s-seed4751.pt-te0.1-tp0.701-tk20-textlen5-69400-merge.wav -------------------------------------------------------------------------------- /listen-speaker/083937_use3.11s-audio0s-seed1579.pt-te0.1-tp0.701-tk20-textlen5-27436-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/083937_use3.11s-audio0s-seed1579.pt-te0.1-tp0.701-tk20-textlen5-27436-merge.wav -------------------------------------------------------------------------------- /listen-speaker/083945_use3.13s-audio0s-seed14.pt-te0.1-tp0.701-tk20-textlen5-57598-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/083945_use3.13s-audio0s-seed14.pt-te0.1-tp0.701-tk20-textlen5-57598-merge.wav -------------------------------------------------------------------------------- /listen-speaker/083955_use2.84s-audio0s-seed3333.pt-te0.1-tp0.701-tk20-textlen5-93133-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/083955_use2.84s-audio0s-seed3333.pt-te0.1-tp0.701-tk20-textlen5-93133-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084004_use3.08s-audio0s-seed1111.pt-te0.1-tp0.701-tk20-textlen5-39727-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084004_use3.08s-audio0s-seed1111.pt-te0.1-tp0.701-tk20-textlen5-39727-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084014_use3.37s-audio0s-seed11.pt-te0.1-tp0.701-tk20-textlen5-27662-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084014_use3.37s-audio0s-seed11.pt-te0.1-tp0.701-tk20-textlen5-27662-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084024_use3.3s-audio0s-seed1031.pt-te0.1-tp0.701-tk20-textlen5-19879-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084024_use3.3s-audio0s-seed1031.pt-te0.1-tp0.701-tk20-textlen5-19879-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084032_use3.36s-audio0s-seed2222.pt-te0.1-tp0.701-tk20-textlen5-48884-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084032_use3.36s-audio0s-seed2222.pt-te0.1-tp0.701-tk20-textlen5-48884-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084040_use3.12s-audio0s-seed12.pt-te0.1-tp0.701-tk20-textlen5-28377-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084040_use3.12s-audio0s-seed12.pt-te0.1-tp0.701-tk20-textlen5-28377-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084048_use3.16s-audio0s-seed5555.pt-te0.1-tp0.701-tk20-textlen5-42929-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084048_use3.16s-audio0s-seed5555.pt-te0.1-tp0.701-tk20-textlen5-42929-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084056_use3.02s-audio0s-seed5099.pt-te0.1-tp0.701-tk20-textlen5-35891-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084056_use3.02s-audio0s-seed5099.pt-te0.1-tp0.701-tk20-textlen5-35891-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084454_use3.47s-audio0s-seed2345.pt-te0.1-tp0.701-tk20-textlen5-86669-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084454_use3.47s-audio0s-seed2345.pt-te0.1-tp0.701-tk20-textlen5-86669-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084503_use3.22s-audio0s-seed4785.pt-te0.1-tp0.701-tk20-textlen5-95898-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084503_use3.22s-audio0s-seed4785.pt-te0.1-tp0.701-tk20-textlen5-95898-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084511_use3.56s-audio0s-seed491.pt-te0.1-tp0.701-tk20-textlen5-66150-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084511_use3.56s-audio0s-seed491.pt-te0.1-tp0.701-tk20-textlen5-66150-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084518_use3.15s-audio0s-seed4444.pt-te0.1-tp0.701-tk20-textlen5-77649-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084518_use3.15s-audio0s-seed4444.pt-te0.1-tp0.701-tk20-textlen5-77649-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084526_use3.38s-audio0s-seed1455.pt-te0.1-tp0.701-tk20-textlen5-54547-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084526_use3.38s-audio0s-seed1455.pt-te0.1-tp0.701-tk20-textlen5-54547-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084755_use3.18s-audio0s-seed2328.pt-te0.1-tp0.701-tk20-textlen5-85733-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084755_use3.18s-audio0s-seed2328.pt-te0.1-tp0.701-tk20-textlen5-85733-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084813_use3.47s-audio0s-seed8888.pt-te0.1-tp0.701-tk20-textlen5-96180-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084813_use3.47s-audio0s-seed8888.pt-te0.1-tp0.701-tk20-textlen5-96180-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084823_use3.33s-audio0s-seed16.pt-te0.1-tp0.701-tk20-textlen5-51038-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084823_use3.33s-audio0s-seed16.pt-te0.1-tp0.701-tk20-textlen5-51038-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084832_use2.95s-audio0s-seed1234.pt-te0.1-tp0.701-tk20-textlen5-80959-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084832_use2.95s-audio0s-seed1234.pt-te0.1-tp0.701-tk20-textlen5-80959-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084842_use3.34s-audio0s-seed1518.pt-te0.1-tp0.701-tk20-textlen5-37066-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084842_use3.34s-audio0s-seed1518.pt-te0.1-tp0.701-tk20-textlen5-37066-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084851_use3.08s-audio0s-seed7777.pt-te0.1-tp0.701-tk20-textlen5-99477-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084851_use3.08s-audio0s-seed7777.pt-te0.1-tp0.701-tk20-textlen5-99477-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084901_use2.81s-audio0s-seed4099.pt-te0.1-tp0.701-tk20-textlen5-16898-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084901_use2.81s-audio0s-seed4099.pt-te0.1-tp0.701-tk20-textlen5-16898-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084910_use3.29s-audio0s-seed5600.pt-te0.1-tp0.701-tk20-textlen5-42899-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084910_use3.29s-audio0s-seed5600.pt-te0.1-tp0.701-tk20-textlen5-42899-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084919_use3.12s-audio0s-seed5400.pt-te0.1-tp0.701-tk20-textlen5-57496-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084919_use3.12s-audio0s-seed5400.pt-te0.1-tp0.701-tk20-textlen5-57496-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084929_use3.43s-audio0s-seed9999.pt-te0.1-tp0.701-tk20-textlen5-32652-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084929_use3.43s-audio0s-seed9999.pt-te0.1-tp0.701-tk20-textlen5-32652-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084945_use3.13s-audio0s-seed125.pt-te0.1-tp0.701-tk20-textlen5-93149-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084945_use3.13s-audio0s-seed125.pt-te0.1-tp0.701-tk20-textlen5-93149-merge.wav -------------------------------------------------------------------------------- /listen-speaker/084954_use3.25s-audio0s-seed2279.pt-te0.1-tp0.701-tk20-textlen5-62556-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/084954_use3.25s-audio0s-seed2279.pt-te0.1-tp0.701-tk20-textlen5-62556-merge.wav -------------------------------------------------------------------------------- /listen-speaker/085002_use3.2s-audio0s-seed6666.pt-te0.1-tp0.701-tk20-textlen5-07948-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/085002_use3.2s-audio0s-seed6666.pt-te0.1-tp0.701-tk20-textlen5-07948-merge.wav -------------------------------------------------------------------------------- /listen-speaker/085011_use3.31s-audio0s-seed492.pt-te0.1-tp0.701-tk20-textlen5-17771-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/085011_use3.31s-audio0s-seed492.pt-te0.1-tp0.701-tk20-textlen5-17771-merge.wav -------------------------------------------------------------------------------- /listen-speaker/085020_use3.04s-audio0s-seed5.pt-te0.1-tp0.701-tk20-textlen5-82025-merge.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/listen-speaker/085020_use3.04s-audio0s-seed5.pt-te0.1-tp0.701-tk20-textlen5-82025-merge.wav -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "chattts-ui" 3 | version = "0.1.0" 4 | description = "一个简单的本地网页界面,直接使用ChatTTS将文字合成为语音,同时支持对外提供API接口。" 5 | authors = ["jianchang512 "] 6 | readme = "README.md" 7 | homepage = "https://github.com/jianchang512/ChatTTS-ui" 8 | repository = "https://github.com/jianchang512/ChatTTS-ui" 9 | documentation = "https://github.com/jianchang512/ChatTTS-ui" 10 | package-mode = false 11 | 12 | [tool.poetry.dependencies] 13 | python = "^3.10" 14 | soundfile = "^0.12.1" 15 | python-dotenv = "^1.0.1" 16 | flask = "^3.0.3" 17 | waitress = "^3.0.0" 18 | modelscope = "^1.14.0" 19 | langsegment = "^0.3.3" 20 | omegaconf = "^2.3.0" 21 | tokenizers = "^0.19.1" 22 | transformers = "^4.41.2" 23 | torch = [ 24 | { version = "^2.3.0+cu118", source = "pytorch-gpu-src" }, 25 | { platform = "darwin", version = "^2" } 26 | ] 27 | torchaudio = [ 28 | { version = "^2.3.0+cu118", source = "pytorch-gpu-src" }, 29 | { platform = "darwin", version = "^2" } 30 | ] 31 | vocos = "^0.1.0" 32 | vector-quantize-pytorch = "^1.14.24" 33 | numpy = { version = "^1.26.4" } 34 | 35 | # Optional dependencies that are not downloaded by default 36 | fastapi = { version = "*", extras = ["all"], optional = true } 37 | pydantic = { version = "^2", optional = true } 38 | 39 | [[tool.poetry.source]] 40 | name = "pytorch-gpu-src" 41 | url = "https://download.pytorch.org/whl/cu118" 42 | priority = "explicit" 43 | 44 | [tool.poetry.group.dev.dependencies] 45 | black = "*" 46 | ruff = "*" 47 | 48 | [tool.poetry.group.test.dependencies] 49 | # https://docs.pytest.org/en/stable/reference/plugin_list.html#plugin-list 50 | # https://docs.pytest.org/en/stable/contents.html 51 | pytest = "*" 52 | # https://pytest-asyncio.readthedocs.io/en/latest/ 53 | pytest-asyncio = "*" 54 | 55 | [tool.black] 56 | line-length = 100 57 | target-version = ["py310", "py311", "py312"] 58 | skip-magic-trailing-comma = true 59 | exclude = ''' 60 | /( 61 | ChatTTS 62 | | .* 63 | | build 64 | | dist 65 | | migrations 66 | | __pycache__ 67 | )/ 68 | ''' 69 | 70 | [tool.pytest.ini_options] 71 | # https://docs.pytest.org/en/stable/reference/reference.html#configuration-options 72 | testpaths = ["tests", "examples"] 73 | asyncio_mode = "auto" 74 | filterwarnings = "ignore::DeprecationWarning" 75 | 76 | [build-system] 77 | requires = ["poetry-core"] 78 | build-backend = "poetry.core.masonry.api" 79 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Flask 2 | ipython 3 | modelscope 4 | numpy==1.26.4 5 | numba 6 | einops 7 | tqdm 8 | omegaconf>=2.3.0 9 | torch>=2.1.0 10 | python-dotenv 11 | requests 12 | soundfile 13 | tokenizers 14 | transformers==4.41.1 15 | vector-quantize-pytorch 16 | vocos 17 | waitress 18 | pybase16384 19 | pynini==2.1.5; sys_platform == 'linux' 20 | WeTextProcessing; sys_platform == 'linux' 21 | nemo_text_processing; sys_platform == 'linux' 22 | av 23 | pydub 24 | pandas 25 | -------------------------------------------------------------------------------- /run.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | .\venv\scripts\python.exe app.py 4 | pause -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | if sys.platform == "darwin": 4 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 5 | 6 | now_dir = os.getcwd() 7 | sys.path.append(now_dir) 8 | 9 | from dotenv import load_dotenv 10 | load_dotenv("sha256.env") 11 | 12 | import wave 13 | import ChatTTS 14 | from IPython.display import Audio 15 | import numpy as np 16 | 17 | def save_wav_file(wav, index): 18 | wav_filename = f"output_audio_{index}.wav" 19 | # Convert numpy array to bytes and write to WAV file 20 | wav_bytes = (wav * 32768).astype('int16').tobytes() 21 | with wave.open(wav_filename, "wb") as wf: 22 | wf.setnchannels(1) # Mono channel 23 | wf.setsampwidth(2) # Sample width in bytes 24 | wf.setframerate(24000) # Sample rate in Hz 25 | wf.writeframes(wav_bytes) 26 | print(f"Audio saved to {wav_filename}") 27 | 28 | def main(): 29 | # Retrieve text from command line argument 30 | try: 31 | sys.argv.remove('--stream') 32 | stream = True 33 | except: 34 | stream = False 35 | text_input = sys.argv[1] if len(sys.argv) > 1 else "" 36 | print("Received text input:", text_input) 37 | 38 | chat = ChatTTS.Chat() 39 | print("Initializing ChatTTS...") 40 | # if using macbook(M1), I suggest you set `device='cpu', compile=False` 41 | #chat.load_models() 42 | chat.load_models(source="custom",custom_path='./models/pzc163/chattts') 43 | print("Models loaded successfully.") 44 | 45 | texts = [text_input] 46 | print("Text prepared for inference:", texts) 47 | 48 | wavs_gen = chat.infer(texts, use_decoder=True, stream=stream) 49 | print("Inference completed. Audio generation successful.") 50 | # Save each generated wav file to a local file 51 | 52 | if stream: 53 | print('generate with stream mode ..') 54 | wavs = [np.array([[]])] 55 | for gen in wavs_gen: 56 | print('got new chunk', gen) 57 | tmp=[np.array([[]])] 58 | tmp[0]=np.hstack([tmp[0], np.array(gen[0])]) 59 | save_wav_file(tmp[0], 11) 60 | wavs[0] = np.hstack([wavs[0], np.array(gen[0])]) 61 | else: 62 | print('generate without stream mode ..') 63 | wavs = wavs_gen 64 | 65 | for index, wav in enumerate(wavs): 66 | save_wav_file(wav, index) 67 | 68 | return Audio(wavs[0], rate=24_000, autoplay=True) 69 | 70 | if __name__ == "__main__": 71 | print("Starting the TTS application...") 72 | main() 73 | print("TTS application finished.") 74 | -------------------------------------------------------------------------------- /runtest.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | .\venv\scripts\python.exe test.py 4 | pause -------------------------------------------------------------------------------- /speaker/1031.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/1031.pt -------------------------------------------------------------------------------- /speaker/11.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/11.pt -------------------------------------------------------------------------------- /speaker/1111.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/1111.pt -------------------------------------------------------------------------------- /speaker/12.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/12.pt -------------------------------------------------------------------------------- /speaker/1234.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/1234.pt -------------------------------------------------------------------------------- /speaker/125.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/125.pt -------------------------------------------------------------------------------- /speaker/13.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/13.pt -------------------------------------------------------------------------------- /speaker/14.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/14.pt -------------------------------------------------------------------------------- /speaker/1455.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/1455.pt -------------------------------------------------------------------------------- /speaker/1518.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/1518.pt -------------------------------------------------------------------------------- /speaker/1579.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/1579.pt -------------------------------------------------------------------------------- /speaker/16.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/16.pt -------------------------------------------------------------------------------- /speaker/1983.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/1983.pt -------------------------------------------------------------------------------- /speaker/2222.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/2222.pt -------------------------------------------------------------------------------- /speaker/2279.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/2279.pt -------------------------------------------------------------------------------- /speaker/2328.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/2328.pt -------------------------------------------------------------------------------- /speaker/2345.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/2345.pt -------------------------------------------------------------------------------- /speaker/3333.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/3333.pt -------------------------------------------------------------------------------- /speaker/4099.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/4099.pt -------------------------------------------------------------------------------- /speaker/4444.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/4444.pt -------------------------------------------------------------------------------- /speaker/4751.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/4751.pt -------------------------------------------------------------------------------- /speaker/4785.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/4785.pt -------------------------------------------------------------------------------- /speaker/491.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/491.pt -------------------------------------------------------------------------------- /speaker/492.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/492.pt -------------------------------------------------------------------------------- /speaker/5.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/5.pt -------------------------------------------------------------------------------- /speaker/5099.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/5099.pt -------------------------------------------------------------------------------- /speaker/5400.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/5400.pt -------------------------------------------------------------------------------- /speaker/5555.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/5555.pt -------------------------------------------------------------------------------- /speaker/5600.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/5600.pt -------------------------------------------------------------------------------- /speaker/6653.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/6653.pt -------------------------------------------------------------------------------- /speaker/6666.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/6666.pt -------------------------------------------------------------------------------- /speaker/7777.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/7777.pt -------------------------------------------------------------------------------- /speaker/7869.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/7869.pt -------------------------------------------------------------------------------- /speaker/8888.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/8888.pt -------------------------------------------------------------------------------- /speaker/9999.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/speaker/9999.pt -------------------------------------------------------------------------------- /static/js/layer/mobile/layer.js: -------------------------------------------------------------------------------- 1 | /*! layer mobile-v2.0.0 Web 通用弹出层组件 MIT License */ 2 | ;!function(e){"use strict";var t=document,n="querySelectorAll",i="getElementsByClassName",a=function(e){return t[n](e)},s={type:0,shade:!0,shadeClose:!0,fixed:!0,anim:"scale"},l={extend:function(e){var t=JSON.parse(JSON.stringify(s));for(var n in e)t[n]=e[n];return t},timer:{},end:{}};l.touch=function(e,t){e.addEventListener("click",function(e){t.call(this,e)},!1)};var r=0,o=["layui-m-layer"],c=function(e){var t=this;t.config=l.extend(e),t.view()};c.prototype.view=function(){var e=this,n=e.config,s=t.createElement("div");e.id=s.id=o[0]+r,s.setAttribute("class",o[0]+" "+o[0]+(n.type||0)),s.setAttribute("index",r);var l=function(){var e="object"==typeof n.title;return n.title?'

'+(e?n.title[0]:n.title)+"

":""}(),c=function(){"string"==typeof n.btn&&(n.btn=[n.btn]);var e,t=(n.btn||[]).length;return 0!==t&&n.btn?(e=''+n.btn[0]+"",2===t&&(e=''+n.btn[1]+""+e),'
'+e+"
"):""}();if(n.fixed||(n.top=n.hasOwnProperty("top")?n.top:100,n.style=n.style||"",n.style+=" top:"+(t.body.scrollTop+n.top)+"px"),2===n.type&&(n.content='

'+(n.content||"")+"

"),n.skin&&(n.anim="up"),"msg"===n.skin&&(n.shade=!1),s.innerHTML=(n.shade?"
':"")+'
"+l+'
'+n.content+"
"+c+"
",!n.type||2===n.type){var d=t[i](o[0]+n.type),y=d.length;y>=1&&layer.close(d[0].getAttribute("index"))}document.body.appendChild(s);var u=e.elem=a("#"+e.id)[0];n.success&&n.success(u),e.index=r++,e.action(n,u)},c.prototype.action=function(e,t){var n=this;e.time&&(l.timer[n.index]=setTimeout(function(){layer.close(n.index)},1e3*e.time));var a=function(){var t=this.getAttribute("type");0==t?(e.no&&e.no(),layer.close(n.index)):e.yes?e.yes(n.index):layer.close(n.index)};if(e.btn)for(var s=t[i]("layui-m-layerbtn")[0].children,r=s.length,o=0;odiv{line-height:22px;padding-top:7px;margin-bottom:20px;font-size:14px}.layui-m-layerbtn{display:box;display:-moz-box;display:-webkit-box;width:100%;height:50px;line-height:50px;font-size:0;border-top:1px solid #D0D0D0;background-color:#F2F2F2}.layui-m-layerbtn span{display:block;-moz-box-flex:1;box-flex:1;-webkit-box-flex:1;font-size:14px;cursor:pointer}.layui-m-layerbtn span[yes]{color:#40AFFE}.layui-m-layerbtn span[no]{border-right:1px solid #D0D0D0;border-radius:0 0 0 5px}.layui-m-layerbtn span:active{background-color:#F6F6F6}.layui-m-layerend{position:absolute;right:7px;top:10px;width:30px;height:30px;border:0;font-weight:400;background:0 0;cursor:pointer;-webkit-appearance:none;font-size:30px}.layui-m-layerend::after,.layui-m-layerend::before{position:absolute;left:5px;top:15px;content:'';width:18px;height:1px;background-color:#999;transform:rotate(45deg);-webkit-transform:rotate(45deg);border-radius:3px}.layui-m-layerend::after{transform:rotate(-45deg);-webkit-transform:rotate(-45deg)}body .layui-m-layer .layui-m-layer-footer{position:fixed;width:95%;max-width:100%;margin:0 auto;left:0;right:0;bottom:10px;background:0 0}.layui-m-layer-footer .layui-m-layercont{padding:20px;border-radius:5px 5px 0 0;background-color:rgba(255,255,255,.8)}.layui-m-layer-footer .layui-m-layerbtn{display:block;height:auto;background:0 0;border-top:none}.layui-m-layer-footer .layui-m-layerbtn span{background-color:rgba(255,255,255,.8)}.layui-m-layer-footer .layui-m-layerbtn span[no]{color:#FD482C;border-top:1px solid #c2c2c2;border-radius:0 0 5px 5px}.layui-m-layer-footer .layui-m-layerbtn span[yes]{margin-top:10px;border-radius:5px}body .layui-m-layer .layui-m-layer-msg{width:auto;max-width:90%;margin:0 auto;bottom:-150px;background-color:rgba(0,0,0,.7);color:#fff}.layui-m-layer-msg .layui-m-layercont{padding:10px 20px} -------------------------------------------------------------------------------- /static/js/layer/theme/default/icon-ext.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/static/js/layer/theme/default/icon-ext.png -------------------------------------------------------------------------------- /static/js/layer/theme/default/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/static/js/layer/theme/default/icon.png -------------------------------------------------------------------------------- /static/js/layer/theme/default/loading-0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/static/js/layer/theme/default/loading-0.gif -------------------------------------------------------------------------------- /static/js/layer/theme/default/loading-1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/static/js/layer/theme/default/loading-1.gif -------------------------------------------------------------------------------- /static/js/layer/theme/default/loading-2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/static/js/layer/theme/default/loading-2.gif -------------------------------------------------------------------------------- /templates/indexen.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | ChatTTS WebUI & API - v{{version}} 7 | 8 | 9 | 46 | 47 | 48 | 49 |
50 |
51 |

ChatTTS WebUI & API(v{{version}})

52 |
53 |
54 |
55 |
56 | 57 | Besides the textbox being required, all others are optional. If you do not understand, no need to fill out. 58 |
59 |
60 |
61 |
62 | Select Voice 63 | 68 |
69 |
70 | Voice Value 71 | 72 |
73 |
74 | text seed 75 | 76 |
77 |
78 | Prompt 79 | 80 |
81 | 82 |
83 |
84 | 85 | 86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 | infer token 94 | 95 |
96 |
97 | refine token 98 | 99 |
100 |
101 | 102 | 103 | 5 104 |
105 | 106 |
107 | 108 | 109 | 0.1 110 |
111 | 112 |
113 | 114 | 115 | 0.05 116 | 117 |
118 | 119 |
120 | 121 | 122 | 20 123 | 124 |
125 |
126 |
127 |
128 | 129 | 132 | 133 |
134 |
135 |
136 |
137 |
138 |
139 | 140 |
141 | GitHub ChatTTS-UI 142 | GitHub ChatTTS 143 |
144 |
145 | 146 | 147 | 148 | 149 | 304 | 305 | 306 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | if sys.platform == "darwin": 5 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 6 | import io 7 | import json 8 | import torchaudio 9 | import wave 10 | from pathlib import Path 11 | print('Starting...') 12 | import shutil 13 | import time 14 | 15 | 16 | import torch 17 | import torch._dynamo 18 | torch._dynamo.config.suppress_errors = True 19 | torch._dynamo.config.cache_size_limit = 64 20 | torch._dynamo.config.suppress_errors = True 21 | torch.set_float32_matmul_precision('high') 22 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 23 | import subprocess 24 | import soundfile as sf 25 | import ChatTTS 26 | import datetime 27 | from dotenv import load_dotenv 28 | load_dotenv() 29 | 30 | import logging 31 | from logging.handlers import RotatingFileHandler 32 | 33 | from random import random 34 | from modelscope import snapshot_download 35 | import numpy as np 36 | import threading 37 | from uilib.cfg import WEB_ADDRESS, SPEAKER_DIR, LOGS_DIR, WAVS_DIR, MODEL_DIR, ROOT_DIR 38 | from uilib import utils,VERSION 39 | from ChatTTS.utils.gpu_utils import select_device 40 | from uilib.utils import is_chinese_os,modelscope_status 41 | merge_size=int(os.getenv('merge_size',10)) 42 | env_lang=os.getenv('lang','') 43 | if env_lang=='zh': 44 | is_cn= True 45 | elif env_lang=='en': 46 | is_cn=False 47 | else: 48 | is_cn=is_chinese_os() 49 | 50 | CHATTTS_DIR= MODEL_DIR+'/pzc163/chatTTS' 51 | 52 | 53 | chat = ChatTTS.Chat() 54 | device=os.getenv('device','default') 55 | chat.load(source="custom",custom_path=CHATTTS_DIR, device=None if device=='default' else device,compile=True if os.getenv('compile','true').lower()!='false' else False) 56 | 57 | for it in os.listdir('./speaker'): 58 | if it.startswith('seed_') and not it.endswith('_emb-covert.pt'): 59 | 60 | 61 | rand_spk=torch.load(f'./speaker/{it}', map_location=select_device(4096) if device=='default' else torch.device(device)) 62 | 63 | torch.save( chat._encode_spk_emb(rand_spk) ,f"{SPEAKER_DIR}/{it.replace('.pt','-covert.pt')}") -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/ChatTTS-ui/d136c389f3588c6319aee7dcf296154555d72912/tools/__init__.py -------------------------------------------------------------------------------- /tools/audio/__init__.py: -------------------------------------------------------------------------------- 1 | from .np import unsafe_float_to_int16 2 | -------------------------------------------------------------------------------- /tools/audio/np.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numba import jit 3 | 4 | 5 | @jit 6 | def unsafe_float_to_int16(audio: np.ndarray) -> np.ndarray: 7 | """ 8 | This function will destroy audio, use only once. 9 | """ 10 | am = np.abs(audio).max() * 32768 11 | am = 32767 * 32768 / am 12 | np.multiply(audio, am, audio) 13 | audio16 = audio.astype(np.int16) 14 | return audio16 15 | -------------------------------------------------------------------------------- /tools/checksum/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/sha256" 5 | "encoding/hex" 6 | "fmt" 7 | "io" 8 | "os" 9 | ) 10 | 11 | func main() { 12 | var buf [32]byte 13 | h := sha256.New() 14 | lst := make([]any, 0, 64) 15 | for _, fname := range files { 16 | f, err := os.Open(fname) 17 | if err != nil { 18 | panic(err) 19 | } 20 | _, err = io.Copy(h, f) 21 | if err != nil { 22 | panic(err) 23 | } 24 | s := hex.EncodeToString(h.Sum(buf[:0])) 25 | fmt.Println("sha256 of", fname, "=", s) 26 | lst = append(lst, s) 27 | h.Reset() 28 | f.Close() 29 | } 30 | f, err := os.Create("ChatTTS/res/sha256_map.json") 31 | if err != nil { 32 | panic(err) 33 | } 34 | _, err = fmt.Fprintf(f, jsontmpl, lst...) 35 | if err != nil { 36 | panic(err) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /tools/checksum/tmpl.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | var files = [...]string{ 4 | "asset/Decoder.pt", 5 | "asset/DVAE.pt", 6 | "asset/GPT.pt", 7 | "asset/spk_stat.pt", 8 | "asset/tokenizer.pt", 9 | "asset/Vocos.pt", 10 | 11 | "config/decoder.yaml", 12 | "config/dvae.yaml", 13 | "config/gpt.yaml", 14 | "config/path.yaml", 15 | "config/vocos.yaml", 16 | } 17 | 18 | const jsontmpl = `{ 19 | "sha256_asset_Decoder_pt" : "%s", 20 | "sha256_asset_DVAE_pt" : "%s", 21 | "sha256_asset_GPT_pt" : "%s", 22 | "sha256_asset_spk_stat_pt" : "%s", 23 | "sha256_asset_tokenizer_pt" : "%s", 24 | "sha256_asset_Vocos_pt" : "%s", 25 | 26 | "sha256_config_decoder_yaml": "%s", 27 | "sha256_config_dvae_yaml" : "%s", 28 | "sha256_config_gpt_yaml" : "%s", 29 | "sha256_config_path_yaml" : "%s", 30 | "sha256_config_vocos_yaml" : "%s" 31 | } 32 | ` 33 | -------------------------------------------------------------------------------- /tools/llm/__init__.py: -------------------------------------------------------------------------------- 1 | from .llm import ChatOpenAI 2 | -------------------------------------------------------------------------------- /tools/llm/llm.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | 3 | prompt_dict = { 4 | "kimi": [ 5 | { 6 | "role": "system", 7 | "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。", 8 | }, 9 | { 10 | "role": "user", 11 | "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。", 12 | }, 13 | { 14 | "role": "assistant", 15 | "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。", 16 | }, 17 | ], 18 | "deepseek": [ 19 | {"role": "system", "content": "You are a helpful assistant"}, 20 | { 21 | "role": "user", 22 | "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。", 23 | }, 24 | { 25 | "role": "assistant", 26 | "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。", 27 | }, 28 | ], 29 | "deepseek_TN": [ 30 | {"role": "system", "content": "You are a helpful assistant"}, 31 | { 32 | "role": "user", 33 | "content": "你好,现在我们在处理TTS的文本输入,下面将会给你输入一段文本,请你将其中的阿拉伯数字等等转为文字表达,并且输出的文本里仅包含逗号和句号这两个标点符号", 34 | }, 35 | { 36 | "role": "assistant", 37 | "content": "好的,我现在对TTS的文本输入进行处理。这一般叫做text normalization。下面请输入", 38 | }, 39 | {"role": "user", "content": "We paid $123 for this desk."}, 40 | { 41 | "role": "assistant", 42 | "content": "We paid one hundred and twenty three dollars for this desk.", 43 | }, 44 | {"role": "user", "content": "详询请拨打010-724654"}, 45 | {"role": "assistant", "content": "详询请拨打零幺零,七二四六五四"}, 46 | {"role": "user", "content": "罗森宣布将于7月24日退市,在华门店超6000家!"}, 47 | { 48 | "role": "assistant", 49 | "content": "罗森宣布将于七月二十四日退市,在华门店超过六千家。", 50 | }, 51 | ], 52 | } 53 | 54 | 55 | class ChatOpenAI: 56 | def __init__(self, api_key, base_url, model): 57 | self.client = OpenAI( 58 | api_key=api_key, 59 | base_url=base_url, 60 | ) 61 | self.model = model 62 | 63 | def call(self, user_question, temperature=0.3, prompt_version="kimi", **kwargs): 64 | 65 | completion = self.client.chat.completions.create( 66 | model=self.model, 67 | messages=prompt_dict[prompt_version] 68 | + [ 69 | {"role": "user", "content": user_question}, 70 | ], 71 | temperature=temperature, 72 | **kwargs 73 | ) 74 | return completion.choices[0].message.content 75 | -------------------------------------------------------------------------------- /tools/logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .log import get_logger 2 | -------------------------------------------------------------------------------- /tools/logger/log.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import logging 3 | from datetime import datetime, timezone 4 | 5 | logging.getLogger("numba").setLevel(logging.WARNING) 6 | logging.getLogger("httpx").setLevel(logging.WARNING) 7 | logging.getLogger("wetext-zh_normalizer").setLevel(logging.WARNING) 8 | logging.getLogger("NeMo-text-processing").setLevel(logging.WARNING) 9 | 10 | # from https://github.com/FloatTech/ZeroBot-Plugin/blob/c70766a989698452e60e5e48fb2f802a2444330d/console/console_windows.go#L89-L96 11 | colorCodePanic = "\x1b[1;31m" 12 | colorCodeFatal = "\x1b[1;31m" 13 | colorCodeError = "\x1b[31m" 14 | colorCodeWarn = "\x1b[33m" 15 | colorCodeInfo = "\x1b[37m" 16 | colorCodeDebug = "\x1b[32m" 17 | colorCodeTrace = "\x1b[36m" 18 | colorReset = "\x1b[0m" 19 | 20 | log_level_color_code = { 21 | logging.DEBUG: colorCodeDebug, 22 | logging.INFO: colorCodeInfo, 23 | logging.WARN: colorCodeWarn, 24 | logging.ERROR: colorCodeError, 25 | logging.FATAL: colorCodeFatal, 26 | } 27 | 28 | log_level_msg_str = { 29 | logging.DEBUG: "DEBU", 30 | logging.INFO: "INFO", 31 | logging.WARN: "WARN", 32 | logging.ERROR: "ERRO", 33 | logging.FATAL: "FATL", 34 | } 35 | 36 | 37 | class Formatter(logging.Formatter): 38 | def __init__(self, color=platform.system().lower() != "windows"): 39 | # https://stackoverflow.com/questions/2720319/python-figure-out-local-timezone 40 | self.tz = datetime.now(timezone.utc).astimezone().tzinfo 41 | self.color = color 42 | 43 | def format(self, record: logging.LogRecord): 44 | logstr = "[" + datetime.now(self.tz).strftime("%z %Y%m%d %H:%M:%S") + "] [" 45 | if self.color: 46 | logstr += log_level_color_code.get(record.levelno, colorCodeInfo) 47 | logstr += log_level_msg_str.get(record.levelno, record.levelname) 48 | if self.color: 49 | logstr += colorReset 50 | fn = record.filename.removesuffix(".py") 51 | logstr += f"] {str(record.name)} | {fn} | {str(record.msg)%record.args}" 52 | return logstr 53 | 54 | 55 | def get_logger(name: str, lv=logging.INFO, remove_exist=False, format_root=False): 56 | logger = logging.getLogger(name) 57 | logger.setLevel(lv) 58 | if remove_exist and logger.hasHandlers(): 59 | logger.handlers.clear() 60 | if not logger.hasHandlers(): 61 | syslog = logging.StreamHandler() 62 | syslog.setFormatter(Formatter()) 63 | logger.addHandler(syslog) 64 | else: 65 | for h in logger.handlers: 66 | h.setFormatter(Formatter()) 67 | if format_root: 68 | for h in logger.root.handlers: 69 | h.setFormatter(Formatter()) 70 | return logger 71 | -------------------------------------------------------------------------------- /tools/normalizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .en import normalizer_en_nemo_text 2 | from .zh import normalizer_zh_tn 3 | -------------------------------------------------------------------------------- /tools/normalizer/en.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | from functools import partial 3 | 4 | 5 | def normalizer_en_nemo_text() -> Callable[[str], str]: 6 | from nemo_text_processing.text_normalization.normalize import Normalizer 7 | 8 | return partial( 9 | Normalizer(input_case="cased", lang="en").normalize, 10 | verbose=False, 11 | punct_post_process=True, 12 | ) 13 | -------------------------------------------------------------------------------- /tools/normalizer/zh.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | 4 | def normalizer_zh_tn() -> Callable[[str], str]: 5 | from tn.chinese.normalizer import Normalizer 6 | 7 | return Normalizer().normalize 8 | -------------------------------------------------------------------------------- /tools/seeder/__init__.py: -------------------------------------------------------------------------------- 1 | from .ctx import TorchSeedContext 2 | -------------------------------------------------------------------------------- /tools/seeder/ctx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TorchSeedContext: 5 | def __init__(self, seed): 6 | self.seed = seed 7 | self.state = None 8 | 9 | def __enter__(self): 10 | self.state = torch.random.get_rng_state() 11 | torch.manual_seed(self.seed) 12 | 13 | def __exit__(self, type, value, traceback): 14 | torch.random.set_rng_state(self.state) 15 | -------------------------------------------------------------------------------- /uilib/__init__.py: -------------------------------------------------------------------------------- 1 | VERSION='1.0' -------------------------------------------------------------------------------- /uilib/cfg.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | import os 4 | 5 | def get_executable_path(): 6 | # 这个函数会返回可执行文件所在的目录 7 | if getattr(sys, 'frozen', False): 8 | # 如果程序是被“冻结”打包的,使用这个路径 9 | return Path(sys.executable).parent.as_posix() 10 | else: 11 | return Path.cwd().as_posix() 12 | 13 | ROOT_DIR=get_executable_path() 14 | 15 | MODEL_DIR_PATH=Path(ROOT_DIR+"/asset") 16 | MODEL_DIR_PATH.mkdir(parents=True, exist_ok=True) 17 | MODEL_DIR=MODEL_DIR_PATH.as_posix() 18 | 19 | WAVS_DIR_PATH=Path(ROOT_DIR+"/static/wavs") 20 | WAVS_DIR_PATH.mkdir(parents=True, exist_ok=True) 21 | WAVS_DIR=WAVS_DIR_PATH.as_posix() 22 | 23 | LOGS_DIR_PATH=Path(ROOT_DIR+"/logs") 24 | LOGS_DIR_PATH.mkdir(parents=True, exist_ok=True) 25 | LOGS_DIR=LOGS_DIR_PATH.as_posix() 26 | 27 | SPEAKER_DIR_PATH=Path(ROOT_DIR+"/speaker") 28 | SPEAKER_DIR_PATH.mkdir(parents=True, exist_ok=True) 29 | SPEAKER_DIR=SPEAKER_DIR_PATH.as_posix() 30 | 31 | # ffmpeg 32 | if sys.platform == 'win32': 33 | os.environ['PATH'] = ROOT_DIR + f';{ROOT_DIR}/ffmpeg;' + os.environ['PATH'] 34 | 35 | else: 36 | os.environ['PATH'] = ROOT_DIR + f':{ROOT_DIR}/ffmpeg:' + os.environ['PATH'] 37 | 38 | 39 | # 读取 .env 变量 40 | WEB_ADDRESS = os.getenv('WEB_ADDRESS', '127.0.0.1:9966') 41 | 42 | -------------------------------------------------------------------------------- /uilib/utils.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import requests 3 | import time 4 | import re 5 | import webbrowser 6 | from pathlib import Path 7 | import pandas as pd 8 | # ref: https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization 9 | from .zh_normalization import TextNormalizer 10 | from .cfg import SPEAKER_DIR 11 | from functools import partial 12 | 13 | def openweb(url): 14 | time.sleep(3) 15 | try: 16 | webbrowser.open(url) 17 | except Exception: 18 | pass 19 | 20 | def get_parameter(request, param, default, cast_type): 21 | # 先request.args 后request.form 然后转换cast_type=int|float类型。 22 | for method in [request.args.get, request.form.get]: 23 | value = method(param, "").strip() 24 | if value: 25 | try: 26 | return cast_type(value) 27 | except ValueError: 28 | break # args转换失败,退出尝试form 29 | return default # 失败,返回默认值。 30 | 31 | 32 | # 数字转为英文读法 33 | def num_to_english(num): 34 | 35 | num_str = str(num) 36 | # English representations for numbers 0-9 37 | english_digits = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"] 38 | units = ["", "ten", "hundred", "thousand"] 39 | big_units = ["", "thousand", "million", "billion", "trillion"] 40 | result = "" 41 | need_and = False # Indicates whether 'and' needs to be added 42 | part = [] # Stores each group of 4 digits 43 | is_first_part = True # Indicates if it is the first part for not adding 'and' at the beginning 44 | 45 | # Split the number into 3-digit groups 46 | while num_str: 47 | part.append(num_str[-3:]) 48 | num_str = num_str[:-3] 49 | 50 | part.reverse() 51 | 52 | for i, p in enumerate(part): 53 | p_str = "" 54 | digit_len = len(p) 55 | if int(p) == 0 and i < len(part) - 1: 56 | continue 57 | 58 | hundreds_digit = int(p) // 100 if digit_len == 3 else None 59 | tens_digit = int(p) % 100 if digit_len >= 2 else int(p[0] if digit_len == 1 else p[1]) 60 | 61 | # Process hundreds 62 | if hundreds_digit is not None and hundreds_digit != 0: 63 | p_str += english_digits[hundreds_digit] + " hundred" 64 | if tens_digit != 0: 65 | p_str += " and " 66 | 67 | # Process tens and ones 68 | if 10 < tens_digit < 20: # Teens exception 69 | teen_map = { 70 | 11: "eleven", 12: "twelve", 13: "thirteen", 14: "fourteen", 15: "fifteen", 71 | 16: "sixteen", 17: "seventeen", 18: "eighteen", 19: "nineteen" 72 | } 73 | p_str += teen_map[tens_digit] 74 | else: 75 | tens_map = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"] 76 | tens_val = tens_digit // 10 77 | ones_val = tens_digit % 10 78 | if tens_val >= 2: 79 | p_str += tens_map[tens_val] + (" " + english_digits[ones_val] if ones_val != 0 else "") 80 | elif tens_digit != 0 and tens_val < 2: # When tens_digit is in [1, 9] 81 | p_str += english_digits[tens_digit] 82 | 83 | if p_str and not is_first_part and need_and: 84 | result += " and " 85 | result += p_str 86 | if i < len(part) - 1 and int(p) != 0: 87 | result += " " + big_units[len(part) - i - 1] + ", " 88 | 89 | is_first_part = False 90 | if int(p) != 0: 91 | need_and = True 92 | 93 | return result.capitalize() 94 | 95 | 96 | def get_lang(text): 97 | # 定义中文标点符号的模式 98 | chinese_punctuation = "[。?!,、;:‘’“”()《》【】…—\u3000]" 99 | # 使用正则表达式替换所有中文标点为"" 100 | cleaned_text = re.sub(chinese_punctuation, "", text) 101 | # 使用正则表达式来匹配中文字符范围 102 | return "zh" if re.search('[\u4e00-\u9fff]', cleaned_text) is not None else "en" 103 | 104 | def fraction_to_words(match): 105 | numerator, denominator = match.groups() 106 | # 这里只是把数字直接拼接成了英文分数的形式, 实际上应该使用某种方式将数字转换为英文单词 107 | # 例如: "1/2" -> "one half", 这里仅为展示目的而直接返回了 "numerator/denominator" 108 | return numerator + " over " + denominator 109 | 110 | 111 | 112 | # 数字转为英文读法 113 | def num2text(text): 114 | numtext=[' zero ',' one ',' two ',' three ',' four ',' five ',' six ',' seven ',' eight ',' nine '] 115 | point=' point ' 116 | text = re.sub(r'(\d)\,(\d)', r'\1\2', text) 117 | text = re.sub(r'(\d+)\s*\+', r'\1 plus ', text) 118 | text = re.sub(r'(\d+)\s*\-', r'\1 minus ', text) 119 | text = re.sub(r'(\d+)\s*[\*x]', r'\1 times ', text) 120 | text = re.sub(r'((?:\d+\.)?\d+)\s*/\s*(\d+)', fraction_to_words, text) 121 | 122 | # 取出数字 number_list= [('1000200030004000.123', '1000200030004000', '123'), ('23425', '23425', '')] 123 | number_list=re.findall(r'((\d+)(?:\.(\d+))?%?)', text) 124 | if len(number_list)>0: 125 | #dc= ('1000200030004000.123', '1000200030004000', '123','') 126 | for m,dc in enumerate(number_list): 127 | if len(dc[1])>16: 128 | continue 129 | int_text= num_to_english(dc[1]) 130 | if len(dc)>2 and dc[2]: 131 | int_text+=point+"".join([numtext[int(i)] for i in dc[2]]) 132 | if dc[0][-1]=='%': 133 | int_text=f' the pronunciation of {int_text}' 134 | text=text.replace(dc[0],int_text) 135 | 136 | 137 | return text.replace('1',' one ').replace('2',' two ').replace('3',' three ').replace('4',' four ').replace('5',' five ').replace('6',' six ').replace('7','seven').replace('8',' eight ').replace('9',' nine ').replace('0',' zero ').replace('=',' equals ') 138 | 139 | 140 | 141 | def remove_brackets(text): 142 | # 正则表达式 143 | text=re.sub(r'\[(uv_break|laugh|lbreak|break)\]',r' \1 ',text,re.I|re.S|re.M) 144 | 145 | # 使用 re.sub 替换掉 [ ] 对 146 | newt=re.sub(r'\[|\]|!|:|{|}', '', text) 147 | return re.sub(r'\s(uv_break|laugh|lbreak|break)(?=\s|$)', r' [\1] ', newt) 148 | 149 | 150 | # 中英文数字转换为文字,特殊符号处理 151 | def split_text(text_list): 152 | 153 | tx = TextNormalizer() 154 | haserror=False 155 | result=[] 156 | for i,text in enumerate(text_list): 157 | text=remove_brackets(text) 158 | if get_lang(text)=='zh': 159 | tmp="".join(tx.normalize(text)) 160 | elif haserror: 161 | tmp=num2text(text) 162 | else: 163 | try: 164 | # 先尝试使用 nemo_text_processing 处理英文 165 | from nemo_text_processing.text_normalization.normalize import Normalizer 166 | fun = partial(Normalizer(input_case='cased', lang="en").normalize, verbose=False, punct_post_process=True) 167 | tmp=fun(text) 168 | print(f'使用nemo处理英文ok') 169 | except Exception as e: 170 | print(f"nemo处理英文失败,改用自定义预处理") 171 | print(e) 172 | haserror=True 173 | tmp=num2text(text) 174 | 175 | if len(tmp)>200: 176 | tmp_res=split_text_by_punctuation(tmp) 177 | result=result+tmp_res 178 | else: 179 | result.append(tmp) 180 | return result 181 | 182 | # 切分长行 200 150 183 | def split_text_by_punctuation(text): 184 | # 定义长度限制 185 | min_length = 150 186 | punctuation_marks = "。?!,、;:”’》」』)】…—" 187 | english_punctuation = ".?!,:;)}…" 188 | 189 | # 结果列表 190 | result = [] 191 | # 起始位置 192 | pos = 0 193 | 194 | # 遍历文本中的每个字符 195 | text_length=len(text) 196 | for i, char in enumerate(text): 197 | if char in punctuation_marks or char in english_punctuation: 198 | if char=='.' and i< text_length-1 and re.match(r'\d',text[i+1]): 199 | continue 200 | # 当遇到标点时,判断当前分段长度是否超过120 201 | if i - pos > min_length: 202 | # 如果长度超过120,将当前分段添加到结果列表中 203 | result.append(text[pos:i+1]) 204 | # 更新起始位置到当前标点的下一个字符 205 | pos = i+1 206 | #print(f'{pos=},{len(text)=}') 207 | 208 | # 如果剩余文本长度超过120或没有更多标点符号可以进行分割,将剩余的文本作为一个分段添加到结果列表 209 | if pos < len(text): 210 | result.append(text[pos:]) 211 | 212 | return result 213 | 214 | 215 | # 获取../static/wavs目录中的所有文件和目录并清理wav 216 | def ClearWav(directory): 217 | files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))] 218 | 219 | if not files: 220 | return False, "wavs目录内无wav文件" 221 | 222 | for filename in os.listdir(directory): 223 | file_path = os.path.join(directory, filename) 224 | try: 225 | if os.path.isfile(file_path) or os.path.islink(file_path): 226 | os.unlink(file_path) 227 | print(f"已删除文件: {file_path}") 228 | elif os.path.isdir(file_path): 229 | print(f"跳过文件夹: {file_path}") 230 | except Exception as e: 231 | print(f"文件删除错误 {file_path}, 报错信息: {e}") 232 | return False, str(e) 233 | return True, "所有wav文件已被删除." 234 | 235 | 236 | 237 | # 加载音色 238 | # 参考 https://github.com/craii/ChatTTS_WebUI/blob/main/utils.py 239 | def load_speaker(name): 240 | speaker_path = f"{SPEAKER_DIR}/{name}.csv" if not name.endswith('.csv') else f"{SPEAKER_DIR}/{name}" 241 | if not os.path.exists(speaker_path): 242 | return None 243 | try: 244 | import torch 245 | d_s = pd.read_csv(speaker_path, header=None).iloc[:, 0] 246 | tensor = torch.tensor(d_s.values) 247 | except Exception as e: 248 | print(e) 249 | return None 250 | return tensor 251 | 252 | 253 | # 获取 speaker_dir下的所有csv pt文件 254 | def get_speakers(): 255 | result=[] 256 | for it in os.listdir(SPEAKER_DIR): 257 | if it.endswith('.pt'): 258 | result.append(it) 259 | return result 260 | 261 | # 判断是否可以连接外网 262 | def is_network(): 263 | try: 264 | import requests 265 | requests.head('https://baidu.com') 266 | except Exception: 267 | return False 268 | else: 269 | return True 270 | return False 271 | 272 | 273 | 274 | def is_chinese_os(): 275 | import subprocess 276 | try: 277 | import locale 278 | # Windows系统 279 | if sys.platform.startswith('win'): 280 | lang = locale.getdefaultlocale()[0] 281 | return lang.startswith('zh_CN') or lang.startswith('zh_TW') or lang.startswith('zh_HK') 282 | # macOS系统 283 | elif sys.platform == 'darwin': 284 | process = subprocess.Popen(['defaults', 'read', '-g', 'AppleLocale'], stdout=subprocess.PIPE) 285 | output, error = process.communicate() 286 | if error: 287 | # 若默认方法出错,则尝试环境变量 288 | return os.getenv('LANG', '').startswith('zh_') 289 | locale = output.decode().strip() 290 | return locale.startswith('zh_') 291 | # 类Unix系统 292 | elif sys.platform.startswith('linux') or sys.platform.startswith('cygwin'): 293 | return os.getenv('LANG', '').startswith('zh_') 294 | # 其他系统 295 | else: 296 | return False 297 | 298 | except Exception as e: 299 | # 输出异常到控制台,实际应用中应该使用日志记录异常 300 | print(e) 301 | return False 302 | 303 | 304 | 305 | def modelscope_status(): 306 | #return False 307 | try: 308 | res=requests.head("https://www.modelscope.cn/") 309 | print(res) 310 | if res.status_code!=200: 311 | return False 312 | except Exception as e: 313 | return False 314 | return True 315 | -------------------------------------------------------------------------------- /uilib/zh_normalization/README.md: -------------------------------------------------------------------------------- 1 | ## Supported NSW (Non-Standard-Word) Normalization 2 | 3 | |NSW type|raw|normalized| 4 | |:--|:-|:-| 5 | |serial number|电影中梁朝伟扮演的陈永仁的编号27149|电影中梁朝伟扮演的陈永仁的编号二七一四九| 6 | |cardinal|这块黄金重达324.75克
我们班的最高总分为583分|这块黄金重达三百二十四点七五克
我们班的最高总分为五百八十三分| 7 | |numeric range |12\~23
-1.5\~2|十二到二十三
负一点五到二| 8 | |date|她出生于86年8月18日,她弟弟出生于1995年3月1日|她出生于八六年八月十八日, 她弟弟出生于一九九五年三月一日| 9 | |time|等会请在12:05请通知我|等会请在十二点零五分请通知我 10 | |temperature|今天的最低气温达到-10°C|今天的最低气温达到零下十度 11 | |fraction|现场有7/12的观众投出了赞成票|现场有十二分之七的观众投出了赞成票| 12 | |percentage|明天有62%的概率降雨|明天有百分之六十二的概率降雨| 13 | |money|随便来几个价格12块5,34.5元,20.1万|随便来几个价格十二块五,三十四点五元,二十点一万| 14 | |telephone|这是固话0421-33441122
这是手机+86 18544139121|这是固话零四二一三三四四一一二二
这是手机八六一八五四四一三九一二一| 15 | ## References 16 | [Pull requests #658 of DeepSpeech](https://github.com/PaddlePaddle/DeepSpeech/pull/658/files) 17 | -------------------------------------------------------------------------------- /uilib/zh_normalization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from .text_normlization import * 15 | -------------------------------------------------------------------------------- /uilib/zh_normalization/chronology.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | 16 | from .num import DIGITS 17 | from .num import num2str 18 | from .num import verbalize_cardinal 19 | from .num import verbalize_digit 20 | 21 | 22 | def _time_num2str(num_string: str) -> str: 23 | """A special case for verbalizing number in time.""" 24 | result = num2str(num_string.lstrip('0')) 25 | if num_string.startswith('0'): 26 | result = DIGITS['0'] + result 27 | return result 28 | 29 | 30 | # 时刻表达式 31 | RE_TIME = re.compile(r'([0-1]?[0-9]|2[0-3])' 32 | r':([0-9][0-9]?)' 33 | r'(:([0-9][0-9]?))?') 34 | 35 | # 时间范围,如8:30-12:30 36 | RE_TIME_RANGE = re.compile(r'([0-1]?[0-9]|2[0-3])' 37 | r':([0-9][0-9]?)' 38 | r'(:([0-9][0-9]?))?' 39 | r'(~|-)' 40 | r'([0-1]?[0-9]|2[0-3])' 41 | r':([0-9][0-9]?)' 42 | r'(:([0-9][0-9]?))?') 43 | 44 | 45 | def replace_time(match) -> str: 46 | """ 47 | Args: 48 | match (re.Match) 49 | Returns: 50 | str 51 | """ 52 | 53 | is_range = len(match.groups()) > 5 54 | 55 | hour = match.group(1) 56 | minute = match.group(2) 57 | second = match.group(4) 58 | 59 | if is_range: 60 | hour_2 = match.group(6) 61 | minute_2 = match.group(7) 62 | second_2 = match.group(9) 63 | 64 | result = f"{num2str(hour)}点" 65 | if minute.lstrip('0'): 66 | if int(minute) == 30: 67 | result += "半" 68 | else: 69 | result += f"{_time_num2str(minute)}分" 70 | if second and second.lstrip('0'): 71 | result += f"{_time_num2str(second)}秒" 72 | 73 | if is_range: 74 | result += "至" 75 | result += f"{num2str(hour_2)}点" 76 | if minute_2.lstrip('0'): 77 | if int(minute) == 30: 78 | result += "半" 79 | else: 80 | result += f"{_time_num2str(minute_2)}分" 81 | if second_2 and second_2.lstrip('0'): 82 | result += f"{_time_num2str(second_2)}秒" 83 | 84 | return result 85 | 86 | 87 | RE_DATE = re.compile(r'(\d{4}|\d{2})年' 88 | r'((0?[1-9]|1[0-2])月)?' 89 | r'(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?') 90 | 91 | 92 | def replace_date(match) -> str: 93 | """ 94 | Args: 95 | match (re.Match) 96 | Returns: 97 | str 98 | """ 99 | year = match.group(1) 100 | month = match.group(3) 101 | day = match.group(5) 102 | result = "" 103 | if year: 104 | result += f"{verbalize_digit(year)}年" 105 | if month: 106 | result += f"{verbalize_cardinal(month)}月" 107 | if day: 108 | result += f"{verbalize_cardinal(day)}{match.group(9)}" 109 | return result 110 | 111 | 112 | # 用 / 或者 - 分隔的 YY/MM/DD 或者 YY-MM-DD 日期 113 | RE_DATE2 = re.compile( 114 | r'(\d{4})([- /.])(0?[1-9]|1[012])\2(0?[1-9]|[12][0-9]|3[01])') 115 | 116 | 117 | def replace_date2(match) -> str: 118 | """ 119 | Args: 120 | match (re.Match) 121 | Returns: 122 | str 123 | """ 124 | year = match.group(1) 125 | month = match.group(3) 126 | day = match.group(4) 127 | result = "" 128 | if year: 129 | result += f"{verbalize_digit(year)}年" 130 | if month: 131 | result += f"{verbalize_cardinal(month)}月" 132 | if day: 133 | result += f"{verbalize_cardinal(day)}日" 134 | return result 135 | -------------------------------------------------------------------------------- /uilib/zh_normalization/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | import string 16 | 17 | #from pypinyin.constants import SUPPORT_UCS4 18 | 19 | # 全角半角转换 20 | # 英文字符全角 -> 半角映射表 (num: 52) 21 | F2H_ASCII_LETTERS = { 22 | ord(char) + 65248: ord(char) 23 | for char in string.ascii_letters 24 | } 25 | 26 | # 英文字符半角 -> 全角映射表 27 | H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()} 28 | 29 | # 数字字符全角 -> 半角映射表 (num: 10) 30 | F2H_DIGITS = {ord(char) + 65248: ord(char) for char in string.digits} 31 | # 数字字符半角 -> 全角映射表 32 | H2F_DIGITS = {value: key for key, value in F2H_DIGITS.items()} 33 | 34 | # 标点符号全角 -> 半角映射表 (num: 32) 35 | F2H_PUNCTUATIONS = {ord(char) + 65248: ord(char) for char in string.punctuation} 36 | # 标点符号半角 -> 全角映射表 37 | H2F_PUNCTUATIONS = {value: key for key, value in F2H_PUNCTUATIONS.items()} 38 | 39 | # 空格 (num: 1) 40 | F2H_SPACE = {'\u3000': ' '} 41 | H2F_SPACE = {' ': '\u3000'} 42 | 43 | # 非"有拼音的汉字"的字符串,可用于NSW提取 44 | ''' 45 | if SUPPORT_UCS4: 46 | RE_NSW = re.compile(r'(?:[^' 47 | r'\u3007' # 〇 48 | r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF] 49 | r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF] 50 | r'\uf900-\ufaff' # CJK兼容:[F900-FAFF] 51 | r'\U00020000-\U0002A6DF' # CJK扩展B:[20000-2A6DF] 52 | r'\U0002A703-\U0002B73F' # CJK扩展C:[2A700-2B73F] 53 | r'\U0002B740-\U0002B81D' # CJK扩展D:[2B740-2B81D] 54 | r'\U0002F80A-\U0002FA1F' # CJK兼容扩展:[2F800-2FA1F] 55 | r'])+') 56 | else: 57 | ''' 58 | RE_NSW = re.compile( # pragma: no cover 59 | r'(?:[^' 60 | r'\u3007' # 〇 61 | r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF] 62 | r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF] 63 | r'\uf900-\ufaff' # CJK兼容:[F900-FAFF] 64 | r'])+') 65 | -------------------------------------------------------------------------------- /uilib/zh_normalization/num.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Rules to verbalize numbers into Chinese characters. 16 | https://zh.wikipedia.org/wiki/中文数字#現代中文 17 | """ 18 | import re 19 | from collections import OrderedDict 20 | from typing import List 21 | 22 | DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')} 23 | UNITS = OrderedDict({ 24 | 1: '十', 25 | 2: '百', 26 | 3: '千', 27 | 4: '万', 28 | 8: '亿', 29 | }) 30 | 31 | COM_QUANTIFIERS = '(封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)' 32 | 33 | # 分数表达式 34 | RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)') 35 | 36 | 37 | def replace_frac(match) -> str: 38 | """ 39 | Args: 40 | match (re.Match) 41 | Returns: 42 | str 43 | """ 44 | sign = match.group(1) 45 | nominator = match.group(2) 46 | denominator = match.group(3) 47 | sign: str = "负" if sign else "" 48 | nominator: str = num2str(nominator) 49 | denominator: str = num2str(denominator) 50 | result = f"{sign}{denominator}分之{nominator}" 51 | return result 52 | 53 | 54 | # 百分数表达式 55 | RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%') 56 | 57 | 58 | def replace_percentage(match) -> str: 59 | """ 60 | Args: 61 | match (re.Match) 62 | Returns: 63 | str 64 | """ 65 | sign = match.group(1) 66 | percent = match.group(2) 67 | sign: str = "负" if sign else "" 68 | percent: str = num2str(percent) 69 | result = f"{sign}百分之{percent}" 70 | return result 71 | 72 | 73 | # 整数表达式 74 | # 带负号的整数 -10 75 | RE_INTEGER = re.compile(r'(-)' r'(\d+)') 76 | 77 | 78 | def replace_negative_num(match) -> str: 79 | """ 80 | Args: 81 | match (re.Match) 82 | Returns: 83 | str 84 | """ 85 | sign = match.group(1) 86 | number = match.group(2) 87 | sign: str = "负" if sign else "" 88 | number: str = num2str(number) 89 | result = f"{sign}{number}" 90 | return result 91 | 92 | 93 | # 编号-无符号整形 94 | # 00078 95 | RE_DEFAULT_NUM = re.compile(r'\d{3}\d*') 96 | 97 | 98 | def replace_default_num(match): 99 | """ 100 | Args: 101 | match (re.Match) 102 | Returns: 103 | str 104 | """ 105 | number = match.group(0) 106 | return verbalize_digit(number, alt_one=False) 107 | 108 | 109 | # 数字表达式 110 | # 纯小数 111 | RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))') 112 | # 正整数 + 量词 113 | RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS) 114 | RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))') 115 | 116 | 117 | def replace_positive_quantifier(match) -> str: 118 | """ 119 | Args: 120 | match (re.Match) 121 | Returns: 122 | str 123 | """ 124 | number = match.group(1) 125 | match_2 = match.group(2) 126 | if match_2 == "+": 127 | match_2 = "多" 128 | match_2: str = match_2 if match_2 else "" 129 | quantifiers: str = match.group(3) 130 | number: str = num2str(number) 131 | result = f"{number}{match_2}{quantifiers}" 132 | return result 133 | 134 | 135 | def replace_number(match) -> str: 136 | """ 137 | Args: 138 | match (re.Match) 139 | Returns: 140 | str 141 | """ 142 | sign = match.group(1) 143 | number = match.group(2) 144 | pure_decimal = match.group(5) 145 | if pure_decimal: 146 | result = num2str(pure_decimal) 147 | else: 148 | sign: str = "负" if sign else "" 149 | number: str = num2str(number) 150 | result = f"{sign}{number}" 151 | return result 152 | 153 | 154 | # 范围表达式 155 | # match.group(1) and match.group(8) are copy from RE_NUMBER 156 | 157 | RE_RANGE = re.compile( 158 | r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))[-~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))') 159 | 160 | 161 | def replace_range(match) -> str: 162 | """ 163 | Args: 164 | match (re.Match) 165 | Returns: 166 | str 167 | """ 168 | first, second = match.group(1), match.group(8) 169 | first = RE_NUMBER.sub(replace_number, first) 170 | second = RE_NUMBER.sub(replace_number, second) 171 | result = f"{first}到{second}" 172 | return result 173 | 174 | 175 | def _get_value(value_string: str, use_zero: bool=True) -> List[str]: 176 | stripped = value_string.lstrip('0') 177 | if len(stripped) == 0: 178 | return [] 179 | elif len(stripped) == 1: 180 | if use_zero and len(stripped) < len(value_string): 181 | return [DIGITS['0'], DIGITS[stripped]] 182 | else: 183 | return [DIGITS[stripped]] 184 | else: 185 | largest_unit = next( 186 | power for power in reversed(UNITS.keys()) if power < len(stripped)) 187 | first_part = value_string[:-largest_unit] 188 | second_part = value_string[-largest_unit:] 189 | return _get_value(first_part) + [UNITS[largest_unit]] + _get_value( 190 | second_part) 191 | 192 | 193 | def verbalize_cardinal(value_string: str) -> str: 194 | if not value_string: 195 | return '' 196 | 197 | # 000 -> '零' , 0 -> '零' 198 | value_string = value_string.lstrip('0') 199 | if len(value_string) == 0: 200 | return DIGITS['0'] 201 | 202 | result_symbols = _get_value(value_string) 203 | # verbalized number starting with '一十*' is abbreviated as `十*` 204 | if len(result_symbols) >= 2 and result_symbols[0] == DIGITS[ 205 | '1'] and result_symbols[1] == UNITS[1]: 206 | result_symbols = result_symbols[1:] 207 | return ''.join(result_symbols) 208 | 209 | 210 | def verbalize_digit(value_string: str, alt_one=False) -> str: 211 | result_symbols = [DIGITS[digit] for digit in value_string] 212 | result = ''.join(result_symbols) 213 | if alt_one: 214 | result = result.replace("一", "幺") 215 | return result 216 | 217 | 218 | def num2str(value_string: str) -> str: 219 | integer_decimal = value_string.split('.') 220 | if len(integer_decimal) == 1: 221 | integer = integer_decimal[0] 222 | decimal = '' 223 | elif len(integer_decimal) == 2: 224 | integer, decimal = integer_decimal 225 | else: 226 | raise ValueError( 227 | f"The value string: '${value_string}' has more than one point in it." 228 | ) 229 | 230 | result = verbalize_cardinal(integer) 231 | 232 | decimal = decimal.rstrip('0') 233 | if decimal: 234 | # '.22' is verbalized as '零点二二' 235 | # '3.20' is verbalized as '三点二 236 | result = result if result else "零" 237 | result += '点' + verbalize_digit(decimal) 238 | return result 239 | -------------------------------------------------------------------------------- /uilib/zh_normalization/phonecode.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | 16 | from .num import verbalize_digit 17 | 18 | # 规范化固话/手机号码 19 | # 手机 20 | # http://www.jihaoba.com/news/show/13680 21 | # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 22 | # 联通:130、131、132、156、155、186、185、176 23 | # 电信:133、153、189、180、181、177 24 | RE_MOBILE_PHONE = re.compile( 25 | r"(? str: 34 | if mobile: 35 | sp_parts = phone_string.strip('+').split() 36 | result = ','.join( 37 | [verbalize_digit(part, alt_one=True) for part in sp_parts]) 38 | return result 39 | else: 40 | sil_parts = phone_string.split('-') 41 | result = ','.join( 42 | [verbalize_digit(part, alt_one=True) for part in sil_parts]) 43 | return result 44 | 45 | 46 | def replace_phone(match) -> str: 47 | """ 48 | Args: 49 | match (re.Match) 50 | Returns: 51 | str 52 | """ 53 | return phone2str(match.group(0), mobile=False) 54 | 55 | 56 | def replace_mobile(match) -> str: 57 | """ 58 | Args: 59 | match (re.Match) 60 | Returns: 61 | str 62 | """ 63 | return phone2str(match.group(0)) 64 | -------------------------------------------------------------------------------- /uilib/zh_normalization/quantifier.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | 16 | from .num import num2str 17 | 18 | # 温度表达式,温度会影响负号的读法 19 | # -3°C 零下三度 20 | RE_TEMPERATURE = re.compile(r'(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)') 21 | measure_dict = { 22 | "cm2": "平方厘米", 23 | "cm²": "平方厘米", 24 | "cm3": "立方厘米", 25 | "cm³": "立方厘米", 26 | "cm": "厘米", 27 | "db": "分贝", 28 | "ds": "毫秒", 29 | "kg": "千克", 30 | "km": "千米", 31 | "m2": "平方米", 32 | "m²": "平方米", 33 | "m³": "立方米", 34 | "m3": "立方米", 35 | "ml": "毫升", 36 | "m": "米", 37 | "mm": "毫米", 38 | "s": "秒" 39 | } 40 | 41 | 42 | def replace_temperature(match) -> str: 43 | """ 44 | Args: 45 | match (re.Match) 46 | Returns: 47 | str 48 | """ 49 | sign = match.group(1) 50 | temperature = match.group(2) 51 | unit = match.group(3) 52 | sign: str = "零下" if sign else "" 53 | temperature: str = num2str(temperature) 54 | unit: str = "摄氏度" if unit == "摄氏度" else "度" 55 | result = f"{sign}{temperature}{unit}" 56 | return result 57 | 58 | 59 | def replace_measure(sentence) -> str: 60 | for q_notation in measure_dict: 61 | if q_notation in sentence and re.search(r'\d{q_notation}', sentence): 62 | sentence = sentence.replace(q_notation, measure_dict[q_notation]) 63 | return sentence 64 | -------------------------------------------------------------------------------- /uilib/zh_normalization/text_normlization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | from typing import List 16 | 17 | from .char_convert import tranditional_to_simplified 18 | from .chronology import RE_DATE 19 | from .chronology import RE_DATE2 20 | from .chronology import RE_TIME 21 | from .chronology import RE_TIME_RANGE 22 | from .chronology import replace_date 23 | from .chronology import replace_date2 24 | from .chronology import replace_time 25 | from .constants import F2H_ASCII_LETTERS 26 | from .constants import F2H_DIGITS 27 | from .constants import F2H_SPACE 28 | from .num import RE_DECIMAL_NUM 29 | from .num import RE_DEFAULT_NUM 30 | from .num import RE_FRAC 31 | from .num import RE_INTEGER 32 | from .num import RE_NUMBER 33 | from .num import RE_PERCENTAGE 34 | from .num import RE_POSITIVE_QUANTIFIERS 35 | from .num import RE_RANGE 36 | from .num import replace_default_num 37 | from .num import replace_frac 38 | from .num import replace_negative_num 39 | from .num import replace_number 40 | from .num import replace_percentage 41 | from .num import replace_positive_quantifier 42 | from .num import replace_range 43 | from .phonecode import RE_MOBILE_PHONE 44 | from .phonecode import RE_NATIONAL_UNIFORM_NUMBER 45 | from .phonecode import RE_TELEPHONE 46 | from .phonecode import replace_mobile 47 | from .phonecode import replace_phone 48 | from .quantifier import RE_TEMPERATURE 49 | from .quantifier import replace_measure 50 | from .quantifier import replace_temperature 51 | 52 | 53 | class TextNormalizer(): 54 | def __init__(self): 55 | self.SENTENCE_SPLITOR = re.compile(r'([:、,;。?!,;?!][”’]?)') 56 | 57 | def _split(self, text: str, lang="zh") -> List[str]: 58 | """Split long text into sentences with sentence-splitting punctuations. 59 | Args: 60 | text (str): The input text. 61 | Returns: 62 | List[str]: Sentences. 63 | 64 | character_map = { 65 | ":": ",", 66 | ";": ",", 67 | "!": "。", 68 | "(": ",", 69 | ")": ",", 70 | "【": ",", 71 | "】": ",", 72 | "『": ",", 73 | "』": ",", 74 | "「": ",", 75 | "」": ",", 76 | "《": ",", 77 | "》": ",", 78 | "-": ",", 79 | "‘": " ", 80 | "“": " ", 81 | "’": " ", 82 | "”": " ", 83 | '"': " ", 84 | "'": " ", 85 | ":": ",", 86 | ";": ",", 87 | "!": ".", 88 | "(": ",", 89 | ")": ",", 90 | "[": ",", 91 | "]": ",", 92 | ">": ",", 93 | "<": ",", 94 | "-": ",", 95 | } 96 | """ 97 | # Only for pure Chinese here 98 | if lang == "zh": 99 | #text = text.replace(" ", "") 100 | # 过滤掉特殊字符 101 | text = re.sub(r'[——《》【】<>{}()()#&@“”^|…\\]', '', text) 102 | text = self.SENTENCE_SPLITOR.sub(r'\1\n', text) 103 | text = text.strip() 104 | sentences = [sentence.strip() for sentence in re.split(r'\n+', text)] 105 | return sentences 106 | 107 | def _post_replace(self, sentence: str) -> str: 108 | 109 | 110 | #sentence = sentence.replace('/', '每') 111 | sentence = sentence.replace('~', '至') 112 | sentence = sentence.replace('~', '至') 113 | sentence = sentence.replace('①', '一') 114 | sentence = sentence.replace('②', '二') 115 | sentence = sentence.replace('③', '三') 116 | sentence = sentence.replace('④', '四') 117 | sentence = sentence.replace('⑤', '五') 118 | sentence = sentence.replace('⑥', '六') 119 | sentence = sentence.replace('⑦', '七') 120 | sentence = sentence.replace('⑧', '八') 121 | sentence = sentence.replace('⑨', '九') 122 | sentence = sentence.replace('⑩', '十') 123 | sentence = sentence.replace('α', '阿尔法') 124 | sentence = sentence.replace('β', '贝塔') 125 | sentence = sentence.replace('γ', '伽玛').replace('Γ', '伽玛') 126 | sentence = sentence.replace('δ', '德尔塔').replace('Δ', '德尔塔') 127 | sentence = sentence.replace('ε', '艾普西龙') 128 | sentence = sentence.replace('ζ', '捷塔') 129 | sentence = sentence.replace('η', '依塔') 130 | sentence = sentence.replace('θ', '西塔').replace('Θ', '西塔') 131 | sentence = sentence.replace('ι', '艾欧塔') 132 | sentence = sentence.replace('κ', '喀帕') 133 | sentence = sentence.replace('λ', '拉姆达').replace('Λ', '拉姆达') 134 | sentence = sentence.replace('μ', '缪') 135 | sentence = sentence.replace('ν', '拗') 136 | sentence = sentence.replace('ξ', '克西').replace('Ξ', '克西') 137 | sentence = sentence.replace('ο', '欧米克伦') 138 | sentence = sentence.replace('π', '派').replace('Π', '派') 139 | sentence = sentence.replace('ρ', '肉') 140 | sentence = sentence.replace('ς', '西格玛').replace('Σ', '西格玛').replace( 141 | 'σ', '西格玛') 142 | sentence = sentence.replace('τ', '套') 143 | sentence = sentence.replace('υ', '宇普西龙') 144 | sentence = sentence.replace('φ', '服艾').replace('Φ', '服艾') 145 | sentence = sentence.replace('χ', '器') 146 | sentence = sentence.replace('ψ', '普赛').replace('Ψ', '普赛') 147 | sentence = sentence.replace('ω', '欧米伽').replace('Ω', '欧米伽') 148 | sentence = sentence.replace('+', '加') 149 | 150 | 151 | # re filter special characters, have one more character "-" than line 68 152 | sentence = re.sub(r'[-——《》【】<=>{}()()#&@“”^|…\\]', '', sentence) 153 | return sentence 154 | 155 | # 数字转为中文读法 156 | def num_to_chinese(self,num): 157 | num_str = str(num) 158 | chinese_digits = "零一二三四五六七八九" 159 | units = ["", "十", "百", "千"] 160 | big_units = ["", "万", "亿", "兆"] 161 | result = "" 162 | zero_flag = False # 标记是否需要加'零' 163 | part = [] # 存储每4位的数字 164 | 165 | # 将数字按每4位分组 166 | while num_str: 167 | part.append(num_str[-4:]) 168 | num_str = num_str[:-4] 169 | 170 | for i in range(len(part)): 171 | part_str = "" 172 | part_zero_flag = False 173 | for j in range(len(part[i])): 174 | digit = int(part[i][j]) 175 | if digit == 0: 176 | part_zero_flag = True 177 | else: 178 | if part_zero_flag or (zero_flag and i > 0 and not result.startswith(chinese_digits[0])): 179 | part_str += chinese_digits[0] 180 | zero_flag = False 181 | part_zero_flag = False 182 | part_str += chinese_digits[digit] + units[len(part[i]) - j - 1] 183 | if part_str.endswith("零"): 184 | part_str = part_str[:-1] # 去除尾部的'零' 185 | if part_str: 186 | zero_flag = True 187 | 188 | if i > 0 and not set(part[i]) <= {'0'}: # 如果当前部分不全是0,则加上相应的大单位 189 | result = part_str + big_units[i] + result 190 | else: 191 | result = part_str + result 192 | 193 | # 处理输入为0的情况或者去掉开头的零 194 | result = result.lstrip(chinese_digits[0]) 195 | if not result: 196 | return chinese_digits[0] 197 | 198 | return result 199 | 200 | def normalize_sentence(self, sentence: str) -> str: 201 | 202 | # basic character conversions 203 | # add 204 | sentence = re.sub(r'(\d+)\s*[\*xX]\s*(\d+)', r'\1 乘 \2', sentence,re.I) 205 | # 区号 电话 分机 206 | sentence = re.sub(r'(0\d+)\-(\d{3,})\-(\d{3,})', r'\1杠\2杠\3', sentence,re.I) 207 | sentence = re.sub(r'(0\d+)\-(\d{3,})', r'\1杠\2', sentence,re.I) 208 | sentence = sentence.replace('=', '等于') 209 | sentence = sentence.replace('÷','除以') 210 | 211 | #sentence = re.sub(r'(\d+)\s*\-', r'\1 减', sentence) 212 | sentence = re.sub(r'((?:\d+\.)?\d+)\s*/\s*(\d+)', r'\2分之\1', sentence) 213 | 214 | # 取出数字 number_list= [('1000200030004000.123', '1000200030004000', '123'), ('23425', '23425', '')] 215 | number_list=re.findall(r'((\d+)(?:\.(\d+))?%?)', sentence) 216 | numtext=['零','一','二','三','四','五','六','七','八','九'] 217 | if len(number_list)>0: 218 | #dc= ('1000200030004000.123', '1000200030004000', '123','') 219 | for m,dc in enumerate(number_list): 220 | n_len=len(dc[1]) 221 | #手机号/座机号 超大数 亿内的数 0开头的数,不做处理 222 | if n_len>16 or n_len<9 or (n_len==11 and str(dc[1])[0]=='1') or str(dc[1])[0]=='0': 223 | continue 224 | int_text=self.num_to_chinese(dc[1]) 225 | if len(dc)>2 and dc[2]: 226 | int_text+="点"+"".join([numtext[int(i)] for i in dc[2]]) 227 | if dc[0][-1]=='%': 228 | int_text=f'百分之{int_text}' 229 | sentence=sentence.replace(dc[0],int_text) 230 | 231 | sentence = tranditional_to_simplified(sentence) 232 | sentence = sentence.translate(F2H_ASCII_LETTERS).translate( 233 | F2H_DIGITS).translate(F2H_SPACE) 234 | 235 | # number related NSW verbalization 236 | sentence = RE_DATE.sub(replace_date, sentence) 237 | sentence = RE_DATE2.sub(replace_date2, sentence) 238 | 239 | # range first 240 | sentence = RE_TIME_RANGE.sub(replace_time, sentence) 241 | sentence = RE_TIME.sub(replace_time, sentence) 242 | 243 | sentence = RE_TEMPERATURE.sub(replace_temperature, sentence) 244 | sentence = replace_measure(sentence) 245 | sentence = RE_FRAC.sub(replace_frac, sentence) 246 | sentence = RE_PERCENTAGE.sub(replace_percentage, sentence) 247 | sentence = RE_MOBILE_PHONE.sub(replace_mobile, sentence) 248 | 249 | sentence = RE_TELEPHONE.sub(replace_phone, sentence) 250 | sentence = RE_NATIONAL_UNIFORM_NUMBER.sub(replace_phone, sentence) 251 | 252 | sentence = RE_RANGE.sub(replace_range, sentence) 253 | sentence = RE_INTEGER.sub(replace_negative_num, sentence) 254 | sentence = RE_DECIMAL_NUM.sub(replace_number, sentence) 255 | sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier, 256 | sentence) 257 | sentence = RE_DEFAULT_NUM.sub(replace_default_num, sentence) 258 | sentence = RE_NUMBER.sub(replace_number, sentence) 259 | sentence = self._post_replace(sentence) 260 | 261 | sentence = sentence.replace('[一break]','[1break]') 262 | 263 | return sentence 264 | 265 | def normalize(self, text: str) -> List[str]: 266 | sentences = self._split(text) 267 | sentences = [self.normalize_sentence(sent) for sent in sentences] 268 | return sentences 269 | --------------------------------------------------------------------------------