├── .env
├── .idea
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── robot_by_qkr_v1.iml
└── workspace.xml
├── README.md
├── VITS
├── VC_inference.py
├── __pycache__
│ ├── VC_inference.cpython-310.pyc
│ ├── attentions.cpython-310.pyc
│ ├── commons.cpython-310.pyc
│ ├── mel_processing.cpython-310.pyc
│ ├── models.cpython-310.pyc
│ ├── modules.cpython-310.pyc
│ ├── transforms.cpython-310.pyc
│ └── utils.cpython-310.pyc
├── attentions.py
├── commons.py
├── configs
│ ├── modified_finetune_speaker.json
│ └── uma_trilingual.json
├── mel_processing.py
├── models.py
├── models_infer.py
├── modules.py
├── monotonic_align
│ ├── __init__.py
│ ├── __pycache__
│ │ └── __init__.cpython-310.pyc
│ ├── build
│ │ ├── lib.win-amd64-3.10
│ │ │ └── monotonic_align
│ │ │ │ └── core.cp310-win_amd64.pyd
│ │ └── temp.win-amd64-3.10
│ │ │ └── Release
│ │ │ ├── core.cp310-win_amd64.exp
│ │ │ ├── core.cp310-win_amd64.lib
│ │ │ └── core.obj
│ ├── core.c
│ ├── core.pyx
│ ├── monotonic_align
│ │ └── core.cp310-win_amd64.pyd
│ └── setup.py
├── text
│ ├── LICENSE
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── cleaners.cpython-310.pyc
│ │ ├── english.cpython-310.pyc
│ │ ├── japanese.cpython-310.pyc
│ │ ├── korean.cpython-310.pyc
│ │ ├── mandarin.cpython-310.pyc
│ │ ├── sanskrit.cpython-310.pyc
│ │ ├── symbols.cpython-310.pyc
│ │ └── thai.cpython-310.pyc
│ ├── cantonese.py
│ ├── cleaners.py
│ ├── english.py
│ ├── japanese.py
│ ├── korean.py
│ ├── mandarin.py
│ ├── ngu_dialect.py
│ ├── sanskrit.py
│ ├── shanghainese.py
│ ├── symbols.py
│ └── thai.py
├── transforms.py
└── utils.py
├── baiduApi.py
├── gptApi.py
├── gptApi_duolun.py
├── input
└── input.wav
├── output
└── output.wav
├── requirements.txt
└── robot.py
/.env:
--------------------------------------------------------------------------------
1 | api_id=****
2 | api_key=*****
3 | api_secert=*****
4 | openai_api=*****
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # 默认忽略的文件
2 | /shelf/
3 | /workspace.xml
4 | # 基于编辑器的 HTTP 客户端请求
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/robot_by_qkr_v1.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 | 1679894164177
143 |
144 |
145 | 1679894164177
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ### 关于项目
2 | 这是一个python制作的语音对话机器人,可以与语音输入与语音输出,来与chatgpt对话
3 |
4 | ### 使用方法
5 | * 安装依赖库 `pip intsall -r requirement.txt`
6 | * 下载vits训练的模型及其配置文件并放置到VITS目录下,这里我提供我我自己训练的阿米娅模型,仅供学习交流。
7 | * https://pan.baidu.com/s/1OIqAHzItdxgYkAma942VvQ?pwd=fuou
8 | * 申请百度云api和openai-api,并填写在.env目录下
9 | * 入口是robot,愉快的对话吧
10 |
11 | ps:反代服务器到期了QAQ,要用的话还是直接挂梯子吧,修改api文件的url
12 |
--------------------------------------------------------------------------------
/VITS/VC_inference.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import no_grad, LongTensor
4 | import argparse
5 | from VITS import commons#自建文件
6 | #from mel_processing import spectrogram_torch#自建文件
7 | from VITS import utils#自建文件
8 | from VITS.models import SynthesizerTrn#自建文件
9 | import librosa
10 | from playsound import playsound
11 | import logging
12 |
13 |
14 | from VITS.text import text_to_sequence, _clean_text
15 |
16 | from scipy.io.wavfile import write
17 |
18 | def to_wav(audio,path):
19 | #将ndarray转换为int16数据类型(必须为16位有符号整数)
20 | # data = audio * 32767
21 | # data = data.astype(np.int16)
22 | data=audio
23 | # 设置采样率
24 | sample_rate = 24100
25 |
26 | # 将数据写入.wav文件
27 | write(path, sample_rate, data)
28 |
29 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
30 | language_marks = {
31 | "Japanese": "",
32 | "日本語": "[JA]",
33 | "简体中文": "[ZH]",
34 | "English": "[EN]",
35 | "Mix": "",
36 | }
37 | lang = ['日本語', '简体中文', 'English', 'Mix']
38 | def get_text(text, hps, is_symbol):
39 | text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
40 | if hps.data.add_blank:
41 | text_norm = commons.intersperse(text_norm, 0)
42 | text_norm = LongTensor(text_norm)
43 | return text_norm
44 |
45 | def create_tts_fn(model, hps, speaker_ids):
46 | def tts_fn(text, speaker, language, speed,path):
47 | if language is not None:
48 | text = language_marks[language] + text + language_marks[language]#选择推理语言
49 | speaker_id = speaker_ids[speaker]#选择说话的人的id
50 | stn_tst = get_text(text, hps, False)#获取文本并选择模型文件(hps)
51 | with no_grad():
52 | x_tst = stn_tst.unsqueeze(0).to(device)
53 | x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device)
54 | sid = LongTensor([speaker_id]).to(device)
55 | #src="http://127.0.0.1:7860/file=C:\Users\15093289086\AppData\Local\Temp\tmpjz4wxooq.wav"
56 | audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8,
57 | length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()#
58 | ##print(type(audio))
59 | to_wav(audio,path)
60 |
61 | del stn_tst, x_tst, x_tst_lengths, sid
62 | ##print("Success", (hps.data.sampling_rate, audio))
63 | return "Success", (hps.data.sampling_rate, audio)
64 |
65 | return tts_fn
66 |
67 | # def create_vc_fn(model, hps, speaker_ids):
68 | # def vc_fn(original_speaker, target_speaker, record_audio, upload_audio):
69 | # input_audio = record_audio if record_audio is not None else upload_audio
70 | # if input_audio is None:
71 | # return "You need to record or upload an audio", None
72 | # sampling_rate, audio = input_audio
73 | # original_speaker_id = speaker_ids[original_speaker]
74 | # target_speaker_id = speaker_ids[target_speaker]
75 | #
76 | # audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32)
77 | # if len(audio.shape) > 1:
78 | # audio = librosa.to_mono(audio.transpose(1, 0))
79 | # if sampling_rate != hps.data.sampling_rate:
80 | # audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=hps.data.sampling_rate)
81 | # with no_grad():
82 | # y = torch.FloatTensor(audio)
83 | # y = y / max(-y.min(), y.max()) / 0.99
84 | # y = y.to(device)
85 | # y = y.unsqueeze(0)
86 | # spec = spectrogram_torch(y, hps.data.filter_length,
87 | # hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
88 | # center=False).to(device)
89 | # spec_lengths = LongTensor([spec.size(-1)]).to(device)
90 | # sid_src = LongTensor([original_speaker_id]).to(device)
91 | # sid_tgt = LongTensor([target_speaker_id]).to(device)
92 | # audio = model.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt)[0][
93 | # 0, 0].data.cpu().float().numpy()
94 | # del y, spec, spec_lengths, sid_src, sid_tgt
95 | # return "Success", (hps.data.sampling_rate, audio)
96 | #
97 | # return vc_fn
98 |
99 |
100 |
101 | def getVoice(text,path):
102 | logging.getLogger('jieba').disabled = True#禁用日志输出
103 | parser = argparse.ArgumentParser()
104 | parser.add_argument("--model_dir", default="VITS/G_latest.pth", help="directory to your fine-tuned model")
105 | parser.add_argument("--config_dir", default="VITS/finetune_speaker.json", help="directory to your model config file")
106 | parser.add_argument("--share", default=False, help="make link public (used in colab)")
107 |
108 | args = parser.parse_args()
109 | hps = utils.get_hparams_from_file(args.config_dir)
110 |
111 | net_g = SynthesizerTrn(
112 | len(hps.symbols),
113 | hps.data.filter_length // 2 + 1,
114 | hps.train.segment_size // hps.data.hop_length,
115 | n_speakers=hps.data.n_speakers,
116 | **hps.model).to(device)
117 | _ = net_g.eval()
118 |
119 | _ = utils.load_checkpoint(args.model_dir, net_g, None)
120 | speaker_ids = hps.speakers
121 | speakers = list(hps.speakers.keys())
122 | tts_fn = create_tts_fn(net_g, hps, speaker_ids)
123 | # vc_fn = create_vc_fn(net_g, hps, speaker_ids)
124 |
125 | # 本地运行,不挂web
126 | textbox = text
127 | char_dropdown = speakers[0]
128 | language_dropdown = lang[1]
129 | duration_slider = 0.8
130 |
131 | tts_fn(textbox, char_dropdown, language_dropdown, duration_slider,path)
132 |
133 | #print('okk...')
134 | # playsound('../output/output.wav')
135 |
136 | if __name__ == "__main__":
137 |
138 | #logging.getLogger().setLevel(logging.ERROR)
139 | logging.getLogger('jieba').disabled = True
140 | parser = argparse.ArgumentParser()
141 | parser.add_argument("--model_dir", default="./G_latest.pth", help="directory to your fine-tuned model")
142 | parser.add_argument("--config_dir", default="./finetune_speaker.json", help="directory to your model config file")
143 | parser.add_argument("--share", default=False, help="make link public (used in colab)")
144 |
145 | args = parser.parse_args()
146 | hps = utils.get_hparams_from_file(args.config_dir)
147 |
148 |
149 | net_g = SynthesizerTrn(
150 | len(hps.symbols),
151 | hps.data.filter_length // 2 + 1,
152 | hps.train.segment_size // hps.data.hop_length,
153 | n_speakers=hps.data.n_speakers,
154 | **hps.model).to(device)
155 | _ = net_g.eval()
156 |
157 | _ = utils.load_checkpoint(args.model_dir, net_g, None)
158 | speaker_ids = hps.speakers
159 | speakers = list(hps.speakers.keys())
160 | tts_fn = create_tts_fn(net_g, hps, speaker_ids)
161 | #vc_fn = create_vc_fn(net_g, hps, speaker_ids)
162 |
163 | #本地运行,不挂web
164 | textbox = "你好,博士。"
165 | char_dropdown = speakers[0]
166 | language_dropdown = lang[1]
167 | duration_slider = 0.8
168 |
169 | tts_fn(textbox, char_dropdown, language_dropdown, duration_slider, './output.wav')
170 |
171 | print('okk...')
172 | playsound('../output/output.wav')
173 |
174 |
175 | # app = gr.Blocks()
176 | # with app:
177 | # with gr.Tab("Text-to-Speech"):
178 | # with gr.Row():
179 | # with gr.Column():
180 | # textbox = gr.TextArea(label="Text",
181 | # placeholder="Type your sentence here",
182 | # value="博士,很高兴见到你,欢迎回来。", elem_id=f"tts-input")
183 | # # select character
184 | # char_dropdown = gr.Dropdown(choices=speakers, value=speakers[0], label='character')
185 | # #print(char_dropdown)
186 | # language_dropdown = gr.Dropdown(choices=lang, value=lang[1], label='language')
187 | # #print(language_dropdown)
188 | # duration_slider = gr.Slider(minimum=0.1, maximum=5, value=1, step=0.1,
189 | # label='速度 Speed')
190 | # with gr.Column():
191 | # text_output = gr.Textbox(label="Message")
192 | # audio_output = gr.Audio(label="Output Audio", elem_id="tts-audio")
193 | # btn = gr.Button("Generate!")
194 | # #res=tts_fn(textbox, char_dropdown, language_dropdown, duration_slider)
195 | # ##print(res)
196 | # btn.click(tts_fn,#发包函数
197 | # inputs=[textbox, char_dropdown, language_dropdown, duration_slider,],
198 | # outputs=[text_output, audio_output])#这是绑定了函数
199 | #
200 | # #可以写成tts_fn(textbox, char_dropdown, language_dropdown, duration_slider)
201 | #
202 | # with gr.Tab("Voice Conversion"):#这个是声音转声音的
203 | # gr.Markdown("""
204 | # 录制或上传声音,并选择要转换的音色。User代表的音色是你自己。
205 | # """)
206 | # with gr.Column():
207 | # record_audio = gr.Audio(label="record your voice", source="microphone")
208 | # upload_audio = gr.Audio(label="or upload audio here", source="upload")
209 | # source_speaker = gr.Dropdown(choices=speakers, value="User", label="source speaker")
210 | # target_speaker = gr.Dropdown(choices=speakers, value=speakers[0], label="target speaker")
211 | # with gr.Column():
212 | # message_box = gr.Textbox(label="Message")
213 | # converted_audio = gr.Audio(label='converted audio')
214 | # btn = gr.Button("Convert!")
215 | # btn.click(vc_fn, inputs=[source_speaker, target_speaker, record_audio, upload_audio],
216 | # outputs=[message_box, converted_audio])
217 | # webbrowser.open("http://127.0.0.1:7860")
218 | # app.launch(share=args.share)
219 | #
220 |
--------------------------------------------------------------------------------
/VITS/__pycache__/VC_inference.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/__pycache__/VC_inference.cpython-310.pyc
--------------------------------------------------------------------------------
/VITS/__pycache__/attentions.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/__pycache__/attentions.cpython-310.pyc
--------------------------------------------------------------------------------
/VITS/__pycache__/commons.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/__pycache__/commons.cpython-310.pyc
--------------------------------------------------------------------------------
/VITS/__pycache__/mel_processing.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/__pycache__/mel_processing.cpython-310.pyc
--------------------------------------------------------------------------------
/VITS/__pycache__/models.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/__pycache__/models.cpython-310.pyc
--------------------------------------------------------------------------------
/VITS/__pycache__/modules.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/__pycache__/modules.cpython-310.pyc
--------------------------------------------------------------------------------
/VITS/__pycache__/transforms.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/__pycache__/transforms.cpython-310.pyc
--------------------------------------------------------------------------------
/VITS/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/VITS/attentions.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 |
8 | from VITS import commons
9 | #from VITS import modules
10 | from VITS.modules import LayerNorm
11 |
12 |
13 | class Encoder(nn.Module):
14 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs):
15 | super().__init__()
16 | self.hidden_channels = hidden_channels
17 | self.filter_channels = filter_channels
18 | self.n_heads = n_heads
19 | self.n_layers = n_layers
20 | self.kernel_size = kernel_size
21 | self.p_dropout = p_dropout
22 | self.window_size = window_size
23 |
24 | self.drop = nn.Dropout(p_dropout)
25 | self.attn_layers = nn.ModuleList()
26 | self.norm_layers_1 = nn.ModuleList()
27 | self.ffn_layers = nn.ModuleList()
28 | self.norm_layers_2 = nn.ModuleList()
29 | for i in range(self.n_layers):
30 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
31 | self.norm_layers_1.append(LayerNorm(hidden_channels))
32 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
33 | self.norm_layers_2.append(LayerNorm(hidden_channels))
34 |
35 | def forward(self, x, x_mask):
36 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
37 | x = x * x_mask
38 | for i in range(self.n_layers):
39 | y = self.attn_layers[i](x, x, attn_mask)
40 | y = self.drop(y)
41 | x = self.norm_layers_1[i](x + y)
42 |
43 | y = self.ffn_layers[i](x, x_mask)
44 | y = self.drop(y)
45 | x = self.norm_layers_2[i](x + y)
46 | x = x * x_mask
47 | return x
48 |
49 |
50 | class Decoder(nn.Module):
51 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
52 | super().__init__()
53 | self.hidden_channels = hidden_channels
54 | self.filter_channels = filter_channels
55 | self.n_heads = n_heads
56 | self.n_layers = n_layers
57 | self.kernel_size = kernel_size
58 | self.p_dropout = p_dropout
59 | self.proximal_bias = proximal_bias
60 | self.proximal_init = proximal_init
61 |
62 | self.drop = nn.Dropout(p_dropout)
63 | self.self_attn_layers = nn.ModuleList()
64 | self.norm_layers_0 = nn.ModuleList()
65 | self.encdec_attn_layers = nn.ModuleList()
66 | self.norm_layers_1 = nn.ModuleList()
67 | self.ffn_layers = nn.ModuleList()
68 | self.norm_layers_2 = nn.ModuleList()
69 | for i in range(self.n_layers):
70 | self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init))
71 | self.norm_layers_0.append(LayerNorm(hidden_channels))
72 | self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
73 | self.norm_layers_1.append(LayerNorm(hidden_channels))
74 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
75 | self.norm_layers_2.append(LayerNorm(hidden_channels))
76 |
77 | def forward(self, x, x_mask, h, h_mask):
78 | """
79 | x: decoder input
80 | h: encoder output
81 | """
82 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
83 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
84 | x = x * x_mask
85 | for i in range(self.n_layers):
86 | y = self.self_attn_layers[i](x, x, self_attn_mask)
87 | y = self.drop(y)
88 | x = self.norm_layers_0[i](x + y)
89 |
90 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
91 | y = self.drop(y)
92 | x = self.norm_layers_1[i](x + y)
93 |
94 | y = self.ffn_layers[i](x, x_mask)
95 | y = self.drop(y)
96 | x = self.norm_layers_2[i](x + y)
97 | x = x * x_mask
98 | return x
99 |
100 |
101 | class MultiHeadAttention(nn.Module):
102 | def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
103 | super().__init__()
104 | assert channels % n_heads == 0
105 |
106 | self.channels = channels
107 | self.out_channels = out_channels
108 | self.n_heads = n_heads
109 | self.p_dropout = p_dropout
110 | self.window_size = window_size
111 | self.heads_share = heads_share
112 | self.block_length = block_length
113 | self.proximal_bias = proximal_bias
114 | self.proximal_init = proximal_init
115 | self.attn = None
116 |
117 | self.k_channels = channels // n_heads
118 | self.conv_q = nn.Conv1d(channels, channels, 1)
119 | self.conv_k = nn.Conv1d(channels, channels, 1)
120 | self.conv_v = nn.Conv1d(channels, channels, 1)
121 | self.conv_o = nn.Conv1d(channels, out_channels, 1)
122 | self.drop = nn.Dropout(p_dropout)
123 |
124 | if window_size is not None:
125 | n_heads_rel = 1 if heads_share else n_heads
126 | rel_stddev = self.k_channels**-0.5
127 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
128 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
129 |
130 | nn.init.xavier_uniform_(self.conv_q.weight)
131 | nn.init.xavier_uniform_(self.conv_k.weight)
132 | nn.init.xavier_uniform_(self.conv_v.weight)
133 | if proximal_init:
134 | with torch.no_grad():
135 | self.conv_k.weight.copy_(self.conv_q.weight)
136 | self.conv_k.bias.copy_(self.conv_q.bias)
137 |
138 | def forward(self, x, c, attn_mask=None):
139 | q = self.conv_q(x)
140 | k = self.conv_k(c)
141 | v = self.conv_v(c)
142 |
143 | x, self.attn = self.attention(q, k, v, mask=attn_mask)
144 |
145 | x = self.conv_o(x)
146 | return x
147 |
148 | def attention(self, query, key, value, mask=None):
149 | # reshape [b, d, t] -> [b, n_h, t, d_k]
150 | b, d, t_s, t_t = (*key.size(), query.size(2))
151 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
152 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
153 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
154 |
155 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
156 | if self.window_size is not None:
157 | assert t_s == t_t, "Relative attention is only available for self-attention."
158 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
159 | rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings)
160 | scores_local = self._relative_position_to_absolute_position(rel_logits)
161 | scores = scores + scores_local
162 | if self.proximal_bias:
163 | assert t_s == t_t, "Proximal bias is only available for self-attention."
164 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
165 | if mask is not None:
166 | scores = scores.masked_fill(mask == 0, -1e4)
167 | if self.block_length is not None:
168 | assert t_s == t_t, "Local attention is only available for self-attention."
169 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
170 | scores = scores.masked_fill(block_mask == 0, -1e4)
171 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
172 | p_attn = self.drop(p_attn)
173 | output = torch.matmul(p_attn, value)
174 | if self.window_size is not None:
175 | relative_weights = self._absolute_position_to_relative_position(p_attn)
176 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
177 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
178 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
179 | return output, p_attn
180 |
181 | def _matmul_with_relative_values(self, x, y):
182 | """
183 | x: [b, h, l, m]
184 | y: [h or 1, m, d]
185 | ret: [b, h, l, d]
186 | """
187 | ret = torch.matmul(x, y.unsqueeze(0))
188 | return ret
189 |
190 | def _matmul_with_relative_keys(self, x, y):
191 | """
192 | x: [b, h, l, d]
193 | y: [h or 1, m, d]
194 | ret: [b, h, l, m]
195 | """
196 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
197 | return ret
198 |
199 | def _get_relative_embeddings(self, relative_embeddings, length):
200 | max_relative_position = 2 * self.window_size + 1
201 | # Pad first before slice to avoid using cond ops.
202 | pad_length = max(length - (self.window_size + 1), 0)
203 | slice_start_position = max((self.window_size + 1) - length, 0)
204 | slice_end_position = slice_start_position + 2 * length - 1
205 | if pad_length > 0:
206 | padded_relative_embeddings = F.pad(
207 | relative_embeddings,
208 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
209 | else:
210 | padded_relative_embeddings = relative_embeddings
211 | used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position]
212 | return used_relative_embeddings
213 |
214 | def _relative_position_to_absolute_position(self, x):
215 | """
216 | x: [b, h, l, 2*l-1]
217 | ret: [b, h, l, l]
218 | """
219 | batch, heads, length, _ = x.size()
220 | # Concat columns of pad to shift from relative to absolute indexing.
221 | x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
222 |
223 | # Concat extra elements so to add up to shape (len+1, 2*len-1).
224 | x_flat = x.view([batch, heads, length * 2 * length])
225 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]]))
226 |
227 | # Reshape and slice out the padded elements.
228 | x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
229 | return x_final
230 |
231 | def _absolute_position_to_relative_position(self, x):
232 | """
233 | x: [b, h, l, l]
234 | ret: [b, h, l, 2*l-1]
235 | """
236 | batch, heads, length, _ = x.size()
237 | # padd along column
238 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
239 | x_flat = x.view([batch, heads, length**2 + length*(length -1)])
240 | # add 0's in the beginning that will skew the elements after reshape
241 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
242 | x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
243 | return x_final
244 |
245 | def _attention_bias_proximal(self, length):
246 | """Bias for self-attention to encourage attention to close positions.
247 | Args:
248 | length: an integer scalar.
249 | Returns:
250 | a Tensor with shape [1, 1, length, length]
251 | """
252 | r = torch.arange(length, dtype=torch.float32)
253 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
254 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
255 |
256 |
257 | class FFN(nn.Module):
258 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False):
259 | super().__init__()
260 | self.in_channels = in_channels
261 | self.out_channels = out_channels
262 | self.filter_channels = filter_channels
263 | self.kernel_size = kernel_size
264 | self.p_dropout = p_dropout
265 | self.activation = activation
266 | self.causal = causal
267 |
268 | if causal:
269 | self.padding = self._causal_padding
270 | else:
271 | self.padding = self._same_padding
272 |
273 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
274 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
275 | self.drop = nn.Dropout(p_dropout)
276 |
277 | def forward(self, x, x_mask):
278 | x = self.conv_1(self.padding(x * x_mask))
279 | if self.activation == "gelu":
280 | x = x * torch.sigmoid(1.702 * x)
281 | else:
282 | x = torch.relu(x)
283 | x = self.drop(x)
284 | x = self.conv_2(self.padding(x * x_mask))
285 | return x * x_mask
286 |
287 | def _causal_padding(self, x):
288 | if self.kernel_size == 1:
289 | return x
290 | pad_l = self.kernel_size - 1
291 | pad_r = 0
292 | padding = [[0, 0], [0, 0], [pad_l, pad_r]]
293 | x = F.pad(x, commons.convert_pad_shape(padding))
294 | return x
295 |
296 | def _same_padding(self, x):
297 | if self.kernel_size == 1:
298 | return x
299 | pad_l = (self.kernel_size - 1) // 2
300 | pad_r = self.kernel_size // 2
301 | padding = [[0, 0], [0, 0], [pad_l, pad_r]]
302 | x = F.pad(x, commons.convert_pad_shape(padding))
303 | return x
304 |
--------------------------------------------------------------------------------
/VITS/commons.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 |
8 | def init_weights(m, mean=0.0, std=0.01):
9 | classname = m.__class__.__name__
10 | if classname.find("Conv") != -1:
11 | m.weight.data.normal_(mean, std)
12 |
13 |
14 | def get_padding(kernel_size, dilation=1):
15 | return int((kernel_size*dilation - dilation)/2)
16 |
17 |
18 | def convert_pad_shape(pad_shape):
19 | l = pad_shape[::-1]
20 | pad_shape = [item for sublist in l for item in sublist]
21 | return pad_shape
22 |
23 |
24 | def intersperse(lst, item):
25 | result = [item] * (len(lst) * 2 + 1)
26 | result[1::2] = lst
27 | return result
28 |
29 |
30 | def kl_divergence(m_p, logs_p, m_q, logs_q):
31 | """KL(P||Q)"""
32 | kl = (logs_q - logs_p) - 0.5
33 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
34 | return kl
35 |
36 |
37 | def rand_gumbel(shape):
38 | """Sample from the Gumbel distribution, protect from overflows."""
39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40 | return -torch.log(-torch.log(uniform_samples))
41 |
42 |
43 | def rand_gumbel_like(x):
44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45 | return g
46 |
47 |
48 | def slice_segments(x, ids_str, segment_size=4):
49 | ret = torch.zeros_like(x[:, :, :segment_size])
50 | for i in range(x.size(0)):
51 | idx_str = ids_str[i]
52 | idx_end = idx_str + segment_size
53 | try:
54 | ret[i] = x[i, :, idx_str:idx_end]
55 | except RuntimeError:
56 | print("?")
57 | return ret
58 |
59 |
60 | def rand_slice_segments(x, x_lengths=None, segment_size=4):
61 | b, d, t = x.size()
62 | if x_lengths is None:
63 | x_lengths = t
64 | ids_str_max = x_lengths - segment_size + 1
65 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
66 | ret = slice_segments(x, ids_str, segment_size)
67 | return ret, ids_str
68 |
69 |
70 | def get_timing_signal_1d(
71 | length, channels, min_timescale=1.0, max_timescale=1.0e4):
72 | position = torch.arange(length, dtype=torch.float)
73 | num_timescales = channels // 2
74 | log_timescale_increment = (
75 | math.log(float(max_timescale) / float(min_timescale)) /
76 | (num_timescales - 1))
77 | inv_timescales = min_timescale * torch.exp(
78 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
79 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
80 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
81 | signal = F.pad(signal, [0, 0, 0, channels % 2])
82 | signal = signal.view(1, channels, length)
83 | return signal
84 |
85 |
86 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
87 | b, channels, length = x.size()
88 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
89 | return x + signal.to(dtype=x.dtype, device=x.device)
90 |
91 |
92 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
93 | b, channels, length = x.size()
94 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
95 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
96 |
97 |
98 | def subsequent_mask(length):
99 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
100 | return mask
101 |
102 |
103 | @torch.jit.script
104 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
105 | n_channels_int = n_channels[0]
106 | in_act = input_a + input_b
107 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
108 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
109 | acts = t_act * s_act
110 | return acts
111 |
112 |
113 | def convert_pad_shape(pad_shape):
114 | l = pad_shape[::-1]
115 | pad_shape = [item for sublist in l for item in sublist]
116 | return pad_shape
117 |
118 |
119 | def shift_1d(x):
120 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
121 | return x
122 |
123 |
124 | def sequence_mask(length, max_length=None):
125 | if max_length is None:
126 | max_length = length.max()
127 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)
128 | return x.unsqueeze(0) < length.unsqueeze(1)
129 |
130 |
131 | def generate_path(duration, mask):
132 | """
133 | duration: [b, 1, t_x]
134 | mask: [b, 1, t_y, t_x]
135 | """
136 | device = duration.device
137 |
138 | b, _, t_y, t_x = mask.shape
139 | cum_duration = torch.cumsum(duration, -1)
140 |
141 | cum_duration_flat = cum_duration.view(b * t_x)
142 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
143 | path = path.view(b, t_x, t_y)
144 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
145 | path = path.unsqueeze(1).transpose(2,3) * mask
146 | return path
147 |
148 |
149 | def clip_grad_value_(parameters, clip_value, norm_type=2):
150 | if isinstance(parameters, torch.Tensor):
151 | parameters = [parameters]
152 | parameters = list(filter(lambda p: p.grad is not None, parameters))
153 | norm_type = float(norm_type)
154 | if clip_value is not None:
155 | clip_value = float(clip_value)
156 |
157 | total_norm = 0
158 | for p in parameters:
159 | param_norm = p.grad.data.norm(norm_type)
160 | total_norm += param_norm.item() ** norm_type
161 | if clip_value is not None:
162 | p.grad.data.clamp_(min=-clip_value, max=clip_value)
163 | total_norm = total_norm ** (1. / norm_type)
164 | return total_norm
165 |
--------------------------------------------------------------------------------
/VITS/configs/modified_finetune_speaker.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 10,
4 | "eval_interval": 100,
5 | "seed": 1234,
6 | "epochs": 10000,
7 | "learning_rate": 0.0002,
8 | "betas": [
9 | 0.8,
10 | 0.99
11 | ],
12 | "eps": 1e-09,
13 | "batch_size": 16,
14 | "fp16_run": true,
15 | "lr_decay": 0.999875,
16 | "segment_size": 8192,
17 | "init_lr_ratio": 1,
18 | "warmup_epochs": 0,
19 | "c_mel": 45,
20 | "c_kl": 1.0
21 | },
22 | "data": {
23 | "training_files": "final_annotation_train.txt",
24 | "validation_files": "final_annotation_val.txt",
25 | "text_cleaners": [
26 | "chinese_cleaners"
27 | ],
28 | "max_wav_value": 32768.0,
29 | "sampling_rate": 22050,
30 | "filter_length": 1024,
31 | "hop_length": 256,
32 | "win_length": 1024,
33 | "n_mel_channels": 80,
34 | "mel_fmin": 0.0,
35 | "mel_fmax": null,
36 | "add_blank": true,
37 | "n_speakers": 2,
38 | "cleaned_text": true
39 | },
40 | "model": {
41 | "inter_channels": 192,
42 | "hidden_channels": 192,
43 | "filter_channels": 768,
44 | "n_heads": 2,
45 | "n_layers": 6,
46 | "kernel_size": 3,
47 | "p_dropout": 0.1,
48 | "resblock": "1",
49 | "resblock_kernel_sizes": [
50 | 3,
51 | 7,
52 | 11
53 | ],
54 | "resblock_dilation_sizes": [
55 | [
56 | 1,
57 | 3,
58 | 5
59 | ],
60 | [
61 | 1,
62 | 3,
63 | 5
64 | ],
65 | [
66 | 1,
67 | 3,
68 | 5
69 | ]
70 | ],
71 | "upsample_rates": [
72 | 8,
73 | 8,
74 | 2,
75 | 2
76 | ],
77 | "upsample_initial_channel": 512,
78 | "upsample_kernel_sizes": [
79 | 16,
80 | 16,
81 | 4,
82 | 4
83 | ],
84 | "n_layers_q": 3,
85 | "use_spectral_norm": false,
86 | "gin_channels": 256
87 | },
88 | "symbols": [
89 | "_",
90 | "\uff1b",
91 | "\uff1a",
92 | "\uff0c",
93 | "\u3002",
94 | "\uff01",
95 | "\uff1f",
96 | "-",
97 | "\u201c",
98 | "\u201d",
99 | "\u300a",
100 | "\u300b",
101 | "\u3001",
102 | "\uff08",
103 | "\uff09",
104 | "\u2026",
105 | "\u2014",
106 | " ",
107 | "A",
108 | "B",
109 | "C",
110 | "D",
111 | "E",
112 | "F",
113 | "G",
114 | "H",
115 | "I",
116 | "J",
117 | "K",
118 | "L",
119 | "M",
120 | "N",
121 | "O",
122 | "P",
123 | "Q",
124 | "R",
125 | "S",
126 | "T",
127 | "U",
128 | "V",
129 | "W",
130 | "X",
131 | "Y",
132 | "Z",
133 | "a",
134 | "b",
135 | "c",
136 | "d",
137 | "e",
138 | "f",
139 | "g",
140 | "h",
141 | "i",
142 | "j",
143 | "k",
144 | "l",
145 | "m",
146 | "n",
147 | "o",
148 | "p",
149 | "q",
150 | "r",
151 | "s",
152 | "t",
153 | "u",
154 | "v",
155 | "w",
156 | "x",
157 | "y",
158 | "z",
159 | "1",
160 | "2",
161 | "3",
162 | "4",
163 | "5",
164 | "0",
165 | "\uff22",
166 | "\uff30"
167 | ],
168 | "speakers": {
169 | "dingzhen": 0,
170 | "taffy": 1
171 | }
172 | }
--------------------------------------------------------------------------------
/VITS/configs/uma_trilingual.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 10000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 16,
11 | "fp16_run": true,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "training_files":"../CH_JA_EN_mix_voice/clipped_3_vits_trilingual_annotations.train.txt.cleaned",
21 | "validation_files":"../CH_JA_EN_mix_voice/clipped_3_vits_trilingual_annotations.val.txt.cleaned",
22 | "text_cleaners":["cjke_cleaners2"],
23 | "max_wav_value": 32768.0,
24 | "sampling_rate": 22050,
25 | "filter_length": 1024,
26 | "hop_length": 256,
27 | "win_length": 1024,
28 | "n_mel_channels": 80,
29 | "mel_fmin": 0.0,
30 | "mel_fmax": null,
31 | "add_blank": true,
32 | "n_speakers": 999,
33 | "cleaned_text": true
34 | },
35 | "model": {
36 | "inter_channels": 192,
37 | "hidden_channels": 192,
38 | "filter_channels": 768,
39 | "n_heads": 2,
40 | "n_layers": 6,
41 | "kernel_size": 3,
42 | "p_dropout": 0.1,
43 | "resblock": "1",
44 | "resblock_kernel_sizes": [3,7,11],
45 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
46 | "upsample_rates": [8,8,2,2],
47 | "upsample_initial_channel": 512,
48 | "upsample_kernel_sizes": [16,16,4,4],
49 | "n_layers_q": 3,
50 | "use_spectral_norm": false,
51 | "gin_channels": 256
52 | },
53 | "symbols": ["_", ",", ".", "!", "?", "-", "~", "\u2026", "N", "Q", "a", "b", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "s", "t", "u", "v", "w", "x", "y", "z", "\u0251", "\u00e6", "\u0283", "\u0291", "\u00e7", "\u026f", "\u026a", "\u0254", "\u025b", "\u0279", "\u00f0", "\u0259", "\u026b", "\u0265", "\u0278", "\u028a", "\u027e", "\u0292", "\u03b8", "\u03b2", "\u014b", "\u0266", "\u207c", "\u02b0", "`", "^", "#", "*", "=", "\u02c8", "\u02cc", "\u2192", "\u2193", "\u2191", " "]
54 | }
--------------------------------------------------------------------------------
/VITS/mel_processing.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import random
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | import torch.utils.data
8 | import numpy as np
9 | import librosa
10 | import librosa.util as librosa_util
11 | from librosa.util import normalize, pad_center, tiny
12 | from scipy.signal import get_window
13 | from scipy.io.wavfile import read
14 | from librosa.filters import mel as librosa_mel_fn
15 |
16 | MAX_WAV_VALUE = 32768.0
17 |
18 |
19 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
20 | """
21 | PARAMS
22 | ------
23 | C: compression factor
24 | """
25 | return torch.log(torch.clamp(x, min=clip_val) * C)
26 |
27 |
28 | def dynamic_range_decompression_torch(x, C=1):
29 | """
30 | PARAMS
31 | ------
32 | C: compression factor used to compress
33 | """
34 | return torch.exp(x) / C
35 |
36 |
37 | def spectral_normalize_torch(magnitudes):
38 | output = dynamic_range_compression_torch(magnitudes)
39 | return output
40 |
41 |
42 | def spectral_de_normalize_torch(magnitudes):
43 | output = dynamic_range_decompression_torch(magnitudes)
44 | return output
45 |
46 |
47 | mel_basis = {}
48 | hann_window = {}
49 |
50 |
51 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
52 | if torch.min(y) < -1.:
53 | print('min value is ', torch.min(y))
54 | if torch.max(y) > 1.:
55 | print('max value is ', torch.max(y))
56 |
57 | global hann_window
58 | dtype_device = str(y.dtype) + '_' + str(y.device)
59 | wnsize_dtype_device = str(win_size) + '_' + dtype_device
60 | if wnsize_dtype_device not in hann_window:
61 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
62 |
63 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
64 | y = y.squeeze(1)
65 |
66 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
67 | center=center, pad_mode='reflect', normalized=False, onesided=True)
68 |
69 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
70 | return spec
71 |
72 |
73 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
74 | global mel_basis
75 | dtype_device = str(spec.dtype) + '_' + str(spec.device)
76 | fmax_dtype_device = str(fmax) + '_' + dtype_device
77 | if fmax_dtype_device not in mel_basis:
78 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
79 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
80 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
81 | spec = spectral_normalize_torch(spec)
82 | return spec
83 |
84 |
85 | def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
86 | if torch.min(y) < -1.:
87 | print('min value is ', torch.min(y))
88 | if torch.max(y) > 1.:
89 | print('max value is ', torch.max(y))
90 |
91 | global mel_basis, hann_window
92 | dtype_device = str(y.dtype) + '_' + str(y.device)
93 | fmax_dtype_device = str(fmax) + '_' + dtype_device
94 | wnsize_dtype_device = str(win_size) + '_' + dtype_device
95 | if fmax_dtype_device not in mel_basis:
96 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
97 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
98 | if wnsize_dtype_device not in hann_window:
99 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
100 |
101 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
102 | y = y.squeeze(1)
103 |
104 | spec = torch.stft(y.float(), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
105 | center=center, pad_mode='reflect', normalized=False, onesided=True)
106 |
107 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
108 |
109 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
110 | spec = spectral_normalize_torch(spec)
111 |
112 | return spec
113 |
--------------------------------------------------------------------------------
/VITS/models.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | from VITS import commons
8 | from VITS import modules#自建文件
9 | from VITS import attentions#自建文件
10 | from VITS import monotonic_align
11 |
12 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
13 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
14 | from VITS.commons import init_weights, get_padding
15 |
16 |
17 | class StochasticDurationPredictor(nn.Module):
18 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
19 | super().__init__()
20 | filter_channels = in_channels # it needs to be removed from future version.
21 | self.in_channels = in_channels
22 | self.filter_channels = filter_channels
23 | self.kernel_size = kernel_size
24 | self.p_dropout = p_dropout
25 | self.n_flows = n_flows
26 | self.gin_channels = gin_channels
27 |
28 | self.log_flow = modules.Log()
29 | self.flows = nn.ModuleList()
30 | self.flows.append(modules.ElementwiseAffine(2))
31 | for i in range(n_flows):
32 | self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
33 | self.flows.append(modules.Flip())
34 |
35 | self.post_pre = nn.Conv1d(1, filter_channels, 1)
36 | self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
37 | self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
38 | self.post_flows = nn.ModuleList()
39 | self.post_flows.append(modules.ElementwiseAffine(2))
40 | for i in range(4):
41 | self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
42 | self.post_flows.append(modules.Flip())
43 |
44 | self.pre = nn.Conv1d(in_channels, filter_channels, 1)
45 | self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
46 | self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
47 | if gin_channels != 0:
48 | self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
49 |
50 | def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
51 | x = torch.detach(x)
52 | x = self.pre(x)
53 | if g is not None:
54 | g = torch.detach(g)
55 | x = x + self.cond(g)
56 | x = self.convs(x, x_mask)
57 | x = self.proj(x) * x_mask
58 |
59 | if not reverse:
60 | flows = self.flows
61 | assert w is not None
62 |
63 | logdet_tot_q = 0
64 | h_w = self.post_pre(w)
65 | h_w = self.post_convs(h_w, x_mask)
66 | h_w = self.post_proj(h_w) * x_mask
67 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
68 | z_q = e_q
69 | for flow in self.post_flows:
70 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
71 | logdet_tot_q += logdet_q
72 | z_u, z1 = torch.split(z_q, [1, 1], 1)
73 | u = torch.sigmoid(z_u) * x_mask
74 | z0 = (w - u) * x_mask
75 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
76 | logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
77 |
78 | logdet_tot = 0
79 | z0, logdet = self.log_flow(z0, x_mask)
80 | logdet_tot += logdet
81 | z = torch.cat([z0, z1], 1)
82 | for flow in flows:
83 | z, logdet = flow(z, x_mask, g=x, reverse=reverse)
84 | logdet_tot = logdet_tot + logdet
85 | nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
86 | return nll + logq # [b]
87 | else:
88 | flows = list(reversed(self.flows))
89 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow
90 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
91 | for flow in flows:
92 | z = flow(z, x_mask, g=x, reverse=reverse)
93 | z0, z1 = torch.split(z, [1, 1], 1)
94 | logw = z0
95 | return logw
96 |
97 |
98 | class DurationPredictor(nn.Module):
99 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
100 | super().__init__()
101 |
102 | self.in_channels = in_channels
103 | self.filter_channels = filter_channels
104 | self.kernel_size = kernel_size
105 | self.p_dropout = p_dropout
106 | self.gin_channels = gin_channels
107 |
108 | self.drop = nn.Dropout(p_dropout)
109 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
110 | self.norm_1 = modules.LayerNorm(filter_channels)
111 | self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
112 | self.norm_2 = modules.LayerNorm(filter_channels)
113 | self.proj = nn.Conv1d(filter_channels, 1, 1)
114 |
115 | if gin_channels != 0:
116 | self.cond = nn.Conv1d(gin_channels, in_channels, 1)
117 |
118 | def forward(self, x, x_mask, g=None):
119 | x = torch.detach(x)
120 | if g is not None:
121 | g = torch.detach(g)
122 | x = x + self.cond(g)
123 | x = self.conv_1(x * x_mask)
124 | x = torch.relu(x)
125 | x = self.norm_1(x)
126 | x = self.drop(x)
127 | x = self.conv_2(x * x_mask)
128 | x = torch.relu(x)
129 | x = self.norm_2(x)
130 | x = self.drop(x)
131 | x = self.proj(x * x_mask)
132 | return x * x_mask
133 |
134 |
135 | class TextEncoder(nn.Module):
136 | def __init__(self,
137 | n_vocab,
138 | out_channels,
139 | hidden_channels,
140 | filter_channels,
141 | n_heads,
142 | n_layers,
143 | kernel_size,
144 | p_dropout):
145 | super().__init__()
146 | self.n_vocab = n_vocab
147 | self.out_channels = out_channels
148 | self.hidden_channels = hidden_channels
149 | self.filter_channels = filter_channels
150 | self.n_heads = n_heads
151 | self.n_layers = n_layers
152 | self.kernel_size = kernel_size
153 | self.p_dropout = p_dropout
154 |
155 | self.emb = nn.Embedding(n_vocab, hidden_channels)
156 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
157 |
158 | self.encoder = attentions.Encoder(
159 | hidden_channels,
160 | filter_channels,
161 | n_heads,
162 | n_layers,
163 | kernel_size,
164 | p_dropout)
165 | self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
166 |
167 | def forward(self, x, x_lengths):
168 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
169 | x = torch.transpose(x, 1, -1) # [b, h, t]
170 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
171 |
172 | x = self.encoder(x * x_mask, x_mask)
173 | stats = self.proj(x) * x_mask
174 |
175 | m, logs = torch.split(stats, self.out_channels, dim=1)
176 | return x, m, logs, x_mask
177 |
178 |
179 | class ResidualCouplingBlock(nn.Module):
180 | def __init__(self,
181 | channels,
182 | hidden_channels,
183 | kernel_size,
184 | dilation_rate,
185 | n_layers,
186 | n_flows=4,
187 | gin_channels=0):
188 | super().__init__()
189 | self.channels = channels
190 | self.hidden_channels = hidden_channels
191 | self.kernel_size = kernel_size
192 | self.dilation_rate = dilation_rate
193 | self.n_layers = n_layers
194 | self.n_flows = n_flows
195 | self.gin_channels = gin_channels
196 |
197 | self.flows = nn.ModuleList()
198 | for i in range(n_flows):
199 | self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
200 | self.flows.append(modules.Flip())
201 |
202 | def forward(self, x, x_mask, g=None, reverse=False):
203 | if not reverse:
204 | for flow in self.flows:
205 | x, _ = flow(x, x_mask, g=g, reverse=reverse)
206 | else:
207 | for flow in reversed(self.flows):
208 | x = flow(x, x_mask, g=g, reverse=reverse)
209 | return x
210 |
211 |
212 | class PosteriorEncoder(nn.Module):
213 | def __init__(self,
214 | in_channels,
215 | out_channels,
216 | hidden_channels,
217 | kernel_size,
218 | dilation_rate,
219 | n_layers,
220 | gin_channels=0):
221 | super().__init__()
222 | self.in_channels = in_channels
223 | self.out_channels = out_channels
224 | self.hidden_channels = hidden_channels
225 | self.kernel_size = kernel_size
226 | self.dilation_rate = dilation_rate
227 | self.n_layers = n_layers
228 | self.gin_channels = gin_channels
229 |
230 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
231 | self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
232 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
233 |
234 | def forward(self, x, x_lengths, g=None):
235 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
236 | x = self.pre(x) * x_mask
237 | x = self.enc(x, x_mask, g=g)
238 | stats = self.proj(x) * x_mask
239 | m, logs = torch.split(stats, self.out_channels, dim=1)
240 | z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
241 | return z, m, logs, x_mask
242 |
243 |
244 | class Generator(torch.nn.Module):
245 | def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
246 | super(Generator, self).__init__()
247 | self.num_kernels = len(resblock_kernel_sizes)
248 | self.num_upsamples = len(upsample_rates)
249 | self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
250 | resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
251 |
252 | self.ups = nn.ModuleList()
253 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
254 | self.ups.append(weight_norm(
255 | ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
256 | k, u, padding=(k-u)//2)))
257 |
258 | self.resblocks = nn.ModuleList()
259 | for i in range(len(self.ups)):
260 | ch = upsample_initial_channel//(2**(i+1))
261 | for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
262 | self.resblocks.append(resblock(ch, k, d))
263 |
264 | self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
265 | self.ups.apply(init_weights)
266 |
267 | if gin_channels != 0:
268 | self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
269 |
270 | def forward(self, x, g=None):
271 | x = self.conv_pre(x)
272 | if g is not None:
273 | x = x + self.cond(g)
274 |
275 | for i in range(self.num_upsamples):
276 | x = F.leaky_relu(x, modules.LRELU_SLOPE)
277 | x = self.ups[i](x)
278 | xs = None
279 | for j in range(self.num_kernels):
280 | if xs is None:
281 | xs = self.resblocks[i*self.num_kernels+j](x)
282 | else:
283 | xs += self.resblocks[i*self.num_kernels+j](x)
284 | x = xs / self.num_kernels
285 | x = F.leaky_relu(x)
286 | x = self.conv_post(x)
287 | x = torch.tanh(x)
288 |
289 | return x
290 |
291 | def remove_weight_norm(self):
292 | #print('Removing weight norm...')
293 | for l in self.ups:
294 | remove_weight_norm(l)
295 | for l in self.resblocks:
296 | l.remove_weight_norm()
297 |
298 |
299 | class DiscriminatorP(torch.nn.Module):
300 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
301 | super(DiscriminatorP, self).__init__()
302 | self.period = period
303 | self.use_spectral_norm = use_spectral_norm
304 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
305 | self.convs = nn.ModuleList([
306 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
307 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
308 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
309 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
310 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
311 | ])
312 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
313 |
314 | def forward(self, x):
315 | fmap = []
316 |
317 | # 1d to 2d
318 | b, c, t = x.shape
319 | if t % self.period != 0: # pad first
320 | n_pad = self.period - (t % self.period)
321 | x = F.pad(x, (0, n_pad), "reflect")
322 | t = t + n_pad
323 | x = x.view(b, c, t // self.period, self.period)
324 |
325 | for l in self.convs:
326 | x = l(x)
327 | x = F.leaky_relu(x, modules.LRELU_SLOPE)
328 | fmap.append(x)
329 | x = self.conv_post(x)
330 | fmap.append(x)
331 | x = torch.flatten(x, 1, -1)
332 |
333 | return x, fmap
334 |
335 |
336 | class DiscriminatorS(torch.nn.Module):
337 | def __init__(self, use_spectral_norm=False):
338 | super(DiscriminatorS, self).__init__()
339 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
340 | self.convs = nn.ModuleList([
341 | norm_f(Conv1d(1, 16, 15, 1, padding=7)),
342 | norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
343 | norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
344 | norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
345 | norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
346 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
347 | ])
348 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
349 |
350 | def forward(self, x):
351 | fmap = []
352 |
353 | for l in self.convs:
354 | x = l(x)
355 | x = F.leaky_relu(x, modules.LRELU_SLOPE)
356 | fmap.append(x)
357 | x = self.conv_post(x)
358 | fmap.append(x)
359 | x = torch.flatten(x, 1, -1)
360 |
361 | return x, fmap
362 |
363 |
364 | class MultiPeriodDiscriminator(torch.nn.Module):
365 | def __init__(self, use_spectral_norm=False):
366 | super(MultiPeriodDiscriminator, self).__init__()
367 | periods = [2,3,5,7,11]
368 |
369 | discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
370 | discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
371 | self.discriminators = nn.ModuleList(discs)
372 |
373 | def forward(self, y, y_hat):
374 | y_d_rs = []
375 | y_d_gs = []
376 | fmap_rs = []
377 | fmap_gs = []
378 | for i, d in enumerate(self.discriminators):
379 | y_d_r, fmap_r = d(y)
380 | y_d_g, fmap_g = d(y_hat)
381 | y_d_rs.append(y_d_r)
382 | y_d_gs.append(y_d_g)
383 | fmap_rs.append(fmap_r)
384 | fmap_gs.append(fmap_g)
385 |
386 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
387 |
388 |
389 |
390 | class SynthesizerTrn(nn.Module):
391 | """
392 | Synthesizer for Training
393 | """
394 |
395 | def __init__(self,
396 | n_vocab,
397 | spec_channels,
398 | segment_size,
399 | inter_channels,
400 | hidden_channels,
401 | filter_channels,
402 | n_heads,
403 | n_layers,
404 | kernel_size,
405 | p_dropout,
406 | resblock,
407 | resblock_kernel_sizes,
408 | resblock_dilation_sizes,
409 | upsample_rates,
410 | upsample_initial_channel,
411 | upsample_kernel_sizes,
412 | n_speakers=0,
413 | gin_channels=0,
414 | use_sdp=True,
415 | **kwargs):
416 |
417 | super().__init__()
418 | self.n_vocab = n_vocab
419 | self.spec_channels = spec_channels
420 | self.inter_channels = inter_channels
421 | self.hidden_channels = hidden_channels
422 | self.filter_channels = filter_channels
423 | self.n_heads = n_heads
424 | self.n_layers = n_layers
425 | self.kernel_size = kernel_size
426 | self.p_dropout = p_dropout
427 | self.resblock = resblock
428 | self.resblock_kernel_sizes = resblock_kernel_sizes
429 | self.resblock_dilation_sizes = resblock_dilation_sizes
430 | self.upsample_rates = upsample_rates
431 | self.upsample_initial_channel = upsample_initial_channel
432 | self.upsample_kernel_sizes = upsample_kernel_sizes
433 | self.segment_size = segment_size
434 | self.n_speakers = n_speakers
435 | self.gin_channels = gin_channels
436 |
437 | self.use_sdp = use_sdp
438 |
439 | self.enc_p = TextEncoder(n_vocab,
440 | inter_channels,
441 | hidden_channels,
442 | filter_channels,
443 | n_heads,
444 | n_layers,
445 | kernel_size,
446 | p_dropout)
447 | self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
448 | self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
449 | self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
450 |
451 | if use_sdp:
452 | self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
453 | else:
454 | self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
455 |
456 | if n_speakers >= 1:
457 | self.emb_g = nn.Embedding(n_speakers, gin_channels)
458 |
459 | def forward(self, x, x_lengths, y, y_lengths, sid=None):
460 |
461 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
462 | if self.n_speakers > 0:
463 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
464 | else:
465 | g = None
466 |
467 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
468 | z_p = self.flow(z, y_mask, g=g)
469 |
470 | with torch.no_grad():
471 | # negative cross-entropy
472 | s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
473 | neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s]
474 | neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
475 | neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
476 | neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s]
477 | neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
478 |
479 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
480 | attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
481 |
482 | w = attn.sum(2)
483 | if self.use_sdp:
484 | l_length = self.dp(x, x_mask, w, g=g)
485 | l_length = l_length / torch.sum(x_mask)
486 | else:
487 | logw_ = torch.log(w + 1e-6) * x_mask
488 | logw = self.dp(x, x_mask, g=g)
489 | l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging
490 |
491 | # expand prior
492 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
493 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
494 |
495 | z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
496 | o = self.dec(z_slice, g=g)
497 | return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
498 |
499 | def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
500 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
501 | if self.n_speakers > 0:
502 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
503 | else:
504 | g = None
505 |
506 | if self.use_sdp:
507 | logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
508 | else:
509 | logw = self.dp(x, x_mask, g=g)
510 | w = torch.exp(logw) * x_mask * length_scale
511 | w_ceil = torch.ceil(w)
512 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
513 | y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
514 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
515 | attn = commons.generate_path(w_ceil, attn_mask)
516 |
517 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
518 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
519 |
520 | z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
521 | z = self.flow(z_p, y_mask, g=g, reverse=True)
522 | o = self.dec((z * y_mask)[:,:,:max_len], g=g)
523 | return o, attn, y_mask, (z, z_p, m_p, logs_p)
524 |
525 | def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
526 | assert self.n_speakers > 0, "n_speakers have to be larger than 0."
527 | g_src = self.emb_g(sid_src).unsqueeze(-1)
528 | g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
529 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
530 | z_p = self.flow(z, y_mask, g=g_src)
531 | z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
532 | o_hat = self.dec(z_hat * y_mask, g=g_tgt)
533 | return o_hat, y_mask, (z, z_p, z_hat)
534 |
--------------------------------------------------------------------------------
/VITS/models_infer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 |
6 | import commons
7 | import modules
8 | import attentions
9 |
10 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12 | from commons import init_weights, get_padding
13 |
14 |
15 | class StochasticDurationPredictor(nn.Module):
16 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
17 | super().__init__()
18 | filter_channels = in_channels # it needs to be removed from future version.
19 | self.in_channels = in_channels
20 | self.filter_channels = filter_channels
21 | self.kernel_size = kernel_size
22 | self.p_dropout = p_dropout
23 | self.n_flows = n_flows
24 | self.gin_channels = gin_channels
25 |
26 | self.log_flow = modules.Log()
27 | self.flows = nn.ModuleList()
28 | self.flows.append(modules.ElementwiseAffine(2))
29 | for i in range(n_flows):
30 | self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
31 | self.flows.append(modules.Flip())
32 |
33 | self.post_pre = nn.Conv1d(1, filter_channels, 1)
34 | self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
35 | self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
36 | self.post_flows = nn.ModuleList()
37 | self.post_flows.append(modules.ElementwiseAffine(2))
38 | for i in range(4):
39 | self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
40 | self.post_flows.append(modules.Flip())
41 |
42 | self.pre = nn.Conv1d(in_channels, filter_channels, 1)
43 | self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
44 | self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
45 | if gin_channels != 0:
46 | self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
47 |
48 | def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
49 | x = torch.detach(x)
50 | x = self.pre(x)
51 | if g is not None:
52 | g = torch.detach(g)
53 | x = x + self.cond(g)
54 | x = self.convs(x, x_mask)
55 | x = self.proj(x) * x_mask
56 |
57 | if not reverse:
58 | flows = self.flows
59 | assert w is not None
60 |
61 | logdet_tot_q = 0
62 | h_w = self.post_pre(w)
63 | h_w = self.post_convs(h_w, x_mask)
64 | h_w = self.post_proj(h_w) * x_mask
65 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
66 | z_q = e_q
67 | for flow in self.post_flows:
68 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
69 | logdet_tot_q += logdet_q
70 | z_u, z1 = torch.split(z_q, [1, 1], 1)
71 | u = torch.sigmoid(z_u) * x_mask
72 | z0 = (w - u) * x_mask
73 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
74 | logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
75 |
76 | logdet_tot = 0
77 | z0, logdet = self.log_flow(z0, x_mask)
78 | logdet_tot += logdet
79 | z = torch.cat([z0, z1], 1)
80 | for flow in flows:
81 | z, logdet = flow(z, x_mask, g=x, reverse=reverse)
82 | logdet_tot = logdet_tot + logdet
83 | nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
84 | return nll + logq # [b]
85 | else:
86 | flows = list(reversed(self.flows))
87 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow
88 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
89 | for flow in flows:
90 | z = flow(z, x_mask, g=x, reverse=reverse)
91 | z0, z1 = torch.split(z, [1, 1], 1)
92 | logw = z0
93 | return logw
94 |
95 |
96 | class DurationPredictor(nn.Module):
97 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
98 | super().__init__()
99 |
100 | self.in_channels = in_channels
101 | self.filter_channels = filter_channels
102 | self.kernel_size = kernel_size
103 | self.p_dropout = p_dropout
104 | self.gin_channels = gin_channels
105 |
106 | self.drop = nn.Dropout(p_dropout)
107 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
108 | self.norm_1 = modules.LayerNorm(filter_channels)
109 | self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
110 | self.norm_2 = modules.LayerNorm(filter_channels)
111 | self.proj = nn.Conv1d(filter_channels, 1, 1)
112 |
113 | if gin_channels != 0:
114 | self.cond = nn.Conv1d(gin_channels, in_channels, 1)
115 |
116 | def forward(self, x, x_mask, g=None):
117 | x = torch.detach(x)
118 | if g is not None:
119 | g = torch.detach(g)
120 | x = x + self.cond(g)
121 | x = self.conv_1(x * x_mask)
122 | x = torch.relu(x)
123 | x = self.norm_1(x)
124 | x = self.drop(x)
125 | x = self.conv_2(x * x_mask)
126 | x = torch.relu(x)
127 | x = self.norm_2(x)
128 | x = self.drop(x)
129 | x = self.proj(x * x_mask)
130 | return x * x_mask
131 |
132 |
133 | class TextEncoder(nn.Module):
134 | def __init__(self,
135 | n_vocab,
136 | out_channels,
137 | hidden_channels,
138 | filter_channels,
139 | n_heads,
140 | n_layers,
141 | kernel_size,
142 | p_dropout):
143 | super().__init__()
144 | self.n_vocab = n_vocab
145 | self.out_channels = out_channels
146 | self.hidden_channels = hidden_channels
147 | self.filter_channels = filter_channels
148 | self.n_heads = n_heads
149 | self.n_layers = n_layers
150 | self.kernel_size = kernel_size
151 | self.p_dropout = p_dropout
152 |
153 | self.emb = nn.Embedding(n_vocab, hidden_channels)
154 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
155 |
156 | self.encoder = attentions.Encoder(
157 | hidden_channels,
158 | filter_channels,
159 | n_heads,
160 | n_layers,
161 | kernel_size,
162 | p_dropout)
163 | self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
164 |
165 | def forward(self, x, x_lengths):
166 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
167 | x = torch.transpose(x, 1, -1) # [b, h, t]
168 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
169 |
170 | x = self.encoder(x * x_mask, x_mask)
171 | stats = self.proj(x) * x_mask
172 |
173 | m, logs = torch.split(stats, self.out_channels, dim=1)
174 | return x, m, logs, x_mask
175 |
176 |
177 | class ResidualCouplingBlock(nn.Module):
178 | def __init__(self,
179 | channels,
180 | hidden_channels,
181 | kernel_size,
182 | dilation_rate,
183 | n_layers,
184 | n_flows=4,
185 | gin_channels=0):
186 | super().__init__()
187 | self.channels = channels
188 | self.hidden_channels = hidden_channels
189 | self.kernel_size = kernel_size
190 | self.dilation_rate = dilation_rate
191 | self.n_layers = n_layers
192 | self.n_flows = n_flows
193 | self.gin_channels = gin_channels
194 |
195 | self.flows = nn.ModuleList()
196 | for i in range(n_flows):
197 | self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
198 | self.flows.append(modules.Flip())
199 |
200 | def forward(self, x, x_mask, g=None, reverse=False):
201 | if not reverse:
202 | for flow in self.flows:
203 | x, _ = flow(x, x_mask, g=g, reverse=reverse)
204 | else:
205 | for flow in reversed(self.flows):
206 | x = flow(x, x_mask, g=g, reverse=reverse)
207 | return x
208 |
209 |
210 | class PosteriorEncoder(nn.Module):
211 | def __init__(self,
212 | in_channels,
213 | out_channels,
214 | hidden_channels,
215 | kernel_size,
216 | dilation_rate,
217 | n_layers,
218 | gin_channels=0):
219 | super().__init__()
220 | self.in_channels = in_channels
221 | self.out_channels = out_channels
222 | self.hidden_channels = hidden_channels
223 | self.kernel_size = kernel_size
224 | self.dilation_rate = dilation_rate
225 | self.n_layers = n_layers
226 | self.gin_channels = gin_channels
227 |
228 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
229 | self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
230 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
231 |
232 | def forward(self, x, x_lengths, g=None):
233 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
234 | x = self.pre(x) * x_mask
235 | x = self.enc(x, x_mask, g=g)
236 | stats = self.proj(x) * x_mask
237 | m, logs = torch.split(stats, self.out_channels, dim=1)
238 | z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
239 | return z, m, logs, x_mask
240 |
241 |
242 | class Generator(torch.nn.Module):
243 | def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
244 | super(Generator, self).__init__()
245 | self.num_kernels = len(resblock_kernel_sizes)
246 | self.num_upsamples = len(upsample_rates)
247 | self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
248 | resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
249 |
250 | self.ups = nn.ModuleList()
251 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
252 | self.ups.append(weight_norm(
253 | ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
254 | k, u, padding=(k-u)//2)))
255 |
256 | self.resblocks = nn.ModuleList()
257 | for i in range(len(self.ups)):
258 | ch = upsample_initial_channel//(2**(i+1))
259 | for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
260 | self.resblocks.append(resblock(ch, k, d))
261 |
262 | self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
263 | self.ups.apply(init_weights)
264 |
265 | if gin_channels != 0:
266 | self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
267 |
268 | def forward(self, x, g=None):
269 | x = self.conv_pre(x)
270 | if g is not None:
271 | x = x + self.cond(g)
272 |
273 | for i in range(self.num_upsamples):
274 | x = F.leaky_relu(x, modules.LRELU_SLOPE)
275 | x = self.ups[i](x)
276 | xs = None
277 | for j in range(self.num_kernels):
278 | if xs is None:
279 | xs = self.resblocks[i*self.num_kernels+j](x)
280 | else:
281 | xs += self.resblocks[i*self.num_kernels+j](x)
282 | x = xs / self.num_kernels
283 | x = F.leaky_relu(x)
284 | x = self.conv_post(x)
285 | x = torch.tanh(x)
286 |
287 | return x
288 |
289 | def remove_weight_norm(self):
290 | #print('Removing weight norm...')
291 | for l in self.ups:
292 | remove_weight_norm(l)
293 | for l in self.resblocks:
294 | l.remove_weight_norm()
295 |
296 |
297 |
298 | class SynthesizerTrn(nn.Module):
299 | """
300 | Synthesizer for Training
301 | """
302 |
303 | def __init__(self,
304 | n_vocab,
305 | spec_channels,
306 | segment_size,
307 | inter_channels,
308 | hidden_channels,
309 | filter_channels,
310 | n_heads,
311 | n_layers,
312 | kernel_size,
313 | p_dropout,
314 | resblock,
315 | resblock_kernel_sizes,
316 | resblock_dilation_sizes,
317 | upsample_rates,
318 | upsample_initial_channel,
319 | upsample_kernel_sizes,
320 | n_speakers=0,
321 | gin_channels=0,
322 | use_sdp=True,
323 | **kwargs):
324 |
325 | super().__init__()
326 | self.n_vocab = n_vocab
327 | self.spec_channels = spec_channels
328 | self.inter_channels = inter_channels
329 | self.hidden_channels = hidden_channels
330 | self.filter_channels = filter_channels
331 | self.n_heads = n_heads
332 | self.n_layers = n_layers
333 | self.kernel_size = kernel_size
334 | self.p_dropout = p_dropout
335 | self.resblock = resblock
336 | self.resblock_kernel_sizes = resblock_kernel_sizes
337 | self.resblock_dilation_sizes = resblock_dilation_sizes
338 | self.upsample_rates = upsample_rates
339 | self.upsample_initial_channel = upsample_initial_channel
340 | self.upsample_kernel_sizes = upsample_kernel_sizes
341 | self.segment_size = segment_size
342 | self.n_speakers = n_speakers
343 | self.gin_channels = gin_channels
344 |
345 | self.use_sdp = use_sdp
346 |
347 | self.enc_p = TextEncoder(n_vocab,
348 | inter_channels,
349 | hidden_channels,
350 | filter_channels,
351 | n_heads,
352 | n_layers,
353 | kernel_size,
354 | p_dropout)
355 | self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
356 | self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
357 | self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
358 |
359 | if use_sdp:
360 | self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
361 | else:
362 | self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
363 |
364 | if n_speakers > 1:
365 | self.emb_g = nn.Embedding(n_speakers, gin_channels)
366 |
367 | def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
368 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
369 | if self.n_speakers > 0:
370 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
371 | else:
372 | g = None
373 |
374 | if self.use_sdp:
375 | logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
376 | else:
377 | logw = self.dp(x, x_mask, g=g)
378 | w = torch.exp(logw) * x_mask * length_scale
379 | w_ceil = torch.ceil(w)
380 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
381 | y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
382 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
383 | attn = commons.generate_path(w_ceil, attn_mask)
384 |
385 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
386 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
387 |
388 | z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
389 | z = self.flow(z_p, y_mask, g=g, reverse=True)
390 | o = self.dec((z * y_mask)[:,:,:max_len], g=g)
391 | return o, attn, y_mask, (z, z_p, m_p, logs_p)
392 |
393 | def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
394 | assert self.n_speakers > 0, "n_speakers have to be larger than 0."
395 | g_src = self.emb_g(sid_src).unsqueeze(-1)
396 | g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
397 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
398 | z_p = self.flow(z, y_mask, g=g_src)
399 | z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
400 | o_hat = self.dec(z_hat * y_mask, g=g_tgt)
401 | return o_hat, y_mask, (z, z_p, z_hat)
402 |
403 |
--------------------------------------------------------------------------------
/VITS/modules.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import numpy as np
4 | import scipy
5 | import torch
6 | from torch import nn
7 | from torch.nn import functional as F
8 |
9 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10 | from torch.nn.utils import weight_norm, remove_weight_norm
11 |
12 | from VITS import commons
13 | from VITS.commons import init_weights, get_padding
14 | from VITS.transforms import piecewise_rational_quadratic_transform
15 |
16 |
17 | LRELU_SLOPE = 0.1
18 |
19 |
20 | class LayerNorm(nn.Module):
21 | def __init__(self, channels, eps=1e-5):
22 | super().__init__()
23 | self.channels = channels
24 | self.eps = eps
25 |
26 | self.gamma = nn.Parameter(torch.ones(channels))
27 | self.beta = nn.Parameter(torch.zeros(channels))
28 |
29 | def forward(self, x):
30 | x = x.transpose(1, -1)
31 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
32 | return x.transpose(1, -1)
33 |
34 |
35 | class ConvReluNorm(nn.Module):
36 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
37 | super().__init__()
38 | self.in_channels = in_channels
39 | self.hidden_channels = hidden_channels
40 | self.out_channels = out_channels
41 | self.kernel_size = kernel_size
42 | self.n_layers = n_layers
43 | self.p_dropout = p_dropout
44 | assert n_layers > 1, "Number of layers should be larger than 0."
45 |
46 | self.conv_layers = nn.ModuleList()
47 | self.norm_layers = nn.ModuleList()
48 | self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
49 | self.norm_layers.append(LayerNorm(hidden_channels))
50 | self.relu_drop = nn.Sequential(
51 | nn.ReLU(),
52 | nn.Dropout(p_dropout))
53 | for _ in range(n_layers-1):
54 | self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
55 | self.norm_layers.append(LayerNorm(hidden_channels))
56 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
57 | self.proj.weight.data.zero_()
58 | self.proj.bias.data.zero_()
59 |
60 | def forward(self, x, x_mask):
61 | x_org = x
62 | for i in range(self.n_layers):
63 | x = self.conv_layers[i](x * x_mask)
64 | x = self.norm_layers[i](x)
65 | x = self.relu_drop(x)
66 | x = x_org + self.proj(x)
67 | return x * x_mask
68 |
69 |
70 | class DDSConv(nn.Module):
71 | """
72 | Dialted and Depth-Separable Convolution
73 | """
74 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
75 | super().__init__()
76 | self.channels = channels
77 | self.kernel_size = kernel_size
78 | self.n_layers = n_layers
79 | self.p_dropout = p_dropout
80 |
81 | self.drop = nn.Dropout(p_dropout)
82 | self.convs_sep = nn.ModuleList()
83 | self.convs_1x1 = nn.ModuleList()
84 | self.norms_1 = nn.ModuleList()
85 | self.norms_2 = nn.ModuleList()
86 | for i in range(n_layers):
87 | dilation = kernel_size ** i
88 | padding = (kernel_size * dilation - dilation) // 2
89 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
90 | groups=channels, dilation=dilation, padding=padding
91 | ))
92 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
93 | self.norms_1.append(LayerNorm(channels))
94 | self.norms_2.append(LayerNorm(channels))
95 |
96 | def forward(self, x, x_mask, g=None):
97 | if g is not None:
98 | x = x + g
99 | for i in range(self.n_layers):
100 | y = self.convs_sep[i](x * x_mask)
101 | y = self.norms_1[i](y)
102 | y = F.gelu(y)
103 | y = self.convs_1x1[i](y)
104 | y = self.norms_2[i](y)
105 | y = F.gelu(y)
106 | y = self.drop(y)
107 | x = x + y
108 | return x * x_mask
109 |
110 |
111 | class WN(torch.nn.Module):
112 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
113 | super(WN, self).__init__()
114 | assert(kernel_size % 2 == 1)
115 | self.hidden_channels =hidden_channels
116 | self.kernel_size = kernel_size,
117 | self.dilation_rate = dilation_rate
118 | self.n_layers = n_layers
119 | self.gin_channels = gin_channels
120 | self.p_dropout = p_dropout
121 |
122 | self.in_layers = torch.nn.ModuleList()
123 | self.res_skip_layers = torch.nn.ModuleList()
124 | self.drop = nn.Dropout(p_dropout)
125 |
126 | if gin_channels != 0:
127 | cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
128 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
129 |
130 | for i in range(n_layers):
131 | dilation = dilation_rate ** i
132 | padding = int((kernel_size * dilation - dilation) / 2)
133 | in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
134 | dilation=dilation, padding=padding)
135 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
136 | self.in_layers.append(in_layer)
137 |
138 | # last one is not necessary
139 | if i < n_layers - 1:
140 | res_skip_channels = 2 * hidden_channels
141 | else:
142 | res_skip_channels = hidden_channels
143 |
144 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
145 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
146 | self.res_skip_layers.append(res_skip_layer)
147 |
148 | def forward(self, x, x_mask, g=None, **kwargs):
149 | output = torch.zeros_like(x)
150 | n_channels_tensor = torch.IntTensor([self.hidden_channels])
151 |
152 | if g is not None:
153 | g = self.cond_layer(g)
154 |
155 | for i in range(self.n_layers):
156 | x_in = self.in_layers[i](x)
157 | if g is not None:
158 | cond_offset = i * 2 * self.hidden_channels
159 | g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
160 | else:
161 | g_l = torch.zeros_like(x_in)
162 |
163 | acts = commons.fused_add_tanh_sigmoid_multiply(
164 | x_in,
165 | g_l,
166 | n_channels_tensor)
167 | acts = self.drop(acts)
168 |
169 | res_skip_acts = self.res_skip_layers[i](acts)
170 | if i < self.n_layers - 1:
171 | res_acts = res_skip_acts[:,:self.hidden_channels,:]
172 | x = (x + res_acts) * x_mask
173 | output = output + res_skip_acts[:,self.hidden_channels:,:]
174 | else:
175 | output = output + res_skip_acts
176 | return output * x_mask
177 |
178 | def remove_weight_norm(self):
179 | if self.gin_channels != 0:
180 | torch.nn.utils.remove_weight_norm(self.cond_layer)
181 | for l in self.in_layers:
182 | torch.nn.utils.remove_weight_norm(l)
183 | for l in self.res_skip_layers:
184 | torch.nn.utils.remove_weight_norm(l)
185 |
186 |
187 | class ResBlock1(torch.nn.Module):
188 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
189 | super(ResBlock1, self).__init__()
190 | self.convs1 = nn.ModuleList([
191 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
192 | padding=get_padding(kernel_size, dilation[0]))),
193 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
194 | padding=get_padding(kernel_size, dilation[1]))),
195 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
196 | padding=get_padding(kernel_size, dilation[2])))
197 | ])
198 | self.convs1.apply(init_weights)
199 |
200 | self.convs2 = nn.ModuleList([
201 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
202 | padding=get_padding(kernel_size, 1))),
203 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
204 | padding=get_padding(kernel_size, 1))),
205 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
206 | padding=get_padding(kernel_size, 1)))
207 | ])
208 | self.convs2.apply(init_weights)
209 |
210 | def forward(self, x, x_mask=None):
211 | for c1, c2 in zip(self.convs1, self.convs2):
212 | xt = F.leaky_relu(x, LRELU_SLOPE)
213 | if x_mask is not None:
214 | xt = xt * x_mask
215 | xt = c1(xt)
216 | xt = F.leaky_relu(xt, LRELU_SLOPE)
217 | if x_mask is not None:
218 | xt = xt * x_mask
219 | xt = c2(xt)
220 | x = xt + x
221 | if x_mask is not None:
222 | x = x * x_mask
223 | return x
224 |
225 | def remove_weight_norm(self):
226 | for l in self.convs1:
227 | remove_weight_norm(l)
228 | for l in self.convs2:
229 | remove_weight_norm(l)
230 |
231 |
232 | class ResBlock2(torch.nn.Module):
233 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
234 | super(ResBlock2, self).__init__()
235 | self.convs = nn.ModuleList([
236 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
237 | padding=get_padding(kernel_size, dilation[0]))),
238 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
239 | padding=get_padding(kernel_size, dilation[1])))
240 | ])
241 | self.convs.apply(init_weights)
242 |
243 | def forward(self, x, x_mask=None):
244 | for c in self.convs:
245 | xt = F.leaky_relu(x, LRELU_SLOPE)
246 | if x_mask is not None:
247 | xt = xt * x_mask
248 | xt = c(xt)
249 | x = xt + x
250 | if x_mask is not None:
251 | x = x * x_mask
252 | return x
253 |
254 | def remove_weight_norm(self):
255 | for l in self.convs:
256 | remove_weight_norm(l)
257 |
258 |
259 | class Log(nn.Module):
260 | def forward(self, x, x_mask, reverse=False, **kwargs):
261 | if not reverse:
262 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
263 | logdet = torch.sum(-y, [1, 2])
264 | return y, logdet
265 | else:
266 | x = torch.exp(x) * x_mask
267 | return x
268 |
269 |
270 | class Flip(nn.Module):
271 | def forward(self, x, *args, reverse=False, **kwargs):
272 | x = torch.flip(x, [1])
273 | if not reverse:
274 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
275 | return x, logdet
276 | else:
277 | return x
278 |
279 |
280 | class ElementwiseAffine(nn.Module):
281 | def __init__(self, channels):
282 | super().__init__()
283 | self.channels = channels
284 | self.m = nn.Parameter(torch.zeros(channels,1))
285 | self.logs = nn.Parameter(torch.zeros(channels,1))
286 |
287 | def forward(self, x, x_mask, reverse=False, **kwargs):
288 | if not reverse:
289 | y = self.m + torch.exp(self.logs) * x
290 | y = y * x_mask
291 | logdet = torch.sum(self.logs * x_mask, [1,2])
292 | return y, logdet
293 | else:
294 | x = (x - self.m) * torch.exp(-self.logs) * x_mask
295 | return x
296 |
297 |
298 | class ResidualCouplingLayer(nn.Module):
299 | def __init__(self,
300 | channels,
301 | hidden_channels,
302 | kernel_size,
303 | dilation_rate,
304 | n_layers,
305 | p_dropout=0,
306 | gin_channels=0,
307 | mean_only=False):
308 | assert channels % 2 == 0, "channels should be divisible by 2"
309 | super().__init__()
310 | self.channels = channels
311 | self.hidden_channels = hidden_channels
312 | self.kernel_size = kernel_size
313 | self.dilation_rate = dilation_rate
314 | self.n_layers = n_layers
315 | self.half_channels = channels // 2
316 | self.mean_only = mean_only
317 |
318 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
319 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
320 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
321 | self.post.weight.data.zero_()
322 | self.post.bias.data.zero_()
323 |
324 | def forward(self, x, x_mask, g=None, reverse=False):
325 | x0, x1 = torch.split(x, [self.half_channels]*2, 1)
326 | h = self.pre(x0) * x_mask
327 | h = self.enc(h, x_mask, g=g)
328 | stats = self.post(h) * x_mask
329 | if not self.mean_only:
330 | m, logs = torch.split(stats, [self.half_channels]*2, 1)
331 | else:
332 | m = stats
333 | logs = torch.zeros_like(m)
334 |
335 | if not reverse:
336 | x1 = m + x1 * torch.exp(logs) * x_mask
337 | x = torch.cat([x0, x1], 1)
338 | logdet = torch.sum(logs, [1,2])
339 | return x, logdet
340 | else:
341 | x1 = (x1 - m) * torch.exp(-logs) * x_mask
342 | x = torch.cat([x0, x1], 1)
343 | return x
344 |
345 |
346 | class ConvFlow(nn.Module):
347 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
348 | super().__init__()
349 | self.in_channels = in_channels
350 | self.filter_channels = filter_channels
351 | self.kernel_size = kernel_size
352 | self.n_layers = n_layers
353 | self.num_bins = num_bins
354 | self.tail_bound = tail_bound
355 | self.half_channels = in_channels // 2
356 |
357 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
358 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
359 | self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
360 | self.proj.weight.data.zero_()
361 | self.proj.bias.data.zero_()
362 |
363 | def forward(self, x, x_mask, g=None, reverse=False):
364 | x0, x1 = torch.split(x, [self.half_channels]*2, 1)
365 | h = self.pre(x0)
366 | h = self.convs(h, x_mask, g=g)
367 | h = self.proj(h) * x_mask
368 |
369 | b, c, t = x0.shape
370 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
371 |
372 | unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
373 | unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels)
374 | unnormalized_derivatives = h[..., 2 * self.num_bins:]
375 |
376 | x1, logabsdet = piecewise_rational_quadratic_transform(x1,
377 | unnormalized_widths,
378 | unnormalized_heights,
379 | unnormalized_derivatives,
380 | inverse=reverse,
381 | tails='linear',
382 | tail_bound=self.tail_bound
383 | )
384 |
385 | x = torch.cat([x0, x1], 1) * x_mask
386 | logdet = torch.sum(logabsdet * x_mask, [1,2])
387 | if not reverse:
388 | return x, logdet
389 | else:
390 | return x
391 |
--------------------------------------------------------------------------------
/VITS/monotonic_align/__init__.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from .monotonic_align.core import maximum_path_c
4 |
5 |
6 | def maximum_path(neg_cent, mask):
7 | """ Cython optimized version.
8 | neg_cent: [b, t_t, t_s]
9 | mask: [b, t_t, t_s]
10 | """
11 | device = neg_cent.device
12 | dtype = neg_cent.dtype
13 | neg_cent = neg_cent.data.cpu().numpy().astype(np.float32)
14 | path = np.zeros(neg_cent.shape, dtype=np.int32)
15 |
16 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)
17 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)
18 | maximum_path_c(path, neg_cent, t_t_max, t_s_max)
19 | return torch.from_numpy(path).to(device=device, dtype=dtype)
20 |
--------------------------------------------------------------------------------
/VITS/monotonic_align/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/monotonic_align/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/VITS/monotonic_align/build/lib.win-amd64-3.10/monotonic_align/core.cp310-win_amd64.pyd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/monotonic_align/build/lib.win-amd64-3.10/monotonic_align/core.cp310-win_amd64.pyd
--------------------------------------------------------------------------------
/VITS/monotonic_align/build/temp.win-amd64-3.10/Release/core.cp310-win_amd64.exp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/monotonic_align/build/temp.win-amd64-3.10/Release/core.cp310-win_amd64.exp
--------------------------------------------------------------------------------
/VITS/monotonic_align/build/temp.win-amd64-3.10/Release/core.cp310-win_amd64.lib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/monotonic_align/build/temp.win-amd64-3.10/Release/core.cp310-win_amd64.lib
--------------------------------------------------------------------------------
/VITS/monotonic_align/build/temp.win-amd64-3.10/Release/core.obj:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/monotonic_align/build/temp.win-amd64-3.10/Release/core.obj
--------------------------------------------------------------------------------
/VITS/monotonic_align/core.pyx:
--------------------------------------------------------------------------------
1 | cimport cython
2 | from cython.parallel import prange
3 |
4 |
5 | @cython.boundscheck(False)
6 | @cython.wraparound(False)
7 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
8 | cdef int x
9 | cdef int y
10 | cdef float v_prev
11 | cdef float v_cur
12 | cdef float tmp
13 | cdef int index = t_x - 1
14 |
15 | for y in range(t_y):
16 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
17 | if x == y:
18 | v_cur = max_neg_val
19 | else:
20 | v_cur = value[y-1, x]
21 | if x == 0:
22 | if y == 0:
23 | v_prev = 0.
24 | else:
25 | v_prev = max_neg_val
26 | else:
27 | v_prev = value[y-1, x-1]
28 | value[y, x] += max(v_prev, v_cur)
29 |
30 | for y in range(t_y - 1, -1, -1):
31 | path[y, index] = 1
32 | if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]):
33 | index = index - 1
34 |
35 |
36 | @cython.boundscheck(False)
37 | @cython.wraparound(False)
38 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
39 | cdef int b = paths.shape[0]
40 | cdef int i
41 | for i in prange(b, nogil=True):
42 | maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
43 |
--------------------------------------------------------------------------------
/VITS/monotonic_align/monotonic_align/core.cp310-win_amd64.pyd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/monotonic_align/monotonic_align/core.cp310-win_amd64.pyd
--------------------------------------------------------------------------------
/VITS/monotonic_align/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from Cython.Build import cythonize
3 | import numpy
4 |
5 | setup(
6 | name = 'monotonic_align',
7 | ext_modules = cythonize("core.pyx"),
8 | include_dirs=[numpy.get_include()]
9 | )
10 |
--------------------------------------------------------------------------------
/VITS/text/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017 Keith Ito
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/VITS/text/__init__.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 | from VITS.text import cleaners
3 | from VITS.text.symbols import symbols
4 |
5 |
6 | # Mappings from symbol to numeric ID and vice versa:
7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9 |
10 |
11 | def text_to_sequence(text, symbols, cleaner_names):
12 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
13 | Args:
14 | text: string to convert to a sequence
15 | cleaner_names: names of the cleaner functions to run the text through
16 | Returns:
17 | List of integers corresponding to the symbols in the text
18 | '''
19 | sequence = []
20 | symbol_to_id = {s: i for i, s in enumerate(symbols)}
21 | clean_text = _clean_text(text, cleaner_names)
22 | ##print(clean_text)
23 | ##print(f" length:{len(clean_text)}")
24 | for symbol in clean_text:
25 | if symbol not in symbol_to_id.keys():
26 | continue
27 | symbol_id = symbol_to_id[symbol]
28 | sequence += [symbol_id]
29 | ##print(f" length:{len(sequence)}")
30 | return sequence
31 |
32 |
33 | def cleaned_text_to_sequence(cleaned_text, symbols):
34 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
35 | Args:
36 | text: string to convert to a sequence
37 | Returns:
38 | List of integers corresponding to the symbols in the text
39 | '''
40 | symbol_to_id = {s: i for i, s in enumerate(symbols)}
41 | sequence = [symbol_to_id[symbol] for symbol in cleaned_text if symbol in symbol_to_id.keys()]
42 | return sequence
43 |
44 |
45 | def sequence_to_text(sequence):
46 | '''Converts a sequence of IDs back to a string'''
47 | result = ''
48 | for symbol_id in sequence:
49 | s = _id_to_symbol[symbol_id]
50 | result += s
51 | return result
52 |
53 |
54 | def _clean_text(text, cleaner_names):
55 | for name in cleaner_names:
56 | cleaner = getattr(cleaners, name)
57 | if not cleaner:
58 | raise Exception('Unknown cleaner: %s' % name)
59 | text = cleaner(text)
60 | return text
61 |
--------------------------------------------------------------------------------
/VITS/text/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/text/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/VITS/text/__pycache__/cleaners.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/text/__pycache__/cleaners.cpython-310.pyc
--------------------------------------------------------------------------------
/VITS/text/__pycache__/english.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/text/__pycache__/english.cpython-310.pyc
--------------------------------------------------------------------------------
/VITS/text/__pycache__/japanese.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/text/__pycache__/japanese.cpython-310.pyc
--------------------------------------------------------------------------------
/VITS/text/__pycache__/korean.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/text/__pycache__/korean.cpython-310.pyc
--------------------------------------------------------------------------------
/VITS/text/__pycache__/mandarin.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/text/__pycache__/mandarin.cpython-310.pyc
--------------------------------------------------------------------------------
/VITS/text/__pycache__/sanskrit.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/text/__pycache__/sanskrit.cpython-310.pyc
--------------------------------------------------------------------------------
/VITS/text/__pycache__/symbols.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/text/__pycache__/symbols.cpython-310.pyc
--------------------------------------------------------------------------------
/VITS/text/__pycache__/thai.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/VITS/text/__pycache__/thai.cpython-310.pyc
--------------------------------------------------------------------------------
/VITS/text/cantonese.py:
--------------------------------------------------------------------------------
1 | import re
2 | import cn2an
3 | import opencc
4 |
5 |
6 | converter = opencc.OpenCC('jyutjyu')
7 |
8 | # List of (Latin alphabet, ipa) pairs:
9 | _latin_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
10 | ('A', 'ei˥'),
11 | ('B', 'biː˥'),
12 | ('C', 'siː˥'),
13 | ('D', 'tiː˥'),
14 | ('E', 'iː˥'),
15 | ('F', 'e˥fuː˨˩'),
16 | ('G', 'tsiː˥'),
17 | ('H', 'ɪk̚˥tsʰyː˨˩'),
18 | ('I', 'ɐi˥'),
19 | ('J', 'tsei˥'),
20 | ('K', 'kʰei˥'),
21 | ('L', 'e˥llou˨˩'),
22 | ('M', 'ɛːm˥'),
23 | ('N', 'ɛːn˥'),
24 | ('O', 'ou˥'),
25 | ('P', 'pʰiː˥'),
26 | ('Q', 'kʰiːu˥'),
27 | ('R', 'aː˥lou˨˩'),
28 | ('S', 'ɛː˥siː˨˩'),
29 | ('T', 'tʰiː˥'),
30 | ('U', 'juː˥'),
31 | ('V', 'wiː˥'),
32 | ('W', 'tʊk̚˥piː˥juː˥'),
33 | ('X', 'ɪk̚˥siː˨˩'),
34 | ('Y', 'waːi˥'),
35 | ('Z', 'iː˨sɛːt̚˥')
36 | ]]
37 |
38 |
39 | def number_to_cantonese(text):
40 | return re.sub(r'\d+(?:\.?\d+)?', lambda x: cn2an.an2cn(x.group()), text)
41 |
42 |
43 | def latin_to_ipa(text):
44 | for regex, replacement in _latin_to_ipa:
45 | text = re.sub(regex, replacement, text)
46 | return text
47 |
48 |
49 | def cantonese_to_ipa(text):
50 | text = number_to_cantonese(text.upper())
51 | text = converter.convert(text).replace('-','').replace('$',' ')
52 | text = re.sub(r'[A-Z]', lambda x: latin_to_ipa(x.group())+' ', text)
53 | text = re.sub(r'[、;:]', ',', text)
54 | text = re.sub(r'\s*,\s*', ', ', text)
55 | text = re.sub(r'\s*。\s*', '. ', text)
56 | text = re.sub(r'\s*?\s*', '? ', text)
57 | text = re.sub(r'\s*!\s*', '! ', text)
58 | text = re.sub(r'\s*$', '', text)
59 | return text
60 |
--------------------------------------------------------------------------------
/VITS/text/cleaners.py:
--------------------------------------------------------------------------------
1 | import re
2 | from VITS.text.japanese import japanese_to_romaji_with_accent, japanese_to_ipa, japanese_to_ipa2, japanese_to_ipa3
3 | from VITS.text.korean import latin_to_hangul, number_to_hangul, divide_hangul, korean_to_lazy_ipa, korean_to_ipa
4 | from VITS.text.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo, chinese_to_romaji, chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2
5 | from VITS.text.sanskrit import devanagari_to_ipa
6 | from VITS.text.english import english_to_lazy_ipa, english_to_ipa2, english_to_lazy_ipa2
7 | from VITS.text.thai import num_to_thai, latin_to_thai
8 | # from text.shanghainese import shanghainese_to_ipa
9 | # from text.cantonese import cantonese_to_ipa
10 | # from text.ngu_dialect import ngu_dialect_to_ipa
11 |
12 |
13 | def japanese_cleaners(text):
14 | text = japanese_to_romaji_with_accent(text)
15 | text = re.sub(r'([A-Za-z])$', r'\1.', text)
16 | return text
17 |
18 |
19 | def japanese_cleaners2(text):
20 | return japanese_cleaners(text).replace('ts', 'ʦ').replace('...', '…')
21 |
22 |
23 | def korean_cleaners(text):
24 | '''Pipeline for Korean text'''
25 | text = latin_to_hangul(text)
26 | text = number_to_hangul(text)
27 | text = divide_hangul(text)
28 | text = re.sub(r'([\u3131-\u3163])$', r'\1.', text)
29 | return text
30 |
31 |
32 | # def chinese_cleaners(text):
33 | # '''Pipeline for Chinese text'''
34 | # text = number_to_chinese(text)
35 | # text = chinese_to_bopomofo(text)
36 | # text = latin_to_bopomofo(text)
37 | # text = re.sub(r'([ˉˊˇˋ˙])$', r'\1。', text)
38 | # return text
39 |
40 | def chinese_cleaners(text):
41 | from pypinyin import Style, pinyin
42 | text = text.replace("[ZH]", "")
43 | phones = [phone[0] for phone in pinyin(text, style=Style.TONE3)]
44 | return ' '.join(phones)
45 |
46 |
47 | def zh_ja_mixture_cleaners(text):
48 | text = re.sub(r'\[ZH\](.*?)\[ZH\]',
49 | lambda x: chinese_to_romaji(x.group(1))+' ', text)
50 | text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_romaji_with_accent(
51 | x.group(1)).replace('ts', 'ʦ').replace('u', 'ɯ').replace('...', '…')+' ', text)
52 | text = re.sub(r'\s+$', '', text)
53 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
54 | return text
55 |
56 |
57 | def sanskrit_cleaners(text):
58 | text = text.replace('॥', '।').replace('ॐ', 'ओम्')
59 | text = re.sub(r'([^।])$', r'\1।', text)
60 | return text
61 |
62 |
63 | def cjks_cleaners(text):
64 | text = re.sub(r'\[ZH\](.*?)\[ZH\]',
65 | lambda x: chinese_to_lazy_ipa(x.group(1))+' ', text)
66 | text = re.sub(r'\[JA\](.*?)\[JA\]',
67 | lambda x: japanese_to_ipa(x.group(1))+' ', text)
68 | text = re.sub(r'\[KO\](.*?)\[KO\]',
69 | lambda x: korean_to_lazy_ipa(x.group(1))+' ', text)
70 | text = re.sub(r'\[SA\](.*?)\[SA\]',
71 | lambda x: devanagari_to_ipa(x.group(1))+' ', text)
72 | text = re.sub(r'\[EN\](.*?)\[EN\]',
73 | lambda x: english_to_lazy_ipa(x.group(1))+' ', text)
74 | text = re.sub(r'\s+$', '', text)
75 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
76 | return text
77 |
78 |
79 | def cjke_cleaners(text):
80 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', lambda x: chinese_to_lazy_ipa(x.group(1)).replace(
81 | 'ʧ', 'tʃ').replace('ʦ', 'ts').replace('ɥan', 'ɥæn')+' ', text)
82 | text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_ipa(x.group(1)).replace('ʧ', 'tʃ').replace(
83 | 'ʦ', 'ts').replace('ɥan', 'ɥæn').replace('ʥ', 'dz')+' ', text)
84 | text = re.sub(r'\[KO\](.*?)\[KO\]',
85 | lambda x: korean_to_ipa(x.group(1))+' ', text)
86 | text = re.sub(r'\[EN\](.*?)\[EN\]', lambda x: english_to_ipa2(x.group(1)).replace('ɑ', 'a').replace(
87 | 'ɔ', 'o').replace('ɛ', 'e').replace('ɪ', 'i').replace('ʊ', 'u')+' ', text)
88 | text = re.sub(r'\s+$', '', text)
89 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
90 | return text
91 |
92 |
93 | def cjke_cleaners2(text):
94 | text = re.sub(r'\[ZH\](.*?)\[ZH\]',
95 | lambda x: chinese_to_ipa(x.group(1))+' ', text)
96 | text = re.sub(r'\[JA\](.*?)\[JA\]',
97 | lambda x: japanese_to_ipa2(x.group(1))+' ', text)
98 | text = re.sub(r'\[KO\](.*?)\[KO\]',
99 | lambda x: korean_to_ipa(x.group(1))+' ', text)
100 | text = re.sub(r'\[EN\](.*?)\[EN\]',
101 | lambda x: english_to_ipa2(x.group(1))+' ', text)
102 | text = re.sub(r'\s+$', '', text)
103 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
104 | return text
105 |
106 |
107 | def thai_cleaners(text):
108 | text = num_to_thai(text)
109 | text = latin_to_thai(text)
110 | return text
111 |
112 |
113 | # def shanghainese_cleaners(text):
114 | # text = shanghainese_to_ipa(text)
115 | # text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
116 | # return text
117 |
118 |
119 | # def chinese_dialect_cleaners(text):
120 | # text = re.sub(r'\[ZH\](.*?)\[ZH\]',
121 | # lambda x: chinese_to_ipa2(x.group(1))+' ', text)
122 | # text = re.sub(r'\[JA\](.*?)\[JA\]',
123 | # lambda x: japanese_to_ipa3(x.group(1)).replace('Q', 'ʔ')+' ', text)
124 | # text = re.sub(r'\[SH\](.*?)\[SH\]', lambda x: shanghainese_to_ipa(x.group(1)).replace('1', '˥˧').replace('5',
125 | # '˧˧˦').replace('6', '˩˩˧').replace('7', '˥').replace('8', '˩˨').replace('ᴀ', 'ɐ').replace('ᴇ', 'e')+' ', text)
126 | # text = re.sub(r'\[GD\](.*?)\[GD\]',
127 | # lambda x: cantonese_to_ipa(x.group(1))+' ', text)
128 | # text = re.sub(r'\[EN\](.*?)\[EN\]',
129 | # lambda x: english_to_lazy_ipa2(x.group(1))+' ', text)
130 | # text = re.sub(r'\[([A-Z]{2})\](.*?)\[\1\]', lambda x: ngu_dialect_to_ipa(x.group(2), x.group(
131 | # 1)).replace('ʣ', 'dz').replace('ʥ', 'dʑ').replace('ʦ', 'ts').replace('ʨ', 'tɕ')+' ', text)
132 | # text = re.sub(r'\s+$', '', text)
133 | # text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
134 | # return text
135 |
--------------------------------------------------------------------------------
/VITS/text/english.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | '''
4 | Cleaners are transformations that run over the input text at both training and eval time.
5 |
6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8 | 1. "english_cleaners" for English text
9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12 | the symbols in symbols.py to match your data).
13 | '''
14 |
15 |
16 | # Regular expression matching whitespace:
17 |
18 |
19 | import re
20 | import inflect
21 | from unidecode import unidecode
22 | import eng_to_ipa as ipa
23 | _inflect = inflect.engine()
24 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
25 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
26 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
27 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
28 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
29 | _number_re = re.compile(r'[0-9]+')
30 |
31 | # List of (regular expression, replacement) pairs for abbreviations:
32 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
33 | ('mrs', 'misess'),
34 | ('mr', 'mister'),
35 | ('dr', 'doctor'),
36 | ('st', 'saint'),
37 | ('co', 'company'),
38 | ('jr', 'junior'),
39 | ('maj', 'major'),
40 | ('gen', 'general'),
41 | ('drs', 'doctors'),
42 | ('rev', 'reverend'),
43 | ('lt', 'lieutenant'),
44 | ('hon', 'honorable'),
45 | ('sgt', 'sergeant'),
46 | ('capt', 'captain'),
47 | ('esq', 'esquire'),
48 | ('ltd', 'limited'),
49 | ('col', 'colonel'),
50 | ('ft', 'fort'),
51 | ]]
52 |
53 |
54 | # List of (ipa, lazy ipa) pairs:
55 | _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
56 | ('r', 'ɹ'),
57 | ('æ', 'e'),
58 | ('ɑ', 'a'),
59 | ('ɔ', 'o'),
60 | ('ð', 'z'),
61 | ('θ', 's'),
62 | ('ɛ', 'e'),
63 | ('ɪ', 'i'),
64 | ('ʊ', 'u'),
65 | ('ʒ', 'ʥ'),
66 | ('ʤ', 'ʥ'),
67 | ('ˈ', '↓'),
68 | ]]
69 |
70 | # List of (ipa, lazy ipa2) pairs:
71 | _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
72 | ('r', 'ɹ'),
73 | ('ð', 'z'),
74 | ('θ', 's'),
75 | ('ʒ', 'ʑ'),
76 | ('ʤ', 'dʑ'),
77 | ('ˈ', '↓'),
78 | ]]
79 |
80 | # List of (ipa, ipa2) pairs
81 | _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
82 | ('r', 'ɹ'),
83 | ('ʤ', 'dʒ'),
84 | ('ʧ', 'tʃ')
85 | ]]
86 |
87 |
88 | def expand_abbreviations(text):
89 | for regex, replacement in _abbreviations:
90 | text = re.sub(regex, replacement, text)
91 | return text
92 |
93 |
94 | def collapse_whitespace(text):
95 | return re.sub(r'\s+', ' ', text)
96 |
97 |
98 | def _remove_commas(m):
99 | return m.group(1).replace(',', '')
100 |
101 |
102 | def _expand_decimal_point(m):
103 | return m.group(1).replace('.', ' point ')
104 |
105 |
106 | def _expand_dollars(m):
107 | match = m.group(1)
108 | parts = match.split('.')
109 | if len(parts) > 2:
110 | return match + ' dollars' # Unexpected format
111 | dollars = int(parts[0]) if parts[0] else 0
112 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
113 | if dollars and cents:
114 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
115 | cent_unit = 'cent' if cents == 1 else 'cents'
116 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
117 | elif dollars:
118 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
119 | return '%s %s' % (dollars, dollar_unit)
120 | elif cents:
121 | cent_unit = 'cent' if cents == 1 else 'cents'
122 | return '%s %s' % (cents, cent_unit)
123 | else:
124 | return 'zero dollars'
125 |
126 |
127 | def _expand_ordinal(m):
128 | return _inflect.number_to_words(m.group(0))
129 |
130 |
131 | def _expand_number(m):
132 | num = int(m.group(0))
133 | if num > 1000 and num < 3000:
134 | if num == 2000:
135 | return 'two thousand'
136 | elif num > 2000 and num < 2010:
137 | return 'two thousand ' + _inflect.number_to_words(num % 100)
138 | elif num % 100 == 0:
139 | return _inflect.number_to_words(num // 100) + ' hundred'
140 | else:
141 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
142 | else:
143 | return _inflect.number_to_words(num, andword='')
144 |
145 |
146 | def normalize_numbers(text):
147 | text = re.sub(_comma_number_re, _remove_commas, text)
148 | text = re.sub(_pounds_re, r'\1 pounds', text)
149 | text = re.sub(_dollars_re, _expand_dollars, text)
150 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
151 | text = re.sub(_ordinal_re, _expand_ordinal, text)
152 | text = re.sub(_number_re, _expand_number, text)
153 | return text
154 |
155 |
156 | def mark_dark_l(text):
157 | return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text)
158 |
159 |
160 | def english_to_ipa(text):
161 | text = unidecode(text).lower()
162 | text = expand_abbreviations(text)
163 | text = normalize_numbers(text)
164 | phonemes = ipa.convert(text)
165 | phonemes = collapse_whitespace(phonemes)
166 | return phonemes
167 |
168 |
169 | def english_to_lazy_ipa(text):
170 | text = english_to_ipa(text)
171 | for regex, replacement in _lazy_ipa:
172 | text = re.sub(regex, replacement, text)
173 | return text
174 |
175 |
176 | def english_to_ipa2(text):
177 | text = english_to_ipa(text)
178 | text = mark_dark_l(text)
179 | for regex, replacement in _ipa_to_ipa2:
180 | text = re.sub(regex, replacement, text)
181 | return text.replace('...', '…')
182 |
183 |
184 | def english_to_lazy_ipa2(text):
185 | text = english_to_ipa(text)
186 | for regex, replacement in _lazy_ipa2:
187 | text = re.sub(regex, replacement, text)
188 | return text
189 |
--------------------------------------------------------------------------------
/VITS/text/japanese.py:
--------------------------------------------------------------------------------
1 | import re
2 | from unidecode import unidecode
3 | import pyopenjtalk
4 |
5 |
6 | # Regular expression matching Japanese without punctuation marks:
7 | _japanese_characters = re.compile(
8 | r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
9 |
10 | # Regular expression matching non-Japanese characters or punctuation marks:
11 | _japanese_marks = re.compile(
12 | r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
13 |
14 | # List of (symbol, Japanese) pairs for marks:
15 | _symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [
16 | ('%', 'パーセント')
17 | ]]
18 |
19 | # List of (romaji, ipa) pairs for marks:
20 | _romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
21 | ('ts', 'ʦ'),
22 | ('u', 'ɯ'),
23 | ('j', 'ʥ'),
24 | ('y', 'j'),
25 | ('ni', 'n^i'),
26 | ('nj', 'n^'),
27 | ('hi', 'çi'),
28 | ('hj', 'ç'),
29 | ('f', 'ɸ'),
30 | ('I', 'i*'),
31 | ('U', 'ɯ*'),
32 | ('r', 'ɾ')
33 | ]]
34 |
35 | # List of (romaji, ipa2) pairs for marks:
36 | _romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
37 | ('u', 'ɯ'),
38 | ('ʧ', 'tʃ'),
39 | ('j', 'dʑ'),
40 | ('y', 'j'),
41 | ('ni', 'n^i'),
42 | ('nj', 'n^'),
43 | ('hi', 'çi'),
44 | ('hj', 'ç'),
45 | ('f', 'ɸ'),
46 | ('I', 'i*'),
47 | ('U', 'ɯ*'),
48 | ('r', 'ɾ')
49 | ]]
50 |
51 | # List of (consonant, sokuon) pairs:
52 | _real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [
53 | (r'Q([↑↓]*[kg])', r'k#\1'),
54 | (r'Q([↑↓]*[tdjʧ])', r't#\1'),
55 | (r'Q([↑↓]*[sʃ])', r's\1'),
56 | (r'Q([↑↓]*[pb])', r'p#\1')
57 | ]]
58 |
59 | # List of (consonant, hatsuon) pairs:
60 | _real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [
61 | (r'N([↑↓]*[pbm])', r'm\1'),
62 | (r'N([↑↓]*[ʧʥj])', r'n^\1'),
63 | (r'N([↑↓]*[tdn])', r'n\1'),
64 | (r'N([↑↓]*[kg])', r'ŋ\1')
65 | ]]
66 |
67 |
68 | def symbols_to_japanese(text):
69 | for regex, replacement in _symbols_to_japanese:
70 | text = re.sub(regex, replacement, text)
71 | return text
72 |
73 |
74 | def japanese_to_romaji_with_accent(text):
75 | '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
76 | text = symbols_to_japanese(text)
77 | sentences = re.split(_japanese_marks, text)
78 | marks = re.findall(_japanese_marks, text)
79 | text = ''
80 | for i, sentence in enumerate(sentences):
81 | if re.match(_japanese_characters, sentence):
82 | if text != '':
83 | text += ' '
84 | labels = pyopenjtalk.extract_fullcontext(sentence)
85 | for n, label in enumerate(labels):
86 | phoneme = re.search(r'\-([^\+]*)\+', label).group(1)
87 | if phoneme not in ['sil', 'pau']:
88 | text += phoneme.replace('ch', 'ʧ').replace('sh',
89 | 'ʃ').replace('cl', 'Q')
90 | else:
91 | continue
92 | # n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
93 | a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
94 | a2 = int(re.search(r"\+(\d+)\+", label).group(1))
95 | a3 = int(re.search(r"\+(\d+)/", label).group(1))
96 | if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']:
97 | a2_next = -1
98 | else:
99 | a2_next = int(
100 | re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
101 | # Accent phrase boundary
102 | if a3 == 1 and a2_next == 1:
103 | text += ' '
104 | # Falling
105 | elif a1 == 0 and a2_next == a2 + 1:
106 | text += '↓'
107 | # Rising
108 | elif a2 == 1 and a2_next == 2:
109 | text += '↑'
110 | if i < len(marks):
111 | text += unidecode(marks[i]).replace(' ', '')
112 | return text
113 |
114 |
115 | def get_real_sokuon(text):
116 | for regex, replacement in _real_sokuon:
117 | text = re.sub(regex, replacement, text)
118 | return text
119 |
120 |
121 | def get_real_hatsuon(text):
122 | for regex, replacement in _real_hatsuon:
123 | text = re.sub(regex, replacement, text)
124 | return text
125 |
126 |
127 | def japanese_to_ipa(text):
128 | text = japanese_to_romaji_with_accent(text).replace('...', '…')
129 | text = re.sub(
130 | r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
131 | text = get_real_sokuon(text)
132 | text = get_real_hatsuon(text)
133 | for regex, replacement in _romaji_to_ipa:
134 | text = re.sub(regex, replacement, text)
135 | return text
136 |
137 |
138 | def japanese_to_ipa2(text):
139 | text = japanese_to_romaji_with_accent(text).replace('...', '…')
140 | text = get_real_sokuon(text)
141 | text = get_real_hatsuon(text)
142 | for regex, replacement in _romaji_to_ipa2:
143 | text = re.sub(regex, replacement, text)
144 | return text
145 |
146 |
147 | def japanese_to_ipa3(text):
148 | text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace(
149 | 'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a')
150 | text = re.sub(
151 | r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
152 | text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text)
153 | return text
154 |
--------------------------------------------------------------------------------
/VITS/text/korean.py:
--------------------------------------------------------------------------------
1 | import re
2 | from jamo import h2j, j2hcj
3 | import ko_pron
4 |
5 |
6 | # This is a list of Korean classifiers preceded by pure Korean numerals.
7 | _korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통'
8 |
9 | # List of (hangul, hangul divided) pairs:
10 | _hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [
11 | ('ㄳ', 'ㄱㅅ'),
12 | ('ㄵ', 'ㄴㅈ'),
13 | ('ㄶ', 'ㄴㅎ'),
14 | ('ㄺ', 'ㄹㄱ'),
15 | ('ㄻ', 'ㄹㅁ'),
16 | ('ㄼ', 'ㄹㅂ'),
17 | ('ㄽ', 'ㄹㅅ'),
18 | ('ㄾ', 'ㄹㅌ'),
19 | ('ㄿ', 'ㄹㅍ'),
20 | ('ㅀ', 'ㄹㅎ'),
21 | ('ㅄ', 'ㅂㅅ'),
22 | ('ㅘ', 'ㅗㅏ'),
23 | ('ㅙ', 'ㅗㅐ'),
24 | ('ㅚ', 'ㅗㅣ'),
25 | ('ㅝ', 'ㅜㅓ'),
26 | ('ㅞ', 'ㅜㅔ'),
27 | ('ㅟ', 'ㅜㅣ'),
28 | ('ㅢ', 'ㅡㅣ'),
29 | ('ㅑ', 'ㅣㅏ'),
30 | ('ㅒ', 'ㅣㅐ'),
31 | ('ㅕ', 'ㅣㅓ'),
32 | ('ㅖ', 'ㅣㅔ'),
33 | ('ㅛ', 'ㅣㅗ'),
34 | ('ㅠ', 'ㅣㅜ')
35 | ]]
36 |
37 | # List of (Latin alphabet, hangul) pairs:
38 | _latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
39 | ('a', '에이'),
40 | ('b', '비'),
41 | ('c', '시'),
42 | ('d', '디'),
43 | ('e', '이'),
44 | ('f', '에프'),
45 | ('g', '지'),
46 | ('h', '에이치'),
47 | ('i', '아이'),
48 | ('j', '제이'),
49 | ('k', '케이'),
50 | ('l', '엘'),
51 | ('m', '엠'),
52 | ('n', '엔'),
53 | ('o', '오'),
54 | ('p', '피'),
55 | ('q', '큐'),
56 | ('r', '아르'),
57 | ('s', '에스'),
58 | ('t', '티'),
59 | ('u', '유'),
60 | ('v', '브이'),
61 | ('w', '더블유'),
62 | ('x', '엑스'),
63 | ('y', '와이'),
64 | ('z', '제트')
65 | ]]
66 |
67 | # List of (ipa, lazy ipa) pairs:
68 | _ipa_to_lazy_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
69 | ('t͡ɕ','ʧ'),
70 | ('d͡ʑ','ʥ'),
71 | ('ɲ','n^'),
72 | ('ɕ','ʃ'),
73 | ('ʷ','w'),
74 | ('ɭ','l`'),
75 | ('ʎ','ɾ'),
76 | ('ɣ','ŋ'),
77 | ('ɰ','ɯ'),
78 | ('ʝ','j'),
79 | ('ʌ','ə'),
80 | ('ɡ','g'),
81 | ('\u031a','#'),
82 | ('\u0348','='),
83 | ('\u031e',''),
84 | ('\u0320',''),
85 | ('\u0339','')
86 | ]]
87 |
88 |
89 | def latin_to_hangul(text):
90 | for regex, replacement in _latin_to_hangul:
91 | text = re.sub(regex, replacement, text)
92 | return text
93 |
94 |
95 | def divide_hangul(text):
96 | text = j2hcj(h2j(text))
97 | for regex, replacement in _hangul_divided:
98 | text = re.sub(regex, replacement, text)
99 | return text
100 |
101 |
102 | def hangul_number(num, sino=True):
103 | '''Reference https://github.com/Kyubyong/g2pK'''
104 | num = re.sub(',', '', num)
105 |
106 | if num == '0':
107 | return '영'
108 | if not sino and num == '20':
109 | return '스무'
110 |
111 | digits = '123456789'
112 | names = '일이삼사오육칠팔구'
113 | digit2name = {d: n for d, n in zip(digits, names)}
114 |
115 | modifiers = '한 두 세 네 다섯 여섯 일곱 여덟 아홉'
116 | decimals = '열 스물 서른 마흔 쉰 예순 일흔 여든 아흔'
117 | digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())}
118 | digit2dec = {d: dec for d, dec in zip(digits, decimals.split())}
119 |
120 | spelledout = []
121 | for i, digit in enumerate(num):
122 | i = len(num) - i - 1
123 | if sino:
124 | if i == 0:
125 | name = digit2name.get(digit, '')
126 | elif i == 1:
127 | name = digit2name.get(digit, '') + '십'
128 | name = name.replace('일십', '십')
129 | else:
130 | if i == 0:
131 | name = digit2mod.get(digit, '')
132 | elif i == 1:
133 | name = digit2dec.get(digit, '')
134 | if digit == '0':
135 | if i % 4 == 0:
136 | last_three = spelledout[-min(3, len(spelledout)):]
137 | if ''.join(last_three) == '':
138 | spelledout.append('')
139 | continue
140 | else:
141 | spelledout.append('')
142 | continue
143 | if i == 2:
144 | name = digit2name.get(digit, '') + '백'
145 | name = name.replace('일백', '백')
146 | elif i == 3:
147 | name = digit2name.get(digit, '') + '천'
148 | name = name.replace('일천', '천')
149 | elif i == 4:
150 | name = digit2name.get(digit, '') + '만'
151 | name = name.replace('일만', '만')
152 | elif i == 5:
153 | name = digit2name.get(digit, '') + '십'
154 | name = name.replace('일십', '십')
155 | elif i == 6:
156 | name = digit2name.get(digit, '') + '백'
157 | name = name.replace('일백', '백')
158 | elif i == 7:
159 | name = digit2name.get(digit, '') + '천'
160 | name = name.replace('일천', '천')
161 | elif i == 8:
162 | name = digit2name.get(digit, '') + '억'
163 | elif i == 9:
164 | name = digit2name.get(digit, '') + '십'
165 | elif i == 10:
166 | name = digit2name.get(digit, '') + '백'
167 | elif i == 11:
168 | name = digit2name.get(digit, '') + '천'
169 | elif i == 12:
170 | name = digit2name.get(digit, '') + '조'
171 | elif i == 13:
172 | name = digit2name.get(digit, '') + '십'
173 | elif i == 14:
174 | name = digit2name.get(digit, '') + '백'
175 | elif i == 15:
176 | name = digit2name.get(digit, '') + '천'
177 | spelledout.append(name)
178 | return ''.join(elem for elem in spelledout)
179 |
180 |
181 | def number_to_hangul(text):
182 | '''Reference https://github.com/Kyubyong/g2pK'''
183 | tokens = set(re.findall(r'(\d[\d,]*)([\uac00-\ud71f]+)', text))
184 | for token in tokens:
185 | num, classifier = token
186 | if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers:
187 | spelledout = hangul_number(num, sino=False)
188 | else:
189 | spelledout = hangul_number(num, sino=True)
190 | text = text.replace(f'{num}{classifier}', f'{spelledout}{classifier}')
191 | # digit by digit for remaining digits
192 | digits = '0123456789'
193 | names = '영일이삼사오육칠팔구'
194 | for d, n in zip(digits, names):
195 | text = text.replace(d, n)
196 | return text
197 |
198 |
199 | def korean_to_lazy_ipa(text):
200 | text = latin_to_hangul(text)
201 | text = number_to_hangul(text)
202 | text=re.sub('[\uac00-\ud7af]+',lambda x:ko_pron.romanise(x.group(0),'ipa').split('] ~ [')[0],text)
203 | for regex, replacement in _ipa_to_lazy_ipa:
204 | text = re.sub(regex, replacement, text)
205 | return text
206 |
207 |
208 | def korean_to_ipa(text):
209 | text = korean_to_lazy_ipa(text)
210 | return text.replace('ʧ','tʃ').replace('ʥ','dʑ')
211 |
--------------------------------------------------------------------------------
/VITS/text/mandarin.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import re
4 | from pypinyin import lazy_pinyin, BOPOMOFO
5 | import jieba
6 | import cn2an
7 | import logging
8 |
9 |
10 | # List of (Latin alphabet, bopomofo) pairs:
11 | _latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
12 | ('a', 'ㄟˉ'),
13 | ('b', 'ㄅㄧˋ'),
14 | ('c', 'ㄙㄧˉ'),
15 | ('d', 'ㄉㄧˋ'),
16 | ('e', 'ㄧˋ'),
17 | ('f', 'ㄝˊㄈㄨˋ'),
18 | ('g', 'ㄐㄧˋ'),
19 | ('h', 'ㄝˇㄑㄩˋ'),
20 | ('i', 'ㄞˋ'),
21 | ('j', 'ㄐㄟˋ'),
22 | ('k', 'ㄎㄟˋ'),
23 | ('l', 'ㄝˊㄛˋ'),
24 | ('m', 'ㄝˊㄇㄨˋ'),
25 | ('n', 'ㄣˉ'),
26 | ('o', 'ㄡˉ'),
27 | ('p', 'ㄆㄧˉ'),
28 | ('q', 'ㄎㄧㄡˉ'),
29 | ('r', 'ㄚˋ'),
30 | ('s', 'ㄝˊㄙˋ'),
31 | ('t', 'ㄊㄧˋ'),
32 | ('u', 'ㄧㄡˉ'),
33 | ('v', 'ㄨㄧˉ'),
34 | ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'),
35 | ('x', 'ㄝˉㄎㄨˋㄙˋ'),
36 | ('y', 'ㄨㄞˋ'),
37 | ('z', 'ㄗㄟˋ')
38 | ]]
39 |
40 | # List of (bopomofo, romaji) pairs:
41 | _bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [
42 | ('ㄅㄛ', 'p⁼wo'),
43 | ('ㄆㄛ', 'pʰwo'),
44 | ('ㄇㄛ', 'mwo'),
45 | ('ㄈㄛ', 'fwo'),
46 | ('ㄅ', 'p⁼'),
47 | ('ㄆ', 'pʰ'),
48 | ('ㄇ', 'm'),
49 | ('ㄈ', 'f'),
50 | ('ㄉ', 't⁼'),
51 | ('ㄊ', 'tʰ'),
52 | ('ㄋ', 'n'),
53 | ('ㄌ', 'l'),
54 | ('ㄍ', 'k⁼'),
55 | ('ㄎ', 'kʰ'),
56 | ('ㄏ', 'h'),
57 | ('ㄐ', 'ʧ⁼'),
58 | ('ㄑ', 'ʧʰ'),
59 | ('ㄒ', 'ʃ'),
60 | ('ㄓ', 'ʦ`⁼'),
61 | ('ㄔ', 'ʦ`ʰ'),
62 | ('ㄕ', 's`'),
63 | ('ㄖ', 'ɹ`'),
64 | ('ㄗ', 'ʦ⁼'),
65 | ('ㄘ', 'ʦʰ'),
66 | ('ㄙ', 's'),
67 | ('ㄚ', 'a'),
68 | ('ㄛ', 'o'),
69 | ('ㄜ', 'ə'),
70 | ('ㄝ', 'e'),
71 | ('ㄞ', 'ai'),
72 | ('ㄟ', 'ei'),
73 | ('ㄠ', 'au'),
74 | ('ㄡ', 'ou'),
75 | ('ㄧㄢ', 'yeNN'),
76 | ('ㄢ', 'aNN'),
77 | ('ㄧㄣ', 'iNN'),
78 | ('ㄣ', 'əNN'),
79 | ('ㄤ', 'aNg'),
80 | ('ㄧㄥ', 'iNg'),
81 | ('ㄨㄥ', 'uNg'),
82 | ('ㄩㄥ', 'yuNg'),
83 | ('ㄥ', 'əNg'),
84 | ('ㄦ', 'əɻ'),
85 | ('ㄧ', 'i'),
86 | ('ㄨ', 'u'),
87 | ('ㄩ', 'ɥ'),
88 | ('ˉ', '→'),
89 | ('ˊ', '↑'),
90 | ('ˇ', '↓↑'),
91 | ('ˋ', '↓'),
92 | ('˙', ''),
93 | (',', ','),
94 | ('。', '.'),
95 | ('!', '!'),
96 | ('?', '?'),
97 | ('—', '-')
98 | ]]
99 |
100 | # List of (romaji, ipa) pairs:
101 | _romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
102 | ('ʃy', 'ʃ'),
103 | ('ʧʰy', 'ʧʰ'),
104 | ('ʧ⁼y', 'ʧ⁼'),
105 | ('NN', 'n'),
106 | ('Ng', 'ŋ'),
107 | ('y', 'j'),
108 | ('h', 'x')
109 | ]]
110 |
111 | # List of (bopomofo, ipa) pairs:
112 | _bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
113 | ('ㄅㄛ', 'p⁼wo'),
114 | ('ㄆㄛ', 'pʰwo'),
115 | ('ㄇㄛ', 'mwo'),
116 | ('ㄈㄛ', 'fwo'),
117 | ('ㄅ', 'p⁼'),
118 | ('ㄆ', 'pʰ'),
119 | ('ㄇ', 'm'),
120 | ('ㄈ', 'f'),
121 | ('ㄉ', 't⁼'),
122 | ('ㄊ', 'tʰ'),
123 | ('ㄋ', 'n'),
124 | ('ㄌ', 'l'),
125 | ('ㄍ', 'k⁼'),
126 | ('ㄎ', 'kʰ'),
127 | ('ㄏ', 'x'),
128 | ('ㄐ', 'tʃ⁼'),
129 | ('ㄑ', 'tʃʰ'),
130 | ('ㄒ', 'ʃ'),
131 | ('ㄓ', 'ts`⁼'),
132 | ('ㄔ', 'ts`ʰ'),
133 | ('ㄕ', 's`'),
134 | ('ㄖ', 'ɹ`'),
135 | ('ㄗ', 'ts⁼'),
136 | ('ㄘ', 'tsʰ'),
137 | ('ㄙ', 's'),
138 | ('ㄚ', 'a'),
139 | ('ㄛ', 'o'),
140 | ('ㄜ', 'ə'),
141 | ('ㄝ', 'ɛ'),
142 | ('ㄞ', 'aɪ'),
143 | ('ㄟ', 'eɪ'),
144 | ('ㄠ', 'ɑʊ'),
145 | ('ㄡ', 'oʊ'),
146 | ('ㄧㄢ', 'jɛn'),
147 | ('ㄩㄢ', 'ɥæn'),
148 | ('ㄢ', 'an'),
149 | ('ㄧㄣ', 'in'),
150 | ('ㄩㄣ', 'ɥn'),
151 | ('ㄣ', 'ən'),
152 | ('ㄤ', 'ɑŋ'),
153 | ('ㄧㄥ', 'iŋ'),
154 | ('ㄨㄥ', 'ʊŋ'),
155 | ('ㄩㄥ', 'jʊŋ'),
156 | ('ㄥ', 'əŋ'),
157 | ('ㄦ', 'əɻ'),
158 | ('ㄧ', 'i'),
159 | ('ㄨ', 'u'),
160 | ('ㄩ', 'ɥ'),
161 | ('ˉ', '→'),
162 | ('ˊ', '↑'),
163 | ('ˇ', '↓↑'),
164 | ('ˋ', '↓'),
165 | ('˙', ''),
166 | (',', ','),
167 | ('。', '.'),
168 | ('!', '!'),
169 | ('?', '?'),
170 | ('—', '-')
171 | ]]
172 |
173 | # List of (bopomofo, ipa2) pairs:
174 | _bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
175 | ('ㄅㄛ', 'pwo'),
176 | ('ㄆㄛ', 'pʰwo'),
177 | ('ㄇㄛ', 'mwo'),
178 | ('ㄈㄛ', 'fwo'),
179 | ('ㄅ', 'p'),
180 | ('ㄆ', 'pʰ'),
181 | ('ㄇ', 'm'),
182 | ('ㄈ', 'f'),
183 | ('ㄉ', 't'),
184 | ('ㄊ', 'tʰ'),
185 | ('ㄋ', 'n'),
186 | ('ㄌ', 'l'),
187 | ('ㄍ', 'k'),
188 | ('ㄎ', 'kʰ'),
189 | ('ㄏ', 'h'),
190 | ('ㄐ', 'tɕ'),
191 | ('ㄑ', 'tɕʰ'),
192 | ('ㄒ', 'ɕ'),
193 | ('ㄓ', 'tʂ'),
194 | ('ㄔ', 'tʂʰ'),
195 | ('ㄕ', 'ʂ'),
196 | ('ㄖ', 'ɻ'),
197 | ('ㄗ', 'ts'),
198 | ('ㄘ', 'tsʰ'),
199 | ('ㄙ', 's'),
200 | ('ㄚ', 'a'),
201 | ('ㄛ', 'o'),
202 | ('ㄜ', 'ɤ'),
203 | ('ㄝ', 'ɛ'),
204 | ('ㄞ', 'aɪ'),
205 | ('ㄟ', 'eɪ'),
206 | ('ㄠ', 'ɑʊ'),
207 | ('ㄡ', 'oʊ'),
208 | ('ㄧㄢ', 'jɛn'),
209 | ('ㄩㄢ', 'yæn'),
210 | ('ㄢ', 'an'),
211 | ('ㄧㄣ', 'in'),
212 | ('ㄩㄣ', 'yn'),
213 | ('ㄣ', 'ən'),
214 | ('ㄤ', 'ɑŋ'),
215 | ('ㄧㄥ', 'iŋ'),
216 | ('ㄨㄥ', 'ʊŋ'),
217 | ('ㄩㄥ', 'jʊŋ'),
218 | ('ㄥ', 'ɤŋ'),
219 | ('ㄦ', 'əɻ'),
220 | ('ㄧ', 'i'),
221 | ('ㄨ', 'u'),
222 | ('ㄩ', 'y'),
223 | ('ˉ', '˥'),
224 | ('ˊ', '˧˥'),
225 | ('ˇ', '˨˩˦'),
226 | ('ˋ', '˥˩'),
227 | ('˙', ''),
228 | (',', ','),
229 | ('。', '.'),
230 | ('!', '!'),
231 | ('?', '?'),
232 | ('—', '-')
233 | ]]
234 |
235 |
236 | def number_to_chinese(text):
237 | numbers = re.findall(r'\d+(?:\.?\d+)?', text)
238 | for number in numbers:
239 | text = text.replace(number, cn2an.an2cn(number), 1)
240 | return text
241 |
242 |
243 | def chinese_to_bopomofo(text):
244 | text = text.replace('、', ',').replace(';', ',').replace(':', ',')
245 | words = jieba.lcut(text, cut_all=False)
246 | text = ''
247 | for word in words:
248 | bopomofos = lazy_pinyin(word, BOPOMOFO)
249 | if not re.search('[\u4e00-\u9fff]', word):
250 | text += word
251 | continue
252 | for i in range(len(bopomofos)):
253 | bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i])
254 | if text != '':
255 | text += ' '
256 | text += ''.join(bopomofos)
257 | return text
258 |
259 |
260 | def latin_to_bopomofo(text):
261 | for regex, replacement in _latin_to_bopomofo:
262 | text = re.sub(regex, replacement, text)
263 | return text
264 |
265 |
266 | def bopomofo_to_romaji(text):
267 | for regex, replacement in _bopomofo_to_romaji:
268 | text = re.sub(regex, replacement, text)
269 | return text
270 |
271 |
272 | def bopomofo_to_ipa(text):
273 | for regex, replacement in _bopomofo_to_ipa:
274 | text = re.sub(regex, replacement, text)
275 | return text
276 |
277 |
278 | def bopomofo_to_ipa2(text):
279 | for regex, replacement in _bopomofo_to_ipa2:
280 | text = re.sub(regex, replacement, text)
281 | return text
282 |
283 |
284 | def chinese_to_romaji(text):
285 | text = number_to_chinese(text)
286 | text = chinese_to_bopomofo(text)
287 | text = latin_to_bopomofo(text)
288 | text = bopomofo_to_romaji(text)
289 | text = re.sub('i([aoe])', r'y\1', text)
290 | text = re.sub('u([aoəe])', r'w\1', text)
291 | text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
292 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
293 | text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
294 | return text
295 |
296 |
297 | def chinese_to_lazy_ipa(text):
298 | text = chinese_to_romaji(text)
299 | for regex, replacement in _romaji_to_ipa:
300 | text = re.sub(regex, replacement, text)
301 | return text
302 |
303 |
304 | def chinese_to_ipa(text):
305 | text = number_to_chinese(text)
306 | text = chinese_to_bopomofo(text)
307 | text = latin_to_bopomofo(text)
308 | text = bopomofo_to_ipa(text)
309 | text = re.sub('i([aoe])', r'j\1', text)
310 | text = re.sub('u([aoəe])', r'w\1', text)
311 | text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
312 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
313 | text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
314 | return text
315 |
316 |
317 | def chinese_to_ipa2(text):
318 | text = number_to_chinese(text)
319 | text = chinese_to_bopomofo(text)
320 | text = latin_to_bopomofo(text)
321 | text = bopomofo_to_ipa2(text)
322 | text = re.sub(r'i([aoe])', r'j\1', text)
323 | text = re.sub(r'u([aoəe])', r'w\1', text)
324 | text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text)
325 | text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text)
326 | return text
327 |
--------------------------------------------------------------------------------
/VITS/text/ngu_dialect.py:
--------------------------------------------------------------------------------
1 | import re
2 | import opencc
3 |
4 |
5 | dialects = {'SZ': 'suzhou', 'WX': 'wuxi', 'CZ': 'changzhou', 'HZ': 'hangzhou',
6 | 'SX': 'shaoxing', 'NB': 'ningbo', 'JJ': 'jingjiang', 'YX': 'yixing',
7 | 'JD': 'jiading', 'ZR': 'zhenru', 'PH': 'pinghu', 'TX': 'tongxiang',
8 | 'JS': 'jiashan', 'HN': 'xiashi', 'LP': 'linping', 'XS': 'xiaoshan',
9 | 'FY': 'fuyang', 'RA': 'ruao', 'CX': 'cixi', 'SM': 'sanmen',
10 | 'TT': 'tiantai', 'WZ': 'wenzhou', 'SC': 'suichang', 'YB': 'youbu'}
11 |
12 | converters = {}
13 |
14 | for dialect in dialects.values():
15 | try:
16 | converters[dialect] = opencc.OpenCC(dialect)
17 | except:
18 | pass
19 |
20 |
21 | def ngu_dialect_to_ipa(text, dialect):
22 | dialect = dialects[dialect]
23 | text = converters[dialect].convert(text).replace('-','').replace('$',' ')
24 | text = re.sub(r'[、;:]', ',', text)
25 | text = re.sub(r'\s*,\s*', ', ', text)
26 | text = re.sub(r'\s*。\s*', '. ', text)
27 | text = re.sub(r'\s*?\s*', '? ', text)
28 | text = re.sub(r'\s*!\s*', '! ', text)
29 | text = re.sub(r'\s*$', '', text)
30 | return text
31 |
--------------------------------------------------------------------------------
/VITS/text/sanskrit.py:
--------------------------------------------------------------------------------
1 | import re
2 | from indic_transliteration import sanscript
3 |
4 |
5 | # List of (iast, ipa) pairs:
6 | _iast_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
7 | ('a', 'ə'),
8 | ('ā', 'aː'),
9 | ('ī', 'iː'),
10 | ('ū', 'uː'),
11 | ('ṛ', 'ɹ`'),
12 | ('ṝ', 'ɹ`ː'),
13 | ('ḷ', 'l`'),
14 | ('ḹ', 'l`ː'),
15 | ('e', 'eː'),
16 | ('o', 'oː'),
17 | ('k', 'k⁼'),
18 | ('k⁼h', 'kʰ'),
19 | ('g', 'g⁼'),
20 | ('g⁼h', 'gʰ'),
21 | ('ṅ', 'ŋ'),
22 | ('c', 'ʧ⁼'),
23 | ('ʧ⁼h', 'ʧʰ'),
24 | ('j', 'ʥ⁼'),
25 | ('ʥ⁼h', 'ʥʰ'),
26 | ('ñ', 'n^'),
27 | ('ṭ', 't`⁼'),
28 | ('t`⁼h', 't`ʰ'),
29 | ('ḍ', 'd`⁼'),
30 | ('d`⁼h', 'd`ʰ'),
31 | ('ṇ', 'n`'),
32 | ('t', 't⁼'),
33 | ('t⁼h', 'tʰ'),
34 | ('d', 'd⁼'),
35 | ('d⁼h', 'dʰ'),
36 | ('p', 'p⁼'),
37 | ('p⁼h', 'pʰ'),
38 | ('b', 'b⁼'),
39 | ('b⁼h', 'bʰ'),
40 | ('y', 'j'),
41 | ('ś', 'ʃ'),
42 | ('ṣ', 's`'),
43 | ('r', 'ɾ'),
44 | ('l̤', 'l`'),
45 | ('h', 'ɦ'),
46 | ("'", ''),
47 | ('~', '^'),
48 | ('ṃ', '^')
49 | ]]
50 |
51 |
52 | def devanagari_to_ipa(text):
53 | text = text.replace('ॐ', 'ओम्')
54 | text = re.sub(r'\s*।\s*$', '.', text)
55 | text = re.sub(r'\s*।\s*', ', ', text)
56 | text = re.sub(r'\s*॥', '.', text)
57 | text = sanscript.transliterate(text, sanscript.DEVANAGARI, sanscript.IAST)
58 | for regex, replacement in _iast_to_ipa:
59 | text = re.sub(regex, replacement, text)
60 | text = re.sub('(.)[`ː]*ḥ', lambda x: x.group(0)
61 | [:-1]+'h'+x.group(1)+'*', text)
62 | return text
63 |
--------------------------------------------------------------------------------
/VITS/text/shanghainese.py:
--------------------------------------------------------------------------------
1 | import re
2 | import cn2an
3 | import opencc
4 |
5 |
6 | converter = opencc.OpenCC('zaonhe')
7 |
8 | # List of (Latin alphabet, ipa) pairs:
9 | _latin_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
10 | ('A', 'ᴇ'),
11 | ('B', 'bi'),
12 | ('C', 'si'),
13 | ('D', 'di'),
14 | ('E', 'i'),
15 | ('F', 'ᴇf'),
16 | ('G', 'dʑi'),
17 | ('H', 'ᴇtɕʰ'),
18 | ('I', 'ᴀi'),
19 | ('J', 'dʑᴇ'),
20 | ('K', 'kʰᴇ'),
21 | ('L', 'ᴇl'),
22 | ('M', 'ᴇm'),
23 | ('N', 'ᴇn'),
24 | ('O', 'o'),
25 | ('P', 'pʰi'),
26 | ('Q', 'kʰiu'),
27 | ('R', 'ᴀl'),
28 | ('S', 'ᴇs'),
29 | ('T', 'tʰi'),
30 | ('U', 'ɦiu'),
31 | ('V', 'vi'),
32 | ('W', 'dᴀbɤliu'),
33 | ('X', 'ᴇks'),
34 | ('Y', 'uᴀi'),
35 | ('Z', 'zᴇ')
36 | ]]
37 |
38 |
39 | def _number_to_shanghainese(num):
40 | num = cn2an.an2cn(num).replace('一十','十').replace('二十', '廿').replace('二', '两')
41 | return re.sub(r'((?:^|[^三四五六七八九])十|廿)两', r'\1二', num)
42 |
43 |
44 | def number_to_shanghainese(text):
45 | return re.sub(r'\d+(?:\.?\d+)?', lambda x: _number_to_shanghainese(x.group()), text)
46 |
47 |
48 | def latin_to_ipa(text):
49 | for regex, replacement in _latin_to_ipa:
50 | text = re.sub(regex, replacement, text)
51 | return text
52 |
53 |
54 | def shanghainese_to_ipa(text):
55 | text = number_to_shanghainese(text.upper())
56 | text = converter.convert(text).replace('-','').replace('$',' ')
57 | text = re.sub(r'[A-Z]', lambda x: latin_to_ipa(x.group())+' ', text)
58 | text = re.sub(r'[、;:]', ',', text)
59 | text = re.sub(r'\s*,\s*', ', ', text)
60 | text = re.sub(r'\s*。\s*', '. ', text)
61 | text = re.sub(r'\s*?\s*', '? ', text)
62 | text = re.sub(r'\s*!\s*', '! ', text)
63 | text = re.sub(r'\s*$', '', text)
64 | return text
65 |
--------------------------------------------------------------------------------
/VITS/text/symbols.py:
--------------------------------------------------------------------------------
1 | '''
2 | Defines the set of symbols used in text input to the model.
3 | '''
4 |
5 | # japanese_cleaners
6 | # _pad = '_'
7 | # _punctuation = ',.!?-'
8 | # _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
9 |
10 |
11 | '''# japanese_cleaners2
12 | _pad = '_'
13 | _punctuation = ',.!?-~…'
14 | _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
15 | '''
16 |
17 |
18 | '''# korean_cleaners
19 | _pad = '_'
20 | _punctuation = ',.!?…~'
21 | _letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
22 | '''
23 |
24 | '''# chinese_cleaners
25 | _pad = '_'
26 | _punctuation = ',。!?—…'
27 | _letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
28 | '''
29 |
30 | # # zh_ja_mixture_cleaners
31 | # _pad = '_'
32 | # _punctuation = ',.!?-~…'
33 | # _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
34 |
35 |
36 | '''# sanskrit_cleaners
37 | _pad = '_'
38 | _punctuation = '।'
39 | _letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ '
40 | '''
41 |
42 | '''# cjks_cleaners
43 | _pad = '_'
44 | _punctuation = ',.!?-~…'
45 | _letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ '
46 | '''
47 |
48 | '''# thai_cleaners
49 | _pad = '_'
50 | _punctuation = '.!? '
51 | _letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์'
52 | '''
53 |
54 | # # cjke_cleaners2
55 | _pad = '_'
56 | _punctuation = ',.!?-~…'
57 | _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
58 |
59 |
60 | '''# shanghainese_cleaners
61 | _pad = '_'
62 | _punctuation = ',.!?…'
63 | _letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 '
64 | '''
65 |
66 | '''# chinese_dialect_cleaners
67 | _pad = '_'
68 | _punctuation = ',.!?~…─'
69 | _letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ '
70 | '''
71 |
72 | # Export all symbols:
73 | symbols = [_pad] + list(_punctuation) + list(_letters)
74 |
75 | # Special symbol ids
76 | SPACE_ID = symbols.index(" ")
77 |
--------------------------------------------------------------------------------
/VITS/text/thai.py:
--------------------------------------------------------------------------------
1 | import re
2 | from num_thai.thainumbers import NumThai
3 |
4 |
5 | num = NumThai()
6 |
7 | # List of (Latin alphabet, Thai) pairs:
8 | _latin_to_thai = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
9 | ('a', 'เอ'),
10 | ('b','บี'),
11 | ('c','ซี'),
12 | ('d','ดี'),
13 | ('e','อี'),
14 | ('f','เอฟ'),
15 | ('g','จี'),
16 | ('h','เอช'),
17 | ('i','ไอ'),
18 | ('j','เจ'),
19 | ('k','เค'),
20 | ('l','แอล'),
21 | ('m','เอ็ม'),
22 | ('n','เอ็น'),
23 | ('o','โอ'),
24 | ('p','พี'),
25 | ('q','คิว'),
26 | ('r','แอร์'),
27 | ('s','เอส'),
28 | ('t','ที'),
29 | ('u','ยู'),
30 | ('v','วี'),
31 | ('w','ดับเบิลยู'),
32 | ('x','เอ็กซ์'),
33 | ('y','วาย'),
34 | ('z','ซี')
35 | ]]
36 |
37 |
38 | def num_to_thai(text):
39 | return re.sub(r'(?:\d+(?:,?\d+)?)+(?:\.\d+(?:,?\d+)?)?', lambda x: ''.join(num.NumberToTextThai(float(x.group(0).replace(',', '')))), text)
40 |
41 | def latin_to_thai(text):
42 | for regex, replacement in _latin_to_thai:
43 | text = re.sub(regex, replacement, text)
44 | return text
45 |
--------------------------------------------------------------------------------
/VITS/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | import numpy as np
5 |
6 |
7 | DEFAULT_MIN_BIN_WIDTH = 1e-3
8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3
9 | DEFAULT_MIN_DERIVATIVE = 1e-3
10 |
11 |
12 | def piecewise_rational_quadratic_transform(inputs,
13 | unnormalized_widths,
14 | unnormalized_heights,
15 | unnormalized_derivatives,
16 | inverse=False,
17 | tails=None,
18 | tail_bound=1.,
19 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
20 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
21 | min_derivative=DEFAULT_MIN_DERIVATIVE):
22 |
23 | if tails is None:
24 | spline_fn = rational_quadratic_spline
25 | spline_kwargs = {}
26 | else:
27 | spline_fn = unconstrained_rational_quadratic_spline
28 | spline_kwargs = {
29 | 'tails': tails,
30 | 'tail_bound': tail_bound
31 | }
32 |
33 | outputs, logabsdet = spline_fn(
34 | inputs=inputs,
35 | unnormalized_widths=unnormalized_widths,
36 | unnormalized_heights=unnormalized_heights,
37 | unnormalized_derivatives=unnormalized_derivatives,
38 | inverse=inverse,
39 | min_bin_width=min_bin_width,
40 | min_bin_height=min_bin_height,
41 | min_derivative=min_derivative,
42 | **spline_kwargs
43 | )
44 | return outputs, logabsdet
45 |
46 |
47 | def searchsorted(bin_locations, inputs, eps=1e-6):
48 | bin_locations[..., -1] += eps
49 | return torch.sum(
50 | inputs[..., None] >= bin_locations,
51 | dim=-1
52 | ) - 1
53 |
54 |
55 | def unconstrained_rational_quadratic_spline(inputs,
56 | unnormalized_widths,
57 | unnormalized_heights,
58 | unnormalized_derivatives,
59 | inverse=False,
60 | tails='linear',
61 | tail_bound=1.,
62 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
63 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
64 | min_derivative=DEFAULT_MIN_DERIVATIVE):
65 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
66 | outside_interval_mask = ~inside_interval_mask
67 |
68 | outputs = torch.zeros_like(inputs)
69 | logabsdet = torch.zeros_like(inputs)
70 |
71 | if tails == 'linear':
72 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
73 | constant = np.log(np.exp(1 - min_derivative) - 1)
74 | unnormalized_derivatives[..., 0] = constant
75 | unnormalized_derivatives[..., -1] = constant
76 |
77 | outputs[outside_interval_mask] = inputs[outside_interval_mask]
78 | logabsdet[outside_interval_mask] = 0
79 | else:
80 | raise RuntimeError('{} tails are not implemented.'.format(tails))
81 |
82 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
83 | inputs=inputs[inside_interval_mask],
84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87 | inverse=inverse,
88 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
89 | min_bin_width=min_bin_width,
90 | min_bin_height=min_bin_height,
91 | min_derivative=min_derivative
92 | )
93 |
94 | return outputs, logabsdet
95 |
96 | def rational_quadratic_spline(inputs,
97 | unnormalized_widths,
98 | unnormalized_heights,
99 | unnormalized_derivatives,
100 | inverse=False,
101 | left=0., right=1., bottom=0., top=1.,
102 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
103 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
104 | min_derivative=DEFAULT_MIN_DERIVATIVE):
105 | if torch.min(inputs) < left or torch.max(inputs) > right:
106 | raise ValueError('Input to a transform is not within its domain')
107 |
108 | num_bins = unnormalized_widths.shape[-1]
109 |
110 | if min_bin_width * num_bins > 1.0:
111 | raise ValueError('Minimal bin width too large for the number of bins')
112 | if min_bin_height * num_bins > 1.0:
113 | raise ValueError('Minimal bin height too large for the number of bins')
114 |
115 | widths = F.softmax(unnormalized_widths, dim=-1)
116 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
117 | cumwidths = torch.cumsum(widths, dim=-1)
118 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
119 | cumwidths = (right - left) * cumwidths + left
120 | cumwidths[..., 0] = left
121 | cumwidths[..., -1] = right
122 | widths = cumwidths[..., 1:] - cumwidths[..., :-1]
123 |
124 | derivatives = min_derivative + F.softplus(unnormalized_derivatives)
125 |
126 | heights = F.softmax(unnormalized_heights, dim=-1)
127 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
128 | cumheights = torch.cumsum(heights, dim=-1)
129 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
130 | cumheights = (top - bottom) * cumheights + bottom
131 | cumheights[..., 0] = bottom
132 | cumheights[..., -1] = top
133 | heights = cumheights[..., 1:] - cumheights[..., :-1]
134 |
135 | if inverse:
136 | bin_idx = searchsorted(cumheights, inputs)[..., None]
137 | else:
138 | bin_idx = searchsorted(cumwidths, inputs)[..., None]
139 |
140 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
141 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
142 |
143 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
144 | delta = heights / widths
145 | input_delta = delta.gather(-1, bin_idx)[..., 0]
146 |
147 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
148 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
149 |
150 | input_heights = heights.gather(-1, bin_idx)[..., 0]
151 |
152 | if inverse:
153 | a = (((inputs - input_cumheights) * (input_derivatives
154 | + input_derivatives_plus_one
155 | - 2 * input_delta)
156 | + input_heights * (input_delta - input_derivatives)))
157 | b = (input_heights * input_derivatives
158 | - (inputs - input_cumheights) * (input_derivatives
159 | + input_derivatives_plus_one
160 | - 2 * input_delta))
161 | c = - input_delta * (inputs - input_cumheights)
162 |
163 | discriminant = b.pow(2) - 4 * a * c
164 | assert (discriminant >= 0).all()
165 |
166 | root = (2 * c) / (-b - torch.sqrt(discriminant))
167 | outputs = root * input_bin_widths + input_cumwidths
168 |
169 | theta_one_minus_theta = root * (1 - root)
170 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
171 | * theta_one_minus_theta)
172 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
173 | + 2 * input_delta * theta_one_minus_theta
174 | + input_derivatives * (1 - root).pow(2))
175 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
176 |
177 | return outputs, -logabsdet
178 | else:
179 | theta = (inputs - input_cumwidths) / input_bin_widths
180 | theta_one_minus_theta = theta * (1 - theta)
181 |
182 | numerator = input_heights * (input_delta * theta.pow(2)
183 | + input_derivatives * theta_one_minus_theta)
184 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
185 | * theta_one_minus_theta)
186 | outputs = input_cumheights + numerator / denominator
187 |
188 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
189 | + 2 * input_delta * theta_one_minus_theta
190 | + input_derivatives * (1 - theta).pow(2))
191 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
192 |
193 | return outputs, logabsdet
194 |
--------------------------------------------------------------------------------
/VITS/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import sys
4 | import argparse
5 | import logging
6 | import json
7 | import subprocess
8 | import numpy as np
9 | from scipy.io.wavfile import read
10 | import torch
11 |
12 | MATPLOTLIB_FLAG = False
13 |
14 | #logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
15 | logger = logging
16 |
17 |
18 | def load_checkpoint(checkpoint_path, model, optimizer=None, drop_speaker_emb=False):
19 | assert os.path.isfile(checkpoint_path)
20 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
21 | iteration = checkpoint_dict['iteration']
22 | learning_rate = checkpoint_dict['learning_rate']
23 | if optimizer is not None:
24 | optimizer.load_state_dict(checkpoint_dict['optimizer'])
25 | saved_state_dict = checkpoint_dict['model']
26 | if hasattr(model, 'module'):
27 | state_dict = model.module.state_dict()
28 | else:
29 | state_dict = model.state_dict()
30 | new_state_dict = {}
31 | for k, v in state_dict.items():
32 | try:
33 | if k == 'emb_g.weight':
34 | if drop_speaker_emb:
35 | new_state_dict[k] = v
36 | continue
37 | v[:saved_state_dict[k].shape[0], :] = saved_state_dict[k]
38 | new_state_dict[k] = v
39 | else:
40 | new_state_dict[k] = saved_state_dict[k]
41 | except:
42 | logger.info("%s is not in the checkpoint" % k)
43 | new_state_dict[k] = v
44 | if hasattr(model, 'module'):
45 | model.module.load_state_dict(new_state_dict)
46 | else:
47 | model.load_state_dict(new_state_dict)
48 | logger.info("Loaded checkpoint '{}' (iteration {})".format(
49 | checkpoint_path, iteration))
50 | return model, optimizer, learning_rate, iteration
51 |
52 |
53 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
54 | logger.info("Saving model and optimizer state at iteration {} to {}".format(
55 | iteration, checkpoint_path))
56 | if hasattr(model, 'module'):
57 | state_dict = model.module.state_dict()
58 | else:
59 | state_dict = model.state_dict()
60 | torch.save({'model': state_dict,
61 | 'iteration': iteration,
62 | 'optimizer': optimizer.state_dict() if optimizer is not None else None,
63 | 'learning_rate': learning_rate}, checkpoint_path)
64 |
65 |
66 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
67 | for k, v in scalars.items():
68 | writer.add_scalar(k, v, global_step)
69 | for k, v in histograms.items():
70 | writer.add_histogram(k, v, global_step)
71 | for k, v in images.items():
72 | writer.add_image(k, v, global_step, dataformats='HWC')
73 | for k, v in audios.items():
74 | writer.add_audio(k, v, global_step, audio_sampling_rate)
75 |
76 |
77 | def latest_checkpoint_path(dir_path, regex="G_*.pth"):
78 | f_list = glob.glob(os.path.join(dir_path, regex))
79 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
80 | x = f_list[-1]
81 | #print(x)
82 | return x
83 |
84 |
85 | def plot_spectrogram_to_numpy(spectrogram):
86 | global MATPLOTLIB_FLAG
87 | if not MATPLOTLIB_FLAG:
88 | import matplotlib
89 | matplotlib.use("Agg")
90 | MATPLOTLIB_FLAG = True
91 | #mpl_logger = logging.getLogger('matplotlib')
92 | #mpl_logger.setLevel(logging.WARNING)
93 | import matplotlib.pylab as plt
94 | import numpy as np
95 |
96 | fig, ax = plt.subplots(figsize=(10, 2))
97 | im = ax.imshow(spectrogram, aspect="auto", origin="lower",
98 | interpolation='none')
99 | plt.colorbar(im, ax=ax)
100 | plt.xlabel("Frames")
101 | plt.ylabel("Channels")
102 | plt.tight_layout()
103 |
104 | fig.canvas.draw()
105 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
106 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
107 | plt.close()
108 | return data
109 |
110 |
111 | def plot_alignment_to_numpy(alignment, info=None):
112 | global MATPLOTLIB_FLAG
113 | if not MATPLOTLIB_FLAG:
114 | import matplotlib
115 | matplotlib.use("Agg")
116 | MATPLOTLIB_FLAG = True
117 | #mpl_logger = logging.getLogger('matplotlib')
118 | #mpl_logger.setLevel(logging.WARNING)
119 | import matplotlib.pylab as plt
120 | import numpy as np
121 |
122 | fig, ax = plt.subplots(figsize=(6, 4))
123 | im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
124 | interpolation='none')
125 | fig.colorbar(im, ax=ax)
126 | xlabel = 'Decoder timestep'
127 | if info is not None:
128 | xlabel += '\n\n' + info
129 | plt.xlabel(xlabel)
130 | plt.ylabel('Encoder timestep')
131 | plt.tight_layout()
132 |
133 | fig.canvas.draw()
134 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
135 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
136 | plt.close()
137 | return data
138 |
139 |
140 | def load_wav_to_torch(full_path):
141 | sampling_rate, data = read(full_path)
142 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate
143 |
144 |
145 | def load_filepaths_and_text(filename, split="|"):
146 | with open(filename, encoding='utf-8') as f:
147 | filepaths_and_text = [line.strip().split(split) for line in f]
148 | return filepaths_and_text
149 |
150 |
151 | def get_hparams(init=True):
152 | parser = argparse.ArgumentParser()
153 | parser.add_argument('-c', '--config', type=str, default="./configs/modified_finetune_speaker.json",
154 | help='JSON file for configuration')
155 | parser.add_argument('-m', '--model', type=str, default="pretrained_models",
156 | help='Model name')
157 | parser.add_argument('-n', '--max_epochs', type=int, default=50,
158 | help='finetune epochs')
159 | parser.add_argument('--drop_speaker_embed', type=bool, default=False, help='whether to drop existing characters')
160 |
161 | args = parser.parse_args()
162 | model_dir = os.path.join("./", args.model)
163 |
164 | if not os.path.exists(model_dir):
165 | os.makedirs(model_dir)
166 |
167 | config_path = args.config
168 | config_save_path = os.path.join(model_dir, "config.json")
169 | if init:
170 | with open(config_path, "r") as f:
171 | data = f.read()
172 | with open(config_save_path, "w") as f:
173 | f.write(data)
174 | else:
175 | with open(config_save_path, "r") as f:
176 | data = f.read()
177 | config = json.loads(data)
178 |
179 | hparams = HParams(**config)
180 | hparams.model_dir = model_dir
181 | hparams.max_epochs = args.max_epochs
182 | hparams.drop_speaker_embed = args.drop_speaker_embed
183 | return hparams
184 |
185 |
186 | def get_hparams_from_dir(model_dir):
187 | config_save_path = os.path.join(model_dir, "config.json")
188 | with open(config_save_path, "r") as f:
189 | data = f.read()
190 | config = json.loads(data)
191 |
192 | hparams = HParams(**config)
193 | hparams.model_dir = model_dir
194 | return hparams
195 |
196 |
197 | def get_hparams_from_file(config_path):
198 | with open(config_path, "r", encoding="utf-8") as f:
199 | data = f.read()
200 | config = json.loads(data)
201 |
202 | hparams = HParams(**config)
203 | return hparams
204 |
205 |
206 | def check_git_hash(model_dir):
207 | source_dir = os.path.dirname(os.path.realpath(__file__))
208 | if not os.path.exists(os.path.join(source_dir, ".git")):
209 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
210 | source_dir
211 | ))
212 | return
213 |
214 | cur_hash = subprocess.getoutput("git rev-parse HEAD")
215 |
216 | path = os.path.join(model_dir, "githash")
217 | if os.path.exists(path):
218 | saved_hash = open(path).read()
219 | if saved_hash != cur_hash:
220 | logger.warn("git hash values are different. {}(saved) != {}(current)".format(
221 | saved_hash[:8], cur_hash[:8]))
222 | else:
223 | open(path, "w").write(cur_hash)
224 |
225 |
226 | def get_logger(model_dir, filename="train.log"):
227 | global logger
228 | logger = logging.getLogger(os.path.basename(model_dir))
229 | #logger.setLevel(logging.DEBUG)
230 |
231 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
232 | if not os.path.exists(model_dir):
233 | os.makedirs(model_dir)
234 | h = logging.FileHandler(os.path.join(model_dir, filename))
235 | #h.setLevel(logging.DEBUG)
236 | h.setFormatter(formatter)
237 | logger.addHandler(h)
238 | return logger
239 |
240 |
241 | class HParams():
242 | def __init__(self, **kwargs):
243 | for k, v in kwargs.items():
244 | if type(v) == dict:
245 | v = HParams(**v)
246 | self[k] = v
247 |
248 | def keys(self):
249 | return self.__dict__.keys()
250 |
251 | def items(self):
252 | return self.__dict__.items()
253 |
254 | def values(self):
255 | return self.__dict__.values()
256 |
257 | def __len__(self):
258 | return len(self.__dict__)
259 |
260 | def __getitem__(self, key):
261 | return getattr(self, key)
262 |
263 | def __setitem__(self, key, value):
264 | return setattr(self, key, value)
265 |
266 | def __contains__(self, key):
267 | return key in self.__dict__
268 |
269 | def __repr__(self):
270 | return self.__dict__.__repr__()
--------------------------------------------------------------------------------
/baiduApi.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*- #
2 | from sys import byteorder
3 | from array import array
4 | from struct import pack
5 | import time
6 |
7 | import pyaudio
8 | import wave
9 | import subprocess
10 | import urllib.request
11 | import urllib
12 | import json
13 | import base64
14 | import requests
15 |
16 | class BaiduRest:
17 | def __init__(self, cu_id, api_key, api_secert):
18 | self.token_url = "https://openapi.baidu.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s"
19 | self.getvoice_url = "http://tsn.baidu.com/text2audio?tex=%s&lan=zh&cuid=%s&ctp=1&tok=%s"
20 | self.upvoice_url = 'http://vop.baidu.com/server_api'
21 | self.cu_id = cu_id
22 | self.getToken(api_key, api_secert)
23 | self.THRESHOLD = 500
24 | self.CHUNK_SIZE = 1024
25 | self.FORMAT = pyaudio.paInt16
26 | self.RATE = 8000
27 |
28 | return
29 |
30 | def getToken(self, api_key, api_secert):
31 | token_url = self.token_url % (api_key, api_secert)
32 |
33 | r_str = urllib.request.urlopen(token_url).read()
34 | token_data = json.loads(r_str)
35 | self.token_str = token_data['access_token']
36 | pass
37 |
38 | def getVoice(self, text, filename):
39 | get_url = self.getvoice_url % (
40 | urllib.parse.quote(text), self.cu_id, self.token_str)
41 |
42 | voice_data = requests.post(get_url).content
43 | voice_fp = open(filename, 'wb+')
44 | voice_fp.write(voice_data)
45 | voice_fp.close()
46 | pass
47 |
48 | def getText(self, filename):
49 | data = {}
50 | data['format'] = 'wav'
51 | data['rate'] = 8000
52 | data['channel'] = 1
53 | data['cuid'] = self.cu_id
54 | data['token'] = self.token_str
55 | wav_fp = open(filename, 'rb')
56 | voice_data = wav_fp.read()
57 | data['len'] = len(voice_data)
58 | data['speech'] = base64.b64encode(voice_data).decode('utf-8')
59 | post_data = json.dumps(data)
60 | r_data = requests.post(
61 | self.upvoice_url, data=bytes(
62 | post_data, encoding="utf-8")).text
63 | # 3.处理返回数据
64 |
65 | return json.loads(r_data)['result'][0]
66 |
67 | # def speakMac(self, audio_file):
68 | # return_code = subprocess.call(["afplay", audio_file])
69 | # return return_code
70 |
71 | # def speak(self, audio_file):
72 | # import mp3play
73 | # clip = mp3play.load(audio_file)
74 | # clip.play()
75 | # time.sleep(10)
76 | # clip.stop()
77 |
78 | def is_silent(self, snd_data):
79 | return max(snd_data) < self.THRESHOLD
80 |
81 | def normalize(self, snd_data):
82 | MAXIMUM = 16384
83 | times = float(MAXIMUM) / max(abs(i) for i in snd_data)
84 |
85 | r = array('h')
86 | for i in snd_data:
87 | r.append(int(i * times))
88 | return r
89 |
90 | def trim(self, snd_data):
91 | def _trim(snd_data):
92 | snd_started = False
93 | r = array('h')
94 |
95 | for i in snd_data:
96 | if not snd_started and abs(i) > self.THRESHOLD:
97 | snd_started = True
98 | r.append(i)
99 |
100 | elif snd_started:
101 | r.append(i)
102 | return r
103 |
104 | # Trim to the left
105 | snd_data = _trim(snd_data)
106 |
107 | # Trim to the right
108 | snd_data.reverse()
109 | snd_data = _trim(snd_data)
110 | snd_data.reverse()
111 | return snd_data
112 |
113 | def add_silence(self, snd_data, seconds):
114 | r = array('h', [0 for i in range(int(seconds * self.RATE))])
115 | r.extend(snd_data)
116 | r.extend([0 for i in range(int(seconds * self.RATE))])
117 | return r
118 |
119 | def record(self):#录音
120 | p = pyaudio.PyAudio()
121 | stream = p.open(format=self.FORMAT, channels=1, rate=self.RATE,
122 | input=True, output=True,
123 | frames_per_buffer=self.CHUNK_SIZE)
124 |
125 | num_silent = 0
126 | snd_started = False
127 |
128 | r = array('h')
129 |
130 | while True:
131 | snd_data = array('h', stream.read(self.CHUNK_SIZE))
132 | if byteorder == 'big':
133 | snd_data.byteswap()
134 | r.extend(snd_data)
135 |
136 | silent = self.is_silent(snd_data)
137 |
138 | if silent and snd_started:
139 | num_silent += 1
140 | elif not silent and not snd_started:
141 | snd_started = True
142 |
143 | if snd_started and num_silent > 10:
144 | break
145 |
146 | sample_width = p.get_sample_size(self.FORMAT)
147 | stream.stop_stream()
148 | stream.close()
149 | p.terminate()
150 |
151 | r = self.normalize(r)
152 | r = self.trim(r)
153 | r = self.add_silence(r, 0.5)
154 | return sample_width, r
155 |
156 | def record_to_file(self, path):
157 |
158 | sample_width, data = self.record()
159 | data = pack('<' + ('h' * len(data)), *data)
160 |
161 | wf = wave.open(path, 'wb')
162 | wf.setnchannels(1)
163 | wf.setsampwidth(sample_width)
164 | wf.setframerate(self.RATE)
165 | wf.writeframes(data)
166 | wf.close()
167 |
168 | def recorder(self, filename):
169 | self.record_to_file(filename)
170 |
171 |
--------------------------------------------------------------------------------
/gptApi.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import os
3 | import ssl
4 |
5 | import warnings
6 | warnings.filterwarnings("ignore")
7 | # 只需要在python里设置代理即可
8 | #os.environ['HTTP_PROXY'] = 'http://202.79.168.14:6666'
9 | #os.environ['HTTPS_PROXY'] = 'http://202.79.168.14:6666'
10 | #openai没办法使用全局代理,在lib\site-packages\openai\api_requestor.py文件里添加代理即可
11 |
12 | # ssl._create_default_https_context = ssl._create_unverified_context
13 | # openai.api_key = '*********'
14 | # # openai.proxy = 'http://202.79.168.46:6666'
15 | # # openai.verify_ssl_certs = False
16 | # openai.api_base='https://202.79.168.46/v1'
17 | # def test_openai(string):
18 | # completion = openai.ChatCompletion.create(
19 | # model="gpt-3.5-turbo",
20 | # messages=[{"role": "user", "content": string}]
21 | # )
22 | # return completion['choices'][0]['message']['content'].strip()
23 | #
24 | # res=test_openai("巩义在哪里?")
25 | #
26 | # print(res)
27 |
28 | class GptApi:
29 | def __init__(self):
30 | self.api_key='***********'
31 | openai.api_key = self.api_key
32 | openai.api_base='https://202.79.168.46/v1'
33 |
34 | def test_openai(self,string):
35 | completion = openai.ChatCompletion.create(
36 | model="gpt-3.5-turbo",
37 | messages=[{"role": "user", "content": string}]
38 | )
39 | return completion['choices'][0]['message']['content'].strip()
40 |
41 |
42 |
43 | if __name__ == '__main__':
44 | test=GptApi()
45 | res=test.test_openai('你是谁?')
46 | print(res)
47 |
--------------------------------------------------------------------------------
/gptApi_duolun.py:
--------------------------------------------------------------------------------
1 | import os
2 | import openai
3 | from dotenv.main import load_dotenv
4 |
5 | load_dotenv()
6 | openai.api_key = os.environ['openai_api']
7 | openai.api_base='https://202.79.168.46/v1'
8 |
9 | # 单轮对话调用
10 | # model可选"gpt-3.5-turbo"与"gpt-3.5-turbo-0301"
11 | def generate_answer(messages):
12 | completion = openai.ChatCompletion.create(
13 | model="gpt-3.5-turbo",
14 | messages=messages,
15 | temperature=0.7
16 | )
17 | res_msg = completion.choices[0].message
18 | return res_msg["content"].strip()
19 |
20 |
21 | if __name__ == '__main__':
22 | # 维护一个列表用于存储多轮对话的信息
23 | messages = [{"role": "system", "content": "你现在是很有用的助手!"}]
24 | while True:
25 | prompt = input("请输入你的问题:")
26 | messages.append({"role": "user", "content": prompt})
27 | res_msg = generate_answer(messages)
28 | messages.append({"role": "assistant", "content": res_msg})
29 | print(res_msg)
--------------------------------------------------------------------------------
/input/input.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/input/input.wav
--------------------------------------------------------------------------------
/output/output.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanwuyu-web/chat_robot_for_openai/30800dedef102bfc8cf473f334c3cc9aa570f88d/output/output.wav
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.4.0
2 | aiofiles==23.1.0
3 | aiohttp==3.8.4
4 | aiosignal==1.3.1
5 | altair==4.2.2
6 | antlr4-python3-runtime==4.9.3
7 | anyio==3.6.2
8 | async-timeout==4.0.2
9 | attrs==22.2.0
10 | audioread==3.0.0
11 | backports.functools-lru-cache==1.6.4
12 | cachetools==5.3.0
13 | certifi==2022.12.7
14 | cffi==1.15.1
15 | charset-normalizer==3.1.0
16 | click==8.1.3
17 | cloudpickle==2.2.1
18 | cn2an==0.5.19
19 | colorama==0.4.6
20 | contourpy==1.0.7
21 | cycler==0.11.0
22 | Cython==0.29.33
23 | decorator==5.1.1
24 | demucs==4.0.0
25 | diffq==0.2.3
26 | dora-search==0.1.11
27 | einops==0.6.0
28 | eng-to-ipa==0.0.2
29 | entrypoints==0.4
30 | fastapi==0.95.0
31 | ffmpeg-python==0.2.0
32 | ffmpy==0.3.0
33 | filelock==3.10.7
34 | fonttools==4.39.3
35 | frozenlist==1.3.3
36 | fsspec==2023.3.0
37 | future==0.18.3
38 | google-auth==2.17.0
39 | google-auth-oauthlib==0.4.6
40 | gradio==3.23.0
41 | grpcio==1.53.0
42 | h11==0.14.0
43 | httpcore==0.16.3
44 | httpx==0.23.3
45 | huggingface-hub==0.13.3
46 | idna==3.4
47 | indic-transliteration==2.3.37
48 | inflect==6.0.2
49 | jamo==0.4.1
50 | jieba==0.42.1
51 | Jinja2==3.1.2
52 | joblib==1.2.0
53 | jsonschema==4.17.3
54 | julius==0.2.7
55 | kiwisolver==1.4.4
56 | ko-pron==1.3
57 | lameenc==1.4.2
58 | librosa==0.9.1
59 | linkify-it-py==2.0.0
60 | llvmlite==0.39.1
61 | Markdown==3.4.3
62 | markdown-it-py==2.2.0
63 | MarkupSafe==2.1.2
64 | matplotlib==3.7.1
65 | mdit-py-plugins==0.3.3
66 | mdurl==0.1.2
67 | more-itertools==9.1.0
68 | mp3play==0.1.15
69 | mpmath==1.3.0
70 | multidict==6.0.4
71 | networkx==3.0
72 | num-thai==0.0.5
73 | numba==0.56.4
74 | numpy==1.23.5
75 | oauthlib==3.2.2
76 | omegaconf==2.3.0
77 | openai == 0.27.2
78 | openai-whisper==20230314
79 | OpenCC==1.1.1
80 | openunmix==1.2.1
81 | orjson==3.8.9
82 | packaging==23.0
83 | pandas==1.5.3
84 | Pillow==9.4.0
85 | platformdirs==3.2.0
86 | playsound==1.3.0
87 | pooch==1.7.0
88 | proces==0.1.4
89 | protobuf==4.22.1
90 | pyasn1==0.4.8
91 | pyasn1-modules==0.2.8
92 | PyAudio==0.2.13
93 | pycparser==2.21
94 | pydantic==1.10.7
95 | pydub==0.25.1
96 | pyopenjtalk==0.3.0
97 | pyparsing==3.0.9
98 | pypinyin==0.48.0
99 | pyrsistent==0.19.3
100 | python-dateutil==2.8.2
101 | python-dotenv==1.0.0
102 | python-multipart==0.0.6
103 | pytz==2023.3
104 | PyYAML==6.0
105 | regex==2023.3.23
106 | requests==2.28.2
107 | requests-oauthlib==1.3.1
108 | resampy==0.4.2
109 | retrying==1.3.4
110 | rfc3986==1.5.0
111 | roman==4.0
112 | rsa==4.9
113 | scikit-learn==1.2.2
114 | scipy==1.10.1
115 | semantic-version==2.10.0
116 | six==1.16.0
117 | sniffio==1.3.0
118 | soundfile==0.12.1
119 | starlette==0.26.1
120 | submitit==1.4.5
121 | sympy==1.11.1
122 | tensorboard==2.12.0
123 | tensorboard-data-server==0.7.0
124 | tensorboard-plugin-wit==1.8.1
125 | threadpoolctl==3.1.0
126 | tiktoken==0.3.1
127 | toml==0.10.2
128 | toolz==0.12.0
129 | torch==2.0.0
130 | torchaudio==2.0.1
131 | tqdm==4.65.0
132 | treetable==0.2.5
133 | typer==0.7.0
134 | typing_extensions==4.5.0
135 | uc-micro-py==1.0.1
136 | Unidecode==1.3.6
137 | urllib3==1.26.15
138 | uvicorn==0.21.1
139 | websockets==10.4
140 | Werkzeug==2.2.3
141 | yarl==1.8.2
142 |
--------------------------------------------------------------------------------
/robot.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import baiduApi
3 | import gptApi
4 | from playsound import playsound
5 | import gptApi_duolun
6 | from VITS.VC_inference import getVoice
7 | from dotenv.main import load_dotenv
8 | import os
9 |
10 | if __name__ == "__main__":
11 | load_dotenv()
12 |
13 | api_id = os.environ['api_id']
14 | api_key = os.environ['api_key']
15 | api_secert = os.environ['api_secert']
16 | bdr = baiduApi.BaiduRest(api_id, api_key, api_secert)
17 | # robot = gptApi.GptApi()
18 | messages = [{"role": "system", "content": "你现在是很有用的助手!"}]
19 | while True:
20 | input("按下回车开始说话,自动停止")
21 | print('开始录音')
22 | bdr.recorder("./input/input.wav")
23 | print("结束")
24 | ask = bdr.getText('./input/input.wav')
25 | print('你:', ask)
26 |
27 | # robot=gptApi.GptApi()
28 | # ans = robot.test_openai(ask)
29 | messages.append({"role": "user", "content": ask})
30 | ans = gptApi_duolun.generate_answer(messages)
31 | messages.append({"role": "assistant", "content": ans})
32 | print('机器人:', ans)
33 |
34 | #bdr.getVoice(ans, "./output.mp3")#调用百度api合成语音,已经弃用
35 | getVoice(ans.lstrip(),'./output/output.wav')#默认路径:./output/output.mp3,同时删除文本前面空格,以免无法生成语音
36 | #bdr.speakMac("./output.mp3")
37 | playsound('./output/output.wav')
38 |
--------------------------------------------------------------------------------