├── README.md ├── __init__.py ├── client.py ├── data └── 动画.gif ├── docker-compose.yml ├── fastapi ├── ChatTTS │ ├── __init__.py │ ├── __main__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ └── core.cpython-38.pyc │ ├── cli.py │ ├── core.py │ ├── experimental │ │ ├── __init__.py │ │ └── llm.py │ ├── infer │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── api.cpython-38.pyc │ │ └── api.py │ ├── model │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── dvae.cpython-38.pyc │ │ │ └── gpt.cpython-38.pyc │ │ ├── dvae.py │ │ └── gpt.py │ └── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── gpu_utils.cpython-38.pyc │ │ ├── infer_utils.cpython-38.pyc │ │ └── io_utils.cpython-38.pyc │ │ ├── gpu_utils.py │ │ ├── infer_utils.py │ │ └── io_utils.py ├── Dockerfile ├── __init__.py ├── __pycache__ │ └── server.cpython-38.pyc ├── output.wav ├── requirements.txt └── server.py └── streamlit ├── Dockerfile ├── requirements.txt └── ui.py /README.md: -------------------------------------------------------------------------------- 1 | ## **一、项目简介** 2 | 使用 FastAPI 和 Streamlit 本地部署 ChatTTS 文本转语音模型,并通过 Docker Compose 进行容器化部署。 3 | 4 | **操作流程demo:** 5 | 6 | ![语音合成](data/动画.gif) 7 | 8 | ## **二、本地安装使用** 9 | 10 | **环境依赖:** 11 | 12 | ```bash 13 | cuda12.1 14 | pip install requirements.txt 15 | ``` 16 | 17 | **程序运行方式:** 18 | 19 | - 启动FastAPI:用于 API 接口 20 | 21 | ```bash 22 | cd fastapi 23 | uvicorn server:app --host "0.0.0.0" --port 8000 24 | ``` 25 | 26 | - 启动Streamlit:用于网页 27 | 28 | ```bash 29 | cd streamlit 30 | streamlit run ui.py 31 | ``` 32 | 33 | - 访问网页:http://localhost:8501 34 | - 本地使用示例 35 | 36 | ```bash 37 | curl -X POST -H 'content-type: application/json' -d\ 38 | '{"text":"朋友你好啊,今天天气怎么样 ?", "output_path": "abc.wav", "seed":232}' \ 39 | http://localhost:8000/tts 40 | ``` 41 | 42 | - 参数说明: 43 | 44 | text:要合成的文本 45 | 46 | output_path:合成音频的保存路径 47 | 48 | seed:音色种子,不同的种子会产生不同的音色,默认为 697(测试的一个比较好的音色) 49 | 50 | - 运行客户端 51 | 52 | ```bash 53 | python client.py 54 | ``` 55 | 56 | ## **三、Docker 部署** 57 | 58 | docker compose build 59 | docker compose up 60 | 61 | 这个命令将会: 62 | 63 | 构建FastAPI和Streamlit服务的Docker镜像。 64 | 65 | 启动两个服务,将FastAPI暴露在8000端口,Streamlit暴露在8501端口。 66 | 67 | 要访问生成的服务的 FastAPI 文档,请使用 Web 浏览器访问 http://localhost:8000/docs 68 | 69 | 要访问UI,请访问 http://localhost:8501 70 | 71 | 可以通过以下方式检查日志: 72 | 73 | docker compose logs 74 | 75 | ## **四、 参考** 76 | - https://github.com/ultrasev/ChatTTS 77 | - https://github.com/2noise/ChatTTS 78 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanquderzi/ChatTTS-Deployment-using-FastAPI-and-Streamlit/a39fecac6f49535a5d2a376ea9f61adbc3b8871e/__init__.py -------------------------------------------------------------------------------- /client.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | 4 | def synthesize_speech(text, output_path, seed, url="http://localhost:8000/tts"): 5 | # Define the payload for the POST request 6 | payload = { 7 | "text": text, 8 | "output_path": output_path, 9 | "seed": seed 10 | } 11 | 12 | # Define the headers for the POST request 13 | headers = { 14 | "content-type": "application/json" 15 | } 16 | 17 | try: 18 | # Make the POST request to the FastAPI TTS endpoint 19 | response = requests.post(url, headers=headers, data=json.dumps(payload)) 20 | 21 | # Check if the request was successful 22 | if response.status_code == 200: 23 | print(f"Speech synthesis succeeded. Output saved to {output_path}.") 24 | else: 25 | print(f"Failed to synthesize speech. Status code: {response.status_code}") 26 | print(f"Response: {response.text}") 27 | 28 | except Exception as e: 29 | print(f"An error occurred: {e}") 30 | 31 | # Example usage 32 | synthesize_speech("朋友你好啊,今天天气怎么样 ?", "output.wav", 232) 33 | 34 | -------------------------------------------------------------------------------- /data/动画.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanquderzi/ChatTTS-Deployment-using-FastAPI-and-Streamlit/a39fecac6f49535a5d2a376ea9f61adbc3b8871e/data/动画.gif -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | 3 | services: 4 | fastapi: 5 | build: fastapi/ 6 | ports: 7 | - 8001:8001 8 | networks: 9 | - deploy_network 10 | container_name: fastapi 11 | volumes: 12 | - ./fastapi:/fastapi # 将 fastapi 的目录映射到 本地 13 | 14 | streamlit: 15 | build: streamlit/ 16 | depends_on: 17 | - fastapi 18 | ports: 19 | - 8501:8501 20 | networks: 21 | - deploy_network 22 | container_name: streamlit 23 | volumes: 24 | - ./streamlit:/streamlit # 将 streamlit 的目录映射到 本地 25 | - ./fastapi:/fastapi # 将 fastapi 的目录映射到 streamlit 中 26 | 27 | networks: 28 | deploy_network: 29 | driver: bridge 30 | -------------------------------------------------------------------------------- /fastapi/ChatTTS/__init__.py: -------------------------------------------------------------------------------- 1 | from ChatTTS.core import Chat 2 | -------------------------------------------------------------------------------- /fastapi/ChatTTS/__main__.py: -------------------------------------------------------------------------------- 1 | from ChatTTS.cli import main 2 | 3 | if __name__ == "__main__": 4 | main() 5 | -------------------------------------------------------------------------------- /fastapi/ChatTTS/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanquderzi/ChatTTS-Deployment-using-FastAPI-and-Streamlit/a39fecac6f49535a5d2a376ea9f61adbc3b8871e/fastapi/ChatTTS/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /fastapi/ChatTTS/__pycache__/core.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanquderzi/ChatTTS-Deployment-using-FastAPI-and-Streamlit/a39fecac6f49535a5d2a376ea9f61adbc3b8871e/fastapi/ChatTTS/__pycache__/core.cpython-38.pyc -------------------------------------------------------------------------------- /fastapi/ChatTTS/cli.py: -------------------------------------------------------------------------------- 1 | from .core import Chat 2 | import argparse 3 | import numpy as np 4 | import wave 5 | 6 | 7 | def main(): 8 | # cli args 9 | ap = argparse.ArgumentParser(description="Your text to tts") 10 | ap.add_argument("text", type=str, help="Your text") 11 | ap.add_argument( 12 | "-o", "--out-file", help="out file name", default="tts.wav", dest="out_file" 13 | ) 14 | ap.add_argument( 15 | "-s", "--seed", help="out file name", type=int, default=None, dest="seed" 16 | ) 17 | 18 | args = ap.parse_args() 19 | out_file = args.out_file 20 | text = args.text 21 | if not text: 22 | raise ValueError("text is empty") 23 | 24 | chat = Chat() 25 | try: 26 | chat.load_models() 27 | except Exception as e: 28 | # this is a tricky for most newbies do not now the args for cli 29 | print("The model maybe broke will load again") 30 | chat.load_models(force_redownload=True) 31 | texts = [ 32 | text, 33 | ] 34 | if args.seed: 35 | r = chat.sample_random_speaker(seed=args.seed) 36 | params_infer_code = { 37 | "spk_emb": r, # add sampled speaker 38 | "temperature": 0.3, # using custom temperature 39 | "top_P": 0.7, # top P decode 40 | "top_K": 20, # top K decode 41 | } 42 | wavs = chat.infer(texts, use_decoder=True, params_infer_code=params_infer_code) 43 | else: 44 | wavs = chat.infer(texts, use_decoder=True) 45 | 46 | audio_data = np.array(wavs[0], dtype=np.float32) 47 | sample_rate = 24000 48 | audio_data = (audio_data * 32767).astype(np.int16) 49 | 50 | with wave.open(out_file, "w") as wf: 51 | wf.setnchannels(1) # Mono channel 52 | wf.setsampwidth(2) # 2 bytes per sample 53 | wf.setframerate(sample_rate) 54 | wf.writeframes(audio_data.tobytes()) 55 | if args.seed: 56 | print(f"Generate Done for file {out_file} with seed {args.seed}") 57 | else: 58 | print(f"Generate Done for file {out_file}") 59 | -------------------------------------------------------------------------------- /fastapi/ChatTTS/core.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import platform 4 | import logging 5 | from omegaconf import OmegaConf 6 | import random 7 | import numpy as np 8 | 9 | import torch 10 | from vocos import Vocos 11 | from ChatTTS.model.dvae import DVAE 12 | from ChatTTS.model.gpt import GPT_warpper 13 | from ChatTTS.utils.gpu_utils import select_device 14 | from ChatTTS.utils.io_utils import get_latest_modified_file 15 | from ChatTTS.infer.api import refine_text, infer_code 16 | 17 | from huggingface_hub import snapshot_download 18 | 19 | logging.basicConfig(level=logging.INFO) 20 | 21 | 22 | class Chat: 23 | def __init__( 24 | self, 25 | ): 26 | self.pretrain_models = {} 27 | self.logger = logging.getLogger(__name__) 28 | 29 | def check_model(self, level=logging.INFO, use_decoder=False): 30 | not_finish = False 31 | check_list = ["vocos", "gpt", "tokenizer"] 32 | 33 | if use_decoder: 34 | check_list.append("decoder") 35 | else: 36 | check_list.append("dvae") 37 | 38 | for module in check_list: 39 | if module not in self.pretrain_models: 40 | self.logger.log(logging.WARNING, f"{module} not initialized.") 41 | not_finish = True 42 | 43 | if not not_finish: 44 | self.logger.log(level, f"All initialized.") 45 | 46 | return not not_finish 47 | 48 | def load_models( 49 | self, source="huggingface", force_redownload=False, local_path="" 50 | ): 51 | if source == "huggingface": 52 | hf_home = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface")) 53 | try: 54 | download_path = get_latest_modified_file( 55 | os.path.join(hf_home, "hub/models--2Noise--ChatTTS/snapshots") 56 | ) 57 | except: 58 | download_path = None 59 | if download_path is None or force_redownload: 60 | self.logger.log( 61 | logging.INFO, 62 | f"Download from HF: https://huggingface.co/2Noise/ChatTTS", 63 | ) 64 | download_path = snapshot_download( 65 | repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"] 66 | ) 67 | else: 68 | self.logger.log(logging.INFO, f"Load from cache: {download_path}") 69 | self._load( 70 | **{ 71 | k: os.path.join(download_path, v) 72 | for k, v in OmegaConf.load( 73 | os.path.join(download_path, "config", "path.yaml") 74 | ).items() 75 | } 76 | ) 77 | elif source == "local": 78 | self.logger.log(logging.INFO, f"Load from local: {local_path}") 79 | self._load( 80 | **{ 81 | k: os.path.join(local_path, v) 82 | for k, v in OmegaConf.load( 83 | os.path.join(local_path, "config", "path.yaml") 84 | ).items() 85 | } 86 | ) 87 | 88 | def _load( 89 | self, 90 | vocos_config_path: str = None, 91 | vocos_ckpt_path: str = None, 92 | dvae_config_path: str = None, 93 | dvae_ckpt_path: str = None, 94 | gpt_config_path: str = None, 95 | gpt_ckpt_path: str = None, 96 | decoder_config_path: str = None, 97 | decoder_ckpt_path: str = None, 98 | tokenizer_path: str = None, 99 | device: str = None, 100 | ): 101 | if not device: 102 | device = select_device(4096) 103 | self.logger.log(logging.INFO, f"use {device}") 104 | 105 | if vocos_config_path: 106 | vocos = Vocos.from_hparams(vocos_config_path).to(device).eval() 107 | assert vocos_ckpt_path, "vocos_ckpt_path should not be None" 108 | vocos.load_state_dict(torch.load(vocos_ckpt_path, map_location=device)) 109 | self.pretrain_models["vocos"] = vocos 110 | self.logger.log(logging.INFO, "vocos loaded.") 111 | 112 | if dvae_config_path: 113 | cfg = OmegaConf.load(dvae_config_path) 114 | dvae = DVAE(**cfg).to(device).eval() 115 | assert dvae_ckpt_path, "dvae_ckpt_path should not be None" 116 | dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location=device)) 117 | self.pretrain_models["dvae"] = dvae 118 | self.logger.log(logging.INFO, "dvae loaded.") 119 | 120 | if gpt_config_path: 121 | cfg = OmegaConf.load(gpt_config_path) 122 | gpt = GPT_warpper(**cfg).to(device).eval() 123 | assert gpt_ckpt_path, "gpt_ckpt_path should not be None" 124 | gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=device)) 125 | self.pretrain_models["gpt"] = gpt 126 | spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt") 127 | assert os.path.exists( 128 | spk_stat_path 129 | ), f"Missing spk_stat.pt: {spk_stat_path}" 130 | self.pretrain_models["spk_stat"] = torch.load( 131 | spk_stat_path, map_location=device 132 | ).to(device) 133 | self.logger.log(logging.INFO, "gpt loaded.") 134 | 135 | if decoder_config_path: 136 | cfg = OmegaConf.load(decoder_config_path) 137 | decoder = DVAE(**cfg).to(device).eval() 138 | assert decoder_ckpt_path, "decoder_ckpt_path should not be None" 139 | decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location=device)) 140 | self.pretrain_models["decoder"] = decoder 141 | self.logger.log(logging.INFO, "decoder loaded.") 142 | 143 | if tokenizer_path: 144 | tokenizer = torch.load(tokenizer_path, map_location=device) 145 | tokenizer.padding_side = "left" 146 | self.pretrain_models["tokenizer"] = tokenizer 147 | self.logger.log(logging.INFO, "tokenizer loaded.") 148 | 149 | self.check_model() 150 | 151 | def infer( 152 | self, 153 | text, 154 | skip_refine_text=False, 155 | refine_text_only=False, 156 | params_refine_text={}, 157 | params_infer_code={}, 158 | use_decoder=False, 159 | ): 160 | 161 | assert self.check_model(use_decoder=use_decoder) 162 | 163 | if not skip_refine_text: 164 | text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)[ 165 | "ids" 166 | ] 167 | text_tokens = [ 168 | i[ 169 | i 170 | < self.pretrain_models["tokenizer"].convert_tokens_to_ids( 171 | "[break_0]" 172 | ) 173 | ] 174 | for i in text_tokens 175 | ] 176 | text = self.pretrain_models["tokenizer"].batch_decode(text_tokens) 177 | if refine_text_only: 178 | return text 179 | 180 | text = [params_infer_code.get("prompt", "") + i for i in text] 181 | params_infer_code.pop("prompt", "") 182 | result = infer_code( 183 | self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder 184 | ) 185 | 186 | if use_decoder: 187 | mel_spec = [ 188 | self.pretrain_models["decoder"](i[None].permute(0, 2, 1)) 189 | for i in result["hiddens"] 190 | ] 191 | else: 192 | mel_spec = [ 193 | self.pretrain_models["dvae"](i[None].permute(0, 2, 1)) 194 | for i in result["ids"] 195 | ] 196 | wav = [self.pretrain_models["vocos"].decode(i).cpu().numpy() for i in mel_spec] 197 | 198 | return wav 199 | 200 | def sample_random_speaker(self, seed): 201 | torch.manual_seed(seed) 202 | dim = self.pretrain_models["gpt"].gpt.layers[0].mlp.gate_proj.in_features 203 | std, mean = self.pretrain_models["spk_stat"].chunk(2) 204 | return torch.randn(dim, device=std.device) * std + mean 205 | -------------------------------------------------------------------------------- /fastapi/ChatTTS/experimental/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanquderzi/ChatTTS-Deployment-using-FastAPI-and-Streamlit/a39fecac6f49535a5d2a376ea9f61adbc3b8871e/fastapi/ChatTTS/experimental/__init__.py -------------------------------------------------------------------------------- /fastapi/ChatTTS/experimental/llm.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | 3 | prompt_dict = { 4 | "kimi": [ 5 | { 6 | "role": "system", 7 | "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。", 8 | }, 9 | { 10 | "role": "user", 11 | "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。", 12 | }, 13 | { 14 | "role": "assistant", 15 | "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。", 16 | }, 17 | ], 18 | "deepseek": [ 19 | {"role": "system", "content": "You are a helpful assistant"}, 20 | { 21 | "role": "user", 22 | "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。", 23 | }, 24 | { 25 | "role": "assistant", 26 | "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。", 27 | }, 28 | ], 29 | "deepseek_TN": [ 30 | {"role": "system", "content": "You are a helpful assistant"}, 31 | { 32 | "role": "user", 33 | "content": "你好,现在我们在处理TTS的文本输入,下面将会给你输入一段文本,请你将其中的阿拉伯数字等等转为文字表达,并且输出的文本里仅包含逗号和句号这两个标点符号", 34 | }, 35 | { 36 | "role": "assistant", 37 | "content": "好的,我现在对TTS的文本输入进行处理。这一般叫做text normalization。下面请输入", 38 | }, 39 | {"role": "user", "content": "We paid $123 for this desk."}, 40 | { 41 | "role": "assistant", 42 | "content": "We paid one hundred and twenty three dollars for this desk.", 43 | }, 44 | {"role": "user", "content": "详询请拨打010-724654"}, 45 | {"role": "assistant", "content": "详询请拨打零幺零,七二四六五四"}, 46 | {"role": "user", "content": "罗森宣布将于7月24日退市,在华门店超6000家!"}, 47 | { 48 | "role": "assistant", 49 | "content": "罗森宣布将于七月二十四日退市,在华门店超过六千家。", 50 | }, 51 | ], 52 | } 53 | 54 | 55 | class llm_api: 56 | def __init__(self, api_key, base_url, model): 57 | self.client = OpenAI( 58 | api_key=api_key, 59 | base_url=base_url, 60 | ) 61 | self.model = model 62 | 63 | def call(self, user_question, temperature=0.3, prompt_version="kimi", **kwargs): 64 | 65 | completion = self.client.chat.completions.create( 66 | model=self.model, 67 | messages=prompt_dict[prompt_version] 68 | + [ 69 | {"role": "user", "content": user_question}, 70 | ], 71 | temperature=temperature, 72 | **kwargs 73 | ) 74 | return completion.choices[0].message.content 75 | -------------------------------------------------------------------------------- /fastapi/ChatTTS/infer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanquderzi/ChatTTS-Deployment-using-FastAPI-and-Streamlit/a39fecac6f49535a5d2a376ea9f61adbc3b8871e/fastapi/ChatTTS/infer/__init__.py -------------------------------------------------------------------------------- /fastapi/ChatTTS/infer/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanquderzi/ChatTTS-Deployment-using-FastAPI-and-Streamlit/a39fecac6f49535a5d2a376ea9f61adbc3b8871e/fastapi/ChatTTS/infer/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /fastapi/ChatTTS/infer/__pycache__/api.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanquderzi/ChatTTS-Deployment-using-FastAPI-and-Streamlit/a39fecac6f49535a5d2a376ea9f61adbc3b8871e/fastapi/ChatTTS/infer/__pycache__/api.cpython-38.pyc -------------------------------------------------------------------------------- /fastapi/ChatTTS/infer/api.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from transformers.generation import TopKLogitsWarper, TopPLogitsWarper 4 | from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat 5 | 6 | 7 | def infer_code( 8 | models, 9 | text, 10 | spk_emb=None, 11 | top_P=0.7, 12 | top_K=20, 13 | temperature=0.3, 14 | repetition_penalty=1.05, 15 | max_new_token=2048, 16 | **kwargs, 17 | ): 18 | 19 | device = next(models["gpt"].parameters()).device 20 | 21 | if not isinstance(text, list): 22 | text = [text] 23 | 24 | if not isinstance(temperature, list): 25 | temperature = [temperature] * models["gpt"].num_vq 26 | 27 | if spk_emb is not None: 28 | text = [f"[Stts][spk_emb]{i}[uv_break][Ptts]" for i in text] 29 | else: 30 | text = [f"[Stts][empty_spk]{i}[uv_break][Ptts]" for i in text] 31 | 32 | text_token = models["tokenizer"]( 33 | text, return_tensors="pt", add_special_tokens=False, padding=True 34 | ).to(device) 35 | input_ids = text_token["input_ids"][..., None].expand(-1, -1, models["gpt"].num_vq) 36 | text_mask = torch.ones(text_token["input_ids"].shape, dtype=bool, device=device) 37 | 38 | inputs = { 39 | "input_ids": input_ids, 40 | "text_mask": text_mask, 41 | "attention_mask": text_token["attention_mask"], 42 | } 43 | 44 | emb = models["gpt"].get_emb(**inputs) 45 | if spk_emb is not None: 46 | emb[ 47 | inputs["input_ids"][..., 0] 48 | == models["tokenizer"].convert_tokens_to_ids("[spk_emb]") 49 | ] = F.normalize( 50 | spk_emb.to(device).to(emb.dtype)[None].expand(len(text), -1), 51 | p=2.0, 52 | dim=1, 53 | eps=1e-12, 54 | ) 55 | 56 | num_code = models["gpt"].emb_code[0].num_embeddings - 1 57 | 58 | LogitsWarpers = [] 59 | if top_P is not None: 60 | LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3)) 61 | if top_K is not None: 62 | LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3)) 63 | 64 | LogitsProcessors = [] 65 | if repetition_penalty is not None and repetition_penalty != 1: 66 | LogitsProcessors.append( 67 | CustomRepetitionPenaltyLogitsProcessorRepeat( 68 | repetition_penalty, num_code, 16 69 | ) 70 | ) 71 | 72 | result = models["gpt"].generate( 73 | emb, 74 | inputs["input_ids"], 75 | temperature=torch.tensor(temperature, device=device), 76 | attention_mask=inputs["attention_mask"], 77 | LogitsWarpers=LogitsWarpers, 78 | LogitsProcessors=LogitsProcessors, 79 | eos_token=num_code, 80 | max_new_token=max_new_token, 81 | infer_text=False, 82 | **kwargs, 83 | ) 84 | 85 | return result 86 | 87 | 88 | def refine_text( 89 | models, 90 | text, 91 | top_P=0.7, 92 | top_K=20, 93 | temperature=0.7, 94 | repetition_penalty=1.0, 95 | max_new_token=384, 96 | prompt="", 97 | **kwargs, 98 | ): 99 | 100 | device = next(models["gpt"].parameters()).device 101 | 102 | if not isinstance(text, list): 103 | text = [text] 104 | 105 | assert len(text), "text should not be empty" 106 | 107 | text = [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text] 108 | text_token = models["tokenizer"]( 109 | text, return_tensors="pt", add_special_tokens=False, padding=True 110 | ).to(device) 111 | text_mask = torch.ones(text_token["input_ids"].shape, dtype=bool, device=device) 112 | 113 | inputs = { 114 | "input_ids": text_token["input_ids"][..., None].expand( 115 | -1, -1, models["gpt"].num_vq 116 | ), 117 | "text_mask": text_mask, 118 | "attention_mask": text_token["attention_mask"], 119 | } 120 | 121 | LogitsWarpers = [] 122 | if top_P is not None: 123 | LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3)) 124 | if top_K is not None: 125 | LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3)) 126 | 127 | LogitsProcessors = [] 128 | if repetition_penalty is not None and repetition_penalty != 1: 129 | LogitsProcessors.append( 130 | CustomRepetitionPenaltyLogitsProcessorRepeat( 131 | repetition_penalty, len(models["tokenizer"]), 16 132 | ) 133 | ) 134 | 135 | result = models["gpt"].generate( 136 | models["gpt"].get_emb(**inputs), 137 | inputs["input_ids"], 138 | temperature=torch.tensor( 139 | [ 140 | temperature, 141 | ], 142 | device=device, 143 | ), 144 | attention_mask=inputs["attention_mask"], 145 | LogitsWarpers=LogitsWarpers, 146 | LogitsProcessors=LogitsProcessors, 147 | eos_token=torch.tensor( 148 | models["tokenizer"].convert_tokens_to_ids("[Ebreak]"), device=device 149 | )[None], 150 | max_new_token=max_new_token, 151 | infer_text=True, 152 | **kwargs, 153 | ) 154 | return result 155 | -------------------------------------------------------------------------------- /fastapi/ChatTTS/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanquderzi/ChatTTS-Deployment-using-FastAPI-and-Streamlit/a39fecac6f49535a5d2a376ea9f61adbc3b8871e/fastapi/ChatTTS/model/__init__.py -------------------------------------------------------------------------------- /fastapi/ChatTTS/model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanquderzi/ChatTTS-Deployment-using-FastAPI-and-Streamlit/a39fecac6f49535a5d2a376ea9f61adbc3b8871e/fastapi/ChatTTS/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /fastapi/ChatTTS/model/__pycache__/dvae.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanquderzi/ChatTTS-Deployment-using-FastAPI-and-Streamlit/a39fecac6f49535a5d2a376ea9f61adbc3b8871e/fastapi/ChatTTS/model/__pycache__/dvae.cpython-38.pyc -------------------------------------------------------------------------------- /fastapi/ChatTTS/model/__pycache__/gpt.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanquderzi/ChatTTS-Deployment-using-FastAPI-and-Streamlit/a39fecac6f49535a5d2a376ea9f61adbc3b8871e/fastapi/ChatTTS/model/__pycache__/gpt.cpython-38.pyc -------------------------------------------------------------------------------- /fastapi/ChatTTS/model/dvae.py: -------------------------------------------------------------------------------- 1 | import math 2 | from einops import rearrange 3 | from vector_quantize_pytorch import GroupedResidualFSQ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ConvNeXtBlock(nn.Module): 11 | def __init__( 12 | self, 13 | dim: int, 14 | intermediate_dim: int, 15 | kernel, 16 | dilation, 17 | layer_scale_init_value: float = 1e-6, 18 | ): 19 | # ConvNeXt Block copied from Vocos. 20 | super().__init__() 21 | self.dwconv = nn.Conv1d( 22 | dim, 23 | dim, 24 | kernel_size=kernel, 25 | padding=dilation * (kernel // 2), 26 | dilation=dilation, 27 | groups=dim, 28 | ) # depthwise conv 29 | 30 | self.norm = nn.LayerNorm(dim, eps=1e-6) 31 | self.pwconv1 = nn.Linear( 32 | dim, intermediate_dim 33 | ) # pointwise/1x1 convs, implemented with linear layers 34 | self.act = nn.GELU() 35 | self.pwconv2 = nn.Linear(intermediate_dim, dim) 36 | self.gamma = ( 37 | nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) 38 | if layer_scale_init_value > 0 39 | else None 40 | ) 41 | 42 | def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor: 43 | residual = x 44 | x = self.dwconv(x) 45 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) 46 | x = self.norm(x) 47 | x = self.pwconv1(x) 48 | x = self.act(x) 49 | x = self.pwconv2(x) 50 | if self.gamma is not None: 51 | x = self.gamma * x 52 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) 53 | 54 | x = residual + x 55 | return x 56 | 57 | 58 | class GFSQ(nn.Module): 59 | 60 | def __init__(self, dim, levels, G, R, eps=1e-5, transpose=True): 61 | super(GFSQ, self).__init__() 62 | self.quantizer = GroupedResidualFSQ( 63 | dim=dim, 64 | levels=levels, 65 | num_quantizers=R, 66 | groups=G, 67 | ) 68 | self.n_ind = math.prod(levels) 69 | self.eps = eps 70 | self.transpose = transpose 71 | self.G = G 72 | self.R = R 73 | 74 | def _embed(self, x): 75 | if self.transpose: 76 | x = x.transpose(1, 2) 77 | x = rearrange( 78 | x, 79 | "b t (g r) -> g b t r", 80 | g=self.G, 81 | r=self.R, 82 | ) 83 | feat = self.quantizer.get_output_from_indices(x) 84 | return feat.transpose(1, 2) if self.transpose else feat 85 | 86 | def forward( 87 | self, 88 | x, 89 | ): 90 | if self.transpose: 91 | x = x.transpose(1, 2) 92 | feat, ind = self.quantizer(x) 93 | ind = rearrange( 94 | ind, 95 | "g b t r ->b t (g r)", 96 | ) 97 | embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype) 98 | e_mean = torch.mean(embed_onehot, dim=[0, 1]) 99 | e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1) 100 | perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1)) 101 | 102 | return ( 103 | torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device), 104 | feat.transpose(1, 2) if self.transpose else feat, 105 | perplexity, 106 | None, 107 | ind.transpose(1, 2) if self.transpose else ind, 108 | ) 109 | 110 | 111 | class DVAEDecoder(nn.Module): 112 | def __init__( 113 | self, 114 | idim, 115 | odim, 116 | n_layer=12, 117 | bn_dim=64, 118 | hidden=256, 119 | kernel=7, 120 | dilation=2, 121 | up=False, 122 | ): 123 | super().__init__() 124 | self.up = up 125 | self.conv_in = nn.Sequential( 126 | nn.Conv1d(idim, bn_dim, 3, 1, 1), 127 | nn.GELU(), 128 | nn.Conv1d(bn_dim, hidden, 3, 1, 1), 129 | ) 130 | self.decoder_block = nn.ModuleList( 131 | [ 132 | ConvNeXtBlock( 133 | hidden, 134 | hidden * 4, 135 | kernel, 136 | dilation, 137 | ) 138 | for _ in range(n_layer) 139 | ] 140 | ) 141 | self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False) 142 | 143 | def forward(self, input, conditioning=None): 144 | # B, T, C 145 | x = input.transpose(1, 2) 146 | x = self.conv_in(x) 147 | for f in self.decoder_block: 148 | x = f(x, conditioning) 149 | 150 | x = self.conv_out(x) 151 | return x.transpose(1, 2) 152 | 153 | 154 | class DVAE(nn.Module): 155 | def __init__(self, decoder_config, vq_config, dim=512): 156 | super().__init__() 157 | self.register_buffer("coef", torch.randn(1, 100, 1)) 158 | 159 | self.decoder = DVAEDecoder(**decoder_config) 160 | self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False) 161 | if vq_config is not None: 162 | self.vq_layer = GFSQ(**vq_config) 163 | else: 164 | self.vq_layer = None 165 | 166 | def forward(self, inp): 167 | 168 | if self.vq_layer is not None: 169 | vq_feats = self.vq_layer._embed(inp) 170 | else: 171 | vq_feats = inp.detach().clone() 172 | 173 | temp = torch.chunk(vq_feats, 2, dim=1) # flatten trick :) 174 | temp = torch.stack(temp, -1) 175 | vq_feats = temp.reshape(*temp.shape[:2], -1) 176 | 177 | vq_feats = vq_feats.transpose(1, 2) 178 | dec_out = self.decoder(input=vq_feats) 179 | dec_out = self.out_conv(dec_out.transpose(1, 2)) 180 | mel = dec_out * self.coef 181 | 182 | return mel 183 | -------------------------------------------------------------------------------- /fastapi/ChatTTS/model/gpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 4 | 5 | import logging 6 | from tqdm import tqdm 7 | from einops import rearrange 8 | from transformers.cache_utils import Cache 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.nn.utils.parametrize as P 14 | from torch.nn.utils.parametrizations import weight_norm 15 | from transformers import LlamaModel, LlamaConfig 16 | 17 | 18 | class LlamaMLP(nn.Module): 19 | def __init__(self, hidden_size, intermediate_size): 20 | super().__init__() 21 | self.hidden_size = hidden_size 22 | self.intermediate_size = intermediate_size 23 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 24 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 25 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 26 | self.act_fn = F.silu 27 | 28 | def forward(self, x): 29 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 30 | return down_proj 31 | 32 | 33 | class GPT_warpper(nn.Module): 34 | def __init__( 35 | self, 36 | gpt_config, 37 | num_audio_tokens, 38 | num_text_tokens, 39 | num_vq=4, 40 | **kwargs, 41 | ): 42 | super().__init__() 43 | 44 | self.logger = logging.getLogger(__name__) 45 | self.gpt = self.build_model(gpt_config) 46 | self.model_dim = self.gpt.config.hidden_size 47 | 48 | self.num_vq = num_vq 49 | self.emb_code = nn.ModuleList( 50 | [nn.Embedding(num_audio_tokens, self.model_dim) for i in range(self.num_vq)] 51 | ) 52 | self.emb_text = nn.Embedding(num_text_tokens, self.model_dim) 53 | self.head_text = weight_norm( 54 | nn.Linear(self.model_dim, num_text_tokens, bias=False), name="weight" 55 | ) 56 | self.head_code = nn.ModuleList( 57 | [ 58 | weight_norm( 59 | nn.Linear(self.model_dim, num_audio_tokens, bias=False), 60 | name="weight", 61 | ) 62 | for i in range(self.num_vq) 63 | ] 64 | ) 65 | 66 | def build_model(self, config): 67 | 68 | configuration = LlamaConfig(**config) 69 | model = LlamaModel(configuration) 70 | del model.embed_tokens 71 | 72 | return model 73 | 74 | def get_emb(self, input_ids, text_mask, **kwargs): 75 | 76 | emb_text = self.emb_text(input_ids[text_mask][:, 0]) 77 | 78 | emb_code = [ 79 | self.emb_code[i](input_ids[~text_mask][:, i]) for i in range(self.num_vq) 80 | ] 81 | emb_code = torch.stack(emb_code, 2).sum(2) 82 | 83 | emb = torch.zeros( 84 | (input_ids.shape[:-1]) + (emb_text.shape[-1],), 85 | device=emb_text.device, 86 | dtype=emb_text.dtype, 87 | ) 88 | emb[text_mask] = emb_text 89 | emb[~text_mask] = emb_code.to(emb.dtype) 90 | 91 | return emb 92 | 93 | def prepare_inputs_for_generation( 94 | self, 95 | input_ids, 96 | past_key_values=None, 97 | attention_mask=None, 98 | inputs_embeds=None, 99 | cache_position=None, 100 | **kwargs, 101 | ): 102 | # With static cache, the `past_key_values` is None 103 | # TODO joao: standardize interface for the different Cache classes and remove of this if 104 | has_static_cache = False 105 | if past_key_values is None: 106 | past_key_values = getattr( 107 | self.gpt.layers[0].self_attn, "past_key_value", None 108 | ) 109 | has_static_cache = past_key_values is not None 110 | 111 | past_length = 0 112 | if past_key_values is not None: 113 | if isinstance(past_key_values, Cache): 114 | past_length = ( 115 | cache_position[0] 116 | if cache_position is not None 117 | else past_key_values.get_seq_length() 118 | ) 119 | max_cache_length = ( 120 | torch.tensor( 121 | past_key_values.get_max_length(), device=input_ids.device 122 | ) 123 | if past_key_values.get_max_length() is not None 124 | else None 125 | ) 126 | cache_length = ( 127 | past_length 128 | if max_cache_length is None 129 | else torch.min(max_cache_length, past_length) 130 | ) 131 | # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects 132 | else: 133 | cache_length = past_length = past_key_values[0][0].shape[2] 134 | max_cache_length = None 135 | 136 | # Keep only the unprocessed tokens: 137 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where 138 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as 139 | # input) 140 | if ( 141 | attention_mask is not None 142 | and attention_mask.shape[1] > input_ids.shape[1] 143 | ): 144 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] 145 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard 146 | # input_ids based on the past_length. 147 | elif past_length < input_ids.shape[1]: 148 | input_ids = input_ids[:, past_length:] 149 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. 150 | 151 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. 152 | if ( 153 | max_cache_length is not None 154 | and attention_mask is not None 155 | and cache_length + input_ids.shape[1] > max_cache_length 156 | ): 157 | attention_mask = attention_mask[:, -max_cache_length:] 158 | 159 | position_ids = kwargs.get("position_ids", None) 160 | if attention_mask is not None and position_ids is None: 161 | # create position_ids on the fly for batch generation 162 | position_ids = attention_mask.long().cumsum(-1) - 1 163 | position_ids.masked_fill_(attention_mask == 0, 1) 164 | if past_key_values: 165 | position_ids = position_ids[:, -input_ids.shape[1] :] 166 | 167 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 168 | if inputs_embeds is not None and past_key_values is None: 169 | model_inputs = {"inputs_embeds": inputs_embeds} 170 | else: 171 | # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise 172 | # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 173 | # TODO: use `next_tokens` directly instead. 174 | model_inputs = {"input_ids": input_ids.contiguous()} 175 | 176 | input_length = ( 177 | position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] 178 | ) 179 | if cache_position is None: 180 | cache_position = torch.arange( 181 | past_length, past_length + input_length, device=input_ids.device 182 | ) 183 | else: 184 | cache_position = cache_position[-input_length:] 185 | 186 | if has_static_cache: 187 | past_key_values = None 188 | 189 | model_inputs.update( 190 | { 191 | "position_ids": position_ids, 192 | "cache_position": cache_position, 193 | "past_key_values": past_key_values, 194 | "use_cache": kwargs.get("use_cache"), 195 | "attention_mask": attention_mask, 196 | } 197 | ) 198 | return model_inputs 199 | 200 | def generate( 201 | self, 202 | emb, 203 | inputs_ids, 204 | temperature, 205 | eos_token, 206 | attention_mask=None, 207 | max_new_token=2048, 208 | min_new_token=0, 209 | LogitsWarpers=[], 210 | LogitsProcessors=[], 211 | infer_text=False, 212 | return_attn=False, 213 | return_hidden=False, 214 | ): 215 | 216 | with torch.no_grad(): 217 | 218 | attentions = [] 219 | hiddens = [] 220 | 221 | start_idx, end_idx = inputs_ids.shape[1], torch.zeros( 222 | inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long 223 | ) 224 | finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool() 225 | 226 | temperature = temperature[None].expand(inputs_ids.shape[0], -1) 227 | temperature = rearrange(temperature, "b n -> (b n) 1") 228 | 229 | attention_mask_cache = torch.ones( 230 | ( 231 | inputs_ids.shape[0], 232 | inputs_ids.shape[1] + max_new_token, 233 | ), 234 | dtype=torch.bool, 235 | device=inputs_ids.device, 236 | ) 237 | if attention_mask is not None: 238 | attention_mask_cache[:, : attention_mask.shape[1]] = attention_mask 239 | 240 | for i in tqdm(range(max_new_token)): 241 | 242 | model_input = self.prepare_inputs_for_generation( 243 | inputs_ids, 244 | outputs.past_key_values if i != 0 else None, 245 | attention_mask_cache[:, : inputs_ids.shape[1]], 246 | use_cache=True, 247 | ) 248 | 249 | if i == 0: 250 | model_input["inputs_embeds"] = emb 251 | else: 252 | if infer_text: 253 | model_input["inputs_embeds"] = self.emb_text( 254 | model_input["input_ids"][:, :, 0] 255 | ) 256 | else: 257 | code_emb = [ 258 | self.emb_code[i](model_input["input_ids"][:, :, i]) 259 | for i in range(self.num_vq) 260 | ] 261 | model_input["inputs_embeds"] = torch.stack(code_emb, 3).sum(3) 262 | 263 | model_input["input_ids"] = None 264 | outputs = self.gpt.forward(**model_input, output_attentions=return_attn) 265 | attentions.append(outputs.attentions) 266 | hidden_states = outputs[0] # 🐻 267 | if return_hidden: 268 | hiddens.append(hidden_states[:, -1]) 269 | 270 | with P.cached(): 271 | if infer_text: 272 | logits = self.head_text(hidden_states) 273 | else: 274 | logits = torch.stack( 275 | [ 276 | self.head_code[i](hidden_states) 277 | for i in range(self.num_vq) 278 | ], 279 | 3, 280 | ) 281 | 282 | logits = logits[:, -1].float() 283 | 284 | if not infer_text: 285 | logits = rearrange(logits, "b c n -> (b n) c") 286 | logits_token = rearrange( 287 | inputs_ids[:, start_idx:], "b c n -> (b n) c" 288 | ) 289 | else: 290 | logits_token = inputs_ids[:, start_idx:, 0] 291 | 292 | logits = logits / temperature 293 | 294 | for logitsProcessors in LogitsProcessors: 295 | logits = logitsProcessors(logits_token, logits) 296 | 297 | for logitsWarpers in LogitsWarpers: 298 | logits = logitsWarpers(logits_token, logits) 299 | 300 | if i < min_new_token: 301 | logits[:, eos_token] = -torch.inf 302 | 303 | scores = F.softmax(logits, dim=-1) 304 | 305 | idx_next = torch.multinomial(scores, num_samples=1) 306 | 307 | if not infer_text: 308 | idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq) 309 | finish = finish | (idx_next == eos_token).any(1) 310 | inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(1)], 1) 311 | else: 312 | finish = finish | (idx_next == eos_token).any(1) 313 | inputs_ids = torch.cat( 314 | [ 315 | inputs_ids, 316 | idx_next.unsqueeze(-1).expand(-1, -1, self.num_vq), 317 | ], 318 | 1, 319 | ) 320 | 321 | end_idx = end_idx + (~finish).int() 322 | 323 | if finish.all(): 324 | break 325 | 326 | inputs_ids = [ 327 | inputs_ids[idx, start_idx : start_idx + i] 328 | for idx, i in enumerate(end_idx.int()) 329 | ] 330 | inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids 331 | 332 | if return_hidden: 333 | hiddens = torch.stack(hiddens, 1) 334 | hiddens = [hiddens[idx, :i] for idx, i in enumerate(end_idx.int())] 335 | 336 | if not finish.all(): 337 | self.logger.warn( 338 | f"Incomplete result. hit max_new_token: {max_new_token}" 339 | ) 340 | 341 | return { 342 | "ids": inputs_ids, 343 | "attentions": attentions, 344 | "hiddens": hiddens, 345 | } 346 | -------------------------------------------------------------------------------- /fastapi/ChatTTS/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanquderzi/ChatTTS-Deployment-using-FastAPI-and-Streamlit/a39fecac6f49535a5d2a376ea9f61adbc3b8871e/fastapi/ChatTTS/utils/__init__.py -------------------------------------------------------------------------------- /fastapi/ChatTTS/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanquderzi/ChatTTS-Deployment-using-FastAPI-and-Streamlit/a39fecac6f49535a5d2a376ea9f61adbc3b8871e/fastapi/ChatTTS/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /fastapi/ChatTTS/utils/__pycache__/gpu_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanquderzi/ChatTTS-Deployment-using-FastAPI-and-Streamlit/a39fecac6f49535a5d2a376ea9f61adbc3b8871e/fastapi/ChatTTS/utils/__pycache__/gpu_utils.cpython-38.pyc -------------------------------------------------------------------------------- /fastapi/ChatTTS/utils/__pycache__/infer_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanquderzi/ChatTTS-Deployment-using-FastAPI-and-Streamlit/a39fecac6f49535a5d2a376ea9f61adbc3b8871e/fastapi/ChatTTS/utils/__pycache__/infer_utils.cpython-38.pyc -------------------------------------------------------------------------------- /fastapi/ChatTTS/utils/__pycache__/io_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanquderzi/ChatTTS-Deployment-using-FastAPI-and-Streamlit/a39fecac6f49535a5d2a376ea9f61adbc3b8871e/fastapi/ChatTTS/utils/__pycache__/io_utils.cpython-38.pyc -------------------------------------------------------------------------------- /fastapi/ChatTTS/utils/gpu_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | 4 | 5 | def select_device(min_memory=2048): 6 | logger = logging.getLogger(__name__) 7 | if torch.cuda.is_available(): 8 | available_gpus = [] 9 | for i in range(torch.cuda.device_count()): 10 | props = torch.cuda.get_device_properties(i) 11 | free_memory = props.total_memory - torch.cuda.memory_reserved(i) 12 | available_gpus.append((i, free_memory)) 13 | selected_gpu, max_free_memory = max(available_gpus, key=lambda x: x[1]) 14 | device = torch.device(f"cuda:{selected_gpu}") 15 | free_memory_mb = max_free_memory / (1024 * 1024) 16 | if free_memory_mb < min_memory: 17 | logger.log( 18 | logging.WARNING, 19 | f"GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left.", 20 | ) 21 | device = torch.device("cpu") 22 | # future support 23 | # elif torch.backends.mps.is_available(): 24 | # device = torch.device("mps") 25 | else: 26 | logger.log(logging.WARNING, f"No GPU found, use CPU instead") 27 | device = torch.device("cpu") 28 | 29 | return device 30 | -------------------------------------------------------------------------------- /fastapi/ChatTTS/utils/infer_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class CustomRepetitionPenaltyLogitsProcessorRepeat: 6 | 7 | def __init__(self, penalty: float, max_input_ids, past_window): 8 | if not isinstance(penalty, float) or not (penalty > 0): 9 | raise ValueError( 10 | f"`penalty` has to be a strictly positive float, but is {penalty}" 11 | ) 12 | 13 | self.penalty = penalty 14 | self.max_input_ids = max_input_ids 15 | self.past_window = past_window 16 | 17 | def __call__( 18 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor 19 | ) -> torch.FloatTensor: 20 | 21 | input_ids = input_ids[:, -self.past_window :] 22 | freq = F.one_hot(input_ids, scores.size(1)).sum(1) 23 | freq[self.max_input_ids :] = 0 24 | alpha = self.penalty**freq 25 | scores = torch.where(scores < 0, scores * alpha, scores / alpha) 26 | 27 | return scores 28 | 29 | 30 | class CustomRepetitionPenaltyLogitsProcessor: 31 | 32 | def __init__(self, penalty: float, max_input_ids, past_window): 33 | if not isinstance(penalty, float) or not (penalty > 0): 34 | raise ValueError( 35 | f"`penalty` has to be a strictly positive float, but is {penalty}" 36 | ) 37 | 38 | self.penalty = penalty 39 | self.max_input_ids = max_input_ids 40 | self.past_window = past_window 41 | 42 | def __call__( 43 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor 44 | ) -> torch.FloatTensor: 45 | 46 | input_ids = input_ids[:, -self.past_window :] 47 | score = torch.gather(scores, 1, input_ids) 48 | _score = score.detach().clone() 49 | score = torch.where(score < 0, score * self.penalty, score / self.penalty) 50 | score[input_ids >= self.max_input_ids] = _score[input_ids >= self.max_input_ids] 51 | scores.scatter_(1, input_ids, score) 52 | 53 | return scores 54 | -------------------------------------------------------------------------------- /fastapi/ChatTTS/utils/io_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | 5 | def get_latest_modified_file(directory): 6 | logger = logging.getLogger(__name__) 7 | 8 | files = [os.path.join(directory, f) for f in os.listdir(directory)] 9 | if not files: 10 | logger.log(logging.WARNING, f"No files found in the directory: {directory}") 11 | return None 12 | latest_file = max(files, key=os.path.getmtime) 13 | 14 | return latest_file 15 | -------------------------------------------------------------------------------- /fastapi/Dockerfile: -------------------------------------------------------------------------------- 1 | # 使用官方Python运行时作为基础镜像 2 | FROM python:3.9-slim 3 | 4 | # 设置工作目录 5 | WORKDIR /fastapi 6 | 7 | # 将当前目录的内容复制到容器中的/fastapi目录 8 | COPY . /fastapi 9 | COPY requirements.txt /fastapi 10 | 11 | # 安装requirements.txt中指定的所有依赖包 12 | RUN pip install --no-cache-dir -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple 13 | 14 | # 暴露8000端口 15 | EXPOSE 8000 16 | 17 | # 运行FastAPI应用的uvicorn服务器 18 | CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"] -------------------------------------------------------------------------------- /fastapi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanquderzi/ChatTTS-Deployment-using-FastAPI-and-Streamlit/a39fecac6f49535a5d2a376ea9f61adbc3b8871e/fastapi/__init__.py -------------------------------------------------------------------------------- /fastapi/__pycache__/server.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanquderzi/ChatTTS-Deployment-using-FastAPI-and-Streamlit/a39fecac6f49535a5d2a376ea9f61adbc3b8871e/fastapi/__pycache__/server.cpython-38.pyc -------------------------------------------------------------------------------- /fastapi/output.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanquderzi/ChatTTS-Deployment-using-FastAPI-and-Streamlit/a39fecac6f49535a5d2a376ea9f61adbc3b8871e/fastapi/output.wav -------------------------------------------------------------------------------- /fastapi/requirements.txt: -------------------------------------------------------------------------------- 1 | omegaconf==2.3.0 2 | torch==2.1.2 3 | tqdm==4.66.4 4 | einops==0.8.0 5 | vector_quantize_pytorch==1.14.8 6 | transformers==4.41.2 7 | vocos==0.1.0 8 | fastapi==0.111.0 9 | uvicorn==0.30.0 10 | starlette==0.37.2 11 | pydantic==2.7.2 12 | modelscope==1.14.0 13 | -------------------------------------------------------------------------------- /fastapi/server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import wave 3 | import numpy as np 4 | from fastapi import FastAPI, HTTPException, Depends 5 | import pydantic 6 | import ChatTTS 7 | #模型下载 8 | from modelscope import snapshot_download 9 | model_dir = snapshot_download('mirror013/ChatTTS') 10 | 11 | app = FastAPI() 12 | 13 | 14 | class TTSInput(pydantic.BaseModel): 15 | text: str 16 | output_path: str 17 | seed: int = 697 18 | 19 | 20 | def get_chat_model() -> ChatTTS.Chat: 21 | chat = ChatTTS.Chat() 22 | chat.load_models( 23 | source="local", 24 | local_path=model_dir, 25 | ) 26 | return chat 27 | 28 | 29 | @app.post("/tts") 30 | def tts(input: TTSInput, chat: ChatTTS.Chat = Depends(get_chat_model)): 31 | try: 32 | texts = [input.text] 33 | r = chat.sample_random_speaker(seed=input.seed) 34 | 35 | params_infer_code = { 36 | 'spk_emb': r, # add sampled speaker 37 | 'temperature': .3, # using customtemperature 38 | 'top_P': 0.7, # top P decode 39 | 'top_K': 20, # top K decode 40 | } 41 | 42 | params_refine_text = { 43 | 'prompt': '[oral_2][laugh_0][break_6]' 44 | } 45 | 46 | wavs = chat.infer(texts, 47 | params_infer_code=params_infer_code, 48 | params_refine_text=params_refine_text, use_decoder=True) 49 | 50 | audio_data = np.array(wavs[0], dtype=np.float32) 51 | sample_rate = 24000 52 | audio_data = (audio_data * 32767).astype(np.int16) 53 | 54 | with wave.open(input.output_path, "w") as wf: 55 | wf.setnchannels(1) 56 | wf.setsampwidth(2) 57 | wf.setframerate(sample_rate) 58 | wf.writeframes(audio_data.tobytes()) 59 | return {"output_path": input.output_path} 60 | 61 | except Exception as e: 62 | raise HTTPException(status_code=500, detail=str(e)) 63 | -------------------------------------------------------------------------------- /streamlit/Dockerfile: -------------------------------------------------------------------------------- 1 | # 使用官方Python运行时作为基础镜像 2 | FROM python:3.9-slim 3 | 4 | # 设置工作目录 5 | WORKDIR /streamlit 6 | 7 | # 将当前目录的内容复制到容器中的/streamlit目录 8 | COPY ui.py /streamlit 9 | COPY requirements.txt /streamlit 10 | 11 | # 安装requirements.txt中指定的所有依赖包 12 | RUN pip install --no-cache-dir -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple 13 | 14 | # 暴露Streamlit的默认端口 15 | EXPOSE 8501 16 | 17 | # 运行Streamlit应用 18 | CMD ["streamlit", "run", "ui.py"] 19 | -------------------------------------------------------------------------------- /streamlit/requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit==1.30.0 2 | requests==2.31.0 3 | cn2an==0.5.22 4 | -------------------------------------------------------------------------------- /streamlit/ui.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import requests 3 | import json 4 | import re 5 | from cn2an import an2cn 6 | import os 7 | 8 | def convert_arabic_to_chinese_in_string(s): 9 | def replace_func(match): 10 | number = match.group(0) 11 | return an2cn(number) 12 | 13 | # 匹配所有的阿拉伯数字 14 | converted_str = re.sub(r'\d+', replace_func, s) 15 | return converted_str 16 | 17 | def synthesize_speech(text, output_path, seed, url="http://fastapi:8001/tts"): # 本地部署fastapi 改为localhost 18 | payload = { 19 | "text": text, 20 | "output_path": output_path, 21 | "seed": seed 22 | } 23 | 24 | headers = { 25 | "content-type": "application/json" 26 | } 27 | 28 | try: 29 | response = requests.post(url, headers=headers, data=json.dumps(payload)) 30 | 31 | if response.status_code == 200: 32 | st.success(f"语音合成成功,输出保存到 {output_path}.") 33 | return True 34 | else: 35 | st.error(f"语音合成失败,状态码: {response.status_code}") 36 | st.error(f"响应: {response.text}") 37 | return False 38 | 39 | except Exception as e: 40 | st.error(f"发生错误: {e}") 41 | return False 42 | 43 | # 初始化session state 44 | if 'synthesized' not in st.session_state: 45 | st.session_state.synthesized = False 46 | if 'output_path' not in st.session_state: 47 | st.session_state.output_path = "" 48 | 49 | st.title("ChatTTS语音合成接口") 50 | st.write("输入文本、输出文件路径和种子值,生成语音文件。") 51 | 52 | text = st.text_area("输入文本") 53 | text = convert_arabic_to_chinese_in_string(text) 54 | output_path = st.text_input("输出文件路径", "output.wav") 55 | seed = st.number_input("种子值", value=0) 56 | 57 | if st.button("合成语音"): 58 | if synthesize_speech(text, output_path, seed): 59 | st.session_state.synthesized = True 60 | st.session_state.output_path = os.path.join('/fastapi/', output_path) 61 | 62 | if st.session_state.synthesized: 63 | output_file_path = st.session_state.output_path 64 | 65 | # 检查文件是否存在 66 | if os.path.exists(output_file_path): 67 | audio_file = open(output_file_path, 'rb') 68 | audio_bytes = audio_file.read() 69 | st.audio(audio_bytes, format='audio/wav') 70 | 71 | # 提供下载链接 72 | with open(output_file_path, 'rb') as f: 73 | st.download_button( 74 | label="下载生成的语音文件", 75 | data=f, 76 | file_name=os.path.basename(output_file_path), 77 | mime='audio/wav' 78 | ) 79 | else: 80 | st.error(f"文件不存在: {output_file_path}") 81 | 82 | --------------------------------------------------------------------------------