├── 5.wav ├── README.md └── api2.py /5.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/gptsovits-api/511bee4a47e2425e3daaf6be70310df41fa67bd7/5.wav -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # api2.py 适用于 GPT-SoVITS 的api调用接口 2 | 3 | > [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/docs/cn/README.md) 一个非常棒的零(少)样本中文声音克隆项目,之前有一篇文章详细介绍过如何部署和训练自己的模型([点击查看](https://juejin.cn/post/7341210909070000168)),并使用该模型在web界面中合成声音,可惜它自带的 api.py 在调用方面支持比较差,比如不能中英混合、无法按标点切分句子等,因此对api.py做了修改,详细使用说明如下。 4 | > 5 | > 修改后代码开源地址:https://github.com/jianchang512/gptsovits-api 6 | 7 | 8 | 9 | 10 | 下载api2.py,复制到GPT-SoVITS软件目录下,执行命令同自带api.py一样,只需要将名字 api.py 改成 api2.py。默认端口也是 9880,默认绑定 127.0.0.1 11 | 12 | ---- 13 | 14 | 15 | ## 使用默认模型启动并指定默认参考音频 -dr -dt -dl 16 | 17 | 假设参考音频要使用根目录下的 **123.wav** ,音频文字是 **“一二三四五六七。”** ,音频语言是中文,那么命令如下: 18 | 19 | 20 | ` .\runtime\python api2.py -dr "123.wav" -dt "一二三四五六七。" -dl "zh" ` 21 | 22 | 23 | ![image.png](https://p1-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/a131add45289448391eeec4ebcd9c2d1~tplv-k3u1fbpfcp-jj-mark:0:0:0:0:q75.image#?w=1018&h=82&s=9175&e=png&b=0c0c0c) 24 | 25 | Linux下命令去掉 `.\runtime\` 即可 26 | 27 | 如上述命令这样在启动后,指定的参考音频将作为默认配置,当请求api数据中如果未指定参考音频,将使用。 28 | 29 | 30 | ## 默认模型启动并绑定ip地址和端口 -a -p 31 | 32 | 假设要指定绑定内网ip **192.168.0.120** ,端口要使用 **9001**,不指定默认参考音频,那么执行如下命令: 33 | 34 | `.\runtime\python api2.py -a "127.0.0.1" -p 9001 ` 35 | 36 | 37 | ![image.png](https://p1-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/4d57b23134b5436f825f04944e4bc250~tplv-k3u1fbpfcp-jj-mark:0:0:0:0:q75.image#?w=700&h=70&s=4954&e=png&b=0c0c0c) 38 | 39 | 40 | ## 启动自己训练好的模型 -s -g 41 | 42 | 在指定自己的模型时,必须确保同时**指定参考音频**。训练好的模型可分别在软件目录下的 **GPT_weights** 和 **SoVITS_weights** 目录下寻找,以你训练时命名的模型名称开头,后跟`e数字`最大的那个即可。 43 | 44 | `.\runtime\python api2.py -s "SoVITS_weights/你的模型名" -g "GPT_weights/你的模型名" -dr "参考音频路径和名称" -dt "参考音频的文字内容,使用双引号括起来,确保文字内容里没有双引号" -dl zh|ja|en三者选一 ` 45 | 46 | 47 | ![image.png](https://p6-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/81051eca33694c50a40ddfd67828eda4~tplv-k3u1fbpfcp-jj-mark:0:0:0:0:q75.image#?w=1666&h=54&s=10806&e=png&b=0c0c0c) 48 | 49 | 50 | ## 强制在CPU上推理 -d cpu 51 | 52 | 默认将优先使用 CUDA 或 mps(Mac), 如果你想指定在CPU上运行,可以通过 `-d cpu `指定 53 | 54 | `.\runtime\python api2.py -d cpu` 55 | 56 | 注意 `-d` 后只能是 `cpu 或 cuda 或 mps`,并且只有在正确配置 cuda 后才能指定 cuda,只有Apple CPU Mac上才能指定 mps 57 | 58 | 59 | ## 全部按照默认运行 60 | 61 | `.\runtime\python api2.py` 62 | 63 | 这种方式将使用默认模型,并且在 api 请求时必须指定参考音频、参考音频文字内容、参考音频语言代码,api 监听 9880 端口 64 | 65 | 66 | ## 可使用的语言代码 -dl 67 | 68 | 仅支持 **中文、日语、英语** 三种语言,对应只可使用 `zh(代表中文或中英混合)`、`ja(代表日语或日英混合)`、`en(代表英语)`,使用 -dl 指定,如 `-dl zh`,`-dl ja`,`-dl en` 69 | 70 | ## 参考音频路径 -dr 71 | 72 | 参考音频填写以软件根目录为起点的相对目录,假如你的参考音频是直接放在软件根目录下,那么只需要填写带后缀的完整名字即可,比如 `-dr 123.wav`,如果是在子目录下,比如在 `wavs` 文件夹下,那么填写 `-dr "wavs/123.wav"` 73 | 74 | ## 参考音频的文字内容 -dt 75 | 76 | 参考音频的文字内容就是音频里的说话文字,需要正确填写标点符号,并使用英文双引号括起来。请注意,文字中不要再有英文双引号。 77 | 78 | `-dt "这里填写参考音频的文字内容,不要含有英文双引号"` 79 | 80 | 81 | ![image.png](https://p1-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/075dcfcb088f45b0982420638f1005b0~tplv-k3u1fbpfcp-jj-mark:0:0:0:0:q75.image#?w=1448&h=750&s=148467&e=png&b=0c0c0c) 82 | 83 | 84 | ## 可用的命令行参数: 85 | 86 | **模型相关参数** 87 | 88 | `-s` SoVITS模型路径, 默认模型无需填写,自训练模型在 SoVITS_weights 目录下 89 | 90 | `-g` GPT模型路径, 默认模型无需填写,自训练模型在 GPT_weights 目录下 91 | 92 | **参考音频相关参数** 93 | 94 | `-dr` 默认参考音频路径,如果在根目录下,直接填写带后缀名字,否则加上 路径/名字 95 | 96 | `-dt` 默认参考音频文本,音频的文字内容,以英文双引号括起来 97 | 98 | `-dl` 默认参考音频内容的语种, "zh"或"en"或"ja" 99 | 100 | **设备和地址相关参数** 101 | 102 | `-d` 推理设备, "cuda","cpu","mps" 只有配置好了cuda环境才可指定cuda,只有Apple CPU上才可指定mps 103 | 104 | `-a` 绑定地址, 默认"127.0.0.1" 105 | 106 | `-p` 绑定端口, 默认9880 107 | 108 | **不常用参数,新手可忽略不必设置** 109 | 110 | `-fp` 使用全精度 111 | 112 | `-hp` 使用半精度 113 | 114 | `-hb` cnhubert路径 115 | 116 | `-b` bert路径 117 | 118 | `-c` 1-5, 默认5,代表按标点符号切分。 1=凑四句一切 2=凑50字一切 3=按中文句号。切 4=按英文句号.切 5=按标点符号切 119 | 120 | 121 | 122 | ## API调用示例: 123 | 124 | 调用地址url: `http://你指定的ip:指定的端口`,默认是 `http://127.0.0.1:9880` 125 | 126 | 127 | **调用时不指定参考音频** 128 | 129 | 启动 api2.py 时必须指定默认参考音频,才可在调用api时不指定,否则会失败: 130 | 131 | GET方式调用,可直接浏览器中打开: 132 | 133 | `http://127.0.0.1:9880?text=亲爱的朋友你好啊,希望你的每一天都充满快乐。&text_language=zh` 134 | 135 | ![image.png](https://p6-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/dcc86a39958047fda782393c664729a0~tplv-k3u1fbpfcp-jj-mark:0:0:0:0:q75.image#?w=1137&h=713&s=31693&e=png&b=000000) 136 | 137 | 138 | POST方式调用,以json格式传参: 139 | 140 | ```json 141 | { 142 | "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", 143 | "text_language": "zh" 144 | } 145 | ``` 146 | 147 | ![image.png](https://p1-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/740bcf5d828c4a528273af2ff8d04471~tplv-k3u1fbpfcp-jj-mark:0:0:0:0:q75.image#?w=877&h=356&s=31347&e=png&b=ffffff) 148 | 149 | 150 | 151 | ### 手动指定当次所使用的参考音频: 152 | 153 | GET方式: 154 | 155 | `http://127.0.0.1:9880?refer_wav_path=wavs/5.wav&prompt_text=为什么御弟哥哥,甘愿守孤灯。&prompt_language=zh&text=亲爱的朋友你好啊,希望你的每一天都充满快乐。&text_language=zh` 156 | 157 | POST方式: 158 | 159 | ```json 160 | { 161 | "refer_wav_path": "wavs/5.wav", 162 | "prompt_text": "为什么御弟哥哥,甘愿守孤灯。", 163 | "prompt_language": "zh", 164 | "text": "亲爱的朋友你好啊,希望你的每一天都充满快乐。", 165 | "text_language": "zh" 166 | } 167 | ``` 168 | 169 | ![image.png](https://p1-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/1d715a9b3d1c4851a6dfd7f2e724d68b~tplv-k3u1fbpfcp-jj-mark:0:0:0:0:q75.image#?w=835&h=447&s=39473&e=png&b=fffefe) 170 | 171 | 172 | 173 | ## Api调用返回信息: 174 | 175 | 成功时: 返回 wav 音频流,可直接播放或保存到 wav文件中,http 状态码 200 176 | 177 | 失败时: 返回包含错误信息的 json, http 状态码 400 178 | 179 | ``` 180 | {"code": 400, "message": "未指定参考音频且接口无预设"} 181 | ``` 182 | 183 | 184 | 185 | 186 | 187 | ## 问题:想切换模型怎么办 188 | 189 | api2.py和官方原版api.py 一样都不支持动态模型切换,也不建议这样做,因为动态启动加载模型很慢,而且在失败时也不方便处理。 190 | 191 | **解决方法是:** 一个模型起一个api服务器,绑定不同的端口,在启动api2.py时,指定当前服务所要使用的模型和绑定的端口。 192 | 193 | 比如起2个服务,一个使用默认模型,绑定 9880 端口,一个绑定自己训练的模型,绑定 9881 端口,命令如下 194 | 195 | **默认模型 9880 端口**: http://127.0.0.1:9880 196 | 197 | `.\runtime\python api2.py -dr "5.wav" -dt "今天好开心" -dl zh ` 198 | 199 | 200 | **自己训练的模型**: http://127.0.0.1:9881 201 | 202 | `.\runtime\python api2.py -p 9881 -s "SoVITS_weights/mymode-e200.pth" -g "GPT_weights/mymode-e200.ckpt" -dr "wavs/10.wav" -dt "御弟哥哥,为什么甘愿守孤灯" -dl zh ` 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | -------------------------------------------------------------------------------- /api2.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @author jiangchang512 3 | 针对GPT-SoVITS 自带的api.py存在不支持中英混合、无法按标点切分句子,进行修改后的api文件 4 | 5 | github.com/jianchang512/gptsovits-api 6 | 7 | 8 | 使用方法同自带api.py一样,只需要将名字 api.py 改成 api2.py 9 | 10 | 默认端口也是 9880,默认绑定本机所有可用ip地址 11 | 12 | ################################################### 13 | 14 | # api2.py 使用示例 15 | 16 | 17 | ## 使用默认模型启动并指定默认参考音频 -dr -dt -dl 18 | 19 | 假设参考音频要使用根目录下的 123.wav,音频文字是“一二三四五六七。” ,音频语言是中文,那么命令如下 20 | 21 | windows 预打包版命令: 22 | 23 | ` .\runtime\python api2.py -dr "123.wav" -dt "一二三四五六七。" -dl "zh" ` 24 | 25 | Linux下命令去掉 .\runtime\即可 26 | 27 | 如上述命令这样在启动后,指定的参考音频将作为默认配置,当请求api数据中如果未指定参考音频,将使用。 28 | 29 | ## 默认模型启动并绑定ip地址和端口 -a -p 30 | 31 | 假设要指定绑定内网ip 192.168.0.120 ,端口要使用 9001,不指定默认参考音频,那么执行如下命令 32 | 33 | `.\runtime\python api2.py -a "127.0.0.1" -p 9001 ` 34 | 35 | ## 启动自己训练好的模型 -s -g 36 | 37 | 在指定自己的模型时,必须确保同时指定参考音频。训练好的模型可分别在软件目录下的 GPT_weights 和 SoVITS_weights 目录下寻找,以你训练时命名的模型名称开头,后跟`e数字`最大的那个即可。 38 | 39 | 命令 `.\runtime\python api2.py -s SoVITS_weights/你的模型名 -g GPT_weights/你的模型名 -dr "参考音频路径和名称" -dt "参考音频的文字内容,使用双引号括起来,确保文字内容里没有双引号" -dl zh|ja|en三者选一 ` 40 | 41 | 42 | ## 强制在CPU上运算 -d cpu 43 | 44 | 默认将优先使用 CUDA或mps(Mac),如果你只想在CPU上运行,可以通过 -d cpu指定 45 | 46 | 命令 `.\runtime\python api2.py -d cpu` 47 | 注意 -d后只能是 `cpu或 cuda或mps`,并且只有在正确配置cuda后才能指定cuda,只要apple芯片的mac上才能指定mps 48 | 49 | 50 | ## 全部按照默认运行 51 | 52 | `.\runtime\python api2.py` 53 | 54 | 这种方式将使用默认模型,并且在api请求时必须指定参考音频、参考音频文字内容、参考音频语言代码,api监听 9880 端口 55 | 56 | 57 | 58 | ## 可使用的语言代码 -dl 59 | 60 | 仅支持 中文、日语、英语三种语言,对应只可使用 zh(代表中文或中英混合)、ja(代表日语或日英混合)、en(代表英语),使用 -dl 指定,如 `-dl zh`,`-dl ja`,`-dl en` 61 | 62 | ## 参考音频路径 -dr 63 | 64 | 参考音频填写以软件根目录为起点的相对目录,加入你的参考音频直接放在软件根目录下,那么只需要填写带后缀的完整名字即可,比如 `-dr 123.wav`,如果是在子目录下,比如在 wavs 文件夹下,那么填写 `-dr "wavs/123.wav"` 65 | 66 | ## 参考音频的文字内容 -dt 67 | 68 | 参考音频的文字内容就是音频里的说话文字,需要正确填写标点符号。但请注意,不要使用英文双引号。 69 | 70 | `-dt "这里填写参考音频的文字内容,不要含有英文双引号"` 71 | 72 | 73 | 74 | 75 | ## 可用的命令行参数: 76 | 77 | 模型相关参数 78 | `-s` `SoVITS模型路径, 默认模型无需填写,自训练模型在 SoVITS_weights 目录下` 79 | `-g` `GPT模型路径, 默认模型无需填写,自训练模型在 GPT_weights 目录下` 80 | 81 | 参考音频相关参数 82 | `-dr` - `默认参考音频路径,如果在根目录下,直接填写带后缀名字,否则加上 路径/名字` 83 | `-dt` - `默认参考音频文本,音频的文字内容,以英文双引号括起来` 84 | `-dl` - `默认参考音频内容的语种, "zh"或"en"或"ja"` 85 | 86 | 设备和地址相关参数 87 | `-d` - `推理设备, "cuda","cpu","mps" 只有配置好了cuda环境才可指定cuda,只有Apple CPU上才可指定mps` 88 | `-a` - `绑定地址, 默认"127.0.0.1"` 89 | `-p` - `绑定端口, 默认9880` 90 | 91 | # 以下不常用,新手可忽略不必设置 92 | `-fp` `使用全精度` 93 | `-hp` `使用半精度` 94 | `-hb` `cnhubert路径` 95 | `-b` `bert路径` 96 | `-c` 1-5, 默认5,代表按标点符号切分。 1=凑四句一切 2=凑50字一切 3=按中文句号。切 4=按英文句号.切 5=按标点符号切 97 | 98 | 99 | ## API调用: 100 | 101 | 调用url: `http://你指定的ip:指定的端口/` 102 | 103 | 104 | ### 调用时不指定参考音频,使用 启动 api2.py 时指定的默认参考音频,在启动时必须指定了默认,否则会失败: 105 | 106 | GET方式调用,可直接浏览器中打开 107 | 108 | `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` 109 | 110 | 111 | POST方式调用,以json格式传参: 112 | ```json 113 | { 114 | "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", 115 | "text_language": "zh" 116 | } 117 | ``` 118 | 119 | 120 | 121 | ### 手动指定当次推理所使用的参考音频: 122 | 123 | GET方式: 124 | `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` 125 | 126 | POST方式: 127 | ```json 128 | { 129 | "refer_wav_path": "123.wav", 130 | "prompt_text": "一二三。", 131 | "prompt_language": "zh", 132 | "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", 133 | "text_language": "zh" 134 | } 135 | ``` 136 | 137 | 138 | 139 | ### 返回信息: 140 | 141 | 成功时: 直接返回 wav 音频流, http code 200 142 | 失败时: 返回包含错误信息的 json, http code 400 143 | 144 | 145 | 146 | ''' 147 | 148 | import os, re, logging,argparse,torch,sys 149 | if torch.cuda.is_available(): 150 | device = "cuda" 151 | elif torch.backends.mps.is_available(): 152 | device = "mps" 153 | else: 154 | device = "cpu" 155 | 156 | # 处理参数 157 | parser = argparse.ArgumentParser(description="GPT-SoVITS api") 158 | 159 | parser.add_argument("-s", "--sovits_path", type=str, default="GPT_SoVITS/pretrained_models/s2G488k.pth", help="SoVITS模型路径") 160 | parser.add_argument("-g", "--gpt_path", type=str, default="GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", help="GPT模型路径") 161 | 162 | parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径") 163 | parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本") 164 | parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种") 165 | 166 | parser.add_argument("-d", "--device", type=str, default=device, help="cuda / cpu / mps") 167 | parser.add_argument("-a", "--bind_addr", type=str, default='127.0.0.1', help="default: 127.0.0.1") 168 | parser.add_argument("-p", "--port", type=int, default='9880', help="default: 9880") 169 | parser.add_argument("-c", "--cut", type=int, default=5, help="default: 5 按标点符号切分") 170 | parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度") 171 | parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度") 172 | 173 | parser.add_argument("-hb", "--hubert_path", type=str, default='GPT_SoVITS/pretrained_models/chinese-hubert-base', help="覆盖config.cnhubert_path") 174 | parser.add_argument("-b", "--bert_path", type=str, default='GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large', help="覆盖config.bert_path") 175 | 176 | args = parser.parse_args() 177 | 178 | 179 | default_wav=args.default_refer_path 180 | default_text=args.default_refer_text 181 | default_language=args.default_refer_language 182 | 183 | 184 | splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", } 185 | root_dir=os.getcwd() 186 | SoVITS_weight_root = "SoVITS_weights" 187 | GPT_weight_root = "GPT_weights" 188 | os.makedirs(SoVITS_weight_root, exist_ok=True) 189 | os.makedirs(GPT_weight_root, exist_ok=True) 190 | 191 | 192 | host=args.bind_addr 193 | port = args.port 194 | is_half=bool(args.half_precision) 195 | sys.path.append(root_dir) 196 | sys.path.append(os.path.join(root_dir,"GPT_SoVITS")) 197 | gpt_path = args.gpt_path 198 | sovits_path = args.sovits_path 199 | bert_path = args.bert_path 200 | 201 | 202 | os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。 203 | 204 | if "_CUDA_VISIBLE_DEVICES" in os.environ: 205 | os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] 206 | 207 | 208 | 209 | 210 | import LangSegment 211 | import pdb 212 | from fastapi.responses import StreamingResponse, JSONResponse 213 | from fastapi import FastAPI, Request, HTTPException 214 | import signal 215 | from io import BytesIO 216 | import uvicorn 217 | import soundfile as sf 218 | import gradio as gr 219 | from transformers import AutoModelForMaskedLM, AutoTokenizer 220 | import numpy as np 221 | import librosa, torch 222 | from GPT_SoVITS.feature_extractor import cnhubert 223 | from GPT_SoVITS.module.models import SynthesizerTrn 224 | from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule 225 | from GPT_SoVITS.text import cleaned_text_to_sequence 226 | from GPT_SoVITS.text.cleaner import clean_text 227 | from time import time as ttime 228 | from GPT_SoVITS.module.mel_processing import spectrogram_torch 229 | from GPT_SoVITS.my_utils import load_audio 230 | cnhubert.cnhubert_base_path = args.hubert_path 231 | 232 | logging.getLogger("markdown_it").setLevel(logging.ERROR) 233 | logging.getLogger("urllib3").setLevel(logging.ERROR) 234 | logging.getLogger("httpcore").setLevel(logging.ERROR) 235 | logging.getLogger("httpx").setLevel(logging.ERROR) 236 | logging.getLogger("asyncio").setLevel(logging.ERROR) 237 | logging.getLogger("charset_normalizer").setLevel(logging.ERROR) 238 | logging.getLogger("torchaudio._extension").setLevel(logging.ERROR) 239 | 240 | 241 | 242 | tokenizer = AutoTokenizer.from_pretrained(bert_path) 243 | bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) 244 | if is_half == True: 245 | bert_model = bert_model.half().to(device) 246 | else: 247 | bert_model = bert_model.to(device) 248 | 249 | 250 | def get_bert_feature(text, word2ph): 251 | with torch.no_grad(): 252 | inputs = tokenizer(text, return_tensors="pt") 253 | for i in inputs: 254 | inputs[i] = inputs[i].to(device) 255 | res = bert_model(**inputs, output_hidden_states=True) 256 | res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] 257 | assert len(word2ph) == len(text) 258 | phone_level_feature = [] 259 | for i in range(len(word2ph)): 260 | repeat_feature = res[i].repeat(word2ph[i], 1) 261 | phone_level_feature.append(repeat_feature) 262 | phone_level_feature = torch.cat(phone_level_feature, dim=0) 263 | return phone_level_feature.T 264 | 265 | 266 | class DictToAttrRecursive(dict): 267 | def __init__(self, input_dict): 268 | super().__init__(input_dict) 269 | for key, value in input_dict.items(): 270 | if isinstance(value, dict): 271 | value = DictToAttrRecursive(value) 272 | self[key] = value 273 | setattr(self, key, value) 274 | 275 | def __getattr__(self, item): 276 | try: 277 | return self[item] 278 | except KeyError: 279 | raise AttributeError(f"Attribute {item} not found") 280 | 281 | def __setattr__(self, key, value): 282 | if isinstance(value, dict): 283 | value = DictToAttrRecursive(value) 284 | super(DictToAttrRecursive, self).__setitem__(key, value) 285 | super().__setattr__(key, value) 286 | 287 | def __delattr__(self, item): 288 | try: 289 | del self[item] 290 | except KeyError: 291 | raise AttributeError(f"Attribute {item} not found") 292 | 293 | 294 | ssl_model = cnhubert.get_model() 295 | if is_half == True: 296 | ssl_model = ssl_model.half().to(device) 297 | else: 298 | ssl_model = ssl_model.to(device) 299 | 300 | 301 | def change_sovits_weights(sovits_path): 302 | global vq_model, hps 303 | dict_s2 = torch.load(sovits_path, map_location="cpu") 304 | hps = dict_s2["config"] 305 | hps = DictToAttrRecursive(hps) 306 | hps.model.semantic_frame_rate = "25hz" 307 | vq_model = SynthesizerTrn( 308 | hps.data.filter_length // 2 + 1, 309 | hps.train.segment_size // hps.data.hop_length, 310 | n_speakers=hps.data.n_speakers, 311 | **hps.model 312 | ) 313 | if ("pretrained" not in sovits_path): 314 | del vq_model.enc_q 315 | if is_half == True: 316 | vq_model = vq_model.half().to(device) 317 | else: 318 | vq_model = vq_model.to(device) 319 | vq_model.eval() 320 | print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) 321 | with open("./sweight.txt", "w", encoding="utf-8") as f: 322 | f.write(sovits_path) 323 | 324 | 325 | change_sovits_weights(sovits_path) 326 | 327 | 328 | def change_gpt_weights(gpt_path): 329 | global hz, max_sec, t2s_model, config 330 | hz = 50 331 | dict_s1 = torch.load(gpt_path, map_location="cpu") 332 | config = dict_s1["config"] 333 | max_sec = config["data"]["max_sec"] 334 | t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) 335 | t2s_model.load_state_dict(dict_s1["weight"]) 336 | if is_half == True: 337 | t2s_model = t2s_model.half() 338 | t2s_model = t2s_model.to(device) 339 | t2s_model.eval() 340 | total = sum([param.nelement() for param in t2s_model.parameters()]) 341 | print("Number of parameter: %.2fM" % (total / 1e6)) 342 | with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path) 343 | 344 | 345 | change_gpt_weights(gpt_path) 346 | 347 | 348 | def get_spepc(hps, filename): 349 | audio = load_audio(filename, int(hps.data.sampling_rate)) 350 | audio = torch.FloatTensor(audio) 351 | audio_norm = audio 352 | audio_norm = audio_norm.unsqueeze(0) 353 | spec = spectrogram_torch( 354 | audio_norm, 355 | hps.data.filter_length, 356 | hps.data.sampling_rate, 357 | hps.data.hop_length, 358 | hps.data.win_length, 359 | center=False, 360 | ) 361 | return spec 362 | 363 | 364 | 365 | 366 | def splite_en_inf(sentence, language): 367 | pattern = re.compile(r'[a-zA-Z ]+') 368 | textlist = [] 369 | langlist = [] 370 | pos = 0 371 | for match in pattern.finditer(sentence): 372 | start, end = match.span() 373 | if start > pos: 374 | textlist.append(sentence[pos:start]) 375 | langlist.append(language) 376 | textlist.append(sentence[start:end]) 377 | langlist.append("en") 378 | pos = end 379 | if pos < len(sentence): 380 | textlist.append(sentence[pos:]) 381 | langlist.append(language) 382 | # Merge punctuation into previous word 383 | for i in range(len(textlist)-1, 0, -1): 384 | if re.match(r'^[\W_]+$', textlist[i]): 385 | textlist[i-1] += textlist[i] 386 | del textlist[i] 387 | del langlist[i] 388 | # Merge consecutive words with the same language tag 389 | i = 0 390 | while i < len(langlist) - 1: 391 | if langlist[i] == langlist[i+1]: 392 | textlist[i] += textlist[i+1] 393 | del textlist[i+1] 394 | del langlist[i+1] 395 | else: 396 | i += 1 397 | 398 | return textlist, langlist 399 | 400 | 401 | def clean_text_inf(text, language): 402 | formattext = "" 403 | language = language.replace("all_","") 404 | for tmp in LangSegment.getTexts(text): 405 | if language == "ja": 406 | if tmp["lang"] == language or tmp["lang"] == "zh": 407 | formattext += tmp["text"] + " " 408 | continue 409 | if tmp["lang"] == language: 410 | formattext += tmp["text"] + " " 411 | while " " in formattext: 412 | formattext = formattext.replace(" ", " ") 413 | phones, word2ph, norm_text = clean_text(formattext, language) 414 | phones = cleaned_text_to_sequence(phones) 415 | return phones, word2ph, norm_text 416 | 417 | dtype=torch.float16 if is_half == True else torch.float32 418 | def get_bert_inf(phones, word2ph, norm_text, language): 419 | language=language.replace("all_","") 420 | if language == "zh": 421 | bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype) 422 | else: 423 | bert = torch.zeros( 424 | (1024, len(phones)), 425 | dtype=torch.float16 if is_half == True else torch.float32, 426 | ).to(device) 427 | 428 | return bert 429 | 430 | 431 | def nonen_clean_text_inf(text, language): 432 | if(language!="auto"): 433 | textlist, langlist = splite_en_inf(text, language) 434 | else: 435 | textlist=[] 436 | langlist=[] 437 | for tmp in LangSegment.getTexts(text): 438 | langlist.append(tmp["lang"]) 439 | textlist.append(tmp["text"]) 440 | phones_list = [] 441 | word2ph_list = [] 442 | norm_text_list = [] 443 | for i in range(len(textlist)): 444 | lang = langlist[i] 445 | phones, word2ph, norm_text = clean_text_inf(textlist[i], lang) 446 | phones_list.append(phones) 447 | if lang == "zh": 448 | word2ph_list.append(word2ph) 449 | norm_text_list.append(norm_text) 450 | print(word2ph_list) 451 | phones = sum(phones_list, []) 452 | word2ph = sum(word2ph_list, []) 453 | norm_text = ' '.join(norm_text_list) 454 | 455 | return phones, word2ph, norm_text 456 | 457 | 458 | def nonen_get_bert_inf(text, language): 459 | if(language!="auto"): 460 | textlist, langlist = splite_en_inf(text, language) 461 | else: 462 | textlist=[] 463 | langlist=[] 464 | for tmp in LangSegment.getTexts(text): 465 | langlist.append(tmp["lang"]) 466 | textlist.append(tmp["text"]) 467 | print(textlist) 468 | print(langlist) 469 | bert_list = [] 470 | for i in range(len(textlist)): 471 | lang = langlist[i] 472 | phones, word2ph, norm_text = clean_text_inf(textlist[i], lang) 473 | bert = get_bert_inf(phones, word2ph, norm_text, lang) 474 | bert_list.append(bert) 475 | bert = torch.cat(bert_list, dim=1) 476 | 477 | return bert 478 | 479 | 480 | 481 | 482 | def get_first(text): 483 | pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]" 484 | text = re.split(pattern, text)[0].strip() 485 | return text 486 | 487 | 488 | def get_cleaned_text_final(text,language): 489 | if language in {"en","all_zh","all_ja"}: 490 | phones, word2ph, norm_text = clean_text_inf(text, language) 491 | elif language in {"zh", "ja","auto"}: 492 | phones, word2ph, norm_text = nonen_clean_text_inf(text, language) 493 | return phones, word2ph, norm_text 494 | 495 | def get_bert_final(phones, word2ph, text,language,device): 496 | if language == "en": 497 | bert = get_bert_inf(phones, word2ph, text, language) 498 | elif language in {"zh", "ja","auto"}: 499 | bert = nonen_get_bert_inf(text, language) 500 | elif language == "all_zh": 501 | bert = get_bert_feature(text, word2ph).to(device) 502 | else: 503 | bert = torch.zeros((1024, len(phones))).to(device) 504 | return bert 505 | 506 | def merge_short_text_in_array(texts, threshold): 507 | if (len(texts)) < 2: 508 | return texts 509 | result = [] 510 | text = "" 511 | for ele in texts: 512 | text += ele 513 | if len(text) >= threshold: 514 | result.append(text) 515 | text = "" 516 | if (len(text) > 0): 517 | if len(result) == 0: 518 | result.append(text) 519 | else: 520 | result[len(result) - 1] += text 521 | return result 522 | 523 | def get_tts_wav(*,refer_wav_path, prompt_text, prompt_language="zh", text="", text_language="zh", top_k=5, top_p=1, temperature=1, ref_free = False): 524 | text+='.' 525 | print(f'{refer_wav_path=},{prompt_text=},{prompt_language=},{text=},{text_language=}') 526 | if prompt_text is None or len(prompt_text) == 0: 527 | ref_free = True 528 | t0 = ttime() 529 | #prompt_language = dict_language[prompt_language] 530 | #text_language = dict_language[text_language] 531 | if not ref_free: 532 | prompt_text = prompt_text.strip("\n") 533 | if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "." 534 | print(("实际输入的参考文本:"), prompt_text) 535 | text = text.strip("\n") 536 | for t in splits: 537 | text= text if not re.search(fr'\{t}',text) else re.sub(fr'\{t}+',t,text) 538 | 539 | if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text 540 | if len(text)<4: 541 | raise Exception('有效文字数太少,至少输入4个字符') 542 | print(("实际输入的目标文本:"), text) 543 | zero_wav = np.zeros( 544 | int(hps.data.sampling_rate * 0.3), 545 | dtype=np.float16 if is_half == True else np.float32, 546 | ) 547 | with torch.no_grad(): 548 | wav16k, sr = librosa.load(refer_wav_path, sr=16000) 549 | if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000): 550 | raise OSError(("参考音频在3~10秒范围外,请更换!")) 551 | wav16k = torch.from_numpy(wav16k) 552 | zero_wav_torch = torch.from_numpy(zero_wav) 553 | if is_half == True: 554 | wav16k = wav16k.half().to(device) 555 | zero_wav_torch = zero_wav_torch.half().to(device) 556 | else: 557 | wav16k = wav16k.to(device) 558 | zero_wav_torch = zero_wav_torch.to(device) 559 | wav16k = torch.cat([wav16k, zero_wav_torch]) 560 | ssl_content = ssl_model.model(wav16k.unsqueeze(0))[ 561 | "last_hidden_state" 562 | ].transpose( 563 | 1, 2 564 | ) # .float() 565 | codes = vq_model.extract_latent(ssl_content) 566 | 567 | prompt_semantic = codes[0, 0] 568 | t1 = ttime() 569 | print(f'{args.cut=}') 570 | if (args.cut == 1): 571 | text = cut1(text) 572 | elif (args.cut == 2): 573 | text = cut2(text) 574 | elif (args.cut == 3): 575 | text = cut3(text) 576 | elif (args.cut == 4): 577 | text = cut4(text) 578 | elif (args.cut == 5): 579 | text = cut5(text) 580 | 581 | while "\n\n" in text: 582 | text = text.replace("\n\n", "\n") 583 | print(("实际输入的目标文本(切句后):"), text) 584 | texts = text.split("\n") 585 | texts = merge_short_text_in_array(texts, 5) 586 | audio_opt = [] 587 | if not ref_free: 588 | phones1, word2ph1, norm_text1=get_cleaned_text_final(prompt_text, prompt_language) 589 | bert1=get_bert_final(phones1, word2ph1, norm_text1,prompt_language,device).to(dtype) 590 | 591 | for text in texts: 592 | # 解决输入目标文本的空行导致报错的问题 593 | if (len(text.strip()) == 0): 594 | continue 595 | if (text[-1] not in splits): text += "。" if text_language != "en" else "." 596 | print(("实际输入的目标文本(每句):"), text) 597 | phones2, word2ph2, norm_text2 = get_cleaned_text_final(text, text_language) 598 | bert2 = get_bert_final(phones2, word2ph2, norm_text2, text_language, device).to(dtype) 599 | if not ref_free: 600 | bert = torch.cat([bert1, bert2], 1) 601 | all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0) 602 | else: 603 | bert = bert2 604 | all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0) 605 | 606 | bert = bert.to(device).unsqueeze(0) 607 | all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) 608 | prompt = prompt_semantic.unsqueeze(0).to(device) 609 | t2 = ttime() 610 | with torch.no_grad(): 611 | # pred_semantic = t2s_model.model.infer( 612 | pred_semantic, idx = t2s_model.model.infer_panel( 613 | all_phoneme_ids, 614 | all_phoneme_len, 615 | None if ref_free else prompt, 616 | bert, 617 | # prompt_phone_len=ph_offset, 618 | top_k=top_k, 619 | top_p=top_p, 620 | temperature=temperature, 621 | early_stop_num=hz * max_sec, 622 | ) 623 | t3 = ttime() 624 | # print(pred_semantic.shape,idx) 625 | pred_semantic = pred_semantic[:, -idx:].unsqueeze( 626 | 0 627 | ) # .unsqueeze(0)#mq要多unsqueeze一次 628 | refer = get_spepc(hps, refer_wav_path) # .to(device) 629 | if is_half == True: 630 | refer = refer.half().to(device) 631 | else: 632 | refer = refer.to(device) 633 | # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0] 634 | audio = ( 635 | vq_model.decode( 636 | pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer 637 | ) 638 | .detach() 639 | .cpu() 640 | .numpy()[0, 0] 641 | ) ###试试重建不带上prompt部分 642 | max_audio=np.abs(audio).max()#简单防止16bit爆音 643 | if max_audio>1:audio/=max_audio 644 | audio_opt.append(audio) 645 | audio_opt.append(zero_wav) 646 | t4 = ttime() 647 | print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) 648 | yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype( 649 | np.int16 650 | ) 651 | 652 | 653 | def split(todo_text): 654 | todo_text = todo_text.replace("……", "。").replace("——", ",") 655 | if todo_text[-1] not in splits: 656 | todo_text += "。" 657 | i_split_head = i_split_tail = 0 658 | len_text = len(todo_text) 659 | todo_texts = [] 660 | while 1: 661 | if i_split_head >= len_text: 662 | break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入 663 | if todo_text[i_split_head] in splits: 664 | i_split_head += 1 665 | todo_texts.append(todo_text[i_split_tail:i_split_head]) 666 | i_split_tail = i_split_head 667 | else: 668 | i_split_head += 1 669 | return todo_texts 670 | 671 | 672 | def cut1(inp): 673 | inp = inp.strip("\n") 674 | inps = split(inp) 675 | split_idx = list(range(0, len(inps), 4)) 676 | split_idx[-1] = None 677 | if len(split_idx) > 1: 678 | opts = [] 679 | for idx in range(len(split_idx) - 1): 680 | opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]])) 681 | else: 682 | opts = [inp] 683 | return "\n".join(opts) 684 | 685 | 686 | def cut2(inp): 687 | inp = inp.strip("\n") 688 | inps = split(inp) 689 | if len(inps) < 2: 690 | return inp 691 | opts = [] 692 | summ = 0 693 | tmp_str = "" 694 | for i in range(len(inps)): 695 | summ += len(inps[i]) 696 | tmp_str += inps[i] 697 | if summ > 50: 698 | summ = 0 699 | opts.append(tmp_str) 700 | tmp_str = "" 701 | if tmp_str != "": 702 | opts.append(tmp_str) 703 | # print(opts) 704 | if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起 705 | opts[-2] = opts[-2] + opts[-1] 706 | opts = opts[:-1] 707 | return "\n".join(opts) 708 | 709 | 710 | def cut3(inp): 711 | inp = inp.strip("\n") 712 | return "\n".join(["%s" % item for item in inp.strip("。").split("。")]) 713 | 714 | 715 | def cut4(inp): 716 | inp = inp.strip("\n") 717 | return "\n".join(["%s" % item for item in inp.strip(".").split(".")]) 718 | 719 | 720 | # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py 721 | def cut5(inp): 722 | # if not re.search(r'[^\w\s]', inp[-1]): 723 | # inp += '。' 724 | inp = inp.strip("\n") 725 | punds = r'[,.;?!、,。?!;:]' 726 | items = re.split(f'({punds})', inp) 727 | items = ["".join(group) for group in zip(items[::2], items[1::2])] 728 | opt = "\n".join(items) 729 | return opt 730 | 731 | 732 | def custom_sort_key(s): 733 | # 使用正则表达式提取字符串中的数字部分和非数字部分 734 | parts = re.split('(\d+)', s) 735 | # 将数字部分转换为整数,非数字部分保持不变 736 | parts = [int(part) if part.isdigit() else part for part in parts] 737 | return parts 738 | 739 | 740 | def change_choices(): 741 | SoVITS_names, GPT_names = get_weights_names() 742 | return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"} 743 | 744 | 745 | 746 | 747 | 748 | def get_weights_names(): 749 | SoVITS_names = [sovits_path] 750 | for name in os.listdir(SoVITS_weight_root): 751 | if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (SoVITS_weight_root, name)) 752 | GPT_names = [gpt_path] 753 | for name in os.listdir(GPT_weight_root): 754 | if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (GPT_weight_root, name)) 755 | return SoVITS_names, GPT_names 756 | 757 | 758 | SoVITS_names, GPT_names = get_weights_names() 759 | 760 | 761 | app = FastAPI() 762 | 763 | def handle(refer_wav_path, prompt_text, prompt_language, text, text_language): 764 | if ( 765 | refer_wav_path == "" or refer_wav_path is None 766 | or prompt_text == "" or prompt_text is None 767 | or prompt_language == "" or prompt_language is None 768 | ): 769 | refer_wav_path,prompt_text,prompt_language=default_wav,default_text,default_language 770 | if not refer_wav_path or not prompt_text or not prompt_language: 771 | return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) 772 | 773 | with torch.no_grad(): 774 | gen = get_tts_wav( 775 | refer_wav_path=refer_wav_path, prompt_text=prompt_text, prompt_language=prompt_language, text=text, text_language=text_language 776 | ) 777 | sampling_rate, audio_data = next(gen) 778 | 779 | wav = BytesIO() 780 | sf.write(wav, audio_data, sampling_rate, format="wav") 781 | wav.seek(0) 782 | 783 | torch.cuda.empty_cache() 784 | if device == "mps": 785 | print('executed torch.mps.empty_cache()') 786 | torch.mps.empty_cache() 787 | return StreamingResponse(wav, media_type="audio/wav") 788 | 789 | @app.post("/") 790 | async def tts_endpoint(request: Request): 791 | json_post_raw = await request.json() 792 | return handle( 793 | json_post_raw.get("refer_wav_path"), 794 | json_post_raw.get("prompt_text"), 795 | json_post_raw.get("prompt_language"), 796 | json_post_raw.get("text"), 797 | json_post_raw.get("text_language"), 798 | ) 799 | 800 | 801 | @app.get("/") 802 | async def tts_endpoint( 803 | refer_wav_path: str = None, 804 | prompt_text: str = None, 805 | prompt_language: str = None, 806 | text: str = None, 807 | text_language: str = None, 808 | ): 809 | return handle(refer_wav_path, prompt_text, prompt_language, text, text_language) 810 | 811 | 812 | 813 | if __name__ == "__main__": 814 | uvicorn.run(app, host=host, port=port, workers=1) 815 | 816 | --------------------------------------------------------------------------------